""" Tests for MCP Safety Integration. Tests cover: - MCPToolCall and MCPToolResult data structures - MCPSafetyWrapper: tool registration, execution, safety checks - Tool classification and action type mapping - SafeToolExecutor context manager - Factory function create_mcp_wrapper """ from datetime import datetime, timedelta from unittest.mock import AsyncMock, MagicMock, patch import pytest import pytest_asyncio from app.services.safety.exceptions import EmergencyStopError from app.services.safety.mcp.integration import ( MCPSafetyWrapper, MCPToolCall, MCPToolResult, SafeToolExecutor, create_mcp_wrapper, ) from app.services.safety.models import ( ActionType, AutonomyLevel, SafetyDecision, ) class TestMCPToolCall: """Tests for MCPToolCall dataclass.""" def test_tool_call_creation(self): """Test creating a tool call.""" call = MCPToolCall( tool_name="file_read", arguments={"path": "/tmp/test.txt"}, # noqa: S108 server_name="file-server", project_id="proj-1", context={"session_id": "sess-1"}, ) assert call.tool_name == "file_read" assert call.arguments == {"path": "/tmp/test.txt"} # noqa: S108 assert call.server_name == "file-server" assert call.project_id == "proj-1" assert call.context == {"session_id": "sess-1"} def test_tool_call_defaults(self): """Test tool call default values.""" call = MCPToolCall( tool_name="test", arguments={}, ) assert call.server_name is None assert call.project_id is None assert call.context == {} class TestMCPToolResult: """Tests for MCPToolResult dataclass.""" def test_tool_result_success(self): """Test creating a successful result.""" result = MCPToolResult( success=True, result={"data": "test"}, safety_decision=SafetyDecision.ALLOW, execution_time_ms=50.0, ) assert result.success is True assert result.result == {"data": "test"} assert result.error is None assert result.safety_decision == SafetyDecision.ALLOW assert result.execution_time_ms == 50.0 def test_tool_result_failure(self): """Test creating a failed result.""" result = MCPToolResult( success=False, error="Permission denied", safety_decision=SafetyDecision.DENY, ) assert result.success is False assert result.error == "Permission denied" assert result.result is None def test_tool_result_with_ids(self): """Test result with approval and checkpoint IDs.""" result = MCPToolResult( success=True, approval_id="approval-123", checkpoint_id="checkpoint-456", ) assert result.approval_id == "approval-123" assert result.checkpoint_id == "checkpoint-456" def test_tool_result_defaults(self): """Test result default values.""" result = MCPToolResult(success=True) assert result.result is None assert result.error is None assert result.safety_decision == SafetyDecision.ALLOW assert result.execution_time_ms == 0.0 assert result.approval_id is None assert result.checkpoint_id is None assert result.metadata == {} class TestMCPSafetyWrapperClassification: """Tests for tool classification.""" def test_classify_file_read(self): """Test classifying file read tools.""" wrapper = MCPSafetyWrapper() assert wrapper._classify_tool("file_read") == ActionType.FILE_READ assert wrapper._classify_tool("get_file") == ActionType.FILE_READ assert wrapper._classify_tool("list_files") == ActionType.FILE_READ assert wrapper._classify_tool("search_file") == ActionType.FILE_READ def test_classify_file_write(self): """Test classifying file write tools.""" wrapper = MCPSafetyWrapper() assert wrapper._classify_tool("file_write") == ActionType.FILE_WRITE assert wrapper._classify_tool("create_file") == ActionType.FILE_WRITE assert wrapper._classify_tool("update_file") == ActionType.FILE_WRITE def test_classify_file_delete(self): """Test classifying file delete tools.""" wrapper = MCPSafetyWrapper() assert wrapper._classify_tool("file_delete") == ActionType.FILE_DELETE assert wrapper._classify_tool("remove_file") == ActionType.FILE_DELETE def test_classify_database_read(self): """Test classifying database read tools.""" wrapper = MCPSafetyWrapper() assert wrapper._classify_tool("database_query") == ActionType.DATABASE_QUERY assert wrapper._classify_tool("db_read") == ActionType.DATABASE_QUERY assert wrapper._classify_tool("query_database") == ActionType.DATABASE_QUERY def test_classify_database_mutate(self): """Test classifying database mutate tools.""" wrapper = MCPSafetyWrapper() assert wrapper._classify_tool("database_write") == ActionType.DATABASE_MUTATE assert wrapper._classify_tool("db_update") == ActionType.DATABASE_MUTATE assert wrapper._classify_tool("database_delete") == ActionType.DATABASE_MUTATE def test_classify_shell_command(self): """Test classifying shell command tools.""" wrapper = MCPSafetyWrapper() assert wrapper._classify_tool("shell_execute") == ActionType.SHELL_COMMAND assert wrapper._classify_tool("exec_command") == ActionType.SHELL_COMMAND assert wrapper._classify_tool("bash_run") == ActionType.SHELL_COMMAND def test_classify_git_operation(self): """Test classifying git tools.""" wrapper = MCPSafetyWrapper() assert wrapper._classify_tool("git_commit") == ActionType.GIT_OPERATION assert wrapper._classify_tool("git_push") == ActionType.GIT_OPERATION assert wrapper._classify_tool("git_status") == ActionType.GIT_OPERATION def test_classify_network_request(self): """Test classifying network tools.""" wrapper = MCPSafetyWrapper() assert wrapper._classify_tool("http_get") == ActionType.NETWORK_REQUEST assert wrapper._classify_tool("fetch_url") == ActionType.NETWORK_REQUEST assert wrapper._classify_tool("api_request") == ActionType.NETWORK_REQUEST def test_classify_llm_call(self): """Test classifying LLM tools.""" wrapper = MCPSafetyWrapper() assert wrapper._classify_tool("llm_generate") == ActionType.LLM_CALL assert wrapper._classify_tool("ai_complete") == ActionType.LLM_CALL assert wrapper._classify_tool("claude_chat") == ActionType.LLM_CALL def test_classify_default(self): """Test default classification for unknown tools.""" wrapper = MCPSafetyWrapper() assert wrapper._classify_tool("unknown_tool") == ActionType.TOOL_CALL assert wrapper._classify_tool("custom_action") == ActionType.TOOL_CALL class TestMCPSafetyWrapperToolHandlers: """Tests for tool handler registration.""" def test_register_tool_handler(self): """Test registering a tool handler.""" wrapper = MCPSafetyWrapper() def handler(path: str) -> str: return f"Read: {path}" wrapper.register_tool_handler("file_read", handler) assert "file_read" in wrapper._tool_handlers assert wrapper._tool_handlers["file_read"] is handler def test_register_multiple_handlers(self): """Test registering multiple handlers.""" wrapper = MCPSafetyWrapper() wrapper.register_tool_handler("tool1", lambda: None) wrapper.register_tool_handler("tool2", lambda: None) wrapper.register_tool_handler("tool3", lambda: None) assert len(wrapper._tool_handlers) == 3 def test_overwrite_handler(self): """Test overwriting a handler.""" wrapper = MCPSafetyWrapper() handler1 = lambda: "first" # noqa: E731 handler2 = lambda: "second" # noqa: E731 wrapper.register_tool_handler("tool", handler1) wrapper.register_tool_handler("tool", handler2) assert wrapper._tool_handlers["tool"] is handler2 class TestMCPSafetyWrapperExecution: """Tests for tool execution.""" @pytest_asyncio.fixture async def mock_guardian(self): """Create a mock SafetyGuardian.""" guardian = AsyncMock() guardian.validate = AsyncMock() return guardian @pytest_asyncio.fixture async def mock_emergency(self): """Create a mock EmergencyControls.""" emergency = AsyncMock() emergency.check_allowed = AsyncMock() return emergency @pytest.mark.asyncio async def test_execute_allowed(self, mock_guardian, mock_emergency): """Test executing an allowed tool call.""" mock_guardian.validate.return_value = MagicMock( decision=SafetyDecision.ALLOW, reasons=[], approval_id=None, checkpoint_id=None, ) wrapper = MCPSafetyWrapper( guardian=mock_guardian, emergency_controls=mock_emergency, ) async def handler(path: str) -> dict: return {"content": f"Data from {path}"} wrapper.register_tool_handler("file_read", handler) call = MCPToolCall( tool_name="file_read", arguments={"path": "/test.txt"}, project_id="proj-1", ) result = await wrapper.execute(call, "agent-1") assert result.success is True assert result.result == {"content": "Data from /test.txt"} assert result.safety_decision == SafetyDecision.ALLOW @pytest.mark.asyncio async def test_execute_denied(self, mock_guardian, mock_emergency): """Test executing a denied tool call.""" mock_guardian.validate.return_value = MagicMock( decision=SafetyDecision.DENY, reasons=["Permission denied", "Rate limit exceeded"], ) wrapper = MCPSafetyWrapper( guardian=mock_guardian, emergency_controls=mock_emergency, ) call = MCPToolCall( tool_name="file_write", arguments={"path": "/etc/passwd"}, ) result = await wrapper.execute(call, "agent-1") assert result.success is False assert "Permission denied" in result.error assert "Rate limit exceeded" in result.error assert result.safety_decision == SafetyDecision.DENY @pytest.mark.asyncio async def test_execute_requires_approval(self, mock_guardian, mock_emergency): """Test executing a tool that requires approval.""" mock_guardian.validate.return_value = MagicMock( decision=SafetyDecision.REQUIRE_APPROVAL, reasons=["Destructive operation requires approval"], approval_id="approval-123", ) wrapper = MCPSafetyWrapper( guardian=mock_guardian, emergency_controls=mock_emergency, ) call = MCPToolCall( tool_name="file_delete", arguments={"path": "/important.txt"}, ) result = await wrapper.execute(call, "agent-1") assert result.success is False assert result.safety_decision == SafetyDecision.REQUIRE_APPROVAL assert result.approval_id == "approval-123" assert "requires human approval" in result.error @pytest.mark.asyncio async def test_execute_emergency_stop(self, mock_guardian, mock_emergency): """Test execution blocked by emergency stop.""" mock_emergency.check_allowed.side_effect = EmergencyStopError( "Emergency stop active" ) wrapper = MCPSafetyWrapper( guardian=mock_guardian, emergency_controls=mock_emergency, ) call = MCPToolCall( tool_name="file_write", arguments={"path": "/test.txt"}, project_id="proj-1", ) result = await wrapper.execute(call, "agent-1") assert result.success is False assert result.safety_decision == SafetyDecision.DENY assert result.metadata.get("emergency_stop") is True @pytest.mark.asyncio async def test_execute_bypass_safety(self, mock_guardian, mock_emergency): """Test executing with safety bypass.""" wrapper = MCPSafetyWrapper( guardian=mock_guardian, emergency_controls=mock_emergency, ) async def handler(data: str) -> str: return f"Processed: {data}" wrapper.register_tool_handler("custom_tool", handler) call = MCPToolCall( tool_name="custom_tool", arguments={"data": "test"}, ) result = await wrapper.execute(call, "agent-1", bypass_safety=True) assert result.success is True assert result.result == "Processed: test" # Guardian should not be called when bypassing mock_guardian.validate.assert_not_called() @pytest.mark.asyncio async def test_execute_no_handler(self, mock_guardian, mock_emergency): """Test executing a tool with no registered handler.""" mock_guardian.validate.return_value = MagicMock( decision=SafetyDecision.ALLOW, reasons=[], approval_id=None, checkpoint_id=None, ) wrapper = MCPSafetyWrapper( guardian=mock_guardian, emergency_controls=mock_emergency, ) call = MCPToolCall( tool_name="unregistered_tool", arguments={}, ) result = await wrapper.execute(call, "agent-1") assert result.success is False assert "No handler registered" in result.error @pytest.mark.asyncio async def test_execute_handler_exception(self, mock_guardian, mock_emergency): """Test handling exceptions from tool handler.""" mock_guardian.validate.return_value = MagicMock( decision=SafetyDecision.ALLOW, reasons=[], approval_id=None, checkpoint_id=None, ) wrapper = MCPSafetyWrapper( guardian=mock_guardian, emergency_controls=mock_emergency, ) async def failing_handler() -> None: raise ValueError("Handler failed!") wrapper.register_tool_handler("failing_tool", failing_handler) call = MCPToolCall( tool_name="failing_tool", arguments={}, ) result = await wrapper.execute(call, "agent-1") assert result.success is False assert "Handler failed!" in result.error # Decision is still ALLOW because the safety check passed assert result.safety_decision == SafetyDecision.ALLOW @pytest.mark.asyncio async def test_execute_sync_handler(self, mock_guardian, mock_emergency): """Test executing a synchronous handler.""" mock_guardian.validate.return_value = MagicMock( decision=SafetyDecision.ALLOW, reasons=[], approval_id=None, checkpoint_id=None, ) wrapper = MCPSafetyWrapper( guardian=mock_guardian, emergency_controls=mock_emergency, ) def sync_handler(value: int) -> int: return value * 2 wrapper.register_tool_handler("sync_tool", sync_handler) call = MCPToolCall( tool_name="sync_tool", arguments={"value": 21}, ) result = await wrapper.execute(call, "agent-1") assert result.success is True assert result.result == 42 class TestBuildActionRequest: """Tests for _build_action_request.""" def test_build_action_request_basic(self): """Test building a basic action request.""" wrapper = MCPSafetyWrapper() call = MCPToolCall( tool_name="file_read", arguments={"path": "/test.txt"}, project_id="proj-1", ) action = wrapper._build_action_request(call, "agent-1", AutonomyLevel.MILESTONE) assert action.action_type == ActionType.FILE_READ assert action.tool_name == "file_read" assert action.arguments == {"path": "/test.txt"} assert action.resource == "/test.txt" assert action.metadata.agent_id == "agent-1" assert action.metadata.project_id == "proj-1" assert action.metadata.autonomy_level == AutonomyLevel.MILESTONE def test_build_action_request_with_context(self): """Test building action request with session context.""" wrapper = MCPSafetyWrapper() call = MCPToolCall( tool_name="database_query", arguments={"resource": "users", "query": "SELECT *"}, context={"session_id": "sess-123"}, project_id="proj-2", ) action = wrapper._build_action_request( call, "agent-2", AutonomyLevel.AUTONOMOUS ) assert action.resource == "users" assert action.metadata.session_id == "sess-123" assert action.metadata.autonomy_level == AutonomyLevel.AUTONOMOUS def test_build_action_request_no_resource(self): """Test building action request without resource.""" wrapper = MCPSafetyWrapper() call = MCPToolCall( tool_name="llm_generate", arguments={"prompt": "Hello"}, ) action = wrapper._build_action_request( call, "agent-1", AutonomyLevel.FULL_CONTROL ) assert action.resource is None class TestElapsedTime: """Tests for _elapsed_ms helper.""" def test_elapsed_ms(self): """Test calculating elapsed time.""" wrapper = MCPSafetyWrapper() start = datetime.utcnow() - timedelta(milliseconds=100) elapsed = wrapper._elapsed_ms(start) # Should be at least 100ms, but allow some tolerance assert elapsed >= 99 assert elapsed < 200 class TestSafeToolExecutor: """Tests for SafeToolExecutor context manager.""" @pytest.mark.asyncio async def test_executor_execute(self): """Test executing within context manager.""" mock_guardian = AsyncMock() mock_guardian.validate.return_value = MagicMock( decision=SafetyDecision.ALLOW, reasons=[], approval_id=None, checkpoint_id=None, ) mock_emergency = AsyncMock() wrapper = MCPSafetyWrapper( guardian=mock_guardian, emergency_controls=mock_emergency, ) async def handler() -> str: return "success" wrapper.register_tool_handler("test_tool", handler) call = MCPToolCall(tool_name="test_tool", arguments={}) async with SafeToolExecutor(wrapper, call, "agent-1") as executor: result = await executor.execute() assert result.success is True assert result.result == "success" @pytest.mark.asyncio async def test_executor_result_property(self): """Test accessing result via property.""" mock_guardian = AsyncMock() mock_guardian.validate.return_value = MagicMock( decision=SafetyDecision.ALLOW, reasons=[], approval_id=None, checkpoint_id=None, ) mock_emergency = AsyncMock() wrapper = MCPSafetyWrapper( guardian=mock_guardian, emergency_controls=mock_emergency, ) wrapper.register_tool_handler("tool", lambda: "data") call = MCPToolCall(tool_name="tool", arguments={}) executor = SafeToolExecutor(wrapper, call, "agent-1") # Before execution assert executor.result is None async with executor: await executor.execute() # After execution assert executor.result is not None assert executor.result.success is True @pytest.mark.asyncio async def test_executor_with_autonomy_level(self): """Test executor with custom autonomy level.""" mock_guardian = AsyncMock() mock_guardian.validate.return_value = MagicMock( decision=SafetyDecision.ALLOW, reasons=[], approval_id=None, checkpoint_id=None, ) mock_emergency = AsyncMock() wrapper = MCPSafetyWrapper( guardian=mock_guardian, emergency_controls=mock_emergency, ) wrapper.register_tool_handler("tool", lambda: None) call = MCPToolCall(tool_name="tool", arguments={}) async with SafeToolExecutor( wrapper, call, "agent-1", AutonomyLevel.AUTONOMOUS ) as executor: await executor.execute() # Check that guardian was called with correct autonomy level mock_guardian.validate.assert_called_once() action = mock_guardian.validate.call_args[0][0] assert action.metadata.autonomy_level == AutonomyLevel.AUTONOMOUS class TestCreateMCPWrapper: """Tests for create_mcp_wrapper factory function.""" @pytest.mark.asyncio async def test_create_wrapper_with_guardian(self): """Test creating wrapper with provided guardian.""" mock_guardian = AsyncMock() with patch( "app.services.safety.mcp.integration.get_emergency_controls" ) as mock_get_emergency: mock_get_emergency.return_value = AsyncMock() wrapper = await create_mcp_wrapper(guardian=mock_guardian) assert wrapper._guardian is mock_guardian @pytest.mark.asyncio async def test_create_wrapper_default_guardian(self): """Test creating wrapper with default guardian.""" with ( patch( "app.services.safety.mcp.integration.get_safety_guardian" ) as mock_get_guardian, patch( "app.services.safety.mcp.integration.get_emergency_controls" ) as mock_get_emergency, ): mock_guardian = AsyncMock() mock_get_guardian.return_value = mock_guardian mock_get_emergency.return_value = AsyncMock() wrapper = await create_mcp_wrapper() assert wrapper._guardian is mock_guardian mock_get_guardian.assert_called_once() class TestLazyGetters: """Tests for lazy getter methods.""" @pytest.mark.asyncio async def test_get_guardian_lazy(self): """Test lazy guardian initialization.""" wrapper = MCPSafetyWrapper() with patch( "app.services.safety.mcp.integration.get_safety_guardian" ) as mock_get: mock_guardian = AsyncMock() mock_get.return_value = mock_guardian guardian = await wrapper._get_guardian() assert guardian is mock_guardian mock_get.assert_called_once() @pytest.mark.asyncio async def test_get_guardian_cached(self): """Test guardian is cached after first access.""" mock_guardian = AsyncMock() wrapper = MCPSafetyWrapper(guardian=mock_guardian) guardian = await wrapper._get_guardian() assert guardian is mock_guardian @pytest.mark.asyncio async def test_get_emergency_controls_lazy(self): """Test lazy emergency controls initialization.""" wrapper = MCPSafetyWrapper() with patch( "app.services.safety.mcp.integration.get_emergency_controls" ) as mock_get: mock_emergency = AsyncMock() mock_get.return_value = mock_emergency emergency = await wrapper._get_emergency_controls() assert emergency is mock_emergency mock_get.assert_called_once() @pytest.mark.asyncio async def test_get_emergency_controls_cached(self): """Test emergency controls is cached after first access.""" mock_emergency = AsyncMock() wrapper = MCPSafetyWrapper(emergency_controls=mock_emergency) emergency = await wrapper._get_emergency_controls() assert emergency is mock_emergency class TestEdgeCases: """Tests for edge cases and error handling.""" @pytest.mark.asyncio async def test_execute_with_safety_error(self): """Test handling SafetyError from guardian.""" from app.services.safety.exceptions import SafetyError mock_guardian = AsyncMock() mock_guardian.validate.side_effect = SafetyError("Internal safety error") mock_emergency = AsyncMock() wrapper = MCPSafetyWrapper( guardian=mock_guardian, emergency_controls=mock_emergency, ) call = MCPToolCall(tool_name="test", arguments={}) result = await wrapper.execute(call, "agent-1") assert result.success is False assert "Internal safety error" in result.error assert result.safety_decision == SafetyDecision.DENY @pytest.mark.asyncio async def test_execute_with_checkpoint_id(self): """Test that checkpoint_id is propagated to result.""" mock_guardian = AsyncMock() mock_guardian.validate.return_value = MagicMock( decision=SafetyDecision.ALLOW, reasons=[], approval_id=None, checkpoint_id="checkpoint-abc", ) mock_emergency = AsyncMock() wrapper = MCPSafetyWrapper( guardian=mock_guardian, emergency_controls=mock_emergency, ) wrapper.register_tool_handler("tool", lambda: "result") call = MCPToolCall(tool_name="tool", arguments={}) result = await wrapper.execute(call, "agent-1") assert result.success is True assert result.checkpoint_id == "checkpoint-abc" def test_destructive_tools_constant(self): """Test DESTRUCTIVE_TOOLS class constant.""" assert "file_write" in MCPSafetyWrapper.DESTRUCTIVE_TOOLS assert "file_delete" in MCPSafetyWrapper.DESTRUCTIVE_TOOLS assert "shell_execute" in MCPSafetyWrapper.DESTRUCTIVE_TOOLS assert "git_push" in MCPSafetyWrapper.DESTRUCTIVE_TOOLS def test_read_only_tools_constant(self): """Test READ_ONLY_TOOLS class constant.""" assert "file_read" in MCPSafetyWrapper.READ_ONLY_TOOLS assert "database_query" in MCPSafetyWrapper.READ_ONLY_TOOLS assert "git_status" in MCPSafetyWrapper.READ_ONLY_TOOLS assert "search" in MCPSafetyWrapper.READ_ONLY_TOOLS @pytest.mark.asyncio async def test_scope_with_project_id(self): """Test that scope is set correctly with project_id.""" mock_guardian = AsyncMock() mock_guardian.validate.return_value = MagicMock( decision=SafetyDecision.ALLOW, reasons=[], approval_id=None, checkpoint_id=None, ) mock_emergency = AsyncMock() wrapper = MCPSafetyWrapper( guardian=mock_guardian, emergency_controls=mock_emergency, ) wrapper.register_tool_handler("tool", lambda: None) call = MCPToolCall( tool_name="tool", arguments={}, project_id="proj-123", ) await wrapper.execute(call, "agent-1") # Verify emergency check was called with project scope mock_emergency.check_allowed.assert_called_once() call_kwargs = mock_emergency.check_allowed.call_args assert "project:proj-123" in str(call_kwargs) @pytest.mark.asyncio async def test_scope_without_project_id(self): """Test that scope falls back to agent when no project_id.""" mock_guardian = AsyncMock() mock_guardian.validate.return_value = MagicMock( decision=SafetyDecision.ALLOW, reasons=[], approval_id=None, checkpoint_id=None, ) mock_emergency = AsyncMock() wrapper = MCPSafetyWrapper( guardian=mock_guardian, emergency_controls=mock_emergency, ) wrapper.register_tool_handler("tool", lambda: None) call = MCPToolCall( tool_name="tool", arguments={}, # No project_id ) await wrapper.execute(call, "agent-555") # Verify emergency check was called with agent scope mock_emergency.check_allowed.assert_called_once() call_kwargs = mock_emergency.check_allowed.call_args assert "agent:agent-555" in str(call_kwargs)