forked from cardosofelipe/fast-next-template
feat(memory): implement caching layer for memory operations (#98)
Add comprehensive caching layer for the Agent Memory System: - HotMemoryCache: LRU cache for frequently accessed memories - Python 3.12 type parameter syntax - Thread-safe operations with RLock - TTL-based expiration - Access count tracking for hot memory identification - Scoped invalidation by type, scope, or pattern - EmbeddingCache: Cache embeddings by content hash - Content-hash based deduplication - Optional Redis backing for persistence - LRU eviction with configurable max size - CachedEmbeddingGenerator wrapper for transparent caching - CacheManager: Unified cache management - Coordinates hot cache, embedding cache, and retrieval cache - Centralized invalidation across all caches - Aggregated statistics and hit rate tracking - Automatic cleanup scheduling - Cache warmup support Performance targets: - Cache hit rate > 80% for hot memories - Cache operations < 1ms (memory), < 5ms (Redis) 83 new tests with comprehensive coverage. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
21
backend/app/services/memory/cache/__init__.py
vendored
Normal file
21
backend/app/services/memory/cache/__init__.py
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
# app/services/memory/cache/__init__.py
|
||||
"""
|
||||
Memory Caching Layer.
|
||||
|
||||
Provides caching for memory operations:
|
||||
- Hot Memory Cache: LRU cache for frequently accessed memories
|
||||
- Embedding Cache: Cache embeddings by content hash
|
||||
- Cache Manager: Unified cache management with invalidation
|
||||
"""
|
||||
|
||||
from .cache_manager import CacheManager, CacheStats, get_cache_manager
|
||||
from .embedding_cache import EmbeddingCache
|
||||
from .hot_cache import HotMemoryCache
|
||||
|
||||
__all__ = [
|
||||
"CacheManager",
|
||||
"CacheStats",
|
||||
"EmbeddingCache",
|
||||
"HotMemoryCache",
|
||||
"get_cache_manager",
|
||||
]
|
||||
500
backend/app/services/memory/cache/cache_manager.py
vendored
Normal file
500
backend/app/services/memory/cache/cache_manager.py
vendored
Normal file
@@ -0,0 +1,500 @@
|
||||
# app/services/memory/cache/cache_manager.py
|
||||
"""
|
||||
Cache Manager.
|
||||
|
||||
Unified cache management for memory operations.
|
||||
Coordinates hot cache, embedding cache, and retrieval cache.
|
||||
Provides centralized invalidation and statistics.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import UUID
|
||||
|
||||
from app.services.memory.config import get_memory_settings
|
||||
|
||||
from .embedding_cache import EmbeddingCache, create_embedding_cache
|
||||
from .hot_cache import CacheKey, HotMemoryCache, create_hot_cache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from app.services.memory.indexing.retrieval import RetrievalCache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time as timezone-aware datetime."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats:
|
||||
"""Aggregated cache statistics."""
|
||||
|
||||
hot_cache: dict[str, Any] = field(default_factory=dict)
|
||||
embedding_cache: dict[str, Any] = field(default_factory=dict)
|
||||
retrieval_cache: dict[str, Any] = field(default_factory=dict)
|
||||
overall_hit_rate: float = 0.0
|
||||
last_cleanup: datetime | None = None
|
||||
cleanup_count: int = 0
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"hot_cache": self.hot_cache,
|
||||
"embedding_cache": self.embedding_cache,
|
||||
"retrieval_cache": self.retrieval_cache,
|
||||
"overall_hit_rate": self.overall_hit_rate,
|
||||
"last_cleanup": self.last_cleanup.isoformat() if self.last_cleanup else None,
|
||||
"cleanup_count": self.cleanup_count,
|
||||
}
|
||||
|
||||
|
||||
class CacheManager:
|
||||
"""
|
||||
Unified cache manager for memory operations.
|
||||
|
||||
Provides:
|
||||
- Centralized cache configuration
|
||||
- Coordinated invalidation across caches
|
||||
- Aggregated statistics
|
||||
- Automatic cleanup scheduling
|
||||
|
||||
Performance targets:
|
||||
- Overall cache hit rate > 80%
|
||||
- Cache operations < 1ms (memory), < 5ms (Redis)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hot_cache: HotMemoryCache[Any] | None = None,
|
||||
embedding_cache: EmbeddingCache | None = None,
|
||||
retrieval_cache: "RetrievalCache | None" = None,
|
||||
redis: "Redis | None" = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the cache manager.
|
||||
|
||||
Args:
|
||||
hot_cache: Optional pre-configured hot cache
|
||||
embedding_cache: Optional pre-configured embedding cache
|
||||
retrieval_cache: Optional pre-configured retrieval cache
|
||||
redis: Optional Redis connection for persistence
|
||||
"""
|
||||
self._settings = get_memory_settings()
|
||||
self._redis = redis
|
||||
self._enabled = self._settings.cache_enabled
|
||||
|
||||
# Initialize caches
|
||||
if hot_cache:
|
||||
self._hot_cache = hot_cache
|
||||
else:
|
||||
self._hot_cache = create_hot_cache(
|
||||
max_size=self._settings.cache_max_items,
|
||||
default_ttl_seconds=self._settings.cache_ttl_seconds,
|
||||
)
|
||||
|
||||
if embedding_cache:
|
||||
self._embedding_cache = embedding_cache
|
||||
else:
|
||||
self._embedding_cache = create_embedding_cache(
|
||||
max_size=self._settings.cache_max_items,
|
||||
default_ttl_seconds=self._settings.cache_ttl_seconds * 12, # 1hr for embeddings
|
||||
redis=redis,
|
||||
)
|
||||
|
||||
self._retrieval_cache = retrieval_cache
|
||||
|
||||
# Stats tracking
|
||||
self._last_cleanup: datetime | None = None
|
||||
self._cleanup_count = 0
|
||||
self._lock = threading.RLock()
|
||||
|
||||
logger.info(
|
||||
f"Initialized CacheManager: enabled={self._enabled}, "
|
||||
f"redis={'connected' if redis else 'disabled'}"
|
||||
)
|
||||
|
||||
def set_redis(self, redis: "Redis") -> None:
|
||||
"""Set Redis connection for all caches."""
|
||||
self._redis = redis
|
||||
self._embedding_cache.set_redis(redis)
|
||||
|
||||
def set_retrieval_cache(self, cache: "RetrievalCache") -> None:
|
||||
"""Set retrieval cache instance."""
|
||||
self._retrieval_cache = cache
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
"""Check if caching is enabled."""
|
||||
return self._enabled
|
||||
|
||||
@property
|
||||
def hot_cache(self) -> HotMemoryCache[Any]:
|
||||
"""Get the hot memory cache."""
|
||||
return self._hot_cache
|
||||
|
||||
@property
|
||||
def embedding_cache(self) -> EmbeddingCache:
|
||||
"""Get the embedding cache."""
|
||||
return self._embedding_cache
|
||||
|
||||
@property
|
||||
def retrieval_cache(self) -> "RetrievalCache | None":
|
||||
"""Get the retrieval cache."""
|
||||
return self._retrieval_cache
|
||||
|
||||
# =========================================================================
|
||||
# Hot Memory Cache Operations
|
||||
# =========================================================================
|
||||
|
||||
def get_memory(
|
||||
self,
|
||||
memory_type: str,
|
||||
memory_id: UUID | str,
|
||||
scope: str | None = None,
|
||||
) -> Any | None:
|
||||
"""
|
||||
Get a memory from hot cache.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory
|
||||
memory_id: Memory ID
|
||||
scope: Optional scope
|
||||
|
||||
Returns:
|
||||
Cached memory or None
|
||||
"""
|
||||
if not self._enabled:
|
||||
return None
|
||||
return self._hot_cache.get_by_id(memory_type, memory_id, scope)
|
||||
|
||||
def cache_memory(
|
||||
self,
|
||||
memory_type: str,
|
||||
memory_id: UUID | str,
|
||||
memory: Any,
|
||||
scope: str | None = None,
|
||||
ttl_seconds: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Cache a memory in hot cache.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory
|
||||
memory_id: Memory ID
|
||||
memory: Memory object
|
||||
scope: Optional scope
|
||||
ttl_seconds: Optional TTL override
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
self._hot_cache.put_by_id(memory_type, memory_id, memory, scope, ttl_seconds)
|
||||
|
||||
# =========================================================================
|
||||
# Embedding Cache Operations
|
||||
# =========================================================================
|
||||
|
||||
async def get_embedding(
|
||||
self,
|
||||
content: str,
|
||||
model: str = "default",
|
||||
) -> list[float] | None:
|
||||
"""
|
||||
Get a cached embedding.
|
||||
|
||||
Args:
|
||||
content: Content text
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
Cached embedding or None
|
||||
"""
|
||||
if not self._enabled:
|
||||
return None
|
||||
return await self._embedding_cache.get(content, model)
|
||||
|
||||
async def cache_embedding(
|
||||
self,
|
||||
content: str,
|
||||
embedding: list[float],
|
||||
model: str = "default",
|
||||
ttl_seconds: float | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Cache an embedding.
|
||||
|
||||
Args:
|
||||
content: Content text
|
||||
embedding: Embedding vector
|
||||
model: Model name
|
||||
ttl_seconds: Optional TTL override
|
||||
|
||||
Returns:
|
||||
Content hash
|
||||
"""
|
||||
if not self._enabled:
|
||||
return EmbeddingCache.hash_content(content)
|
||||
return await self._embedding_cache.put(content, embedding, model, ttl_seconds)
|
||||
|
||||
# =========================================================================
|
||||
# Invalidation
|
||||
# =========================================================================
|
||||
|
||||
async def invalidate_memory(
|
||||
self,
|
||||
memory_type: str,
|
||||
memory_id: UUID | str,
|
||||
scope: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Invalidate a memory across all caches.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory
|
||||
memory_id: Memory ID
|
||||
scope: Optional scope
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
count = 0
|
||||
|
||||
# Invalidate hot cache
|
||||
if self._hot_cache.invalidate_by_id(memory_type, memory_id, scope):
|
||||
count += 1
|
||||
|
||||
# Invalidate retrieval cache
|
||||
if self._retrieval_cache:
|
||||
uuid_id = UUID(str(memory_id)) if not isinstance(memory_id, UUID) else memory_id
|
||||
count += self._retrieval_cache.invalidate_by_memory(uuid_id)
|
||||
|
||||
logger.debug(f"Invalidated {count} cache entries for {memory_type}:{memory_id}")
|
||||
return count
|
||||
|
||||
async def invalidate_by_type(self, memory_type: str) -> int:
|
||||
"""
|
||||
Invalidate all entries of a memory type.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
count = self._hot_cache.invalidate_by_type(memory_type)
|
||||
|
||||
if self._retrieval_cache:
|
||||
count += self._retrieval_cache.clear()
|
||||
|
||||
logger.info(f"Invalidated {count} cache entries for type {memory_type}")
|
||||
return count
|
||||
|
||||
async def invalidate_by_scope(self, scope: str) -> int:
|
||||
"""
|
||||
Invalidate all entries in a scope.
|
||||
|
||||
Args:
|
||||
scope: Scope to invalidate (e.g., project_id)
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
count = self._hot_cache.invalidate_by_scope(scope)
|
||||
|
||||
# Retrieval cache doesn't support scope-based invalidation
|
||||
# so we clear it entirely for safety
|
||||
if self._retrieval_cache:
|
||||
count += self._retrieval_cache.clear()
|
||||
|
||||
logger.info(f"Invalidated {count} cache entries for scope {scope}")
|
||||
return count
|
||||
|
||||
async def invalidate_embedding(
|
||||
self,
|
||||
content: str,
|
||||
model: str = "default",
|
||||
) -> bool:
|
||||
"""
|
||||
Invalidate a cached embedding.
|
||||
|
||||
Args:
|
||||
content: Content text
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
True if entry was found and removed
|
||||
"""
|
||||
return await self._embedding_cache.invalidate(content, model)
|
||||
|
||||
async def clear_all(self) -> int:
|
||||
"""
|
||||
Clear all caches.
|
||||
|
||||
Returns:
|
||||
Total number of entries cleared
|
||||
"""
|
||||
count = 0
|
||||
|
||||
count += self._hot_cache.clear()
|
||||
count += await self._embedding_cache.clear()
|
||||
|
||||
if self._retrieval_cache:
|
||||
count += self._retrieval_cache.clear()
|
||||
|
||||
logger.info(f"Cleared {count} entries from all caches")
|
||||
return count
|
||||
|
||||
# =========================================================================
|
||||
# Cleanup
|
||||
# =========================================================================
|
||||
|
||||
async def cleanup_expired(self) -> int:
|
||||
"""
|
||||
Clean up expired entries from all caches.
|
||||
|
||||
Returns:
|
||||
Number of entries cleaned up
|
||||
"""
|
||||
with self._lock:
|
||||
count = 0
|
||||
|
||||
count += self._hot_cache.cleanup_expired()
|
||||
count += self._embedding_cache.cleanup_expired()
|
||||
|
||||
# Retrieval cache doesn't have a cleanup method,
|
||||
# but entries expire on access
|
||||
|
||||
self._last_cleanup = _utcnow()
|
||||
self._cleanup_count += 1
|
||||
|
||||
if count > 0:
|
||||
logger.info(f"Cleaned up {count} expired cache entries")
|
||||
|
||||
return count
|
||||
|
||||
# =========================================================================
|
||||
# Statistics
|
||||
# =========================================================================
|
||||
|
||||
def get_stats(self) -> CacheStats:
|
||||
"""
|
||||
Get aggregated cache statistics.
|
||||
|
||||
Returns:
|
||||
CacheStats with all cache metrics
|
||||
"""
|
||||
hot_stats = self._hot_cache.get_stats().to_dict()
|
||||
emb_stats = self._embedding_cache.get_stats().to_dict()
|
||||
|
||||
retrieval_stats: dict[str, Any] = {}
|
||||
if self._retrieval_cache:
|
||||
retrieval_stats = self._retrieval_cache.get_stats()
|
||||
|
||||
# Calculate overall hit rate
|
||||
total_hits = hot_stats.get("hits", 0) + emb_stats.get("hits", 0)
|
||||
total_misses = hot_stats.get("misses", 0) + emb_stats.get("misses", 0)
|
||||
|
||||
if retrieval_stats:
|
||||
# Retrieval cache doesn't track hits/misses the same way
|
||||
pass
|
||||
|
||||
total_requests = total_hits + total_misses
|
||||
overall_hit_rate = total_hits / total_requests if total_requests > 0 else 0.0
|
||||
|
||||
return CacheStats(
|
||||
hot_cache=hot_stats,
|
||||
embedding_cache=emb_stats,
|
||||
retrieval_cache=retrieval_stats,
|
||||
overall_hit_rate=overall_hit_rate,
|
||||
last_cleanup=self._last_cleanup,
|
||||
cleanup_count=self._cleanup_count,
|
||||
)
|
||||
|
||||
def get_hot_memories(self, limit: int = 10) -> list[tuple[CacheKey, int]]:
|
||||
"""
|
||||
Get the most frequently accessed memories.
|
||||
|
||||
Args:
|
||||
limit: Maximum number to return
|
||||
|
||||
Returns:
|
||||
List of (key, access_count) tuples
|
||||
"""
|
||||
return self._hot_cache.get_hot_memories(limit)
|
||||
|
||||
def reset_stats(self) -> None:
|
||||
"""Reset all cache statistics."""
|
||||
self._hot_cache.reset_stats()
|
||||
self._embedding_cache.reset_stats()
|
||||
|
||||
# =========================================================================
|
||||
# Warmup
|
||||
# =========================================================================
|
||||
|
||||
async def warmup(
|
||||
self,
|
||||
memories: list[tuple[str, UUID | str, Any]],
|
||||
scope: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Warm up the hot cache with memories.
|
||||
|
||||
Args:
|
||||
memories: List of (memory_type, memory_id, memory) tuples
|
||||
scope: Optional scope for all memories
|
||||
|
||||
Returns:
|
||||
Number of memories cached
|
||||
"""
|
||||
if not self._enabled:
|
||||
return 0
|
||||
|
||||
for memory_type, memory_id, memory in memories:
|
||||
self._hot_cache.put_by_id(memory_type, memory_id, memory, scope)
|
||||
|
||||
logger.info(f"Warmed up cache with {len(memories)} memories")
|
||||
return len(memories)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_cache_manager: CacheManager | None = None
|
||||
_cache_manager_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_cache_manager(
|
||||
redis: "Redis | None" = None,
|
||||
reset: bool = False,
|
||||
) -> CacheManager:
|
||||
"""
|
||||
Get the global CacheManager instance.
|
||||
|
||||
Thread-safe with double-checked locking pattern.
|
||||
|
||||
Args:
|
||||
redis: Optional Redis connection
|
||||
reset: Force create a new instance
|
||||
|
||||
Returns:
|
||||
CacheManager instance
|
||||
"""
|
||||
global _cache_manager
|
||||
|
||||
if reset or _cache_manager is None:
|
||||
with _cache_manager_lock:
|
||||
if reset or _cache_manager is None:
|
||||
_cache_manager = CacheManager(redis=redis)
|
||||
|
||||
return _cache_manager
|
||||
|
||||
|
||||
def reset_cache_manager() -> None:
|
||||
"""Reset the global cache manager instance."""
|
||||
global _cache_manager
|
||||
with _cache_manager_lock:
|
||||
_cache_manager = None
|
||||
627
backend/app/services/memory/cache/embedding_cache.py
vendored
Normal file
627
backend/app/services/memory/cache/embedding_cache.py
vendored
Normal file
@@ -0,0 +1,627 @@
|
||||
# app/services/memory/cache/embedding_cache.py
|
||||
"""
|
||||
Embedding Cache.
|
||||
|
||||
Caches embeddings by content hash to avoid recomputing.
|
||||
Provides significant performance improvement for repeated content.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio import Redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time as timezone-aware datetime."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingEntry:
|
||||
"""A cached embedding entry."""
|
||||
|
||||
embedding: list[float]
|
||||
content_hash: str
|
||||
model: str
|
||||
created_at: datetime
|
||||
ttl_seconds: float = 3600.0 # 1 hour default
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if this entry has expired."""
|
||||
age = (_utcnow() - self.created_at).total_seconds()
|
||||
return age > self.ttl_seconds
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingCacheStats:
|
||||
"""Statistics for the embedding cache."""
|
||||
|
||||
hits: int = 0
|
||||
misses: int = 0
|
||||
evictions: int = 0
|
||||
expirations: int = 0
|
||||
current_size: int = 0
|
||||
max_size: int = 0
|
||||
bytes_saved: int = 0 # Estimated bytes saved by caching
|
||||
|
||||
@property
|
||||
def hit_rate(self) -> float:
|
||||
"""Calculate cache hit rate."""
|
||||
total = self.hits + self.misses
|
||||
if total == 0:
|
||||
return 0.0
|
||||
return self.hits / total
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"hits": self.hits,
|
||||
"misses": self.misses,
|
||||
"evictions": self.evictions,
|
||||
"expirations": self.expirations,
|
||||
"current_size": self.current_size,
|
||||
"max_size": self.max_size,
|
||||
"hit_rate": self.hit_rate,
|
||||
"bytes_saved": self.bytes_saved,
|
||||
}
|
||||
|
||||
|
||||
class EmbeddingCache:
|
||||
"""
|
||||
Cache for embeddings by content hash.
|
||||
|
||||
Features:
|
||||
- Content-hash based deduplication
|
||||
- LRU eviction
|
||||
- TTL-based expiration
|
||||
- Optional Redis backing for persistence
|
||||
- Thread-safe operations
|
||||
|
||||
Performance targets:
|
||||
- Cache hit rate > 90% for repeated content
|
||||
- Get/put operations < 1ms (memory), < 5ms (Redis)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_size: int = 50000,
|
||||
default_ttl_seconds: float = 3600.0,
|
||||
redis: "Redis | None" = None,
|
||||
redis_prefix: str = "mem:emb",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the embedding cache.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of entries in memory cache
|
||||
default_ttl_seconds: Default TTL for entries (1 hour)
|
||||
redis: Optional Redis connection for persistence
|
||||
redis_prefix: Prefix for Redis keys
|
||||
"""
|
||||
self._max_size = max_size
|
||||
self._default_ttl = default_ttl_seconds
|
||||
self._cache: OrderedDict[str, EmbeddingEntry] = OrderedDict()
|
||||
self._lock = threading.RLock()
|
||||
self._stats = EmbeddingCacheStats(max_size=max_size)
|
||||
self._redis = redis
|
||||
self._redis_prefix = redis_prefix
|
||||
|
||||
logger.info(
|
||||
f"Initialized EmbeddingCache with max_size={max_size}, "
|
||||
f"ttl={default_ttl_seconds}s, redis={'enabled' if redis else 'disabled'}"
|
||||
)
|
||||
|
||||
def set_redis(self, redis: "Redis") -> None:
|
||||
"""Set Redis connection for persistence."""
|
||||
self._redis = redis
|
||||
|
||||
@staticmethod
|
||||
def hash_content(content: str) -> str:
|
||||
"""
|
||||
Compute hash of content for cache key.
|
||||
|
||||
Args:
|
||||
content: Content to hash
|
||||
|
||||
Returns:
|
||||
32-character hex hash
|
||||
"""
|
||||
return hashlib.sha256(content.encode()).hexdigest()[:32]
|
||||
|
||||
def _cache_key(self, content_hash: str, model: str) -> str:
|
||||
"""Build cache key from content hash and model."""
|
||||
return f"{content_hash}:{model}"
|
||||
|
||||
def _redis_key(self, content_hash: str, model: str) -> str:
|
||||
"""Build Redis key from content hash and model."""
|
||||
return f"{self._redis_prefix}:{content_hash}:{model}"
|
||||
|
||||
async def get(
|
||||
self,
|
||||
content: str,
|
||||
model: str = "default",
|
||||
) -> list[float] | None:
|
||||
"""
|
||||
Get a cached embedding.
|
||||
|
||||
Args:
|
||||
content: Content text
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
Cached embedding or None if not found/expired
|
||||
"""
|
||||
content_hash = self.hash_content(content)
|
||||
cache_key = self._cache_key(content_hash, model)
|
||||
|
||||
# Check memory cache first
|
||||
with self._lock:
|
||||
if cache_key in self._cache:
|
||||
entry = self._cache[cache_key]
|
||||
if entry.is_expired():
|
||||
del self._cache[cache_key]
|
||||
self._stats.expirations += 1
|
||||
self._stats.current_size = len(self._cache)
|
||||
else:
|
||||
# Move to end (most recently used)
|
||||
self._cache.move_to_end(cache_key)
|
||||
self._stats.hits += 1
|
||||
return entry.embedding
|
||||
|
||||
# Check Redis if available
|
||||
if self._redis:
|
||||
try:
|
||||
redis_key = self._redis_key(content_hash, model)
|
||||
data = await self._redis.get(redis_key)
|
||||
if data:
|
||||
import json
|
||||
|
||||
embedding = json.loads(data)
|
||||
# Store in memory cache for faster access
|
||||
self._put_memory(content_hash, model, embedding)
|
||||
self._stats.hits += 1
|
||||
return embedding
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis get error: {e}")
|
||||
|
||||
self._stats.misses += 1
|
||||
return None
|
||||
|
||||
async def get_by_hash(
|
||||
self,
|
||||
content_hash: str,
|
||||
model: str = "default",
|
||||
) -> list[float] | None:
|
||||
"""
|
||||
Get a cached embedding by hash.
|
||||
|
||||
Args:
|
||||
content_hash: Content hash
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
Cached embedding or None if not found/expired
|
||||
"""
|
||||
cache_key = self._cache_key(content_hash, model)
|
||||
|
||||
with self._lock:
|
||||
if cache_key in self._cache:
|
||||
entry = self._cache[cache_key]
|
||||
if entry.is_expired():
|
||||
del self._cache[cache_key]
|
||||
self._stats.expirations += 1
|
||||
self._stats.current_size = len(self._cache)
|
||||
else:
|
||||
self._cache.move_to_end(cache_key)
|
||||
self._stats.hits += 1
|
||||
return entry.embedding
|
||||
|
||||
# Check Redis
|
||||
if self._redis:
|
||||
try:
|
||||
redis_key = self._redis_key(content_hash, model)
|
||||
data = await self._redis.get(redis_key)
|
||||
if data:
|
||||
import json
|
||||
|
||||
embedding = json.loads(data)
|
||||
self._put_memory(content_hash, model, embedding)
|
||||
self._stats.hits += 1
|
||||
return embedding
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis get error: {e}")
|
||||
|
||||
self._stats.misses += 1
|
||||
return None
|
||||
|
||||
async def put(
|
||||
self,
|
||||
content: str,
|
||||
embedding: list[float],
|
||||
model: str = "default",
|
||||
ttl_seconds: float | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Cache an embedding.
|
||||
|
||||
Args:
|
||||
content: Content text
|
||||
embedding: Embedding vector
|
||||
model: Model name
|
||||
ttl_seconds: Optional TTL override
|
||||
|
||||
Returns:
|
||||
Content hash
|
||||
"""
|
||||
content_hash = self.hash_content(content)
|
||||
ttl = ttl_seconds or self._default_ttl
|
||||
|
||||
# Store in memory
|
||||
self._put_memory(content_hash, model, embedding, ttl)
|
||||
|
||||
# Store in Redis if available
|
||||
if self._redis:
|
||||
try:
|
||||
import json
|
||||
|
||||
redis_key = self._redis_key(content_hash, model)
|
||||
await self._redis.setex(
|
||||
redis_key,
|
||||
int(ttl),
|
||||
json.dumps(embedding),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis put error: {e}")
|
||||
|
||||
return content_hash
|
||||
|
||||
def _put_memory(
|
||||
self,
|
||||
content_hash: str,
|
||||
model: str,
|
||||
embedding: list[float],
|
||||
ttl_seconds: float | None = None,
|
||||
) -> None:
|
||||
"""Store in memory cache."""
|
||||
with self._lock:
|
||||
# Evict if at capacity
|
||||
self._evict_if_needed()
|
||||
|
||||
cache_key = self._cache_key(content_hash, model)
|
||||
entry = EmbeddingEntry(
|
||||
embedding=embedding,
|
||||
content_hash=content_hash,
|
||||
model=model,
|
||||
created_at=_utcnow(),
|
||||
ttl_seconds=ttl_seconds or self._default_ttl,
|
||||
)
|
||||
|
||||
self._cache[cache_key] = entry
|
||||
self._cache.move_to_end(cache_key)
|
||||
self._stats.current_size = len(self._cache)
|
||||
|
||||
def _evict_if_needed(self) -> None:
|
||||
"""Evict entries if cache is at capacity."""
|
||||
while len(self._cache) >= self._max_size:
|
||||
if self._cache:
|
||||
self._cache.popitem(last=False)
|
||||
self._stats.evictions += 1
|
||||
|
||||
async def put_batch(
|
||||
self,
|
||||
items: list[tuple[str, list[float]]],
|
||||
model: str = "default",
|
||||
ttl_seconds: float | None = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Cache multiple embeddings.
|
||||
|
||||
Args:
|
||||
items: List of (content, embedding) tuples
|
||||
model: Model name
|
||||
ttl_seconds: Optional TTL override
|
||||
|
||||
Returns:
|
||||
List of content hashes
|
||||
"""
|
||||
hashes = []
|
||||
for content, embedding in items:
|
||||
content_hash = await self.put(content, embedding, model, ttl_seconds)
|
||||
hashes.append(content_hash)
|
||||
return hashes
|
||||
|
||||
async def invalidate(
|
||||
self,
|
||||
content: str,
|
||||
model: str = "default",
|
||||
) -> bool:
|
||||
"""
|
||||
Invalidate a cached embedding.
|
||||
|
||||
Args:
|
||||
content: Content text
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
True if entry was found and removed
|
||||
"""
|
||||
content_hash = self.hash_content(content)
|
||||
return await self.invalidate_by_hash(content_hash, model)
|
||||
|
||||
async def invalidate_by_hash(
|
||||
self,
|
||||
content_hash: str,
|
||||
model: str = "default",
|
||||
) -> bool:
|
||||
"""
|
||||
Invalidate a cached embedding by hash.
|
||||
|
||||
Args:
|
||||
content_hash: Content hash
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
True if entry was found and removed
|
||||
"""
|
||||
cache_key = self._cache_key(content_hash, model)
|
||||
removed = False
|
||||
|
||||
with self._lock:
|
||||
if cache_key in self._cache:
|
||||
del self._cache[cache_key]
|
||||
self._stats.current_size = len(self._cache)
|
||||
removed = True
|
||||
|
||||
# Remove from Redis
|
||||
if self._redis:
|
||||
try:
|
||||
redis_key = self._redis_key(content_hash, model)
|
||||
await self._redis.delete(redis_key)
|
||||
removed = True
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis delete error: {e}")
|
||||
|
||||
return removed
|
||||
|
||||
async def invalidate_by_model(self, model: str) -> int:
|
||||
"""
|
||||
Invalidate all embeddings for a model.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
count = 0
|
||||
|
||||
with self._lock:
|
||||
keys_to_remove = [
|
||||
k for k, v in self._cache.items() if v.model == model
|
||||
]
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
count += 1
|
||||
|
||||
self._stats.current_size = len(self._cache)
|
||||
|
||||
# Note: Redis pattern deletion would require SCAN which is expensive
|
||||
# For now, we only clear memory cache for model-based invalidation
|
||||
|
||||
return count
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""
|
||||
Clear all cache entries.
|
||||
|
||||
Returns:
|
||||
Number of entries cleared
|
||||
"""
|
||||
with self._lock:
|
||||
count = len(self._cache)
|
||||
self._cache.clear()
|
||||
self._stats.current_size = 0
|
||||
|
||||
# Clear Redis entries
|
||||
if self._redis:
|
||||
try:
|
||||
pattern = f"{self._redis_prefix}:*"
|
||||
deleted = 0
|
||||
async for key in self._redis.scan_iter(match=pattern):
|
||||
await self._redis.delete(key)
|
||||
deleted += 1
|
||||
count = max(count, deleted)
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis clear error: {e}")
|
||||
|
||||
logger.info(f"Cleared {count} entries from embedding cache")
|
||||
return count
|
||||
|
||||
def cleanup_expired(self) -> int:
|
||||
"""
|
||||
Remove all expired entries from memory cache.
|
||||
|
||||
Returns:
|
||||
Number of entries removed
|
||||
"""
|
||||
with self._lock:
|
||||
keys_to_remove = [
|
||||
k for k, v in self._cache.items() if v.is_expired()
|
||||
]
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
self._stats.expirations += 1
|
||||
|
||||
self._stats.current_size = len(self._cache)
|
||||
|
||||
if keys_to_remove:
|
||||
logger.debug(f"Cleaned up {len(keys_to_remove)} expired embeddings")
|
||||
|
||||
return len(keys_to_remove)
|
||||
|
||||
def get_stats(self) -> EmbeddingCacheStats:
|
||||
"""Get cache statistics."""
|
||||
with self._lock:
|
||||
self._stats.current_size = len(self._cache)
|
||||
return self._stats
|
||||
|
||||
def reset_stats(self) -> None:
|
||||
"""Reset cache statistics."""
|
||||
with self._lock:
|
||||
self._stats = EmbeddingCacheStats(
|
||||
max_size=self._max_size,
|
||||
current_size=len(self._cache),
|
||||
)
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""Get current cache size."""
|
||||
return len(self._cache)
|
||||
|
||||
@property
|
||||
def max_size(self) -> int:
|
||||
"""Get maximum cache size."""
|
||||
return self._max_size
|
||||
|
||||
|
||||
class CachedEmbeddingGenerator:
|
||||
"""
|
||||
Wrapper for embedding generators with caching.
|
||||
|
||||
Wraps an embedding generator to cache results.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
generator: Any,
|
||||
cache: EmbeddingCache,
|
||||
model: str = "default",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the cached embedding generator.
|
||||
|
||||
Args:
|
||||
generator: Underlying embedding generator
|
||||
cache: Embedding cache
|
||||
model: Model name for cache keys
|
||||
"""
|
||||
self._generator = generator
|
||||
self._cache = cache
|
||||
self._model = model
|
||||
self._call_count = 0
|
||||
self._cache_hit_count = 0
|
||||
|
||||
async def generate(self, text: str) -> list[float]:
|
||||
"""
|
||||
Generate embedding with caching.
|
||||
|
||||
Args:
|
||||
text: Text to embed
|
||||
|
||||
Returns:
|
||||
Embedding vector
|
||||
"""
|
||||
self._call_count += 1
|
||||
|
||||
# Check cache first
|
||||
cached = await self._cache.get(text, self._model)
|
||||
if cached is not None:
|
||||
self._cache_hit_count += 1
|
||||
return cached
|
||||
|
||||
# Generate and cache
|
||||
embedding = await self._generator.generate(text)
|
||||
await self._cache.put(text, embedding, self._model)
|
||||
|
||||
return embedding
|
||||
|
||||
async def generate_batch(
|
||||
self,
|
||||
texts: list[str],
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
Generate embeddings for multiple texts with caching.
|
||||
|
||||
Args:
|
||||
texts: Texts to embed
|
||||
|
||||
Returns:
|
||||
List of embedding vectors
|
||||
"""
|
||||
results: list[list[float] | None] = [None] * len(texts)
|
||||
to_generate: list[tuple[int, str]] = []
|
||||
|
||||
# Check cache for each text
|
||||
for i, text in enumerate(texts):
|
||||
cached = await self._cache.get(text, self._model)
|
||||
if cached is not None:
|
||||
results[i] = cached
|
||||
self._cache_hit_count += 1
|
||||
else:
|
||||
to_generate.append((i, text))
|
||||
|
||||
self._call_count += len(texts)
|
||||
|
||||
# Generate missing embeddings
|
||||
if to_generate:
|
||||
if hasattr(self._generator, "generate_batch"):
|
||||
texts_to_gen = [t for _, t in to_generate]
|
||||
embeddings = await self._generator.generate_batch(texts_to_gen)
|
||||
|
||||
for (idx, text), embedding in zip(to_generate, embeddings, strict=True):
|
||||
results[idx] = embedding
|
||||
await self._cache.put(text, embedding, self._model)
|
||||
else:
|
||||
# Fallback to individual generation
|
||||
for idx, text in to_generate:
|
||||
embedding = await self._generator.generate(text)
|
||||
results[idx] = embedding
|
||||
await self._cache.put(text, embedding, self._model)
|
||||
|
||||
return results # type: ignore[return-value]
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""Get generator statistics."""
|
||||
return {
|
||||
"call_count": self._call_count,
|
||||
"cache_hit_count": self._cache_hit_count,
|
||||
"cache_hit_rate": (
|
||||
self._cache_hit_count / self._call_count
|
||||
if self._call_count > 0
|
||||
else 0.0
|
||||
),
|
||||
"cache_stats": self._cache.get_stats().to_dict(),
|
||||
}
|
||||
|
||||
|
||||
# Factory function
|
||||
def create_embedding_cache(
|
||||
max_size: int = 50000,
|
||||
default_ttl_seconds: float = 3600.0,
|
||||
redis: "Redis | None" = None,
|
||||
) -> EmbeddingCache:
|
||||
"""
|
||||
Create an embedding cache.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of entries
|
||||
default_ttl_seconds: Default TTL for entries
|
||||
redis: Optional Redis connection
|
||||
|
||||
Returns:
|
||||
Configured EmbeddingCache instance
|
||||
"""
|
||||
return EmbeddingCache(
|
||||
max_size=max_size,
|
||||
default_ttl_seconds=default_ttl_seconds,
|
||||
redis=redis,
|
||||
)
|
||||
463
backend/app/services/memory/cache/hot_cache.py
vendored
Normal file
463
backend/app/services/memory/cache/hot_cache.py
vendored
Normal file
@@ -0,0 +1,463 @@
|
||||
# app/services/memory/cache/hot_cache.py
|
||||
"""
|
||||
Hot Memory Cache.
|
||||
|
||||
LRU cache for frequently accessed memories.
|
||||
Provides fast access to recently used memories without database queries.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time as timezone-aware datetime."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry[T]:
|
||||
"""A cached memory entry with metadata."""
|
||||
|
||||
value: T
|
||||
created_at: datetime
|
||||
last_accessed_at: datetime
|
||||
access_count: int = 1
|
||||
ttl_seconds: float = 300.0
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if this entry has expired."""
|
||||
age = (_utcnow() - self.created_at).total_seconds()
|
||||
return age > self.ttl_seconds
|
||||
|
||||
def touch(self) -> None:
|
||||
"""Update access time and count."""
|
||||
self.last_accessed_at = _utcnow()
|
||||
self.access_count += 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheKey:
|
||||
"""A structured cache key with components."""
|
||||
|
||||
memory_type: str
|
||||
memory_id: str
|
||||
scope: str | None = None
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.memory_type, self.memory_id, self.scope))
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, CacheKey):
|
||||
return False
|
||||
return (
|
||||
self.memory_type == other.memory_type
|
||||
and self.memory_id == other.memory_id
|
||||
and self.scope == other.scope
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.scope:
|
||||
return f"{self.memory_type}:{self.scope}:{self.memory_id}"
|
||||
return f"{self.memory_type}:{self.memory_id}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class HotCacheStats:
|
||||
"""Statistics for the hot memory cache."""
|
||||
|
||||
hits: int = 0
|
||||
misses: int = 0
|
||||
evictions: int = 0
|
||||
expirations: int = 0
|
||||
current_size: int = 0
|
||||
max_size: int = 0
|
||||
|
||||
@property
|
||||
def hit_rate(self) -> float:
|
||||
"""Calculate cache hit rate."""
|
||||
total = self.hits + self.misses
|
||||
if total == 0:
|
||||
return 0.0
|
||||
return self.hits / total
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"hits": self.hits,
|
||||
"misses": self.misses,
|
||||
"evictions": self.evictions,
|
||||
"expirations": self.expirations,
|
||||
"current_size": self.current_size,
|
||||
"max_size": self.max_size,
|
||||
"hit_rate": self.hit_rate,
|
||||
}
|
||||
|
||||
|
||||
class HotMemoryCache[T]:
|
||||
"""
|
||||
LRU cache for frequently accessed memories.
|
||||
|
||||
Features:
|
||||
- LRU eviction when capacity is reached
|
||||
- TTL-based expiration
|
||||
- Access count tracking for hot memory identification
|
||||
- Thread-safe operations
|
||||
- Scoped invalidation
|
||||
|
||||
Performance targets:
|
||||
- Cache hit rate > 80% for hot memories
|
||||
- Get/put operations < 1ms
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_size: int = 10000,
|
||||
default_ttl_seconds: float = 300.0,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the hot memory cache.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of entries
|
||||
default_ttl_seconds: Default TTL for entries (5 minutes)
|
||||
"""
|
||||
self._max_size = max_size
|
||||
self._default_ttl = default_ttl_seconds
|
||||
self._cache: OrderedDict[CacheKey, CacheEntry[T]] = OrderedDict()
|
||||
self._lock = threading.RLock()
|
||||
self._stats = HotCacheStats(max_size=max_size)
|
||||
logger.info(
|
||||
f"Initialized HotMemoryCache with max_size={max_size}, "
|
||||
f"ttl={default_ttl_seconds}s"
|
||||
)
|
||||
|
||||
def get(self, key: CacheKey) -> T | None:
|
||||
"""
|
||||
Get a memory from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
Cached value or None if not found/expired
|
||||
"""
|
||||
with self._lock:
|
||||
if key not in self._cache:
|
||||
self._stats.misses += 1
|
||||
return None
|
||||
|
||||
entry = self._cache[key]
|
||||
|
||||
# Check expiration
|
||||
if entry.is_expired():
|
||||
del self._cache[key]
|
||||
self._stats.expirations += 1
|
||||
self._stats.misses += 1
|
||||
self._stats.current_size = len(self._cache)
|
||||
return None
|
||||
|
||||
# Move to end (most recently used)
|
||||
self._cache.move_to_end(key)
|
||||
entry.touch()
|
||||
|
||||
self._stats.hits += 1
|
||||
return entry.value
|
||||
|
||||
def get_by_id(
|
||||
self,
|
||||
memory_type: str,
|
||||
memory_id: UUID | str,
|
||||
scope: str | None = None,
|
||||
) -> T | None:
|
||||
"""
|
||||
Get a memory by type and ID.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory (episodic, semantic, procedural)
|
||||
memory_id: Memory ID
|
||||
scope: Optional scope (project_id, agent_id)
|
||||
|
||||
Returns:
|
||||
Cached value or None if not found/expired
|
||||
"""
|
||||
key = CacheKey(
|
||||
memory_type=memory_type,
|
||||
memory_id=str(memory_id),
|
||||
scope=scope,
|
||||
)
|
||||
return self.get(key)
|
||||
|
||||
def put(
|
||||
self,
|
||||
key: CacheKey,
|
||||
value: T,
|
||||
ttl_seconds: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Put a memory into cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Value to cache
|
||||
ttl_seconds: Optional TTL override
|
||||
"""
|
||||
with self._lock:
|
||||
# Evict if at capacity
|
||||
self._evict_if_needed()
|
||||
|
||||
now = _utcnow()
|
||||
entry = CacheEntry(
|
||||
value=value,
|
||||
created_at=now,
|
||||
last_accessed_at=now,
|
||||
access_count=1,
|
||||
ttl_seconds=ttl_seconds or self._default_ttl,
|
||||
)
|
||||
|
||||
self._cache[key] = entry
|
||||
self._cache.move_to_end(key)
|
||||
self._stats.current_size = len(self._cache)
|
||||
|
||||
def put_by_id(
|
||||
self,
|
||||
memory_type: str,
|
||||
memory_id: UUID | str,
|
||||
value: T,
|
||||
scope: str | None = None,
|
||||
ttl_seconds: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Put a memory by type and ID.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory
|
||||
memory_id: Memory ID
|
||||
value: Value to cache
|
||||
scope: Optional scope
|
||||
ttl_seconds: Optional TTL override
|
||||
"""
|
||||
key = CacheKey(
|
||||
memory_type=memory_type,
|
||||
memory_id=str(memory_id),
|
||||
scope=scope,
|
||||
)
|
||||
self.put(key, value, ttl_seconds)
|
||||
|
||||
def _evict_if_needed(self) -> None:
|
||||
"""Evict entries if cache is at capacity."""
|
||||
while len(self._cache) >= self._max_size:
|
||||
# Remove least recently used (first item)
|
||||
if self._cache:
|
||||
self._cache.popitem(last=False)
|
||||
self._stats.evictions += 1
|
||||
|
||||
def invalidate(self, key: CacheKey) -> bool:
|
||||
"""
|
||||
Invalidate a specific cache entry.
|
||||
|
||||
Args:
|
||||
key: Cache key to invalidate
|
||||
|
||||
Returns:
|
||||
True if entry was found and removed
|
||||
"""
|
||||
with self._lock:
|
||||
if key in self._cache:
|
||||
del self._cache[key]
|
||||
self._stats.current_size = len(self._cache)
|
||||
return True
|
||||
return False
|
||||
|
||||
def invalidate_by_id(
|
||||
self,
|
||||
memory_type: str,
|
||||
memory_id: UUID | str,
|
||||
scope: str | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Invalidate a memory by type and ID.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory
|
||||
memory_id: Memory ID
|
||||
scope: Optional scope
|
||||
|
||||
Returns:
|
||||
True if entry was found and removed
|
||||
"""
|
||||
key = CacheKey(
|
||||
memory_type=memory_type,
|
||||
memory_id=str(memory_id),
|
||||
scope=scope,
|
||||
)
|
||||
return self.invalidate(key)
|
||||
|
||||
def invalidate_by_type(self, memory_type: str) -> int:
|
||||
"""
|
||||
Invalidate all entries of a memory type.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory to invalidate
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
with self._lock:
|
||||
keys_to_remove = [
|
||||
k for k in self._cache.keys() if k.memory_type == memory_type
|
||||
]
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
|
||||
self._stats.current_size = len(self._cache)
|
||||
return len(keys_to_remove)
|
||||
|
||||
def invalidate_by_scope(self, scope: str) -> int:
|
||||
"""
|
||||
Invalidate all entries in a scope.
|
||||
|
||||
Args:
|
||||
scope: Scope to invalidate (e.g., project_id)
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
with self._lock:
|
||||
keys_to_remove = [k for k in self._cache.keys() if k.scope == scope]
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
|
||||
self._stats.current_size = len(self._cache)
|
||||
return len(keys_to_remove)
|
||||
|
||||
def invalidate_pattern(self, pattern: str) -> int:
|
||||
"""
|
||||
Invalidate entries matching a pattern.
|
||||
|
||||
Pattern can include * as wildcard.
|
||||
|
||||
Args:
|
||||
pattern: Pattern to match (e.g., "episodic:*")
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
import fnmatch
|
||||
|
||||
with self._lock:
|
||||
keys_to_remove = [
|
||||
k for k in self._cache.keys() if fnmatch.fnmatch(str(k), pattern)
|
||||
]
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
|
||||
self._stats.current_size = len(self._cache)
|
||||
return len(keys_to_remove)
|
||||
|
||||
def clear(self) -> int:
|
||||
"""
|
||||
Clear all cache entries.
|
||||
|
||||
Returns:
|
||||
Number of entries cleared
|
||||
"""
|
||||
with self._lock:
|
||||
count = len(self._cache)
|
||||
self._cache.clear()
|
||||
self._stats.current_size = 0
|
||||
logger.info(f"Cleared {count} entries from hot cache")
|
||||
return count
|
||||
|
||||
def cleanup_expired(self) -> int:
|
||||
"""
|
||||
Remove all expired entries.
|
||||
|
||||
Returns:
|
||||
Number of entries removed
|
||||
"""
|
||||
with self._lock:
|
||||
keys_to_remove = [
|
||||
k for k, v in self._cache.items() if v.is_expired()
|
||||
]
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
self._stats.expirations += 1
|
||||
|
||||
self._stats.current_size = len(self._cache)
|
||||
|
||||
if keys_to_remove:
|
||||
logger.debug(f"Cleaned up {len(keys_to_remove)} expired entries")
|
||||
|
||||
return len(keys_to_remove)
|
||||
|
||||
def get_hot_memories(self, limit: int = 10) -> list[tuple[CacheKey, int]]:
|
||||
"""
|
||||
Get the most frequently accessed memories.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of memories to return
|
||||
|
||||
Returns:
|
||||
List of (key, access_count) tuples sorted by access count
|
||||
"""
|
||||
with self._lock:
|
||||
entries = [
|
||||
(k, v.access_count)
|
||||
for k, v in self._cache.items()
|
||||
if not v.is_expired()
|
||||
]
|
||||
entries.sort(key=lambda x: x[1], reverse=True)
|
||||
return entries[:limit]
|
||||
|
||||
def get_stats(self) -> HotCacheStats:
|
||||
"""Get cache statistics."""
|
||||
with self._lock:
|
||||
self._stats.current_size = len(self._cache)
|
||||
return self._stats
|
||||
|
||||
def reset_stats(self) -> None:
|
||||
"""Reset cache statistics."""
|
||||
with self._lock:
|
||||
self._stats = HotCacheStats(
|
||||
max_size=self._max_size,
|
||||
current_size=len(self._cache),
|
||||
)
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""Get current cache size."""
|
||||
return len(self._cache)
|
||||
|
||||
@property
|
||||
def max_size(self) -> int:
|
||||
"""Get maximum cache size."""
|
||||
return self._max_size
|
||||
|
||||
|
||||
# Factory function for typed caches
|
||||
def create_hot_cache(
|
||||
max_size: int = 10000,
|
||||
default_ttl_seconds: float = 300.0,
|
||||
) -> HotMemoryCache[Any]:
|
||||
"""
|
||||
Create a hot memory cache.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of entries
|
||||
default_ttl_seconds: Default TTL for entries
|
||||
|
||||
Returns:
|
||||
Configured HotMemoryCache instance
|
||||
"""
|
||||
return HotMemoryCache(
|
||||
max_size=max_size,
|
||||
default_ttl_seconds=default_ttl_seconds,
|
||||
)
|
||||
2
backend/tests/unit/services/memory/cache/__init__.py
vendored
Normal file
2
backend/tests/unit/services/memory/cache/__init__.py
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
# tests/unit/services/memory/cache/__init__.py
|
||||
"""Tests for memory caching layer."""
|
||||
331
backend/tests/unit/services/memory/cache/test_cache_manager.py
vendored
Normal file
331
backend/tests/unit/services/memory/cache/test_cache_manager.py
vendored
Normal file
@@ -0,0 +1,331 @@
|
||||
# tests/unit/services/memory/cache/test_cache_manager.py
|
||||
"""Tests for CacheManager."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.cache.cache_manager import (
|
||||
CacheManager,
|
||||
CacheStats,
|
||||
get_cache_manager,
|
||||
reset_cache_manager,
|
||||
)
|
||||
from app.services.memory.cache.embedding_cache import EmbeddingCache
|
||||
from app.services.memory.cache.hot_cache import HotMemoryCache
|
||||
|
||||
pytestmark = pytest.mark.asyncio(loop_scope="function")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_singleton() -> None:
|
||||
"""Reset singleton before each test."""
|
||||
reset_cache_manager()
|
||||
|
||||
|
||||
class TestCacheStats:
|
||||
"""Tests for CacheStats."""
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Should convert to dictionary."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
stats = CacheStats(
|
||||
hot_cache={"hits": 10},
|
||||
embedding_cache={"hits": 20},
|
||||
overall_hit_rate=0.75,
|
||||
last_cleanup=datetime.now(UTC),
|
||||
cleanup_count=5,
|
||||
)
|
||||
|
||||
result = stats.to_dict()
|
||||
|
||||
assert result["hot_cache"] == {"hits": 10}
|
||||
assert result["overall_hit_rate"] == 0.75
|
||||
assert result["cleanup_count"] == 5
|
||||
assert result["last_cleanup"] is not None
|
||||
|
||||
|
||||
class TestCacheManager:
|
||||
"""Tests for CacheManager."""
|
||||
|
||||
@pytest.fixture
|
||||
def manager(self) -> CacheManager:
|
||||
"""Create a cache manager."""
|
||||
return CacheManager()
|
||||
|
||||
def test_is_enabled(self, manager: CacheManager) -> None:
|
||||
"""Should check if caching is enabled."""
|
||||
# Default is enabled from settings
|
||||
assert manager.is_enabled is True
|
||||
|
||||
def test_has_hot_cache(self, manager: CacheManager) -> None:
|
||||
"""Should have hot memory cache."""
|
||||
assert manager.hot_cache is not None
|
||||
assert isinstance(manager.hot_cache, HotMemoryCache)
|
||||
|
||||
def test_has_embedding_cache(self, manager: CacheManager) -> None:
|
||||
"""Should have embedding cache."""
|
||||
assert manager.embedding_cache is not None
|
||||
assert isinstance(manager.embedding_cache, EmbeddingCache)
|
||||
|
||||
def test_cache_memory(self, manager: CacheManager) -> None:
|
||||
"""Should cache memory in hot cache."""
|
||||
memory_id = uuid4()
|
||||
memory = {"task": "test", "data": "value"}
|
||||
|
||||
manager.cache_memory("episodic", memory_id, memory)
|
||||
result = manager.get_memory("episodic", memory_id)
|
||||
|
||||
assert result == memory
|
||||
|
||||
def test_cache_memory_with_scope(self, manager: CacheManager) -> None:
|
||||
"""Should cache memory with scope."""
|
||||
memory_id = uuid4()
|
||||
memory = {"task": "test"}
|
||||
|
||||
manager.cache_memory("semantic", memory_id, memory, scope="proj-123")
|
||||
result = manager.get_memory("semantic", memory_id, scope="proj-123")
|
||||
|
||||
assert result == memory
|
||||
|
||||
async def test_cache_embedding(self, manager: CacheManager) -> None:
|
||||
"""Should cache embedding."""
|
||||
content = "test content"
|
||||
embedding = [0.1, 0.2, 0.3]
|
||||
|
||||
content_hash = await manager.cache_embedding(content, embedding)
|
||||
result = await manager.get_embedding(content)
|
||||
|
||||
assert result == embedding
|
||||
assert len(content_hash) == 32
|
||||
|
||||
async def test_invalidate_memory(self, manager: CacheManager) -> None:
|
||||
"""Should invalidate memory from hot cache."""
|
||||
memory_id = uuid4()
|
||||
manager.cache_memory("episodic", memory_id, {"data": "test"})
|
||||
|
||||
count = await manager.invalidate_memory("episodic", memory_id)
|
||||
|
||||
assert count >= 1
|
||||
assert manager.get_memory("episodic", memory_id) is None
|
||||
|
||||
async def test_invalidate_by_type(self, manager: CacheManager) -> None:
|
||||
"""Should invalidate all entries of a type."""
|
||||
manager.cache_memory("episodic", uuid4(), {"data": "1"})
|
||||
manager.cache_memory("episodic", uuid4(), {"data": "2"})
|
||||
manager.cache_memory("semantic", uuid4(), {"data": "3"})
|
||||
|
||||
count = await manager.invalidate_by_type("episodic")
|
||||
|
||||
assert count >= 2
|
||||
|
||||
async def test_invalidate_by_scope(self, manager: CacheManager) -> None:
|
||||
"""Should invalidate all entries in a scope."""
|
||||
manager.cache_memory("episodic", uuid4(), {"data": "1"}, scope="proj-1")
|
||||
manager.cache_memory("semantic", uuid4(), {"data": "2"}, scope="proj-1")
|
||||
manager.cache_memory("episodic", uuid4(), {"data": "3"}, scope="proj-2")
|
||||
|
||||
count = await manager.invalidate_by_scope("proj-1")
|
||||
|
||||
assert count >= 2
|
||||
|
||||
async def test_invalidate_embedding(self, manager: CacheManager) -> None:
|
||||
"""Should invalidate cached embedding."""
|
||||
content = "test content"
|
||||
await manager.cache_embedding(content, [0.1, 0.2])
|
||||
|
||||
result = await manager.invalidate_embedding(content)
|
||||
|
||||
assert result is True
|
||||
assert await manager.get_embedding(content) is None
|
||||
|
||||
async def test_clear_all(self, manager: CacheManager) -> None:
|
||||
"""Should clear all caches."""
|
||||
manager.cache_memory("episodic", uuid4(), {"data": "test"})
|
||||
await manager.cache_embedding("content", [0.1])
|
||||
|
||||
count = await manager.clear_all()
|
||||
|
||||
assert count >= 2
|
||||
|
||||
async def test_cleanup_expired(self, manager: CacheManager) -> None:
|
||||
"""Should clean up expired entries."""
|
||||
count = await manager.cleanup_expired()
|
||||
|
||||
# May be 0 if no expired entries
|
||||
assert count >= 0
|
||||
assert manager._cleanup_count == 1
|
||||
assert manager._last_cleanup is not None
|
||||
|
||||
def test_get_stats(self, manager: CacheManager) -> None:
|
||||
"""Should return aggregated statistics."""
|
||||
manager.cache_memory("episodic", uuid4(), {"data": "test"})
|
||||
|
||||
stats = manager.get_stats()
|
||||
|
||||
assert "hot_cache" in stats.to_dict()
|
||||
assert "embedding_cache" in stats.to_dict()
|
||||
assert "overall_hit_rate" in stats.to_dict()
|
||||
|
||||
def test_get_hot_memories(self, manager: CacheManager) -> None:
|
||||
"""Should return most accessed memories."""
|
||||
id1 = uuid4()
|
||||
id2 = uuid4()
|
||||
|
||||
manager.cache_memory("episodic", id1, {"data": "1"})
|
||||
manager.cache_memory("episodic", id2, {"data": "2"})
|
||||
|
||||
# Access first multiple times
|
||||
for _ in range(5):
|
||||
manager.get_memory("episodic", id1)
|
||||
|
||||
hot = manager.get_hot_memories(limit=2)
|
||||
|
||||
assert len(hot) == 2
|
||||
|
||||
def test_reset_stats(self, manager: CacheManager) -> None:
|
||||
"""Should reset all statistics."""
|
||||
manager.cache_memory("episodic", uuid4(), {"data": "test"})
|
||||
manager.get_memory("episodic", uuid4()) # Miss
|
||||
|
||||
manager.reset_stats()
|
||||
|
||||
stats = manager.get_stats()
|
||||
assert stats.hot_cache.get("hits", 0) == 0
|
||||
|
||||
async def test_warmup(self, manager: CacheManager) -> None:
|
||||
"""Should warm up cache with memories."""
|
||||
memories = [
|
||||
("episodic", uuid4(), {"data": "1"}),
|
||||
("episodic", uuid4(), {"data": "2"}),
|
||||
("semantic", uuid4(), {"data": "3"}),
|
||||
]
|
||||
|
||||
count = await manager.warmup(memories)
|
||||
|
||||
assert count == 3
|
||||
|
||||
|
||||
class TestCacheManagerWithRetrieval:
|
||||
"""Tests for CacheManager with retrieval cache."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_retrieval_cache(self) -> MagicMock:
|
||||
"""Create mock retrieval cache."""
|
||||
cache = MagicMock()
|
||||
cache.invalidate_by_memory = MagicMock(return_value=1)
|
||||
cache.clear = MagicMock(return_value=5)
|
||||
cache.get_stats = MagicMock(return_value={"entries": 10})
|
||||
return cache
|
||||
|
||||
@pytest.fixture
|
||||
def manager_with_retrieval(
|
||||
self,
|
||||
mock_retrieval_cache: MagicMock,
|
||||
) -> CacheManager:
|
||||
"""Create manager with retrieval cache."""
|
||||
manager = CacheManager()
|
||||
manager.set_retrieval_cache(mock_retrieval_cache)
|
||||
return manager
|
||||
|
||||
async def test_invalidate_clears_retrieval(
|
||||
self,
|
||||
manager_with_retrieval: CacheManager,
|
||||
mock_retrieval_cache: MagicMock,
|
||||
) -> None:
|
||||
"""Should invalidate retrieval cache entries."""
|
||||
memory_id = uuid4()
|
||||
|
||||
await manager_with_retrieval.invalidate_memory("episodic", memory_id)
|
||||
|
||||
mock_retrieval_cache.invalidate_by_memory.assert_called_once_with(memory_id)
|
||||
|
||||
def test_stats_includes_retrieval(
|
||||
self,
|
||||
manager_with_retrieval: CacheManager,
|
||||
) -> None:
|
||||
"""Should include retrieval cache stats."""
|
||||
stats = manager_with_retrieval.get_stats()
|
||||
|
||||
assert "retrieval_cache" in stats.to_dict()
|
||||
|
||||
|
||||
class TestCacheManagerDisabled:
|
||||
"""Tests for CacheManager when disabled."""
|
||||
|
||||
@pytest.fixture
|
||||
def disabled_manager(self) -> CacheManager:
|
||||
"""Create a disabled cache manager."""
|
||||
with patch(
|
||||
"app.services.memory.cache.cache_manager.get_memory_settings"
|
||||
) as mock_settings:
|
||||
settings = MagicMock()
|
||||
settings.cache_enabled = False
|
||||
settings.cache_max_items = 1000
|
||||
settings.cache_ttl_seconds = 300
|
||||
mock_settings.return_value = settings
|
||||
|
||||
return CacheManager()
|
||||
|
||||
def test_get_memory_returns_none(self, disabled_manager: CacheManager) -> None:
|
||||
"""Should return None when disabled."""
|
||||
disabled_manager.cache_memory("episodic", uuid4(), {"data": "test"})
|
||||
result = disabled_manager.get_memory("episodic", uuid4())
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_get_embedding_returns_none(
|
||||
self,
|
||||
disabled_manager: CacheManager,
|
||||
) -> None:
|
||||
"""Should return None for embeddings when disabled."""
|
||||
result = await disabled_manager.get_embedding("content")
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_warmup_returns_zero(self, disabled_manager: CacheManager) -> None:
|
||||
"""Should return 0 from warmup when disabled."""
|
||||
count = await disabled_manager.warmup([("episodic", uuid4(), {})])
|
||||
|
||||
assert count == 0
|
||||
|
||||
|
||||
class TestGetCacheManager:
|
||||
"""Tests for get_cache_manager factory."""
|
||||
|
||||
def test_returns_singleton(self) -> None:
|
||||
"""Should return same instance."""
|
||||
manager1 = get_cache_manager()
|
||||
manager2 = get_cache_manager()
|
||||
|
||||
assert manager1 is manager2
|
||||
|
||||
def test_reset_creates_new(self) -> None:
|
||||
"""Should create new instance after reset."""
|
||||
manager1 = get_cache_manager()
|
||||
reset_cache_manager()
|
||||
manager2 = get_cache_manager()
|
||||
|
||||
assert manager1 is not manager2
|
||||
|
||||
def test_reset_parameter(self) -> None:
|
||||
"""Should create new instance with reset=True."""
|
||||
manager1 = get_cache_manager()
|
||||
manager2 = get_cache_manager(reset=True)
|
||||
|
||||
assert manager1 is not manager2
|
||||
|
||||
|
||||
class TestResetCacheManager:
|
||||
"""Tests for reset_cache_manager."""
|
||||
|
||||
def test_resets_singleton(self) -> None:
|
||||
"""Should reset the singleton."""
|
||||
get_cache_manager()
|
||||
reset_cache_manager()
|
||||
|
||||
# Next call should create new instance
|
||||
manager = get_cache_manager()
|
||||
assert manager is not None
|
||||
391
backend/tests/unit/services/memory/cache/test_embedding_cache.py
vendored
Normal file
391
backend/tests/unit/services/memory/cache/test_embedding_cache.py
vendored
Normal file
@@ -0,0 +1,391 @@
|
||||
# tests/unit/services/memory/cache/test_embedding_cache.py
|
||||
"""Tests for EmbeddingCache."""
|
||||
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.cache.embedding_cache import (
|
||||
CachedEmbeddingGenerator,
|
||||
EmbeddingCache,
|
||||
EmbeddingCacheStats,
|
||||
EmbeddingEntry,
|
||||
create_embedding_cache,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.asyncio(loop_scope="function")
|
||||
|
||||
|
||||
class TestEmbeddingEntry:
|
||||
"""Tests for EmbeddingEntry."""
|
||||
|
||||
def test_creates_entry(self) -> None:
|
||||
"""Should create entry with embedding."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
entry = EmbeddingEntry(
|
||||
embedding=[0.1, 0.2, 0.3],
|
||||
content_hash="abc123",
|
||||
model="text-embedding-3-small",
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
assert entry.embedding == [0.1, 0.2, 0.3]
|
||||
assert entry.content_hash == "abc123"
|
||||
assert entry.ttl_seconds == 3600.0
|
||||
|
||||
def test_is_expired(self) -> None:
|
||||
"""Should detect expired entries."""
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
old_time = datetime.now(UTC) - timedelta(seconds=4000)
|
||||
entry = EmbeddingEntry(
|
||||
embedding=[0.1],
|
||||
content_hash="abc",
|
||||
model="default",
|
||||
created_at=old_time,
|
||||
ttl_seconds=3600.0,
|
||||
)
|
||||
|
||||
assert entry.is_expired() is True
|
||||
|
||||
def test_not_expired(self) -> None:
|
||||
"""Should detect non-expired entries."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
entry = EmbeddingEntry(
|
||||
embedding=[0.1],
|
||||
content_hash="abc",
|
||||
model="default",
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
assert entry.is_expired() is False
|
||||
|
||||
|
||||
class TestEmbeddingCacheStats:
|
||||
"""Tests for EmbeddingCacheStats."""
|
||||
|
||||
def test_hit_rate_calculation(self) -> None:
|
||||
"""Should calculate hit rate correctly."""
|
||||
stats = EmbeddingCacheStats(hits=90, misses=10)
|
||||
|
||||
assert stats.hit_rate == 0.9
|
||||
|
||||
def test_hit_rate_zero_requests(self) -> None:
|
||||
"""Should return 0 for no requests."""
|
||||
stats = EmbeddingCacheStats()
|
||||
|
||||
assert stats.hit_rate == 0.0
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Should convert to dictionary."""
|
||||
stats = EmbeddingCacheStats(hits=10, misses=5, bytes_saved=1000)
|
||||
|
||||
result = stats.to_dict()
|
||||
|
||||
assert result["hits"] == 10
|
||||
assert result["bytes_saved"] == 1000
|
||||
|
||||
|
||||
class TestEmbeddingCache:
|
||||
"""Tests for EmbeddingCache."""
|
||||
|
||||
@pytest.fixture
|
||||
def cache(self) -> EmbeddingCache:
|
||||
"""Create an embedding cache."""
|
||||
return EmbeddingCache(max_size=100, default_ttl_seconds=300.0)
|
||||
|
||||
async def test_put_and_get(self, cache: EmbeddingCache) -> None:
|
||||
"""Should store and retrieve embeddings."""
|
||||
content = "Hello world"
|
||||
embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
|
||||
content_hash = await cache.put(content, embedding)
|
||||
result = await cache.get(content)
|
||||
|
||||
assert result == embedding
|
||||
assert len(content_hash) == 32
|
||||
|
||||
async def test_get_missing(self, cache: EmbeddingCache) -> None:
|
||||
"""Should return None for missing content."""
|
||||
result = await cache.get("nonexistent content")
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_get_by_hash(self, cache: EmbeddingCache) -> None:
|
||||
"""Should get by content hash."""
|
||||
content = "Test content"
|
||||
embedding = [0.1, 0.2]
|
||||
|
||||
content_hash = await cache.put(content, embedding)
|
||||
result = await cache.get_by_hash(content_hash)
|
||||
|
||||
assert result == embedding
|
||||
|
||||
async def test_model_separation(self, cache: EmbeddingCache) -> None:
|
||||
"""Should separate embeddings by model."""
|
||||
content = "Same content"
|
||||
emb1 = [0.1, 0.2]
|
||||
emb2 = [0.3, 0.4]
|
||||
|
||||
await cache.put(content, emb1, model="model-a")
|
||||
await cache.put(content, emb2, model="model-b")
|
||||
|
||||
result1 = await cache.get(content, model="model-a")
|
||||
result2 = await cache.get(content, model="model-b")
|
||||
|
||||
assert result1 == emb1
|
||||
assert result2 == emb2
|
||||
|
||||
async def test_lru_eviction(self) -> None:
|
||||
"""Should evict LRU entries when at capacity."""
|
||||
cache = EmbeddingCache(max_size=3)
|
||||
|
||||
await cache.put("content1", [0.1])
|
||||
await cache.put("content2", [0.2])
|
||||
await cache.put("content3", [0.3])
|
||||
|
||||
# Access first to make it recent
|
||||
await cache.get("content1")
|
||||
|
||||
# Add fourth, should evict second (LRU)
|
||||
await cache.put("content4", [0.4])
|
||||
|
||||
assert await cache.get("content1") is not None
|
||||
assert await cache.get("content2") is None # Evicted
|
||||
assert await cache.get("content3") is not None
|
||||
assert await cache.get("content4") is not None
|
||||
|
||||
async def test_ttl_expiration(self) -> None:
|
||||
"""Should expire entries after TTL."""
|
||||
cache = EmbeddingCache(max_size=100, default_ttl_seconds=0.1)
|
||||
|
||||
await cache.put("content", [0.1, 0.2])
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
result = await cache.get("content")
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_put_batch(self, cache: EmbeddingCache) -> None:
|
||||
"""Should cache multiple embeddings."""
|
||||
items = [
|
||||
("content1", [0.1]),
|
||||
("content2", [0.2]),
|
||||
("content3", [0.3]),
|
||||
]
|
||||
|
||||
hashes = await cache.put_batch(items)
|
||||
|
||||
assert len(hashes) == 3
|
||||
assert await cache.get("content1") == [0.1]
|
||||
assert await cache.get("content2") == [0.2]
|
||||
|
||||
async def test_invalidate(self, cache: EmbeddingCache) -> None:
|
||||
"""Should invalidate cached embedding."""
|
||||
await cache.put("content", [0.1, 0.2])
|
||||
|
||||
result = await cache.invalidate("content")
|
||||
|
||||
assert result is True
|
||||
assert await cache.get("content") is None
|
||||
|
||||
async def test_invalidate_by_hash(self, cache: EmbeddingCache) -> None:
|
||||
"""Should invalidate by hash."""
|
||||
content_hash = await cache.put("content", [0.1, 0.2])
|
||||
|
||||
result = await cache.invalidate_by_hash(content_hash)
|
||||
|
||||
assert result is True
|
||||
assert await cache.get("content") is None
|
||||
|
||||
async def test_invalidate_by_model(self, cache: EmbeddingCache) -> None:
|
||||
"""Should invalidate all embeddings for a model."""
|
||||
await cache.put("content1", [0.1], model="model-a")
|
||||
await cache.put("content2", [0.2], model="model-a")
|
||||
await cache.put("content3", [0.3], model="model-b")
|
||||
|
||||
count = await cache.invalidate_by_model("model-a")
|
||||
|
||||
assert count == 2
|
||||
assert await cache.get("content1", model="model-a") is None
|
||||
assert await cache.get("content3", model="model-b") is not None
|
||||
|
||||
async def test_clear(self, cache: EmbeddingCache) -> None:
|
||||
"""Should clear all entries."""
|
||||
await cache.put("content1", [0.1])
|
||||
await cache.put("content2", [0.2])
|
||||
|
||||
count = await cache.clear()
|
||||
|
||||
assert count == 2
|
||||
assert cache.size == 0
|
||||
|
||||
def test_cleanup_expired(self) -> None:
|
||||
"""Should remove expired entries."""
|
||||
cache = EmbeddingCache(max_size=100, default_ttl_seconds=0.1)
|
||||
|
||||
# Use synchronous put for setup
|
||||
cache._put_memory("hash1", "default", [0.1])
|
||||
cache._put_memory("hash2", "default", [0.2], ttl_seconds=10)
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
count = cache.cleanup_expired()
|
||||
|
||||
assert count == 1
|
||||
|
||||
def test_get_stats(self, cache: EmbeddingCache) -> None:
|
||||
"""Should return accurate statistics."""
|
||||
# Put synchronously for setup
|
||||
cache._put_memory("hash1", "default", [0.1])
|
||||
|
||||
stats = cache.get_stats()
|
||||
|
||||
assert stats.current_size == 1
|
||||
|
||||
def test_hash_content(self) -> None:
|
||||
"""Should produce consistent hashes."""
|
||||
hash1 = EmbeddingCache.hash_content("test content")
|
||||
hash2 = EmbeddingCache.hash_content("test content")
|
||||
hash3 = EmbeddingCache.hash_content("different content")
|
||||
|
||||
assert hash1 == hash2
|
||||
assert hash1 != hash3
|
||||
assert len(hash1) == 32
|
||||
|
||||
|
||||
class TestEmbeddingCacheWithRedis:
|
||||
"""Tests for EmbeddingCache with Redis."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis(self) -> MagicMock:
|
||||
"""Create mock Redis."""
|
||||
redis = MagicMock()
|
||||
redis.get = AsyncMock(return_value=None)
|
||||
redis.setex = AsyncMock()
|
||||
redis.delete = AsyncMock()
|
||||
redis.scan_iter = MagicMock(return_value=iter([]))
|
||||
return redis
|
||||
|
||||
@pytest.fixture
|
||||
def cache_with_redis(self, mock_redis: MagicMock) -> EmbeddingCache:
|
||||
"""Create cache with mock Redis."""
|
||||
return EmbeddingCache(
|
||||
max_size=100,
|
||||
default_ttl_seconds=300.0,
|
||||
redis=mock_redis,
|
||||
)
|
||||
|
||||
async def test_put_stores_in_redis(
|
||||
self,
|
||||
cache_with_redis: EmbeddingCache,
|
||||
mock_redis: MagicMock,
|
||||
) -> None:
|
||||
"""Should store in Redis when available."""
|
||||
await cache_with_redis.put("content", [0.1, 0.2])
|
||||
|
||||
mock_redis.setex.assert_called_once()
|
||||
|
||||
async def test_get_checks_redis_on_miss(
|
||||
self,
|
||||
cache_with_redis: EmbeddingCache,
|
||||
mock_redis: MagicMock,
|
||||
) -> None:
|
||||
"""Should check Redis when memory cache misses."""
|
||||
import json
|
||||
|
||||
mock_redis.get.return_value = json.dumps([0.1, 0.2])
|
||||
|
||||
result = await cache_with_redis.get("content")
|
||||
|
||||
assert result == [0.1, 0.2]
|
||||
mock_redis.get.assert_called_once()
|
||||
|
||||
|
||||
class TestCachedEmbeddingGenerator:
|
||||
"""Tests for CachedEmbeddingGenerator."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_generator(self) -> MagicMock:
|
||||
"""Create mock embedding generator."""
|
||||
gen = MagicMock()
|
||||
gen.generate = AsyncMock(return_value=[0.1, 0.2, 0.3])
|
||||
gen.generate_batch = AsyncMock(return_value=[[0.1], [0.2], [0.3]])
|
||||
return gen
|
||||
|
||||
@pytest.fixture
|
||||
def cache(self) -> EmbeddingCache:
|
||||
"""Create embedding cache."""
|
||||
return EmbeddingCache(max_size=100)
|
||||
|
||||
@pytest.fixture
|
||||
def cached_gen(
|
||||
self,
|
||||
mock_generator: MagicMock,
|
||||
cache: EmbeddingCache,
|
||||
) -> CachedEmbeddingGenerator:
|
||||
"""Create cached generator."""
|
||||
return CachedEmbeddingGenerator(mock_generator, cache)
|
||||
|
||||
async def test_generate_caches_result(
|
||||
self,
|
||||
cached_gen: CachedEmbeddingGenerator,
|
||||
mock_generator: MagicMock,
|
||||
) -> None:
|
||||
"""Should cache generated embedding."""
|
||||
result1 = await cached_gen.generate("test text")
|
||||
result2 = await cached_gen.generate("test text")
|
||||
|
||||
assert result1 == [0.1, 0.2, 0.3]
|
||||
assert result2 == [0.1, 0.2, 0.3]
|
||||
mock_generator.generate.assert_called_once() # Only called once
|
||||
|
||||
async def test_generate_batch_uses_cache(
|
||||
self,
|
||||
cached_gen: CachedEmbeddingGenerator,
|
||||
mock_generator: MagicMock,
|
||||
cache: EmbeddingCache,
|
||||
) -> None:
|
||||
"""Should use cache for batch generation."""
|
||||
# Pre-cache one embedding
|
||||
await cache.put("text1", [0.5])
|
||||
|
||||
# Mock returns 2 embeddings for the 2 uncached texts
|
||||
mock_generator.generate_batch = AsyncMock(return_value=[[0.2], [0.3]])
|
||||
|
||||
results = await cached_gen.generate_batch(["text1", "text2", "text3"])
|
||||
|
||||
assert len(results) == 3
|
||||
assert results[0] == [0.5] # From cache
|
||||
assert results[1] == [0.2] # Generated
|
||||
assert results[2] == [0.3] # Generated
|
||||
|
||||
async def test_get_stats(self, cached_gen: CachedEmbeddingGenerator) -> None:
|
||||
"""Should return generator statistics."""
|
||||
await cached_gen.generate("text1")
|
||||
await cached_gen.generate("text1") # Cache hit
|
||||
|
||||
stats = cached_gen.get_stats()
|
||||
|
||||
assert stats["call_count"] == 2
|
||||
assert stats["cache_hit_count"] == 1
|
||||
|
||||
|
||||
class TestCreateEmbeddingCache:
|
||||
"""Tests for factory function."""
|
||||
|
||||
def test_creates_cache(self) -> None:
|
||||
"""Should create cache with defaults."""
|
||||
cache = create_embedding_cache()
|
||||
|
||||
assert cache.max_size == 50000
|
||||
|
||||
def test_creates_cache_with_options(self) -> None:
|
||||
"""Should create cache with custom options."""
|
||||
cache = create_embedding_cache(max_size=1000, default_ttl_seconds=600.0)
|
||||
|
||||
assert cache.max_size == 1000
|
||||
355
backend/tests/unit/services/memory/cache/test_hot_cache.py
vendored
Normal file
355
backend/tests/unit/services/memory/cache/test_hot_cache.py
vendored
Normal file
@@ -0,0 +1,355 @@
|
||||
# tests/unit/services/memory/cache/test_hot_cache.py
|
||||
"""Tests for HotMemoryCache."""
|
||||
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.cache.hot_cache import (
|
||||
CacheEntry,
|
||||
CacheKey,
|
||||
HotCacheStats,
|
||||
HotMemoryCache,
|
||||
create_hot_cache,
|
||||
)
|
||||
|
||||
|
||||
class TestCacheKey:
|
||||
"""Tests for CacheKey."""
|
||||
|
||||
def test_creates_key(self) -> None:
|
||||
"""Should create key with required fields."""
|
||||
key = CacheKey(memory_type="episodic", memory_id="123")
|
||||
|
||||
assert key.memory_type == "episodic"
|
||||
assert key.memory_id == "123"
|
||||
assert key.scope is None
|
||||
|
||||
def test_creates_key_with_scope(self) -> None:
|
||||
"""Should create key with scope."""
|
||||
key = CacheKey(memory_type="semantic", memory_id="456", scope="proj-123")
|
||||
|
||||
assert key.scope == "proj-123"
|
||||
|
||||
def test_hash_and_equality(self) -> None:
|
||||
"""Keys with same values should be equal and have same hash."""
|
||||
key1 = CacheKey(memory_type="episodic", memory_id="123", scope="proj-1")
|
||||
key2 = CacheKey(memory_type="episodic", memory_id="123", scope="proj-1")
|
||||
|
||||
assert key1 == key2
|
||||
assert hash(key1) == hash(key2)
|
||||
|
||||
def test_str_representation(self) -> None:
|
||||
"""Should produce readable string."""
|
||||
key = CacheKey(memory_type="episodic", memory_id="123", scope="proj-1")
|
||||
|
||||
assert str(key) == "episodic:proj-1:123"
|
||||
|
||||
def test_str_without_scope(self) -> None:
|
||||
"""Should produce string without scope."""
|
||||
key = CacheKey(memory_type="episodic", memory_id="123")
|
||||
|
||||
assert str(key) == "episodic:123"
|
||||
|
||||
|
||||
class TestCacheEntry:
|
||||
"""Tests for CacheEntry."""
|
||||
|
||||
def test_creates_entry(self) -> None:
|
||||
"""Should create entry with value."""
|
||||
entry = CacheEntry(
|
||||
value={"data": "test"},
|
||||
created_at=pytest.importorskip("datetime").datetime.now(
|
||||
pytest.importorskip("datetime").UTC
|
||||
),
|
||||
last_accessed_at=pytest.importorskip("datetime").datetime.now(
|
||||
pytest.importorskip("datetime").UTC
|
||||
),
|
||||
)
|
||||
|
||||
assert entry.value == {"data": "test"}
|
||||
assert entry.access_count == 1
|
||||
assert entry.ttl_seconds == 300.0
|
||||
|
||||
def test_is_expired(self) -> None:
|
||||
"""Should detect expired entries."""
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
old_time = datetime.now(UTC) - timedelta(seconds=400)
|
||||
entry = CacheEntry(
|
||||
value="test",
|
||||
created_at=old_time,
|
||||
last_accessed_at=old_time,
|
||||
ttl_seconds=300.0,
|
||||
)
|
||||
|
||||
assert entry.is_expired() is True
|
||||
|
||||
def test_not_expired(self) -> None:
|
||||
"""Should detect non-expired entries."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
entry = CacheEntry(
|
||||
value="test",
|
||||
created_at=datetime.now(UTC),
|
||||
last_accessed_at=datetime.now(UTC),
|
||||
ttl_seconds=300.0,
|
||||
)
|
||||
|
||||
assert entry.is_expired() is False
|
||||
|
||||
def test_touch_updates_access(self) -> None:
|
||||
"""Touch should update access time and count."""
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
old_time = datetime.now(UTC) - timedelta(seconds=10)
|
||||
entry = CacheEntry(
|
||||
value="test",
|
||||
created_at=old_time,
|
||||
last_accessed_at=old_time,
|
||||
access_count=5,
|
||||
)
|
||||
|
||||
entry.touch()
|
||||
|
||||
assert entry.access_count == 6
|
||||
assert entry.last_accessed_at > old_time
|
||||
|
||||
|
||||
class TestHotCacheStats:
|
||||
"""Tests for HotCacheStats."""
|
||||
|
||||
def test_hit_rate_calculation(self) -> None:
|
||||
"""Should calculate hit rate correctly."""
|
||||
stats = HotCacheStats(hits=80, misses=20)
|
||||
|
||||
assert stats.hit_rate == 0.8
|
||||
|
||||
def test_hit_rate_zero_requests(self) -> None:
|
||||
"""Should return 0 for no requests."""
|
||||
stats = HotCacheStats()
|
||||
|
||||
assert stats.hit_rate == 0.0
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Should convert to dictionary."""
|
||||
stats = HotCacheStats(hits=10, misses=5, evictions=2)
|
||||
|
||||
result = stats.to_dict()
|
||||
|
||||
assert result["hits"] == 10
|
||||
assert result["misses"] == 5
|
||||
assert result["evictions"] == 2
|
||||
assert "hit_rate" in result
|
||||
|
||||
|
||||
class TestHotMemoryCache:
|
||||
"""Tests for HotMemoryCache."""
|
||||
|
||||
@pytest.fixture
|
||||
def cache(self) -> HotMemoryCache[dict]:
|
||||
"""Create a hot memory cache."""
|
||||
return HotMemoryCache[dict](max_size=100, default_ttl_seconds=300.0)
|
||||
|
||||
def test_put_and_get(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should store and retrieve values."""
|
||||
key = CacheKey(memory_type="episodic", memory_id="123")
|
||||
value = {"data": "test"}
|
||||
|
||||
cache.put(key, value)
|
||||
result = cache.get(key)
|
||||
|
||||
assert result == value
|
||||
|
||||
def test_get_missing_key(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should return None for missing keys."""
|
||||
key = CacheKey(memory_type="episodic", memory_id="nonexistent")
|
||||
|
||||
result = cache.get(key)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_put_by_id(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should store by type and ID."""
|
||||
memory_id = uuid4()
|
||||
value = {"data": "test"}
|
||||
|
||||
cache.put_by_id("episodic", memory_id, value)
|
||||
result = cache.get_by_id("episodic", memory_id)
|
||||
|
||||
assert result == value
|
||||
|
||||
def test_put_by_id_with_scope(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should store with scope."""
|
||||
memory_id = uuid4()
|
||||
value = {"data": "test"}
|
||||
|
||||
cache.put_by_id("semantic", memory_id, value, scope="proj-123")
|
||||
result = cache.get_by_id("semantic", memory_id, scope="proj-123")
|
||||
|
||||
assert result == value
|
||||
|
||||
def test_lru_eviction(self) -> None:
|
||||
"""Should evict LRU entries when at capacity."""
|
||||
cache = HotMemoryCache[str](max_size=3)
|
||||
|
||||
# Fill cache
|
||||
cache.put_by_id("test", "1", "first")
|
||||
cache.put_by_id("test", "2", "second")
|
||||
cache.put_by_id("test", "3", "third")
|
||||
|
||||
# Access first to make it recent
|
||||
cache.get_by_id("test", "1")
|
||||
|
||||
# Add fourth, should evict second (LRU)
|
||||
cache.put_by_id("test", "4", "fourth")
|
||||
|
||||
assert cache.get_by_id("test", "1") is not None # Accessed, kept
|
||||
assert cache.get_by_id("test", "2") is None # Evicted (LRU)
|
||||
assert cache.get_by_id("test", "3") is not None
|
||||
assert cache.get_by_id("test", "4") is not None
|
||||
|
||||
def test_ttl_expiration(self) -> None:
|
||||
"""Should expire entries after TTL."""
|
||||
cache = HotMemoryCache[str](max_size=100, default_ttl_seconds=0.1)
|
||||
|
||||
cache.put_by_id("test", "1", "value")
|
||||
|
||||
# Wait for expiration
|
||||
time.sleep(0.2)
|
||||
|
||||
result = cache.get_by_id("test", "1")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_invalidate(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should invalidate specific entry."""
|
||||
key = CacheKey(memory_type="episodic", memory_id="123")
|
||||
cache.put(key, {"data": "test"})
|
||||
|
||||
result = cache.invalidate(key)
|
||||
|
||||
assert result is True
|
||||
assert cache.get(key) is None
|
||||
|
||||
def test_invalidate_by_id(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should invalidate by ID."""
|
||||
memory_id = uuid4()
|
||||
cache.put_by_id("episodic", memory_id, {"data": "test"})
|
||||
|
||||
result = cache.invalidate_by_id("episodic", memory_id)
|
||||
|
||||
assert result is True
|
||||
assert cache.get_by_id("episodic", memory_id) is None
|
||||
|
||||
def test_invalidate_by_type(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should invalidate all entries of a type."""
|
||||
cache.put_by_id("episodic", "1", {"data": "1"})
|
||||
cache.put_by_id("episodic", "2", {"data": "2"})
|
||||
cache.put_by_id("semantic", "3", {"data": "3"})
|
||||
|
||||
count = cache.invalidate_by_type("episodic")
|
||||
|
||||
assert count == 2
|
||||
assert cache.get_by_id("episodic", "1") is None
|
||||
assert cache.get_by_id("episodic", "2") is None
|
||||
assert cache.get_by_id("semantic", "3") is not None
|
||||
|
||||
def test_invalidate_by_scope(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should invalidate all entries in a scope."""
|
||||
cache.put_by_id("episodic", "1", {"data": "1"}, scope="proj-1")
|
||||
cache.put_by_id("semantic", "2", {"data": "2"}, scope="proj-1")
|
||||
cache.put_by_id("episodic", "3", {"data": "3"}, scope="proj-2")
|
||||
|
||||
count = cache.invalidate_by_scope("proj-1")
|
||||
|
||||
assert count == 2
|
||||
assert cache.get_by_id("episodic", "3", scope="proj-2") is not None
|
||||
|
||||
def test_invalidate_pattern(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should invalidate entries matching pattern."""
|
||||
cache.put_by_id("episodic", "123", {"data": "1"})
|
||||
cache.put_by_id("episodic", "124", {"data": "2"})
|
||||
cache.put_by_id("semantic", "125", {"data": "3"})
|
||||
|
||||
count = cache.invalidate_pattern("episodic:*")
|
||||
|
||||
assert count == 2
|
||||
|
||||
def test_clear(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should clear all entries."""
|
||||
cache.put_by_id("episodic", "1", {"data": "1"})
|
||||
cache.put_by_id("semantic", "2", {"data": "2"})
|
||||
|
||||
count = cache.clear()
|
||||
|
||||
assert count == 2
|
||||
assert cache.size == 0
|
||||
|
||||
def test_cleanup_expired(self) -> None:
|
||||
"""Should remove expired entries."""
|
||||
cache = HotMemoryCache[str](max_size=100, default_ttl_seconds=0.1)
|
||||
|
||||
cache.put_by_id("test", "1", "value1")
|
||||
cache.put_by_id("test", "2", "value2", ttl_seconds=10)
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
count = cache.cleanup_expired()
|
||||
|
||||
assert count == 1 # Only the first one expired
|
||||
assert cache.size == 1
|
||||
|
||||
def test_get_hot_memories(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should return most accessed memories."""
|
||||
cache.put_by_id("episodic", "1", {"data": "1"})
|
||||
cache.put_by_id("episodic", "2", {"data": "2"})
|
||||
|
||||
# Access first one multiple times
|
||||
for _ in range(5):
|
||||
cache.get_by_id("episodic", "1")
|
||||
|
||||
hot = cache.get_hot_memories(limit=2)
|
||||
|
||||
assert len(hot) == 2
|
||||
assert hot[0][1] >= hot[1][1] # Sorted by access count
|
||||
|
||||
def test_get_stats(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should return accurate statistics."""
|
||||
cache.put_by_id("episodic", "1", {"data": "1"})
|
||||
cache.get_by_id("episodic", "1") # Hit
|
||||
cache.get_by_id("episodic", "2") # Miss
|
||||
|
||||
stats = cache.get_stats()
|
||||
|
||||
assert stats.hits == 1
|
||||
assert stats.misses == 1
|
||||
assert stats.current_size == 1
|
||||
|
||||
def test_reset_stats(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should reset statistics."""
|
||||
cache.put_by_id("episodic", "1", {"data": "1"})
|
||||
cache.get_by_id("episodic", "1")
|
||||
|
||||
cache.reset_stats()
|
||||
stats = cache.get_stats()
|
||||
|
||||
assert stats.hits == 0
|
||||
assert stats.misses == 0
|
||||
|
||||
|
||||
class TestCreateHotCache:
|
||||
"""Tests for factory function."""
|
||||
|
||||
def test_creates_cache(self) -> None:
|
||||
"""Should create cache with defaults."""
|
||||
cache = create_hot_cache()
|
||||
|
||||
assert cache.max_size == 10000
|
||||
|
||||
def test_creates_cache_with_options(self) -> None:
|
||||
"""Should create cache with custom options."""
|
||||
cache = create_hot_cache(max_size=500, default_ttl_seconds=60.0)
|
||||
|
||||
assert cache.max_size == 500
|
||||
Reference in New Issue
Block a user