- Added timeout enforcement for token counting, scoring, and compression with detailed error handling. - Introduced tenant isolation in context caching using project and agent identifiers. - Enhanced budget management with stricter checks for critical context overspending and buffer limitations. - Optimized per-context locking with cleanup to prevent memory leaks in concurrent environments. - Updated default assembly timeout settings for improved performance and reliability. - Improved XML escaping in Claude adapter for safety against injection attacks. - Standardized token estimation using model-specific ratios.
435 lines
12 KiB
Python
435 lines
12 KiB
Python
"""
|
|
Context Cache Implementation.
|
|
|
|
Provides Redis-based caching for context operations including
|
|
assembled contexts, token counts, and scoring results.
|
|
"""
|
|
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
from ..config import ContextSettings, get_context_settings
|
|
from ..exceptions import CacheError
|
|
from ..types import AssembledContext, BaseContext
|
|
|
|
if TYPE_CHECKING:
|
|
from redis.asyncio import Redis
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ContextCache:
|
|
"""
|
|
Redis-based caching for context operations.
|
|
|
|
Provides caching for:
|
|
- Assembled contexts (fingerprint-based)
|
|
- Token counts (content hash-based)
|
|
- Scoring results (context + query hash-based)
|
|
|
|
Cache keys use a hierarchical structure:
|
|
- ctx:assembled:{fingerprint}
|
|
- ctx:tokens:{model}:{content_hash}
|
|
- ctx:score:{scorer}:{context_hash}:{query_hash}
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
redis: "Redis | None" = None,
|
|
settings: ContextSettings | None = None,
|
|
) -> None:
|
|
"""
|
|
Initialize the context cache.
|
|
|
|
Args:
|
|
redis: Redis connection (optional for testing)
|
|
settings: Cache settings
|
|
"""
|
|
self._redis = redis
|
|
self._settings = settings or get_context_settings()
|
|
self._prefix = self._settings.cache_prefix
|
|
self._ttl = self._settings.cache_ttl_seconds
|
|
|
|
# In-memory fallback cache when Redis unavailable
|
|
self._memory_cache: dict[str, tuple[str, float]] = {}
|
|
self._max_memory_items = self._settings.cache_memory_max_items
|
|
|
|
def set_redis(self, redis: "Redis") -> None:
|
|
"""Set Redis connection."""
|
|
self._redis = redis
|
|
|
|
@property
|
|
def is_enabled(self) -> bool:
|
|
"""Check if caching is enabled and available."""
|
|
return self._settings.cache_enabled and self._redis is not None
|
|
|
|
def _cache_key(self, *parts: str) -> str:
|
|
"""
|
|
Build a cache key from parts.
|
|
|
|
Args:
|
|
*parts: Key components
|
|
|
|
Returns:
|
|
Colon-separated cache key
|
|
"""
|
|
return f"{self._prefix}:{':'.join(parts)}"
|
|
|
|
@staticmethod
|
|
def _hash_content(content: str) -> str:
|
|
"""
|
|
Compute hash of content for cache key.
|
|
|
|
Args:
|
|
content: Content to hash
|
|
|
|
Returns:
|
|
32-character hex hash
|
|
"""
|
|
return hashlib.sha256(content.encode()).hexdigest()[:32]
|
|
|
|
def compute_fingerprint(
|
|
self,
|
|
contexts: list[BaseContext],
|
|
query: str,
|
|
model: str,
|
|
project_id: str | None = None,
|
|
agent_id: str | None = None,
|
|
) -> str:
|
|
"""
|
|
Compute a fingerprint for a context assembly request.
|
|
|
|
The fingerprint is based on:
|
|
- Project and agent IDs (for tenant isolation)
|
|
- Context content hash and metadata (not full content for performance)
|
|
- Query string
|
|
- Target model
|
|
|
|
SECURITY: project_id and agent_id MUST be included to prevent
|
|
cross-tenant cache pollution. Without these, one tenant could
|
|
receive cached contexts from another tenant with the same query.
|
|
|
|
Args:
|
|
contexts: List of contexts
|
|
query: Query string
|
|
model: Model name
|
|
project_id: Project ID for tenant isolation
|
|
agent_id: Agent ID for tenant isolation
|
|
|
|
Returns:
|
|
32-character hex fingerprint
|
|
"""
|
|
# Build a deterministic representation using content hashes for performance
|
|
# This avoids JSON serializing potentially large content strings
|
|
context_data = []
|
|
for ctx in contexts:
|
|
context_data.append(
|
|
{
|
|
"type": ctx.get_type().value,
|
|
"content_hash": self._hash_content(
|
|
ctx.content
|
|
), # Hash instead of full content
|
|
"source": ctx.source,
|
|
"priority": ctx.priority, # Already an int
|
|
}
|
|
)
|
|
|
|
data = {
|
|
# CRITICAL: Include tenant identifiers for cache isolation
|
|
"project_id": project_id or "",
|
|
"agent_id": agent_id or "",
|
|
"contexts": context_data,
|
|
"query": query,
|
|
"model": model,
|
|
}
|
|
|
|
content = json.dumps(data, sort_keys=True)
|
|
return self._hash_content(content)
|
|
|
|
async def get_assembled(
|
|
self,
|
|
fingerprint: str,
|
|
) -> AssembledContext | None:
|
|
"""
|
|
Get cached assembled context by fingerprint.
|
|
|
|
Args:
|
|
fingerprint: Assembly fingerprint
|
|
|
|
Returns:
|
|
Cached AssembledContext or None if not found
|
|
"""
|
|
if not self.is_enabled:
|
|
return None
|
|
|
|
key = self._cache_key("assembled", fingerprint)
|
|
|
|
try:
|
|
data = await self._redis.get(key) # type: ignore
|
|
if data:
|
|
logger.debug(f"Cache hit for assembled context: {fingerprint}")
|
|
result = AssembledContext.from_json(data)
|
|
result.cache_hit = True
|
|
result.cache_key = fingerprint
|
|
return result
|
|
except Exception as e:
|
|
logger.warning(f"Cache get error: {e}")
|
|
raise CacheError(f"Failed to get assembled context: {e}") from e
|
|
|
|
return None
|
|
|
|
async def set_assembled(
|
|
self,
|
|
fingerprint: str,
|
|
context: AssembledContext,
|
|
ttl: int | None = None,
|
|
) -> None:
|
|
"""
|
|
Cache an assembled context.
|
|
|
|
Args:
|
|
fingerprint: Assembly fingerprint
|
|
context: Assembled context to cache
|
|
ttl: Optional TTL override in seconds
|
|
"""
|
|
if not self.is_enabled:
|
|
return
|
|
|
|
key = self._cache_key("assembled", fingerprint)
|
|
expire = ttl or self._ttl
|
|
|
|
try:
|
|
await self._redis.setex(key, expire, context.to_json()) # type: ignore
|
|
logger.debug(f"Cached assembled context: {fingerprint}")
|
|
except Exception as e:
|
|
logger.warning(f"Cache set error: {e}")
|
|
raise CacheError(f"Failed to cache assembled context: {e}") from e
|
|
|
|
async def get_token_count(
|
|
self,
|
|
content: str,
|
|
model: str | None = None,
|
|
) -> int | None:
|
|
"""
|
|
Get cached token count.
|
|
|
|
Args:
|
|
content: Content to look up
|
|
model: Model name for model-specific tokenization
|
|
|
|
Returns:
|
|
Cached token count or None if not found
|
|
"""
|
|
model_key = model or "default"
|
|
content_hash = self._hash_content(content)
|
|
key = self._cache_key("tokens", model_key, content_hash)
|
|
|
|
# Try in-memory first
|
|
if key in self._memory_cache:
|
|
return int(self._memory_cache[key][0])
|
|
|
|
if not self.is_enabled:
|
|
return None
|
|
|
|
try:
|
|
data = await self._redis.get(key) # type: ignore
|
|
if data:
|
|
count = int(data)
|
|
# Store in memory for faster subsequent access
|
|
self._set_memory(key, str(count))
|
|
return count
|
|
except Exception as e:
|
|
logger.warning(f"Cache get error for tokens: {e}")
|
|
|
|
return None
|
|
|
|
async def set_token_count(
|
|
self,
|
|
content: str,
|
|
count: int,
|
|
model: str | None = None,
|
|
ttl: int | None = None,
|
|
) -> None:
|
|
"""
|
|
Cache a token count.
|
|
|
|
Args:
|
|
content: Content that was counted
|
|
count: Token count
|
|
model: Model name
|
|
ttl: Optional TTL override in seconds
|
|
"""
|
|
model_key = model or "default"
|
|
content_hash = self._hash_content(content)
|
|
key = self._cache_key("tokens", model_key, content_hash)
|
|
expire = ttl or self._ttl
|
|
|
|
# Always store in memory
|
|
self._set_memory(key, str(count))
|
|
|
|
if not self.is_enabled:
|
|
return
|
|
|
|
try:
|
|
await self._redis.setex(key, expire, str(count)) # type: ignore
|
|
except Exception as e:
|
|
logger.warning(f"Cache set error for tokens: {e}")
|
|
|
|
async def get_score(
|
|
self,
|
|
scorer_name: str,
|
|
context_id: str,
|
|
query: str,
|
|
) -> float | None:
|
|
"""
|
|
Get cached score.
|
|
|
|
Args:
|
|
scorer_name: Name of the scorer
|
|
context_id: Context identifier
|
|
query: Query string
|
|
|
|
Returns:
|
|
Cached score or None if not found
|
|
"""
|
|
query_hash = self._hash_content(query)[:16]
|
|
key = self._cache_key("score", scorer_name, context_id, query_hash)
|
|
|
|
# Try in-memory first
|
|
if key in self._memory_cache:
|
|
return float(self._memory_cache[key][0])
|
|
|
|
if not self.is_enabled:
|
|
return None
|
|
|
|
try:
|
|
data = await self._redis.get(key) # type: ignore
|
|
if data:
|
|
score = float(data)
|
|
self._set_memory(key, str(score))
|
|
return score
|
|
except Exception as e:
|
|
logger.warning(f"Cache get error for score: {e}")
|
|
|
|
return None
|
|
|
|
async def set_score(
|
|
self,
|
|
scorer_name: str,
|
|
context_id: str,
|
|
query: str,
|
|
score: float,
|
|
ttl: int | None = None,
|
|
) -> None:
|
|
"""
|
|
Cache a score.
|
|
|
|
Args:
|
|
scorer_name: Name of the scorer
|
|
context_id: Context identifier
|
|
query: Query string
|
|
score: Score value
|
|
ttl: Optional TTL override in seconds
|
|
"""
|
|
query_hash = self._hash_content(query)[:16]
|
|
key = self._cache_key("score", scorer_name, context_id, query_hash)
|
|
expire = ttl or self._ttl
|
|
|
|
# Always store in memory
|
|
self._set_memory(key, str(score))
|
|
|
|
if not self.is_enabled:
|
|
return
|
|
|
|
try:
|
|
await self._redis.setex(key, expire, str(score)) # type: ignore
|
|
except Exception as e:
|
|
logger.warning(f"Cache set error for score: {e}")
|
|
|
|
async def invalidate(self, pattern: str) -> int:
|
|
"""
|
|
Invalidate cache entries matching a pattern.
|
|
|
|
Args:
|
|
pattern: Key pattern (supports * wildcard)
|
|
|
|
Returns:
|
|
Number of keys deleted
|
|
"""
|
|
if not self.is_enabled:
|
|
return 0
|
|
|
|
full_pattern = self._cache_key(pattern)
|
|
deleted = 0
|
|
|
|
try:
|
|
async for key in self._redis.scan_iter(match=full_pattern): # type: ignore
|
|
await self._redis.delete(key) # type: ignore
|
|
deleted += 1
|
|
|
|
logger.info(f"Invalidated {deleted} cache entries matching {pattern}")
|
|
except Exception as e:
|
|
logger.warning(f"Cache invalidation error: {e}")
|
|
raise CacheError(f"Failed to invalidate cache: {e}") from e
|
|
|
|
return deleted
|
|
|
|
async def clear_all(self) -> int:
|
|
"""
|
|
Clear all context cache entries.
|
|
|
|
Returns:
|
|
Number of keys deleted
|
|
"""
|
|
self._memory_cache.clear()
|
|
return await self.invalidate("*")
|
|
|
|
def _set_memory(self, key: str, value: str) -> None:
|
|
"""
|
|
Set a value in the memory cache.
|
|
|
|
Uses LRU-style eviction when max items reached.
|
|
|
|
Args:
|
|
key: Cache key
|
|
value: Value to store
|
|
"""
|
|
import time
|
|
|
|
if len(self._memory_cache) >= self._max_memory_items:
|
|
# Evict oldest entries
|
|
sorted_keys = sorted(
|
|
self._memory_cache.keys(),
|
|
key=lambda k: self._memory_cache[k][1],
|
|
)
|
|
for k in sorted_keys[: len(sorted_keys) // 2]:
|
|
del self._memory_cache[k]
|
|
|
|
self._memory_cache[key] = (value, time.time())
|
|
|
|
async def get_stats(self) -> dict[str, Any]:
|
|
"""
|
|
Get cache statistics.
|
|
|
|
Returns:
|
|
Dictionary with cache stats
|
|
"""
|
|
stats = {
|
|
"enabled": self._settings.cache_enabled,
|
|
"redis_available": self._redis is not None,
|
|
"memory_items": len(self._memory_cache),
|
|
"ttl_seconds": self._ttl,
|
|
}
|
|
|
|
if self.is_enabled:
|
|
try:
|
|
# Get Redis info
|
|
info = await self._redis.info("memory") # type: ignore
|
|
stats["redis_memory_used"] = info.get("used_memory_human", "unknown")
|
|
except Exception as e:
|
|
logger.debug(f"Failed to get Redis stats: {e}")
|
|
|
|
return stats
|