# tests/unit/services/memory/mcp/test_tools.py """Tests for MCP tool definitions.""" import pytest from pydantic import ValidationError from app.services.memory.mcp.tools import ( MEMORY_TOOL_DEFINITIONS, AnalysisType, ForgetArgs, GetMemoryStatsArgs, MemoryToolDefinition, MemoryType, OutcomeType, RecallArgs, RecordOutcomeArgs, ReflectArgs, RememberArgs, SearchProceduresArgs, get_all_tool_schemas, get_tool_definition, ) class TestMemoryType: """Tests for MemoryType enum.""" def test_all_types_defined(self) -> None: """All memory types should be defined.""" assert MemoryType.WORKING == "working" assert MemoryType.EPISODIC == "episodic" assert MemoryType.SEMANTIC == "semantic" assert MemoryType.PROCEDURAL == "procedural" def test_enum_values(self) -> None: """Enum values should match strings.""" assert MemoryType.WORKING.value == "working" assert MemoryType("episodic") == MemoryType.EPISODIC class TestAnalysisType: """Tests for AnalysisType enum.""" def test_all_types_defined(self) -> None: """All analysis types should be defined.""" assert AnalysisType.RECENT_PATTERNS == "recent_patterns" assert AnalysisType.SUCCESS_FACTORS == "success_factors" assert AnalysisType.FAILURE_PATTERNS == "failure_patterns" assert AnalysisType.COMMON_PROCEDURES == "common_procedures" assert AnalysisType.LEARNING_PROGRESS == "learning_progress" class TestOutcomeType: """Tests for OutcomeType enum.""" def test_all_outcomes_defined(self) -> None: """All outcome types should be defined.""" assert OutcomeType.SUCCESS == "success" assert OutcomeType.PARTIAL == "partial" assert OutcomeType.FAILURE == "failure" assert OutcomeType.ABANDONED == "abandoned" class TestRememberArgs: """Tests for RememberArgs validation.""" def test_valid_working_memory_args(self) -> None: """Valid working memory args should parse.""" args = RememberArgs( memory_type=MemoryType.WORKING, content="Test content", key="test_key", ttl_seconds=3600, ) assert args.memory_type == MemoryType.WORKING assert args.key == "test_key" assert args.ttl_seconds == 3600 def test_valid_semantic_args(self) -> None: """Valid semantic memory args should parse.""" args = RememberArgs( memory_type=MemoryType.SEMANTIC, content="User prefers dark mode", subject="User", predicate="prefers", object_value="dark mode", ) assert args.subject == "User" assert args.predicate == "prefers" assert args.object_value == "dark mode" def test_valid_procedural_args(self) -> None: """Valid procedural memory args should parse.""" args = RememberArgs( memory_type=MemoryType.PROCEDURAL, content="File creation procedure", trigger="When creating a new file", steps=[{"action": "check_exists"}, {"action": "create"}], ) assert args.trigger == "When creating a new file" assert len(args.steps) == 2 def test_importance_validation(self) -> None: """Importance must be between 0 and 1.""" args = RememberArgs( memory_type=MemoryType.WORKING, content="Test", importance=0.8, ) assert args.importance == 0.8 with pytest.raises(ValidationError): RememberArgs( memory_type=MemoryType.WORKING, content="Test", importance=1.5, # Invalid ) with pytest.raises(ValidationError): RememberArgs( memory_type=MemoryType.WORKING, content="Test", importance=-0.1, # Invalid ) def test_content_required(self) -> None: """Content is required.""" with pytest.raises(ValidationError): RememberArgs( memory_type=MemoryType.WORKING, content="", # Empty not allowed ) def test_ttl_validation(self) -> None: """TTL must be within bounds.""" with pytest.raises(ValidationError): RememberArgs( memory_type=MemoryType.WORKING, content="Test", ttl_seconds=0, # Too low ) with pytest.raises(ValidationError): RememberArgs( memory_type=MemoryType.WORKING, content="Test", ttl_seconds=86400 * 31, # Over 30 days ) def test_default_values(self) -> None: """Default values should be set correctly.""" args = RememberArgs( memory_type=MemoryType.WORKING, content="Test", ) assert args.importance == 0.5 assert args.ttl_seconds is None assert args.metadata == {} assert args.key is None class TestRecallArgs: """Tests for RecallArgs validation.""" def test_valid_args(self) -> None: """Valid recall args should parse.""" args = RecallArgs( query="authentication errors", memory_types=[MemoryType.EPISODIC, MemoryType.SEMANTIC], limit=10, ) assert args.query == "authentication errors" assert len(args.memory_types) == 2 assert args.limit == 10 def test_default_memory_types(self) -> None: """Default memory types should be episodic and semantic.""" args = RecallArgs(query="test query") assert MemoryType.EPISODIC in args.memory_types assert MemoryType.SEMANTIC in args.memory_types def test_limit_validation(self) -> None: """Limit must be between 1 and 100.""" with pytest.raises(ValidationError): RecallArgs(query="test", limit=0) with pytest.raises(ValidationError): RecallArgs(query="test", limit=101) def test_min_relevance_validation(self) -> None: """Min relevance must be between 0 and 1.""" args = RecallArgs(query="test", min_relevance=0.5) assert args.min_relevance == 0.5 with pytest.raises(ValidationError): RecallArgs(query="test", min_relevance=1.5) class TestForgetArgs: """Tests for ForgetArgs validation.""" def test_valid_key_deletion(self) -> None: """Valid key deletion args should parse.""" args = ForgetArgs( memory_type=MemoryType.WORKING, key="temp_key", ) assert args.memory_type == MemoryType.WORKING assert args.key == "temp_key" def test_valid_id_deletion(self) -> None: """Valid ID deletion args should parse.""" args = ForgetArgs( memory_type=MemoryType.EPISODIC, memory_id="12345678-1234-1234-1234-123456789012", ) assert args.memory_id is not None def test_pattern_deletion_requires_confirm(self) -> None: """Pattern deletion should parse but service should validate confirm.""" args = ForgetArgs( memory_type=MemoryType.WORKING, pattern="cache_*", confirm_bulk=False, ) assert args.pattern == "cache_*" assert args.confirm_bulk is False class TestReflectArgs: """Tests for ReflectArgs validation.""" def test_valid_args(self) -> None: """Valid reflect args should parse.""" args = ReflectArgs( analysis_type=AnalysisType.SUCCESS_FACTORS, depth=3, ) assert args.analysis_type == AnalysisType.SUCCESS_FACTORS assert args.depth == 3 def test_depth_validation(self) -> None: """Depth must be between 1 and 5.""" with pytest.raises(ValidationError): ReflectArgs(analysis_type=AnalysisType.SUCCESS_FACTORS, depth=0) with pytest.raises(ValidationError): ReflectArgs(analysis_type=AnalysisType.SUCCESS_FACTORS, depth=6) def test_default_values(self) -> None: """Default values should be set correctly.""" args = ReflectArgs(analysis_type=AnalysisType.RECENT_PATTERNS) assert args.depth == 3 assert args.include_examples is True assert args.max_items == 10 class TestGetMemoryStatsArgs: """Tests for GetMemoryStatsArgs validation.""" def test_valid_args(self) -> None: """Valid args should parse.""" args = GetMemoryStatsArgs( include_breakdown=True, include_recent_activity=True, time_range_days=30, ) assert args.include_breakdown is True assert args.time_range_days == 30 def test_time_range_validation(self) -> None: """Time range must be between 1 and 90.""" with pytest.raises(ValidationError): GetMemoryStatsArgs(time_range_days=0) with pytest.raises(ValidationError): GetMemoryStatsArgs(time_range_days=91) class TestSearchProceduresArgs: """Tests for SearchProceduresArgs validation.""" def test_valid_args(self) -> None: """Valid args should parse.""" args = SearchProceduresArgs( trigger="Deploying to production", min_success_rate=0.8, limit=5, ) assert args.trigger == "Deploying to production" assert args.min_success_rate == 0.8 def test_trigger_required(self) -> None: """Trigger is required.""" with pytest.raises(ValidationError): SearchProceduresArgs(trigger="") def test_success_rate_validation(self) -> None: """Success rate must be between 0 and 1.""" with pytest.raises(ValidationError): SearchProceduresArgs(trigger="test", min_success_rate=1.5) class TestRecordOutcomeArgs: """Tests for RecordOutcomeArgs validation.""" def test_valid_success_args(self) -> None: """Valid success args should parse.""" args = RecordOutcomeArgs( task_type="code_review", outcome=OutcomeType.SUCCESS, lessons_learned="Breaking changes caught early", ) assert args.task_type == "code_review" assert args.outcome == OutcomeType.SUCCESS def test_valid_failure_args(self) -> None: """Valid failure args should parse.""" args = RecordOutcomeArgs( task_type="deployment", outcome=OutcomeType.FAILURE, error_details="Database migration timeout", duration_seconds=120.5, ) assert args.outcome == OutcomeType.FAILURE assert args.error_details is not None def test_task_type_required(self) -> None: """Task type is required.""" with pytest.raises(ValidationError): RecordOutcomeArgs(task_type="", outcome=OutcomeType.SUCCESS) class TestMemoryToolDefinition: """Tests for MemoryToolDefinition class.""" def test_to_mcp_format(self) -> None: """Tool should convert to MCP format.""" tool = MemoryToolDefinition( name="test_tool", description="A test tool", args_schema=RememberArgs, ) mcp_format = tool.to_mcp_format() assert mcp_format["name"] == "test_tool" assert mcp_format["description"] == "A test tool" assert "inputSchema" in mcp_format assert "properties" in mcp_format["inputSchema"] def test_validate_args(self) -> None: """Tool should validate args using schema.""" tool = MemoryToolDefinition( name="remember", description="Store in memory", args_schema=RememberArgs, ) # Valid args validated = tool.validate_args({ "memory_type": "working", "content": "Test content", }) assert isinstance(validated, RememberArgs) # Invalid args with pytest.raises(ValidationError): tool.validate_args({"memory_type": "invalid"}) class TestToolDefinitions: """Tests for the tool definitions dictionary.""" def test_all_tools_defined(self) -> None: """All expected tools should be defined.""" expected_tools = [ "remember", "recall", "forget", "reflect", "get_memory_stats", "search_procedures", "record_outcome", ] for tool_name in expected_tools: assert tool_name in MEMORY_TOOL_DEFINITIONS assert isinstance(MEMORY_TOOL_DEFINITIONS[tool_name], MemoryToolDefinition) def test_get_tool_definition(self) -> None: """get_tool_definition should return correct tool.""" tool = get_tool_definition("remember") assert tool is not None assert tool.name == "remember" unknown = get_tool_definition("unknown_tool") assert unknown is None def test_get_all_tool_schemas(self) -> None: """get_all_tool_schemas should return MCP-formatted schemas.""" schemas = get_all_tool_schemas() assert len(schemas) == 7 for schema in schemas: assert "name" in schema assert "description" in schema assert "inputSchema" in schema def test_tool_descriptions_not_empty(self) -> None: """All tools should have descriptions.""" for name, tool in MEMORY_TOOL_DEFINITIONS.items(): assert tool.description, f"Tool {name} has empty description" assert len(tool.description) > 50, f"Tool {name} description too short" def test_input_schemas_have_properties(self) -> None: """All tool schemas should have properties defined.""" for name, tool in MEMORY_TOOL_DEFINITIONS.items(): schema = tool.to_mcp_format() assert "properties" in schema["inputSchema"], f"Tool {name} missing properties"