"""Tests for context cache module.""" from unittest.mock import AsyncMock, MagicMock import pytest from app.services.context.cache import ContextCache from app.services.context.config import ContextSettings from app.services.context.exceptions import CacheError from app.services.context.types import ( AssembledContext, ContextPriority, KnowledgeContext, SystemContext, TaskContext, ) class TestContextCacheBasics: """Basic tests for ContextCache.""" def test_creation(self) -> None: """Test cache creation without Redis.""" cache = ContextCache() assert cache._redis is None assert not cache.is_enabled def test_creation_with_settings(self) -> None: """Test cache creation with custom settings.""" settings = ContextSettings( cache_prefix="test", cache_ttl_seconds=60, ) cache = ContextCache(settings=settings) assert cache._prefix == "test" assert cache._ttl == 60 def test_set_redis(self) -> None: """Test setting Redis connection.""" cache = ContextCache() mock_redis = MagicMock() cache.set_redis(mock_redis) assert cache._redis is mock_redis def test_is_enabled(self) -> None: """Test is_enabled property.""" settings = ContextSettings(cache_enabled=True) cache = ContextCache(settings=settings) assert not cache.is_enabled # No Redis cache.set_redis(MagicMock()) assert cache.is_enabled # Disabled in settings settings2 = ContextSettings(cache_enabled=False) cache2 = ContextCache(redis=MagicMock(), settings=settings2) assert not cache2.is_enabled def test_cache_key(self) -> None: """Test cache key generation.""" cache = ContextCache() key = cache._cache_key("assembled", "abc123") assert key == "ctx:assembled:abc123" def test_hash_content(self) -> None: """Test content hashing.""" hash1 = ContextCache._hash_content("hello world") hash2 = ContextCache._hash_content("hello world") hash3 = ContextCache._hash_content("different") assert hash1 == hash2 assert hash1 != hash3 assert len(hash1) == 32 class TestFingerprintComputation: """Tests for fingerprint computation.""" def test_compute_fingerprint(self) -> None: """Test fingerprint computation.""" cache = ContextCache() contexts = [ SystemContext(content="System", source="system"), TaskContext(content="Task", source="task"), ] fp1 = cache.compute_fingerprint(contexts, "query", "claude-3") fp2 = cache.compute_fingerprint(contexts, "query", "claude-3") fp3 = cache.compute_fingerprint(contexts, "different", "claude-3") assert fp1 == fp2 # Same inputs = same fingerprint assert fp1 != fp3 # Different query = different fingerprint assert len(fp1) == 32 def test_fingerprint_includes_priority(self) -> None: """Test that fingerprint changes with priority.""" cache = ContextCache() # Use KnowledgeContext since SystemContext has __post_init__ that may override ctx1 = [ KnowledgeContext( content="Knowledge", source="docs", priority=ContextPriority.NORMAL.value, ) ] ctx2 = [ KnowledgeContext( content="Knowledge", source="docs", priority=ContextPriority.HIGH.value, ) ] fp1 = cache.compute_fingerprint(ctx1, "query", "claude-3") fp2 = cache.compute_fingerprint(ctx2, "query", "claude-3") assert fp1 != fp2 def test_fingerprint_includes_model(self) -> None: """Test that fingerprint changes with model.""" cache = ContextCache() contexts = [SystemContext(content="System", source="system")] fp1 = cache.compute_fingerprint(contexts, "query", "claude-3") fp2 = cache.compute_fingerprint(contexts, "query", "gpt-4") assert fp1 != fp2 class TestMemoryCache: """Tests for in-memory caching.""" def test_memory_cache_fallback(self) -> None: """Test memory cache when Redis unavailable.""" cache = ContextCache() # Should use memory cache cache._set_memory("test-key", "42") assert "test-key" in cache._memory_cache assert cache._memory_cache["test-key"][0] == "42" def test_memory_cache_eviction(self) -> None: """Test memory cache eviction.""" cache = ContextCache() cache._max_memory_items = 10 # Fill cache for i in range(15): cache._set_memory(f"key-{i}", f"value-{i}") # Should have evicted some items assert len(cache._memory_cache) < 15 class TestAssembledContextCache: """Tests for assembled context caching.""" @pytest.mark.asyncio async def test_get_assembled_no_redis(self) -> None: """Test get_assembled without Redis returns None.""" cache = ContextCache() result = await cache.get_assembled("fingerprint") assert result is None @pytest.mark.asyncio async def test_get_assembled_not_found(self) -> None: """Test get_assembled when key not found.""" mock_redis = AsyncMock() mock_redis.get.return_value = None settings = ContextSettings(cache_enabled=True) cache = ContextCache(redis=mock_redis, settings=settings) result = await cache.get_assembled("fingerprint") assert result is None @pytest.mark.asyncio async def test_get_assembled_found(self) -> None: """Test get_assembled when key found.""" # Create a context ctx = AssembledContext( content="Test content", total_tokens=100, context_count=2, ) mock_redis = AsyncMock() mock_redis.get.return_value = ctx.to_json() settings = ContextSettings(cache_enabled=True) cache = ContextCache(redis=mock_redis, settings=settings) result = await cache.get_assembled("fingerprint") assert result is not None assert result.content == "Test content" assert result.total_tokens == 100 assert result.cache_hit is True assert result.cache_key == "fingerprint" @pytest.mark.asyncio async def test_set_assembled(self) -> None: """Test set_assembled.""" mock_redis = AsyncMock() settings = ContextSettings(cache_enabled=True, cache_ttl_seconds=60) cache = ContextCache(redis=mock_redis, settings=settings) ctx = AssembledContext( content="Test content", total_tokens=100, context_count=2, ) await cache.set_assembled("fingerprint", ctx) mock_redis.setex.assert_called_once() call_args = mock_redis.setex.call_args assert call_args[0][0] == "ctx:assembled:fingerprint" assert call_args[0][1] == 60 # TTL @pytest.mark.asyncio async def test_set_assembled_custom_ttl(self) -> None: """Test set_assembled with custom TTL.""" mock_redis = AsyncMock() settings = ContextSettings(cache_enabled=True) cache = ContextCache(redis=mock_redis, settings=settings) ctx = AssembledContext( content="Test", total_tokens=10, context_count=1, ) await cache.set_assembled("fp", ctx, ttl=120) call_args = mock_redis.setex.call_args assert call_args[0][1] == 120 @pytest.mark.asyncio async def test_cache_error_on_get(self) -> None: """Test CacheError raised on Redis error.""" mock_redis = AsyncMock() mock_redis.get.side_effect = Exception("Redis error") settings = ContextSettings(cache_enabled=True) cache = ContextCache(redis=mock_redis, settings=settings) with pytest.raises(CacheError): await cache.get_assembled("fingerprint") @pytest.mark.asyncio async def test_cache_error_on_set(self) -> None: """Test CacheError raised on Redis error.""" mock_redis = AsyncMock() mock_redis.setex.side_effect = Exception("Redis error") settings = ContextSettings(cache_enabled=True) cache = ContextCache(redis=mock_redis, settings=settings) ctx = AssembledContext( content="Test", total_tokens=10, context_count=1, ) with pytest.raises(CacheError): await cache.set_assembled("fp", ctx) class TestTokenCountCache: """Tests for token count caching.""" @pytest.mark.asyncio async def test_get_token_count_memory_fallback(self) -> None: """Test get_token_count uses memory cache.""" cache = ContextCache() # Set in memory key = cache._cache_key("tokens", "default", cache._hash_content("hello")) cache._set_memory(key, "42") result = await cache.get_token_count("hello") assert result == 42 @pytest.mark.asyncio async def test_set_token_count_memory(self) -> None: """Test set_token_count stores in memory.""" cache = ContextCache() await cache.set_token_count("hello", 42) result = await cache.get_token_count("hello") assert result == 42 @pytest.mark.asyncio async def test_set_token_count_with_model(self) -> None: """Test set_token_count with model-specific tokenization.""" mock_redis = AsyncMock() settings = ContextSettings(cache_enabled=True) cache = ContextCache(redis=mock_redis, settings=settings) await cache.set_token_count("hello", 42, model="claude-3") await cache.set_token_count("hello", 50, model="gpt-4") # Different models should have different keys assert mock_redis.setex.call_count == 2 calls = mock_redis.setex.call_args_list key1 = calls[0][0][0] key2 = calls[1][0][0] assert "claude-3" in key1 assert "gpt-4" in key2 class TestScoreCache: """Tests for score caching.""" @pytest.mark.asyncio async def test_get_score_memory_fallback(self) -> None: """Test get_score uses memory cache.""" cache = ContextCache() # Set in memory query_hash = cache._hash_content("query")[:16] key = cache._cache_key("score", "relevance", "ctx-123", query_hash) cache._set_memory(key, "0.85") result = await cache.get_score("relevance", "ctx-123", "query") assert result == 0.85 @pytest.mark.asyncio async def test_set_score_memory(self) -> None: """Test set_score stores in memory.""" cache = ContextCache() await cache.set_score("relevance", "ctx-123", "query", 0.85) result = await cache.get_score("relevance", "ctx-123", "query") assert result == 0.85 @pytest.mark.asyncio async def test_set_score_with_redis(self) -> None: """Test set_score with Redis.""" mock_redis = AsyncMock() settings = ContextSettings(cache_enabled=True) cache = ContextCache(redis=mock_redis, settings=settings) await cache.set_score("relevance", "ctx-123", "query", 0.85) mock_redis.setex.assert_called_once() class TestCacheInvalidation: """Tests for cache invalidation.""" @pytest.mark.asyncio async def test_invalidate_pattern(self) -> None: """Test invalidate with pattern.""" mock_redis = AsyncMock() # Set up scan_iter to return matching keys async def mock_scan_iter(match=None): for key in ["ctx:assembled:1", "ctx:assembled:2"]: yield key mock_redis.scan_iter = mock_scan_iter settings = ContextSettings(cache_enabled=True) cache = ContextCache(redis=mock_redis, settings=settings) deleted = await cache.invalidate("assembled:*") assert deleted == 2 assert mock_redis.delete.call_count == 2 @pytest.mark.asyncio async def test_clear_all(self) -> None: """Test clear_all.""" mock_redis = AsyncMock() async def mock_scan_iter(match=None): for key in ["ctx:1", "ctx:2", "ctx:3"]: yield key mock_redis.scan_iter = mock_scan_iter settings = ContextSettings(cache_enabled=True) cache = ContextCache(redis=mock_redis, settings=settings) # Add to memory cache cache._set_memory("test", "value") assert len(cache._memory_cache) > 0 deleted = await cache.clear_all() assert deleted == 3 assert len(cache._memory_cache) == 0 class TestCacheStats: """Tests for cache statistics.""" @pytest.mark.asyncio async def test_get_stats_no_redis(self) -> None: """Test get_stats without Redis.""" cache = ContextCache() cache._set_memory("key", "value") stats = await cache.get_stats() assert stats["enabled"] is True assert stats["redis_available"] is False assert stats["memory_items"] == 1 @pytest.mark.asyncio async def test_get_stats_with_redis(self) -> None: """Test get_stats with Redis.""" mock_redis = AsyncMock() mock_redis.info.return_value = {"used_memory_human": "1.5M"} settings = ContextSettings(cache_enabled=True, cache_ttl_seconds=300) cache = ContextCache(redis=mock_redis, settings=settings) stats = await cache.get_stats() assert stats["enabled"] is True assert stats["redis_available"] is True assert stats["ttl_seconds"] == 300 assert stats["redis_memory_used"] == "1.5M" class TestCacheIntegration: """Integration tests for cache.""" @pytest.mark.asyncio async def test_full_workflow(self) -> None: """Test complete cache workflow.""" mock_redis = AsyncMock() mock_redis.get.return_value = None settings = ContextSettings(cache_enabled=True) cache = ContextCache(redis=mock_redis, settings=settings) contexts = [ SystemContext(content="System", source="system"), KnowledgeContext(content="Knowledge", source="docs"), ] # Compute fingerprint fp = cache.compute_fingerprint(contexts, "query", "claude-3") assert len(fp) == 32 # Check cache (miss) result = await cache.get_assembled(fp) assert result is None # Create and cache assembled context assembled = AssembledContext( content="Assembled content", total_tokens=100, context_count=2, model="claude-3", ) await cache.set_assembled(fp, assembled) # Verify setex was called mock_redis.setex.assert_called_once() # Mock cache hit mock_redis.get.return_value = assembled.to_json() result = await cache.get_assembled(fp) assert result is not None assert result.cache_hit is True assert result.content == "Assembled content"