feat(memory): implement caching layer for memory operations (#98)
Add comprehensive caching layer for the Agent Memory System: - HotMemoryCache: LRU cache for frequently accessed memories - Python 3.12 type parameter syntax - Thread-safe operations with RLock - TTL-based expiration - Access count tracking for hot memory identification - Scoped invalidation by type, scope, or pattern - EmbeddingCache: Cache embeddings by content hash - Content-hash based deduplication - Optional Redis backing for persistence - LRU eviction with configurable max size - CachedEmbeddingGenerator wrapper for transparent caching - CacheManager: Unified cache management - Coordinates hot cache, embedding cache, and retrieval cache - Centralized invalidation across all caches - Aggregated statistics and hit rate tracking - Automatic cleanup scheduling - Cache warmup support Performance targets: - Cache hit rate > 80% for hot memories - Cache operations < 1ms (memory), < 5ms (Redis) 83 new tests with comprehensive coverage. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2
backend/tests/unit/services/memory/cache/__init__.py
vendored
Normal file
2
backend/tests/unit/services/memory/cache/__init__.py
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
# tests/unit/services/memory/cache/__init__.py
|
||||
"""Tests for memory caching layer."""
|
||||
331
backend/tests/unit/services/memory/cache/test_cache_manager.py
vendored
Normal file
331
backend/tests/unit/services/memory/cache/test_cache_manager.py
vendored
Normal file
@@ -0,0 +1,331 @@
|
||||
# tests/unit/services/memory/cache/test_cache_manager.py
|
||||
"""Tests for CacheManager."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.cache.cache_manager import (
|
||||
CacheManager,
|
||||
CacheStats,
|
||||
get_cache_manager,
|
||||
reset_cache_manager,
|
||||
)
|
||||
from app.services.memory.cache.embedding_cache import EmbeddingCache
|
||||
from app.services.memory.cache.hot_cache import HotMemoryCache
|
||||
|
||||
pytestmark = pytest.mark.asyncio(loop_scope="function")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_singleton() -> None:
|
||||
"""Reset singleton before each test."""
|
||||
reset_cache_manager()
|
||||
|
||||
|
||||
class TestCacheStats:
|
||||
"""Tests for CacheStats."""
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Should convert to dictionary."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
stats = CacheStats(
|
||||
hot_cache={"hits": 10},
|
||||
embedding_cache={"hits": 20},
|
||||
overall_hit_rate=0.75,
|
||||
last_cleanup=datetime.now(UTC),
|
||||
cleanup_count=5,
|
||||
)
|
||||
|
||||
result = stats.to_dict()
|
||||
|
||||
assert result["hot_cache"] == {"hits": 10}
|
||||
assert result["overall_hit_rate"] == 0.75
|
||||
assert result["cleanup_count"] == 5
|
||||
assert result["last_cleanup"] is not None
|
||||
|
||||
|
||||
class TestCacheManager:
|
||||
"""Tests for CacheManager."""
|
||||
|
||||
@pytest.fixture
|
||||
def manager(self) -> CacheManager:
|
||||
"""Create a cache manager."""
|
||||
return CacheManager()
|
||||
|
||||
def test_is_enabled(self, manager: CacheManager) -> None:
|
||||
"""Should check if caching is enabled."""
|
||||
# Default is enabled from settings
|
||||
assert manager.is_enabled is True
|
||||
|
||||
def test_has_hot_cache(self, manager: CacheManager) -> None:
|
||||
"""Should have hot memory cache."""
|
||||
assert manager.hot_cache is not None
|
||||
assert isinstance(manager.hot_cache, HotMemoryCache)
|
||||
|
||||
def test_has_embedding_cache(self, manager: CacheManager) -> None:
|
||||
"""Should have embedding cache."""
|
||||
assert manager.embedding_cache is not None
|
||||
assert isinstance(manager.embedding_cache, EmbeddingCache)
|
||||
|
||||
def test_cache_memory(self, manager: CacheManager) -> None:
|
||||
"""Should cache memory in hot cache."""
|
||||
memory_id = uuid4()
|
||||
memory = {"task": "test", "data": "value"}
|
||||
|
||||
manager.cache_memory("episodic", memory_id, memory)
|
||||
result = manager.get_memory("episodic", memory_id)
|
||||
|
||||
assert result == memory
|
||||
|
||||
def test_cache_memory_with_scope(self, manager: CacheManager) -> None:
|
||||
"""Should cache memory with scope."""
|
||||
memory_id = uuid4()
|
||||
memory = {"task": "test"}
|
||||
|
||||
manager.cache_memory("semantic", memory_id, memory, scope="proj-123")
|
||||
result = manager.get_memory("semantic", memory_id, scope="proj-123")
|
||||
|
||||
assert result == memory
|
||||
|
||||
async def test_cache_embedding(self, manager: CacheManager) -> None:
|
||||
"""Should cache embedding."""
|
||||
content = "test content"
|
||||
embedding = [0.1, 0.2, 0.3]
|
||||
|
||||
content_hash = await manager.cache_embedding(content, embedding)
|
||||
result = await manager.get_embedding(content)
|
||||
|
||||
assert result == embedding
|
||||
assert len(content_hash) == 32
|
||||
|
||||
async def test_invalidate_memory(self, manager: CacheManager) -> None:
|
||||
"""Should invalidate memory from hot cache."""
|
||||
memory_id = uuid4()
|
||||
manager.cache_memory("episodic", memory_id, {"data": "test"})
|
||||
|
||||
count = await manager.invalidate_memory("episodic", memory_id)
|
||||
|
||||
assert count >= 1
|
||||
assert manager.get_memory("episodic", memory_id) is None
|
||||
|
||||
async def test_invalidate_by_type(self, manager: CacheManager) -> None:
|
||||
"""Should invalidate all entries of a type."""
|
||||
manager.cache_memory("episodic", uuid4(), {"data": "1"})
|
||||
manager.cache_memory("episodic", uuid4(), {"data": "2"})
|
||||
manager.cache_memory("semantic", uuid4(), {"data": "3"})
|
||||
|
||||
count = await manager.invalidate_by_type("episodic")
|
||||
|
||||
assert count >= 2
|
||||
|
||||
async def test_invalidate_by_scope(self, manager: CacheManager) -> None:
|
||||
"""Should invalidate all entries in a scope."""
|
||||
manager.cache_memory("episodic", uuid4(), {"data": "1"}, scope="proj-1")
|
||||
manager.cache_memory("semantic", uuid4(), {"data": "2"}, scope="proj-1")
|
||||
manager.cache_memory("episodic", uuid4(), {"data": "3"}, scope="proj-2")
|
||||
|
||||
count = await manager.invalidate_by_scope("proj-1")
|
||||
|
||||
assert count >= 2
|
||||
|
||||
async def test_invalidate_embedding(self, manager: CacheManager) -> None:
|
||||
"""Should invalidate cached embedding."""
|
||||
content = "test content"
|
||||
await manager.cache_embedding(content, [0.1, 0.2])
|
||||
|
||||
result = await manager.invalidate_embedding(content)
|
||||
|
||||
assert result is True
|
||||
assert await manager.get_embedding(content) is None
|
||||
|
||||
async def test_clear_all(self, manager: CacheManager) -> None:
|
||||
"""Should clear all caches."""
|
||||
manager.cache_memory("episodic", uuid4(), {"data": "test"})
|
||||
await manager.cache_embedding("content", [0.1])
|
||||
|
||||
count = await manager.clear_all()
|
||||
|
||||
assert count >= 2
|
||||
|
||||
async def test_cleanup_expired(self, manager: CacheManager) -> None:
|
||||
"""Should clean up expired entries."""
|
||||
count = await manager.cleanup_expired()
|
||||
|
||||
# May be 0 if no expired entries
|
||||
assert count >= 0
|
||||
assert manager._cleanup_count == 1
|
||||
assert manager._last_cleanup is not None
|
||||
|
||||
def test_get_stats(self, manager: CacheManager) -> None:
|
||||
"""Should return aggregated statistics."""
|
||||
manager.cache_memory("episodic", uuid4(), {"data": "test"})
|
||||
|
||||
stats = manager.get_stats()
|
||||
|
||||
assert "hot_cache" in stats.to_dict()
|
||||
assert "embedding_cache" in stats.to_dict()
|
||||
assert "overall_hit_rate" in stats.to_dict()
|
||||
|
||||
def test_get_hot_memories(self, manager: CacheManager) -> None:
|
||||
"""Should return most accessed memories."""
|
||||
id1 = uuid4()
|
||||
id2 = uuid4()
|
||||
|
||||
manager.cache_memory("episodic", id1, {"data": "1"})
|
||||
manager.cache_memory("episodic", id2, {"data": "2"})
|
||||
|
||||
# Access first multiple times
|
||||
for _ in range(5):
|
||||
manager.get_memory("episodic", id1)
|
||||
|
||||
hot = manager.get_hot_memories(limit=2)
|
||||
|
||||
assert len(hot) == 2
|
||||
|
||||
def test_reset_stats(self, manager: CacheManager) -> None:
|
||||
"""Should reset all statistics."""
|
||||
manager.cache_memory("episodic", uuid4(), {"data": "test"})
|
||||
manager.get_memory("episodic", uuid4()) # Miss
|
||||
|
||||
manager.reset_stats()
|
||||
|
||||
stats = manager.get_stats()
|
||||
assert stats.hot_cache.get("hits", 0) == 0
|
||||
|
||||
async def test_warmup(self, manager: CacheManager) -> None:
|
||||
"""Should warm up cache with memories."""
|
||||
memories = [
|
||||
("episodic", uuid4(), {"data": "1"}),
|
||||
("episodic", uuid4(), {"data": "2"}),
|
||||
("semantic", uuid4(), {"data": "3"}),
|
||||
]
|
||||
|
||||
count = await manager.warmup(memories)
|
||||
|
||||
assert count == 3
|
||||
|
||||
|
||||
class TestCacheManagerWithRetrieval:
|
||||
"""Tests for CacheManager with retrieval cache."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_retrieval_cache(self) -> MagicMock:
|
||||
"""Create mock retrieval cache."""
|
||||
cache = MagicMock()
|
||||
cache.invalidate_by_memory = MagicMock(return_value=1)
|
||||
cache.clear = MagicMock(return_value=5)
|
||||
cache.get_stats = MagicMock(return_value={"entries": 10})
|
||||
return cache
|
||||
|
||||
@pytest.fixture
|
||||
def manager_with_retrieval(
|
||||
self,
|
||||
mock_retrieval_cache: MagicMock,
|
||||
) -> CacheManager:
|
||||
"""Create manager with retrieval cache."""
|
||||
manager = CacheManager()
|
||||
manager.set_retrieval_cache(mock_retrieval_cache)
|
||||
return manager
|
||||
|
||||
async def test_invalidate_clears_retrieval(
|
||||
self,
|
||||
manager_with_retrieval: CacheManager,
|
||||
mock_retrieval_cache: MagicMock,
|
||||
) -> None:
|
||||
"""Should invalidate retrieval cache entries."""
|
||||
memory_id = uuid4()
|
||||
|
||||
await manager_with_retrieval.invalidate_memory("episodic", memory_id)
|
||||
|
||||
mock_retrieval_cache.invalidate_by_memory.assert_called_once_with(memory_id)
|
||||
|
||||
def test_stats_includes_retrieval(
|
||||
self,
|
||||
manager_with_retrieval: CacheManager,
|
||||
) -> None:
|
||||
"""Should include retrieval cache stats."""
|
||||
stats = manager_with_retrieval.get_stats()
|
||||
|
||||
assert "retrieval_cache" in stats.to_dict()
|
||||
|
||||
|
||||
class TestCacheManagerDisabled:
|
||||
"""Tests for CacheManager when disabled."""
|
||||
|
||||
@pytest.fixture
|
||||
def disabled_manager(self) -> CacheManager:
|
||||
"""Create a disabled cache manager."""
|
||||
with patch(
|
||||
"app.services.memory.cache.cache_manager.get_memory_settings"
|
||||
) as mock_settings:
|
||||
settings = MagicMock()
|
||||
settings.cache_enabled = False
|
||||
settings.cache_max_items = 1000
|
||||
settings.cache_ttl_seconds = 300
|
||||
mock_settings.return_value = settings
|
||||
|
||||
return CacheManager()
|
||||
|
||||
def test_get_memory_returns_none(self, disabled_manager: CacheManager) -> None:
|
||||
"""Should return None when disabled."""
|
||||
disabled_manager.cache_memory("episodic", uuid4(), {"data": "test"})
|
||||
result = disabled_manager.get_memory("episodic", uuid4())
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_get_embedding_returns_none(
|
||||
self,
|
||||
disabled_manager: CacheManager,
|
||||
) -> None:
|
||||
"""Should return None for embeddings when disabled."""
|
||||
result = await disabled_manager.get_embedding("content")
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_warmup_returns_zero(self, disabled_manager: CacheManager) -> None:
|
||||
"""Should return 0 from warmup when disabled."""
|
||||
count = await disabled_manager.warmup([("episodic", uuid4(), {})])
|
||||
|
||||
assert count == 0
|
||||
|
||||
|
||||
class TestGetCacheManager:
|
||||
"""Tests for get_cache_manager factory."""
|
||||
|
||||
def test_returns_singleton(self) -> None:
|
||||
"""Should return same instance."""
|
||||
manager1 = get_cache_manager()
|
||||
manager2 = get_cache_manager()
|
||||
|
||||
assert manager1 is manager2
|
||||
|
||||
def test_reset_creates_new(self) -> None:
|
||||
"""Should create new instance after reset."""
|
||||
manager1 = get_cache_manager()
|
||||
reset_cache_manager()
|
||||
manager2 = get_cache_manager()
|
||||
|
||||
assert manager1 is not manager2
|
||||
|
||||
def test_reset_parameter(self) -> None:
|
||||
"""Should create new instance with reset=True."""
|
||||
manager1 = get_cache_manager()
|
||||
manager2 = get_cache_manager(reset=True)
|
||||
|
||||
assert manager1 is not manager2
|
||||
|
||||
|
||||
class TestResetCacheManager:
|
||||
"""Tests for reset_cache_manager."""
|
||||
|
||||
def test_resets_singleton(self) -> None:
|
||||
"""Should reset the singleton."""
|
||||
get_cache_manager()
|
||||
reset_cache_manager()
|
||||
|
||||
# Next call should create new instance
|
||||
manager = get_cache_manager()
|
||||
assert manager is not None
|
||||
391
backend/tests/unit/services/memory/cache/test_embedding_cache.py
vendored
Normal file
391
backend/tests/unit/services/memory/cache/test_embedding_cache.py
vendored
Normal file
@@ -0,0 +1,391 @@
|
||||
# tests/unit/services/memory/cache/test_embedding_cache.py
|
||||
"""Tests for EmbeddingCache."""
|
||||
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.cache.embedding_cache import (
|
||||
CachedEmbeddingGenerator,
|
||||
EmbeddingCache,
|
||||
EmbeddingCacheStats,
|
||||
EmbeddingEntry,
|
||||
create_embedding_cache,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.asyncio(loop_scope="function")
|
||||
|
||||
|
||||
class TestEmbeddingEntry:
|
||||
"""Tests for EmbeddingEntry."""
|
||||
|
||||
def test_creates_entry(self) -> None:
|
||||
"""Should create entry with embedding."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
entry = EmbeddingEntry(
|
||||
embedding=[0.1, 0.2, 0.3],
|
||||
content_hash="abc123",
|
||||
model="text-embedding-3-small",
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
assert entry.embedding == [0.1, 0.2, 0.3]
|
||||
assert entry.content_hash == "abc123"
|
||||
assert entry.ttl_seconds == 3600.0
|
||||
|
||||
def test_is_expired(self) -> None:
|
||||
"""Should detect expired entries."""
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
old_time = datetime.now(UTC) - timedelta(seconds=4000)
|
||||
entry = EmbeddingEntry(
|
||||
embedding=[0.1],
|
||||
content_hash="abc",
|
||||
model="default",
|
||||
created_at=old_time,
|
||||
ttl_seconds=3600.0,
|
||||
)
|
||||
|
||||
assert entry.is_expired() is True
|
||||
|
||||
def test_not_expired(self) -> None:
|
||||
"""Should detect non-expired entries."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
entry = EmbeddingEntry(
|
||||
embedding=[0.1],
|
||||
content_hash="abc",
|
||||
model="default",
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
assert entry.is_expired() is False
|
||||
|
||||
|
||||
class TestEmbeddingCacheStats:
|
||||
"""Tests for EmbeddingCacheStats."""
|
||||
|
||||
def test_hit_rate_calculation(self) -> None:
|
||||
"""Should calculate hit rate correctly."""
|
||||
stats = EmbeddingCacheStats(hits=90, misses=10)
|
||||
|
||||
assert stats.hit_rate == 0.9
|
||||
|
||||
def test_hit_rate_zero_requests(self) -> None:
|
||||
"""Should return 0 for no requests."""
|
||||
stats = EmbeddingCacheStats()
|
||||
|
||||
assert stats.hit_rate == 0.0
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Should convert to dictionary."""
|
||||
stats = EmbeddingCacheStats(hits=10, misses=5, bytes_saved=1000)
|
||||
|
||||
result = stats.to_dict()
|
||||
|
||||
assert result["hits"] == 10
|
||||
assert result["bytes_saved"] == 1000
|
||||
|
||||
|
||||
class TestEmbeddingCache:
|
||||
"""Tests for EmbeddingCache."""
|
||||
|
||||
@pytest.fixture
|
||||
def cache(self) -> EmbeddingCache:
|
||||
"""Create an embedding cache."""
|
||||
return EmbeddingCache(max_size=100, default_ttl_seconds=300.0)
|
||||
|
||||
async def test_put_and_get(self, cache: EmbeddingCache) -> None:
|
||||
"""Should store and retrieve embeddings."""
|
||||
content = "Hello world"
|
||||
embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
|
||||
content_hash = await cache.put(content, embedding)
|
||||
result = await cache.get(content)
|
||||
|
||||
assert result == embedding
|
||||
assert len(content_hash) == 32
|
||||
|
||||
async def test_get_missing(self, cache: EmbeddingCache) -> None:
|
||||
"""Should return None for missing content."""
|
||||
result = await cache.get("nonexistent content")
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_get_by_hash(self, cache: EmbeddingCache) -> None:
|
||||
"""Should get by content hash."""
|
||||
content = "Test content"
|
||||
embedding = [0.1, 0.2]
|
||||
|
||||
content_hash = await cache.put(content, embedding)
|
||||
result = await cache.get_by_hash(content_hash)
|
||||
|
||||
assert result == embedding
|
||||
|
||||
async def test_model_separation(self, cache: EmbeddingCache) -> None:
|
||||
"""Should separate embeddings by model."""
|
||||
content = "Same content"
|
||||
emb1 = [0.1, 0.2]
|
||||
emb2 = [0.3, 0.4]
|
||||
|
||||
await cache.put(content, emb1, model="model-a")
|
||||
await cache.put(content, emb2, model="model-b")
|
||||
|
||||
result1 = await cache.get(content, model="model-a")
|
||||
result2 = await cache.get(content, model="model-b")
|
||||
|
||||
assert result1 == emb1
|
||||
assert result2 == emb2
|
||||
|
||||
async def test_lru_eviction(self) -> None:
|
||||
"""Should evict LRU entries when at capacity."""
|
||||
cache = EmbeddingCache(max_size=3)
|
||||
|
||||
await cache.put("content1", [0.1])
|
||||
await cache.put("content2", [0.2])
|
||||
await cache.put("content3", [0.3])
|
||||
|
||||
# Access first to make it recent
|
||||
await cache.get("content1")
|
||||
|
||||
# Add fourth, should evict second (LRU)
|
||||
await cache.put("content4", [0.4])
|
||||
|
||||
assert await cache.get("content1") is not None
|
||||
assert await cache.get("content2") is None # Evicted
|
||||
assert await cache.get("content3") is not None
|
||||
assert await cache.get("content4") is not None
|
||||
|
||||
async def test_ttl_expiration(self) -> None:
|
||||
"""Should expire entries after TTL."""
|
||||
cache = EmbeddingCache(max_size=100, default_ttl_seconds=0.1)
|
||||
|
||||
await cache.put("content", [0.1, 0.2])
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
result = await cache.get("content")
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_put_batch(self, cache: EmbeddingCache) -> None:
|
||||
"""Should cache multiple embeddings."""
|
||||
items = [
|
||||
("content1", [0.1]),
|
||||
("content2", [0.2]),
|
||||
("content3", [0.3]),
|
||||
]
|
||||
|
||||
hashes = await cache.put_batch(items)
|
||||
|
||||
assert len(hashes) == 3
|
||||
assert await cache.get("content1") == [0.1]
|
||||
assert await cache.get("content2") == [0.2]
|
||||
|
||||
async def test_invalidate(self, cache: EmbeddingCache) -> None:
|
||||
"""Should invalidate cached embedding."""
|
||||
await cache.put("content", [0.1, 0.2])
|
||||
|
||||
result = await cache.invalidate("content")
|
||||
|
||||
assert result is True
|
||||
assert await cache.get("content") is None
|
||||
|
||||
async def test_invalidate_by_hash(self, cache: EmbeddingCache) -> None:
|
||||
"""Should invalidate by hash."""
|
||||
content_hash = await cache.put("content", [0.1, 0.2])
|
||||
|
||||
result = await cache.invalidate_by_hash(content_hash)
|
||||
|
||||
assert result is True
|
||||
assert await cache.get("content") is None
|
||||
|
||||
async def test_invalidate_by_model(self, cache: EmbeddingCache) -> None:
|
||||
"""Should invalidate all embeddings for a model."""
|
||||
await cache.put("content1", [0.1], model="model-a")
|
||||
await cache.put("content2", [0.2], model="model-a")
|
||||
await cache.put("content3", [0.3], model="model-b")
|
||||
|
||||
count = await cache.invalidate_by_model("model-a")
|
||||
|
||||
assert count == 2
|
||||
assert await cache.get("content1", model="model-a") is None
|
||||
assert await cache.get("content3", model="model-b") is not None
|
||||
|
||||
async def test_clear(self, cache: EmbeddingCache) -> None:
|
||||
"""Should clear all entries."""
|
||||
await cache.put("content1", [0.1])
|
||||
await cache.put("content2", [0.2])
|
||||
|
||||
count = await cache.clear()
|
||||
|
||||
assert count == 2
|
||||
assert cache.size == 0
|
||||
|
||||
def test_cleanup_expired(self) -> None:
|
||||
"""Should remove expired entries."""
|
||||
cache = EmbeddingCache(max_size=100, default_ttl_seconds=0.1)
|
||||
|
||||
# Use synchronous put for setup
|
||||
cache._put_memory("hash1", "default", [0.1])
|
||||
cache._put_memory("hash2", "default", [0.2], ttl_seconds=10)
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
count = cache.cleanup_expired()
|
||||
|
||||
assert count == 1
|
||||
|
||||
def test_get_stats(self, cache: EmbeddingCache) -> None:
|
||||
"""Should return accurate statistics."""
|
||||
# Put synchronously for setup
|
||||
cache._put_memory("hash1", "default", [0.1])
|
||||
|
||||
stats = cache.get_stats()
|
||||
|
||||
assert stats.current_size == 1
|
||||
|
||||
def test_hash_content(self) -> None:
|
||||
"""Should produce consistent hashes."""
|
||||
hash1 = EmbeddingCache.hash_content("test content")
|
||||
hash2 = EmbeddingCache.hash_content("test content")
|
||||
hash3 = EmbeddingCache.hash_content("different content")
|
||||
|
||||
assert hash1 == hash2
|
||||
assert hash1 != hash3
|
||||
assert len(hash1) == 32
|
||||
|
||||
|
||||
class TestEmbeddingCacheWithRedis:
|
||||
"""Tests for EmbeddingCache with Redis."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis(self) -> MagicMock:
|
||||
"""Create mock Redis."""
|
||||
redis = MagicMock()
|
||||
redis.get = AsyncMock(return_value=None)
|
||||
redis.setex = AsyncMock()
|
||||
redis.delete = AsyncMock()
|
||||
redis.scan_iter = MagicMock(return_value=iter([]))
|
||||
return redis
|
||||
|
||||
@pytest.fixture
|
||||
def cache_with_redis(self, mock_redis: MagicMock) -> EmbeddingCache:
|
||||
"""Create cache with mock Redis."""
|
||||
return EmbeddingCache(
|
||||
max_size=100,
|
||||
default_ttl_seconds=300.0,
|
||||
redis=mock_redis,
|
||||
)
|
||||
|
||||
async def test_put_stores_in_redis(
|
||||
self,
|
||||
cache_with_redis: EmbeddingCache,
|
||||
mock_redis: MagicMock,
|
||||
) -> None:
|
||||
"""Should store in Redis when available."""
|
||||
await cache_with_redis.put("content", [0.1, 0.2])
|
||||
|
||||
mock_redis.setex.assert_called_once()
|
||||
|
||||
async def test_get_checks_redis_on_miss(
|
||||
self,
|
||||
cache_with_redis: EmbeddingCache,
|
||||
mock_redis: MagicMock,
|
||||
) -> None:
|
||||
"""Should check Redis when memory cache misses."""
|
||||
import json
|
||||
|
||||
mock_redis.get.return_value = json.dumps([0.1, 0.2])
|
||||
|
||||
result = await cache_with_redis.get("content")
|
||||
|
||||
assert result == [0.1, 0.2]
|
||||
mock_redis.get.assert_called_once()
|
||||
|
||||
|
||||
class TestCachedEmbeddingGenerator:
|
||||
"""Tests for CachedEmbeddingGenerator."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_generator(self) -> MagicMock:
|
||||
"""Create mock embedding generator."""
|
||||
gen = MagicMock()
|
||||
gen.generate = AsyncMock(return_value=[0.1, 0.2, 0.3])
|
||||
gen.generate_batch = AsyncMock(return_value=[[0.1], [0.2], [0.3]])
|
||||
return gen
|
||||
|
||||
@pytest.fixture
|
||||
def cache(self) -> EmbeddingCache:
|
||||
"""Create embedding cache."""
|
||||
return EmbeddingCache(max_size=100)
|
||||
|
||||
@pytest.fixture
|
||||
def cached_gen(
|
||||
self,
|
||||
mock_generator: MagicMock,
|
||||
cache: EmbeddingCache,
|
||||
) -> CachedEmbeddingGenerator:
|
||||
"""Create cached generator."""
|
||||
return CachedEmbeddingGenerator(mock_generator, cache)
|
||||
|
||||
async def test_generate_caches_result(
|
||||
self,
|
||||
cached_gen: CachedEmbeddingGenerator,
|
||||
mock_generator: MagicMock,
|
||||
) -> None:
|
||||
"""Should cache generated embedding."""
|
||||
result1 = await cached_gen.generate("test text")
|
||||
result2 = await cached_gen.generate("test text")
|
||||
|
||||
assert result1 == [0.1, 0.2, 0.3]
|
||||
assert result2 == [0.1, 0.2, 0.3]
|
||||
mock_generator.generate.assert_called_once() # Only called once
|
||||
|
||||
async def test_generate_batch_uses_cache(
|
||||
self,
|
||||
cached_gen: CachedEmbeddingGenerator,
|
||||
mock_generator: MagicMock,
|
||||
cache: EmbeddingCache,
|
||||
) -> None:
|
||||
"""Should use cache for batch generation."""
|
||||
# Pre-cache one embedding
|
||||
await cache.put("text1", [0.5])
|
||||
|
||||
# Mock returns 2 embeddings for the 2 uncached texts
|
||||
mock_generator.generate_batch = AsyncMock(return_value=[[0.2], [0.3]])
|
||||
|
||||
results = await cached_gen.generate_batch(["text1", "text2", "text3"])
|
||||
|
||||
assert len(results) == 3
|
||||
assert results[0] == [0.5] # From cache
|
||||
assert results[1] == [0.2] # Generated
|
||||
assert results[2] == [0.3] # Generated
|
||||
|
||||
async def test_get_stats(self, cached_gen: CachedEmbeddingGenerator) -> None:
|
||||
"""Should return generator statistics."""
|
||||
await cached_gen.generate("text1")
|
||||
await cached_gen.generate("text1") # Cache hit
|
||||
|
||||
stats = cached_gen.get_stats()
|
||||
|
||||
assert stats["call_count"] == 2
|
||||
assert stats["cache_hit_count"] == 1
|
||||
|
||||
|
||||
class TestCreateEmbeddingCache:
|
||||
"""Tests for factory function."""
|
||||
|
||||
def test_creates_cache(self) -> None:
|
||||
"""Should create cache with defaults."""
|
||||
cache = create_embedding_cache()
|
||||
|
||||
assert cache.max_size == 50000
|
||||
|
||||
def test_creates_cache_with_options(self) -> None:
|
||||
"""Should create cache with custom options."""
|
||||
cache = create_embedding_cache(max_size=1000, default_ttl_seconds=600.0)
|
||||
|
||||
assert cache.max_size == 1000
|
||||
355
backend/tests/unit/services/memory/cache/test_hot_cache.py
vendored
Normal file
355
backend/tests/unit/services/memory/cache/test_hot_cache.py
vendored
Normal file
@@ -0,0 +1,355 @@
|
||||
# tests/unit/services/memory/cache/test_hot_cache.py
|
||||
"""Tests for HotMemoryCache."""
|
||||
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.cache.hot_cache import (
|
||||
CacheEntry,
|
||||
CacheKey,
|
||||
HotCacheStats,
|
||||
HotMemoryCache,
|
||||
create_hot_cache,
|
||||
)
|
||||
|
||||
|
||||
class TestCacheKey:
|
||||
"""Tests for CacheKey."""
|
||||
|
||||
def test_creates_key(self) -> None:
|
||||
"""Should create key with required fields."""
|
||||
key = CacheKey(memory_type="episodic", memory_id="123")
|
||||
|
||||
assert key.memory_type == "episodic"
|
||||
assert key.memory_id == "123"
|
||||
assert key.scope is None
|
||||
|
||||
def test_creates_key_with_scope(self) -> None:
|
||||
"""Should create key with scope."""
|
||||
key = CacheKey(memory_type="semantic", memory_id="456", scope="proj-123")
|
||||
|
||||
assert key.scope == "proj-123"
|
||||
|
||||
def test_hash_and_equality(self) -> None:
|
||||
"""Keys with same values should be equal and have same hash."""
|
||||
key1 = CacheKey(memory_type="episodic", memory_id="123", scope="proj-1")
|
||||
key2 = CacheKey(memory_type="episodic", memory_id="123", scope="proj-1")
|
||||
|
||||
assert key1 == key2
|
||||
assert hash(key1) == hash(key2)
|
||||
|
||||
def test_str_representation(self) -> None:
|
||||
"""Should produce readable string."""
|
||||
key = CacheKey(memory_type="episodic", memory_id="123", scope="proj-1")
|
||||
|
||||
assert str(key) == "episodic:proj-1:123"
|
||||
|
||||
def test_str_without_scope(self) -> None:
|
||||
"""Should produce string without scope."""
|
||||
key = CacheKey(memory_type="episodic", memory_id="123")
|
||||
|
||||
assert str(key) == "episodic:123"
|
||||
|
||||
|
||||
class TestCacheEntry:
|
||||
"""Tests for CacheEntry."""
|
||||
|
||||
def test_creates_entry(self) -> None:
|
||||
"""Should create entry with value."""
|
||||
entry = CacheEntry(
|
||||
value={"data": "test"},
|
||||
created_at=pytest.importorskip("datetime").datetime.now(
|
||||
pytest.importorskip("datetime").UTC
|
||||
),
|
||||
last_accessed_at=pytest.importorskip("datetime").datetime.now(
|
||||
pytest.importorskip("datetime").UTC
|
||||
),
|
||||
)
|
||||
|
||||
assert entry.value == {"data": "test"}
|
||||
assert entry.access_count == 1
|
||||
assert entry.ttl_seconds == 300.0
|
||||
|
||||
def test_is_expired(self) -> None:
|
||||
"""Should detect expired entries."""
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
old_time = datetime.now(UTC) - timedelta(seconds=400)
|
||||
entry = CacheEntry(
|
||||
value="test",
|
||||
created_at=old_time,
|
||||
last_accessed_at=old_time,
|
||||
ttl_seconds=300.0,
|
||||
)
|
||||
|
||||
assert entry.is_expired() is True
|
||||
|
||||
def test_not_expired(self) -> None:
|
||||
"""Should detect non-expired entries."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
entry = CacheEntry(
|
||||
value="test",
|
||||
created_at=datetime.now(UTC),
|
||||
last_accessed_at=datetime.now(UTC),
|
||||
ttl_seconds=300.0,
|
||||
)
|
||||
|
||||
assert entry.is_expired() is False
|
||||
|
||||
def test_touch_updates_access(self) -> None:
|
||||
"""Touch should update access time and count."""
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
old_time = datetime.now(UTC) - timedelta(seconds=10)
|
||||
entry = CacheEntry(
|
||||
value="test",
|
||||
created_at=old_time,
|
||||
last_accessed_at=old_time,
|
||||
access_count=5,
|
||||
)
|
||||
|
||||
entry.touch()
|
||||
|
||||
assert entry.access_count == 6
|
||||
assert entry.last_accessed_at > old_time
|
||||
|
||||
|
||||
class TestHotCacheStats:
|
||||
"""Tests for HotCacheStats."""
|
||||
|
||||
def test_hit_rate_calculation(self) -> None:
|
||||
"""Should calculate hit rate correctly."""
|
||||
stats = HotCacheStats(hits=80, misses=20)
|
||||
|
||||
assert stats.hit_rate == 0.8
|
||||
|
||||
def test_hit_rate_zero_requests(self) -> None:
|
||||
"""Should return 0 for no requests."""
|
||||
stats = HotCacheStats()
|
||||
|
||||
assert stats.hit_rate == 0.0
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Should convert to dictionary."""
|
||||
stats = HotCacheStats(hits=10, misses=5, evictions=2)
|
||||
|
||||
result = stats.to_dict()
|
||||
|
||||
assert result["hits"] == 10
|
||||
assert result["misses"] == 5
|
||||
assert result["evictions"] == 2
|
||||
assert "hit_rate" in result
|
||||
|
||||
|
||||
class TestHotMemoryCache:
|
||||
"""Tests for HotMemoryCache."""
|
||||
|
||||
@pytest.fixture
|
||||
def cache(self) -> HotMemoryCache[dict]:
|
||||
"""Create a hot memory cache."""
|
||||
return HotMemoryCache[dict](max_size=100, default_ttl_seconds=300.0)
|
||||
|
||||
def test_put_and_get(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should store and retrieve values."""
|
||||
key = CacheKey(memory_type="episodic", memory_id="123")
|
||||
value = {"data": "test"}
|
||||
|
||||
cache.put(key, value)
|
||||
result = cache.get(key)
|
||||
|
||||
assert result == value
|
||||
|
||||
def test_get_missing_key(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should return None for missing keys."""
|
||||
key = CacheKey(memory_type="episodic", memory_id="nonexistent")
|
||||
|
||||
result = cache.get(key)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_put_by_id(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should store by type and ID."""
|
||||
memory_id = uuid4()
|
||||
value = {"data": "test"}
|
||||
|
||||
cache.put_by_id("episodic", memory_id, value)
|
||||
result = cache.get_by_id("episodic", memory_id)
|
||||
|
||||
assert result == value
|
||||
|
||||
def test_put_by_id_with_scope(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should store with scope."""
|
||||
memory_id = uuid4()
|
||||
value = {"data": "test"}
|
||||
|
||||
cache.put_by_id("semantic", memory_id, value, scope="proj-123")
|
||||
result = cache.get_by_id("semantic", memory_id, scope="proj-123")
|
||||
|
||||
assert result == value
|
||||
|
||||
def test_lru_eviction(self) -> None:
|
||||
"""Should evict LRU entries when at capacity."""
|
||||
cache = HotMemoryCache[str](max_size=3)
|
||||
|
||||
# Fill cache
|
||||
cache.put_by_id("test", "1", "first")
|
||||
cache.put_by_id("test", "2", "second")
|
||||
cache.put_by_id("test", "3", "third")
|
||||
|
||||
# Access first to make it recent
|
||||
cache.get_by_id("test", "1")
|
||||
|
||||
# Add fourth, should evict second (LRU)
|
||||
cache.put_by_id("test", "4", "fourth")
|
||||
|
||||
assert cache.get_by_id("test", "1") is not None # Accessed, kept
|
||||
assert cache.get_by_id("test", "2") is None # Evicted (LRU)
|
||||
assert cache.get_by_id("test", "3") is not None
|
||||
assert cache.get_by_id("test", "4") is not None
|
||||
|
||||
def test_ttl_expiration(self) -> None:
|
||||
"""Should expire entries after TTL."""
|
||||
cache = HotMemoryCache[str](max_size=100, default_ttl_seconds=0.1)
|
||||
|
||||
cache.put_by_id("test", "1", "value")
|
||||
|
||||
# Wait for expiration
|
||||
time.sleep(0.2)
|
||||
|
||||
result = cache.get_by_id("test", "1")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_invalidate(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should invalidate specific entry."""
|
||||
key = CacheKey(memory_type="episodic", memory_id="123")
|
||||
cache.put(key, {"data": "test"})
|
||||
|
||||
result = cache.invalidate(key)
|
||||
|
||||
assert result is True
|
||||
assert cache.get(key) is None
|
||||
|
||||
def test_invalidate_by_id(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should invalidate by ID."""
|
||||
memory_id = uuid4()
|
||||
cache.put_by_id("episodic", memory_id, {"data": "test"})
|
||||
|
||||
result = cache.invalidate_by_id("episodic", memory_id)
|
||||
|
||||
assert result is True
|
||||
assert cache.get_by_id("episodic", memory_id) is None
|
||||
|
||||
def test_invalidate_by_type(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should invalidate all entries of a type."""
|
||||
cache.put_by_id("episodic", "1", {"data": "1"})
|
||||
cache.put_by_id("episodic", "2", {"data": "2"})
|
||||
cache.put_by_id("semantic", "3", {"data": "3"})
|
||||
|
||||
count = cache.invalidate_by_type("episodic")
|
||||
|
||||
assert count == 2
|
||||
assert cache.get_by_id("episodic", "1") is None
|
||||
assert cache.get_by_id("episodic", "2") is None
|
||||
assert cache.get_by_id("semantic", "3") is not None
|
||||
|
||||
def test_invalidate_by_scope(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should invalidate all entries in a scope."""
|
||||
cache.put_by_id("episodic", "1", {"data": "1"}, scope="proj-1")
|
||||
cache.put_by_id("semantic", "2", {"data": "2"}, scope="proj-1")
|
||||
cache.put_by_id("episodic", "3", {"data": "3"}, scope="proj-2")
|
||||
|
||||
count = cache.invalidate_by_scope("proj-1")
|
||||
|
||||
assert count == 2
|
||||
assert cache.get_by_id("episodic", "3", scope="proj-2") is not None
|
||||
|
||||
def test_invalidate_pattern(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should invalidate entries matching pattern."""
|
||||
cache.put_by_id("episodic", "123", {"data": "1"})
|
||||
cache.put_by_id("episodic", "124", {"data": "2"})
|
||||
cache.put_by_id("semantic", "125", {"data": "3"})
|
||||
|
||||
count = cache.invalidate_pattern("episodic:*")
|
||||
|
||||
assert count == 2
|
||||
|
||||
def test_clear(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should clear all entries."""
|
||||
cache.put_by_id("episodic", "1", {"data": "1"})
|
||||
cache.put_by_id("semantic", "2", {"data": "2"})
|
||||
|
||||
count = cache.clear()
|
||||
|
||||
assert count == 2
|
||||
assert cache.size == 0
|
||||
|
||||
def test_cleanup_expired(self) -> None:
|
||||
"""Should remove expired entries."""
|
||||
cache = HotMemoryCache[str](max_size=100, default_ttl_seconds=0.1)
|
||||
|
||||
cache.put_by_id("test", "1", "value1")
|
||||
cache.put_by_id("test", "2", "value2", ttl_seconds=10)
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
count = cache.cleanup_expired()
|
||||
|
||||
assert count == 1 # Only the first one expired
|
||||
assert cache.size == 1
|
||||
|
||||
def test_get_hot_memories(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should return most accessed memories."""
|
||||
cache.put_by_id("episodic", "1", {"data": "1"})
|
||||
cache.put_by_id("episodic", "2", {"data": "2"})
|
||||
|
||||
# Access first one multiple times
|
||||
for _ in range(5):
|
||||
cache.get_by_id("episodic", "1")
|
||||
|
||||
hot = cache.get_hot_memories(limit=2)
|
||||
|
||||
assert len(hot) == 2
|
||||
assert hot[0][1] >= hot[1][1] # Sorted by access count
|
||||
|
||||
def test_get_stats(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should return accurate statistics."""
|
||||
cache.put_by_id("episodic", "1", {"data": "1"})
|
||||
cache.get_by_id("episodic", "1") # Hit
|
||||
cache.get_by_id("episodic", "2") # Miss
|
||||
|
||||
stats = cache.get_stats()
|
||||
|
||||
assert stats.hits == 1
|
||||
assert stats.misses == 1
|
||||
assert stats.current_size == 1
|
||||
|
||||
def test_reset_stats(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should reset statistics."""
|
||||
cache.put_by_id("episodic", "1", {"data": "1"})
|
||||
cache.get_by_id("episodic", "1")
|
||||
|
||||
cache.reset_stats()
|
||||
stats = cache.get_stats()
|
||||
|
||||
assert stats.hits == 0
|
||||
assert stats.misses == 0
|
||||
|
||||
|
||||
class TestCreateHotCache:
|
||||
"""Tests for factory function."""
|
||||
|
||||
def test_creates_cache(self) -> None:
|
||||
"""Should create cache with defaults."""
|
||||
cache = create_hot_cache()
|
||||
|
||||
assert cache.max_size == 10000
|
||||
|
||||
def test_creates_cache_with_options(self) -> None:
|
||||
"""Should create cache with custom options."""
|
||||
cache = create_hot_cache(max_size=500, default_ttl_seconds=60.0)
|
||||
|
||||
assert cache.max_size == 500
|
||||
Reference in New Issue
Block a user