""" Context Cache Implementation. Provides Redis-based caching for context operations including assembled contexts, token counts, and scoring results. """ import hashlib import json import logging from typing import TYPE_CHECKING, Any from ..config import ContextSettings, get_context_settings from ..exceptions import CacheError from ..types import AssembledContext, BaseContext if TYPE_CHECKING: from redis.asyncio import Redis logger = logging.getLogger(__name__) class ContextCache: """ Redis-based caching for context operations. Provides caching for: - Assembled contexts (fingerprint-based) - Token counts (content hash-based) - Scoring results (context + query hash-based) Cache keys use a hierarchical structure: - ctx:assembled:{fingerprint} - ctx:tokens:{model}:{content_hash} - ctx:score:{scorer}:{context_hash}:{query_hash} """ def __init__( self, redis: "Redis | None" = None, settings: ContextSettings | None = None, ) -> None: """ Initialize the context cache. Args: redis: Redis connection (optional for testing) settings: Cache settings """ self._redis = redis self._settings = settings or get_context_settings() self._prefix = self._settings.cache_prefix self._ttl = self._settings.cache_ttl_seconds # In-memory fallback cache when Redis unavailable self._memory_cache: dict[str, tuple[str, float]] = {} self._max_memory_items = self._settings.cache_memory_max_items def set_redis(self, redis: "Redis") -> None: """Set Redis connection.""" self._redis = redis @property def is_enabled(self) -> bool: """Check if caching is enabled and available.""" return self._settings.cache_enabled and self._redis is not None def _cache_key(self, *parts: str) -> str: """ Build a cache key from parts. Args: *parts: Key components Returns: Colon-separated cache key """ return f"{self._prefix}:{':'.join(parts)}" @staticmethod def _hash_content(content: str) -> str: """ Compute hash of content for cache key. Args: content: Content to hash Returns: 32-character hex hash """ return hashlib.sha256(content.encode()).hexdigest()[:32] def compute_fingerprint( self, contexts: list[BaseContext], query: str, model: str, project_id: str | None = None, agent_id: str | None = None, ) -> str: """ Compute a fingerprint for a context assembly request. The fingerprint is based on: - Project and agent IDs (for tenant isolation) - Context content hash and metadata (not full content for performance) - Query string - Target model SECURITY: project_id and agent_id MUST be included to prevent cross-tenant cache pollution. Without these, one tenant could receive cached contexts from another tenant with the same query. Args: contexts: List of contexts query: Query string model: Model name project_id: Project ID for tenant isolation agent_id: Agent ID for tenant isolation Returns: 32-character hex fingerprint """ # Build a deterministic representation using content hashes for performance # This avoids JSON serializing potentially large content strings context_data = [] for ctx in contexts: context_data.append( { "type": ctx.get_type().value, "content_hash": self._hash_content( ctx.content ), # Hash instead of full content "source": ctx.source, "priority": ctx.priority, # Already an int } ) data = { # CRITICAL: Include tenant identifiers for cache isolation "project_id": project_id or "", "agent_id": agent_id or "", "contexts": context_data, "query": query, "model": model, } content = json.dumps(data, sort_keys=True) return self._hash_content(content) async def get_assembled( self, fingerprint: str, ) -> AssembledContext | None: """ Get cached assembled context by fingerprint. Args: fingerprint: Assembly fingerprint Returns: Cached AssembledContext or None if not found """ if not self.is_enabled: return None key = self._cache_key("assembled", fingerprint) try: data = await self._redis.get(key) # type: ignore if data: logger.debug(f"Cache hit for assembled context: {fingerprint}") result = AssembledContext.from_json(data) result.cache_hit = True result.cache_key = fingerprint return result except Exception as e: logger.warning(f"Cache get error: {e}") raise CacheError(f"Failed to get assembled context: {e}") from e return None async def set_assembled( self, fingerprint: str, context: AssembledContext, ttl: int | None = None, ) -> None: """ Cache an assembled context. Args: fingerprint: Assembly fingerprint context: Assembled context to cache ttl: Optional TTL override in seconds """ if not self.is_enabled: return key = self._cache_key("assembled", fingerprint) expire = ttl or self._ttl try: await self._redis.setex(key, expire, context.to_json()) # type: ignore logger.debug(f"Cached assembled context: {fingerprint}") except Exception as e: logger.warning(f"Cache set error: {e}") raise CacheError(f"Failed to cache assembled context: {e}") from e async def get_token_count( self, content: str, model: str | None = None, ) -> int | None: """ Get cached token count. Args: content: Content to look up model: Model name for model-specific tokenization Returns: Cached token count or None if not found """ model_key = model or "default" content_hash = self._hash_content(content) key = self._cache_key("tokens", model_key, content_hash) # Try in-memory first if key in self._memory_cache: return int(self._memory_cache[key][0]) if not self.is_enabled: return None try: data = await self._redis.get(key) # type: ignore if data: count = int(data) # Store in memory for faster subsequent access self._set_memory(key, str(count)) return count except Exception as e: logger.warning(f"Cache get error for tokens: {e}") return None async def set_token_count( self, content: str, count: int, model: str | None = None, ttl: int | None = None, ) -> None: """ Cache a token count. Args: content: Content that was counted count: Token count model: Model name ttl: Optional TTL override in seconds """ model_key = model or "default" content_hash = self._hash_content(content) key = self._cache_key("tokens", model_key, content_hash) expire = ttl or self._ttl # Always store in memory self._set_memory(key, str(count)) if not self.is_enabled: return try: await self._redis.setex(key, expire, str(count)) # type: ignore except Exception as e: logger.warning(f"Cache set error for tokens: {e}") async def get_score( self, scorer_name: str, context_id: str, query: str, ) -> float | None: """ Get cached score. Args: scorer_name: Name of the scorer context_id: Context identifier query: Query string Returns: Cached score or None if not found """ query_hash = self._hash_content(query)[:16] key = self._cache_key("score", scorer_name, context_id, query_hash) # Try in-memory first if key in self._memory_cache: return float(self._memory_cache[key][0]) if not self.is_enabled: return None try: data = await self._redis.get(key) # type: ignore if data: score = float(data) self._set_memory(key, str(score)) return score except Exception as e: logger.warning(f"Cache get error for score: {e}") return None async def set_score( self, scorer_name: str, context_id: str, query: str, score: float, ttl: int | None = None, ) -> None: """ Cache a score. Args: scorer_name: Name of the scorer context_id: Context identifier query: Query string score: Score value ttl: Optional TTL override in seconds """ query_hash = self._hash_content(query)[:16] key = self._cache_key("score", scorer_name, context_id, query_hash) expire = ttl or self._ttl # Always store in memory self._set_memory(key, str(score)) if not self.is_enabled: return try: await self._redis.setex(key, expire, str(score)) # type: ignore except Exception as e: logger.warning(f"Cache set error for score: {e}") async def invalidate(self, pattern: str) -> int: """ Invalidate cache entries matching a pattern. Args: pattern: Key pattern (supports * wildcard) Returns: Number of keys deleted """ if not self.is_enabled: return 0 full_pattern = self._cache_key(pattern) deleted = 0 try: async for key in self._redis.scan_iter(match=full_pattern): # type: ignore await self._redis.delete(key) # type: ignore deleted += 1 logger.info(f"Invalidated {deleted} cache entries matching {pattern}") except Exception as e: logger.warning(f"Cache invalidation error: {e}") raise CacheError(f"Failed to invalidate cache: {e}") from e return deleted async def clear_all(self) -> int: """ Clear all context cache entries. Returns: Number of keys deleted """ self._memory_cache.clear() return await self.invalidate("*") def _set_memory(self, key: str, value: str) -> None: """ Set a value in the memory cache. Uses LRU-style eviction when max items reached. Args: key: Cache key value: Value to store """ import time if len(self._memory_cache) >= self._max_memory_items: # Evict oldest entries sorted_keys = sorted( self._memory_cache.keys(), key=lambda k: self._memory_cache[k][1], ) for k in sorted_keys[: len(sorted_keys) // 2]: del self._memory_cache[k] self._memory_cache[key] = (value, time.time()) async def get_stats(self) -> dict[str, Any]: """ Get cache statistics. Returns: Dictionary with cache stats """ stats = { "enabled": self._settings.cache_enabled, "redis_available": self._redis is not None, "memory_items": len(self._memory_cache), "ttl_seconds": self._ttl, } if self.is_enabled: try: # Get Redis info info = await self._redis.info("memory") # type: ignore stats["redis_memory_used"] = info.get("used_memory_human", "unknown") except Exception as e: logger.debug(f"Failed to get Redis stats: {e}") return stats