feat(memory): implement MCP tools for agent memory operations (#96)
Add MCP-compatible tools that expose memory operations to agents: Tools implemented: - remember: Store data in working, episodic, semantic, or procedural memory - recall: Retrieve memories by query across multiple memory types - forget: Delete specific keys or bulk delete by pattern - reflect: Analyze patterns in recent episodes (success/failure factors) - get_memory_stats: Return usage statistics and breakdowns - search_procedures: Find procedures matching trigger patterns - record_outcome: Record task outcomes and update procedure success rates Key components: - tools.py: Pydantic schemas for tool argument validation with comprehensive field constraints (importance 0-1, TTL limits, limit ranges) - service.py: MemoryToolService coordinating memory type operations with proper scoping via ToolContext (project_id, agent_instance_id, session_id) - Lazy initialization of memory services (WorkingMemory, EpisodicMemory, SemanticMemory, ProceduralMemory) Test coverage: - 60 tests covering tool definitions, argument validation, and service execution paths - Mock-based tests for all memory type interactions 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
651
backend/tests/unit/services/memory/mcp/test_service.py
Normal file
651
backend/tests/unit/services/memory/mcp/test_service.py
Normal file
@@ -0,0 +1,651 @@
|
||||
# tests/unit/services/memory/mcp/test_service.py
|
||||
"""Tests for MemoryToolService."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.mcp.service import (
|
||||
MemoryToolService,
|
||||
ToolContext,
|
||||
ToolResult,
|
||||
get_memory_tool_service,
|
||||
)
|
||||
from app.services.memory.mcp.tools import (
|
||||
AnalysisType,
|
||||
MemoryType,
|
||||
OutcomeType,
|
||||
)
|
||||
from app.services.memory.types import Outcome
|
||||
|
||||
pytestmark = pytest.mark.asyncio(loop_scope="function")
|
||||
|
||||
|
||||
def make_context(
|
||||
project_id: UUID | None = None,
|
||||
agent_instance_id: UUID | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> ToolContext:
|
||||
"""Create a test context."""
|
||||
return ToolContext(
|
||||
project_id=project_id or uuid4(),
|
||||
agent_instance_id=agent_instance_id or uuid4(),
|
||||
session_id=session_id or "test-session",
|
||||
)
|
||||
|
||||
|
||||
def make_mock_session() -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
session.execute = AsyncMock()
|
||||
session.commit = AsyncMock()
|
||||
session.flush = AsyncMock()
|
||||
return session
|
||||
|
||||
|
||||
class TestToolContext:
|
||||
"""Tests for ToolContext dataclass."""
|
||||
|
||||
def test_context_creation(self) -> None:
|
||||
"""Context should be creatable with required fields."""
|
||||
project_id = uuid4()
|
||||
ctx = ToolContext(project_id=project_id)
|
||||
assert ctx.project_id == project_id
|
||||
assert ctx.agent_instance_id is None
|
||||
assert ctx.session_id is None
|
||||
|
||||
def test_context_with_all_fields(self) -> None:
|
||||
"""Context should accept all optional fields."""
|
||||
project_id = uuid4()
|
||||
agent_id = uuid4()
|
||||
ctx = ToolContext(
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_id,
|
||||
agent_type_id=uuid4(),
|
||||
session_id="session-123",
|
||||
user_id=uuid4(),
|
||||
)
|
||||
assert ctx.project_id == project_id
|
||||
assert ctx.agent_instance_id == agent_id
|
||||
assert ctx.session_id == "session-123"
|
||||
|
||||
|
||||
class TestToolResult:
|
||||
"""Tests for ToolResult dataclass."""
|
||||
|
||||
def test_success_result(self) -> None:
|
||||
"""Success result should have correct fields."""
|
||||
result = ToolResult(
|
||||
success=True,
|
||||
data={"key": "value"},
|
||||
execution_time_ms=10.5,
|
||||
)
|
||||
assert result.success is True
|
||||
assert result.data == {"key": "value"}
|
||||
assert result.error is None
|
||||
|
||||
def test_error_result(self) -> None:
|
||||
"""Error result should have correct fields."""
|
||||
result = ToolResult(
|
||||
success=False,
|
||||
error="Something went wrong",
|
||||
error_code="VALIDATION_ERROR",
|
||||
)
|
||||
assert result.success is False
|
||||
assert result.error == "Something went wrong"
|
||||
assert result.error_code == "VALIDATION_ERROR"
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Result should convert to dict correctly."""
|
||||
result = ToolResult(
|
||||
success=True,
|
||||
data={"test": 1},
|
||||
execution_time_ms=5.0,
|
||||
)
|
||||
result_dict = result.to_dict()
|
||||
assert result_dict["success"] is True
|
||||
assert result_dict["data"] == {"test": 1}
|
||||
assert result_dict["execution_time_ms"] == 5.0
|
||||
|
||||
|
||||
class TestMemoryToolService:
|
||||
"""Tests for MemoryToolService."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock session."""
|
||||
return make_mock_session()
|
||||
|
||||
@pytest.fixture
|
||||
def service(self, mock_session: AsyncMock) -> MemoryToolService:
|
||||
"""Create a service with mock session."""
|
||||
return MemoryToolService(session=mock_session)
|
||||
|
||||
@pytest.fixture
|
||||
def context(self) -> ToolContext:
|
||||
"""Create a test context."""
|
||||
return make_context()
|
||||
|
||||
async def test_execute_unknown_tool(
|
||||
self,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Unknown tool should return error."""
|
||||
result = await service.execute_tool(
|
||||
tool_name="unknown_tool",
|
||||
arguments={},
|
||||
context=context,
|
||||
)
|
||||
assert result.success is False
|
||||
assert result.error_code == "UNKNOWN_TOOL"
|
||||
|
||||
async def test_execute_with_invalid_args(
|
||||
self,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Invalid arguments should return validation error."""
|
||||
result = await service.execute_tool(
|
||||
tool_name="remember",
|
||||
arguments={"memory_type": "invalid_type"},
|
||||
context=context,
|
||||
)
|
||||
assert result.success is False
|
||||
assert result.error_code == "VALIDATION_ERROR"
|
||||
|
||||
@patch("app.services.memory.mcp.service.WorkingMemory")
|
||||
async def test_remember_working_memory(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Remember should store in working memory."""
|
||||
# Setup mock
|
||||
mock_working = AsyncMock()
|
||||
mock_working.set = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="remember",
|
||||
arguments={
|
||||
"memory_type": "working",
|
||||
"content": "Test content",
|
||||
"key": "test_key",
|
||||
"ttl_seconds": 3600,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["stored"] is True
|
||||
assert result.data["memory_type"] == "working"
|
||||
assert result.data["key"] == "test_key"
|
||||
|
||||
async def test_remember_episodic_memory(
|
||||
self,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Remember should store in episodic memory."""
|
||||
with patch("app.services.memory.mcp.service.EpisodicMemory") as mock_episodic_cls:
|
||||
# Setup mock
|
||||
mock_episode = MagicMock()
|
||||
mock_episode.id = uuid4()
|
||||
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.record_episode = AsyncMock(return_value=mock_episode)
|
||||
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="remember",
|
||||
arguments={
|
||||
"memory_type": "episodic",
|
||||
"content": "Important event happened",
|
||||
"importance": 0.8,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["stored"] is True
|
||||
assert result.data["memory_type"] == "episodic"
|
||||
assert "episode_id" in result.data
|
||||
|
||||
async def test_remember_working_without_key(
|
||||
self,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Working memory without key should fail."""
|
||||
result = await service.execute_tool(
|
||||
tool_name="remember",
|
||||
arguments={
|
||||
"memory_type": "working",
|
||||
"content": "Test content",
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "key is required" in result.error.lower()
|
||||
|
||||
async def test_remember_working_without_session(
|
||||
self,
|
||||
service: MemoryToolService,
|
||||
) -> None:
|
||||
"""Working memory without session should fail."""
|
||||
context = ToolContext(project_id=uuid4(), session_id=None)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="remember",
|
||||
arguments={
|
||||
"memory_type": "working",
|
||||
"content": "Test content",
|
||||
"key": "test_key",
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "session id is required" in result.error.lower()
|
||||
|
||||
async def test_remember_semantic_memory(
|
||||
self,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Remember should store facts in semantic memory."""
|
||||
with patch("app.services.memory.mcp.service.SemanticMemory") as mock_semantic_cls:
|
||||
mock_fact = MagicMock()
|
||||
mock_fact.id = uuid4()
|
||||
|
||||
mock_semantic = AsyncMock()
|
||||
mock_semantic.store_fact = AsyncMock(return_value=mock_fact)
|
||||
mock_semantic_cls.create = AsyncMock(return_value=mock_semantic)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="remember",
|
||||
arguments={
|
||||
"memory_type": "semantic",
|
||||
"content": "User prefers dark mode",
|
||||
"subject": "User",
|
||||
"predicate": "prefers",
|
||||
"object_value": "dark mode",
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["memory_type"] == "semantic"
|
||||
assert "fact_id" in result.data
|
||||
assert "triple" in result.data
|
||||
|
||||
async def test_remember_semantic_without_fields(
|
||||
self,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Semantic memory without subject/predicate/object should fail."""
|
||||
result = await service.execute_tool(
|
||||
tool_name="remember",
|
||||
arguments={
|
||||
"memory_type": "semantic",
|
||||
"content": "Some content",
|
||||
"subject": "User",
|
||||
# Missing predicate and object_value
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "required" in result.error.lower()
|
||||
|
||||
async def test_remember_procedural_memory(
|
||||
self,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Remember should store procedures in procedural memory."""
|
||||
with patch("app.services.memory.mcp.service.ProceduralMemory") as mock_procedural_cls:
|
||||
mock_procedure = MagicMock()
|
||||
mock_procedure.id = uuid4()
|
||||
|
||||
mock_procedural = AsyncMock()
|
||||
mock_procedural.record_procedure = AsyncMock(return_value=mock_procedure)
|
||||
mock_procedural_cls.create = AsyncMock(return_value=mock_procedural)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="remember",
|
||||
arguments={
|
||||
"memory_type": "procedural",
|
||||
"content": "File creation procedure",
|
||||
"trigger": "When creating a new file",
|
||||
"steps": [
|
||||
{"action": "check_exists"},
|
||||
{"action": "create"},
|
||||
],
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["memory_type"] == "procedural"
|
||||
assert "procedure_id" in result.data
|
||||
assert result.data["steps_count"] == 2
|
||||
|
||||
@patch("app.services.memory.mcp.service.EpisodicMemory")
|
||||
@patch("app.services.memory.mcp.service.SemanticMemory")
|
||||
async def test_recall_from_multiple_types(
|
||||
self,
|
||||
mock_semantic_cls: MagicMock,
|
||||
mock_episodic_cls: MagicMock,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Recall should search across multiple memory types."""
|
||||
# Mock episodic
|
||||
mock_episode = MagicMock()
|
||||
mock_episode.id = uuid4()
|
||||
mock_episode.task_description = "Test episode"
|
||||
mock_episode.outcome = Outcome.SUCCESS
|
||||
mock_episode.occurred_at = datetime.now(UTC)
|
||||
mock_episode.importance_score = 0.9
|
||||
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.search_similar = AsyncMock(return_value=[mock_episode])
|
||||
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
|
||||
|
||||
# Mock semantic
|
||||
mock_fact = MagicMock()
|
||||
mock_fact.id = uuid4()
|
||||
mock_fact.subject = "User"
|
||||
mock_fact.predicate = "prefers"
|
||||
mock_fact.object = "dark mode"
|
||||
mock_fact.confidence = 0.8
|
||||
|
||||
mock_semantic = AsyncMock()
|
||||
mock_semantic.search_facts = AsyncMock(return_value=[mock_fact])
|
||||
mock_semantic_cls.create = AsyncMock(return_value=mock_semantic)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="recall",
|
||||
arguments={
|
||||
"query": "user preferences",
|
||||
"memory_types": ["episodic", "semantic"],
|
||||
"limit": 10,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["total_results"] == 2
|
||||
assert len(result.data["results"]) == 2
|
||||
|
||||
@patch("app.services.memory.mcp.service.WorkingMemory")
|
||||
async def test_forget_working_memory(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Forget should delete from working memory."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.delete = AsyncMock(return_value=True)
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="forget",
|
||||
arguments={
|
||||
"memory_type": "working",
|
||||
"key": "temp_key",
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["deleted"] is True
|
||||
assert result.data["deleted_count"] == 1
|
||||
|
||||
async def test_forget_pattern_requires_confirm(
|
||||
self,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Pattern deletion should require confirmation."""
|
||||
with patch("app.services.memory.mcp.service.WorkingMemory") as mock_working_cls:
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(return_value=["cache_1", "cache_2"])
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="forget",
|
||||
arguments={
|
||||
"memory_type": "working",
|
||||
"pattern": "cache_*",
|
||||
"confirm_bulk": False,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "confirm_bulk" in result.error.lower()
|
||||
|
||||
@patch("app.services.memory.mcp.service.EpisodicMemory")
|
||||
async def test_reflect_recent_patterns(
|
||||
self,
|
||||
mock_episodic_cls: MagicMock,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Reflect should analyze recent patterns."""
|
||||
# Create mock episodes
|
||||
mock_episodes = []
|
||||
for i in range(5):
|
||||
ep = MagicMock()
|
||||
ep.id = uuid4()
|
||||
ep.task_type = "code_review" if i % 2 == 0 else "deployment"
|
||||
ep.outcome = Outcome.SUCCESS if i < 3 else Outcome.FAILURE
|
||||
ep.task_description = f"Episode {i}"
|
||||
ep.lessons_learned = None
|
||||
ep.occurred_at = datetime.now(UTC)
|
||||
mock_episodes.append(ep)
|
||||
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.get_recent = AsyncMock(return_value=mock_episodes)
|
||||
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="reflect",
|
||||
arguments={
|
||||
"analysis_type": "recent_patterns",
|
||||
"depth": 3,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["analysis_type"] == "recent_patterns"
|
||||
assert result.data["total_episodes"] == 5
|
||||
assert "top_task_types" in result.data
|
||||
assert "outcome_distribution" in result.data
|
||||
|
||||
@patch("app.services.memory.mcp.service.EpisodicMemory")
|
||||
async def test_reflect_success_factors(
|
||||
self,
|
||||
mock_episodic_cls: MagicMock,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Reflect should analyze success factors."""
|
||||
mock_episodes = []
|
||||
for i in range(10):
|
||||
ep = MagicMock()
|
||||
ep.id = uuid4()
|
||||
ep.task_type = "code_review"
|
||||
ep.outcome = Outcome.SUCCESS if i < 8 else Outcome.FAILURE
|
||||
ep.task_description = f"Episode {i}"
|
||||
ep.lessons_learned = "Learned something" if i < 3 else None
|
||||
ep.occurred_at = datetime.now(UTC)
|
||||
mock_episodes.append(ep)
|
||||
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.get_recent = AsyncMock(return_value=mock_episodes)
|
||||
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="reflect",
|
||||
arguments={
|
||||
"analysis_type": "success_factors",
|
||||
"include_examples": True,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["analysis_type"] == "success_factors"
|
||||
assert result.data["overall_success_rate"] == 0.8
|
||||
|
||||
@patch("app.services.memory.mcp.service.EpisodicMemory")
|
||||
@patch("app.services.memory.mcp.service.SemanticMemory")
|
||||
@patch("app.services.memory.mcp.service.ProceduralMemory")
|
||||
@patch("app.services.memory.mcp.service.WorkingMemory")
|
||||
async def test_get_memory_stats(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
mock_procedural_cls: MagicMock,
|
||||
mock_semantic_cls: MagicMock,
|
||||
mock_episodic_cls: MagicMock,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Get memory stats should return statistics."""
|
||||
# Setup mocks
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(return_value=["key1", "key2"])
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.get_recent = AsyncMock(return_value=[MagicMock() for _ in range(10)])
|
||||
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
|
||||
|
||||
mock_semantic = AsyncMock()
|
||||
mock_semantic.search_facts = AsyncMock(return_value=[MagicMock() for _ in range(5)])
|
||||
mock_semantic_cls.create = AsyncMock(return_value=mock_semantic)
|
||||
|
||||
mock_procedural = AsyncMock()
|
||||
mock_procedural.find_matching = AsyncMock(return_value=[MagicMock() for _ in range(3)])
|
||||
mock_procedural_cls.create = AsyncMock(return_value=mock_procedural)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="get_memory_stats",
|
||||
arguments={
|
||||
"include_breakdown": True,
|
||||
"include_recent_activity": False,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert "breakdown" in result.data
|
||||
breakdown = result.data["breakdown"]
|
||||
assert breakdown["working"] == 2
|
||||
assert breakdown["episodic"] == 10
|
||||
assert breakdown["semantic"] == 5
|
||||
assert breakdown["procedural"] == 3
|
||||
assert breakdown["total"] == 20
|
||||
|
||||
@patch("app.services.memory.mcp.service.ProceduralMemory")
|
||||
async def test_search_procedures(
|
||||
self,
|
||||
mock_procedural_cls: MagicMock,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Search procedures should return matching procedures."""
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.id = uuid4()
|
||||
mock_proc.name = "Deployment procedure"
|
||||
mock_proc.description = "How to deploy"
|
||||
mock_proc.trigger = "When deploying"
|
||||
mock_proc.success_rate = 0.9
|
||||
mock_proc.execution_count = 10
|
||||
mock_proc.steps = [{"action": "deploy"}]
|
||||
|
||||
mock_procedural = AsyncMock()
|
||||
mock_procedural.find_matching = AsyncMock(return_value=[mock_proc])
|
||||
mock_procedural_cls.create = AsyncMock(return_value=mock_procedural)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="search_procedures",
|
||||
arguments={
|
||||
"trigger": "Deploying to production",
|
||||
"min_success_rate": 0.8,
|
||||
"include_steps": True,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["procedures_found"] == 1
|
||||
proc = result.data["procedures"][0]
|
||||
assert proc["name"] == "Deployment procedure"
|
||||
assert "steps" in proc
|
||||
|
||||
async def test_record_outcome(
|
||||
self,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Record outcome should store outcome and update procedure."""
|
||||
with (
|
||||
patch("app.services.memory.mcp.service.EpisodicMemory") as mock_episodic_cls,
|
||||
patch("app.services.memory.mcp.service.ProceduralMemory") as mock_procedural_cls,
|
||||
):
|
||||
mock_episode = MagicMock()
|
||||
mock_episode.id = uuid4()
|
||||
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.record_episode = AsyncMock(return_value=mock_episode)
|
||||
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
|
||||
|
||||
mock_procedural = AsyncMock()
|
||||
mock_procedural.record_outcome = AsyncMock()
|
||||
mock_procedural_cls.create = AsyncMock(return_value=mock_procedural)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="record_outcome",
|
||||
arguments={
|
||||
"task_type": "code_review",
|
||||
"outcome": "success",
|
||||
"lessons_learned": "Breaking changes caught early",
|
||||
"duration_seconds": 120.5,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["recorded"] is True
|
||||
assert result.data["outcome"] == "success"
|
||||
assert "episode_id" in result.data
|
||||
|
||||
|
||||
class TestGetMemoryToolService:
|
||||
"""Tests for get_memory_tool_service factory."""
|
||||
|
||||
async def test_creates_service(self) -> None:
|
||||
"""Factory should create a service."""
|
||||
mock_session = make_mock_session()
|
||||
service = await get_memory_tool_service(mock_session)
|
||||
assert isinstance(service, MemoryToolService)
|
||||
|
||||
async def test_accepts_embedding_generator(self) -> None:
|
||||
"""Factory should accept embedding generator."""
|
||||
mock_session = make_mock_session()
|
||||
mock_generator = MagicMock()
|
||||
service = await get_memory_tool_service(mock_session, mock_generator)
|
||||
assert service._embedding_generator is mock_generator
|
||||
Reference in New Issue
Block a user