forked from cardosofelipe/pragma-stack
test(safety): add comprehensive tests for safety framework modules
Add tests to improve backend coverage from 85% to 93%: - test_audit.py: 60 tests for AuditLogger (20% -> 99%) - Hash chain integrity, sanitization, retention, handlers - Fixed bug: hash chain modification after event creation - Fixed bug: verification not using correct prev_hash - test_hitl.py: Tests for HITL manager (0% -> 100%) - test_permissions.py: Tests for permissions manager (0% -> 99%) - test_rollback.py: Tests for rollback manager (0% -> 100%) - test_metrics.py: Tests for metrics collector (0% -> 100%) - test_mcp_integration.py: Tests for MCP safety wrapper (0% -> 100%) - test_validation.py: Additional cache and edge case tests (76% -> 100%) - test_scoring.py: Lock cleanup and edge case tests (78% -> 91%) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
874
backend/tests/services/safety/test_mcp_integration.py
Normal file
874
backend/tests/services/safety/test_mcp_integration.py
Normal file
@@ -0,0 +1,874 @@
|
||||
"""
|
||||
Tests for MCP Safety Integration.
|
||||
|
||||
Tests cover:
|
||||
- MCPToolCall and MCPToolResult data structures
|
||||
- MCPSafetyWrapper: tool registration, execution, safety checks
|
||||
- Tool classification and action type mapping
|
||||
- SafeToolExecutor context manager
|
||||
- Factory function create_mcp_wrapper
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from app.services.safety.exceptions import EmergencyStopError
|
||||
from app.services.safety.mcp.integration import (
|
||||
MCPSafetyWrapper,
|
||||
MCPToolCall,
|
||||
MCPToolResult,
|
||||
SafeToolExecutor,
|
||||
create_mcp_wrapper,
|
||||
)
|
||||
from app.services.safety.models import (
|
||||
ActionType,
|
||||
AutonomyLevel,
|
||||
SafetyDecision,
|
||||
)
|
||||
|
||||
|
||||
class TestMCPToolCall:
|
||||
"""Tests for MCPToolCall dataclass."""
|
||||
|
||||
def test_tool_call_creation(self):
|
||||
"""Test creating a tool call."""
|
||||
call = MCPToolCall(
|
||||
tool_name="file_read",
|
||||
arguments={"path": "/tmp/test.txt"}, # noqa: S108
|
||||
server_name="file-server",
|
||||
project_id="proj-1",
|
||||
context={"session_id": "sess-1"},
|
||||
)
|
||||
|
||||
assert call.tool_name == "file_read"
|
||||
assert call.arguments == {"path": "/tmp/test.txt"} # noqa: S108
|
||||
assert call.server_name == "file-server"
|
||||
assert call.project_id == "proj-1"
|
||||
assert call.context == {"session_id": "sess-1"}
|
||||
|
||||
def test_tool_call_defaults(self):
|
||||
"""Test tool call default values."""
|
||||
call = MCPToolCall(
|
||||
tool_name="test",
|
||||
arguments={},
|
||||
)
|
||||
|
||||
assert call.server_name is None
|
||||
assert call.project_id is None
|
||||
assert call.context == {}
|
||||
|
||||
|
||||
class TestMCPToolResult:
|
||||
"""Tests for MCPToolResult dataclass."""
|
||||
|
||||
def test_tool_result_success(self):
|
||||
"""Test creating a successful result."""
|
||||
result = MCPToolResult(
|
||||
success=True,
|
||||
result={"data": "test"},
|
||||
safety_decision=SafetyDecision.ALLOW,
|
||||
execution_time_ms=50.0,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.result == {"data": "test"}
|
||||
assert result.error is None
|
||||
assert result.safety_decision == SafetyDecision.ALLOW
|
||||
assert result.execution_time_ms == 50.0
|
||||
|
||||
def test_tool_result_failure(self):
|
||||
"""Test creating a failed result."""
|
||||
result = MCPToolResult(
|
||||
success=False,
|
||||
error="Permission denied",
|
||||
safety_decision=SafetyDecision.DENY,
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert result.error == "Permission denied"
|
||||
assert result.result is None
|
||||
|
||||
def test_tool_result_with_ids(self):
|
||||
"""Test result with approval and checkpoint IDs."""
|
||||
result = MCPToolResult(
|
||||
success=True,
|
||||
approval_id="approval-123",
|
||||
checkpoint_id="checkpoint-456",
|
||||
)
|
||||
|
||||
assert result.approval_id == "approval-123"
|
||||
assert result.checkpoint_id == "checkpoint-456"
|
||||
|
||||
def test_tool_result_defaults(self):
|
||||
"""Test result default values."""
|
||||
result = MCPToolResult(success=True)
|
||||
|
||||
assert result.result is None
|
||||
assert result.error is None
|
||||
assert result.safety_decision == SafetyDecision.ALLOW
|
||||
assert result.execution_time_ms == 0.0
|
||||
assert result.approval_id is None
|
||||
assert result.checkpoint_id is None
|
||||
assert result.metadata == {}
|
||||
|
||||
|
||||
class TestMCPSafetyWrapperClassification:
|
||||
"""Tests for tool classification."""
|
||||
|
||||
def test_classify_file_read(self):
|
||||
"""Test classifying file read tools."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
assert wrapper._classify_tool("file_read") == ActionType.FILE_READ
|
||||
assert wrapper._classify_tool("get_file") == ActionType.FILE_READ
|
||||
assert wrapper._classify_tool("list_files") == ActionType.FILE_READ
|
||||
assert wrapper._classify_tool("search_file") == ActionType.FILE_READ
|
||||
|
||||
def test_classify_file_write(self):
|
||||
"""Test classifying file write tools."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
assert wrapper._classify_tool("file_write") == ActionType.FILE_WRITE
|
||||
assert wrapper._classify_tool("create_file") == ActionType.FILE_WRITE
|
||||
assert wrapper._classify_tool("update_file") == ActionType.FILE_WRITE
|
||||
|
||||
def test_classify_file_delete(self):
|
||||
"""Test classifying file delete tools."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
assert wrapper._classify_tool("file_delete") == ActionType.FILE_DELETE
|
||||
assert wrapper._classify_tool("remove_file") == ActionType.FILE_DELETE
|
||||
|
||||
def test_classify_database_read(self):
|
||||
"""Test classifying database read tools."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
assert wrapper._classify_tool("database_query") == ActionType.DATABASE_QUERY
|
||||
assert wrapper._classify_tool("db_read") == ActionType.DATABASE_QUERY
|
||||
assert wrapper._classify_tool("query_database") == ActionType.DATABASE_QUERY
|
||||
|
||||
def test_classify_database_mutate(self):
|
||||
"""Test classifying database mutate tools."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
assert wrapper._classify_tool("database_write") == ActionType.DATABASE_MUTATE
|
||||
assert wrapper._classify_tool("db_update") == ActionType.DATABASE_MUTATE
|
||||
assert wrapper._classify_tool("database_delete") == ActionType.DATABASE_MUTATE
|
||||
|
||||
def test_classify_shell_command(self):
|
||||
"""Test classifying shell command tools."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
assert wrapper._classify_tool("shell_execute") == ActionType.SHELL_COMMAND
|
||||
assert wrapper._classify_tool("exec_command") == ActionType.SHELL_COMMAND
|
||||
assert wrapper._classify_tool("bash_run") == ActionType.SHELL_COMMAND
|
||||
|
||||
def test_classify_git_operation(self):
|
||||
"""Test classifying git tools."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
assert wrapper._classify_tool("git_commit") == ActionType.GIT_OPERATION
|
||||
assert wrapper._classify_tool("git_push") == ActionType.GIT_OPERATION
|
||||
assert wrapper._classify_tool("git_status") == ActionType.GIT_OPERATION
|
||||
|
||||
def test_classify_network_request(self):
|
||||
"""Test classifying network tools."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
assert wrapper._classify_tool("http_get") == ActionType.NETWORK_REQUEST
|
||||
assert wrapper._classify_tool("fetch_url") == ActionType.NETWORK_REQUEST
|
||||
assert wrapper._classify_tool("api_request") == ActionType.NETWORK_REQUEST
|
||||
|
||||
def test_classify_llm_call(self):
|
||||
"""Test classifying LLM tools."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
assert wrapper._classify_tool("llm_generate") == ActionType.LLM_CALL
|
||||
assert wrapper._classify_tool("ai_complete") == ActionType.LLM_CALL
|
||||
assert wrapper._classify_tool("claude_chat") == ActionType.LLM_CALL
|
||||
|
||||
def test_classify_default(self):
|
||||
"""Test default classification for unknown tools."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
assert wrapper._classify_tool("unknown_tool") == ActionType.TOOL_CALL
|
||||
assert wrapper._classify_tool("custom_action") == ActionType.TOOL_CALL
|
||||
|
||||
|
||||
class TestMCPSafetyWrapperToolHandlers:
|
||||
"""Tests for tool handler registration."""
|
||||
|
||||
def test_register_tool_handler(self):
|
||||
"""Test registering a tool handler."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
def handler(path: str) -> str:
|
||||
return f"Read: {path}"
|
||||
|
||||
wrapper.register_tool_handler("file_read", handler)
|
||||
|
||||
assert "file_read" in wrapper._tool_handlers
|
||||
assert wrapper._tool_handlers["file_read"] is handler
|
||||
|
||||
def test_register_multiple_handlers(self):
|
||||
"""Test registering multiple handlers."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
wrapper.register_tool_handler("tool1", lambda: None)
|
||||
wrapper.register_tool_handler("tool2", lambda: None)
|
||||
wrapper.register_tool_handler("tool3", lambda: None)
|
||||
|
||||
assert len(wrapper._tool_handlers) == 3
|
||||
|
||||
def test_overwrite_handler(self):
|
||||
"""Test overwriting a handler."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
handler1 = lambda: "first" # noqa: E731
|
||||
handler2 = lambda: "second" # noqa: E731
|
||||
|
||||
wrapper.register_tool_handler("tool", handler1)
|
||||
wrapper.register_tool_handler("tool", handler2)
|
||||
|
||||
assert wrapper._tool_handlers["tool"] is handler2
|
||||
|
||||
|
||||
class TestMCPSafetyWrapperExecution:
|
||||
"""Tests for tool execution."""
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def mock_guardian(self):
|
||||
"""Create a mock SafetyGuardian."""
|
||||
guardian = AsyncMock()
|
||||
guardian.validate = AsyncMock()
|
||||
return guardian
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def mock_emergency(self):
|
||||
"""Create a mock EmergencyControls."""
|
||||
emergency = AsyncMock()
|
||||
emergency.check_allowed = AsyncMock()
|
||||
return emergency
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_allowed(self, mock_guardian, mock_emergency):
|
||||
"""Test executing an allowed tool call."""
|
||||
mock_guardian.validate.return_value = MagicMock(
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=[],
|
||||
approval_id=None,
|
||||
checkpoint_id=None,
|
||||
)
|
||||
|
||||
wrapper = MCPSafetyWrapper(
|
||||
guardian=mock_guardian,
|
||||
emergency_controls=mock_emergency,
|
||||
)
|
||||
|
||||
async def handler(path: str) -> dict:
|
||||
return {"content": f"Data from {path}"}
|
||||
|
||||
wrapper.register_tool_handler("file_read", handler)
|
||||
|
||||
call = MCPToolCall(
|
||||
tool_name="file_read",
|
||||
arguments={"path": "/test.txt"},
|
||||
project_id="proj-1",
|
||||
)
|
||||
|
||||
result = await wrapper.execute(call, "agent-1")
|
||||
|
||||
assert result.success is True
|
||||
assert result.result == {"content": "Data from /test.txt"}
|
||||
assert result.safety_decision == SafetyDecision.ALLOW
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_denied(self, mock_guardian, mock_emergency):
|
||||
"""Test executing a denied tool call."""
|
||||
mock_guardian.validate.return_value = MagicMock(
|
||||
decision=SafetyDecision.DENY,
|
||||
reasons=["Permission denied", "Rate limit exceeded"],
|
||||
)
|
||||
|
||||
wrapper = MCPSafetyWrapper(
|
||||
guardian=mock_guardian,
|
||||
emergency_controls=mock_emergency,
|
||||
)
|
||||
|
||||
call = MCPToolCall(
|
||||
tool_name="file_write",
|
||||
arguments={"path": "/etc/passwd"},
|
||||
)
|
||||
|
||||
result = await wrapper.execute(call, "agent-1")
|
||||
|
||||
assert result.success is False
|
||||
assert "Permission denied" in result.error
|
||||
assert "Rate limit exceeded" in result.error
|
||||
assert result.safety_decision == SafetyDecision.DENY
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_requires_approval(self, mock_guardian, mock_emergency):
|
||||
"""Test executing a tool that requires approval."""
|
||||
mock_guardian.validate.return_value = MagicMock(
|
||||
decision=SafetyDecision.REQUIRE_APPROVAL,
|
||||
reasons=["Destructive operation requires approval"],
|
||||
approval_id="approval-123",
|
||||
)
|
||||
|
||||
wrapper = MCPSafetyWrapper(
|
||||
guardian=mock_guardian,
|
||||
emergency_controls=mock_emergency,
|
||||
)
|
||||
|
||||
call = MCPToolCall(
|
||||
tool_name="file_delete",
|
||||
arguments={"path": "/important.txt"},
|
||||
)
|
||||
|
||||
result = await wrapper.execute(call, "agent-1")
|
||||
|
||||
assert result.success is False
|
||||
assert result.safety_decision == SafetyDecision.REQUIRE_APPROVAL
|
||||
assert result.approval_id == "approval-123"
|
||||
assert "requires human approval" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_emergency_stop(self, mock_guardian, mock_emergency):
|
||||
"""Test execution blocked by emergency stop."""
|
||||
mock_emergency.check_allowed.side_effect = EmergencyStopError(
|
||||
"Emergency stop active"
|
||||
)
|
||||
|
||||
wrapper = MCPSafetyWrapper(
|
||||
guardian=mock_guardian,
|
||||
emergency_controls=mock_emergency,
|
||||
)
|
||||
|
||||
call = MCPToolCall(
|
||||
tool_name="file_write",
|
||||
arguments={"path": "/test.txt"},
|
||||
project_id="proj-1",
|
||||
)
|
||||
|
||||
result = await wrapper.execute(call, "agent-1")
|
||||
|
||||
assert result.success is False
|
||||
assert result.safety_decision == SafetyDecision.DENY
|
||||
assert result.metadata.get("emergency_stop") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_bypass_safety(self, mock_guardian, mock_emergency):
|
||||
"""Test executing with safety bypass."""
|
||||
wrapper = MCPSafetyWrapper(
|
||||
guardian=mock_guardian,
|
||||
emergency_controls=mock_emergency,
|
||||
)
|
||||
|
||||
async def handler(data: str) -> str:
|
||||
return f"Processed: {data}"
|
||||
|
||||
wrapper.register_tool_handler("custom_tool", handler)
|
||||
|
||||
call = MCPToolCall(
|
||||
tool_name="custom_tool",
|
||||
arguments={"data": "test"},
|
||||
)
|
||||
|
||||
result = await wrapper.execute(call, "agent-1", bypass_safety=True)
|
||||
|
||||
assert result.success is True
|
||||
assert result.result == "Processed: test"
|
||||
# Guardian should not be called when bypassing
|
||||
mock_guardian.validate.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_no_handler(self, mock_guardian, mock_emergency):
|
||||
"""Test executing a tool with no registered handler."""
|
||||
mock_guardian.validate.return_value = MagicMock(
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=[],
|
||||
approval_id=None,
|
||||
checkpoint_id=None,
|
||||
)
|
||||
|
||||
wrapper = MCPSafetyWrapper(
|
||||
guardian=mock_guardian,
|
||||
emergency_controls=mock_emergency,
|
||||
)
|
||||
|
||||
call = MCPToolCall(
|
||||
tool_name="unregistered_tool",
|
||||
arguments={},
|
||||
)
|
||||
|
||||
result = await wrapper.execute(call, "agent-1")
|
||||
|
||||
assert result.success is False
|
||||
assert "No handler registered" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_handler_exception(self, mock_guardian, mock_emergency):
|
||||
"""Test handling exceptions from tool handler."""
|
||||
mock_guardian.validate.return_value = MagicMock(
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=[],
|
||||
approval_id=None,
|
||||
checkpoint_id=None,
|
||||
)
|
||||
|
||||
wrapper = MCPSafetyWrapper(
|
||||
guardian=mock_guardian,
|
||||
emergency_controls=mock_emergency,
|
||||
)
|
||||
|
||||
async def failing_handler() -> None:
|
||||
raise ValueError("Handler failed!")
|
||||
|
||||
wrapper.register_tool_handler("failing_tool", failing_handler)
|
||||
|
||||
call = MCPToolCall(
|
||||
tool_name="failing_tool",
|
||||
arguments={},
|
||||
)
|
||||
|
||||
result = await wrapper.execute(call, "agent-1")
|
||||
|
||||
assert result.success is False
|
||||
assert "Handler failed!" in result.error
|
||||
# Decision is still ALLOW because the safety check passed
|
||||
assert result.safety_decision == SafetyDecision.ALLOW
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sync_handler(self, mock_guardian, mock_emergency):
|
||||
"""Test executing a synchronous handler."""
|
||||
mock_guardian.validate.return_value = MagicMock(
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=[],
|
||||
approval_id=None,
|
||||
checkpoint_id=None,
|
||||
)
|
||||
|
||||
wrapper = MCPSafetyWrapper(
|
||||
guardian=mock_guardian,
|
||||
emergency_controls=mock_emergency,
|
||||
)
|
||||
|
||||
def sync_handler(value: int) -> int:
|
||||
return value * 2
|
||||
|
||||
wrapper.register_tool_handler("sync_tool", sync_handler)
|
||||
|
||||
call = MCPToolCall(
|
||||
tool_name="sync_tool",
|
||||
arguments={"value": 21},
|
||||
)
|
||||
|
||||
result = await wrapper.execute(call, "agent-1")
|
||||
|
||||
assert result.success is True
|
||||
assert result.result == 42
|
||||
|
||||
|
||||
class TestBuildActionRequest:
|
||||
"""Tests for _build_action_request."""
|
||||
|
||||
def test_build_action_request_basic(self):
|
||||
"""Test building a basic action request."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
call = MCPToolCall(
|
||||
tool_name="file_read",
|
||||
arguments={"path": "/test.txt"},
|
||||
project_id="proj-1",
|
||||
)
|
||||
|
||||
action = wrapper._build_action_request(call, "agent-1", AutonomyLevel.MILESTONE)
|
||||
|
||||
assert action.action_type == ActionType.FILE_READ
|
||||
assert action.tool_name == "file_read"
|
||||
assert action.arguments == {"path": "/test.txt"}
|
||||
assert action.resource == "/test.txt"
|
||||
assert action.metadata.agent_id == "agent-1"
|
||||
assert action.metadata.project_id == "proj-1"
|
||||
assert action.metadata.autonomy_level == AutonomyLevel.MILESTONE
|
||||
|
||||
def test_build_action_request_with_context(self):
|
||||
"""Test building action request with session context."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
call = MCPToolCall(
|
||||
tool_name="database_query",
|
||||
arguments={"resource": "users", "query": "SELECT *"},
|
||||
context={"session_id": "sess-123"},
|
||||
project_id="proj-2",
|
||||
)
|
||||
|
||||
action = wrapper._build_action_request(
|
||||
call, "agent-2", AutonomyLevel.AUTONOMOUS
|
||||
)
|
||||
|
||||
assert action.resource == "users"
|
||||
assert action.metadata.session_id == "sess-123"
|
||||
assert action.metadata.autonomy_level == AutonomyLevel.AUTONOMOUS
|
||||
|
||||
def test_build_action_request_no_resource(self):
|
||||
"""Test building action request without resource."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
call = MCPToolCall(
|
||||
tool_name="llm_generate",
|
||||
arguments={"prompt": "Hello"},
|
||||
)
|
||||
|
||||
action = wrapper._build_action_request(
|
||||
call, "agent-1", AutonomyLevel.FULL_CONTROL
|
||||
)
|
||||
|
||||
assert action.resource is None
|
||||
|
||||
|
||||
class TestElapsedTime:
|
||||
"""Tests for _elapsed_ms helper."""
|
||||
|
||||
def test_elapsed_ms(self):
|
||||
"""Test calculating elapsed time."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
start = datetime.utcnow() - timedelta(milliseconds=100)
|
||||
elapsed = wrapper._elapsed_ms(start)
|
||||
|
||||
# Should be at least 100ms, but allow some tolerance
|
||||
assert elapsed >= 99
|
||||
assert elapsed < 200
|
||||
|
||||
|
||||
class TestSafeToolExecutor:
|
||||
"""Tests for SafeToolExecutor context manager."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_executor_execute(self):
|
||||
"""Test executing within context manager."""
|
||||
mock_guardian = AsyncMock()
|
||||
mock_guardian.validate.return_value = MagicMock(
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=[],
|
||||
approval_id=None,
|
||||
checkpoint_id=None,
|
||||
)
|
||||
|
||||
mock_emergency = AsyncMock()
|
||||
|
||||
wrapper = MCPSafetyWrapper(
|
||||
guardian=mock_guardian,
|
||||
emergency_controls=mock_emergency,
|
||||
)
|
||||
|
||||
async def handler() -> str:
|
||||
return "success"
|
||||
|
||||
wrapper.register_tool_handler("test_tool", handler)
|
||||
|
||||
call = MCPToolCall(tool_name="test_tool", arguments={})
|
||||
|
||||
async with SafeToolExecutor(wrapper, call, "agent-1") as executor:
|
||||
result = await executor.execute()
|
||||
|
||||
assert result.success is True
|
||||
assert result.result == "success"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_executor_result_property(self):
|
||||
"""Test accessing result via property."""
|
||||
mock_guardian = AsyncMock()
|
||||
mock_guardian.validate.return_value = MagicMock(
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=[],
|
||||
approval_id=None,
|
||||
checkpoint_id=None,
|
||||
)
|
||||
|
||||
mock_emergency = AsyncMock()
|
||||
|
||||
wrapper = MCPSafetyWrapper(
|
||||
guardian=mock_guardian,
|
||||
emergency_controls=mock_emergency,
|
||||
)
|
||||
|
||||
wrapper.register_tool_handler("tool", lambda: "data")
|
||||
|
||||
call = MCPToolCall(tool_name="tool", arguments={})
|
||||
executor = SafeToolExecutor(wrapper, call, "agent-1")
|
||||
|
||||
# Before execution
|
||||
assert executor.result is None
|
||||
|
||||
async with executor:
|
||||
await executor.execute()
|
||||
|
||||
# After execution
|
||||
assert executor.result is not None
|
||||
assert executor.result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_executor_with_autonomy_level(self):
|
||||
"""Test executor with custom autonomy level."""
|
||||
mock_guardian = AsyncMock()
|
||||
mock_guardian.validate.return_value = MagicMock(
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=[],
|
||||
approval_id=None,
|
||||
checkpoint_id=None,
|
||||
)
|
||||
|
||||
mock_emergency = AsyncMock()
|
||||
|
||||
wrapper = MCPSafetyWrapper(
|
||||
guardian=mock_guardian,
|
||||
emergency_controls=mock_emergency,
|
||||
)
|
||||
|
||||
wrapper.register_tool_handler("tool", lambda: None)
|
||||
|
||||
call = MCPToolCall(tool_name="tool", arguments={})
|
||||
|
||||
async with SafeToolExecutor(
|
||||
wrapper, call, "agent-1", AutonomyLevel.AUTONOMOUS
|
||||
) as executor:
|
||||
await executor.execute()
|
||||
|
||||
# Check that guardian was called with correct autonomy level
|
||||
mock_guardian.validate.assert_called_once()
|
||||
action = mock_guardian.validate.call_args[0][0]
|
||||
assert action.metadata.autonomy_level == AutonomyLevel.AUTONOMOUS
|
||||
|
||||
|
||||
class TestCreateMCPWrapper:
|
||||
"""Tests for create_mcp_wrapper factory function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_wrapper_with_guardian(self):
|
||||
"""Test creating wrapper with provided guardian."""
|
||||
mock_guardian = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"app.services.safety.mcp.integration.get_emergency_controls"
|
||||
) as mock_get_emergency:
|
||||
mock_get_emergency.return_value = AsyncMock()
|
||||
|
||||
wrapper = await create_mcp_wrapper(guardian=mock_guardian)
|
||||
|
||||
assert wrapper._guardian is mock_guardian
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_wrapper_default_guardian(self):
|
||||
"""Test creating wrapper with default guardian."""
|
||||
with (
|
||||
patch(
|
||||
"app.services.safety.mcp.integration.get_safety_guardian"
|
||||
) as mock_get_guardian,
|
||||
patch(
|
||||
"app.services.safety.mcp.integration.get_emergency_controls"
|
||||
) as mock_get_emergency,
|
||||
):
|
||||
mock_guardian = AsyncMock()
|
||||
mock_get_guardian.return_value = mock_guardian
|
||||
mock_get_emergency.return_value = AsyncMock()
|
||||
|
||||
wrapper = await create_mcp_wrapper()
|
||||
|
||||
assert wrapper._guardian is mock_guardian
|
||||
mock_get_guardian.assert_called_once()
|
||||
|
||||
|
||||
class TestLazyGetters:
|
||||
"""Tests for lazy getter methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_guardian_lazy(self):
|
||||
"""Test lazy guardian initialization."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
with patch(
|
||||
"app.services.safety.mcp.integration.get_safety_guardian"
|
||||
) as mock_get:
|
||||
mock_guardian = AsyncMock()
|
||||
mock_get.return_value = mock_guardian
|
||||
|
||||
guardian = await wrapper._get_guardian()
|
||||
|
||||
assert guardian is mock_guardian
|
||||
mock_get.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_guardian_cached(self):
|
||||
"""Test guardian is cached after first access."""
|
||||
mock_guardian = AsyncMock()
|
||||
wrapper = MCPSafetyWrapper(guardian=mock_guardian)
|
||||
|
||||
guardian = await wrapper._get_guardian()
|
||||
|
||||
assert guardian is mock_guardian
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_emergency_controls_lazy(self):
|
||||
"""Test lazy emergency controls initialization."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
with patch(
|
||||
"app.services.safety.mcp.integration.get_emergency_controls"
|
||||
) as mock_get:
|
||||
mock_emergency = AsyncMock()
|
||||
mock_get.return_value = mock_emergency
|
||||
|
||||
emergency = await wrapper._get_emergency_controls()
|
||||
|
||||
assert emergency is mock_emergency
|
||||
mock_get.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_emergency_controls_cached(self):
|
||||
"""Test emergency controls is cached after first access."""
|
||||
mock_emergency = AsyncMock()
|
||||
wrapper = MCPSafetyWrapper(emergency_controls=mock_emergency)
|
||||
|
||||
emergency = await wrapper._get_emergency_controls()
|
||||
|
||||
assert emergency is mock_emergency
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Tests for edge cases and error handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_with_safety_error(self):
|
||||
"""Test handling SafetyError from guardian."""
|
||||
from app.services.safety.exceptions import SafetyError
|
||||
|
||||
mock_guardian = AsyncMock()
|
||||
mock_guardian.validate.side_effect = SafetyError("Internal safety error")
|
||||
|
||||
mock_emergency = AsyncMock()
|
||||
|
||||
wrapper = MCPSafetyWrapper(
|
||||
guardian=mock_guardian,
|
||||
emergency_controls=mock_emergency,
|
||||
)
|
||||
|
||||
call = MCPToolCall(tool_name="test", arguments={})
|
||||
|
||||
result = await wrapper.execute(call, "agent-1")
|
||||
|
||||
assert result.success is False
|
||||
assert "Internal safety error" in result.error
|
||||
assert result.safety_decision == SafetyDecision.DENY
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_with_checkpoint_id(self):
|
||||
"""Test that checkpoint_id is propagated to result."""
|
||||
mock_guardian = AsyncMock()
|
||||
mock_guardian.validate.return_value = MagicMock(
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=[],
|
||||
approval_id=None,
|
||||
checkpoint_id="checkpoint-abc",
|
||||
)
|
||||
|
||||
mock_emergency = AsyncMock()
|
||||
|
||||
wrapper = MCPSafetyWrapper(
|
||||
guardian=mock_guardian,
|
||||
emergency_controls=mock_emergency,
|
||||
)
|
||||
|
||||
wrapper.register_tool_handler("tool", lambda: "result")
|
||||
|
||||
call = MCPToolCall(tool_name="tool", arguments={})
|
||||
|
||||
result = await wrapper.execute(call, "agent-1")
|
||||
|
||||
assert result.success is True
|
||||
assert result.checkpoint_id == "checkpoint-abc"
|
||||
|
||||
def test_destructive_tools_constant(self):
|
||||
"""Test DESTRUCTIVE_TOOLS class constant."""
|
||||
assert "file_write" in MCPSafetyWrapper.DESTRUCTIVE_TOOLS
|
||||
assert "file_delete" in MCPSafetyWrapper.DESTRUCTIVE_TOOLS
|
||||
assert "shell_execute" in MCPSafetyWrapper.DESTRUCTIVE_TOOLS
|
||||
assert "git_push" in MCPSafetyWrapper.DESTRUCTIVE_TOOLS
|
||||
|
||||
def test_read_only_tools_constant(self):
|
||||
"""Test READ_ONLY_TOOLS class constant."""
|
||||
assert "file_read" in MCPSafetyWrapper.READ_ONLY_TOOLS
|
||||
assert "database_query" in MCPSafetyWrapper.READ_ONLY_TOOLS
|
||||
assert "git_status" in MCPSafetyWrapper.READ_ONLY_TOOLS
|
||||
assert "search" in MCPSafetyWrapper.READ_ONLY_TOOLS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scope_with_project_id(self):
|
||||
"""Test that scope is set correctly with project_id."""
|
||||
mock_guardian = AsyncMock()
|
||||
mock_guardian.validate.return_value = MagicMock(
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=[],
|
||||
approval_id=None,
|
||||
checkpoint_id=None,
|
||||
)
|
||||
|
||||
mock_emergency = AsyncMock()
|
||||
|
||||
wrapper = MCPSafetyWrapper(
|
||||
guardian=mock_guardian,
|
||||
emergency_controls=mock_emergency,
|
||||
)
|
||||
|
||||
wrapper.register_tool_handler("tool", lambda: None)
|
||||
|
||||
call = MCPToolCall(
|
||||
tool_name="tool",
|
||||
arguments={},
|
||||
project_id="proj-123",
|
||||
)
|
||||
|
||||
await wrapper.execute(call, "agent-1")
|
||||
|
||||
# Verify emergency check was called with project scope
|
||||
mock_emergency.check_allowed.assert_called_once()
|
||||
call_kwargs = mock_emergency.check_allowed.call_args
|
||||
assert "project:proj-123" in str(call_kwargs)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scope_without_project_id(self):
|
||||
"""Test that scope falls back to agent when no project_id."""
|
||||
mock_guardian = AsyncMock()
|
||||
mock_guardian.validate.return_value = MagicMock(
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=[],
|
||||
approval_id=None,
|
||||
checkpoint_id=None,
|
||||
)
|
||||
|
||||
mock_emergency = AsyncMock()
|
||||
|
||||
wrapper = MCPSafetyWrapper(
|
||||
guardian=mock_guardian,
|
||||
emergency_controls=mock_emergency,
|
||||
)
|
||||
|
||||
wrapper.register_tool_handler("tool", lambda: None)
|
||||
|
||||
call = MCPToolCall(
|
||||
tool_name="tool",
|
||||
arguments={},
|
||||
# No project_id
|
||||
)
|
||||
|
||||
await wrapper.execute(call, "agent-555")
|
||||
|
||||
# Verify emergency check was called with agent scope
|
||||
mock_emergency.check_allowed.assert_called_once()
|
||||
call_kwargs = mock_emergency.check_allowed.call_args
|
||||
assert "agent:agent-555" in str(call_kwargs)
|
||||
Reference in New Issue
Block a user