diff --git a/backend/app/services/memory/cache/__init__.py b/backend/app/services/memory/cache/__init__.py new file mode 100644 index 0000000..e956620 --- /dev/null +++ b/backend/app/services/memory/cache/__init__.py @@ -0,0 +1,21 @@ +# app/services/memory/cache/__init__.py +""" +Memory Caching Layer. + +Provides caching for memory operations: +- Hot Memory Cache: LRU cache for frequently accessed memories +- Embedding Cache: Cache embeddings by content hash +- Cache Manager: Unified cache management with invalidation +""" + +from .cache_manager import CacheManager, CacheStats, get_cache_manager +from .embedding_cache import EmbeddingCache +from .hot_cache import HotMemoryCache + +__all__ = [ + "CacheManager", + "CacheStats", + "EmbeddingCache", + "HotMemoryCache", + "get_cache_manager", +] diff --git a/backend/app/services/memory/cache/cache_manager.py b/backend/app/services/memory/cache/cache_manager.py new file mode 100644 index 0000000..2412109 --- /dev/null +++ b/backend/app/services/memory/cache/cache_manager.py @@ -0,0 +1,500 @@ +# app/services/memory/cache/cache_manager.py +""" +Cache Manager. + +Unified cache management for memory operations. +Coordinates hot cache, embedding cache, and retrieval cache. +Provides centralized invalidation and statistics. +""" + +import logging +import threading +from dataclasses import dataclass, field +from datetime import UTC, datetime +from typing import TYPE_CHECKING, Any +from uuid import UUID + +from app.services.memory.config import get_memory_settings + +from .embedding_cache import EmbeddingCache, create_embedding_cache +from .hot_cache import CacheKey, HotMemoryCache, create_hot_cache + +if TYPE_CHECKING: + from redis.asyncio import Redis + + from app.services.memory.indexing.retrieval import RetrievalCache + +logger = logging.getLogger(__name__) + + +def _utcnow() -> datetime: + """Get current UTC time as timezone-aware datetime.""" + return datetime.now(UTC) + + +@dataclass +class CacheStats: + """Aggregated cache statistics.""" + + hot_cache: dict[str, Any] = field(default_factory=dict) + embedding_cache: dict[str, Any] = field(default_factory=dict) + retrieval_cache: dict[str, Any] = field(default_factory=dict) + overall_hit_rate: float = 0.0 + last_cleanup: datetime | None = None + cleanup_count: int = 0 + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary.""" + return { + "hot_cache": self.hot_cache, + "embedding_cache": self.embedding_cache, + "retrieval_cache": self.retrieval_cache, + "overall_hit_rate": self.overall_hit_rate, + "last_cleanup": self.last_cleanup.isoformat() if self.last_cleanup else None, + "cleanup_count": self.cleanup_count, + } + + +class CacheManager: + """ + Unified cache manager for memory operations. + + Provides: + - Centralized cache configuration + - Coordinated invalidation across caches + - Aggregated statistics + - Automatic cleanup scheduling + + Performance targets: + - Overall cache hit rate > 80% + - Cache operations < 1ms (memory), < 5ms (Redis) + """ + + def __init__( + self, + hot_cache: HotMemoryCache[Any] | None = None, + embedding_cache: EmbeddingCache | None = None, + retrieval_cache: "RetrievalCache | None" = None, + redis: "Redis | None" = None, + ) -> None: + """ + Initialize the cache manager. + + Args: + hot_cache: Optional pre-configured hot cache + embedding_cache: Optional pre-configured embedding cache + retrieval_cache: Optional pre-configured retrieval cache + redis: Optional Redis connection for persistence + """ + self._settings = get_memory_settings() + self._redis = redis + self._enabled = self._settings.cache_enabled + + # Initialize caches + if hot_cache: + self._hot_cache = hot_cache + else: + self._hot_cache = create_hot_cache( + max_size=self._settings.cache_max_items, + default_ttl_seconds=self._settings.cache_ttl_seconds, + ) + + if embedding_cache: + self._embedding_cache = embedding_cache + else: + self._embedding_cache = create_embedding_cache( + max_size=self._settings.cache_max_items, + default_ttl_seconds=self._settings.cache_ttl_seconds * 12, # 1hr for embeddings + redis=redis, + ) + + self._retrieval_cache = retrieval_cache + + # Stats tracking + self._last_cleanup: datetime | None = None + self._cleanup_count = 0 + self._lock = threading.RLock() + + logger.info( + f"Initialized CacheManager: enabled={self._enabled}, " + f"redis={'connected' if redis else 'disabled'}" + ) + + def set_redis(self, redis: "Redis") -> None: + """Set Redis connection for all caches.""" + self._redis = redis + self._embedding_cache.set_redis(redis) + + def set_retrieval_cache(self, cache: "RetrievalCache") -> None: + """Set retrieval cache instance.""" + self._retrieval_cache = cache + + @property + def is_enabled(self) -> bool: + """Check if caching is enabled.""" + return self._enabled + + @property + def hot_cache(self) -> HotMemoryCache[Any]: + """Get the hot memory cache.""" + return self._hot_cache + + @property + def embedding_cache(self) -> EmbeddingCache: + """Get the embedding cache.""" + return self._embedding_cache + + @property + def retrieval_cache(self) -> "RetrievalCache | None": + """Get the retrieval cache.""" + return self._retrieval_cache + + # ========================================================================= + # Hot Memory Cache Operations + # ========================================================================= + + def get_memory( + self, + memory_type: str, + memory_id: UUID | str, + scope: str | None = None, + ) -> Any | None: + """ + Get a memory from hot cache. + + Args: + memory_type: Type of memory + memory_id: Memory ID + scope: Optional scope + + Returns: + Cached memory or None + """ + if not self._enabled: + return None + return self._hot_cache.get_by_id(memory_type, memory_id, scope) + + def cache_memory( + self, + memory_type: str, + memory_id: UUID | str, + memory: Any, + scope: str | None = None, + ttl_seconds: float | None = None, + ) -> None: + """ + Cache a memory in hot cache. + + Args: + memory_type: Type of memory + memory_id: Memory ID + memory: Memory object + scope: Optional scope + ttl_seconds: Optional TTL override + """ + if not self._enabled: + return + self._hot_cache.put_by_id(memory_type, memory_id, memory, scope, ttl_seconds) + + # ========================================================================= + # Embedding Cache Operations + # ========================================================================= + + async def get_embedding( + self, + content: str, + model: str = "default", + ) -> list[float] | None: + """ + Get a cached embedding. + + Args: + content: Content text + model: Model name + + Returns: + Cached embedding or None + """ + if not self._enabled: + return None + return await self._embedding_cache.get(content, model) + + async def cache_embedding( + self, + content: str, + embedding: list[float], + model: str = "default", + ttl_seconds: float | None = None, + ) -> str: + """ + Cache an embedding. + + Args: + content: Content text + embedding: Embedding vector + model: Model name + ttl_seconds: Optional TTL override + + Returns: + Content hash + """ + if not self._enabled: + return EmbeddingCache.hash_content(content) + return await self._embedding_cache.put(content, embedding, model, ttl_seconds) + + # ========================================================================= + # Invalidation + # ========================================================================= + + async def invalidate_memory( + self, + memory_type: str, + memory_id: UUID | str, + scope: str | None = None, + ) -> int: + """ + Invalidate a memory across all caches. + + Args: + memory_type: Type of memory + memory_id: Memory ID + scope: Optional scope + + Returns: + Number of entries invalidated + """ + count = 0 + + # Invalidate hot cache + if self._hot_cache.invalidate_by_id(memory_type, memory_id, scope): + count += 1 + + # Invalidate retrieval cache + if self._retrieval_cache: + uuid_id = UUID(str(memory_id)) if not isinstance(memory_id, UUID) else memory_id + count += self._retrieval_cache.invalidate_by_memory(uuid_id) + + logger.debug(f"Invalidated {count} cache entries for {memory_type}:{memory_id}") + return count + + async def invalidate_by_type(self, memory_type: str) -> int: + """ + Invalidate all entries of a memory type. + + Args: + memory_type: Type of memory + + Returns: + Number of entries invalidated + """ + count = self._hot_cache.invalidate_by_type(memory_type) + + if self._retrieval_cache: + count += self._retrieval_cache.clear() + + logger.info(f"Invalidated {count} cache entries for type {memory_type}") + return count + + async def invalidate_by_scope(self, scope: str) -> int: + """ + Invalidate all entries in a scope. + + Args: + scope: Scope to invalidate (e.g., project_id) + + Returns: + Number of entries invalidated + """ + count = self._hot_cache.invalidate_by_scope(scope) + + # Retrieval cache doesn't support scope-based invalidation + # so we clear it entirely for safety + if self._retrieval_cache: + count += self._retrieval_cache.clear() + + logger.info(f"Invalidated {count} cache entries for scope {scope}") + return count + + async def invalidate_embedding( + self, + content: str, + model: str = "default", + ) -> bool: + """ + Invalidate a cached embedding. + + Args: + content: Content text + model: Model name + + Returns: + True if entry was found and removed + """ + return await self._embedding_cache.invalidate(content, model) + + async def clear_all(self) -> int: + """ + Clear all caches. + + Returns: + Total number of entries cleared + """ + count = 0 + + count += self._hot_cache.clear() + count += await self._embedding_cache.clear() + + if self._retrieval_cache: + count += self._retrieval_cache.clear() + + logger.info(f"Cleared {count} entries from all caches") + return count + + # ========================================================================= + # Cleanup + # ========================================================================= + + async def cleanup_expired(self) -> int: + """ + Clean up expired entries from all caches. + + Returns: + Number of entries cleaned up + """ + with self._lock: + count = 0 + + count += self._hot_cache.cleanup_expired() + count += self._embedding_cache.cleanup_expired() + + # Retrieval cache doesn't have a cleanup method, + # but entries expire on access + + self._last_cleanup = _utcnow() + self._cleanup_count += 1 + + if count > 0: + logger.info(f"Cleaned up {count} expired cache entries") + + return count + + # ========================================================================= + # Statistics + # ========================================================================= + + def get_stats(self) -> CacheStats: + """ + Get aggregated cache statistics. + + Returns: + CacheStats with all cache metrics + """ + hot_stats = self._hot_cache.get_stats().to_dict() + emb_stats = self._embedding_cache.get_stats().to_dict() + + retrieval_stats: dict[str, Any] = {} + if self._retrieval_cache: + retrieval_stats = self._retrieval_cache.get_stats() + + # Calculate overall hit rate + total_hits = hot_stats.get("hits", 0) + emb_stats.get("hits", 0) + total_misses = hot_stats.get("misses", 0) + emb_stats.get("misses", 0) + + if retrieval_stats: + # Retrieval cache doesn't track hits/misses the same way + pass + + total_requests = total_hits + total_misses + overall_hit_rate = total_hits / total_requests if total_requests > 0 else 0.0 + + return CacheStats( + hot_cache=hot_stats, + embedding_cache=emb_stats, + retrieval_cache=retrieval_stats, + overall_hit_rate=overall_hit_rate, + last_cleanup=self._last_cleanup, + cleanup_count=self._cleanup_count, + ) + + def get_hot_memories(self, limit: int = 10) -> list[tuple[CacheKey, int]]: + """ + Get the most frequently accessed memories. + + Args: + limit: Maximum number to return + + Returns: + List of (key, access_count) tuples + """ + return self._hot_cache.get_hot_memories(limit) + + def reset_stats(self) -> None: + """Reset all cache statistics.""" + self._hot_cache.reset_stats() + self._embedding_cache.reset_stats() + + # ========================================================================= + # Warmup + # ========================================================================= + + async def warmup( + self, + memories: list[tuple[str, UUID | str, Any]], + scope: str | None = None, + ) -> int: + """ + Warm up the hot cache with memories. + + Args: + memories: List of (memory_type, memory_id, memory) tuples + scope: Optional scope for all memories + + Returns: + Number of memories cached + """ + if not self._enabled: + return 0 + + for memory_type, memory_id, memory in memories: + self._hot_cache.put_by_id(memory_type, memory_id, memory, scope) + + logger.info(f"Warmed up cache with {len(memories)} memories") + return len(memories) + + +# Singleton instance +_cache_manager: CacheManager | None = None +_cache_manager_lock = threading.Lock() + + +def get_cache_manager( + redis: "Redis | None" = None, + reset: bool = False, +) -> CacheManager: + """ + Get the global CacheManager instance. + + Thread-safe with double-checked locking pattern. + + Args: + redis: Optional Redis connection + reset: Force create a new instance + + Returns: + CacheManager instance + """ + global _cache_manager + + if reset or _cache_manager is None: + with _cache_manager_lock: + if reset or _cache_manager is None: + _cache_manager = CacheManager(redis=redis) + + return _cache_manager + + +def reset_cache_manager() -> None: + """Reset the global cache manager instance.""" + global _cache_manager + with _cache_manager_lock: + _cache_manager = None diff --git a/backend/app/services/memory/cache/embedding_cache.py b/backend/app/services/memory/cache/embedding_cache.py new file mode 100644 index 0000000..52bea84 --- /dev/null +++ b/backend/app/services/memory/cache/embedding_cache.py @@ -0,0 +1,627 @@ +# app/services/memory/cache/embedding_cache.py +""" +Embedding Cache. + +Caches embeddings by content hash to avoid recomputing. +Provides significant performance improvement for repeated content. +""" + +import hashlib +import logging +import threading +from collections import OrderedDict +from dataclasses import dataclass +from datetime import UTC, datetime +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from redis.asyncio import Redis + +logger = logging.getLogger(__name__) + + +def _utcnow() -> datetime: + """Get current UTC time as timezone-aware datetime.""" + return datetime.now(UTC) + + +@dataclass +class EmbeddingEntry: + """A cached embedding entry.""" + + embedding: list[float] + content_hash: str + model: str + created_at: datetime + ttl_seconds: float = 3600.0 # 1 hour default + + def is_expired(self) -> bool: + """Check if this entry has expired.""" + age = (_utcnow() - self.created_at).total_seconds() + return age > self.ttl_seconds + + +@dataclass +class EmbeddingCacheStats: + """Statistics for the embedding cache.""" + + hits: int = 0 + misses: int = 0 + evictions: int = 0 + expirations: int = 0 + current_size: int = 0 + max_size: int = 0 + bytes_saved: int = 0 # Estimated bytes saved by caching + + @property + def hit_rate(self) -> float: + """Calculate cache hit rate.""" + total = self.hits + self.misses + if total == 0: + return 0.0 + return self.hits / total + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary.""" + return { + "hits": self.hits, + "misses": self.misses, + "evictions": self.evictions, + "expirations": self.expirations, + "current_size": self.current_size, + "max_size": self.max_size, + "hit_rate": self.hit_rate, + "bytes_saved": self.bytes_saved, + } + + +class EmbeddingCache: + """ + Cache for embeddings by content hash. + + Features: + - Content-hash based deduplication + - LRU eviction + - TTL-based expiration + - Optional Redis backing for persistence + - Thread-safe operations + + Performance targets: + - Cache hit rate > 90% for repeated content + - Get/put operations < 1ms (memory), < 5ms (Redis) + """ + + def __init__( + self, + max_size: int = 50000, + default_ttl_seconds: float = 3600.0, + redis: "Redis | None" = None, + redis_prefix: str = "mem:emb", + ) -> None: + """ + Initialize the embedding cache. + + Args: + max_size: Maximum number of entries in memory cache + default_ttl_seconds: Default TTL for entries (1 hour) + redis: Optional Redis connection for persistence + redis_prefix: Prefix for Redis keys + """ + self._max_size = max_size + self._default_ttl = default_ttl_seconds + self._cache: OrderedDict[str, EmbeddingEntry] = OrderedDict() + self._lock = threading.RLock() + self._stats = EmbeddingCacheStats(max_size=max_size) + self._redis = redis + self._redis_prefix = redis_prefix + + logger.info( + f"Initialized EmbeddingCache with max_size={max_size}, " + f"ttl={default_ttl_seconds}s, redis={'enabled' if redis else 'disabled'}" + ) + + def set_redis(self, redis: "Redis") -> None: + """Set Redis connection for persistence.""" + self._redis = redis + + @staticmethod + def hash_content(content: str) -> str: + """ + Compute hash of content for cache key. + + Args: + content: Content to hash + + Returns: + 32-character hex hash + """ + return hashlib.sha256(content.encode()).hexdigest()[:32] + + def _cache_key(self, content_hash: str, model: str) -> str: + """Build cache key from content hash and model.""" + return f"{content_hash}:{model}" + + def _redis_key(self, content_hash: str, model: str) -> str: + """Build Redis key from content hash and model.""" + return f"{self._redis_prefix}:{content_hash}:{model}" + + async def get( + self, + content: str, + model: str = "default", + ) -> list[float] | None: + """ + Get a cached embedding. + + Args: + content: Content text + model: Model name + + Returns: + Cached embedding or None if not found/expired + """ + content_hash = self.hash_content(content) + cache_key = self._cache_key(content_hash, model) + + # Check memory cache first + with self._lock: + if cache_key in self._cache: + entry = self._cache[cache_key] + if entry.is_expired(): + del self._cache[cache_key] + self._stats.expirations += 1 + self._stats.current_size = len(self._cache) + else: + # Move to end (most recently used) + self._cache.move_to_end(cache_key) + self._stats.hits += 1 + return entry.embedding + + # Check Redis if available + if self._redis: + try: + redis_key = self._redis_key(content_hash, model) + data = await self._redis.get(redis_key) + if data: + import json + + embedding = json.loads(data) + # Store in memory cache for faster access + self._put_memory(content_hash, model, embedding) + self._stats.hits += 1 + return embedding + except Exception as e: + logger.warning(f"Redis get error: {e}") + + self._stats.misses += 1 + return None + + async def get_by_hash( + self, + content_hash: str, + model: str = "default", + ) -> list[float] | None: + """ + Get a cached embedding by hash. + + Args: + content_hash: Content hash + model: Model name + + Returns: + Cached embedding or None if not found/expired + """ + cache_key = self._cache_key(content_hash, model) + + with self._lock: + if cache_key in self._cache: + entry = self._cache[cache_key] + if entry.is_expired(): + del self._cache[cache_key] + self._stats.expirations += 1 + self._stats.current_size = len(self._cache) + else: + self._cache.move_to_end(cache_key) + self._stats.hits += 1 + return entry.embedding + + # Check Redis + if self._redis: + try: + redis_key = self._redis_key(content_hash, model) + data = await self._redis.get(redis_key) + if data: + import json + + embedding = json.loads(data) + self._put_memory(content_hash, model, embedding) + self._stats.hits += 1 + return embedding + except Exception as e: + logger.warning(f"Redis get error: {e}") + + self._stats.misses += 1 + return None + + async def put( + self, + content: str, + embedding: list[float], + model: str = "default", + ttl_seconds: float | None = None, + ) -> str: + """ + Cache an embedding. + + Args: + content: Content text + embedding: Embedding vector + model: Model name + ttl_seconds: Optional TTL override + + Returns: + Content hash + """ + content_hash = self.hash_content(content) + ttl = ttl_seconds or self._default_ttl + + # Store in memory + self._put_memory(content_hash, model, embedding, ttl) + + # Store in Redis if available + if self._redis: + try: + import json + + redis_key = self._redis_key(content_hash, model) + await self._redis.setex( + redis_key, + int(ttl), + json.dumps(embedding), + ) + except Exception as e: + logger.warning(f"Redis put error: {e}") + + return content_hash + + def _put_memory( + self, + content_hash: str, + model: str, + embedding: list[float], + ttl_seconds: float | None = None, + ) -> None: + """Store in memory cache.""" + with self._lock: + # Evict if at capacity + self._evict_if_needed() + + cache_key = self._cache_key(content_hash, model) + entry = EmbeddingEntry( + embedding=embedding, + content_hash=content_hash, + model=model, + created_at=_utcnow(), + ttl_seconds=ttl_seconds or self._default_ttl, + ) + + self._cache[cache_key] = entry + self._cache.move_to_end(cache_key) + self._stats.current_size = len(self._cache) + + def _evict_if_needed(self) -> None: + """Evict entries if cache is at capacity.""" + while len(self._cache) >= self._max_size: + if self._cache: + self._cache.popitem(last=False) + self._stats.evictions += 1 + + async def put_batch( + self, + items: list[tuple[str, list[float]]], + model: str = "default", + ttl_seconds: float | None = None, + ) -> list[str]: + """ + Cache multiple embeddings. + + Args: + items: List of (content, embedding) tuples + model: Model name + ttl_seconds: Optional TTL override + + Returns: + List of content hashes + """ + hashes = [] + for content, embedding in items: + content_hash = await self.put(content, embedding, model, ttl_seconds) + hashes.append(content_hash) + return hashes + + async def invalidate( + self, + content: str, + model: str = "default", + ) -> bool: + """ + Invalidate a cached embedding. + + Args: + content: Content text + model: Model name + + Returns: + True if entry was found and removed + """ + content_hash = self.hash_content(content) + return await self.invalidate_by_hash(content_hash, model) + + async def invalidate_by_hash( + self, + content_hash: str, + model: str = "default", + ) -> bool: + """ + Invalidate a cached embedding by hash. + + Args: + content_hash: Content hash + model: Model name + + Returns: + True if entry was found and removed + """ + cache_key = self._cache_key(content_hash, model) + removed = False + + with self._lock: + if cache_key in self._cache: + del self._cache[cache_key] + self._stats.current_size = len(self._cache) + removed = True + + # Remove from Redis + if self._redis: + try: + redis_key = self._redis_key(content_hash, model) + await self._redis.delete(redis_key) + removed = True + except Exception as e: + logger.warning(f"Redis delete error: {e}") + + return removed + + async def invalidate_by_model(self, model: str) -> int: + """ + Invalidate all embeddings for a model. + + Args: + model: Model name + + Returns: + Number of entries invalidated + """ + count = 0 + + with self._lock: + keys_to_remove = [ + k for k, v in self._cache.items() if v.model == model + ] + for key in keys_to_remove: + del self._cache[key] + count += 1 + + self._stats.current_size = len(self._cache) + + # Note: Redis pattern deletion would require SCAN which is expensive + # For now, we only clear memory cache for model-based invalidation + + return count + + async def clear(self) -> int: + """ + Clear all cache entries. + + Returns: + Number of entries cleared + """ + with self._lock: + count = len(self._cache) + self._cache.clear() + self._stats.current_size = 0 + + # Clear Redis entries + if self._redis: + try: + pattern = f"{self._redis_prefix}:*" + deleted = 0 + async for key in self._redis.scan_iter(match=pattern): + await self._redis.delete(key) + deleted += 1 + count = max(count, deleted) + except Exception as e: + logger.warning(f"Redis clear error: {e}") + + logger.info(f"Cleared {count} entries from embedding cache") + return count + + def cleanup_expired(self) -> int: + """ + Remove all expired entries from memory cache. + + Returns: + Number of entries removed + """ + with self._lock: + keys_to_remove = [ + k for k, v in self._cache.items() if v.is_expired() + ] + for key in keys_to_remove: + del self._cache[key] + self._stats.expirations += 1 + + self._stats.current_size = len(self._cache) + + if keys_to_remove: + logger.debug(f"Cleaned up {len(keys_to_remove)} expired embeddings") + + return len(keys_to_remove) + + def get_stats(self) -> EmbeddingCacheStats: + """Get cache statistics.""" + with self._lock: + self._stats.current_size = len(self._cache) + return self._stats + + def reset_stats(self) -> None: + """Reset cache statistics.""" + with self._lock: + self._stats = EmbeddingCacheStats( + max_size=self._max_size, + current_size=len(self._cache), + ) + + @property + def size(self) -> int: + """Get current cache size.""" + return len(self._cache) + + @property + def max_size(self) -> int: + """Get maximum cache size.""" + return self._max_size + + +class CachedEmbeddingGenerator: + """ + Wrapper for embedding generators with caching. + + Wraps an embedding generator to cache results. + """ + + def __init__( + self, + generator: Any, + cache: EmbeddingCache, + model: str = "default", + ) -> None: + """ + Initialize the cached embedding generator. + + Args: + generator: Underlying embedding generator + cache: Embedding cache + model: Model name for cache keys + """ + self._generator = generator + self._cache = cache + self._model = model + self._call_count = 0 + self._cache_hit_count = 0 + + async def generate(self, text: str) -> list[float]: + """ + Generate embedding with caching. + + Args: + text: Text to embed + + Returns: + Embedding vector + """ + self._call_count += 1 + + # Check cache first + cached = await self._cache.get(text, self._model) + if cached is not None: + self._cache_hit_count += 1 + return cached + + # Generate and cache + embedding = await self._generator.generate(text) + await self._cache.put(text, embedding, self._model) + + return embedding + + async def generate_batch( + self, + texts: list[str], + ) -> list[list[float]]: + """ + Generate embeddings for multiple texts with caching. + + Args: + texts: Texts to embed + + Returns: + List of embedding vectors + """ + results: list[list[float] | None] = [None] * len(texts) + to_generate: list[tuple[int, str]] = [] + + # Check cache for each text + for i, text in enumerate(texts): + cached = await self._cache.get(text, self._model) + if cached is not None: + results[i] = cached + self._cache_hit_count += 1 + else: + to_generate.append((i, text)) + + self._call_count += len(texts) + + # Generate missing embeddings + if to_generate: + if hasattr(self._generator, "generate_batch"): + texts_to_gen = [t for _, t in to_generate] + embeddings = await self._generator.generate_batch(texts_to_gen) + + for (idx, text), embedding in zip(to_generate, embeddings, strict=True): + results[idx] = embedding + await self._cache.put(text, embedding, self._model) + else: + # Fallback to individual generation + for idx, text in to_generate: + embedding = await self._generator.generate(text) + results[idx] = embedding + await self._cache.put(text, embedding, self._model) + + return results # type: ignore[return-value] + + def get_stats(self) -> dict[str, Any]: + """Get generator statistics.""" + return { + "call_count": self._call_count, + "cache_hit_count": self._cache_hit_count, + "cache_hit_rate": ( + self._cache_hit_count / self._call_count + if self._call_count > 0 + else 0.0 + ), + "cache_stats": self._cache.get_stats().to_dict(), + } + + +# Factory function +def create_embedding_cache( + max_size: int = 50000, + default_ttl_seconds: float = 3600.0, + redis: "Redis | None" = None, +) -> EmbeddingCache: + """ + Create an embedding cache. + + Args: + max_size: Maximum number of entries + default_ttl_seconds: Default TTL for entries + redis: Optional Redis connection + + Returns: + Configured EmbeddingCache instance + """ + return EmbeddingCache( + max_size=max_size, + default_ttl_seconds=default_ttl_seconds, + redis=redis, + ) diff --git a/backend/app/services/memory/cache/hot_cache.py b/backend/app/services/memory/cache/hot_cache.py new file mode 100644 index 0000000..f389521 --- /dev/null +++ b/backend/app/services/memory/cache/hot_cache.py @@ -0,0 +1,463 @@ +# app/services/memory/cache/hot_cache.py +""" +Hot Memory Cache. + +LRU cache for frequently accessed memories. +Provides fast access to recently used memories without database queries. +""" + +import logging +import threading +from collections import OrderedDict +from dataclasses import dataclass +from datetime import UTC, datetime +from typing import Any +from uuid import UUID + +logger = logging.getLogger(__name__) + + +def _utcnow() -> datetime: + """Get current UTC time as timezone-aware datetime.""" + return datetime.now(UTC) + + +@dataclass +class CacheEntry[T]: + """A cached memory entry with metadata.""" + + value: T + created_at: datetime + last_accessed_at: datetime + access_count: int = 1 + ttl_seconds: float = 300.0 + + def is_expired(self) -> bool: + """Check if this entry has expired.""" + age = (_utcnow() - self.created_at).total_seconds() + return age > self.ttl_seconds + + def touch(self) -> None: + """Update access time and count.""" + self.last_accessed_at = _utcnow() + self.access_count += 1 + + +@dataclass +class CacheKey: + """A structured cache key with components.""" + + memory_type: str + memory_id: str + scope: str | None = None + + def __hash__(self) -> int: + return hash((self.memory_type, self.memory_id, self.scope)) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, CacheKey): + return False + return ( + self.memory_type == other.memory_type + and self.memory_id == other.memory_id + and self.scope == other.scope + ) + + def __str__(self) -> str: + if self.scope: + return f"{self.memory_type}:{self.scope}:{self.memory_id}" + return f"{self.memory_type}:{self.memory_id}" + + +@dataclass +class HotCacheStats: + """Statistics for the hot memory cache.""" + + hits: int = 0 + misses: int = 0 + evictions: int = 0 + expirations: int = 0 + current_size: int = 0 + max_size: int = 0 + + @property + def hit_rate(self) -> float: + """Calculate cache hit rate.""" + total = self.hits + self.misses + if total == 0: + return 0.0 + return self.hits / total + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary.""" + return { + "hits": self.hits, + "misses": self.misses, + "evictions": self.evictions, + "expirations": self.expirations, + "current_size": self.current_size, + "max_size": self.max_size, + "hit_rate": self.hit_rate, + } + + +class HotMemoryCache[T]: + """ + LRU cache for frequently accessed memories. + + Features: + - LRU eviction when capacity is reached + - TTL-based expiration + - Access count tracking for hot memory identification + - Thread-safe operations + - Scoped invalidation + + Performance targets: + - Cache hit rate > 80% for hot memories + - Get/put operations < 1ms + """ + + def __init__( + self, + max_size: int = 10000, + default_ttl_seconds: float = 300.0, + ) -> None: + """ + Initialize the hot memory cache. + + Args: + max_size: Maximum number of entries + default_ttl_seconds: Default TTL for entries (5 minutes) + """ + self._max_size = max_size + self._default_ttl = default_ttl_seconds + self._cache: OrderedDict[CacheKey, CacheEntry[T]] = OrderedDict() + self._lock = threading.RLock() + self._stats = HotCacheStats(max_size=max_size) + logger.info( + f"Initialized HotMemoryCache with max_size={max_size}, " + f"ttl={default_ttl_seconds}s" + ) + + def get(self, key: CacheKey) -> T | None: + """ + Get a memory from cache. + + Args: + key: Cache key + + Returns: + Cached value or None if not found/expired + """ + with self._lock: + if key not in self._cache: + self._stats.misses += 1 + return None + + entry = self._cache[key] + + # Check expiration + if entry.is_expired(): + del self._cache[key] + self._stats.expirations += 1 + self._stats.misses += 1 + self._stats.current_size = len(self._cache) + return None + + # Move to end (most recently used) + self._cache.move_to_end(key) + entry.touch() + + self._stats.hits += 1 + return entry.value + + def get_by_id( + self, + memory_type: str, + memory_id: UUID | str, + scope: str | None = None, + ) -> T | None: + """ + Get a memory by type and ID. + + Args: + memory_type: Type of memory (episodic, semantic, procedural) + memory_id: Memory ID + scope: Optional scope (project_id, agent_id) + + Returns: + Cached value or None if not found/expired + """ + key = CacheKey( + memory_type=memory_type, + memory_id=str(memory_id), + scope=scope, + ) + return self.get(key) + + def put( + self, + key: CacheKey, + value: T, + ttl_seconds: float | None = None, + ) -> None: + """ + Put a memory into cache. + + Args: + key: Cache key + value: Value to cache + ttl_seconds: Optional TTL override + """ + with self._lock: + # Evict if at capacity + self._evict_if_needed() + + now = _utcnow() + entry = CacheEntry( + value=value, + created_at=now, + last_accessed_at=now, + access_count=1, + ttl_seconds=ttl_seconds or self._default_ttl, + ) + + self._cache[key] = entry + self._cache.move_to_end(key) + self._stats.current_size = len(self._cache) + + def put_by_id( + self, + memory_type: str, + memory_id: UUID | str, + value: T, + scope: str | None = None, + ttl_seconds: float | None = None, + ) -> None: + """ + Put a memory by type and ID. + + Args: + memory_type: Type of memory + memory_id: Memory ID + value: Value to cache + scope: Optional scope + ttl_seconds: Optional TTL override + """ + key = CacheKey( + memory_type=memory_type, + memory_id=str(memory_id), + scope=scope, + ) + self.put(key, value, ttl_seconds) + + def _evict_if_needed(self) -> None: + """Evict entries if cache is at capacity.""" + while len(self._cache) >= self._max_size: + # Remove least recently used (first item) + if self._cache: + self._cache.popitem(last=False) + self._stats.evictions += 1 + + def invalidate(self, key: CacheKey) -> bool: + """ + Invalidate a specific cache entry. + + Args: + key: Cache key to invalidate + + Returns: + True if entry was found and removed + """ + with self._lock: + if key in self._cache: + del self._cache[key] + self._stats.current_size = len(self._cache) + return True + return False + + def invalidate_by_id( + self, + memory_type: str, + memory_id: UUID | str, + scope: str | None = None, + ) -> bool: + """ + Invalidate a memory by type and ID. + + Args: + memory_type: Type of memory + memory_id: Memory ID + scope: Optional scope + + Returns: + True if entry was found and removed + """ + key = CacheKey( + memory_type=memory_type, + memory_id=str(memory_id), + scope=scope, + ) + return self.invalidate(key) + + def invalidate_by_type(self, memory_type: str) -> int: + """ + Invalidate all entries of a memory type. + + Args: + memory_type: Type of memory to invalidate + + Returns: + Number of entries invalidated + """ + with self._lock: + keys_to_remove = [ + k for k in self._cache.keys() if k.memory_type == memory_type + ] + for key in keys_to_remove: + del self._cache[key] + + self._stats.current_size = len(self._cache) + return len(keys_to_remove) + + def invalidate_by_scope(self, scope: str) -> int: + """ + Invalidate all entries in a scope. + + Args: + scope: Scope to invalidate (e.g., project_id) + + Returns: + Number of entries invalidated + """ + with self._lock: + keys_to_remove = [k for k in self._cache.keys() if k.scope == scope] + for key in keys_to_remove: + del self._cache[key] + + self._stats.current_size = len(self._cache) + return len(keys_to_remove) + + def invalidate_pattern(self, pattern: str) -> int: + """ + Invalidate entries matching a pattern. + + Pattern can include * as wildcard. + + Args: + pattern: Pattern to match (e.g., "episodic:*") + + Returns: + Number of entries invalidated + """ + import fnmatch + + with self._lock: + keys_to_remove = [ + k for k in self._cache.keys() if fnmatch.fnmatch(str(k), pattern) + ] + for key in keys_to_remove: + del self._cache[key] + + self._stats.current_size = len(self._cache) + return len(keys_to_remove) + + def clear(self) -> int: + """ + Clear all cache entries. + + Returns: + Number of entries cleared + """ + with self._lock: + count = len(self._cache) + self._cache.clear() + self._stats.current_size = 0 + logger.info(f"Cleared {count} entries from hot cache") + return count + + def cleanup_expired(self) -> int: + """ + Remove all expired entries. + + Returns: + Number of entries removed + """ + with self._lock: + keys_to_remove = [ + k for k, v in self._cache.items() if v.is_expired() + ] + for key in keys_to_remove: + del self._cache[key] + self._stats.expirations += 1 + + self._stats.current_size = len(self._cache) + + if keys_to_remove: + logger.debug(f"Cleaned up {len(keys_to_remove)} expired entries") + + return len(keys_to_remove) + + def get_hot_memories(self, limit: int = 10) -> list[tuple[CacheKey, int]]: + """ + Get the most frequently accessed memories. + + Args: + limit: Maximum number of memories to return + + Returns: + List of (key, access_count) tuples sorted by access count + """ + with self._lock: + entries = [ + (k, v.access_count) + for k, v in self._cache.items() + if not v.is_expired() + ] + entries.sort(key=lambda x: x[1], reverse=True) + return entries[:limit] + + def get_stats(self) -> HotCacheStats: + """Get cache statistics.""" + with self._lock: + self._stats.current_size = len(self._cache) + return self._stats + + def reset_stats(self) -> None: + """Reset cache statistics.""" + with self._lock: + self._stats = HotCacheStats( + max_size=self._max_size, + current_size=len(self._cache), + ) + + @property + def size(self) -> int: + """Get current cache size.""" + return len(self._cache) + + @property + def max_size(self) -> int: + """Get maximum cache size.""" + return self._max_size + + +# Factory function for typed caches +def create_hot_cache( + max_size: int = 10000, + default_ttl_seconds: float = 300.0, +) -> HotMemoryCache[Any]: + """ + Create a hot memory cache. + + Args: + max_size: Maximum number of entries + default_ttl_seconds: Default TTL for entries + + Returns: + Configured HotMemoryCache instance + """ + return HotMemoryCache( + max_size=max_size, + default_ttl_seconds=default_ttl_seconds, + ) diff --git a/backend/tests/unit/services/memory/cache/__init__.py b/backend/tests/unit/services/memory/cache/__init__.py new file mode 100644 index 0000000..ed887de --- /dev/null +++ b/backend/tests/unit/services/memory/cache/__init__.py @@ -0,0 +1,2 @@ +# tests/unit/services/memory/cache/__init__.py +"""Tests for memory caching layer.""" diff --git a/backend/tests/unit/services/memory/cache/test_cache_manager.py b/backend/tests/unit/services/memory/cache/test_cache_manager.py new file mode 100644 index 0000000..63a3b14 --- /dev/null +++ b/backend/tests/unit/services/memory/cache/test_cache_manager.py @@ -0,0 +1,331 @@ +# tests/unit/services/memory/cache/test_cache_manager.py +"""Tests for CacheManager.""" + +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest + +from app.services.memory.cache.cache_manager import ( + CacheManager, + CacheStats, + get_cache_manager, + reset_cache_manager, +) +from app.services.memory.cache.embedding_cache import EmbeddingCache +from app.services.memory.cache.hot_cache import HotMemoryCache + +pytestmark = pytest.mark.asyncio(loop_scope="function") + + +@pytest.fixture(autouse=True) +def reset_singleton() -> None: + """Reset singleton before each test.""" + reset_cache_manager() + + +class TestCacheStats: + """Tests for CacheStats.""" + + def test_to_dict(self) -> None: + """Should convert to dictionary.""" + from datetime import UTC, datetime + + stats = CacheStats( + hot_cache={"hits": 10}, + embedding_cache={"hits": 20}, + overall_hit_rate=0.75, + last_cleanup=datetime.now(UTC), + cleanup_count=5, + ) + + result = stats.to_dict() + + assert result["hot_cache"] == {"hits": 10} + assert result["overall_hit_rate"] == 0.75 + assert result["cleanup_count"] == 5 + assert result["last_cleanup"] is not None + + +class TestCacheManager: + """Tests for CacheManager.""" + + @pytest.fixture + def manager(self) -> CacheManager: + """Create a cache manager.""" + return CacheManager() + + def test_is_enabled(self, manager: CacheManager) -> None: + """Should check if caching is enabled.""" + # Default is enabled from settings + assert manager.is_enabled is True + + def test_has_hot_cache(self, manager: CacheManager) -> None: + """Should have hot memory cache.""" + assert manager.hot_cache is not None + assert isinstance(manager.hot_cache, HotMemoryCache) + + def test_has_embedding_cache(self, manager: CacheManager) -> None: + """Should have embedding cache.""" + assert manager.embedding_cache is not None + assert isinstance(manager.embedding_cache, EmbeddingCache) + + def test_cache_memory(self, manager: CacheManager) -> None: + """Should cache memory in hot cache.""" + memory_id = uuid4() + memory = {"task": "test", "data": "value"} + + manager.cache_memory("episodic", memory_id, memory) + result = manager.get_memory("episodic", memory_id) + + assert result == memory + + def test_cache_memory_with_scope(self, manager: CacheManager) -> None: + """Should cache memory with scope.""" + memory_id = uuid4() + memory = {"task": "test"} + + manager.cache_memory("semantic", memory_id, memory, scope="proj-123") + result = manager.get_memory("semantic", memory_id, scope="proj-123") + + assert result == memory + + async def test_cache_embedding(self, manager: CacheManager) -> None: + """Should cache embedding.""" + content = "test content" + embedding = [0.1, 0.2, 0.3] + + content_hash = await manager.cache_embedding(content, embedding) + result = await manager.get_embedding(content) + + assert result == embedding + assert len(content_hash) == 32 + + async def test_invalidate_memory(self, manager: CacheManager) -> None: + """Should invalidate memory from hot cache.""" + memory_id = uuid4() + manager.cache_memory("episodic", memory_id, {"data": "test"}) + + count = await manager.invalidate_memory("episodic", memory_id) + + assert count >= 1 + assert manager.get_memory("episodic", memory_id) is None + + async def test_invalidate_by_type(self, manager: CacheManager) -> None: + """Should invalidate all entries of a type.""" + manager.cache_memory("episodic", uuid4(), {"data": "1"}) + manager.cache_memory("episodic", uuid4(), {"data": "2"}) + manager.cache_memory("semantic", uuid4(), {"data": "3"}) + + count = await manager.invalidate_by_type("episodic") + + assert count >= 2 + + async def test_invalidate_by_scope(self, manager: CacheManager) -> None: + """Should invalidate all entries in a scope.""" + manager.cache_memory("episodic", uuid4(), {"data": "1"}, scope="proj-1") + manager.cache_memory("semantic", uuid4(), {"data": "2"}, scope="proj-1") + manager.cache_memory("episodic", uuid4(), {"data": "3"}, scope="proj-2") + + count = await manager.invalidate_by_scope("proj-1") + + assert count >= 2 + + async def test_invalidate_embedding(self, manager: CacheManager) -> None: + """Should invalidate cached embedding.""" + content = "test content" + await manager.cache_embedding(content, [0.1, 0.2]) + + result = await manager.invalidate_embedding(content) + + assert result is True + assert await manager.get_embedding(content) is None + + async def test_clear_all(self, manager: CacheManager) -> None: + """Should clear all caches.""" + manager.cache_memory("episodic", uuid4(), {"data": "test"}) + await manager.cache_embedding("content", [0.1]) + + count = await manager.clear_all() + + assert count >= 2 + + async def test_cleanup_expired(self, manager: CacheManager) -> None: + """Should clean up expired entries.""" + count = await manager.cleanup_expired() + + # May be 0 if no expired entries + assert count >= 0 + assert manager._cleanup_count == 1 + assert manager._last_cleanup is not None + + def test_get_stats(self, manager: CacheManager) -> None: + """Should return aggregated statistics.""" + manager.cache_memory("episodic", uuid4(), {"data": "test"}) + + stats = manager.get_stats() + + assert "hot_cache" in stats.to_dict() + assert "embedding_cache" in stats.to_dict() + assert "overall_hit_rate" in stats.to_dict() + + def test_get_hot_memories(self, manager: CacheManager) -> None: + """Should return most accessed memories.""" + id1 = uuid4() + id2 = uuid4() + + manager.cache_memory("episodic", id1, {"data": "1"}) + manager.cache_memory("episodic", id2, {"data": "2"}) + + # Access first multiple times + for _ in range(5): + manager.get_memory("episodic", id1) + + hot = manager.get_hot_memories(limit=2) + + assert len(hot) == 2 + + def test_reset_stats(self, manager: CacheManager) -> None: + """Should reset all statistics.""" + manager.cache_memory("episodic", uuid4(), {"data": "test"}) + manager.get_memory("episodic", uuid4()) # Miss + + manager.reset_stats() + + stats = manager.get_stats() + assert stats.hot_cache.get("hits", 0) == 0 + + async def test_warmup(self, manager: CacheManager) -> None: + """Should warm up cache with memories.""" + memories = [ + ("episodic", uuid4(), {"data": "1"}), + ("episodic", uuid4(), {"data": "2"}), + ("semantic", uuid4(), {"data": "3"}), + ] + + count = await manager.warmup(memories) + + assert count == 3 + + +class TestCacheManagerWithRetrieval: + """Tests for CacheManager with retrieval cache.""" + + @pytest.fixture + def mock_retrieval_cache(self) -> MagicMock: + """Create mock retrieval cache.""" + cache = MagicMock() + cache.invalidate_by_memory = MagicMock(return_value=1) + cache.clear = MagicMock(return_value=5) + cache.get_stats = MagicMock(return_value={"entries": 10}) + return cache + + @pytest.fixture + def manager_with_retrieval( + self, + mock_retrieval_cache: MagicMock, + ) -> CacheManager: + """Create manager with retrieval cache.""" + manager = CacheManager() + manager.set_retrieval_cache(mock_retrieval_cache) + return manager + + async def test_invalidate_clears_retrieval( + self, + manager_with_retrieval: CacheManager, + mock_retrieval_cache: MagicMock, + ) -> None: + """Should invalidate retrieval cache entries.""" + memory_id = uuid4() + + await manager_with_retrieval.invalidate_memory("episodic", memory_id) + + mock_retrieval_cache.invalidate_by_memory.assert_called_once_with(memory_id) + + def test_stats_includes_retrieval( + self, + manager_with_retrieval: CacheManager, + ) -> None: + """Should include retrieval cache stats.""" + stats = manager_with_retrieval.get_stats() + + assert "retrieval_cache" in stats.to_dict() + + +class TestCacheManagerDisabled: + """Tests for CacheManager when disabled.""" + + @pytest.fixture + def disabled_manager(self) -> CacheManager: + """Create a disabled cache manager.""" + with patch( + "app.services.memory.cache.cache_manager.get_memory_settings" + ) as mock_settings: + settings = MagicMock() + settings.cache_enabled = False + settings.cache_max_items = 1000 + settings.cache_ttl_seconds = 300 + mock_settings.return_value = settings + + return CacheManager() + + def test_get_memory_returns_none(self, disabled_manager: CacheManager) -> None: + """Should return None when disabled.""" + disabled_manager.cache_memory("episodic", uuid4(), {"data": "test"}) + result = disabled_manager.get_memory("episodic", uuid4()) + + assert result is None + + async def test_get_embedding_returns_none( + self, + disabled_manager: CacheManager, + ) -> None: + """Should return None for embeddings when disabled.""" + result = await disabled_manager.get_embedding("content") + + assert result is None + + async def test_warmup_returns_zero(self, disabled_manager: CacheManager) -> None: + """Should return 0 from warmup when disabled.""" + count = await disabled_manager.warmup([("episodic", uuid4(), {})]) + + assert count == 0 + + +class TestGetCacheManager: + """Tests for get_cache_manager factory.""" + + def test_returns_singleton(self) -> None: + """Should return same instance.""" + manager1 = get_cache_manager() + manager2 = get_cache_manager() + + assert manager1 is manager2 + + def test_reset_creates_new(self) -> None: + """Should create new instance after reset.""" + manager1 = get_cache_manager() + reset_cache_manager() + manager2 = get_cache_manager() + + assert manager1 is not manager2 + + def test_reset_parameter(self) -> None: + """Should create new instance with reset=True.""" + manager1 = get_cache_manager() + manager2 = get_cache_manager(reset=True) + + assert manager1 is not manager2 + + +class TestResetCacheManager: + """Tests for reset_cache_manager.""" + + def test_resets_singleton(self) -> None: + """Should reset the singleton.""" + get_cache_manager() + reset_cache_manager() + + # Next call should create new instance + manager = get_cache_manager() + assert manager is not None diff --git a/backend/tests/unit/services/memory/cache/test_embedding_cache.py b/backend/tests/unit/services/memory/cache/test_embedding_cache.py new file mode 100644 index 0000000..74229d0 --- /dev/null +++ b/backend/tests/unit/services/memory/cache/test_embedding_cache.py @@ -0,0 +1,391 @@ +# tests/unit/services/memory/cache/test_embedding_cache.py +"""Tests for EmbeddingCache.""" + +import time +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from app.services.memory.cache.embedding_cache import ( + CachedEmbeddingGenerator, + EmbeddingCache, + EmbeddingCacheStats, + EmbeddingEntry, + create_embedding_cache, +) + +pytestmark = pytest.mark.asyncio(loop_scope="function") + + +class TestEmbeddingEntry: + """Tests for EmbeddingEntry.""" + + def test_creates_entry(self) -> None: + """Should create entry with embedding.""" + from datetime import UTC, datetime + + entry = EmbeddingEntry( + embedding=[0.1, 0.2, 0.3], + content_hash="abc123", + model="text-embedding-3-small", + created_at=datetime.now(UTC), + ) + + assert entry.embedding == [0.1, 0.2, 0.3] + assert entry.content_hash == "abc123" + assert entry.ttl_seconds == 3600.0 + + def test_is_expired(self) -> None: + """Should detect expired entries.""" + from datetime import UTC, datetime, timedelta + + old_time = datetime.now(UTC) - timedelta(seconds=4000) + entry = EmbeddingEntry( + embedding=[0.1], + content_hash="abc", + model="default", + created_at=old_time, + ttl_seconds=3600.0, + ) + + assert entry.is_expired() is True + + def test_not_expired(self) -> None: + """Should detect non-expired entries.""" + from datetime import UTC, datetime + + entry = EmbeddingEntry( + embedding=[0.1], + content_hash="abc", + model="default", + created_at=datetime.now(UTC), + ) + + assert entry.is_expired() is False + + +class TestEmbeddingCacheStats: + """Tests for EmbeddingCacheStats.""" + + def test_hit_rate_calculation(self) -> None: + """Should calculate hit rate correctly.""" + stats = EmbeddingCacheStats(hits=90, misses=10) + + assert stats.hit_rate == 0.9 + + def test_hit_rate_zero_requests(self) -> None: + """Should return 0 for no requests.""" + stats = EmbeddingCacheStats() + + assert stats.hit_rate == 0.0 + + def test_to_dict(self) -> None: + """Should convert to dictionary.""" + stats = EmbeddingCacheStats(hits=10, misses=5, bytes_saved=1000) + + result = stats.to_dict() + + assert result["hits"] == 10 + assert result["bytes_saved"] == 1000 + + +class TestEmbeddingCache: + """Tests for EmbeddingCache.""" + + @pytest.fixture + def cache(self) -> EmbeddingCache: + """Create an embedding cache.""" + return EmbeddingCache(max_size=100, default_ttl_seconds=300.0) + + async def test_put_and_get(self, cache: EmbeddingCache) -> None: + """Should store and retrieve embeddings.""" + content = "Hello world" + embedding = [0.1, 0.2, 0.3, 0.4] + + content_hash = await cache.put(content, embedding) + result = await cache.get(content) + + assert result == embedding + assert len(content_hash) == 32 + + async def test_get_missing(self, cache: EmbeddingCache) -> None: + """Should return None for missing content.""" + result = await cache.get("nonexistent content") + + assert result is None + + async def test_get_by_hash(self, cache: EmbeddingCache) -> None: + """Should get by content hash.""" + content = "Test content" + embedding = [0.1, 0.2] + + content_hash = await cache.put(content, embedding) + result = await cache.get_by_hash(content_hash) + + assert result == embedding + + async def test_model_separation(self, cache: EmbeddingCache) -> None: + """Should separate embeddings by model.""" + content = "Same content" + emb1 = [0.1, 0.2] + emb2 = [0.3, 0.4] + + await cache.put(content, emb1, model="model-a") + await cache.put(content, emb2, model="model-b") + + result1 = await cache.get(content, model="model-a") + result2 = await cache.get(content, model="model-b") + + assert result1 == emb1 + assert result2 == emb2 + + async def test_lru_eviction(self) -> None: + """Should evict LRU entries when at capacity.""" + cache = EmbeddingCache(max_size=3) + + await cache.put("content1", [0.1]) + await cache.put("content2", [0.2]) + await cache.put("content3", [0.3]) + + # Access first to make it recent + await cache.get("content1") + + # Add fourth, should evict second (LRU) + await cache.put("content4", [0.4]) + + assert await cache.get("content1") is not None + assert await cache.get("content2") is None # Evicted + assert await cache.get("content3") is not None + assert await cache.get("content4") is not None + + async def test_ttl_expiration(self) -> None: + """Should expire entries after TTL.""" + cache = EmbeddingCache(max_size=100, default_ttl_seconds=0.1) + + await cache.put("content", [0.1, 0.2]) + + time.sleep(0.2) + + result = await cache.get("content") + + assert result is None + + async def test_put_batch(self, cache: EmbeddingCache) -> None: + """Should cache multiple embeddings.""" + items = [ + ("content1", [0.1]), + ("content2", [0.2]), + ("content3", [0.3]), + ] + + hashes = await cache.put_batch(items) + + assert len(hashes) == 3 + assert await cache.get("content1") == [0.1] + assert await cache.get("content2") == [0.2] + + async def test_invalidate(self, cache: EmbeddingCache) -> None: + """Should invalidate cached embedding.""" + await cache.put("content", [0.1, 0.2]) + + result = await cache.invalidate("content") + + assert result is True + assert await cache.get("content") is None + + async def test_invalidate_by_hash(self, cache: EmbeddingCache) -> None: + """Should invalidate by hash.""" + content_hash = await cache.put("content", [0.1, 0.2]) + + result = await cache.invalidate_by_hash(content_hash) + + assert result is True + assert await cache.get("content") is None + + async def test_invalidate_by_model(self, cache: EmbeddingCache) -> None: + """Should invalidate all embeddings for a model.""" + await cache.put("content1", [0.1], model="model-a") + await cache.put("content2", [0.2], model="model-a") + await cache.put("content3", [0.3], model="model-b") + + count = await cache.invalidate_by_model("model-a") + + assert count == 2 + assert await cache.get("content1", model="model-a") is None + assert await cache.get("content3", model="model-b") is not None + + async def test_clear(self, cache: EmbeddingCache) -> None: + """Should clear all entries.""" + await cache.put("content1", [0.1]) + await cache.put("content2", [0.2]) + + count = await cache.clear() + + assert count == 2 + assert cache.size == 0 + + def test_cleanup_expired(self) -> None: + """Should remove expired entries.""" + cache = EmbeddingCache(max_size=100, default_ttl_seconds=0.1) + + # Use synchronous put for setup + cache._put_memory("hash1", "default", [0.1]) + cache._put_memory("hash2", "default", [0.2], ttl_seconds=10) + + time.sleep(0.2) + + count = cache.cleanup_expired() + + assert count == 1 + + def test_get_stats(self, cache: EmbeddingCache) -> None: + """Should return accurate statistics.""" + # Put synchronously for setup + cache._put_memory("hash1", "default", [0.1]) + + stats = cache.get_stats() + + assert stats.current_size == 1 + + def test_hash_content(self) -> None: + """Should produce consistent hashes.""" + hash1 = EmbeddingCache.hash_content("test content") + hash2 = EmbeddingCache.hash_content("test content") + hash3 = EmbeddingCache.hash_content("different content") + + assert hash1 == hash2 + assert hash1 != hash3 + assert len(hash1) == 32 + + +class TestEmbeddingCacheWithRedis: + """Tests for EmbeddingCache with Redis.""" + + @pytest.fixture + def mock_redis(self) -> MagicMock: + """Create mock Redis.""" + redis = MagicMock() + redis.get = AsyncMock(return_value=None) + redis.setex = AsyncMock() + redis.delete = AsyncMock() + redis.scan_iter = MagicMock(return_value=iter([])) + return redis + + @pytest.fixture + def cache_with_redis(self, mock_redis: MagicMock) -> EmbeddingCache: + """Create cache with mock Redis.""" + return EmbeddingCache( + max_size=100, + default_ttl_seconds=300.0, + redis=mock_redis, + ) + + async def test_put_stores_in_redis( + self, + cache_with_redis: EmbeddingCache, + mock_redis: MagicMock, + ) -> None: + """Should store in Redis when available.""" + await cache_with_redis.put("content", [0.1, 0.2]) + + mock_redis.setex.assert_called_once() + + async def test_get_checks_redis_on_miss( + self, + cache_with_redis: EmbeddingCache, + mock_redis: MagicMock, + ) -> None: + """Should check Redis when memory cache misses.""" + import json + + mock_redis.get.return_value = json.dumps([0.1, 0.2]) + + result = await cache_with_redis.get("content") + + assert result == [0.1, 0.2] + mock_redis.get.assert_called_once() + + +class TestCachedEmbeddingGenerator: + """Tests for CachedEmbeddingGenerator.""" + + @pytest.fixture + def mock_generator(self) -> MagicMock: + """Create mock embedding generator.""" + gen = MagicMock() + gen.generate = AsyncMock(return_value=[0.1, 0.2, 0.3]) + gen.generate_batch = AsyncMock(return_value=[[0.1], [0.2], [0.3]]) + return gen + + @pytest.fixture + def cache(self) -> EmbeddingCache: + """Create embedding cache.""" + return EmbeddingCache(max_size=100) + + @pytest.fixture + def cached_gen( + self, + mock_generator: MagicMock, + cache: EmbeddingCache, + ) -> CachedEmbeddingGenerator: + """Create cached generator.""" + return CachedEmbeddingGenerator(mock_generator, cache) + + async def test_generate_caches_result( + self, + cached_gen: CachedEmbeddingGenerator, + mock_generator: MagicMock, + ) -> None: + """Should cache generated embedding.""" + result1 = await cached_gen.generate("test text") + result2 = await cached_gen.generate("test text") + + assert result1 == [0.1, 0.2, 0.3] + assert result2 == [0.1, 0.2, 0.3] + mock_generator.generate.assert_called_once() # Only called once + + async def test_generate_batch_uses_cache( + self, + cached_gen: CachedEmbeddingGenerator, + mock_generator: MagicMock, + cache: EmbeddingCache, + ) -> None: + """Should use cache for batch generation.""" + # Pre-cache one embedding + await cache.put("text1", [0.5]) + + # Mock returns 2 embeddings for the 2 uncached texts + mock_generator.generate_batch = AsyncMock(return_value=[[0.2], [0.3]]) + + results = await cached_gen.generate_batch(["text1", "text2", "text3"]) + + assert len(results) == 3 + assert results[0] == [0.5] # From cache + assert results[1] == [0.2] # Generated + assert results[2] == [0.3] # Generated + + async def test_get_stats(self, cached_gen: CachedEmbeddingGenerator) -> None: + """Should return generator statistics.""" + await cached_gen.generate("text1") + await cached_gen.generate("text1") # Cache hit + + stats = cached_gen.get_stats() + + assert stats["call_count"] == 2 + assert stats["cache_hit_count"] == 1 + + +class TestCreateEmbeddingCache: + """Tests for factory function.""" + + def test_creates_cache(self) -> None: + """Should create cache with defaults.""" + cache = create_embedding_cache() + + assert cache.max_size == 50000 + + def test_creates_cache_with_options(self) -> None: + """Should create cache with custom options.""" + cache = create_embedding_cache(max_size=1000, default_ttl_seconds=600.0) + + assert cache.max_size == 1000 diff --git a/backend/tests/unit/services/memory/cache/test_hot_cache.py b/backend/tests/unit/services/memory/cache/test_hot_cache.py new file mode 100644 index 0000000..5a59211 --- /dev/null +++ b/backend/tests/unit/services/memory/cache/test_hot_cache.py @@ -0,0 +1,355 @@ +# tests/unit/services/memory/cache/test_hot_cache.py +"""Tests for HotMemoryCache.""" + +import time +from uuid import uuid4 + +import pytest + +from app.services.memory.cache.hot_cache import ( + CacheEntry, + CacheKey, + HotCacheStats, + HotMemoryCache, + create_hot_cache, +) + + +class TestCacheKey: + """Tests for CacheKey.""" + + def test_creates_key(self) -> None: + """Should create key with required fields.""" + key = CacheKey(memory_type="episodic", memory_id="123") + + assert key.memory_type == "episodic" + assert key.memory_id == "123" + assert key.scope is None + + def test_creates_key_with_scope(self) -> None: + """Should create key with scope.""" + key = CacheKey(memory_type="semantic", memory_id="456", scope="proj-123") + + assert key.scope == "proj-123" + + def test_hash_and_equality(self) -> None: + """Keys with same values should be equal and have same hash.""" + key1 = CacheKey(memory_type="episodic", memory_id="123", scope="proj-1") + key2 = CacheKey(memory_type="episodic", memory_id="123", scope="proj-1") + + assert key1 == key2 + assert hash(key1) == hash(key2) + + def test_str_representation(self) -> None: + """Should produce readable string.""" + key = CacheKey(memory_type="episodic", memory_id="123", scope="proj-1") + + assert str(key) == "episodic:proj-1:123" + + def test_str_without_scope(self) -> None: + """Should produce string without scope.""" + key = CacheKey(memory_type="episodic", memory_id="123") + + assert str(key) == "episodic:123" + + +class TestCacheEntry: + """Tests for CacheEntry.""" + + def test_creates_entry(self) -> None: + """Should create entry with value.""" + entry = CacheEntry( + value={"data": "test"}, + created_at=pytest.importorskip("datetime").datetime.now( + pytest.importorskip("datetime").UTC + ), + last_accessed_at=pytest.importorskip("datetime").datetime.now( + pytest.importorskip("datetime").UTC + ), + ) + + assert entry.value == {"data": "test"} + assert entry.access_count == 1 + assert entry.ttl_seconds == 300.0 + + def test_is_expired(self) -> None: + """Should detect expired entries.""" + from datetime import UTC, datetime, timedelta + + old_time = datetime.now(UTC) - timedelta(seconds=400) + entry = CacheEntry( + value="test", + created_at=old_time, + last_accessed_at=old_time, + ttl_seconds=300.0, + ) + + assert entry.is_expired() is True + + def test_not_expired(self) -> None: + """Should detect non-expired entries.""" + from datetime import UTC, datetime + + entry = CacheEntry( + value="test", + created_at=datetime.now(UTC), + last_accessed_at=datetime.now(UTC), + ttl_seconds=300.0, + ) + + assert entry.is_expired() is False + + def test_touch_updates_access(self) -> None: + """Touch should update access time and count.""" + from datetime import UTC, datetime, timedelta + + old_time = datetime.now(UTC) - timedelta(seconds=10) + entry = CacheEntry( + value="test", + created_at=old_time, + last_accessed_at=old_time, + access_count=5, + ) + + entry.touch() + + assert entry.access_count == 6 + assert entry.last_accessed_at > old_time + + +class TestHotCacheStats: + """Tests for HotCacheStats.""" + + def test_hit_rate_calculation(self) -> None: + """Should calculate hit rate correctly.""" + stats = HotCacheStats(hits=80, misses=20) + + assert stats.hit_rate == 0.8 + + def test_hit_rate_zero_requests(self) -> None: + """Should return 0 for no requests.""" + stats = HotCacheStats() + + assert stats.hit_rate == 0.0 + + def test_to_dict(self) -> None: + """Should convert to dictionary.""" + stats = HotCacheStats(hits=10, misses=5, evictions=2) + + result = stats.to_dict() + + assert result["hits"] == 10 + assert result["misses"] == 5 + assert result["evictions"] == 2 + assert "hit_rate" in result + + +class TestHotMemoryCache: + """Tests for HotMemoryCache.""" + + @pytest.fixture + def cache(self) -> HotMemoryCache[dict]: + """Create a hot memory cache.""" + return HotMemoryCache[dict](max_size=100, default_ttl_seconds=300.0) + + def test_put_and_get(self, cache: HotMemoryCache[dict]) -> None: + """Should store and retrieve values.""" + key = CacheKey(memory_type="episodic", memory_id="123") + value = {"data": "test"} + + cache.put(key, value) + result = cache.get(key) + + assert result == value + + def test_get_missing_key(self, cache: HotMemoryCache[dict]) -> None: + """Should return None for missing keys.""" + key = CacheKey(memory_type="episodic", memory_id="nonexistent") + + result = cache.get(key) + + assert result is None + + def test_put_by_id(self, cache: HotMemoryCache[dict]) -> None: + """Should store by type and ID.""" + memory_id = uuid4() + value = {"data": "test"} + + cache.put_by_id("episodic", memory_id, value) + result = cache.get_by_id("episodic", memory_id) + + assert result == value + + def test_put_by_id_with_scope(self, cache: HotMemoryCache[dict]) -> None: + """Should store with scope.""" + memory_id = uuid4() + value = {"data": "test"} + + cache.put_by_id("semantic", memory_id, value, scope="proj-123") + result = cache.get_by_id("semantic", memory_id, scope="proj-123") + + assert result == value + + def test_lru_eviction(self) -> None: + """Should evict LRU entries when at capacity.""" + cache = HotMemoryCache[str](max_size=3) + + # Fill cache + cache.put_by_id("test", "1", "first") + cache.put_by_id("test", "2", "second") + cache.put_by_id("test", "3", "third") + + # Access first to make it recent + cache.get_by_id("test", "1") + + # Add fourth, should evict second (LRU) + cache.put_by_id("test", "4", "fourth") + + assert cache.get_by_id("test", "1") is not None # Accessed, kept + assert cache.get_by_id("test", "2") is None # Evicted (LRU) + assert cache.get_by_id("test", "3") is not None + assert cache.get_by_id("test", "4") is not None + + def test_ttl_expiration(self) -> None: + """Should expire entries after TTL.""" + cache = HotMemoryCache[str](max_size=100, default_ttl_seconds=0.1) + + cache.put_by_id("test", "1", "value") + + # Wait for expiration + time.sleep(0.2) + + result = cache.get_by_id("test", "1") + + assert result is None + + def test_invalidate(self, cache: HotMemoryCache[dict]) -> None: + """Should invalidate specific entry.""" + key = CacheKey(memory_type="episodic", memory_id="123") + cache.put(key, {"data": "test"}) + + result = cache.invalidate(key) + + assert result is True + assert cache.get(key) is None + + def test_invalidate_by_id(self, cache: HotMemoryCache[dict]) -> None: + """Should invalidate by ID.""" + memory_id = uuid4() + cache.put_by_id("episodic", memory_id, {"data": "test"}) + + result = cache.invalidate_by_id("episodic", memory_id) + + assert result is True + assert cache.get_by_id("episodic", memory_id) is None + + def test_invalidate_by_type(self, cache: HotMemoryCache[dict]) -> None: + """Should invalidate all entries of a type.""" + cache.put_by_id("episodic", "1", {"data": "1"}) + cache.put_by_id("episodic", "2", {"data": "2"}) + cache.put_by_id("semantic", "3", {"data": "3"}) + + count = cache.invalidate_by_type("episodic") + + assert count == 2 + assert cache.get_by_id("episodic", "1") is None + assert cache.get_by_id("episodic", "2") is None + assert cache.get_by_id("semantic", "3") is not None + + def test_invalidate_by_scope(self, cache: HotMemoryCache[dict]) -> None: + """Should invalidate all entries in a scope.""" + cache.put_by_id("episodic", "1", {"data": "1"}, scope="proj-1") + cache.put_by_id("semantic", "2", {"data": "2"}, scope="proj-1") + cache.put_by_id("episodic", "3", {"data": "3"}, scope="proj-2") + + count = cache.invalidate_by_scope("proj-1") + + assert count == 2 + assert cache.get_by_id("episodic", "3", scope="proj-2") is not None + + def test_invalidate_pattern(self, cache: HotMemoryCache[dict]) -> None: + """Should invalidate entries matching pattern.""" + cache.put_by_id("episodic", "123", {"data": "1"}) + cache.put_by_id("episodic", "124", {"data": "2"}) + cache.put_by_id("semantic", "125", {"data": "3"}) + + count = cache.invalidate_pattern("episodic:*") + + assert count == 2 + + def test_clear(self, cache: HotMemoryCache[dict]) -> None: + """Should clear all entries.""" + cache.put_by_id("episodic", "1", {"data": "1"}) + cache.put_by_id("semantic", "2", {"data": "2"}) + + count = cache.clear() + + assert count == 2 + assert cache.size == 0 + + def test_cleanup_expired(self) -> None: + """Should remove expired entries.""" + cache = HotMemoryCache[str](max_size=100, default_ttl_seconds=0.1) + + cache.put_by_id("test", "1", "value1") + cache.put_by_id("test", "2", "value2", ttl_seconds=10) + + time.sleep(0.2) + + count = cache.cleanup_expired() + + assert count == 1 # Only the first one expired + assert cache.size == 1 + + def test_get_hot_memories(self, cache: HotMemoryCache[dict]) -> None: + """Should return most accessed memories.""" + cache.put_by_id("episodic", "1", {"data": "1"}) + cache.put_by_id("episodic", "2", {"data": "2"}) + + # Access first one multiple times + for _ in range(5): + cache.get_by_id("episodic", "1") + + hot = cache.get_hot_memories(limit=2) + + assert len(hot) == 2 + assert hot[0][1] >= hot[1][1] # Sorted by access count + + def test_get_stats(self, cache: HotMemoryCache[dict]) -> None: + """Should return accurate statistics.""" + cache.put_by_id("episodic", "1", {"data": "1"}) + cache.get_by_id("episodic", "1") # Hit + cache.get_by_id("episodic", "2") # Miss + + stats = cache.get_stats() + + assert stats.hits == 1 + assert stats.misses == 1 + assert stats.current_size == 1 + + def test_reset_stats(self, cache: HotMemoryCache[dict]) -> None: + """Should reset statistics.""" + cache.put_by_id("episodic", "1", {"data": "1"}) + cache.get_by_id("episodic", "1") + + cache.reset_stats() + stats = cache.get_stats() + + assert stats.hits == 0 + assert stats.misses == 0 + + +class TestCreateHotCache: + """Tests for factory function.""" + + def test_creates_cache(self) -> None: + """Should create cache with defaults.""" + cache = create_hot_cache() + + assert cache.max_size == 10000 + + def test_creates_cache_with_options(self) -> None: + """Should create cache with custom options.""" + cache = create_hot_cache(max_size=500, default_ttl_seconds=60.0) + + assert cache.max_size == 500