feat(memory): implement caching layer for memory operations (#98)

Add comprehensive caching layer for the Agent Memory System:

- HotMemoryCache: LRU cache for frequently accessed memories
  - Python 3.12 type parameter syntax
  - Thread-safe operations with RLock
  - TTL-based expiration
  - Access count tracking for hot memory identification
  - Scoped invalidation by type, scope, or pattern

- EmbeddingCache: Cache embeddings by content hash
  - Content-hash based deduplication
  - Optional Redis backing for persistence
  - LRU eviction with configurable max size
  - CachedEmbeddingGenerator wrapper for transparent caching

- CacheManager: Unified cache management
  - Coordinates hot cache, embedding cache, and retrieval cache
  - Centralized invalidation across all caches
  - Aggregated statistics and hit rate tracking
  - Automatic cleanup scheduling
  - Cache warmup support

Performance targets:
- Cache hit rate > 80% for hot memories
- Cache operations < 1ms (memory), < 5ms (Redis)

83 new tests with comprehensive coverage.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-05 04:04:13 +01:00
parent 30e5c68304
commit 6954774e36
8 changed files with 2690 additions and 0 deletions

View File

@@ -0,0 +1,21 @@
# app/services/memory/cache/__init__.py
"""
Memory Caching Layer.
Provides caching for memory operations:
- Hot Memory Cache: LRU cache for frequently accessed memories
- Embedding Cache: Cache embeddings by content hash
- Cache Manager: Unified cache management with invalidation
"""
from .cache_manager import CacheManager, CacheStats, get_cache_manager
from .embedding_cache import EmbeddingCache
from .hot_cache import HotMemoryCache
__all__ = [
"CacheManager",
"CacheStats",
"EmbeddingCache",
"HotMemoryCache",
"get_cache_manager",
]

View File

@@ -0,0 +1,500 @@
# app/services/memory/cache/cache_manager.py
"""
Cache Manager.
Unified cache management for memory operations.
Coordinates hot cache, embedding cache, and retrieval cache.
Provides centralized invalidation and statistics.
"""
import logging
import threading
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any
from uuid import UUID
from app.services.memory.config import get_memory_settings
from .embedding_cache import EmbeddingCache, create_embedding_cache
from .hot_cache import CacheKey, HotMemoryCache, create_hot_cache
if TYPE_CHECKING:
from redis.asyncio import Redis
from app.services.memory.indexing.retrieval import RetrievalCache
logger = logging.getLogger(__name__)
def _utcnow() -> datetime:
"""Get current UTC time as timezone-aware datetime."""
return datetime.now(UTC)
@dataclass
class CacheStats:
"""Aggregated cache statistics."""
hot_cache: dict[str, Any] = field(default_factory=dict)
embedding_cache: dict[str, Any] = field(default_factory=dict)
retrieval_cache: dict[str, Any] = field(default_factory=dict)
overall_hit_rate: float = 0.0
last_cleanup: datetime | None = None
cleanup_count: int = 0
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"hot_cache": self.hot_cache,
"embedding_cache": self.embedding_cache,
"retrieval_cache": self.retrieval_cache,
"overall_hit_rate": self.overall_hit_rate,
"last_cleanup": self.last_cleanup.isoformat() if self.last_cleanup else None,
"cleanup_count": self.cleanup_count,
}
class CacheManager:
"""
Unified cache manager for memory operations.
Provides:
- Centralized cache configuration
- Coordinated invalidation across caches
- Aggregated statistics
- Automatic cleanup scheduling
Performance targets:
- Overall cache hit rate > 80%
- Cache operations < 1ms (memory), < 5ms (Redis)
"""
def __init__(
self,
hot_cache: HotMemoryCache[Any] | None = None,
embedding_cache: EmbeddingCache | None = None,
retrieval_cache: "RetrievalCache | None" = None,
redis: "Redis | None" = None,
) -> None:
"""
Initialize the cache manager.
Args:
hot_cache: Optional pre-configured hot cache
embedding_cache: Optional pre-configured embedding cache
retrieval_cache: Optional pre-configured retrieval cache
redis: Optional Redis connection for persistence
"""
self._settings = get_memory_settings()
self._redis = redis
self._enabled = self._settings.cache_enabled
# Initialize caches
if hot_cache:
self._hot_cache = hot_cache
else:
self._hot_cache = create_hot_cache(
max_size=self._settings.cache_max_items,
default_ttl_seconds=self._settings.cache_ttl_seconds,
)
if embedding_cache:
self._embedding_cache = embedding_cache
else:
self._embedding_cache = create_embedding_cache(
max_size=self._settings.cache_max_items,
default_ttl_seconds=self._settings.cache_ttl_seconds * 12, # 1hr for embeddings
redis=redis,
)
self._retrieval_cache = retrieval_cache
# Stats tracking
self._last_cleanup: datetime | None = None
self._cleanup_count = 0
self._lock = threading.RLock()
logger.info(
f"Initialized CacheManager: enabled={self._enabled}, "
f"redis={'connected' if redis else 'disabled'}"
)
def set_redis(self, redis: "Redis") -> None:
"""Set Redis connection for all caches."""
self._redis = redis
self._embedding_cache.set_redis(redis)
def set_retrieval_cache(self, cache: "RetrievalCache") -> None:
"""Set retrieval cache instance."""
self._retrieval_cache = cache
@property
def is_enabled(self) -> bool:
"""Check if caching is enabled."""
return self._enabled
@property
def hot_cache(self) -> HotMemoryCache[Any]:
"""Get the hot memory cache."""
return self._hot_cache
@property
def embedding_cache(self) -> EmbeddingCache:
"""Get the embedding cache."""
return self._embedding_cache
@property
def retrieval_cache(self) -> "RetrievalCache | None":
"""Get the retrieval cache."""
return self._retrieval_cache
# =========================================================================
# Hot Memory Cache Operations
# =========================================================================
def get_memory(
self,
memory_type: str,
memory_id: UUID | str,
scope: str | None = None,
) -> Any | None:
"""
Get a memory from hot cache.
Args:
memory_type: Type of memory
memory_id: Memory ID
scope: Optional scope
Returns:
Cached memory or None
"""
if not self._enabled:
return None
return self._hot_cache.get_by_id(memory_type, memory_id, scope)
def cache_memory(
self,
memory_type: str,
memory_id: UUID | str,
memory: Any,
scope: str | None = None,
ttl_seconds: float | None = None,
) -> None:
"""
Cache a memory in hot cache.
Args:
memory_type: Type of memory
memory_id: Memory ID
memory: Memory object
scope: Optional scope
ttl_seconds: Optional TTL override
"""
if not self._enabled:
return
self._hot_cache.put_by_id(memory_type, memory_id, memory, scope, ttl_seconds)
# =========================================================================
# Embedding Cache Operations
# =========================================================================
async def get_embedding(
self,
content: str,
model: str = "default",
) -> list[float] | None:
"""
Get a cached embedding.
Args:
content: Content text
model: Model name
Returns:
Cached embedding or None
"""
if not self._enabled:
return None
return await self._embedding_cache.get(content, model)
async def cache_embedding(
self,
content: str,
embedding: list[float],
model: str = "default",
ttl_seconds: float | None = None,
) -> str:
"""
Cache an embedding.
Args:
content: Content text
embedding: Embedding vector
model: Model name
ttl_seconds: Optional TTL override
Returns:
Content hash
"""
if not self._enabled:
return EmbeddingCache.hash_content(content)
return await self._embedding_cache.put(content, embedding, model, ttl_seconds)
# =========================================================================
# Invalidation
# =========================================================================
async def invalidate_memory(
self,
memory_type: str,
memory_id: UUID | str,
scope: str | None = None,
) -> int:
"""
Invalidate a memory across all caches.
Args:
memory_type: Type of memory
memory_id: Memory ID
scope: Optional scope
Returns:
Number of entries invalidated
"""
count = 0
# Invalidate hot cache
if self._hot_cache.invalidate_by_id(memory_type, memory_id, scope):
count += 1
# Invalidate retrieval cache
if self._retrieval_cache:
uuid_id = UUID(str(memory_id)) if not isinstance(memory_id, UUID) else memory_id
count += self._retrieval_cache.invalidate_by_memory(uuid_id)
logger.debug(f"Invalidated {count} cache entries for {memory_type}:{memory_id}")
return count
async def invalidate_by_type(self, memory_type: str) -> int:
"""
Invalidate all entries of a memory type.
Args:
memory_type: Type of memory
Returns:
Number of entries invalidated
"""
count = self._hot_cache.invalidate_by_type(memory_type)
if self._retrieval_cache:
count += self._retrieval_cache.clear()
logger.info(f"Invalidated {count} cache entries for type {memory_type}")
return count
async def invalidate_by_scope(self, scope: str) -> int:
"""
Invalidate all entries in a scope.
Args:
scope: Scope to invalidate (e.g., project_id)
Returns:
Number of entries invalidated
"""
count = self._hot_cache.invalidate_by_scope(scope)
# Retrieval cache doesn't support scope-based invalidation
# so we clear it entirely for safety
if self._retrieval_cache:
count += self._retrieval_cache.clear()
logger.info(f"Invalidated {count} cache entries for scope {scope}")
return count
async def invalidate_embedding(
self,
content: str,
model: str = "default",
) -> bool:
"""
Invalidate a cached embedding.
Args:
content: Content text
model: Model name
Returns:
True if entry was found and removed
"""
return await self._embedding_cache.invalidate(content, model)
async def clear_all(self) -> int:
"""
Clear all caches.
Returns:
Total number of entries cleared
"""
count = 0
count += self._hot_cache.clear()
count += await self._embedding_cache.clear()
if self._retrieval_cache:
count += self._retrieval_cache.clear()
logger.info(f"Cleared {count} entries from all caches")
return count
# =========================================================================
# Cleanup
# =========================================================================
async def cleanup_expired(self) -> int:
"""
Clean up expired entries from all caches.
Returns:
Number of entries cleaned up
"""
with self._lock:
count = 0
count += self._hot_cache.cleanup_expired()
count += self._embedding_cache.cleanup_expired()
# Retrieval cache doesn't have a cleanup method,
# but entries expire on access
self._last_cleanup = _utcnow()
self._cleanup_count += 1
if count > 0:
logger.info(f"Cleaned up {count} expired cache entries")
return count
# =========================================================================
# Statistics
# =========================================================================
def get_stats(self) -> CacheStats:
"""
Get aggregated cache statistics.
Returns:
CacheStats with all cache metrics
"""
hot_stats = self._hot_cache.get_stats().to_dict()
emb_stats = self._embedding_cache.get_stats().to_dict()
retrieval_stats: dict[str, Any] = {}
if self._retrieval_cache:
retrieval_stats = self._retrieval_cache.get_stats()
# Calculate overall hit rate
total_hits = hot_stats.get("hits", 0) + emb_stats.get("hits", 0)
total_misses = hot_stats.get("misses", 0) + emb_stats.get("misses", 0)
if retrieval_stats:
# Retrieval cache doesn't track hits/misses the same way
pass
total_requests = total_hits + total_misses
overall_hit_rate = total_hits / total_requests if total_requests > 0 else 0.0
return CacheStats(
hot_cache=hot_stats,
embedding_cache=emb_stats,
retrieval_cache=retrieval_stats,
overall_hit_rate=overall_hit_rate,
last_cleanup=self._last_cleanup,
cleanup_count=self._cleanup_count,
)
def get_hot_memories(self, limit: int = 10) -> list[tuple[CacheKey, int]]:
"""
Get the most frequently accessed memories.
Args:
limit: Maximum number to return
Returns:
List of (key, access_count) tuples
"""
return self._hot_cache.get_hot_memories(limit)
def reset_stats(self) -> None:
"""Reset all cache statistics."""
self._hot_cache.reset_stats()
self._embedding_cache.reset_stats()
# =========================================================================
# Warmup
# =========================================================================
async def warmup(
self,
memories: list[tuple[str, UUID | str, Any]],
scope: str | None = None,
) -> int:
"""
Warm up the hot cache with memories.
Args:
memories: List of (memory_type, memory_id, memory) tuples
scope: Optional scope for all memories
Returns:
Number of memories cached
"""
if not self._enabled:
return 0
for memory_type, memory_id, memory in memories:
self._hot_cache.put_by_id(memory_type, memory_id, memory, scope)
logger.info(f"Warmed up cache with {len(memories)} memories")
return len(memories)
# Singleton instance
_cache_manager: CacheManager | None = None
_cache_manager_lock = threading.Lock()
def get_cache_manager(
redis: "Redis | None" = None,
reset: bool = False,
) -> CacheManager:
"""
Get the global CacheManager instance.
Thread-safe with double-checked locking pattern.
Args:
redis: Optional Redis connection
reset: Force create a new instance
Returns:
CacheManager instance
"""
global _cache_manager
if reset or _cache_manager is None:
with _cache_manager_lock:
if reset or _cache_manager is None:
_cache_manager = CacheManager(redis=redis)
return _cache_manager
def reset_cache_manager() -> None:
"""Reset the global cache manager instance."""
global _cache_manager
with _cache_manager_lock:
_cache_manager = None

