forked from cardosofelipe/fast-next-template
Improved code readability and uniformity by standardizing line breaks, indentation, and inline conditions across safety-related services, models, and tests, including content filters, validation rules, and emergency controls.
410 lines
12 KiB
Python
410 lines
12 KiB
Python
"""
|
|
MCP Safety Integration
|
|
|
|
Provides safety-aware wrappers for MCP tool execution.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
from typing import Any, ClassVar, TypeVar
|
|
|
|
from ..audit import AuditLogger
|
|
from ..emergency import EmergencyControls, get_emergency_controls
|
|
from ..exceptions import (
|
|
EmergencyStopError,
|
|
SafetyError,
|
|
)
|
|
from ..guardian import SafetyGuardian, get_safety_guardian
|
|
from ..models import (
|
|
ActionMetadata,
|
|
ActionRequest,
|
|
ActionType,
|
|
AutonomyLevel,
|
|
SafetyDecision,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
@dataclass
|
|
class MCPToolCall:
|
|
"""Represents an MCP tool call."""
|
|
|
|
tool_name: str
|
|
arguments: dict[str, Any]
|
|
server_name: str | None = None
|
|
project_id: str | None = None
|
|
context: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
@dataclass
|
|
class MCPToolResult:
|
|
"""Result of an MCP tool execution."""
|
|
|
|
success: bool
|
|
result: Any = None
|
|
error: str | None = None
|
|
safety_decision: SafetyDecision = SafetyDecision.ALLOW
|
|
execution_time_ms: float = 0.0
|
|
approval_id: str | None = None
|
|
checkpoint_id: str | None = None
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
class MCPSafetyWrapper:
|
|
"""
|
|
Wraps MCP tool execution with safety checks.
|
|
|
|
Features:
|
|
- Pre-execution validation via SafetyGuardian
|
|
- Permission checking per tool/resource
|
|
- Budget and rate limit enforcement
|
|
- Audit logging of all MCP calls
|
|
- Emergency stop integration
|
|
- Checkpoint creation for destructive operations
|
|
"""
|
|
|
|
# Tool categories for automatic classification
|
|
DESTRUCTIVE_TOOLS: ClassVar[set[str]] = {
|
|
"file_write",
|
|
"file_delete",
|
|
"database_mutate",
|
|
"shell_execute",
|
|
"git_push",
|
|
"git_commit",
|
|
"deploy",
|
|
}
|
|
|
|
READ_ONLY_TOOLS: ClassVar[set[str]] = {
|
|
"file_read",
|
|
"database_query",
|
|
"git_status",
|
|
"git_log",
|
|
"list_files",
|
|
"search",
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
guardian: SafetyGuardian | None = None,
|
|
audit_logger: AuditLogger | None = None,
|
|
emergency_controls: EmergencyControls | None = None,
|
|
) -> None:
|
|
"""
|
|
Initialize MCPSafetyWrapper.
|
|
|
|
Args:
|
|
guardian: SafetyGuardian instance (uses singleton if not provided)
|
|
audit_logger: AuditLogger instance
|
|
emergency_controls: EmergencyControls instance
|
|
"""
|
|
self._guardian = guardian
|
|
self._audit_logger = audit_logger
|
|
self._emergency_controls = emergency_controls
|
|
self._tool_handlers: dict[str, Callable[..., Any]] = {}
|
|
self._lock = asyncio.Lock()
|
|
|
|
async def _get_guardian(self) -> SafetyGuardian:
|
|
"""Get or create SafetyGuardian."""
|
|
if self._guardian is None:
|
|
self._guardian = await get_safety_guardian()
|
|
return self._guardian
|
|
|
|
async def _get_emergency_controls(self) -> EmergencyControls:
|
|
"""Get or create EmergencyControls."""
|
|
if self._emergency_controls is None:
|
|
self._emergency_controls = await get_emergency_controls()
|
|
return self._emergency_controls
|
|
|
|
def register_tool_handler(
|
|
self,
|
|
tool_name: str,
|
|
handler: Callable[..., Any],
|
|
) -> None:
|
|
"""
|
|
Register a handler for a tool.
|
|
|
|
Args:
|
|
tool_name: Name of the tool
|
|
handler: Async function to handle the tool call
|
|
"""
|
|
self._tool_handlers[tool_name] = handler
|
|
logger.debug("Registered handler for tool: %s", tool_name)
|
|
|
|
async def execute(
|
|
self,
|
|
tool_call: MCPToolCall,
|
|
agent_id: str,
|
|
autonomy_level: AutonomyLevel = AutonomyLevel.MILESTONE,
|
|
bypass_safety: bool = False,
|
|
) -> MCPToolResult:
|
|
"""
|
|
Execute an MCP tool call with safety checks.
|
|
|
|
Args:
|
|
tool_call: The tool call to execute
|
|
agent_id: ID of the calling agent
|
|
autonomy_level: Agent's autonomy level
|
|
bypass_safety: Bypass safety checks (emergency only)
|
|
|
|
Returns:
|
|
MCPToolResult with execution outcome
|
|
"""
|
|
start_time = datetime.utcnow()
|
|
|
|
# Check emergency controls first
|
|
emergency = await self._get_emergency_controls()
|
|
scope = f"agent:{agent_id}"
|
|
if tool_call.project_id:
|
|
scope = f"project:{tool_call.project_id}"
|
|
|
|
try:
|
|
await emergency.check_allowed(scope=scope, raise_if_blocked=True)
|
|
except EmergencyStopError as e:
|
|
return MCPToolResult(
|
|
success=False,
|
|
error=str(e),
|
|
safety_decision=SafetyDecision.DENY,
|
|
metadata={"emergency_stop": True},
|
|
)
|
|
|
|
# Build action request
|
|
action = self._build_action_request(
|
|
tool_call=tool_call,
|
|
agent_id=agent_id,
|
|
autonomy_level=autonomy_level,
|
|
)
|
|
|
|
# Skip safety checks if bypass is enabled
|
|
if bypass_safety:
|
|
logger.warning(
|
|
"Safety bypass enabled for tool: %s (agent: %s)",
|
|
tool_call.tool_name,
|
|
agent_id,
|
|
)
|
|
return await self._execute_tool(tool_call, action, start_time)
|
|
|
|
# Run safety validation
|
|
guardian = await self._get_guardian()
|
|
try:
|
|
guardian_result = await guardian.validate(action)
|
|
except SafetyError as e:
|
|
return MCPToolResult(
|
|
success=False,
|
|
error=str(e),
|
|
safety_decision=SafetyDecision.DENY,
|
|
execution_time_ms=self._elapsed_ms(start_time),
|
|
)
|
|
|
|
# Handle safety decision
|
|
if guardian_result.decision == SafetyDecision.DENY:
|
|
return MCPToolResult(
|
|
success=False,
|
|
error="; ".join(guardian_result.reasons),
|
|
safety_decision=SafetyDecision.DENY,
|
|
execution_time_ms=self._elapsed_ms(start_time),
|
|
)
|
|
|
|
if guardian_result.decision == SafetyDecision.REQUIRE_APPROVAL:
|
|
# For now, just return that approval is required
|
|
# The caller should handle the approval flow
|
|
return MCPToolResult(
|
|
success=False,
|
|
error="Action requires human approval",
|
|
safety_decision=SafetyDecision.REQUIRE_APPROVAL,
|
|
approval_id=guardian_result.approval_id,
|
|
execution_time_ms=self._elapsed_ms(start_time),
|
|
)
|
|
|
|
# Execute the tool
|
|
result = await self._execute_tool(
|
|
tool_call,
|
|
action,
|
|
start_time,
|
|
checkpoint_id=guardian_result.checkpoint_id,
|
|
)
|
|
|
|
return result
|
|
|
|
async def _execute_tool(
|
|
self,
|
|
tool_call: MCPToolCall,
|
|
action: ActionRequest,
|
|
start_time: datetime,
|
|
checkpoint_id: str | None = None,
|
|
) -> MCPToolResult:
|
|
"""Execute the actual tool call."""
|
|
handler = self._tool_handlers.get(tool_call.tool_name)
|
|
|
|
if handler is None:
|
|
return MCPToolResult(
|
|
success=False,
|
|
error=f"No handler registered for tool: {tool_call.tool_name}",
|
|
safety_decision=SafetyDecision.ALLOW,
|
|
execution_time_ms=self._elapsed_ms(start_time),
|
|
)
|
|
|
|
try:
|
|
if asyncio.iscoroutinefunction(handler):
|
|
result = await handler(**tool_call.arguments)
|
|
else:
|
|
result = handler(**tool_call.arguments)
|
|
|
|
return MCPToolResult(
|
|
success=True,
|
|
result=result,
|
|
safety_decision=SafetyDecision.ALLOW,
|
|
execution_time_ms=self._elapsed_ms(start_time),
|
|
checkpoint_id=checkpoint_id,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error("Tool execution failed: %s - %s", tool_call.tool_name, e)
|
|
return MCPToolResult(
|
|
success=False,
|
|
error=str(e),
|
|
safety_decision=SafetyDecision.ALLOW,
|
|
execution_time_ms=self._elapsed_ms(start_time),
|
|
checkpoint_id=checkpoint_id,
|
|
)
|
|
|
|
def _build_action_request(
|
|
self,
|
|
tool_call: MCPToolCall,
|
|
agent_id: str,
|
|
autonomy_level: AutonomyLevel,
|
|
) -> ActionRequest:
|
|
"""Build an ActionRequest from an MCP tool call."""
|
|
action_type = self._classify_tool(tool_call.tool_name)
|
|
|
|
metadata = ActionMetadata(
|
|
agent_id=agent_id,
|
|
session_id=tool_call.context.get("session_id", ""),
|
|
project_id=tool_call.project_id or "",
|
|
autonomy_level=autonomy_level,
|
|
)
|
|
|
|
return ActionRequest(
|
|
action_type=action_type,
|
|
tool_name=tool_call.tool_name,
|
|
arguments=tool_call.arguments,
|
|
resource=tool_call.arguments.get(
|
|
"path", tool_call.arguments.get("resource")
|
|
),
|
|
metadata=metadata,
|
|
)
|
|
|
|
def _classify_tool(self, tool_name: str) -> ActionType:
|
|
"""Classify a tool into an action type."""
|
|
tool_lower = tool_name.lower()
|
|
|
|
# Check destructive patterns
|
|
if any(
|
|
d in tool_lower for d in ["write", "create", "delete", "remove", "update"]
|
|
):
|
|
if "file" in tool_lower:
|
|
if "delete" in tool_lower or "remove" in tool_lower:
|
|
return ActionType.FILE_DELETE
|
|
return ActionType.FILE_WRITE
|
|
if "database" in tool_lower or "db" in tool_lower:
|
|
return ActionType.DATABASE_MUTATE
|
|
|
|
# Check read patterns
|
|
if any(r in tool_lower for r in ["read", "get", "list", "search", "query"]):
|
|
if "file" in tool_lower:
|
|
return ActionType.FILE_READ
|
|
if "database" in tool_lower or "db" in tool_lower:
|
|
return ActionType.DATABASE_QUERY
|
|
|
|
# Check specific types
|
|
if "shell" in tool_lower or "exec" in tool_lower or "bash" in tool_lower:
|
|
return ActionType.SHELL_COMMAND
|
|
|
|
if "git" in tool_lower:
|
|
return ActionType.GIT_OPERATION
|
|
|
|
if "http" in tool_lower or "fetch" in tool_lower or "request" in tool_lower:
|
|
return ActionType.NETWORK_REQUEST
|
|
|
|
if "llm" in tool_lower or "ai" in tool_lower or "claude" in tool_lower:
|
|
return ActionType.LLM_CALL
|
|
|
|
# Default to tool call
|
|
return ActionType.TOOL_CALL
|
|
|
|
def _elapsed_ms(self, start_time: datetime) -> float:
|
|
"""Calculate elapsed time in milliseconds."""
|
|
return (datetime.utcnow() - start_time).total_seconds() * 1000
|
|
|
|
|
|
class SafeToolExecutor:
|
|
"""
|
|
Context manager for safe tool execution with automatic cleanup.
|
|
|
|
Usage:
|
|
async with SafeToolExecutor(wrapper, tool_call, agent_id) as executor:
|
|
result = await executor.execute()
|
|
if result.success:
|
|
# Use result
|
|
else:
|
|
# Handle error or approval required
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
wrapper: MCPSafetyWrapper,
|
|
tool_call: MCPToolCall,
|
|
agent_id: str,
|
|
autonomy_level: AutonomyLevel = AutonomyLevel.MILESTONE,
|
|
) -> None:
|
|
self._wrapper = wrapper
|
|
self._tool_call = tool_call
|
|
self._agent_id = agent_id
|
|
self._autonomy_level = autonomy_level
|
|
self._result: MCPToolResult | None = None
|
|
|
|
async def __aenter__(self) -> "SafeToolExecutor":
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: type[Exception] | None,
|
|
exc_val: Exception | None,
|
|
exc_tb: Any,
|
|
) -> bool:
|
|
# Could trigger rollback here if needed
|
|
return False
|
|
|
|
async def execute(self) -> MCPToolResult:
|
|
"""Execute the tool call."""
|
|
self._result = await self._wrapper.execute(
|
|
self._tool_call,
|
|
self._agent_id,
|
|
self._autonomy_level,
|
|
)
|
|
return self._result
|
|
|
|
@property
|
|
def result(self) -> MCPToolResult | None:
|
|
"""Get the execution result."""
|
|
return self._result
|
|
|
|
|
|
# Factory function
|
|
async def create_mcp_wrapper(
|
|
guardian: SafetyGuardian | None = None,
|
|
) -> MCPSafetyWrapper:
|
|
"""Create an MCPSafetyWrapper with default configuration."""
|
|
if guardian is None:
|
|
guardian = await get_safety_guardian()
|
|
|
|
return MCPSafetyWrapper(
|
|
guardian=guardian,
|
|
emergency_controls=await get_emergency_controls(),
|
|
)
|