Files
fast-next-template/backend/app/services/safety/guardian.py
Felipe Cardoso 520c06175e refactor(safety): apply consistent formatting across services and tests
Improved code readability and uniformity by standardizing line breaks, indentation, and inline conditions across safety-related services, models, and tests, including content filters, validation rules, and emergency controls.
2026-01-03 16:23:39 +01:00

617 lines
20 KiB
Python

"""
Safety Guardian
Main facade for the safety framework. Orchestrates all safety checks
before, during, and after action execution.
"""
import asyncio
import logging
from typing import Any
from .audit import AuditLogger, get_audit_logger
from .config import (
SafetyConfig,
get_policy_for_autonomy_level,
get_safety_config,
)
from .exceptions import (
SafetyError,
)
from .models import (
ActionRequest,
ActionResult,
AuditEventType,
GuardianResult,
SafetyDecision,
SafetyPolicy,
)
logger = logging.getLogger(__name__)
class SafetyGuardian:
"""
Central orchestrator for all safety checks.
The SafetyGuardian is the main entry point for validating agent actions.
It coordinates multiple safety subsystems:
- Permission checking
- Cost/budget control
- Rate limiting
- Loop detection
- Human-in-the-loop approval
- Rollback/checkpoint management
- Content filtering
- Sandbox execution
Usage:
guardian = SafetyGuardian()
await guardian.initialize()
# Before executing an action
result = await guardian.validate(action_request)
if not result.allowed:
# Handle denial
# After action execution
await guardian.record_execution(action_request, action_result)
"""
def __init__(
self,
config: SafetyConfig | None = None,
audit_logger: AuditLogger | None = None,
) -> None:
"""
Initialize the SafetyGuardian.
Args:
config: Optional safety configuration. If None, loads from environment.
audit_logger: Optional audit logger. If None, uses global instance.
"""
self._config = config or get_safety_config()
self._audit_logger = audit_logger
self._initialized = False
self._lock = asyncio.Lock()
# Subsystem references (will be initialized lazily)
self._permission_manager: Any = None
self._cost_controller: Any = None
self._rate_limiter: Any = None
self._loop_detector: Any = None
self._hitl_manager: Any = None
self._rollback_manager: Any = None
self._content_filter: Any = None
self._sandbox_executor: Any = None
self._emergency_controls: Any = None
# Policy cache
self._policies: dict[str, SafetyPolicy] = {}
self._default_policy: SafetyPolicy | None = None
@property
def is_initialized(self) -> bool:
"""Check if the guardian is initialized."""
return self._initialized
async def initialize(self) -> None:
"""Initialize the SafetyGuardian and all subsystems."""
async with self._lock:
if self._initialized:
logger.warning("SafetyGuardian already initialized")
return
logger.info("Initializing SafetyGuardian")
# Get audit logger
if self._audit_logger is None:
self._audit_logger = await get_audit_logger()
# Initialize subsystems lazily as they're implemented
# For now, we'll import and initialize them when available
self._initialized = True
logger.info("SafetyGuardian initialized")
async def shutdown(self) -> None:
"""Shutdown the SafetyGuardian and all subsystems."""
async with self._lock:
if not self._initialized:
return
logger.info("Shutting down SafetyGuardian")
# Shutdown subsystems
# (Will be implemented as subsystems are added)
self._initialized = False
logger.info("SafetyGuardian shutdown complete")
async def validate(
self,
action: ActionRequest,
policy: SafetyPolicy | None = None,
) -> GuardianResult:
"""
Validate an action before execution.
Runs all safety checks in order:
1. Permission check
2. Cost/budget check
3. Rate limit check
4. Loop detection
5. HITL check (if required)
6. Checkpoint creation (if destructive)
Args:
action: The action to validate
policy: Optional policy override. If None, uses autonomy-level policy.
Returns:
GuardianResult with decision and details
"""
if not self._initialized:
await self.initialize()
if not self._config.enabled:
# Safety disabled - allow everything (NOT RECOMMENDED)
logger.warning("Safety framework disabled - allowing action %s", action.id)
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Safety framework disabled"],
)
# Get policy for this action
effective_policy = policy or self._get_policy(action)
reasons: list[str] = []
audit_events = []
try:
# Log action request
if self._audit_logger:
event = await self._audit_logger.log(
AuditEventType.ACTION_REQUESTED,
agent_id=action.metadata.agent_id,
action_id=action.id,
project_id=action.metadata.project_id,
session_id=action.metadata.session_id,
details={
"action_type": action.action_type.value,
"tool_name": action.tool_name,
"resource": action.resource,
},
correlation_id=action.metadata.correlation_id,
)
audit_events.append(event)
# 1. Permission check
permission_result = await self._check_permissions(action, effective_policy)
if permission_result.decision == SafetyDecision.DENY:
return await self._create_denial_result(
action, permission_result.reasons, audit_events
)
# 2. Cost/budget check
budget_result = await self._check_budget(action, effective_policy)
if budget_result.decision == SafetyDecision.DENY:
return await self._create_denial_result(
action, budget_result.reasons, audit_events
)
# 3. Rate limit check
rate_result = await self._check_rate_limit(action, effective_policy)
if rate_result.decision == SafetyDecision.DENY:
return await self._create_denial_result(
action,
rate_result.reasons,
audit_events,
retry_after=rate_result.retry_after_seconds,
)
if rate_result.decision == SafetyDecision.DELAY:
# Return delay decision
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DELAY,
reasons=rate_result.reasons,
retry_after_seconds=rate_result.retry_after_seconds,
audit_events=audit_events,
)
# 4. Loop detection
loop_result = await self._check_loops(action, effective_policy)
if loop_result.decision == SafetyDecision.DENY:
return await self._create_denial_result(
action, loop_result.reasons, audit_events
)
# 5. HITL check
hitl_result = await self._check_hitl(action, effective_policy)
if hitl_result.decision == SafetyDecision.REQUIRE_APPROVAL:
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.REQUIRE_APPROVAL,
reasons=hitl_result.reasons,
approval_id=hitl_result.approval_id,
audit_events=audit_events,
)
# 6. Create checkpoint if destructive
checkpoint_id = None
if action.is_destructive and self._config.auto_checkpoint_destructive:
checkpoint_id = await self._create_checkpoint(action)
# All checks passed
reasons.append("All safety checks passed")
if self._audit_logger:
event = await self._audit_logger.log_action_request(
action, SafetyDecision.ALLOW, reasons
)
audit_events.append(event)
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=reasons,
checkpoint_id=checkpoint_id,
audit_events=audit_events,
)
except SafetyError as e:
# Known safety error
return await self._create_denial_result(action, [str(e)], audit_events)
except Exception as e:
# Unknown error - fail closed in strict mode
logger.error("Unexpected error in safety validation: %s", e)
if self._config.strict_mode:
return await self._create_denial_result(
action,
[f"Safety validation error: {e}"],
audit_events,
)
else:
# Non-strict mode - allow with warning
logger.warning("Non-strict mode: allowing action despite error")
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Allowed despite validation error (non-strict mode)"],
audit_events=audit_events,
)
async def record_execution(
self,
action: ActionRequest,
result: ActionResult,
) -> None:
"""
Record action execution result for auditing and tracking.
Args:
action: The executed action
result: The execution result
"""
if self._audit_logger:
await self._audit_logger.log_action_executed(
action,
success=result.success,
execution_time_ms=result.execution_time_ms,
error=result.error,
)
# Update cost tracking
if self._cost_controller:
# Track actual cost
pass
# Update loop detection history
if self._loop_detector:
# Add to action history
pass
async def rollback(self, checkpoint_id: str) -> bool:
"""
Rollback to a checkpoint.
Args:
checkpoint_id: ID of the checkpoint to rollback to
Returns:
True if rollback succeeded
"""
if self._rollback_manager is None:
logger.warning("Rollback manager not available")
return False
# Delegate to rollback manager
return await self._rollback_manager.rollback(checkpoint_id)
async def emergency_stop(
self,
stop_type: str = "kill",
reason: str = "Manual emergency stop",
triggered_by: str = "system",
) -> None:
"""
Trigger emergency stop.
Args:
stop_type: Type of stop (kill, pause, lockdown)
reason: Reason for the stop
triggered_by: Who triggered the stop
"""
logger.critical(
"Emergency stop triggered: type=%s, reason=%s, by=%s",
stop_type,
reason,
triggered_by,
)
if self._audit_logger:
await self._audit_logger.log_emergency_stop(
stop_type=stop_type,
triggered_by=triggered_by,
reason=reason,
)
if self._emergency_controls:
await self._emergency_controls.execute_stop(stop_type)
def _get_policy(self, action: ActionRequest) -> SafetyPolicy:
"""Get the effective policy for an action."""
# Check cached policies
autonomy_level = action.metadata.autonomy_level
if autonomy_level.value not in self._policies:
self._policies[autonomy_level.value] = get_policy_for_autonomy_level(
autonomy_level
)
return self._policies[autonomy_level.value]
async def _check_permissions(
self,
action: ActionRequest,
policy: SafetyPolicy,
) -> GuardianResult:
"""Check if action is permitted."""
reasons: list[str] = []
# Check denied tools
if action.tool_name:
for pattern in policy.denied_tools:
if self._matches_pattern(action.tool_name, pattern):
reasons.append(
f"Tool '{action.tool_name}' denied by pattern '{pattern}'"
)
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=reasons,
)
# Check allowed tools (if not "*")
if action.tool_name and "*" not in policy.allowed_tools:
allowed = False
for pattern in policy.allowed_tools:
if self._matches_pattern(action.tool_name, pattern):
allowed = True
break
if not allowed:
reasons.append(f"Tool '{action.tool_name}' not in allowed list")
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=reasons,
)
# Check file patterns
if action.resource:
for pattern in policy.denied_file_patterns:
if self._matches_pattern(action.resource, pattern):
reasons.append(
f"Resource '{action.resource}' denied by pattern '{pattern}'"
)
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=reasons,
)
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Permission check passed"],
)
async def _check_budget(
self,
action: ActionRequest,
policy: SafetyPolicy,
) -> GuardianResult:
"""Check if action is within budget."""
# TODO: Implement with CostController
# For now, return allow
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Budget check passed (not fully implemented)"],
)
async def _check_rate_limit(
self,
action: ActionRequest,
policy: SafetyPolicy,
) -> GuardianResult:
"""Check if action is within rate limits."""
# TODO: Implement with RateLimiter
# For now, return allow
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Rate limit check passed (not fully implemented)"],
)
async def _check_loops(
self,
action: ActionRequest,
policy: SafetyPolicy,
) -> GuardianResult:
"""Check for action loops."""
# TODO: Implement with LoopDetector
# For now, return allow
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Loop check passed (not fully implemented)"],
)
async def _check_hitl(
self,
action: ActionRequest,
policy: SafetyPolicy,
) -> GuardianResult:
"""Check if human approval is required."""
if not self._config.hitl_enabled:
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["HITL disabled"],
)
# Check if action requires approval
requires_approval = False
for pattern in policy.require_approval_for:
if pattern == "*":
requires_approval = True
break
if action.tool_name and self._matches_pattern(action.tool_name, pattern):
requires_approval = True
break
if action.action_type.value and self._matches_pattern(
action.action_type.value, pattern
):
requires_approval = True
break
if requires_approval:
# TODO: Create approval request with HITLManager
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.REQUIRE_APPROVAL,
reasons=["Action requires human approval"],
approval_id=None, # Will be set by HITLManager
)
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["No approval required"],
)
async def _create_checkpoint(self, action: ActionRequest) -> str | None:
"""Create a checkpoint before destructive action."""
if self._rollback_manager is None:
logger.warning("Rollback manager not available - skipping checkpoint")
return None
# TODO: Implement with RollbackManager
return None
async def _create_denial_result(
self,
action: ActionRequest,
reasons: list[str],
audit_events: list[Any],
retry_after: float | None = None,
) -> GuardianResult:
"""Create a denial result with audit logging."""
if self._audit_logger:
event = await self._audit_logger.log_action_request(
action, SafetyDecision.DENY, reasons
)
audit_events.append(event)
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=reasons,
retry_after_seconds=retry_after,
audit_events=audit_events,
)
def _matches_pattern(self, value: str, pattern: str) -> bool:
"""Check if value matches a pattern (supports * wildcard)."""
if pattern == "*":
return True
if "*" not in pattern:
return value == pattern
# Simple wildcard matching
if pattern.startswith("*") and pattern.endswith("*"):
return pattern[1:-1] in value
elif pattern.startswith("*"):
return value.endswith(pattern[1:])
elif pattern.endswith("*"):
return value.startswith(pattern[:-1])
else:
# Pattern like "foo*bar"
parts = pattern.split("*")
if len(parts) == 2:
return value.startswith(parts[0]) and value.endswith(parts[1])
return False
# Singleton instance
_guardian_instance: SafetyGuardian | None = None
_guardian_lock = asyncio.Lock()
async def get_safety_guardian() -> SafetyGuardian:
"""Get the global SafetyGuardian instance."""
global _guardian_instance
async with _guardian_lock:
if _guardian_instance is None:
_guardian_instance = SafetyGuardian()
await _guardian_instance.initialize()
return _guardian_instance
async def shutdown_safety_guardian() -> None:
"""Shutdown the global SafetyGuardian."""
global _guardian_instance
async with _guardian_lock:
if _guardian_instance is not None:
await _guardian_instance.shutdown()
_guardian_instance = None
def reset_safety_guardian() -> None:
"""Reset the SafetyGuardian (for testing)."""
global _guardian_instance
_guardian_instance = None