Files
syndarix/backend/tests/unit/services/memory/indexing/test_retrieval.py
Felipe Cardoso 999b7ac03f feat(memory): implement memory indexing and retrieval engine (#94)
Add comprehensive indexing and retrieval system for memory search:
- VectorIndex for semantic similarity search using cosine similarity
- TemporalIndex for time-based queries with range and recency support
- EntityIndex for entity-based lookups with multi-entity intersection
- OutcomeIndex for success/failure filtering on episodes
- MemoryIndexer as unified interface for all index types
- RetrievalEngine with hybrid search combining all indices
- RelevanceScorer for multi-signal relevance scoring
- RetrievalCache for LRU caching of search results

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 02:50:13 +01:00

451 lines
13 KiB
Python

# tests/unit/services/memory/indexing/test_retrieval.py
"""Unit tests for memory retrieval."""
from datetime import UTC, datetime, timedelta
from uuid import uuid4
import pytest
from app.services.memory.indexing.index import MemoryIndexer
from app.services.memory.indexing.retrieval import (
RelevanceScorer,
RetrievalCache,
RetrievalEngine,
RetrievalQuery,
ScoredResult,
get_retrieval_engine,
)
from app.services.memory.types import Episode, MemoryType, Outcome
def _utcnow() -> datetime:
"""Get current UTC time."""
return datetime.now(UTC)
def make_episode(
embedding: list[float] | None = None,
outcome: Outcome = Outcome.SUCCESS,
occurred_at: datetime | None = None,
task_type: str = "test_task",
) -> Episode:
"""Create a test episode."""
return Episode(
id=uuid4(),
project_id=uuid4(),
agent_instance_id=uuid4(),
agent_type_id=uuid4(),
session_id="test-session",
task_type=task_type,
task_description="Test task description",
actions=[{"action": "test"}],
context_summary="Test context",
outcome=outcome,
outcome_details="Test outcome",
duration_seconds=10.0,
tokens_used=100,
lessons_learned=["lesson1"],
importance_score=0.8,
embedding=embedding,
occurred_at=occurred_at or _utcnow(),
created_at=_utcnow(),
updated_at=_utcnow(),
)
class TestRetrievalQuery:
"""Tests for RetrievalQuery."""
def test_default_values(self) -> None:
"""Test default query values."""
query = RetrievalQuery()
assert query.query_text is None
assert query.limit == 10
assert query.min_relevance == 0.0
assert query.use_vector is True
assert query.use_temporal is True
def test_cache_key_generation(self) -> None:
"""Test cache key generation."""
query1 = RetrievalQuery(query_text="test", limit=10)
query2 = RetrievalQuery(query_text="test", limit=10)
query3 = RetrievalQuery(query_text="different", limit=10)
# Same queries should have same key
assert query1.to_cache_key() == query2.to_cache_key()
# Different queries should have different keys
assert query1.to_cache_key() != query3.to_cache_key()
class TestScoredResult:
"""Tests for ScoredResult."""
def test_creation(self) -> None:
"""Test creating a scored result."""
result = ScoredResult(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
relevance_score=0.85,
score_breakdown={"vector": 0.9, "recency": 0.8},
)
assert result.relevance_score == 0.85
assert result.score_breakdown["vector"] == 0.9
class TestRelevanceScorer:
"""Tests for RelevanceScorer."""
@pytest.fixture
def scorer(self) -> RelevanceScorer:
"""Create a relevance scorer."""
return RelevanceScorer()
def test_score_with_vector(self, scorer: RelevanceScorer) -> None:
"""Test scoring with vector similarity."""
result = scorer.score(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
vector_similarity=0.9,
)
assert result.relevance_score > 0
assert result.score_breakdown["vector"] == 0.9
def test_score_with_recency(self, scorer: RelevanceScorer) -> None:
"""Test scoring with recency."""
recent_result = scorer.score(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
timestamp=_utcnow(),
)
old_result = scorer.score(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
timestamp=_utcnow() - timedelta(days=7),
)
# Recent should have higher recency score
assert (
recent_result.score_breakdown["recency"]
> old_result.score_breakdown["recency"]
)
def test_score_with_outcome_preference(self, scorer: RelevanceScorer) -> None:
"""Test scoring with outcome preference."""
success_result = scorer.score(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
outcome=Outcome.SUCCESS,
preferred_outcomes=[Outcome.SUCCESS],
)
failure_result = scorer.score(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
outcome=Outcome.FAILURE,
preferred_outcomes=[Outcome.SUCCESS],
)
assert success_result.score_breakdown["outcome"] == 1.0
assert failure_result.score_breakdown["outcome"] == 0.0
def test_score_with_entity_match(self, scorer: RelevanceScorer) -> None:
"""Test scoring with entity matches."""
full_match = scorer.score(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
entity_match_count=3,
entity_total=3,
)
partial_match = scorer.score(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
entity_match_count=1,
entity_total=3,
)
assert (
full_match.score_breakdown["entity"]
> partial_match.score_breakdown["entity"]
)
class TestRetrievalCache:
"""Tests for RetrievalCache."""
@pytest.fixture
def cache(self) -> RetrievalCache:
"""Create a retrieval cache."""
return RetrievalCache(max_entries=10, default_ttl_seconds=60)
def test_put_and_get(self, cache: RetrievalCache) -> None:
"""Test putting and getting from cache."""
results = [
ScoredResult(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
relevance_score=0.8,
)
]
cache.put("test_key", results)
cached = cache.get("test_key")
assert cached is not None
assert len(cached) == 1
def test_get_nonexistent(self, cache: RetrievalCache) -> None:
"""Test getting nonexistent entry."""
result = cache.get("nonexistent")
assert result is None
def test_lru_eviction(self) -> None:
"""Test LRU eviction when at capacity."""
cache = RetrievalCache(max_entries=2, default_ttl_seconds=60)
results = [
ScoredResult(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
relevance_score=0.8,
)
]
cache.put("key1", results)
cache.put("key2", results)
cache.put("key3", results) # Should evict key1
assert cache.get("key1") is None
assert cache.get("key2") is not None
assert cache.get("key3") is not None
def test_invalidate(self, cache: RetrievalCache) -> None:
"""Test invalidating a cache entry."""
results = [
ScoredResult(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
relevance_score=0.8,
)
]
cache.put("test_key", results)
removed = cache.invalidate("test_key")
assert removed is True
assert cache.get("test_key") is None
def test_invalidate_by_memory(self, cache: RetrievalCache) -> None:
"""Test invalidating by memory ID."""
memory_id = uuid4()
results = [
ScoredResult(
memory_id=memory_id,
memory_type=MemoryType.EPISODIC,
relevance_score=0.8,
)
]
cache.put("key1", results)
cache.put("key2", results)
count = cache.invalidate_by_memory(memory_id)
assert count == 2
assert cache.get("key1") is None
assert cache.get("key2") is None
def test_clear(self, cache: RetrievalCache) -> None:
"""Test clearing the cache."""
results = [
ScoredResult(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
relevance_score=0.8,
)
]
cache.put("key1", results)
cache.put("key2", results)
count = cache.clear()
assert count == 2
assert cache.get("key1") is None
def test_get_stats(self, cache: RetrievalCache) -> None:
"""Test getting cache statistics."""
stats = cache.get_stats()
assert "total_entries" in stats
assert "max_entries" in stats
assert stats["max_entries"] == 10
class TestRetrievalEngine:
"""Tests for RetrievalEngine."""
@pytest.fixture
def indexer(self) -> MemoryIndexer:
"""Create a memory indexer."""
return MemoryIndexer()
@pytest.fixture
def engine(self, indexer: MemoryIndexer) -> RetrievalEngine:
"""Create a retrieval engine."""
return RetrievalEngine(indexer=indexer, enable_cache=True)
@pytest.mark.asyncio
async def test_retrieve_by_vector(
self, engine: RetrievalEngine, indexer: MemoryIndexer
) -> None:
"""Test retrieval by vector similarity."""
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
e2 = make_episode(embedding=[0.9, 0.1, 0.0, 0.0])
e3 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0])
await indexer.index(e1)
await indexer.index(e2)
await indexer.index(e3)
query = RetrievalQuery(
query_embedding=[1.0, 0.0, 0.0, 0.0],
limit=2,
use_temporal=False,
use_entity=False,
use_outcome=False,
)
result = await engine.retrieve(query)
assert len(result.items) > 0
assert result.retrieval_type == "hybrid"
@pytest.mark.asyncio
async def test_retrieve_recent(
self, engine: RetrievalEngine, indexer: MemoryIndexer
) -> None:
"""Test retrieval of recent items."""
now = _utcnow()
old = make_episode(occurred_at=now - timedelta(hours=2))
recent = make_episode(occurred_at=now - timedelta(minutes=30))
await indexer.index(old)
await indexer.index(recent)
result = await engine.retrieve_recent(hours=1)
assert len(result.items) == 1
@pytest.mark.asyncio
async def test_retrieve_by_entity(
self, engine: RetrievalEngine, indexer: MemoryIndexer
) -> None:
"""Test retrieval by entity."""
e1 = make_episode(task_type="deploy")
e2 = make_episode(task_type="test")
await indexer.index(e1)
await indexer.index(e2)
result = await engine.retrieve_by_entity("task_type", "deploy")
assert len(result.items) == 1
@pytest.mark.asyncio
async def test_retrieve_successful(
self, engine: RetrievalEngine, indexer: MemoryIndexer
) -> None:
"""Test retrieval of successful items."""
success = make_episode(outcome=Outcome.SUCCESS)
failure = make_episode(outcome=Outcome.FAILURE)
await indexer.index(success)
await indexer.index(failure)
result = await engine.retrieve_successful()
assert len(result.items) == 1
# Check outcome index was used
assert result.items[0].memory_id == success.id
@pytest.mark.asyncio
async def test_retrieve_with_cache(
self, engine: RetrievalEngine, indexer: MemoryIndexer
) -> None:
"""Test that retrieval uses cache."""
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
await indexer.index(episode)
query = RetrievalQuery(
query_embedding=[1.0, 0.0, 0.0, 0.0],
limit=10,
)
# First retrieval
result1 = await engine.retrieve(query)
assert result1.metadata.get("cache_hit") is False
# Second retrieval should be cached
result2 = await engine.retrieve(query)
assert result2.metadata.get("cache_hit") is True
@pytest.mark.asyncio
async def test_invalidate_cache(
self, engine: RetrievalEngine, indexer: MemoryIndexer
) -> None:
"""Test cache invalidation."""
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
await indexer.index(episode)
query = RetrievalQuery(query_embedding=[1.0, 0.0, 0.0, 0.0])
await engine.retrieve(query)
count = engine.invalidate_cache()
assert count > 0
@pytest.mark.asyncio
async def test_retrieve_similar(
self, engine: RetrievalEngine, indexer: MemoryIndexer
) -> None:
"""Test retrieve_similar convenience method."""
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
e2 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0])
await indexer.index(e1)
await indexer.index(e2)
result = await engine.retrieve_similar(
embedding=[1.0, 0.0, 0.0, 0.0],
limit=1,
)
assert len(result.items) == 1
def test_get_cache_stats(self, engine: RetrievalEngine) -> None:
"""Test getting cache statistics."""
stats = engine.get_cache_stats()
assert "total_entries" in stats
class TestGetRetrievalEngine:
"""Tests for singleton getter."""
def test_returns_instance(self) -> None:
"""Test that getter returns instance."""
engine = get_retrieval_engine()
assert engine is not None
assert isinstance(engine, RetrievalEngine)
def test_returns_same_instance(self) -> None:
"""Test that getter returns same instance."""
engine1 = get_retrieval_engine()
engine2 = get_retrieval_engine()
assert engine1 is engine2