feat(memory): implement MCP tools for agent memory operations (#96)

Add MCP-compatible tools that expose memory operations to agents:

Tools implemented:
- remember: Store data in working, episodic, semantic, or procedural memory
- recall: Retrieve memories by query across multiple memory types
- forget: Delete specific keys or bulk delete by pattern
- reflect: Analyze patterns in recent episodes (success/failure factors)
- get_memory_stats: Return usage statistics and breakdowns
- search_procedures: Find procedures matching trigger patterns
- record_outcome: Record task outcomes and update procedure success rates

Key components:
- tools.py: Pydantic schemas for tool argument validation with comprehensive
  field constraints (importance 0-1, TTL limits, limit ranges)
- service.py: MemoryToolService coordinating memory type operations with
  proper scoping via ToolContext (project_id, agent_instance_id, session_id)
- Lazy initialization of memory services (WorkingMemory, EpisodicMemory,
  SemanticMemory, ProceduralMemory)

Test coverage:
- 60 tests covering tool definitions, argument validation, and service
  execution paths
- Mock-based tests for all memory type interactions

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-05 03:32:10 +01:00
parent 1670e05e0d
commit 0b24d4c6cc
7 changed files with 2648 additions and 0 deletions

View File

@@ -0,0 +1,2 @@
# tests/unit/services/memory/mcp/__init__.py
"""Tests for memory MCP tools."""

View File

@@ -0,0 +1,651 @@
# tests/unit/services/memory/mcp/test_service.py
"""Tests for MemoryToolService."""
from datetime import UTC, datetime
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import UUID, uuid4
import pytest
from app.services.memory.mcp.service import (
MemoryToolService,
ToolContext,
ToolResult,
get_memory_tool_service,
)
from app.services.memory.mcp.tools import (
AnalysisType,
MemoryType,
OutcomeType,
)
from app.services.memory.types import Outcome
pytestmark = pytest.mark.asyncio(loop_scope="function")
def make_context(
project_id: UUID | None = None,
agent_instance_id: UUID | None = None,
session_id: str | None = None,
) -> ToolContext:
"""Create a test context."""
return ToolContext(
project_id=project_id or uuid4(),
agent_instance_id=agent_instance_id or uuid4(),
session_id=session_id or "test-session",
)
def make_mock_session() -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
session.execute = AsyncMock()
session.commit = AsyncMock()
session.flush = AsyncMock()
return session
class TestToolContext:
"""Tests for ToolContext dataclass."""
def test_context_creation(self) -> None:
"""Context should be creatable with required fields."""
project_id = uuid4()
ctx = ToolContext(project_id=project_id)
assert ctx.project_id == project_id
assert ctx.agent_instance_id is None
assert ctx.session_id is None
def test_context_with_all_fields(self) -> None:
"""Context should accept all optional fields."""
project_id = uuid4()
agent_id = uuid4()
ctx = ToolContext(
project_id=project_id,
agent_instance_id=agent_id,
agent_type_id=uuid4(),
session_id="session-123",
user_id=uuid4(),
)
assert ctx.project_id == project_id
assert ctx.agent_instance_id == agent_id
assert ctx.session_id == "session-123"
class TestToolResult:
"""Tests for ToolResult dataclass."""
def test_success_result(self) -> None:
"""Success result should have correct fields."""
result = ToolResult(
success=True,
data={"key": "value"},
execution_time_ms=10.5,
)
assert result.success is True
assert result.data == {"key": "value"}
assert result.error is None
def test_error_result(self) -> None:
"""Error result should have correct fields."""
result = ToolResult(
success=False,
error="Something went wrong",
error_code="VALIDATION_ERROR",
)
assert result.success is False
assert result.error == "Something went wrong"
assert result.error_code == "VALIDATION_ERROR"
def test_to_dict(self) -> None:
"""Result should convert to dict correctly."""
result = ToolResult(
success=True,
data={"test": 1},
execution_time_ms=5.0,
)
result_dict = result.to_dict()
assert result_dict["success"] is True
assert result_dict["data"] == {"test": 1}
assert result_dict["execution_time_ms"] == 5.0
class TestMemoryToolService:
"""Tests for MemoryToolService."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock session."""
return make_mock_session()
@pytest.fixture
def service(self, mock_session: AsyncMock) -> MemoryToolService:
"""Create a service with mock session."""
return MemoryToolService(session=mock_session)
@pytest.fixture
def context(self) -> ToolContext:
"""Create a test context."""
return make_context()
async def test_execute_unknown_tool(
self,
service: MemoryToolService,
context: ToolContext,
) -> None:
"""Unknown tool should return error."""
result = await service.execute_tool(
tool_name="unknown_tool",
arguments={},
context=context,
)
assert result.success is False
assert result.error_code == "UNKNOWN_TOOL"
async def test_execute_with_invalid_args(
self,
service: MemoryToolService,
context: ToolContext,
) -> None:
"""Invalid arguments should return validation error."""
result = await service.execute_tool(
tool_name="remember",
arguments={"memory_type": "invalid_type"},
context=context,
)
assert result.success is False
assert result.error_code == "VALIDATION_ERROR"
@patch("app.services.memory.mcp.service.WorkingMemory")
async def test_remember_working_memory(
self,
mock_working_cls: MagicMock,
service: MemoryToolService,
context: ToolContext,
) -> None:
"""Remember should store in working memory."""
# Setup mock
mock_working = AsyncMock()
mock_working.set = AsyncMock()
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
result = await service.execute_tool(
tool_name="remember",
arguments={
"memory_type": "working",
"content": "Test content",
"key": "test_key",
"ttl_seconds": 3600,
},
context=context,
)
assert result.success is True
assert result.data["stored"] is True
assert result.data["memory_type"] == "working"
assert result.data["key"] == "test_key"
async def test_remember_episodic_memory(
self,
service: MemoryToolService,
context: ToolContext,
) -> None:
"""Remember should store in episodic memory."""
with patch("app.services.memory.mcp.service.EpisodicMemory") as mock_episodic_cls:
# Setup mock
mock_episode = MagicMock()
mock_episode.id = uuid4()
mock_episodic = AsyncMock()
mock_episodic.record_episode = AsyncMock(return_value=mock_episode)
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
result = await service.execute_tool(
tool_name="remember",
arguments={
"memory_type": "episodic",
"content": "Important event happened",
"importance": 0.8,
},
context=context,
)
assert result.success is True
assert result.data["stored"] is True
assert result.data["memory_type"] == "episodic"
assert "episode_id" in result.data
async def test_remember_working_without_key(
self,
service: MemoryToolService,
context: ToolContext,
) -> None:
"""Working memory without key should fail."""
result = await service.execute_tool(
tool_name="remember",
arguments={
"memory_type": "working",
"content": "Test content",
},
context=context,
)
assert result.success is False
assert "key is required" in result.error.lower()
async def test_remember_working_without_session(
self,
service: MemoryToolService,
) -> None:
"""Working memory without session should fail."""
context = ToolContext(project_id=uuid4(), session_id=None)
result = await service.execute_tool(
tool_name="remember",
arguments={
"memory_type": "working",
"content": "Test content",
"key": "test_key",
},
context=context,
)
assert result.success is False
assert "session id is required" in result.error.lower()
async def test_remember_semantic_memory(
self,
service: MemoryToolService,
context: ToolContext,
) -> None:
"""Remember should store facts in semantic memory."""
with patch("app.services.memory.mcp.service.SemanticMemory") as mock_semantic_cls:
mock_fact = MagicMock()
mock_fact.id = uuid4()
mock_semantic = AsyncMock()
mock_semantic.store_fact = AsyncMock(return_value=mock_fact)
mock_semantic_cls.create = AsyncMock(return_value=mock_semantic)
result = await service.execute_tool(
tool_name="remember",
arguments={
"memory_type": "semantic",
"content": "User prefers dark mode",
"subject": "User",
"predicate": "prefers",
"object_value": "dark mode",
},
context=context,
)
assert result.success is True
assert result.data["memory_type"] == "semantic"
assert "fact_id" in result.data
assert "triple" in result.data
async def test_remember_semantic_without_fields(
self,
service: MemoryToolService,
context: ToolContext,
) -> None:
"""Semantic memory without subject/predicate/object should fail."""
result = await service.execute_tool(
tool_name="remember",
arguments={
"memory_type": "semantic",
"content": "Some content",
"subject": "User",
# Missing predicate and object_value
},
context=context,
)
assert result.success is False
assert "required" in result.error.lower()
async def test_remember_procedural_memory(
self,
service: MemoryToolService,
context: ToolContext,
) -> None:
"""Remember should store procedures in procedural memory."""
with patch("app.services.memory.mcp.service.ProceduralMemory") as mock_procedural_cls:
mock_procedure = MagicMock()
mock_procedure.id = uuid4()
mock_procedural = AsyncMock()
mock_procedural.record_procedure = AsyncMock(return_value=mock_procedure)
mock_procedural_cls.create = AsyncMock(return_value=mock_procedural)
result = await service.execute_tool(
tool_name="remember",
arguments={
"memory_type": "procedural",
"content": "File creation procedure",
"trigger": "When creating a new file",
"steps": [
{"action": "check_exists"},
{"action": "create"},
],
},
context=context,
)
assert result.success is True
assert result.data["memory_type"] == "procedural"
assert "procedure_id" in result.data
assert result.data["steps_count"] == 2
@patch("app.services.memory.mcp.service.EpisodicMemory")
@patch("app.services.memory.mcp.service.SemanticMemory")
async def test_recall_from_multiple_types(
self,
mock_semantic_cls: MagicMock,
mock_episodic_cls: MagicMock,
service: MemoryToolService,
context: ToolContext,
) -> None:
"""Recall should search across multiple memory types."""
# Mock episodic
mock_episode = MagicMock()
mock_episode.id = uuid4()
mock_episode.task_description = "Test episode"
mock_episode.outcome = Outcome.SUCCESS
mock_episode.occurred_at = datetime.now(UTC)
mock_episode.importance_score = 0.9
mock_episodic = AsyncMock()
mock_episodic.search_similar = AsyncMock(return_value=[mock_episode])
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
# Mock semantic
mock_fact = MagicMock()
mock_fact.id = uuid4()
mock_fact.subject = "User"
mock_fact.predicate = "prefers"
mock_fact.object = "dark mode"
mock_fact.confidence = 0.8
mock_semantic = AsyncMock()
mock_semantic.search_facts = AsyncMock(return_value=[mock_fact])
mock_semantic_cls.create = AsyncMock(return_value=mock_semantic)
result = await service.execute_tool(
tool_name="recall",
arguments={
"query": "user preferences",
"memory_types": ["episodic", "semantic"],
"limit": 10,
},
context=context,
)
assert result.success is True
assert result.data["total_results"] == 2
assert len(result.data["results"]) == 2
@patch("app.services.memory.mcp.service.WorkingMemory")
async def test_forget_working_memory(
self,
mock_working_cls: MagicMock,
service: MemoryToolService,
context: ToolContext,
) -> None:
"""Forget should delete from working memory."""
mock_working = AsyncMock()
mock_working.delete = AsyncMock(return_value=True)
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
result = await service.execute_tool(
tool_name="forget",
arguments={
"memory_type": "working",
"key": "temp_key",
},
context=context,
)
assert result.success is True
assert result.data["deleted"] is True
assert result.data["deleted_count"] == 1
async def test_forget_pattern_requires_confirm(
self,
service: MemoryToolService,
context: ToolContext,
) -> None:
"""Pattern deletion should require confirmation."""
with patch("app.services.memory.mcp.service.WorkingMemory") as mock_working_cls:
mock_working = AsyncMock()
mock_working.list_keys = AsyncMock(return_value=["cache_1", "cache_2"])
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
result = await service.execute_tool(
tool_name="forget",
arguments={
"memory_type": "working",
"pattern": "cache_*",
"confirm_bulk": False,
},
context=context,
)
assert result.success is False
assert "confirm_bulk" in result.error.lower()
@patch("app.services.memory.mcp.service.EpisodicMemory")
async def test_reflect_recent_patterns(
self,
mock_episodic_cls: MagicMock,
service: MemoryToolService,
context: ToolContext,
) -> None:
"""Reflect should analyze recent patterns."""
# Create mock episodes
mock_episodes = []
for i in range(5):
ep = MagicMock()
ep.id = uuid4()
ep.task_type = "code_review" if i % 2 == 0 else "deployment"
ep.outcome = Outcome.SUCCESS if i < 3 else Outcome.FAILURE
ep.task_description = f"Episode {i}"
ep.lessons_learned = None
ep.occurred_at = datetime.now(UTC)
mock_episodes.append(ep)
mock_episodic = AsyncMock()
mock_episodic.get_recent = AsyncMock(return_value=mock_episodes)
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
result = await service.execute_tool(
tool_name="reflect",
arguments={
"analysis_type": "recent_patterns",
"depth": 3,
},
context=context,
)
assert result.success is True
assert result.data["analysis_type"] == "recent_patterns"
assert result.data["total_episodes"] == 5
assert "top_task_types" in result.data
assert "outcome_distribution" in result.data
@patch("app.services.memory.mcp.service.EpisodicMemory")
async def test_reflect_success_factors(
self,
mock_episodic_cls: MagicMock,
service: MemoryToolService,
context: ToolContext,
) -> None:
"""Reflect should analyze success factors."""
mock_episodes = []
for i in range(10):
ep = MagicMock()
ep.id = uuid4()
ep.task_type = "code_review"
ep.outcome = Outcome.SUCCESS if i < 8 else Outcome.FAILURE
ep.task_description = f"Episode {i}"
ep.lessons_learned = "Learned something" if i < 3 else None
ep.occurred_at = datetime.now(UTC)
mock_episodes.append(ep)
mock_episodic = AsyncMock()
mock_episodic.get_recent = AsyncMock(return_value=mock_episodes)
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
result = await service.execute_tool(
tool_name="reflect",
arguments={
"analysis_type": "success_factors",
"include_examples": True,
},
context=context,
)
assert result.success is True
assert result.data["analysis_type"] == "success_factors"
assert result.data["overall_success_rate"] == 0.8
@patch("app.services.memory.mcp.service.EpisodicMemory")
@patch("app.services.memory.mcp.service.SemanticMemory")
@patch("app.services.memory.mcp.service.ProceduralMemory")
@patch("app.services.memory.mcp.service.WorkingMemory")
async def test_get_memory_stats(
self,
mock_working_cls: MagicMock,
mock_procedural_cls: MagicMock,
mock_semantic_cls: MagicMock,
mock_episodic_cls: MagicMock,
service: MemoryToolService,
context: ToolContext,
) -> None:
"""Get memory stats should return statistics."""
# Setup mocks
mock_working = AsyncMock()
mock_working.list_keys = AsyncMock(return_value=["key1", "key2"])
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
mock_episodic = AsyncMock()
mock_episodic.get_recent = AsyncMock(return_value=[MagicMock() for _ in range(10)])
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
mock_semantic = AsyncMock()
mock_semantic.search_facts = AsyncMock(return_value=[MagicMock() for _ in range(5)])
mock_semantic_cls.create = AsyncMock(return_value=mock_semantic)
mock_procedural = AsyncMock()
mock_procedural.find_matching = AsyncMock(return_value=[MagicMock() for _ in range(3)])
mock_procedural_cls.create = AsyncMock(return_value=mock_procedural)
result = await service.execute_tool(
tool_name="get_memory_stats",
arguments={
"include_breakdown": True,
"include_recent_activity": False,
},
context=context,
)
assert result.success is True
assert "breakdown" in result.data
breakdown = result.data["breakdown"]
assert breakdown["working"] == 2
assert breakdown["episodic"] == 10
assert breakdown["semantic"] == 5
assert breakdown["procedural"] == 3
assert breakdown["total"] == 20
@patch("app.services.memory.mcp.service.ProceduralMemory")
async def test_search_procedures(
self,
mock_procedural_cls: MagicMock,
service: MemoryToolService,
context: ToolContext,
) -> None:
"""Search procedures should return matching procedures."""
mock_proc = MagicMock()
mock_proc.id = uuid4()
mock_proc.name = "Deployment procedure"
mock_proc.description = "How to deploy"
mock_proc.trigger = "When deploying"
mock_proc.success_rate = 0.9
mock_proc.execution_count = 10
mock_proc.steps = [{"action": "deploy"}]
mock_procedural = AsyncMock()
mock_procedural.find_matching = AsyncMock(return_value=[mock_proc])
mock_procedural_cls.create = AsyncMock(return_value=mock_procedural)
result = await service.execute_tool(
tool_name="search_procedures",
arguments={
"trigger": "Deploying to production",
"min_success_rate": 0.8,
"include_steps": True,
},
context=context,
)
assert result.success is True
assert result.data["procedures_found"] == 1
proc = result.data["procedures"][0]
assert proc["name"] == "Deployment procedure"
assert "steps" in proc
async def test_record_outcome(
self,
service: MemoryToolService,
context: ToolContext,
) -> None:
"""Record outcome should store outcome and update procedure."""
with (
patch("app.services.memory.mcp.service.EpisodicMemory") as mock_episodic_cls,
patch("app.services.memory.mcp.service.ProceduralMemory") as mock_procedural_cls,
):
mock_episode = MagicMock()
mock_episode.id = uuid4()
mock_episodic = AsyncMock()
mock_episodic.record_episode = AsyncMock(return_value=mock_episode)
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
mock_procedural = AsyncMock()
mock_procedural.record_outcome = AsyncMock()
mock_procedural_cls.create = AsyncMock(return_value=mock_procedural)
result = await service.execute_tool(
tool_name="record_outcome",
arguments={
"task_type": "code_review",
"outcome": "success",
"lessons_learned": "Breaking changes caught early",
"duration_seconds": 120.5,
},
context=context,
)
assert result.success is True
assert result.data["recorded"] is True
assert result.data["outcome"] == "success"
assert "episode_id" in result.data
class TestGetMemoryToolService:
"""Tests for get_memory_tool_service factory."""
async def test_creates_service(self) -> None:
"""Factory should create a service."""
mock_session = make_mock_session()
service = await get_memory_tool_service(mock_session)
assert isinstance(service, MemoryToolService)
async def test_accepts_embedding_generator(self) -> None:
"""Factory should accept embedding generator."""
mock_session = make_mock_session()
mock_generator = MagicMock()
service = await get_memory_tool_service(mock_session, mock_generator)
assert service._embedding_generator is mock_generator

View File

@@ -0,0 +1,420 @@
# tests/unit/services/memory/mcp/test_tools.py
"""Tests for MCP tool definitions."""
import pytest
from pydantic import ValidationError
from app.services.memory.mcp.tools import (
MEMORY_TOOL_DEFINITIONS,
AnalysisType,
ForgetArgs,
GetMemoryStatsArgs,
MemoryToolDefinition,
MemoryType,
OutcomeType,
RecallArgs,
RecordOutcomeArgs,
ReflectArgs,
RememberArgs,
SearchProceduresArgs,
get_all_tool_schemas,
get_tool_definition,
)
class TestMemoryType:
"""Tests for MemoryType enum."""
def test_all_types_defined(self) -> None:
"""All memory types should be defined."""
assert MemoryType.WORKING == "working"
assert MemoryType.EPISODIC == "episodic"
assert MemoryType.SEMANTIC == "semantic"
assert MemoryType.PROCEDURAL == "procedural"
def test_enum_values(self) -> None:
"""Enum values should match strings."""
assert MemoryType.WORKING.value == "working"
assert MemoryType("episodic") == MemoryType.EPISODIC
class TestAnalysisType:
"""Tests for AnalysisType enum."""
def test_all_types_defined(self) -> None:
"""All analysis types should be defined."""
assert AnalysisType.RECENT_PATTERNS == "recent_patterns"
assert AnalysisType.SUCCESS_FACTORS == "success_factors"
assert AnalysisType.FAILURE_PATTERNS == "failure_patterns"
assert AnalysisType.COMMON_PROCEDURES == "common_procedures"
assert AnalysisType.LEARNING_PROGRESS == "learning_progress"
class TestOutcomeType:
"""Tests for OutcomeType enum."""
def test_all_outcomes_defined(self) -> None:
"""All outcome types should be defined."""
assert OutcomeType.SUCCESS == "success"
assert OutcomeType.PARTIAL == "partial"
assert OutcomeType.FAILURE == "failure"
assert OutcomeType.ABANDONED == "abandoned"
class TestRememberArgs:
"""Tests for RememberArgs validation."""
def test_valid_working_memory_args(self) -> None:
"""Valid working memory args should parse."""
args = RememberArgs(
memory_type=MemoryType.WORKING,
content="Test content",
key="test_key",
ttl_seconds=3600,
)
assert args.memory_type == MemoryType.WORKING
assert args.key == "test_key"
assert args.ttl_seconds == 3600
def test_valid_semantic_args(self) -> None:
"""Valid semantic memory args should parse."""
args = RememberArgs(
memory_type=MemoryType.SEMANTIC,
content="User prefers dark mode",
subject="User",
predicate="prefers",
object_value="dark mode",
)
assert args.subject == "User"
assert args.predicate == "prefers"
assert args.object_value == "dark mode"
def test_valid_procedural_args(self) -> None:
"""Valid procedural memory args should parse."""
args = RememberArgs(
memory_type=MemoryType.PROCEDURAL,
content="File creation procedure",
trigger="When creating a new file",
steps=[{"action": "check_exists"}, {"action": "create"}],
)
assert args.trigger == "When creating a new file"
assert len(args.steps) == 2
def test_importance_validation(self) -> None:
"""Importance must be between 0 and 1."""
args = RememberArgs(
memory_type=MemoryType.WORKING,
content="Test",
importance=0.8,
)
assert args.importance == 0.8
with pytest.raises(ValidationError):
RememberArgs(
memory_type=MemoryType.WORKING,
content="Test",
importance=1.5, # Invalid
)
with pytest.raises(ValidationError):
RememberArgs(
memory_type=MemoryType.WORKING,
content="Test",
importance=-0.1, # Invalid
)
def test_content_required(self) -> None:
"""Content is required."""
with pytest.raises(ValidationError):
RememberArgs(
memory_type=MemoryType.WORKING,
content="", # Empty not allowed
)
def test_ttl_validation(self) -> None:
"""TTL must be within bounds."""
with pytest.raises(ValidationError):
RememberArgs(
memory_type=MemoryType.WORKING,
content="Test",
ttl_seconds=0, # Too low
)
with pytest.raises(ValidationError):
RememberArgs(
memory_type=MemoryType.WORKING,
content="Test",
ttl_seconds=86400 * 31, # Over 30 days
)
def test_default_values(self) -> None:
"""Default values should be set correctly."""
args = RememberArgs(
memory_type=MemoryType.WORKING,
content="Test",
)
assert args.importance == 0.5
assert args.ttl_seconds is None
assert args.metadata == {}
assert args.key is None
class TestRecallArgs:
"""Tests for RecallArgs validation."""
def test_valid_args(self) -> None:
"""Valid recall args should parse."""
args = RecallArgs(
query="authentication errors",
memory_types=[MemoryType.EPISODIC, MemoryType.SEMANTIC],
limit=10,
)
assert args.query == "authentication errors"
assert len(args.memory_types) == 2
assert args.limit == 10
def test_default_memory_types(self) -> None:
"""Default memory types should be episodic and semantic."""
args = RecallArgs(query="test query")
assert MemoryType.EPISODIC in args.memory_types
assert MemoryType.SEMANTIC in args.memory_types
def test_limit_validation(self) -> None:
"""Limit must be between 1 and 100."""
with pytest.raises(ValidationError):
RecallArgs(query="test", limit=0)
with pytest.raises(ValidationError):
RecallArgs(query="test", limit=101)
def test_min_relevance_validation(self) -> None:
"""Min relevance must be between 0 and 1."""
args = RecallArgs(query="test", min_relevance=0.5)
assert args.min_relevance == 0.5
with pytest.raises(ValidationError):
RecallArgs(query="test", min_relevance=1.5)
class TestForgetArgs:
"""Tests for ForgetArgs validation."""
def test_valid_key_deletion(self) -> None:
"""Valid key deletion args should parse."""
args = ForgetArgs(
memory_type=MemoryType.WORKING,
key="temp_key",
)
assert args.memory_type == MemoryType.WORKING
assert args.key == "temp_key"
def test_valid_id_deletion(self) -> None:
"""Valid ID deletion args should parse."""
args = ForgetArgs(
memory_type=MemoryType.EPISODIC,
memory_id="12345678-1234-1234-1234-123456789012",
)
assert args.memory_id is not None
def test_pattern_deletion_requires_confirm(self) -> None:
"""Pattern deletion should parse but service should validate confirm."""
args = ForgetArgs(
memory_type=MemoryType.WORKING,
pattern="cache_*",
confirm_bulk=False,
)
assert args.pattern == "cache_*"
assert args.confirm_bulk is False
class TestReflectArgs:
"""Tests for ReflectArgs validation."""
def test_valid_args(self) -> None:
"""Valid reflect args should parse."""
args = ReflectArgs(
analysis_type=AnalysisType.SUCCESS_FACTORS,
depth=3,
)
assert args.analysis_type == AnalysisType.SUCCESS_FACTORS
assert args.depth == 3
def test_depth_validation(self) -> None:
"""Depth must be between 1 and 5."""
with pytest.raises(ValidationError):
ReflectArgs(analysis_type=AnalysisType.SUCCESS_FACTORS, depth=0)
with pytest.raises(ValidationError):
ReflectArgs(analysis_type=AnalysisType.SUCCESS_FACTORS, depth=6)
def test_default_values(self) -> None:
"""Default values should be set correctly."""
args = ReflectArgs(analysis_type=AnalysisType.RECENT_PATTERNS)
assert args.depth == 3
assert args.include_examples is True
assert args.max_items == 10
class TestGetMemoryStatsArgs:
"""Tests for GetMemoryStatsArgs validation."""
def test_valid_args(self) -> None:
"""Valid args should parse."""
args = GetMemoryStatsArgs(
include_breakdown=True,
include_recent_activity=True,
time_range_days=30,
)
assert args.include_breakdown is True
assert args.time_range_days == 30
def test_time_range_validation(self) -> None:
"""Time range must be between 1 and 90."""
with pytest.raises(ValidationError):
GetMemoryStatsArgs(time_range_days=0)
with pytest.raises(ValidationError):
GetMemoryStatsArgs(time_range_days=91)
class TestSearchProceduresArgs:
"""Tests for SearchProceduresArgs validation."""
def test_valid_args(self) -> None:
"""Valid args should parse."""
args = SearchProceduresArgs(
trigger="Deploying to production",
min_success_rate=0.8,
limit=5,
)
assert args.trigger == "Deploying to production"
assert args.min_success_rate == 0.8
def test_trigger_required(self) -> None:
"""Trigger is required."""
with pytest.raises(ValidationError):
SearchProceduresArgs(trigger="")
def test_success_rate_validation(self) -> None:
"""Success rate must be between 0 and 1."""
with pytest.raises(ValidationError):
SearchProceduresArgs(trigger="test", min_success_rate=1.5)
class TestRecordOutcomeArgs:
"""Tests for RecordOutcomeArgs validation."""
def test_valid_success_args(self) -> None:
"""Valid success args should parse."""
args = RecordOutcomeArgs(
task_type="code_review",
outcome=OutcomeType.SUCCESS,
lessons_learned="Breaking changes caught early",
)
assert args.task_type == "code_review"
assert args.outcome == OutcomeType.SUCCESS
def test_valid_failure_args(self) -> None:
"""Valid failure args should parse."""
args = RecordOutcomeArgs(
task_type="deployment",
outcome=OutcomeType.FAILURE,
error_details="Database migration timeout",
duration_seconds=120.5,
)
assert args.outcome == OutcomeType.FAILURE
assert args.error_details is not None
def test_task_type_required(self) -> None:
"""Task type is required."""
with pytest.raises(ValidationError):
RecordOutcomeArgs(task_type="", outcome=OutcomeType.SUCCESS)
class TestMemoryToolDefinition:
"""Tests for MemoryToolDefinition class."""
def test_to_mcp_format(self) -> None:
"""Tool should convert to MCP format."""
tool = MemoryToolDefinition(
name="test_tool",
description="A test tool",
args_schema=RememberArgs,
)
mcp_format = tool.to_mcp_format()
assert mcp_format["name"] == "test_tool"
assert mcp_format["description"] == "A test tool"
assert "inputSchema" in mcp_format
assert "properties" in mcp_format["inputSchema"]
def test_validate_args(self) -> None:
"""Tool should validate args using schema."""
tool = MemoryToolDefinition(
name="remember",
description="Store in memory",
args_schema=RememberArgs,
)
# Valid args
validated = tool.validate_args({
"memory_type": "working",
"content": "Test content",
})
assert isinstance(validated, RememberArgs)
# Invalid args
with pytest.raises(ValidationError):
tool.validate_args({"memory_type": "invalid"})
class TestToolDefinitions:
"""Tests for the tool definitions dictionary."""
def test_all_tools_defined(self) -> None:
"""All expected tools should be defined."""
expected_tools = [
"remember",
"recall",
"forget",
"reflect",
"get_memory_stats",
"search_procedures",
"record_outcome",
]
for tool_name in expected_tools:
assert tool_name in MEMORY_TOOL_DEFINITIONS
assert isinstance(MEMORY_TOOL_DEFINITIONS[tool_name], MemoryToolDefinition)
def test_get_tool_definition(self) -> None:
"""get_tool_definition should return correct tool."""
tool = get_tool_definition("remember")
assert tool is not None
assert tool.name == "remember"
unknown = get_tool_definition("unknown_tool")
assert unknown is None
def test_get_all_tool_schemas(self) -> None:
"""get_all_tool_schemas should return MCP-formatted schemas."""
schemas = get_all_tool_schemas()
assert len(schemas) == 7
for schema in schemas:
assert "name" in schema
assert "description" in schema
assert "inputSchema" in schema
def test_tool_descriptions_not_empty(self) -> None:
"""All tools should have descriptions."""
for name, tool in MEMORY_TOOL_DEFINITIONS.items():
assert tool.description, f"Tool {name} has empty description"
assert len(tool.description) > 50, f"Tool {name} description too short"
def test_input_schemas_have_properties(self) -> None:
"""All tool schemas should have properties defined."""
for name, tool in MEMORY_TOOL_DEFINITIONS.items():
schema = tool.to_mcp_format()
assert "properties" in schema["inputSchema"], f"Tool {name} missing properties"