feat(memory): add episodic memory implementation (Issue #90)

Implements the episodic memory service for storing and retrieving
agent task execution experiences. This enables learning from past
successes and failures.

Components:
- EpisodicMemory: Main service class combining recording and retrieval
- EpisodeRecorder: Handles episode creation, importance scoring
- EpisodeRetriever: Multiple retrieval strategies (recency, semantic,
  outcome, importance, task type)

Key features:
- Records task completions with context, actions, outcomes
- Calculates importance scores based on outcome, duration, lessons
- Semantic search with fallback to recency when embeddings unavailable
- Full CRUD operations with statistics and summarization
- Comprehensive unit tests (50 tests, all passing)

Closes #90

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-05 02:08:16 +01:00
parent bd988f76b0
commit 3554efe66a
8 changed files with 2472 additions and 4 deletions

View File

@@ -0,0 +1,2 @@
# tests/unit/services/memory/episodic/__init__.py
"""Unit tests for episodic memory service."""

View File

@@ -0,0 +1,359 @@
# tests/unit/services/memory/episodic/test_memory.py
"""Unit tests for EpisodicMemory class."""
from unittest.mock import AsyncMock, MagicMock
from uuid import uuid4
import pytest
from app.services.memory.episodic.memory import EpisodicMemory
from app.services.memory.episodic.retrieval import RetrievalStrategy
from app.services.memory.types import EpisodeCreate, Outcome, RetrievalResult
class TestEpisodicMemoryInit:
"""Tests for EpisodicMemory initialization."""
def test_init_creates_recorder_and_retriever(self) -> None:
"""Test that init creates recorder and retriever."""
mock_session = AsyncMock()
memory = EpisodicMemory(session=mock_session)
assert memory._recorder is not None
assert memory._retriever is not None
assert memory._session is mock_session
def test_init_with_embedding_generator(self) -> None:
"""Test init with embedding generator."""
mock_session = AsyncMock()
mock_embedding_gen = AsyncMock()
memory = EpisodicMemory(
session=mock_session, embedding_generator=mock_embedding_gen
)
assert memory._embedding_generator is mock_embedding_gen
@pytest.mark.asyncio
async def test_create_factory_method(self) -> None:
"""Test create factory method."""
mock_session = AsyncMock()
memory = await EpisodicMemory.create(session=mock_session)
assert memory is not None
assert memory._session is mock_session
class TestEpisodicMemoryRecording:
"""Tests for episode recording methods."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
session.add = MagicMock()
session.flush = AsyncMock()
session.refresh = AsyncMock()
return session
@pytest.fixture
def memory(self, mock_session: AsyncMock) -> EpisodicMemory:
"""Create an EpisodicMemory instance."""
return EpisodicMemory(session=mock_session)
@pytest.mark.asyncio
async def test_record_episode(
self,
memory: EpisodicMemory,
) -> None:
"""Test recording an episode."""
episode_data = EpisodeCreate(
project_id=uuid4(),
session_id="test-session",
task_type="test_task",
task_description="Test description",
actions=[{"action": "test"}],
context_summary="Test context",
outcome=Outcome.SUCCESS,
outcome_details="Success",
duration_seconds=30.0,
tokens_used=100,
)
result = await memory.record_episode(episode_data)
assert result.project_id == episode_data.project_id
assert result.task_type == "test_task"
assert result.outcome == Outcome.SUCCESS
@pytest.mark.asyncio
async def test_record_success(
self,
memory: EpisodicMemory,
) -> None:
"""Test convenience method for recording success."""
project_id = uuid4()
result = await memory.record_success(
project_id=project_id,
session_id="test-session",
task_type="deployment",
task_description="Deploy to production",
actions=[{"step": "deploy"}],
context_summary="Deploying v1.0",
outcome_details="Deployed successfully",
duration_seconds=60.0,
tokens_used=200,
)
assert result.outcome == Outcome.SUCCESS
assert result.task_type == "deployment"
@pytest.mark.asyncio
async def test_record_failure(
self,
memory: EpisodicMemory,
) -> None:
"""Test convenience method for recording failure."""
project_id = uuid4()
result = await memory.record_failure(
project_id=project_id,
session_id="test-session",
task_type="deployment",
task_description="Deploy to production",
actions=[{"step": "deploy"}],
context_summary="Deploying v1.0",
error_details="Connection timeout",
duration_seconds=30.0,
tokens_used=100,
)
assert result.outcome == Outcome.FAILURE
assert result.outcome_details == "Connection timeout"
class TestEpisodicMemoryRetrieval:
"""Tests for episode retrieval methods."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
session.execute.return_value = mock_result
return session
@pytest.fixture
def memory(self, mock_session: AsyncMock) -> EpisodicMemory:
"""Create an EpisodicMemory instance."""
return EpisodicMemory(session=mock_session)
@pytest.mark.asyncio
async def test_search_similar(
self,
memory: EpisodicMemory,
) -> None:
"""Test semantic search."""
project_id = uuid4()
results = await memory.search_similar(project_id, "authentication bug")
assert isinstance(results, list)
@pytest.mark.asyncio
async def test_get_recent(
self,
memory: EpisodicMemory,
) -> None:
"""Test getting recent episodes."""
project_id = uuid4()
results = await memory.get_recent(project_id, limit=5)
assert isinstance(results, list)
@pytest.mark.asyncio
async def test_get_by_outcome(
self,
memory: EpisodicMemory,
) -> None:
"""Test getting episodes by outcome."""
project_id = uuid4()
results = await memory.get_by_outcome(project_id, Outcome.FAILURE, limit=5)
assert isinstance(results, list)
@pytest.mark.asyncio
async def test_get_by_task_type(
self,
memory: EpisodicMemory,
) -> None:
"""Test getting episodes by task type."""
project_id = uuid4()
results = await memory.get_by_task_type(project_id, "code_review", limit=5)
assert isinstance(results, list)
@pytest.mark.asyncio
async def test_get_important(
self,
memory: EpisodicMemory,
) -> None:
"""Test getting important episodes."""
project_id = uuid4()
results = await memory.get_important(project_id, limit=5, min_importance=0.8)
assert isinstance(results, list)
@pytest.mark.asyncio
async def test_retrieve_with_full_result(
self,
memory: EpisodicMemory,
) -> None:
"""Test retrieve with full result metadata."""
project_id = uuid4()
result = await memory.retrieve(project_id, RetrievalStrategy.RECENCY, limit=10)
assert isinstance(result, RetrievalResult)
assert result.retrieval_type == "recency"
class TestEpisodicMemorySummarization:
"""Tests for episode summarization."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
return session
@pytest.fixture
def memory(self, mock_session: AsyncMock) -> EpisodicMemory:
"""Create an EpisodicMemory instance."""
return EpisodicMemory(session=mock_session)
@pytest.mark.asyncio
async def test_summarize_empty_list(
self,
memory: EpisodicMemory,
) -> None:
"""Test summarizing empty episode list."""
summary = await memory.summarize_episodes([])
assert "No episodes to summarize" in summary
@pytest.mark.asyncio
async def test_summarize_not_found(
self,
memory: EpisodicMemory,
mock_session: AsyncMock,
) -> None:
"""Test summarizing when episodes not found."""
# Mock get_by_id to return None
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
summary = await memory.summarize_episodes([uuid4(), uuid4()])
assert "No episodes found" in summary
class TestEpisodicMemoryStats:
"""Tests for episode statistics."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
return session
@pytest.fixture
def memory(self, mock_session: AsyncMock) -> EpisodicMemory:
"""Create an EpisodicMemory instance."""
return EpisodicMemory(session=mock_session)
@pytest.mark.asyncio
async def test_get_stats(
self,
memory: EpisodicMemory,
mock_session: AsyncMock,
) -> None:
"""Test getting episode statistics."""
# Mock empty result
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
mock_session.execute.return_value = mock_result
stats = await memory.get_stats(uuid4())
assert "total_count" in stats
assert "success_count" in stats
assert "failure_count" in stats
@pytest.mark.asyncio
async def test_count(
self,
memory: EpisodicMemory,
mock_session: AsyncMock,
) -> None:
"""Test counting episodes."""
# Mock result with 3 episodes
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [1, 2, 3]
mock_session.execute.return_value = mock_result
count = await memory.count(uuid4())
assert count == 3
class TestEpisodicMemoryModification:
"""Tests for episode modification methods."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
return session
@pytest.fixture
def memory(self, mock_session: AsyncMock) -> EpisodicMemory:
"""Create an EpisodicMemory instance."""
return EpisodicMemory(session=mock_session)
@pytest.mark.asyncio
async def test_get_by_id_not_found(
self,
memory: EpisodicMemory,
mock_session: AsyncMock,
) -> None:
"""Test get_by_id returns None when not found."""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
result = await memory.get_by_id(uuid4())
assert result is None
@pytest.mark.asyncio
async def test_update_importance_not_found(
self,
memory: EpisodicMemory,
mock_session: AsyncMock,
) -> None:
"""Test update_importance returns None when not found."""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
result = await memory.update_importance(uuid4(), 0.9)
assert result is None
@pytest.mark.asyncio
async def test_delete_not_found(
self,
memory: EpisodicMemory,
mock_session: AsyncMock,
) -> None:
"""Test delete returns False when not found."""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
result = await memory.delete(uuid4())
assert result is False

View File

@@ -0,0 +1,348 @@
# tests/unit/services/memory/episodic/test_recorder.py
"""Unit tests for EpisodeRecorder."""
from unittest.mock import AsyncMock, MagicMock
from uuid import uuid4
import pytest
from app.models.memory.enums import EpisodeOutcome
from app.services.memory.episodic.recorder import EpisodeRecorder, _outcome_to_db
from app.services.memory.types import EpisodeCreate, Outcome
class TestOutcomeConversion:
"""Tests for outcome conversion functions."""
def test_outcome_to_db_success(self) -> None:
"""Test converting success outcome."""
result = _outcome_to_db(Outcome.SUCCESS)
assert result == EpisodeOutcome.SUCCESS
def test_outcome_to_db_failure(self) -> None:
"""Test converting failure outcome."""
result = _outcome_to_db(Outcome.FAILURE)
assert result == EpisodeOutcome.FAILURE
def test_outcome_to_db_partial(self) -> None:
"""Test converting partial outcome."""
result = _outcome_to_db(Outcome.PARTIAL)
assert result == EpisodeOutcome.PARTIAL
class TestEpisodeRecorderImportanceCalculation:
"""Tests for importance score calculation."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
return session
@pytest.fixture
def recorder(self, mock_session: AsyncMock) -> EpisodeRecorder:
"""Create a recorder with mocked session."""
return EpisodeRecorder(session=mock_session)
def test_calculate_importance_success_default(
self, recorder: EpisodeRecorder
) -> None:
"""Test importance for successful episode (default)."""
episode = EpisodeCreate(
project_id=uuid4(),
session_id="test-session",
task_type="test",
task_description="Test task",
actions=[],
context_summary="Context",
outcome=Outcome.SUCCESS,
outcome_details="",
duration_seconds=10.0,
tokens_used=100,
)
score = recorder._calculate_importance(episode)
assert 0.0 <= score <= 1.0
assert score == 0.5 # Base score for success
def test_calculate_importance_failure_higher(
self, recorder: EpisodeRecorder
) -> None:
"""Test that failures get higher importance."""
episode = EpisodeCreate(
project_id=uuid4(),
session_id="test-session",
task_type="test",
task_description="Test task",
actions=[],
context_summary="Context",
outcome=Outcome.FAILURE,
outcome_details="Error occurred",
duration_seconds=10.0,
tokens_used=100,
)
score = recorder._calculate_importance(episode)
assert score >= 0.7 # Failure adds 0.2 to base 0.5
def test_calculate_importance_with_lessons(self, recorder: EpisodeRecorder) -> None:
"""Test that lessons increase importance."""
episode = EpisodeCreate(
project_id=uuid4(),
session_id="test-session",
task_type="test",
task_description="Test task",
actions=[],
context_summary="Context",
outcome=Outcome.SUCCESS,
outcome_details="",
duration_seconds=10.0,
tokens_used=100,
lessons_learned=["Lesson 1", "Lesson 2"],
)
score = recorder._calculate_importance(episode)
assert score > 0.5 # Lessons add to importance
def test_calculate_importance_long_duration(
self, recorder: EpisodeRecorder
) -> None:
"""Test that longer tasks get higher importance."""
episode = EpisodeCreate(
project_id=uuid4(),
session_id="test-session",
task_type="test",
task_description="Test task",
actions=[],
context_summary="Context",
outcome=Outcome.SUCCESS,
outcome_details="",
duration_seconds=400.0, # > 300 seconds
tokens_used=100,
)
score = recorder._calculate_importance(episode)
assert score > 0.5 # Long duration adds to importance
def test_calculate_importance_clamped_to_max(
self, recorder: EpisodeRecorder
) -> None:
"""Test that importance is clamped to 1.0 max."""
episode = EpisodeCreate(
project_id=uuid4(),
session_id="test-session",
task_type="test",
task_description="Test task",
actions=[],
context_summary="Context",
outcome=Outcome.FAILURE, # +0.2
outcome_details="Error",
duration_seconds=400.0, # +0.1
tokens_used=2000, # +0.05
lessons_learned=["L1", "L2", "L3", "L4", "L5"], # +0.15
)
score = recorder._calculate_importance(episode)
assert score <= 1.0
class TestEpisodeRecorderEmbeddingText:
"""Tests for embedding text generation."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
return AsyncMock()
@pytest.fixture
def recorder(self, mock_session: AsyncMock) -> EpisodeRecorder:
"""Create a recorder with mocked session."""
return EpisodeRecorder(session=mock_session)
def test_create_embedding_text_basic(self, recorder: EpisodeRecorder) -> None:
"""Test basic embedding text creation."""
episode = EpisodeCreate(
project_id=uuid4(),
session_id="test-session",
task_type="code_review",
task_description="Review PR #123",
actions=[],
context_summary="Reviewing authentication changes",
outcome=Outcome.SUCCESS,
outcome_details="",
duration_seconds=60.0,
tokens_used=500,
)
text = recorder._create_embedding_text(episode)
assert "code_review" in text
assert "Review PR #123" in text
assert "authentication" in text
assert "success" in text
def test_create_embedding_text_with_details(
self, recorder: EpisodeRecorder
) -> None:
"""Test embedding text includes outcome details."""
episode = EpisodeCreate(
project_id=uuid4(),
session_id="test-session",
task_type="deployment",
task_description="Deploy to production",
actions=[],
context_summary="Deploying v1.0.0",
outcome=Outcome.FAILURE,
outcome_details="Connection timeout to server",
duration_seconds=30.0,
tokens_used=200,
)
text = recorder._create_embedding_text(episode)
assert "Connection timeout" in text
def test_create_embedding_text_with_lessons(
self, recorder: EpisodeRecorder
) -> None:
"""Test embedding text includes lessons learned."""
episode = EpisodeCreate(
project_id=uuid4(),
session_id="test-session",
task_type="debugging",
task_description="Fix memory leak",
actions=[],
context_summary="Debugging memory issues",
outcome=Outcome.SUCCESS,
outcome_details="",
duration_seconds=120.0,
tokens_used=800,
lessons_learned=["Always close file handles", "Use context managers"],
)
text = recorder._create_embedding_text(episode)
assert "Always close file handles" in text
assert "context managers" in text
class TestEpisodeRecorderRecord:
"""Tests for episode recording."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
session.add = MagicMock()
session.flush = AsyncMock()
session.refresh = AsyncMock()
return session
@pytest.fixture
def recorder(self, mock_session: AsyncMock) -> EpisodeRecorder:
"""Create a recorder with mocked session."""
return EpisodeRecorder(session=mock_session)
@pytest.mark.asyncio
async def test_record_creates_episode(
self,
recorder: EpisodeRecorder,
mock_session: AsyncMock,
) -> None:
"""Test that record creates an episode."""
episode_data = EpisodeCreate(
project_id=uuid4(),
session_id="test-session",
task_type="test_task",
task_description="Test description",
actions=[{"action": "test"}],
context_summary="Test context",
outcome=Outcome.SUCCESS,
outcome_details="Success",
duration_seconds=30.0,
tokens_used=100,
)
result = await recorder.record(episode_data)
# Verify session methods were called
mock_session.add.assert_called_once()
mock_session.flush.assert_called_once()
mock_session.refresh.assert_called_once()
# Verify result
assert result.project_id == episode_data.project_id
assert result.session_id == episode_data.session_id
assert result.task_type == episode_data.task_type
assert result.outcome == Outcome.SUCCESS
@pytest.mark.asyncio
async def test_record_with_embedding_generator(
self,
mock_session: AsyncMock,
) -> None:
"""Test recording with embedding generator."""
mock_embedding_gen = AsyncMock()
mock_embedding_gen.generate = AsyncMock(return_value=[0.1] * 1536)
recorder = EpisodeRecorder(
session=mock_session, embedding_generator=mock_embedding_gen
)
episode_data = EpisodeCreate(
project_id=uuid4(),
session_id="test-session",
task_type="test_task",
task_description="Test description",
actions=[],
context_summary="Test context",
outcome=Outcome.SUCCESS,
outcome_details="",
duration_seconds=10.0,
tokens_used=50,
)
await recorder.record(episode_data)
# Verify embedding generator was called
mock_embedding_gen.generate.assert_called_once()
class TestEpisodeRecorderStats:
"""Tests for episode statistics."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
return session
@pytest.fixture
def recorder(self, mock_session: AsyncMock) -> EpisodeRecorder:
"""Create a recorder with mocked session."""
return EpisodeRecorder(session=mock_session)
@pytest.mark.asyncio
async def test_get_stats_empty(
self,
recorder: EpisodeRecorder,
mock_session: AsyncMock,
) -> None:
"""Test stats for project with no episodes."""
# Mock empty result
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
mock_session.execute.return_value = mock_result
stats = await recorder.get_stats(uuid4())
assert stats["total_count"] == 0
assert stats["success_count"] == 0
assert stats["failure_count"] == 0
assert stats["partial_count"] == 0
assert stats["avg_importance"] == 0.0
@pytest.mark.asyncio
async def test_count_by_project(
self,
recorder: EpisodeRecorder,
mock_session: AsyncMock,
) -> None:
"""Test counting episodes by project."""
# Mock result with 5 episodes
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [1, 2, 3, 4, 5]
mock_session.execute.return_value = mock_result
count = await recorder.count_by_project(uuid4())
assert count == 5

View File

@@ -0,0 +1,400 @@
# tests/unit/services/memory/episodic/test_retrieval.py
"""Unit tests for episode retrieval strategies."""
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock
from uuid import uuid4
import pytest
from app.models.memory.enums import EpisodeOutcome
from app.services.memory.episodic.retrieval import (
EpisodeRetriever,
ImportanceRetriever,
OutcomeRetriever,
RecencyRetriever,
RetrievalStrategy,
SemanticRetriever,
TaskTypeRetriever,
)
from app.services.memory.types import Outcome
def create_mock_episode_model(
project_id=None,
outcome=EpisodeOutcome.SUCCESS,
task_type="test_task",
importance_score=0.5,
occurred_at=None,
):
"""Create a mock episode model for testing."""
mock = MagicMock()
mock.id = uuid4()
mock.project_id = project_id or uuid4()
mock.agent_instance_id = None
mock.agent_type_id = None
mock.session_id = "test-session"
mock.task_type = task_type
mock.task_description = "Test description"
mock.actions = []
mock.context_summary = "Test context"
mock.outcome = outcome
mock.outcome_details = ""
mock.duration_seconds = 30.0
mock.tokens_used = 100
mock.lessons_learned = []
mock.importance_score = importance_score
mock.embedding = None
mock.occurred_at = occurred_at or datetime.now(UTC)
mock.created_at = datetime.now(UTC)
mock.updated_at = datetime.now(UTC)
return mock
class TestRetrievalStrategy:
"""Tests for RetrievalStrategy enum."""
def test_strategy_values(self) -> None:
"""Test that strategy enum has expected values."""
assert RetrievalStrategy.SEMANTIC == "semantic"
assert RetrievalStrategy.RECENCY == "recency"
assert RetrievalStrategy.OUTCOME == "outcome"
assert RetrievalStrategy.IMPORTANCE == "importance"
assert RetrievalStrategy.HYBRID == "hybrid"
class TestRecencyRetriever:
"""Tests for RecencyRetriever."""
@pytest.fixture
def retriever(self) -> RecencyRetriever:
"""Create a recency retriever."""
return RecencyRetriever()
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
return AsyncMock()
@pytest.mark.asyncio
async def test_retrieve_returns_episodes(
self,
retriever: RecencyRetriever,
mock_session: AsyncMock,
) -> None:
"""Test that retrieve returns episodes."""
project_id = uuid4()
mock_episodes = [
create_mock_episode_model(project_id=project_id),
create_mock_episode_model(project_id=project_id),
]
# Mock query result
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = mock_episodes
mock_session.execute.return_value = mock_result
result = await retriever.retrieve(mock_session, project_id, limit=10)
assert len(result.items) == 2
assert result.retrieval_type == "recency"
assert result.latency_ms >= 0
@pytest.mark.asyncio
async def test_retrieve_with_since_filter(
self,
retriever: RecencyRetriever,
mock_session: AsyncMock,
) -> None:
"""Test retrieve with since time filter."""
project_id = uuid4()
since = datetime.now(UTC) - timedelta(hours=1)
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
mock_session.execute.return_value = mock_result
result = await retriever.retrieve(
mock_session, project_id, limit=10, since=since
)
assert result.metadata["since"] == since.isoformat()
class TestOutcomeRetriever:
"""Tests for OutcomeRetriever."""
@pytest.fixture
def retriever(self) -> OutcomeRetriever:
"""Create an outcome retriever."""
return OutcomeRetriever()
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
return AsyncMock()
@pytest.mark.asyncio
async def test_retrieve_by_success(
self,
retriever: OutcomeRetriever,
mock_session: AsyncMock,
) -> None:
"""Test retrieving successful episodes."""
project_id = uuid4()
mock_episodes = [
create_mock_episode_model(
project_id=project_id, outcome=EpisodeOutcome.SUCCESS
),
]
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = mock_episodes
mock_session.execute.return_value = mock_result
result = await retriever.retrieve(
mock_session, project_id, limit=10, outcome=Outcome.SUCCESS
)
assert result.retrieval_type == "outcome"
assert result.metadata["outcome"] == "success"
@pytest.mark.asyncio
async def test_retrieve_by_failure(
self,
retriever: OutcomeRetriever,
mock_session: AsyncMock,
) -> None:
"""Test retrieving failed episodes."""
project_id = uuid4()
mock_episodes = [
create_mock_episode_model(
project_id=project_id, outcome=EpisodeOutcome.FAILURE
),
]
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = mock_episodes
mock_session.execute.return_value = mock_result
result = await retriever.retrieve(
mock_session, project_id, limit=10, outcome=Outcome.FAILURE
)
assert result.metadata["outcome"] == "failure"
class TestImportanceRetriever:
"""Tests for ImportanceRetriever."""
@pytest.fixture
def retriever(self) -> ImportanceRetriever:
"""Create an importance retriever."""
return ImportanceRetriever()
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
return AsyncMock()
@pytest.mark.asyncio
async def test_retrieve_by_importance(
self,
retriever: ImportanceRetriever,
mock_session: AsyncMock,
) -> None:
"""Test retrieving by importance score."""
project_id = uuid4()
mock_episodes = [
create_mock_episode_model(project_id=project_id, importance_score=0.9),
create_mock_episode_model(project_id=project_id, importance_score=0.8),
]
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = mock_episodes
mock_session.execute.return_value = mock_result
result = await retriever.retrieve(
mock_session, project_id, limit=10, min_importance=0.7
)
assert result.retrieval_type == "importance"
assert result.metadata["min_importance"] == 0.7
class TestTaskTypeRetriever:
"""Tests for TaskTypeRetriever."""
@pytest.fixture
def retriever(self) -> TaskTypeRetriever:
"""Create a task type retriever."""
return TaskTypeRetriever()
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
return AsyncMock()
@pytest.mark.asyncio
async def test_retrieve_by_task_type(
self,
retriever: TaskTypeRetriever,
mock_session: AsyncMock,
) -> None:
"""Test retrieving by task type."""
project_id = uuid4()
mock_episodes = [
create_mock_episode_model(project_id=project_id, task_type="code_review"),
]
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = mock_episodes
mock_session.execute.return_value = mock_result
result = await retriever.retrieve(
mock_session, project_id, limit=10, task_type="code_review"
)
assert result.metadata["task_type"] == "code_review"
class TestSemanticRetriever:
"""Tests for SemanticRetriever."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
return AsyncMock()
@pytest.mark.asyncio
async def test_retrieve_falls_back_without_query(
self,
mock_session: AsyncMock,
) -> None:
"""Test that semantic search falls back to recency without query."""
retriever = SemanticRetriever()
project_id = uuid4()
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
mock_session.execute.return_value = mock_result
result = await retriever.retrieve(mock_session, project_id, limit=10)
# Should fall back to recency
assert result.retrieval_type == "semantic"
@pytest.mark.asyncio
async def test_retrieve_with_embedding_generator(
self,
mock_session: AsyncMock,
) -> None:
"""Test semantic retrieval with embedding generator."""
mock_embedding_gen = AsyncMock()
mock_embedding_gen.generate = AsyncMock(return_value=[0.1] * 1536)
retriever = SemanticRetriever(embedding_generator=mock_embedding_gen)
project_id = uuid4()
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
mock_session.execute.return_value = mock_result
result = await retriever.retrieve(
mock_session, project_id, limit=10, query_text="test query"
)
assert result.retrieval_type == "semantic"
class TestEpisodeRetriever:
"""Tests for unified EpisodeRetriever."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
session.execute.return_value = mock_result
return session
@pytest.fixture
def retriever(self, mock_session: AsyncMock) -> EpisodeRetriever:
"""Create an episode retriever."""
return EpisodeRetriever(session=mock_session)
@pytest.mark.asyncio
async def test_retrieve_with_recency_strategy(
self,
retriever: EpisodeRetriever,
) -> None:
"""Test retrieve with recency strategy."""
project_id = uuid4()
result = await retriever.retrieve(
project_id, RetrievalStrategy.RECENCY, limit=10
)
assert result.retrieval_type == "recency"
@pytest.mark.asyncio
async def test_retrieve_with_outcome_strategy(
self,
retriever: EpisodeRetriever,
) -> None:
"""Test retrieve with outcome strategy."""
project_id = uuid4()
result = await retriever.retrieve(
project_id, RetrievalStrategy.OUTCOME, limit=10
)
assert result.retrieval_type == "outcome"
@pytest.mark.asyncio
async def test_get_recent_convenience_method(
self,
retriever: EpisodeRetriever,
) -> None:
"""Test get_recent convenience method."""
project_id = uuid4()
result = await retriever.get_recent(project_id, limit=5)
assert result.retrieval_type == "recency"
@pytest.mark.asyncio
async def test_get_by_outcome_convenience_method(
self,
retriever: EpisodeRetriever,
) -> None:
"""Test get_by_outcome convenience method."""
project_id = uuid4()
result = await retriever.get_by_outcome(project_id, Outcome.SUCCESS, limit=5)
assert result.retrieval_type == "outcome"
@pytest.mark.asyncio
async def test_get_important_convenience_method(
self,
retriever: EpisodeRetriever,
) -> None:
"""Test get_important convenience method."""
project_id = uuid4()
result = await retriever.get_important(project_id, limit=5, min_importance=0.8)
assert result.retrieval_type == "importance"
@pytest.mark.asyncio
async def test_search_similar_convenience_method(
self,
retriever: EpisodeRetriever,
) -> None:
"""Test search_similar convenience method."""
project_id = uuid4()
result = await retriever.search_similar(project_id, "test query", limit=5)
assert result.retrieval_type == "semantic"
@pytest.mark.asyncio
async def test_unknown_strategy_raises_error(
self,
retriever: EpisodeRetriever,
) -> None:
"""Test that unknown strategy raises ValueError."""
project_id = uuid4()
with pytest.raises(ValueError, match="Unknown retrieval strategy"):
await retriever.retrieve(project_id, "invalid_strategy", limit=10) # type: ignore