Files
syndarix/backend/app/services/safety/mcp/integration.py
Felipe Cardoso 520c06175e refactor(safety): apply consistent formatting across services and tests
Improved code readability and uniformity by standardizing line breaks, indentation, and inline conditions across safety-related services, models, and tests, including content filters, validation rules, and emergency controls.
2026-01-03 16:23:39 +01:00

410 lines
12 KiB
Python

"""
MCP Safety Integration
Provides safety-aware wrappers for MCP tool execution.
"""
import asyncio
import logging
from collections.abc import Callable
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, ClassVar, TypeVar
from ..audit import AuditLogger
from ..emergency import EmergencyControls, get_emergency_controls
from ..exceptions import (
EmergencyStopError,
SafetyError,
)
from ..guardian import SafetyGuardian, get_safety_guardian
from ..models import (
ActionMetadata,
ActionRequest,
ActionType,
AutonomyLevel,
SafetyDecision,
)
logger = logging.getLogger(__name__)
T = TypeVar("T")
@dataclass
class MCPToolCall:
"""Represents an MCP tool call."""
tool_name: str
arguments: dict[str, Any]
server_name: str | None = None
project_id: str | None = None
context: dict[str, Any] = field(default_factory=dict)
@dataclass
class MCPToolResult:
"""Result of an MCP tool execution."""
success: bool
result: Any = None
error: str | None = None
safety_decision: SafetyDecision = SafetyDecision.ALLOW
execution_time_ms: float = 0.0
approval_id: str | None = None
checkpoint_id: str | None = None
metadata: dict[str, Any] = field(default_factory=dict)
class MCPSafetyWrapper:
"""
Wraps MCP tool execution with safety checks.
Features:
- Pre-execution validation via SafetyGuardian
- Permission checking per tool/resource
- Budget and rate limit enforcement
- Audit logging of all MCP calls
- Emergency stop integration
- Checkpoint creation for destructive operations
"""
# Tool categories for automatic classification
DESTRUCTIVE_TOOLS: ClassVar[set[str]] = {
"file_write",
"file_delete",
"database_mutate",
"shell_execute",
"git_push",
"git_commit",
"deploy",
}
READ_ONLY_TOOLS: ClassVar[set[str]] = {
"file_read",
"database_query",
"git_status",
"git_log",
"list_files",
"search",
}
def __init__(
self,
guardian: SafetyGuardian | None = None,
audit_logger: AuditLogger | None = None,
emergency_controls: EmergencyControls | None = None,
) -> None:
"""
Initialize MCPSafetyWrapper.
Args:
guardian: SafetyGuardian instance (uses singleton if not provided)
audit_logger: AuditLogger instance
emergency_controls: EmergencyControls instance
"""
self._guardian = guardian
self._audit_logger = audit_logger
self._emergency_controls = emergency_controls
self._tool_handlers: dict[str, Callable[..., Any]] = {}
self._lock = asyncio.Lock()
async def _get_guardian(self) -> SafetyGuardian:
"""Get or create SafetyGuardian."""
if self._guardian is None:
self._guardian = await get_safety_guardian()
return self._guardian
async def _get_emergency_controls(self) -> EmergencyControls:
"""Get or create EmergencyControls."""
if self._emergency_controls is None:
self._emergency_controls = await get_emergency_controls()
return self._emergency_controls
def register_tool_handler(
self,
tool_name: str,
handler: Callable[..., Any],
) -> None:
"""
Register a handler for a tool.
Args:
tool_name: Name of the tool
handler: Async function to handle the tool call
"""
self._tool_handlers[tool_name] = handler
logger.debug("Registered handler for tool: %s", tool_name)
async def execute(
self,
tool_call: MCPToolCall,
agent_id: str,
autonomy_level: AutonomyLevel = AutonomyLevel.MILESTONE,
bypass_safety: bool = False,
) -> MCPToolResult:
"""
Execute an MCP tool call with safety checks.
Args:
tool_call: The tool call to execute
agent_id: ID of the calling agent
autonomy_level: Agent's autonomy level
bypass_safety: Bypass safety checks (emergency only)
Returns:
MCPToolResult with execution outcome
"""
start_time = datetime.utcnow()
# Check emergency controls first
emergency = await self._get_emergency_controls()
scope = f"agent:{agent_id}"
if tool_call.project_id:
scope = f"project:{tool_call.project_id}"
try:
await emergency.check_allowed(scope=scope, raise_if_blocked=True)
except EmergencyStopError as e:
return MCPToolResult(
success=False,
error=str(e),
safety_decision=SafetyDecision.DENY,
metadata={"emergency_stop": True},
)
# Build action request
action = self._build_action_request(
tool_call=tool_call,
agent_id=agent_id,
autonomy_level=autonomy_level,
)
# Skip safety checks if bypass is enabled
if bypass_safety:
logger.warning(
"Safety bypass enabled for tool: %s (agent: %s)",
tool_call.tool_name,
agent_id,
)
return await self._execute_tool(tool_call, action, start_time)
# Run safety validation
guardian = await self._get_guardian()
try:
guardian_result = await guardian.validate(action)
except SafetyError as e:
return MCPToolResult(
success=False,
error=str(e),
safety_decision=SafetyDecision.DENY,
execution_time_ms=self._elapsed_ms(start_time),
)
# Handle safety decision
if guardian_result.decision == SafetyDecision.DENY:
return MCPToolResult(
success=False,
error="; ".join(guardian_result.reasons),
safety_decision=SafetyDecision.DENY,
execution_time_ms=self._elapsed_ms(start_time),
)
if guardian_result.decision == SafetyDecision.REQUIRE_APPROVAL:
# For now, just return that approval is required
# The caller should handle the approval flow
return MCPToolResult(
success=False,
error="Action requires human approval",
safety_decision=SafetyDecision.REQUIRE_APPROVAL,
approval_id=guardian_result.approval_id,
execution_time_ms=self._elapsed_ms(start_time),
)
# Execute the tool
result = await self._execute_tool(
tool_call,
action,
start_time,
checkpoint_id=guardian_result.checkpoint_id,
)
return result
async def _execute_tool(
self,
tool_call: MCPToolCall,
action: ActionRequest,
start_time: datetime,
checkpoint_id: str | None = None,
) -> MCPToolResult:
"""Execute the actual tool call."""
handler = self._tool_handlers.get(tool_call.tool_name)
if handler is None:
return MCPToolResult(
success=False,
error=f"No handler registered for tool: {tool_call.tool_name}",
safety_decision=SafetyDecision.ALLOW,
execution_time_ms=self._elapsed_ms(start_time),
)
try:
if asyncio.iscoroutinefunction(handler):
result = await handler(**tool_call.arguments)
else:
result = handler(**tool_call.arguments)
return MCPToolResult(
success=True,
result=result,
safety_decision=SafetyDecision.ALLOW,
execution_time_ms=self._elapsed_ms(start_time),
checkpoint_id=checkpoint_id,
)
except Exception as e:
logger.error("Tool execution failed: %s - %s", tool_call.tool_name, e)
return MCPToolResult(
success=False,
error=str(e),
safety_decision=SafetyDecision.ALLOW,
execution_time_ms=self._elapsed_ms(start_time),
checkpoint_id=checkpoint_id,
)
def _build_action_request(
self,
tool_call: MCPToolCall,
agent_id: str,
autonomy_level: AutonomyLevel,
) -> ActionRequest:
"""Build an ActionRequest from an MCP tool call."""
action_type = self._classify_tool(tool_call.tool_name)
metadata = ActionMetadata(
agent_id=agent_id,
session_id=tool_call.context.get("session_id", ""),
project_id=tool_call.project_id or "",
autonomy_level=autonomy_level,
)
return ActionRequest(
action_type=action_type,
tool_name=tool_call.tool_name,
arguments=tool_call.arguments,
resource=tool_call.arguments.get(
"path", tool_call.arguments.get("resource")
),
metadata=metadata,
)
def _classify_tool(self, tool_name: str) -> ActionType:
"""Classify a tool into an action type."""
tool_lower = tool_name.lower()
# Check destructive patterns
if any(
d in tool_lower for d in ["write", "create", "delete", "remove", "update"]
):
if "file" in tool_lower:
if "delete" in tool_lower or "remove" in tool_lower:
return ActionType.FILE_DELETE
return ActionType.FILE_WRITE
if "database" in tool_lower or "db" in tool_lower:
return ActionType.DATABASE_MUTATE
# Check read patterns
if any(r in tool_lower for r in ["read", "get", "list", "search", "query"]):
if "file" in tool_lower:
return ActionType.FILE_READ
if "database" in tool_lower or "db" in tool_lower:
return ActionType.DATABASE_QUERY
# Check specific types
if "shell" in tool_lower or "exec" in tool_lower or "bash" in tool_lower:
return ActionType.SHELL_COMMAND
if "git" in tool_lower:
return ActionType.GIT_OPERATION
if "http" in tool_lower or "fetch" in tool_lower or "request" in tool_lower:
return ActionType.NETWORK_REQUEST
if "llm" in tool_lower or "ai" in tool_lower or "claude" in tool_lower:
return ActionType.LLM_CALL
# Default to tool call
return ActionType.TOOL_CALL
def _elapsed_ms(self, start_time: datetime) -> float:
"""Calculate elapsed time in milliseconds."""
return (datetime.utcnow() - start_time).total_seconds() * 1000
class SafeToolExecutor:
"""
Context manager for safe tool execution with automatic cleanup.
Usage:
async with SafeToolExecutor(wrapper, tool_call, agent_id) as executor:
result = await executor.execute()
if result.success:
# Use result
else:
# Handle error or approval required
"""
def __init__(
self,
wrapper: MCPSafetyWrapper,
tool_call: MCPToolCall,
agent_id: str,
autonomy_level: AutonomyLevel = AutonomyLevel.MILESTONE,
) -> None:
self._wrapper = wrapper
self._tool_call = tool_call
self._agent_id = agent_id
self._autonomy_level = autonomy_level
self._result: MCPToolResult | None = None
async def __aenter__(self) -> "SafeToolExecutor":
return self
async def __aexit__(
self,
exc_type: type[Exception] | None,
exc_val: Exception | None,
exc_tb: Any,
) -> bool:
# Could trigger rollback here if needed
return False
async def execute(self) -> MCPToolResult:
"""Execute the tool call."""
self._result = await self._wrapper.execute(
self._tool_call,
self._agent_id,
self._autonomy_level,
)
return self._result
@property
def result(self) -> MCPToolResult | None:
"""Get the execution result."""
return self._result
# Factory function
async def create_mcp_wrapper(
guardian: SafetyGuardian | None = None,
) -> MCPSafetyWrapper:
"""Create an MCPSafetyWrapper with default configuration."""
if guardian is None:
guardian = await get_safety_guardian()
return MCPSafetyWrapper(
guardian=guardian,
emergency_controls=await get_emergency_controls(),
)