feat(safety): add Phase D MCP integration and metrics

- Add MCPSafetyWrapper for safe MCP tool execution
- Add MCPToolCall/MCPToolResult models for MCP interactions
- Add SafeToolExecutor context manager
- Add SafetyMetrics collector with Prometheus export support
- Track validations, approvals, rate limits, budgets, and more
- Support for counters, gauges, and histograms

Issue #63

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-03 11:40:14 +01:00
parent ef659cd72d
commit f36bfb3781
4 changed files with 857 additions and 0 deletions

View File

@@ -0,0 +1,17 @@
"""MCP safety integration."""
from .integration import (
MCPSafetyWrapper,
MCPToolCall,
MCPToolResult,
SafeToolExecutor,
create_mcp_wrapper,
)
__all__ = [
"MCPSafetyWrapper",
"MCPToolCall",
"MCPToolResult",
"SafeToolExecutor",
"create_mcp_wrapper",
]

View File

@@ -0,0 +1,405 @@
"""
MCP Safety Integration
Provides safety-aware wrappers for MCP tool execution.
"""
import asyncio
import logging
from collections.abc import Callable
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, ClassVar, TypeVar
from ..audit import AuditLogger
from ..emergency import EmergencyControls, get_emergency_controls
from ..exceptions import (
EmergencyStopError,
SafetyError,
)
from ..guardian import SafetyGuardian, get_safety_guardian
from ..models import (
ActionMetadata,
ActionRequest,
ActionType,
AutonomyLevel,
SafetyDecision,
)
logger = logging.getLogger(__name__)
T = TypeVar("T")
@dataclass
class MCPToolCall:
"""Represents an MCP tool call."""
tool_name: str
arguments: dict[str, Any]
server_name: str | None = None
project_id: str | None = None
context: dict[str, Any] = field(default_factory=dict)
@dataclass
class MCPToolResult:
"""Result of an MCP tool execution."""
success: bool
result: Any = None
error: str | None = None
safety_decision: SafetyDecision = SafetyDecision.ALLOW
execution_time_ms: float = 0.0
approval_id: str | None = None
checkpoint_id: str | None = None
metadata: dict[str, Any] = field(default_factory=dict)
class MCPSafetyWrapper:
"""
Wraps MCP tool execution with safety checks.
Features:
- Pre-execution validation via SafetyGuardian
- Permission checking per tool/resource
- Budget and rate limit enforcement
- Audit logging of all MCP calls
- Emergency stop integration
- Checkpoint creation for destructive operations
"""
# Tool categories for automatic classification
DESTRUCTIVE_TOOLS: ClassVar[set[str]] = {
"file_write",
"file_delete",
"database_mutate",
"shell_execute",
"git_push",
"git_commit",
"deploy",
}
READ_ONLY_TOOLS: ClassVar[set[str]] = {
"file_read",
"database_query",
"git_status",
"git_log",
"list_files",
"search",
}
def __init__(
self,
guardian: SafetyGuardian | None = None,
audit_logger: AuditLogger | None = None,
emergency_controls: EmergencyControls | None = None,
) -> None:
"""
Initialize MCPSafetyWrapper.
Args:
guardian: SafetyGuardian instance (uses singleton if not provided)
audit_logger: AuditLogger instance
emergency_controls: EmergencyControls instance
"""
self._guardian = guardian
self._audit_logger = audit_logger
self._emergency_controls = emergency_controls
self._tool_handlers: dict[str, Callable[..., Any]] = {}
self._lock = asyncio.Lock()
async def _get_guardian(self) -> SafetyGuardian:
"""Get or create SafetyGuardian."""
if self._guardian is None:
self._guardian = await get_safety_guardian()
return self._guardian
async def _get_emergency_controls(self) -> EmergencyControls:
"""Get or create EmergencyControls."""
if self._emergency_controls is None:
self._emergency_controls = await get_emergency_controls()
return self._emergency_controls
def register_tool_handler(
self,
tool_name: str,
handler: Callable[..., Any],
) -> None:
"""
Register a handler for a tool.
Args:
tool_name: Name of the tool
handler: Async function to handle the tool call
"""
self._tool_handlers[tool_name] = handler
logger.debug("Registered handler for tool: %s", tool_name)
async def execute(
self,
tool_call: MCPToolCall,
agent_id: str,
autonomy_level: AutonomyLevel = AutonomyLevel.MILESTONE,
bypass_safety: bool = False,
) -> MCPToolResult:
"""
Execute an MCP tool call with safety checks.
Args:
tool_call: The tool call to execute
agent_id: ID of the calling agent
autonomy_level: Agent's autonomy level
bypass_safety: Bypass safety checks (emergency only)
Returns:
MCPToolResult with execution outcome
"""
start_time = datetime.utcnow()
# Check emergency controls first
emergency = await self._get_emergency_controls()
scope = f"agent:{agent_id}"
if tool_call.project_id:
scope = f"project:{tool_call.project_id}"
try:
await emergency.check_allowed(scope=scope, raise_if_blocked=True)
except EmergencyStopError as e:
return MCPToolResult(
success=False,
error=str(e),
safety_decision=SafetyDecision.DENY,
metadata={"emergency_stop": True},
)
# Build action request
action = self._build_action_request(
tool_call=tool_call,
agent_id=agent_id,
autonomy_level=autonomy_level,
)
# Skip safety checks if bypass is enabled
if bypass_safety:
logger.warning(
"Safety bypass enabled for tool: %s (agent: %s)",
tool_call.tool_name,
agent_id,
)
return await self._execute_tool(tool_call, action, start_time)
# Run safety validation
guardian = await self._get_guardian()
try:
guardian_result = await guardian.validate(action)
except SafetyError as e:
return MCPToolResult(
success=False,
error=str(e),
safety_decision=SafetyDecision.DENY,
execution_time_ms=self._elapsed_ms(start_time),
)
# Handle safety decision
if guardian_result.decision == SafetyDecision.DENY:
return MCPToolResult(
success=False,
error="; ".join(guardian_result.reasons),
safety_decision=SafetyDecision.DENY,
execution_time_ms=self._elapsed_ms(start_time),
)
if guardian_result.decision == SafetyDecision.REQUIRE_APPROVAL:
# For now, just return that approval is required
# The caller should handle the approval flow
return MCPToolResult(
success=False,
error="Action requires human approval",
safety_decision=SafetyDecision.REQUIRE_APPROVAL,
approval_id=guardian_result.approval_id,
execution_time_ms=self._elapsed_ms(start_time),
)
# Execute the tool
result = await self._execute_tool(
tool_call,
action,
start_time,
checkpoint_id=guardian_result.checkpoint_id,
)
return result
async def _execute_tool(
self,
tool_call: MCPToolCall,
action: ActionRequest,
start_time: datetime,
checkpoint_id: str | None = None,
) -> MCPToolResult:
"""Execute the actual tool call."""
handler = self._tool_handlers.get(tool_call.tool_name)
if handler is None:
return MCPToolResult(
success=False,
error=f"No handler registered for tool: {tool_call.tool_name}",
safety_decision=SafetyDecision.ALLOW,
execution_time_ms=self._elapsed_ms(start_time),
)
try:
if asyncio.iscoroutinefunction(handler):
result = await handler(**tool_call.arguments)
else:
result = handler(**tool_call.arguments)
return MCPToolResult(
success=True,
result=result,
safety_decision=SafetyDecision.ALLOW,
execution_time_ms=self._elapsed_ms(start_time),
checkpoint_id=checkpoint_id,
)
except Exception as e:
logger.error("Tool execution failed: %s - %s", tool_call.tool_name, e)
return MCPToolResult(
success=False,
error=str(e),
safety_decision=SafetyDecision.ALLOW,
execution_time_ms=self._elapsed_ms(start_time),
checkpoint_id=checkpoint_id,
)
def _build_action_request(
self,
tool_call: MCPToolCall,
agent_id: str,
autonomy_level: AutonomyLevel,
) -> ActionRequest:
"""Build an ActionRequest from an MCP tool call."""
action_type = self._classify_tool(tool_call.tool_name)
metadata = ActionMetadata(
agent_id=agent_id,
session_id=tool_call.context.get("session_id", ""),
project_id=tool_call.project_id or "",
autonomy_level=autonomy_level,
)
return ActionRequest(
action_type=action_type,
tool_name=tool_call.tool_name,
arguments=tool_call.arguments,
resource=tool_call.arguments.get("path", tool_call.arguments.get("resource")),
metadata=metadata,
)
def _classify_tool(self, tool_name: str) -> ActionType:
"""Classify a tool into an action type."""
tool_lower = tool_name.lower()
# Check destructive patterns
if any(d in tool_lower for d in ["write", "create", "delete", "remove", "update"]):
if "file" in tool_lower:
if "delete" in tool_lower or "remove" in tool_lower:
return ActionType.FILE_DELETE
return ActionType.FILE_WRITE
if "database" in tool_lower or "db" in tool_lower:
return ActionType.DATABASE_MUTATE
# Check read patterns
if any(r in tool_lower for r in ["read", "get", "list", "search", "query"]):
if "file" in tool_lower:
return ActionType.FILE_READ
if "database" in tool_lower or "db" in tool_lower:
return ActionType.DATABASE_QUERY
# Check specific types
if "shell" in tool_lower or "exec" in tool_lower or "bash" in tool_lower:
return ActionType.SHELL_COMMAND
if "git" in tool_lower:
return ActionType.GIT_OPERATION
if "http" in tool_lower or "fetch" in tool_lower or "request" in tool_lower:
return ActionType.NETWORK_REQUEST
if "llm" in tool_lower or "ai" in tool_lower or "claude" in tool_lower:
return ActionType.LLM_CALL
# Default to tool call
return ActionType.TOOL_CALL
def _elapsed_ms(self, start_time: datetime) -> float:
"""Calculate elapsed time in milliseconds."""
return (datetime.utcnow() - start_time).total_seconds() * 1000
class SafeToolExecutor:
"""
Context manager for safe tool execution with automatic cleanup.
Usage:
async with SafeToolExecutor(wrapper, tool_call, agent_id) as executor:
result = await executor.execute()
if result.success:
# Use result
else:
# Handle error or approval required
"""
def __init__(
self,
wrapper: MCPSafetyWrapper,
tool_call: MCPToolCall,
agent_id: str,
autonomy_level: AutonomyLevel = AutonomyLevel.MILESTONE,
) -> None:
self._wrapper = wrapper
self._tool_call = tool_call
self._agent_id = agent_id
self._autonomy_level = autonomy_level
self._result: MCPToolResult | None = None
async def __aenter__(self) -> "SafeToolExecutor":
return self
async def __aexit__(
self,
exc_type: type[Exception] | None,
exc_val: Exception | None,
exc_tb: Any,
) -> bool:
# Could trigger rollback here if needed
return False
async def execute(self) -> MCPToolResult:
"""Execute the tool call."""
self._result = await self._wrapper.execute(
self._tool_call,
self._agent_id,
self._autonomy_level,
)
return self._result
@property
def result(self) -> MCPToolResult | None:
"""Get the execution result."""
return self._result
# Factory function
async def create_mcp_wrapper(
guardian: SafetyGuardian | None = None,
) -> MCPSafetyWrapper:
"""Create an MCPSafetyWrapper with default configuration."""
if guardian is None:
guardian = await get_safety_guardian()
return MCPSafetyWrapper(
guardian=guardian,
emergency_controls=await get_emergency_controls(),
)

