""" MCP Safety Integration Provides safety-aware wrappers for MCP tool execution. """ import asyncio import logging from collections.abc import Callable from dataclasses import dataclass, field from datetime import datetime from typing import Any, ClassVar, TypeVar from ..audit import AuditLogger from ..emergency import EmergencyControls, get_emergency_controls from ..exceptions import ( EmergencyStopError, SafetyError, ) from ..guardian import SafetyGuardian, get_safety_guardian from ..models import ( ActionMetadata, ActionRequest, ActionType, AutonomyLevel, SafetyDecision, ) logger = logging.getLogger(__name__) T = TypeVar("T") @dataclass class MCPToolCall: """Represents an MCP tool call.""" tool_name: str arguments: dict[str, Any] server_name: str | None = None project_id: str | None = None context: dict[str, Any] = field(default_factory=dict) @dataclass class MCPToolResult: """Result of an MCP tool execution.""" success: bool result: Any = None error: str | None = None safety_decision: SafetyDecision = SafetyDecision.ALLOW execution_time_ms: float = 0.0 approval_id: str | None = None checkpoint_id: str | None = None metadata: dict[str, Any] = field(default_factory=dict) class MCPSafetyWrapper: """ Wraps MCP tool execution with safety checks. Features: - Pre-execution validation via SafetyGuardian - Permission checking per tool/resource - Budget and rate limit enforcement - Audit logging of all MCP calls - Emergency stop integration - Checkpoint creation for destructive operations """ # Tool categories for automatic classification DESTRUCTIVE_TOOLS: ClassVar[set[str]] = { "file_write", "file_delete", "database_mutate", "shell_execute", "git_push", "git_commit", "deploy", } READ_ONLY_TOOLS: ClassVar[set[str]] = { "file_read", "database_query", "git_status", "git_log", "list_files", "search", } def __init__( self, guardian: SafetyGuardian | None = None, audit_logger: AuditLogger | None = None, emergency_controls: EmergencyControls | None = None, ) -> None: """ Initialize MCPSafetyWrapper. Args: guardian: SafetyGuardian instance (uses singleton if not provided) audit_logger: AuditLogger instance emergency_controls: EmergencyControls instance """ self._guardian = guardian self._audit_logger = audit_logger self._emergency_controls = emergency_controls self._tool_handlers: dict[str, Callable[..., Any]] = {} self._lock = asyncio.Lock() async def _get_guardian(self) -> SafetyGuardian: """Get or create SafetyGuardian.""" if self._guardian is None: self._guardian = await get_safety_guardian() return self._guardian async def _get_emergency_controls(self) -> EmergencyControls: """Get or create EmergencyControls.""" if self._emergency_controls is None: self._emergency_controls = await get_emergency_controls() return self._emergency_controls def register_tool_handler( self, tool_name: str, handler: Callable[..., Any], ) -> None: """ Register a handler for a tool. Args: tool_name: Name of the tool handler: Async function to handle the tool call """ self._tool_handlers[tool_name] = handler logger.debug("Registered handler for tool: %s", tool_name) async def execute( self, tool_call: MCPToolCall, agent_id: str, autonomy_level: AutonomyLevel = AutonomyLevel.MILESTONE, bypass_safety: bool = False, ) -> MCPToolResult: """ Execute an MCP tool call with safety checks. Args: tool_call: The tool call to execute agent_id: ID of the calling agent autonomy_level: Agent's autonomy level bypass_safety: Bypass safety checks (emergency only) Returns: MCPToolResult with execution outcome """ start_time = datetime.utcnow() # Check emergency controls first emergency = await self._get_emergency_controls() scope = f"agent:{agent_id}" if tool_call.project_id: scope = f"project:{tool_call.project_id}" try: await emergency.check_allowed(scope=scope, raise_if_blocked=True) except EmergencyStopError as e: return MCPToolResult( success=False, error=str(e), safety_decision=SafetyDecision.DENY, metadata={"emergency_stop": True}, ) # Build action request action = self._build_action_request( tool_call=tool_call, agent_id=agent_id, autonomy_level=autonomy_level, ) # Skip safety checks if bypass is enabled if bypass_safety: logger.warning( "Safety bypass enabled for tool: %s (agent: %s)", tool_call.tool_name, agent_id, ) return await self._execute_tool(tool_call, action, start_time) # Run safety validation guardian = await self._get_guardian() try: guardian_result = await guardian.validate(action) except SafetyError as e: return MCPToolResult( success=False, error=str(e), safety_decision=SafetyDecision.DENY, execution_time_ms=self._elapsed_ms(start_time), ) # Handle safety decision if guardian_result.decision == SafetyDecision.DENY: return MCPToolResult( success=False, error="; ".join(guardian_result.reasons), safety_decision=SafetyDecision.DENY, execution_time_ms=self._elapsed_ms(start_time), ) if guardian_result.decision == SafetyDecision.REQUIRE_APPROVAL: # For now, just return that approval is required # The caller should handle the approval flow return MCPToolResult( success=False, error="Action requires human approval", safety_decision=SafetyDecision.REQUIRE_APPROVAL, approval_id=guardian_result.approval_id, execution_time_ms=self._elapsed_ms(start_time), ) # Execute the tool result = await self._execute_tool( tool_call, action, start_time, checkpoint_id=guardian_result.checkpoint_id, ) return result async def _execute_tool( self, tool_call: MCPToolCall, action: ActionRequest, start_time: datetime, checkpoint_id: str | None = None, ) -> MCPToolResult: """Execute the actual tool call.""" handler = self._tool_handlers.get(tool_call.tool_name) if handler is None: return MCPToolResult( success=False, error=f"No handler registered for tool: {tool_call.tool_name}", safety_decision=SafetyDecision.ALLOW, execution_time_ms=self._elapsed_ms(start_time), ) try: if asyncio.iscoroutinefunction(handler): result = await handler(**tool_call.arguments) else: result = handler(**tool_call.arguments) return MCPToolResult( success=True, result=result, safety_decision=SafetyDecision.ALLOW, execution_time_ms=self._elapsed_ms(start_time), checkpoint_id=checkpoint_id, ) except Exception as e: logger.error("Tool execution failed: %s - %s", tool_call.tool_name, e) return MCPToolResult( success=False, error=str(e), safety_decision=SafetyDecision.ALLOW, execution_time_ms=self._elapsed_ms(start_time), checkpoint_id=checkpoint_id, ) def _build_action_request( self, tool_call: MCPToolCall, agent_id: str, autonomy_level: AutonomyLevel, ) -> ActionRequest: """Build an ActionRequest from an MCP tool call.""" action_type = self._classify_tool(tool_call.tool_name) metadata = ActionMetadata( agent_id=agent_id, session_id=tool_call.context.get("session_id", ""), project_id=tool_call.project_id or "", autonomy_level=autonomy_level, ) return ActionRequest( action_type=action_type, tool_name=tool_call.tool_name, arguments=tool_call.arguments, resource=tool_call.arguments.get("path", tool_call.arguments.get("resource")), metadata=metadata, ) def _classify_tool(self, tool_name: str) -> ActionType: """Classify a tool into an action type.""" tool_lower = tool_name.lower() # Check destructive patterns if any(d in tool_lower for d in ["write", "create", "delete", "remove", "update"]): if "file" in tool_lower: if "delete" in tool_lower or "remove" in tool_lower: return ActionType.FILE_DELETE return ActionType.FILE_WRITE if "database" in tool_lower or "db" in tool_lower: return ActionType.DATABASE_MUTATE # Check read patterns if any(r in tool_lower for r in ["read", "get", "list", "search", "query"]): if "file" in tool_lower: return ActionType.FILE_READ if "database" in tool_lower or "db" in tool_lower: return ActionType.DATABASE_QUERY # Check specific types if "shell" in tool_lower or "exec" in tool_lower or "bash" in tool_lower: return ActionType.SHELL_COMMAND if "git" in tool_lower: return ActionType.GIT_OPERATION if "http" in tool_lower or "fetch" in tool_lower or "request" in tool_lower: return ActionType.NETWORK_REQUEST if "llm" in tool_lower or "ai" in tool_lower or "claude" in tool_lower: return ActionType.LLM_CALL # Default to tool call return ActionType.TOOL_CALL def _elapsed_ms(self, start_time: datetime) -> float: """Calculate elapsed time in milliseconds.""" return (datetime.utcnow() - start_time).total_seconds() * 1000 class SafeToolExecutor: """ Context manager for safe tool execution with automatic cleanup. Usage: async with SafeToolExecutor(wrapper, tool_call, agent_id) as executor: result = await executor.execute() if result.success: # Use result else: # Handle error or approval required """ def __init__( self, wrapper: MCPSafetyWrapper, tool_call: MCPToolCall, agent_id: str, autonomy_level: AutonomyLevel = AutonomyLevel.MILESTONE, ) -> None: self._wrapper = wrapper self._tool_call = tool_call self._agent_id = agent_id self._autonomy_level = autonomy_level self._result: MCPToolResult | None = None async def __aenter__(self) -> "SafeToolExecutor": return self async def __aexit__( self, exc_type: type[Exception] | None, exc_val: Exception | None, exc_tb: Any, ) -> bool: # Could trigger rollback here if needed return False async def execute(self) -> MCPToolResult: """Execute the tool call.""" self._result = await self._wrapper.execute( self._tool_call, self._agent_id, self._autonomy_level, ) return self._result @property def result(self) -> MCPToolResult | None: """Get the execution result.""" return self._result # Factory function async def create_mcp_wrapper( guardian: SafetyGuardian | None = None, ) -> MCPSafetyWrapper: """Create an MCPSafetyWrapper with default configuration.""" if guardian is None: guardian = await get_safety_guardian() return MCPSafetyWrapper( guardian=guardian, emergency_controls=await get_emergency_controls(), )