Files
syndarix/backend/app/services/memory/cache/embedding_cache.py
Felipe Cardoso 6954774e36 feat(memory): implement caching layer for memory operations (#98)
Add comprehensive caching layer for the Agent Memory System:

- HotMemoryCache: LRU cache for frequently accessed memories
  - Python 3.12 type parameter syntax
  - Thread-safe operations with RLock
  - TTL-based expiration
  - Access count tracking for hot memory identification
  - Scoped invalidation by type, scope, or pattern

- EmbeddingCache: Cache embeddings by content hash
  - Content-hash based deduplication
  - Optional Redis backing for persistence
  - LRU eviction with configurable max size
  - CachedEmbeddingGenerator wrapper for transparent caching

- CacheManager: Unified cache management
  - Coordinates hot cache, embedding cache, and retrieval cache
  - Centralized invalidation across all caches
  - Aggregated statistics and hit rate tracking
  - Automatic cleanup scheduling
  - Cache warmup support

Performance targets:
- Cache hit rate > 80% for hot memories
- Cache operations < 1ms (memory), < 5ms (Redis)

83 new tests with comprehensive coverage.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 04:04:13 +01:00

628 lines
18 KiB
Python

# app/services/memory/cache/embedding_cache.py
"""
Embedding Cache.
Caches embeddings by content hash to avoid recomputing.
Provides significant performance improvement for repeated content.
"""
import hashlib
import logging
import threading
from collections import OrderedDict
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from redis.asyncio import Redis
logger = logging.getLogger(__name__)
def _utcnow() -> datetime:
"""Get current UTC time as timezone-aware datetime."""
return datetime.now(UTC)
@dataclass
class EmbeddingEntry:
"""A cached embedding entry."""
embedding: list[float]
content_hash: str
model: str
created_at: datetime
ttl_seconds: float = 3600.0 # 1 hour default
def is_expired(self) -> bool:
"""Check if this entry has expired."""
age = (_utcnow() - self.created_at).total_seconds()
return age > self.ttl_seconds
@dataclass
class EmbeddingCacheStats:
"""Statistics for the embedding cache."""
hits: int = 0
misses: int = 0
evictions: int = 0
expirations: int = 0
current_size: int = 0
max_size: int = 0
bytes_saved: int = 0 # Estimated bytes saved by caching
@property
def hit_rate(self) -> float:
"""Calculate cache hit rate."""
total = self.hits + self.misses
if total == 0:
return 0.0
return self.hits / total
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"hits": self.hits,
"misses": self.misses,
"evictions": self.evictions,
"expirations": self.expirations,
"current_size": self.current_size,
"max_size": self.max_size,
"hit_rate": self.hit_rate,
"bytes_saved": self.bytes_saved,
}
class EmbeddingCache:
"""
Cache for embeddings by content hash.
Features:
- Content-hash based deduplication
- LRU eviction
- TTL-based expiration
- Optional Redis backing for persistence
- Thread-safe operations
Performance targets:
- Cache hit rate > 90% for repeated content
- Get/put operations < 1ms (memory), < 5ms (Redis)
"""
def __init__(
self,
max_size: int = 50000,
default_ttl_seconds: float = 3600.0,
redis: "Redis | None" = None,
redis_prefix: str = "mem:emb",
) -> None:
"""
Initialize the embedding cache.
Args:
max_size: Maximum number of entries in memory cache
default_ttl_seconds: Default TTL for entries (1 hour)
redis: Optional Redis connection for persistence
redis_prefix: Prefix for Redis keys
"""
self._max_size = max_size
self._default_ttl = default_ttl_seconds
self._cache: OrderedDict[str, EmbeddingEntry] = OrderedDict()
self._lock = threading.RLock()
self._stats = EmbeddingCacheStats(max_size=max_size)
self._redis = redis
self._redis_prefix = redis_prefix
logger.info(
f"Initialized EmbeddingCache with max_size={max_size}, "
f"ttl={default_ttl_seconds}s, redis={'enabled' if redis else 'disabled'}"
)
def set_redis(self, redis: "Redis") -> None:
"""Set Redis connection for persistence."""
self._redis = redis
@staticmethod
def hash_content(content: str) -> str:
"""
Compute hash of content for cache key.
Args:
content: Content to hash
Returns:
32-character hex hash
"""
return hashlib.sha256(content.encode()).hexdigest()[:32]
def _cache_key(self, content_hash: str, model: str) -> str:
"""Build cache key from content hash and model."""
return f"{content_hash}:{model}"
def _redis_key(self, content_hash: str, model: str) -> str:
"""Build Redis key from content hash and model."""
return f"{self._redis_prefix}:{content_hash}:{model}"
async def get(
self,
content: str,
model: str = "default",
) -> list[float] | None:
"""
Get a cached embedding.
Args:
content: Content text
model: Model name
Returns:
Cached embedding or None if not found/expired
"""
content_hash = self.hash_content(content)
cache_key = self._cache_key(content_hash, model)
# Check memory cache first
with self._lock:
if cache_key in self._cache:
entry = self._cache[cache_key]
if entry.is_expired():
del self._cache[cache_key]
self._stats.expirations += 1
self._stats.current_size = len(self._cache)
else:
# Move to end (most recently used)
self._cache.move_to_end(cache_key)
self._stats.hits += 1
return entry.embedding
# Check Redis if available
if self._redis:
try:
redis_key = self._redis_key(content_hash, model)
data = await self._redis.get(redis_key)
if data:
import json
embedding = json.loads(data)
# Store in memory cache for faster access
self._put_memory(content_hash, model, embedding)
self._stats.hits += 1
return embedding
except Exception as e:
logger.warning(f"Redis get error: {e}")
self._stats.misses += 1
return None
async def get_by_hash(
self,
content_hash: str,
model: str = "default",
) -> list[float] | None:
"""
Get a cached embedding by hash.
Args:
content_hash: Content hash
model: Model name
Returns:
Cached embedding or None if not found/expired
"""
cache_key = self._cache_key(content_hash, model)
with self._lock:
if cache_key in self._cache:
entry = self._cache[cache_key]
if entry.is_expired():
del self._cache[cache_key]
self._stats.expirations += 1
self._stats.current_size = len(self._cache)
else:
self._cache.move_to_end(cache_key)
self._stats.hits += 1
return entry.embedding
# Check Redis
if self._redis:
try:
redis_key = self._redis_key(content_hash, model)
data = await self._redis.get(redis_key)
if data:
import json
embedding = json.loads(data)
self._put_memory(content_hash, model, embedding)
self._stats.hits += 1
return embedding
except Exception as e:
logger.warning(f"Redis get error: {e}")
self._stats.misses += 1
return None
async def put(
self,
content: str,
embedding: list[float],
model: str = "default",
ttl_seconds: float | None = None,
) -> str:
"""
Cache an embedding.
Args:
content: Content text
embedding: Embedding vector
model: Model name
ttl_seconds: Optional TTL override
Returns:
Content hash
"""
content_hash = self.hash_content(content)
ttl = ttl_seconds or self._default_ttl
# Store in memory
self._put_memory(content_hash, model, embedding, ttl)
# Store in Redis if available
if self._redis:
try:
import json
redis_key = self._redis_key(content_hash, model)
await self._redis.setex(
redis_key,
int(ttl),
json.dumps(embedding),
)
except Exception as e:
logger.warning(f"Redis put error: {e}")
return content_hash
def _put_memory(
self,
content_hash: str,
model: str,
embedding: list[float],
ttl_seconds: float | None = None,
) -> None:
"""Store in memory cache."""
with self._lock:
# Evict if at capacity
self._evict_if_needed()
cache_key = self._cache_key(content_hash, model)
entry = EmbeddingEntry(
embedding=embedding,
content_hash=content_hash,
model=model,
created_at=_utcnow(),
ttl_seconds=ttl_seconds or self._default_ttl,
)
self._cache[cache_key] = entry
self._cache.move_to_end(cache_key)
self._stats.current_size = len(self._cache)
def _evict_if_needed(self) -> None:
"""Evict entries if cache is at capacity."""
while len(self._cache) >= self._max_size:
if self._cache:
self._cache.popitem(last=False)
self._stats.evictions += 1
async def put_batch(
self,
items: list[tuple[str, list[float]]],
model: str = "default",
ttl_seconds: float | None = None,
) -> list[str]:
"""
Cache multiple embeddings.
Args:
items: List of (content, embedding) tuples
model: Model name
ttl_seconds: Optional TTL override
Returns:
List of content hashes
"""
hashes = []
for content, embedding in items:
content_hash = await self.put(content, embedding, model, ttl_seconds)
hashes.append(content_hash)
return hashes
async def invalidate(
self,
content: str,
model: str = "default",
) -> bool:
"""
Invalidate a cached embedding.
Args:
content: Content text
model: Model name
Returns:
True if entry was found and removed
"""
content_hash = self.hash_content(content)
return await self.invalidate_by_hash(content_hash, model)
async def invalidate_by_hash(
self,
content_hash: str,
model: str = "default",
) -> bool:
"""
Invalidate a cached embedding by hash.
Args:
content_hash: Content hash
model: Model name
Returns:
True if entry was found and removed
"""
cache_key = self._cache_key(content_hash, model)
removed = False
with self._lock:
if cache_key in self._cache:
del self._cache[cache_key]
self._stats.current_size = len(self._cache)
removed = True
# Remove from Redis
if self._redis:
try:
redis_key = self._redis_key(content_hash, model)
await self._redis.delete(redis_key)
removed = True
except Exception as e:
logger.warning(f"Redis delete error: {e}")
return removed
async def invalidate_by_model(self, model: str) -> int:
"""
Invalidate all embeddings for a model.
Args:
model: Model name
Returns:
Number of entries invalidated
"""
count = 0
with self._lock:
keys_to_remove = [
k for k, v in self._cache.items() if v.model == model
]
for key in keys_to_remove:
del self._cache[key]
count += 1
self._stats.current_size = len(self._cache)
# Note: Redis pattern deletion would require SCAN which is expensive
# For now, we only clear memory cache for model-based invalidation
return count
async def clear(self) -> int:
"""
Clear all cache entries.
Returns:
Number of entries cleared
"""
with self._lock:
count = len(self._cache)
self._cache.clear()
self._stats.current_size = 0
# Clear Redis entries
if self._redis:
try:
pattern = f"{self._redis_prefix}:*"
deleted = 0
async for key in self._redis.scan_iter(match=pattern):
await self._redis.delete(key)
deleted += 1
count = max(count, deleted)
except Exception as e:
logger.warning(f"Redis clear error: {e}")
logger.info(f"Cleared {count} entries from embedding cache")
return count
def cleanup_expired(self) -> int:
"""
Remove all expired entries from memory cache.
Returns:
Number of entries removed
"""
with self._lock:
keys_to_remove = [
k for k, v in self._cache.items() if v.is_expired()
]
for key in keys_to_remove:
del self._cache[key]
self._stats.expirations += 1
self._stats.current_size = len(self._cache)
if keys_to_remove:
logger.debug(f"Cleaned up {len(keys_to_remove)} expired embeddings")
return len(keys_to_remove)
def get_stats(self) -> EmbeddingCacheStats:
"""Get cache statistics."""
with self._lock:
self._stats.current_size = len(self._cache)
return self._stats
def reset_stats(self) -> None:
"""Reset cache statistics."""
with self._lock:
self._stats = EmbeddingCacheStats(
max_size=self._max_size,
current_size=len(self._cache),
)
@property
def size(self) -> int:
"""Get current cache size."""
return len(self._cache)
@property
def max_size(self) -> int:
"""Get maximum cache size."""
return self._max_size
class CachedEmbeddingGenerator:
"""
Wrapper for embedding generators with caching.
Wraps an embedding generator to cache results.
"""
def __init__(
self,
generator: Any,
cache: EmbeddingCache,
model: str = "default",
) -> None:
"""
Initialize the cached embedding generator.
Args:
generator: Underlying embedding generator
cache: Embedding cache
model: Model name for cache keys
"""
self._generator = generator
self._cache = cache
self._model = model
self._call_count = 0
self._cache_hit_count = 0
async def generate(self, text: str) -> list[float]:
"""
Generate embedding with caching.
Args:
text: Text to embed
Returns:
Embedding vector
"""
self._call_count += 1
# Check cache first
cached = await self._cache.get(text, self._model)
if cached is not None:
self._cache_hit_count += 1
return cached
# Generate and cache
embedding = await self._generator.generate(text)
await self._cache.put(text, embedding, self._model)
return embedding
async def generate_batch(
self,
texts: list[str],
) -> list[list[float]]:
"""
Generate embeddings for multiple texts with caching.
Args:
texts: Texts to embed
Returns:
List of embedding vectors
"""
results: list[list[float] | None] = [None] * len(texts)
to_generate: list[tuple[int, str]] = []
# Check cache for each text
for i, text in enumerate(texts):
cached = await self._cache.get(text, self._model)
if cached is not None:
results[i] = cached
self._cache_hit_count += 1
else:
to_generate.append((i, text))
self._call_count += len(texts)
# Generate missing embeddings
if to_generate:
if hasattr(self._generator, "generate_batch"):
texts_to_gen = [t for _, t in to_generate]
embeddings = await self._generator.generate_batch(texts_to_gen)
for (idx, text), embedding in zip(to_generate, embeddings, strict=True):
results[idx] = embedding
await self._cache.put(text, embedding, self._model)
else:
# Fallback to individual generation
for idx, text in to_generate:
embedding = await self._generator.generate(text)
results[idx] = embedding
await self._cache.put(text, embedding, self._model)
return results # type: ignore[return-value]
def get_stats(self) -> dict[str, Any]:
"""Get generator statistics."""
return {
"call_count": self._call_count,
"cache_hit_count": self._cache_hit_count,
"cache_hit_rate": (
self._cache_hit_count / self._call_count
if self._call_count > 0
else 0.0
),
"cache_stats": self._cache.get_stats().to_dict(),
}
# Factory function
def create_embedding_cache(
max_size: int = 50000,
default_ttl_seconds: float = 3600.0,
redis: "Redis | None" = None,
) -> EmbeddingCache:
"""
Create an embedding cache.
Args:
max_size: Maximum number of entries
default_ttl_seconds: Default TTL for entries
redis: Optional Redis connection
Returns:
Configured EmbeddingCache instance
"""
return EmbeddingCache(
max_size=max_size,
default_ttl_seconds=default_ttl_seconds,
redis=redis,
)