forked from cardosofelipe/fast-next-template
Add comprehensive caching layer for the Agent Memory System: - HotMemoryCache: LRU cache for frequently accessed memories - Python 3.12 type parameter syntax - Thread-safe operations with RLock - TTL-based expiration - Access count tracking for hot memory identification - Scoped invalidation by type, scope, or pattern - EmbeddingCache: Cache embeddings by content hash - Content-hash based deduplication - Optional Redis backing for persistence - LRU eviction with configurable max size - CachedEmbeddingGenerator wrapper for transparent caching - CacheManager: Unified cache management - Coordinates hot cache, embedding cache, and retrieval cache - Centralized invalidation across all caches - Aggregated statistics and hit rate tracking - Automatic cleanup scheduling - Cache warmup support Performance targets: - Cache hit rate > 80% for hot memories - Cache operations < 1ms (memory), < 5ms (Redis) 83 new tests with comprehensive coverage. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
628 lines
18 KiB
Python
628 lines
18 KiB
Python
# app/services/memory/cache/embedding_cache.py
|
|
"""
|
|
Embedding Cache.
|
|
|
|
Caches embeddings by content hash to avoid recomputing.
|
|
Provides significant performance improvement for repeated content.
|
|
"""
|
|
|
|
import hashlib
|
|
import logging
|
|
import threading
|
|
from collections import OrderedDict
|
|
from dataclasses import dataclass
|
|
from datetime import UTC, datetime
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
if TYPE_CHECKING:
|
|
from redis.asyncio import Redis
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _utcnow() -> datetime:
|
|
"""Get current UTC time as timezone-aware datetime."""
|
|
return datetime.now(UTC)
|
|
|
|
|
|
@dataclass
|
|
class EmbeddingEntry:
|
|
"""A cached embedding entry."""
|
|
|
|
embedding: list[float]
|
|
content_hash: str
|
|
model: str
|
|
created_at: datetime
|
|
ttl_seconds: float = 3600.0 # 1 hour default
|
|
|
|
def is_expired(self) -> bool:
|
|
"""Check if this entry has expired."""
|
|
age = (_utcnow() - self.created_at).total_seconds()
|
|
return age > self.ttl_seconds
|
|
|
|
|
|
@dataclass
|
|
class EmbeddingCacheStats:
|
|
"""Statistics for the embedding cache."""
|
|
|
|
hits: int = 0
|
|
misses: int = 0
|
|
evictions: int = 0
|
|
expirations: int = 0
|
|
current_size: int = 0
|
|
max_size: int = 0
|
|
bytes_saved: int = 0 # Estimated bytes saved by caching
|
|
|
|
@property
|
|
def hit_rate(self) -> float:
|
|
"""Calculate cache hit rate."""
|
|
total = self.hits + self.misses
|
|
if total == 0:
|
|
return 0.0
|
|
return self.hits / total
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
"""Convert to dictionary."""
|
|
return {
|
|
"hits": self.hits,
|
|
"misses": self.misses,
|
|
"evictions": self.evictions,
|
|
"expirations": self.expirations,
|
|
"current_size": self.current_size,
|
|
"max_size": self.max_size,
|
|
"hit_rate": self.hit_rate,
|
|
"bytes_saved": self.bytes_saved,
|
|
}
|
|
|
|
|
|
class EmbeddingCache:
|
|
"""
|
|
Cache for embeddings by content hash.
|
|
|
|
Features:
|
|
- Content-hash based deduplication
|
|
- LRU eviction
|
|
- TTL-based expiration
|
|
- Optional Redis backing for persistence
|
|
- Thread-safe operations
|
|
|
|
Performance targets:
|
|
- Cache hit rate > 90% for repeated content
|
|
- Get/put operations < 1ms (memory), < 5ms (Redis)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
max_size: int = 50000,
|
|
default_ttl_seconds: float = 3600.0,
|
|
redis: "Redis | None" = None,
|
|
redis_prefix: str = "mem:emb",
|
|
) -> None:
|
|
"""
|
|
Initialize the embedding cache.
|
|
|
|
Args:
|
|
max_size: Maximum number of entries in memory cache
|
|
default_ttl_seconds: Default TTL for entries (1 hour)
|
|
redis: Optional Redis connection for persistence
|
|
redis_prefix: Prefix for Redis keys
|
|
"""
|
|
self._max_size = max_size
|
|
self._default_ttl = default_ttl_seconds
|
|
self._cache: OrderedDict[str, EmbeddingEntry] = OrderedDict()
|
|
self._lock = threading.RLock()
|
|
self._stats = EmbeddingCacheStats(max_size=max_size)
|
|
self._redis = redis
|
|
self._redis_prefix = redis_prefix
|
|
|
|
logger.info(
|
|
f"Initialized EmbeddingCache with max_size={max_size}, "
|
|
f"ttl={default_ttl_seconds}s, redis={'enabled' if redis else 'disabled'}"
|
|
)
|
|
|
|
def set_redis(self, redis: "Redis") -> None:
|
|
"""Set Redis connection for persistence."""
|
|
self._redis = redis
|
|
|
|
@staticmethod
|
|
def hash_content(content: str) -> str:
|
|
"""
|
|
Compute hash of content for cache key.
|
|
|
|
Args:
|
|
content: Content to hash
|
|
|
|
Returns:
|
|
32-character hex hash
|
|
"""
|
|
return hashlib.sha256(content.encode()).hexdigest()[:32]
|
|
|
|
def _cache_key(self, content_hash: str, model: str) -> str:
|
|
"""Build cache key from content hash and model."""
|
|
return f"{content_hash}:{model}"
|
|
|
|
def _redis_key(self, content_hash: str, model: str) -> str:
|
|
"""Build Redis key from content hash and model."""
|
|
return f"{self._redis_prefix}:{content_hash}:{model}"
|
|
|
|
async def get(
|
|
self,
|
|
content: str,
|
|
model: str = "default",
|
|
) -> list[float] | None:
|
|
"""
|
|
Get a cached embedding.
|
|
|
|
Args:
|
|
content: Content text
|
|
model: Model name
|
|
|
|
Returns:
|
|
Cached embedding or None if not found/expired
|
|
"""
|
|
content_hash = self.hash_content(content)
|
|
cache_key = self._cache_key(content_hash, model)
|
|
|
|
# Check memory cache first
|
|
with self._lock:
|
|
if cache_key in self._cache:
|
|
entry = self._cache[cache_key]
|
|
if entry.is_expired():
|
|
del self._cache[cache_key]
|
|
self._stats.expirations += 1
|
|
self._stats.current_size = len(self._cache)
|
|
else:
|
|
# Move to end (most recently used)
|
|
self._cache.move_to_end(cache_key)
|
|
self._stats.hits += 1
|
|
return entry.embedding
|
|
|
|
# Check Redis if available
|
|
if self._redis:
|
|
try:
|
|
redis_key = self._redis_key(content_hash, model)
|
|
data = await self._redis.get(redis_key)
|
|
if data:
|
|
import json
|
|
|
|
embedding = json.loads(data)
|
|
# Store in memory cache for faster access
|
|
self._put_memory(content_hash, model, embedding)
|
|
self._stats.hits += 1
|
|
return embedding
|
|
except Exception as e:
|
|
logger.warning(f"Redis get error: {e}")
|
|
|
|
self._stats.misses += 1
|
|
return None
|
|
|
|
async def get_by_hash(
|
|
self,
|
|
content_hash: str,
|
|
model: str = "default",
|
|
) -> list[float] | None:
|
|
"""
|
|
Get a cached embedding by hash.
|
|
|
|
Args:
|
|
content_hash: Content hash
|
|
model: Model name
|
|
|
|
Returns:
|
|
Cached embedding or None if not found/expired
|
|
"""
|
|
cache_key = self._cache_key(content_hash, model)
|
|
|
|
with self._lock:
|
|
if cache_key in self._cache:
|
|
entry = self._cache[cache_key]
|
|
if entry.is_expired():
|
|
del self._cache[cache_key]
|
|
self._stats.expirations += 1
|
|
self._stats.current_size = len(self._cache)
|
|
else:
|
|
self._cache.move_to_end(cache_key)
|
|
self._stats.hits += 1
|
|
return entry.embedding
|
|
|
|
# Check Redis
|
|
if self._redis:
|
|
try:
|
|
redis_key = self._redis_key(content_hash, model)
|
|
data = await self._redis.get(redis_key)
|
|
if data:
|
|
import json
|
|
|
|
embedding = json.loads(data)
|
|
self._put_memory(content_hash, model, embedding)
|
|
self._stats.hits += 1
|
|
return embedding
|
|
except Exception as e:
|
|
logger.warning(f"Redis get error: {e}")
|
|
|
|
self._stats.misses += 1
|
|
return None
|
|
|
|
async def put(
|
|
self,
|
|
content: str,
|
|
embedding: list[float],
|
|
model: str = "default",
|
|
ttl_seconds: float | None = None,
|
|
) -> str:
|
|
"""
|
|
Cache an embedding.
|
|
|
|
Args:
|
|
content: Content text
|
|
embedding: Embedding vector
|
|
model: Model name
|
|
ttl_seconds: Optional TTL override
|
|
|
|
Returns:
|
|
Content hash
|
|
"""
|
|
content_hash = self.hash_content(content)
|
|
ttl = ttl_seconds or self._default_ttl
|
|
|
|
# Store in memory
|
|
self._put_memory(content_hash, model, embedding, ttl)
|
|
|
|
# Store in Redis if available
|
|
if self._redis:
|
|
try:
|
|
import json
|
|
|
|
redis_key = self._redis_key(content_hash, model)
|
|
await self._redis.setex(
|
|
redis_key,
|
|
int(ttl),
|
|
json.dumps(embedding),
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Redis put error: {e}")
|
|
|
|
return content_hash
|
|
|
|
def _put_memory(
|
|
self,
|
|
content_hash: str,
|
|
model: str,
|
|
embedding: list[float],
|
|
ttl_seconds: float | None = None,
|
|
) -> None:
|
|
"""Store in memory cache."""
|
|
with self._lock:
|
|
# Evict if at capacity
|
|
self._evict_if_needed()
|
|
|
|
cache_key = self._cache_key(content_hash, model)
|
|
entry = EmbeddingEntry(
|
|
embedding=embedding,
|
|
content_hash=content_hash,
|
|
model=model,
|
|
created_at=_utcnow(),
|
|
ttl_seconds=ttl_seconds or self._default_ttl,
|
|
)
|
|
|
|
self._cache[cache_key] = entry
|
|
self._cache.move_to_end(cache_key)
|
|
self._stats.current_size = len(self._cache)
|
|
|
|
def _evict_if_needed(self) -> None:
|
|
"""Evict entries if cache is at capacity."""
|
|
while len(self._cache) >= self._max_size:
|
|
if self._cache:
|
|
self._cache.popitem(last=False)
|
|
self._stats.evictions += 1
|
|
|
|
async def put_batch(
|
|
self,
|
|
items: list[tuple[str, list[float]]],
|
|
model: str = "default",
|
|
ttl_seconds: float | None = None,
|
|
) -> list[str]:
|
|
"""
|
|
Cache multiple embeddings.
|
|
|
|
Args:
|
|
items: List of (content, embedding) tuples
|
|
model: Model name
|
|
ttl_seconds: Optional TTL override
|
|
|
|
Returns:
|
|
List of content hashes
|
|
"""
|
|
hashes = []
|
|
for content, embedding in items:
|
|
content_hash = await self.put(content, embedding, model, ttl_seconds)
|
|
hashes.append(content_hash)
|
|
return hashes
|
|
|
|
async def invalidate(
|
|
self,
|
|
content: str,
|
|
model: str = "default",
|
|
) -> bool:
|
|
"""
|
|
Invalidate a cached embedding.
|
|
|
|
Args:
|
|
content: Content text
|
|
model: Model name
|
|
|
|
Returns:
|
|
True if entry was found and removed
|
|
"""
|
|
content_hash = self.hash_content(content)
|
|
return await self.invalidate_by_hash(content_hash, model)
|
|
|
|
async def invalidate_by_hash(
|
|
self,
|
|
content_hash: str,
|
|
model: str = "default",
|
|
) -> bool:
|
|
"""
|
|
Invalidate a cached embedding by hash.
|
|
|
|
Args:
|
|
content_hash: Content hash
|
|
model: Model name
|
|
|
|
Returns:
|
|
True if entry was found and removed
|
|
"""
|
|
cache_key = self._cache_key(content_hash, model)
|
|
removed = False
|
|
|
|
with self._lock:
|
|
if cache_key in self._cache:
|
|
del self._cache[cache_key]
|
|
self._stats.current_size = len(self._cache)
|
|
removed = True
|
|
|
|
# Remove from Redis
|
|
if self._redis:
|
|
try:
|
|
redis_key = self._redis_key(content_hash, model)
|
|
await self._redis.delete(redis_key)
|
|
removed = True
|
|
except Exception as e:
|
|
logger.warning(f"Redis delete error: {e}")
|
|
|
|
return removed
|
|
|
|
async def invalidate_by_model(self, model: str) -> int:
|
|
"""
|
|
Invalidate all embeddings for a model.
|
|
|
|
Args:
|
|
model: Model name
|
|
|
|
Returns:
|
|
Number of entries invalidated
|
|
"""
|
|
count = 0
|
|
|
|
with self._lock:
|
|
keys_to_remove = [
|
|
k for k, v in self._cache.items() if v.model == model
|
|
]
|
|
for key in keys_to_remove:
|
|
del self._cache[key]
|
|
count += 1
|
|
|
|
self._stats.current_size = len(self._cache)
|
|
|
|
# Note: Redis pattern deletion would require SCAN which is expensive
|
|
# For now, we only clear memory cache for model-based invalidation
|
|
|
|
return count
|
|
|
|
async def clear(self) -> int:
|
|
"""
|
|
Clear all cache entries.
|
|
|
|
Returns:
|
|
Number of entries cleared
|
|
"""
|
|
with self._lock:
|
|
count = len(self._cache)
|
|
self._cache.clear()
|
|
self._stats.current_size = 0
|
|
|
|
# Clear Redis entries
|
|
if self._redis:
|
|
try:
|
|
pattern = f"{self._redis_prefix}:*"
|
|
deleted = 0
|
|
async for key in self._redis.scan_iter(match=pattern):
|
|
await self._redis.delete(key)
|
|
deleted += 1
|
|
count = max(count, deleted)
|
|
except Exception as e:
|
|
logger.warning(f"Redis clear error: {e}")
|
|
|
|
logger.info(f"Cleared {count} entries from embedding cache")
|
|
return count
|
|
|
|
def cleanup_expired(self) -> int:
|
|
"""
|
|
Remove all expired entries from memory cache.
|
|
|
|
Returns:
|
|
Number of entries removed
|
|
"""
|
|
with self._lock:
|
|
keys_to_remove = [
|
|
k for k, v in self._cache.items() if v.is_expired()
|
|
]
|
|
for key in keys_to_remove:
|
|
del self._cache[key]
|
|
self._stats.expirations += 1
|
|
|
|
self._stats.current_size = len(self._cache)
|
|
|
|
if keys_to_remove:
|
|
logger.debug(f"Cleaned up {len(keys_to_remove)} expired embeddings")
|
|
|
|
return len(keys_to_remove)
|
|
|
|
def get_stats(self) -> EmbeddingCacheStats:
|
|
"""Get cache statistics."""
|
|
with self._lock:
|
|
self._stats.current_size = len(self._cache)
|
|
return self._stats
|
|
|
|
def reset_stats(self) -> None:
|
|
"""Reset cache statistics."""
|
|
with self._lock:
|
|
self._stats = EmbeddingCacheStats(
|
|
max_size=self._max_size,
|
|
current_size=len(self._cache),
|
|
)
|
|
|
|
@property
|
|
def size(self) -> int:
|
|
"""Get current cache size."""
|
|
return len(self._cache)
|
|
|
|
@property
|
|
def max_size(self) -> int:
|
|
"""Get maximum cache size."""
|
|
return self._max_size
|
|
|
|
|
|
class CachedEmbeddingGenerator:
|
|
"""
|
|
Wrapper for embedding generators with caching.
|
|
|
|
Wraps an embedding generator to cache results.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
generator: Any,
|
|
cache: EmbeddingCache,
|
|
model: str = "default",
|
|
) -> None:
|
|
"""
|
|
Initialize the cached embedding generator.
|
|
|
|
Args:
|
|
generator: Underlying embedding generator
|
|
cache: Embedding cache
|
|
model: Model name for cache keys
|
|
"""
|
|
self._generator = generator
|
|
self._cache = cache
|
|
self._model = model
|
|
self._call_count = 0
|
|
self._cache_hit_count = 0
|
|
|
|
async def generate(self, text: str) -> list[float]:
|
|
"""
|
|
Generate embedding with caching.
|
|
|
|
Args:
|
|
text: Text to embed
|
|
|
|
Returns:
|
|
Embedding vector
|
|
"""
|
|
self._call_count += 1
|
|
|
|
# Check cache first
|
|
cached = await self._cache.get(text, self._model)
|
|
if cached is not None:
|
|
self._cache_hit_count += 1
|
|
return cached
|
|
|
|
# Generate and cache
|
|
embedding = await self._generator.generate(text)
|
|
await self._cache.put(text, embedding, self._model)
|
|
|
|
return embedding
|
|
|
|
async def generate_batch(
|
|
self,
|
|
texts: list[str],
|
|
) -> list[list[float]]:
|
|
"""
|
|
Generate embeddings for multiple texts with caching.
|
|
|
|
Args:
|
|
texts: Texts to embed
|
|
|
|
Returns:
|
|
List of embedding vectors
|
|
"""
|
|
results: list[list[float] | None] = [None] * len(texts)
|
|
to_generate: list[tuple[int, str]] = []
|
|
|
|
# Check cache for each text
|
|
for i, text in enumerate(texts):
|
|
cached = await self._cache.get(text, self._model)
|
|
if cached is not None:
|
|
results[i] = cached
|
|
self._cache_hit_count += 1
|
|
else:
|
|
to_generate.append((i, text))
|
|
|
|
self._call_count += len(texts)
|
|
|
|
# Generate missing embeddings
|
|
if to_generate:
|
|
if hasattr(self._generator, "generate_batch"):
|
|
texts_to_gen = [t for _, t in to_generate]
|
|
embeddings = await self._generator.generate_batch(texts_to_gen)
|
|
|
|
for (idx, text), embedding in zip(to_generate, embeddings, strict=True):
|
|
results[idx] = embedding
|
|
await self._cache.put(text, embedding, self._model)
|
|
else:
|
|
# Fallback to individual generation
|
|
for idx, text in to_generate:
|
|
embedding = await self._generator.generate(text)
|
|
results[idx] = embedding
|
|
await self._cache.put(text, embedding, self._model)
|
|
|
|
return results # type: ignore[return-value]
|
|
|
|
def get_stats(self) -> dict[str, Any]:
|
|
"""Get generator statistics."""
|
|
return {
|
|
"call_count": self._call_count,
|
|
"cache_hit_count": self._cache_hit_count,
|
|
"cache_hit_rate": (
|
|
self._cache_hit_count / self._call_count
|
|
if self._call_count > 0
|
|
else 0.0
|
|
),
|
|
"cache_stats": self._cache.get_stats().to_dict(),
|
|
}
|
|
|
|
|
|
# Factory function
|
|
def create_embedding_cache(
|
|
max_size: int = 50000,
|
|
default_ttl_seconds: float = 3600.0,
|
|
redis: "Redis | None" = None,
|
|
) -> EmbeddingCache:
|
|
"""
|
|
Create an embedding cache.
|
|
|
|
Args:
|
|
max_size: Maximum number of entries
|
|
default_ttl_seconds: Default TTL for entries
|
|
redis: Optional Redis connection
|
|
|
|
Returns:
|
|
Configured EmbeddingCache instance
|
|
"""
|
|
return EmbeddingCache(
|
|
max_size=max_size,
|
|
default_ttl_seconds=default_ttl_seconds,
|
|
redis=redis,
|
|
)
|