forked from cardosofelipe/fast-next-template
Implements the episodic memory service for storing and retrieving agent task execution experiences. This enables learning from past successes and failures. Components: - EpisodicMemory: Main service class combining recording and retrieval - EpisodeRecorder: Handles episode creation, importance scoring - EpisodeRetriever: Multiple retrieval strategies (recency, semantic, outcome, importance, task type) Key features: - Records task completions with context, actions, outcomes - Calculates importance scores based on outcome, duration, lessons - Semantic search with fallback to recency when embeddings unavailable - Full CRUD operations with statistics and summarization - Comprehensive unit tests (50 tests, all passing) Closes #90 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
401 lines
12 KiB
Python
401 lines
12 KiB
Python
# tests/unit/services/memory/episodic/test_retrieval.py
|
|
"""Unit tests for episode retrieval strategies."""
|
|
|
|
from datetime import UTC, datetime, timedelta
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
|
|
from app.models.memory.enums import EpisodeOutcome
|
|
from app.services.memory.episodic.retrieval import (
|
|
EpisodeRetriever,
|
|
ImportanceRetriever,
|
|
OutcomeRetriever,
|
|
RecencyRetriever,
|
|
RetrievalStrategy,
|
|
SemanticRetriever,
|
|
TaskTypeRetriever,
|
|
)
|
|
from app.services.memory.types import Outcome
|
|
|
|
|
|
def create_mock_episode_model(
|
|
project_id=None,
|
|
outcome=EpisodeOutcome.SUCCESS,
|
|
task_type="test_task",
|
|
importance_score=0.5,
|
|
occurred_at=None,
|
|
):
|
|
"""Create a mock episode model for testing."""
|
|
mock = MagicMock()
|
|
mock.id = uuid4()
|
|
mock.project_id = project_id or uuid4()
|
|
mock.agent_instance_id = None
|
|
mock.agent_type_id = None
|
|
mock.session_id = "test-session"
|
|
mock.task_type = task_type
|
|
mock.task_description = "Test description"
|
|
mock.actions = []
|
|
mock.context_summary = "Test context"
|
|
mock.outcome = outcome
|
|
mock.outcome_details = ""
|
|
mock.duration_seconds = 30.0
|
|
mock.tokens_used = 100
|
|
mock.lessons_learned = []
|
|
mock.importance_score = importance_score
|
|
mock.embedding = None
|
|
mock.occurred_at = occurred_at or datetime.now(UTC)
|
|
mock.created_at = datetime.now(UTC)
|
|
mock.updated_at = datetime.now(UTC)
|
|
return mock
|
|
|
|
|
|
class TestRetrievalStrategy:
|
|
"""Tests for RetrievalStrategy enum."""
|
|
|
|
def test_strategy_values(self) -> None:
|
|
"""Test that strategy enum has expected values."""
|
|
assert RetrievalStrategy.SEMANTIC == "semantic"
|
|
assert RetrievalStrategy.RECENCY == "recency"
|
|
assert RetrievalStrategy.OUTCOME == "outcome"
|
|
assert RetrievalStrategy.IMPORTANCE == "importance"
|
|
assert RetrievalStrategy.HYBRID == "hybrid"
|
|
|
|
|
|
class TestRecencyRetriever:
|
|
"""Tests for RecencyRetriever."""
|
|
|
|
@pytest.fixture
|
|
def retriever(self) -> RecencyRetriever:
|
|
"""Create a recency retriever."""
|
|
return RecencyRetriever()
|
|
|
|
@pytest.fixture
|
|
def mock_session(self) -> AsyncMock:
|
|
"""Create a mock database session."""
|
|
return AsyncMock()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retrieve_returns_episodes(
|
|
self,
|
|
retriever: RecencyRetriever,
|
|
mock_session: AsyncMock,
|
|
) -> None:
|
|
"""Test that retrieve returns episodes."""
|
|
project_id = uuid4()
|
|
mock_episodes = [
|
|
create_mock_episode_model(project_id=project_id),
|
|
create_mock_episode_model(project_id=project_id),
|
|
]
|
|
|
|
# Mock query result
|
|
mock_result = MagicMock()
|
|
mock_result.scalars.return_value.all.return_value = mock_episodes
|
|
mock_session.execute.return_value = mock_result
|
|
|
|
result = await retriever.retrieve(mock_session, project_id, limit=10)
|
|
|
|
assert len(result.items) == 2
|
|
assert result.retrieval_type == "recency"
|
|
assert result.latency_ms >= 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retrieve_with_since_filter(
|
|
self,
|
|
retriever: RecencyRetriever,
|
|
mock_session: AsyncMock,
|
|
) -> None:
|
|
"""Test retrieve with since time filter."""
|
|
project_id = uuid4()
|
|
since = datetime.now(UTC) - timedelta(hours=1)
|
|
|
|
mock_result = MagicMock()
|
|
mock_result.scalars.return_value.all.return_value = []
|
|
mock_session.execute.return_value = mock_result
|
|
|
|
result = await retriever.retrieve(
|
|
mock_session, project_id, limit=10, since=since
|
|
)
|
|
|
|
assert result.metadata["since"] == since.isoformat()
|
|
|
|
|
|
class TestOutcomeRetriever:
|
|
"""Tests for OutcomeRetriever."""
|
|
|
|
@pytest.fixture
|
|
def retriever(self) -> OutcomeRetriever:
|
|
"""Create an outcome retriever."""
|
|
return OutcomeRetriever()
|
|
|
|
@pytest.fixture
|
|
def mock_session(self) -> AsyncMock:
|
|
"""Create a mock database session."""
|
|
return AsyncMock()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retrieve_by_success(
|
|
self,
|
|
retriever: OutcomeRetriever,
|
|
mock_session: AsyncMock,
|
|
) -> None:
|
|
"""Test retrieving successful episodes."""
|
|
project_id = uuid4()
|
|
mock_episodes = [
|
|
create_mock_episode_model(
|
|
project_id=project_id, outcome=EpisodeOutcome.SUCCESS
|
|
),
|
|
]
|
|
|
|
mock_result = MagicMock()
|
|
mock_result.scalars.return_value.all.return_value = mock_episodes
|
|
mock_session.execute.return_value = mock_result
|
|
|
|
result = await retriever.retrieve(
|
|
mock_session, project_id, limit=10, outcome=Outcome.SUCCESS
|
|
)
|
|
|
|
assert result.retrieval_type == "outcome"
|
|
assert result.metadata["outcome"] == "success"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retrieve_by_failure(
|
|
self,
|
|
retriever: OutcomeRetriever,
|
|
mock_session: AsyncMock,
|
|
) -> None:
|
|
"""Test retrieving failed episodes."""
|
|
project_id = uuid4()
|
|
mock_episodes = [
|
|
create_mock_episode_model(
|
|
project_id=project_id, outcome=EpisodeOutcome.FAILURE
|
|
),
|
|
]
|
|
|
|
mock_result = MagicMock()
|
|
mock_result.scalars.return_value.all.return_value = mock_episodes
|
|
mock_session.execute.return_value = mock_result
|
|
|
|
result = await retriever.retrieve(
|
|
mock_session, project_id, limit=10, outcome=Outcome.FAILURE
|
|
)
|
|
|
|
assert result.metadata["outcome"] == "failure"
|
|
|
|
|
|
class TestImportanceRetriever:
|
|
"""Tests for ImportanceRetriever."""
|
|
|
|
@pytest.fixture
|
|
def retriever(self) -> ImportanceRetriever:
|
|
"""Create an importance retriever."""
|
|
return ImportanceRetriever()
|
|
|
|
@pytest.fixture
|
|
def mock_session(self) -> AsyncMock:
|
|
"""Create a mock database session."""
|
|
return AsyncMock()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retrieve_by_importance(
|
|
self,
|
|
retriever: ImportanceRetriever,
|
|
mock_session: AsyncMock,
|
|
) -> None:
|
|
"""Test retrieving by importance score."""
|
|
project_id = uuid4()
|
|
mock_episodes = [
|
|
create_mock_episode_model(project_id=project_id, importance_score=0.9),
|
|
create_mock_episode_model(project_id=project_id, importance_score=0.8),
|
|
]
|
|
|
|
mock_result = MagicMock()
|
|
mock_result.scalars.return_value.all.return_value = mock_episodes
|
|
mock_session.execute.return_value = mock_result
|
|
|
|
result = await retriever.retrieve(
|
|
mock_session, project_id, limit=10, min_importance=0.7
|
|
)
|
|
|
|
assert result.retrieval_type == "importance"
|
|
assert result.metadata["min_importance"] == 0.7
|
|
|
|
|
|
class TestTaskTypeRetriever:
|
|
"""Tests for TaskTypeRetriever."""
|
|
|
|
@pytest.fixture
|
|
def retriever(self) -> TaskTypeRetriever:
|
|
"""Create a task type retriever."""
|
|
return TaskTypeRetriever()
|
|
|
|
@pytest.fixture
|
|
def mock_session(self) -> AsyncMock:
|
|
"""Create a mock database session."""
|
|
return AsyncMock()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retrieve_by_task_type(
|
|
self,
|
|
retriever: TaskTypeRetriever,
|
|
mock_session: AsyncMock,
|
|
) -> None:
|
|
"""Test retrieving by task type."""
|
|
project_id = uuid4()
|
|
mock_episodes = [
|
|
create_mock_episode_model(project_id=project_id, task_type="code_review"),
|
|
]
|
|
|
|
mock_result = MagicMock()
|
|
mock_result.scalars.return_value.all.return_value = mock_episodes
|
|
mock_session.execute.return_value = mock_result
|
|
|
|
result = await retriever.retrieve(
|
|
mock_session, project_id, limit=10, task_type="code_review"
|
|
)
|
|
|
|
assert result.metadata["task_type"] == "code_review"
|
|
|
|
|
|
class TestSemanticRetriever:
|
|
"""Tests for SemanticRetriever."""
|
|
|
|
@pytest.fixture
|
|
def mock_session(self) -> AsyncMock:
|
|
"""Create a mock database session."""
|
|
return AsyncMock()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retrieve_falls_back_without_query(
|
|
self,
|
|
mock_session: AsyncMock,
|
|
) -> None:
|
|
"""Test that semantic search falls back to recency without query."""
|
|
retriever = SemanticRetriever()
|
|
project_id = uuid4()
|
|
|
|
mock_result = MagicMock()
|
|
mock_result.scalars.return_value.all.return_value = []
|
|
mock_session.execute.return_value = mock_result
|
|
|
|
result = await retriever.retrieve(mock_session, project_id, limit=10)
|
|
|
|
# Should fall back to recency
|
|
assert result.retrieval_type == "semantic"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retrieve_with_embedding_generator(
|
|
self,
|
|
mock_session: AsyncMock,
|
|
) -> None:
|
|
"""Test semantic retrieval with embedding generator."""
|
|
mock_embedding_gen = AsyncMock()
|
|
mock_embedding_gen.generate = AsyncMock(return_value=[0.1] * 1536)
|
|
|
|
retriever = SemanticRetriever(embedding_generator=mock_embedding_gen)
|
|
project_id = uuid4()
|
|
|
|
mock_result = MagicMock()
|
|
mock_result.scalars.return_value.all.return_value = []
|
|
mock_session.execute.return_value = mock_result
|
|
|
|
result = await retriever.retrieve(
|
|
mock_session, project_id, limit=10, query_text="test query"
|
|
)
|
|
|
|
assert result.retrieval_type == "semantic"
|
|
|
|
|
|
class TestEpisodeRetriever:
|
|
"""Tests for unified EpisodeRetriever."""
|
|
|
|
@pytest.fixture
|
|
def mock_session(self) -> AsyncMock:
|
|
"""Create a mock database session."""
|
|
session = AsyncMock()
|
|
mock_result = MagicMock()
|
|
mock_result.scalars.return_value.all.return_value = []
|
|
session.execute.return_value = mock_result
|
|
return session
|
|
|
|
@pytest.fixture
|
|
def retriever(self, mock_session: AsyncMock) -> EpisodeRetriever:
|
|
"""Create an episode retriever."""
|
|
return EpisodeRetriever(session=mock_session)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retrieve_with_recency_strategy(
|
|
self,
|
|
retriever: EpisodeRetriever,
|
|
) -> None:
|
|
"""Test retrieve with recency strategy."""
|
|
project_id = uuid4()
|
|
result = await retriever.retrieve(
|
|
project_id, RetrievalStrategy.RECENCY, limit=10
|
|
)
|
|
assert result.retrieval_type == "recency"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retrieve_with_outcome_strategy(
|
|
self,
|
|
retriever: EpisodeRetriever,
|
|
) -> None:
|
|
"""Test retrieve with outcome strategy."""
|
|
project_id = uuid4()
|
|
result = await retriever.retrieve(
|
|
project_id, RetrievalStrategy.OUTCOME, limit=10
|
|
)
|
|
assert result.retrieval_type == "outcome"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_recent_convenience_method(
|
|
self,
|
|
retriever: EpisodeRetriever,
|
|
) -> None:
|
|
"""Test get_recent convenience method."""
|
|
project_id = uuid4()
|
|
result = await retriever.get_recent(project_id, limit=5)
|
|
assert result.retrieval_type == "recency"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_by_outcome_convenience_method(
|
|
self,
|
|
retriever: EpisodeRetriever,
|
|
) -> None:
|
|
"""Test get_by_outcome convenience method."""
|
|
project_id = uuid4()
|
|
result = await retriever.get_by_outcome(project_id, Outcome.SUCCESS, limit=5)
|
|
assert result.retrieval_type == "outcome"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_important_convenience_method(
|
|
self,
|
|
retriever: EpisodeRetriever,
|
|
) -> None:
|
|
"""Test get_important convenience method."""
|
|
project_id = uuid4()
|
|
result = await retriever.get_important(project_id, limit=5, min_importance=0.8)
|
|
assert result.retrieval_type == "importance"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_similar_convenience_method(
|
|
self,
|
|
retriever: EpisodeRetriever,
|
|
) -> None:
|
|
"""Test search_similar convenience method."""
|
|
project_id = uuid4()
|
|
result = await retriever.search_similar(project_id, "test query", limit=5)
|
|
assert result.retrieval_type == "semantic"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unknown_strategy_raises_error(
|
|
self,
|
|
retriever: EpisodeRetriever,
|
|
) -> None:
|
|
"""Test that unknown strategy raises ValueError."""
|
|
project_id = uuid4()
|
|
|
|
with pytest.raises(ValueError, match="Unknown retrieval strategy"):
|
|
await retriever.retrieve(project_id, "invalid_strategy", limit=10) # type: ignore
|