feat(memory): implement memory indexing and retrieval engine (#94)

Add comprehensive indexing and retrieval system for memory search:
- VectorIndex for semantic similarity search using cosine similarity
- TemporalIndex for time-based queries with range and recency support
- EntityIndex for entity-based lookups with multi-entity intersection
- OutcomeIndex for success/failure filtering on episodes
- MemoryIndexer as unified interface for all index types
- RetrievalEngine with hybrid search combining all indices
- RelevanceScorer for multi-signal relevance scoring
- RetrievalCache for LRU caching of search results

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-05 02:50:13 +01:00
parent 48ecb40f18
commit 999b7ac03f
6 changed files with 2602 additions and 3 deletions

View File

@@ -0,0 +1,2 @@
# tests/unit/services/memory/indexing/__init__.py
"""Unit tests for memory indexing."""

View File

@@ -0,0 +1,497 @@
# tests/unit/services/memory/indexing/test_index.py
"""Unit tests for memory indexing."""
from datetime import UTC, datetime, timedelta
from uuid import uuid4
import pytest
from app.services.memory.indexing.index import (
EntityIndex,
MemoryIndexer,
OutcomeIndex,
TemporalIndex,
VectorIndex,
get_memory_indexer,
)
from app.services.memory.types import Episode, Fact, MemoryType, Outcome, Procedure
def _utcnow() -> datetime:
"""Get current UTC time."""
return datetime.now(UTC)
def make_episode(
embedding: list[float] | None = None,
outcome: Outcome = Outcome.SUCCESS,
occurred_at: datetime | None = None,
) -> Episode:
"""Create a test episode."""
return Episode(
id=uuid4(),
project_id=uuid4(),
agent_instance_id=uuid4(),
agent_type_id=uuid4(),
session_id="test-session",
task_type="test_task",
task_description="Test task description",
actions=[{"action": "test"}],
context_summary="Test context",
outcome=outcome,
outcome_details="Test outcome",
duration_seconds=10.0,
tokens_used=100,
lessons_learned=["lesson1"],
importance_score=0.8,
embedding=embedding,
occurred_at=occurred_at or _utcnow(),
created_at=_utcnow(),
updated_at=_utcnow(),
)
def make_fact(
embedding: list[float] | None = None,
subject: str = "test_subject",
predicate: str = "has_property",
obj: str = "test_value",
) -> Fact:
"""Create a test fact."""
return Fact(
id=uuid4(),
project_id=uuid4(),
subject=subject,
predicate=predicate,
object=obj,
confidence=0.9,
source_episode_ids=[uuid4()],
first_learned=_utcnow(),
last_reinforced=_utcnow(),
reinforcement_count=1,
embedding=embedding,
created_at=_utcnow(),
updated_at=_utcnow(),
)
def make_procedure(
embedding: list[float] | None = None,
success_count: int = 8,
failure_count: int = 2,
) -> Procedure:
"""Create a test procedure."""
return Procedure(
id=uuid4(),
project_id=uuid4(),
agent_type_id=uuid4(),
name="test_procedure",
trigger_pattern="test.*",
steps=[{"step": 1, "action": "test"}],
success_count=success_count,
failure_count=failure_count,
last_used=_utcnow(),
embedding=embedding,
created_at=_utcnow(),
updated_at=_utcnow(),
)
class TestVectorIndex:
"""Tests for VectorIndex."""
@pytest.fixture
def index(self) -> VectorIndex[Episode]:
"""Create a vector index."""
return VectorIndex[Episode](dimension=4)
@pytest.mark.asyncio
async def test_add_item(self, index: VectorIndex[Episode]) -> None:
"""Test adding an item to the index."""
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
entry = await index.add(episode)
assert entry.memory_id == episode.id
assert entry.memory_type == MemoryType.EPISODIC
assert entry.dimension == 4
assert await index.count() == 1
@pytest.mark.asyncio
async def test_remove_item(self, index: VectorIndex[Episode]) -> None:
"""Test removing an item from the index."""
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
await index.add(episode)
result = await index.remove(episode.id)
assert result is True
assert await index.count() == 0
@pytest.mark.asyncio
async def test_remove_nonexistent(self, index: VectorIndex[Episode]) -> None:
"""Test removing a nonexistent item."""
result = await index.remove(uuid4())
assert result is False
@pytest.mark.asyncio
async def test_search_similar(self, index: VectorIndex[Episode]) -> None:
"""Test searching for similar items."""
# Add items with different embeddings
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
e2 = make_episode(embedding=[0.9, 0.1, 0.0, 0.0])
e3 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0])
await index.add(e1)
await index.add(e2)
await index.add(e3)
# Search for similar to first
results = await index.search([1.0, 0.0, 0.0, 0.0], limit=2)
assert len(results) == 2
# First result should be most similar
assert results[0].memory_id == e1.id
@pytest.mark.asyncio
async def test_search_min_similarity(self, index: VectorIndex[Episode]) -> None:
"""Test minimum similarity threshold."""
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
e2 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0]) # Orthogonal
await index.add(e1)
await index.add(e2)
# Search with high threshold
results = await index.search([1.0, 0.0, 0.0, 0.0], min_similarity=0.9)
assert len(results) == 1
assert results[0].memory_id == e1.id
@pytest.mark.asyncio
async def test_search_empty_query(self, index: VectorIndex[Episode]) -> None:
"""Test search with empty query."""
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
await index.add(e1)
results = await index.search([], limit=10)
assert len(results) == 0
@pytest.mark.asyncio
async def test_clear(self, index: VectorIndex[Episode]) -> None:
"""Test clearing the index."""
await index.add(make_episode(embedding=[1.0, 0.0, 0.0, 0.0]))
await index.add(make_episode(embedding=[0.0, 1.0, 0.0, 0.0]))
count = await index.clear()
assert count == 2
assert await index.count() == 0
class TestTemporalIndex:
"""Tests for TemporalIndex."""
@pytest.fixture
def index(self) -> TemporalIndex[Episode]:
"""Create a temporal index."""
return TemporalIndex[Episode]()
@pytest.mark.asyncio
async def test_add_item(self, index: TemporalIndex[Episode]) -> None:
"""Test adding an item."""
episode = make_episode()
entry = await index.add(episode)
assert entry.memory_id == episode.id
assert await index.count() == 1
@pytest.mark.asyncio
async def test_search_by_time_range(self, index: TemporalIndex[Episode]) -> None:
"""Test searching by time range."""
now = _utcnow()
old = make_episode(occurred_at=now - timedelta(hours=2))
recent = make_episode(occurred_at=now - timedelta(hours=1))
newest = make_episode(occurred_at=now)
await index.add(old)
await index.add(recent)
await index.add(newest)
# Search last hour
results = await index.search(
query=None,
start_time=now - timedelta(hours=1, minutes=30),
end_time=now,
)
assert len(results) == 2
@pytest.mark.asyncio
async def test_search_recent(self, index: TemporalIndex[Episode]) -> None:
"""Test searching for recent items."""
now = _utcnow()
old = make_episode(occurred_at=now - timedelta(hours=2))
recent = make_episode(occurred_at=now - timedelta(minutes=30))
await index.add(old)
await index.add(recent)
# Search last hour (3600 seconds)
results = await index.search(query=None, recent_seconds=3600)
assert len(results) == 1
assert results[0].memory_id == recent.id
@pytest.mark.asyncio
async def test_search_order(self, index: TemporalIndex[Episode]) -> None:
"""Test result ordering."""
now = _utcnow()
e1 = make_episode(occurred_at=now - timedelta(hours=2))
e2 = make_episode(occurred_at=now - timedelta(hours=1))
e3 = make_episode(occurred_at=now)
await index.add(e1)
await index.add(e2)
await index.add(e3)
# Descending order (newest first)
results_desc = await index.search(query=None, order="desc", limit=10)
assert results_desc[0].memory_id == e3.id
# Ascending order (oldest first)
results_asc = await index.search(query=None, order="asc", limit=10)
assert results_asc[0].memory_id == e1.id
class TestEntityIndex:
"""Tests for EntityIndex."""
@pytest.fixture
def index(self) -> EntityIndex[Fact]:
"""Create an entity index."""
return EntityIndex[Fact]()
@pytest.mark.asyncio
async def test_add_item(self, index: EntityIndex[Fact]) -> None:
"""Test adding an item."""
fact = make_fact(subject="user", obj="admin")
entry = await index.add(fact)
assert entry.memory_id == fact.id
assert await index.count() == 1
@pytest.mark.asyncio
async def test_search_by_entity(self, index: EntityIndex[Fact]) -> None:
"""Test searching by entity."""
f1 = make_fact(subject="user", obj="admin")
f2 = make_fact(subject="system", obj="config")
await index.add(f1)
await index.add(f2)
results = await index.search(
query=None,
entity_type="subject",
entity_value="user",
)
assert len(results) == 1
assert results[0].memory_id == f1.id
@pytest.mark.asyncio
async def test_search_multiple_entities(self, index: EntityIndex[Fact]) -> None:
"""Test searching with multiple entities."""
f1 = make_fact(subject="user", obj="admin")
f2 = make_fact(subject="user", obj="guest")
await index.add(f1)
await index.add(f2)
# Search for facts about "user" subject
results = await index.search(
query=None,
entities=[("subject", "user")],
)
assert len(results) == 2
@pytest.mark.asyncio
async def test_search_match_all(self, index: EntityIndex[Fact]) -> None:
"""Test matching all entities."""
f1 = make_fact(subject="user", obj="admin")
f2 = make_fact(subject="user", obj="guest")
await index.add(f1)
await index.add(f2)
# Search for user+admin (match all)
results = await index.search(
query=None,
entities=[("subject", "user"), ("object", "admin")],
match_all=True,
)
assert len(results) == 1
assert results[0].memory_id == f1.id
@pytest.mark.asyncio
async def test_get_entities(self, index: EntityIndex[Fact]) -> None:
"""Test getting entities for a memory."""
fact = make_fact(subject="user", obj="admin")
await index.add(fact)
entities = await index.get_entities(fact.id)
assert ("subject", "user") in entities
assert ("object", "admin") in entities
class TestOutcomeIndex:
"""Tests for OutcomeIndex."""
@pytest.fixture
def index(self) -> OutcomeIndex[Episode]:
"""Create an outcome index."""
return OutcomeIndex[Episode]()
@pytest.mark.asyncio
async def test_add_item(self, index: OutcomeIndex[Episode]) -> None:
"""Test adding an item."""
episode = make_episode(outcome=Outcome.SUCCESS)
entry = await index.add(episode)
assert entry.memory_id == episode.id
assert entry.outcome == Outcome.SUCCESS
assert await index.count() == 1
@pytest.mark.asyncio
async def test_search_by_outcome(self, index: OutcomeIndex[Episode]) -> None:
"""Test searching by outcome."""
success = make_episode(outcome=Outcome.SUCCESS)
failure = make_episode(outcome=Outcome.FAILURE)
await index.add(success)
await index.add(failure)
results = await index.search(query=None, outcome=Outcome.SUCCESS)
assert len(results) == 1
assert results[0].memory_id == success.id
@pytest.mark.asyncio
async def test_search_multiple_outcomes(self, index: OutcomeIndex[Episode]) -> None:
"""Test searching with multiple outcomes."""
success = make_episode(outcome=Outcome.SUCCESS)
partial = make_episode(outcome=Outcome.PARTIAL)
failure = make_episode(outcome=Outcome.FAILURE)
await index.add(success)
await index.add(partial)
await index.add(failure)
results = await index.search(
query=None,
outcomes=[Outcome.SUCCESS, Outcome.PARTIAL],
)
assert len(results) == 2
@pytest.mark.asyncio
async def test_get_outcome_stats(self, index: OutcomeIndex[Episode]) -> None:
"""Test getting outcome statistics."""
await index.add(make_episode(outcome=Outcome.SUCCESS))
await index.add(make_episode(outcome=Outcome.SUCCESS))
await index.add(make_episode(outcome=Outcome.FAILURE))
stats = await index.get_outcome_stats()
assert stats[Outcome.SUCCESS] == 2
assert stats[Outcome.FAILURE] == 1
assert stats[Outcome.PARTIAL] == 0
class TestMemoryIndexer:
"""Tests for MemoryIndexer."""
@pytest.fixture
def indexer(self) -> MemoryIndexer:
"""Create a memory indexer."""
return MemoryIndexer()
@pytest.mark.asyncio
async def test_index_episode(self, indexer: MemoryIndexer) -> None:
"""Test indexing an episode."""
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
results = await indexer.index(episode)
assert "vector" in results
assert "temporal" in results
assert "entity" in results
assert "outcome" in results
@pytest.mark.asyncio
async def test_index_fact(self, indexer: MemoryIndexer) -> None:
"""Test indexing a fact."""
fact = make_fact(embedding=[1.0, 0.0, 0.0, 0.0])
results = await indexer.index(fact)
# Facts don't have outcomes
assert "vector" in results
assert "temporal" in results
assert "entity" in results
assert "outcome" not in results
@pytest.mark.asyncio
async def test_remove_from_all(self, indexer: MemoryIndexer) -> None:
"""Test removing from all indices."""
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
await indexer.index(episode)
results = await indexer.remove(episode.id)
assert results["vector"] is True
assert results["temporal"] is True
assert results["entity"] is True
assert results["outcome"] is True
@pytest.mark.asyncio
async def test_clear_all(self, indexer: MemoryIndexer) -> None:
"""Test clearing all indices."""
await indexer.index(make_episode(embedding=[1.0, 0.0, 0.0, 0.0]))
await indexer.index(make_episode(embedding=[0.0, 1.0, 0.0, 0.0]))
counts = await indexer.clear_all()
assert counts["vector"] == 2
assert counts["temporal"] == 2
@pytest.mark.asyncio
async def test_get_stats(self, indexer: MemoryIndexer) -> None:
"""Test getting index statistics."""
await indexer.index(make_episode(embedding=[1.0, 0.0, 0.0, 0.0]))
stats = await indexer.get_stats()
assert stats["vector"] == 1
assert stats["temporal"] == 1
assert stats["entity"] == 1
assert stats["outcome"] == 1
class TestGetMemoryIndexer:
"""Tests for singleton getter."""
def test_returns_instance(self) -> None:
"""Test that getter returns instance."""
indexer = get_memory_indexer()
assert indexer is not None
assert isinstance(indexer, MemoryIndexer)
def test_returns_same_instance(self) -> None:
"""Test that getter returns same instance."""
indexer1 = get_memory_indexer()
indexer2 = get_memory_indexer()
assert indexer1 is indexer2

