feat(backend): add safety framework foundation (Phase A) (#63)

Core safety framework architecture for autonomous agent guardrails:

**Core Components:**
- SafetyGuardian: Main orchestrator for all safety checks
- AuditLogger: Comprehensive audit logging with hash chain tamper detection
- SafetyConfig: Pydantic-based configuration
- Models: Action requests, validation results, policies, checkpoints

**Exception Hierarchy:**
- SafetyError base with context preservation
- Permission, Budget, RateLimit, Loop errors
- Approval workflow errors (Required, Denied, Timeout)
- Rollback, Sandbox, Emergency exceptions

**Safety Policy System:**
- Autonomy level based policies (FULL_CONTROL, MILESTONE, AUTONOMOUS)
- Cost limits, rate limits, permission patterns
- HITL approval requirements per action type
- Configurable loop detection thresholds

**Directory Structure:**
- validation/, costs/, limits/, loops/ - Control subsystems
- permissions/, rollback/, hitl/ - Access and recovery
- content/, sandbox/, emergency/ - Protection systems
- audit/, policies/ - Logging and configuration

Phase A establishes the architecture. Subsystems to be implemented in Phase B-C.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-03 11:22:25 +01:00
parent e5975fa5d0
commit 498c0a0e94
18 changed files with 2450 additions and 0 deletions

View File

@@ -0,0 +1,170 @@
"""
Safety and Guardrails Framework
Comprehensive safety framework for autonomous agent operation.
Provides multi-layered protection including:
- Pre-execution validation
- Cost and budget controls
- Rate limiting
- Loop detection and prevention
- Human-in-the-loop approval
- Rollback and checkpointing
- Content filtering
- Sandboxed execution
- Emergency controls
- Complete audit trail
Usage:
from app.services.safety import get_safety_guardian, SafetyGuardian
guardian = await get_safety_guardian()
result = await guardian.validate(action_request)
if result.allowed:
# Execute action
pass
else:
# Handle denial
print(f"Action denied: {result.reasons}")
"""
# Exceptions
# Audit
from .audit import (
AuditLogger,
get_audit_logger,
reset_audit_logger,
shutdown_audit_logger,
)
# Configuration
from .config import (
AutonomyConfig,
SafetyConfig,
get_autonomy_config,
get_default_policy,
get_policy_for_autonomy_level,
get_safety_config,
load_policies_from_directory,
load_policy_from_file,
reset_config_cache,
)
from .exceptions import (
ApprovalDeniedError,
ApprovalRequiredError,
ApprovalTimeoutError,
BudgetExceededError,
CheckpointError,
ContentFilterError,
EmergencyStopError,
LoopDetectedError,
PermissionDeniedError,
PolicyViolationError,
RateLimitExceededError,
RollbackError,
SafetyError,
SandboxError,
SandboxTimeoutError,
ValidationError,
)
# Guardian
from .guardian import (
SafetyGuardian,
get_safety_guardian,
reset_safety_guardian,
shutdown_safety_guardian,
)
# Models
from .models import (
ActionMetadata,
ActionRequest,
ActionResult,
ActionType,
ApprovalRequest,
ApprovalResponse,
ApprovalStatus,
AuditEvent,
AuditEventType,
AutonomyLevel,
BudgetScope,
BudgetStatus,
Checkpoint,
CheckpointType,
GuardianResult,
PermissionLevel,
RateLimitConfig,
RateLimitStatus,
ResourceType,
RollbackResult,
SafetyDecision,
SafetyPolicy,
ValidationResult,
ValidationRule,
)
__all__ = [
"ActionMetadata",
"ActionRequest",
"ActionResult",
# Models
"ActionType",
"ApprovalDeniedError",
"ApprovalRequest",
"ApprovalRequiredError",
"ApprovalResponse",
"ApprovalStatus",
"ApprovalTimeoutError",
"AuditEvent",
"AuditEventType",
# Audit
"AuditLogger",
"AutonomyConfig",
"AutonomyLevel",
"BudgetExceededError",
"BudgetScope",
"BudgetStatus",
"Checkpoint",
"CheckpointError",
"CheckpointType",
"ContentFilterError",
"EmergencyStopError",
"GuardianResult",
"LoopDetectedError",
"PermissionDeniedError",
"PermissionLevel",
"PolicyViolationError",
"RateLimitConfig",
"RateLimitExceededError",
"RateLimitStatus",
"ResourceType",
"RollbackError",
"RollbackResult",
# Configuration
"SafetyConfig",
"SafetyDecision",
# Exceptions
"SafetyError",
# Guardian
"SafetyGuardian",
"SafetyPolicy",
"SandboxError",
"SandboxTimeoutError",
"ValidationError",
"ValidationResult",
"ValidationRule",
"get_audit_logger",
"get_autonomy_config",
"get_default_policy",
"get_policy_for_autonomy_level",
"get_safety_config",
"get_safety_guardian",
"load_policies_from_directory",
"load_policy_from_file",
"reset_audit_logger",
"reset_config_cache",
"reset_safety_guardian",
"shutdown_audit_logger",
"shutdown_safety_guardian",
]

View File

@@ -0,0 +1,19 @@
"""
Audit System
Comprehensive audit logging for all safety-related events.
"""
from .logger import (
AuditLogger,
get_audit_logger,
reset_audit_logger,
shutdown_audit_logger,
)
__all__ = [
"AuditLogger",
"get_audit_logger",
"reset_audit_logger",
"shutdown_audit_logger",
]

View File

@@ -0,0 +1,585 @@
"""
Audit Logger
Comprehensive audit logging for all safety-related events.
Provides tamper detection, structured logging, and compliance support.
"""
import asyncio
import hashlib
import json
import logging
from collections import deque
from datetime import datetime, timedelta
from typing import Any
from uuid import uuid4
from ..config import get_safety_config
from ..models import (
ActionRequest,
AuditEvent,
AuditEventType,
SafetyDecision,
)
logger = logging.getLogger(__name__)
class AuditLogger:
"""
Audit logger for safety events.
Features:
- Structured event logging
- In-memory buffer with async flush
- Tamper detection via hash chains
- Query/search capability
- Retention policy enforcement
"""
def __init__(
self,
max_buffer_size: int = 1000,
flush_interval_seconds: float = 10.0,
enable_hash_chain: bool = True,
) -> None:
"""
Initialize the audit logger.
Args:
max_buffer_size: Maximum events to buffer before auto-flush
flush_interval_seconds: Interval for periodic flush
enable_hash_chain: Enable tamper detection via hash chain
"""
self._buffer: deque[AuditEvent] = deque(maxlen=max_buffer_size)
self._persisted: list[AuditEvent] = []
self._flush_interval = flush_interval_seconds
self._enable_hash_chain = enable_hash_chain
self._last_hash: str | None = None
self._lock = asyncio.Lock()
self._flush_task: asyncio.Task[None] | None = None
self._running = False
# Event handlers for real-time processing
self._handlers: list[Any] = []
config = get_safety_config()
self._retention_days = config.audit_retention_days
self._include_sensitive = config.audit_include_sensitive
async def start(self) -> None:
"""Start the audit logger background tasks."""
if self._running:
return
self._running = True
self._flush_task = asyncio.create_task(self._periodic_flush())
logger.info("Audit logger started")
async def stop(self) -> None:
"""Stop the audit logger and flush remaining events."""
self._running = False
if self._flush_task:
self._flush_task.cancel()
try:
await self._flush_task
except asyncio.CancelledError:
pass
# Final flush
await self.flush()
logger.info("Audit logger stopped")
async def log(
self,
event_type: AuditEventType,
*,
agent_id: str | None = None,
action_id: str | None = None,
project_id: str | None = None,
session_id: str | None = None,
user_id: str | None = None,
decision: SafetyDecision | None = None,
details: dict[str, Any] | None = None,
correlation_id: str | None = None,
) -> AuditEvent:
"""
Log an audit event.
Args:
event_type: Type of audit event
agent_id: Agent ID if applicable
action_id: Action ID if applicable
project_id: Project ID if applicable
session_id: Session ID if applicable
user_id: User ID if applicable
decision: Safety decision if applicable
details: Additional event details
correlation_id: Correlation ID for tracing
Returns:
The created audit event
"""
# Sanitize sensitive data if needed
sanitized_details = self._sanitize_details(details) if details else {}
event = AuditEvent(
id=str(uuid4()),
event_type=event_type,
timestamp=datetime.utcnow(),
agent_id=agent_id,
action_id=action_id,
project_id=project_id,
session_id=session_id,
user_id=user_id,
decision=decision,
details=sanitized_details,
correlation_id=correlation_id,
)
async with self._lock:
# Add hash chain for tamper detection
if self._enable_hash_chain:
event_hash = self._compute_hash(event)
sanitized_details["_hash"] = event_hash
sanitized_details["_prev_hash"] = self._last_hash
self._last_hash = event_hash
self._buffer.append(event)
# Notify handlers
await self._notify_handlers(event)
# Log to standard logger as well
self._log_to_logger(event)
return event
async def log_action_request(
self,
action: ActionRequest,
decision: SafetyDecision,
reasons: list[str] | None = None,
) -> AuditEvent:
"""Log an action request with its validation decision."""
event_type = (
AuditEventType.ACTION_DENIED
if decision == SafetyDecision.DENY
else AuditEventType.ACTION_VALIDATED
)
return await self.log(
event_type,
agent_id=action.metadata.agent_id,
action_id=action.id,
project_id=action.metadata.project_id,
session_id=action.metadata.session_id,
user_id=action.metadata.user_id,
decision=decision,
details={
"action_type": action.action_type.value,
"tool_name": action.tool_name,
"resource": action.resource,
"is_destructive": action.is_destructive,
"reasons": reasons or [],
},
correlation_id=action.metadata.correlation_id,
)
async def log_action_executed(
self,
action: ActionRequest,
success: bool,
execution_time_ms: float,
error: str | None = None,
) -> AuditEvent:
"""Log an action execution result."""
event_type = (
AuditEventType.ACTION_EXECUTED
if success
else AuditEventType.ACTION_FAILED
)
return await self.log(
event_type,
agent_id=action.metadata.agent_id,
action_id=action.id,
project_id=action.metadata.project_id,
session_id=action.metadata.session_id,
decision=SafetyDecision.ALLOW if success else SafetyDecision.DENY,
details={
"action_type": action.action_type.value,
"tool_name": action.tool_name,
"success": success,
"execution_time_ms": execution_time_ms,
"error": error,
},
correlation_id=action.metadata.correlation_id,
)
async def log_approval_event(
self,
event_type: AuditEventType,
approval_id: str,
action: ActionRequest,
decided_by: str | None = None,
reason: str | None = None,
) -> AuditEvent:
"""Log an approval-related event."""
return await self.log(
event_type,
agent_id=action.metadata.agent_id,
action_id=action.id,
project_id=action.metadata.project_id,
session_id=action.metadata.session_id,
user_id=decided_by,
details={
"approval_id": approval_id,
"action_type": action.action_type.value,
"tool_name": action.tool_name,
"decided_by": decided_by,
"reason": reason,
},
correlation_id=action.metadata.correlation_id,
)
async def log_budget_event(
self,
event_type: AuditEventType,
agent_id: str,
scope: str,
current_usage: float,
limit: float,
unit: str = "tokens",
) -> AuditEvent:
"""Log a budget-related event."""
return await self.log(
event_type,
agent_id=agent_id,
details={
"scope": scope,
"current_usage": current_usage,
"limit": limit,
"unit": unit,
"usage_percent": (current_usage / limit * 100) if limit > 0 else 0,
},
)
async def log_emergency_stop(
self,
stop_type: str,
triggered_by: str,
reason: str,
affected_agents: list[str] | None = None,
) -> AuditEvent:
"""Log an emergency stop event."""
return await self.log(
AuditEventType.EMERGENCY_STOP,
user_id=triggered_by,
details={
"stop_type": stop_type,
"triggered_by": triggered_by,
"reason": reason,
"affected_agents": affected_agents or [],
},
)
async def flush(self) -> int:
"""
Flush buffered events to persistent storage.
Returns:
Number of events flushed
"""
async with self._lock:
if not self._buffer:
return 0
events = list(self._buffer)
self._buffer.clear()
# Persist events (in production, this would go to database/storage)
self._persisted.extend(events)
# Enforce retention
self._enforce_retention()
logger.debug("Flushed %d audit events", len(events))
return len(events)
async def query(
self,
*,
event_types: list[AuditEventType] | None = None,
agent_id: str | None = None,
action_id: str | None = None,
project_id: str | None = None,
session_id: str | None = None,
user_id: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
correlation_id: str | None = None,
limit: int = 100,
offset: int = 0,
) -> list[AuditEvent]:
"""
Query audit events with filters.
Args:
event_types: Filter by event types
agent_id: Filter by agent ID
action_id: Filter by action ID
project_id: Filter by project ID
session_id: Filter by session ID
user_id: Filter by user ID
start_time: Filter events after this time
end_time: Filter events before this time
correlation_id: Filter by correlation ID
limit: Maximum results to return
offset: Result offset for pagination
Returns:
List of matching audit events
"""
# Combine buffer and persisted for query
all_events = list(self._persisted) + list(self._buffer)
results = []
for event in all_events:
if event_types and event.event_type not in event_types:
continue
if agent_id and event.agent_id != agent_id:
continue
if action_id and event.action_id != action_id:
continue
if project_id and event.project_id != project_id:
continue
if session_id and event.session_id != session_id:
continue
if user_id and event.user_id != user_id:
continue
if start_time and event.timestamp < start_time:
continue
if end_time and event.timestamp > end_time:
continue
if correlation_id and event.correlation_id != correlation_id:
continue
results.append(event)
# Sort by timestamp descending
results.sort(key=lambda e: e.timestamp, reverse=True)
# Apply pagination
return results[offset : offset + limit]
async def get_action_history(
self,
agent_id: str,
limit: int = 100,
) -> list[AuditEvent]:
"""Get action history for an agent."""
return await self.query(
agent_id=agent_id,
event_types=[
AuditEventType.ACTION_REQUESTED,
AuditEventType.ACTION_VALIDATED,
AuditEventType.ACTION_DENIED,
AuditEventType.ACTION_EXECUTED,
AuditEventType.ACTION_FAILED,
],
limit=limit,
)
async def verify_integrity(self) -> tuple[bool, list[str]]:
"""
Verify audit log integrity using hash chain.
Returns:
Tuple of (is_valid, list of issues found)
"""
if not self._enable_hash_chain:
return True, []
issues: list[str] = []
all_events = list(self._persisted) + list(self._buffer)
prev_hash: str | None = None
for event in sorted(all_events, key=lambda e: e.timestamp):
stored_prev = event.details.get("_prev_hash")
stored_hash = event.details.get("_hash")
if stored_prev != prev_hash:
issues.append(
f"Hash chain broken at event {event.id}: "
f"expected prev_hash={prev_hash}, got {stored_prev}"
)
if stored_hash:
computed = self._compute_hash(event)
if computed != stored_hash:
issues.append(
f"Hash mismatch at event {event.id}: "
f"expected {computed}, got {stored_hash}"
)
prev_hash = stored_hash
return len(issues) == 0, issues
def add_handler(self, handler: Any) -> None:
"""Add a real-time event handler."""
self._handlers.append(handler)
def remove_handler(self, handler: Any) -> None:
"""Remove an event handler."""
if handler in self._handlers:
self._handlers.remove(handler)
def _sanitize_details(self, details: dict[str, Any]) -> dict[str, Any]:
"""Sanitize sensitive data from details."""
if self._include_sensitive:
return details
sanitized: dict[str, Any] = {}
sensitive_keys = {
"password",
"secret",
"token",
"api_key",
"apikey",
"auth",
"credential",
}
for key, value in details.items():
lower_key = key.lower()
if any(s in lower_key for s in sensitive_keys):
sanitized[key] = "[REDACTED]"
elif isinstance(value, dict):
sanitized[key] = self._sanitize_details(value)
else:
sanitized[key] = value
return sanitized
def _compute_hash(self, event: AuditEvent) -> str:
"""Compute hash for an event (excluding hash fields)."""
data = {
"id": event.id,
"event_type": event.event_type.value,
"timestamp": event.timestamp.isoformat(),
"agent_id": event.agent_id,
"action_id": event.action_id,
"project_id": event.project_id,
"session_id": event.session_id,
"user_id": event.user_id,
"decision": event.decision.value if event.decision else None,
"details": {
k: v
for k, v in event.details.items()
if not k.startswith("_")
},
"correlation_id": event.correlation_id,
}
if self._last_hash:
data["_prev_hash"] = self._last_hash
serialized = json.dumps(data, sort_keys=True, default=str)
return hashlib.sha256(serialized.encode()).hexdigest()
def _log_to_logger(self, event: AuditEvent) -> None:
"""Log event to standard Python logger."""
log_data = {
"audit_event": event.event_type.value,
"event_id": event.id,
"agent_id": event.agent_id,
"action_id": event.action_id,
"decision": event.decision.value if event.decision else None,
}
# Use appropriate log level based on event type
if event.event_type in {
AuditEventType.ACTION_DENIED,
AuditEventType.POLICY_VIOLATION,
AuditEventType.EMERGENCY_STOP,
}:
logger.warning("Audit: %s", log_data)
elif event.event_type in {
AuditEventType.ACTION_FAILED,
AuditEventType.ROLLBACK_FAILED,
}:
logger.error("Audit: %s", log_data)
else:
logger.info("Audit: %s", log_data)
def _enforce_retention(self) -> None:
"""Enforce retention policy on persisted events."""
if not self._retention_days:
return
cutoff = datetime.utcnow() - timedelta(days=self._retention_days)
before_count = len(self._persisted)
self._persisted = [e for e in self._persisted if e.timestamp >= cutoff]
removed = before_count - len(self._persisted)
if removed > 0:
logger.info("Removed %d expired audit events", removed)
async def _periodic_flush(self) -> None:
"""Background task for periodic flushing."""
while self._running:
try:
await asyncio.sleep(self._flush_interval)
await self.flush()
except asyncio.CancelledError:
break
except Exception as e:
logger.error("Error in periodic audit flush: %s", e)
async def _notify_handlers(self, event: AuditEvent) -> None:
"""Notify all registered handlers of a new event."""
for handler in self._handlers:
try:
if asyncio.iscoroutinefunction(handler):
await handler(event)
else:
handler(event)
except Exception as e:
logger.error("Error in audit event handler: %s", e)
# Singleton instance
_audit_logger: AuditLogger | None = None
_audit_lock = asyncio.Lock()
async def get_audit_logger() -> AuditLogger:
"""Get the global audit logger instance."""
global _audit_logger
async with _audit_lock:
if _audit_logger is None:
_audit_logger = AuditLogger()
await _audit_logger.start()
return _audit_logger
async def shutdown_audit_logger() -> None:
"""Shutdown the global audit logger."""
global _audit_logger
async with _audit_lock:
if _audit_logger is not None:
await _audit_logger.stop()
_audit_logger = None
def reset_audit_logger() -> None:
"""Reset the audit logger (for testing)."""
global _audit_logger
_audit_logger = None

View File

@@ -0,0 +1,300 @@
"""
Safety Framework Configuration
Pydantic settings for the safety and guardrails framework.
"""
import logging
import os
from functools import lru_cache
from pathlib import Path
from typing import Any
import yaml
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
from .models import AutonomyLevel, SafetyPolicy
logger = logging.getLogger(__name__)
class SafetyConfig(BaseSettings):
"""Configuration for the safety framework."""
model_config = SettingsConfigDict(
env_prefix="SAFETY_",
env_file=".env",
env_file_encoding="utf-8",
extra="ignore",
)
# General settings
enabled: bool = Field(True, description="Enable safety framework")
strict_mode: bool = Field(
True, description="Strict mode (fail closed on errors)"
)
log_level: str = Field("INFO", description="Logging level")
# Default autonomy level
default_autonomy_level: AutonomyLevel = Field(
AutonomyLevel.MILESTONE,
description="Default autonomy level for new agents",
)
# Default budget limits
default_session_token_budget: int = Field(
100_000, description="Default tokens per session"
)
default_daily_token_budget: int = Field(
1_000_000, description="Default tokens per day"
)
default_session_cost_limit: float = Field(
10.0, description="Default USD per session"
)
default_daily_cost_limit: float = Field(100.0, description="Default USD per day")
# Default rate limits
default_actions_per_minute: int = Field(60, description="Default actions per min")
default_llm_calls_per_minute: int = Field(20, description="Default LLM calls/min")
default_file_ops_per_minute: int = Field(100, description="Default file ops/min")
# Loop detection
loop_detection_enabled: bool = Field(True, description="Enable loop detection")
max_repeated_actions: int = Field(5, description="Max exact repetitions")
max_similar_actions: int = Field(10, description="Max similar actions")
loop_history_size: int = Field(100, description="Action history size for loops")
# HITL settings
hitl_enabled: bool = Field(True, description="Enable human-in-the-loop")
hitl_default_timeout: int = Field(300, description="Default approval timeout (s)")
hitl_notification_channels: list[str] = Field(
default_factory=list, description="Notification channels"
)
# Rollback settings
rollback_enabled: bool = Field(True, description="Enable rollback capability")
checkpoint_retention_hours: int = Field(24, description="Checkpoint retention")
auto_checkpoint_destructive: bool = Field(
True, description="Auto-checkpoint destructive actions"
)
# Sandbox settings
sandbox_enabled: bool = Field(False, description="Enable sandbox execution")
sandbox_timeout: int = Field(300, description="Sandbox timeout (s)")
sandbox_memory_mb: int = Field(1024, description="Sandbox memory limit (MB)")
sandbox_cpu_limit: float = Field(1.0, description="Sandbox CPU limit")
sandbox_network_enabled: bool = Field(False, description="Allow sandbox network")
# Audit settings
audit_enabled: bool = Field(True, description="Enable audit logging")
audit_retention_days: int = Field(90, description="Audit log retention (days)")
audit_include_sensitive: bool = Field(
False, description="Include sensitive data in audit"
)
# Content filtering
content_filter_enabled: bool = Field(True, description="Enable content filtering")
filter_pii: bool = Field(True, description="Filter PII")
filter_secrets: bool = Field(True, description="Filter secrets")
# Emergency controls
emergency_stop_enabled: bool = Field(True, description="Enable emergency stop")
emergency_webhook_url: str | None = Field(None, description="Emergency webhook")
# Policy file path
policy_file: str | None = Field(None, description="Path to policy YAML file")
# Validation cache
validation_cache_ttl: int = Field(60, description="Validation cache TTL (s)")
validation_cache_size: int = Field(1000, description="Validation cache size")
class AutonomyConfig(BaseSettings):
"""Configuration for autonomy levels."""
model_config = SettingsConfigDict(
env_prefix="AUTONOMY_",
env_file=".env",
env_file_encoding="utf-8",
extra="ignore",
)
# FULL_CONTROL settings
full_control_cost_limit: float = Field(1.0, description="USD limit per session")
full_control_require_all_approval: bool = Field(
True, description="Require approval for all"
)
full_control_block_destructive: bool = Field(
True, description="Block destructive actions"
)
# MILESTONE settings
milestone_cost_limit: float = Field(10.0, description="USD limit per session")
milestone_require_critical_approval: bool = Field(
True, description="Require approval for critical"
)
milestone_auto_checkpoint: bool = Field(
True, description="Auto-checkpoint destructive"
)
# AUTONOMOUS settings
autonomous_cost_limit: float = Field(100.0, description="USD limit per session")
autonomous_auto_approve_normal: bool = Field(
True, description="Auto-approve normal actions"
)
autonomous_auto_checkpoint: bool = Field(True, description="Auto-checkpoint all")
def _expand_env_vars(value: Any) -> Any:
"""Recursively expand environment variables in values."""
if isinstance(value, str):
return os.path.expandvars(value)
elif isinstance(value, dict):
return {k: _expand_env_vars(v) for k, v in value.items()}
elif isinstance(value, list):
return [_expand_env_vars(v) for v in value]
return value
def load_policy_from_file(file_path: str | Path) -> SafetyPolicy | None:
"""Load a safety policy from a YAML file."""
path = Path(file_path)
if not path.exists():
logger.warning("Policy file not found: %s", path)
return None
try:
with open(path) as f:
data = yaml.safe_load(f)
if data is None:
logger.warning("Empty policy file: %s", path)
return None
# Expand environment variables
data = _expand_env_vars(data)
return SafetyPolicy(**data)
except Exception as e:
logger.error("Failed to load policy file %s: %s", path, e)
return None
def load_policies_from_directory(directory: str | Path) -> dict[str, SafetyPolicy]:
"""Load all safety policies from a directory."""
policies: dict[str, SafetyPolicy] = {}
path = Path(directory)
if not path.exists() or not path.is_dir():
logger.warning("Policy directory not found: %s", path)
return policies
for file_path in path.glob("*.yaml"):
policy = load_policy_from_file(file_path)
if policy:
policies[policy.name] = policy
logger.info("Loaded policy: %s from %s", policy.name, file_path.name)
return policies
@lru_cache(maxsize=1)
def get_safety_config() -> SafetyConfig:
"""Get the safety configuration (cached singleton)."""
return SafetyConfig()
@lru_cache(maxsize=1)
def get_autonomy_config() -> AutonomyConfig:
"""Get the autonomy configuration (cached singleton)."""
return AutonomyConfig()
def get_default_policy() -> SafetyPolicy:
"""Get the default safety policy."""
config = get_safety_config()
return SafetyPolicy(
name="default",
description="Default safety policy",
max_tokens_per_session=config.default_session_token_budget,
max_tokens_per_day=config.default_daily_token_budget,
max_cost_per_session_usd=config.default_session_cost_limit,
max_cost_per_day_usd=config.default_daily_cost_limit,
max_actions_per_minute=config.default_actions_per_minute,
max_llm_calls_per_minute=config.default_llm_calls_per_minute,
max_file_operations_per_minute=config.default_file_ops_per_minute,
max_repeated_actions=config.max_repeated_actions,
max_similar_actions=config.max_similar_actions,
require_sandbox=config.sandbox_enabled,
sandbox_timeout_seconds=config.sandbox_timeout,
sandbox_memory_mb=config.sandbox_memory_mb,
)
def get_policy_for_autonomy_level(level: AutonomyLevel) -> SafetyPolicy:
"""Get the safety policy for a given autonomy level."""
autonomy = get_autonomy_config()
base_policy = get_default_policy()
if level == AutonomyLevel.FULL_CONTROL:
return SafetyPolicy(
name="full_control",
description="Full control mode - all actions require approval",
max_cost_per_session_usd=autonomy.full_control_cost_limit,
max_cost_per_day_usd=autonomy.full_control_cost_limit * 10,
require_approval_for=["*"], # All actions
max_tokens_per_session=base_policy.max_tokens_per_session // 10,
max_tokens_per_day=base_policy.max_tokens_per_day // 10,
max_actions_per_minute=base_policy.max_actions_per_minute // 2,
max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute // 2,
max_file_operations_per_minute=base_policy.max_file_operations_per_minute // 2,
denied_tools=["delete_*", "destroy_*", "drop_*"],
)
elif level == AutonomyLevel.MILESTONE:
return SafetyPolicy(
name="milestone",
description="Milestone mode - approval at milestones only",
max_cost_per_session_usd=autonomy.milestone_cost_limit,
max_cost_per_day_usd=autonomy.milestone_cost_limit * 10,
require_approval_for=[
"delete_file",
"push_to_remote",
"deploy_*",
"modify_critical_*",
"create_pull_request",
],
max_tokens_per_session=base_policy.max_tokens_per_session,
max_tokens_per_day=base_policy.max_tokens_per_day,
max_actions_per_minute=base_policy.max_actions_per_minute,
max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute,
max_file_operations_per_minute=base_policy.max_file_operations_per_minute,
)
else: # AUTONOMOUS
return SafetyPolicy(
name="autonomous",
description="Autonomous mode - minimal intervention",
max_cost_per_session_usd=autonomy.autonomous_cost_limit,
max_cost_per_day_usd=autonomy.autonomous_cost_limit * 10,
require_approval_for=[
"deploy_to_production",
"delete_repository",
"modify_production_config",
],
max_tokens_per_session=base_policy.max_tokens_per_session * 5,
max_tokens_per_day=base_policy.max_tokens_per_day * 5,
max_actions_per_minute=base_policy.max_actions_per_minute * 2,
max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute * 2,
max_file_operations_per_minute=base_policy.max_file_operations_per_minute * 2,
)
def reset_config_cache() -> None:
"""Reset configuration caches (for testing)."""
get_safety_config.cache_clear()
get_autonomy_config.cache_clear()

View File

@@ -0,0 +1 @@
"""${dir} module."""

View File

@@ -0,0 +1 @@
"""${dir} module."""

View File

@@ -0,0 +1 @@
"""${dir} module."""

View File

@@ -0,0 +1,277 @@
"""
Safety Framework Exceptions
Custom exception classes for the safety and guardrails framework.
"""
from typing import Any
class SafetyError(Exception):
"""Base exception for all safety-related errors."""
def __init__(
self,
message: str,
*,
action_id: str | None = None,
agent_id: str | None = None,
details: dict[str, Any] | None = None,
) -> None:
super().__init__(message)
self.message = message
self.action_id = action_id
self.agent_id = agent_id
self.details = details or {}
class PermissionDeniedError(SafetyError):
"""Raised when an action is not permitted."""
def __init__(
self,
message: str = "Permission denied",
*,
action_type: str | None = None,
resource: str | None = None,
required_permission: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.action_type = action_type
self.resource = resource
self.required_permission = required_permission
class BudgetExceededError(SafetyError):
"""Raised when cost budget is exceeded."""
def __init__(
self,
message: str = "Budget exceeded",
*,
budget_type: str = "session",
current_usage: float = 0.0,
budget_limit: float = 0.0,
unit: str = "tokens",
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.budget_type = budget_type
self.current_usage = current_usage
self.budget_limit = budget_limit
self.unit = unit
class RateLimitExceededError(SafetyError):
"""Raised when rate limit is exceeded."""
def __init__(
self,
message: str = "Rate limit exceeded",
*,
limit_type: str = "actions",
limit_value: int = 0,
window_seconds: int = 60,
retry_after_seconds: float = 0.0,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.limit_type = limit_type
self.limit_value = limit_value
self.window_seconds = window_seconds
self.retry_after_seconds = retry_after_seconds
class LoopDetectedError(SafetyError):
"""Raised when an action loop is detected."""
def __init__(
self,
message: str = "Loop detected",
*,
loop_type: str = "exact",
repetition_count: int = 0,
action_pattern: list[str] | None = None,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.loop_type = loop_type
self.repetition_count = repetition_count
self.action_pattern = action_pattern or []
class ApprovalRequiredError(SafetyError):
"""Raised when human approval is required."""
def __init__(
self,
message: str = "Human approval required",
*,
approval_id: str | None = None,
reason: str | None = None,
timeout_seconds: int = 300,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.approval_id = approval_id
self.reason = reason
self.timeout_seconds = timeout_seconds
class ApprovalDeniedError(SafetyError):
"""Raised when human explicitly denies an action."""
def __init__(
self,
message: str = "Approval denied by human",
*,
approval_id: str | None = None,
denied_by: str | None = None,
denial_reason: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.approval_id = approval_id
self.denied_by = denied_by
self.denial_reason = denial_reason
class ApprovalTimeoutError(SafetyError):
"""Raised when approval request times out."""
def __init__(
self,
message: str = "Approval request timed out",
*,
approval_id: str | None = None,
timeout_seconds: int = 300,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.approval_id = approval_id
self.timeout_seconds = timeout_seconds
class RollbackError(SafetyError):
"""Raised when rollback fails."""
def __init__(
self,
message: str = "Rollback failed",
*,
checkpoint_id: str | None = None,
failed_actions: list[str] | None = None,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.checkpoint_id = checkpoint_id
self.failed_actions = failed_actions or []
class CheckpointError(SafetyError):
"""Raised when checkpoint creation fails."""
def __init__(
self,
message: str = "Checkpoint creation failed",
*,
checkpoint_type: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.checkpoint_type = checkpoint_type
class ValidationError(SafetyError):
"""Raised when action validation fails."""
def __init__(
self,
message: str = "Validation failed",
*,
validation_rules: list[str] | None = None,
failed_rules: list[str] | None = None,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.validation_rules = validation_rules or []
self.failed_rules = failed_rules or []
class ContentFilterError(SafetyError):
"""Raised when content filtering detects prohibited content."""
def __init__(
self,
message: str = "Prohibited content detected",
*,
filter_type: str | None = None,
detected_patterns: list[str] | None = None,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.filter_type = filter_type
self.detected_patterns = detected_patterns or []
class SandboxError(SafetyError):
"""Raised when sandbox execution fails."""
def __init__(
self,
message: str = "Sandbox execution failed",
*,
exit_code: int | None = None,
stderr: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.exit_code = exit_code
self.stderr = stderr
class SandboxTimeoutError(SandboxError):
"""Raised when sandbox execution times out."""
def __init__(
self,
message: str = "Sandbox execution timed out",
*,
timeout_seconds: int = 300,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.timeout_seconds = timeout_seconds
class EmergencyStopError(SafetyError):
"""Raised when emergency stop is triggered."""
def __init__(
self,
message: str = "Emergency stop triggered",
*,
stop_type: str = "kill",
triggered_by: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.stop_type = stop_type
self.triggered_by = triggered_by
class PolicyViolationError(SafetyError):
"""Raised when an action violates a safety policy."""
def __init__(
self,
message: str = "Policy violation",
*,
policy_name: str | None = None,
violated_rules: list[str] | None = None,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.policy_name = policy_name
self.violated_rules = violated_rules or []

View File

@@ -0,0 +1,614 @@
"""
Safety Guardian
Main facade for the safety framework. Orchestrates all safety checks
before, during, and after action execution.
"""
import asyncio
import logging
from typing import Any
from .audit import AuditLogger, get_audit_logger
from .config import (
SafetyConfig,
get_policy_for_autonomy_level,
get_safety_config,
)
from .exceptions import (
SafetyError,
)
from .models import (
ActionRequest,
ActionResult,
AuditEventType,
GuardianResult,
SafetyDecision,
SafetyPolicy,
)
logger = logging.getLogger(__name__)
class SafetyGuardian:
"""
Central orchestrator for all safety checks.
The SafetyGuardian is the main entry point for validating agent actions.
It coordinates multiple safety subsystems:
- Permission checking
- Cost/budget control
- Rate limiting
- Loop detection
- Human-in-the-loop approval
- Rollback/checkpoint management
- Content filtering
- Sandbox execution
Usage:
guardian = SafetyGuardian()
await guardian.initialize()
# Before executing an action
result = await guardian.validate(action_request)
if not result.allowed:
# Handle denial
# After action execution
await guardian.record_execution(action_request, action_result)
"""
def __init__(
self,
config: SafetyConfig | None = None,
audit_logger: AuditLogger | None = None,
) -> None:
"""
Initialize the SafetyGuardian.
Args:
config: Optional safety configuration. If None, loads from environment.
audit_logger: Optional audit logger. If None, uses global instance.
"""
self._config = config or get_safety_config()
self._audit_logger = audit_logger
self._initialized = False
self._lock = asyncio.Lock()
# Subsystem references (will be initialized lazily)
self._permission_manager: Any = None
self._cost_controller: Any = None
self._rate_limiter: Any = None
self._loop_detector: Any = None
self._hitl_manager: Any = None
self._rollback_manager: Any = None
self._content_filter: Any = None
self._sandbox_executor: Any = None
self._emergency_controls: Any = None
# Policy cache
self._policies: dict[str, SafetyPolicy] = {}
self._default_policy: SafetyPolicy | None = None
@property
def is_initialized(self) -> bool:
"""Check if the guardian is initialized."""
return self._initialized
async def initialize(self) -> None:
"""Initialize the SafetyGuardian and all subsystems."""
async with self._lock:
if self._initialized:
logger.warning("SafetyGuardian already initialized")
return
logger.info("Initializing SafetyGuardian")
# Get audit logger
if self._audit_logger is None:
self._audit_logger = await get_audit_logger()
# Initialize subsystems lazily as they're implemented
# For now, we'll import and initialize them when available
self._initialized = True
logger.info("SafetyGuardian initialized")
async def shutdown(self) -> None:
"""Shutdown the SafetyGuardian and all subsystems."""
async with self._lock:
if not self._initialized:
return
logger.info("Shutting down SafetyGuardian")
# Shutdown subsystems
# (Will be implemented as subsystems are added)
self._initialized = False
logger.info("SafetyGuardian shutdown complete")
async def validate(
self,
action: ActionRequest,
policy: SafetyPolicy | None = None,
) -> GuardianResult:
"""
Validate an action before execution.
Runs all safety checks in order:
1. Permission check
2. Cost/budget check
3. Rate limit check
4. Loop detection
5. HITL check (if required)
6. Checkpoint creation (if destructive)
Args:
action: The action to validate
policy: Optional policy override. If None, uses autonomy-level policy.
Returns:
GuardianResult with decision and details
"""
if not self._initialized:
await self.initialize()
if not self._config.enabled:
# Safety disabled - allow everything (NOT RECOMMENDED)
logger.warning("Safety framework disabled - allowing action %s", action.id)
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Safety framework disabled"],
)
# Get policy for this action
effective_policy = policy or self._get_policy(action)
reasons: list[str] = []
audit_events = []
try:
# Log action request
if self._audit_logger:
event = await self._audit_logger.log(
AuditEventType.ACTION_REQUESTED,
agent_id=action.metadata.agent_id,
action_id=action.id,
project_id=action.metadata.project_id,
session_id=action.metadata.session_id,
details={
"action_type": action.action_type.value,
"tool_name": action.tool_name,
"resource": action.resource,
},
correlation_id=action.metadata.correlation_id,
)
audit_events.append(event)
# 1. Permission check
permission_result = await self._check_permissions(action, effective_policy)
if permission_result.decision == SafetyDecision.DENY:
return await self._create_denial_result(
action, permission_result.reasons, audit_events
)
# 2. Cost/budget check
budget_result = await self._check_budget(action, effective_policy)
if budget_result.decision == SafetyDecision.DENY:
return await self._create_denial_result(
action, budget_result.reasons, audit_events
)
# 3. Rate limit check
rate_result = await self._check_rate_limit(action, effective_policy)
if rate_result.decision == SafetyDecision.DENY:
return await self._create_denial_result(
action,
rate_result.reasons,
audit_events,
retry_after=rate_result.retry_after_seconds,
)
if rate_result.decision == SafetyDecision.DELAY:
# Return delay decision
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DELAY,
reasons=rate_result.reasons,
retry_after_seconds=rate_result.retry_after_seconds,
audit_events=audit_events,
)
# 4. Loop detection
loop_result = await self._check_loops(action, effective_policy)
if loop_result.decision == SafetyDecision.DENY:
return await self._create_denial_result(
action, loop_result.reasons, audit_events
)
# 5. HITL check
hitl_result = await self._check_hitl(action, effective_policy)
if hitl_result.decision == SafetyDecision.REQUIRE_APPROVAL:
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.REQUIRE_APPROVAL,
reasons=hitl_result.reasons,
approval_id=hitl_result.approval_id,
audit_events=audit_events,
)
# 6. Create checkpoint if destructive
checkpoint_id = None
if action.is_destructive and self._config.auto_checkpoint_destructive:
checkpoint_id = await self._create_checkpoint(action)
# All checks passed
reasons.append("All safety checks passed")
if self._audit_logger:
event = await self._audit_logger.log_action_request(
action, SafetyDecision.ALLOW, reasons
)
audit_events.append(event)
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=reasons,
checkpoint_id=checkpoint_id,
audit_events=audit_events,
)
except SafetyError as e:
# Known safety error
return await self._create_denial_result(
action, [str(e)], audit_events
)
except Exception as e:
# Unknown error - fail closed in strict mode
logger.error("Unexpected error in safety validation: %s", e)
if self._config.strict_mode:
return await self._create_denial_result(
action,
[f"Safety validation error: {e}"],
audit_events,
)
else:
# Non-strict mode - allow with warning
logger.warning("Non-strict mode: allowing action despite error")
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Allowed despite validation error (non-strict mode)"],
audit_events=audit_events,
)
async def record_execution(
self,
action: ActionRequest,
result: ActionResult,
) -> None:
"""
Record action execution result for auditing and tracking.
Args:
action: The executed action
result: The execution result
"""
if self._audit_logger:
await self._audit_logger.log_action_executed(
action,
success=result.success,
execution_time_ms=result.execution_time_ms,
error=result.error,
)
# Update cost tracking
if self._cost_controller:
# Track actual cost
pass
# Update loop detection history
if self._loop_detector:
# Add to action history
pass
async def rollback(self, checkpoint_id: str) -> bool:
"""
Rollback to a checkpoint.
Args:
checkpoint_id: ID of the checkpoint to rollback to
Returns:
True if rollback succeeded
"""
if self._rollback_manager is None:
logger.warning("Rollback manager not available")
return False
# Delegate to rollback manager
return await self._rollback_manager.rollback(checkpoint_id)
async def emergency_stop(
self,
stop_type: str = "kill",
reason: str = "Manual emergency stop",
triggered_by: str = "system",
) -> None:
"""
Trigger emergency stop.
Args:
stop_type: Type of stop (kill, pause, lockdown)
reason: Reason for the stop
triggered_by: Who triggered the stop
"""
logger.critical(
"Emergency stop triggered: type=%s, reason=%s, by=%s",
stop_type,
reason,
triggered_by,
)
if self._audit_logger:
await self._audit_logger.log_emergency_stop(
stop_type=stop_type,
triggered_by=triggered_by,
reason=reason,
)
if self._emergency_controls:
await self._emergency_controls.execute_stop(stop_type)
def _get_policy(self, action: ActionRequest) -> SafetyPolicy:
"""Get the effective policy for an action."""
# Check cached policies
autonomy_level = action.metadata.autonomy_level
if autonomy_level.value not in self._policies:
self._policies[autonomy_level.value] = get_policy_for_autonomy_level(
autonomy_level
)
return self._policies[autonomy_level.value]
async def _check_permissions(
self,
action: ActionRequest,
policy: SafetyPolicy,
) -> GuardianResult:
"""Check if action is permitted."""
reasons: list[str] = []
# Check denied tools
if action.tool_name:
for pattern in policy.denied_tools:
if self._matches_pattern(action.tool_name, pattern):
reasons.append(f"Tool '{action.tool_name}' denied by pattern '{pattern}'")
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=reasons,
)
# Check allowed tools (if not "*")
if action.tool_name and "*" not in policy.allowed_tools:
allowed = False
for pattern in policy.allowed_tools:
if self._matches_pattern(action.tool_name, pattern):
allowed = True
break
if not allowed:
reasons.append(f"Tool '{action.tool_name}' not in allowed list")
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=reasons,
)
# Check file patterns
if action.resource:
for pattern in policy.denied_file_patterns:
if self._matches_pattern(action.resource, pattern):
reasons.append(f"Resource '{action.resource}' denied by pattern '{pattern}'")
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=reasons,
)
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Permission check passed"],
)
async def _check_budget(
self,
action: ActionRequest,
policy: SafetyPolicy,
) -> GuardianResult:
"""Check if action is within budget."""
# TODO: Implement with CostController
# For now, return allow
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Budget check passed (not fully implemented)"],
)
async def _check_rate_limit(
self,
action: ActionRequest,
policy: SafetyPolicy,
) -> GuardianResult:
"""Check if action is within rate limits."""
# TODO: Implement with RateLimiter
# For now, return allow
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Rate limit check passed (not fully implemented)"],
)
async def _check_loops(
self,
action: ActionRequest,
policy: SafetyPolicy,
) -> GuardianResult:
"""Check for action loops."""
# TODO: Implement with LoopDetector
# For now, return allow
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Loop check passed (not fully implemented)"],
)
async def _check_hitl(
self,
action: ActionRequest,
policy: SafetyPolicy,
) -> GuardianResult:
"""Check if human approval is required."""
if not self._config.hitl_enabled:
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["HITL disabled"],
)
# Check if action requires approval
requires_approval = False
for pattern in policy.require_approval_for:
if pattern == "*":
requires_approval = True
break
if action.tool_name and self._matches_pattern(action.tool_name, pattern):
requires_approval = True
break
if action.action_type.value and self._matches_pattern(
action.action_type.value, pattern
):
requires_approval = True
break
if requires_approval:
# TODO: Create approval request with HITLManager
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.REQUIRE_APPROVAL,
reasons=["Action requires human approval"],
approval_id=None, # Will be set by HITLManager
)
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["No approval required"],
)
async def _create_checkpoint(self, action: ActionRequest) -> str | None:
"""Create a checkpoint before destructive action."""
if self._rollback_manager is None:
logger.warning("Rollback manager not available - skipping checkpoint")
return None
# TODO: Implement with RollbackManager
return None
async def _create_denial_result(
self,
action: ActionRequest,
reasons: list[str],
audit_events: list[Any],
retry_after: float | None = None,
) -> GuardianResult:
"""Create a denial result with audit logging."""
if self._audit_logger:
event = await self._audit_logger.log_action_request(
action, SafetyDecision.DENY, reasons
)
audit_events.append(event)
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=reasons,
retry_after_seconds=retry_after,
audit_events=audit_events,
)
def _matches_pattern(self, value: str, pattern: str) -> bool:
"""Check if value matches a pattern (supports * wildcard)."""
if pattern == "*":
return True
if "*" not in pattern:
return value == pattern
# Simple wildcard matching
if pattern.startswith("*") and pattern.endswith("*"):
return pattern[1:-1] in value
elif pattern.startswith("*"):
return value.endswith(pattern[1:])
elif pattern.endswith("*"):
return value.startswith(pattern[:-1])
else:
# Pattern like "foo*bar"
parts = pattern.split("*")
if len(parts) == 2:
return value.startswith(parts[0]) and value.endswith(parts[1])
return False
# Singleton instance
_guardian_instance: SafetyGuardian | None = None
_guardian_lock = asyncio.Lock()
async def get_safety_guardian() -> SafetyGuardian:
"""Get the global SafetyGuardian instance."""
global _guardian_instance
async with _guardian_lock:
if _guardian_instance is None:
_guardian_instance = SafetyGuardian()
await _guardian_instance.initialize()
return _guardian_instance
async def shutdown_safety_guardian() -> None:
"""Shutdown the global SafetyGuardian."""
global _guardian_instance
async with _guardian_lock:
if _guardian_instance is not None:
await _guardian_instance.shutdown()
_guardian_instance = None
def reset_safety_guardian() -> None:
"""Reset the SafetyGuardian (for testing)."""
global _guardian_instance
_guardian_instance = None

