feat(memory): add episodic memory implementation (Issue #90)
Implements the episodic memory service for storing and retrieving agent task execution experiences. This enables learning from past successes and failures. Components: - EpisodicMemory: Main service class combining recording and retrieval - EpisodeRecorder: Handles episode creation, importance scoring - EpisodeRetriever: Multiple retrieval strategies (recency, semantic, outcome, importance, task type) Key features: - Records task completions with context, actions, outcomes - Calculates importance scores based on outcome, duration, lessons - Semantic search with fallback to recency when embeddings unavailable - Full CRUD operations with statistics and summarization - Comprehensive unit tests (50 tests, all passing) Closes #90 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
400
backend/tests/unit/services/memory/episodic/test_retrieval.py
Normal file
400
backend/tests/unit/services/memory/episodic/test_retrieval.py
Normal file
@@ -0,0 +1,400 @@
|
||||
# tests/unit/services/memory/episodic/test_retrieval.py
|
||||
"""Unit tests for episode retrieval strategies."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.memory.enums import EpisodeOutcome
|
||||
from app.services.memory.episodic.retrieval import (
|
||||
EpisodeRetriever,
|
||||
ImportanceRetriever,
|
||||
OutcomeRetriever,
|
||||
RecencyRetriever,
|
||||
RetrievalStrategy,
|
||||
SemanticRetriever,
|
||||
TaskTypeRetriever,
|
||||
)
|
||||
from app.services.memory.types import Outcome
|
||||
|
||||
|
||||
def create_mock_episode_model(
|
||||
project_id=None,
|
||||
outcome=EpisodeOutcome.SUCCESS,
|
||||
task_type="test_task",
|
||||
importance_score=0.5,
|
||||
occurred_at=None,
|
||||
):
|
||||
"""Create a mock episode model for testing."""
|
||||
mock = MagicMock()
|
||||
mock.id = uuid4()
|
||||
mock.project_id = project_id or uuid4()
|
||||
mock.agent_instance_id = None
|
||||
mock.agent_type_id = None
|
||||
mock.session_id = "test-session"
|
||||
mock.task_type = task_type
|
||||
mock.task_description = "Test description"
|
||||
mock.actions = []
|
||||
mock.context_summary = "Test context"
|
||||
mock.outcome = outcome
|
||||
mock.outcome_details = ""
|
||||
mock.duration_seconds = 30.0
|
||||
mock.tokens_used = 100
|
||||
mock.lessons_learned = []
|
||||
mock.importance_score = importance_score
|
||||
mock.embedding = None
|
||||
mock.occurred_at = occurred_at or datetime.now(UTC)
|
||||
mock.created_at = datetime.now(UTC)
|
||||
mock.updated_at = datetime.now(UTC)
|
||||
return mock
|
||||
|
||||
|
||||
class TestRetrievalStrategy:
|
||||
"""Tests for RetrievalStrategy enum."""
|
||||
|
||||
def test_strategy_values(self) -> None:
|
||||
"""Test that strategy enum has expected values."""
|
||||
assert RetrievalStrategy.SEMANTIC == "semantic"
|
||||
assert RetrievalStrategy.RECENCY == "recency"
|
||||
assert RetrievalStrategy.OUTCOME == "outcome"
|
||||
assert RetrievalStrategy.IMPORTANCE == "importance"
|
||||
assert RetrievalStrategy.HYBRID == "hybrid"
|
||||
|
||||
|
||||
class TestRecencyRetriever:
|
||||
"""Tests for RecencyRetriever."""
|
||||
|
||||
@pytest.fixture
|
||||
def retriever(self) -> RecencyRetriever:
|
||||
"""Create a recency retriever."""
|
||||
return RecencyRetriever()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_returns_episodes(
|
||||
self,
|
||||
retriever: RecencyRetriever,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test that retrieve returns episodes."""
|
||||
project_id = uuid4()
|
||||
mock_episodes = [
|
||||
create_mock_episode_model(project_id=project_id),
|
||||
create_mock_episode_model(project_id=project_id),
|
||||
]
|
||||
|
||||
# Mock query result
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = mock_episodes
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await retriever.retrieve(mock_session, project_id, limit=10)
|
||||
|
||||
assert len(result.items) == 2
|
||||
assert result.retrieval_type == "recency"
|
||||
assert result.latency_ms >= 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_with_since_filter(
|
||||
self,
|
||||
retriever: RecencyRetriever,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test retrieve with since time filter."""
|
||||
project_id = uuid4()
|
||||
since = datetime.now(UTC) - timedelta(hours=1)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await retriever.retrieve(
|
||||
mock_session, project_id, limit=10, since=since
|
||||
)
|
||||
|
||||
assert result.metadata["since"] == since.isoformat()
|
||||
|
||||
|
||||
class TestOutcomeRetriever:
|
||||
"""Tests for OutcomeRetriever."""
|
||||
|
||||
@pytest.fixture
|
||||
def retriever(self) -> OutcomeRetriever:
|
||||
"""Create an outcome retriever."""
|
||||
return OutcomeRetriever()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_by_success(
|
||||
self,
|
||||
retriever: OutcomeRetriever,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test retrieving successful episodes."""
|
||||
project_id = uuid4()
|
||||
mock_episodes = [
|
||||
create_mock_episode_model(
|
||||
project_id=project_id, outcome=EpisodeOutcome.SUCCESS
|
||||
),
|
||||
]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = mock_episodes
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await retriever.retrieve(
|
||||
mock_session, project_id, limit=10, outcome=Outcome.SUCCESS
|
||||
)
|
||||
|
||||
assert result.retrieval_type == "outcome"
|
||||
assert result.metadata["outcome"] == "success"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_by_failure(
|
||||
self,
|
||||
retriever: OutcomeRetriever,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test retrieving failed episodes."""
|
||||
project_id = uuid4()
|
||||
mock_episodes = [
|
||||
create_mock_episode_model(
|
||||
project_id=project_id, outcome=EpisodeOutcome.FAILURE
|
||||
),
|
||||
]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = mock_episodes
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await retriever.retrieve(
|
||||
mock_session, project_id, limit=10, outcome=Outcome.FAILURE
|
||||
)
|
||||
|
||||
assert result.metadata["outcome"] == "failure"
|
||||
|
||||
|
||||
class TestImportanceRetriever:
|
||||
"""Tests for ImportanceRetriever."""
|
||||
|
||||
@pytest.fixture
|
||||
def retriever(self) -> ImportanceRetriever:
|
||||
"""Create an importance retriever."""
|
||||
return ImportanceRetriever()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_by_importance(
|
||||
self,
|
||||
retriever: ImportanceRetriever,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test retrieving by importance score."""
|
||||
project_id = uuid4()
|
||||
mock_episodes = [
|
||||
create_mock_episode_model(project_id=project_id, importance_score=0.9),
|
||||
create_mock_episode_model(project_id=project_id, importance_score=0.8),
|
||||
]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = mock_episodes
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await retriever.retrieve(
|
||||
mock_session, project_id, limit=10, min_importance=0.7
|
||||
)
|
||||
|
||||
assert result.retrieval_type == "importance"
|
||||
assert result.metadata["min_importance"] == 0.7
|
||||
|
||||
|
||||
class TestTaskTypeRetriever:
|
||||
"""Tests for TaskTypeRetriever."""
|
||||
|
||||
@pytest.fixture
|
||||
def retriever(self) -> TaskTypeRetriever:
|
||||
"""Create a task type retriever."""
|
||||
return TaskTypeRetriever()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_by_task_type(
|
||||
self,
|
||||
retriever: TaskTypeRetriever,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test retrieving by task type."""
|
||||
project_id = uuid4()
|
||||
mock_episodes = [
|
||||
create_mock_episode_model(project_id=project_id, task_type="code_review"),
|
||||
]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = mock_episodes
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await retriever.retrieve(
|
||||
mock_session, project_id, limit=10, task_type="code_review"
|
||||
)
|
||||
|
||||
assert result.metadata["task_type"] == "code_review"
|
||||
|
||||
|
||||
class TestSemanticRetriever:
|
||||
"""Tests for SemanticRetriever."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_falls_back_without_query(
|
||||
self,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test that semantic search falls back to recency without query."""
|
||||
retriever = SemanticRetriever()
|
||||
project_id = uuid4()
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await retriever.retrieve(mock_session, project_id, limit=10)
|
||||
|
||||
# Should fall back to recency
|
||||
assert result.retrieval_type == "semantic"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_with_embedding_generator(
|
||||
self,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test semantic retrieval with embedding generator."""
|
||||
mock_embedding_gen = AsyncMock()
|
||||
mock_embedding_gen.generate = AsyncMock(return_value=[0.1] * 1536)
|
||||
|
||||
retriever = SemanticRetriever(embedding_generator=mock_embedding_gen)
|
||||
project_id = uuid4()
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await retriever.retrieve(
|
||||
mock_session, project_id, limit=10, query_text="test query"
|
||||
)
|
||||
|
||||
assert result.retrieval_type == "semantic"
|
||||
|
||||
|
||||
class TestEpisodeRetriever:
|
||||
"""Tests for unified EpisodeRetriever."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
session.execute.return_value = mock_result
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def retriever(self, mock_session: AsyncMock) -> EpisodeRetriever:
|
||||
"""Create an episode retriever."""
|
||||
return EpisodeRetriever(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_with_recency_strategy(
|
||||
self,
|
||||
retriever: EpisodeRetriever,
|
||||
) -> None:
|
||||
"""Test retrieve with recency strategy."""
|
||||
project_id = uuid4()
|
||||
result = await retriever.retrieve(
|
||||
project_id, RetrievalStrategy.RECENCY, limit=10
|
||||
)
|
||||
assert result.retrieval_type == "recency"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_with_outcome_strategy(
|
||||
self,
|
||||
retriever: EpisodeRetriever,
|
||||
) -> None:
|
||||
"""Test retrieve with outcome strategy."""
|
||||
project_id = uuid4()
|
||||
result = await retriever.retrieve(
|
||||
project_id, RetrievalStrategy.OUTCOME, limit=10
|
||||
)
|
||||
assert result.retrieval_type == "outcome"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_recent_convenience_method(
|
||||
self,
|
||||
retriever: EpisodeRetriever,
|
||||
) -> None:
|
||||
"""Test get_recent convenience method."""
|
||||
project_id = uuid4()
|
||||
result = await retriever.get_recent(project_id, limit=5)
|
||||
assert result.retrieval_type == "recency"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_outcome_convenience_method(
|
||||
self,
|
||||
retriever: EpisodeRetriever,
|
||||
) -> None:
|
||||
"""Test get_by_outcome convenience method."""
|
||||
project_id = uuid4()
|
||||
result = await retriever.get_by_outcome(project_id, Outcome.SUCCESS, limit=5)
|
||||
assert result.retrieval_type == "outcome"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_important_convenience_method(
|
||||
self,
|
||||
retriever: EpisodeRetriever,
|
||||
) -> None:
|
||||
"""Test get_important convenience method."""
|
||||
project_id = uuid4()
|
||||
result = await retriever.get_important(project_id, limit=5, min_importance=0.8)
|
||||
assert result.retrieval_type == "importance"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_similar_convenience_method(
|
||||
self,
|
||||
retriever: EpisodeRetriever,
|
||||
) -> None:
|
||||
"""Test search_similar convenience method."""
|
||||
project_id = uuid4()
|
||||
result = await retriever.search_similar(project_id, "test query", limit=5)
|
||||
assert result.retrieval_type == "semantic"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_strategy_raises_error(
|
||||
self,
|
||||
retriever: EpisodeRetriever,
|
||||
) -> None:
|
||||
"""Test that unknown strategy raises ValueError."""
|
||||
project_id = uuid4()
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown retrieval strategy"):
|
||||
await retriever.retrieve(project_id, "invalid_strategy", limit=10) # type: ignore
|
||||
Reference in New Issue
Block a user