forked from cardosofelipe/fast-next-template
feat(backend): add safety framework foundation (Phase A) (#63)
Core safety framework architecture for autonomous agent guardrails: **Core Components:** - SafetyGuardian: Main orchestrator for all safety checks - AuditLogger: Comprehensive audit logging with hash chain tamper detection - SafetyConfig: Pydantic-based configuration - Models: Action requests, validation results, policies, checkpoints **Exception Hierarchy:** - SafetyError base with context preservation - Permission, Budget, RateLimit, Loop errors - Approval workflow errors (Required, Denied, Timeout) - Rollback, Sandbox, Emergency exceptions **Safety Policy System:** - Autonomy level based policies (FULL_CONTROL, MILESTONE, AUTONOMOUS) - Cost limits, rate limits, permission patterns - HITL approval requirements per action type - Configurable loop detection thresholds **Directory Structure:** - validation/, costs/, limits/, loops/ - Control subsystems - permissions/, rollback/, hitl/ - Access and recovery - content/, sandbox/, emergency/ - Protection systems - audit/, policies/ - Logging and configuration Phase A establishes the architecture. Subsystems to be implemented in Phase B-C. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
170
backend/app/services/safety/__init__.py
Normal file
170
backend/app/services/safety/__init__.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
Safety and Guardrails Framework
|
||||
|
||||
Comprehensive safety framework for autonomous agent operation.
|
||||
Provides multi-layered protection including:
|
||||
- Pre-execution validation
|
||||
- Cost and budget controls
|
||||
- Rate limiting
|
||||
- Loop detection and prevention
|
||||
- Human-in-the-loop approval
|
||||
- Rollback and checkpointing
|
||||
- Content filtering
|
||||
- Sandboxed execution
|
||||
- Emergency controls
|
||||
- Complete audit trail
|
||||
|
||||
Usage:
|
||||
from app.services.safety import get_safety_guardian, SafetyGuardian
|
||||
|
||||
guardian = await get_safety_guardian()
|
||||
result = await guardian.validate(action_request)
|
||||
|
||||
if result.allowed:
|
||||
# Execute action
|
||||
pass
|
||||
else:
|
||||
# Handle denial
|
||||
print(f"Action denied: {result.reasons}")
|
||||
"""
|
||||
|
||||
# Exceptions
|
||||
# Audit
|
||||
from .audit import (
|
||||
AuditLogger,
|
||||
get_audit_logger,
|
||||
reset_audit_logger,
|
||||
shutdown_audit_logger,
|
||||
)
|
||||
|
||||
# Configuration
|
||||
from .config import (
|
||||
AutonomyConfig,
|
||||
SafetyConfig,
|
||||
get_autonomy_config,
|
||||
get_default_policy,
|
||||
get_policy_for_autonomy_level,
|
||||
get_safety_config,
|
||||
load_policies_from_directory,
|
||||
load_policy_from_file,
|
||||
reset_config_cache,
|
||||
)
|
||||
from .exceptions import (
|
||||
ApprovalDeniedError,
|
||||
ApprovalRequiredError,
|
||||
ApprovalTimeoutError,
|
||||
BudgetExceededError,
|
||||
CheckpointError,
|
||||
ContentFilterError,
|
||||
EmergencyStopError,
|
||||
LoopDetectedError,
|
||||
PermissionDeniedError,
|
||||
PolicyViolationError,
|
||||
RateLimitExceededError,
|
||||
RollbackError,
|
||||
SafetyError,
|
||||
SandboxError,
|
||||
SandboxTimeoutError,
|
||||
ValidationError,
|
||||
)
|
||||
|
||||
# Guardian
|
||||
from .guardian import (
|
||||
SafetyGuardian,
|
||||
get_safety_guardian,
|
||||
reset_safety_guardian,
|
||||
shutdown_safety_guardian,
|
||||
)
|
||||
|
||||
# Models
|
||||
from .models import (
|
||||
ActionMetadata,
|
||||
ActionRequest,
|
||||
ActionResult,
|
||||
ActionType,
|
||||
ApprovalRequest,
|
||||
ApprovalResponse,
|
||||
ApprovalStatus,
|
||||
AuditEvent,
|
||||
AuditEventType,
|
||||
AutonomyLevel,
|
||||
BudgetScope,
|
||||
BudgetStatus,
|
||||
Checkpoint,
|
||||
CheckpointType,
|
||||
GuardianResult,
|
||||
PermissionLevel,
|
||||
RateLimitConfig,
|
||||
RateLimitStatus,
|
||||
ResourceType,
|
||||
RollbackResult,
|
||||
SafetyDecision,
|
||||
SafetyPolicy,
|
||||
ValidationResult,
|
||||
ValidationRule,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ActionMetadata",
|
||||
"ActionRequest",
|
||||
"ActionResult",
|
||||
# Models
|
||||
"ActionType",
|
||||
"ApprovalDeniedError",
|
||||
"ApprovalRequest",
|
||||
"ApprovalRequiredError",
|
||||
"ApprovalResponse",
|
||||
"ApprovalStatus",
|
||||
"ApprovalTimeoutError",
|
||||
"AuditEvent",
|
||||
"AuditEventType",
|
||||
# Audit
|
||||
"AuditLogger",
|
||||
"AutonomyConfig",
|
||||
"AutonomyLevel",
|
||||
"BudgetExceededError",
|
||||
"BudgetScope",
|
||||
"BudgetStatus",
|
||||
"Checkpoint",
|
||||
"CheckpointError",
|
||||
"CheckpointType",
|
||||
"ContentFilterError",
|
||||
"EmergencyStopError",
|
||||
"GuardianResult",
|
||||
"LoopDetectedError",
|
||||
"PermissionDeniedError",
|
||||
"PermissionLevel",
|
||||
"PolicyViolationError",
|
||||
"RateLimitConfig",
|
||||
"RateLimitExceededError",
|
||||
"RateLimitStatus",
|
||||
"ResourceType",
|
||||
"RollbackError",
|
||||
"RollbackResult",
|
||||
# Configuration
|
||||
"SafetyConfig",
|
||||
"SafetyDecision",
|
||||
# Exceptions
|
||||
"SafetyError",
|
||||
# Guardian
|
||||
"SafetyGuardian",
|
||||
"SafetyPolicy",
|
||||
"SandboxError",
|
||||
"SandboxTimeoutError",
|
||||
"ValidationError",
|
||||
"ValidationResult",
|
||||
"ValidationRule",
|
||||
"get_audit_logger",
|
||||
"get_autonomy_config",
|
||||
"get_default_policy",
|
||||
"get_policy_for_autonomy_level",
|
||||
"get_safety_config",
|
||||
"get_safety_guardian",
|
||||
"load_policies_from_directory",
|
||||
"load_policy_from_file",
|
||||
"reset_audit_logger",
|
||||
"reset_config_cache",
|
||||
"reset_safety_guardian",
|
||||
"shutdown_audit_logger",
|
||||
"shutdown_safety_guardian",
|
||||
]
|
||||
19
backend/app/services/safety/audit/__init__.py
Normal file
19
backend/app/services/safety/audit/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
Audit System
|
||||
|
||||
Comprehensive audit logging for all safety-related events.
|
||||
"""
|
||||
|
||||
from .logger import (
|
||||
AuditLogger,
|
||||
get_audit_logger,
|
||||
reset_audit_logger,
|
||||
shutdown_audit_logger,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AuditLogger",
|
||||
"get_audit_logger",
|
||||
"reset_audit_logger",
|
||||
"shutdown_audit_logger",
|
||||
]
|
||||
585
backend/app/services/safety/audit/logger.py
Normal file
585
backend/app/services/safety/audit/logger.py
Normal file
@@ -0,0 +1,585 @@
|
||||
"""
|
||||
Audit Logger
|
||||
|
||||
Comprehensive audit logging for all safety-related events.
|
||||
Provides tamper detection, structured logging, and compliance support.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from collections import deque
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from ..config import get_safety_config
|
||||
from ..models import (
|
||||
ActionRequest,
|
||||
AuditEvent,
|
||||
AuditEventType,
|
||||
SafetyDecision,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuditLogger:
|
||||
"""
|
||||
Audit logger for safety events.
|
||||
|
||||
Features:
|
||||
- Structured event logging
|
||||
- In-memory buffer with async flush
|
||||
- Tamper detection via hash chains
|
||||
- Query/search capability
|
||||
- Retention policy enforcement
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_buffer_size: int = 1000,
|
||||
flush_interval_seconds: float = 10.0,
|
||||
enable_hash_chain: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the audit logger.
|
||||
|
||||
Args:
|
||||
max_buffer_size: Maximum events to buffer before auto-flush
|
||||
flush_interval_seconds: Interval for periodic flush
|
||||
enable_hash_chain: Enable tamper detection via hash chain
|
||||
"""
|
||||
self._buffer: deque[AuditEvent] = deque(maxlen=max_buffer_size)
|
||||
self._persisted: list[AuditEvent] = []
|
||||
self._flush_interval = flush_interval_seconds
|
||||
self._enable_hash_chain = enable_hash_chain
|
||||
self._last_hash: str | None = None
|
||||
self._lock = asyncio.Lock()
|
||||
self._flush_task: asyncio.Task[None] | None = None
|
||||
self._running = False
|
||||
|
||||
# Event handlers for real-time processing
|
||||
self._handlers: list[Any] = []
|
||||
|
||||
config = get_safety_config()
|
||||
self._retention_days = config.audit_retention_days
|
||||
self._include_sensitive = config.audit_include_sensitive
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the audit logger background tasks."""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._flush_task = asyncio.create_task(self._periodic_flush())
|
||||
logger.info("Audit logger started")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the audit logger and flush remaining events."""
|
||||
self._running = False
|
||||
|
||||
if self._flush_task:
|
||||
self._flush_task.cancel()
|
||||
try:
|
||||
await self._flush_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Final flush
|
||||
await self.flush()
|
||||
logger.info("Audit logger stopped")
|
||||
|
||||
async def log(
|
||||
self,
|
||||
event_type: AuditEventType,
|
||||
*,
|
||||
agent_id: str | None = None,
|
||||
action_id: str | None = None,
|
||||
project_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
decision: SafetyDecision | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
correlation_id: str | None = None,
|
||||
) -> AuditEvent:
|
||||
"""
|
||||
Log an audit event.
|
||||
|
||||
Args:
|
||||
event_type: Type of audit event
|
||||
agent_id: Agent ID if applicable
|
||||
action_id: Action ID if applicable
|
||||
project_id: Project ID if applicable
|
||||
session_id: Session ID if applicable
|
||||
user_id: User ID if applicable
|
||||
decision: Safety decision if applicable
|
||||
details: Additional event details
|
||||
correlation_id: Correlation ID for tracing
|
||||
|
||||
Returns:
|
||||
The created audit event
|
||||
"""
|
||||
# Sanitize sensitive data if needed
|
||||
sanitized_details = self._sanitize_details(details) if details else {}
|
||||
|
||||
event = AuditEvent(
|
||||
id=str(uuid4()),
|
||||
event_type=event_type,
|
||||
timestamp=datetime.utcnow(),
|
||||
agent_id=agent_id,
|
||||
action_id=action_id,
|
||||
project_id=project_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
decision=decision,
|
||||
details=sanitized_details,
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
# Add hash chain for tamper detection
|
||||
if self._enable_hash_chain:
|
||||
event_hash = self._compute_hash(event)
|
||||
sanitized_details["_hash"] = event_hash
|
||||
sanitized_details["_prev_hash"] = self._last_hash
|
||||
self._last_hash = event_hash
|
||||
|
||||
self._buffer.append(event)
|
||||
|
||||
# Notify handlers
|
||||
await self._notify_handlers(event)
|
||||
|
||||
# Log to standard logger as well
|
||||
self._log_to_logger(event)
|
||||
|
||||
return event
|
||||
|
||||
async def log_action_request(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
decision: SafetyDecision,
|
||||
reasons: list[str] | None = None,
|
||||
) -> AuditEvent:
|
||||
"""Log an action request with its validation decision."""
|
||||
event_type = (
|
||||
AuditEventType.ACTION_DENIED
|
||||
if decision == SafetyDecision.DENY
|
||||
else AuditEventType.ACTION_VALIDATED
|
||||
)
|
||||
|
||||
return await self.log(
|
||||
event_type,
|
||||
agent_id=action.metadata.agent_id,
|
||||
action_id=action.id,
|
||||
project_id=action.metadata.project_id,
|
||||
session_id=action.metadata.session_id,
|
||||
user_id=action.metadata.user_id,
|
||||
decision=decision,
|
||||
details={
|
||||
"action_type": action.action_type.value,
|
||||
"tool_name": action.tool_name,
|
||||
"resource": action.resource,
|
||||
"is_destructive": action.is_destructive,
|
||||
"reasons": reasons or [],
|
||||
},
|
||||
correlation_id=action.metadata.correlation_id,
|
||||
)
|
||||
|
||||
async def log_action_executed(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
success: bool,
|
||||
execution_time_ms: float,
|
||||
error: str | None = None,
|
||||
) -> AuditEvent:
|
||||
"""Log an action execution result."""
|
||||
event_type = (
|
||||
AuditEventType.ACTION_EXECUTED
|
||||
if success
|
||||
else AuditEventType.ACTION_FAILED
|
||||
)
|
||||
|
||||
return await self.log(
|
||||
event_type,
|
||||
agent_id=action.metadata.agent_id,
|
||||
action_id=action.id,
|
||||
project_id=action.metadata.project_id,
|
||||
session_id=action.metadata.session_id,
|
||||
decision=SafetyDecision.ALLOW if success else SafetyDecision.DENY,
|
||||
details={
|
||||
"action_type": action.action_type.value,
|
||||
"tool_name": action.tool_name,
|
||||
"success": success,
|
||||
"execution_time_ms": execution_time_ms,
|
||||
"error": error,
|
||||
},
|
||||
correlation_id=action.metadata.correlation_id,
|
||||
)
|
||||
|
||||
async def log_approval_event(
|
||||
self,
|
||||
event_type: AuditEventType,
|
||||
approval_id: str,
|
||||
action: ActionRequest,
|
||||
decided_by: str | None = None,
|
||||
reason: str | None = None,
|
||||
) -> AuditEvent:
|
||||
"""Log an approval-related event."""
|
||||
return await self.log(
|
||||
event_type,
|
||||
agent_id=action.metadata.agent_id,
|
||||
action_id=action.id,
|
||||
project_id=action.metadata.project_id,
|
||||
session_id=action.metadata.session_id,
|
||||
user_id=decided_by,
|
||||
details={
|
||||
"approval_id": approval_id,
|
||||
"action_type": action.action_type.value,
|
||||
"tool_name": action.tool_name,
|
||||
"decided_by": decided_by,
|
||||
"reason": reason,
|
||||
},
|
||||
correlation_id=action.metadata.correlation_id,
|
||||
)
|
||||
|
||||
async def log_budget_event(
|
||||
self,
|
||||
event_type: AuditEventType,
|
||||
agent_id: str,
|
||||
scope: str,
|
||||
current_usage: float,
|
||||
limit: float,
|
||||
unit: str = "tokens",
|
||||
) -> AuditEvent:
|
||||
"""Log a budget-related event."""
|
||||
return await self.log(
|
||||
event_type,
|
||||
agent_id=agent_id,
|
||||
details={
|
||||
"scope": scope,
|
||||
"current_usage": current_usage,
|
||||
"limit": limit,
|
||||
"unit": unit,
|
||||
"usage_percent": (current_usage / limit * 100) if limit > 0 else 0,
|
||||
},
|
||||
)
|
||||
|
||||
async def log_emergency_stop(
|
||||
self,
|
||||
stop_type: str,
|
||||
triggered_by: str,
|
||||
reason: str,
|
||||
affected_agents: list[str] | None = None,
|
||||
) -> AuditEvent:
|
||||
"""Log an emergency stop event."""
|
||||
return await self.log(
|
||||
AuditEventType.EMERGENCY_STOP,
|
||||
user_id=triggered_by,
|
||||
details={
|
||||
"stop_type": stop_type,
|
||||
"triggered_by": triggered_by,
|
||||
"reason": reason,
|
||||
"affected_agents": affected_agents or [],
|
||||
},
|
||||
)
|
||||
|
||||
async def flush(self) -> int:
|
||||
"""
|
||||
Flush buffered events to persistent storage.
|
||||
|
||||
Returns:
|
||||
Number of events flushed
|
||||
"""
|
||||
async with self._lock:
|
||||
if not self._buffer:
|
||||
return 0
|
||||
|
||||
events = list(self._buffer)
|
||||
self._buffer.clear()
|
||||
|
||||
# Persist events (in production, this would go to database/storage)
|
||||
self._persisted.extend(events)
|
||||
|
||||
# Enforce retention
|
||||
self._enforce_retention()
|
||||
|
||||
logger.debug("Flushed %d audit events", len(events))
|
||||
return len(events)
|
||||
|
||||
async def query(
|
||||
self,
|
||||
*,
|
||||
event_types: list[AuditEventType] | None = None,
|
||||
agent_id: str | None = None,
|
||||
action_id: str | None = None,
|
||||
project_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
correlation_id: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[AuditEvent]:
|
||||
"""
|
||||
Query audit events with filters.
|
||||
|
||||
Args:
|
||||
event_types: Filter by event types
|
||||
agent_id: Filter by agent ID
|
||||
action_id: Filter by action ID
|
||||
project_id: Filter by project ID
|
||||
session_id: Filter by session ID
|
||||
user_id: Filter by user ID
|
||||
start_time: Filter events after this time
|
||||
end_time: Filter events before this time
|
||||
correlation_id: Filter by correlation ID
|
||||
limit: Maximum results to return
|
||||
offset: Result offset for pagination
|
||||
|
||||
Returns:
|
||||
List of matching audit events
|
||||
"""
|
||||
# Combine buffer and persisted for query
|
||||
all_events = list(self._persisted) + list(self._buffer)
|
||||
|
||||
results = []
|
||||
for event in all_events:
|
||||
if event_types and event.event_type not in event_types:
|
||||
continue
|
||||
if agent_id and event.agent_id != agent_id:
|
||||
continue
|
||||
if action_id and event.action_id != action_id:
|
||||
continue
|
||||
if project_id and event.project_id != project_id:
|
||||
continue
|
||||
if session_id and event.session_id != session_id:
|
||||
continue
|
||||
if user_id and event.user_id != user_id:
|
||||
continue
|
||||
if start_time and event.timestamp < start_time:
|
||||
continue
|
||||
if end_time and event.timestamp > end_time:
|
||||
continue
|
||||
if correlation_id and event.correlation_id != correlation_id:
|
||||
continue
|
||||
|
||||
results.append(event)
|
||||
|
||||
# Sort by timestamp descending
|
||||
results.sort(key=lambda e: e.timestamp, reverse=True)
|
||||
|
||||
# Apply pagination
|
||||
return results[offset : offset + limit]
|
||||
|
||||
async def get_action_history(
|
||||
self,
|
||||
agent_id: str,
|
||||
limit: int = 100,
|
||||
) -> list[AuditEvent]:
|
||||
"""Get action history for an agent."""
|
||||
return await self.query(
|
||||
agent_id=agent_id,
|
||||
event_types=[
|
||||
AuditEventType.ACTION_REQUESTED,
|
||||
AuditEventType.ACTION_VALIDATED,
|
||||
AuditEventType.ACTION_DENIED,
|
||||
AuditEventType.ACTION_EXECUTED,
|
||||
AuditEventType.ACTION_FAILED,
|
||||
],
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
async def verify_integrity(self) -> tuple[bool, list[str]]:
|
||||
"""
|
||||
Verify audit log integrity using hash chain.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, list of issues found)
|
||||
"""
|
||||
if not self._enable_hash_chain:
|
||||
return True, []
|
||||
|
||||
issues: list[str] = []
|
||||
all_events = list(self._persisted) + list(self._buffer)
|
||||
|
||||
prev_hash: str | None = None
|
||||
for event in sorted(all_events, key=lambda e: e.timestamp):
|
||||
stored_prev = event.details.get("_prev_hash")
|
||||
stored_hash = event.details.get("_hash")
|
||||
|
||||
if stored_prev != prev_hash:
|
||||
issues.append(
|
||||
f"Hash chain broken at event {event.id}: "
|
||||
f"expected prev_hash={prev_hash}, got {stored_prev}"
|
||||
)
|
||||
|
||||
if stored_hash:
|
||||
computed = self._compute_hash(event)
|
||||
if computed != stored_hash:
|
||||
issues.append(
|
||||
f"Hash mismatch at event {event.id}: "
|
||||
f"expected {computed}, got {stored_hash}"
|
||||
)
|
||||
|
||||
prev_hash = stored_hash
|
||||
|
||||
return len(issues) == 0, issues
|
||||
|
||||
def add_handler(self, handler: Any) -> None:
|
||||
"""Add a real-time event handler."""
|
||||
self._handlers.append(handler)
|
||||
|
||||
def remove_handler(self, handler: Any) -> None:
|
||||
"""Remove an event handler."""
|
||||
if handler in self._handlers:
|
||||
self._handlers.remove(handler)
|
||||
|
||||
def _sanitize_details(self, details: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Sanitize sensitive data from details."""
|
||||
if self._include_sensitive:
|
||||
return details
|
||||
|
||||
sanitized: dict[str, Any] = {}
|
||||
sensitive_keys = {
|
||||
"password",
|
||||
"secret",
|
||||
"token",
|
||||
"api_key",
|
||||
"apikey",
|
||||
"auth",
|
||||
"credential",
|
||||
}
|
||||
|
||||
for key, value in details.items():
|
||||
lower_key = key.lower()
|
||||
if any(s in lower_key for s in sensitive_keys):
|
||||
sanitized[key] = "[REDACTED]"
|
||||
elif isinstance(value, dict):
|
||||
sanitized[key] = self._sanitize_details(value)
|
||||
else:
|
||||
sanitized[key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
def _compute_hash(self, event: AuditEvent) -> str:
|
||||
"""Compute hash for an event (excluding hash fields)."""
|
||||
data = {
|
||||
"id": event.id,
|
||||
"event_type": event.event_type.value,
|
||||
"timestamp": event.timestamp.isoformat(),
|
||||
"agent_id": event.agent_id,
|
||||
"action_id": event.action_id,
|
||||
"project_id": event.project_id,
|
||||
"session_id": event.session_id,
|
||||
"user_id": event.user_id,
|
||||
"decision": event.decision.value if event.decision else None,
|
||||
"details": {
|
||||
k: v
|
||||
for k, v in event.details.items()
|
||||
if not k.startswith("_")
|
||||
},
|
||||
"correlation_id": event.correlation_id,
|
||||
}
|
||||
|
||||
if self._last_hash:
|
||||
data["_prev_hash"] = self._last_hash
|
||||
|
||||
serialized = json.dumps(data, sort_keys=True, default=str)
|
||||
return hashlib.sha256(serialized.encode()).hexdigest()
|
||||
|
||||
def _log_to_logger(self, event: AuditEvent) -> None:
|
||||
"""Log event to standard Python logger."""
|
||||
log_data = {
|
||||
"audit_event": event.event_type.value,
|
||||
"event_id": event.id,
|
||||
"agent_id": event.agent_id,
|
||||
"action_id": event.action_id,
|
||||
"decision": event.decision.value if event.decision else None,
|
||||
}
|
||||
|
||||
# Use appropriate log level based on event type
|
||||
if event.event_type in {
|
||||
AuditEventType.ACTION_DENIED,
|
||||
AuditEventType.POLICY_VIOLATION,
|
||||
AuditEventType.EMERGENCY_STOP,
|
||||
}:
|
||||
logger.warning("Audit: %s", log_data)
|
||||
elif event.event_type in {
|
||||
AuditEventType.ACTION_FAILED,
|
||||
AuditEventType.ROLLBACK_FAILED,
|
||||
}:
|
||||
logger.error("Audit: %s", log_data)
|
||||
else:
|
||||
logger.info("Audit: %s", log_data)
|
||||
|
||||
def _enforce_retention(self) -> None:
|
||||
"""Enforce retention policy on persisted events."""
|
||||
if not self._retention_days:
|
||||
return
|
||||
|
||||
cutoff = datetime.utcnow() - timedelta(days=self._retention_days)
|
||||
before_count = len(self._persisted)
|
||||
|
||||
self._persisted = [e for e in self._persisted if e.timestamp >= cutoff]
|
||||
|
||||
removed = before_count - len(self._persisted)
|
||||
if removed > 0:
|
||||
logger.info("Removed %d expired audit events", removed)
|
||||
|
||||
async def _periodic_flush(self) -> None:
|
||||
"""Background task for periodic flushing."""
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(self._flush_interval)
|
||||
await self.flush()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("Error in periodic audit flush: %s", e)
|
||||
|
||||
async def _notify_handlers(self, event: AuditEvent) -> None:
|
||||
"""Notify all registered handlers of a new event."""
|
||||
for handler in self._handlers:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
await handler(event)
|
||||
else:
|
||||
handler(event)
|
||||
except Exception as e:
|
||||
logger.error("Error in audit event handler: %s", e)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_audit_logger: AuditLogger | None = None
|
||||
_audit_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_audit_logger() -> AuditLogger:
|
||||
"""Get the global audit logger instance."""
|
||||
global _audit_logger
|
||||
|
||||
async with _audit_lock:
|
||||
if _audit_logger is None:
|
||||
_audit_logger = AuditLogger()
|
||||
await _audit_logger.start()
|
||||
|
||||
return _audit_logger
|
||||
|
||||
|
||||
async def shutdown_audit_logger() -> None:
|
||||
"""Shutdown the global audit logger."""
|
||||
global _audit_logger
|
||||
|
||||
async with _audit_lock:
|
||||
if _audit_logger is not None:
|
||||
await _audit_logger.stop()
|
||||
_audit_logger = None
|
||||
|
||||
|
||||
def reset_audit_logger() -> None:
|
||||
"""Reset the audit logger (for testing)."""
|
||||
global _audit_logger
|
||||
_audit_logger = None
|
||||
300
backend/app/services/safety/config.py
Normal file
300
backend/app/services/safety/config.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""
|
||||
Safety Framework Configuration
|
||||
|
||||
Pydantic settings for the safety and guardrails framework.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from .models import AutonomyLevel, SafetyPolicy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SafetyConfig(BaseSettings):
|
||||
"""Configuration for the safety framework."""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="SAFETY_",
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
# General settings
|
||||
enabled: bool = Field(True, description="Enable safety framework")
|
||||
strict_mode: bool = Field(
|
||||
True, description="Strict mode (fail closed on errors)"
|
||||
)
|
||||
log_level: str = Field("INFO", description="Logging level")
|
||||
|
||||
# Default autonomy level
|
||||
default_autonomy_level: AutonomyLevel = Field(
|
||||
AutonomyLevel.MILESTONE,
|
||||
description="Default autonomy level for new agents",
|
||||
)
|
||||
|
||||
# Default budget limits
|
||||
default_session_token_budget: int = Field(
|
||||
100_000, description="Default tokens per session"
|
||||
)
|
||||
default_daily_token_budget: int = Field(
|
||||
1_000_000, description="Default tokens per day"
|
||||
)
|
||||
default_session_cost_limit: float = Field(
|
||||
10.0, description="Default USD per session"
|
||||
)
|
||||
default_daily_cost_limit: float = Field(100.0, description="Default USD per day")
|
||||
|
||||
# Default rate limits
|
||||
default_actions_per_minute: int = Field(60, description="Default actions per min")
|
||||
default_llm_calls_per_minute: int = Field(20, description="Default LLM calls/min")
|
||||
default_file_ops_per_minute: int = Field(100, description="Default file ops/min")
|
||||
|
||||
# Loop detection
|
||||
loop_detection_enabled: bool = Field(True, description="Enable loop detection")
|
||||
max_repeated_actions: int = Field(5, description="Max exact repetitions")
|
||||
max_similar_actions: int = Field(10, description="Max similar actions")
|
||||
loop_history_size: int = Field(100, description="Action history size for loops")
|
||||
|
||||
# HITL settings
|
||||
hitl_enabled: bool = Field(True, description="Enable human-in-the-loop")
|
||||
hitl_default_timeout: int = Field(300, description="Default approval timeout (s)")
|
||||
hitl_notification_channels: list[str] = Field(
|
||||
default_factory=list, description="Notification channels"
|
||||
)
|
||||
|
||||
# Rollback settings
|
||||
rollback_enabled: bool = Field(True, description="Enable rollback capability")
|
||||
checkpoint_retention_hours: int = Field(24, description="Checkpoint retention")
|
||||
auto_checkpoint_destructive: bool = Field(
|
||||
True, description="Auto-checkpoint destructive actions"
|
||||
)
|
||||
|
||||
# Sandbox settings
|
||||
sandbox_enabled: bool = Field(False, description="Enable sandbox execution")
|
||||
sandbox_timeout: int = Field(300, description="Sandbox timeout (s)")
|
||||
sandbox_memory_mb: int = Field(1024, description="Sandbox memory limit (MB)")
|
||||
sandbox_cpu_limit: float = Field(1.0, description="Sandbox CPU limit")
|
||||
sandbox_network_enabled: bool = Field(False, description="Allow sandbox network")
|
||||
|
||||
# Audit settings
|
||||
audit_enabled: bool = Field(True, description="Enable audit logging")
|
||||
audit_retention_days: int = Field(90, description="Audit log retention (days)")
|
||||
audit_include_sensitive: bool = Field(
|
||||
False, description="Include sensitive data in audit"
|
||||
)
|
||||
|
||||
# Content filtering
|
||||
content_filter_enabled: bool = Field(True, description="Enable content filtering")
|
||||
filter_pii: bool = Field(True, description="Filter PII")
|
||||
filter_secrets: bool = Field(True, description="Filter secrets")
|
||||
|
||||
# Emergency controls
|
||||
emergency_stop_enabled: bool = Field(True, description="Enable emergency stop")
|
||||
emergency_webhook_url: str | None = Field(None, description="Emergency webhook")
|
||||
|
||||
# Policy file path
|
||||
policy_file: str | None = Field(None, description="Path to policy YAML file")
|
||||
|
||||
# Validation cache
|
||||
validation_cache_ttl: int = Field(60, description="Validation cache TTL (s)")
|
||||
validation_cache_size: int = Field(1000, description="Validation cache size")
|
||||
|
||||
|
||||
class AutonomyConfig(BaseSettings):
|
||||
"""Configuration for autonomy levels."""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="AUTONOMY_",
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
# FULL_CONTROL settings
|
||||
full_control_cost_limit: float = Field(1.0, description="USD limit per session")
|
||||
full_control_require_all_approval: bool = Field(
|
||||
True, description="Require approval for all"
|
||||
)
|
||||
full_control_block_destructive: bool = Field(
|
||||
True, description="Block destructive actions"
|
||||
)
|
||||
|
||||
# MILESTONE settings
|
||||
milestone_cost_limit: float = Field(10.0, description="USD limit per session")
|
||||
milestone_require_critical_approval: bool = Field(
|
||||
True, description="Require approval for critical"
|
||||
)
|
||||
milestone_auto_checkpoint: bool = Field(
|
||||
True, description="Auto-checkpoint destructive"
|
||||
)
|
||||
|
||||
# AUTONOMOUS settings
|
||||
autonomous_cost_limit: float = Field(100.0, description="USD limit per session")
|
||||
autonomous_auto_approve_normal: bool = Field(
|
||||
True, description="Auto-approve normal actions"
|
||||
)
|
||||
autonomous_auto_checkpoint: bool = Field(True, description="Auto-checkpoint all")
|
||||
|
||||
|
||||
def _expand_env_vars(value: Any) -> Any:
|
||||
"""Recursively expand environment variables in values."""
|
||||
if isinstance(value, str):
|
||||
return os.path.expandvars(value)
|
||||
elif isinstance(value, dict):
|
||||
return {k: _expand_env_vars(v) for k, v in value.items()}
|
||||
elif isinstance(value, list):
|
||||
return [_expand_env_vars(v) for v in value]
|
||||
return value
|
||||
|
||||
|
||||
def load_policy_from_file(file_path: str | Path) -> SafetyPolicy | None:
|
||||
"""Load a safety policy from a YAML file."""
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
logger.warning("Policy file not found: %s", path)
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(path) as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
if data is None:
|
||||
logger.warning("Empty policy file: %s", path)
|
||||
return None
|
||||
|
||||
# Expand environment variables
|
||||
data = _expand_env_vars(data)
|
||||
|
||||
return SafetyPolicy(**data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to load policy file %s: %s", path, e)
|
||||
return None
|
||||
|
||||
|
||||
def load_policies_from_directory(directory: str | Path) -> dict[str, SafetyPolicy]:
|
||||
"""Load all safety policies from a directory."""
|
||||
policies: dict[str, SafetyPolicy] = {}
|
||||
path = Path(directory)
|
||||
|
||||
if not path.exists() or not path.is_dir():
|
||||
logger.warning("Policy directory not found: %s", path)
|
||||
return policies
|
||||
|
||||
for file_path in path.glob("*.yaml"):
|
||||
policy = load_policy_from_file(file_path)
|
||||
if policy:
|
||||
policies[policy.name] = policy
|
||||
logger.info("Loaded policy: %s from %s", policy.name, file_path.name)
|
||||
|
||||
return policies
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_safety_config() -> SafetyConfig:
|
||||
"""Get the safety configuration (cached singleton)."""
|
||||
return SafetyConfig()
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_autonomy_config() -> AutonomyConfig:
|
||||
"""Get the autonomy configuration (cached singleton)."""
|
||||
return AutonomyConfig()
|
||||
|
||||
|
||||
def get_default_policy() -> SafetyPolicy:
|
||||
"""Get the default safety policy."""
|
||||
config = get_safety_config()
|
||||
|
||||
return SafetyPolicy(
|
||||
name="default",
|
||||
description="Default safety policy",
|
||||
max_tokens_per_session=config.default_session_token_budget,
|
||||
max_tokens_per_day=config.default_daily_token_budget,
|
||||
max_cost_per_session_usd=config.default_session_cost_limit,
|
||||
max_cost_per_day_usd=config.default_daily_cost_limit,
|
||||
max_actions_per_minute=config.default_actions_per_minute,
|
||||
max_llm_calls_per_minute=config.default_llm_calls_per_minute,
|
||||
max_file_operations_per_minute=config.default_file_ops_per_minute,
|
||||
max_repeated_actions=config.max_repeated_actions,
|
||||
max_similar_actions=config.max_similar_actions,
|
||||
require_sandbox=config.sandbox_enabled,
|
||||
sandbox_timeout_seconds=config.sandbox_timeout,
|
||||
sandbox_memory_mb=config.sandbox_memory_mb,
|
||||
)
|
||||
|
||||
|
||||
def get_policy_for_autonomy_level(level: AutonomyLevel) -> SafetyPolicy:
|
||||
"""Get the safety policy for a given autonomy level."""
|
||||
autonomy = get_autonomy_config()
|
||||
|
||||
base_policy = get_default_policy()
|
||||
|
||||
if level == AutonomyLevel.FULL_CONTROL:
|
||||
return SafetyPolicy(
|
||||
name="full_control",
|
||||
description="Full control mode - all actions require approval",
|
||||
max_cost_per_session_usd=autonomy.full_control_cost_limit,
|
||||
max_cost_per_day_usd=autonomy.full_control_cost_limit * 10,
|
||||
require_approval_for=["*"], # All actions
|
||||
max_tokens_per_session=base_policy.max_tokens_per_session // 10,
|
||||
max_tokens_per_day=base_policy.max_tokens_per_day // 10,
|
||||
max_actions_per_minute=base_policy.max_actions_per_minute // 2,
|
||||
max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute // 2,
|
||||
max_file_operations_per_minute=base_policy.max_file_operations_per_minute // 2,
|
||||
denied_tools=["delete_*", "destroy_*", "drop_*"],
|
||||
)
|
||||
|
||||
elif level == AutonomyLevel.MILESTONE:
|
||||
return SafetyPolicy(
|
||||
name="milestone",
|
||||
description="Milestone mode - approval at milestones only",
|
||||
max_cost_per_session_usd=autonomy.milestone_cost_limit,
|
||||
max_cost_per_day_usd=autonomy.milestone_cost_limit * 10,
|
||||
require_approval_for=[
|
||||
"delete_file",
|
||||
"push_to_remote",
|
||||
"deploy_*",
|
||||
"modify_critical_*",
|
||||
"create_pull_request",
|
||||
],
|
||||
max_tokens_per_session=base_policy.max_tokens_per_session,
|
||||
max_tokens_per_day=base_policy.max_tokens_per_day,
|
||||
max_actions_per_minute=base_policy.max_actions_per_minute,
|
||||
max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute,
|
||||
max_file_operations_per_minute=base_policy.max_file_operations_per_minute,
|
||||
)
|
||||
|
||||
else: # AUTONOMOUS
|
||||
return SafetyPolicy(
|
||||
name="autonomous",
|
||||
description="Autonomous mode - minimal intervention",
|
||||
max_cost_per_session_usd=autonomy.autonomous_cost_limit,
|
||||
max_cost_per_day_usd=autonomy.autonomous_cost_limit * 10,
|
||||
require_approval_for=[
|
||||
"deploy_to_production",
|
||||
"delete_repository",
|
||||
"modify_production_config",
|
||||
],
|
||||
max_tokens_per_session=base_policy.max_tokens_per_session * 5,
|
||||
max_tokens_per_day=base_policy.max_tokens_per_day * 5,
|
||||
max_actions_per_minute=base_policy.max_actions_per_minute * 2,
|
||||
max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute * 2,
|
||||
max_file_operations_per_minute=base_policy.max_file_operations_per_minute * 2,
|
||||
)
|
||||
|
||||
|
||||
def reset_config_cache() -> None:
|
||||
"""Reset configuration caches (for testing)."""
|
||||
get_safety_config.cache_clear()
|
||||
get_autonomy_config.cache_clear()
|
||||
1
backend/app/services/safety/content/__init__.py
Normal file
1
backend/app/services/safety/content/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""${dir} module."""
|
||||
1
backend/app/services/safety/costs/__init__.py
Normal file
1
backend/app/services/safety/costs/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""${dir} module."""
|
||||
1
backend/app/services/safety/emergency/__init__.py
Normal file
1
backend/app/services/safety/emergency/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""${dir} module."""
|
||||
277
backend/app/services/safety/exceptions.py
Normal file
277
backend/app/services/safety/exceptions.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""
|
||||
Safety Framework Exceptions
|
||||
|
||||
Custom exception classes for the safety and guardrails framework.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class SafetyError(Exception):
|
||||
"""Base exception for all safety-related errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
action_id: str | None = None,
|
||||
agent_id: str | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.action_id = action_id
|
||||
self.agent_id = agent_id
|
||||
self.details = details or {}
|
||||
|
||||
|
||||
class PermissionDeniedError(SafetyError):
|
||||
"""Raised when an action is not permitted."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Permission denied",
|
||||
*,
|
||||
action_type: str | None = None,
|
||||
resource: str | None = None,
|
||||
required_permission: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.action_type = action_type
|
||||
self.resource = resource
|
||||
self.required_permission = required_permission
|
||||
|
||||
|
||||
class BudgetExceededError(SafetyError):
|
||||
"""Raised when cost budget is exceeded."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Budget exceeded",
|
||||
*,
|
||||
budget_type: str = "session",
|
||||
current_usage: float = 0.0,
|
||||
budget_limit: float = 0.0,
|
||||
unit: str = "tokens",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.budget_type = budget_type
|
||||
self.current_usage = current_usage
|
||||
self.budget_limit = budget_limit
|
||||
self.unit = unit
|
||||
|
||||
|
||||
class RateLimitExceededError(SafetyError):
|
||||
"""Raised when rate limit is exceeded."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Rate limit exceeded",
|
||||
*,
|
||||
limit_type: str = "actions",
|
||||
limit_value: int = 0,
|
||||
window_seconds: int = 60,
|
||||
retry_after_seconds: float = 0.0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.limit_type = limit_type
|
||||
self.limit_value = limit_value
|
||||
self.window_seconds = window_seconds
|
||||
self.retry_after_seconds = retry_after_seconds
|
||||
|
||||
|
||||
class LoopDetectedError(SafetyError):
|
||||
"""Raised when an action loop is detected."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Loop detected",
|
||||
*,
|
||||
loop_type: str = "exact",
|
||||
repetition_count: int = 0,
|
||||
action_pattern: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.loop_type = loop_type
|
||||
self.repetition_count = repetition_count
|
||||
self.action_pattern = action_pattern or []
|
||||
|
||||
|
||||
class ApprovalRequiredError(SafetyError):
|
||||
"""Raised when human approval is required."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Human approval required",
|
||||
*,
|
||||
approval_id: str | None = None,
|
||||
reason: str | None = None,
|
||||
timeout_seconds: int = 300,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.approval_id = approval_id
|
||||
self.reason = reason
|
||||
self.timeout_seconds = timeout_seconds
|
||||
|
||||
|
||||
class ApprovalDeniedError(SafetyError):
|
||||
"""Raised when human explicitly denies an action."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Approval denied by human",
|
||||
*,
|
||||
approval_id: str | None = None,
|
||||
denied_by: str | None = None,
|
||||
denial_reason: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.approval_id = approval_id
|
||||
self.denied_by = denied_by
|
||||
self.denial_reason = denial_reason
|
||||
|
||||
|
||||
class ApprovalTimeoutError(SafetyError):
|
||||
"""Raised when approval request times out."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Approval request timed out",
|
||||
*,
|
||||
approval_id: str | None = None,
|
||||
timeout_seconds: int = 300,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.approval_id = approval_id
|
||||
self.timeout_seconds = timeout_seconds
|
||||
|
||||
|
||||
class RollbackError(SafetyError):
|
||||
"""Raised when rollback fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Rollback failed",
|
||||
*,
|
||||
checkpoint_id: str | None = None,
|
||||
failed_actions: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.checkpoint_id = checkpoint_id
|
||||
self.failed_actions = failed_actions or []
|
||||
|
||||
|
||||
class CheckpointError(SafetyError):
|
||||
"""Raised when checkpoint creation fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Checkpoint creation failed",
|
||||
*,
|
||||
checkpoint_type: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.checkpoint_type = checkpoint_type
|
||||
|
||||
|
||||
class ValidationError(SafetyError):
|
||||
"""Raised when action validation fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Validation failed",
|
||||
*,
|
||||
validation_rules: list[str] | None = None,
|
||||
failed_rules: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.validation_rules = validation_rules or []
|
||||
self.failed_rules = failed_rules or []
|
||||
|
||||
|
||||
class ContentFilterError(SafetyError):
|
||||
"""Raised when content filtering detects prohibited content."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Prohibited content detected",
|
||||
*,
|
||||
filter_type: str | None = None,
|
||||
detected_patterns: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.filter_type = filter_type
|
||||
self.detected_patterns = detected_patterns or []
|
||||
|
||||
|
||||
class SandboxError(SafetyError):
|
||||
"""Raised when sandbox execution fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Sandbox execution failed",
|
||||
*,
|
||||
exit_code: int | None = None,
|
||||
stderr: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.exit_code = exit_code
|
||||
self.stderr = stderr
|
||||
|
||||
|
||||
class SandboxTimeoutError(SandboxError):
|
||||
"""Raised when sandbox execution times out."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Sandbox execution timed out",
|
||||
*,
|
||||
timeout_seconds: int = 300,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.timeout_seconds = timeout_seconds
|
||||
|
||||
|
||||
class EmergencyStopError(SafetyError):
|
||||
"""Raised when emergency stop is triggered."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Emergency stop triggered",
|
||||
*,
|
||||
stop_type: str = "kill",
|
||||
triggered_by: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.stop_type = stop_type
|
||||
self.triggered_by = triggered_by
|
||||
|
||||
|
||||
class PolicyViolationError(SafetyError):
|
||||
"""Raised when an action violates a safety policy."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Policy violation",
|
||||
*,
|
||||
policy_name: str | None = None,
|
||||
violated_rules: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.policy_name = policy_name
|
||||
self.violated_rules = violated_rules or []
|
||||
614
backend/app/services/safety/guardian.py
Normal file
614
backend/app/services/safety/guardian.py
Normal file
@@ -0,0 +1,614 @@
|
||||
"""
|
||||
Safety Guardian
|
||||
|
||||
Main facade for the safety framework. Orchestrates all safety checks
|
||||
before, during, and after action execution.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from .audit import AuditLogger, get_audit_logger
|
||||
from .config import (
|
||||
SafetyConfig,
|
||||
get_policy_for_autonomy_level,
|
||||
get_safety_config,
|
||||
)
|
||||
from .exceptions import (
|
||||
SafetyError,
|
||||
)
|
||||
from .models import (
|
||||
ActionRequest,
|
||||
ActionResult,
|
||||
AuditEventType,
|
||||
GuardianResult,
|
||||
SafetyDecision,
|
||||
SafetyPolicy,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SafetyGuardian:
|
||||
"""
|
||||
Central orchestrator for all safety checks.
|
||||
|
||||
The SafetyGuardian is the main entry point for validating agent actions.
|
||||
It coordinates multiple safety subsystems:
|
||||
- Permission checking
|
||||
- Cost/budget control
|
||||
- Rate limiting
|
||||
- Loop detection
|
||||
- Human-in-the-loop approval
|
||||
- Rollback/checkpoint management
|
||||
- Content filtering
|
||||
- Sandbox execution
|
||||
|
||||
Usage:
|
||||
guardian = SafetyGuardian()
|
||||
await guardian.initialize()
|
||||
|
||||
# Before executing an action
|
||||
result = await guardian.validate(action_request)
|
||||
if not result.allowed:
|
||||
# Handle denial
|
||||
|
||||
# After action execution
|
||||
await guardian.record_execution(action_request, action_result)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SafetyConfig | None = None,
|
||||
audit_logger: AuditLogger | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the SafetyGuardian.
|
||||
|
||||
Args:
|
||||
config: Optional safety configuration. If None, loads from environment.
|
||||
audit_logger: Optional audit logger. If None, uses global instance.
|
||||
"""
|
||||
self._config = config or get_safety_config()
|
||||
self._audit_logger = audit_logger
|
||||
self._initialized = False
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Subsystem references (will be initialized lazily)
|
||||
self._permission_manager: Any = None
|
||||
self._cost_controller: Any = None
|
||||
self._rate_limiter: Any = None
|
||||
self._loop_detector: Any = None
|
||||
self._hitl_manager: Any = None
|
||||
self._rollback_manager: Any = None
|
||||
self._content_filter: Any = None
|
||||
self._sandbox_executor: Any = None
|
||||
self._emergency_controls: Any = None
|
||||
|
||||
# Policy cache
|
||||
self._policies: dict[str, SafetyPolicy] = {}
|
||||
self._default_policy: SafetyPolicy | None = None
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if the guardian is initialized."""
|
||||
return self._initialized
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the SafetyGuardian and all subsystems."""
|
||||
async with self._lock:
|
||||
if self._initialized:
|
||||
logger.warning("SafetyGuardian already initialized")
|
||||
return
|
||||
|
||||
logger.info("Initializing SafetyGuardian")
|
||||
|
||||
# Get audit logger
|
||||
if self._audit_logger is None:
|
||||
self._audit_logger = await get_audit_logger()
|
||||
|
||||
# Initialize subsystems lazily as they're implemented
|
||||
# For now, we'll import and initialize them when available
|
||||
|
||||
self._initialized = True
|
||||
logger.info("SafetyGuardian initialized")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Shutdown the SafetyGuardian and all subsystems."""
|
||||
async with self._lock:
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
logger.info("Shutting down SafetyGuardian")
|
||||
|
||||
# Shutdown subsystems
|
||||
# (Will be implemented as subsystems are added)
|
||||
|
||||
self._initialized = False
|
||||
logger.info("SafetyGuardian shutdown complete")
|
||||
|
||||
async def validate(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
policy: SafetyPolicy | None = None,
|
||||
) -> GuardianResult:
|
||||
"""
|
||||
Validate an action before execution.
|
||||
|
||||
Runs all safety checks in order:
|
||||
1. Permission check
|
||||
2. Cost/budget check
|
||||
3. Rate limit check
|
||||
4. Loop detection
|
||||
5. HITL check (if required)
|
||||
6. Checkpoint creation (if destructive)
|
||||
|
||||
Args:
|
||||
action: The action to validate
|
||||
policy: Optional policy override. If None, uses autonomy-level policy.
|
||||
|
||||
Returns:
|
||||
GuardianResult with decision and details
|
||||
"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
if not self._config.enabled:
|
||||
# Safety disabled - allow everything (NOT RECOMMENDED)
|
||||
logger.warning("Safety framework disabled - allowing action %s", action.id)
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["Safety framework disabled"],
|
||||
)
|
||||
|
||||
# Get policy for this action
|
||||
effective_policy = policy or self._get_policy(action)
|
||||
|
||||
reasons: list[str] = []
|
||||
audit_events = []
|
||||
|
||||
try:
|
||||
# Log action request
|
||||
if self._audit_logger:
|
||||
event = await self._audit_logger.log(
|
||||
AuditEventType.ACTION_REQUESTED,
|
||||
agent_id=action.metadata.agent_id,
|
||||
action_id=action.id,
|
||||
project_id=action.metadata.project_id,
|
||||
session_id=action.metadata.session_id,
|
||||
details={
|
||||
"action_type": action.action_type.value,
|
||||
"tool_name": action.tool_name,
|
||||
"resource": action.resource,
|
||||
},
|
||||
correlation_id=action.metadata.correlation_id,
|
||||
)
|
||||
audit_events.append(event)
|
||||
|
||||
# 1. Permission check
|
||||
permission_result = await self._check_permissions(action, effective_policy)
|
||||
if permission_result.decision == SafetyDecision.DENY:
|
||||
return await self._create_denial_result(
|
||||
action, permission_result.reasons, audit_events
|
||||
)
|
||||
|
||||
# 2. Cost/budget check
|
||||
budget_result = await self._check_budget(action, effective_policy)
|
||||
if budget_result.decision == SafetyDecision.DENY:
|
||||
return await self._create_denial_result(
|
||||
action, budget_result.reasons, audit_events
|
||||
)
|
||||
|
||||
# 3. Rate limit check
|
||||
rate_result = await self._check_rate_limit(action, effective_policy)
|
||||
if rate_result.decision == SafetyDecision.DENY:
|
||||
return await self._create_denial_result(
|
||||
action,
|
||||
rate_result.reasons,
|
||||
audit_events,
|
||||
retry_after=rate_result.retry_after_seconds,
|
||||
)
|
||||
if rate_result.decision == SafetyDecision.DELAY:
|
||||
# Return delay decision
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DELAY,
|
||||
reasons=rate_result.reasons,
|
||||
retry_after_seconds=rate_result.retry_after_seconds,
|
||||
audit_events=audit_events,
|
||||
)
|
||||
|
||||
# 4. Loop detection
|
||||
loop_result = await self._check_loops(action, effective_policy)
|
||||
if loop_result.decision == SafetyDecision.DENY:
|
||||
return await self._create_denial_result(
|
||||
action, loop_result.reasons, audit_events
|
||||
)
|
||||
|
||||
# 5. HITL check
|
||||
hitl_result = await self._check_hitl(action, effective_policy)
|
||||
if hitl_result.decision == SafetyDecision.REQUIRE_APPROVAL:
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.REQUIRE_APPROVAL,
|
||||
reasons=hitl_result.reasons,
|
||||
approval_id=hitl_result.approval_id,
|
||||
audit_events=audit_events,
|
||||
)
|
||||
|
||||
# 6. Create checkpoint if destructive
|
||||
checkpoint_id = None
|
||||
if action.is_destructive and self._config.auto_checkpoint_destructive:
|
||||
checkpoint_id = await self._create_checkpoint(action)
|
||||
|
||||
# All checks passed
|
||||
reasons.append("All safety checks passed")
|
||||
|
||||
if self._audit_logger:
|
||||
event = await self._audit_logger.log_action_request(
|
||||
action, SafetyDecision.ALLOW, reasons
|
||||
)
|
||||
audit_events.append(event)
|
||||
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=reasons,
|
||||
checkpoint_id=checkpoint_id,
|
||||
audit_events=audit_events,
|
||||
)
|
||||
|
||||
except SafetyError as e:
|
||||
# Known safety error
|
||||
return await self._create_denial_result(
|
||||
action, [str(e)], audit_events
|
||||
)
|
||||
except Exception as e:
|
||||
# Unknown error - fail closed in strict mode
|
||||
logger.error("Unexpected error in safety validation: %s", e)
|
||||
if self._config.strict_mode:
|
||||
return await self._create_denial_result(
|
||||
action,
|
||||
[f"Safety validation error: {e}"],
|
||||
audit_events,
|
||||
)
|
||||
else:
|
||||
# Non-strict mode - allow with warning
|
||||
logger.warning("Non-strict mode: allowing action despite error")
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["Allowed despite validation error (non-strict mode)"],
|
||||
audit_events=audit_events,
|
||||
)
|
||||
|
||||
async def record_execution(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
result: ActionResult,
|
||||
) -> None:
|
||||
"""
|
||||
Record action execution result for auditing and tracking.
|
||||
|
||||
Args:
|
||||
action: The executed action
|
||||
result: The execution result
|
||||
"""
|
||||
if self._audit_logger:
|
||||
await self._audit_logger.log_action_executed(
|
||||
action,
|
||||
success=result.success,
|
||||
execution_time_ms=result.execution_time_ms,
|
||||
error=result.error,
|
||||
)
|
||||
|
||||
# Update cost tracking
|
||||
if self._cost_controller:
|
||||
# Track actual cost
|
||||
pass
|
||||
|
||||
# Update loop detection history
|
||||
if self._loop_detector:
|
||||
# Add to action history
|
||||
pass
|
||||
|
||||
async def rollback(self, checkpoint_id: str) -> bool:
|
||||
"""
|
||||
Rollback to a checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_id: ID of the checkpoint to rollback to
|
||||
|
||||
Returns:
|
||||
True if rollback succeeded
|
||||
"""
|
||||
if self._rollback_manager is None:
|
||||
logger.warning("Rollback manager not available")
|
||||
return False
|
||||
|
||||
# Delegate to rollback manager
|
||||
return await self._rollback_manager.rollback(checkpoint_id)
|
||||
|
||||
async def emergency_stop(
|
||||
self,
|
||||
stop_type: str = "kill",
|
||||
reason: str = "Manual emergency stop",
|
||||
triggered_by: str = "system",
|
||||
) -> None:
|
||||
"""
|
||||
Trigger emergency stop.
|
||||
|
||||
Args:
|
||||
stop_type: Type of stop (kill, pause, lockdown)
|
||||
reason: Reason for the stop
|
||||
triggered_by: Who triggered the stop
|
||||
"""
|
||||
logger.critical(
|
||||
"Emergency stop triggered: type=%s, reason=%s, by=%s",
|
||||
stop_type,
|
||||
reason,
|
||||
triggered_by,
|
||||
)
|
||||
|
||||
if self._audit_logger:
|
||||
await self._audit_logger.log_emergency_stop(
|
||||
stop_type=stop_type,
|
||||
triggered_by=triggered_by,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
if self._emergency_controls:
|
||||
await self._emergency_controls.execute_stop(stop_type)
|
||||
|
||||
def _get_policy(self, action: ActionRequest) -> SafetyPolicy:
|
||||
"""Get the effective policy for an action."""
|
||||
# Check cached policies
|
||||
autonomy_level = action.metadata.autonomy_level
|
||||
|
||||
if autonomy_level.value not in self._policies:
|
||||
self._policies[autonomy_level.value] = get_policy_for_autonomy_level(
|
||||
autonomy_level
|
||||
)
|
||||
|
||||
return self._policies[autonomy_level.value]
|
||||
|
||||
async def _check_permissions(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
policy: SafetyPolicy,
|
||||
) -> GuardianResult:
|
||||
"""Check if action is permitted."""
|
||||
reasons: list[str] = []
|
||||
|
||||
# Check denied tools
|
||||
if action.tool_name:
|
||||
for pattern in policy.denied_tools:
|
||||
if self._matches_pattern(action.tool_name, pattern):
|
||||
reasons.append(f"Tool '{action.tool_name}' denied by pattern '{pattern}'")
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DENY,
|
||||
reasons=reasons,
|
||||
)
|
||||
|
||||
# Check allowed tools (if not "*")
|
||||
if action.tool_name and "*" not in policy.allowed_tools:
|
||||
allowed = False
|
||||
for pattern in policy.allowed_tools:
|
||||
if self._matches_pattern(action.tool_name, pattern):
|
||||
allowed = True
|
||||
break
|
||||
if not allowed:
|
||||
reasons.append(f"Tool '{action.tool_name}' not in allowed list")
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DENY,
|
||||
reasons=reasons,
|
||||
)
|
||||
|
||||
# Check file patterns
|
||||
if action.resource:
|
||||
for pattern in policy.denied_file_patterns:
|
||||
if self._matches_pattern(action.resource, pattern):
|
||||
reasons.append(f"Resource '{action.resource}' denied by pattern '{pattern}'")
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DENY,
|
||||
reasons=reasons,
|
||||
)
|
||||
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["Permission check passed"],
|
||||
)
|
||||
|
||||
async def _check_budget(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
policy: SafetyPolicy,
|
||||
) -> GuardianResult:
|
||||
"""Check if action is within budget."""
|
||||
# TODO: Implement with CostController
|
||||
# For now, return allow
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["Budget check passed (not fully implemented)"],
|
||||
)
|
||||
|
||||
async def _check_rate_limit(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
policy: SafetyPolicy,
|
||||
) -> GuardianResult:
|
||||
"""Check if action is within rate limits."""
|
||||
# TODO: Implement with RateLimiter
|
||||
# For now, return allow
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["Rate limit check passed (not fully implemented)"],
|
||||
)
|
||||
|
||||
async def _check_loops(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
policy: SafetyPolicy,
|
||||
) -> GuardianResult:
|
||||
"""Check for action loops."""
|
||||
# TODO: Implement with LoopDetector
|
||||
# For now, return allow
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["Loop check passed (not fully implemented)"],
|
||||
)
|
||||
|
||||
async def _check_hitl(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
policy: SafetyPolicy,
|
||||
) -> GuardianResult:
|
||||
"""Check if human approval is required."""
|
||||
if not self._config.hitl_enabled:
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["HITL disabled"],
|
||||
)
|
||||
|
||||
# Check if action requires approval
|
||||
requires_approval = False
|
||||
for pattern in policy.require_approval_for:
|
||||
if pattern == "*":
|
||||
requires_approval = True
|
||||
break
|
||||
if action.tool_name and self._matches_pattern(action.tool_name, pattern):
|
||||
requires_approval = True
|
||||
break
|
||||
if action.action_type.value and self._matches_pattern(
|
||||
action.action_type.value, pattern
|
||||
):
|
||||
requires_approval = True
|
||||
break
|
||||
|
||||
if requires_approval:
|
||||
# TODO: Create approval request with HITLManager
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.REQUIRE_APPROVAL,
|
||||
reasons=["Action requires human approval"],
|
||||
approval_id=None, # Will be set by HITLManager
|
||||
)
|
||||
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["No approval required"],
|
||||
)
|
||||
|
||||
async def _create_checkpoint(self, action: ActionRequest) -> str | None:
|
||||
"""Create a checkpoint before destructive action."""
|
||||
if self._rollback_manager is None:
|
||||
logger.warning("Rollback manager not available - skipping checkpoint")
|
||||
return None
|
||||
|
||||
# TODO: Implement with RollbackManager
|
||||
return None
|
||||
|
||||
async def _create_denial_result(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
reasons: list[str],
|
||||
audit_events: list[Any],
|
||||
retry_after: float | None = None,
|
||||
) -> GuardianResult:
|
||||
"""Create a denial result with audit logging."""
|
||||
if self._audit_logger:
|
||||
event = await self._audit_logger.log_action_request(
|
||||
action, SafetyDecision.DENY, reasons
|
||||
)
|
||||
audit_events.append(event)
|
||||
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DENY,
|
||||
reasons=reasons,
|
||||
retry_after_seconds=retry_after,
|
||||
audit_events=audit_events,
|
||||
)
|
||||
|
||||
def _matches_pattern(self, value: str, pattern: str) -> bool:
|
||||
"""Check if value matches a pattern (supports * wildcard)."""
|
||||
if pattern == "*":
|
||||
return True
|
||||
|
||||
if "*" not in pattern:
|
||||
return value == pattern
|
||||
|
||||
# Simple wildcard matching
|
||||
if pattern.startswith("*") and pattern.endswith("*"):
|
||||
return pattern[1:-1] in value
|
||||
elif pattern.startswith("*"):
|
||||
return value.endswith(pattern[1:])
|
||||
elif pattern.endswith("*"):
|
||||
return value.startswith(pattern[:-1])
|
||||
else:
|
||||
# Pattern like "foo*bar"
|
||||
parts = pattern.split("*")
|
||||
if len(parts) == 2:
|
||||
return value.startswith(parts[0]) and value.endswith(parts[1])
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_guardian_instance: SafetyGuardian | None = None
|
||||
_guardian_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_safety_guardian() -> SafetyGuardian:
|
||||
"""Get the global SafetyGuardian instance."""
|
||||
global _guardian_instance
|
||||
|
||||
async with _guardian_lock:
|
||||
if _guardian_instance is None:
|
||||
_guardian_instance = SafetyGuardian()
|
||||
await _guardian_instance.initialize()
|
||||
|
||||
return _guardian_instance
|
||||
|
||||
|
||||
async def shutdown_safety_guardian() -> None:
|
||||
"""Shutdown the global SafetyGuardian."""
|
||||
global _guardian_instance
|
||||
|
||||
async with _guardian_lock:
|
||||
if _guardian_instance is not None:
|
||||
await _guardian_instance.shutdown()
|
||||
_guardian_instance = None
|
||||
|
||||
|
||||
def reset_safety_guardian() -> None:
|
||||
"""Reset the SafetyGuardian (for testing)."""
|
||||
global _guardian_instance
|
||||
_guardian_instance = None
|
||||
1
backend/app/services/safety/hitl/__init__.py
Normal file
1
backend/app/services/safety/hitl/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""${dir} module."""
|
||||
1
backend/app/services/safety/limits/__init__.py
Normal file
1
backend/app/services/safety/limits/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""${dir} module."""
|
||||
1
backend/app/services/safety/loops/__init__.py
Normal file
1
backend/app/services/safety/loops/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""${dir} module."""
|
||||
474
backend/app/services/safety/models.py
Normal file
474
backend/app/services/safety/models.py
Normal file
@@ -0,0 +1,474 @@
|
||||
"""
|
||||
Safety Framework Models
|
||||
|
||||
Core Pydantic models for actions, events, policies, and safety decisions.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# ============================================================================
|
||||
# Enums
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class ActionType(str, Enum):
|
||||
"""Types of actions that can be performed."""
|
||||
|
||||
TOOL_CALL = "tool_call"
|
||||
FILE_READ = "file_read"
|
||||
FILE_WRITE = "file_write"
|
||||
FILE_DELETE = "file_delete"
|
||||
API_CALL = "api_call"
|
||||
DATABASE_QUERY = "database_query"
|
||||
DATABASE_MUTATE = "database_mutate"
|
||||
GIT_OPERATION = "git_operation"
|
||||
SHELL_COMMAND = "shell_command"
|
||||
LLM_CALL = "llm_call"
|
||||
NETWORK_REQUEST = "network_request"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class ResourceType(str, Enum):
|
||||
"""Types of resources that can be accessed."""
|
||||
|
||||
FILE = "file"
|
||||
DATABASE = "database"
|
||||
API = "api"
|
||||
NETWORK = "network"
|
||||
GIT = "git"
|
||||
SHELL = "shell"
|
||||
LLM = "llm"
|
||||
MEMORY = "memory"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class PermissionLevel(str, Enum):
|
||||
"""Permission levels for resource access."""
|
||||
|
||||
NONE = "none"
|
||||
READ = "read"
|
||||
WRITE = "write"
|
||||
EXECUTE = "execute"
|
||||
DELETE = "delete"
|
||||
ADMIN = "admin"
|
||||
|
||||
|
||||
class AutonomyLevel(str, Enum):
|
||||
"""Autonomy levels for agent operation."""
|
||||
|
||||
FULL_CONTROL = "full_control" # Approve every action
|
||||
MILESTONE = "milestone" # Approve at milestones
|
||||
AUTONOMOUS = "autonomous" # Only major decisions
|
||||
|
||||
|
||||
class SafetyDecision(str, Enum):
|
||||
"""Result of safety validation."""
|
||||
|
||||
ALLOW = "allow"
|
||||
DENY = "deny"
|
||||
REQUIRE_APPROVAL = "require_approval"
|
||||
DELAY = "delay"
|
||||
SANDBOX = "sandbox"
|
||||
|
||||
|
||||
class ApprovalStatus(str, Enum):
|
||||
"""Status of approval request."""
|
||||
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
DENIED = "denied"
|
||||
TIMEOUT = "timeout"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class AuditEventType(str, Enum):
|
||||
"""Types of audit events."""
|
||||
|
||||
ACTION_REQUESTED = "action_requested"
|
||||
ACTION_VALIDATED = "action_validated"
|
||||
ACTION_DENIED = "action_denied"
|
||||
ACTION_EXECUTED = "action_executed"
|
||||
ACTION_FAILED = "action_failed"
|
||||
APPROVAL_REQUESTED = "approval_requested"
|
||||
APPROVAL_GRANTED = "approval_granted"
|
||||
APPROVAL_DENIED = "approval_denied"
|
||||
APPROVAL_TIMEOUT = "approval_timeout"
|
||||
CHECKPOINT_CREATED = "checkpoint_created"
|
||||
ROLLBACK_STARTED = "rollback_started"
|
||||
ROLLBACK_COMPLETED = "rollback_completed"
|
||||
ROLLBACK_FAILED = "rollback_failed"
|
||||
BUDGET_WARNING = "budget_warning"
|
||||
BUDGET_EXCEEDED = "budget_exceeded"
|
||||
RATE_LIMITED = "rate_limited"
|
||||
LOOP_DETECTED = "loop_detected"
|
||||
EMERGENCY_STOP = "emergency_stop"
|
||||
POLICY_VIOLATION = "policy_violation"
|
||||
CONTENT_FILTERED = "content_filtered"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Action Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class ActionMetadata(BaseModel):
|
||||
"""Metadata associated with an action."""
|
||||
|
||||
agent_id: str = Field(..., description="ID of the agent performing the action")
|
||||
project_id: str | None = Field(None, description="ID of the project context")
|
||||
session_id: str | None = Field(None, description="ID of the current session")
|
||||
task_id: str | None = Field(None, description="ID of the current task")
|
||||
parent_action_id: str | None = Field(None, description="ID of the parent action")
|
||||
correlation_id: str | None = Field(None, description="Correlation ID for tracing")
|
||||
user_id: str | None = Field(None, description="ID of the user who initiated")
|
||||
autonomy_level: AutonomyLevel = Field(
|
||||
default=AutonomyLevel.MILESTONE,
|
||||
description="Current autonomy level",
|
||||
)
|
||||
context: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Additional context",
|
||||
)
|
||||
|
||||
|
||||
class ActionRequest(BaseModel):
|
||||
"""Request to perform an action."""
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
action_type: ActionType = Field(..., description="Type of action to perform")
|
||||
tool_name: str | None = Field(None, description="Name of the tool to call")
|
||||
resource: str | None = Field(None, description="Resource being accessed")
|
||||
resource_type: ResourceType | None = Field(None, description="Type of resource")
|
||||
arguments: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Action arguments",
|
||||
)
|
||||
metadata: ActionMetadata = Field(..., description="Action metadata")
|
||||
estimated_cost_tokens: int = Field(0, description="Estimated token cost")
|
||||
estimated_cost_usd: float = Field(0.0, description="Estimated USD cost")
|
||||
is_destructive: bool = Field(False, description="Whether action is destructive")
|
||||
is_reversible: bool = Field(True, description="Whether action can be rolled back")
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class ActionResult(BaseModel):
|
||||
"""Result of an executed action."""
|
||||
|
||||
action_id: str = Field(..., description="ID of the action")
|
||||
success: bool = Field(..., description="Whether action succeeded")
|
||||
data: Any = Field(None, description="Action result data")
|
||||
error: str | None = Field(None, description="Error message if failed")
|
||||
error_code: str | None = Field(None, description="Error code if failed")
|
||||
execution_time_ms: float = Field(0.0, description="Execution time in ms")
|
||||
actual_cost_tokens: int = Field(0, description="Actual token cost")
|
||||
actual_cost_usd: float = Field(0.0, description="Actual USD cost")
|
||||
checkpoint_id: str | None = Field(None, description="Checkpoint ID if created")
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Validation Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class ValidationRule(BaseModel):
|
||||
"""A single validation rule."""
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
name: str = Field(..., description="Rule name")
|
||||
description: str | None = Field(None, description="Rule description")
|
||||
priority: int = Field(0, description="Rule priority (higher = evaluated first)")
|
||||
enabled: bool = Field(True, description="Whether rule is enabled")
|
||||
|
||||
# Rule conditions
|
||||
action_types: list[ActionType] | None = Field(
|
||||
None, description="Action types this rule applies to"
|
||||
)
|
||||
tool_patterns: list[str] | None = Field(
|
||||
None, description="Tool name patterns (supports wildcards)"
|
||||
)
|
||||
resource_patterns: list[str] | None = Field(
|
||||
None, description="Resource patterns (supports wildcards)"
|
||||
)
|
||||
agent_ids: list[str] | None = Field(
|
||||
None, description="Agent IDs this rule applies to"
|
||||
)
|
||||
|
||||
# Rule decision
|
||||
decision: SafetyDecision = Field(..., description="Decision when rule matches")
|
||||
reason: str | None = Field(None, description="Reason for decision")
|
||||
|
||||
|
||||
class ValidationResult(BaseModel):
|
||||
"""Result of action validation."""
|
||||
|
||||
action_id: str = Field(..., description="ID of the validated action")
|
||||
decision: SafetyDecision = Field(..., description="Validation decision")
|
||||
applied_rules: list[str] = Field(
|
||||
default_factory=list, description="IDs of applied rules"
|
||||
)
|
||||
reasons: list[str] = Field(
|
||||
default_factory=list, description="Reasons for decision"
|
||||
)
|
||||
approval_id: str | None = Field(None, description="Approval request ID if needed")
|
||||
retry_after_seconds: float | None = Field(
|
||||
None, description="Retry delay if rate limited"
|
||||
)
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Budget Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class BudgetScope(str, Enum):
|
||||
"""Scope of a budget limit."""
|
||||
|
||||
SESSION = "session"
|
||||
DAILY = "daily"
|
||||
WEEKLY = "weekly"
|
||||
MONTHLY = "monthly"
|
||||
PROJECT = "project"
|
||||
AGENT = "agent"
|
||||
|
||||
|
||||
class BudgetStatus(BaseModel):
|
||||
"""Current budget status."""
|
||||
|
||||
scope: BudgetScope = Field(..., description="Budget scope")
|
||||
scope_id: str = Field(..., description="ID within scope (session/agent/project)")
|
||||
tokens_used: int = Field(0, description="Tokens used in this scope")
|
||||
tokens_limit: int = Field(100000, description="Token limit for this scope")
|
||||
cost_used_usd: float = Field(0.0, description="USD spent in this scope")
|
||||
cost_limit_usd: float = Field(10.0, description="USD limit for this scope")
|
||||
tokens_remaining: int = Field(0, description="Remaining tokens")
|
||||
cost_remaining_usd: float = Field(0.0, description="Remaining USD budget")
|
||||
warning_threshold: float = Field(0.8, description="Warn at this usage fraction")
|
||||
is_warning: bool = Field(False, description="Whether at warning level")
|
||||
is_exceeded: bool = Field(False, description="Whether budget exceeded")
|
||||
reset_at: datetime | None = Field(None, description="When budget resets")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Rate Limit Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class RateLimitConfig(BaseModel):
|
||||
"""Configuration for a rate limit."""
|
||||
|
||||
name: str = Field(..., description="Rate limit name")
|
||||
limit: int = Field(..., description="Maximum allowed in window")
|
||||
window_seconds: int = Field(60, description="Time window in seconds")
|
||||
burst_limit: int | None = Field(None, description="Burst allowance")
|
||||
slowdown_threshold: float = Field(
|
||||
0.8, description="Start slowing at this fraction"
|
||||
)
|
||||
|
||||
|
||||
class RateLimitStatus(BaseModel):
|
||||
"""Current rate limit status."""
|
||||
|
||||
name: str = Field(..., description="Rate limit name")
|
||||
current_count: int = Field(0, description="Current count in window")
|
||||
limit: int = Field(..., description="Maximum allowed")
|
||||
window_seconds: int = Field(..., description="Time window")
|
||||
remaining: int = Field(..., description="Remaining in window")
|
||||
reset_at: datetime = Field(..., description="When window resets")
|
||||
is_limited: bool = Field(False, description="Whether currently limited")
|
||||
retry_after_seconds: float = Field(0.0, description="Seconds until retry")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Approval Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class ApprovalRequest(BaseModel):
|
||||
"""Request for human approval."""
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
action: ActionRequest = Field(..., description="Action requiring approval")
|
||||
reason: str = Field(..., description="Why approval is required")
|
||||
urgency: str = Field("normal", description="Urgency level")
|
||||
timeout_seconds: int = Field(300, description="Timeout for approval")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
expires_at: datetime | None = Field(None, description="When request expires")
|
||||
suggested_action: str | None = Field(None, description="Suggested response")
|
||||
context: dict[str, Any] = Field(default_factory=dict, description="Extra context")
|
||||
|
||||
|
||||
class ApprovalResponse(BaseModel):
|
||||
"""Response to an approval request."""
|
||||
|
||||
request_id: str = Field(..., description="ID of the approval request")
|
||||
status: ApprovalStatus = Field(..., description="Approval status")
|
||||
decided_by: str | None = Field(None, description="Who made the decision")
|
||||
reason: str | None = Field(None, description="Reason for decision")
|
||||
modifications: dict[str, Any] | None = Field(
|
||||
None, description="Modifications to action"
|
||||
)
|
||||
decided_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Checkpoint/Rollback Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class CheckpointType(str, Enum):
|
||||
"""Types of checkpoints."""
|
||||
|
||||
FILE = "file"
|
||||
DATABASE = "database"
|
||||
GIT = "git"
|
||||
COMPOSITE = "composite"
|
||||
|
||||
|
||||
class Checkpoint(BaseModel):
|
||||
"""A rollback checkpoint."""
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
checkpoint_type: CheckpointType = Field(..., description="Type of checkpoint")
|
||||
action_id: str = Field(..., description="Action this checkpoint is for")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
expires_at: datetime | None = Field(None, description="When checkpoint expires")
|
||||
data: dict[str, Any] = Field(default_factory=dict, description="Checkpoint data")
|
||||
description: str | None = Field(None, description="Description of checkpoint")
|
||||
is_valid: bool = Field(True, description="Whether checkpoint is still valid")
|
||||
|
||||
|
||||
class RollbackResult(BaseModel):
|
||||
"""Result of a rollback operation."""
|
||||
|
||||
checkpoint_id: str = Field(..., description="ID of checkpoint rolled back to")
|
||||
success: bool = Field(..., description="Whether rollback succeeded")
|
||||
actions_rolled_back: list[str] = Field(
|
||||
default_factory=list, description="IDs of rolled back actions"
|
||||
)
|
||||
failed_actions: list[str] = Field(
|
||||
default_factory=list, description="IDs of actions that failed to rollback"
|
||||
)
|
||||
error: str | None = Field(None, description="Error message if failed")
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Audit Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class AuditEvent(BaseModel):
|
||||
"""An audit log event."""
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
event_type: AuditEventType = Field(..., description="Type of audit event")
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
agent_id: str | None = Field(None, description="Agent ID if applicable")
|
||||
action_id: str | None = Field(None, description="Action ID if applicable")
|
||||
project_id: str | None = Field(None, description="Project ID if applicable")
|
||||
session_id: str | None = Field(None, description="Session ID if applicable")
|
||||
user_id: str | None = Field(None, description="User ID if applicable")
|
||||
decision: SafetyDecision | None = Field(None, description="Safety decision")
|
||||
details: dict[str, Any] = Field(default_factory=dict, description="Event details")
|
||||
correlation_id: str | None = Field(None, description="Correlation ID for tracing")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Policy Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class SafetyPolicy(BaseModel):
|
||||
"""A complete safety policy configuration."""
|
||||
|
||||
name: str = Field(..., description="Policy name")
|
||||
description: str | None = Field(None, description="Policy description")
|
||||
version: str = Field("1.0.0", description="Policy version")
|
||||
enabled: bool = Field(True, description="Whether policy is enabled")
|
||||
|
||||
# Cost controls
|
||||
max_tokens_per_session: int = Field(100_000, description="Max tokens per session")
|
||||
max_tokens_per_day: int = Field(1_000_000, description="Max tokens per day")
|
||||
max_cost_per_session_usd: float = Field(10.0, description="Max USD per session")
|
||||
max_cost_per_day_usd: float = Field(100.0, description="Max USD per day")
|
||||
|
||||
# Rate limits
|
||||
max_actions_per_minute: int = Field(60, description="Max actions per minute")
|
||||
max_llm_calls_per_minute: int = Field(20, description="Max LLM calls per minute")
|
||||
max_file_operations_per_minute: int = Field(
|
||||
100, description="Max file ops per minute"
|
||||
)
|
||||
|
||||
# Permissions
|
||||
allowed_tools: list[str] = Field(
|
||||
default_factory=lambda: ["*"],
|
||||
description="Allowed tool patterns",
|
||||
)
|
||||
denied_tools: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Denied tool patterns",
|
||||
)
|
||||
allowed_file_patterns: list[str] = Field(
|
||||
default_factory=lambda: ["**/*"],
|
||||
description="Allowed file patterns",
|
||||
)
|
||||
denied_file_patterns: list[str] = Field(
|
||||
default_factory=lambda: ["**/.env", "**/secrets/**"],
|
||||
description="Denied file patterns",
|
||||
)
|
||||
|
||||
# HITL
|
||||
require_approval_for: list[str] = Field(
|
||||
default_factory=lambda: [
|
||||
"delete_file",
|
||||
"push_to_remote",
|
||||
"deploy_to_production",
|
||||
"modify_critical_config",
|
||||
],
|
||||
description="Actions requiring approval",
|
||||
)
|
||||
|
||||
# Loop detection
|
||||
max_repeated_actions: int = Field(5, description="Max exact repetitions")
|
||||
max_similar_actions: int = Field(10, description="Max similar actions")
|
||||
|
||||
# Sandbox
|
||||
require_sandbox: bool = Field(False, description="Require sandbox execution")
|
||||
sandbox_timeout_seconds: int = Field(300, description="Sandbox timeout")
|
||||
sandbox_memory_mb: int = Field(1024, description="Sandbox memory limit")
|
||||
|
||||
# Validation rules
|
||||
validation_rules: list[ValidationRule] = Field(
|
||||
default_factory=list,
|
||||
description="Custom validation rules",
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Guardian Result Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class GuardianResult(BaseModel):
|
||||
"""Result of SafetyGuardian evaluation."""
|
||||
|
||||
action_id: str = Field(..., description="ID of the action")
|
||||
allowed: bool = Field(..., description="Whether action is allowed")
|
||||
decision: SafetyDecision = Field(..., description="Safety decision")
|
||||
reasons: list[str] = Field(default_factory=list, description="Decision reasons")
|
||||
approval_id: str | None = Field(None, description="Approval ID if needed")
|
||||
checkpoint_id: str | None = Field(None, description="Checkpoint ID if created")
|
||||
retry_after_seconds: float | None = Field(None, description="Retry delay")
|
||||
modified_action: ActionRequest | None = Field(
|
||||
None, description="Modified action if changed"
|
||||
)
|
||||
audit_events: list[AuditEvent] = Field(
|
||||
default_factory=list, description="Generated audit events"
|
||||
)
|
||||
1
backend/app/services/safety/permissions/__init__.py
Normal file
1
backend/app/services/safety/permissions/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""${dir} module."""
|
||||
1
backend/app/services/safety/policies/__init__.py
Normal file
1
backend/app/services/safety/policies/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""${dir} module."""
|
||||
1
backend/app/services/safety/rollback/__init__.py
Normal file
1
backend/app/services/safety/rollback/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""${dir} module."""
|
||||
1
backend/app/services/safety/sandbox/__init__.py
Normal file
1
backend/app/services/safety/sandbox/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""${dir} module."""
|
||||
1
backend/app/services/safety/validation/__init__.py
Normal file
1
backend/app/services/safety/validation/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""${dir} module."""
|
||||
Reference in New Issue
Block a user