View File

@@ -0,0 +1,627 @@
# app/services/memory/cache/embedding_cache.py
"""
Embedding Cache.
Caches embeddings by content hash to avoid recomputing.
Provides significant performance improvement for repeated content.
"""
import hashlib
import logging
import threading
from collections import OrderedDict
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from redis.asyncio import Redis
logger = logging.getLogger(__name__)
def _utcnow() -> datetime:
"""Get current UTC time as timezone-aware datetime."""
return datetime.now(UTC)
@dataclass
class EmbeddingEntry:
"""A cached embedding entry."""
embedding: list[float]
content_hash: str
model: str
created_at: datetime
ttl_seconds: float = 3600.0 # 1 hour default
def is_expired(self) -> bool:
"""Check if this entry has expired."""
age = (_utcnow() - self.created_at).total_seconds()
return age > self.ttl_seconds
@dataclass
class EmbeddingCacheStats:
"""Statistics for the embedding cache."""
hits: int = 0
misses: int = 0
evictions: int = 0
expirations: int = 0
current_size: int = 0
max_size: int = 0
bytes_saved: int = 0 # Estimated bytes saved by caching
@property
def hit_rate(self) -> float:
"""Calculate cache hit rate."""
total = self.hits + self.misses
if total == 0:
return 0.0
return self.hits / total
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"hits": self.hits,
"misses": self.misses,
"evictions": self.evictions,
"expirations": self.expirations,
"current_size": self.current_size,
"max_size": self.max_size,
"hit_rate": self.hit_rate,
"bytes_saved": self.bytes_saved,
}
class EmbeddingCache:
"""
Cache for embeddings by content hash.
Features:
- Content-hash based deduplication
- LRU eviction
- TTL-based expiration
- Optional Redis backing for persistence
- Thread-safe operations
Performance targets:
- Cache hit rate > 90% for repeated content
- Get/put operations < 1ms (memory), < 5ms (Redis)
"""
def __init__(
self,
max_size: int = 50000,
default_ttl_seconds: float = 3600.0,
redis: "Redis | None" = None,
redis_prefix: str = "mem:emb",
) -> None:
"""
Initialize the embedding cache.
Args:
max_size: Maximum number of entries in memory cache
default_ttl_seconds: Default TTL for entries (1 hour)
redis: Optional Redis connection for persistence
redis_prefix: Prefix for Redis keys
"""
self._max_size = max_size
self._default_ttl = default_ttl_seconds
self._cache: OrderedDict[str, EmbeddingEntry] = OrderedDict()
self._lock = threading.RLock()
self._stats = EmbeddingCacheStats(max_size=max_size)
self._redis = redis
self._redis_prefix = redis_prefix
logger.info(
f"Initialized EmbeddingCache with max_size={max_size}, "
f"ttl={default_ttl_seconds}s, redis={'enabled' if redis else 'disabled'}"
)
def set_redis(self, redis: "Redis") -> None:
"""Set Redis connection for persistence."""
self._redis = redis
@staticmethod
def hash_content(content: str) -> str:
"""
Compute hash of content for cache key.
Args:
content: Content to hash
Returns:
32-character hex hash
"""
return hashlib.sha256(content.encode()).hexdigest()[:32]
def _cache_key(self, content_hash: str, model: str) -> str:
"""Build cache key from content hash and model."""
return f"{content_hash}:{model}"
def _redis_key(self, content_hash: str, model: str) -> str:
"""Build Redis key from content hash and model."""
return f"{self._redis_prefix}:{content_hash}:{model}"
async def get(
self,
content: str,
model: str = "default",
) -> list[float] | None:
"""
Get a cached embedding.
Args:
content: Content text
model: Model name
Returns:
Cached embedding or None if not found/expired
"""
content_hash = self.hash_content(content)
cache_key = self._cache_key(content_hash, model)
# Check memory cache first
with self._lock:
if cache_key in self._cache:
entry = self._cache[cache_key]
if entry.is_expired():
del self._cache[cache_key]
self._stats.expirations += 1
self._stats.current_size = len(self._cache)
else:
# Move to end (most recently used)
self._cache.move_to_end(cache_key)
self._stats.hits += 1
return entry.embedding
# Check Redis if available
if self._redis:
try:
redis_key = self._redis_key(content_hash, model)
data = await self._redis.get(redis_key)
if data:
import json
embedding = json.loads(data)
# Store in memory cache for faster access
self._put_memory(content_hash, model, embedding)
self._stats.hits += 1
return embedding
except Exception as e:
logger.warning(f"Redis get error: {e}")
self._stats.misses += 1
return None
async def get_by_hash(
self,
content_hash: str,
model: str = "default",
) -> list[float] | None:
"""
Get a cached embedding by hash.
Args:
content_hash: Content hash
model: Model name
Returns:
Cached embedding or None if not found/expired
"""
cache_key = self._cache_key(content_hash, model)
with self._lock:
if cache_key in self._cache:
entry = self._cache[cache_key]
if entry.is_expired():
del self._cache[cache_key]
self._stats.expirations += 1
self._stats.current_size = len(self._cache)
else:
self._cache.move_to_end(cache_key)
self._stats.hits += 1
return entry.embedding
# Check Redis
if self._redis:
try:
redis_key = self._redis_key(content_hash, model)
data = await self._redis.get(redis_key)
if data:
import json
embedding = json.loads(data)
self._put_memory(content_hash, model, embedding)
self._stats.hits += 1
return embedding
except Exception as e:
logger.warning(f"Redis get error: {e}")
self._stats.misses += 1
return None
async def put(
self,
content: str,
embedding: list[float],
model: str = "default",
ttl_seconds: float | None = None,
) -> str:
"""
Cache an embedding.
Args:
content: Content text
embedding: Embedding vector
model: Model name
ttl_seconds: Optional TTL override
Returns:
Content hash
"""
content_hash = self.hash_content(content)
ttl = ttl_seconds or self._default_ttl
# Store in memory
self._put_memory(content_hash, model, embedding, ttl)
# Store in Redis if available
if self._redis:
try:
import json
redis_key = self._redis_key(content_hash, model)
await self._redis.setex(
redis_key,
int(ttl),
json.dumps(embedding),
)
except Exception as e:
logger.warning(f"Redis put error: {e}")
return content_hash
def _put_memory(
self,
content_hash: str,
model: str,
embedding: list[float],
ttl_seconds: float | None = None,
) -> None:
"""Store in memory cache."""
with self._lock:
# Evict if at capacity
self._evict_if_needed()
cache_key = self._cache_key(content_hash, model)
entry = EmbeddingEntry(
embedding=embedding,
content_hash=content_hash,
model=model,
created_at=_utcnow(),
ttl_seconds=ttl_seconds or self._default_ttl,
)
self._cache[cache_key] = entry
self._cache.move_to_end(cache_key)
self._stats.current_size = len(self._cache)
def _evict_if_needed(self) -> None:
"""Evict entries if cache is at capacity."""
while len(self._cache) >= self._max_size:
if self._cache:
self._cache.popitem(last=False)
self._stats.evictions += 1
async def put_batch(
self,
items: list[tuple[str, list[float]]],
model: str = "default",
ttl_seconds: float | None = None,
) -> list[str]:
"""
Cache multiple embeddings.
Args:
items: List of (content, embedding) tuples
model: Model name
ttl_seconds: Optional TTL override
Returns:
List of content hashes
"""
hashes = []
for content, embedding in items:
content_hash = await self.put(content, embedding, model, ttl_seconds)
hashes.append(content_hash)
return hashes
async def invalidate(
self,
content: str,
model: str = "default",
) -> bool:
"""
Invalidate a cached embedding.
Args:
content: Content text
model: Model name
Returns:
True if entry was found and removed
"""
content_hash = self.hash_content(content)
return await self.invalidate_by_hash(content_hash, model)
async def invalidate_by_hash(
self,
content_hash: str,
model: str = "default",
) -> bool:
"""
Invalidate a cached embedding by hash.
Args:
content_hash: Content hash
model: Model name
Returns:
True if entry was found and removed
"""
cache_key = self._cache_key(content_hash, model)
removed = False
with self._lock:
if cache_key in self._cache:
del self._cache[cache_key]
self._stats.current_size = len(self._cache)
removed = True
# Remove from Redis
if self._redis:
try:
redis_key = self._redis_key(content_hash, model)
await self._redis.delete(redis_key)
removed = True
except Exception as e:
logger.warning(f"Redis delete error: {e}")
return removed
async def invalidate_by_model(self, model: str) -> int:
"""
Invalidate all embeddings for a model.
Args:
model: Model name
Returns:
Number of entries invalidated
"""
count = 0
with self._lock:
keys_to_remove = [
k for k, v in self._cache.items() if v.model == model
]
for key in keys_to_remove:
del self._cache[key]
count += 1
self._stats.current_size = len(self._cache)
# Note: Redis pattern deletion would require SCAN which is expensive
# For now, we only clear memory cache for model-based invalidation
return count
async def clear(self) -> int:
"""
Clear all cache entries.
Returns:
Number of entries cleared
"""
with self._lock:
count = len(self._cache)
self._cache.clear()
self._stats.current_size = 0
# Clear Redis entries
if self._redis:
try:
pattern = f"{self._redis_prefix}:*"
deleted = 0
async for key in self._redis.scan_iter(match=pattern):
await self._redis.delete(key)
deleted += 1
count = max(count, deleted)
except Exception as e:
logger.warning(f"Redis clear error: {e}")
logger.info(f"Cleared {count} entries from embedding cache")
return count
def cleanup_expired(self) -> int:
"""
Remove all expired entries from memory cache.
Returns:
Number of entries removed
"""
with self._lock:
keys_to_remove = [
k for k, v in self._cache.items() if v.is_expired()
]
for key in keys_to_remove:
del self._cache[key]
self._stats.expirations += 1
self._stats.current_size = len(self._cache)
if keys_to_remove:
logger.debug(f"Cleaned up {len(keys_to_remove)} expired embeddings")
return len(keys_to_remove)
def get_stats(self) -> EmbeddingCacheStats:
"""Get cache statistics."""
with self._lock:
self._stats.current_size = len(self._cache)
return self._stats
def reset_stats(self) -> None:
"""Reset cache statistics."""
with self._lock:
self._stats = EmbeddingCacheStats(
max_size=self._max_size,
current_size=len(self._cache),
)
@property
def size(self) -> int:
"""Get current cache size."""
return len(self._cache)
@property
def max_size(self) -> int:
"""Get maximum cache size."""
return self._max_size
class CachedEmbeddingGenerator:
"""
Wrapper for embedding generators with caching.
Wraps an embedding generator to cache results.
"""
def __init__(
self,
generator: Any,
cache: EmbeddingCache,
model: str = "default",
) -> None:
"""
Initialize the cached embedding generator.
Args:
generator: Underlying embedding generator
cache: Embedding cache
model: Model name for cache keys
"""
self._generator = generator
self._cache = cache
self._model = model
self._call_count = 0
self._cache_hit_count = 0
async def generate(self, text: str) -> list[float]:
"""
Generate embedding with caching.
Args:
text: Text to embed
Returns:
Embedding vector
"""
self._call_count += 1
# Check cache first
cached = await self._cache.get(text, self._model)
if cached is not None:
self._cache_hit_count += 1
return cached
# Generate and cache
embedding = await self._generator.generate(text)
await self._cache.put(text, embedding, self._model)
return embedding
async def generate_batch(
self,
texts: list[str],
) -> list[list[float]]:
"""
Generate embeddings for multiple texts with caching.
Args:
texts: Texts to embed
Returns:
List of embedding vectors
"""
results: list[list[float] | None] = [None] * len(texts)
to_generate: list[tuple[int, str]] = []
# Check cache for each text
for i, text in enumerate(texts):
cached = await self._cache.get(text, self._model)
if cached is not None:
results[i] = cached
self._cache_hit_count += 1
else:
to_generate.append((i, text))
self._call_count += len(texts)
# Generate missing embeddings
if to_generate:
if hasattr(self._generator, "generate_batch"):
texts_to_gen = [t for _, t in to_generate]
embeddings = await self._generator.generate_batch(texts_to_gen)
for (idx, text), embedding in zip(to_generate, embeddings, strict=True):
results[idx] = embedding
await self._cache.put(text, embedding, self._model)
else:
# Fallback to individual generation
for idx, text in to_generate:
embedding = await self._generator.generate(text)
results[idx] = embedding
await self._cache.put(text, embedding, self._model)
return results # type: ignore[return-value]
def get_stats(self) -> dict[str, Any]:
"""Get generator statistics."""
return {
"call_count": self._call_count,
"cache_hit_count": self._cache_hit_count,
"cache_hit_rate": (
self._cache_hit_count / self._call_count
if self._call_count > 0
else 0.0
),
"cache_stats": self._cache.get_stats().to_dict(),
}
# Factory function
def create_embedding_cache(
max_size: int = 50000,
default_ttl_seconds: float = 3600.0,
redis: "Redis | None" = None,
) -> EmbeddingCache:
"""
Create an embedding cache.
Args:
max_size: Maximum number of entries
default_ttl_seconds: Default TTL for entries
redis: Optional Redis connection
Returns:
Configured EmbeddingCache instance
"""
return EmbeddingCache(
max_size=max_size,
default_ttl_seconds=default_ttl_seconds,
redis=redis,
)

