# tests/unit/services/memory/indexing/test_retrieval.py """Unit tests for memory retrieval.""" from datetime import UTC, datetime, timedelta from uuid import uuid4 import pytest from app.services.memory.indexing.index import MemoryIndexer from app.services.memory.indexing.retrieval import ( RelevanceScorer, RetrievalCache, RetrievalEngine, RetrievalQuery, ScoredResult, get_retrieval_engine, ) from app.services.memory.types import Episode, MemoryType, Outcome def _utcnow() -> datetime: """Get current UTC time.""" return datetime.now(UTC) def make_episode( embedding: list[float] | None = None, outcome: Outcome = Outcome.SUCCESS, occurred_at: datetime | None = None, task_type: str = "test_task", ) -> Episode: """Create a test episode.""" return Episode( id=uuid4(), project_id=uuid4(), agent_instance_id=uuid4(), agent_type_id=uuid4(), session_id="test-session", task_type=task_type, task_description="Test task description", actions=[{"action": "test"}], context_summary="Test context", outcome=outcome, outcome_details="Test outcome", duration_seconds=10.0, tokens_used=100, lessons_learned=["lesson1"], importance_score=0.8, embedding=embedding, occurred_at=occurred_at or _utcnow(), created_at=_utcnow(), updated_at=_utcnow(), ) class TestRetrievalQuery: """Tests for RetrievalQuery.""" def test_default_values(self) -> None: """Test default query values.""" query = RetrievalQuery() assert query.query_text is None assert query.limit == 10 assert query.min_relevance == 0.0 assert query.use_vector is True assert query.use_temporal is True def test_cache_key_generation(self) -> None: """Test cache key generation.""" query1 = RetrievalQuery(query_text="test", limit=10) query2 = RetrievalQuery(query_text="test", limit=10) query3 = RetrievalQuery(query_text="different", limit=10) # Same queries should have same key assert query1.to_cache_key() == query2.to_cache_key() # Different queries should have different keys assert query1.to_cache_key() != query3.to_cache_key() class TestScoredResult: """Tests for ScoredResult.""" def test_creation(self) -> None: """Test creating a scored result.""" result = ScoredResult( memory_id=uuid4(), memory_type=MemoryType.EPISODIC, relevance_score=0.85, score_breakdown={"vector": 0.9, "recency": 0.8}, ) assert result.relevance_score == 0.85 assert result.score_breakdown["vector"] == 0.9 class TestRelevanceScorer: """Tests for RelevanceScorer.""" @pytest.fixture def scorer(self) -> RelevanceScorer: """Create a relevance scorer.""" return RelevanceScorer() def test_score_with_vector(self, scorer: RelevanceScorer) -> None: """Test scoring with vector similarity.""" result = scorer.score( memory_id=uuid4(), memory_type=MemoryType.EPISODIC, vector_similarity=0.9, ) assert result.relevance_score > 0 assert result.score_breakdown["vector"] == 0.9 def test_score_with_recency(self, scorer: RelevanceScorer) -> None: """Test scoring with recency.""" recent_result = scorer.score( memory_id=uuid4(), memory_type=MemoryType.EPISODIC, timestamp=_utcnow(), ) old_result = scorer.score( memory_id=uuid4(), memory_type=MemoryType.EPISODIC, timestamp=_utcnow() - timedelta(days=7), ) # Recent should have higher recency score assert ( recent_result.score_breakdown["recency"] > old_result.score_breakdown["recency"] ) def test_score_with_outcome_preference(self, scorer: RelevanceScorer) -> None: """Test scoring with outcome preference.""" success_result = scorer.score( memory_id=uuid4(), memory_type=MemoryType.EPISODIC, outcome=Outcome.SUCCESS, preferred_outcomes=[Outcome.SUCCESS], ) failure_result = scorer.score( memory_id=uuid4(), memory_type=MemoryType.EPISODIC, outcome=Outcome.FAILURE, preferred_outcomes=[Outcome.SUCCESS], ) assert success_result.score_breakdown["outcome"] == 1.0 assert failure_result.score_breakdown["outcome"] == 0.0 def test_score_with_entity_match(self, scorer: RelevanceScorer) -> None: """Test scoring with entity matches.""" full_match = scorer.score( memory_id=uuid4(), memory_type=MemoryType.EPISODIC, entity_match_count=3, entity_total=3, ) partial_match = scorer.score( memory_id=uuid4(), memory_type=MemoryType.EPISODIC, entity_match_count=1, entity_total=3, ) assert ( full_match.score_breakdown["entity"] > partial_match.score_breakdown["entity"] ) class TestRetrievalCache: """Tests for RetrievalCache.""" @pytest.fixture def cache(self) -> RetrievalCache: """Create a retrieval cache.""" return RetrievalCache(max_entries=10, default_ttl_seconds=60) def test_put_and_get(self, cache: RetrievalCache) -> None: """Test putting and getting from cache.""" results = [ ScoredResult( memory_id=uuid4(), memory_type=MemoryType.EPISODIC, relevance_score=0.8, ) ] cache.put("test_key", results) cached = cache.get("test_key") assert cached is not None assert len(cached) == 1 def test_get_nonexistent(self, cache: RetrievalCache) -> None: """Test getting nonexistent entry.""" result = cache.get("nonexistent") assert result is None def test_lru_eviction(self) -> None: """Test LRU eviction when at capacity.""" cache = RetrievalCache(max_entries=2, default_ttl_seconds=60) results = [ ScoredResult( memory_id=uuid4(), memory_type=MemoryType.EPISODIC, relevance_score=0.8, ) ] cache.put("key1", results) cache.put("key2", results) cache.put("key3", results) # Should evict key1 assert cache.get("key1") is None assert cache.get("key2") is not None assert cache.get("key3") is not None def test_invalidate(self, cache: RetrievalCache) -> None: """Test invalidating a cache entry.""" results = [ ScoredResult( memory_id=uuid4(), memory_type=MemoryType.EPISODIC, relevance_score=0.8, ) ] cache.put("test_key", results) removed = cache.invalidate("test_key") assert removed is True assert cache.get("test_key") is None def test_invalidate_by_memory(self, cache: RetrievalCache) -> None: """Test invalidating by memory ID.""" memory_id = uuid4() results = [ ScoredResult( memory_id=memory_id, memory_type=MemoryType.EPISODIC, relevance_score=0.8, ) ] cache.put("key1", results) cache.put("key2", results) count = cache.invalidate_by_memory(memory_id) assert count == 2 assert cache.get("key1") is None assert cache.get("key2") is None def test_clear(self, cache: RetrievalCache) -> None: """Test clearing the cache.""" results = [ ScoredResult( memory_id=uuid4(), memory_type=MemoryType.EPISODIC, relevance_score=0.8, ) ] cache.put("key1", results) cache.put("key2", results) count = cache.clear() assert count == 2 assert cache.get("key1") is None def test_get_stats(self, cache: RetrievalCache) -> None: """Test getting cache statistics.""" stats = cache.get_stats() assert "total_entries" in stats assert "max_entries" in stats assert stats["max_entries"] == 10 class TestRetrievalEngine: """Tests for RetrievalEngine.""" @pytest.fixture def indexer(self) -> MemoryIndexer: """Create a memory indexer.""" return MemoryIndexer() @pytest.fixture def engine(self, indexer: MemoryIndexer) -> RetrievalEngine: """Create a retrieval engine.""" return RetrievalEngine(indexer=indexer, enable_cache=True) @pytest.mark.asyncio async def test_retrieve_by_vector( self, engine: RetrievalEngine, indexer: MemoryIndexer ) -> None: """Test retrieval by vector similarity.""" e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0]) e2 = make_episode(embedding=[0.9, 0.1, 0.0, 0.0]) e3 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0]) await indexer.index(e1) await indexer.index(e2) await indexer.index(e3) query = RetrievalQuery( query_embedding=[1.0, 0.0, 0.0, 0.0], limit=2, use_temporal=False, use_entity=False, use_outcome=False, ) result = await engine.retrieve(query) assert len(result.items) > 0 assert result.retrieval_type == "hybrid" @pytest.mark.asyncio async def test_retrieve_recent( self, engine: RetrievalEngine, indexer: MemoryIndexer ) -> None: """Test retrieval of recent items.""" now = _utcnow() old = make_episode(occurred_at=now - timedelta(hours=2)) recent = make_episode(occurred_at=now - timedelta(minutes=30)) await indexer.index(old) await indexer.index(recent) result = await engine.retrieve_recent(hours=1) assert len(result.items) == 1 @pytest.mark.asyncio async def test_retrieve_by_entity( self, engine: RetrievalEngine, indexer: MemoryIndexer ) -> None: """Test retrieval by entity.""" e1 = make_episode(task_type="deploy") e2 = make_episode(task_type="test") await indexer.index(e1) await indexer.index(e2) result = await engine.retrieve_by_entity("task_type", "deploy") assert len(result.items) == 1 @pytest.mark.asyncio async def test_retrieve_successful( self, engine: RetrievalEngine, indexer: MemoryIndexer ) -> None: """Test retrieval of successful items.""" success = make_episode(outcome=Outcome.SUCCESS) failure = make_episode(outcome=Outcome.FAILURE) await indexer.index(success) await indexer.index(failure) result = await engine.retrieve_successful() assert len(result.items) == 1 # Check outcome index was used assert result.items[0].memory_id == success.id @pytest.mark.asyncio async def test_retrieve_with_cache( self, engine: RetrievalEngine, indexer: MemoryIndexer ) -> None: """Test that retrieval uses cache.""" episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0]) await indexer.index(episode) query = RetrievalQuery( query_embedding=[1.0, 0.0, 0.0, 0.0], limit=10, ) # First retrieval result1 = await engine.retrieve(query) assert result1.metadata.get("cache_hit") is False # Second retrieval should be cached result2 = await engine.retrieve(query) assert result2.metadata.get("cache_hit") is True @pytest.mark.asyncio async def test_invalidate_cache( self, engine: RetrievalEngine, indexer: MemoryIndexer ) -> None: """Test cache invalidation.""" episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0]) await indexer.index(episode) query = RetrievalQuery(query_embedding=[1.0, 0.0, 0.0, 0.0]) await engine.retrieve(query) count = engine.invalidate_cache() assert count > 0 @pytest.mark.asyncio async def test_retrieve_similar( self, engine: RetrievalEngine, indexer: MemoryIndexer ) -> None: """Test retrieve_similar convenience method.""" e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0]) e2 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0]) await indexer.index(e1) await indexer.index(e2) result = await engine.retrieve_similar( embedding=[1.0, 0.0, 0.0, 0.0], limit=1, ) assert len(result.items) == 1 def test_get_cache_stats(self, engine: RetrievalEngine) -> None: """Test getting cache statistics.""" stats = engine.get_cache_stats() assert "total_entries" in stats class TestGetRetrievalEngine: """Tests for singleton getter.""" def test_returns_instance(self) -> None: """Test that getter returns instance.""" engine = get_retrieval_engine() assert engine is not None assert isinstance(engine, RetrievalEngine) def test_returns_same_instance(self) -> None: """Test that getter returns same instance.""" engine1 = get_retrieval_engine() engine2 = get_retrieval_engine() assert engine1 is engine2