forked from cardosofelipe/pragma-stack
feat(context): implement Redis-based caching layer (#84)
Phase 6 of Context Management Engine - Caching Layer: - Add ContextCache with Redis integration - Support fingerprint-based assembled context caching - Support token count caching (model-specific) - Support score caching (scorer + context + query) - Add in-memory fallback with LRU eviction - Add cache invalidation with pattern matching - Add cache statistics reporting Key features: - Hierarchical cache key structure (ctx:type:hash) - Automatic TTL expiration - Memory cache for fast repeated access - Graceful degradation when Redis unavailable Tests: 29 new tests, 285 total context tests passing 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -85,6 +85,9 @@ from .adapters import (
|
|||||||
OpenAIAdapter,
|
OpenAIAdapter,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Cache
|
||||||
|
from .cache import ContextCache
|
||||||
|
|
||||||
# Prioritization
|
# Prioritization
|
||||||
from .prioritization import (
|
from .prioritization import (
|
||||||
ContextRanker,
|
ContextRanker,
|
||||||
@@ -132,6 +135,8 @@ __all__ = [
|
|||||||
"BudgetAllocator",
|
"BudgetAllocator",
|
||||||
"TokenBudget",
|
"TokenBudget",
|
||||||
"TokenCalculator",
|
"TokenCalculator",
|
||||||
|
# Cache
|
||||||
|
"ContextCache",
|
||||||
# Compression
|
# Compression
|
||||||
"ContextCompressor",
|
"ContextCompressor",
|
||||||
"TruncationResult",
|
"TruncationResult",
|
||||||
|
|||||||
@@ -3,3 +3,9 @@ Context Cache Module.
|
|||||||
|
|
||||||
Provides Redis-based caching for assembled contexts.
|
Provides Redis-based caching for assembled contexts.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from .context_cache import ContextCache
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ContextCache",
|
||||||
|
]
|
||||||
|
|||||||
417
backend/app/services/context/cache/context_cache.py
vendored
Normal file
417
backend/app/services/context/cache/context_cache.py
vendored
Normal file
@@ -0,0 +1,417 @@
|
|||||||
|
"""
|
||||||
|
Context Cache Implementation.
|
||||||
|
|
||||||
|
Provides Redis-based caching for context operations including
|
||||||
|
assembled contexts, token counts, and scoring results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from ..config import ContextSettings, get_context_settings
|
||||||
|
from ..exceptions import CacheError
|
||||||
|
from ..types import AssembledContext, BaseContext
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from redis.asyncio import Redis
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ContextCache:
|
||||||
|
"""
|
||||||
|
Redis-based caching for context operations.
|
||||||
|
|
||||||
|
Provides caching for:
|
||||||
|
- Assembled contexts (fingerprint-based)
|
||||||
|
- Token counts (content hash-based)
|
||||||
|
- Scoring results (context + query hash-based)
|
||||||
|
|
||||||
|
Cache keys use a hierarchical structure:
|
||||||
|
- ctx:assembled:{fingerprint}
|
||||||
|
- ctx:tokens:{model}:{content_hash}
|
||||||
|
- ctx:score:{scorer}:{context_hash}:{query_hash}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
redis: "Redis | None" = None,
|
||||||
|
settings: ContextSettings | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the context cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
redis: Redis connection (optional for testing)
|
||||||
|
settings: Cache settings
|
||||||
|
"""
|
||||||
|
self._redis = redis
|
||||||
|
self._settings = settings or get_context_settings()
|
||||||
|
self._prefix = self._settings.cache_prefix
|
||||||
|
self._ttl = self._settings.cache_ttl_seconds
|
||||||
|
|
||||||
|
# In-memory fallback cache when Redis unavailable
|
||||||
|
self._memory_cache: dict[str, tuple[str, float]] = {}
|
||||||
|
self._max_memory_items = 1000
|
||||||
|
|
||||||
|
def set_redis(self, redis: "Redis") -> None:
|
||||||
|
"""Set Redis connection."""
|
||||||
|
self._redis = redis
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_enabled(self) -> bool:
|
||||||
|
"""Check if caching is enabled and available."""
|
||||||
|
return self._settings.cache_enabled and self._redis is not None
|
||||||
|
|
||||||
|
def _cache_key(self, *parts: str) -> str:
|
||||||
|
"""
|
||||||
|
Build a cache key from parts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*parts: Key components
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Colon-separated cache key
|
||||||
|
"""
|
||||||
|
return f"{self._prefix}:{':'.join(parts)}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _hash_content(content: str) -> str:
|
||||||
|
"""
|
||||||
|
Compute hash of content for cache key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content to hash
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
32-character hex hash
|
||||||
|
"""
|
||||||
|
return hashlib.sha256(content.encode()).hexdigest()[:32]
|
||||||
|
|
||||||
|
def compute_fingerprint(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
query: str,
|
||||||
|
model: str,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Compute a fingerprint for a context assembly request.
|
||||||
|
|
||||||
|
The fingerprint is based on:
|
||||||
|
- Context content and metadata
|
||||||
|
- Query string
|
||||||
|
- Target model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: List of contexts
|
||||||
|
query: Query string
|
||||||
|
model: Model name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
32-character hex fingerprint
|
||||||
|
"""
|
||||||
|
# Build a deterministic representation
|
||||||
|
context_data = []
|
||||||
|
for ctx in contexts:
|
||||||
|
context_data.append({
|
||||||
|
"type": ctx.get_type().value,
|
||||||
|
"content": ctx.content,
|
||||||
|
"source": ctx.source,
|
||||||
|
"priority": ctx.priority, # Already an int
|
||||||
|
})
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"contexts": context_data,
|
||||||
|
"query": query,
|
||||||
|
"model": model,
|
||||||
|
}
|
||||||
|
|
||||||
|
content = json.dumps(data, sort_keys=True)
|
||||||
|
return self._hash_content(content)
|
||||||
|
|
||||||
|
async def get_assembled(
|
||||||
|
self,
|
||||||
|
fingerprint: str,
|
||||||
|
) -> AssembledContext | None:
|
||||||
|
"""
|
||||||
|
Get cached assembled context by fingerprint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fingerprint: Assembly fingerprint
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached AssembledContext or None if not found
|
||||||
|
"""
|
||||||
|
if not self.is_enabled:
|
||||||
|
return None
|
||||||
|
|
||||||
|
key = self._cache_key("assembled", fingerprint)
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = await self._redis.get(key) # type: ignore
|
||||||
|
if data:
|
||||||
|
logger.debug(f"Cache hit for assembled context: {fingerprint}")
|
||||||
|
result = AssembledContext.from_json(data)
|
||||||
|
result.cache_hit = True
|
||||||
|
result.cache_key = fingerprint
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Cache get error: {e}")
|
||||||
|
raise CacheError(f"Failed to get assembled context: {e}") from e
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def set_assembled(
|
||||||
|
self,
|
||||||
|
fingerprint: str,
|
||||||
|
context: AssembledContext,
|
||||||
|
ttl: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Cache an assembled context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fingerprint: Assembly fingerprint
|
||||||
|
context: Assembled context to cache
|
||||||
|
ttl: Optional TTL override in seconds
|
||||||
|
"""
|
||||||
|
if not self.is_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
key = self._cache_key("assembled", fingerprint)
|
||||||
|
expire = ttl or self._ttl
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._redis.setex(key, expire, context.to_json()) # type: ignore
|
||||||
|
logger.debug(f"Cached assembled context: {fingerprint}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Cache set error: {e}")
|
||||||
|
raise CacheError(f"Failed to cache assembled context: {e}") from e
|
||||||
|
|
||||||
|
async def get_token_count(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> int | None:
|
||||||
|
"""
|
||||||
|
Get cached token count.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content to look up
|
||||||
|
model: Model name for model-specific tokenization
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached token count or None if not found
|
||||||
|
"""
|
||||||
|
model_key = model or "default"
|
||||||
|
content_hash = self._hash_content(content)
|
||||||
|
key = self._cache_key("tokens", model_key, content_hash)
|
||||||
|
|
||||||
|
# Try in-memory first
|
||||||
|
if key in self._memory_cache:
|
||||||
|
return int(self._memory_cache[key][0])
|
||||||
|
|
||||||
|
if not self.is_enabled:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = await self._redis.get(key) # type: ignore
|
||||||
|
if data:
|
||||||
|
count = int(data)
|
||||||
|
# Store in memory for faster subsequent access
|
||||||
|
self._set_memory(key, str(count))
|
||||||
|
return count
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Cache get error for tokens: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def set_token_count(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
count: int,
|
||||||
|
model: str | None = None,
|
||||||
|
ttl: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Cache a token count.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content that was counted
|
||||||
|
count: Token count
|
||||||
|
model: Model name
|
||||||
|
ttl: Optional TTL override in seconds
|
||||||
|
"""
|
||||||
|
model_key = model or "default"
|
||||||
|
content_hash = self._hash_content(content)
|
||||||
|
key = self._cache_key("tokens", model_key, content_hash)
|
||||||
|
expire = ttl or self._ttl
|
||||||
|
|
||||||
|
# Always store in memory
|
||||||
|
self._set_memory(key, str(count))
|
||||||
|
|
||||||
|
if not self.is_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._redis.setex(key, expire, str(count)) # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Cache set error for tokens: {e}")
|
||||||
|
|
||||||
|
async def get_score(
|
||||||
|
self,
|
||||||
|
scorer_name: str,
|
||||||
|
context_id: str,
|
||||||
|
query: str,
|
||||||
|
) -> float | None:
|
||||||
|
"""
|
||||||
|
Get cached score.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scorer_name: Name of the scorer
|
||||||
|
context_id: Context identifier
|
||||||
|
query: Query string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached score or None if not found
|
||||||
|
"""
|
||||||
|
query_hash = self._hash_content(query)[:16]
|
||||||
|
key = self._cache_key("score", scorer_name, context_id, query_hash)
|
||||||
|
|
||||||
|
# Try in-memory first
|
||||||
|
if key in self._memory_cache:
|
||||||
|
return float(self._memory_cache[key][0])
|
||||||
|
|
||||||
|
if not self.is_enabled:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = await self._redis.get(key) # type: ignore
|
||||||
|
if data:
|
||||||
|
score = float(data)
|
||||||
|
self._set_memory(key, str(score))
|
||||||
|
return score
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Cache get error for score: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def set_score(
|
||||||
|
self,
|
||||||
|
scorer_name: str,
|
||||||
|
context_id: str,
|
||||||
|
query: str,
|
||||||
|
score: float,
|
||||||
|
ttl: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Cache a score.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scorer_name: Name of the scorer
|
||||||
|
context_id: Context identifier
|
||||||
|
query: Query string
|
||||||
|
score: Score value
|
||||||
|
ttl: Optional TTL override in seconds
|
||||||
|
"""
|
||||||
|
query_hash = self._hash_content(query)[:16]
|
||||||
|
key = self._cache_key("score", scorer_name, context_id, query_hash)
|
||||||
|
expire = ttl or self._ttl
|
||||||
|
|
||||||
|
# Always store in memory
|
||||||
|
self._set_memory(key, str(score))
|
||||||
|
|
||||||
|
if not self.is_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._redis.setex(key, expire, str(score)) # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Cache set error for score: {e}")
|
||||||
|
|
||||||
|
async def invalidate(self, pattern: str) -> int:
|
||||||
|
"""
|
||||||
|
Invalidate cache entries matching a pattern.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: Key pattern (supports * wildcard)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of keys deleted
|
||||||
|
"""
|
||||||
|
if not self.is_enabled:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
full_pattern = self._cache_key(pattern)
|
||||||
|
deleted = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for key in self._redis.scan_iter(match=full_pattern): # type: ignore
|
||||||
|
await self._redis.delete(key) # type: ignore
|
||||||
|
deleted += 1
|
||||||
|
|
||||||
|
logger.info(f"Invalidated {deleted} cache entries matching {pattern}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Cache invalidation error: {e}")
|
||||||
|
raise CacheError(f"Failed to invalidate cache: {e}") from e
|
||||||
|
|
||||||
|
return deleted
|
||||||
|
|
||||||
|
async def clear_all(self) -> int:
|
||||||
|
"""
|
||||||
|
Clear all context cache entries.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of keys deleted
|
||||||
|
"""
|
||||||
|
self._memory_cache.clear()
|
||||||
|
return await self.invalidate("*")
|
||||||
|
|
||||||
|
def _set_memory(self, key: str, value: str) -> None:
|
||||||
|
"""
|
||||||
|
Set a value in the memory cache.
|
||||||
|
|
||||||
|
Uses LRU-style eviction when max items reached.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key
|
||||||
|
value: Value to store
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
if len(self._memory_cache) >= self._max_memory_items:
|
||||||
|
# Evict oldest entries
|
||||||
|
sorted_keys = sorted(
|
||||||
|
self._memory_cache.keys(),
|
||||||
|
key=lambda k: self._memory_cache[k][1],
|
||||||
|
)
|
||||||
|
for k in sorted_keys[: len(sorted_keys) // 2]:
|
||||||
|
del self._memory_cache[k]
|
||||||
|
|
||||||
|
self._memory_cache[key] = (value, time.time())
|
||||||
|
|
||||||
|
async def get_stats(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get cache statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with cache stats
|
||||||
|
"""
|
||||||
|
stats = {
|
||||||
|
"enabled": self._settings.cache_enabled,
|
||||||
|
"redis_available": self._redis is not None,
|
||||||
|
"memory_items": len(self._memory_cache),
|
||||||
|
"ttl_seconds": self._ttl,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.is_enabled:
|
||||||
|
try:
|
||||||
|
# Get Redis info
|
||||||
|
info = await self._redis.info("memory") # type: ignore
|
||||||
|
stats["redis_memory_used"] = info.get("used_memory_human", "unknown")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return stats
|
||||||
479
backend/tests/services/context/test_cache.py
Normal file
479
backend/tests/services/context/test_cache.py
Normal file
@@ -0,0 +1,479 @@
|
|||||||
|
"""Tests for context cache module."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.context.cache import ContextCache
|
||||||
|
from app.services.context.config import ContextSettings
|
||||||
|
from app.services.context.exceptions import CacheError
|
||||||
|
from app.services.context.types import (
|
||||||
|
AssembledContext,
|
||||||
|
ContextPriority,
|
||||||
|
KnowledgeContext,
|
||||||
|
SystemContext,
|
||||||
|
TaskContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextCacheBasics:
|
||||||
|
"""Basic tests for ContextCache."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test cache creation without Redis."""
|
||||||
|
cache = ContextCache()
|
||||||
|
assert cache._redis is None
|
||||||
|
assert not cache.is_enabled
|
||||||
|
|
||||||
|
def test_creation_with_settings(self) -> None:
|
||||||
|
"""Test cache creation with custom settings."""
|
||||||
|
settings = ContextSettings(
|
||||||
|
cache_prefix="test",
|
||||||
|
cache_ttl_seconds=60,
|
||||||
|
)
|
||||||
|
cache = ContextCache(settings=settings)
|
||||||
|
assert cache._prefix == "test"
|
||||||
|
assert cache._ttl == 60
|
||||||
|
|
||||||
|
def test_set_redis(self) -> None:
|
||||||
|
"""Test setting Redis connection."""
|
||||||
|
cache = ContextCache()
|
||||||
|
mock_redis = MagicMock()
|
||||||
|
cache.set_redis(mock_redis)
|
||||||
|
assert cache._redis is mock_redis
|
||||||
|
|
||||||
|
def test_is_enabled(self) -> None:
|
||||||
|
"""Test is_enabled property."""
|
||||||
|
settings = ContextSettings(cache_enabled=True)
|
||||||
|
cache = ContextCache(settings=settings)
|
||||||
|
assert not cache.is_enabled # No Redis
|
||||||
|
|
||||||
|
cache.set_redis(MagicMock())
|
||||||
|
assert cache.is_enabled
|
||||||
|
|
||||||
|
# Disabled in settings
|
||||||
|
settings2 = ContextSettings(cache_enabled=False)
|
||||||
|
cache2 = ContextCache(redis=MagicMock(), settings=settings2)
|
||||||
|
assert not cache2.is_enabled
|
||||||
|
|
||||||
|
def test_cache_key(self) -> None:
|
||||||
|
"""Test cache key generation."""
|
||||||
|
cache = ContextCache()
|
||||||
|
key = cache._cache_key("assembled", "abc123")
|
||||||
|
assert key == "ctx:assembled:abc123"
|
||||||
|
|
||||||
|
def test_hash_content(self) -> None:
|
||||||
|
"""Test content hashing."""
|
||||||
|
hash1 = ContextCache._hash_content("hello world")
|
||||||
|
hash2 = ContextCache._hash_content("hello world")
|
||||||
|
hash3 = ContextCache._hash_content("different")
|
||||||
|
|
||||||
|
assert hash1 == hash2
|
||||||
|
assert hash1 != hash3
|
||||||
|
assert len(hash1) == 32
|
||||||
|
|
||||||
|
|
||||||
|
class TestFingerprintComputation:
|
||||||
|
"""Tests for fingerprint computation."""
|
||||||
|
|
||||||
|
def test_compute_fingerprint(self) -> None:
|
||||||
|
"""Test fingerprint computation."""
|
||||||
|
cache = ContextCache()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
SystemContext(content="System", source="system"),
|
||||||
|
TaskContext(content="Task", source="task"),
|
||||||
|
]
|
||||||
|
|
||||||
|
fp1 = cache.compute_fingerprint(contexts, "query", "claude-3")
|
||||||
|
fp2 = cache.compute_fingerprint(contexts, "query", "claude-3")
|
||||||
|
fp3 = cache.compute_fingerprint(contexts, "different", "claude-3")
|
||||||
|
|
||||||
|
assert fp1 == fp2 # Same inputs = same fingerprint
|
||||||
|
assert fp1 != fp3 # Different query = different fingerprint
|
||||||
|
assert len(fp1) == 32
|
||||||
|
|
||||||
|
def test_fingerprint_includes_priority(self) -> None:
|
||||||
|
"""Test that fingerprint changes with priority."""
|
||||||
|
cache = ContextCache()
|
||||||
|
|
||||||
|
# Use KnowledgeContext since SystemContext has __post_init__ that may override
|
||||||
|
ctx1 = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Knowledge",
|
||||||
|
source="docs",
|
||||||
|
priority=ContextPriority.NORMAL.value,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
ctx2 = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Knowledge",
|
||||||
|
source="docs",
|
||||||
|
priority=ContextPriority.HIGH.value,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
fp1 = cache.compute_fingerprint(ctx1, "query", "claude-3")
|
||||||
|
fp2 = cache.compute_fingerprint(ctx2, "query", "claude-3")
|
||||||
|
|
||||||
|
assert fp1 != fp2
|
||||||
|
|
||||||
|
def test_fingerprint_includes_model(self) -> None:
|
||||||
|
"""Test that fingerprint changes with model."""
|
||||||
|
cache = ContextCache()
|
||||||
|
contexts = [SystemContext(content="System", source="system")]
|
||||||
|
|
||||||
|
fp1 = cache.compute_fingerprint(contexts, "query", "claude-3")
|
||||||
|
fp2 = cache.compute_fingerprint(contexts, "query", "gpt-4")
|
||||||
|
|
||||||
|
assert fp1 != fp2
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemoryCache:
|
||||||
|
"""Tests for in-memory caching."""
|
||||||
|
|
||||||
|
def test_memory_cache_fallback(self) -> None:
|
||||||
|
"""Test memory cache when Redis unavailable."""
|
||||||
|
cache = ContextCache()
|
||||||
|
|
||||||
|
# Should use memory cache
|
||||||
|
cache._set_memory("test-key", "42")
|
||||||
|
assert "test-key" in cache._memory_cache
|
||||||
|
assert cache._memory_cache["test-key"][0] == "42"
|
||||||
|
|
||||||
|
def test_memory_cache_eviction(self) -> None:
|
||||||
|
"""Test memory cache eviction."""
|
||||||
|
cache = ContextCache()
|
||||||
|
cache._max_memory_items = 10
|
||||||
|
|
||||||
|
# Fill cache
|
||||||
|
for i in range(15):
|
||||||
|
cache._set_memory(f"key-{i}", f"value-{i}")
|
||||||
|
|
||||||
|
# Should have evicted some items
|
||||||
|
assert len(cache._memory_cache) < 15
|
||||||
|
|
||||||
|
|
||||||
|
class TestAssembledContextCache:
|
||||||
|
"""Tests for assembled context caching."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_assembled_no_redis(self) -> None:
|
||||||
|
"""Test get_assembled without Redis returns None."""
|
||||||
|
cache = ContextCache()
|
||||||
|
result = await cache.get_assembled("fingerprint")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_assembled_not_found(self) -> None:
|
||||||
|
"""Test get_assembled when key not found."""
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.get.return_value = None
|
||||||
|
|
||||||
|
settings = ContextSettings(cache_enabled=True)
|
||||||
|
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||||
|
|
||||||
|
result = await cache.get_assembled("fingerprint")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_assembled_found(self) -> None:
|
||||||
|
"""Test get_assembled when key found."""
|
||||||
|
# Create a context
|
||||||
|
ctx = AssembledContext(
|
||||||
|
content="Test content",
|
||||||
|
total_tokens=100,
|
||||||
|
context_count=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.get.return_value = ctx.to_json()
|
||||||
|
|
||||||
|
settings = ContextSettings(cache_enabled=True)
|
||||||
|
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||||
|
|
||||||
|
result = await cache.get_assembled("fingerprint")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.content == "Test content"
|
||||||
|
assert result.total_tokens == 100
|
||||||
|
assert result.cache_hit is True
|
||||||
|
assert result.cache_key == "fingerprint"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_assembled(self) -> None:
|
||||||
|
"""Test set_assembled."""
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
|
||||||
|
settings = ContextSettings(cache_enabled=True, cache_ttl_seconds=60)
|
||||||
|
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||||
|
|
||||||
|
ctx = AssembledContext(
|
||||||
|
content="Test content",
|
||||||
|
total_tokens=100,
|
||||||
|
context_count=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
await cache.set_assembled("fingerprint", ctx)
|
||||||
|
|
||||||
|
mock_redis.setex.assert_called_once()
|
||||||
|
call_args = mock_redis.setex.call_args
|
||||||
|
assert call_args[0][0] == "ctx:assembled:fingerprint"
|
||||||
|
assert call_args[0][1] == 60 # TTL
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_assembled_custom_ttl(self) -> None:
|
||||||
|
"""Test set_assembled with custom TTL."""
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
|
||||||
|
settings = ContextSettings(cache_enabled=True)
|
||||||
|
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||||
|
|
||||||
|
ctx = AssembledContext(
|
||||||
|
content="Test",
|
||||||
|
total_tokens=10,
|
||||||
|
context_count=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
await cache.set_assembled("fp", ctx, ttl=120)
|
||||||
|
|
||||||
|
call_args = mock_redis.setex.call_args
|
||||||
|
assert call_args[0][1] == 120
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_error_on_get(self) -> None:
|
||||||
|
"""Test CacheError raised on Redis error."""
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.get.side_effect = Exception("Redis error")
|
||||||
|
|
||||||
|
settings = ContextSettings(cache_enabled=True)
|
||||||
|
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||||
|
|
||||||
|
with pytest.raises(CacheError):
|
||||||
|
await cache.get_assembled("fingerprint")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_error_on_set(self) -> None:
|
||||||
|
"""Test CacheError raised on Redis error."""
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.setex.side_effect = Exception("Redis error")
|
||||||
|
|
||||||
|
settings = ContextSettings(cache_enabled=True)
|
||||||
|
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||||
|
|
||||||
|
ctx = AssembledContext(
|
||||||
|
content="Test",
|
||||||
|
total_tokens=10,
|
||||||
|
context_count=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(CacheError):
|
||||||
|
await cache.set_assembled("fp", ctx)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenCountCache:
|
||||||
|
"""Tests for token count caching."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_token_count_memory_fallback(self) -> None:
|
||||||
|
"""Test get_token_count uses memory cache."""
|
||||||
|
cache = ContextCache()
|
||||||
|
|
||||||
|
# Set in memory
|
||||||
|
key = cache._cache_key("tokens", "default", cache._hash_content("hello"))
|
||||||
|
cache._set_memory(key, "42")
|
||||||
|
|
||||||
|
result = await cache.get_token_count("hello")
|
||||||
|
assert result == 42
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_token_count_memory(self) -> None:
|
||||||
|
"""Test set_token_count stores in memory."""
|
||||||
|
cache = ContextCache()
|
||||||
|
|
||||||
|
await cache.set_token_count("hello", 42)
|
||||||
|
|
||||||
|
result = await cache.get_token_count("hello")
|
||||||
|
assert result == 42
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_token_count_with_model(self) -> None:
|
||||||
|
"""Test set_token_count with model-specific tokenization."""
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
settings = ContextSettings(cache_enabled=True)
|
||||||
|
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||||
|
|
||||||
|
await cache.set_token_count("hello", 42, model="claude-3")
|
||||||
|
await cache.set_token_count("hello", 50, model="gpt-4")
|
||||||
|
|
||||||
|
# Different models should have different keys
|
||||||
|
assert mock_redis.setex.call_count == 2
|
||||||
|
calls = mock_redis.setex.call_args_list
|
||||||
|
|
||||||
|
key1 = calls[0][0][0]
|
||||||
|
key2 = calls[1][0][0]
|
||||||
|
assert "claude-3" in key1
|
||||||
|
assert "gpt-4" in key2
|
||||||
|
|
||||||
|
|
||||||
|
class TestScoreCache:
|
||||||
|
"""Tests for score caching."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_score_memory_fallback(self) -> None:
|
||||||
|
"""Test get_score uses memory cache."""
|
||||||
|
cache = ContextCache()
|
||||||
|
|
||||||
|
# Set in memory
|
||||||
|
query_hash = cache._hash_content("query")[:16]
|
||||||
|
key = cache._cache_key("score", "relevance", "ctx-123", query_hash)
|
||||||
|
cache._set_memory(key, "0.85")
|
||||||
|
|
||||||
|
result = await cache.get_score("relevance", "ctx-123", "query")
|
||||||
|
assert result == 0.85
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_score_memory(self) -> None:
|
||||||
|
"""Test set_score stores in memory."""
|
||||||
|
cache = ContextCache()
|
||||||
|
|
||||||
|
await cache.set_score("relevance", "ctx-123", "query", 0.85)
|
||||||
|
|
||||||
|
result = await cache.get_score("relevance", "ctx-123", "query")
|
||||||
|
assert result == 0.85
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_score_with_redis(self) -> None:
|
||||||
|
"""Test set_score with Redis."""
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
settings = ContextSettings(cache_enabled=True)
|
||||||
|
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||||
|
|
||||||
|
await cache.set_score("relevance", "ctx-123", "query", 0.85)
|
||||||
|
|
||||||
|
mock_redis.setex.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheInvalidation:
|
||||||
|
"""Tests for cache invalidation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalidate_pattern(self) -> None:
|
||||||
|
"""Test invalidate with pattern."""
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
|
||||||
|
# Set up scan_iter to return matching keys
|
||||||
|
async def mock_scan_iter(match=None):
|
||||||
|
for key in ["ctx:assembled:1", "ctx:assembled:2"]:
|
||||||
|
yield key
|
||||||
|
|
||||||
|
mock_redis.scan_iter = mock_scan_iter
|
||||||
|
|
||||||
|
settings = ContextSettings(cache_enabled=True)
|
||||||
|
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||||
|
|
||||||
|
deleted = await cache.invalidate("assembled:*")
|
||||||
|
|
||||||
|
assert deleted == 2
|
||||||
|
assert mock_redis.delete.call_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_clear_all(self) -> None:
|
||||||
|
"""Test clear_all."""
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
|
||||||
|
async def mock_scan_iter(match=None):
|
||||||
|
for key in ["ctx:1", "ctx:2", "ctx:3"]:
|
||||||
|
yield key
|
||||||
|
|
||||||
|
mock_redis.scan_iter = mock_scan_iter
|
||||||
|
|
||||||
|
settings = ContextSettings(cache_enabled=True)
|
||||||
|
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||||
|
|
||||||
|
# Add to memory cache
|
||||||
|
cache._set_memory("test", "value")
|
||||||
|
assert len(cache._memory_cache) > 0
|
||||||
|
|
||||||
|
deleted = await cache.clear_all()
|
||||||
|
|
||||||
|
assert deleted == 3
|
||||||
|
assert len(cache._memory_cache) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheStats:
|
||||||
|
"""Tests for cache statistics."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_stats_no_redis(self) -> None:
|
||||||
|
"""Test get_stats without Redis."""
|
||||||
|
cache = ContextCache()
|
||||||
|
cache._set_memory("key", "value")
|
||||||
|
|
||||||
|
stats = await cache.get_stats()
|
||||||
|
|
||||||
|
assert stats["enabled"] is True
|
||||||
|
assert stats["redis_available"] is False
|
||||||
|
assert stats["memory_items"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_stats_with_redis(self) -> None:
|
||||||
|
"""Test get_stats with Redis."""
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.info.return_value = {"used_memory_human": "1.5M"}
|
||||||
|
|
||||||
|
settings = ContextSettings(cache_enabled=True, cache_ttl_seconds=300)
|
||||||
|
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||||
|
|
||||||
|
stats = await cache.get_stats()
|
||||||
|
|
||||||
|
assert stats["enabled"] is True
|
||||||
|
assert stats["redis_available"] is True
|
||||||
|
assert stats["ttl_seconds"] == 300
|
||||||
|
assert stats["redis_memory_used"] == "1.5M"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheIntegration:
|
||||||
|
"""Integration tests for cache."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_workflow(self) -> None:
|
||||||
|
"""Test complete cache workflow."""
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.get.return_value = None
|
||||||
|
|
||||||
|
settings = ContextSettings(cache_enabled=True)
|
||||||
|
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
SystemContext(content="System", source="system"),
|
||||||
|
KnowledgeContext(content="Knowledge", source="docs"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Compute fingerprint
|
||||||
|
fp = cache.compute_fingerprint(contexts, "query", "claude-3")
|
||||||
|
assert len(fp) == 32
|
||||||
|
|
||||||
|
# Check cache (miss)
|
||||||
|
result = await cache.get_assembled(fp)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
# Create and cache assembled context
|
||||||
|
assembled = AssembledContext(
|
||||||
|
content="Assembled content",
|
||||||
|
total_tokens=100,
|
||||||
|
context_count=2,
|
||||||
|
model="claude-3",
|
||||||
|
)
|
||||||
|
await cache.set_assembled(fp, assembled)
|
||||||
|
|
||||||
|
# Verify setex was called
|
||||||
|
mock_redis.setex.assert_called_once()
|
||||||
|
|
||||||
|
# Mock cache hit
|
||||||
|
mock_redis.get.return_value = assembled.to_json()
|
||||||
|
result = await cache.get_assembled(fp)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.cache_hit is True
|
||||||
|
assert result.content == "Assembled content"
|
||||||
Reference in New Issue
Block a user