View File

@@ -0,0 +1,463 @@
# app/services/memory/cache/hot_cache.py
"""
Hot Memory Cache.
LRU cache for frequently accessed memories.
Provides fast access to recently used memories without database queries.
"""
import logging
import threading
from collections import OrderedDict
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import Any
from uuid import UUID
logger = logging.getLogger(__name__)
def _utcnow() -> datetime:
"""Get current UTC time as timezone-aware datetime."""
return datetime.now(UTC)
@dataclass
class CacheEntry[T]:
"""A cached memory entry with metadata."""
value: T
created_at: datetime
last_accessed_at: datetime
access_count: int = 1
ttl_seconds: float = 300.0
def is_expired(self) -> bool:
"""Check if this entry has expired."""
age = (_utcnow() - self.created_at).total_seconds()
return age > self.ttl_seconds
def touch(self) -> None:
"""Update access time and count."""
self.last_accessed_at = _utcnow()
self.access_count += 1
@dataclass
class CacheKey:
"""A structured cache key with components."""
memory_type: str
memory_id: str
scope: str | None = None
def __hash__(self) -> int:
return hash((self.memory_type, self.memory_id, self.scope))
def __eq__(self, other: object) -> bool:
if not isinstance(other, CacheKey):
return False
return (
self.memory_type == other.memory_type
and self.memory_id == other.memory_id
and self.scope == other.scope
)
def __str__(self) -> str:
if self.scope:
return f"{self.memory_type}:{self.scope}:{self.memory_id}"
return f"{self.memory_type}:{self.memory_id}"
@dataclass
class HotCacheStats:
"""Statistics for the hot memory cache."""
hits: int = 0
misses: int = 0
evictions: int = 0
expirations: int = 0
current_size: int = 0
max_size: int = 0
@property
def hit_rate(self) -> float:
"""Calculate cache hit rate."""
total = self.hits + self.misses
if total == 0:
return 0.0
return self.hits / total
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"hits": self.hits,
"misses": self.misses,
"evictions": self.evictions,
"expirations": self.expirations,
"current_size": self.current_size,
"max_size": self.max_size,
"hit_rate": self.hit_rate,
}
class HotMemoryCache[T]:
"""
LRU cache for frequently accessed memories.
Features:
- LRU eviction when capacity is reached
- TTL-based expiration
- Access count tracking for hot memory identification
- Thread-safe operations
- Scoped invalidation
Performance targets:
- Cache hit rate > 80% for hot memories
- Get/put operations < 1ms
"""
def __init__(
self,
max_size: int = 10000,
default_ttl_seconds: float = 300.0,
) -> None:
"""
Initialize the hot memory cache.
Args:
max_size: Maximum number of entries
default_ttl_seconds: Default TTL for entries (5 minutes)
"""
self._max_size = max_size
self._default_ttl = default_ttl_seconds
self._cache: OrderedDict[CacheKey, CacheEntry[T]] = OrderedDict()
self._lock = threading.RLock()
self._stats = HotCacheStats(max_size=max_size)
logger.info(
f"Initialized HotMemoryCache with max_size={max_size}, "
f"ttl={default_ttl_seconds}s"
)
def get(self, key: CacheKey) -> T | None:
"""
Get a memory from cache.
Args:
key: Cache key
Returns:
Cached value or None if not found/expired
"""
with self._lock:
if key not in self._cache:
self._stats.misses += 1
return None
entry = self._cache[key]
# Check expiration
if entry.is_expired():
del self._cache[key]
self._stats.expirations += 1
self._stats.misses += 1
self._stats.current_size = len(self._cache)
return None
# Move to end (most recently used)
self._cache.move_to_end(key)
entry.touch()
self._stats.hits += 1
return entry.value
def get_by_id(
self,
memory_type: str,
memory_id: UUID | str,
scope: str | None = None,
) -> T | None:
"""
Get a memory by type and ID.
Args:
memory_type: Type of memory (episodic, semantic, procedural)
memory_id: Memory ID
scope: Optional scope (project_id, agent_id)
Returns:
Cached value or None if not found/expired
"""
key = CacheKey(
memory_type=memory_type,
memory_id=str(memory_id),
scope=scope,
)
return self.get(key)
def put(
self,
key: CacheKey,
value: T,
ttl_seconds: float | None = None,
) -> None:
"""
Put a memory into cache.
Args:
key: Cache key
value: Value to cache
ttl_seconds: Optional TTL override
"""
with self._lock:
# Evict if at capacity
self._evict_if_needed()
now = _utcnow()
entry = CacheEntry(
value=value,
created_at=now,
last_accessed_at=now,
access_count=1,
ttl_seconds=ttl_seconds or self._default_ttl,
)
self._cache[key] = entry
self._cache.move_to_end(key)
self._stats.current_size = len(self._cache)
def put_by_id(
self,
memory_type: str,
memory_id: UUID | str,
value: T,
scope: str | None = None,
ttl_seconds: float | None = None,
) -> None:
"""
Put a memory by type and ID.
Args:
memory_type: Type of memory
memory_id: Memory ID
value: Value to cache
scope: Optional scope
ttl_seconds: Optional TTL override
"""
key = CacheKey(
memory_type=memory_type,
memory_id=str(memory_id),
scope=scope,
)
self.put(key, value, ttl_seconds)
def _evict_if_needed(self) -> None:
"""Evict entries if cache is at capacity."""
while len(self._cache) >= self._max_size:
# Remove least recently used (first item)
if self._cache:
self._cache.popitem(last=False)
self._stats.evictions += 1
def invalidate(self, key: CacheKey) -> bool:
"""
Invalidate a specific cache entry.
Args:
key: Cache key to invalidate
Returns:
True if entry was found and removed
"""
with self._lock:
if key in self._cache:
del self._cache[key]
self._stats.current_size = len(self._cache)
return True
return False
def invalidate_by_id(
self,
memory_type: str,
memory_id: UUID | str,
scope: str | None = None,
) -> bool:
"""
Invalidate a memory by type and ID.
Args:
memory_type: Type of memory
memory_id: Memory ID
scope: Optional scope
Returns:
True if entry was found and removed
"""
key = CacheKey(
memory_type=memory_type,
memory_id=str(memory_id),
scope=scope,
)
return self.invalidate(key)
def invalidate_by_type(self, memory_type: str) -> int:
"""
Invalidate all entries of a memory type.
Args:
memory_type: Type of memory to invalidate
Returns:
Number of entries invalidated
"""
with self._lock:
keys_to_remove = [
k for k in self._cache.keys() if k.memory_type == memory_type
]
for key in keys_to_remove:
del self._cache[key]
self._stats.current_size = len(self._cache)
return len(keys_to_remove)
def invalidate_by_scope(self, scope: str) -> int:
"""
Invalidate all entries in a scope.
Args:
scope: Scope to invalidate (e.g., project_id)
Returns:
Number of entries invalidated
"""
with self._lock:
keys_to_remove = [k for k in self._cache.keys() if k.scope == scope]
for key in keys_to_remove:
del self._cache[key]
self._stats.current_size = len(self._cache)
return len(keys_to_remove)
def invalidate_pattern(self, pattern: str) -> int:
"""
Invalidate entries matching a pattern.
Pattern can include * as wildcard.
Args:
pattern: Pattern to match (e.g., "episodic:*")
Returns:
Number of entries invalidated
"""
import fnmatch
with self._lock:
keys_to_remove = [
k for k in self._cache.keys() if fnmatch.fnmatch(str(k), pattern)
]
for key in keys_to_remove:
del self._cache[key]
self._stats.current_size = len(self._cache)
return len(keys_to_remove)
def clear(self) -> int:
"""
Clear all cache entries.
Returns:
Number of entries cleared
"""
with self._lock:
count = len(self._cache)
self._cache.clear()
self._stats.current_size = 0
logger.info(f"Cleared {count} entries from hot cache")
return count
def cleanup_expired(self) -> int:
"""
Remove all expired entries.
Returns:
Number of entries removed
"""
with self._lock:
keys_to_remove = [
k for k, v in self._cache.items() if v.is_expired()
]
for key in keys_to_remove:
del self._cache[key]
self._stats.expirations += 1
self._stats.current_size = len(self._cache)
if keys_to_remove:
logger.debug(f"Cleaned up {len(keys_to_remove)} expired entries")
return len(keys_to_remove)
def get_hot_memories(self, limit: int = 10) -> list[tuple[CacheKey, int]]:
"""
Get the most frequently accessed memories.
Args:
limit: Maximum number of memories to return
Returns:
List of (key, access_count) tuples sorted by access count
"""
with self._lock:
entries = [
(k, v.access_count)
for k, v in self._cache.items()
if not v.is_expired()
]
entries.sort(key=lambda x: x[1], reverse=True)
return entries[:limit]
def get_stats(self) -> HotCacheStats:
"""Get cache statistics."""
with self._lock:
self._stats.current_size = len(self._cache)
return self._stats
def reset_stats(self) -> None:
"""Reset cache statistics."""
with self._lock:
self._stats = HotCacheStats(
max_size=self._max_size,
current_size=len(self._cache),
)
@property
def size(self) -> int:
"""Get current cache size."""
return len(self._cache)
@property
def max_size(self) -> int:
"""Get maximum cache size."""
return self._max_size
# Factory function for typed caches
def create_hot_cache(
max_size: int = 10000,
default_ttl_seconds: float = 300.0,
) -> HotMemoryCache[Any]:
"""
Create a hot memory cache.
Args:
max_size: Maximum number of entries
default_ttl_seconds: Default TTL for entries
Returns:
Configured HotMemoryCache instance
"""
return HotMemoryCache(
max_size=max_size,
default_ttl_seconds=default_ttl_seconds,
)

