forked from cardosofelipe/fast-next-template
Auto-fixed linting errors and formatting issues: - Removed unused imports (F401): pytest, Any, AnalysisType, MemoryType, OutcomeType - Removed unused variable (F841): hooks variable in test - Applied consistent formatting across memory service and test files 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
506 lines
14 KiB
Python
506 lines
14 KiB
Python
# app/services/memory/cache/cache_manager.py
|
|
"""
|
|
Cache Manager.
|
|
|
|
Unified cache management for memory operations.
|
|
Coordinates hot cache, embedding cache, and retrieval cache.
|
|
Provides centralized invalidation and statistics.
|
|
"""
|
|
|
|
import logging
|
|
import threading
|
|
from dataclasses import dataclass, field
|
|
from datetime import UTC, datetime
|
|
from typing import TYPE_CHECKING, Any
|
|
from uuid import UUID
|
|
|
|
from app.services.memory.config import get_memory_settings
|
|
|
|
from .embedding_cache import EmbeddingCache, create_embedding_cache
|
|
from .hot_cache import CacheKey, HotMemoryCache, create_hot_cache
|
|
|
|
if TYPE_CHECKING:
|
|
from redis.asyncio import Redis
|
|
|
|
from app.services.memory.indexing.retrieval import RetrievalCache
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _utcnow() -> datetime:
|
|
"""Get current UTC time as timezone-aware datetime."""
|
|
return datetime.now(UTC)
|
|
|
|
|
|
@dataclass
|
|
class CacheStats:
|
|
"""Aggregated cache statistics."""
|
|
|
|
hot_cache: dict[str, Any] = field(default_factory=dict)
|
|
embedding_cache: dict[str, Any] = field(default_factory=dict)
|
|
retrieval_cache: dict[str, Any] = field(default_factory=dict)
|
|
overall_hit_rate: float = 0.0
|
|
last_cleanup: datetime | None = None
|
|
cleanup_count: int = 0
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
"""Convert to dictionary."""
|
|
return {
|
|
"hot_cache": self.hot_cache,
|
|
"embedding_cache": self.embedding_cache,
|
|
"retrieval_cache": self.retrieval_cache,
|
|
"overall_hit_rate": self.overall_hit_rate,
|
|
"last_cleanup": self.last_cleanup.isoformat()
|
|
if self.last_cleanup
|
|
else None,
|
|
"cleanup_count": self.cleanup_count,
|
|
}
|
|
|
|
|
|
class CacheManager:
|
|
"""
|
|
Unified cache manager for memory operations.
|
|
|
|
Provides:
|
|
- Centralized cache configuration
|
|
- Coordinated invalidation across caches
|
|
- Aggregated statistics
|
|
- Automatic cleanup scheduling
|
|
|
|
Performance targets:
|
|
- Overall cache hit rate > 80%
|
|
- Cache operations < 1ms (memory), < 5ms (Redis)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
hot_cache: HotMemoryCache[Any] | None = None,
|
|
embedding_cache: EmbeddingCache | None = None,
|
|
retrieval_cache: "RetrievalCache | None" = None,
|
|
redis: "Redis | None" = None,
|
|
) -> None:
|
|
"""
|
|
Initialize the cache manager.
|
|
|
|
Args:
|
|
hot_cache: Optional pre-configured hot cache
|
|
embedding_cache: Optional pre-configured embedding cache
|
|
retrieval_cache: Optional pre-configured retrieval cache
|
|
redis: Optional Redis connection for persistence
|
|
"""
|
|
self._settings = get_memory_settings()
|
|
self._redis = redis
|
|
self._enabled = self._settings.cache_enabled
|
|
|
|
# Initialize caches
|
|
if hot_cache:
|
|
self._hot_cache = hot_cache
|
|
else:
|
|
self._hot_cache = create_hot_cache(
|
|
max_size=self._settings.cache_max_items,
|
|
default_ttl_seconds=self._settings.cache_ttl_seconds,
|
|
)
|
|
|
|
if embedding_cache:
|
|
self._embedding_cache = embedding_cache
|
|
else:
|
|
self._embedding_cache = create_embedding_cache(
|
|
max_size=self._settings.cache_max_items,
|
|
default_ttl_seconds=self._settings.cache_ttl_seconds
|
|
* 12, # 1hr for embeddings
|
|
redis=redis,
|
|
)
|
|
|
|
self._retrieval_cache = retrieval_cache
|
|
|
|
# Stats tracking
|
|
self._last_cleanup: datetime | None = None
|
|
self._cleanup_count = 0
|
|
self._lock = threading.RLock()
|
|
|
|
logger.info(
|
|
f"Initialized CacheManager: enabled={self._enabled}, "
|
|
f"redis={'connected' if redis else 'disabled'}"
|
|
)
|
|
|
|
def set_redis(self, redis: "Redis") -> None:
|
|
"""Set Redis connection for all caches."""
|
|
self._redis = redis
|
|
self._embedding_cache.set_redis(redis)
|
|
|
|
def set_retrieval_cache(self, cache: "RetrievalCache") -> None:
|
|
"""Set retrieval cache instance."""
|
|
self._retrieval_cache = cache
|
|
|
|
@property
|
|
def is_enabled(self) -> bool:
|
|
"""Check if caching is enabled."""
|
|
return self._enabled
|
|
|
|
@property
|
|
def hot_cache(self) -> HotMemoryCache[Any]:
|
|
"""Get the hot memory cache."""
|
|
return self._hot_cache
|
|
|
|
@property
|
|
def embedding_cache(self) -> EmbeddingCache:
|
|
"""Get the embedding cache."""
|
|
return self._embedding_cache
|
|
|
|
@property
|
|
def retrieval_cache(self) -> "RetrievalCache | None":
|
|
"""Get the retrieval cache."""
|
|
return self._retrieval_cache
|
|
|
|
# =========================================================================
|
|
# Hot Memory Cache Operations
|
|
# =========================================================================
|
|
|
|
def get_memory(
|
|
self,
|
|
memory_type: str,
|
|
memory_id: UUID | str,
|
|
scope: str | None = None,
|
|
) -> Any | None:
|
|
"""
|
|
Get a memory from hot cache.
|
|
|
|
Args:
|
|
memory_type: Type of memory
|
|
memory_id: Memory ID
|
|
scope: Optional scope
|
|
|
|
Returns:
|
|
Cached memory or None
|
|
"""
|
|
if not self._enabled:
|
|
return None
|
|
return self._hot_cache.get_by_id(memory_type, memory_id, scope)
|
|
|
|
def cache_memory(
|
|
self,
|
|
memory_type: str,
|
|
memory_id: UUID | str,
|
|
memory: Any,
|
|
scope: str | None = None,
|
|
ttl_seconds: float | None = None,
|
|
) -> None:
|
|
"""
|
|
Cache a memory in hot cache.
|
|
|
|
Args:
|
|
memory_type: Type of memory
|
|
memory_id: Memory ID
|
|
memory: Memory object
|
|
scope: Optional scope
|
|
ttl_seconds: Optional TTL override
|
|
"""
|
|
if not self._enabled:
|
|
return
|
|
self._hot_cache.put_by_id(memory_type, memory_id, memory, scope, ttl_seconds)
|
|
|
|
# =========================================================================
|
|
# Embedding Cache Operations
|
|
# =========================================================================
|
|
|
|
async def get_embedding(
|
|
self,
|
|
content: str,
|
|
model: str = "default",
|
|
) -> list[float] | None:
|
|
"""
|
|
Get a cached embedding.
|
|
|
|
Args:
|
|
content: Content text
|
|
model: Model name
|
|
|
|
Returns:
|
|
Cached embedding or None
|
|
"""
|
|
if not self._enabled:
|
|
return None
|
|
return await self._embedding_cache.get(content, model)
|
|
|
|
async def cache_embedding(
|
|
self,
|
|
content: str,
|
|
embedding: list[float],
|
|
model: str = "default",
|
|
ttl_seconds: float | None = None,
|
|
) -> str:
|
|
"""
|
|
Cache an embedding.
|
|
|
|
Args:
|
|
content: Content text
|
|
embedding: Embedding vector
|
|
model: Model name
|
|
ttl_seconds: Optional TTL override
|
|
|
|
Returns:
|
|
Content hash
|
|
"""
|
|
if not self._enabled:
|
|
return EmbeddingCache.hash_content(content)
|
|
return await self._embedding_cache.put(content, embedding, model, ttl_seconds)
|
|
|
|
# =========================================================================
|
|
# Invalidation
|
|
# =========================================================================
|
|
|
|
async def invalidate_memory(
|
|
self,
|
|
memory_type: str,
|
|
memory_id: UUID | str,
|
|
scope: str | None = None,
|
|
) -> int:
|
|
"""
|
|
Invalidate a memory across all caches.
|
|
|
|
Args:
|
|
memory_type: Type of memory
|
|
memory_id: Memory ID
|
|
scope: Optional scope
|
|
|
|
Returns:
|
|
Number of entries invalidated
|
|
"""
|
|
count = 0
|
|
|
|
# Invalidate hot cache
|
|
if self._hot_cache.invalidate_by_id(memory_type, memory_id, scope):
|
|
count += 1
|
|
|
|
# Invalidate retrieval cache
|
|
if self._retrieval_cache:
|
|
uuid_id = (
|
|
UUID(str(memory_id)) if not isinstance(memory_id, UUID) else memory_id
|
|
)
|
|
count += self._retrieval_cache.invalidate_by_memory(uuid_id)
|
|
|
|
logger.debug(f"Invalidated {count} cache entries for {memory_type}:{memory_id}")
|
|
return count
|
|
|
|
async def invalidate_by_type(self, memory_type: str) -> int:
|
|
"""
|
|
Invalidate all entries of a memory type.
|
|
|
|
Args:
|
|
memory_type: Type of memory
|
|
|
|
Returns:
|
|
Number of entries invalidated
|
|
"""
|
|
count = self._hot_cache.invalidate_by_type(memory_type)
|
|
|
|
if self._retrieval_cache:
|
|
count += self._retrieval_cache.clear()
|
|
|
|
logger.info(f"Invalidated {count} cache entries for type {memory_type}")
|
|
return count
|
|
|
|
async def invalidate_by_scope(self, scope: str) -> int:
|
|
"""
|
|
Invalidate all entries in a scope.
|
|
|
|
Args:
|
|
scope: Scope to invalidate (e.g., project_id)
|
|
|
|
Returns:
|
|
Number of entries invalidated
|
|
"""
|
|
count = self._hot_cache.invalidate_by_scope(scope)
|
|
|
|
# Retrieval cache doesn't support scope-based invalidation
|
|
# so we clear it entirely for safety
|
|
if self._retrieval_cache:
|
|
count += self._retrieval_cache.clear()
|
|
|
|
logger.info(f"Invalidated {count} cache entries for scope {scope}")
|
|
return count
|
|
|
|
async def invalidate_embedding(
|
|
self,
|
|
content: str,
|
|
model: str = "default",
|
|
) -> bool:
|
|
"""
|
|
Invalidate a cached embedding.
|
|
|
|
Args:
|
|
content: Content text
|
|
model: Model name
|
|
|
|
Returns:
|
|
True if entry was found and removed
|
|
"""
|
|
return await self._embedding_cache.invalidate(content, model)
|
|
|
|
async def clear_all(self) -> int:
|
|
"""
|
|
Clear all caches.
|
|
|
|
Returns:
|
|
Total number of entries cleared
|
|
"""
|
|
count = 0
|
|
|
|
count += self._hot_cache.clear()
|
|
count += await self._embedding_cache.clear()
|
|
|
|
if self._retrieval_cache:
|
|
count += self._retrieval_cache.clear()
|
|
|
|
logger.info(f"Cleared {count} entries from all caches")
|
|
return count
|
|
|
|
# =========================================================================
|
|
# Cleanup
|
|
# =========================================================================
|
|
|
|
async def cleanup_expired(self) -> int:
|
|
"""
|
|
Clean up expired entries from all caches.
|
|
|
|
Returns:
|
|
Number of entries cleaned up
|
|
"""
|
|
with self._lock:
|
|
count = 0
|
|
|
|
count += self._hot_cache.cleanup_expired()
|
|
count += self._embedding_cache.cleanup_expired()
|
|
|
|
# Retrieval cache doesn't have a cleanup method,
|
|
# but entries expire on access
|
|
|
|
self._last_cleanup = _utcnow()
|
|
self._cleanup_count += 1
|
|
|
|
if count > 0:
|
|
logger.info(f"Cleaned up {count} expired cache entries")
|
|
|
|
return count
|
|
|
|
# =========================================================================
|
|
# Statistics
|
|
# =========================================================================
|
|
|
|
def get_stats(self) -> CacheStats:
|
|
"""
|
|
Get aggregated cache statistics.
|
|
|
|
Returns:
|
|
CacheStats with all cache metrics
|
|
"""
|
|
hot_stats = self._hot_cache.get_stats().to_dict()
|
|
emb_stats = self._embedding_cache.get_stats().to_dict()
|
|
|
|
retrieval_stats: dict[str, Any] = {}
|
|
if self._retrieval_cache:
|
|
retrieval_stats = self._retrieval_cache.get_stats()
|
|
|
|
# Calculate overall hit rate
|
|
total_hits = hot_stats.get("hits", 0) + emb_stats.get("hits", 0)
|
|
total_misses = hot_stats.get("misses", 0) + emb_stats.get("misses", 0)
|
|
|
|
if retrieval_stats:
|
|
# Retrieval cache doesn't track hits/misses the same way
|
|
pass
|
|
|
|
total_requests = total_hits + total_misses
|
|
overall_hit_rate = total_hits / total_requests if total_requests > 0 else 0.0
|
|
|
|
return CacheStats(
|
|
hot_cache=hot_stats,
|
|
embedding_cache=emb_stats,
|
|
retrieval_cache=retrieval_stats,
|
|
overall_hit_rate=overall_hit_rate,
|
|
last_cleanup=self._last_cleanup,
|
|
cleanup_count=self._cleanup_count,
|
|
)
|
|
|
|
def get_hot_memories(self, limit: int = 10) -> list[tuple[CacheKey, int]]:
|
|
"""
|
|
Get the most frequently accessed memories.
|
|
|
|
Args:
|
|
limit: Maximum number to return
|
|
|
|
Returns:
|
|
List of (key, access_count) tuples
|
|
"""
|
|
return self._hot_cache.get_hot_memories(limit)
|
|
|
|
def reset_stats(self) -> None:
|
|
"""Reset all cache statistics."""
|
|
self._hot_cache.reset_stats()
|
|
self._embedding_cache.reset_stats()
|
|
|
|
# =========================================================================
|
|
# Warmup
|
|
# =========================================================================
|
|
|
|
async def warmup(
|
|
self,
|
|
memories: list[tuple[str, UUID | str, Any]],
|
|
scope: str | None = None,
|
|
) -> int:
|
|
"""
|
|
Warm up the hot cache with memories.
|
|
|
|
Args:
|
|
memories: List of (memory_type, memory_id, memory) tuples
|
|
scope: Optional scope for all memories
|
|
|
|
Returns:
|
|
Number of memories cached
|
|
"""
|
|
if not self._enabled:
|
|
return 0
|
|
|
|
for memory_type, memory_id, memory in memories:
|
|
self._hot_cache.put_by_id(memory_type, memory_id, memory, scope)
|
|
|
|
logger.info(f"Warmed up cache with {len(memories)} memories")
|
|
return len(memories)
|
|
|
|
|
|
# Singleton instance
|
|
_cache_manager: CacheManager | None = None
|
|
_cache_manager_lock = threading.Lock()
|
|
|
|
|
|
def get_cache_manager(
|
|
redis: "Redis | None" = None,
|
|
reset: bool = False,
|
|
) -> CacheManager:
|
|
"""
|
|
Get the global CacheManager instance.
|
|
|
|
Thread-safe with double-checked locking pattern.
|
|
|
|
Args:
|
|
redis: Optional Redis connection
|
|
reset: Force create a new instance
|
|
|
|
Returns:
|
|
CacheManager instance
|
|
"""
|
|
global _cache_manager
|
|
|
|
if reset or _cache_manager is None:
|
|
with _cache_manager_lock:
|
|
if reset or _cache_manager is None:
|
|
_cache_manager = CacheManager(redis=redis)
|
|
|
|
return _cache_manager
|
|
|
|
|
|
def reset_cache_manager() -> None:
|
|
"""Reset the global cache manager instance."""
|
|
global _cache_manager
|
|
with _cache_manager_lock:
|
|
_cache_manager = None
|