forked from cardosofelipe/fast-next-template
feat(safety): add Phase D MCP integration and metrics
- Add MCPSafetyWrapper for safe MCP tool execution - Add MCPToolCall/MCPToolResult models for MCP interactions - Add SafeToolExecutor context manager - Add SafetyMetrics collector with Prometheus export support - Track validations, approvals, rate limits, budgets, and more - Support for counters, gauges, and histograms Issue #63 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
17
backend/app/services/safety/mcp/__init__.py
Normal file
17
backend/app/services/safety/mcp/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""MCP safety integration."""
|
||||
|
||||
from .integration import (
|
||||
MCPSafetyWrapper,
|
||||
MCPToolCall,
|
||||
MCPToolResult,
|
||||
SafeToolExecutor,
|
||||
create_mcp_wrapper,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MCPSafetyWrapper",
|
||||
"MCPToolCall",
|
||||
"MCPToolResult",
|
||||
"SafeToolExecutor",
|
||||
"create_mcp_wrapper",
|
||||
]
|
||||
405
backend/app/services/safety/mcp/integration.py
Normal file
405
backend/app/services/safety/mcp/integration.py
Normal file
@@ -0,0 +1,405 @@
|
||||
"""
|
||||
MCP Safety Integration
|
||||
|
||||
Provides safety-aware wrappers for MCP tool execution.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, ClassVar, TypeVar
|
||||
|
||||
from ..audit import AuditLogger
|
||||
from ..emergency import EmergencyControls, get_emergency_controls
|
||||
from ..exceptions import (
|
||||
EmergencyStopError,
|
||||
SafetyError,
|
||||
)
|
||||
from ..guardian import SafetyGuardian, get_safety_guardian
|
||||
from ..models import (
|
||||
ActionMetadata,
|
||||
ActionRequest,
|
||||
ActionType,
|
||||
AutonomyLevel,
|
||||
SafetyDecision,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPToolCall:
|
||||
"""Represents an MCP tool call."""
|
||||
|
||||
tool_name: str
|
||||
arguments: dict[str, Any]
|
||||
server_name: str | None = None
|
||||
project_id: str | None = None
|
||||
context: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPToolResult:
|
||||
"""Result of an MCP tool execution."""
|
||||
|
||||
success: bool
|
||||
result: Any = None
|
||||
error: str | None = None
|
||||
safety_decision: SafetyDecision = SafetyDecision.ALLOW
|
||||
execution_time_ms: float = 0.0
|
||||
approval_id: str | None = None
|
||||
checkpoint_id: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class MCPSafetyWrapper:
|
||||
"""
|
||||
Wraps MCP tool execution with safety checks.
|
||||
|
||||
Features:
|
||||
- Pre-execution validation via SafetyGuardian
|
||||
- Permission checking per tool/resource
|
||||
- Budget and rate limit enforcement
|
||||
- Audit logging of all MCP calls
|
||||
- Emergency stop integration
|
||||
- Checkpoint creation for destructive operations
|
||||
"""
|
||||
|
||||
# Tool categories for automatic classification
|
||||
DESTRUCTIVE_TOOLS: ClassVar[set[str]] = {
|
||||
"file_write",
|
||||
"file_delete",
|
||||
"database_mutate",
|
||||
"shell_execute",
|
||||
"git_push",
|
||||
"git_commit",
|
||||
"deploy",
|
||||
}
|
||||
|
||||
READ_ONLY_TOOLS: ClassVar[set[str]] = {
|
||||
"file_read",
|
||||
"database_query",
|
||||
"git_status",
|
||||
"git_log",
|
||||
"list_files",
|
||||
"search",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guardian: SafetyGuardian | None = None,
|
||||
audit_logger: AuditLogger | None = None,
|
||||
emergency_controls: EmergencyControls | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize MCPSafetyWrapper.
|
||||
|
||||
Args:
|
||||
guardian: SafetyGuardian instance (uses singleton if not provided)
|
||||
audit_logger: AuditLogger instance
|
||||
emergency_controls: EmergencyControls instance
|
||||
"""
|
||||
self._guardian = guardian
|
||||
self._audit_logger = audit_logger
|
||||
self._emergency_controls = emergency_controls
|
||||
self._tool_handlers: dict[str, Callable[..., Any]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def _get_guardian(self) -> SafetyGuardian:
|
||||
"""Get or create SafetyGuardian."""
|
||||
if self._guardian is None:
|
||||
self._guardian = await get_safety_guardian()
|
||||
return self._guardian
|
||||
|
||||
async def _get_emergency_controls(self) -> EmergencyControls:
|
||||
"""Get or create EmergencyControls."""
|
||||
if self._emergency_controls is None:
|
||||
self._emergency_controls = await get_emergency_controls()
|
||||
return self._emergency_controls
|
||||
|
||||
def register_tool_handler(
|
||||
self,
|
||||
tool_name: str,
|
||||
handler: Callable[..., Any],
|
||||
) -> None:
|
||||
"""
|
||||
Register a handler for a tool.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool
|
||||
handler: Async function to handle the tool call
|
||||
"""
|
||||
self._tool_handlers[tool_name] = handler
|
||||
logger.debug("Registered handler for tool: %s", tool_name)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
tool_call: MCPToolCall,
|
||||
agent_id: str,
|
||||
autonomy_level: AutonomyLevel = AutonomyLevel.MILESTONE,
|
||||
bypass_safety: bool = False,
|
||||
) -> MCPToolResult:
|
||||
"""
|
||||
Execute an MCP tool call with safety checks.
|
||||
|
||||
Args:
|
||||
tool_call: The tool call to execute
|
||||
agent_id: ID of the calling agent
|
||||
autonomy_level: Agent's autonomy level
|
||||
bypass_safety: Bypass safety checks (emergency only)
|
||||
|
||||
Returns:
|
||||
MCPToolResult with execution outcome
|
||||
"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
# Check emergency controls first
|
||||
emergency = await self._get_emergency_controls()
|
||||
scope = f"agent:{agent_id}"
|
||||
if tool_call.project_id:
|
||||
scope = f"project:{tool_call.project_id}"
|
||||
|
||||
try:
|
||||
await emergency.check_allowed(scope=scope, raise_if_blocked=True)
|
||||
except EmergencyStopError as e:
|
||||
return MCPToolResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
safety_decision=SafetyDecision.DENY,
|
||||
metadata={"emergency_stop": True},
|
||||
)
|
||||
|
||||
# Build action request
|
||||
action = self._build_action_request(
|
||||
tool_call=tool_call,
|
||||
agent_id=agent_id,
|
||||
autonomy_level=autonomy_level,
|
||||
)
|
||||
|
||||
# Skip safety checks if bypass is enabled
|
||||
if bypass_safety:
|
||||
logger.warning(
|
||||
"Safety bypass enabled for tool: %s (agent: %s)",
|
||||
tool_call.tool_name,
|
||||
agent_id,
|
||||
)
|
||||
return await self._execute_tool(tool_call, action, start_time)
|
||||
|
||||
# Run safety validation
|
||||
guardian = await self._get_guardian()
|
||||
try:
|
||||
guardian_result = await guardian.validate(action)
|
||||
except SafetyError as e:
|
||||
return MCPToolResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
safety_decision=SafetyDecision.DENY,
|
||||
execution_time_ms=self._elapsed_ms(start_time),
|
||||
)
|
||||
|
||||
# Handle safety decision
|
||||
if guardian_result.decision == SafetyDecision.DENY:
|
||||
return MCPToolResult(
|
||||
success=False,
|
||||
error="; ".join(guardian_result.reasons),
|
||||
safety_decision=SafetyDecision.DENY,
|
||||
execution_time_ms=self._elapsed_ms(start_time),
|
||||
)
|
||||
|
||||
if guardian_result.decision == SafetyDecision.REQUIRE_APPROVAL:
|
||||
# For now, just return that approval is required
|
||||
# The caller should handle the approval flow
|
||||
return MCPToolResult(
|
||||
success=False,
|
||||
error="Action requires human approval",
|
||||
safety_decision=SafetyDecision.REQUIRE_APPROVAL,
|
||||
approval_id=guardian_result.approval_id,
|
||||
execution_time_ms=self._elapsed_ms(start_time),
|
||||
)
|
||||
|
||||
# Execute the tool
|
||||
result = await self._execute_tool(
|
||||
tool_call,
|
||||
action,
|
||||
start_time,
|
||||
checkpoint_id=guardian_result.checkpoint_id,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def _execute_tool(
|
||||
self,
|
||||
tool_call: MCPToolCall,
|
||||
action: ActionRequest,
|
||||
start_time: datetime,
|
||||
checkpoint_id: str | None = None,
|
||||
) -> MCPToolResult:
|
||||
"""Execute the actual tool call."""
|
||||
handler = self._tool_handlers.get(tool_call.tool_name)
|
||||
|
||||
if handler is None:
|
||||
return MCPToolResult(
|
||||
success=False,
|
||||
error=f"No handler registered for tool: {tool_call.tool_name}",
|
||||
safety_decision=SafetyDecision.ALLOW,
|
||||
execution_time_ms=self._elapsed_ms(start_time),
|
||||
)
|
||||
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
result = await handler(**tool_call.arguments)
|
||||
else:
|
||||
result = handler(**tool_call.arguments)
|
||||
|
||||
return MCPToolResult(
|
||||
success=True,
|
||||
result=result,
|
||||
safety_decision=SafetyDecision.ALLOW,
|
||||
execution_time_ms=self._elapsed_ms(start_time),
|
||||
checkpoint_id=checkpoint_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Tool execution failed: %s - %s", tool_call.tool_name, e)
|
||||
return MCPToolResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
safety_decision=SafetyDecision.ALLOW,
|
||||
execution_time_ms=self._elapsed_ms(start_time),
|
||||
checkpoint_id=checkpoint_id,
|
||||
)
|
||||
|
||||
def _build_action_request(
|
||||
self,
|
||||
tool_call: MCPToolCall,
|
||||
agent_id: str,
|
||||
autonomy_level: AutonomyLevel,
|
||||
) -> ActionRequest:
|
||||
"""Build an ActionRequest from an MCP tool call."""
|
||||
action_type = self._classify_tool(tool_call.tool_name)
|
||||
|
||||
metadata = ActionMetadata(
|
||||
agent_id=agent_id,
|
||||
session_id=tool_call.context.get("session_id", ""),
|
||||
project_id=tool_call.project_id or "",
|
||||
autonomy_level=autonomy_level,
|
||||
)
|
||||
|
||||
return ActionRequest(
|
||||
action_type=action_type,
|
||||
tool_name=tool_call.tool_name,
|
||||
arguments=tool_call.arguments,
|
||||
resource=tool_call.arguments.get("path", tool_call.arguments.get("resource")),
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _classify_tool(self, tool_name: str) -> ActionType:
|
||||
"""Classify a tool into an action type."""
|
||||
tool_lower = tool_name.lower()
|
||||
|
||||
# Check destructive patterns
|
||||
if any(d in tool_lower for d in ["write", "create", "delete", "remove", "update"]):
|
||||
if "file" in tool_lower:
|
||||
if "delete" in tool_lower or "remove" in tool_lower:
|
||||
return ActionType.FILE_DELETE
|
||||
return ActionType.FILE_WRITE
|
||||
if "database" in tool_lower or "db" in tool_lower:
|
||||
return ActionType.DATABASE_MUTATE
|
||||
|
||||
# Check read patterns
|
||||
if any(r in tool_lower for r in ["read", "get", "list", "search", "query"]):
|
||||
if "file" in tool_lower:
|
||||
return ActionType.FILE_READ
|
||||
if "database" in tool_lower or "db" in tool_lower:
|
||||
return ActionType.DATABASE_QUERY
|
||||
|
||||
# Check specific types
|
||||
if "shell" in tool_lower or "exec" in tool_lower or "bash" in tool_lower:
|
||||
return ActionType.SHELL_COMMAND
|
||||
|
||||
if "git" in tool_lower:
|
||||
return ActionType.GIT_OPERATION
|
||||
|
||||
if "http" in tool_lower or "fetch" in tool_lower or "request" in tool_lower:
|
||||
return ActionType.NETWORK_REQUEST
|
||||
|
||||
if "llm" in tool_lower or "ai" in tool_lower or "claude" in tool_lower:
|
||||
return ActionType.LLM_CALL
|
||||
|
||||
# Default to tool call
|
||||
return ActionType.TOOL_CALL
|
||||
|
||||
def _elapsed_ms(self, start_time: datetime) -> float:
|
||||
"""Calculate elapsed time in milliseconds."""
|
||||
return (datetime.utcnow() - start_time).total_seconds() * 1000
|
||||
|
||||
|
||||
class SafeToolExecutor:
|
||||
"""
|
||||
Context manager for safe tool execution with automatic cleanup.
|
||||
|
||||
Usage:
|
||||
async with SafeToolExecutor(wrapper, tool_call, agent_id) as executor:
|
||||
result = await executor.execute()
|
||||
if result.success:
|
||||
# Use result
|
||||
else:
|
||||
# Handle error or approval required
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
wrapper: MCPSafetyWrapper,
|
||||
tool_call: MCPToolCall,
|
||||
agent_id: str,
|
||||
autonomy_level: AutonomyLevel = AutonomyLevel.MILESTONE,
|
||||
) -> None:
|
||||
self._wrapper = wrapper
|
||||
self._tool_call = tool_call
|
||||
self._agent_id = agent_id
|
||||
self._autonomy_level = autonomy_level
|
||||
self._result: MCPToolResult | None = None
|
||||
|
||||
async def __aenter__(self) -> "SafeToolExecutor":
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[Exception] | None,
|
||||
exc_val: Exception | None,
|
||||
exc_tb: Any,
|
||||
) -> bool:
|
||||
# Could trigger rollback here if needed
|
||||
return False
|
||||
|
||||
async def execute(self) -> MCPToolResult:
|
||||
"""Execute the tool call."""
|
||||
self._result = await self._wrapper.execute(
|
||||
self._tool_call,
|
||||
self._agent_id,
|
||||
self._autonomy_level,
|
||||
)
|
||||
return self._result
|
||||
|
||||
@property
|
||||
def result(self) -> MCPToolResult | None:
|
||||
"""Get the execution result."""
|
||||
return self._result
|
||||
|
||||
|
||||
# Factory function
|
||||
async def create_mcp_wrapper(
|
||||
guardian: SafetyGuardian | None = None,
|
||||
) -> MCPSafetyWrapper:
|
||||
"""Create an MCPSafetyWrapper with default configuration."""
|
||||
if guardian is None:
|
||||
guardian = await get_safety_guardian()
|
||||
|
||||
return MCPSafetyWrapper(
|
||||
guardian=guardian,
|
||||
emergency_controls=await get_emergency_controls(),
|
||||
)
|
||||
19
backend/app/services/safety/metrics/__init__.py
Normal file
19
backend/app/services/safety/metrics/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Safety metrics collection and export."""
|
||||
|
||||
from .collector import (
|
||||
MetricType,
|
||||
MetricValue,
|
||||
SafetyMetrics,
|
||||
get_safety_metrics,
|
||||
record_mcp_call,
|
||||
record_validation,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MetricType",
|
||||
"MetricValue",
|
||||
"SafetyMetrics",
|
||||
"get_safety_metrics",
|
||||
"record_mcp_call",
|
||||
"record_validation",
|
||||
]
|
||||
416
backend/app/services/safety/metrics/collector.py
Normal file
416
backend/app/services/safety/metrics/collector.py
Normal file
@@ -0,0 +1,416 @@
|
||||
"""
|
||||
Safety Metrics Collector
|
||||
|
||||
Collects and exposes metrics for the safety framework.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import Counter, defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MetricType(str, Enum):
|
||||
"""Types of metrics."""
|
||||
|
||||
COUNTER = "counter"
|
||||
GAUGE = "gauge"
|
||||
HISTOGRAM = "histogram"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetricValue:
|
||||
"""A single metric value."""
|
||||
|
||||
name: str
|
||||
metric_type: MetricType
|
||||
value: float
|
||||
labels: dict[str, str] = field(default_factory=dict)
|
||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HistogramBucket:
|
||||
"""Histogram bucket for distribution metrics."""
|
||||
|
||||
le: float # Less than or equal
|
||||
count: int = 0
|
||||
|
||||
|
||||
class SafetyMetrics:
|
||||
"""
|
||||
Collects safety framework metrics.
|
||||
|
||||
Metrics tracked:
|
||||
- Action validation counts (by decision type)
|
||||
- Approval request counts and latencies
|
||||
- Budget usage and remaining
|
||||
- Rate limit hits
|
||||
- Loop detections
|
||||
- Emergency events
|
||||
- Content filter matches
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize SafetyMetrics."""
|
||||
self._counters: dict[str, Counter[str]] = defaultdict(Counter)
|
||||
self._gauges: dict[str, dict[str, float]] = defaultdict(dict)
|
||||
self._histograms: dict[str, list[float]] = defaultdict(list)
|
||||
self._histogram_buckets: dict[str, list[HistogramBucket]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Initialize histogram buckets
|
||||
self._init_histogram_buckets()
|
||||
|
||||
def _init_histogram_buckets(self) -> None:
|
||||
"""Initialize histogram buckets for latency metrics."""
|
||||
latency_buckets = [0.01, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, float("inf")]
|
||||
|
||||
for name in [
|
||||
"validation_latency_seconds",
|
||||
"approval_latency_seconds",
|
||||
"mcp_execution_latency_seconds",
|
||||
]:
|
||||
self._histogram_buckets[name] = [
|
||||
HistogramBucket(le=b) for b in latency_buckets
|
||||
]
|
||||
|
||||
# Counter methods
|
||||
|
||||
async def inc_validations(
|
||||
self,
|
||||
decision: str,
|
||||
agent_id: str | None = None,
|
||||
) -> None:
|
||||
"""Increment validation counter."""
|
||||
async with self._lock:
|
||||
labels = f"decision={decision}"
|
||||
if agent_id:
|
||||
labels += f",agent_id={agent_id}"
|
||||
self._counters["safety_validations_total"][labels] += 1
|
||||
|
||||
async def inc_approvals_requested(self, urgency: str = "normal") -> None:
|
||||
"""Increment approval requests counter."""
|
||||
async with self._lock:
|
||||
labels = f"urgency={urgency}"
|
||||
self._counters["safety_approvals_requested_total"][labels] += 1
|
||||
|
||||
async def inc_approvals_granted(self) -> None:
|
||||
"""Increment approvals granted counter."""
|
||||
async with self._lock:
|
||||
self._counters["safety_approvals_granted_total"][""] += 1
|
||||
|
||||
async def inc_approvals_denied(self, reason: str = "manual") -> None:
|
||||
"""Increment approvals denied counter."""
|
||||
async with self._lock:
|
||||
labels = f"reason={reason}"
|
||||
self._counters["safety_approvals_denied_total"][labels] += 1
|
||||
|
||||
async def inc_rate_limit_exceeded(self, limit_type: str) -> None:
|
||||
"""Increment rate limit exceeded counter."""
|
||||
async with self._lock:
|
||||
labels = f"limit_type={limit_type}"
|
||||
self._counters["safety_rate_limit_exceeded_total"][labels] += 1
|
||||
|
||||
async def inc_budget_exceeded(self, budget_type: str) -> None:
|
||||
"""Increment budget exceeded counter."""
|
||||
async with self._lock:
|
||||
labels = f"budget_type={budget_type}"
|
||||
self._counters["safety_budget_exceeded_total"][labels] += 1
|
||||
|
||||
async def inc_loops_detected(self, loop_type: str) -> None:
|
||||
"""Increment loop detection counter."""
|
||||
async with self._lock:
|
||||
labels = f"loop_type={loop_type}"
|
||||
self._counters["safety_loops_detected_total"][labels] += 1
|
||||
|
||||
async def inc_emergency_events(self, event_type: str, scope: str) -> None:
|
||||
"""Increment emergency events counter."""
|
||||
async with self._lock:
|
||||
labels = f"event_type={event_type},scope={scope}"
|
||||
self._counters["safety_emergency_events_total"][labels] += 1
|
||||
|
||||
async def inc_content_filtered(self, category: str, action: str) -> None:
|
||||
"""Increment content filter counter."""
|
||||
async with self._lock:
|
||||
labels = f"category={category},action={action}"
|
||||
self._counters["safety_content_filtered_total"][labels] += 1
|
||||
|
||||
async def inc_checkpoints_created(self) -> None:
|
||||
"""Increment checkpoints created counter."""
|
||||
async with self._lock:
|
||||
self._counters["safety_checkpoints_created_total"][""] += 1
|
||||
|
||||
async def inc_rollbacks_executed(self, success: bool) -> None:
|
||||
"""Increment rollbacks counter."""
|
||||
async with self._lock:
|
||||
labels = f"success={str(success).lower()}"
|
||||
self._counters["safety_rollbacks_total"][labels] += 1
|
||||
|
||||
async def inc_mcp_calls(self, tool_name: str, success: bool) -> None:
|
||||
"""Increment MCP tool calls counter."""
|
||||
async with self._lock:
|
||||
labels = f"tool_name={tool_name},success={str(success).lower()}"
|
||||
self._counters["safety_mcp_calls_total"][labels] += 1
|
||||
|
||||
# Gauge methods
|
||||
|
||||
async def set_budget_remaining(
|
||||
self,
|
||||
scope: str,
|
||||
budget_type: str,
|
||||
remaining: float,
|
||||
) -> None:
|
||||
"""Set remaining budget gauge."""
|
||||
async with self._lock:
|
||||
labels = f"scope={scope},budget_type={budget_type}"
|
||||
self._gauges["safety_budget_remaining"][labels] = remaining
|
||||
|
||||
async def set_rate_limit_remaining(
|
||||
self,
|
||||
scope: str,
|
||||
limit_type: str,
|
||||
remaining: int,
|
||||
) -> None:
|
||||
"""Set remaining rate limit gauge."""
|
||||
async with self._lock:
|
||||
labels = f"scope={scope},limit_type={limit_type}"
|
||||
self._gauges["safety_rate_limit_remaining"][labels] = float(remaining)
|
||||
|
||||
async def set_pending_approvals(self, count: int) -> None:
|
||||
"""Set pending approvals gauge."""
|
||||
async with self._lock:
|
||||
self._gauges["safety_pending_approvals"][""] = float(count)
|
||||
|
||||
async def set_active_checkpoints(self, count: int) -> None:
|
||||
"""Set active checkpoints gauge."""
|
||||
async with self._lock:
|
||||
self._gauges["safety_active_checkpoints"][""] = float(count)
|
||||
|
||||
async def set_emergency_state(self, scope: str, state: str) -> None:
|
||||
"""Set emergency state gauge (0=normal, 1=paused, 2=stopped)."""
|
||||
async with self._lock:
|
||||
state_value = {"normal": 0, "paused": 1, "stopped": 2}.get(state, -1)
|
||||
labels = f"scope={scope}"
|
||||
self._gauges["safety_emergency_state"][labels] = float(state_value)
|
||||
|
||||
# Histogram methods
|
||||
|
||||
async def observe_validation_latency(self, latency_seconds: float) -> None:
|
||||
"""Observe validation latency."""
|
||||
async with self._lock:
|
||||
self._observe_histogram("validation_latency_seconds", latency_seconds)
|
||||
|
||||
async def observe_approval_latency(self, latency_seconds: float) -> None:
|
||||
"""Observe approval latency."""
|
||||
async with self._lock:
|
||||
self._observe_histogram("approval_latency_seconds", latency_seconds)
|
||||
|
||||
async def observe_mcp_execution_latency(self, latency_seconds: float) -> None:
|
||||
"""Observe MCP execution latency."""
|
||||
async with self._lock:
|
||||
self._observe_histogram("mcp_execution_latency_seconds", latency_seconds)
|
||||
|
||||
def _observe_histogram(self, name: str, value: float) -> None:
|
||||
"""Record a value in a histogram."""
|
||||
self._histograms[name].append(value)
|
||||
|
||||
# Update buckets
|
||||
if name in self._histogram_buckets:
|
||||
for bucket in self._histogram_buckets[name]:
|
||||
if value <= bucket.le:
|
||||
bucket.count += 1
|
||||
|
||||
# Export methods
|
||||
|
||||
async def get_all_metrics(self) -> list[MetricValue]:
|
||||
"""Get all metrics as MetricValue objects."""
|
||||
metrics: list[MetricValue] = []
|
||||
|
||||
async with self._lock:
|
||||
# Export counters
|
||||
for name, counter in self._counters.items():
|
||||
for labels_str, value in counter.items():
|
||||
labels = self._parse_labels(labels_str)
|
||||
metrics.append(
|
||||
MetricValue(
|
||||
name=name,
|
||||
metric_type=MetricType.COUNTER,
|
||||
value=float(value),
|
||||
labels=labels,
|
||||
)
|
||||
)
|
||||
|
||||
# Export gauges
|
||||
for name, gauge_dict in self._gauges.items():
|
||||
for labels_str, gauge_value in gauge_dict.items():
|
||||
gauge_labels = self._parse_labels(labels_str)
|
||||
metrics.append(
|
||||
MetricValue(
|
||||
name=name,
|
||||
metric_type=MetricType.GAUGE,
|
||||
value=gauge_value,
|
||||
labels=gauge_labels,
|
||||
)
|
||||
)
|
||||
|
||||
# Export histogram summaries
|
||||
for name, values in self._histograms.items():
|
||||
if values:
|
||||
metrics.append(
|
||||
MetricValue(
|
||||
name=f"{name}_count",
|
||||
metric_type=MetricType.COUNTER,
|
||||
value=float(len(values)),
|
||||
)
|
||||
)
|
||||
metrics.append(
|
||||
MetricValue(
|
||||
name=f"{name}_sum",
|
||||
metric_type=MetricType.COUNTER,
|
||||
value=sum(values),
|
||||
)
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
async def get_prometheus_format(self) -> str:
|
||||
"""Export metrics in Prometheus text format."""
|
||||
lines: list[str] = []
|
||||
|
||||
async with self._lock:
|
||||
# Export counters
|
||||
for name, counter in self._counters.items():
|
||||
lines.append(f"# TYPE {name} counter")
|
||||
for labels_str, value in counter.items():
|
||||
if labels_str:
|
||||
lines.append(f"{name}{{{labels_str}}} {value}")
|
||||
else:
|
||||
lines.append(f"{name} {value}")
|
||||
|
||||
# Export gauges
|
||||
for name, gauge_dict in self._gauges.items():
|
||||
lines.append(f"# TYPE {name} gauge")
|
||||
for labels_str, gauge_value in gauge_dict.items():
|
||||
if labels_str:
|
||||
lines.append(f"{name}{{{labels_str}}} {gauge_value}")
|
||||
else:
|
||||
lines.append(f"{name} {gauge_value}")
|
||||
|
||||
# Export histograms
|
||||
for name, buckets in self._histogram_buckets.items():
|
||||
lines.append(f"# TYPE {name} histogram")
|
||||
for bucket in buckets:
|
||||
le_str = "+Inf" if bucket.le == float("inf") else str(bucket.le)
|
||||
lines.append(f'{name}_bucket{{le="{le_str}"}} {bucket.count}')
|
||||
|
||||
if name in self._histograms:
|
||||
values = self._histograms[name]
|
||||
lines.append(f"{name}_count {len(values)}")
|
||||
lines.append(f"{name}_sum {sum(values)}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def get_summary(self) -> dict[str, Any]:
|
||||
"""Get a summary of key metrics."""
|
||||
async with self._lock:
|
||||
total_validations = sum(self._counters["safety_validations_total"].values())
|
||||
denied_validations = sum(
|
||||
v for k, v in self._counters["safety_validations_total"].items()
|
||||
if "decision=deny" in k
|
||||
)
|
||||
|
||||
return {
|
||||
"total_validations": total_validations,
|
||||
"denied_validations": denied_validations,
|
||||
"approval_requests": sum(
|
||||
self._counters["safety_approvals_requested_total"].values()
|
||||
),
|
||||
"approvals_granted": sum(
|
||||
self._counters["safety_approvals_granted_total"].values()
|
||||
),
|
||||
"approvals_denied": sum(
|
||||
self._counters["safety_approvals_denied_total"].values()
|
||||
),
|
||||
"rate_limit_hits": sum(
|
||||
self._counters["safety_rate_limit_exceeded_total"].values()
|
||||
),
|
||||
"budget_exceeded": sum(
|
||||
self._counters["safety_budget_exceeded_total"].values()
|
||||
),
|
||||
"loops_detected": sum(
|
||||
self._counters["safety_loops_detected_total"].values()
|
||||
),
|
||||
"emergency_events": sum(
|
||||
self._counters["safety_emergency_events_total"].values()
|
||||
),
|
||||
"content_filtered": sum(
|
||||
self._counters["safety_content_filtered_total"].values()
|
||||
),
|
||||
"checkpoints_created": sum(
|
||||
self._counters["safety_checkpoints_created_total"].values()
|
||||
),
|
||||
"rollbacks_executed": sum(
|
||||
self._counters["safety_rollbacks_total"].values()
|
||||
),
|
||||
"mcp_calls": sum(
|
||||
self._counters["safety_mcp_calls_total"].values()
|
||||
),
|
||||
"pending_approvals": self._gauges.get("safety_pending_approvals", {}).get("", 0),
|
||||
"active_checkpoints": self._gauges.get("safety_active_checkpoints", {}).get("", 0),
|
||||
}
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Reset all metrics."""
|
||||
async with self._lock:
|
||||
self._counters.clear()
|
||||
self._gauges.clear()
|
||||
self._histograms.clear()
|
||||
self._init_histogram_buckets()
|
||||
|
||||
def _parse_labels(self, labels_str: str) -> dict[str, str]:
|
||||
"""Parse labels string into dictionary."""
|
||||
if not labels_str:
|
||||
return {}
|
||||
|
||||
labels = {}
|
||||
for pair in labels_str.split(","):
|
||||
if "=" in pair:
|
||||
key, value = pair.split("=", 1)
|
||||
labels[key.strip()] = value.strip()
|
||||
|
||||
return labels
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_metrics: SafetyMetrics | None = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_safety_metrics() -> SafetyMetrics:
|
||||
"""Get the singleton SafetyMetrics instance."""
|
||||
global _metrics
|
||||
|
||||
async with _lock:
|
||||
if _metrics is None:
|
||||
_metrics = SafetyMetrics()
|
||||
return _metrics
|
||||
|
||||
|
||||
# Convenience functions
|
||||
async def record_validation(decision: str, agent_id: str | None = None) -> None:
|
||||
"""Record a validation event."""
|
||||
metrics = await get_safety_metrics()
|
||||
await metrics.inc_validations(decision, agent_id)
|
||||
|
||||
|
||||
async def record_mcp_call(tool_name: str, success: bool, latency_ms: float) -> None:
|
||||
"""Record an MCP tool call."""
|
||||
metrics = await get_safety_metrics()
|
||||
await metrics.inc_mcp_calls(tool_name, success)
|
||||
await metrics.observe_mcp_execution_latency(latency_ms / 1000)
|
||||
Reference in New Issue
Block a user