View File

@@ -0,0 +1,2 @@
# tests/unit/services/memory/cache/__init__.py
"""Tests for memory caching layer."""

View File

@@ -0,0 +1,331 @@
# tests/unit/services/memory/cache/test_cache_manager.py
"""Tests for CacheManager."""
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from app.services.memory.cache.cache_manager import (
CacheManager,
CacheStats,
get_cache_manager,
reset_cache_manager,
)
from app.services.memory.cache.embedding_cache import EmbeddingCache
from app.services.memory.cache.hot_cache import HotMemoryCache
pytestmark = pytest.mark.asyncio(loop_scope="function")
@pytest.fixture(autouse=True)
def reset_singleton() -> None:
"""Reset singleton before each test."""
reset_cache_manager()
class TestCacheStats:
"""Tests for CacheStats."""
def test_to_dict(self) -> None:
"""Should convert to dictionary."""
from datetime import UTC, datetime
stats = CacheStats(
hot_cache={"hits": 10},
embedding_cache={"hits": 20},
overall_hit_rate=0.75,
last_cleanup=datetime.now(UTC),
cleanup_count=5,
)
result = stats.to_dict()
assert result["hot_cache"] == {"hits": 10}
assert result["overall_hit_rate"] == 0.75
assert result["cleanup_count"] == 5
assert result["last_cleanup"] is not None
class TestCacheManager:
"""Tests for CacheManager."""
@pytest.fixture
def manager(self) -> CacheManager:
"""Create a cache manager."""
return CacheManager()
def test_is_enabled(self, manager: CacheManager) -> None:
"""Should check if caching is enabled."""
# Default is enabled from settings
assert manager.is_enabled is True
def test_has_hot_cache(self, manager: CacheManager) -> None:
"""Should have hot memory cache."""
assert manager.hot_cache is not None
assert isinstance(manager.hot_cache, HotMemoryCache)
def test_has_embedding_cache(self, manager: CacheManager) -> None:
"""Should have embedding cache."""
assert manager.embedding_cache is not None
assert isinstance(manager.embedding_cache, EmbeddingCache)
def test_cache_memory(self, manager: CacheManager) -> None:
"""Should cache memory in hot cache."""
memory_id = uuid4()
memory = {"task": "test", "data": "value"}
manager.cache_memory("episodic", memory_id, memory)
result = manager.get_memory("episodic", memory_id)
assert result == memory
def test_cache_memory_with_scope(self, manager: CacheManager) -> None:
"""Should cache memory with scope."""
memory_id = uuid4()
memory = {"task": "test"}
manager.cache_memory("semantic", memory_id, memory, scope="proj-123")
result = manager.get_memory("semantic", memory_id, scope="proj-123")
assert result == memory
async def test_cache_embedding(self, manager: CacheManager) -> None:
"""Should cache embedding."""
content = "test content"
embedding = [0.1, 0.2, 0.3]
content_hash = await manager.cache_embedding(content, embedding)
result = await manager.get_embedding(content)
assert result == embedding
assert len(content_hash) == 32
async def test_invalidate_memory(self, manager: CacheManager) -> None:
"""Should invalidate memory from hot cache."""
memory_id = uuid4()
manager.cache_memory("episodic", memory_id, {"data": "test"})
count = await manager.invalidate_memory("episodic", memory_id)
assert count >= 1
assert manager.get_memory("episodic", memory_id) is None
async def test_invalidate_by_type(self, manager: CacheManager) -> None:
"""Should invalidate all entries of a type."""
manager.cache_memory("episodic", uuid4(), {"data": "1"})
manager.cache_memory("episodic", uuid4(), {"data": "2"})
manager.cache_memory("semantic", uuid4(), {"data": "3"})
count = await manager.invalidate_by_type("episodic")
assert count >= 2
async def test_invalidate_by_scope(self, manager: CacheManager) -> None:
"""Should invalidate all entries in a scope."""
manager.cache_memory("episodic", uuid4(), {"data": "1"}, scope="proj-1")
manager.cache_memory("semantic", uuid4(), {"data": "2"}, scope="proj-1")
manager.cache_memory("episodic", uuid4(), {"data": "3"}, scope="proj-2")
count = await manager.invalidate_by_scope("proj-1")
assert count >= 2
async def test_invalidate_embedding(self, manager: CacheManager) -> None:
"""Should invalidate cached embedding."""
content = "test content"
await manager.cache_embedding(content, [0.1, 0.2])
result = await manager.invalidate_embedding(content)
assert result is True
assert await manager.get_embedding(content) is None
async def test_clear_all(self, manager: CacheManager) -> None:
"""Should clear all caches."""
manager.cache_memory("episodic", uuid4(), {"data": "test"})
await manager.cache_embedding("content", [0.1])
count = await manager.clear_all()
assert count >= 2
async def test_cleanup_expired(self, manager: CacheManager) -> None:
"""Should clean up expired entries."""
count = await manager.cleanup_expired()
# May be 0 if no expired entries
assert count >= 0
assert manager._cleanup_count == 1
assert manager._last_cleanup is not None
def test_get_stats(self, manager: CacheManager) -> None:
"""Should return aggregated statistics."""
manager.cache_memory("episodic", uuid4(), {"data": "test"})
stats = manager.get_stats()
assert "hot_cache" in stats.to_dict()
assert "embedding_cache" in stats.to_dict()
assert "overall_hit_rate" in stats.to_dict()
def test_get_hot_memories(self, manager: CacheManager) -> None:
"""Should return most accessed memories."""
id1 = uuid4()
id2 = uuid4()
manager.cache_memory("episodic", id1, {"data": "1"})
manager.cache_memory("episodic", id2, {"data": "2"})
# Access first multiple times
for _ in range(5):
manager.get_memory("episodic", id1)
hot = manager.get_hot_memories(limit=2)
assert len(hot) == 2
def test_reset_stats(self, manager: CacheManager) -> None:
"""Should reset all statistics."""
manager.cache_memory("episodic", uuid4(), {"data": "test"})
manager.get_memory("episodic", uuid4()) # Miss
manager.reset_stats()
stats = manager.get_stats()
assert stats.hot_cache.get("hits", 0) == 0
async def test_warmup(self, manager: CacheManager) -> None:
"""Should warm up cache with memories."""
memories = [
("episodic", uuid4(), {"data": "1"}),
("episodic", uuid4(), {"data": "2"}),
("semantic", uuid4(), {"data": "3"}),
]
count = await manager.warmup(memories)
assert count == 3
class TestCacheManagerWithRetrieval:
"""Tests for CacheManager with retrieval cache."""
@pytest.fixture
def mock_retrieval_cache(self) -> MagicMock:
"""Create mock retrieval cache."""
cache = MagicMock()
cache.invalidate_by_memory = MagicMock(return_value=1)
cache.clear = MagicMock(return_value=5)
cache.get_stats = MagicMock(return_value={"entries": 10})
return cache
@pytest.fixture
def manager_with_retrieval(
self,
mock_retrieval_cache: MagicMock,
) -> CacheManager:
"""Create manager with retrieval cache."""
manager = CacheManager()
manager.set_retrieval_cache(mock_retrieval_cache)
return manager
async def test_invalidate_clears_retrieval(
self,
manager_with_retrieval: CacheManager,
mock_retrieval_cache: MagicMock,
) -> None:
"""Should invalidate retrieval cache entries."""
memory_id = uuid4()
await manager_with_retrieval.invalidate_memory("episodic", memory_id)
mock_retrieval_cache.invalidate_by_memory.assert_called_once_with(memory_id)
def test_stats_includes_retrieval(
self,
manager_with_retrieval: CacheManager,
) -> None:
"""Should include retrieval cache stats."""
stats = manager_with_retrieval.get_stats()
assert "retrieval_cache" in stats.to_dict()
class TestCacheManagerDisabled:
"""Tests for CacheManager when disabled."""
@pytest.fixture
def disabled_manager(self) -> CacheManager:
"""Create a disabled cache manager."""
with patch(
"app.services.memory.cache.cache_manager.get_memory_settings"
) as mock_settings:
settings = MagicMock()
settings.cache_enabled = False
settings.cache_max_items = 1000
settings.cache_ttl_seconds = 300
mock_settings.return_value = settings
return CacheManager()
def test_get_memory_returns_none(self, disabled_manager: CacheManager) -> None:
"""Should return None when disabled."""
disabled_manager.cache_memory("episodic", uuid4(), {"data": "test"})
result = disabled_manager.get_memory("episodic", uuid4())
assert result is None
async def test_get_embedding_returns_none(
self,
disabled_manager: CacheManager,
) -> None:
"""Should return None for embeddings when disabled."""
result = await disabled_manager.get_embedding("content")
assert result is None
async def test_warmup_returns_zero(self, disabled_manager: CacheManager) -> None:
"""Should return 0 from warmup when disabled."""
count = await disabled_manager.warmup([("episodic", uuid4(), {})])
assert count == 0
class TestGetCacheManager:
"""Tests for get_cache_manager factory."""
def test_returns_singleton(self) -> None:
"""Should return same instance."""
manager1 = get_cache_manager()
manager2 = get_cache_manager()
assert manager1 is manager2
def test_reset_creates_new(self) -> None:
"""Should create new instance after reset."""
manager1 = get_cache_manager()
reset_cache_manager()
manager2 = get_cache_manager()
assert manager1 is not manager2
def test_reset_parameter(self) -> None:
"""Should create new instance with reset=True."""
manager1 = get_cache_manager()
manager2 = get_cache_manager(reset=True)
assert manager1 is not manager2
class TestResetCacheManager:
"""Tests for reset_cache_manager."""
def test_resets_singleton(self) -> None:
"""Should reset the singleton."""
get_cache_manager()
reset_cache_manager()
# Next call should create new instance
manager = get_cache_manager()
assert manager is not None

