From 498c0a0e94d92bf376fcf0259c2306dc44c8881f Mon Sep 17 00:00:00 2001 From: Felipe Cardoso Date: Sat, 3 Jan 2026 11:22:25 +0100 Subject: [PATCH] feat(backend): add safety framework foundation (Phase A) (#63) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Core safety framework architecture for autonomous agent guardrails: **Core Components:** - SafetyGuardian: Main orchestrator for all safety checks - AuditLogger: Comprehensive audit logging with hash chain tamper detection - SafetyConfig: Pydantic-based configuration - Models: Action requests, validation results, policies, checkpoints **Exception Hierarchy:** - SafetyError base with context preservation - Permission, Budget, RateLimit, Loop errors - Approval workflow errors (Required, Denied, Timeout) - Rollback, Sandbox, Emergency exceptions **Safety Policy System:** - Autonomy level based policies (FULL_CONTROL, MILESTONE, AUTONOMOUS) - Cost limits, rate limits, permission patterns - HITL approval requirements per action type - Configurable loop detection thresholds **Directory Structure:** - validation/, costs/, limits/, loops/ - Control subsystems - permissions/, rollback/, hitl/ - Access and recovery - content/, sandbox/, emergency/ - Protection systems - audit/, policies/ - Logging and configuration Phase A establishes the architecture. Subsystems to be implemented in Phase B-C. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- backend/app/services/safety/__init__.py | 170 +++++ backend/app/services/safety/audit/__init__.py | 19 + backend/app/services/safety/audit/logger.py | 585 +++++++++++++++++ backend/app/services/safety/config.py | 300 +++++++++ .../app/services/safety/content/__init__.py | 1 + backend/app/services/safety/costs/__init__.py | 1 + .../app/services/safety/emergency/__init__.py | 1 + backend/app/services/safety/exceptions.py | 277 ++++++++ backend/app/services/safety/guardian.py | 614 ++++++++++++++++++ backend/app/services/safety/hitl/__init__.py | 1 + .../app/services/safety/limits/__init__.py | 1 + backend/app/services/safety/loops/__init__.py | 1 + backend/app/services/safety/models.py | 474 ++++++++++++++ .../services/safety/permissions/__init__.py | 1 + .../app/services/safety/policies/__init__.py | 1 + .../app/services/safety/rollback/__init__.py | 1 + .../app/services/safety/sandbox/__init__.py | 1 + .../services/safety/validation/__init__.py | 1 + 18 files changed, 2450 insertions(+) create mode 100644 backend/app/services/safety/__init__.py create mode 100644 backend/app/services/safety/audit/__init__.py create mode 100644 backend/app/services/safety/audit/logger.py create mode 100644 backend/app/services/safety/config.py create mode 100644 backend/app/services/safety/content/__init__.py create mode 100644 backend/app/services/safety/costs/__init__.py create mode 100644 backend/app/services/safety/emergency/__init__.py create mode 100644 backend/app/services/safety/exceptions.py create mode 100644 backend/app/services/safety/guardian.py create mode 100644 backend/app/services/safety/hitl/__init__.py create mode 100644 backend/app/services/safety/limits/__init__.py create mode 100644 backend/app/services/safety/loops/__init__.py create mode 100644 backend/app/services/safety/models.py create mode 100644 backend/app/services/safety/permissions/__init__.py create mode 100644 backend/app/services/safety/policies/__init__.py create mode 100644 backend/app/services/safety/rollback/__init__.py create mode 100644 backend/app/services/safety/sandbox/__init__.py create mode 100644 backend/app/services/safety/validation/__init__.py diff --git a/backend/app/services/safety/__init__.py b/backend/app/services/safety/__init__.py new file mode 100644 index 0000000..6c50e47 --- /dev/null +++ b/backend/app/services/safety/__init__.py @@ -0,0 +1,170 @@ +""" +Safety and Guardrails Framework + +Comprehensive safety framework for autonomous agent operation. +Provides multi-layered protection including: +- Pre-execution validation +- Cost and budget controls +- Rate limiting +- Loop detection and prevention +- Human-in-the-loop approval +- Rollback and checkpointing +- Content filtering +- Sandboxed execution +- Emergency controls +- Complete audit trail + +Usage: + from app.services.safety import get_safety_guardian, SafetyGuardian + + guardian = await get_safety_guardian() + result = await guardian.validate(action_request) + + if result.allowed: + # Execute action + pass + else: + # Handle denial + print(f"Action denied: {result.reasons}") +""" + +# Exceptions +# Audit +from .audit import ( + AuditLogger, + get_audit_logger, + reset_audit_logger, + shutdown_audit_logger, +) + +# Configuration +from .config import ( + AutonomyConfig, + SafetyConfig, + get_autonomy_config, + get_default_policy, + get_policy_for_autonomy_level, + get_safety_config, + load_policies_from_directory, + load_policy_from_file, + reset_config_cache, +) +from .exceptions import ( + ApprovalDeniedError, + ApprovalRequiredError, + ApprovalTimeoutError, + BudgetExceededError, + CheckpointError, + ContentFilterError, + EmergencyStopError, + LoopDetectedError, + PermissionDeniedError, + PolicyViolationError, + RateLimitExceededError, + RollbackError, + SafetyError, + SandboxError, + SandboxTimeoutError, + ValidationError, +) + +# Guardian +from .guardian import ( + SafetyGuardian, + get_safety_guardian, + reset_safety_guardian, + shutdown_safety_guardian, +) + +# Models +from .models import ( + ActionMetadata, + ActionRequest, + ActionResult, + ActionType, + ApprovalRequest, + ApprovalResponse, + ApprovalStatus, + AuditEvent, + AuditEventType, + AutonomyLevel, + BudgetScope, + BudgetStatus, + Checkpoint, + CheckpointType, + GuardianResult, + PermissionLevel, + RateLimitConfig, + RateLimitStatus, + ResourceType, + RollbackResult, + SafetyDecision, + SafetyPolicy, + ValidationResult, + ValidationRule, +) + +__all__ = [ + "ActionMetadata", + "ActionRequest", + "ActionResult", + # Models + "ActionType", + "ApprovalDeniedError", + "ApprovalRequest", + "ApprovalRequiredError", + "ApprovalResponse", + "ApprovalStatus", + "ApprovalTimeoutError", + "AuditEvent", + "AuditEventType", + # Audit + "AuditLogger", + "AutonomyConfig", + "AutonomyLevel", + "BudgetExceededError", + "BudgetScope", + "BudgetStatus", + "Checkpoint", + "CheckpointError", + "CheckpointType", + "ContentFilterError", + "EmergencyStopError", + "GuardianResult", + "LoopDetectedError", + "PermissionDeniedError", + "PermissionLevel", + "PolicyViolationError", + "RateLimitConfig", + "RateLimitExceededError", + "RateLimitStatus", + "ResourceType", + "RollbackError", + "RollbackResult", + # Configuration + "SafetyConfig", + "SafetyDecision", + # Exceptions + "SafetyError", + # Guardian + "SafetyGuardian", + "SafetyPolicy", + "SandboxError", + "SandboxTimeoutError", + "ValidationError", + "ValidationResult", + "ValidationRule", + "get_audit_logger", + "get_autonomy_config", + "get_default_policy", + "get_policy_for_autonomy_level", + "get_safety_config", + "get_safety_guardian", + "load_policies_from_directory", + "load_policy_from_file", + "reset_audit_logger", + "reset_config_cache", + "reset_safety_guardian", + "shutdown_audit_logger", + "shutdown_safety_guardian", +] diff --git a/backend/app/services/safety/audit/__init__.py b/backend/app/services/safety/audit/__init__.py new file mode 100644 index 0000000..cf456e7 --- /dev/null +++ b/backend/app/services/safety/audit/__init__.py @@ -0,0 +1,19 @@ +""" +Audit System + +Comprehensive audit logging for all safety-related events. +""" + +from .logger import ( + AuditLogger, + get_audit_logger, + reset_audit_logger, + shutdown_audit_logger, +) + +__all__ = [ + "AuditLogger", + "get_audit_logger", + "reset_audit_logger", + "shutdown_audit_logger", +] diff --git a/backend/app/services/safety/audit/logger.py b/backend/app/services/safety/audit/logger.py new file mode 100644 index 0000000..f52babe --- /dev/null +++ b/backend/app/services/safety/audit/logger.py @@ -0,0 +1,585 @@ +""" +Audit Logger + +Comprehensive audit logging for all safety-related events. +Provides tamper detection, structured logging, and compliance support. +""" + +import asyncio +import hashlib +import json +import logging +from collections import deque +from datetime import datetime, timedelta +from typing import Any +from uuid import uuid4 + +from ..config import get_safety_config +from ..models import ( + ActionRequest, + AuditEvent, + AuditEventType, + SafetyDecision, +) + +logger = logging.getLogger(__name__) + + +class AuditLogger: + """ + Audit logger for safety events. + + Features: + - Structured event logging + - In-memory buffer with async flush + - Tamper detection via hash chains + - Query/search capability + - Retention policy enforcement + """ + + def __init__( + self, + max_buffer_size: int = 1000, + flush_interval_seconds: float = 10.0, + enable_hash_chain: bool = True, + ) -> None: + """ + Initialize the audit logger. + + Args: + max_buffer_size: Maximum events to buffer before auto-flush + flush_interval_seconds: Interval for periodic flush + enable_hash_chain: Enable tamper detection via hash chain + """ + self._buffer: deque[AuditEvent] = deque(maxlen=max_buffer_size) + self._persisted: list[AuditEvent] = [] + self._flush_interval = flush_interval_seconds + self._enable_hash_chain = enable_hash_chain + self._last_hash: str | None = None + self._lock = asyncio.Lock() + self._flush_task: asyncio.Task[None] | None = None + self._running = False + + # Event handlers for real-time processing + self._handlers: list[Any] = [] + + config = get_safety_config() + self._retention_days = config.audit_retention_days + self._include_sensitive = config.audit_include_sensitive + + async def start(self) -> None: + """Start the audit logger background tasks.""" + if self._running: + return + + self._running = True + self._flush_task = asyncio.create_task(self._periodic_flush()) + logger.info("Audit logger started") + + async def stop(self) -> None: + """Stop the audit logger and flush remaining events.""" + self._running = False + + if self._flush_task: + self._flush_task.cancel() + try: + await self._flush_task + except asyncio.CancelledError: + pass + + # Final flush + await self.flush() + logger.info("Audit logger stopped") + + async def log( + self, + event_type: AuditEventType, + *, + agent_id: str | None = None, + action_id: str | None = None, + project_id: str | None = None, + session_id: str | None = None, + user_id: str | None = None, + decision: SafetyDecision | None = None, + details: dict[str, Any] | None = None, + correlation_id: str | None = None, + ) -> AuditEvent: + """ + Log an audit event. + + Args: + event_type: Type of audit event + agent_id: Agent ID if applicable + action_id: Action ID if applicable + project_id: Project ID if applicable + session_id: Session ID if applicable + user_id: User ID if applicable + decision: Safety decision if applicable + details: Additional event details + correlation_id: Correlation ID for tracing + + Returns: + The created audit event + """ + # Sanitize sensitive data if needed + sanitized_details = self._sanitize_details(details) if details else {} + + event = AuditEvent( + id=str(uuid4()), + event_type=event_type, + timestamp=datetime.utcnow(), + agent_id=agent_id, + action_id=action_id, + project_id=project_id, + session_id=session_id, + user_id=user_id, + decision=decision, + details=sanitized_details, + correlation_id=correlation_id, + ) + + async with self._lock: + # Add hash chain for tamper detection + if self._enable_hash_chain: + event_hash = self._compute_hash(event) + sanitized_details["_hash"] = event_hash + sanitized_details["_prev_hash"] = self._last_hash + self._last_hash = event_hash + + self._buffer.append(event) + + # Notify handlers + await self._notify_handlers(event) + + # Log to standard logger as well + self._log_to_logger(event) + + return event + + async def log_action_request( + self, + action: ActionRequest, + decision: SafetyDecision, + reasons: list[str] | None = None, + ) -> AuditEvent: + """Log an action request with its validation decision.""" + event_type = ( + AuditEventType.ACTION_DENIED + if decision == SafetyDecision.DENY + else AuditEventType.ACTION_VALIDATED + ) + + return await self.log( + event_type, + agent_id=action.metadata.agent_id, + action_id=action.id, + project_id=action.metadata.project_id, + session_id=action.metadata.session_id, + user_id=action.metadata.user_id, + decision=decision, + details={ + "action_type": action.action_type.value, + "tool_name": action.tool_name, + "resource": action.resource, + "is_destructive": action.is_destructive, + "reasons": reasons or [], + }, + correlation_id=action.metadata.correlation_id, + ) + + async def log_action_executed( + self, + action: ActionRequest, + success: bool, + execution_time_ms: float, + error: str | None = None, + ) -> AuditEvent: + """Log an action execution result.""" + event_type = ( + AuditEventType.ACTION_EXECUTED + if success + else AuditEventType.ACTION_FAILED + ) + + return await self.log( + event_type, + agent_id=action.metadata.agent_id, + action_id=action.id, + project_id=action.metadata.project_id, + session_id=action.metadata.session_id, + decision=SafetyDecision.ALLOW if success else SafetyDecision.DENY, + details={ + "action_type": action.action_type.value, + "tool_name": action.tool_name, + "success": success, + "execution_time_ms": execution_time_ms, + "error": error, + }, + correlation_id=action.metadata.correlation_id, + ) + + async def log_approval_event( + self, + event_type: AuditEventType, + approval_id: str, + action: ActionRequest, + decided_by: str | None = None, + reason: str | None = None, + ) -> AuditEvent: + """Log an approval-related event.""" + return await self.log( + event_type, + agent_id=action.metadata.agent_id, + action_id=action.id, + project_id=action.metadata.project_id, + session_id=action.metadata.session_id, + user_id=decided_by, + details={ + "approval_id": approval_id, + "action_type": action.action_type.value, + "tool_name": action.tool_name, + "decided_by": decided_by, + "reason": reason, + }, + correlation_id=action.metadata.correlation_id, + ) + + async def log_budget_event( + self, + event_type: AuditEventType, + agent_id: str, + scope: str, + current_usage: float, + limit: float, + unit: str = "tokens", + ) -> AuditEvent: + """Log a budget-related event.""" + return await self.log( + event_type, + agent_id=agent_id, + details={ + "scope": scope, + "current_usage": current_usage, + "limit": limit, + "unit": unit, + "usage_percent": (current_usage / limit * 100) if limit > 0 else 0, + }, + ) + + async def log_emergency_stop( + self, + stop_type: str, + triggered_by: str, + reason: str, + affected_agents: list[str] | None = None, + ) -> AuditEvent: + """Log an emergency stop event.""" + return await self.log( + AuditEventType.EMERGENCY_STOP, + user_id=triggered_by, + details={ + "stop_type": stop_type, + "triggered_by": triggered_by, + "reason": reason, + "affected_agents": affected_agents or [], + }, + ) + + async def flush(self) -> int: + """ + Flush buffered events to persistent storage. + + Returns: + Number of events flushed + """ + async with self._lock: + if not self._buffer: + return 0 + + events = list(self._buffer) + self._buffer.clear() + + # Persist events (in production, this would go to database/storage) + self._persisted.extend(events) + + # Enforce retention + self._enforce_retention() + + logger.debug("Flushed %d audit events", len(events)) + return len(events) + + async def query( + self, + *, + event_types: list[AuditEventType] | None = None, + agent_id: str | None = None, + action_id: str | None = None, + project_id: str | None = None, + session_id: str | None = None, + user_id: str | None = None, + start_time: datetime | None = None, + end_time: datetime | None = None, + correlation_id: str | None = None, + limit: int = 100, + offset: int = 0, + ) -> list[AuditEvent]: + """ + Query audit events with filters. + + Args: + event_types: Filter by event types + agent_id: Filter by agent ID + action_id: Filter by action ID + project_id: Filter by project ID + session_id: Filter by session ID + user_id: Filter by user ID + start_time: Filter events after this time + end_time: Filter events before this time + correlation_id: Filter by correlation ID + limit: Maximum results to return + offset: Result offset for pagination + + Returns: + List of matching audit events + """ + # Combine buffer and persisted for query + all_events = list(self._persisted) + list(self._buffer) + + results = [] + for event in all_events: + if event_types and event.event_type not in event_types: + continue + if agent_id and event.agent_id != agent_id: + continue + if action_id and event.action_id != action_id: + continue + if project_id and event.project_id != project_id: + continue + if session_id and event.session_id != session_id: + continue + if user_id and event.user_id != user_id: + continue + if start_time and event.timestamp < start_time: + continue + if end_time and event.timestamp > end_time: + continue + if correlation_id and event.correlation_id != correlation_id: + continue + + results.append(event) + + # Sort by timestamp descending + results.sort(key=lambda e: e.timestamp, reverse=True) + + # Apply pagination + return results[offset : offset + limit] + + async def get_action_history( + self, + agent_id: str, + limit: int = 100, + ) -> list[AuditEvent]: + """Get action history for an agent.""" + return await self.query( + agent_id=agent_id, + event_types=[ + AuditEventType.ACTION_REQUESTED, + AuditEventType.ACTION_VALIDATED, + AuditEventType.ACTION_DENIED, + AuditEventType.ACTION_EXECUTED, + AuditEventType.ACTION_FAILED, + ], + limit=limit, + ) + + async def verify_integrity(self) -> tuple[bool, list[str]]: + """ + Verify audit log integrity using hash chain. + + Returns: + Tuple of (is_valid, list of issues found) + """ + if not self._enable_hash_chain: + return True, [] + + issues: list[str] = [] + all_events = list(self._persisted) + list(self._buffer) + + prev_hash: str | None = None + for event in sorted(all_events, key=lambda e: e.timestamp): + stored_prev = event.details.get("_prev_hash") + stored_hash = event.details.get("_hash") + + if stored_prev != prev_hash: + issues.append( + f"Hash chain broken at event {event.id}: " + f"expected prev_hash={prev_hash}, got {stored_prev}" + ) + + if stored_hash: + computed = self._compute_hash(event) + if computed != stored_hash: + issues.append( + f"Hash mismatch at event {event.id}: " + f"expected {computed}, got {stored_hash}" + ) + + prev_hash = stored_hash + + return len(issues) == 0, issues + + def add_handler(self, handler: Any) -> None: + """Add a real-time event handler.""" + self._handlers.append(handler) + + def remove_handler(self, handler: Any) -> None: + """Remove an event handler.""" + if handler in self._handlers: + self._handlers.remove(handler) + + def _sanitize_details(self, details: dict[str, Any]) -> dict[str, Any]: + """Sanitize sensitive data from details.""" + if self._include_sensitive: + return details + + sanitized: dict[str, Any] = {} + sensitive_keys = { + "password", + "secret", + "token", + "api_key", + "apikey", + "auth", + "credential", + } + + for key, value in details.items(): + lower_key = key.lower() + if any(s in lower_key for s in sensitive_keys): + sanitized[key] = "[REDACTED]" + elif isinstance(value, dict): + sanitized[key] = self._sanitize_details(value) + else: + sanitized[key] = value + + return sanitized + + def _compute_hash(self, event: AuditEvent) -> str: + """Compute hash for an event (excluding hash fields).""" + data = { + "id": event.id, + "event_type": event.event_type.value, + "timestamp": event.timestamp.isoformat(), + "agent_id": event.agent_id, + "action_id": event.action_id, + "project_id": event.project_id, + "session_id": event.session_id, + "user_id": event.user_id, + "decision": event.decision.value if event.decision else None, + "details": { + k: v + for k, v in event.details.items() + if not k.startswith("_") + }, + "correlation_id": event.correlation_id, + } + + if self._last_hash: + data["_prev_hash"] = self._last_hash + + serialized = json.dumps(data, sort_keys=True, default=str) + return hashlib.sha256(serialized.encode()).hexdigest() + + def _log_to_logger(self, event: AuditEvent) -> None: + """Log event to standard Python logger.""" + log_data = { + "audit_event": event.event_type.value, + "event_id": event.id, + "agent_id": event.agent_id, + "action_id": event.action_id, + "decision": event.decision.value if event.decision else None, + } + + # Use appropriate log level based on event type + if event.event_type in { + AuditEventType.ACTION_DENIED, + AuditEventType.POLICY_VIOLATION, + AuditEventType.EMERGENCY_STOP, + }: + logger.warning("Audit: %s", log_data) + elif event.event_type in { + AuditEventType.ACTION_FAILED, + AuditEventType.ROLLBACK_FAILED, + }: + logger.error("Audit: %s", log_data) + else: + logger.info("Audit: %s", log_data) + + def _enforce_retention(self) -> None: + """Enforce retention policy on persisted events.""" + if not self._retention_days: + return + + cutoff = datetime.utcnow() - timedelta(days=self._retention_days) + before_count = len(self._persisted) + + self._persisted = [e for e in self._persisted if e.timestamp >= cutoff] + + removed = before_count - len(self._persisted) + if removed > 0: + logger.info("Removed %d expired audit events", removed) + + async def _periodic_flush(self) -> None: + """Background task for periodic flushing.""" + while self._running: + try: + await asyncio.sleep(self._flush_interval) + await self.flush() + except asyncio.CancelledError: + break + except Exception as e: + logger.error("Error in periodic audit flush: %s", e) + + async def _notify_handlers(self, event: AuditEvent) -> None: + """Notify all registered handlers of a new event.""" + for handler in self._handlers: + try: + if asyncio.iscoroutinefunction(handler): + await handler(event) + else: + handler(event) + except Exception as e: + logger.error("Error in audit event handler: %s", e) + + +# Singleton instance +_audit_logger: AuditLogger | None = None +_audit_lock = asyncio.Lock() + + +async def get_audit_logger() -> AuditLogger: + """Get the global audit logger instance.""" + global _audit_logger + + async with _audit_lock: + if _audit_logger is None: + _audit_logger = AuditLogger() + await _audit_logger.start() + + return _audit_logger + + +async def shutdown_audit_logger() -> None: + """Shutdown the global audit logger.""" + global _audit_logger + + async with _audit_lock: + if _audit_logger is not None: + await _audit_logger.stop() + _audit_logger = None + + +def reset_audit_logger() -> None: + """Reset the audit logger (for testing).""" + global _audit_logger + _audit_logger = None diff --git a/backend/app/services/safety/config.py b/backend/app/services/safety/config.py new file mode 100644 index 0000000..2488895 --- /dev/null +++ b/backend/app/services/safety/config.py @@ -0,0 +1,300 @@ +""" +Safety Framework Configuration + +Pydantic settings for the safety and guardrails framework. +""" + +import logging +import os +from functools import lru_cache +from pathlib import Path +from typing import Any + +import yaml +from pydantic import Field +from pydantic_settings import BaseSettings, SettingsConfigDict + +from .models import AutonomyLevel, SafetyPolicy + +logger = logging.getLogger(__name__) + + +class SafetyConfig(BaseSettings): + """Configuration for the safety framework.""" + + model_config = SettingsConfigDict( + env_prefix="SAFETY_", + env_file=".env", + env_file_encoding="utf-8", + extra="ignore", + ) + + # General settings + enabled: bool = Field(True, description="Enable safety framework") + strict_mode: bool = Field( + True, description="Strict mode (fail closed on errors)" + ) + log_level: str = Field("INFO", description="Logging level") + + # Default autonomy level + default_autonomy_level: AutonomyLevel = Field( + AutonomyLevel.MILESTONE, + description="Default autonomy level for new agents", + ) + + # Default budget limits + default_session_token_budget: int = Field( + 100_000, description="Default tokens per session" + ) + default_daily_token_budget: int = Field( + 1_000_000, description="Default tokens per day" + ) + default_session_cost_limit: float = Field( + 10.0, description="Default USD per session" + ) + default_daily_cost_limit: float = Field(100.0, description="Default USD per day") + + # Default rate limits + default_actions_per_minute: int = Field(60, description="Default actions per min") + default_llm_calls_per_minute: int = Field(20, description="Default LLM calls/min") + default_file_ops_per_minute: int = Field(100, description="Default file ops/min") + + # Loop detection + loop_detection_enabled: bool = Field(True, description="Enable loop detection") + max_repeated_actions: int = Field(5, description="Max exact repetitions") + max_similar_actions: int = Field(10, description="Max similar actions") + loop_history_size: int = Field(100, description="Action history size for loops") + + # HITL settings + hitl_enabled: bool = Field(True, description="Enable human-in-the-loop") + hitl_default_timeout: int = Field(300, description="Default approval timeout (s)") + hitl_notification_channels: list[str] = Field( + default_factory=list, description="Notification channels" + ) + + # Rollback settings + rollback_enabled: bool = Field(True, description="Enable rollback capability") + checkpoint_retention_hours: int = Field(24, description="Checkpoint retention") + auto_checkpoint_destructive: bool = Field( + True, description="Auto-checkpoint destructive actions" + ) + + # Sandbox settings + sandbox_enabled: bool = Field(False, description="Enable sandbox execution") + sandbox_timeout: int = Field(300, description="Sandbox timeout (s)") + sandbox_memory_mb: int = Field(1024, description="Sandbox memory limit (MB)") + sandbox_cpu_limit: float = Field(1.0, description="Sandbox CPU limit") + sandbox_network_enabled: bool = Field(False, description="Allow sandbox network") + + # Audit settings + audit_enabled: bool = Field(True, description="Enable audit logging") + audit_retention_days: int = Field(90, description="Audit log retention (days)") + audit_include_sensitive: bool = Field( + False, description="Include sensitive data in audit" + ) + + # Content filtering + content_filter_enabled: bool = Field(True, description="Enable content filtering") + filter_pii: bool = Field(True, description="Filter PII") + filter_secrets: bool = Field(True, description="Filter secrets") + + # Emergency controls + emergency_stop_enabled: bool = Field(True, description="Enable emergency stop") + emergency_webhook_url: str | None = Field(None, description="Emergency webhook") + + # Policy file path + policy_file: str | None = Field(None, description="Path to policy YAML file") + + # Validation cache + validation_cache_ttl: int = Field(60, description="Validation cache TTL (s)") + validation_cache_size: int = Field(1000, description="Validation cache size") + + +class AutonomyConfig(BaseSettings): + """Configuration for autonomy levels.""" + + model_config = SettingsConfigDict( + env_prefix="AUTONOMY_", + env_file=".env", + env_file_encoding="utf-8", + extra="ignore", + ) + + # FULL_CONTROL settings + full_control_cost_limit: float = Field(1.0, description="USD limit per session") + full_control_require_all_approval: bool = Field( + True, description="Require approval for all" + ) + full_control_block_destructive: bool = Field( + True, description="Block destructive actions" + ) + + # MILESTONE settings + milestone_cost_limit: float = Field(10.0, description="USD limit per session") + milestone_require_critical_approval: bool = Field( + True, description="Require approval for critical" + ) + milestone_auto_checkpoint: bool = Field( + True, description="Auto-checkpoint destructive" + ) + + # AUTONOMOUS settings + autonomous_cost_limit: float = Field(100.0, description="USD limit per session") + autonomous_auto_approve_normal: bool = Field( + True, description="Auto-approve normal actions" + ) + autonomous_auto_checkpoint: bool = Field(True, description="Auto-checkpoint all") + + +def _expand_env_vars(value: Any) -> Any: + """Recursively expand environment variables in values.""" + if isinstance(value, str): + return os.path.expandvars(value) + elif isinstance(value, dict): + return {k: _expand_env_vars(v) for k, v in value.items()} + elif isinstance(value, list): + return [_expand_env_vars(v) for v in value] + return value + + +def load_policy_from_file(file_path: str | Path) -> SafetyPolicy | None: + """Load a safety policy from a YAML file.""" + path = Path(file_path) + if not path.exists(): + logger.warning("Policy file not found: %s", path) + return None + + try: + with open(path) as f: + data = yaml.safe_load(f) + + if data is None: + logger.warning("Empty policy file: %s", path) + return None + + # Expand environment variables + data = _expand_env_vars(data) + + return SafetyPolicy(**data) + + except Exception as e: + logger.error("Failed to load policy file %s: %s", path, e) + return None + + +def load_policies_from_directory(directory: str | Path) -> dict[str, SafetyPolicy]: + """Load all safety policies from a directory.""" + policies: dict[str, SafetyPolicy] = {} + path = Path(directory) + + if not path.exists() or not path.is_dir(): + logger.warning("Policy directory not found: %s", path) + return policies + + for file_path in path.glob("*.yaml"): + policy = load_policy_from_file(file_path) + if policy: + policies[policy.name] = policy + logger.info("Loaded policy: %s from %s", policy.name, file_path.name) + + return policies + + +@lru_cache(maxsize=1) +def get_safety_config() -> SafetyConfig: + """Get the safety configuration (cached singleton).""" + return SafetyConfig() + + +@lru_cache(maxsize=1) +def get_autonomy_config() -> AutonomyConfig: + """Get the autonomy configuration (cached singleton).""" + return AutonomyConfig() + + +def get_default_policy() -> SafetyPolicy: + """Get the default safety policy.""" + config = get_safety_config() + + return SafetyPolicy( + name="default", + description="Default safety policy", + max_tokens_per_session=config.default_session_token_budget, + max_tokens_per_day=config.default_daily_token_budget, + max_cost_per_session_usd=config.default_session_cost_limit, + max_cost_per_day_usd=config.default_daily_cost_limit, + max_actions_per_minute=config.default_actions_per_minute, + max_llm_calls_per_minute=config.default_llm_calls_per_minute, + max_file_operations_per_minute=config.default_file_ops_per_minute, + max_repeated_actions=config.max_repeated_actions, + max_similar_actions=config.max_similar_actions, + require_sandbox=config.sandbox_enabled, + sandbox_timeout_seconds=config.sandbox_timeout, + sandbox_memory_mb=config.sandbox_memory_mb, + ) + + +def get_policy_for_autonomy_level(level: AutonomyLevel) -> SafetyPolicy: + """Get the safety policy for a given autonomy level.""" + autonomy = get_autonomy_config() + + base_policy = get_default_policy() + + if level == AutonomyLevel.FULL_CONTROL: + return SafetyPolicy( + name="full_control", + description="Full control mode - all actions require approval", + max_cost_per_session_usd=autonomy.full_control_cost_limit, + max_cost_per_day_usd=autonomy.full_control_cost_limit * 10, + require_approval_for=["*"], # All actions + max_tokens_per_session=base_policy.max_tokens_per_session // 10, + max_tokens_per_day=base_policy.max_tokens_per_day // 10, + max_actions_per_minute=base_policy.max_actions_per_minute // 2, + max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute // 2, + max_file_operations_per_minute=base_policy.max_file_operations_per_minute // 2, + denied_tools=["delete_*", "destroy_*", "drop_*"], + ) + + elif level == AutonomyLevel.MILESTONE: + return SafetyPolicy( + name="milestone", + description="Milestone mode - approval at milestones only", + max_cost_per_session_usd=autonomy.milestone_cost_limit, + max_cost_per_day_usd=autonomy.milestone_cost_limit * 10, + require_approval_for=[ + "delete_file", + "push_to_remote", + "deploy_*", + "modify_critical_*", + "create_pull_request", + ], + max_tokens_per_session=base_policy.max_tokens_per_session, + max_tokens_per_day=base_policy.max_tokens_per_day, + max_actions_per_minute=base_policy.max_actions_per_minute, + max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute, + max_file_operations_per_minute=base_policy.max_file_operations_per_minute, + ) + + else: # AUTONOMOUS + return SafetyPolicy( + name="autonomous", + description="Autonomous mode - minimal intervention", + max_cost_per_session_usd=autonomy.autonomous_cost_limit, + max_cost_per_day_usd=autonomy.autonomous_cost_limit * 10, + require_approval_for=[ + "deploy_to_production", + "delete_repository", + "modify_production_config", + ], + max_tokens_per_session=base_policy.max_tokens_per_session * 5, + max_tokens_per_day=base_policy.max_tokens_per_day * 5, + max_actions_per_minute=base_policy.max_actions_per_minute * 2, + max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute * 2, + max_file_operations_per_minute=base_policy.max_file_operations_per_minute * 2, + ) + + +def reset_config_cache() -> None: + """Reset configuration caches (for testing).""" + get_safety_config.cache_clear() + get_autonomy_config.cache_clear() diff --git a/backend/app/services/safety/content/__init__.py b/backend/app/services/safety/content/__init__.py new file mode 100644 index 0000000..9f4729c --- /dev/null +++ b/backend/app/services/safety/content/__init__.py @@ -0,0 +1 @@ +"""${dir} module.""" diff --git a/backend/app/services/safety/costs/__init__.py b/backend/app/services/safety/costs/__init__.py new file mode 100644 index 0000000..9f4729c --- /dev/null +++ b/backend/app/services/safety/costs/__init__.py @@ -0,0 +1 @@ +"""${dir} module.""" diff --git a/backend/app/services/safety/emergency/__init__.py b/backend/app/services/safety/emergency/__init__.py new file mode 100644 index 0000000..9f4729c --- /dev/null +++ b/backend/app/services/safety/emergency/__init__.py @@ -0,0 +1 @@ +"""${dir} module.""" diff --git a/backend/app/services/safety/exceptions.py b/backend/app/services/safety/exceptions.py new file mode 100644 index 0000000..5ebb70e --- /dev/null +++ b/backend/app/services/safety/exceptions.py @@ -0,0 +1,277 @@ +""" +Safety Framework Exceptions + +Custom exception classes for the safety and guardrails framework. +""" + +from typing import Any + + +class SafetyError(Exception): + """Base exception for all safety-related errors.""" + + def __init__( + self, + message: str, + *, + action_id: str | None = None, + agent_id: str | None = None, + details: dict[str, Any] | None = None, + ) -> None: + super().__init__(message) + self.message = message + self.action_id = action_id + self.agent_id = agent_id + self.details = details or {} + + +class PermissionDeniedError(SafetyError): + """Raised when an action is not permitted.""" + + def __init__( + self, + message: str = "Permission denied", + *, + action_type: str | None = None, + resource: str | None = None, + required_permission: str | None = None, + **kwargs: Any, + ) -> None: + super().__init__(message, **kwargs) + self.action_type = action_type + self.resource = resource + self.required_permission = required_permission + + +class BudgetExceededError(SafetyError): + """Raised when cost budget is exceeded.""" + + def __init__( + self, + message: str = "Budget exceeded", + *, + budget_type: str = "session", + current_usage: float = 0.0, + budget_limit: float = 0.0, + unit: str = "tokens", + **kwargs: Any, + ) -> None: + super().__init__(message, **kwargs) + self.budget_type = budget_type + self.current_usage = current_usage + self.budget_limit = budget_limit + self.unit = unit + + +class RateLimitExceededError(SafetyError): + """Raised when rate limit is exceeded.""" + + def __init__( + self, + message: str = "Rate limit exceeded", + *, + limit_type: str = "actions", + limit_value: int = 0, + window_seconds: int = 60, + retry_after_seconds: float = 0.0, + **kwargs: Any, + ) -> None: + super().__init__(message, **kwargs) + self.limit_type = limit_type + self.limit_value = limit_value + self.window_seconds = window_seconds + self.retry_after_seconds = retry_after_seconds + + +class LoopDetectedError(SafetyError): + """Raised when an action loop is detected.""" + + def __init__( + self, + message: str = "Loop detected", + *, + loop_type: str = "exact", + repetition_count: int = 0, + action_pattern: list[str] | None = None, + **kwargs: Any, + ) -> None: + super().__init__(message, **kwargs) + self.loop_type = loop_type + self.repetition_count = repetition_count + self.action_pattern = action_pattern or [] + + +class ApprovalRequiredError(SafetyError): + """Raised when human approval is required.""" + + def __init__( + self, + message: str = "Human approval required", + *, + approval_id: str | None = None, + reason: str | None = None, + timeout_seconds: int = 300, + **kwargs: Any, + ) -> None: + super().__init__(message, **kwargs) + self.approval_id = approval_id + self.reason = reason + self.timeout_seconds = timeout_seconds + + +class ApprovalDeniedError(SafetyError): + """Raised when human explicitly denies an action.""" + + def __init__( + self, + message: str = "Approval denied by human", + *, + approval_id: str | None = None, + denied_by: str | None = None, + denial_reason: str | None = None, + **kwargs: Any, + ) -> None: + super().__init__(message, **kwargs) + self.approval_id = approval_id + self.denied_by = denied_by + self.denial_reason = denial_reason + + +class ApprovalTimeoutError(SafetyError): + """Raised when approval request times out.""" + + def __init__( + self, + message: str = "Approval request timed out", + *, + approval_id: str | None = None, + timeout_seconds: int = 300, + **kwargs: Any, + ) -> None: + super().__init__(message, **kwargs) + self.approval_id = approval_id + self.timeout_seconds = timeout_seconds + + +class RollbackError(SafetyError): + """Raised when rollback fails.""" + + def __init__( + self, + message: str = "Rollback failed", + *, + checkpoint_id: str | None = None, + failed_actions: list[str] | None = None, + **kwargs: Any, + ) -> None: + super().__init__(message, **kwargs) + self.checkpoint_id = checkpoint_id + self.failed_actions = failed_actions or [] + + +class CheckpointError(SafetyError): + """Raised when checkpoint creation fails.""" + + def __init__( + self, + message: str = "Checkpoint creation failed", + *, + checkpoint_type: str | None = None, + **kwargs: Any, + ) -> None: + super().__init__(message, **kwargs) + self.checkpoint_type = checkpoint_type + + +class ValidationError(SafetyError): + """Raised when action validation fails.""" + + def __init__( + self, + message: str = "Validation failed", + *, + validation_rules: list[str] | None = None, + failed_rules: list[str] | None = None, + **kwargs: Any, + ) -> None: + super().__init__(message, **kwargs) + self.validation_rules = validation_rules or [] + self.failed_rules = failed_rules or [] + + +class ContentFilterError(SafetyError): + """Raised when content filtering detects prohibited content.""" + + def __init__( + self, + message: str = "Prohibited content detected", + *, + filter_type: str | None = None, + detected_patterns: list[str] | None = None, + **kwargs: Any, + ) -> None: + super().__init__(message, **kwargs) + self.filter_type = filter_type + self.detected_patterns = detected_patterns or [] + + +class SandboxError(SafetyError): + """Raised when sandbox execution fails.""" + + def __init__( + self, + message: str = "Sandbox execution failed", + *, + exit_code: int | None = None, + stderr: str | None = None, + **kwargs: Any, + ) -> None: + super().__init__(message, **kwargs) + self.exit_code = exit_code + self.stderr = stderr + + +class SandboxTimeoutError(SandboxError): + """Raised when sandbox execution times out.""" + + def __init__( + self, + message: str = "Sandbox execution timed out", + *, + timeout_seconds: int = 300, + **kwargs: Any, + ) -> None: + super().__init__(message, **kwargs) + self.timeout_seconds = timeout_seconds + + +class EmergencyStopError(SafetyError): + """Raised when emergency stop is triggered.""" + + def __init__( + self, + message: str = "Emergency stop triggered", + *, + stop_type: str = "kill", + triggered_by: str | None = None, + **kwargs: Any, + ) -> None: + super().__init__(message, **kwargs) + self.stop_type = stop_type + self.triggered_by = triggered_by + + +class PolicyViolationError(SafetyError): + """Raised when an action violates a safety policy.""" + + def __init__( + self, + message: str = "Policy violation", + *, + policy_name: str | None = None, + violated_rules: list[str] | None = None, + **kwargs: Any, + ) -> None: + super().__init__(message, **kwargs) + self.policy_name = policy_name + self.violated_rules = violated_rules or [] diff --git a/backend/app/services/safety/guardian.py b/backend/app/services/safety/guardian.py new file mode 100644 index 0000000..a830a45 --- /dev/null +++ b/backend/app/services/safety/guardian.py @@ -0,0 +1,614 @@ +""" +Safety Guardian + +Main facade for the safety framework. Orchestrates all safety checks +before, during, and after action execution. +""" + +import asyncio +import logging +from typing import Any + +from .audit import AuditLogger, get_audit_logger +from .config import ( + SafetyConfig, + get_policy_for_autonomy_level, + get_safety_config, +) +from .exceptions import ( + SafetyError, +) +from .models import ( + ActionRequest, + ActionResult, + AuditEventType, + GuardianResult, + SafetyDecision, + SafetyPolicy, +) + +logger = logging.getLogger(__name__) + + +class SafetyGuardian: + """ + Central orchestrator for all safety checks. + + The SafetyGuardian is the main entry point for validating agent actions. + It coordinates multiple safety subsystems: + - Permission checking + - Cost/budget control + - Rate limiting + - Loop detection + - Human-in-the-loop approval + - Rollback/checkpoint management + - Content filtering + - Sandbox execution + + Usage: + guardian = SafetyGuardian() + await guardian.initialize() + + # Before executing an action + result = await guardian.validate(action_request) + if not result.allowed: + # Handle denial + + # After action execution + await guardian.record_execution(action_request, action_result) + """ + + def __init__( + self, + config: SafetyConfig | None = None, + audit_logger: AuditLogger | None = None, + ) -> None: + """ + Initialize the SafetyGuardian. + + Args: + config: Optional safety configuration. If None, loads from environment. + audit_logger: Optional audit logger. If None, uses global instance. + """ + self._config = config or get_safety_config() + self._audit_logger = audit_logger + self._initialized = False + self._lock = asyncio.Lock() + + # Subsystem references (will be initialized lazily) + self._permission_manager: Any = None + self._cost_controller: Any = None + self._rate_limiter: Any = None + self._loop_detector: Any = None + self._hitl_manager: Any = None + self._rollback_manager: Any = None + self._content_filter: Any = None + self._sandbox_executor: Any = None + self._emergency_controls: Any = None + + # Policy cache + self._policies: dict[str, SafetyPolicy] = {} + self._default_policy: SafetyPolicy | None = None + + @property + def is_initialized(self) -> bool: + """Check if the guardian is initialized.""" + return self._initialized + + async def initialize(self) -> None: + """Initialize the SafetyGuardian and all subsystems.""" + async with self._lock: + if self._initialized: + logger.warning("SafetyGuardian already initialized") + return + + logger.info("Initializing SafetyGuardian") + + # Get audit logger + if self._audit_logger is None: + self._audit_logger = await get_audit_logger() + + # Initialize subsystems lazily as they're implemented + # For now, we'll import and initialize them when available + + self._initialized = True + logger.info("SafetyGuardian initialized") + + async def shutdown(self) -> None: + """Shutdown the SafetyGuardian and all subsystems.""" + async with self._lock: + if not self._initialized: + return + + logger.info("Shutting down SafetyGuardian") + + # Shutdown subsystems + # (Will be implemented as subsystems are added) + + self._initialized = False + logger.info("SafetyGuardian shutdown complete") + + async def validate( + self, + action: ActionRequest, + policy: SafetyPolicy | None = None, + ) -> GuardianResult: + """ + Validate an action before execution. + + Runs all safety checks in order: + 1. Permission check + 2. Cost/budget check + 3. Rate limit check + 4. Loop detection + 5. HITL check (if required) + 6. Checkpoint creation (if destructive) + + Args: + action: The action to validate + policy: Optional policy override. If None, uses autonomy-level policy. + + Returns: + GuardianResult with decision and details + """ + if not self._initialized: + await self.initialize() + + if not self._config.enabled: + # Safety disabled - allow everything (NOT RECOMMENDED) + logger.warning("Safety framework disabled - allowing action %s", action.id) + return GuardianResult( + action_id=action.id, + allowed=True, + decision=SafetyDecision.ALLOW, + reasons=["Safety framework disabled"], + ) + + # Get policy for this action + effective_policy = policy or self._get_policy(action) + + reasons: list[str] = [] + audit_events = [] + + try: + # Log action request + if self._audit_logger: + event = await self._audit_logger.log( + AuditEventType.ACTION_REQUESTED, + agent_id=action.metadata.agent_id, + action_id=action.id, + project_id=action.metadata.project_id, + session_id=action.metadata.session_id, + details={ + "action_type": action.action_type.value, + "tool_name": action.tool_name, + "resource": action.resource, + }, + correlation_id=action.metadata.correlation_id, + ) + audit_events.append(event) + + # 1. Permission check + permission_result = await self._check_permissions(action, effective_policy) + if permission_result.decision == SafetyDecision.DENY: + return await self._create_denial_result( + action, permission_result.reasons, audit_events + ) + + # 2. Cost/budget check + budget_result = await self._check_budget(action, effective_policy) + if budget_result.decision == SafetyDecision.DENY: + return await self._create_denial_result( + action, budget_result.reasons, audit_events + ) + + # 3. Rate limit check + rate_result = await self._check_rate_limit(action, effective_policy) + if rate_result.decision == SafetyDecision.DENY: + return await self._create_denial_result( + action, + rate_result.reasons, + audit_events, + retry_after=rate_result.retry_after_seconds, + ) + if rate_result.decision == SafetyDecision.DELAY: + # Return delay decision + return GuardianResult( + action_id=action.id, + allowed=False, + decision=SafetyDecision.DELAY, + reasons=rate_result.reasons, + retry_after_seconds=rate_result.retry_after_seconds, + audit_events=audit_events, + ) + + # 4. Loop detection + loop_result = await self._check_loops(action, effective_policy) + if loop_result.decision == SafetyDecision.DENY: + return await self._create_denial_result( + action, loop_result.reasons, audit_events + ) + + # 5. HITL check + hitl_result = await self._check_hitl(action, effective_policy) + if hitl_result.decision == SafetyDecision.REQUIRE_APPROVAL: + return GuardianResult( + action_id=action.id, + allowed=False, + decision=SafetyDecision.REQUIRE_APPROVAL, + reasons=hitl_result.reasons, + approval_id=hitl_result.approval_id, + audit_events=audit_events, + ) + + # 6. Create checkpoint if destructive + checkpoint_id = None + if action.is_destructive and self._config.auto_checkpoint_destructive: + checkpoint_id = await self._create_checkpoint(action) + + # All checks passed + reasons.append("All safety checks passed") + + if self._audit_logger: + event = await self._audit_logger.log_action_request( + action, SafetyDecision.ALLOW, reasons + ) + audit_events.append(event) + + return GuardianResult( + action_id=action.id, + allowed=True, + decision=SafetyDecision.ALLOW, + reasons=reasons, + checkpoint_id=checkpoint_id, + audit_events=audit_events, + ) + + except SafetyError as e: + # Known safety error + return await self._create_denial_result( + action, [str(e)], audit_events + ) + except Exception as e: + # Unknown error - fail closed in strict mode + logger.error("Unexpected error in safety validation: %s", e) + if self._config.strict_mode: + return await self._create_denial_result( + action, + [f"Safety validation error: {e}"], + audit_events, + ) + else: + # Non-strict mode - allow with warning + logger.warning("Non-strict mode: allowing action despite error") + return GuardianResult( + action_id=action.id, + allowed=True, + decision=SafetyDecision.ALLOW, + reasons=["Allowed despite validation error (non-strict mode)"], + audit_events=audit_events, + ) + + async def record_execution( + self, + action: ActionRequest, + result: ActionResult, + ) -> None: + """ + Record action execution result for auditing and tracking. + + Args: + action: The executed action + result: The execution result + """ + if self._audit_logger: + await self._audit_logger.log_action_executed( + action, + success=result.success, + execution_time_ms=result.execution_time_ms, + error=result.error, + ) + + # Update cost tracking + if self._cost_controller: + # Track actual cost + pass + + # Update loop detection history + if self._loop_detector: + # Add to action history + pass + + async def rollback(self, checkpoint_id: str) -> bool: + """ + Rollback to a checkpoint. + + Args: + checkpoint_id: ID of the checkpoint to rollback to + + Returns: + True if rollback succeeded + """ + if self._rollback_manager is None: + logger.warning("Rollback manager not available") + return False + + # Delegate to rollback manager + return await self._rollback_manager.rollback(checkpoint_id) + + async def emergency_stop( + self, + stop_type: str = "kill", + reason: str = "Manual emergency stop", + triggered_by: str = "system", + ) -> None: + """ + Trigger emergency stop. + + Args: + stop_type: Type of stop (kill, pause, lockdown) + reason: Reason for the stop + triggered_by: Who triggered the stop + """ + logger.critical( + "Emergency stop triggered: type=%s, reason=%s, by=%s", + stop_type, + reason, + triggered_by, + ) + + if self._audit_logger: + await self._audit_logger.log_emergency_stop( + stop_type=stop_type, + triggered_by=triggered_by, + reason=reason, + ) + + if self._emergency_controls: + await self._emergency_controls.execute_stop(stop_type) + + def _get_policy(self, action: ActionRequest) -> SafetyPolicy: + """Get the effective policy for an action.""" + # Check cached policies + autonomy_level = action.metadata.autonomy_level + + if autonomy_level.value not in self._policies: + self._policies[autonomy_level.value] = get_policy_for_autonomy_level( + autonomy_level + ) + + return self._policies[autonomy_level.value] + + async def _check_permissions( + self, + action: ActionRequest, + policy: SafetyPolicy, + ) -> GuardianResult: + """Check if action is permitted.""" + reasons: list[str] = [] + + # Check denied tools + if action.tool_name: + for pattern in policy.denied_tools: + if self._matches_pattern(action.tool_name, pattern): + reasons.append(f"Tool '{action.tool_name}' denied by pattern '{pattern}'") + return GuardianResult( + action_id=action.id, + allowed=False, + decision=SafetyDecision.DENY, + reasons=reasons, + ) + + # Check allowed tools (if not "*") + if action.tool_name and "*" not in policy.allowed_tools: + allowed = False + for pattern in policy.allowed_tools: + if self._matches_pattern(action.tool_name, pattern): + allowed = True + break + if not allowed: + reasons.append(f"Tool '{action.tool_name}' not in allowed list") + return GuardianResult( + action_id=action.id, + allowed=False, + decision=SafetyDecision.DENY, + reasons=reasons, + ) + + # Check file patterns + if action.resource: + for pattern in policy.denied_file_patterns: + if self._matches_pattern(action.resource, pattern): + reasons.append(f"Resource '{action.resource}' denied by pattern '{pattern}'") + return GuardianResult( + action_id=action.id, + allowed=False, + decision=SafetyDecision.DENY, + reasons=reasons, + ) + + return GuardianResult( + action_id=action.id, + allowed=True, + decision=SafetyDecision.ALLOW, + reasons=["Permission check passed"], + ) + + async def _check_budget( + self, + action: ActionRequest, + policy: SafetyPolicy, + ) -> GuardianResult: + """Check if action is within budget.""" + # TODO: Implement with CostController + # For now, return allow + return GuardianResult( + action_id=action.id, + allowed=True, + decision=SafetyDecision.ALLOW, + reasons=["Budget check passed (not fully implemented)"], + ) + + async def _check_rate_limit( + self, + action: ActionRequest, + policy: SafetyPolicy, + ) -> GuardianResult: + """Check if action is within rate limits.""" + # TODO: Implement with RateLimiter + # For now, return allow + return GuardianResult( + action_id=action.id, + allowed=True, + decision=SafetyDecision.ALLOW, + reasons=["Rate limit check passed (not fully implemented)"], + ) + + async def _check_loops( + self, + action: ActionRequest, + policy: SafetyPolicy, + ) -> GuardianResult: + """Check for action loops.""" + # TODO: Implement with LoopDetector + # For now, return allow + return GuardianResult( + action_id=action.id, + allowed=True, + decision=SafetyDecision.ALLOW, + reasons=["Loop check passed (not fully implemented)"], + ) + + async def _check_hitl( + self, + action: ActionRequest, + policy: SafetyPolicy, + ) -> GuardianResult: + """Check if human approval is required.""" + if not self._config.hitl_enabled: + return GuardianResult( + action_id=action.id, + allowed=True, + decision=SafetyDecision.ALLOW, + reasons=["HITL disabled"], + ) + + # Check if action requires approval + requires_approval = False + for pattern in policy.require_approval_for: + if pattern == "*": + requires_approval = True + break + if action.tool_name and self._matches_pattern(action.tool_name, pattern): + requires_approval = True + break + if action.action_type.value and self._matches_pattern( + action.action_type.value, pattern + ): + requires_approval = True + break + + if requires_approval: + # TODO: Create approval request with HITLManager + return GuardianResult( + action_id=action.id, + allowed=False, + decision=SafetyDecision.REQUIRE_APPROVAL, + reasons=["Action requires human approval"], + approval_id=None, # Will be set by HITLManager + ) + + return GuardianResult( + action_id=action.id, + allowed=True, + decision=SafetyDecision.ALLOW, + reasons=["No approval required"], + ) + + async def _create_checkpoint(self, action: ActionRequest) -> str | None: + """Create a checkpoint before destructive action.""" + if self._rollback_manager is None: + logger.warning("Rollback manager not available - skipping checkpoint") + return None + + # TODO: Implement with RollbackManager + return None + + async def _create_denial_result( + self, + action: ActionRequest, + reasons: list[str], + audit_events: list[Any], + retry_after: float | None = None, + ) -> GuardianResult: + """Create a denial result with audit logging.""" + if self._audit_logger: + event = await self._audit_logger.log_action_request( + action, SafetyDecision.DENY, reasons + ) + audit_events.append(event) + + return GuardianResult( + action_id=action.id, + allowed=False, + decision=SafetyDecision.DENY, + reasons=reasons, + retry_after_seconds=retry_after, + audit_events=audit_events, + ) + + def _matches_pattern(self, value: str, pattern: str) -> bool: + """Check if value matches a pattern (supports * wildcard).""" + if pattern == "*": + return True + + if "*" not in pattern: + return value == pattern + + # Simple wildcard matching + if pattern.startswith("*") and pattern.endswith("*"): + return pattern[1:-1] in value + elif pattern.startswith("*"): + return value.endswith(pattern[1:]) + elif pattern.endswith("*"): + return value.startswith(pattern[:-1]) + else: + # Pattern like "foo*bar" + parts = pattern.split("*") + if len(parts) == 2: + return value.startswith(parts[0]) and value.endswith(parts[1]) + + return False + + +# Singleton instance +_guardian_instance: SafetyGuardian | None = None +_guardian_lock = asyncio.Lock() + + +async def get_safety_guardian() -> SafetyGuardian: + """Get the global SafetyGuardian instance.""" + global _guardian_instance + + async with _guardian_lock: + if _guardian_instance is None: + _guardian_instance = SafetyGuardian() + await _guardian_instance.initialize() + + return _guardian_instance + + +async def shutdown_safety_guardian() -> None: + """Shutdown the global SafetyGuardian.""" + global _guardian_instance + + async with _guardian_lock: + if _guardian_instance is not None: + await _guardian_instance.shutdown() + _guardian_instance = None + + +def reset_safety_guardian() -> None: + """Reset the SafetyGuardian (for testing).""" + global _guardian_instance + _guardian_instance = None diff --git a/backend/app/services/safety/hitl/__init__.py b/backend/app/services/safety/hitl/__init__.py new file mode 100644 index 0000000..9f4729c --- /dev/null +++ b/backend/app/services/safety/hitl/__init__.py @@ -0,0 +1 @@ +"""${dir} module.""" diff --git a/backend/app/services/safety/limits/__init__.py b/backend/app/services/safety/limits/__init__.py new file mode 100644 index 0000000..9f4729c --- /dev/null +++ b/backend/app/services/safety/limits/__init__.py @@ -0,0 +1 @@ +"""${dir} module.""" diff --git a/backend/app/services/safety/loops/__init__.py b/backend/app/services/safety/loops/__init__.py new file mode 100644 index 0000000..9f4729c --- /dev/null +++ b/backend/app/services/safety/loops/__init__.py @@ -0,0 +1 @@ +"""${dir} module.""" diff --git a/backend/app/services/safety/models.py b/backend/app/services/safety/models.py new file mode 100644 index 0000000..d1b1c9c --- /dev/null +++ b/backend/app/services/safety/models.py @@ -0,0 +1,474 @@ +""" +Safety Framework Models + +Core Pydantic models for actions, events, policies, and safety decisions. +""" + +from datetime import datetime +from enum import Enum +from typing import Any +from uuid import uuid4 + +from pydantic import BaseModel, Field + +# ============================================================================ +# Enums +# ============================================================================ + + +class ActionType(str, Enum): + """Types of actions that can be performed.""" + + TOOL_CALL = "tool_call" + FILE_READ = "file_read" + FILE_WRITE = "file_write" + FILE_DELETE = "file_delete" + API_CALL = "api_call" + DATABASE_QUERY = "database_query" + DATABASE_MUTATE = "database_mutate" + GIT_OPERATION = "git_operation" + SHELL_COMMAND = "shell_command" + LLM_CALL = "llm_call" + NETWORK_REQUEST = "network_request" + CUSTOM = "custom" + + +class ResourceType(str, Enum): + """Types of resources that can be accessed.""" + + FILE = "file" + DATABASE = "database" + API = "api" + NETWORK = "network" + GIT = "git" + SHELL = "shell" + LLM = "llm" + MEMORY = "memory" + CUSTOM = "custom" + + +class PermissionLevel(str, Enum): + """Permission levels for resource access.""" + + NONE = "none" + READ = "read" + WRITE = "write" + EXECUTE = "execute" + DELETE = "delete" + ADMIN = "admin" + + +class AutonomyLevel(str, Enum): + """Autonomy levels for agent operation.""" + + FULL_CONTROL = "full_control" # Approve every action + MILESTONE = "milestone" # Approve at milestones + AUTONOMOUS = "autonomous" # Only major decisions + + +class SafetyDecision(str, Enum): + """Result of safety validation.""" + + ALLOW = "allow" + DENY = "deny" + REQUIRE_APPROVAL = "require_approval" + DELAY = "delay" + SANDBOX = "sandbox" + + +class ApprovalStatus(str, Enum): + """Status of approval request.""" + + PENDING = "pending" + APPROVED = "approved" + DENIED = "denied" + TIMEOUT = "timeout" + CANCELLED = "cancelled" + + +class AuditEventType(str, Enum): + """Types of audit events.""" + + ACTION_REQUESTED = "action_requested" + ACTION_VALIDATED = "action_validated" + ACTION_DENIED = "action_denied" + ACTION_EXECUTED = "action_executed" + ACTION_FAILED = "action_failed" + APPROVAL_REQUESTED = "approval_requested" + APPROVAL_GRANTED = "approval_granted" + APPROVAL_DENIED = "approval_denied" + APPROVAL_TIMEOUT = "approval_timeout" + CHECKPOINT_CREATED = "checkpoint_created" + ROLLBACK_STARTED = "rollback_started" + ROLLBACK_COMPLETED = "rollback_completed" + ROLLBACK_FAILED = "rollback_failed" + BUDGET_WARNING = "budget_warning" + BUDGET_EXCEEDED = "budget_exceeded" + RATE_LIMITED = "rate_limited" + LOOP_DETECTED = "loop_detected" + EMERGENCY_STOP = "emergency_stop" + POLICY_VIOLATION = "policy_violation" + CONTENT_FILTERED = "content_filtered" + + +# ============================================================================ +# Action Models +# ============================================================================ + + +class ActionMetadata(BaseModel): + """Metadata associated with an action.""" + + agent_id: str = Field(..., description="ID of the agent performing the action") + project_id: str | None = Field(None, description="ID of the project context") + session_id: str | None = Field(None, description="ID of the current session") + task_id: str | None = Field(None, description="ID of the current task") + parent_action_id: str | None = Field(None, description="ID of the parent action") + correlation_id: str | None = Field(None, description="Correlation ID for tracing") + user_id: str | None = Field(None, description="ID of the user who initiated") + autonomy_level: AutonomyLevel = Field( + default=AutonomyLevel.MILESTONE, + description="Current autonomy level", + ) + context: dict[str, Any] = Field( + default_factory=dict, + description="Additional context", + ) + + +class ActionRequest(BaseModel): + """Request to perform an action.""" + + id: str = Field(default_factory=lambda: str(uuid4())) + action_type: ActionType = Field(..., description="Type of action to perform") + tool_name: str | None = Field(None, description="Name of the tool to call") + resource: str | None = Field(None, description="Resource being accessed") + resource_type: ResourceType | None = Field(None, description="Type of resource") + arguments: dict[str, Any] = Field( + default_factory=dict, + description="Action arguments", + ) + metadata: ActionMetadata = Field(..., description="Action metadata") + estimated_cost_tokens: int = Field(0, description="Estimated token cost") + estimated_cost_usd: float = Field(0.0, description="Estimated USD cost") + is_destructive: bool = Field(False, description="Whether action is destructive") + is_reversible: bool = Field(True, description="Whether action can be rolled back") + timestamp: datetime = Field(default_factory=datetime.utcnow) + + +class ActionResult(BaseModel): + """Result of an executed action.""" + + action_id: str = Field(..., description="ID of the action") + success: bool = Field(..., description="Whether action succeeded") + data: Any = Field(None, description="Action result data") + error: str | None = Field(None, description="Error message if failed") + error_code: str | None = Field(None, description="Error code if failed") + execution_time_ms: float = Field(0.0, description="Execution time in ms") + actual_cost_tokens: int = Field(0, description="Actual token cost") + actual_cost_usd: float = Field(0.0, description="Actual USD cost") + checkpoint_id: str | None = Field(None, description="Checkpoint ID if created") + timestamp: datetime = Field(default_factory=datetime.utcnow) + + +# ============================================================================ +# Validation Models +# ============================================================================ + + +class ValidationRule(BaseModel): + """A single validation rule.""" + + id: str = Field(default_factory=lambda: str(uuid4())) + name: str = Field(..., description="Rule name") + description: str | None = Field(None, description="Rule description") + priority: int = Field(0, description="Rule priority (higher = evaluated first)") + enabled: bool = Field(True, description="Whether rule is enabled") + + # Rule conditions + action_types: list[ActionType] | None = Field( + None, description="Action types this rule applies to" + ) + tool_patterns: list[str] | None = Field( + None, description="Tool name patterns (supports wildcards)" + ) + resource_patterns: list[str] | None = Field( + None, description="Resource patterns (supports wildcards)" + ) + agent_ids: list[str] | None = Field( + None, description="Agent IDs this rule applies to" + ) + + # Rule decision + decision: SafetyDecision = Field(..., description="Decision when rule matches") + reason: str | None = Field(None, description="Reason for decision") + + +class ValidationResult(BaseModel): + """Result of action validation.""" + + action_id: str = Field(..., description="ID of the validated action") + decision: SafetyDecision = Field(..., description="Validation decision") + applied_rules: list[str] = Field( + default_factory=list, description="IDs of applied rules" + ) + reasons: list[str] = Field( + default_factory=list, description="Reasons for decision" + ) + approval_id: str | None = Field(None, description="Approval request ID if needed") + retry_after_seconds: float | None = Field( + None, description="Retry delay if rate limited" + ) + timestamp: datetime = Field(default_factory=datetime.utcnow) + + +# ============================================================================ +# Budget Models +# ============================================================================ + + +class BudgetScope(str, Enum): + """Scope of a budget limit.""" + + SESSION = "session" + DAILY = "daily" + WEEKLY = "weekly" + MONTHLY = "monthly" + PROJECT = "project" + AGENT = "agent" + + +class BudgetStatus(BaseModel): + """Current budget status.""" + + scope: BudgetScope = Field(..., description="Budget scope") + scope_id: str = Field(..., description="ID within scope (session/agent/project)") + tokens_used: int = Field(0, description="Tokens used in this scope") + tokens_limit: int = Field(100000, description="Token limit for this scope") + cost_used_usd: float = Field(0.0, description="USD spent in this scope") + cost_limit_usd: float = Field(10.0, description="USD limit for this scope") + tokens_remaining: int = Field(0, description="Remaining tokens") + cost_remaining_usd: float = Field(0.0, description="Remaining USD budget") + warning_threshold: float = Field(0.8, description="Warn at this usage fraction") + is_warning: bool = Field(False, description="Whether at warning level") + is_exceeded: bool = Field(False, description="Whether budget exceeded") + reset_at: datetime | None = Field(None, description="When budget resets") + + +# ============================================================================ +# Rate Limit Models +# ============================================================================ + + +class RateLimitConfig(BaseModel): + """Configuration for a rate limit.""" + + name: str = Field(..., description="Rate limit name") + limit: int = Field(..., description="Maximum allowed in window") + window_seconds: int = Field(60, description="Time window in seconds") + burst_limit: int | None = Field(None, description="Burst allowance") + slowdown_threshold: float = Field( + 0.8, description="Start slowing at this fraction" + ) + + +class RateLimitStatus(BaseModel): + """Current rate limit status.""" + + name: str = Field(..., description="Rate limit name") + current_count: int = Field(0, description="Current count in window") + limit: int = Field(..., description="Maximum allowed") + window_seconds: int = Field(..., description="Time window") + remaining: int = Field(..., description="Remaining in window") + reset_at: datetime = Field(..., description="When window resets") + is_limited: bool = Field(False, description="Whether currently limited") + retry_after_seconds: float = Field(0.0, description="Seconds until retry") + + +# ============================================================================ +# Approval Models +# ============================================================================ + + +class ApprovalRequest(BaseModel): + """Request for human approval.""" + + id: str = Field(default_factory=lambda: str(uuid4())) + action: ActionRequest = Field(..., description="Action requiring approval") + reason: str = Field(..., description="Why approval is required") + urgency: str = Field("normal", description="Urgency level") + timeout_seconds: int = Field(300, description="Timeout for approval") + created_at: datetime = Field(default_factory=datetime.utcnow) + expires_at: datetime | None = Field(None, description="When request expires") + suggested_action: str | None = Field(None, description="Suggested response") + context: dict[str, Any] = Field(default_factory=dict, description="Extra context") + + +class ApprovalResponse(BaseModel): + """Response to an approval request.""" + + request_id: str = Field(..., description="ID of the approval request") + status: ApprovalStatus = Field(..., description="Approval status") + decided_by: str | None = Field(None, description="Who made the decision") + reason: str | None = Field(None, description="Reason for decision") + modifications: dict[str, Any] | None = Field( + None, description="Modifications to action" + ) + decided_at: datetime = Field(default_factory=datetime.utcnow) + + +# ============================================================================ +# Checkpoint/Rollback Models +# ============================================================================ + + +class CheckpointType(str, Enum): + """Types of checkpoints.""" + + FILE = "file" + DATABASE = "database" + GIT = "git" + COMPOSITE = "composite" + + +class Checkpoint(BaseModel): + """A rollback checkpoint.""" + + id: str = Field(default_factory=lambda: str(uuid4())) + checkpoint_type: CheckpointType = Field(..., description="Type of checkpoint") + action_id: str = Field(..., description="Action this checkpoint is for") + created_at: datetime = Field(default_factory=datetime.utcnow) + expires_at: datetime | None = Field(None, description="When checkpoint expires") + data: dict[str, Any] = Field(default_factory=dict, description="Checkpoint data") + description: str | None = Field(None, description="Description of checkpoint") + is_valid: bool = Field(True, description="Whether checkpoint is still valid") + + +class RollbackResult(BaseModel): + """Result of a rollback operation.""" + + checkpoint_id: str = Field(..., description="ID of checkpoint rolled back to") + success: bool = Field(..., description="Whether rollback succeeded") + actions_rolled_back: list[str] = Field( + default_factory=list, description="IDs of rolled back actions" + ) + failed_actions: list[str] = Field( + default_factory=list, description="IDs of actions that failed to rollback" + ) + error: str | None = Field(None, description="Error message if failed") + timestamp: datetime = Field(default_factory=datetime.utcnow) + + +# ============================================================================ +# Audit Models +# ============================================================================ + + +class AuditEvent(BaseModel): + """An audit log event.""" + + id: str = Field(default_factory=lambda: str(uuid4())) + event_type: AuditEventType = Field(..., description="Type of audit event") + timestamp: datetime = Field(default_factory=datetime.utcnow) + agent_id: str | None = Field(None, description="Agent ID if applicable") + action_id: str | None = Field(None, description="Action ID if applicable") + project_id: str | None = Field(None, description="Project ID if applicable") + session_id: str | None = Field(None, description="Session ID if applicable") + user_id: str | None = Field(None, description="User ID if applicable") + decision: SafetyDecision | None = Field(None, description="Safety decision") + details: dict[str, Any] = Field(default_factory=dict, description="Event details") + correlation_id: str | None = Field(None, description="Correlation ID for tracing") + + +# ============================================================================ +# Policy Models +# ============================================================================ + + +class SafetyPolicy(BaseModel): + """A complete safety policy configuration.""" + + name: str = Field(..., description="Policy name") + description: str | None = Field(None, description="Policy description") + version: str = Field("1.0.0", description="Policy version") + enabled: bool = Field(True, description="Whether policy is enabled") + + # Cost controls + max_tokens_per_session: int = Field(100_000, description="Max tokens per session") + max_tokens_per_day: int = Field(1_000_000, description="Max tokens per day") + max_cost_per_session_usd: float = Field(10.0, description="Max USD per session") + max_cost_per_day_usd: float = Field(100.0, description="Max USD per day") + + # Rate limits + max_actions_per_minute: int = Field(60, description="Max actions per minute") + max_llm_calls_per_minute: int = Field(20, description="Max LLM calls per minute") + max_file_operations_per_minute: int = Field( + 100, description="Max file ops per minute" + ) + + # Permissions + allowed_tools: list[str] = Field( + default_factory=lambda: ["*"], + description="Allowed tool patterns", + ) + denied_tools: list[str] = Field( + default_factory=list, + description="Denied tool patterns", + ) + allowed_file_patterns: list[str] = Field( + default_factory=lambda: ["**/*"], + description="Allowed file patterns", + ) + denied_file_patterns: list[str] = Field( + default_factory=lambda: ["**/.env", "**/secrets/**"], + description="Denied file patterns", + ) + + # HITL + require_approval_for: list[str] = Field( + default_factory=lambda: [ + "delete_file", + "push_to_remote", + "deploy_to_production", + "modify_critical_config", + ], + description="Actions requiring approval", + ) + + # Loop detection + max_repeated_actions: int = Field(5, description="Max exact repetitions") + max_similar_actions: int = Field(10, description="Max similar actions") + + # Sandbox + require_sandbox: bool = Field(False, description="Require sandbox execution") + sandbox_timeout_seconds: int = Field(300, description="Sandbox timeout") + sandbox_memory_mb: int = Field(1024, description="Sandbox memory limit") + + # Validation rules + validation_rules: list[ValidationRule] = Field( + default_factory=list, + description="Custom validation rules", + ) + + +# ============================================================================ +# Guardian Result Models +# ============================================================================ + + +class GuardianResult(BaseModel): + """Result of SafetyGuardian evaluation.""" + + action_id: str = Field(..., description="ID of the action") + allowed: bool = Field(..., description="Whether action is allowed") + decision: SafetyDecision = Field(..., description="Safety decision") + reasons: list[str] = Field(default_factory=list, description="Decision reasons") + approval_id: str | None = Field(None, description="Approval ID if needed") + checkpoint_id: str | None = Field(None, description="Checkpoint ID if created") + retry_after_seconds: float | None = Field(None, description="Retry delay") + modified_action: ActionRequest | None = Field( + None, description="Modified action if changed" + ) + audit_events: list[AuditEvent] = Field( + default_factory=list, description="Generated audit events" + ) diff --git a/backend/app/services/safety/permissions/__init__.py b/backend/app/services/safety/permissions/__init__.py new file mode 100644 index 0000000..9f4729c --- /dev/null +++ b/backend/app/services/safety/permissions/__init__.py @@ -0,0 +1 @@ +"""${dir} module.""" diff --git a/backend/app/services/safety/policies/__init__.py b/backend/app/services/safety/policies/__init__.py new file mode 100644 index 0000000..9f4729c --- /dev/null +++ b/backend/app/services/safety/policies/__init__.py @@ -0,0 +1 @@ +"""${dir} module.""" diff --git a/backend/app/services/safety/rollback/__init__.py b/backend/app/services/safety/rollback/__init__.py new file mode 100644 index 0000000..9f4729c --- /dev/null +++ b/backend/app/services/safety/rollback/__init__.py @@ -0,0 +1 @@ +"""${dir} module.""" diff --git a/backend/app/services/safety/sandbox/__init__.py b/backend/app/services/safety/sandbox/__init__.py new file mode 100644 index 0000000..9f4729c --- /dev/null +++ b/backend/app/services/safety/sandbox/__init__.py @@ -0,0 +1 @@ +"""${dir} module.""" diff --git a/backend/app/services/safety/validation/__init__.py b/backend/app/services/safety/validation/__init__.py new file mode 100644 index 0000000..9f4729c --- /dev/null +++ b/backend/app/services/safety/validation/__init__.py @@ -0,0 +1 @@ +"""${dir} module."""