Skip to content
42 changes: 42 additions & 0 deletions redisvl/extensions/cache/embeddings/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
class EmbeddingsCache(BaseCache):
"""Embeddings Cache for storing embedding vectors with exact key matching."""

_warning_shown: bool = False # Class-level flag to prevent warning spam

def __init__(
self,
name: str = "embedcache",
Expand Down Expand Up @@ -124,6 +126,14 @@ def _process_cache_data(
cache_hit = CacheEntry(**convert_bytes(data))
return cache_hit.model_dump(exclude_none=True)

def _should_warn_for_async_only(self) -> bool:
"""Check if warning should be shown for async-only client usage.

Returns:
bool: True if only async client is available and warning hasn't been shown.
"""
return self._owns_redis_client is False and self._redis_client is None

def get(
self,
text: str,
Expand Down Expand Up @@ -167,6 +177,14 @@ def get_by_key(self, key: str) -> Optional[Dict[str, Any]]:

embedding_data = cache.get_by_key("embedcache:1234567890abcdef")
"""
if self._should_warn_for_async_only():
if not EmbeddingsCache._warning_shown:
logger.warning(
"EmbeddingsCache initialized with async_redis_client only. "
"Use async methods (aget_by_key) instead of sync methods (get_by_key)."
)
EmbeddingsCache._warning_shown = True

client = self._get_redis_client()

# Get all fields
Expand Down Expand Up @@ -202,6 +220,14 @@ def mget_by_keys(self, keys: List[str]) -> List[Optional[Dict[str, Any]]]:
if not keys:
return []

if self._should_warn_for_async_only():
if not EmbeddingsCache._warning_shown:
logger.warning(
"EmbeddingsCache initialized with async_redis_client only. "
"Use async methods (amget_by_keys) instead of sync methods (mget_by_keys)."
)
EmbeddingsCache._warning_shown = True

client = self._get_redis_client()

with client.pipeline(transaction=False) as pipeline:
Expand Down Expand Up @@ -283,6 +309,14 @@ def set(
text, model_name, embedding, metadata
)

if self._should_warn_for_async_only():
if not EmbeddingsCache._warning_shown:
logger.warning(
"EmbeddingsCache initialized with async_redis_client only. "
"Use async methods (aset) instead of sync methods (set)."
)
EmbeddingsCache._warning_shown = True

# Store in Redis
client = self._get_redis_client()
client.hset(name=key, mapping=cache_entry) # type: ignore
Expand Down Expand Up @@ -333,6 +367,14 @@ def mset(
if not items:
return []

if self._should_warn_for_async_only():
if not EmbeddingsCache._warning_shown:
logger.warning(
"EmbeddingsCache initialized with async_redis_client only. "
"Use async methods (amset) instead of sync methods (mset)."
)
EmbeddingsCache._warning_shown = True

client = self._get_redis_client()
keys = []

Expand Down
88 changes: 88 additions & 0 deletions tests/integration/test_embedcache_warnings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Test warning behavior when using sync methods with async-only client."""

import logging
from unittest.mock import patch

import pytest
from redis import Redis

from redisvl.extensions.cache.embeddings import EmbeddingsCache


@pytest.fixture(autouse=True)
def reset_warning_flag():
"""Reset the warning flag before each test to ensure test isolation."""
EmbeddingsCache._warning_shown = False
yield
# Optionally reset after test as well for cleanup
EmbeddingsCache._warning_shown = False


@pytest.mark.asyncio
async def test_sync_methods_warn_with_async_only_client(async_client, caplog):
"""Test that sync methods warn when only async client is provided."""
# Initialize EmbeddingsCache with only async_redis_client
cache = EmbeddingsCache(name="test_cache", async_redis_client=async_client)

# Mock _get_redis_client to prevent actual connection attempt
with patch.object(cache, "_get_redis_client") as mock_get_client:
# Mock the Redis client methods that would be called
mock_client = mock_get_client.return_value
mock_client.hgetall.return_value = {} # Empty result for get_by_key
mock_client.hset.return_value = 1 # Success for set

# Capture log warnings
with caplog.at_level(logging.WARNING):
# First sync method call should warn
_ = cache.get_by_key("test_key")

# Check warning was logged
assert len(caplog.records) == 1
assert (
"initialized with async_redis_client only" in caplog.records[0].message
)
assert "Use async methods" in caplog.records[0].message

# Clear captured logs
caplog.clear()

# Second sync method call should NOT warn (flag prevents spam)
_ = cache.set(text="test", model_name="model", embedding=[0.1, 0.2])

# Should not have logged another warning
assert len(caplog.records) == 0


def test_no_warning_with_sync_client(redis_url):
"""Test that no warning is shown when sync client is provided."""
# Create sync redis client from redis_url
sync_client = Redis.from_url(redis_url)

try:
# Initialize EmbeddingsCache with sync_redis_client
cache = EmbeddingsCache(name="test_cache", redis_client=sync_client)

with patch("redisvl.utils.log.get_logger") as mock_logger:
# Sync methods should not warn
_ = cache.get_by_key("test_key")
_ = cache.set(text="test", model_name="model", embedding=[0.1, 0.2])

# No warnings should have been logged
mock_logger.return_value.warning.assert_not_called()
finally:
sync_client.close()


@pytest.mark.asyncio
async def test_async_methods_no_warning(async_client):
"""Test that async methods don't trigger warnings."""
# Initialize EmbeddingsCache with only async_redis_client
cache = EmbeddingsCache(name="test_cache", async_redis_client=async_client)

with patch("redisvl.utils.log.get_logger") as mock_logger:
# Async methods should not warn
_ = await cache.aget_by_key("test_key")
_ = await cache.aset(text="test", model_name="model", embedding=[0.1, 0.2])

# No warnings should have been logged
mock_logger.return_value.warning.assert_not_called()
Loading