View File

@@ -0,0 +1,391 @@
# tests/unit/services/memory/cache/test_embedding_cache.py
"""Tests for EmbeddingCache."""
import time
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.services.memory.cache.embedding_cache import (
CachedEmbeddingGenerator,
EmbeddingCache,
EmbeddingCacheStats,
EmbeddingEntry,
create_embedding_cache,
)
pytestmark = pytest.mark.asyncio(loop_scope="function")
class TestEmbeddingEntry:
"""Tests for EmbeddingEntry."""
def test_creates_entry(self) -> None:
"""Should create entry with embedding."""
from datetime import UTC, datetime
entry = EmbeddingEntry(
embedding=[0.1, 0.2, 0.3],
content_hash="abc123",
model="text-embedding-3-small",
created_at=datetime.now(UTC),
)
assert entry.embedding == [0.1, 0.2, 0.3]
assert entry.content_hash == "abc123"
assert entry.ttl_seconds == 3600.0
def test_is_expired(self) -> None:
"""Should detect expired entries."""
from datetime import UTC, datetime, timedelta
old_time = datetime.now(UTC) - timedelta(seconds=4000)
entry = EmbeddingEntry(
embedding=[0.1],
content_hash="abc",
model="default",
created_at=old_time,
ttl_seconds=3600.0,
)
assert entry.is_expired() is True
def test_not_expired(self) -> None:
"""Should detect non-expired entries."""
from datetime import UTC, datetime
entry = EmbeddingEntry(
embedding=[0.1],
content_hash="abc",
model="default",
created_at=datetime.now(UTC),
)
assert entry.is_expired() is False
class TestEmbeddingCacheStats:
"""Tests for EmbeddingCacheStats."""
def test_hit_rate_calculation(self) -> None:
"""Should calculate hit rate correctly."""
stats = EmbeddingCacheStats(hits=90, misses=10)
assert stats.hit_rate == 0.9
def test_hit_rate_zero_requests(self) -> None:
"""Should return 0 for no requests."""
stats = EmbeddingCacheStats()
assert stats.hit_rate == 0.0
def test_to_dict(self) -> None:
"""Should convert to dictionary."""
stats = EmbeddingCacheStats(hits=10, misses=5, bytes_saved=1000)
result = stats.to_dict()
assert result["hits"] == 10
assert result["bytes_saved"] == 1000
class TestEmbeddingCache:
"""Tests for EmbeddingCache."""
@pytest.fixture
def cache(self) -> EmbeddingCache:
"""Create an embedding cache."""
return EmbeddingCache(max_size=100, default_ttl_seconds=300.0)
async def test_put_and_get(self, cache: EmbeddingCache) -> None:
"""Should store and retrieve embeddings."""
content = "Hello world"
embedding = [0.1, 0.2, 0.3, 0.4]
content_hash = await cache.put(content, embedding)
result = await cache.get(content)
assert result == embedding
assert len(content_hash) == 32
async def test_get_missing(self, cache: EmbeddingCache) -> None:
"""Should return None for missing content."""
result = await cache.get("nonexistent content")
assert result is None
async def test_get_by_hash(self, cache: EmbeddingCache) -> None:
"""Should get by content hash."""
content = "Test content"
embedding = [0.1, 0.2]
content_hash = await cache.put(content, embedding)
result = await cache.get_by_hash(content_hash)
assert result == embedding
async def test_model_separation(self, cache: EmbeddingCache) -> None:
"""Should separate embeddings by model."""
content = "Same content"
emb1 = [0.1, 0.2]
emb2 = [0.3, 0.4]
await cache.put(content, emb1, model="model-a")
await cache.put(content, emb2, model="model-b")
result1 = await cache.get(content, model="model-a")
result2 = await cache.get(content, model="model-b")
assert result1 == emb1
assert result2 == emb2
async def test_lru_eviction(self) -> None:
"""Should evict LRU entries when at capacity."""
cache = EmbeddingCache(max_size=3)
await cache.put("content1", [0.1])
await cache.put("content2", [0.2])
await cache.put("content3", [0.3])
# Access first to make it recent
await cache.get("content1")
# Add fourth, should evict second (LRU)
await cache.put("content4", [0.4])
assert await cache.get("content1") is not None
assert await cache.get("content2") is None # Evicted
assert await cache.get("content3") is not None
assert await cache.get("content4") is not None
async def test_ttl_expiration(self) -> None:
"""Should expire entries after TTL."""
cache = EmbeddingCache(max_size=100, default_ttl_seconds=0.1)
await cache.put("content", [0.1, 0.2])
time.sleep(0.2)
result = await cache.get("content")
assert result is None
async def test_put_batch(self, cache: EmbeddingCache) -> None:
"""Should cache multiple embeddings."""
items = [
("content1", [0.1]),
("content2", [0.2]),
("content3", [0.3]),
]
hashes = await cache.put_batch(items)
assert len(hashes) == 3
assert await cache.get("content1") == [0.1]
assert await cache.get("content2") == [0.2]
async def test_invalidate(self, cache: EmbeddingCache) -> None:
"""Should invalidate cached embedding."""
await cache.put("content", [0.1, 0.2])
result = await cache.invalidate("content")
assert result is True
assert await cache.get("content") is None
async def test_invalidate_by_hash(self, cache: EmbeddingCache) -> None:
"""Should invalidate by hash."""
content_hash = await cache.put("content", [0.1, 0.2])
result = await cache.invalidate_by_hash(content_hash)
assert result is True
assert await cache.get("content") is None
async def test_invalidate_by_model(self, cache: EmbeddingCache) -> None:
"""Should invalidate all embeddings for a model."""
await cache.put("content1", [0.1], model="model-a")
await cache.put("content2", [0.2], model="model-a")
await cache.put("content3", [0.3], model="model-b")
count = await cache.invalidate_by_model("model-a")
assert count == 2
assert await cache.get("content1", model="model-a") is None
assert await cache.get("content3", model="model-b") is not None
async def test_clear(self, cache: EmbeddingCache) -> None:
"""Should clear all entries."""
await cache.put("content1", [0.1])
await cache.put("content2", [0.2])
count = await cache.clear()
assert count == 2
assert cache.size == 0
def test_cleanup_expired(self) -> None:
"""Should remove expired entries."""
cache = EmbeddingCache(max_size=100, default_ttl_seconds=0.1)
# Use synchronous put for setup
cache._put_memory("hash1", "default", [0.1])
cache._put_memory("hash2", "default", [0.2], ttl_seconds=10)
time.sleep(0.2)
count = cache.cleanup_expired()
assert count == 1
def test_get_stats(self, cache: EmbeddingCache) -> None:
"""Should return accurate statistics."""
# Put synchronously for setup
cache._put_memory("hash1", "default", [0.1])
stats = cache.get_stats()
assert stats.current_size == 1
def test_hash_content(self) -> None:
"""Should produce consistent hashes."""
hash1 = EmbeddingCache.hash_content("test content")
hash2 = EmbeddingCache.hash_content("test content")
hash3 = EmbeddingCache.hash_content("different content")
assert hash1 == hash2
assert hash1 != hash3
assert len(hash1) == 32
class TestEmbeddingCacheWithRedis:
"""Tests for EmbeddingCache with Redis."""
@pytest.fixture
def mock_redis(self) -> MagicMock:
"""Create mock Redis."""
redis = MagicMock()
redis.get = AsyncMock(return_value=None)
redis.setex = AsyncMock()
redis.delete = AsyncMock()
redis.scan_iter = MagicMock(return_value=iter([]))
return redis
@pytest.fixture
def cache_with_redis(self, mock_redis: MagicMock) -> EmbeddingCache:
"""Create cache with mock Redis."""
return EmbeddingCache(
max_size=100,
default_ttl_seconds=300.0,
redis=mock_redis,
)
async def test_put_stores_in_redis(
self,
cache_with_redis: EmbeddingCache,
mock_redis: MagicMock,
) -> None:
"""Should store in Redis when available."""
await cache_with_redis.put("content", [0.1, 0.2])
mock_redis.setex.assert_called_once()
async def test_get_checks_redis_on_miss(
self,
cache_with_redis: EmbeddingCache,
mock_redis: MagicMock,
) -> None:
"""Should check Redis when memory cache misses."""
import json
mock_redis.get.return_value = json.dumps([0.1, 0.2])
result = await cache_with_redis.get("content")
assert result == [0.1, 0.2]
mock_redis.get.assert_called_once()
class TestCachedEmbeddingGenerator:
"""Tests for CachedEmbeddingGenerator."""
@pytest.fixture
def mock_generator(self) -> MagicMock:
"""Create mock embedding generator."""
gen = MagicMock()
gen.generate = AsyncMock(return_value=[0.1, 0.2, 0.3])
gen.generate_batch = AsyncMock(return_value=[[0.1], [0.2], [0.3]])
return gen
@pytest.fixture
def cache(self) -> EmbeddingCache:
"""Create embedding cache."""
return EmbeddingCache(max_size=100)
@pytest.fixture
def cached_gen(
self,
mock_generator: MagicMock,
cache: EmbeddingCache,
) -> CachedEmbeddingGenerator:
"""Create cached generator."""
return CachedEmbeddingGenerator(mock_generator, cache)
async def test_generate_caches_result(
self,
cached_gen: CachedEmbeddingGenerator,
mock_generator: MagicMock,
) -> None:
"""Should cache generated embedding."""
result1 = await cached_gen.generate("test text")
result2 = await cached_gen.generate("test text")
assert result1 == [0.1, 0.2, 0.3]
assert result2 == [0.1, 0.2, 0.3]
mock_generator.generate.assert_called_once() # Only called once
async def test_generate_batch_uses_cache(
self,
cached_gen: CachedEmbeddingGenerator,
mock_generator: MagicMock,
cache: EmbeddingCache,
) -> None:
"""Should use cache for batch generation."""
# Pre-cache one embedding
await cache.put("text1", [0.5])
# Mock returns 2 embeddings for the 2 uncached texts
mock_generator.generate_batch = AsyncMock(return_value=[[0.2], [0.3]])
results = await cached_gen.generate_batch(["text1", "text2", "text3"])
assert len(results) == 3
assert results[0] == [0.5] # From cache
assert results[1] == [0.2] # Generated
assert results[2] == [0.3] # Generated
async def test_get_stats(self, cached_gen: CachedEmbeddingGenerator) -> None:
"""Should return generator statistics."""
await cached_gen.generate("text1")
await cached_gen.generate("text1") # Cache hit
stats = cached_gen.get_stats()
assert stats["call_count"] == 2
assert stats["cache_hit_count"] == 1
class TestCreateEmbeddingCache:
"""Tests for factory function."""
def test_creates_cache(self) -> None:
"""Should create cache with defaults."""
cache = create_embedding_cache()
assert cache.max_size == 50000
def test_creates_cache_with_options(self) -> None:
"""Should create cache with custom options."""
cache = create_embedding_cache(max_size=1000, default_ttl_seconds=600.0)
assert cache.max_size == 1000