View File

@@ -0,0 +1,450 @@
# tests/unit/services/memory/indexing/test_retrieval.py
"""Unit tests for memory retrieval."""
from datetime import UTC, datetime, timedelta
from uuid import uuid4
import pytest
from app.services.memory.indexing.index import MemoryIndexer
from app.services.memory.indexing.retrieval import (
RelevanceScorer,
RetrievalCache,
RetrievalEngine,
RetrievalQuery,
ScoredResult,
get_retrieval_engine,
)
from app.services.memory.types import Episode, MemoryType, Outcome
def _utcnow() -> datetime:
"""Get current UTC time."""
return datetime.now(UTC)
def make_episode(
embedding: list[float] | None = None,
outcome: Outcome = Outcome.SUCCESS,
occurred_at: datetime | None = None,
task_type: str = "test_task",
) -> Episode:
"""Create a test episode."""
return Episode(
id=uuid4(),
project_id=uuid4(),
agent_instance_id=uuid4(),
agent_type_id=uuid4(),
session_id="test-session",
task_type=task_type,
task_description="Test task description",
actions=[{"action": "test"}],
context_summary="Test context",
outcome=outcome,
outcome_details="Test outcome",
duration_seconds=10.0,
tokens_used=100,
lessons_learned=["lesson1"],
importance_score=0.8,
embedding=embedding,
occurred_at=occurred_at or _utcnow(),
created_at=_utcnow(),
updated_at=_utcnow(),
)
class TestRetrievalQuery:
"""Tests for RetrievalQuery."""
def test_default_values(self) -> None:
"""Test default query values."""
query = RetrievalQuery()
assert query.query_text is None
assert query.limit == 10
assert query.min_relevance == 0.0
assert query.use_vector is True
assert query.use_temporal is True
def test_cache_key_generation(self) -> None:
"""Test cache key generation."""
query1 = RetrievalQuery(query_text="test", limit=10)
query2 = RetrievalQuery(query_text="test", limit=10)
query3 = RetrievalQuery(query_text="different", limit=10)
# Same queries should have same key
assert query1.to_cache_key() == query2.to_cache_key()
# Different queries should have different keys
assert query1.to_cache_key() != query3.to_cache_key()
class TestScoredResult:
"""Tests for ScoredResult."""
def test_creation(self) -> None:
"""Test creating a scored result."""
result = ScoredResult(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
relevance_score=0.85,
score_breakdown={"vector": 0.9, "recency": 0.8},
)
assert result.relevance_score == 0.85
assert result.score_breakdown["vector"] == 0.9
class TestRelevanceScorer:
"""Tests for RelevanceScorer."""
@pytest.fixture
def scorer(self) -> RelevanceScorer:
"""Create a relevance scorer."""
return RelevanceScorer()
def test_score_with_vector(self, scorer: RelevanceScorer) -> None:
"""Test scoring with vector similarity."""
result = scorer.score(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
vector_similarity=0.9,
)
assert result.relevance_score > 0
assert result.score_breakdown["vector"] == 0.9
def test_score_with_recency(self, scorer: RelevanceScorer) -> None:
"""Test scoring with recency."""
recent_result = scorer.score(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
timestamp=_utcnow(),
)
old_result = scorer.score(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
timestamp=_utcnow() - timedelta(days=7),
)
# Recent should have higher recency score
assert (
recent_result.score_breakdown["recency"]
> old_result.score_breakdown["recency"]
)
def test_score_with_outcome_preference(self, scorer: RelevanceScorer) -> None:
"""Test scoring with outcome preference."""
success_result = scorer.score(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
outcome=Outcome.SUCCESS,
preferred_outcomes=[Outcome.SUCCESS],
)
failure_result = scorer.score(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
outcome=Outcome.FAILURE,
preferred_outcomes=[Outcome.SUCCESS],
)
assert success_result.score_breakdown["outcome"] == 1.0
assert failure_result.score_breakdown["outcome"] == 0.0
def test_score_with_entity_match(self, scorer: RelevanceScorer) -> None:
"""Test scoring with entity matches."""
full_match = scorer.score(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
entity_match_count=3,
entity_total=3,
)
partial_match = scorer.score(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
entity_match_count=1,
entity_total=3,
)
assert (
full_match.score_breakdown["entity"]
> partial_match.score_breakdown["entity"]
)
class TestRetrievalCache:
"""Tests for RetrievalCache."""
@pytest.fixture
def cache(self) -> RetrievalCache:
"""Create a retrieval cache."""
return RetrievalCache(max_entries=10, default_ttl_seconds=60)
def test_put_and_get(self, cache: RetrievalCache) -> None:
"""Test putting and getting from cache."""
results = [
ScoredResult(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
relevance_score=0.8,
)
]
cache.put("test_key", results)
cached = cache.get("test_key")
assert cached is not None
assert len(cached) == 1
def test_get_nonexistent(self, cache: RetrievalCache) -> None:
"""Test getting nonexistent entry."""
result = cache.get("nonexistent")
assert result is None
def test_lru_eviction(self) -> None:
"""Test LRU eviction when at capacity."""
cache = RetrievalCache(max_entries=2, default_ttl_seconds=60)
results = [
ScoredResult(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
relevance_score=0.8,
)
]
cache.put("key1", results)
cache.put("key2", results)
cache.put("key3", results) # Should evict key1
assert cache.get("key1") is None
assert cache.get("key2") is not None
assert cache.get("key3") is not None
def test_invalidate(self, cache: RetrievalCache) -> None:
"""Test invalidating a cache entry."""
results = [
ScoredResult(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
relevance_score=0.8,
)
]
cache.put("test_key", results)
removed = cache.invalidate("test_key")
assert removed is True
assert cache.get("test_key") is None
def test_invalidate_by_memory(self, cache: RetrievalCache) -> None:
"""Test invalidating by memory ID."""
memory_id = uuid4()
results = [
ScoredResult(
memory_id=memory_id,
memory_type=MemoryType.EPISODIC,
relevance_score=0.8,
)
]
cache.put("key1", results)
cache.put("key2", results)
count = cache.invalidate_by_memory(memory_id)
assert count == 2
assert cache.get("key1") is None
assert cache.get("key2") is None
def test_clear(self, cache: RetrievalCache) -> None:
"""Test clearing the cache."""
results = [
ScoredResult(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
relevance_score=0.8,
)
]
cache.put("key1", results)
cache.put("key2", results)
count = cache.clear()
assert count == 2
assert cache.get("key1") is None
def test_get_stats(self, cache: RetrievalCache) -> None:
"""Test getting cache statistics."""
stats = cache.get_stats()
assert "total_entries" in stats
assert "max_entries" in stats
assert stats["max_entries"] == 10
class TestRetrievalEngine:
"""Tests for RetrievalEngine."""
@pytest.fixture
def indexer(self) -> MemoryIndexer:
"""Create a memory indexer."""
return MemoryIndexer()
@pytest.fixture
def engine(self, indexer: MemoryIndexer) -> RetrievalEngine:
"""Create a retrieval engine."""
return RetrievalEngine(indexer=indexer, enable_cache=True)
@pytest.mark.asyncio
async def test_retrieve_by_vector(
self, engine: RetrievalEngine, indexer: MemoryIndexer
) -> None:
"""Test retrieval by vector similarity."""
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
e2 = make_episode(embedding=[0.9, 0.1, 0.0, 0.0])
e3 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0])
await indexer.index(e1)
await indexer.index(e2)
await indexer.index(e3)
query = RetrievalQuery(
query_embedding=[1.0, 0.0, 0.0, 0.0],
limit=2,
use_temporal=False,
use_entity=False,
use_outcome=False,
)
result = await engine.retrieve(query)
assert len(result.items) > 0
assert result.retrieval_type == "hybrid"
@pytest.mark.asyncio
async def test_retrieve_recent(
self, engine: RetrievalEngine, indexer: MemoryIndexer
) -> None:
"""Test retrieval of recent items."""
now = _utcnow()
old = make_episode(occurred_at=now - timedelta(hours=2))
recent = make_episode(occurred_at=now - timedelta(minutes=30))
await indexer.index(old)
await indexer.index(recent)
result = await engine.retrieve_recent(hours=1)
assert len(result.items) == 1
@pytest.mark.asyncio
async def test_retrieve_by_entity(
self, engine: RetrievalEngine, indexer: MemoryIndexer
) -> None:
"""Test retrieval by entity."""
e1 = make_episode(task_type="deploy")
e2 = make_episode(task_type="test")
await indexer.index(e1)
await indexer.index(e2)
result = await engine.retrieve_by_entity("task_type", "deploy")
assert len(result.items) == 1
@pytest.mark.asyncio
async def test_retrieve_successful(
self, engine: RetrievalEngine, indexer: MemoryIndexer
) -> None:
"""Test retrieval of successful items."""
success = make_episode(outcome=Outcome.SUCCESS)
failure = make_episode(outcome=Outcome.FAILURE)
await indexer.index(success)
await indexer.index(failure)
result = await engine.retrieve_successful()
assert len(result.items) == 1
# Check outcome index was used
assert result.items[0].memory_id == success.id
@pytest.mark.asyncio
async def test_retrieve_with_cache(
self, engine: RetrievalEngine, indexer: MemoryIndexer
) -> None:
"""Test that retrieval uses cache."""
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
await indexer.index(episode)
query = RetrievalQuery(
query_embedding=[1.0, 0.0, 0.0, 0.0],
limit=10,
)
# First retrieval
result1 = await engine.retrieve(query)
assert result1.metadata.get("cache_hit") is False
# Second retrieval should be cached
result2 = await engine.retrieve(query)
assert result2.metadata.get("cache_hit") is True
@pytest.mark.asyncio
async def test_invalidate_cache(
self, engine: RetrievalEngine, indexer: MemoryIndexer
) -> None:
"""Test cache invalidation."""
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
await indexer.index(episode)
query = RetrievalQuery(query_embedding=[1.0, 0.0, 0.0, 0.0])
await engine.retrieve(query)
count = engine.invalidate_cache()
assert count > 0
@pytest.mark.asyncio
async def test_retrieve_similar(
self, engine: RetrievalEngine, indexer: MemoryIndexer
) -> None:
"""Test retrieve_similar convenience method."""
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
e2 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0])
await indexer.index(e1)
await indexer.index(e2)
result = await engine.retrieve_similar(
embedding=[1.0, 0.0, 0.0, 0.0],
limit=1,
)
assert len(result.items) == 1
def test_get_cache_stats(self, engine: RetrievalEngine) -> None:
"""Test getting cache statistics."""
stats = engine.get_cache_stats()
assert "total_entries" in stats
class TestGetRetrievalEngine:
"""Tests for singleton getter."""
def test_returns_instance(self) -> None:
"""Test that getter returns instance."""
engine = get_retrieval_engine()
assert engine is not None
assert isinstance(engine, RetrievalEngine)
def test_returns_same_instance(self) -> None:
"""Test that getter returns same instance."""
engine1 = get_retrieval_engine()
engine2 = get_retrieval_engine()
assert engine1 is engine2