""" Safety Guardian Main facade for the safety framework. Orchestrates all safety checks before, during, and after action execution. """ import asyncio import logging from typing import Any from .audit import AuditLogger, get_audit_logger from .config import ( SafetyConfig, get_policy_for_autonomy_level, get_safety_config, ) from .exceptions import ( SafetyError, ) from .models import ( ActionRequest, ActionResult, AuditEventType, GuardianResult, SafetyDecision, SafetyPolicy, ) logger = logging.getLogger(__name__) class SafetyGuardian: """ Central orchestrator for all safety checks. The SafetyGuardian is the main entry point for validating agent actions. It coordinates multiple safety subsystems: - Permission checking - Cost/budget control - Rate limiting - Loop detection - Human-in-the-loop approval - Rollback/checkpoint management - Content filtering - Sandbox execution Usage: guardian = SafetyGuardian() await guardian.initialize() # Before executing an action result = await guardian.validate(action_request) if not result.allowed: # Handle denial # After action execution await guardian.record_execution(action_request, action_result) """ def __init__( self, config: SafetyConfig | None = None, audit_logger: AuditLogger | None = None, ) -> None: """ Initialize the SafetyGuardian. Args: config: Optional safety configuration. If None, loads from environment. audit_logger: Optional audit logger. If None, uses global instance. """ self._config = config or get_safety_config() self._audit_logger = audit_logger self._initialized = False self._lock = asyncio.Lock() # Subsystem references (will be initialized lazily) self._permission_manager: Any = None self._cost_controller: Any = None self._rate_limiter: Any = None self._loop_detector: Any = None self._hitl_manager: Any = None self._rollback_manager: Any = None self._content_filter: Any = None self._sandbox_executor: Any = None self._emergency_controls: Any = None # Policy cache self._policies: dict[str, SafetyPolicy] = {} self._default_policy: SafetyPolicy | None = None @property def is_initialized(self) -> bool: """Check if the guardian is initialized.""" return self._initialized async def initialize(self) -> None: """Initialize the SafetyGuardian and all subsystems.""" async with self._lock: if self._initialized: logger.warning("SafetyGuardian already initialized") return logger.info("Initializing SafetyGuardian") # Get audit logger if self._audit_logger is None: self._audit_logger = await get_audit_logger() # Initialize subsystems lazily as they're implemented # For now, we'll import and initialize them when available self._initialized = True logger.info("SafetyGuardian initialized") async def shutdown(self) -> None: """Shutdown the SafetyGuardian and all subsystems.""" async with self._lock: if not self._initialized: return logger.info("Shutting down SafetyGuardian") # Shutdown subsystems # (Will be implemented as subsystems are added) self._initialized = False logger.info("SafetyGuardian shutdown complete") async def validate( self, action: ActionRequest, policy: SafetyPolicy | None = None, ) -> GuardianResult: """ Validate an action before execution. Runs all safety checks in order: 1. Permission check 2. Cost/budget check 3. Rate limit check 4. Loop detection 5. HITL check (if required) 6. Checkpoint creation (if destructive) Args: action: The action to validate policy: Optional policy override. If None, uses autonomy-level policy. Returns: GuardianResult with decision and details """ if not self._initialized: await self.initialize() if not self._config.enabled: # Safety disabled - allow everything (NOT RECOMMENDED) logger.warning("Safety framework disabled - allowing action %s", action.id) return GuardianResult( action_id=action.id, allowed=True, decision=SafetyDecision.ALLOW, reasons=["Safety framework disabled"], ) # Get policy for this action effective_policy = policy or self._get_policy(action) reasons: list[str] = [] audit_events = [] try: # Log action request if self._audit_logger: event = await self._audit_logger.log( AuditEventType.ACTION_REQUESTED, agent_id=action.metadata.agent_id, action_id=action.id, project_id=action.metadata.project_id, session_id=action.metadata.session_id, details={ "action_type": action.action_type.value, "tool_name": action.tool_name, "resource": action.resource, }, correlation_id=action.metadata.correlation_id, ) audit_events.append(event) # 1. Permission check permission_result = await self._check_permissions(action, effective_policy) if permission_result.decision == SafetyDecision.DENY: return await self._create_denial_result( action, permission_result.reasons, audit_events ) # 2. Cost/budget check budget_result = await self._check_budget(action, effective_policy) if budget_result.decision == SafetyDecision.DENY: return await self._create_denial_result( action, budget_result.reasons, audit_events ) # 3. Rate limit check rate_result = await self._check_rate_limit(action, effective_policy) if rate_result.decision == SafetyDecision.DENY: return await self._create_denial_result( action, rate_result.reasons, audit_events, retry_after=rate_result.retry_after_seconds, ) if rate_result.decision == SafetyDecision.DELAY: # Return delay decision return GuardianResult( action_id=action.id, allowed=False, decision=SafetyDecision.DELAY, reasons=rate_result.reasons, retry_after_seconds=rate_result.retry_after_seconds, audit_events=audit_events, ) # 4. Loop detection loop_result = await self._check_loops(action, effective_policy) if loop_result.decision == SafetyDecision.DENY: return await self._create_denial_result( action, loop_result.reasons, audit_events ) # 5. HITL check hitl_result = await self._check_hitl(action, effective_policy) if hitl_result.decision == SafetyDecision.REQUIRE_APPROVAL: return GuardianResult( action_id=action.id, allowed=False, decision=SafetyDecision.REQUIRE_APPROVAL, reasons=hitl_result.reasons, approval_id=hitl_result.approval_id, audit_events=audit_events, ) # 6. Create checkpoint if destructive checkpoint_id = None if action.is_destructive and self._config.auto_checkpoint_destructive: checkpoint_id = await self._create_checkpoint(action) # All checks passed reasons.append("All safety checks passed") if self._audit_logger: event = await self._audit_logger.log_action_request( action, SafetyDecision.ALLOW, reasons ) audit_events.append(event) return GuardianResult( action_id=action.id, allowed=True, decision=SafetyDecision.ALLOW, reasons=reasons, checkpoint_id=checkpoint_id, audit_events=audit_events, ) except SafetyError as e: # Known safety error return await self._create_denial_result(action, [str(e)], audit_events) except Exception as e: # Unknown error - fail closed in strict mode logger.error("Unexpected error in safety validation: %s", e) if self._config.strict_mode: return await self._create_denial_result( action, [f"Safety validation error: {e}"], audit_events, ) else: # Non-strict mode - allow with warning logger.warning("Non-strict mode: allowing action despite error") return GuardianResult( action_id=action.id, allowed=True, decision=SafetyDecision.ALLOW, reasons=["Allowed despite validation error (non-strict mode)"], audit_events=audit_events, ) async def record_execution( self, action: ActionRequest, result: ActionResult, ) -> None: """ Record action execution result for auditing and tracking. Args: action: The executed action result: The execution result """ if self._audit_logger: await self._audit_logger.log_action_executed( action, success=result.success, execution_time_ms=result.execution_time_ms, error=result.error, ) # Update cost tracking if self._cost_controller: # Track actual cost pass # Update loop detection history if self._loop_detector: # Add to action history pass async def rollback(self, checkpoint_id: str) -> bool: """ Rollback to a checkpoint. Args: checkpoint_id: ID of the checkpoint to rollback to Returns: True if rollback succeeded """ if self._rollback_manager is None: logger.warning("Rollback manager not available") return False # Delegate to rollback manager return await self._rollback_manager.rollback(checkpoint_id) async def emergency_stop( self, stop_type: str = "kill", reason: str = "Manual emergency stop", triggered_by: str = "system", ) -> None: """ Trigger emergency stop. Args: stop_type: Type of stop (kill, pause, lockdown) reason: Reason for the stop triggered_by: Who triggered the stop """ logger.critical( "Emergency stop triggered: type=%s, reason=%s, by=%s", stop_type, reason, triggered_by, ) if self._audit_logger: await self._audit_logger.log_emergency_stop( stop_type=stop_type, triggered_by=triggered_by, reason=reason, ) if self._emergency_controls: await self._emergency_controls.execute_stop(stop_type) def _get_policy(self, action: ActionRequest) -> SafetyPolicy: """Get the effective policy for an action.""" # Check cached policies autonomy_level = action.metadata.autonomy_level if autonomy_level.value not in self._policies: self._policies[autonomy_level.value] = get_policy_for_autonomy_level( autonomy_level ) return self._policies[autonomy_level.value] async def _check_permissions( self, action: ActionRequest, policy: SafetyPolicy, ) -> GuardianResult: """Check if action is permitted.""" reasons: list[str] = [] # Check denied tools if action.tool_name: for pattern in policy.denied_tools: if self._matches_pattern(action.tool_name, pattern): reasons.append( f"Tool '{action.tool_name}' denied by pattern '{pattern}'" ) return GuardianResult( action_id=action.id, allowed=False, decision=SafetyDecision.DENY, reasons=reasons, ) # Check allowed tools (if not "*") if action.tool_name and "*" not in policy.allowed_tools: allowed = False for pattern in policy.allowed_tools: if self._matches_pattern(action.tool_name, pattern): allowed = True break if not allowed: reasons.append(f"Tool '{action.tool_name}' not in allowed list") return GuardianResult( action_id=action.id, allowed=False, decision=SafetyDecision.DENY, reasons=reasons, ) # Check file patterns if action.resource: for pattern in policy.denied_file_patterns: if self._matches_pattern(action.resource, pattern): reasons.append( f"Resource '{action.resource}' denied by pattern '{pattern}'" ) return GuardianResult( action_id=action.id, allowed=False, decision=SafetyDecision.DENY, reasons=reasons, ) return GuardianResult( action_id=action.id, allowed=True, decision=SafetyDecision.ALLOW, reasons=["Permission check passed"], ) async def _check_budget( self, action: ActionRequest, policy: SafetyPolicy, ) -> GuardianResult: """Check if action is within budget.""" # TODO: Implement with CostController # For now, return allow return GuardianResult( action_id=action.id, allowed=True, decision=SafetyDecision.ALLOW, reasons=["Budget check passed (not fully implemented)"], ) async def _check_rate_limit( self, action: ActionRequest, policy: SafetyPolicy, ) -> GuardianResult: """Check if action is within rate limits.""" # TODO: Implement with RateLimiter # For now, return allow return GuardianResult( action_id=action.id, allowed=True, decision=SafetyDecision.ALLOW, reasons=["Rate limit check passed (not fully implemented)"], ) async def _check_loops( self, action: ActionRequest, policy: SafetyPolicy, ) -> GuardianResult: """Check for action loops.""" # TODO: Implement with LoopDetector # For now, return allow return GuardianResult( action_id=action.id, allowed=True, decision=SafetyDecision.ALLOW, reasons=["Loop check passed (not fully implemented)"], ) async def _check_hitl( self, action: ActionRequest, policy: SafetyPolicy, ) -> GuardianResult: """Check if human approval is required.""" if not self._config.hitl_enabled: return GuardianResult( action_id=action.id, allowed=True, decision=SafetyDecision.ALLOW, reasons=["HITL disabled"], ) # Check if action requires approval requires_approval = False for pattern in policy.require_approval_for: if pattern == "*": requires_approval = True break if action.tool_name and self._matches_pattern(action.tool_name, pattern): requires_approval = True break if action.action_type.value and self._matches_pattern( action.action_type.value, pattern ): requires_approval = True break if requires_approval: # TODO: Create approval request with HITLManager return GuardianResult( action_id=action.id, allowed=False, decision=SafetyDecision.REQUIRE_APPROVAL, reasons=["Action requires human approval"], approval_id=None, # Will be set by HITLManager ) return GuardianResult( action_id=action.id, allowed=True, decision=SafetyDecision.ALLOW, reasons=["No approval required"], ) async def _create_checkpoint(self, action: ActionRequest) -> str | None: """Create a checkpoint before destructive action.""" if self._rollback_manager is None: logger.warning("Rollback manager not available - skipping checkpoint") return None # TODO: Implement with RollbackManager return None async def _create_denial_result( self, action: ActionRequest, reasons: list[str], audit_events: list[Any], retry_after: float | None = None, ) -> GuardianResult: """Create a denial result with audit logging.""" if self._audit_logger: event = await self._audit_logger.log_action_request( action, SafetyDecision.DENY, reasons ) audit_events.append(event) return GuardianResult( action_id=action.id, allowed=False, decision=SafetyDecision.DENY, reasons=reasons, retry_after_seconds=retry_after, audit_events=audit_events, ) def _matches_pattern(self, value: str, pattern: str) -> bool: """Check if value matches a pattern (supports * wildcard).""" if pattern == "*": return True if "*" not in pattern: return value == pattern # Simple wildcard matching if pattern.startswith("*") and pattern.endswith("*"): return pattern[1:-1] in value elif pattern.startswith("*"): return value.endswith(pattern[1:]) elif pattern.endswith("*"): return value.startswith(pattern[:-1]) else: # Pattern like "foo*bar" parts = pattern.split("*") if len(parts) == 2: return value.startswith(parts[0]) and value.endswith(parts[1]) return False # Singleton instance _guardian_instance: SafetyGuardian | None = None _guardian_lock = asyncio.Lock() async def get_safety_guardian() -> SafetyGuardian: """Get the global SafetyGuardian instance.""" global _guardian_instance async with _guardian_lock: if _guardian_instance is None: _guardian_instance = SafetyGuardian() await _guardian_instance.initialize() return _guardian_instance async def shutdown_safety_guardian() -> None: """Shutdown the global SafetyGuardian.""" global _guardian_instance async with _guardian_lock: if _guardian_instance is not None: await _guardian_instance.shutdown() _guardian_instance = None def reset_safety_guardian() -> None: """Reset the SafetyGuardian (for testing).""" global _guardian_instance _guardian_instance = None