View File

@@ -0,0 +1 @@
"""${dir} module."""

View File

@@ -0,0 +1 @@
"""${dir} module."""

View File

@@ -0,0 +1 @@
"""${dir} module."""

View File

@@ -0,0 +1,474 @@
"""
Safety Framework Models
Core Pydantic models for actions, events, policies, and safety decisions.
"""
from datetime import datetime
from enum import Enum
from typing import Any
from uuid import uuid4
from pydantic import BaseModel, Field
# ============================================================================
# Enums
# ============================================================================
class ActionType(str, Enum):
"""Types of actions that can be performed."""
TOOL_CALL = "tool_call"
FILE_READ = "file_read"
FILE_WRITE = "file_write"
FILE_DELETE = "file_delete"
API_CALL = "api_call"
DATABASE_QUERY = "database_query"
DATABASE_MUTATE = "database_mutate"
GIT_OPERATION = "git_operation"
SHELL_COMMAND = "shell_command"
LLM_CALL = "llm_call"
NETWORK_REQUEST = "network_request"
CUSTOM = "custom"
class ResourceType(str, Enum):
"""Types of resources that can be accessed."""
FILE = "file"
DATABASE = "database"
API = "api"
NETWORK = "network"
GIT = "git"
SHELL = "shell"
LLM = "llm"
MEMORY = "memory"
CUSTOM = "custom"
class PermissionLevel(str, Enum):
"""Permission levels for resource access."""
NONE = "none"
READ = "read"
WRITE = "write"
EXECUTE = "execute"
DELETE = "delete"
ADMIN = "admin"
class AutonomyLevel(str, Enum):
"""Autonomy levels for agent operation."""
FULL_CONTROL = "full_control" # Approve every action
MILESTONE = "milestone" # Approve at milestones
AUTONOMOUS = "autonomous" # Only major decisions
class SafetyDecision(str, Enum):
"""Result of safety validation."""
ALLOW = "allow"
DENY = "deny"
REQUIRE_APPROVAL = "require_approval"
DELAY = "delay"
SANDBOX = "sandbox"
class ApprovalStatus(str, Enum):
"""Status of approval request."""
PENDING = "pending"
APPROVED = "approved"
DENIED = "denied"
TIMEOUT = "timeout"
CANCELLED = "cancelled"
class AuditEventType(str, Enum):
"""Types of audit events."""
ACTION_REQUESTED = "action_requested"
ACTION_VALIDATED = "action_validated"
ACTION_DENIED = "action_denied"
ACTION_EXECUTED = "action_executed"
ACTION_FAILED = "action_failed"
APPROVAL_REQUESTED = "approval_requested"
APPROVAL_GRANTED = "approval_granted"
APPROVAL_DENIED = "approval_denied"
APPROVAL_TIMEOUT = "approval_timeout"
CHECKPOINT_CREATED = "checkpoint_created"
ROLLBACK_STARTED = "rollback_started"
ROLLBACK_COMPLETED = "rollback_completed"
ROLLBACK_FAILED = "rollback_failed"
BUDGET_WARNING = "budget_warning"
BUDGET_EXCEEDED = "budget_exceeded"
RATE_LIMITED = "rate_limited"
LOOP_DETECTED = "loop_detected"
EMERGENCY_STOP = "emergency_stop"
POLICY_VIOLATION = "policy_violation"
CONTENT_FILTERED = "content_filtered"
# ============================================================================
# Action Models
# ============================================================================
class ActionMetadata(BaseModel):
"""Metadata associated with an action."""
agent_id: str = Field(..., description="ID of the agent performing the action")
project_id: str | None = Field(None, description="ID of the project context")
session_id: str | None = Field(None, description="ID of the current session")
task_id: str | None = Field(None, description="ID of the current task")
parent_action_id: str | None = Field(None, description="ID of the parent action")
correlation_id: str | None = Field(None, description="Correlation ID for tracing")
user_id: str | None = Field(None, description="ID of the user who initiated")
autonomy_level: AutonomyLevel = Field(
default=AutonomyLevel.MILESTONE,
description="Current autonomy level",
)
context: dict[str, Any] = Field(
default_factory=dict,
description="Additional context",
)
class ActionRequest(BaseModel):
"""Request to perform an action."""
id: str = Field(default_factory=lambda: str(uuid4()))
action_type: ActionType = Field(..., description="Type of action to perform")
tool_name: str | None = Field(None, description="Name of the tool to call")
resource: str | None = Field(None, description="Resource being accessed")
resource_type: ResourceType | None = Field(None, description="Type of resource")
arguments: dict[str, Any] = Field(
default_factory=dict,
description="Action arguments",
)
metadata: ActionMetadata = Field(..., description="Action metadata")
estimated_cost_tokens: int = Field(0, description="Estimated token cost")
estimated_cost_usd: float = Field(0.0, description="Estimated USD cost")
is_destructive: bool = Field(False, description="Whether action is destructive")
is_reversible: bool = Field(True, description="Whether action can be rolled back")
timestamp: datetime = Field(default_factory=datetime.utcnow)
class ActionResult(BaseModel):
"""Result of an executed action."""
action_id: str = Field(..., description="ID of the action")
success: bool = Field(..., description="Whether action succeeded")
data: Any = Field(None, description="Action result data")
error: str | None = Field(None, description="Error message if failed")
error_code: str | None = Field(None, description="Error code if failed")
execution_time_ms: float = Field(0.0, description="Execution time in ms")
actual_cost_tokens: int = Field(0, description="Actual token cost")
actual_cost_usd: float = Field(0.0, description="Actual USD cost")
checkpoint_id: str | None = Field(None, description="Checkpoint ID if created")
timestamp: datetime = Field(default_factory=datetime.utcnow)
# ============================================================================
# Validation Models
# ============================================================================
class ValidationRule(BaseModel):
"""A single validation rule."""
id: str = Field(default_factory=lambda: str(uuid4()))
name: str = Field(..., description="Rule name")
description: str | None = Field(None, description="Rule description")
priority: int = Field(0, description="Rule priority (higher = evaluated first)")
enabled: bool = Field(True, description="Whether rule is enabled")
# Rule conditions
action_types: list[ActionType] | None = Field(
None, description="Action types this rule applies to"
)
tool_patterns: list[str] | None = Field(
None, description="Tool name patterns (supports wildcards)"
)
resource_patterns: list[str] | None = Field(
None, description="Resource patterns (supports wildcards)"
)
agent_ids: list[str] | None = Field(
None, description="Agent IDs this rule applies to"
)
# Rule decision
decision: SafetyDecision = Field(..., description="Decision when rule matches")
reason: str | None = Field(None, description="Reason for decision")
class ValidationResult(BaseModel):
"""Result of action validation."""
action_id: str = Field(..., description="ID of the validated action")
decision: SafetyDecision = Field(..., description="Validation decision")
applied_rules: list[str] = Field(
default_factory=list, description="IDs of applied rules"
)
reasons: list[str] = Field(
default_factory=list, description="Reasons for decision"
)
approval_id: str | None = Field(None, description="Approval request ID if needed")
retry_after_seconds: float | None = Field(
None, description="Retry delay if rate limited"
)
timestamp: datetime = Field(default_factory=datetime.utcnow)
# ============================================================================
# Budget Models
# ============================================================================
class BudgetScope(str, Enum):
"""Scope of a budget limit."""
SESSION = "session"
DAILY = "daily"
WEEKLY = "weekly"
MONTHLY = "monthly"
PROJECT = "project"
AGENT = "agent"
class BudgetStatus(BaseModel):
"""Current budget status."""
scope: BudgetScope = Field(..., description="Budget scope")
scope_id: str = Field(..., description="ID within scope (session/agent/project)")
tokens_used: int = Field(0, description="Tokens used in this scope")
tokens_limit: int = Field(100000, description="Token limit for this scope")
cost_used_usd: float = Field(0.0, description="USD spent in this scope")
cost_limit_usd: float = Field(10.0, description="USD limit for this scope")
tokens_remaining: int = Field(0, description="Remaining tokens")
cost_remaining_usd: float = Field(0.0, description="Remaining USD budget")
warning_threshold: float = Field(0.8, description="Warn at this usage fraction")
is_warning: bool = Field(False, description="Whether at warning level")
is_exceeded: bool = Field(False, description="Whether budget exceeded")
reset_at: datetime | None = Field(None, description="When budget resets")
# ============================================================================
# Rate Limit Models
# ============================================================================
class RateLimitConfig(BaseModel):
"""Configuration for a rate limit."""
name: str = Field(..., description="Rate limit name")
limit: int = Field(..., description="Maximum allowed in window")
window_seconds: int = Field(60, description="Time window in seconds")
burst_limit: int | None = Field(None, description="Burst allowance")
slowdown_threshold: float = Field(
0.8, description="Start slowing at this fraction"
)
class RateLimitStatus(BaseModel):
"""Current rate limit status."""
name: str = Field(..., description="Rate limit name")
current_count: int = Field(0, description="Current count in window")
limit: int = Field(..., description="Maximum allowed")
window_seconds: int = Field(..., description="Time window")
remaining: int = Field(..., description="Remaining in window")
reset_at: datetime = Field(..., description="When window resets")
is_limited: bool = Field(False, description="Whether currently limited")
retry_after_seconds: float = Field(0.0, description="Seconds until retry")
# ============================================================================
# Approval Models
# ============================================================================
class ApprovalRequest(BaseModel):
"""Request for human approval."""
id: str = Field(default_factory=lambda: str(uuid4()))
action: ActionRequest = Field(..., description="Action requiring approval")
reason: str = Field(..., description="Why approval is required")
urgency: str = Field("normal", description="Urgency level")
timeout_seconds: int = Field(300, description="Timeout for approval")
created_at: datetime = Field(default_factory=datetime.utcnow)
expires_at: datetime | None = Field(None, description="When request expires")
suggested_action: str | None = Field(None, description="Suggested response")
context: dict[str, Any] = Field(default_factory=dict, description="Extra context")
class ApprovalResponse(BaseModel):
"""Response to an approval request."""
request_id: str = Field(..., description="ID of the approval request")
status: ApprovalStatus = Field(..., description="Approval status")
decided_by: str | None = Field(None, description="Who made the decision")
reason: str | None = Field(None, description="Reason for decision")
modifications: dict[str, Any] | None = Field(
None, description="Modifications to action"
)
decided_at: datetime = Field(default_factory=datetime.utcnow)
# ============================================================================
# Checkpoint/Rollback Models
# ============================================================================
class CheckpointType(str, Enum):
"""Types of checkpoints."""
FILE = "file"
DATABASE = "database"
GIT = "git"
COMPOSITE = "composite"
class Checkpoint(BaseModel):
"""A rollback checkpoint."""
id: str = Field(default_factory=lambda: str(uuid4()))
checkpoint_type: CheckpointType = Field(..., description="Type of checkpoint")
action_id: str = Field(..., description="Action this checkpoint is for")
created_at: datetime = Field(default_factory=datetime.utcnow)
expires_at: datetime | None = Field(None, description="When checkpoint expires")
data: dict[str, Any] = Field(default_factory=dict, description="Checkpoint data")
description: str | None = Field(None, description="Description of checkpoint")
is_valid: bool = Field(True, description="Whether checkpoint is still valid")
class RollbackResult(BaseModel):
"""Result of a rollback operation."""
checkpoint_id: str = Field(..., description="ID of checkpoint rolled back to")
success: bool = Field(..., description="Whether rollback succeeded")
actions_rolled_back: list[str] = Field(
default_factory=list, description="IDs of rolled back actions"
)
failed_actions: list[str] = Field(
default_factory=list, description="IDs of actions that failed to rollback"
)
error: str | None = Field(None, description="Error message if failed")
timestamp: datetime = Field(default_factory=datetime.utcnow)
# ============================================================================
# Audit Models
# ============================================================================
class AuditEvent(BaseModel):
"""An audit log event."""
id: str = Field(default_factory=lambda: str(uuid4()))
event_type: AuditEventType = Field(..., description="Type of audit event")
timestamp: datetime = Field(default_factory=datetime.utcnow)
agent_id: str | None = Field(None, description="Agent ID if applicable")
action_id: str | None = Field(None, description="Action ID if applicable")
project_id: str | None = Field(None, description="Project ID if applicable")
session_id: str | None = Field(None, description="Session ID if applicable")
user_id: str | None = Field(None, description="User ID if applicable")
decision: SafetyDecision | None = Field(None, description="Safety decision")
details: dict[str, Any] = Field(default_factory=dict, description="Event details")
correlation_id: str | None = Field(None, description="Correlation ID for tracing")
# ============================================================================
# Policy Models
# ============================================================================
class SafetyPolicy(BaseModel):
"""A complete safety policy configuration."""
name: str = Field(..., description="Policy name")
description: str | None = Field(None, description="Policy description")
version: str = Field("1.0.0", description="Policy version")
enabled: bool = Field(True, description="Whether policy is enabled")
# Cost controls
max_tokens_per_session: int = Field(100_000, description="Max tokens per session")
max_tokens_per_day: int = Field(1_000_000, description="Max tokens per day")
max_cost_per_session_usd: float = Field(10.0, description="Max USD per session")
max_cost_per_day_usd: float = Field(100.0, description="Max USD per day")
# Rate limits
max_actions_per_minute: int = Field(60, description="Max actions per minute")
max_llm_calls_per_minute: int = Field(20, description="Max LLM calls per minute")
max_file_operations_per_minute: int = Field(
100, description="Max file ops per minute"
)
# Permissions
allowed_tools: list[str] = Field(
default_factory=lambda: ["*"],
description="Allowed tool patterns",
)
denied_tools: list[str] = Field(
default_factory=list,
description="Denied tool patterns",
)
allowed_file_patterns: list[str] = Field(
default_factory=lambda: ["**/*"],
description="Allowed file patterns",
)
denied_file_patterns: list[str] = Field(
default_factory=lambda: ["**/.env", "**/secrets/**"],
description="Denied file patterns",
)
# HITL
require_approval_for: list[str] = Field(
default_factory=lambda: [
"delete_file",
"push_to_remote",
"deploy_to_production",
"modify_critical_config",
],
description="Actions requiring approval",
)
# Loop detection
max_repeated_actions: int = Field(5, description="Max exact repetitions")
max_similar_actions: int = Field(10, description="Max similar actions")
# Sandbox
require_sandbox: bool = Field(False, description="Require sandbox execution")
sandbox_timeout_seconds: int = Field(300, description="Sandbox timeout")
sandbox_memory_mb: int = Field(1024, description="Sandbox memory limit")
# Validation rules
validation_rules: list[ValidationRule] = Field(
default_factory=list,
description="Custom validation rules",
)
# ============================================================================
# Guardian Result Models
# ============================================================================
class GuardianResult(BaseModel):
"""Result of SafetyGuardian evaluation."""
action_id: str = Field(..., description="ID of the action")
allowed: bool = Field(..., description="Whether action is allowed")
decision: SafetyDecision = Field(..., description="Safety decision")
reasons: list[str] = Field(default_factory=list, description="Decision reasons")
approval_id: str | None = Field(None, description="Approval ID if needed")
checkpoint_id: str | None = Field(None, description="Checkpoint ID if created")
retry_after_seconds: float | None = Field(None, description="Retry delay")
modified_action: ActionRequest | None = Field(
None, description="Modified action if changed"
)
audit_events: list[AuditEvent] = Field(
default_factory=list, description="Generated audit events"
)

View File

@@ -0,0 +1 @@
"""${dir} module."""

View File

@@ -0,0 +1 @@
"""${dir} module."""

View File

@@ -0,0 +1 @@
"""${dir} module."""

View File

@@ -0,0 +1 @@
"""${dir} module."""

View File

@@ -0,0 +1 @@
"""${dir} module."""