forked from cardosofelipe/fast-next-template
feat(memory): implement memory indexing and retrieval engine (#94)
Add comprehensive indexing and retrieval system for memory search: - VectorIndex for semantic similarity search using cosine similarity - TemporalIndex for time-based queries with range and recency support - EntityIndex for entity-based lookups with multi-entity intersection - OutcomeIndex for success/failure filtering on episodes - MemoryIndexer as unified interface for all index types - RetrievalEngine with hybrid search combining all indices - RelevanceScorer for multi-signal relevance scoring - RetrievalCache for LRU caching of search results 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2
backend/tests/unit/services/memory/indexing/__init__.py
Normal file
2
backend/tests/unit/services/memory/indexing/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# tests/unit/services/memory/indexing/__init__.py
|
||||
"""Unit tests for memory indexing."""
|
||||
497
backend/tests/unit/services/memory/indexing/test_index.py
Normal file
497
backend/tests/unit/services/memory/indexing/test_index.py
Normal file
@@ -0,0 +1,497 @@
|
||||
# tests/unit/services/memory/indexing/test_index.py
|
||||
"""Unit tests for memory indexing."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.indexing.index import (
|
||||
EntityIndex,
|
||||
MemoryIndexer,
|
||||
OutcomeIndex,
|
||||
TemporalIndex,
|
||||
VectorIndex,
|
||||
get_memory_indexer,
|
||||
)
|
||||
from app.services.memory.types import Episode, Fact, MemoryType, Outcome, Procedure
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
def make_episode(
|
||||
embedding: list[float] | None = None,
|
||||
outcome: Outcome = Outcome.SUCCESS,
|
||||
occurred_at: datetime | None = None,
|
||||
) -> Episode:
|
||||
"""Create a test episode."""
|
||||
return Episode(
|
||||
id=uuid4(),
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
agent_type_id=uuid4(),
|
||||
session_id="test-session",
|
||||
task_type="test_task",
|
||||
task_description="Test task description",
|
||||
actions=[{"action": "test"}],
|
||||
context_summary="Test context",
|
||||
outcome=outcome,
|
||||
outcome_details="Test outcome",
|
||||
duration_seconds=10.0,
|
||||
tokens_used=100,
|
||||
lessons_learned=["lesson1"],
|
||||
importance_score=0.8,
|
||||
embedding=embedding,
|
||||
occurred_at=occurred_at or _utcnow(),
|
||||
created_at=_utcnow(),
|
||||
updated_at=_utcnow(),
|
||||
)
|
||||
|
||||
|
||||
def make_fact(
|
||||
embedding: list[float] | None = None,
|
||||
subject: str = "test_subject",
|
||||
predicate: str = "has_property",
|
||||
obj: str = "test_value",
|
||||
) -> Fact:
|
||||
"""Create a test fact."""
|
||||
return Fact(
|
||||
id=uuid4(),
|
||||
project_id=uuid4(),
|
||||
subject=subject,
|
||||
predicate=predicate,
|
||||
object=obj,
|
||||
confidence=0.9,
|
||||
source_episode_ids=[uuid4()],
|
||||
first_learned=_utcnow(),
|
||||
last_reinforced=_utcnow(),
|
||||
reinforcement_count=1,
|
||||
embedding=embedding,
|
||||
created_at=_utcnow(),
|
||||
updated_at=_utcnow(),
|
||||
)
|
||||
|
||||
|
||||
def make_procedure(
|
||||
embedding: list[float] | None = None,
|
||||
success_count: int = 8,
|
||||
failure_count: int = 2,
|
||||
) -> Procedure:
|
||||
"""Create a test procedure."""
|
||||
return Procedure(
|
||||
id=uuid4(),
|
||||
project_id=uuid4(),
|
||||
agent_type_id=uuid4(),
|
||||
name="test_procedure",
|
||||
trigger_pattern="test.*",
|
||||
steps=[{"step": 1, "action": "test"}],
|
||||
success_count=success_count,
|
||||
failure_count=failure_count,
|
||||
last_used=_utcnow(),
|
||||
embedding=embedding,
|
||||
created_at=_utcnow(),
|
||||
updated_at=_utcnow(),
|
||||
)
|
||||
|
||||
|
||||
class TestVectorIndex:
|
||||
"""Tests for VectorIndex."""
|
||||
|
||||
@pytest.fixture
|
||||
def index(self) -> VectorIndex[Episode]:
|
||||
"""Create a vector index."""
|
||||
return VectorIndex[Episode](dimension=4)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_item(self, index: VectorIndex[Episode]) -> None:
|
||||
"""Test adding an item to the index."""
|
||||
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
|
||||
entry = await index.add(episode)
|
||||
|
||||
assert entry.memory_id == episode.id
|
||||
assert entry.memory_type == MemoryType.EPISODIC
|
||||
assert entry.dimension == 4
|
||||
assert await index.count() == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_item(self, index: VectorIndex[Episode]) -> None:
|
||||
"""Test removing an item from the index."""
|
||||
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
await index.add(episode)
|
||||
|
||||
result = await index.remove(episode.id)
|
||||
|
||||
assert result is True
|
||||
assert await index.count() == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_nonexistent(self, index: VectorIndex[Episode]) -> None:
|
||||
"""Test removing a nonexistent item."""
|
||||
result = await index.remove(uuid4())
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_similar(self, index: VectorIndex[Episode]) -> None:
|
||||
"""Test searching for similar items."""
|
||||
# Add items with different embeddings
|
||||
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
e2 = make_episode(embedding=[0.9, 0.1, 0.0, 0.0])
|
||||
e3 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0])
|
||||
|
||||
await index.add(e1)
|
||||
await index.add(e2)
|
||||
await index.add(e3)
|
||||
|
||||
# Search for similar to first
|
||||
results = await index.search([1.0, 0.0, 0.0, 0.0], limit=2)
|
||||
|
||||
assert len(results) == 2
|
||||
# First result should be most similar
|
||||
assert results[0].memory_id == e1.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_min_similarity(self, index: VectorIndex[Episode]) -> None:
|
||||
"""Test minimum similarity threshold."""
|
||||
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
e2 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0]) # Orthogonal
|
||||
|
||||
await index.add(e1)
|
||||
await index.add(e2)
|
||||
|
||||
# Search with high threshold
|
||||
results = await index.search([1.0, 0.0, 0.0, 0.0], min_similarity=0.9)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].memory_id == e1.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_empty_query(self, index: VectorIndex[Episode]) -> None:
|
||||
"""Test search with empty query."""
|
||||
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
await index.add(e1)
|
||||
|
||||
results = await index.search([], limit=10)
|
||||
assert len(results) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear(self, index: VectorIndex[Episode]) -> None:
|
||||
"""Test clearing the index."""
|
||||
await index.add(make_episode(embedding=[1.0, 0.0, 0.0, 0.0]))
|
||||
await index.add(make_episode(embedding=[0.0, 1.0, 0.0, 0.0]))
|
||||
|
||||
count = await index.clear()
|
||||
|
||||
assert count == 2
|
||||
assert await index.count() == 0
|
||||
|
||||
|
||||
class TestTemporalIndex:
|
||||
"""Tests for TemporalIndex."""
|
||||
|
||||
@pytest.fixture
|
||||
def index(self) -> TemporalIndex[Episode]:
|
||||
"""Create a temporal index."""
|
||||
return TemporalIndex[Episode]()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_item(self, index: TemporalIndex[Episode]) -> None:
|
||||
"""Test adding an item."""
|
||||
episode = make_episode()
|
||||
entry = await index.add(episode)
|
||||
|
||||
assert entry.memory_id == episode.id
|
||||
assert await index.count() == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_by_time_range(self, index: TemporalIndex[Episode]) -> None:
|
||||
"""Test searching by time range."""
|
||||
now = _utcnow()
|
||||
old = make_episode(occurred_at=now - timedelta(hours=2))
|
||||
recent = make_episode(occurred_at=now - timedelta(hours=1))
|
||||
newest = make_episode(occurred_at=now)
|
||||
|
||||
await index.add(old)
|
||||
await index.add(recent)
|
||||
await index.add(newest)
|
||||
|
||||
# Search last hour
|
||||
results = await index.search(
|
||||
query=None,
|
||||
start_time=now - timedelta(hours=1, minutes=30),
|
||||
end_time=now,
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_recent(self, index: TemporalIndex[Episode]) -> None:
|
||||
"""Test searching for recent items."""
|
||||
now = _utcnow()
|
||||
old = make_episode(occurred_at=now - timedelta(hours=2))
|
||||
recent = make_episode(occurred_at=now - timedelta(minutes=30))
|
||||
|
||||
await index.add(old)
|
||||
await index.add(recent)
|
||||
|
||||
# Search last hour (3600 seconds)
|
||||
results = await index.search(query=None, recent_seconds=3600)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].memory_id == recent.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_order(self, index: TemporalIndex[Episode]) -> None:
|
||||
"""Test result ordering."""
|
||||
now = _utcnow()
|
||||
e1 = make_episode(occurred_at=now - timedelta(hours=2))
|
||||
e2 = make_episode(occurred_at=now - timedelta(hours=1))
|
||||
e3 = make_episode(occurred_at=now)
|
||||
|
||||
await index.add(e1)
|
||||
await index.add(e2)
|
||||
await index.add(e3)
|
||||
|
||||
# Descending order (newest first)
|
||||
results_desc = await index.search(query=None, order="desc", limit=10)
|
||||
assert results_desc[0].memory_id == e3.id
|
||||
|
||||
# Ascending order (oldest first)
|
||||
results_asc = await index.search(query=None, order="asc", limit=10)
|
||||
assert results_asc[0].memory_id == e1.id
|
||||
|
||||
|
||||
class TestEntityIndex:
|
||||
"""Tests for EntityIndex."""
|
||||
|
||||
@pytest.fixture
|
||||
def index(self) -> EntityIndex[Fact]:
|
||||
"""Create an entity index."""
|
||||
return EntityIndex[Fact]()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_item(self, index: EntityIndex[Fact]) -> None:
|
||||
"""Test adding an item."""
|
||||
fact = make_fact(subject="user", obj="admin")
|
||||
entry = await index.add(fact)
|
||||
|
||||
assert entry.memory_id == fact.id
|
||||
assert await index.count() == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_by_entity(self, index: EntityIndex[Fact]) -> None:
|
||||
"""Test searching by entity."""
|
||||
f1 = make_fact(subject="user", obj="admin")
|
||||
f2 = make_fact(subject="system", obj="config")
|
||||
|
||||
await index.add(f1)
|
||||
await index.add(f2)
|
||||
|
||||
results = await index.search(
|
||||
query=None,
|
||||
entity_type="subject",
|
||||
entity_value="user",
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].memory_id == f1.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_multiple_entities(self, index: EntityIndex[Fact]) -> None:
|
||||
"""Test searching with multiple entities."""
|
||||
f1 = make_fact(subject="user", obj="admin")
|
||||
f2 = make_fact(subject="user", obj="guest")
|
||||
|
||||
await index.add(f1)
|
||||
await index.add(f2)
|
||||
|
||||
# Search for facts about "user" subject
|
||||
results = await index.search(
|
||||
query=None,
|
||||
entities=[("subject", "user")],
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_match_all(self, index: EntityIndex[Fact]) -> None:
|
||||
"""Test matching all entities."""
|
||||
f1 = make_fact(subject="user", obj="admin")
|
||||
f2 = make_fact(subject="user", obj="guest")
|
||||
|
||||
await index.add(f1)
|
||||
await index.add(f2)
|
||||
|
||||
# Search for user+admin (match all)
|
||||
results = await index.search(
|
||||
query=None,
|
||||
entities=[("subject", "user"), ("object", "admin")],
|
||||
match_all=True,
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].memory_id == f1.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_entities(self, index: EntityIndex[Fact]) -> None:
|
||||
"""Test getting entities for a memory."""
|
||||
fact = make_fact(subject="user", obj="admin")
|
||||
await index.add(fact)
|
||||
|
||||
entities = await index.get_entities(fact.id)
|
||||
|
||||
assert ("subject", "user") in entities
|
||||
assert ("object", "admin") in entities
|
||||
|
||||
|
||||
class TestOutcomeIndex:
|
||||
"""Tests for OutcomeIndex."""
|
||||
|
||||
@pytest.fixture
|
||||
def index(self) -> OutcomeIndex[Episode]:
|
||||
"""Create an outcome index."""
|
||||
return OutcomeIndex[Episode]()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_item(self, index: OutcomeIndex[Episode]) -> None:
|
||||
"""Test adding an item."""
|
||||
episode = make_episode(outcome=Outcome.SUCCESS)
|
||||
entry = await index.add(episode)
|
||||
|
||||
assert entry.memory_id == episode.id
|
||||
assert entry.outcome == Outcome.SUCCESS
|
||||
assert await index.count() == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_by_outcome(self, index: OutcomeIndex[Episode]) -> None:
|
||||
"""Test searching by outcome."""
|
||||
success = make_episode(outcome=Outcome.SUCCESS)
|
||||
failure = make_episode(outcome=Outcome.FAILURE)
|
||||
|
||||
await index.add(success)
|
||||
await index.add(failure)
|
||||
|
||||
results = await index.search(query=None, outcome=Outcome.SUCCESS)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].memory_id == success.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_multiple_outcomes(self, index: OutcomeIndex[Episode]) -> None:
|
||||
"""Test searching with multiple outcomes."""
|
||||
success = make_episode(outcome=Outcome.SUCCESS)
|
||||
partial = make_episode(outcome=Outcome.PARTIAL)
|
||||
failure = make_episode(outcome=Outcome.FAILURE)
|
||||
|
||||
await index.add(success)
|
||||
await index.add(partial)
|
||||
await index.add(failure)
|
||||
|
||||
results = await index.search(
|
||||
query=None,
|
||||
outcomes=[Outcome.SUCCESS, Outcome.PARTIAL],
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_outcome_stats(self, index: OutcomeIndex[Episode]) -> None:
|
||||
"""Test getting outcome statistics."""
|
||||
await index.add(make_episode(outcome=Outcome.SUCCESS))
|
||||
await index.add(make_episode(outcome=Outcome.SUCCESS))
|
||||
await index.add(make_episode(outcome=Outcome.FAILURE))
|
||||
|
||||
stats = await index.get_outcome_stats()
|
||||
|
||||
assert stats[Outcome.SUCCESS] == 2
|
||||
assert stats[Outcome.FAILURE] == 1
|
||||
assert stats[Outcome.PARTIAL] == 0
|
||||
|
||||
|
||||
class TestMemoryIndexer:
|
||||
"""Tests for MemoryIndexer."""
|
||||
|
||||
@pytest.fixture
|
||||
def indexer(self) -> MemoryIndexer:
|
||||
"""Create a memory indexer."""
|
||||
return MemoryIndexer()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_index_episode(self, indexer: MemoryIndexer) -> None:
|
||||
"""Test indexing an episode."""
|
||||
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
|
||||
results = await indexer.index(episode)
|
||||
|
||||
assert "vector" in results
|
||||
assert "temporal" in results
|
||||
assert "entity" in results
|
||||
assert "outcome" in results
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_index_fact(self, indexer: MemoryIndexer) -> None:
|
||||
"""Test indexing a fact."""
|
||||
fact = make_fact(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
|
||||
results = await indexer.index(fact)
|
||||
|
||||
# Facts don't have outcomes
|
||||
assert "vector" in results
|
||||
assert "temporal" in results
|
||||
assert "entity" in results
|
||||
assert "outcome" not in results
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_from_all(self, indexer: MemoryIndexer) -> None:
|
||||
"""Test removing from all indices."""
|
||||
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
await indexer.index(episode)
|
||||
|
||||
results = await indexer.remove(episode.id)
|
||||
|
||||
assert results["vector"] is True
|
||||
assert results["temporal"] is True
|
||||
assert results["entity"] is True
|
||||
assert results["outcome"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_all(self, indexer: MemoryIndexer) -> None:
|
||||
"""Test clearing all indices."""
|
||||
await indexer.index(make_episode(embedding=[1.0, 0.0, 0.0, 0.0]))
|
||||
await indexer.index(make_episode(embedding=[0.0, 1.0, 0.0, 0.0]))
|
||||
|
||||
counts = await indexer.clear_all()
|
||||
|
||||
assert counts["vector"] == 2
|
||||
assert counts["temporal"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stats(self, indexer: MemoryIndexer) -> None:
|
||||
"""Test getting index statistics."""
|
||||
await indexer.index(make_episode(embedding=[1.0, 0.0, 0.0, 0.0]))
|
||||
|
||||
stats = await indexer.get_stats()
|
||||
|
||||
assert stats["vector"] == 1
|
||||
assert stats["temporal"] == 1
|
||||
assert stats["entity"] == 1
|
||||
assert stats["outcome"] == 1
|
||||
|
||||
|
||||
class TestGetMemoryIndexer:
|
||||
"""Tests for singleton getter."""
|
||||
|
||||
def test_returns_instance(self) -> None:
|
||||
"""Test that getter returns instance."""
|
||||
indexer = get_memory_indexer()
|
||||
assert indexer is not None
|
||||
assert isinstance(indexer, MemoryIndexer)
|
||||
|
||||
def test_returns_same_instance(self) -> None:
|
||||
"""Test that getter returns same instance."""
|
||||
indexer1 = get_memory_indexer()
|
||||
indexer2 = get_memory_indexer()
|
||||
assert indexer1 is indexer2
|
||||
450
backend/tests/unit/services/memory/indexing/test_retrieval.py
Normal file
450
backend/tests/unit/services/memory/indexing/test_retrieval.py
Normal file
@@ -0,0 +1,450 @@
|
||||
# tests/unit/services/memory/indexing/test_retrieval.py
|
||||
"""Unit tests for memory retrieval."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.indexing.index import MemoryIndexer
|
||||
from app.services.memory.indexing.retrieval import (
|
||||
RelevanceScorer,
|
||||
RetrievalCache,
|
||||
RetrievalEngine,
|
||||
RetrievalQuery,
|
||||
ScoredResult,
|
||||
get_retrieval_engine,
|
||||
)
|
||||
from app.services.memory.types import Episode, MemoryType, Outcome
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
def make_episode(
|
||||
embedding: list[float] | None = None,
|
||||
outcome: Outcome = Outcome.SUCCESS,
|
||||
occurred_at: datetime | None = None,
|
||||
task_type: str = "test_task",
|
||||
) -> Episode:
|
||||
"""Create a test episode."""
|
||||
return Episode(
|
||||
id=uuid4(),
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
agent_type_id=uuid4(),
|
||||
session_id="test-session",
|
||||
task_type=task_type,
|
||||
task_description="Test task description",
|
||||
actions=[{"action": "test"}],
|
||||
context_summary="Test context",
|
||||
outcome=outcome,
|
||||
outcome_details="Test outcome",
|
||||
duration_seconds=10.0,
|
||||
tokens_used=100,
|
||||
lessons_learned=["lesson1"],
|
||||
importance_score=0.8,
|
||||
embedding=embedding,
|
||||
occurred_at=occurred_at or _utcnow(),
|
||||
created_at=_utcnow(),
|
||||
updated_at=_utcnow(),
|
||||
)
|
||||
|
||||
|
||||
class TestRetrievalQuery:
|
||||
"""Tests for RetrievalQuery."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Test default query values."""
|
||||
query = RetrievalQuery()
|
||||
|
||||
assert query.query_text is None
|
||||
assert query.limit == 10
|
||||
assert query.min_relevance == 0.0
|
||||
assert query.use_vector is True
|
||||
assert query.use_temporal is True
|
||||
|
||||
def test_cache_key_generation(self) -> None:
|
||||
"""Test cache key generation."""
|
||||
query1 = RetrievalQuery(query_text="test", limit=10)
|
||||
query2 = RetrievalQuery(query_text="test", limit=10)
|
||||
query3 = RetrievalQuery(query_text="different", limit=10)
|
||||
|
||||
# Same queries should have same key
|
||||
assert query1.to_cache_key() == query2.to_cache_key()
|
||||
# Different queries should have different keys
|
||||
assert query1.to_cache_key() != query3.to_cache_key()
|
||||
|
||||
|
||||
class TestScoredResult:
|
||||
"""Tests for ScoredResult."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test creating a scored result."""
|
||||
result = ScoredResult(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
relevance_score=0.85,
|
||||
score_breakdown={"vector": 0.9, "recency": 0.8},
|
||||
)
|
||||
|
||||
assert result.relevance_score == 0.85
|
||||
assert result.score_breakdown["vector"] == 0.9
|
||||
|
||||
|
||||
class TestRelevanceScorer:
|
||||
"""Tests for RelevanceScorer."""
|
||||
|
||||
@pytest.fixture
|
||||
def scorer(self) -> RelevanceScorer:
|
||||
"""Create a relevance scorer."""
|
||||
return RelevanceScorer()
|
||||
|
||||
def test_score_with_vector(self, scorer: RelevanceScorer) -> None:
|
||||
"""Test scoring with vector similarity."""
|
||||
result = scorer.score(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
vector_similarity=0.9,
|
||||
)
|
||||
|
||||
assert result.relevance_score > 0
|
||||
assert result.score_breakdown["vector"] == 0.9
|
||||
|
||||
def test_score_with_recency(self, scorer: RelevanceScorer) -> None:
|
||||
"""Test scoring with recency."""
|
||||
recent_result = scorer.score(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
timestamp=_utcnow(),
|
||||
)
|
||||
|
||||
old_result = scorer.score(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
timestamp=_utcnow() - timedelta(days=7),
|
||||
)
|
||||
|
||||
# Recent should have higher recency score
|
||||
assert (
|
||||
recent_result.score_breakdown["recency"]
|
||||
> old_result.score_breakdown["recency"]
|
||||
)
|
||||
|
||||
def test_score_with_outcome_preference(self, scorer: RelevanceScorer) -> None:
|
||||
"""Test scoring with outcome preference."""
|
||||
success_result = scorer.score(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
outcome=Outcome.SUCCESS,
|
||||
preferred_outcomes=[Outcome.SUCCESS],
|
||||
)
|
||||
|
||||
failure_result = scorer.score(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
outcome=Outcome.FAILURE,
|
||||
preferred_outcomes=[Outcome.SUCCESS],
|
||||
)
|
||||
|
||||
assert success_result.score_breakdown["outcome"] == 1.0
|
||||
assert failure_result.score_breakdown["outcome"] == 0.0
|
||||
|
||||
def test_score_with_entity_match(self, scorer: RelevanceScorer) -> None:
|
||||
"""Test scoring with entity matches."""
|
||||
full_match = scorer.score(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
entity_match_count=3,
|
||||
entity_total=3,
|
||||
)
|
||||
|
||||
partial_match = scorer.score(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
entity_match_count=1,
|
||||
entity_total=3,
|
||||
)
|
||||
|
||||
assert (
|
||||
full_match.score_breakdown["entity"]
|
||||
> partial_match.score_breakdown["entity"]
|
||||
)
|
||||
|
||||
|
||||
class TestRetrievalCache:
|
||||
"""Tests for RetrievalCache."""
|
||||
|
||||
@pytest.fixture
|
||||
def cache(self) -> RetrievalCache:
|
||||
"""Create a retrieval cache."""
|
||||
return RetrievalCache(max_entries=10, default_ttl_seconds=60)
|
||||
|
||||
def test_put_and_get(self, cache: RetrievalCache) -> None:
|
||||
"""Test putting and getting from cache."""
|
||||
results = [
|
||||
ScoredResult(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
relevance_score=0.8,
|
||||
)
|
||||
]
|
||||
|
||||
cache.put("test_key", results)
|
||||
cached = cache.get("test_key")
|
||||
|
||||
assert cached is not None
|
||||
assert len(cached) == 1
|
||||
|
||||
def test_get_nonexistent(self, cache: RetrievalCache) -> None:
|
||||
"""Test getting nonexistent entry."""
|
||||
result = cache.get("nonexistent")
|
||||
assert result is None
|
||||
|
||||
def test_lru_eviction(self) -> None:
|
||||
"""Test LRU eviction when at capacity."""
|
||||
cache = RetrievalCache(max_entries=2, default_ttl_seconds=60)
|
||||
|
||||
results = [
|
||||
ScoredResult(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
relevance_score=0.8,
|
||||
)
|
||||
]
|
||||
|
||||
cache.put("key1", results)
|
||||
cache.put("key2", results)
|
||||
cache.put("key3", results) # Should evict key1
|
||||
|
||||
assert cache.get("key1") is None
|
||||
assert cache.get("key2") is not None
|
||||
assert cache.get("key3") is not None
|
||||
|
||||
def test_invalidate(self, cache: RetrievalCache) -> None:
|
||||
"""Test invalidating a cache entry."""
|
||||
results = [
|
||||
ScoredResult(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
relevance_score=0.8,
|
||||
)
|
||||
]
|
||||
|
||||
cache.put("test_key", results)
|
||||
removed = cache.invalidate("test_key")
|
||||
|
||||
assert removed is True
|
||||
assert cache.get("test_key") is None
|
||||
|
||||
def test_invalidate_by_memory(self, cache: RetrievalCache) -> None:
|
||||
"""Test invalidating by memory ID."""
|
||||
memory_id = uuid4()
|
||||
results = [
|
||||
ScoredResult(
|
||||
memory_id=memory_id,
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
relevance_score=0.8,
|
||||
)
|
||||
]
|
||||
|
||||
cache.put("key1", results)
|
||||
cache.put("key2", results)
|
||||
|
||||
count = cache.invalidate_by_memory(memory_id)
|
||||
|
||||
assert count == 2
|
||||
assert cache.get("key1") is None
|
||||
assert cache.get("key2") is None
|
||||
|
||||
def test_clear(self, cache: RetrievalCache) -> None:
|
||||
"""Test clearing the cache."""
|
||||
results = [
|
||||
ScoredResult(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
relevance_score=0.8,
|
||||
)
|
||||
]
|
||||
|
||||
cache.put("key1", results)
|
||||
cache.put("key2", results)
|
||||
|
||||
count = cache.clear()
|
||||
|
||||
assert count == 2
|
||||
assert cache.get("key1") is None
|
||||
|
||||
def test_get_stats(self, cache: RetrievalCache) -> None:
|
||||
"""Test getting cache statistics."""
|
||||
stats = cache.get_stats()
|
||||
|
||||
assert "total_entries" in stats
|
||||
assert "max_entries" in stats
|
||||
assert stats["max_entries"] == 10
|
||||
|
||||
|
||||
class TestRetrievalEngine:
|
||||
"""Tests for RetrievalEngine."""
|
||||
|
||||
@pytest.fixture
|
||||
def indexer(self) -> MemoryIndexer:
|
||||
"""Create a memory indexer."""
|
||||
return MemoryIndexer()
|
||||
|
||||
@pytest.fixture
|
||||
def engine(self, indexer: MemoryIndexer) -> RetrievalEngine:
|
||||
"""Create a retrieval engine."""
|
||||
return RetrievalEngine(indexer=indexer, enable_cache=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_by_vector(
|
||||
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
||||
) -> None:
|
||||
"""Test retrieval by vector similarity."""
|
||||
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
e2 = make_episode(embedding=[0.9, 0.1, 0.0, 0.0])
|
||||
e3 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0])
|
||||
|
||||
await indexer.index(e1)
|
||||
await indexer.index(e2)
|
||||
await indexer.index(e3)
|
||||
|
||||
query = RetrievalQuery(
|
||||
query_embedding=[1.0, 0.0, 0.0, 0.0],
|
||||
limit=2,
|
||||
use_temporal=False,
|
||||
use_entity=False,
|
||||
use_outcome=False,
|
||||
)
|
||||
|
||||
result = await engine.retrieve(query)
|
||||
|
||||
assert len(result.items) > 0
|
||||
assert result.retrieval_type == "hybrid"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_recent(
|
||||
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
||||
) -> None:
|
||||
"""Test retrieval of recent items."""
|
||||
now = _utcnow()
|
||||
old = make_episode(occurred_at=now - timedelta(hours=2))
|
||||
recent = make_episode(occurred_at=now - timedelta(minutes=30))
|
||||
|
||||
await indexer.index(old)
|
||||
await indexer.index(recent)
|
||||
|
||||
result = await engine.retrieve_recent(hours=1)
|
||||
|
||||
assert len(result.items) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_by_entity(
|
||||
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
||||
) -> None:
|
||||
"""Test retrieval by entity."""
|
||||
e1 = make_episode(task_type="deploy")
|
||||
e2 = make_episode(task_type="test")
|
||||
|
||||
await indexer.index(e1)
|
||||
await indexer.index(e2)
|
||||
|
||||
result = await engine.retrieve_by_entity("task_type", "deploy")
|
||||
|
||||
assert len(result.items) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_successful(
|
||||
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
||||
) -> None:
|
||||
"""Test retrieval of successful items."""
|
||||
success = make_episode(outcome=Outcome.SUCCESS)
|
||||
failure = make_episode(outcome=Outcome.FAILURE)
|
||||
|
||||
await indexer.index(success)
|
||||
await indexer.index(failure)
|
||||
|
||||
result = await engine.retrieve_successful()
|
||||
|
||||
assert len(result.items) == 1
|
||||
# Check outcome index was used
|
||||
assert result.items[0].memory_id == success.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_with_cache(
|
||||
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
||||
) -> None:
|
||||
"""Test that retrieval uses cache."""
|
||||
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
await indexer.index(episode)
|
||||
|
||||
query = RetrievalQuery(
|
||||
query_embedding=[1.0, 0.0, 0.0, 0.0],
|
||||
limit=10,
|
||||
)
|
||||
|
||||
# First retrieval
|
||||
result1 = await engine.retrieve(query)
|
||||
assert result1.metadata.get("cache_hit") is False
|
||||
|
||||
# Second retrieval should be cached
|
||||
result2 = await engine.retrieve(query)
|
||||
assert result2.metadata.get("cache_hit") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalidate_cache(
|
||||
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
||||
) -> None:
|
||||
"""Test cache invalidation."""
|
||||
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
await indexer.index(episode)
|
||||
|
||||
query = RetrievalQuery(query_embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
await engine.retrieve(query)
|
||||
|
||||
count = engine.invalidate_cache()
|
||||
|
||||
assert count > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_similar(
|
||||
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
||||
) -> None:
|
||||
"""Test retrieve_similar convenience method."""
|
||||
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
e2 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0])
|
||||
|
||||
await indexer.index(e1)
|
||||
await indexer.index(e2)
|
||||
|
||||
result = await engine.retrieve_similar(
|
||||
embedding=[1.0, 0.0, 0.0, 0.0],
|
||||
limit=1,
|
||||
)
|
||||
|
||||
assert len(result.items) == 1
|
||||
|
||||
def test_get_cache_stats(self, engine: RetrievalEngine) -> None:
|
||||
"""Test getting cache statistics."""
|
||||
stats = engine.get_cache_stats()
|
||||
|
||||
assert "total_entries" in stats
|
||||
|
||||
|
||||
class TestGetRetrievalEngine:
|
||||
"""Tests for singleton getter."""
|
||||
|
||||
def test_returns_instance(self) -> None:
|
||||
"""Test that getter returns instance."""
|
||||
engine = get_retrieval_engine()
|
||||
assert engine is not None
|
||||
assert isinstance(engine, RetrievalEngine)
|
||||
|
||||
def test_returns_same_instance(self) -> None:
|
||||
"""Test that getter returns same instance."""
|
||||
engine1 = get_retrieval_engine()
|
||||
engine2 = get_retrieval_engine()
|
||||
assert engine1 is engine2
|
||||
Reference in New Issue
Block a user