forked from cardosofelipe/fast-next-template
Add MCP-compatible tools that expose memory operations to agents: Tools implemented: - remember: Store data in working, episodic, semantic, or procedural memory - recall: Retrieve memories by query across multiple memory types - forget: Delete specific keys or bulk delete by pattern - reflect: Analyze patterns in recent episodes (success/failure factors) - get_memory_stats: Return usage statistics and breakdowns - search_procedures: Find procedures matching trigger patterns - record_outcome: Record task outcomes and update procedure success rates Key components: - tools.py: Pydantic schemas for tool argument validation with comprehensive field constraints (importance 0-1, TTL limits, limit ranges) - service.py: MemoryToolService coordinating memory type operations with proper scoping via ToolContext (project_id, agent_instance_id, session_id) - Lazy initialization of memory services (WorkingMemory, EpisodicMemory, SemanticMemory, ProceduralMemory) Test coverage: - 60 tests covering tool definitions, argument validation, and service execution paths - Mock-based tests for all memory type interactions 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
421 lines
14 KiB
Python
421 lines
14 KiB
Python
# tests/unit/services/memory/mcp/test_tools.py
|
|
"""Tests for MCP tool definitions."""
|
|
|
|
import pytest
|
|
from pydantic import ValidationError
|
|
|
|
from app.services.memory.mcp.tools import (
|
|
MEMORY_TOOL_DEFINITIONS,
|
|
AnalysisType,
|
|
ForgetArgs,
|
|
GetMemoryStatsArgs,
|
|
MemoryToolDefinition,
|
|
MemoryType,
|
|
OutcomeType,
|
|
RecallArgs,
|
|
RecordOutcomeArgs,
|
|
ReflectArgs,
|
|
RememberArgs,
|
|
SearchProceduresArgs,
|
|
get_all_tool_schemas,
|
|
get_tool_definition,
|
|
)
|
|
|
|
|
|
class TestMemoryType:
|
|
"""Tests for MemoryType enum."""
|
|
|
|
def test_all_types_defined(self) -> None:
|
|
"""All memory types should be defined."""
|
|
assert MemoryType.WORKING == "working"
|
|
assert MemoryType.EPISODIC == "episodic"
|
|
assert MemoryType.SEMANTIC == "semantic"
|
|
assert MemoryType.PROCEDURAL == "procedural"
|
|
|
|
def test_enum_values(self) -> None:
|
|
"""Enum values should match strings."""
|
|
assert MemoryType.WORKING.value == "working"
|
|
assert MemoryType("episodic") == MemoryType.EPISODIC
|
|
|
|
|
|
class TestAnalysisType:
|
|
"""Tests for AnalysisType enum."""
|
|
|
|
def test_all_types_defined(self) -> None:
|
|
"""All analysis types should be defined."""
|
|
assert AnalysisType.RECENT_PATTERNS == "recent_patterns"
|
|
assert AnalysisType.SUCCESS_FACTORS == "success_factors"
|
|
assert AnalysisType.FAILURE_PATTERNS == "failure_patterns"
|
|
assert AnalysisType.COMMON_PROCEDURES == "common_procedures"
|
|
assert AnalysisType.LEARNING_PROGRESS == "learning_progress"
|
|
|
|
|
|
class TestOutcomeType:
|
|
"""Tests for OutcomeType enum."""
|
|
|
|
def test_all_outcomes_defined(self) -> None:
|
|
"""All outcome types should be defined."""
|
|
assert OutcomeType.SUCCESS == "success"
|
|
assert OutcomeType.PARTIAL == "partial"
|
|
assert OutcomeType.FAILURE == "failure"
|
|
assert OutcomeType.ABANDONED == "abandoned"
|
|
|
|
|
|
class TestRememberArgs:
|
|
"""Tests for RememberArgs validation."""
|
|
|
|
def test_valid_working_memory_args(self) -> None:
|
|
"""Valid working memory args should parse."""
|
|
args = RememberArgs(
|
|
memory_type=MemoryType.WORKING,
|
|
content="Test content",
|
|
key="test_key",
|
|
ttl_seconds=3600,
|
|
)
|
|
assert args.memory_type == MemoryType.WORKING
|
|
assert args.key == "test_key"
|
|
assert args.ttl_seconds == 3600
|
|
|
|
def test_valid_semantic_args(self) -> None:
|
|
"""Valid semantic memory args should parse."""
|
|
args = RememberArgs(
|
|
memory_type=MemoryType.SEMANTIC,
|
|
content="User prefers dark mode",
|
|
subject="User",
|
|
predicate="prefers",
|
|
object_value="dark mode",
|
|
)
|
|
assert args.subject == "User"
|
|
assert args.predicate == "prefers"
|
|
assert args.object_value == "dark mode"
|
|
|
|
def test_valid_procedural_args(self) -> None:
|
|
"""Valid procedural memory args should parse."""
|
|
args = RememberArgs(
|
|
memory_type=MemoryType.PROCEDURAL,
|
|
content="File creation procedure",
|
|
trigger="When creating a new file",
|
|
steps=[{"action": "check_exists"}, {"action": "create"}],
|
|
)
|
|
assert args.trigger == "When creating a new file"
|
|
assert len(args.steps) == 2
|
|
|
|
def test_importance_validation(self) -> None:
|
|
"""Importance must be between 0 and 1."""
|
|
args = RememberArgs(
|
|
memory_type=MemoryType.WORKING,
|
|
content="Test",
|
|
importance=0.8,
|
|
)
|
|
assert args.importance == 0.8
|
|
|
|
with pytest.raises(ValidationError):
|
|
RememberArgs(
|
|
memory_type=MemoryType.WORKING,
|
|
content="Test",
|
|
importance=1.5, # Invalid
|
|
)
|
|
|
|
with pytest.raises(ValidationError):
|
|
RememberArgs(
|
|
memory_type=MemoryType.WORKING,
|
|
content="Test",
|
|
importance=-0.1, # Invalid
|
|
)
|
|
|
|
def test_content_required(self) -> None:
|
|
"""Content is required."""
|
|
with pytest.raises(ValidationError):
|
|
RememberArgs(
|
|
memory_type=MemoryType.WORKING,
|
|
content="", # Empty not allowed
|
|
)
|
|
|
|
def test_ttl_validation(self) -> None:
|
|
"""TTL must be within bounds."""
|
|
with pytest.raises(ValidationError):
|
|
RememberArgs(
|
|
memory_type=MemoryType.WORKING,
|
|
content="Test",
|
|
ttl_seconds=0, # Too low
|
|
)
|
|
|
|
with pytest.raises(ValidationError):
|
|
RememberArgs(
|
|
memory_type=MemoryType.WORKING,
|
|
content="Test",
|
|
ttl_seconds=86400 * 31, # Over 30 days
|
|
)
|
|
|
|
def test_default_values(self) -> None:
|
|
"""Default values should be set correctly."""
|
|
args = RememberArgs(
|
|
memory_type=MemoryType.WORKING,
|
|
content="Test",
|
|
)
|
|
assert args.importance == 0.5
|
|
assert args.ttl_seconds is None
|
|
assert args.metadata == {}
|
|
assert args.key is None
|
|
|
|
|
|
class TestRecallArgs:
|
|
"""Tests for RecallArgs validation."""
|
|
|
|
def test_valid_args(self) -> None:
|
|
"""Valid recall args should parse."""
|
|
args = RecallArgs(
|
|
query="authentication errors",
|
|
memory_types=[MemoryType.EPISODIC, MemoryType.SEMANTIC],
|
|
limit=10,
|
|
)
|
|
assert args.query == "authentication errors"
|
|
assert len(args.memory_types) == 2
|
|
assert args.limit == 10
|
|
|
|
def test_default_memory_types(self) -> None:
|
|
"""Default memory types should be episodic and semantic."""
|
|
args = RecallArgs(query="test query")
|
|
assert MemoryType.EPISODIC in args.memory_types
|
|
assert MemoryType.SEMANTIC in args.memory_types
|
|
|
|
def test_limit_validation(self) -> None:
|
|
"""Limit must be between 1 and 100."""
|
|
with pytest.raises(ValidationError):
|
|
RecallArgs(query="test", limit=0)
|
|
|
|
with pytest.raises(ValidationError):
|
|
RecallArgs(query="test", limit=101)
|
|
|
|
def test_min_relevance_validation(self) -> None:
|
|
"""Min relevance must be between 0 and 1."""
|
|
args = RecallArgs(query="test", min_relevance=0.5)
|
|
assert args.min_relevance == 0.5
|
|
|
|
with pytest.raises(ValidationError):
|
|
RecallArgs(query="test", min_relevance=1.5)
|
|
|
|
|
|
class TestForgetArgs:
|
|
"""Tests for ForgetArgs validation."""
|
|
|
|
def test_valid_key_deletion(self) -> None:
|
|
"""Valid key deletion args should parse."""
|
|
args = ForgetArgs(
|
|
memory_type=MemoryType.WORKING,
|
|
key="temp_key",
|
|
)
|
|
assert args.memory_type == MemoryType.WORKING
|
|
assert args.key == "temp_key"
|
|
|
|
def test_valid_id_deletion(self) -> None:
|
|
"""Valid ID deletion args should parse."""
|
|
args = ForgetArgs(
|
|
memory_type=MemoryType.EPISODIC,
|
|
memory_id="12345678-1234-1234-1234-123456789012",
|
|
)
|
|
assert args.memory_id is not None
|
|
|
|
def test_pattern_deletion_requires_confirm(self) -> None:
|
|
"""Pattern deletion should parse but service should validate confirm."""
|
|
args = ForgetArgs(
|
|
memory_type=MemoryType.WORKING,
|
|
pattern="cache_*",
|
|
confirm_bulk=False,
|
|
)
|
|
assert args.pattern == "cache_*"
|
|
assert args.confirm_bulk is False
|
|
|
|
|
|
class TestReflectArgs:
|
|
"""Tests for ReflectArgs validation."""
|
|
|
|
def test_valid_args(self) -> None:
|
|
"""Valid reflect args should parse."""
|
|
args = ReflectArgs(
|
|
analysis_type=AnalysisType.SUCCESS_FACTORS,
|
|
depth=3,
|
|
)
|
|
assert args.analysis_type == AnalysisType.SUCCESS_FACTORS
|
|
assert args.depth == 3
|
|
|
|
def test_depth_validation(self) -> None:
|
|
"""Depth must be between 1 and 5."""
|
|
with pytest.raises(ValidationError):
|
|
ReflectArgs(analysis_type=AnalysisType.SUCCESS_FACTORS, depth=0)
|
|
|
|
with pytest.raises(ValidationError):
|
|
ReflectArgs(analysis_type=AnalysisType.SUCCESS_FACTORS, depth=6)
|
|
|
|
def test_default_values(self) -> None:
|
|
"""Default values should be set correctly."""
|
|
args = ReflectArgs(analysis_type=AnalysisType.RECENT_PATTERNS)
|
|
assert args.depth == 3
|
|
assert args.include_examples is True
|
|
assert args.max_items == 10
|
|
|
|
|
|
class TestGetMemoryStatsArgs:
|
|
"""Tests for GetMemoryStatsArgs validation."""
|
|
|
|
def test_valid_args(self) -> None:
|
|
"""Valid args should parse."""
|
|
args = GetMemoryStatsArgs(
|
|
include_breakdown=True,
|
|
include_recent_activity=True,
|
|
time_range_days=30,
|
|
)
|
|
assert args.include_breakdown is True
|
|
assert args.time_range_days == 30
|
|
|
|
def test_time_range_validation(self) -> None:
|
|
"""Time range must be between 1 and 90."""
|
|
with pytest.raises(ValidationError):
|
|
GetMemoryStatsArgs(time_range_days=0)
|
|
|
|
with pytest.raises(ValidationError):
|
|
GetMemoryStatsArgs(time_range_days=91)
|
|
|
|
|
|
class TestSearchProceduresArgs:
|
|
"""Tests for SearchProceduresArgs validation."""
|
|
|
|
def test_valid_args(self) -> None:
|
|
"""Valid args should parse."""
|
|
args = SearchProceduresArgs(
|
|
trigger="Deploying to production",
|
|
min_success_rate=0.8,
|
|
limit=5,
|
|
)
|
|
assert args.trigger == "Deploying to production"
|
|
assert args.min_success_rate == 0.8
|
|
|
|
def test_trigger_required(self) -> None:
|
|
"""Trigger is required."""
|
|
with pytest.raises(ValidationError):
|
|
SearchProceduresArgs(trigger="")
|
|
|
|
def test_success_rate_validation(self) -> None:
|
|
"""Success rate must be between 0 and 1."""
|
|
with pytest.raises(ValidationError):
|
|
SearchProceduresArgs(trigger="test", min_success_rate=1.5)
|
|
|
|
|
|
class TestRecordOutcomeArgs:
|
|
"""Tests for RecordOutcomeArgs validation."""
|
|
|
|
def test_valid_success_args(self) -> None:
|
|
"""Valid success args should parse."""
|
|
args = RecordOutcomeArgs(
|
|
task_type="code_review",
|
|
outcome=OutcomeType.SUCCESS,
|
|
lessons_learned="Breaking changes caught early",
|
|
)
|
|
assert args.task_type == "code_review"
|
|
assert args.outcome == OutcomeType.SUCCESS
|
|
|
|
def test_valid_failure_args(self) -> None:
|
|
"""Valid failure args should parse."""
|
|
args = RecordOutcomeArgs(
|
|
task_type="deployment",
|
|
outcome=OutcomeType.FAILURE,
|
|
error_details="Database migration timeout",
|
|
duration_seconds=120.5,
|
|
)
|
|
assert args.outcome == OutcomeType.FAILURE
|
|
assert args.error_details is not None
|
|
|
|
def test_task_type_required(self) -> None:
|
|
"""Task type is required."""
|
|
with pytest.raises(ValidationError):
|
|
RecordOutcomeArgs(task_type="", outcome=OutcomeType.SUCCESS)
|
|
|
|
|
|
class TestMemoryToolDefinition:
|
|
"""Tests for MemoryToolDefinition class."""
|
|
|
|
def test_to_mcp_format(self) -> None:
|
|
"""Tool should convert to MCP format."""
|
|
tool = MemoryToolDefinition(
|
|
name="test_tool",
|
|
description="A test tool",
|
|
args_schema=RememberArgs,
|
|
)
|
|
|
|
mcp_format = tool.to_mcp_format()
|
|
|
|
assert mcp_format["name"] == "test_tool"
|
|
assert mcp_format["description"] == "A test tool"
|
|
assert "inputSchema" in mcp_format
|
|
assert "properties" in mcp_format["inputSchema"]
|
|
|
|
def test_validate_args(self) -> None:
|
|
"""Tool should validate args using schema."""
|
|
tool = MemoryToolDefinition(
|
|
name="remember",
|
|
description="Store in memory",
|
|
args_schema=RememberArgs,
|
|
)
|
|
|
|
# Valid args
|
|
validated = tool.validate_args({
|
|
"memory_type": "working",
|
|
"content": "Test content",
|
|
})
|
|
assert isinstance(validated, RememberArgs)
|
|
|
|
# Invalid args
|
|
with pytest.raises(ValidationError):
|
|
tool.validate_args({"memory_type": "invalid"})
|
|
|
|
|
|
class TestToolDefinitions:
|
|
"""Tests for the tool definitions dictionary."""
|
|
|
|
def test_all_tools_defined(self) -> None:
|
|
"""All expected tools should be defined."""
|
|
expected_tools = [
|
|
"remember",
|
|
"recall",
|
|
"forget",
|
|
"reflect",
|
|
"get_memory_stats",
|
|
"search_procedures",
|
|
"record_outcome",
|
|
]
|
|
|
|
for tool_name in expected_tools:
|
|
assert tool_name in MEMORY_TOOL_DEFINITIONS
|
|
assert isinstance(MEMORY_TOOL_DEFINITIONS[tool_name], MemoryToolDefinition)
|
|
|
|
def test_get_tool_definition(self) -> None:
|
|
"""get_tool_definition should return correct tool."""
|
|
tool = get_tool_definition("remember")
|
|
assert tool is not None
|
|
assert tool.name == "remember"
|
|
|
|
unknown = get_tool_definition("unknown_tool")
|
|
assert unknown is None
|
|
|
|
def test_get_all_tool_schemas(self) -> None:
|
|
"""get_all_tool_schemas should return MCP-formatted schemas."""
|
|
schemas = get_all_tool_schemas()
|
|
|
|
assert len(schemas) == 7
|
|
for schema in schemas:
|
|
assert "name" in schema
|
|
assert "description" in schema
|
|
assert "inputSchema" in schema
|
|
|
|
def test_tool_descriptions_not_empty(self) -> None:
|
|
"""All tools should have descriptions."""
|
|
for name, tool in MEMORY_TOOL_DEFINITIONS.items():
|
|
assert tool.description, f"Tool {name} has empty description"
|
|
assert len(tool.description) > 50, f"Tool {name} description too short"
|
|
|
|
def test_input_schemas_have_properties(self) -> None:
|
|
"""All tool schemas should have properties defined."""
|
|
for name, tool in MEMORY_TOOL_DEFINITIONS.items():
|
|
schema = tool.to_mcp_format()
|
|
assert "properties" in schema["inputSchema"], f"Tool {name} missing properties"
|