Files
syndarix/backend/tests/unit/services/memory/episodic/test_retrieval.py
Felipe Cardoso 3554efe66a feat(memory): add episodic memory implementation (Issue #90)
Implements the episodic memory service for storing and retrieving
agent task execution experiences. This enables learning from past
successes and failures.

Components:
- EpisodicMemory: Main service class combining recording and retrieval
- EpisodeRecorder: Handles episode creation, importance scoring
- EpisodeRetriever: Multiple retrieval strategies (recency, semantic,
  outcome, importance, task type)

Key features:
- Records task completions with context, actions, outcomes
- Calculates importance scores based on outcome, duration, lessons
- Semantic search with fallback to recency when embeddings unavailable
- Full CRUD operations with statistics and summarization
- Comprehensive unit tests (50 tests, all passing)

Closes #90

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 02:08:16 +01:00

401 lines
12 KiB
Python

# tests/unit/services/memory/episodic/test_retrieval.py
"""Unit tests for episode retrieval strategies."""
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock
from uuid import uuid4
import pytest
from app.models.memory.enums import EpisodeOutcome
from app.services.memory.episodic.retrieval import (
EpisodeRetriever,
ImportanceRetriever,
OutcomeRetriever,
RecencyRetriever,
RetrievalStrategy,
SemanticRetriever,
TaskTypeRetriever,
)
from app.services.memory.types import Outcome
def create_mock_episode_model(
project_id=None,
outcome=EpisodeOutcome.SUCCESS,
task_type="test_task",
importance_score=0.5,
occurred_at=None,
):
"""Create a mock episode model for testing."""
mock = MagicMock()
mock.id = uuid4()
mock.project_id = project_id or uuid4()
mock.agent_instance_id = None
mock.agent_type_id = None
mock.session_id = "test-session"
mock.task_type = task_type
mock.task_description = "Test description"
mock.actions = []
mock.context_summary = "Test context"
mock.outcome = outcome
mock.outcome_details = ""
mock.duration_seconds = 30.0
mock.tokens_used = 100
mock.lessons_learned = []
mock.importance_score = importance_score
mock.embedding = None
mock.occurred_at = occurred_at or datetime.now(UTC)
mock.created_at = datetime.now(UTC)
mock.updated_at = datetime.now(UTC)
return mock
class TestRetrievalStrategy:
"""Tests for RetrievalStrategy enum."""
def test_strategy_values(self) -> None:
"""Test that strategy enum has expected values."""
assert RetrievalStrategy.SEMANTIC == "semantic"
assert RetrievalStrategy.RECENCY == "recency"
assert RetrievalStrategy.OUTCOME == "outcome"
assert RetrievalStrategy.IMPORTANCE == "importance"
assert RetrievalStrategy.HYBRID == "hybrid"
class TestRecencyRetriever:
"""Tests for RecencyRetriever."""
@pytest.fixture
def retriever(self) -> RecencyRetriever:
"""Create a recency retriever."""
return RecencyRetriever()
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
return AsyncMock()
@pytest.mark.asyncio
async def test_retrieve_returns_episodes(
self,
retriever: RecencyRetriever,
mock_session: AsyncMock,
) -> None:
"""Test that retrieve returns episodes."""
project_id = uuid4()
mock_episodes = [
create_mock_episode_model(project_id=project_id),
create_mock_episode_model(project_id=project_id),
]
# Mock query result
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = mock_episodes
mock_session.execute.return_value = mock_result
result = await retriever.retrieve(mock_session, project_id, limit=10)
assert len(result.items) == 2
assert result.retrieval_type == "recency"
assert result.latency_ms >= 0
@pytest.mark.asyncio
async def test_retrieve_with_since_filter(
self,
retriever: RecencyRetriever,
mock_session: AsyncMock,
) -> None:
"""Test retrieve with since time filter."""
project_id = uuid4()
since = datetime.now(UTC) - timedelta(hours=1)
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
mock_session.execute.return_value = mock_result
result = await retriever.retrieve(
mock_session, project_id, limit=10, since=since
)
assert result.metadata["since"] == since.isoformat()
class TestOutcomeRetriever:
"""Tests for OutcomeRetriever."""
@pytest.fixture
def retriever(self) -> OutcomeRetriever:
"""Create an outcome retriever."""
return OutcomeRetriever()
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
return AsyncMock()
@pytest.mark.asyncio
async def test_retrieve_by_success(
self,
retriever: OutcomeRetriever,
mock_session: AsyncMock,
) -> None:
"""Test retrieving successful episodes."""
project_id = uuid4()
mock_episodes = [
create_mock_episode_model(
project_id=project_id, outcome=EpisodeOutcome.SUCCESS
),
]
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = mock_episodes
mock_session.execute.return_value = mock_result
result = await retriever.retrieve(
mock_session, project_id, limit=10, outcome=Outcome.SUCCESS
)
assert result.retrieval_type == "outcome"
assert result.metadata["outcome"] == "success"
@pytest.mark.asyncio
async def test_retrieve_by_failure(
self,
retriever: OutcomeRetriever,
mock_session: AsyncMock,
) -> None:
"""Test retrieving failed episodes."""
project_id = uuid4()
mock_episodes = [
create_mock_episode_model(
project_id=project_id, outcome=EpisodeOutcome.FAILURE
),
]
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = mock_episodes
mock_session.execute.return_value = mock_result
result = await retriever.retrieve(
mock_session, project_id, limit=10, outcome=Outcome.FAILURE
)
assert result.metadata["outcome"] == "failure"
class TestImportanceRetriever:
"""Tests for ImportanceRetriever."""
@pytest.fixture
def retriever(self) -> ImportanceRetriever:
"""Create an importance retriever."""
return ImportanceRetriever()
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
return AsyncMock()
@pytest.mark.asyncio
async def test_retrieve_by_importance(
self,
retriever: ImportanceRetriever,
mock_session: AsyncMock,
) -> None:
"""Test retrieving by importance score."""
project_id = uuid4()
mock_episodes = [
create_mock_episode_model(project_id=project_id, importance_score=0.9),
create_mock_episode_model(project_id=project_id, importance_score=0.8),
]
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = mock_episodes
mock_session.execute.return_value = mock_result
result = await retriever.retrieve(
mock_session, project_id, limit=10, min_importance=0.7
)
assert result.retrieval_type == "importance"
assert result.metadata["min_importance"] == 0.7
class TestTaskTypeRetriever:
"""Tests for TaskTypeRetriever."""
@pytest.fixture
def retriever(self) -> TaskTypeRetriever:
"""Create a task type retriever."""
return TaskTypeRetriever()
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
return AsyncMock()
@pytest.mark.asyncio
async def test_retrieve_by_task_type(
self,
retriever: TaskTypeRetriever,
mock_session: AsyncMock,
) -> None:
"""Test retrieving by task type."""
project_id = uuid4()
mock_episodes = [
create_mock_episode_model(project_id=project_id, task_type="code_review"),
]
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = mock_episodes
mock_session.execute.return_value = mock_result
result = await retriever.retrieve(
mock_session, project_id, limit=10, task_type="code_review"
)
assert result.metadata["task_type"] == "code_review"
class TestSemanticRetriever:
"""Tests for SemanticRetriever."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
return AsyncMock()
@pytest.mark.asyncio
async def test_retrieve_falls_back_without_query(
self,
mock_session: AsyncMock,
) -> None:
"""Test that semantic search falls back to recency without query."""
retriever = SemanticRetriever()
project_id = uuid4()
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
mock_session.execute.return_value = mock_result
result = await retriever.retrieve(mock_session, project_id, limit=10)
# Should fall back to recency
assert result.retrieval_type == "semantic"
@pytest.mark.asyncio
async def test_retrieve_with_embedding_generator(
self,
mock_session: AsyncMock,
) -> None:
"""Test semantic retrieval with embedding generator."""
mock_embedding_gen = AsyncMock()
mock_embedding_gen.generate = AsyncMock(return_value=[0.1] * 1536)
retriever = SemanticRetriever(embedding_generator=mock_embedding_gen)
project_id = uuid4()
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
mock_session.execute.return_value = mock_result
result = await retriever.retrieve(
mock_session, project_id, limit=10, query_text="test query"
)
assert result.retrieval_type == "semantic"
class TestEpisodeRetriever:
"""Tests for unified EpisodeRetriever."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
session.execute.return_value = mock_result
return session
@pytest.fixture
def retriever(self, mock_session: AsyncMock) -> EpisodeRetriever:
"""Create an episode retriever."""
return EpisodeRetriever(session=mock_session)
@pytest.mark.asyncio
async def test_retrieve_with_recency_strategy(
self,
retriever: EpisodeRetriever,
) -> None:
"""Test retrieve with recency strategy."""
project_id = uuid4()
result = await retriever.retrieve(
project_id, RetrievalStrategy.RECENCY, limit=10
)
assert result.retrieval_type == "recency"
@pytest.mark.asyncio
async def test_retrieve_with_outcome_strategy(
self,
retriever: EpisodeRetriever,
) -> None:
"""Test retrieve with outcome strategy."""
project_id = uuid4()
result = await retriever.retrieve(
project_id, RetrievalStrategy.OUTCOME, limit=10
)
assert result.retrieval_type == "outcome"
@pytest.mark.asyncio
async def test_get_recent_convenience_method(
self,
retriever: EpisodeRetriever,
) -> None:
"""Test get_recent convenience method."""
project_id = uuid4()
result = await retriever.get_recent(project_id, limit=5)
assert result.retrieval_type == "recency"
@pytest.mark.asyncio
async def test_get_by_outcome_convenience_method(
self,
retriever: EpisodeRetriever,
) -> None:
"""Test get_by_outcome convenience method."""
project_id = uuid4()
result = await retriever.get_by_outcome(project_id, Outcome.SUCCESS, limit=5)
assert result.retrieval_type == "outcome"
@pytest.mark.asyncio
async def test_get_important_convenience_method(
self,
retriever: EpisodeRetriever,
) -> None:
"""Test get_important convenience method."""
project_id = uuid4()
result = await retriever.get_important(project_id, limit=5, min_importance=0.8)
assert result.retrieval_type == "importance"
@pytest.mark.asyncio
async def test_search_similar_convenience_method(
self,
retriever: EpisodeRetriever,
) -> None:
"""Test search_similar convenience method."""
project_id = uuid4()
result = await retriever.search_similar(project_id, "test query", limit=5)
assert result.retrieval_type == "semantic"
@pytest.mark.asyncio
async def test_unknown_strategy_raises_error(
self,
retriever: EpisodeRetriever,
) -> None:
"""Test that unknown strategy raises ValueError."""
project_id = uuid4()
with pytest.raises(ValueError, match="Unknown retrieval strategy"):
await retriever.retrieve(project_id, "invalid_strategy", limit=10) # type: ignore