forked from cardosofelipe/fast-next-template
feat(safety): add Phase D MCP integration and metrics
- Add MCPSafetyWrapper for safe MCP tool execution - Add MCPToolCall/MCPToolResult models for MCP interactions - Add SafeToolExecutor context manager - Add SafetyMetrics collector with Prometheus export support - Track validations, approvals, rate limits, budgets, and more - Support for counters, gauges, and histograms Issue #63 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
17
backend/app/services/safety/mcp/__init__.py
Normal file
17
backend/app/services/safety/mcp/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
"""MCP safety integration."""
|
||||||
|
|
||||||
|
from .integration import (
|
||||||
|
MCPSafetyWrapper,
|
||||||
|
MCPToolCall,
|
||||||
|
MCPToolResult,
|
||||||
|
SafeToolExecutor,
|
||||||
|
create_mcp_wrapper,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MCPSafetyWrapper",
|
||||||
|
"MCPToolCall",
|
||||||
|
"MCPToolResult",
|
||||||
|
"SafeToolExecutor",
|
||||||
|
"create_mcp_wrapper",
|
||||||
|
]
|
||||||
405
backend/app/services/safety/mcp/integration.py
Normal file
405
backend/app/services/safety/mcp/integration.py
Normal file
@@ -0,0 +1,405 @@
|
|||||||
|
"""
|
||||||
|
MCP Safety Integration
|
||||||
|
|
||||||
|
Provides safety-aware wrappers for MCP tool execution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from collections.abc import Callable
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, ClassVar, TypeVar
|
||||||
|
|
||||||
|
from ..audit import AuditLogger
|
||||||
|
from ..emergency import EmergencyControls, get_emergency_controls
|
||||||
|
from ..exceptions import (
|
||||||
|
EmergencyStopError,
|
||||||
|
SafetyError,
|
||||||
|
)
|
||||||
|
from ..guardian import SafetyGuardian, get_safety_guardian
|
||||||
|
from ..models import (
|
||||||
|
ActionMetadata,
|
||||||
|
ActionRequest,
|
||||||
|
ActionType,
|
||||||
|
AutonomyLevel,
|
||||||
|
SafetyDecision,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MCPToolCall:
|
||||||
|
"""Represents an MCP tool call."""
|
||||||
|
|
||||||
|
tool_name: str
|
||||||
|
arguments: dict[str, Any]
|
||||||
|
server_name: str | None = None
|
||||||
|
project_id: str | None = None
|
||||||
|
context: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MCPToolResult:
|
||||||
|
"""Result of an MCP tool execution."""
|
||||||
|
|
||||||
|
success: bool
|
||||||
|
result: Any = None
|
||||||
|
error: str | None = None
|
||||||
|
safety_decision: SafetyDecision = SafetyDecision.ALLOW
|
||||||
|
execution_time_ms: float = 0.0
|
||||||
|
approval_id: str | None = None
|
||||||
|
checkpoint_id: str | None = None
|
||||||
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPSafetyWrapper:
|
||||||
|
"""
|
||||||
|
Wraps MCP tool execution with safety checks.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Pre-execution validation via SafetyGuardian
|
||||||
|
- Permission checking per tool/resource
|
||||||
|
- Budget and rate limit enforcement
|
||||||
|
- Audit logging of all MCP calls
|
||||||
|
- Emergency stop integration
|
||||||
|
- Checkpoint creation for destructive operations
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Tool categories for automatic classification
|
||||||
|
DESTRUCTIVE_TOOLS: ClassVar[set[str]] = {
|
||||||
|
"file_write",
|
||||||
|
"file_delete",
|
||||||
|
"database_mutate",
|
||||||
|
"shell_execute",
|
||||||
|
"git_push",
|
||||||
|
"git_commit",
|
||||||
|
"deploy",
|
||||||
|
}
|
||||||
|
|
||||||
|
READ_ONLY_TOOLS: ClassVar[set[str]] = {
|
||||||
|
"file_read",
|
||||||
|
"database_query",
|
||||||
|
"git_status",
|
||||||
|
"git_log",
|
||||||
|
"list_files",
|
||||||
|
"search",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
guardian: SafetyGuardian | None = None,
|
||||||
|
audit_logger: AuditLogger | None = None,
|
||||||
|
emergency_controls: EmergencyControls | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize MCPSafetyWrapper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
guardian: SafetyGuardian instance (uses singleton if not provided)
|
||||||
|
audit_logger: AuditLogger instance
|
||||||
|
emergency_controls: EmergencyControls instance
|
||||||
|
"""
|
||||||
|
self._guardian = guardian
|
||||||
|
self._audit_logger = audit_logger
|
||||||
|
self._emergency_controls = emergency_controls
|
||||||
|
self._tool_handlers: dict[str, Callable[..., Any]] = {}
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def _get_guardian(self) -> SafetyGuardian:
|
||||||
|
"""Get or create SafetyGuardian."""
|
||||||
|
if self._guardian is None:
|
||||||
|
self._guardian = await get_safety_guardian()
|
||||||
|
return self._guardian
|
||||||
|
|
||||||
|
async def _get_emergency_controls(self) -> EmergencyControls:
|
||||||
|
"""Get or create EmergencyControls."""
|
||||||
|
if self._emergency_controls is None:
|
||||||
|
self._emergency_controls = await get_emergency_controls()
|
||||||
|
return self._emergency_controls
|
||||||
|
|
||||||
|
def register_tool_handler(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
handler: Callable[..., Any],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Register a handler for a tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the tool
|
||||||
|
handler: Async function to handle the tool call
|
||||||
|
"""
|
||||||
|
self._tool_handlers[tool_name] = handler
|
||||||
|
logger.debug("Registered handler for tool: %s", tool_name)
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
tool_call: MCPToolCall,
|
||||||
|
agent_id: str,
|
||||||
|
autonomy_level: AutonomyLevel = AutonomyLevel.MILESTONE,
|
||||||
|
bypass_safety: bool = False,
|
||||||
|
) -> MCPToolResult:
|
||||||
|
"""
|
||||||
|
Execute an MCP tool call with safety checks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_call: The tool call to execute
|
||||||
|
agent_id: ID of the calling agent
|
||||||
|
autonomy_level: Agent's autonomy level
|
||||||
|
bypass_safety: Bypass safety checks (emergency only)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MCPToolResult with execution outcome
|
||||||
|
"""
|
||||||
|
start_time = datetime.utcnow()
|
||||||
|
|
||||||
|
# Check emergency controls first
|
||||||
|
emergency = await self._get_emergency_controls()
|
||||||
|
scope = f"agent:{agent_id}"
|
||||||
|
if tool_call.project_id:
|
||||||
|
scope = f"project:{tool_call.project_id}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
await emergency.check_allowed(scope=scope, raise_if_blocked=True)
|
||||||
|
except EmergencyStopError as e:
|
||||||
|
return MCPToolResult(
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
safety_decision=SafetyDecision.DENY,
|
||||||
|
metadata={"emergency_stop": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build action request
|
||||||
|
action = self._build_action_request(
|
||||||
|
tool_call=tool_call,
|
||||||
|
agent_id=agent_id,
|
||||||
|
autonomy_level=autonomy_level,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Skip safety checks if bypass is enabled
|
||||||
|
if bypass_safety:
|
||||||
|
logger.warning(
|
||||||
|
"Safety bypass enabled for tool: %s (agent: %s)",
|
||||||
|
tool_call.tool_name,
|
||||||
|
agent_id,
|
||||||
|
)
|
||||||
|
return await self._execute_tool(tool_call, action, start_time)
|
||||||
|
|
||||||
|
# Run safety validation
|
||||||
|
guardian = await self._get_guardian()
|
||||||
|
try:
|
||||||
|
guardian_result = await guardian.validate(action)
|
||||||
|
except SafetyError as e:
|
||||||
|
return MCPToolResult(
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
safety_decision=SafetyDecision.DENY,
|
||||||
|
execution_time_ms=self._elapsed_ms(start_time),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle safety decision
|
||||||
|
if guardian_result.decision == SafetyDecision.DENY:
|
||||||
|
return MCPToolResult(
|
||||||
|
success=False,
|
||||||
|
error="; ".join(guardian_result.reasons),
|
||||||
|
safety_decision=SafetyDecision.DENY,
|
||||||
|
execution_time_ms=self._elapsed_ms(start_time),
|
||||||
|
)
|
||||||
|
|
||||||
|
if guardian_result.decision == SafetyDecision.REQUIRE_APPROVAL:
|
||||||
|
# For now, just return that approval is required
|
||||||
|
# The caller should handle the approval flow
|
||||||
|
return MCPToolResult(
|
||||||
|
success=False,
|
||||||
|
error="Action requires human approval",
|
||||||
|
safety_decision=SafetyDecision.REQUIRE_APPROVAL,
|
||||||
|
approval_id=guardian_result.approval_id,
|
||||||
|
execution_time_ms=self._elapsed_ms(start_time),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute the tool
|
||||||
|
result = await self._execute_tool(
|
||||||
|
tool_call,
|
||||||
|
action,
|
||||||
|
start_time,
|
||||||
|
checkpoint_id=guardian_result.checkpoint_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def _execute_tool(
|
||||||
|
self,
|
||||||
|
tool_call: MCPToolCall,
|
||||||
|
action: ActionRequest,
|
||||||
|
start_time: datetime,
|
||||||
|
checkpoint_id: str | None = None,
|
||||||
|
) -> MCPToolResult:
|
||||||
|
"""Execute the actual tool call."""
|
||||||
|
handler = self._tool_handlers.get(tool_call.tool_name)
|
||||||
|
|
||||||
|
if handler is None:
|
||||||
|
return MCPToolResult(
|
||||||
|
success=False,
|
||||||
|
error=f"No handler registered for tool: {tool_call.tool_name}",
|
||||||
|
safety_decision=SafetyDecision.ALLOW,
|
||||||
|
execution_time_ms=self._elapsed_ms(start_time),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if asyncio.iscoroutinefunction(handler):
|
||||||
|
result = await handler(**tool_call.arguments)
|
||||||
|
else:
|
||||||
|
result = handler(**tool_call.arguments)
|
||||||
|
|
||||||
|
return MCPToolResult(
|
||||||
|
success=True,
|
||||||
|
result=result,
|
||||||
|
safety_decision=SafetyDecision.ALLOW,
|
||||||
|
execution_time_ms=self._elapsed_ms(start_time),
|
||||||
|
checkpoint_id=checkpoint_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Tool execution failed: %s - %s", tool_call.tool_name, e)
|
||||||
|
return MCPToolResult(
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
safety_decision=SafetyDecision.ALLOW,
|
||||||
|
execution_time_ms=self._elapsed_ms(start_time),
|
||||||
|
checkpoint_id=checkpoint_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_action_request(
|
||||||
|
self,
|
||||||
|
tool_call: MCPToolCall,
|
||||||
|
agent_id: str,
|
||||||
|
autonomy_level: AutonomyLevel,
|
||||||
|
) -> ActionRequest:
|
||||||
|
"""Build an ActionRequest from an MCP tool call."""
|
||||||
|
action_type = self._classify_tool(tool_call.tool_name)
|
||||||
|
|
||||||
|
metadata = ActionMetadata(
|
||||||
|
agent_id=agent_id,
|
||||||
|
session_id=tool_call.context.get("session_id", ""),
|
||||||
|
project_id=tool_call.project_id or "",
|
||||||
|
autonomy_level=autonomy_level,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ActionRequest(
|
||||||
|
action_type=action_type,
|
||||||
|
tool_name=tool_call.tool_name,
|
||||||
|
arguments=tool_call.arguments,
|
||||||
|
resource=tool_call.arguments.get("path", tool_call.arguments.get("resource")),
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _classify_tool(self, tool_name: str) -> ActionType:
|
||||||
|
"""Classify a tool into an action type."""
|
||||||
|
tool_lower = tool_name.lower()
|
||||||
|
|
||||||
|
# Check destructive patterns
|
||||||
|
if any(d in tool_lower for d in ["write", "create", "delete", "remove", "update"]):
|
||||||
|
if "file" in tool_lower:
|
||||||
|
if "delete" in tool_lower or "remove" in tool_lower:
|
||||||
|
return ActionType.FILE_DELETE
|
||||||
|
return ActionType.FILE_WRITE
|
||||||
|
if "database" in tool_lower or "db" in tool_lower:
|
||||||
|
return ActionType.DATABASE_MUTATE
|
||||||
|
|
||||||
|
# Check read patterns
|
||||||
|
if any(r in tool_lower for r in ["read", "get", "list", "search", "query"]):
|
||||||
|
if "file" in tool_lower:
|
||||||
|
return ActionType.FILE_READ
|
||||||
|
if "database" in tool_lower or "db" in tool_lower:
|
||||||
|
return ActionType.DATABASE_QUERY
|
||||||
|
|
||||||
|
# Check specific types
|
||||||
|
if "shell" in tool_lower or "exec" in tool_lower or "bash" in tool_lower:
|
||||||
|
return ActionType.SHELL_COMMAND
|
||||||
|
|
||||||
|
if "git" in tool_lower:
|
||||||
|
return ActionType.GIT_OPERATION
|
||||||
|
|
||||||
|
if "http" in tool_lower or "fetch" in tool_lower or "request" in tool_lower:
|
||||||
|
return ActionType.NETWORK_REQUEST
|
||||||
|
|
||||||
|
if "llm" in tool_lower or "ai" in tool_lower or "claude" in tool_lower:
|
||||||
|
return ActionType.LLM_CALL
|
||||||
|
|
||||||
|
# Default to tool call
|
||||||
|
return ActionType.TOOL_CALL
|
||||||
|
|
||||||
|
def _elapsed_ms(self, start_time: datetime) -> float:
|
||||||
|
"""Calculate elapsed time in milliseconds."""
|
||||||
|
return (datetime.utcnow() - start_time).total_seconds() * 1000
|
||||||
|
|
||||||
|
|
||||||
|
class SafeToolExecutor:
|
||||||
|
"""
|
||||||
|
Context manager for safe tool execution with automatic cleanup.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
async with SafeToolExecutor(wrapper, tool_call, agent_id) as executor:
|
||||||
|
result = await executor.execute()
|
||||||
|
if result.success:
|
||||||
|
# Use result
|
||||||
|
else:
|
||||||
|
# Handle error or approval required
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
wrapper: MCPSafetyWrapper,
|
||||||
|
tool_call: MCPToolCall,
|
||||||
|
agent_id: str,
|
||||||
|
autonomy_level: AutonomyLevel = AutonomyLevel.MILESTONE,
|
||||||
|
) -> None:
|
||||||
|
self._wrapper = wrapper
|
||||||
|
self._tool_call = tool_call
|
||||||
|
self._agent_id = agent_id
|
||||||
|
self._autonomy_level = autonomy_level
|
||||||
|
self._result: MCPToolResult | None = None
|
||||||
|
|
||||||
|
async def __aenter__(self) -> "SafeToolExecutor":
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[Exception] | None,
|
||||||
|
exc_val: Exception | None,
|
||||||
|
exc_tb: Any,
|
||||||
|
) -> bool:
|
||||||
|
# Could trigger rollback here if needed
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def execute(self) -> MCPToolResult:
|
||||||
|
"""Execute the tool call."""
|
||||||
|
self._result = await self._wrapper.execute(
|
||||||
|
self._tool_call,
|
||||||
|
self._agent_id,
|
||||||
|
self._autonomy_level,
|
||||||
|
)
|
||||||
|
return self._result
|
||||||
|
|
||||||
|
@property
|
||||||
|
def result(self) -> MCPToolResult | None:
|
||||||
|
"""Get the execution result."""
|
||||||
|
return self._result
|
||||||
|
|
||||||
|
|
||||||
|
# Factory function
|
||||||
|
async def create_mcp_wrapper(
|
||||||
|
guardian: SafetyGuardian | None = None,
|
||||||
|
) -> MCPSafetyWrapper:
|
||||||
|
"""Create an MCPSafetyWrapper with default configuration."""
|
||||||
|
if guardian is None:
|
||||||
|
guardian = await get_safety_guardian()
|
||||||
|
|
||||||
|
return MCPSafetyWrapper(
|
||||||
|
guardian=guardian,
|
||||||
|
emergency_controls=await get_emergency_controls(),
|
||||||
|
)
|
||||||
19
backend/app/services/safety/metrics/__init__.py
Normal file
19
backend/app/services/safety/metrics/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
"""Safety metrics collection and export."""
|
||||||
|
|
||||||
|
from .collector import (
|
||||||
|
MetricType,
|
||||||
|
MetricValue,
|
||||||
|
SafetyMetrics,
|
||||||
|
get_safety_metrics,
|
||||||
|
record_mcp_call,
|
||||||
|
record_validation,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MetricType",
|
||||||
|
"MetricValue",
|
||||||
|
"SafetyMetrics",
|
||||||
|
"get_safety_metrics",
|
||||||
|
"record_mcp_call",
|
||||||
|
"record_validation",
|
||||||
|
]
|
||||||
416
backend/app/services/safety/metrics/collector.py
Normal file
416
backend/app/services/safety/metrics/collector.py
Normal file
@@ -0,0 +1,416 @@
|
|||||||
|
"""
|
||||||
|
Safety Metrics Collector
|
||||||
|
|
||||||
|
Collects and exposes metrics for the safety framework.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from collections import Counter, defaultdict
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MetricType(str, Enum):
|
||||||
|
"""Types of metrics."""
|
||||||
|
|
||||||
|
COUNTER = "counter"
|
||||||
|
GAUGE = "gauge"
|
||||||
|
HISTOGRAM = "histogram"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MetricValue:
|
||||||
|
"""A single metric value."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
metric_type: MetricType
|
||||||
|
value: float
|
||||||
|
labels: dict[str, str] = field(default_factory=dict)
|
||||||
|
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HistogramBucket:
|
||||||
|
"""Histogram bucket for distribution metrics."""
|
||||||
|
|
||||||
|
le: float # Less than or equal
|
||||||
|
count: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class SafetyMetrics:
|
||||||
|
"""
|
||||||
|
Collects safety framework metrics.
|
||||||
|
|
||||||
|
Metrics tracked:
|
||||||
|
- Action validation counts (by decision type)
|
||||||
|
- Approval request counts and latencies
|
||||||
|
- Budget usage and remaining
|
||||||
|
- Rate limit hits
|
||||||
|
- Loop detections
|
||||||
|
- Emergency events
|
||||||
|
- Content filter matches
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize SafetyMetrics."""
|
||||||
|
self._counters: dict[str, Counter[str]] = defaultdict(Counter)
|
||||||
|
self._gauges: dict[str, dict[str, float]] = defaultdict(dict)
|
||||||
|
self._histograms: dict[str, list[float]] = defaultdict(list)
|
||||||
|
self._histogram_buckets: dict[str, list[HistogramBucket]] = {}
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
# Initialize histogram buckets
|
||||||
|
self._init_histogram_buckets()
|
||||||
|
|
||||||
|
def _init_histogram_buckets(self) -> None:
|
||||||
|
"""Initialize histogram buckets for latency metrics."""
|
||||||
|
latency_buckets = [0.01, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, float("inf")]
|
||||||
|
|
||||||
|
for name in [
|
||||||
|
"validation_latency_seconds",
|
||||||
|
"approval_latency_seconds",
|
||||||
|
"mcp_execution_latency_seconds",
|
||||||
|
]:
|
||||||
|
self._histogram_buckets[name] = [
|
||||||
|
HistogramBucket(le=b) for b in latency_buckets
|
||||||
|
]
|
||||||
|
|
||||||
|
# Counter methods
|
||||||
|
|
||||||
|
async def inc_validations(
|
||||||
|
self,
|
||||||
|
decision: str,
|
||||||
|
agent_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Increment validation counter."""
|
||||||
|
async with self._lock:
|
||||||
|
labels = f"decision={decision}"
|
||||||
|
if agent_id:
|
||||||
|
labels += f",agent_id={agent_id}"
|
||||||
|
self._counters["safety_validations_total"][labels] += 1
|
||||||
|
|
||||||
|
async def inc_approvals_requested(self, urgency: str = "normal") -> None:
|
||||||
|
"""Increment approval requests counter."""
|
||||||
|
async with self._lock:
|
||||||
|
labels = f"urgency={urgency}"
|
||||||
|
self._counters["safety_approvals_requested_total"][labels] += 1
|
||||||
|
|
||||||
|
async def inc_approvals_granted(self) -> None:
|
||||||
|
"""Increment approvals granted counter."""
|
||||||
|
async with self._lock:
|
||||||
|
self._counters["safety_approvals_granted_total"][""] += 1
|
||||||
|
|
||||||
|
async def inc_approvals_denied(self, reason: str = "manual") -> None:
|
||||||
|
"""Increment approvals denied counter."""
|
||||||
|
async with self._lock:
|
||||||
|
labels = f"reason={reason}"
|
||||||
|
self._counters["safety_approvals_denied_total"][labels] += 1
|
||||||
|
|
||||||
|
async def inc_rate_limit_exceeded(self, limit_type: str) -> None:
|
||||||
|
"""Increment rate limit exceeded counter."""
|
||||||
|
async with self._lock:
|
||||||
|
labels = f"limit_type={limit_type}"
|
||||||
|
self._counters["safety_rate_limit_exceeded_total"][labels] += 1
|
||||||
|
|
||||||
|
async def inc_budget_exceeded(self, budget_type: str) -> None:
|
||||||
|
"""Increment budget exceeded counter."""
|
||||||
|
async with self._lock:
|
||||||
|
labels = f"budget_type={budget_type}"
|
||||||
|
self._counters["safety_budget_exceeded_total"][labels] += 1
|
||||||
|
|
||||||
|
async def inc_loops_detected(self, loop_type: str) -> None:
|
||||||
|
"""Increment loop detection counter."""
|
||||||
|
async with self._lock:
|
||||||
|
labels = f"loop_type={loop_type}"
|
||||||
|
self._counters["safety_loops_detected_total"][labels] += 1
|
||||||
|
|
||||||
|
async def inc_emergency_events(self, event_type: str, scope: str) -> None:
|
||||||
|
"""Increment emergency events counter."""
|
||||||
|
async with self._lock:
|
||||||
|
labels = f"event_type={event_type},scope={scope}"
|
||||||
|
self._counters["safety_emergency_events_total"][labels] += 1
|
||||||
|
|
||||||
|
async def inc_content_filtered(self, category: str, action: str) -> None:
|
||||||
|
"""Increment content filter counter."""
|
||||||
|
async with self._lock:
|
||||||
|
labels = f"category={category},action={action}"
|
||||||
|
self._counters["safety_content_filtered_total"][labels] += 1
|
||||||
|
|
||||||
|
async def inc_checkpoints_created(self) -> None:
|
||||||
|
"""Increment checkpoints created counter."""
|
||||||
|
async with self._lock:
|
||||||
|
self._counters["safety_checkpoints_created_total"][""] += 1
|
||||||
|
|
||||||
|
async def inc_rollbacks_executed(self, success: bool) -> None:
|
||||||
|
"""Increment rollbacks counter."""
|
||||||
|
async with self._lock:
|
||||||
|
labels = f"success={str(success).lower()}"
|
||||||
|
self._counters["safety_rollbacks_total"][labels] += 1
|
||||||
|
|
||||||
|
async def inc_mcp_calls(self, tool_name: str, success: bool) -> None:
|
||||||
|
"""Increment MCP tool calls counter."""
|
||||||
|
async with self._lock:
|
||||||
|
labels = f"tool_name={tool_name},success={str(success).lower()}"
|
||||||
|
self._counters["safety_mcp_calls_total"][labels] += 1
|
||||||
|
|
||||||
|
# Gauge methods
|
||||||
|
|
||||||
|
async def set_budget_remaining(
|
||||||
|
self,
|
||||||
|
scope: str,
|
||||||
|
budget_type: str,
|
||||||
|
remaining: float,
|
||||||
|
) -> None:
|
||||||
|
"""Set remaining budget gauge."""
|
||||||
|
async with self._lock:
|
||||||
|
labels = f"scope={scope},budget_type={budget_type}"
|
||||||
|
self._gauges["safety_budget_remaining"][labels] = remaining
|
||||||
|
|
||||||
|
async def set_rate_limit_remaining(
|
||||||
|
self,
|
||||||
|
scope: str,
|
||||||
|
limit_type: str,
|
||||||
|
remaining: int,
|
||||||
|
) -> None:
|
||||||
|
"""Set remaining rate limit gauge."""
|
||||||
|
async with self._lock:
|
||||||
|
labels = f"scope={scope},limit_type={limit_type}"
|
||||||
|
self._gauges["safety_rate_limit_remaining"][labels] = float(remaining)
|
||||||
|
|
||||||
|
async def set_pending_approvals(self, count: int) -> None:
|
||||||
|
"""Set pending approvals gauge."""
|
||||||
|
async with self._lock:
|
||||||
|
self._gauges["safety_pending_approvals"][""] = float(count)
|
||||||
|
|
||||||
|
async def set_active_checkpoints(self, count: int) -> None:
|
||||||
|
"""Set active checkpoints gauge."""
|
||||||
|
async with self._lock:
|
||||||
|
self._gauges["safety_active_checkpoints"][""] = float(count)
|
||||||
|
|
||||||
|
async def set_emergency_state(self, scope: str, state: str) -> None:
|
||||||
|
"""Set emergency state gauge (0=normal, 1=paused, 2=stopped)."""
|
||||||
|
async with self._lock:
|
||||||
|
state_value = {"normal": 0, "paused": 1, "stopped": 2}.get(state, -1)
|
||||||
|
labels = f"scope={scope}"
|
||||||
|
self._gauges["safety_emergency_state"][labels] = float(state_value)
|
||||||
|
|
||||||
|
# Histogram methods
|
||||||
|
|
||||||
|
async def observe_validation_latency(self, latency_seconds: float) -> None:
|
||||||
|
"""Observe validation latency."""
|
||||||
|
async with self._lock:
|
||||||
|
self._observe_histogram("validation_latency_seconds", latency_seconds)
|
||||||
|
|
||||||
|
async def observe_approval_latency(self, latency_seconds: float) -> None:
|
||||||
|
"""Observe approval latency."""
|
||||||
|
async with self._lock:
|
||||||
|
self._observe_histogram("approval_latency_seconds", latency_seconds)
|
||||||
|
|
||||||
|
async def observe_mcp_execution_latency(self, latency_seconds: float) -> None:
|
||||||
|
"""Observe MCP execution latency."""
|
||||||
|
async with self._lock:
|
||||||
|
self._observe_histogram("mcp_execution_latency_seconds", latency_seconds)
|
||||||
|
|
||||||
|
def _observe_histogram(self, name: str, value: float) -> None:
|
||||||
|
"""Record a value in a histogram."""
|
||||||
|
self._histograms[name].append(value)
|
||||||
|
|
||||||
|
# Update buckets
|
||||||
|
if name in self._histogram_buckets:
|
||||||
|
for bucket in self._histogram_buckets[name]:
|
||||||
|
if value <= bucket.le:
|
||||||
|
bucket.count += 1
|
||||||
|
|
||||||
|
# Export methods
|
||||||
|
|
||||||
|
async def get_all_metrics(self) -> list[MetricValue]:
|
||||||
|
"""Get all metrics as MetricValue objects."""
|
||||||
|
metrics: list[MetricValue] = []
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
# Export counters
|
||||||
|
for name, counter in self._counters.items():
|
||||||
|
for labels_str, value in counter.items():
|
||||||
|
labels = self._parse_labels(labels_str)
|
||||||
|
metrics.append(
|
||||||
|
MetricValue(
|
||||||
|
name=name,
|
||||||
|
metric_type=MetricType.COUNTER,
|
||||||
|
value=float(value),
|
||||||
|
labels=labels,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Export gauges
|
||||||
|
for name, gauge_dict in self._gauges.items():
|
||||||
|
for labels_str, gauge_value in gauge_dict.items():
|
||||||
|
gauge_labels = self._parse_labels(labels_str)
|
||||||
|
metrics.append(
|
||||||
|
MetricValue(
|
||||||
|
name=name,
|
||||||
|
metric_type=MetricType.GAUGE,
|
||||||
|
value=gauge_value,
|
||||||
|
labels=gauge_labels,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Export histogram summaries
|
||||||
|
for name, values in self._histograms.items():
|
||||||
|
if values:
|
||||||
|
metrics.append(
|
||||||
|
MetricValue(
|
||||||
|
name=f"{name}_count",
|
||||||
|
metric_type=MetricType.COUNTER,
|
||||||
|
value=float(len(values)),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
metrics.append(
|
||||||
|
MetricValue(
|
||||||
|
name=f"{name}_sum",
|
||||||
|
metric_type=MetricType.COUNTER,
|
||||||
|
value=sum(values),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
async def get_prometheus_format(self) -> str:
|
||||||
|
"""Export metrics in Prometheus text format."""
|
||||||
|
lines: list[str] = []
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
# Export counters
|
||||||
|
for name, counter in self._counters.items():
|
||||||
|
lines.append(f"# TYPE {name} counter")
|
||||||
|
for labels_str, value in counter.items():
|
||||||
|
if labels_str:
|
||||||
|
lines.append(f"{name}{{{labels_str}}} {value}")
|
||||||
|
else:
|
||||||
|
lines.append(f"{name} {value}")
|
||||||
|
|
||||||
|
# Export gauges
|
||||||
|
for name, gauge_dict in self._gauges.items():
|
||||||
|
lines.append(f"# TYPE {name} gauge")
|
||||||
|
for labels_str, gauge_value in gauge_dict.items():
|
||||||
|
if labels_str:
|
||||||
|
lines.append(f"{name}{{{labels_str}}} {gauge_value}")
|
||||||
|
else:
|
||||||
|
lines.append(f"{name} {gauge_value}")
|
||||||
|
|
||||||
|
# Export histograms
|
||||||
|
for name, buckets in self._histogram_buckets.items():
|
||||||
|
lines.append(f"# TYPE {name} histogram")
|
||||||
|
for bucket in buckets:
|
||||||
|
le_str = "+Inf" if bucket.le == float("inf") else str(bucket.le)
|
||||||
|
lines.append(f'{name}_bucket{{le="{le_str}"}} {bucket.count}')
|
||||||
|
|
||||||
|
if name in self._histograms:
|
||||||
|
values = self._histograms[name]
|
||||||
|
lines.append(f"{name}_count {len(values)}")
|
||||||
|
lines.append(f"{name}_sum {sum(values)}")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
async def get_summary(self) -> dict[str, Any]:
|
||||||
|
"""Get a summary of key metrics."""
|
||||||
|
async with self._lock:
|
||||||
|
total_validations = sum(self._counters["safety_validations_total"].values())
|
||||||
|
denied_validations = sum(
|
||||||
|
v for k, v in self._counters["safety_validations_total"].items()
|
||||||
|
if "decision=deny" in k
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_validations": total_validations,
|
||||||
|
"denied_validations": denied_validations,
|
||||||
|
"approval_requests": sum(
|
||||||
|
self._counters["safety_approvals_requested_total"].values()
|
||||||
|
),
|
||||||
|
"approvals_granted": sum(
|
||||||
|
self._counters["safety_approvals_granted_total"].values()
|
||||||
|
),
|
||||||
|
"approvals_denied": sum(
|
||||||
|
self._counters["safety_approvals_denied_total"].values()
|
||||||
|
),
|
||||||
|
"rate_limit_hits": sum(
|
||||||
|
self._counters["safety_rate_limit_exceeded_total"].values()
|
||||||
|
),
|
||||||
|
"budget_exceeded": sum(
|
||||||
|
self._counters["safety_budget_exceeded_total"].values()
|
||||||
|
),
|
||||||
|
"loops_detected": sum(
|
||||||
|
self._counters["safety_loops_detected_total"].values()
|
||||||
|
),
|
||||||
|
"emergency_events": sum(
|
||||||
|
self._counters["safety_emergency_events_total"].values()
|
||||||
|
),
|
||||||
|
"content_filtered": sum(
|
||||||
|
self._counters["safety_content_filtered_total"].values()
|
||||||
|
),
|
||||||
|
"checkpoints_created": sum(
|
||||||
|
self._counters["safety_checkpoints_created_total"].values()
|
||||||
|
),
|
||||||
|
"rollbacks_executed": sum(
|
||||||
|
self._counters["safety_rollbacks_total"].values()
|
||||||
|
),
|
||||||
|
"mcp_calls": sum(
|
||||||
|
self._counters["safety_mcp_calls_total"].values()
|
||||||
|
),
|
||||||
|
"pending_approvals": self._gauges.get("safety_pending_approvals", {}).get("", 0),
|
||||||
|
"active_checkpoints": self._gauges.get("safety_active_checkpoints", {}).get("", 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def reset(self) -> None:
|
||||||
|
"""Reset all metrics."""
|
||||||
|
async with self._lock:
|
||||||
|
self._counters.clear()
|
||||||
|
self._gauges.clear()
|
||||||
|
self._histograms.clear()
|
||||||
|
self._init_histogram_buckets()
|
||||||
|
|
||||||
|
def _parse_labels(self, labels_str: str) -> dict[str, str]:
|
||||||
|
"""Parse labels string into dictionary."""
|
||||||
|
if not labels_str:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
labels = {}
|
||||||
|
for pair in labels_str.split(","):
|
||||||
|
if "=" in pair:
|
||||||
|
key, value = pair.split("=", 1)
|
||||||
|
labels[key.strip()] = value.strip()
|
||||||
|
|
||||||
|
return labels
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
_metrics: SafetyMetrics | None = None
|
||||||
|
_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_safety_metrics() -> SafetyMetrics:
|
||||||
|
"""Get the singleton SafetyMetrics instance."""
|
||||||
|
global _metrics
|
||||||
|
|
||||||
|
async with _lock:
|
||||||
|
if _metrics is None:
|
||||||
|
_metrics = SafetyMetrics()
|
||||||
|
return _metrics
|
||||||
|
|
||||||
|
|
||||||
|
# Convenience functions
|
||||||
|
async def record_validation(decision: str, agent_id: str | None = None) -> None:
|
||||||
|
"""Record a validation event."""
|
||||||
|
metrics = await get_safety_metrics()
|
||||||
|
await metrics.inc_validations(decision, agent_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def record_mcp_call(tool_name: str, success: bool, latency_ms: float) -> None:
|
||||||
|
"""Record an MCP tool call."""
|
||||||
|
metrics = await get_safety_metrics()
|
||||||
|
await metrics.inc_mcp_calls(tool_name, success)
|
||||||
|
await metrics.observe_mcp_execution_latency(latency_ms / 1000)
|
||||||
Reference in New Issue
Block a user