# app/services/memory/cache/cache_manager.py """ Cache Manager. Unified cache management for memory operations. Coordinates hot cache, embedding cache, and retrieval cache. Provides centralized invalidation and statistics. """ import logging import threading from dataclasses import dataclass, field from datetime import UTC, datetime from typing import TYPE_CHECKING, Any from uuid import UUID from app.services.memory.config import get_memory_settings from .embedding_cache import EmbeddingCache, create_embedding_cache from .hot_cache import CacheKey, HotMemoryCache, create_hot_cache if TYPE_CHECKING: from redis.asyncio import Redis from app.services.memory.indexing.retrieval import RetrievalCache logger = logging.getLogger(__name__) def _utcnow() -> datetime: """Get current UTC time as timezone-aware datetime.""" return datetime.now(UTC) @dataclass class CacheStats: """Aggregated cache statistics.""" hot_cache: dict[str, Any] = field(default_factory=dict) embedding_cache: dict[str, Any] = field(default_factory=dict) retrieval_cache: dict[str, Any] = field(default_factory=dict) overall_hit_rate: float = 0.0 last_cleanup: datetime | None = None cleanup_count: int = 0 def to_dict(self) -> dict[str, Any]: """Convert to dictionary.""" return { "hot_cache": self.hot_cache, "embedding_cache": self.embedding_cache, "retrieval_cache": self.retrieval_cache, "overall_hit_rate": self.overall_hit_rate, "last_cleanup": self.last_cleanup.isoformat() if self.last_cleanup else None, "cleanup_count": self.cleanup_count, } class CacheManager: """ Unified cache manager for memory operations. Provides: - Centralized cache configuration - Coordinated invalidation across caches - Aggregated statistics - Automatic cleanup scheduling Performance targets: - Overall cache hit rate > 80% - Cache operations < 1ms (memory), < 5ms (Redis) """ def __init__( self, hot_cache: HotMemoryCache[Any] | None = None, embedding_cache: EmbeddingCache | None = None, retrieval_cache: "RetrievalCache | None" = None, redis: "Redis | None" = None, ) -> None: """ Initialize the cache manager. Args: hot_cache: Optional pre-configured hot cache embedding_cache: Optional pre-configured embedding cache retrieval_cache: Optional pre-configured retrieval cache redis: Optional Redis connection for persistence """ self._settings = get_memory_settings() self._redis = redis self._enabled = self._settings.cache_enabled # Initialize caches if hot_cache: self._hot_cache = hot_cache else: self._hot_cache = create_hot_cache( max_size=self._settings.cache_max_items, default_ttl_seconds=self._settings.cache_ttl_seconds, ) if embedding_cache: self._embedding_cache = embedding_cache else: self._embedding_cache = create_embedding_cache( max_size=self._settings.cache_max_items, default_ttl_seconds=self._settings.cache_ttl_seconds * 12, # 1hr for embeddings redis=redis, ) self._retrieval_cache = retrieval_cache # Stats tracking self._last_cleanup: datetime | None = None self._cleanup_count = 0 self._lock = threading.RLock() logger.info( f"Initialized CacheManager: enabled={self._enabled}, " f"redis={'connected' if redis else 'disabled'}" ) def set_redis(self, redis: "Redis") -> None: """Set Redis connection for all caches.""" self._redis = redis self._embedding_cache.set_redis(redis) def set_retrieval_cache(self, cache: "RetrievalCache") -> None: """Set retrieval cache instance.""" self._retrieval_cache = cache @property def is_enabled(self) -> bool: """Check if caching is enabled.""" return self._enabled @property def hot_cache(self) -> HotMemoryCache[Any]: """Get the hot memory cache.""" return self._hot_cache @property def embedding_cache(self) -> EmbeddingCache: """Get the embedding cache.""" return self._embedding_cache @property def retrieval_cache(self) -> "RetrievalCache | None": """Get the retrieval cache.""" return self._retrieval_cache # ========================================================================= # Hot Memory Cache Operations # ========================================================================= def get_memory( self, memory_type: str, memory_id: UUID | str, scope: str | None = None, ) -> Any | None: """ Get a memory from hot cache. Args: memory_type: Type of memory memory_id: Memory ID scope: Optional scope Returns: Cached memory or None """ if not self._enabled: return None return self._hot_cache.get_by_id(memory_type, memory_id, scope) def cache_memory( self, memory_type: str, memory_id: UUID | str, memory: Any, scope: str | None = None, ttl_seconds: float | None = None, ) -> None: """ Cache a memory in hot cache. Args: memory_type: Type of memory memory_id: Memory ID memory: Memory object scope: Optional scope ttl_seconds: Optional TTL override """ if not self._enabled: return self._hot_cache.put_by_id(memory_type, memory_id, memory, scope, ttl_seconds) # ========================================================================= # Embedding Cache Operations # ========================================================================= async def get_embedding( self, content: str, model: str = "default", ) -> list[float] | None: """ Get a cached embedding. Args: content: Content text model: Model name Returns: Cached embedding or None """ if not self._enabled: return None return await self._embedding_cache.get(content, model) async def cache_embedding( self, content: str, embedding: list[float], model: str = "default", ttl_seconds: float | None = None, ) -> str: """ Cache an embedding. Args: content: Content text embedding: Embedding vector model: Model name ttl_seconds: Optional TTL override Returns: Content hash """ if not self._enabled: return EmbeddingCache.hash_content(content) return await self._embedding_cache.put(content, embedding, model, ttl_seconds) # ========================================================================= # Invalidation # ========================================================================= async def invalidate_memory( self, memory_type: str, memory_id: UUID | str, scope: str | None = None, ) -> int: """ Invalidate a memory across all caches. Args: memory_type: Type of memory memory_id: Memory ID scope: Optional scope Returns: Number of entries invalidated """ count = 0 # Invalidate hot cache if self._hot_cache.invalidate_by_id(memory_type, memory_id, scope): count += 1 # Invalidate retrieval cache if self._retrieval_cache: uuid_id = ( UUID(str(memory_id)) if not isinstance(memory_id, UUID) else memory_id ) count += self._retrieval_cache.invalidate_by_memory(uuid_id) logger.debug(f"Invalidated {count} cache entries for {memory_type}:{memory_id}") return count async def invalidate_by_type(self, memory_type: str) -> int: """ Invalidate all entries of a memory type. Args: memory_type: Type of memory Returns: Number of entries invalidated """ count = self._hot_cache.invalidate_by_type(memory_type) if self._retrieval_cache: count += self._retrieval_cache.clear() logger.info(f"Invalidated {count} cache entries for type {memory_type}") return count async def invalidate_by_scope(self, scope: str) -> int: """ Invalidate all entries in a scope. Args: scope: Scope to invalidate (e.g., project_id) Returns: Number of entries invalidated """ count = self._hot_cache.invalidate_by_scope(scope) # Retrieval cache doesn't support scope-based invalidation # so we clear it entirely for safety if self._retrieval_cache: count += self._retrieval_cache.clear() logger.info(f"Invalidated {count} cache entries for scope {scope}") return count async def invalidate_embedding( self, content: str, model: str = "default", ) -> bool: """ Invalidate a cached embedding. Args: content: Content text model: Model name Returns: True if entry was found and removed """ return await self._embedding_cache.invalidate(content, model) async def clear_all(self) -> int: """ Clear all caches. Returns: Total number of entries cleared """ count = 0 count += self._hot_cache.clear() count += await self._embedding_cache.clear() if self._retrieval_cache: count += self._retrieval_cache.clear() logger.info(f"Cleared {count} entries from all caches") return count # ========================================================================= # Cleanup # ========================================================================= async def cleanup_expired(self) -> int: """ Clean up expired entries from all caches. Returns: Number of entries cleaned up """ with self._lock: count = 0 count += self._hot_cache.cleanup_expired() count += self._embedding_cache.cleanup_expired() # Retrieval cache doesn't have a cleanup method, # but entries expire on access self._last_cleanup = _utcnow() self._cleanup_count += 1 if count > 0: logger.info(f"Cleaned up {count} expired cache entries") return count # ========================================================================= # Statistics # ========================================================================= def get_stats(self) -> CacheStats: """ Get aggregated cache statistics. Returns: CacheStats with all cache metrics """ hot_stats = self._hot_cache.get_stats().to_dict() emb_stats = self._embedding_cache.get_stats().to_dict() retrieval_stats: dict[str, Any] = {} if self._retrieval_cache: retrieval_stats = self._retrieval_cache.get_stats() # Calculate overall hit rate total_hits = hot_stats.get("hits", 0) + emb_stats.get("hits", 0) total_misses = hot_stats.get("misses", 0) + emb_stats.get("misses", 0) if retrieval_stats: # Retrieval cache doesn't track hits/misses the same way pass total_requests = total_hits + total_misses overall_hit_rate = total_hits / total_requests if total_requests > 0 else 0.0 return CacheStats( hot_cache=hot_stats, embedding_cache=emb_stats, retrieval_cache=retrieval_stats, overall_hit_rate=overall_hit_rate, last_cleanup=self._last_cleanup, cleanup_count=self._cleanup_count, ) def get_hot_memories(self, limit: int = 10) -> list[tuple[CacheKey, int]]: """ Get the most frequently accessed memories. Args: limit: Maximum number to return Returns: List of (key, access_count) tuples """ return self._hot_cache.get_hot_memories(limit) def reset_stats(self) -> None: """Reset all cache statistics.""" self._hot_cache.reset_stats() self._embedding_cache.reset_stats() # ========================================================================= # Warmup # ========================================================================= async def warmup( self, memories: list[tuple[str, UUID | str, Any]], scope: str | None = None, ) -> int: """ Warm up the hot cache with memories. Args: memories: List of (memory_type, memory_id, memory) tuples scope: Optional scope for all memories Returns: Number of memories cached """ if not self._enabled: return 0 for memory_type, memory_id, memory in memories: self._hot_cache.put_by_id(memory_type, memory_id, memory, scope) logger.info(f"Warmed up cache with {len(memories)} memories") return len(memories) # Singleton instance _cache_manager: CacheManager | None = None _cache_manager_lock = threading.Lock() def get_cache_manager( redis: "Redis | None" = None, reset: bool = False, ) -> CacheManager: """ Get the global CacheManager instance. Thread-safe with double-checked locking pattern. Args: redis: Optional Redis connection reset: Force create a new instance Returns: CacheManager instance """ global _cache_manager if reset or _cache_manager is None: with _cache_manager_lock: if reset or _cache_manager is None: _cache_manager = CacheManager(redis=redis) return _cache_manager def reset_cache_manager() -> None: """Reset the global cache manager instance.""" global _cache_manager with _cache_manager_lock: _cache_manager = None