Skip to content

Commit 82ddb58

Browse files
bsboddenowaisGitHub CopilotCopilot
authored
feat(cache): add warnings when using sync methods with async-only Redis client (#391)
Co-authored-by: owais <[email protected]> Co-authored-by: GitHub Copilot <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 76c74a0 commit 82ddb58

File tree

2 files changed

+130
-0
lines changed

2 files changed

+130
-0
lines changed

redisvl/extensions/cache/embeddings/embeddings.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
class EmbeddingsCache(BaseCache):
1515
"""Embeddings Cache for storing embedding vectors with exact key matching."""
1616

17+
_warning_shown: bool = False # Class-level flag to prevent warning spam
18+
1719
def __init__(
1820
self,
1921
name: str = "embedcache",
@@ -124,6 +126,14 @@ def _process_cache_data(
124126
cache_hit = CacheEntry(**convert_bytes(data))
125127
return cache_hit.model_dump(exclude_none=True)
126128

129+
def _should_warn_for_async_only(self) -> bool:
130+
"""Check if only async client is available (no sync client).
131+
132+
Returns:
133+
bool: True if only async client is available (no sync client).
134+
"""
135+
return self._owns_redis_client is False and self._redis_client is None
136+
127137
def get(
128138
self,
129139
text: str,
@@ -167,6 +177,14 @@ def get_by_key(self, key: str) -> Optional[Dict[str, Any]]:
167177
168178
embedding_data = cache.get_by_key("embedcache:1234567890abcdef")
169179
"""
180+
if self._should_warn_for_async_only():
181+
if not EmbeddingsCache._warning_shown:
182+
logger.warning(
183+
"EmbeddingsCache initialized with async_redis_client only. "
184+
"Use async methods (aget_by_key) instead of sync methods (get_by_key)."
185+
)
186+
EmbeddingsCache._warning_shown = True
187+
170188
client = self._get_redis_client()
171189

172190
# Get all fields
@@ -202,6 +220,14 @@ def mget_by_keys(self, keys: List[str]) -> List[Optional[Dict[str, Any]]]:
202220
if not keys:
203221
return []
204222

223+
if self._should_warn_for_async_only():
224+
if not EmbeddingsCache._warning_shown:
225+
logger.warning(
226+
"EmbeddingsCache initialized with async_redis_client only. "
227+
"Use async methods (amget_by_keys) instead of sync methods (mget_by_keys)."
228+
)
229+
EmbeddingsCache._warning_shown = True
230+
205231
client = self._get_redis_client()
206232

207233
with client.pipeline(transaction=False) as pipeline:
@@ -283,6 +309,14 @@ def set(
283309
text, model_name, embedding, metadata
284310
)
285311

312+
if self._should_warn_for_async_only():
313+
if not EmbeddingsCache._warning_shown:
314+
logger.warning(
315+
"EmbeddingsCache initialized with async_redis_client only. "
316+
"Use async methods (aset) instead of sync methods (set)."
317+
)
318+
EmbeddingsCache._warning_shown = True
319+
286320
# Store in Redis
287321
client = self._get_redis_client()
288322
client.hset(name=key, mapping=cache_entry) # type: ignore
@@ -333,6 +367,14 @@ def mset(
333367
if not items:
334368
return []
335369

370+
if self._should_warn_for_async_only():
371+
if not EmbeddingsCache._warning_shown:
372+
logger.warning(
373+
"EmbeddingsCache initialized with async_redis_client only. "
374+
"Use async methods (amset) instead of sync methods (mset)."
375+
)
376+
EmbeddingsCache._warning_shown = True
377+
336378
client = self._get_redis_client()
337379
keys = []
338380

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""Test warning behavior when using sync methods with async-only client."""
2+
3+
import logging
4+
from unittest.mock import patch
5+
6+
import pytest
7+
from redis import Redis
8+
9+
from redisvl.extensions.cache.embeddings import EmbeddingsCache
10+
11+
12+
@pytest.fixture(autouse=True)
13+
def reset_warning_flag():
14+
"""Reset the warning flag before each test to ensure test isolation."""
15+
EmbeddingsCache._warning_shown = False
16+
yield
17+
# Optionally reset after test as well for cleanup
18+
EmbeddingsCache._warning_shown = False
19+
20+
21+
@pytest.mark.asyncio
22+
async def test_sync_methods_warn_with_async_only_client(async_client, caplog):
23+
"""Test that sync methods warn when only async client is provided."""
24+
# Initialize EmbeddingsCache with only async_redis_client
25+
cache = EmbeddingsCache(name="test_cache", async_redis_client=async_client)
26+
27+
# Mock _get_redis_client to prevent actual connection attempt
28+
with patch.object(cache, "_get_redis_client") as mock_get_client:
29+
# Mock the Redis client methods that would be called
30+
mock_client = mock_get_client.return_value
31+
mock_client.hgetall.return_value = {} # Empty result for get_by_key
32+
mock_client.hset.return_value = 1 # Success for set
33+
34+
# Capture log warnings
35+
with caplog.at_level(logging.WARNING):
36+
# First sync method call should warn
37+
_ = cache.get_by_key("test_key")
38+
39+
# Check warning was logged
40+
assert len(caplog.records) == 1
41+
assert (
42+
"initialized with async_redis_client only" in caplog.records[0].message
43+
)
44+
assert "Use async methods" in caplog.records[0].message
45+
46+
# Clear captured logs
47+
caplog.clear()
48+
49+
# Second sync method call should NOT warn (flag prevents spam)
50+
_ = cache.set(text="test", model_name="model", embedding=[0.1, 0.2])
51+
52+
# Should not have logged another warning
53+
assert len(caplog.records) == 0
54+
55+
56+
def test_no_warning_with_sync_client(redis_url):
57+
"""Test that no warning is shown when sync client is provided."""
58+
# Create sync redis client from redis_url
59+
sync_client = Redis.from_url(redis_url)
60+
61+
try:
62+
# Initialize EmbeddingsCache with sync_redis_client
63+
cache = EmbeddingsCache(name="test_cache", redis_client=sync_client)
64+
65+
with patch("redisvl.utils.log.get_logger") as mock_logger:
66+
# Sync methods should not warn
67+
_ = cache.get_by_key("test_key")
68+
_ = cache.set(text="test", model_name="model", embedding=[0.1, 0.2])
69+
70+
# No warnings should have been logged
71+
mock_logger.return_value.warning.assert_not_called()
72+
finally:
73+
sync_client.close()
74+
75+
76+
@pytest.mark.asyncio
77+
async def test_async_methods_no_warning(async_client):
78+
"""Test that async methods don't trigger warnings."""
79+
# Initialize EmbeddingsCache with only async_redis_client
80+
cache = EmbeddingsCache(name="test_cache", async_redis_client=async_client)
81+
82+
with patch("redisvl.utils.log.get_logger") as mock_logger:
83+
# Async methods should not warn
84+
_ = await cache.aget_by_key("test_key")
85+
_ = await cache.aset(text="test", model_name="model", embedding=[0.1, 0.2])
86+
87+
# No warnings should have been logged
88+
mock_logger.return_value.warning.assert_not_called()

0 commit comments

Comments
 (0)