# tests/unit/services/memory/episodic/test_retrieval.py """Unit tests for episode retrieval strategies.""" from datetime import UTC, datetime, timedelta from unittest.mock import AsyncMock, MagicMock from uuid import uuid4 import pytest from app.models.memory.enums import EpisodeOutcome from app.services.memory.episodic.retrieval import ( EpisodeRetriever, ImportanceRetriever, OutcomeRetriever, RecencyRetriever, RetrievalStrategy, SemanticRetriever, TaskTypeRetriever, ) from app.services.memory.types import Outcome def create_mock_episode_model( project_id=None, outcome=EpisodeOutcome.SUCCESS, task_type="test_task", importance_score=0.5, occurred_at=None, ): """Create a mock episode model for testing.""" mock = MagicMock() mock.id = uuid4() mock.project_id = project_id or uuid4() mock.agent_instance_id = None mock.agent_type_id = None mock.session_id = "test-session" mock.task_type = task_type mock.task_description = "Test description" mock.actions = [] mock.context_summary = "Test context" mock.outcome = outcome mock.outcome_details = "" mock.duration_seconds = 30.0 mock.tokens_used = 100 mock.lessons_learned = [] mock.importance_score = importance_score mock.embedding = None mock.occurred_at = occurred_at or datetime.now(UTC) mock.created_at = datetime.now(UTC) mock.updated_at = datetime.now(UTC) return mock class TestRetrievalStrategy: """Tests for RetrievalStrategy enum.""" def test_strategy_values(self) -> None: """Test that strategy enum has expected values.""" assert RetrievalStrategy.SEMANTIC == "semantic" assert RetrievalStrategy.RECENCY == "recency" assert RetrievalStrategy.OUTCOME == "outcome" assert RetrievalStrategy.IMPORTANCE == "importance" assert RetrievalStrategy.HYBRID == "hybrid" class TestRecencyRetriever: """Tests for RecencyRetriever.""" @pytest.fixture def retriever(self) -> RecencyRetriever: """Create a recency retriever.""" return RecencyRetriever() @pytest.fixture def mock_session(self) -> AsyncMock: """Create a mock database session.""" return AsyncMock() @pytest.mark.asyncio async def test_retrieve_returns_episodes( self, retriever: RecencyRetriever, mock_session: AsyncMock, ) -> None: """Test that retrieve returns episodes.""" project_id = uuid4() mock_episodes = [ create_mock_episode_model(project_id=project_id), create_mock_episode_model(project_id=project_id), ] # Mock query result mock_result = MagicMock() mock_result.scalars.return_value.all.return_value = mock_episodes mock_session.execute.return_value = mock_result result = await retriever.retrieve(mock_session, project_id, limit=10) assert len(result.items) == 2 assert result.retrieval_type == "recency" assert result.latency_ms >= 0 @pytest.mark.asyncio async def test_retrieve_with_since_filter( self, retriever: RecencyRetriever, mock_session: AsyncMock, ) -> None: """Test retrieve with since time filter.""" project_id = uuid4() since = datetime.now(UTC) - timedelta(hours=1) mock_result = MagicMock() mock_result.scalars.return_value.all.return_value = [] mock_session.execute.return_value = mock_result result = await retriever.retrieve( mock_session, project_id, limit=10, since=since ) assert result.metadata["since"] == since.isoformat() class TestOutcomeRetriever: """Tests for OutcomeRetriever.""" @pytest.fixture def retriever(self) -> OutcomeRetriever: """Create an outcome retriever.""" return OutcomeRetriever() @pytest.fixture def mock_session(self) -> AsyncMock: """Create a mock database session.""" return AsyncMock() @pytest.mark.asyncio async def test_retrieve_by_success( self, retriever: OutcomeRetriever, mock_session: AsyncMock, ) -> None: """Test retrieving successful episodes.""" project_id = uuid4() mock_episodes = [ create_mock_episode_model( project_id=project_id, outcome=EpisodeOutcome.SUCCESS ), ] mock_result = MagicMock() mock_result.scalars.return_value.all.return_value = mock_episodes mock_session.execute.return_value = mock_result result = await retriever.retrieve( mock_session, project_id, limit=10, outcome=Outcome.SUCCESS ) assert result.retrieval_type == "outcome" assert result.metadata["outcome"] == "success" @pytest.mark.asyncio async def test_retrieve_by_failure( self, retriever: OutcomeRetriever, mock_session: AsyncMock, ) -> None: """Test retrieving failed episodes.""" project_id = uuid4() mock_episodes = [ create_mock_episode_model( project_id=project_id, outcome=EpisodeOutcome.FAILURE ), ] mock_result = MagicMock() mock_result.scalars.return_value.all.return_value = mock_episodes mock_session.execute.return_value = mock_result result = await retriever.retrieve( mock_session, project_id, limit=10, outcome=Outcome.FAILURE ) assert result.metadata["outcome"] == "failure" class TestImportanceRetriever: """Tests for ImportanceRetriever.""" @pytest.fixture def retriever(self) -> ImportanceRetriever: """Create an importance retriever.""" return ImportanceRetriever() @pytest.fixture def mock_session(self) -> AsyncMock: """Create a mock database session.""" return AsyncMock() @pytest.mark.asyncio async def test_retrieve_by_importance( self, retriever: ImportanceRetriever, mock_session: AsyncMock, ) -> None: """Test retrieving by importance score.""" project_id = uuid4() mock_episodes = [ create_mock_episode_model(project_id=project_id, importance_score=0.9), create_mock_episode_model(project_id=project_id, importance_score=0.8), ] mock_result = MagicMock() mock_result.scalars.return_value.all.return_value = mock_episodes mock_session.execute.return_value = mock_result result = await retriever.retrieve( mock_session, project_id, limit=10, min_importance=0.7 ) assert result.retrieval_type == "importance" assert result.metadata["min_importance"] == 0.7 class TestTaskTypeRetriever: """Tests for TaskTypeRetriever.""" @pytest.fixture def retriever(self) -> TaskTypeRetriever: """Create a task type retriever.""" return TaskTypeRetriever() @pytest.fixture def mock_session(self) -> AsyncMock: """Create a mock database session.""" return AsyncMock() @pytest.mark.asyncio async def test_retrieve_by_task_type( self, retriever: TaskTypeRetriever, mock_session: AsyncMock, ) -> None: """Test retrieving by task type.""" project_id = uuid4() mock_episodes = [ create_mock_episode_model(project_id=project_id, task_type="code_review"), ] mock_result = MagicMock() mock_result.scalars.return_value.all.return_value = mock_episodes mock_session.execute.return_value = mock_result result = await retriever.retrieve( mock_session, project_id, limit=10, task_type="code_review" ) assert result.metadata["task_type"] == "code_review" class TestSemanticRetriever: """Tests for SemanticRetriever.""" @pytest.fixture def mock_session(self) -> AsyncMock: """Create a mock database session.""" return AsyncMock() @pytest.mark.asyncio async def test_retrieve_falls_back_without_query( self, mock_session: AsyncMock, ) -> None: """Test that semantic search falls back to recency without query.""" retriever = SemanticRetriever() project_id = uuid4() mock_result = MagicMock() mock_result.scalars.return_value.all.return_value = [] mock_session.execute.return_value = mock_result result = await retriever.retrieve(mock_session, project_id, limit=10) # Should fall back to recency assert result.retrieval_type == "semantic" @pytest.mark.asyncio async def test_retrieve_with_embedding_generator( self, mock_session: AsyncMock, ) -> None: """Test semantic retrieval with embedding generator.""" mock_embedding_gen = AsyncMock() mock_embedding_gen.generate = AsyncMock(return_value=[0.1] * 1536) retriever = SemanticRetriever(embedding_generator=mock_embedding_gen) project_id = uuid4() mock_result = MagicMock() mock_result.scalars.return_value.all.return_value = [] mock_session.execute.return_value = mock_result result = await retriever.retrieve( mock_session, project_id, limit=10, query_text="test query" ) assert result.retrieval_type == "semantic" class TestEpisodeRetriever: """Tests for unified EpisodeRetriever.""" @pytest.fixture def mock_session(self) -> AsyncMock: """Create a mock database session.""" session = AsyncMock() mock_result = MagicMock() mock_result.scalars.return_value.all.return_value = [] session.execute.return_value = mock_result return session @pytest.fixture def retriever(self, mock_session: AsyncMock) -> EpisodeRetriever: """Create an episode retriever.""" return EpisodeRetriever(session=mock_session) @pytest.mark.asyncio async def test_retrieve_with_recency_strategy( self, retriever: EpisodeRetriever, ) -> None: """Test retrieve with recency strategy.""" project_id = uuid4() result = await retriever.retrieve( project_id, RetrievalStrategy.RECENCY, limit=10 ) assert result.retrieval_type == "recency" @pytest.mark.asyncio async def test_retrieve_with_outcome_strategy( self, retriever: EpisodeRetriever, ) -> None: """Test retrieve with outcome strategy.""" project_id = uuid4() result = await retriever.retrieve( project_id, RetrievalStrategy.OUTCOME, limit=10 ) assert result.retrieval_type == "outcome" @pytest.mark.asyncio async def test_get_recent_convenience_method( self, retriever: EpisodeRetriever, ) -> None: """Test get_recent convenience method.""" project_id = uuid4() result = await retriever.get_recent(project_id, limit=5) assert result.retrieval_type == "recency" @pytest.mark.asyncio async def test_get_by_outcome_convenience_method( self, retriever: EpisodeRetriever, ) -> None: """Test get_by_outcome convenience method.""" project_id = uuid4() result = await retriever.get_by_outcome(project_id, Outcome.SUCCESS, limit=5) assert result.retrieval_type == "outcome" @pytest.mark.asyncio async def test_get_important_convenience_method( self, retriever: EpisodeRetriever, ) -> None: """Test get_important convenience method.""" project_id = uuid4() result = await retriever.get_important(project_id, limit=5, min_importance=0.8) assert result.retrieval_type == "importance" @pytest.mark.asyncio async def test_search_similar_convenience_method( self, retriever: EpisodeRetriever, ) -> None: """Test search_similar convenience method.""" project_id = uuid4() result = await retriever.search_similar(project_id, "test query", limit=5) assert result.retrieval_type == "semantic" @pytest.mark.asyncio async def test_unknown_strategy_raises_error( self, retriever: EpisodeRetriever, ) -> None: """Test that unknown strategy raises ValueError.""" project_id = uuid4() with pytest.raises(ValueError, match="Unknown retrieval strategy"): await retriever.retrieve(project_id, "invalid_strategy", limit=10) # type: ignore