View File

@@ -0,0 +1,355 @@
# tests/unit/services/memory/cache/test_hot_cache.py
"""Tests for HotMemoryCache."""
import time
from uuid import uuid4
import pytest
from app.services.memory.cache.hot_cache import (
CacheEntry,
CacheKey,
HotCacheStats,
HotMemoryCache,
create_hot_cache,
)
class TestCacheKey:
"""Tests for CacheKey."""
def test_creates_key(self) -> None:
"""Should create key with required fields."""
key = CacheKey(memory_type="episodic", memory_id="123")
assert key.memory_type == "episodic"
assert key.memory_id == "123"
assert key.scope is None
def test_creates_key_with_scope(self) -> None:
"""Should create key with scope."""
key = CacheKey(memory_type="semantic", memory_id="456", scope="proj-123")
assert key.scope == "proj-123"
def test_hash_and_equality(self) -> None:
"""Keys with same values should be equal and have same hash."""
key1 = CacheKey(memory_type="episodic", memory_id="123", scope="proj-1")
key2 = CacheKey(memory_type="episodic", memory_id="123", scope="proj-1")
assert key1 == key2
assert hash(key1) == hash(key2)
def test_str_representation(self) -> None:
"""Should produce readable string."""
key = CacheKey(memory_type="episodic", memory_id="123", scope="proj-1")
assert str(key) == "episodic:proj-1:123"
def test_str_without_scope(self) -> None:
"""Should produce string without scope."""
key = CacheKey(memory_type="episodic", memory_id="123")
assert str(key) == "episodic:123"
class TestCacheEntry:
"""Tests for CacheEntry."""
def test_creates_entry(self) -> None:
"""Should create entry with value."""
entry = CacheEntry(
value={"data": "test"},
created_at=pytest.importorskip("datetime").datetime.now(
pytest.importorskip("datetime").UTC
),
last_accessed_at=pytest.importorskip("datetime").datetime.now(
pytest.importorskip("datetime").UTC
),
)
assert entry.value == {"data": "test"}
assert entry.access_count == 1
assert entry.ttl_seconds == 300.0
def test_is_expired(self) -> None:
"""Should detect expired entries."""
from datetime import UTC, datetime, timedelta
old_time = datetime.now(UTC) - timedelta(seconds=400)
entry = CacheEntry(
value="test",
created_at=old_time,
last_accessed_at=old_time,
ttl_seconds=300.0,
)
assert entry.is_expired() is True
def test_not_expired(self) -> None:
"""Should detect non-expired entries."""
from datetime import UTC, datetime
entry = CacheEntry(
value="test",
created_at=datetime.now(UTC),
last_accessed_at=datetime.now(UTC),
ttl_seconds=300.0,
)
assert entry.is_expired() is False
def test_touch_updates_access(self) -> None:
"""Touch should update access time and count."""
from datetime import UTC, datetime, timedelta
old_time = datetime.now(UTC) - timedelta(seconds=10)
entry = CacheEntry(
value="test",
created_at=old_time,
last_accessed_at=old_time,
access_count=5,
)
entry.touch()
assert entry.access_count == 6
assert entry.last_accessed_at > old_time
class TestHotCacheStats:
"""Tests for HotCacheStats."""
def test_hit_rate_calculation(self) -> None:
"""Should calculate hit rate correctly."""
stats = HotCacheStats(hits=80, misses=20)
assert stats.hit_rate == 0.8
def test_hit_rate_zero_requests(self) -> None:
"""Should return 0 for no requests."""
stats = HotCacheStats()
assert stats.hit_rate == 0.0
def test_to_dict(self) -> None:
"""Should convert to dictionary."""
stats = HotCacheStats(hits=10, misses=5, evictions=2)
result = stats.to_dict()
assert result["hits"] == 10
assert result["misses"] == 5
assert result["evictions"] == 2
assert "hit_rate" in result
class TestHotMemoryCache:
"""Tests for HotMemoryCache."""
@pytest.fixture
def cache(self) -> HotMemoryCache[dict]:
"""Create a hot memory cache."""
return HotMemoryCache[dict](max_size=100, default_ttl_seconds=300.0)
def test_put_and_get(self, cache: HotMemoryCache[dict]) -> None:
"""Should store and retrieve values."""
key = CacheKey(memory_type="episodic", memory_id="123")
value = {"data": "test"}
cache.put(key, value)
result = cache.get(key)
assert result == value
def test_get_missing_key(self, cache: HotMemoryCache[dict]) -> None:
"""Should return None for missing keys."""
key = CacheKey(memory_type="episodic", memory_id="nonexistent")
result = cache.get(key)
assert result is None
def test_put_by_id(self, cache: HotMemoryCache[dict]) -> None:
"""Should store by type and ID."""
memory_id = uuid4()
value = {"data": "test"}
cache.put_by_id("episodic", memory_id, value)
result = cache.get_by_id("episodic", memory_id)
assert result == value
def test_put_by_id_with_scope(self, cache: HotMemoryCache[dict]) -> None:
"""Should store with scope."""
memory_id = uuid4()
value = {"data": "test"}
cache.put_by_id("semantic", memory_id, value, scope="proj-123")
result = cache.get_by_id("semantic", memory_id, scope="proj-123")
assert result == value
def test_lru_eviction(self) -> None:
"""Should evict LRU entries when at capacity."""
cache = HotMemoryCache[str](max_size=3)
# Fill cache
cache.put_by_id("test", "1", "first")
cache.put_by_id("test", "2", "second")
cache.put_by_id("test", "3", "third")
# Access first to make it recent
cache.get_by_id("test", "1")
# Add fourth, should evict second (LRU)
cache.put_by_id("test", "4", "fourth")
assert cache.get_by_id("test", "1") is not None # Accessed, kept
assert cache.get_by_id("test", "2") is None # Evicted (LRU)
assert cache.get_by_id("test", "3") is not None
assert cache.get_by_id("test", "4") is not None
def test_ttl_expiration(self) -> None:
"""Should expire entries after TTL."""
cache = HotMemoryCache[str](max_size=100, default_ttl_seconds=0.1)
cache.put_by_id("test", "1", "value")
# Wait for expiration
time.sleep(0.2)
result = cache.get_by_id("test", "1")
assert result is None
def test_invalidate(self, cache: HotMemoryCache[dict]) -> None:
"""Should invalidate specific entry."""
key = CacheKey(memory_type="episodic", memory_id="123")
cache.put(key, {"data": "test"})
result = cache.invalidate(key)
assert result is True
assert cache.get(key) is None
def test_invalidate_by_id(self, cache: HotMemoryCache[dict]) -> None:
"""Should invalidate by ID."""
memory_id = uuid4()
cache.put_by_id("episodic", memory_id, {"data": "test"})
result = cache.invalidate_by_id("episodic", memory_id)
assert result is True
assert cache.get_by_id("episodic", memory_id) is None
def test_invalidate_by_type(self, cache: HotMemoryCache[dict]) -> None:
"""Should invalidate all entries of a type."""
cache.put_by_id("episodic", "1", {"data": "1"})
cache.put_by_id("episodic", "2", {"data": "2"})
cache.put_by_id("semantic", "3", {"data": "3"})
count = cache.invalidate_by_type("episodic")
assert count == 2
assert cache.get_by_id("episodic", "1") is None
assert cache.get_by_id("episodic", "2") is None
assert cache.get_by_id("semantic", "3") is not None
def test_invalidate_by_scope(self, cache: HotMemoryCache[dict]) -> None:
"""Should invalidate all entries in a scope."""
cache.put_by_id("episodic", "1", {"data": "1"}, scope="proj-1")
cache.put_by_id("semantic", "2", {"data": "2"}, scope="proj-1")
cache.put_by_id("episodic", "3", {"data": "3"}, scope="proj-2")
count = cache.invalidate_by_scope("proj-1")
assert count == 2
assert cache.get_by_id("episodic", "3", scope="proj-2") is not None
def test_invalidate_pattern(self, cache: HotMemoryCache[dict]) -> None:
"""Should invalidate entries matching pattern."""
cache.put_by_id("episodic", "123", {"data": "1"})
cache.put_by_id("episodic", "124", {"data": "2"})
cache.put_by_id("semantic", "125", {"data": "3"})
count = cache.invalidate_pattern("episodic:*")
assert count == 2
def test_clear(self, cache: HotMemoryCache[dict]) -> None:
"""Should clear all entries."""
cache.put_by_id("episodic", "1", {"data": "1"})
cache.put_by_id("semantic", "2", {"data": "2"})
count = cache.clear()
assert count == 2
assert cache.size == 0
def test_cleanup_expired(self) -> None:
"""Should remove expired entries."""
cache = HotMemoryCache[str](max_size=100, default_ttl_seconds=0.1)
cache.put_by_id("test", "1", "value1")
cache.put_by_id("test", "2", "value2", ttl_seconds=10)
time.sleep(0.2)
count = cache.cleanup_expired()
assert count == 1 # Only the first one expired
assert cache.size == 1
def test_get_hot_memories(self, cache: HotMemoryCache[dict]) -> None:
"""Should return most accessed memories."""
cache.put_by_id("episodic", "1", {"data": "1"})
cache.put_by_id("episodic", "2", {"data": "2"})
# Access first one multiple times
for _ in range(5):
cache.get_by_id("episodic", "1")
hot = cache.get_hot_memories(limit=2)
assert len(hot) == 2
assert hot[0][1] >= hot[1][1] # Sorted by access count
def test_get_stats(self, cache: HotMemoryCache[dict]) -> None:
"""Should return accurate statistics."""
cache.put_by_id("episodic", "1", {"data": "1"})
cache.get_by_id("episodic", "1") # Hit
cache.get_by_id("episodic", "2") # Miss
stats = cache.get_stats()
assert stats.hits == 1
assert stats.misses == 1
assert stats.current_size == 1
def test_reset_stats(self, cache: HotMemoryCache[dict]) -> None:
"""Should reset statistics."""
cache.put_by_id("episodic", "1", {"data": "1"})
cache.get_by_id("episodic", "1")
cache.reset_stats()
stats = cache.get_stats()
assert stats.hits == 0
assert stats.misses == 0
class TestCreateHotCache:
"""Tests for factory function."""
def test_creates_cache(self) -> None:
"""Should create cache with defaults."""
cache = create_hot_cache()
assert cache.max_size == 10000
def test_creates_cache_with_options(self) -> None:
"""Should create cache with custom options."""
cache = create_hot_cache(max_size=500, default_ttl_seconds=60.0)
assert cache.max_size == 500