forked from cardosofelipe/fast-next-template
Add comprehensive indexing and retrieval system for memory search: - VectorIndex for semantic similarity search using cosine similarity - TemporalIndex for time-based queries with range and recency support - EntityIndex for entity-based lookups with multi-entity intersection - OutcomeIndex for success/failure filtering on episodes - MemoryIndexer as unified interface for all index types - RetrievalEngine with hybrid search combining all indices - RelevanceScorer for multi-signal relevance scoring - RetrievalCache for LRU caching of search results 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
451 lines
13 KiB
Python
451 lines
13 KiB
Python
# tests/unit/services/memory/indexing/test_retrieval.py
|
|
"""Unit tests for memory retrieval."""
|
|
|
|
from datetime import UTC, datetime, timedelta
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
|
|
from app.services.memory.indexing.index import MemoryIndexer
|
|
from app.services.memory.indexing.retrieval import (
|
|
RelevanceScorer,
|
|
RetrievalCache,
|
|
RetrievalEngine,
|
|
RetrievalQuery,
|
|
ScoredResult,
|
|
get_retrieval_engine,
|
|
)
|
|
from app.services.memory.types import Episode, MemoryType, Outcome
|
|
|
|
|
|
def _utcnow() -> datetime:
|
|
"""Get current UTC time."""
|
|
return datetime.now(UTC)
|
|
|
|
|
|
def make_episode(
|
|
embedding: list[float] | None = None,
|
|
outcome: Outcome = Outcome.SUCCESS,
|
|
occurred_at: datetime | None = None,
|
|
task_type: str = "test_task",
|
|
) -> Episode:
|
|
"""Create a test episode."""
|
|
return Episode(
|
|
id=uuid4(),
|
|
project_id=uuid4(),
|
|
agent_instance_id=uuid4(),
|
|
agent_type_id=uuid4(),
|
|
session_id="test-session",
|
|
task_type=task_type,
|
|
task_description="Test task description",
|
|
actions=[{"action": "test"}],
|
|
context_summary="Test context",
|
|
outcome=outcome,
|
|
outcome_details="Test outcome",
|
|
duration_seconds=10.0,
|
|
tokens_used=100,
|
|
lessons_learned=["lesson1"],
|
|
importance_score=0.8,
|
|
embedding=embedding,
|
|
occurred_at=occurred_at or _utcnow(),
|
|
created_at=_utcnow(),
|
|
updated_at=_utcnow(),
|
|
)
|
|
|
|
|
|
class TestRetrievalQuery:
|
|
"""Tests for RetrievalQuery."""
|
|
|
|
def test_default_values(self) -> None:
|
|
"""Test default query values."""
|
|
query = RetrievalQuery()
|
|
|
|
assert query.query_text is None
|
|
assert query.limit == 10
|
|
assert query.min_relevance == 0.0
|
|
assert query.use_vector is True
|
|
assert query.use_temporal is True
|
|
|
|
def test_cache_key_generation(self) -> None:
|
|
"""Test cache key generation."""
|
|
query1 = RetrievalQuery(query_text="test", limit=10)
|
|
query2 = RetrievalQuery(query_text="test", limit=10)
|
|
query3 = RetrievalQuery(query_text="different", limit=10)
|
|
|
|
# Same queries should have same key
|
|
assert query1.to_cache_key() == query2.to_cache_key()
|
|
# Different queries should have different keys
|
|
assert query1.to_cache_key() != query3.to_cache_key()
|
|
|
|
|
|
class TestScoredResult:
|
|
"""Tests for ScoredResult."""
|
|
|
|
def test_creation(self) -> None:
|
|
"""Test creating a scored result."""
|
|
result = ScoredResult(
|
|
memory_id=uuid4(),
|
|
memory_type=MemoryType.EPISODIC,
|
|
relevance_score=0.85,
|
|
score_breakdown={"vector": 0.9, "recency": 0.8},
|
|
)
|
|
|
|
assert result.relevance_score == 0.85
|
|
assert result.score_breakdown["vector"] == 0.9
|
|
|
|
|
|
class TestRelevanceScorer:
|
|
"""Tests for RelevanceScorer."""
|
|
|
|
@pytest.fixture
|
|
def scorer(self) -> RelevanceScorer:
|
|
"""Create a relevance scorer."""
|
|
return RelevanceScorer()
|
|
|
|
def test_score_with_vector(self, scorer: RelevanceScorer) -> None:
|
|
"""Test scoring with vector similarity."""
|
|
result = scorer.score(
|
|
memory_id=uuid4(),
|
|
memory_type=MemoryType.EPISODIC,
|
|
vector_similarity=0.9,
|
|
)
|
|
|
|
assert result.relevance_score > 0
|
|
assert result.score_breakdown["vector"] == 0.9
|
|
|
|
def test_score_with_recency(self, scorer: RelevanceScorer) -> None:
|
|
"""Test scoring with recency."""
|
|
recent_result = scorer.score(
|
|
memory_id=uuid4(),
|
|
memory_type=MemoryType.EPISODIC,
|
|
timestamp=_utcnow(),
|
|
)
|
|
|
|
old_result = scorer.score(
|
|
memory_id=uuid4(),
|
|
memory_type=MemoryType.EPISODIC,
|
|
timestamp=_utcnow() - timedelta(days=7),
|
|
)
|
|
|
|
# Recent should have higher recency score
|
|
assert (
|
|
recent_result.score_breakdown["recency"]
|
|
> old_result.score_breakdown["recency"]
|
|
)
|
|
|
|
def test_score_with_outcome_preference(self, scorer: RelevanceScorer) -> None:
|
|
"""Test scoring with outcome preference."""
|
|
success_result = scorer.score(
|
|
memory_id=uuid4(),
|
|
memory_type=MemoryType.EPISODIC,
|
|
outcome=Outcome.SUCCESS,
|
|
preferred_outcomes=[Outcome.SUCCESS],
|
|
)
|
|
|
|
failure_result = scorer.score(
|
|
memory_id=uuid4(),
|
|
memory_type=MemoryType.EPISODIC,
|
|
outcome=Outcome.FAILURE,
|
|
preferred_outcomes=[Outcome.SUCCESS],
|
|
)
|
|
|
|
assert success_result.score_breakdown["outcome"] == 1.0
|
|
assert failure_result.score_breakdown["outcome"] == 0.0
|
|
|
|
def test_score_with_entity_match(self, scorer: RelevanceScorer) -> None:
|
|
"""Test scoring with entity matches."""
|
|
full_match = scorer.score(
|
|
memory_id=uuid4(),
|
|
memory_type=MemoryType.EPISODIC,
|
|
entity_match_count=3,
|
|
entity_total=3,
|
|
)
|
|
|
|
partial_match = scorer.score(
|
|
memory_id=uuid4(),
|
|
memory_type=MemoryType.EPISODIC,
|
|
entity_match_count=1,
|
|
entity_total=3,
|
|
)
|
|
|
|
assert (
|
|
full_match.score_breakdown["entity"]
|
|
> partial_match.score_breakdown["entity"]
|
|
)
|
|
|
|
|
|
class TestRetrievalCache:
|
|
"""Tests for RetrievalCache."""
|
|
|
|
@pytest.fixture
|
|
def cache(self) -> RetrievalCache:
|
|
"""Create a retrieval cache."""
|
|
return RetrievalCache(max_entries=10, default_ttl_seconds=60)
|
|
|
|
def test_put_and_get(self, cache: RetrievalCache) -> None:
|
|
"""Test putting and getting from cache."""
|
|
results = [
|
|
ScoredResult(
|
|
memory_id=uuid4(),
|
|
memory_type=MemoryType.EPISODIC,
|
|
relevance_score=0.8,
|
|
)
|
|
]
|
|
|
|
cache.put("test_key", results)
|
|
cached = cache.get("test_key")
|
|
|
|
assert cached is not None
|
|
assert len(cached) == 1
|
|
|
|
def test_get_nonexistent(self, cache: RetrievalCache) -> None:
|
|
"""Test getting nonexistent entry."""
|
|
result = cache.get("nonexistent")
|
|
assert result is None
|
|
|
|
def test_lru_eviction(self) -> None:
|
|
"""Test LRU eviction when at capacity."""
|
|
cache = RetrievalCache(max_entries=2, default_ttl_seconds=60)
|
|
|
|
results = [
|
|
ScoredResult(
|
|
memory_id=uuid4(),
|
|
memory_type=MemoryType.EPISODIC,
|
|
relevance_score=0.8,
|
|
)
|
|
]
|
|
|
|
cache.put("key1", results)
|
|
cache.put("key2", results)
|
|
cache.put("key3", results) # Should evict key1
|
|
|
|
assert cache.get("key1") is None
|
|
assert cache.get("key2") is not None
|
|
assert cache.get("key3") is not None
|
|
|
|
def test_invalidate(self, cache: RetrievalCache) -> None:
|
|
"""Test invalidating a cache entry."""
|
|
results = [
|
|
ScoredResult(
|
|
memory_id=uuid4(),
|
|
memory_type=MemoryType.EPISODIC,
|
|
relevance_score=0.8,
|
|
)
|
|
]
|
|
|
|
cache.put("test_key", results)
|
|
removed = cache.invalidate("test_key")
|
|
|
|
assert removed is True
|
|
assert cache.get("test_key") is None
|
|
|
|
def test_invalidate_by_memory(self, cache: RetrievalCache) -> None:
|
|
"""Test invalidating by memory ID."""
|
|
memory_id = uuid4()
|
|
results = [
|
|
ScoredResult(
|
|
memory_id=memory_id,
|
|
memory_type=MemoryType.EPISODIC,
|
|
relevance_score=0.8,
|
|
)
|
|
]
|
|
|
|
cache.put("key1", results)
|
|
cache.put("key2", results)
|
|
|
|
count = cache.invalidate_by_memory(memory_id)
|
|
|
|
assert count == 2
|
|
assert cache.get("key1") is None
|
|
assert cache.get("key2") is None
|
|
|
|
def test_clear(self, cache: RetrievalCache) -> None:
|
|
"""Test clearing the cache."""
|
|
results = [
|
|
ScoredResult(
|
|
memory_id=uuid4(),
|
|
memory_type=MemoryType.EPISODIC,
|
|
relevance_score=0.8,
|
|
)
|
|
]
|
|
|
|
cache.put("key1", results)
|
|
cache.put("key2", results)
|
|
|
|
count = cache.clear()
|
|
|
|
assert count == 2
|
|
assert cache.get("key1") is None
|
|
|
|
def test_get_stats(self, cache: RetrievalCache) -> None:
|
|
"""Test getting cache statistics."""
|
|
stats = cache.get_stats()
|
|
|
|
assert "total_entries" in stats
|
|
assert "max_entries" in stats
|
|
assert stats["max_entries"] == 10
|
|
|
|
|
|
class TestRetrievalEngine:
|
|
"""Tests for RetrievalEngine."""
|
|
|
|
@pytest.fixture
|
|
def indexer(self) -> MemoryIndexer:
|
|
"""Create a memory indexer."""
|
|
return MemoryIndexer()
|
|
|
|
@pytest.fixture
|
|
def engine(self, indexer: MemoryIndexer) -> RetrievalEngine:
|
|
"""Create a retrieval engine."""
|
|
return RetrievalEngine(indexer=indexer, enable_cache=True)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retrieve_by_vector(
|
|
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
|
) -> None:
|
|
"""Test retrieval by vector similarity."""
|
|
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
|
e2 = make_episode(embedding=[0.9, 0.1, 0.0, 0.0])
|
|
e3 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0])
|
|
|
|
await indexer.index(e1)
|
|
await indexer.index(e2)
|
|
await indexer.index(e3)
|
|
|
|
query = RetrievalQuery(
|
|
query_embedding=[1.0, 0.0, 0.0, 0.0],
|
|
limit=2,
|
|
use_temporal=False,
|
|
use_entity=False,
|
|
use_outcome=False,
|
|
)
|
|
|
|
result = await engine.retrieve(query)
|
|
|
|
assert len(result.items) > 0
|
|
assert result.retrieval_type == "hybrid"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retrieve_recent(
|
|
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
|
) -> None:
|
|
"""Test retrieval of recent items."""
|
|
now = _utcnow()
|
|
old = make_episode(occurred_at=now - timedelta(hours=2))
|
|
recent = make_episode(occurred_at=now - timedelta(minutes=30))
|
|
|
|
await indexer.index(old)
|
|
await indexer.index(recent)
|
|
|
|
result = await engine.retrieve_recent(hours=1)
|
|
|
|
assert len(result.items) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retrieve_by_entity(
|
|
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
|
) -> None:
|
|
"""Test retrieval by entity."""
|
|
e1 = make_episode(task_type="deploy")
|
|
e2 = make_episode(task_type="test")
|
|
|
|
await indexer.index(e1)
|
|
await indexer.index(e2)
|
|
|
|
result = await engine.retrieve_by_entity("task_type", "deploy")
|
|
|
|
assert len(result.items) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retrieve_successful(
|
|
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
|
) -> None:
|
|
"""Test retrieval of successful items."""
|
|
success = make_episode(outcome=Outcome.SUCCESS)
|
|
failure = make_episode(outcome=Outcome.FAILURE)
|
|
|
|
await indexer.index(success)
|
|
await indexer.index(failure)
|
|
|
|
result = await engine.retrieve_successful()
|
|
|
|
assert len(result.items) == 1
|
|
# Check outcome index was used
|
|
assert result.items[0].memory_id == success.id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retrieve_with_cache(
|
|
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
|
) -> None:
|
|
"""Test that retrieval uses cache."""
|
|
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
|
await indexer.index(episode)
|
|
|
|
query = RetrievalQuery(
|
|
query_embedding=[1.0, 0.0, 0.0, 0.0],
|
|
limit=10,
|
|
)
|
|
|
|
# First retrieval
|
|
result1 = await engine.retrieve(query)
|
|
assert result1.metadata.get("cache_hit") is False
|
|
|
|
# Second retrieval should be cached
|
|
result2 = await engine.retrieve(query)
|
|
assert result2.metadata.get("cache_hit") is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalidate_cache(
|
|
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
|
) -> None:
|
|
"""Test cache invalidation."""
|
|
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
|
await indexer.index(episode)
|
|
|
|
query = RetrievalQuery(query_embedding=[1.0, 0.0, 0.0, 0.0])
|
|
await engine.retrieve(query)
|
|
|
|
count = engine.invalidate_cache()
|
|
|
|
assert count > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retrieve_similar(
|
|
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
|
) -> None:
|
|
"""Test retrieve_similar convenience method."""
|
|
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
|
e2 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0])
|
|
|
|
await indexer.index(e1)
|
|
await indexer.index(e2)
|
|
|
|
result = await engine.retrieve_similar(
|
|
embedding=[1.0, 0.0, 0.0, 0.0],
|
|
limit=1,
|
|
)
|
|
|
|
assert len(result.items) == 1
|
|
|
|
def test_get_cache_stats(self, engine: RetrievalEngine) -> None:
|
|
"""Test getting cache statistics."""
|
|
stats = engine.get_cache_stats()
|
|
|
|
assert "total_entries" in stats
|
|
|
|
|
|
class TestGetRetrievalEngine:
|
|
"""Tests for singleton getter."""
|
|
|
|
def test_returns_instance(self) -> None:
|
|
"""Test that getter returns instance."""
|
|
engine = get_retrieval_engine()
|
|
assert engine is not None
|
|
assert isinstance(engine, RetrievalEngine)
|
|
|
|
def test_returns_same_instance(self) -> None:
|
|
"""Test that getter returns same instance."""
|
|
engine1 = get_retrieval_engine()
|
|
engine2 = get_retrieval_engine()
|
|
assert engine1 is engine2
|