""" Safety Guardian Main facade for the safety framework. Orchestrates all safety checks before, during, and after action execution. """ import asyncio import logging from typing import Any from .audit import AuditLogger, get_audit_logger from .config import ( SafetyConfig, get_policy_for_autonomy_level, get_safety_config, ) from .costs.controller import CostController from .exceptions import ( BudgetExceededError, LoopDetectedError, RateLimitExceededError, SafetyError, ) from .limits.limiter import RateLimiter from .loops.detector import LoopDetector from .models import ( ActionRequest, ActionResult, AuditEventType, BudgetScope, GuardianResult, SafetyDecision, SafetyPolicy, ) logger = logging.getLogger(__name__) class SafetyGuardian: """ Central orchestrator for all safety checks. The SafetyGuardian is the main entry point for validating agent actions. It coordinates multiple safety subsystems: - Permission checking - Cost/budget control - Rate limiting - Loop detection - Human-in-the-loop approval - Rollback/checkpoint management - Content filtering - Sandbox execution Usage: guardian = SafetyGuardian() await guardian.initialize() # Before executing an action result = await guardian.validate(action_request) if not result.allowed: # Handle denial # After action execution await guardian.record_execution(action_request, action_result) """ def __init__( self, config: SafetyConfig | None = None, audit_logger: AuditLogger | None = None, cost_controller: CostController | None = None, rate_limiter: RateLimiter | None = None, loop_detector: LoopDetector | None = None, ) -> None: """ Initialize the SafetyGuardian. Args: config: Optional safety configuration. If None, loads from environment. audit_logger: Optional audit logger. If None, uses global instance. cost_controller: Optional cost controller. If None, creates default. rate_limiter: Optional rate limiter. If None, creates default. loop_detector: Optional loop detector. If None, creates default. """ self._config = config or get_safety_config() self._audit_logger = audit_logger self._initialized = False self._lock = asyncio.Lock() # Core safety subsystems (always initialized) self._cost_controller: CostController | None = cost_controller self._rate_limiter: RateLimiter | None = rate_limiter self._loop_detector: LoopDetector | None = loop_detector # Optional subsystems (will be initialized when available) self._permission_manager: Any = None self._hitl_manager: Any = None self._rollback_manager: Any = None self._content_filter: Any = None self._sandbox_executor: Any = None self._emergency_controls: Any = None # Policy cache self._policies: dict[str, SafetyPolicy] = {} self._default_policy: SafetyPolicy | None = None @property def is_initialized(self) -> bool: """Check if the guardian is initialized.""" return self._initialized @property def cost_controller(self) -> CostController | None: """Get the cost controller instance.""" return self._cost_controller @property def rate_limiter(self) -> RateLimiter | None: """Get the rate limiter instance.""" return self._rate_limiter @property def loop_detector(self) -> LoopDetector | None: """Get the loop detector instance.""" return self._loop_detector async def initialize(self) -> None: """Initialize the SafetyGuardian and all subsystems.""" async with self._lock: if self._initialized: logger.warning("SafetyGuardian already initialized") return logger.info("Initializing SafetyGuardian") # Get audit logger if self._audit_logger is None: self._audit_logger = await get_audit_logger() # Initialize core safety subsystems if self._cost_controller is None: self._cost_controller = CostController() logger.debug("Initialized CostController") if self._rate_limiter is None: self._rate_limiter = RateLimiter() logger.debug("Initialized RateLimiter") if self._loop_detector is None: self._loop_detector = LoopDetector() logger.debug("Initialized LoopDetector") self._initialized = True logger.info( "SafetyGuardian initialized with CostController, RateLimiter, LoopDetector" ) async def shutdown(self) -> None: """Shutdown the SafetyGuardian and all subsystems.""" async with self._lock: if not self._initialized: return logger.info("Shutting down SafetyGuardian") # Shutdown subsystems # (Will be implemented as subsystems are added) self._initialized = False logger.info("SafetyGuardian shutdown complete") async def validate( self, action: ActionRequest, policy: SafetyPolicy | None = None, ) -> GuardianResult: """ Validate an action before execution. Runs all safety checks in order: 1. Permission check 2. Cost/budget check 3. Rate limit check 4. Loop detection 5. HITL check (if required) 6. Checkpoint creation (if destructive) Args: action: The action to validate policy: Optional policy override. If None, uses autonomy-level policy. Returns: GuardianResult with decision and details """ if not self._initialized: await self.initialize() if not self._config.enabled: # Safety disabled - allow everything (NOT RECOMMENDED) logger.warning("Safety framework disabled - allowing action %s", action.id) return GuardianResult( action_id=action.id, allowed=True, decision=SafetyDecision.ALLOW, reasons=["Safety framework disabled"], ) # Get policy for this action effective_policy = policy or self._get_policy(action) reasons: list[str] = [] audit_events = [] try: # Log action request if self._audit_logger: event = await self._audit_logger.log( AuditEventType.ACTION_REQUESTED, agent_id=action.metadata.agent_id, action_id=action.id, project_id=action.metadata.project_id, session_id=action.metadata.session_id, details={ "action_type": action.action_type.value, "tool_name": action.tool_name, "resource": action.resource, }, correlation_id=action.metadata.correlation_id, ) audit_events.append(event) # 1. Permission check permission_result = await self._check_permissions(action, effective_policy) if permission_result.decision == SafetyDecision.DENY: return await self._create_denial_result( action, permission_result.reasons, audit_events ) # 2. Cost/budget check budget_result = await self._check_budget(action, effective_policy) if budget_result.decision == SafetyDecision.DENY: return await self._create_denial_result( action, budget_result.reasons, audit_events ) # 3. Rate limit check rate_result = await self._check_rate_limit(action, effective_policy) if rate_result.decision == SafetyDecision.DENY: return await self._create_denial_result( action, rate_result.reasons, audit_events, retry_after=rate_result.retry_after_seconds, ) if rate_result.decision == SafetyDecision.DELAY: # Return delay decision return GuardianResult( action_id=action.id, allowed=False, decision=SafetyDecision.DELAY, reasons=rate_result.reasons, retry_after_seconds=rate_result.retry_after_seconds, audit_events=audit_events, ) # 4. Loop detection loop_result = await self._check_loops(action, effective_policy) if loop_result.decision == SafetyDecision.DENY: return await self._create_denial_result( action, loop_result.reasons, audit_events ) # 5. HITL check hitl_result = await self._check_hitl(action, effective_policy) if hitl_result.decision == SafetyDecision.REQUIRE_APPROVAL: return GuardianResult( action_id=action.id, allowed=False, decision=SafetyDecision.REQUIRE_APPROVAL, reasons=hitl_result.reasons, approval_id=hitl_result.approval_id, audit_events=audit_events, ) # 6. Create checkpoint if destructive checkpoint_id = None if action.is_destructive and self._config.auto_checkpoint_destructive: checkpoint_id = await self._create_checkpoint(action) # All checks passed reasons.append("All safety checks passed") if self._audit_logger: event = await self._audit_logger.log_action_request( action, SafetyDecision.ALLOW, reasons ) audit_events.append(event) return GuardianResult( action_id=action.id, allowed=True, decision=SafetyDecision.ALLOW, reasons=reasons, checkpoint_id=checkpoint_id, audit_events=audit_events, ) except SafetyError as e: # Known safety error return await self._create_denial_result(action, [str(e)], audit_events) except Exception as e: # Unknown error - fail closed in strict mode logger.error("Unexpected error in safety validation: %s", e) if self._config.strict_mode: return await self._create_denial_result( action, [f"Safety validation error: {e}"], audit_events, ) else: # Non-strict mode - allow with warning logger.warning("Non-strict mode: allowing action despite error") return GuardianResult( action_id=action.id, allowed=True, decision=SafetyDecision.ALLOW, reasons=["Allowed despite validation error (non-strict mode)"], audit_events=audit_events, ) async def record_execution( self, action: ActionRequest, result: ActionResult, ) -> None: """ Record action execution result for auditing and tracking. Args: action: The executed action result: The execution result """ if self._audit_logger: await self._audit_logger.log_action_executed( action, success=result.success, execution_time_ms=result.execution_time_ms, error=result.error, ) # Update cost tracking if self._cost_controller: try: # Use explicit None check - 0 is a valid cost value tokens = ( result.actual_cost_tokens if result.actual_cost_tokens is not None else action.estimated_cost_tokens ) cost_usd = ( result.actual_cost_usd if result.actual_cost_usd is not None else action.estimated_cost_usd ) await self._cost_controller.record_usage( agent_id=action.metadata.agent_id, session_id=action.metadata.session_id, tokens=tokens, cost_usd=cost_usd, ) except Exception as e: logger.warning("Failed to record cost: %s", e) # Update rate limiter - consume slots for executed actions if self._rate_limiter: try: await self._rate_limiter.record_action(action) except Exception as e: logger.warning("Failed to record action in rate limiter: %s", e) # Update loop detection history if self._loop_detector: try: await self._loop_detector.record(action) except Exception as e: logger.warning("Failed to record action in loop detector: %s", e) async def rollback(self, checkpoint_id: str) -> bool: """ Rollback to a checkpoint. Args: checkpoint_id: ID of the checkpoint to rollback to Returns: True if rollback succeeded """ if self._rollback_manager is None: logger.warning("Rollback manager not available") return False # Delegate to rollback manager return await self._rollback_manager.rollback(checkpoint_id) async def emergency_stop( self, stop_type: str = "kill", reason: str = "Manual emergency stop", triggered_by: str = "system", ) -> None: """ Trigger emergency stop. Args: stop_type: Type of stop (kill, pause, lockdown) reason: Reason for the stop triggered_by: Who triggered the stop """ logger.critical( "Emergency stop triggered: type=%s, reason=%s, by=%s", stop_type, reason, triggered_by, ) if self._audit_logger: await self._audit_logger.log_emergency_stop( stop_type=stop_type, triggered_by=triggered_by, reason=reason, ) if self._emergency_controls: await self._emergency_controls.execute_stop(stop_type) def _get_policy(self, action: ActionRequest) -> SafetyPolicy: """Get the effective policy for an action.""" # Check cached policies autonomy_level = action.metadata.autonomy_level if autonomy_level.value not in self._policies: self._policies[autonomy_level.value] = get_policy_for_autonomy_level( autonomy_level ) return self._policies[autonomy_level.value] async def _check_permissions( self, action: ActionRequest, policy: SafetyPolicy, ) -> GuardianResult: """Check if action is permitted.""" reasons: list[str] = [] # Check denied tools if action.tool_name: for pattern in policy.denied_tools: if self._matches_pattern(action.tool_name, pattern): reasons.append( f"Tool '{action.tool_name}' denied by pattern '{pattern}'" ) return GuardianResult( action_id=action.id, allowed=False, decision=SafetyDecision.DENY, reasons=reasons, ) # Check allowed tools (if not "*") if action.tool_name and "*" not in policy.allowed_tools: allowed = False for pattern in policy.allowed_tools: if self._matches_pattern(action.tool_name, pattern): allowed = True break if not allowed: reasons.append(f"Tool '{action.tool_name}' not in allowed list") return GuardianResult( action_id=action.id, allowed=False, decision=SafetyDecision.DENY, reasons=reasons, ) # Check file patterns if action.resource: for pattern in policy.denied_file_patterns: if self._matches_pattern(action.resource, pattern): reasons.append( f"Resource '{action.resource}' denied by pattern '{pattern}'" ) return GuardianResult( action_id=action.id, allowed=False, decision=SafetyDecision.DENY, reasons=reasons, ) return GuardianResult( action_id=action.id, allowed=True, decision=SafetyDecision.ALLOW, reasons=["Permission check passed"], ) async def _check_budget( self, action: ActionRequest, policy: SafetyPolicy, ) -> GuardianResult: """Check if action is within budget.""" if self._cost_controller is None: logger.warning("CostController not initialized - skipping budget check") return GuardianResult( action_id=action.id, allowed=True, decision=SafetyDecision.ALLOW, reasons=["Budget check skipped (controller not initialized)"], ) agent_id = action.metadata.agent_id session_id = action.metadata.session_id try: # Check if we have budget for this action has_budget = await self._cost_controller.check_budget( agent_id=agent_id, session_id=session_id, estimated_tokens=action.estimated_cost_tokens, estimated_cost_usd=action.estimated_cost_usd, ) if not has_budget: # Get current status for better error message if session_id: session_status = await self._cost_controller.get_status( BudgetScope.SESSION, session_id ) if session_status and session_status.is_exceeded: return GuardianResult( action_id=action.id, allowed=False, decision=SafetyDecision.DENY, reasons=[ f"Session budget exceeded: {session_status.tokens_used}" f"/{session_status.tokens_limit} tokens" ], ) agent_status = await self._cost_controller.get_status( BudgetScope.DAILY, agent_id ) if agent_status and agent_status.is_exceeded: return GuardianResult( action_id=action.id, allowed=False, decision=SafetyDecision.DENY, reasons=[ f"Daily budget exceeded: {agent_status.tokens_used}" f"/{agent_status.tokens_limit} tokens" ], ) # Generic budget exceeded return GuardianResult( action_id=action.id, allowed=False, decision=SafetyDecision.DENY, reasons=["Budget exceeded"], ) return GuardianResult( action_id=action.id, allowed=True, decision=SafetyDecision.ALLOW, reasons=["Budget check passed"], ) except BudgetExceededError as e: return GuardianResult( action_id=action.id, allowed=False, decision=SafetyDecision.DENY, reasons=[str(e)], ) async def _check_rate_limit( self, action: ActionRequest, policy: SafetyPolicy, ) -> GuardianResult: """Check if action is within rate limits.""" if self._rate_limiter is None: logger.warning("RateLimiter not initialized - skipping rate limit check") return GuardianResult( action_id=action.id, allowed=True, decision=SafetyDecision.ALLOW, reasons=["Rate limit check skipped (limiter not initialized)"], ) try: # Check all applicable rate limits for this action allowed, statuses = await self._rate_limiter.check_action(action) if not allowed: # Find the first exceeded limit for the error message exceeded_status = next( (s for s in statuses if s.is_limited), statuses[0] if statuses else None, ) if exceeded_status: retry_after = exceeded_status.retry_after_seconds # Determine if this is a soft limit (delay) or hard limit (deny) if retry_after > 0 and retry_after <= 5.0: # Short wait - suggest delay return GuardianResult( action_id=action.id, allowed=False, decision=SafetyDecision.DELAY, reasons=[ f"Rate limit '{exceeded_status.name}' exceeded. " f"Current: {exceeded_status.current_count}/{exceeded_status.limit}" ], retry_after_seconds=retry_after, ) else: # Hard deny return GuardianResult( action_id=action.id, allowed=False, decision=SafetyDecision.DENY, reasons=[ f"Rate limit '{exceeded_status.name}' exceeded. " f"Current: {exceeded_status.current_count}/{exceeded_status.limit}. " f"Retry after {retry_after:.1f}s" ], retry_after_seconds=retry_after, ) return GuardianResult( action_id=action.id, allowed=False, decision=SafetyDecision.DENY, reasons=["Rate limit exceeded"], ) return GuardianResult( action_id=action.id, allowed=True, decision=SafetyDecision.ALLOW, reasons=["Rate limit check passed"], ) except RateLimitExceededError as e: return GuardianResult( action_id=action.id, allowed=False, decision=SafetyDecision.DENY, reasons=[str(e)], retry_after_seconds=e.retry_after_seconds, ) async def _check_loops( self, action: ActionRequest, policy: SafetyPolicy, ) -> GuardianResult: """Check for action loops.""" if self._loop_detector is None: logger.warning("LoopDetector not initialized - skipping loop check") return GuardianResult( action_id=action.id, allowed=True, decision=SafetyDecision.ALLOW, reasons=["Loop check skipped (detector not initialized)"], ) try: # Check if this action would create a loop is_loop, loop_type = await self._loop_detector.check(action) if is_loop: # Get suggestions for breaking the loop from .loops.detector import LoopBreaker suggestions = await LoopBreaker.suggest_alternatives( action, loop_type or "unknown" ) return GuardianResult( action_id=action.id, allowed=False, decision=SafetyDecision.DENY, reasons=[ f"Loop detected: {loop_type}", *suggestions, ], ) return GuardianResult( action_id=action.id, allowed=True, decision=SafetyDecision.ALLOW, reasons=["Loop check passed"], ) except LoopDetectedError as e: return GuardianResult( action_id=action.id, allowed=False, decision=SafetyDecision.DENY, reasons=[str(e)], ) async def _check_hitl( self, action: ActionRequest, policy: SafetyPolicy, ) -> GuardianResult: """Check if human approval is required.""" if not self._config.hitl_enabled: return GuardianResult( action_id=action.id, allowed=True, decision=SafetyDecision.ALLOW, reasons=["HITL disabled"], ) # Check if action requires approval requires_approval = False for pattern in policy.require_approval_for: if pattern == "*": requires_approval = True break if action.tool_name and self._matches_pattern(action.tool_name, pattern): requires_approval = True break if action.action_type.value and self._matches_pattern( action.action_type.value, pattern ): requires_approval = True break if requires_approval: # TODO: Create approval request with HITLManager return GuardianResult( action_id=action.id, allowed=False, decision=SafetyDecision.REQUIRE_APPROVAL, reasons=["Action requires human approval"], approval_id=None, # Will be set by HITLManager ) return GuardianResult( action_id=action.id, allowed=True, decision=SafetyDecision.ALLOW, reasons=["No approval required"], ) async def _create_checkpoint(self, action: ActionRequest) -> str | None: """Create a checkpoint before destructive action.""" if self._rollback_manager is None: logger.warning("Rollback manager not available - skipping checkpoint") return None # TODO: Implement with RollbackManager return None async def _create_denial_result( self, action: ActionRequest, reasons: list[str], audit_events: list[Any], retry_after: float | None = None, ) -> GuardianResult: """Create a denial result with audit logging.""" if self._audit_logger: event = await self._audit_logger.log_action_request( action, SafetyDecision.DENY, reasons ) audit_events.append(event) return GuardianResult( action_id=action.id, allowed=False, decision=SafetyDecision.DENY, reasons=reasons, retry_after_seconds=retry_after, audit_events=audit_events, ) def _matches_pattern(self, value: str, pattern: str) -> bool: """Check if value matches a pattern (supports * wildcard).""" if pattern == "*": return True if "*" not in pattern: return value == pattern # Simple wildcard matching if pattern.startswith("*") and pattern.endswith("*"): return pattern[1:-1] in value elif pattern.startswith("*"): return value.endswith(pattern[1:]) elif pattern.endswith("*"): return value.startswith(pattern[:-1]) else: # Pattern like "foo*bar" parts = pattern.split("*") if len(parts) == 2: return value.startswith(parts[0]) and value.endswith(parts[1]) return False # Singleton instance _guardian_instance: SafetyGuardian | None = None _guardian_lock = asyncio.Lock() async def get_safety_guardian() -> SafetyGuardian: """Get the global SafetyGuardian instance.""" global _guardian_instance async with _guardian_lock: if _guardian_instance is None: _guardian_instance = SafetyGuardian() await _guardian_instance.initialize() return _guardian_instance async def shutdown_safety_guardian() -> None: """Shutdown the global SafetyGuardian.""" global _guardian_instance async with _guardian_lock: if _guardian_instance is not None: await _guardian_instance.shutdown() _guardian_instance = None async def reset_safety_guardian() -> None: """ Reset the SafetyGuardian (for testing). This is an async function to properly acquire the guardian lock and avoid race conditions with get_safety_guardian(). """ global _guardian_instance async with _guardian_lock: if _guardian_instance is not None: try: await _guardian_instance.shutdown() except Exception: # noqa: S110 pass # Ignore errors during test cleanup _guardian_instance = None