forked from cardosofelipe/fast-next-template
feat(memory): add episodic memory implementation (Issue #90)
Implements the episodic memory service for storing and retrieving agent task execution experiences. This enables learning from past successes and failures. Components: - EpisodicMemory: Main service class combining recording and retrieval - EpisodeRecorder: Handles episode creation, importance scoring - EpisodeRetriever: Multiple retrieval strategies (recency, semantic, outcome, importance, task type) Key features: - Records task completions with context, actions, outcomes - Calculates importance scores based on outcome, duration, lessons - Semantic search with fallback to recency when embeddings unavailable - Full CRUD operations with statistics and summarization - Comprehensive unit tests (50 tests, all passing) Closes #90 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2
backend/tests/unit/services/memory/episodic/__init__.py
Normal file
2
backend/tests/unit/services/memory/episodic/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# tests/unit/services/memory/episodic/__init__.py
|
||||
"""Unit tests for episodic memory service."""
|
||||
359
backend/tests/unit/services/memory/episodic/test_memory.py
Normal file
359
backend/tests/unit/services/memory/episodic/test_memory.py
Normal file
@@ -0,0 +1,359 @@
|
||||
# tests/unit/services/memory/episodic/test_memory.py
|
||||
"""Unit tests for EpisodicMemory class."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.episodic.memory import EpisodicMemory
|
||||
from app.services.memory.episodic.retrieval import RetrievalStrategy
|
||||
from app.services.memory.types import EpisodeCreate, Outcome, RetrievalResult
|
||||
|
||||
|
||||
class TestEpisodicMemoryInit:
|
||||
"""Tests for EpisodicMemory initialization."""
|
||||
|
||||
def test_init_creates_recorder_and_retriever(self) -> None:
|
||||
"""Test that init creates recorder and retriever."""
|
||||
mock_session = AsyncMock()
|
||||
memory = EpisodicMemory(session=mock_session)
|
||||
|
||||
assert memory._recorder is not None
|
||||
assert memory._retriever is not None
|
||||
assert memory._session is mock_session
|
||||
|
||||
def test_init_with_embedding_generator(self) -> None:
|
||||
"""Test init with embedding generator."""
|
||||
mock_session = AsyncMock()
|
||||
mock_embedding_gen = AsyncMock()
|
||||
memory = EpisodicMemory(
|
||||
session=mock_session, embedding_generator=mock_embedding_gen
|
||||
)
|
||||
|
||||
assert memory._embedding_generator is mock_embedding_gen
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_factory_method(self) -> None:
|
||||
"""Test create factory method."""
|
||||
mock_session = AsyncMock()
|
||||
memory = await EpisodicMemory.create(session=mock_session)
|
||||
|
||||
assert memory is not None
|
||||
assert memory._session is mock_session
|
||||
|
||||
|
||||
class TestEpisodicMemoryRecording:
|
||||
"""Tests for episode recording methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
session.add = MagicMock()
|
||||
session.flush = AsyncMock()
|
||||
session.refresh = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> EpisodicMemory:
|
||||
"""Create an EpisodicMemory instance."""
|
||||
return EpisodicMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_episode(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
) -> None:
|
||||
"""Test recording an episode."""
|
||||
episode_data = EpisodeCreate(
|
||||
project_id=uuid4(),
|
||||
session_id="test-session",
|
||||
task_type="test_task",
|
||||
task_description="Test description",
|
||||
actions=[{"action": "test"}],
|
||||
context_summary="Test context",
|
||||
outcome=Outcome.SUCCESS,
|
||||
outcome_details="Success",
|
||||
duration_seconds=30.0,
|
||||
tokens_used=100,
|
||||
)
|
||||
|
||||
result = await memory.record_episode(episode_data)
|
||||
|
||||
assert result.project_id == episode_data.project_id
|
||||
assert result.task_type == "test_task"
|
||||
assert result.outcome == Outcome.SUCCESS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_success(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
) -> None:
|
||||
"""Test convenience method for recording success."""
|
||||
project_id = uuid4()
|
||||
result = await memory.record_success(
|
||||
project_id=project_id,
|
||||
session_id="test-session",
|
||||
task_type="deployment",
|
||||
task_description="Deploy to production",
|
||||
actions=[{"step": "deploy"}],
|
||||
context_summary="Deploying v1.0",
|
||||
outcome_details="Deployed successfully",
|
||||
duration_seconds=60.0,
|
||||
tokens_used=200,
|
||||
)
|
||||
|
||||
assert result.outcome == Outcome.SUCCESS
|
||||
assert result.task_type == "deployment"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_failure(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
) -> None:
|
||||
"""Test convenience method for recording failure."""
|
||||
project_id = uuid4()
|
||||
result = await memory.record_failure(
|
||||
project_id=project_id,
|
||||
session_id="test-session",
|
||||
task_type="deployment",
|
||||
task_description="Deploy to production",
|
||||
actions=[{"step": "deploy"}],
|
||||
context_summary="Deploying v1.0",
|
||||
error_details="Connection timeout",
|
||||
duration_seconds=30.0,
|
||||
tokens_used=100,
|
||||
)
|
||||
|
||||
assert result.outcome == Outcome.FAILURE
|
||||
assert result.outcome_details == "Connection timeout"
|
||||
|
||||
|
||||
class TestEpisodicMemoryRetrieval:
|
||||
"""Tests for episode retrieval methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
session.execute.return_value = mock_result
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> EpisodicMemory:
|
||||
"""Create an EpisodicMemory instance."""
|
||||
return EpisodicMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_similar(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
) -> None:
|
||||
"""Test semantic search."""
|
||||
project_id = uuid4()
|
||||
results = await memory.search_similar(project_id, "authentication bug")
|
||||
|
||||
assert isinstance(results, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_recent(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
) -> None:
|
||||
"""Test getting recent episodes."""
|
||||
project_id = uuid4()
|
||||
results = await memory.get_recent(project_id, limit=5)
|
||||
|
||||
assert isinstance(results, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_outcome(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
) -> None:
|
||||
"""Test getting episodes by outcome."""
|
||||
project_id = uuid4()
|
||||
results = await memory.get_by_outcome(project_id, Outcome.FAILURE, limit=5)
|
||||
|
||||
assert isinstance(results, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_task_type(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
) -> None:
|
||||
"""Test getting episodes by task type."""
|
||||
project_id = uuid4()
|
||||
results = await memory.get_by_task_type(project_id, "code_review", limit=5)
|
||||
|
||||
assert isinstance(results, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_important(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
) -> None:
|
||||
"""Test getting important episodes."""
|
||||
project_id = uuid4()
|
||||
results = await memory.get_important(project_id, limit=5, min_importance=0.8)
|
||||
|
||||
assert isinstance(results, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_with_full_result(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
) -> None:
|
||||
"""Test retrieve with full result metadata."""
|
||||
project_id = uuid4()
|
||||
result = await memory.retrieve(project_id, RetrievalStrategy.RECENCY, limit=10)
|
||||
|
||||
assert isinstance(result, RetrievalResult)
|
||||
assert result.retrieval_type == "recency"
|
||||
|
||||
|
||||
class TestEpisodicMemorySummarization:
|
||||
"""Tests for episode summarization."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> EpisodicMemory:
|
||||
"""Create an EpisodicMemory instance."""
|
||||
return EpisodicMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_empty_list(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
) -> None:
|
||||
"""Test summarizing empty episode list."""
|
||||
summary = await memory.summarize_episodes([])
|
||||
assert "No episodes to summarize" in summary
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_not_found(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test summarizing when episodes not found."""
|
||||
# Mock get_by_id to return None
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
summary = await memory.summarize_episodes([uuid4(), uuid4()])
|
||||
assert "No episodes found" in summary
|
||||
|
||||
|
||||
class TestEpisodicMemoryStats:
|
||||
"""Tests for episode statistics."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> EpisodicMemory:
|
||||
"""Create an EpisodicMemory instance."""
|
||||
return EpisodicMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stats(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test getting episode statistics."""
|
||||
# Mock empty result
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
stats = await memory.get_stats(uuid4())
|
||||
|
||||
assert "total_count" in stats
|
||||
assert "success_count" in stats
|
||||
assert "failure_count" in stats
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test counting episodes."""
|
||||
# Mock result with 3 episodes
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [1, 2, 3]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
count = await memory.count(uuid4())
|
||||
assert count == 3
|
||||
|
||||
|
||||
class TestEpisodicMemoryModification:
|
||||
"""Tests for episode modification methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> EpisodicMemory:
|
||||
"""Create an EpisodicMemory instance."""
|
||||
return EpisodicMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id_not_found(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test get_by_id returns None when not found."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await memory.get_by_id(uuid4())
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_importance_not_found(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test update_importance returns None when not found."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await memory.update_importance(uuid4(), 0.9)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_not_found(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test delete returns False when not found."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await memory.delete(uuid4())
|
||||
assert result is False
|
||||
348
backend/tests/unit/services/memory/episodic/test_recorder.py
Normal file
348
backend/tests/unit/services/memory/episodic/test_recorder.py
Normal file
@@ -0,0 +1,348 @@
|
||||
# tests/unit/services/memory/episodic/test_recorder.py
|
||||
"""Unit tests for EpisodeRecorder."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.memory.enums import EpisodeOutcome
|
||||
from app.services.memory.episodic.recorder import EpisodeRecorder, _outcome_to_db
|
||||
from app.services.memory.types import EpisodeCreate, Outcome
|
||||
|
||||
|
||||
class TestOutcomeConversion:
|
||||
"""Tests for outcome conversion functions."""
|
||||
|
||||
def test_outcome_to_db_success(self) -> None:
|
||||
"""Test converting success outcome."""
|
||||
result = _outcome_to_db(Outcome.SUCCESS)
|
||||
assert result == EpisodeOutcome.SUCCESS
|
||||
|
||||
def test_outcome_to_db_failure(self) -> None:
|
||||
"""Test converting failure outcome."""
|
||||
result = _outcome_to_db(Outcome.FAILURE)
|
||||
assert result == EpisodeOutcome.FAILURE
|
||||
|
||||
def test_outcome_to_db_partial(self) -> None:
|
||||
"""Test converting partial outcome."""
|
||||
result = _outcome_to_db(Outcome.PARTIAL)
|
||||
assert result == EpisodeOutcome.PARTIAL
|
||||
|
||||
|
||||
class TestEpisodeRecorderImportanceCalculation:
|
||||
"""Tests for importance score calculation."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def recorder(self, mock_session: AsyncMock) -> EpisodeRecorder:
|
||||
"""Create a recorder with mocked session."""
|
||||
return EpisodeRecorder(session=mock_session)
|
||||
|
||||
def test_calculate_importance_success_default(
|
||||
self, recorder: EpisodeRecorder
|
||||
) -> None:
|
||||
"""Test importance for successful episode (default)."""
|
||||
episode = EpisodeCreate(
|
||||
project_id=uuid4(),
|
||||
session_id="test-session",
|
||||
task_type="test",
|
||||
task_description="Test task",
|
||||
actions=[],
|
||||
context_summary="Context",
|
||||
outcome=Outcome.SUCCESS,
|
||||
outcome_details="",
|
||||
duration_seconds=10.0,
|
||||
tokens_used=100,
|
||||
)
|
||||
score = recorder._calculate_importance(episode)
|
||||
assert 0.0 <= score <= 1.0
|
||||
assert score == 0.5 # Base score for success
|
||||
|
||||
def test_calculate_importance_failure_higher(
|
||||
self, recorder: EpisodeRecorder
|
||||
) -> None:
|
||||
"""Test that failures get higher importance."""
|
||||
episode = EpisodeCreate(
|
||||
project_id=uuid4(),
|
||||
session_id="test-session",
|
||||
task_type="test",
|
||||
task_description="Test task",
|
||||
actions=[],
|
||||
context_summary="Context",
|
||||
outcome=Outcome.FAILURE,
|
||||
outcome_details="Error occurred",
|
||||
duration_seconds=10.0,
|
||||
tokens_used=100,
|
||||
)
|
||||
score = recorder._calculate_importance(episode)
|
||||
assert score >= 0.7 # Failure adds 0.2 to base 0.5
|
||||
|
||||
def test_calculate_importance_with_lessons(self, recorder: EpisodeRecorder) -> None:
|
||||
"""Test that lessons increase importance."""
|
||||
episode = EpisodeCreate(
|
||||
project_id=uuid4(),
|
||||
session_id="test-session",
|
||||
task_type="test",
|
||||
task_description="Test task",
|
||||
actions=[],
|
||||
context_summary="Context",
|
||||
outcome=Outcome.SUCCESS,
|
||||
outcome_details="",
|
||||
duration_seconds=10.0,
|
||||
tokens_used=100,
|
||||
lessons_learned=["Lesson 1", "Lesson 2"],
|
||||
)
|
||||
score = recorder._calculate_importance(episode)
|
||||
assert score > 0.5 # Lessons add to importance
|
||||
|
||||
def test_calculate_importance_long_duration(
|
||||
self, recorder: EpisodeRecorder
|
||||
) -> None:
|
||||
"""Test that longer tasks get higher importance."""
|
||||
episode = EpisodeCreate(
|
||||
project_id=uuid4(),
|
||||
session_id="test-session",
|
||||
task_type="test",
|
||||
task_description="Test task",
|
||||
actions=[],
|
||||
context_summary="Context",
|
||||
outcome=Outcome.SUCCESS,
|
||||
outcome_details="",
|
||||
duration_seconds=400.0, # > 300 seconds
|
||||
tokens_used=100,
|
||||
)
|
||||
score = recorder._calculate_importance(episode)
|
||||
assert score > 0.5 # Long duration adds to importance
|
||||
|
||||
def test_calculate_importance_clamped_to_max(
|
||||
self, recorder: EpisodeRecorder
|
||||
) -> None:
|
||||
"""Test that importance is clamped to 1.0 max."""
|
||||
episode = EpisodeCreate(
|
||||
project_id=uuid4(),
|
||||
session_id="test-session",
|
||||
task_type="test",
|
||||
task_description="Test task",
|
||||
actions=[],
|
||||
context_summary="Context",
|
||||
outcome=Outcome.FAILURE, # +0.2
|
||||
outcome_details="Error",
|
||||
duration_seconds=400.0, # +0.1
|
||||
tokens_used=2000, # +0.05
|
||||
lessons_learned=["L1", "L2", "L3", "L4", "L5"], # +0.15
|
||||
)
|
||||
score = recorder._calculate_importance(episode)
|
||||
assert score <= 1.0
|
||||
|
||||
|
||||
class TestEpisodeRecorderEmbeddingText:
|
||||
"""Tests for embedding text generation."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def recorder(self, mock_session: AsyncMock) -> EpisodeRecorder:
|
||||
"""Create a recorder with mocked session."""
|
||||
return EpisodeRecorder(session=mock_session)
|
||||
|
||||
def test_create_embedding_text_basic(self, recorder: EpisodeRecorder) -> None:
|
||||
"""Test basic embedding text creation."""
|
||||
episode = EpisodeCreate(
|
||||
project_id=uuid4(),
|
||||
session_id="test-session",
|
||||
task_type="code_review",
|
||||
task_description="Review PR #123",
|
||||
actions=[],
|
||||
context_summary="Reviewing authentication changes",
|
||||
outcome=Outcome.SUCCESS,
|
||||
outcome_details="",
|
||||
duration_seconds=60.0,
|
||||
tokens_used=500,
|
||||
)
|
||||
text = recorder._create_embedding_text(episode)
|
||||
assert "code_review" in text
|
||||
assert "Review PR #123" in text
|
||||
assert "authentication" in text
|
||||
assert "success" in text
|
||||
|
||||
def test_create_embedding_text_with_details(
|
||||
self, recorder: EpisodeRecorder
|
||||
) -> None:
|
||||
"""Test embedding text includes outcome details."""
|
||||
episode = EpisodeCreate(
|
||||
project_id=uuid4(),
|
||||
session_id="test-session",
|
||||
task_type="deployment",
|
||||
task_description="Deploy to production",
|
||||
actions=[],
|
||||
context_summary="Deploying v1.0.0",
|
||||
outcome=Outcome.FAILURE,
|
||||
outcome_details="Connection timeout to server",
|
||||
duration_seconds=30.0,
|
||||
tokens_used=200,
|
||||
)
|
||||
text = recorder._create_embedding_text(episode)
|
||||
assert "Connection timeout" in text
|
||||
|
||||
def test_create_embedding_text_with_lessons(
|
||||
self, recorder: EpisodeRecorder
|
||||
) -> None:
|
||||
"""Test embedding text includes lessons learned."""
|
||||
episode = EpisodeCreate(
|
||||
project_id=uuid4(),
|
||||
session_id="test-session",
|
||||
task_type="debugging",
|
||||
task_description="Fix memory leak",
|
||||
actions=[],
|
||||
context_summary="Debugging memory issues",
|
||||
outcome=Outcome.SUCCESS,
|
||||
outcome_details="",
|
||||
duration_seconds=120.0,
|
||||
tokens_used=800,
|
||||
lessons_learned=["Always close file handles", "Use context managers"],
|
||||
)
|
||||
text = recorder._create_embedding_text(episode)
|
||||
assert "Always close file handles" in text
|
||||
assert "context managers" in text
|
||||
|
||||
|
||||
class TestEpisodeRecorderRecord:
|
||||
"""Tests for episode recording."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
session.add = MagicMock()
|
||||
session.flush = AsyncMock()
|
||||
session.refresh = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def recorder(self, mock_session: AsyncMock) -> EpisodeRecorder:
|
||||
"""Create a recorder with mocked session."""
|
||||
return EpisodeRecorder(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_creates_episode(
|
||||
self,
|
||||
recorder: EpisodeRecorder,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test that record creates an episode."""
|
||||
episode_data = EpisodeCreate(
|
||||
project_id=uuid4(),
|
||||
session_id="test-session",
|
||||
task_type="test_task",
|
||||
task_description="Test description",
|
||||
actions=[{"action": "test"}],
|
||||
context_summary="Test context",
|
||||
outcome=Outcome.SUCCESS,
|
||||
outcome_details="Success",
|
||||
duration_seconds=30.0,
|
||||
tokens_used=100,
|
||||
)
|
||||
|
||||
result = await recorder.record(episode_data)
|
||||
|
||||
# Verify session methods were called
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.flush.assert_called_once()
|
||||
mock_session.refresh.assert_called_once()
|
||||
|
||||
# Verify result
|
||||
assert result.project_id == episode_data.project_id
|
||||
assert result.session_id == episode_data.session_id
|
||||
assert result.task_type == episode_data.task_type
|
||||
assert result.outcome == Outcome.SUCCESS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_with_embedding_generator(
|
||||
self,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test recording with embedding generator."""
|
||||
mock_embedding_gen = AsyncMock()
|
||||
mock_embedding_gen.generate = AsyncMock(return_value=[0.1] * 1536)
|
||||
|
||||
recorder = EpisodeRecorder(
|
||||
session=mock_session, embedding_generator=mock_embedding_gen
|
||||
)
|
||||
|
||||
episode_data = EpisodeCreate(
|
||||
project_id=uuid4(),
|
||||
session_id="test-session",
|
||||
task_type="test_task",
|
||||
task_description="Test description",
|
||||
actions=[],
|
||||
context_summary="Test context",
|
||||
outcome=Outcome.SUCCESS,
|
||||
outcome_details="",
|
||||
duration_seconds=10.0,
|
||||
tokens_used=50,
|
||||
)
|
||||
|
||||
await recorder.record(episode_data)
|
||||
|
||||
# Verify embedding generator was called
|
||||
mock_embedding_gen.generate.assert_called_once()
|
||||
|
||||
|
||||
class TestEpisodeRecorderStats:
|
||||
"""Tests for episode statistics."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def recorder(self, mock_session: AsyncMock) -> EpisodeRecorder:
|
||||
"""Create a recorder with mocked session."""
|
||||
return EpisodeRecorder(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stats_empty(
|
||||
self,
|
||||
recorder: EpisodeRecorder,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test stats for project with no episodes."""
|
||||
# Mock empty result
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
stats = await recorder.get_stats(uuid4())
|
||||
|
||||
assert stats["total_count"] == 0
|
||||
assert stats["success_count"] == 0
|
||||
assert stats["failure_count"] == 0
|
||||
assert stats["partial_count"] == 0
|
||||
assert stats["avg_importance"] == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_by_project(
|
||||
self,
|
||||
recorder: EpisodeRecorder,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test counting episodes by project."""
|
||||
# Mock result with 5 episodes
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [1, 2, 3, 4, 5]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
count = await recorder.count_by_project(uuid4())
|
||||
|
||||
assert count == 5
|
||||
400
backend/tests/unit/services/memory/episodic/test_retrieval.py
Normal file
400
backend/tests/unit/services/memory/episodic/test_retrieval.py
Normal file
@@ -0,0 +1,400 @@
|
||||
# tests/unit/services/memory/episodic/test_retrieval.py
|
||||
"""Unit tests for episode retrieval strategies."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.memory.enums import EpisodeOutcome
|
||||
from app.services.memory.episodic.retrieval import (
|
||||
EpisodeRetriever,
|
||||
ImportanceRetriever,
|
||||
OutcomeRetriever,
|
||||
RecencyRetriever,
|
||||
RetrievalStrategy,
|
||||
SemanticRetriever,
|
||||
TaskTypeRetriever,
|
||||
)
|
||||
from app.services.memory.types import Outcome
|
||||
|
||||
|
||||
def create_mock_episode_model(
|
||||
project_id=None,
|
||||
outcome=EpisodeOutcome.SUCCESS,
|
||||
task_type="test_task",
|
||||
importance_score=0.5,
|
||||
occurred_at=None,
|
||||
):
|
||||
"""Create a mock episode model for testing."""
|
||||
mock = MagicMock()
|
||||
mock.id = uuid4()
|
||||
mock.project_id = project_id or uuid4()
|
||||
mock.agent_instance_id = None
|
||||
mock.agent_type_id = None
|
||||
mock.session_id = "test-session"
|
||||
mock.task_type = task_type
|
||||
mock.task_description = "Test description"
|
||||
mock.actions = []
|
||||
mock.context_summary = "Test context"
|
||||
mock.outcome = outcome
|
||||
mock.outcome_details = ""
|
||||
mock.duration_seconds = 30.0
|
||||
mock.tokens_used = 100
|
||||
mock.lessons_learned = []
|
||||
mock.importance_score = importance_score
|
||||
mock.embedding = None
|
||||
mock.occurred_at = occurred_at or datetime.now(UTC)
|
||||
mock.created_at = datetime.now(UTC)
|
||||
mock.updated_at = datetime.now(UTC)
|
||||
return mock
|
||||
|
||||
|
||||
class TestRetrievalStrategy:
|
||||
"""Tests for RetrievalStrategy enum."""
|
||||
|
||||
def test_strategy_values(self) -> None:
|
||||
"""Test that strategy enum has expected values."""
|
||||
assert RetrievalStrategy.SEMANTIC == "semantic"
|
||||
assert RetrievalStrategy.RECENCY == "recency"
|
||||
assert RetrievalStrategy.OUTCOME == "outcome"
|
||||
assert RetrievalStrategy.IMPORTANCE == "importance"
|
||||
assert RetrievalStrategy.HYBRID == "hybrid"
|
||||
|
||||
|
||||
class TestRecencyRetriever:
|
||||
"""Tests for RecencyRetriever."""
|
||||
|
||||
@pytest.fixture
|
||||
def retriever(self) -> RecencyRetriever:
|
||||
"""Create a recency retriever."""
|
||||
return RecencyRetriever()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_returns_episodes(
|
||||
self,
|
||||
retriever: RecencyRetriever,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test that retrieve returns episodes."""
|
||||
project_id = uuid4()
|
||||
mock_episodes = [
|
||||
create_mock_episode_model(project_id=project_id),
|
||||
create_mock_episode_model(project_id=project_id),
|
||||
]
|
||||
|
||||
# Mock query result
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = mock_episodes
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await retriever.retrieve(mock_session, project_id, limit=10)
|
||||
|
||||
assert len(result.items) == 2
|
||||
assert result.retrieval_type == "recency"
|
||||
assert result.latency_ms >= 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_with_since_filter(
|
||||
self,
|
||||
retriever: RecencyRetriever,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test retrieve with since time filter."""
|
||||
project_id = uuid4()
|
||||
since = datetime.now(UTC) - timedelta(hours=1)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await retriever.retrieve(
|
||||
mock_session, project_id, limit=10, since=since
|
||||
)
|
||||
|
||||
assert result.metadata["since"] == since.isoformat()
|
||||
|
||||
|
||||
class TestOutcomeRetriever:
|
||||
"""Tests for OutcomeRetriever."""
|
||||
|
||||
@pytest.fixture
|
||||
def retriever(self) -> OutcomeRetriever:
|
||||
"""Create an outcome retriever."""
|
||||
return OutcomeRetriever()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_by_success(
|
||||
self,
|
||||
retriever: OutcomeRetriever,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test retrieving successful episodes."""
|
||||
project_id = uuid4()
|
||||
mock_episodes = [
|
||||
create_mock_episode_model(
|
||||
project_id=project_id, outcome=EpisodeOutcome.SUCCESS
|
||||
),
|
||||
]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = mock_episodes
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await retriever.retrieve(
|
||||
mock_session, project_id, limit=10, outcome=Outcome.SUCCESS
|
||||
)
|
||||
|
||||
assert result.retrieval_type == "outcome"
|
||||
assert result.metadata["outcome"] == "success"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_by_failure(
|
||||
self,
|
||||
retriever: OutcomeRetriever,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test retrieving failed episodes."""
|
||||
project_id = uuid4()
|
||||
mock_episodes = [
|
||||
create_mock_episode_model(
|
||||
project_id=project_id, outcome=EpisodeOutcome.FAILURE
|
||||
),
|
||||
]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = mock_episodes
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await retriever.retrieve(
|
||||
mock_session, project_id, limit=10, outcome=Outcome.FAILURE
|
||||
)
|
||||
|
||||
assert result.metadata["outcome"] == "failure"
|
||||
|
||||
|
||||
class TestImportanceRetriever:
|
||||
"""Tests for ImportanceRetriever."""
|
||||
|
||||
@pytest.fixture
|
||||
def retriever(self) -> ImportanceRetriever:
|
||||
"""Create an importance retriever."""
|
||||
return ImportanceRetriever()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_by_importance(
|
||||
self,
|
||||
retriever: ImportanceRetriever,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test retrieving by importance score."""
|
||||
project_id = uuid4()
|
||||
mock_episodes = [
|
||||
create_mock_episode_model(project_id=project_id, importance_score=0.9),
|
||||
create_mock_episode_model(project_id=project_id, importance_score=0.8),
|
||||
]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = mock_episodes
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await retriever.retrieve(
|
||||
mock_session, project_id, limit=10, min_importance=0.7
|
||||
)
|
||||
|
||||
assert result.retrieval_type == "importance"
|
||||
assert result.metadata["min_importance"] == 0.7
|
||||
|
||||
|
||||
class TestTaskTypeRetriever:
|
||||
"""Tests for TaskTypeRetriever."""
|
||||
|
||||
@pytest.fixture
|
||||
def retriever(self) -> TaskTypeRetriever:
|
||||
"""Create a task type retriever."""
|
||||
return TaskTypeRetriever()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_by_task_type(
|
||||
self,
|
||||
retriever: TaskTypeRetriever,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test retrieving by task type."""
|
||||
project_id = uuid4()
|
||||
mock_episodes = [
|
||||
create_mock_episode_model(project_id=project_id, task_type="code_review"),
|
||||
]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = mock_episodes
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await retriever.retrieve(
|
||||
mock_session, project_id, limit=10, task_type="code_review"
|
||||
)
|
||||
|
||||
assert result.metadata["task_type"] == "code_review"
|
||||
|
||||
|
||||
class TestSemanticRetriever:
|
||||
"""Tests for SemanticRetriever."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_falls_back_without_query(
|
||||
self,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test that semantic search falls back to recency without query."""
|
||||
retriever = SemanticRetriever()
|
||||
project_id = uuid4()
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await retriever.retrieve(mock_session, project_id, limit=10)
|
||||
|
||||
# Should fall back to recency
|
||||
assert result.retrieval_type == "semantic"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_with_embedding_generator(
|
||||
self,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test semantic retrieval with embedding generator."""
|
||||
mock_embedding_gen = AsyncMock()
|
||||
mock_embedding_gen.generate = AsyncMock(return_value=[0.1] * 1536)
|
||||
|
||||
retriever = SemanticRetriever(embedding_generator=mock_embedding_gen)
|
||||
project_id = uuid4()
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await retriever.retrieve(
|
||||
mock_session, project_id, limit=10, query_text="test query"
|
||||
)
|
||||
|
||||
assert result.retrieval_type == "semantic"
|
||||
|
||||
|
||||
class TestEpisodeRetriever:
|
||||
"""Tests for unified EpisodeRetriever."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
session.execute.return_value = mock_result
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def retriever(self, mock_session: AsyncMock) -> EpisodeRetriever:
|
||||
"""Create an episode retriever."""
|
||||
return EpisodeRetriever(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_with_recency_strategy(
|
||||
self,
|
||||
retriever: EpisodeRetriever,
|
||||
) -> None:
|
||||
"""Test retrieve with recency strategy."""
|
||||
project_id = uuid4()
|
||||
result = await retriever.retrieve(
|
||||
project_id, RetrievalStrategy.RECENCY, limit=10
|
||||
)
|
||||
assert result.retrieval_type == "recency"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_with_outcome_strategy(
|
||||
self,
|
||||
retriever: EpisodeRetriever,
|
||||
) -> None:
|
||||
"""Test retrieve with outcome strategy."""
|
||||
project_id = uuid4()
|
||||
result = await retriever.retrieve(
|
||||
project_id, RetrievalStrategy.OUTCOME, limit=10
|
||||
)
|
||||
assert result.retrieval_type == "outcome"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_recent_convenience_method(
|
||||
self,
|
||||
retriever: EpisodeRetriever,
|
||||
) -> None:
|
||||
"""Test get_recent convenience method."""
|
||||
project_id = uuid4()
|
||||
result = await retriever.get_recent(project_id, limit=5)
|
||||
assert result.retrieval_type == "recency"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_outcome_convenience_method(
|
||||
self,
|
||||
retriever: EpisodeRetriever,
|
||||
) -> None:
|
||||
"""Test get_by_outcome convenience method."""
|
||||
project_id = uuid4()
|
||||
result = await retriever.get_by_outcome(project_id, Outcome.SUCCESS, limit=5)
|
||||
assert result.retrieval_type == "outcome"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_important_convenience_method(
|
||||
self,
|
||||
retriever: EpisodeRetriever,
|
||||
) -> None:
|
||||
"""Test get_important convenience method."""
|
||||
project_id = uuid4()
|
||||
result = await retriever.get_important(project_id, limit=5, min_importance=0.8)
|
||||
assert result.retrieval_type == "importance"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_similar_convenience_method(
|
||||
self,
|
||||
retriever: EpisodeRetriever,
|
||||
) -> None:
|
||||
"""Test search_similar convenience method."""
|
||||
project_id = uuid4()
|
||||
result = await retriever.search_similar(project_id, "test query", limit=5)
|
||||
assert result.retrieval_type == "semantic"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_strategy_raises_error(
|
||||
self,
|
||||
retriever: EpisodeRetriever,
|
||||
) -> None:
|
||||
"""Test that unknown strategy raises ValueError."""
|
||||
project_id = uuid4()
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown retrieval strategy"):
|
||||
await retriever.retrieve(project_id, "invalid_strategy", limit=10) # type: ignore
|
||||
Reference in New Issue
Block a user