Add tests to improve backend coverage from 85% to 93%: - test_audit.py: 60 tests for AuditLogger (20% -> 99%) - Hash chain integrity, sanitization, retention, handlers - Fixed bug: hash chain modification after event creation - Fixed bug: verification not using correct prev_hash - test_hitl.py: Tests for HITL manager (0% -> 100%) - test_permissions.py: Tests for permissions manager (0% -> 99%) - test_rollback.py: Tests for rollback manager (0% -> 100%) - test_metrics.py: Tests for metrics collector (0% -> 100%) - test_mcp_integration.py: Tests for MCP safety wrapper (0% -> 100%) - test_validation.py: Additional cache and edge case tests (76% -> 100%) - test_scoring.py: Lock cleanup and edge case tests (78% -> 91%) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
875 lines
28 KiB
Python
875 lines
28 KiB
Python
"""
|
|
Tests for MCP Safety Integration.
|
|
|
|
Tests cover:
|
|
- MCPToolCall and MCPToolResult data structures
|
|
- MCPSafetyWrapper: tool registration, execution, safety checks
|
|
- Tool classification and action type mapping
|
|
- SafeToolExecutor context manager
|
|
- Factory function create_mcp_wrapper
|
|
"""
|
|
|
|
from datetime import datetime, timedelta
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
|
|
from app.services.safety.exceptions import EmergencyStopError
|
|
from app.services.safety.mcp.integration import (
|
|
MCPSafetyWrapper,
|
|
MCPToolCall,
|
|
MCPToolResult,
|
|
SafeToolExecutor,
|
|
create_mcp_wrapper,
|
|
)
|
|
from app.services.safety.models import (
|
|
ActionType,
|
|
AutonomyLevel,
|
|
SafetyDecision,
|
|
)
|
|
|
|
|
|
class TestMCPToolCall:
|
|
"""Tests for MCPToolCall dataclass."""
|
|
|
|
def test_tool_call_creation(self):
|
|
"""Test creating a tool call."""
|
|
call = MCPToolCall(
|
|
tool_name="file_read",
|
|
arguments={"path": "/tmp/test.txt"}, # noqa: S108
|
|
server_name="file-server",
|
|
project_id="proj-1",
|
|
context={"session_id": "sess-1"},
|
|
)
|
|
|
|
assert call.tool_name == "file_read"
|
|
assert call.arguments == {"path": "/tmp/test.txt"} # noqa: S108
|
|
assert call.server_name == "file-server"
|
|
assert call.project_id == "proj-1"
|
|
assert call.context == {"session_id": "sess-1"}
|
|
|
|
def test_tool_call_defaults(self):
|
|
"""Test tool call default values."""
|
|
call = MCPToolCall(
|
|
tool_name="test",
|
|
arguments={},
|
|
)
|
|
|
|
assert call.server_name is None
|
|
assert call.project_id is None
|
|
assert call.context == {}
|
|
|
|
|
|
class TestMCPToolResult:
|
|
"""Tests for MCPToolResult dataclass."""
|
|
|
|
def test_tool_result_success(self):
|
|
"""Test creating a successful result."""
|
|
result = MCPToolResult(
|
|
success=True,
|
|
result={"data": "test"},
|
|
safety_decision=SafetyDecision.ALLOW,
|
|
execution_time_ms=50.0,
|
|
)
|
|
|
|
assert result.success is True
|
|
assert result.result == {"data": "test"}
|
|
assert result.error is None
|
|
assert result.safety_decision == SafetyDecision.ALLOW
|
|
assert result.execution_time_ms == 50.0
|
|
|
|
def test_tool_result_failure(self):
|
|
"""Test creating a failed result."""
|
|
result = MCPToolResult(
|
|
success=False,
|
|
error="Permission denied",
|
|
safety_decision=SafetyDecision.DENY,
|
|
)
|
|
|
|
assert result.success is False
|
|
assert result.error == "Permission denied"
|
|
assert result.result is None
|
|
|
|
def test_tool_result_with_ids(self):
|
|
"""Test result with approval and checkpoint IDs."""
|
|
result = MCPToolResult(
|
|
success=True,
|
|
approval_id="approval-123",
|
|
checkpoint_id="checkpoint-456",
|
|
)
|
|
|
|
assert result.approval_id == "approval-123"
|
|
assert result.checkpoint_id == "checkpoint-456"
|
|
|
|
def test_tool_result_defaults(self):
|
|
"""Test result default values."""
|
|
result = MCPToolResult(success=True)
|
|
|
|
assert result.result is None
|
|
assert result.error is None
|
|
assert result.safety_decision == SafetyDecision.ALLOW
|
|
assert result.execution_time_ms == 0.0
|
|
assert result.approval_id is None
|
|
assert result.checkpoint_id is None
|
|
assert result.metadata == {}
|
|
|
|
|
|
class TestMCPSafetyWrapperClassification:
|
|
"""Tests for tool classification."""
|
|
|
|
def test_classify_file_read(self):
|
|
"""Test classifying file read tools."""
|
|
wrapper = MCPSafetyWrapper()
|
|
|
|
assert wrapper._classify_tool("file_read") == ActionType.FILE_READ
|
|
assert wrapper._classify_tool("get_file") == ActionType.FILE_READ
|
|
assert wrapper._classify_tool("list_files") == ActionType.FILE_READ
|
|
assert wrapper._classify_tool("search_file") == ActionType.FILE_READ
|
|
|
|
def test_classify_file_write(self):
|
|
"""Test classifying file write tools."""
|
|
wrapper = MCPSafetyWrapper()
|
|
|
|
assert wrapper._classify_tool("file_write") == ActionType.FILE_WRITE
|
|
assert wrapper._classify_tool("create_file") == ActionType.FILE_WRITE
|
|
assert wrapper._classify_tool("update_file") == ActionType.FILE_WRITE
|
|
|
|
def test_classify_file_delete(self):
|
|
"""Test classifying file delete tools."""
|
|
wrapper = MCPSafetyWrapper()
|
|
|
|
assert wrapper._classify_tool("file_delete") == ActionType.FILE_DELETE
|
|
assert wrapper._classify_tool("remove_file") == ActionType.FILE_DELETE
|
|
|
|
def test_classify_database_read(self):
|
|
"""Test classifying database read tools."""
|
|
wrapper = MCPSafetyWrapper()
|
|
|
|
assert wrapper._classify_tool("database_query") == ActionType.DATABASE_QUERY
|
|
assert wrapper._classify_tool("db_read") == ActionType.DATABASE_QUERY
|
|
assert wrapper._classify_tool("query_database") == ActionType.DATABASE_QUERY
|
|
|
|
def test_classify_database_mutate(self):
|
|
"""Test classifying database mutate tools."""
|
|
wrapper = MCPSafetyWrapper()
|
|
|
|
assert wrapper._classify_tool("database_write") == ActionType.DATABASE_MUTATE
|
|
assert wrapper._classify_tool("db_update") == ActionType.DATABASE_MUTATE
|
|
assert wrapper._classify_tool("database_delete") == ActionType.DATABASE_MUTATE
|
|
|
|
def test_classify_shell_command(self):
|
|
"""Test classifying shell command tools."""
|
|
wrapper = MCPSafetyWrapper()
|
|
|
|
assert wrapper._classify_tool("shell_execute") == ActionType.SHELL_COMMAND
|
|
assert wrapper._classify_tool("exec_command") == ActionType.SHELL_COMMAND
|
|
assert wrapper._classify_tool("bash_run") == ActionType.SHELL_COMMAND
|
|
|
|
def test_classify_git_operation(self):
|
|
"""Test classifying git tools."""
|
|
wrapper = MCPSafetyWrapper()
|
|
|
|
assert wrapper._classify_tool("git_commit") == ActionType.GIT_OPERATION
|
|
assert wrapper._classify_tool("git_push") == ActionType.GIT_OPERATION
|
|
assert wrapper._classify_tool("git_status") == ActionType.GIT_OPERATION
|
|
|
|
def test_classify_network_request(self):
|
|
"""Test classifying network tools."""
|
|
wrapper = MCPSafetyWrapper()
|
|
|
|
assert wrapper._classify_tool("http_get") == ActionType.NETWORK_REQUEST
|
|
assert wrapper._classify_tool("fetch_url") == ActionType.NETWORK_REQUEST
|
|
assert wrapper._classify_tool("api_request") == ActionType.NETWORK_REQUEST
|
|
|
|
def test_classify_llm_call(self):
|
|
"""Test classifying LLM tools."""
|
|
wrapper = MCPSafetyWrapper()
|
|
|
|
assert wrapper._classify_tool("llm_generate") == ActionType.LLM_CALL
|
|
assert wrapper._classify_tool("ai_complete") == ActionType.LLM_CALL
|
|
assert wrapper._classify_tool("claude_chat") == ActionType.LLM_CALL
|
|
|
|
def test_classify_default(self):
|
|
"""Test default classification for unknown tools."""
|
|
wrapper = MCPSafetyWrapper()
|
|
|
|
assert wrapper._classify_tool("unknown_tool") == ActionType.TOOL_CALL
|
|
assert wrapper._classify_tool("custom_action") == ActionType.TOOL_CALL
|
|
|
|
|
|
class TestMCPSafetyWrapperToolHandlers:
|
|
"""Tests for tool handler registration."""
|
|
|
|
def test_register_tool_handler(self):
|
|
"""Test registering a tool handler."""
|
|
wrapper = MCPSafetyWrapper()
|
|
|
|
def handler(path: str) -> str:
|
|
return f"Read: {path}"
|
|
|
|
wrapper.register_tool_handler("file_read", handler)
|
|
|
|
assert "file_read" in wrapper._tool_handlers
|
|
assert wrapper._tool_handlers["file_read"] is handler
|
|
|
|
def test_register_multiple_handlers(self):
|
|
"""Test registering multiple handlers."""
|
|
wrapper = MCPSafetyWrapper()
|
|
|
|
wrapper.register_tool_handler("tool1", lambda: None)
|
|
wrapper.register_tool_handler("tool2", lambda: None)
|
|
wrapper.register_tool_handler("tool3", lambda: None)
|
|
|
|
assert len(wrapper._tool_handlers) == 3
|
|
|
|
def test_overwrite_handler(self):
|
|
"""Test overwriting a handler."""
|
|
wrapper = MCPSafetyWrapper()
|
|
|
|
handler1 = lambda: "first" # noqa: E731
|
|
handler2 = lambda: "second" # noqa: E731
|
|
|
|
wrapper.register_tool_handler("tool", handler1)
|
|
wrapper.register_tool_handler("tool", handler2)
|
|
|
|
assert wrapper._tool_handlers["tool"] is handler2
|
|
|
|
|
|
class TestMCPSafetyWrapperExecution:
|
|
"""Tests for tool execution."""
|
|
|
|
@pytest_asyncio.fixture
|
|
async def mock_guardian(self):
|
|
"""Create a mock SafetyGuardian."""
|
|
guardian = AsyncMock()
|
|
guardian.validate = AsyncMock()
|
|
return guardian
|
|
|
|
@pytest_asyncio.fixture
|
|
async def mock_emergency(self):
|
|
"""Create a mock EmergencyControls."""
|
|
emergency = AsyncMock()
|
|
emergency.check_allowed = AsyncMock()
|
|
return emergency
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_allowed(self, mock_guardian, mock_emergency):
|
|
"""Test executing an allowed tool call."""
|
|
mock_guardian.validate.return_value = MagicMock(
|
|
decision=SafetyDecision.ALLOW,
|
|
reasons=[],
|
|
approval_id=None,
|
|
checkpoint_id=None,
|
|
)
|
|
|
|
wrapper = MCPSafetyWrapper(
|
|
guardian=mock_guardian,
|
|
emergency_controls=mock_emergency,
|
|
)
|
|
|
|
async def handler(path: str) -> dict:
|
|
return {"content": f"Data from {path}"}
|
|
|
|
wrapper.register_tool_handler("file_read", handler)
|
|
|
|
call = MCPToolCall(
|
|
tool_name="file_read",
|
|
arguments={"path": "/test.txt"},
|
|
project_id="proj-1",
|
|
)
|
|
|
|
result = await wrapper.execute(call, "agent-1")
|
|
|
|
assert result.success is True
|
|
assert result.result == {"content": "Data from /test.txt"}
|
|
assert result.safety_decision == SafetyDecision.ALLOW
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_denied(self, mock_guardian, mock_emergency):
|
|
"""Test executing a denied tool call."""
|
|
mock_guardian.validate.return_value = MagicMock(
|
|
decision=SafetyDecision.DENY,
|
|
reasons=["Permission denied", "Rate limit exceeded"],
|
|
)
|
|
|
|
wrapper = MCPSafetyWrapper(
|
|
guardian=mock_guardian,
|
|
emergency_controls=mock_emergency,
|
|
)
|
|
|
|
call = MCPToolCall(
|
|
tool_name="file_write",
|
|
arguments={"path": "/etc/passwd"},
|
|
)
|
|
|
|
result = await wrapper.execute(call, "agent-1")
|
|
|
|
assert result.success is False
|
|
assert "Permission denied" in result.error
|
|
assert "Rate limit exceeded" in result.error
|
|
assert result.safety_decision == SafetyDecision.DENY
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_requires_approval(self, mock_guardian, mock_emergency):
|
|
"""Test executing a tool that requires approval."""
|
|
mock_guardian.validate.return_value = MagicMock(
|
|
decision=SafetyDecision.REQUIRE_APPROVAL,
|
|
reasons=["Destructive operation requires approval"],
|
|
approval_id="approval-123",
|
|
)
|
|
|
|
wrapper = MCPSafetyWrapper(
|
|
guardian=mock_guardian,
|
|
emergency_controls=mock_emergency,
|
|
)
|
|
|
|
call = MCPToolCall(
|
|
tool_name="file_delete",
|
|
arguments={"path": "/important.txt"},
|
|
)
|
|
|
|
result = await wrapper.execute(call, "agent-1")
|
|
|
|
assert result.success is False
|
|
assert result.safety_decision == SafetyDecision.REQUIRE_APPROVAL
|
|
assert result.approval_id == "approval-123"
|
|
assert "requires human approval" in result.error
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_emergency_stop(self, mock_guardian, mock_emergency):
|
|
"""Test execution blocked by emergency stop."""
|
|
mock_emergency.check_allowed.side_effect = EmergencyStopError(
|
|
"Emergency stop active"
|
|
)
|
|
|
|
wrapper = MCPSafetyWrapper(
|
|
guardian=mock_guardian,
|
|
emergency_controls=mock_emergency,
|
|
)
|
|
|
|
call = MCPToolCall(
|
|
tool_name="file_write",
|
|
arguments={"path": "/test.txt"},
|
|
project_id="proj-1",
|
|
)
|
|
|
|
result = await wrapper.execute(call, "agent-1")
|
|
|
|
assert result.success is False
|
|
assert result.safety_decision == SafetyDecision.DENY
|
|
assert result.metadata.get("emergency_stop") is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_bypass_safety(self, mock_guardian, mock_emergency):
|
|
"""Test executing with safety bypass."""
|
|
wrapper = MCPSafetyWrapper(
|
|
guardian=mock_guardian,
|
|
emergency_controls=mock_emergency,
|
|
)
|
|
|
|
async def handler(data: str) -> str:
|
|
return f"Processed: {data}"
|
|
|
|
wrapper.register_tool_handler("custom_tool", handler)
|
|
|
|
call = MCPToolCall(
|
|
tool_name="custom_tool",
|
|
arguments={"data": "test"},
|
|
)
|
|
|
|
result = await wrapper.execute(call, "agent-1", bypass_safety=True)
|
|
|
|
assert result.success is True
|
|
assert result.result == "Processed: test"
|
|
# Guardian should not be called when bypassing
|
|
mock_guardian.validate.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_no_handler(self, mock_guardian, mock_emergency):
|
|
"""Test executing a tool with no registered handler."""
|
|
mock_guardian.validate.return_value = MagicMock(
|
|
decision=SafetyDecision.ALLOW,
|
|
reasons=[],
|
|
approval_id=None,
|
|
checkpoint_id=None,
|
|
)
|
|
|
|
wrapper = MCPSafetyWrapper(
|
|
guardian=mock_guardian,
|
|
emergency_controls=mock_emergency,
|
|
)
|
|
|
|
call = MCPToolCall(
|
|
tool_name="unregistered_tool",
|
|
arguments={},
|
|
)
|
|
|
|
result = await wrapper.execute(call, "agent-1")
|
|
|
|
assert result.success is False
|
|
assert "No handler registered" in result.error
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_handler_exception(self, mock_guardian, mock_emergency):
|
|
"""Test handling exceptions from tool handler."""
|
|
mock_guardian.validate.return_value = MagicMock(
|
|
decision=SafetyDecision.ALLOW,
|
|
reasons=[],
|
|
approval_id=None,
|
|
checkpoint_id=None,
|
|
)
|
|
|
|
wrapper = MCPSafetyWrapper(
|
|
guardian=mock_guardian,
|
|
emergency_controls=mock_emergency,
|
|
)
|
|
|
|
async def failing_handler() -> None:
|
|
raise ValueError("Handler failed!")
|
|
|
|
wrapper.register_tool_handler("failing_tool", failing_handler)
|
|
|
|
call = MCPToolCall(
|
|
tool_name="failing_tool",
|
|
arguments={},
|
|
)
|
|
|
|
result = await wrapper.execute(call, "agent-1")
|
|
|
|
assert result.success is False
|
|
assert "Handler failed!" in result.error
|
|
# Decision is still ALLOW because the safety check passed
|
|
assert result.safety_decision == SafetyDecision.ALLOW
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_sync_handler(self, mock_guardian, mock_emergency):
|
|
"""Test executing a synchronous handler."""
|
|
mock_guardian.validate.return_value = MagicMock(
|
|
decision=SafetyDecision.ALLOW,
|
|
reasons=[],
|
|
approval_id=None,
|
|
checkpoint_id=None,
|
|
)
|
|
|
|
wrapper = MCPSafetyWrapper(
|
|
guardian=mock_guardian,
|
|
emergency_controls=mock_emergency,
|
|
)
|
|
|
|
def sync_handler(value: int) -> int:
|
|
return value * 2
|
|
|
|
wrapper.register_tool_handler("sync_tool", sync_handler)
|
|
|
|
call = MCPToolCall(
|
|
tool_name="sync_tool",
|
|
arguments={"value": 21},
|
|
)
|
|
|
|
result = await wrapper.execute(call, "agent-1")
|
|
|
|
assert result.success is True
|
|
assert result.result == 42
|
|
|
|
|
|
class TestBuildActionRequest:
|
|
"""Tests for _build_action_request."""
|
|
|
|
def test_build_action_request_basic(self):
|
|
"""Test building a basic action request."""
|
|
wrapper = MCPSafetyWrapper()
|
|
|
|
call = MCPToolCall(
|
|
tool_name="file_read",
|
|
arguments={"path": "/test.txt"},
|
|
project_id="proj-1",
|
|
)
|
|
|
|
action = wrapper._build_action_request(call, "agent-1", AutonomyLevel.MILESTONE)
|
|
|
|
assert action.action_type == ActionType.FILE_READ
|
|
assert action.tool_name == "file_read"
|
|
assert action.arguments == {"path": "/test.txt"}
|
|
assert action.resource == "/test.txt"
|
|
assert action.metadata.agent_id == "agent-1"
|
|
assert action.metadata.project_id == "proj-1"
|
|
assert action.metadata.autonomy_level == AutonomyLevel.MILESTONE
|
|
|
|
def test_build_action_request_with_context(self):
|
|
"""Test building action request with session context."""
|
|
wrapper = MCPSafetyWrapper()
|
|
|
|
call = MCPToolCall(
|
|
tool_name="database_query",
|
|
arguments={"resource": "users", "query": "SELECT *"},
|
|
context={"session_id": "sess-123"},
|
|
project_id="proj-2",
|
|
)
|
|
|
|
action = wrapper._build_action_request(
|
|
call, "agent-2", AutonomyLevel.AUTONOMOUS
|
|
)
|
|
|
|
assert action.resource == "users"
|
|
assert action.metadata.session_id == "sess-123"
|
|
assert action.metadata.autonomy_level == AutonomyLevel.AUTONOMOUS
|
|
|
|
def test_build_action_request_no_resource(self):
|
|
"""Test building action request without resource."""
|
|
wrapper = MCPSafetyWrapper()
|
|
|
|
call = MCPToolCall(
|
|
tool_name="llm_generate",
|
|
arguments={"prompt": "Hello"},
|
|
)
|
|
|
|
action = wrapper._build_action_request(
|
|
call, "agent-1", AutonomyLevel.FULL_CONTROL
|
|
)
|
|
|
|
assert action.resource is None
|
|
|
|
|
|
class TestElapsedTime:
|
|
"""Tests for _elapsed_ms helper."""
|
|
|
|
def test_elapsed_ms(self):
|
|
"""Test calculating elapsed time."""
|
|
wrapper = MCPSafetyWrapper()
|
|
|
|
start = datetime.utcnow() - timedelta(milliseconds=100)
|
|
elapsed = wrapper._elapsed_ms(start)
|
|
|
|
# Should be at least 100ms, but allow some tolerance
|
|
assert elapsed >= 99
|
|
assert elapsed < 200
|
|
|
|
|
|
class TestSafeToolExecutor:
|
|
"""Tests for SafeToolExecutor context manager."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_executor_execute(self):
|
|
"""Test executing within context manager."""
|
|
mock_guardian = AsyncMock()
|
|
mock_guardian.validate.return_value = MagicMock(
|
|
decision=SafetyDecision.ALLOW,
|
|
reasons=[],
|
|
approval_id=None,
|
|
checkpoint_id=None,
|
|
)
|
|
|
|
mock_emergency = AsyncMock()
|
|
|
|
wrapper = MCPSafetyWrapper(
|
|
guardian=mock_guardian,
|
|
emergency_controls=mock_emergency,
|
|
)
|
|
|
|
async def handler() -> str:
|
|
return "success"
|
|
|
|
wrapper.register_tool_handler("test_tool", handler)
|
|
|
|
call = MCPToolCall(tool_name="test_tool", arguments={})
|
|
|
|
async with SafeToolExecutor(wrapper, call, "agent-1") as executor:
|
|
result = await executor.execute()
|
|
|
|
assert result.success is True
|
|
assert result.result == "success"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_executor_result_property(self):
|
|
"""Test accessing result via property."""
|
|
mock_guardian = AsyncMock()
|
|
mock_guardian.validate.return_value = MagicMock(
|
|
decision=SafetyDecision.ALLOW,
|
|
reasons=[],
|
|
approval_id=None,
|
|
checkpoint_id=None,
|
|
)
|
|
|
|
mock_emergency = AsyncMock()
|
|
|
|
wrapper = MCPSafetyWrapper(
|
|
guardian=mock_guardian,
|
|
emergency_controls=mock_emergency,
|
|
)
|
|
|
|
wrapper.register_tool_handler("tool", lambda: "data")
|
|
|
|
call = MCPToolCall(tool_name="tool", arguments={})
|
|
executor = SafeToolExecutor(wrapper, call, "agent-1")
|
|
|
|
# Before execution
|
|
assert executor.result is None
|
|
|
|
async with executor:
|
|
await executor.execute()
|
|
|
|
# After execution
|
|
assert executor.result is not None
|
|
assert executor.result.success is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_executor_with_autonomy_level(self):
|
|
"""Test executor with custom autonomy level."""
|
|
mock_guardian = AsyncMock()
|
|
mock_guardian.validate.return_value = MagicMock(
|
|
decision=SafetyDecision.ALLOW,
|
|
reasons=[],
|
|
approval_id=None,
|
|
checkpoint_id=None,
|
|
)
|
|
|
|
mock_emergency = AsyncMock()
|
|
|
|
wrapper = MCPSafetyWrapper(
|
|
guardian=mock_guardian,
|
|
emergency_controls=mock_emergency,
|
|
)
|
|
|
|
wrapper.register_tool_handler("tool", lambda: None)
|
|
|
|
call = MCPToolCall(tool_name="tool", arguments={})
|
|
|
|
async with SafeToolExecutor(
|
|
wrapper, call, "agent-1", AutonomyLevel.AUTONOMOUS
|
|
) as executor:
|
|
await executor.execute()
|
|
|
|
# Check that guardian was called with correct autonomy level
|
|
mock_guardian.validate.assert_called_once()
|
|
action = mock_guardian.validate.call_args[0][0]
|
|
assert action.metadata.autonomy_level == AutonomyLevel.AUTONOMOUS
|
|
|
|
|
|
class TestCreateMCPWrapper:
|
|
"""Tests for create_mcp_wrapper factory function."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_wrapper_with_guardian(self):
|
|
"""Test creating wrapper with provided guardian."""
|
|
mock_guardian = AsyncMock()
|
|
|
|
with patch(
|
|
"app.services.safety.mcp.integration.get_emergency_controls"
|
|
) as mock_get_emergency:
|
|
mock_get_emergency.return_value = AsyncMock()
|
|
|
|
wrapper = await create_mcp_wrapper(guardian=mock_guardian)
|
|
|
|
assert wrapper._guardian is mock_guardian
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_wrapper_default_guardian(self):
|
|
"""Test creating wrapper with default guardian."""
|
|
with (
|
|
patch(
|
|
"app.services.safety.mcp.integration.get_safety_guardian"
|
|
) as mock_get_guardian,
|
|
patch(
|
|
"app.services.safety.mcp.integration.get_emergency_controls"
|
|
) as mock_get_emergency,
|
|
):
|
|
mock_guardian = AsyncMock()
|
|
mock_get_guardian.return_value = mock_guardian
|
|
mock_get_emergency.return_value = AsyncMock()
|
|
|
|
wrapper = await create_mcp_wrapper()
|
|
|
|
assert wrapper._guardian is mock_guardian
|
|
mock_get_guardian.assert_called_once()
|
|
|
|
|
|
class TestLazyGetters:
|
|
"""Tests for lazy getter methods."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_guardian_lazy(self):
|
|
"""Test lazy guardian initialization."""
|
|
wrapper = MCPSafetyWrapper()
|
|
|
|
with patch(
|
|
"app.services.safety.mcp.integration.get_safety_guardian"
|
|
) as mock_get:
|
|
mock_guardian = AsyncMock()
|
|
mock_get.return_value = mock_guardian
|
|
|
|
guardian = await wrapper._get_guardian()
|
|
|
|
assert guardian is mock_guardian
|
|
mock_get.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_guardian_cached(self):
|
|
"""Test guardian is cached after first access."""
|
|
mock_guardian = AsyncMock()
|
|
wrapper = MCPSafetyWrapper(guardian=mock_guardian)
|
|
|
|
guardian = await wrapper._get_guardian()
|
|
|
|
assert guardian is mock_guardian
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_emergency_controls_lazy(self):
|
|
"""Test lazy emergency controls initialization."""
|
|
wrapper = MCPSafetyWrapper()
|
|
|
|
with patch(
|
|
"app.services.safety.mcp.integration.get_emergency_controls"
|
|
) as mock_get:
|
|
mock_emergency = AsyncMock()
|
|
mock_get.return_value = mock_emergency
|
|
|
|
emergency = await wrapper._get_emergency_controls()
|
|
|
|
assert emergency is mock_emergency
|
|
mock_get.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_emergency_controls_cached(self):
|
|
"""Test emergency controls is cached after first access."""
|
|
mock_emergency = AsyncMock()
|
|
wrapper = MCPSafetyWrapper(emergency_controls=mock_emergency)
|
|
|
|
emergency = await wrapper._get_emergency_controls()
|
|
|
|
assert emergency is mock_emergency
|
|
|
|
|
|
class TestEdgeCases:
|
|
"""Tests for edge cases and error handling."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_with_safety_error(self):
|
|
"""Test handling SafetyError from guardian."""
|
|
from app.services.safety.exceptions import SafetyError
|
|
|
|
mock_guardian = AsyncMock()
|
|
mock_guardian.validate.side_effect = SafetyError("Internal safety error")
|
|
|
|
mock_emergency = AsyncMock()
|
|
|
|
wrapper = MCPSafetyWrapper(
|
|
guardian=mock_guardian,
|
|
emergency_controls=mock_emergency,
|
|
)
|
|
|
|
call = MCPToolCall(tool_name="test", arguments={})
|
|
|
|
result = await wrapper.execute(call, "agent-1")
|
|
|
|
assert result.success is False
|
|
assert "Internal safety error" in result.error
|
|
assert result.safety_decision == SafetyDecision.DENY
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_with_checkpoint_id(self):
|
|
"""Test that checkpoint_id is propagated to result."""
|
|
mock_guardian = AsyncMock()
|
|
mock_guardian.validate.return_value = MagicMock(
|
|
decision=SafetyDecision.ALLOW,
|
|
reasons=[],
|
|
approval_id=None,
|
|
checkpoint_id="checkpoint-abc",
|
|
)
|
|
|
|
mock_emergency = AsyncMock()
|
|
|
|
wrapper = MCPSafetyWrapper(
|
|
guardian=mock_guardian,
|
|
emergency_controls=mock_emergency,
|
|
)
|
|
|
|
wrapper.register_tool_handler("tool", lambda: "result")
|
|
|
|
call = MCPToolCall(tool_name="tool", arguments={})
|
|
|
|
result = await wrapper.execute(call, "agent-1")
|
|
|
|
assert result.success is True
|
|
assert result.checkpoint_id == "checkpoint-abc"
|
|
|
|
def test_destructive_tools_constant(self):
|
|
"""Test DESTRUCTIVE_TOOLS class constant."""
|
|
assert "file_write" in MCPSafetyWrapper.DESTRUCTIVE_TOOLS
|
|
assert "file_delete" in MCPSafetyWrapper.DESTRUCTIVE_TOOLS
|
|
assert "shell_execute" in MCPSafetyWrapper.DESTRUCTIVE_TOOLS
|
|
assert "git_push" in MCPSafetyWrapper.DESTRUCTIVE_TOOLS
|
|
|
|
def test_read_only_tools_constant(self):
|
|
"""Test READ_ONLY_TOOLS class constant."""
|
|
assert "file_read" in MCPSafetyWrapper.READ_ONLY_TOOLS
|
|
assert "database_query" in MCPSafetyWrapper.READ_ONLY_TOOLS
|
|
assert "git_status" in MCPSafetyWrapper.READ_ONLY_TOOLS
|
|
assert "search" in MCPSafetyWrapper.READ_ONLY_TOOLS
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_scope_with_project_id(self):
|
|
"""Test that scope is set correctly with project_id."""
|
|
mock_guardian = AsyncMock()
|
|
mock_guardian.validate.return_value = MagicMock(
|
|
decision=SafetyDecision.ALLOW,
|
|
reasons=[],
|
|
approval_id=None,
|
|
checkpoint_id=None,
|
|
)
|
|
|
|
mock_emergency = AsyncMock()
|
|
|
|
wrapper = MCPSafetyWrapper(
|
|
guardian=mock_guardian,
|
|
emergency_controls=mock_emergency,
|
|
)
|
|
|
|
wrapper.register_tool_handler("tool", lambda: None)
|
|
|
|
call = MCPToolCall(
|
|
tool_name="tool",
|
|
arguments={},
|
|
project_id="proj-123",
|
|
)
|
|
|
|
await wrapper.execute(call, "agent-1")
|
|
|
|
# Verify emergency check was called with project scope
|
|
mock_emergency.check_allowed.assert_called_once()
|
|
call_kwargs = mock_emergency.check_allowed.call_args
|
|
assert "project:proj-123" in str(call_kwargs)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_scope_without_project_id(self):
|
|
"""Test that scope falls back to agent when no project_id."""
|
|
mock_guardian = AsyncMock()
|
|
mock_guardian.validate.return_value = MagicMock(
|
|
decision=SafetyDecision.ALLOW,
|
|
reasons=[],
|
|
approval_id=None,
|
|
checkpoint_id=None,
|
|
)
|
|
|
|
mock_emergency = AsyncMock()
|
|
|
|
wrapper = MCPSafetyWrapper(
|
|
guardian=mock_guardian,
|
|
emergency_controls=mock_emergency,
|
|
)
|
|
|
|
wrapper.register_tool_handler("tool", lambda: None)
|
|
|
|
call = MCPToolCall(
|
|
tool_name="tool",
|
|
arguments={},
|
|
# No project_id
|
|
)
|
|
|
|
await wrapper.execute(call, "agent-555")
|
|
|
|
# Verify emergency check was called with agent scope
|
|
mock_emergency.check_allowed.assert_called_once()
|
|
call_kwargs = mock_emergency.check_allowed.call_args
|
|
assert "agent:agent-555" in str(call_kwargs)
|