# app/services/memory/cache/embedding_cache.py """ Embedding Cache. Caches embeddings by content hash to avoid recomputing. Provides significant performance improvement for repeated content. """ import hashlib import logging import threading from collections import OrderedDict from dataclasses import dataclass from datetime import UTC, datetime from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from redis.asyncio import Redis logger = logging.getLogger(__name__) def _utcnow() -> datetime: """Get current UTC time as timezone-aware datetime.""" return datetime.now(UTC) @dataclass class EmbeddingEntry: """A cached embedding entry.""" embedding: list[float] content_hash: str model: str created_at: datetime ttl_seconds: float = 3600.0 # 1 hour default def is_expired(self) -> bool: """Check if this entry has expired.""" age = (_utcnow() - self.created_at).total_seconds() return age > self.ttl_seconds @dataclass class EmbeddingCacheStats: """Statistics for the embedding cache.""" hits: int = 0 misses: int = 0 evictions: int = 0 expirations: int = 0 current_size: int = 0 max_size: int = 0 bytes_saved: int = 0 # Estimated bytes saved by caching @property def hit_rate(self) -> float: """Calculate cache hit rate.""" total = self.hits + self.misses if total == 0: return 0.0 return self.hits / total def to_dict(self) -> dict[str, Any]: """Convert to dictionary.""" return { "hits": self.hits, "misses": self.misses, "evictions": self.evictions, "expirations": self.expirations, "current_size": self.current_size, "max_size": self.max_size, "hit_rate": self.hit_rate, "bytes_saved": self.bytes_saved, } class EmbeddingCache: """ Cache for embeddings by content hash. Features: - Content-hash based deduplication - LRU eviction - TTL-based expiration - Optional Redis backing for persistence - Thread-safe operations Performance targets: - Cache hit rate > 90% for repeated content - Get/put operations < 1ms (memory), < 5ms (Redis) """ def __init__( self, max_size: int = 50000, default_ttl_seconds: float = 3600.0, redis: "Redis | None" = None, redis_prefix: str = "mem:emb", ) -> None: """ Initialize the embedding cache. Args: max_size: Maximum number of entries in memory cache default_ttl_seconds: Default TTL for entries (1 hour) redis: Optional Redis connection for persistence redis_prefix: Prefix for Redis keys """ self._max_size = max_size self._default_ttl = default_ttl_seconds self._cache: OrderedDict[str, EmbeddingEntry] = OrderedDict() self._lock = threading.RLock() self._stats = EmbeddingCacheStats(max_size=max_size) self._redis = redis self._redis_prefix = redis_prefix logger.info( f"Initialized EmbeddingCache with max_size={max_size}, " f"ttl={default_ttl_seconds}s, redis={'enabled' if redis else 'disabled'}" ) def set_redis(self, redis: "Redis") -> None: """Set Redis connection for persistence.""" self._redis = redis @staticmethod def hash_content(content: str) -> str: """ Compute hash of content for cache key. Args: content: Content to hash Returns: 32-character hex hash """ return hashlib.sha256(content.encode()).hexdigest()[:32] def _cache_key(self, content_hash: str, model: str) -> str: """Build cache key from content hash and model.""" return f"{content_hash}:{model}" def _redis_key(self, content_hash: str, model: str) -> str: """Build Redis key from content hash and model.""" return f"{self._redis_prefix}:{content_hash}:{model}" async def get( self, content: str, model: str = "default", ) -> list[float] | None: """ Get a cached embedding. Args: content: Content text model: Model name Returns: Cached embedding or None if not found/expired """ content_hash = self.hash_content(content) cache_key = self._cache_key(content_hash, model) # Check memory cache first with self._lock: if cache_key in self._cache: entry = self._cache[cache_key] if entry.is_expired(): del self._cache[cache_key] self._stats.expirations += 1 self._stats.current_size = len(self._cache) else: # Move to end (most recently used) self._cache.move_to_end(cache_key) self._stats.hits += 1 return entry.embedding # Check Redis if available if self._redis: try: redis_key = self._redis_key(content_hash, model) data = await self._redis.get(redis_key) if data: import json embedding = json.loads(data) # Store in memory cache for faster access self._put_memory(content_hash, model, embedding) self._stats.hits += 1 return embedding except Exception as e: logger.warning(f"Redis get error: {e}") self._stats.misses += 1 return None async def get_by_hash( self, content_hash: str, model: str = "default", ) -> list[float] | None: """ Get a cached embedding by hash. Args: content_hash: Content hash model: Model name Returns: Cached embedding or None if not found/expired """ cache_key = self._cache_key(content_hash, model) with self._lock: if cache_key in self._cache: entry = self._cache[cache_key] if entry.is_expired(): del self._cache[cache_key] self._stats.expirations += 1 self._stats.current_size = len(self._cache) else: self._cache.move_to_end(cache_key) self._stats.hits += 1 return entry.embedding # Check Redis if self._redis: try: redis_key = self._redis_key(content_hash, model) data = await self._redis.get(redis_key) if data: import json embedding = json.loads(data) self._put_memory(content_hash, model, embedding) self._stats.hits += 1 return embedding except Exception as e: logger.warning(f"Redis get error: {e}") self._stats.misses += 1 return None async def put( self, content: str, embedding: list[float], model: str = "default", ttl_seconds: float | None = None, ) -> str: """ Cache an embedding. Args: content: Content text embedding: Embedding vector model: Model name ttl_seconds: Optional TTL override Returns: Content hash """ content_hash = self.hash_content(content) ttl = ttl_seconds or self._default_ttl # Store in memory self._put_memory(content_hash, model, embedding, ttl) # Store in Redis if available if self._redis: try: import json redis_key = self._redis_key(content_hash, model) await self._redis.setex( redis_key, int(ttl), json.dumps(embedding), ) except Exception as e: logger.warning(f"Redis put error: {e}") return content_hash def _put_memory( self, content_hash: str, model: str, embedding: list[float], ttl_seconds: float | None = None, ) -> None: """Store in memory cache.""" with self._lock: # Evict if at capacity self._evict_if_needed() cache_key = self._cache_key(content_hash, model) entry = EmbeddingEntry( embedding=embedding, content_hash=content_hash, model=model, created_at=_utcnow(), ttl_seconds=ttl_seconds or self._default_ttl, ) self._cache[cache_key] = entry self._cache.move_to_end(cache_key) self._stats.current_size = len(self._cache) def _evict_if_needed(self) -> None: """Evict entries if cache is at capacity.""" while len(self._cache) >= self._max_size: if self._cache: self._cache.popitem(last=False) self._stats.evictions += 1 async def put_batch( self, items: list[tuple[str, list[float]]], model: str = "default", ttl_seconds: float | None = None, ) -> list[str]: """ Cache multiple embeddings. Args: items: List of (content, embedding) tuples model: Model name ttl_seconds: Optional TTL override Returns: List of content hashes """ hashes = [] for content, embedding in items: content_hash = await self.put(content, embedding, model, ttl_seconds) hashes.append(content_hash) return hashes async def invalidate( self, content: str, model: str = "default", ) -> bool: """ Invalidate a cached embedding. Args: content: Content text model: Model name Returns: True if entry was found and removed """ content_hash = self.hash_content(content) return await self.invalidate_by_hash(content_hash, model) async def invalidate_by_hash( self, content_hash: str, model: str = "default", ) -> bool: """ Invalidate a cached embedding by hash. Args: content_hash: Content hash model: Model name Returns: True if entry was found and removed """ cache_key = self._cache_key(content_hash, model) removed = False with self._lock: if cache_key in self._cache: del self._cache[cache_key] self._stats.current_size = len(self._cache) removed = True # Remove from Redis if self._redis: try: redis_key = self._redis_key(content_hash, model) await self._redis.delete(redis_key) removed = True except Exception as e: logger.warning(f"Redis delete error: {e}") return removed async def invalidate_by_model(self, model: str) -> int: """ Invalidate all embeddings for a model. Args: model: Model name Returns: Number of entries invalidated """ count = 0 with self._lock: keys_to_remove = [ k for k, v in self._cache.items() if v.model == model ] for key in keys_to_remove: del self._cache[key] count += 1 self._stats.current_size = len(self._cache) # Note: Redis pattern deletion would require SCAN which is expensive # For now, we only clear memory cache for model-based invalidation return count async def clear(self) -> int: """ Clear all cache entries. Returns: Number of entries cleared """ with self._lock: count = len(self._cache) self._cache.clear() self._stats.current_size = 0 # Clear Redis entries if self._redis: try: pattern = f"{self._redis_prefix}:*" deleted = 0 async for key in self._redis.scan_iter(match=pattern): await self._redis.delete(key) deleted += 1 count = max(count, deleted) except Exception as e: logger.warning(f"Redis clear error: {e}") logger.info(f"Cleared {count} entries from embedding cache") return count def cleanup_expired(self) -> int: """ Remove all expired entries from memory cache. Returns: Number of entries removed """ with self._lock: keys_to_remove = [ k for k, v in self._cache.items() if v.is_expired() ] for key in keys_to_remove: del self._cache[key] self._stats.expirations += 1 self._stats.current_size = len(self._cache) if keys_to_remove: logger.debug(f"Cleaned up {len(keys_to_remove)} expired embeddings") return len(keys_to_remove) def get_stats(self) -> EmbeddingCacheStats: """Get cache statistics.""" with self._lock: self._stats.current_size = len(self._cache) return self._stats def reset_stats(self) -> None: """Reset cache statistics.""" with self._lock: self._stats = EmbeddingCacheStats( max_size=self._max_size, current_size=len(self._cache), ) @property def size(self) -> int: """Get current cache size.""" return len(self._cache) @property def max_size(self) -> int: """Get maximum cache size.""" return self._max_size class CachedEmbeddingGenerator: """ Wrapper for embedding generators with caching. Wraps an embedding generator to cache results. """ def __init__( self, generator: Any, cache: EmbeddingCache, model: str = "default", ) -> None: """ Initialize the cached embedding generator. Args: generator: Underlying embedding generator cache: Embedding cache model: Model name for cache keys """ self._generator = generator self._cache = cache self._model = model self._call_count = 0 self._cache_hit_count = 0 async def generate(self, text: str) -> list[float]: """ Generate embedding with caching. Args: text: Text to embed Returns: Embedding vector """ self._call_count += 1 # Check cache first cached = await self._cache.get(text, self._model) if cached is not None: self._cache_hit_count += 1 return cached # Generate and cache embedding = await self._generator.generate(text) await self._cache.put(text, embedding, self._model) return embedding async def generate_batch( self, texts: list[str], ) -> list[list[float]]: """ Generate embeddings for multiple texts with caching. Args: texts: Texts to embed Returns: List of embedding vectors """ results: list[list[float] | None] = [None] * len(texts) to_generate: list[tuple[int, str]] = [] # Check cache for each text for i, text in enumerate(texts): cached = await self._cache.get(text, self._model) if cached is not None: results[i] = cached self._cache_hit_count += 1 else: to_generate.append((i, text)) self._call_count += len(texts) # Generate missing embeddings if to_generate: if hasattr(self._generator, "generate_batch"): texts_to_gen = [t for _, t in to_generate] embeddings = await self._generator.generate_batch(texts_to_gen) for (idx, text), embedding in zip(to_generate, embeddings, strict=True): results[idx] = embedding await self._cache.put(text, embedding, self._model) else: # Fallback to individual generation for idx, text in to_generate: embedding = await self._generator.generate(text) results[idx] = embedding await self._cache.put(text, embedding, self._model) return results # type: ignore[return-value] def get_stats(self) -> dict[str, Any]: """Get generator statistics.""" return { "call_count": self._call_count, "cache_hit_count": self._cache_hit_count, "cache_hit_rate": ( self._cache_hit_count / self._call_count if self._call_count > 0 else 0.0 ), "cache_stats": self._cache.get_stats().to_dict(), } # Factory function def create_embedding_cache( max_size: int = 50000, default_ttl_seconds: float = 3600.0, redis: "Redis | None" = None, ) -> EmbeddingCache: """ Create an embedding cache. Args: max_size: Maximum number of entries default_ttl_seconds: Default TTL for entries redis: Optional Redis connection Returns: Configured EmbeddingCache instance """ return EmbeddingCache( max_size=max_size, default_ttl_seconds=default_ttl_seconds, redis=redis, )