View File

@@ -0,0 +1,19 @@
"""Safety metrics collection and export."""
from .collector import (
MetricType,
MetricValue,
SafetyMetrics,
get_safety_metrics,
record_mcp_call,
record_validation,
)
__all__ = [
"MetricType",
"MetricValue",
"SafetyMetrics",
"get_safety_metrics",
"record_mcp_call",
"record_validation",
]

View File

@@ -0,0 +1,416 @@
"""
Safety Metrics Collector
Collects and exposes metrics for the safety framework.
"""
import asyncio
import logging
from collections import Counter, defaultdict
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any
logger = logging.getLogger(__name__)
class MetricType(str, Enum):
"""Types of metrics."""
COUNTER = "counter"
GAUGE = "gauge"
HISTOGRAM = "histogram"
@dataclass
class MetricValue:
"""A single metric value."""
name: str
metric_type: MetricType
value: float
labels: dict[str, str] = field(default_factory=dict)
timestamp: datetime = field(default_factory=datetime.utcnow)
@dataclass
class HistogramBucket:
"""Histogram bucket for distribution metrics."""
le: float # Less than or equal
count: int = 0
class SafetyMetrics:
"""
Collects safety framework metrics.
Metrics tracked:
- Action validation counts (by decision type)
- Approval request counts and latencies
- Budget usage and remaining
- Rate limit hits
- Loop detections
- Emergency events
- Content filter matches
"""
def __init__(self) -> None:
"""Initialize SafetyMetrics."""
self._counters: dict[str, Counter[str]] = defaultdict(Counter)
self._gauges: dict[str, dict[str, float]] = defaultdict(dict)
self._histograms: dict[str, list[float]] = defaultdict(list)
self._histogram_buckets: dict[str, list[HistogramBucket]] = {}
self._lock = asyncio.Lock()
# Initialize histogram buckets
self._init_histogram_buckets()
def _init_histogram_buckets(self) -> None:
"""Initialize histogram buckets for latency metrics."""
latency_buckets = [0.01, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, float("inf")]
for name in [
"validation_latency_seconds",
"approval_latency_seconds",
"mcp_execution_latency_seconds",
]:
self._histogram_buckets[name] = [
HistogramBucket(le=b) for b in latency_buckets
]
# Counter methods
async def inc_validations(
self,
decision: str,
agent_id: str | None = None,
) -> None:
"""Increment validation counter."""
async with self._lock:
labels = f"decision={decision}"
if agent_id:
labels += f",agent_id={agent_id}"
self._counters["safety_validations_total"][labels] += 1
async def inc_approvals_requested(self, urgency: str = "normal") -> None:
"""Increment approval requests counter."""
async with self._lock:
labels = f"urgency={urgency}"
self._counters["safety_approvals_requested_total"][labels] += 1
async def inc_approvals_granted(self) -> None:
"""Increment approvals granted counter."""
async with self._lock:
self._counters["safety_approvals_granted_total"][""] += 1
async def inc_approvals_denied(self, reason: str = "manual") -> None:
"""Increment approvals denied counter."""
async with self._lock:
labels = f"reason={reason}"
self._counters["safety_approvals_denied_total"][labels] += 1
async def inc_rate_limit_exceeded(self, limit_type: str) -> None:
"""Increment rate limit exceeded counter."""
async with self._lock:
labels = f"limit_type={limit_type}"
self._counters["safety_rate_limit_exceeded_total"][labels] += 1
async def inc_budget_exceeded(self, budget_type: str) -> None:
"""Increment budget exceeded counter."""
async with self._lock:
labels = f"budget_type={budget_type}"
self._counters["safety_budget_exceeded_total"][labels] += 1
async def inc_loops_detected(self, loop_type: str) -> None:
"""Increment loop detection counter."""
async with self._lock:
labels = f"loop_type={loop_type}"
self._counters["safety_loops_detected_total"][labels] += 1
async def inc_emergency_events(self, event_type: str, scope: str) -> None:
"""Increment emergency events counter."""
async with self._lock:
labels = f"event_type={event_type},scope={scope}"
self._counters["safety_emergency_events_total"][labels] += 1
async def inc_content_filtered(self, category: str, action: str) -> None:
"""Increment content filter counter."""
async with self._lock:
labels = f"category={category},action={action}"
self._counters["safety_content_filtered_total"][labels] += 1
async def inc_checkpoints_created(self) -> None:
"""Increment checkpoints created counter."""
async with self._lock:
self._counters["safety_checkpoints_created_total"][""] += 1
async def inc_rollbacks_executed(self, success: bool) -> None:
"""Increment rollbacks counter."""
async with self._lock:
labels = f"success={str(success).lower()}"
self._counters["safety_rollbacks_total"][labels] += 1
async def inc_mcp_calls(self, tool_name: str, success: bool) -> None:
"""Increment MCP tool calls counter."""
async with self._lock:
labels = f"tool_name={tool_name},success={str(success).lower()}"
self._counters["safety_mcp_calls_total"][labels] += 1
# Gauge methods
async def set_budget_remaining(
self,
scope: str,
budget_type: str,
remaining: float,
) -> None:
"""Set remaining budget gauge."""
async with self._lock:
labels = f"scope={scope},budget_type={budget_type}"
self._gauges["safety_budget_remaining"][labels] = remaining
async def set_rate_limit_remaining(
self,
scope: str,
limit_type: str,
remaining: int,
) -> None:
"""Set remaining rate limit gauge."""
async with self._lock:
labels = f"scope={scope},limit_type={limit_type}"
self._gauges["safety_rate_limit_remaining"][labels] = float(remaining)
async def set_pending_approvals(self, count: int) -> None:
"""Set pending approvals gauge."""
async with self._lock:
self._gauges["safety_pending_approvals"][""] = float(count)
async def set_active_checkpoints(self, count: int) -> None:
"""Set active checkpoints gauge."""
async with self._lock:
self._gauges["safety_active_checkpoints"][""] = float(count)
async def set_emergency_state(self, scope: str, state: str) -> None:
"""Set emergency state gauge (0=normal, 1=paused, 2=stopped)."""
async with self._lock:
state_value = {"normal": 0, "paused": 1, "stopped": 2}.get(state, -1)
labels = f"scope={scope}"
self._gauges["safety_emergency_state"][labels] = float(state_value)
# Histogram methods
async def observe_validation_latency(self, latency_seconds: float) -> None:
"""Observe validation latency."""
async with self._lock:
self._observe_histogram("validation_latency_seconds", latency_seconds)
async def observe_approval_latency(self, latency_seconds: float) -> None:
"""Observe approval latency."""
async with self._lock:
self._observe_histogram("approval_latency_seconds", latency_seconds)
async def observe_mcp_execution_latency(self, latency_seconds: float) -> None:
"""Observe MCP execution latency."""
async with self._lock:
self._observe_histogram("mcp_execution_latency_seconds", latency_seconds)
def _observe_histogram(self, name: str, value: float) -> None:
"""Record a value in a histogram."""
self._histograms[name].append(value)
# Update buckets
if name in self._histogram_buckets:
for bucket in self._histogram_buckets[name]:
if value <= bucket.le:
bucket.count += 1
# Export methods
async def get_all_metrics(self) -> list[MetricValue]:
"""Get all metrics as MetricValue objects."""
metrics: list[MetricValue] = []
async with self._lock:
# Export counters
for name, counter in self._counters.items():
for labels_str, value in counter.items():
labels = self._parse_labels(labels_str)
metrics.append(
MetricValue(
name=name,
metric_type=MetricType.COUNTER,
value=float(value),
labels=labels,
)
)
# Export gauges
for name, gauge_dict in self._gauges.items():
for labels_str, gauge_value in gauge_dict.items():
gauge_labels = self._parse_labels(labels_str)
metrics.append(
MetricValue(
name=name,
metric_type=MetricType.GAUGE,
value=gauge_value,
labels=gauge_labels,
)
)
# Export histogram summaries
for name, values in self._histograms.items():
if values:
metrics.append(
MetricValue(
name=f"{name}_count",
metric_type=MetricType.COUNTER,
value=float(len(values)),
)
)
metrics.append(
MetricValue(
name=f"{name}_sum",
metric_type=MetricType.COUNTER,
value=sum(values),
)
)
return metrics
async def get_prometheus_format(self) -> str:
"""Export metrics in Prometheus text format."""
lines: list[str] = []
async with self._lock:
# Export counters
for name, counter in self._counters.items():
lines.append(f"# TYPE {name} counter")
for labels_str, value in counter.items():
if labels_str:
lines.append(f"{name}{{{labels_str}}} {value}")
else:
lines.append(f"{name} {value}")
# Export gauges
for name, gauge_dict in self._gauges.items():
lines.append(f"# TYPE {name} gauge")
for labels_str, gauge_value in gauge_dict.items():
if labels_str:
lines.append(f"{name}{{{labels_str}}} {gauge_value}")
else:
lines.append(f"{name} {gauge_value}")
# Export histograms
for name, buckets in self._histogram_buckets.items():
lines.append(f"# TYPE {name} histogram")
for bucket in buckets:
le_str = "+Inf" if bucket.le == float("inf") else str(bucket.le)
lines.append(f'{name}_bucket{{le="{le_str}"}} {bucket.count}')
if name in self._histograms:
values = self._histograms[name]
lines.append(f"{name}_count {len(values)}")
lines.append(f"{name}_sum {sum(values)}")
return "\n".join(lines)
async def get_summary(self) -> dict[str, Any]:
"""Get a summary of key metrics."""
async with self._lock:
total_validations = sum(self._counters["safety_validations_total"].values())
denied_validations = sum(
v for k, v in self._counters["safety_validations_total"].items()
if "decision=deny" in k
)
return {
"total_validations": total_validations,
"denied_validations": denied_validations,
"approval_requests": sum(
self._counters["safety_approvals_requested_total"].values()
),
"approvals_granted": sum(
self._counters["safety_approvals_granted_total"].values()
),
"approvals_denied": sum(
self._counters["safety_approvals_denied_total"].values()
),
"rate_limit_hits": sum(
self._counters["safety_rate_limit_exceeded_total"].values()
),
"budget_exceeded": sum(
self._counters["safety_budget_exceeded_total"].values()
),
"loops_detected": sum(
self._counters["safety_loops_detected_total"].values()
),
"emergency_events": sum(
self._counters["safety_emergency_events_total"].values()
),
"content_filtered": sum(
self._counters["safety_content_filtered_total"].values()
),
"checkpoints_created": sum(
self._counters["safety_checkpoints_created_total"].values()
),
"rollbacks_executed": sum(
self._counters["safety_rollbacks_total"].values()
),
"mcp_calls": sum(
self._counters["safety_mcp_calls_total"].values()
),
"pending_approvals": self._gauges.get("safety_pending_approvals", {}).get("", 0),
"active_checkpoints": self._gauges.get("safety_active_checkpoints", {}).get("", 0),
}
async def reset(self) -> None:
"""Reset all metrics."""
async with self._lock:
self._counters.clear()
self._gauges.clear()
self._histograms.clear()
self._init_histogram_buckets()
def _parse_labels(self, labels_str: str) -> dict[str, str]:
"""Parse labels string into dictionary."""
if not labels_str:
return {}
labels = {}
for pair in labels_str.split(","):
if "=" in pair:
key, value = pair.split("=", 1)
labels[key.strip()] = value.strip()
return labels
# Singleton instance
_metrics: SafetyMetrics | None = None
_lock = asyncio.Lock()
async def get_safety_metrics() -> SafetyMetrics:
"""Get the singleton SafetyMetrics instance."""
global _metrics
async with _lock:
if _metrics is None:
_metrics = SafetyMetrics()
return _metrics
# Convenience functions
async def record_validation(decision: str, agent_id: str | None = None) -> None:
"""Record a validation event."""
metrics = await get_safety_metrics()
await metrics.inc_validations(decision, agent_id)
async def record_mcp_call(tool_name: str, success: bool, latency_ms: float) -> None:
"""Record an MCP tool call."""
metrics = await get_safety_metrics()
await metrics.inc_mcp_calls(tool_name, success)
await metrics.observe_mcp_execution_latency(latency_ms / 1000)