Improved code readability and uniformity by standardizing line breaks, indentation, and inline conditions across safety-related services, models, and tests, including content filters, validation rules, and emergency controls.
617 lines
20 KiB
Python
617 lines
20 KiB
Python
"""
|
|
Safety Guardian
|
|
|
|
Main facade for the safety framework. Orchestrates all safety checks
|
|
before, during, and after action execution.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
from typing import Any
|
|
|
|
from .audit import AuditLogger, get_audit_logger
|
|
from .config import (
|
|
SafetyConfig,
|
|
get_policy_for_autonomy_level,
|
|
get_safety_config,
|
|
)
|
|
from .exceptions import (
|
|
SafetyError,
|
|
)
|
|
from .models import (
|
|
ActionRequest,
|
|
ActionResult,
|
|
AuditEventType,
|
|
GuardianResult,
|
|
SafetyDecision,
|
|
SafetyPolicy,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SafetyGuardian:
|
|
"""
|
|
Central orchestrator for all safety checks.
|
|
|
|
The SafetyGuardian is the main entry point for validating agent actions.
|
|
It coordinates multiple safety subsystems:
|
|
- Permission checking
|
|
- Cost/budget control
|
|
- Rate limiting
|
|
- Loop detection
|
|
- Human-in-the-loop approval
|
|
- Rollback/checkpoint management
|
|
- Content filtering
|
|
- Sandbox execution
|
|
|
|
Usage:
|
|
guardian = SafetyGuardian()
|
|
await guardian.initialize()
|
|
|
|
# Before executing an action
|
|
result = await guardian.validate(action_request)
|
|
if not result.allowed:
|
|
# Handle denial
|
|
|
|
# After action execution
|
|
await guardian.record_execution(action_request, action_result)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: SafetyConfig | None = None,
|
|
audit_logger: AuditLogger | None = None,
|
|
) -> None:
|
|
"""
|
|
Initialize the SafetyGuardian.
|
|
|
|
Args:
|
|
config: Optional safety configuration. If None, loads from environment.
|
|
audit_logger: Optional audit logger. If None, uses global instance.
|
|
"""
|
|
self._config = config or get_safety_config()
|
|
self._audit_logger = audit_logger
|
|
self._initialized = False
|
|
self._lock = asyncio.Lock()
|
|
|
|
# Subsystem references (will be initialized lazily)
|
|
self._permission_manager: Any = None
|
|
self._cost_controller: Any = None
|
|
self._rate_limiter: Any = None
|
|
self._loop_detector: Any = None
|
|
self._hitl_manager: Any = None
|
|
self._rollback_manager: Any = None
|
|
self._content_filter: Any = None
|
|
self._sandbox_executor: Any = None
|
|
self._emergency_controls: Any = None
|
|
|
|
# Policy cache
|
|
self._policies: dict[str, SafetyPolicy] = {}
|
|
self._default_policy: SafetyPolicy | None = None
|
|
|
|
@property
|
|
def is_initialized(self) -> bool:
|
|
"""Check if the guardian is initialized."""
|
|
return self._initialized
|
|
|
|
async def initialize(self) -> None:
|
|
"""Initialize the SafetyGuardian and all subsystems."""
|
|
async with self._lock:
|
|
if self._initialized:
|
|
logger.warning("SafetyGuardian already initialized")
|
|
return
|
|
|
|
logger.info("Initializing SafetyGuardian")
|
|
|
|
# Get audit logger
|
|
if self._audit_logger is None:
|
|
self._audit_logger = await get_audit_logger()
|
|
|
|
# Initialize subsystems lazily as they're implemented
|
|
# For now, we'll import and initialize them when available
|
|
|
|
self._initialized = True
|
|
logger.info("SafetyGuardian initialized")
|
|
|
|
async def shutdown(self) -> None:
|
|
"""Shutdown the SafetyGuardian and all subsystems."""
|
|
async with self._lock:
|
|
if not self._initialized:
|
|
return
|
|
|
|
logger.info("Shutting down SafetyGuardian")
|
|
|
|
# Shutdown subsystems
|
|
# (Will be implemented as subsystems are added)
|
|
|
|
self._initialized = False
|
|
logger.info("SafetyGuardian shutdown complete")
|
|
|
|
async def validate(
|
|
self,
|
|
action: ActionRequest,
|
|
policy: SafetyPolicy | None = None,
|
|
) -> GuardianResult:
|
|
"""
|
|
Validate an action before execution.
|
|
|
|
Runs all safety checks in order:
|
|
1. Permission check
|
|
2. Cost/budget check
|
|
3. Rate limit check
|
|
4. Loop detection
|
|
5. HITL check (if required)
|
|
6. Checkpoint creation (if destructive)
|
|
|
|
Args:
|
|
action: The action to validate
|
|
policy: Optional policy override. If None, uses autonomy-level policy.
|
|
|
|
Returns:
|
|
GuardianResult with decision and details
|
|
"""
|
|
if not self._initialized:
|
|
await self.initialize()
|
|
|
|
if not self._config.enabled:
|
|
# Safety disabled - allow everything (NOT RECOMMENDED)
|
|
logger.warning("Safety framework disabled - allowing action %s", action.id)
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=True,
|
|
decision=SafetyDecision.ALLOW,
|
|
reasons=["Safety framework disabled"],
|
|
)
|
|
|
|
# Get policy for this action
|
|
effective_policy = policy or self._get_policy(action)
|
|
|
|
reasons: list[str] = []
|
|
audit_events = []
|
|
|
|
try:
|
|
# Log action request
|
|
if self._audit_logger:
|
|
event = await self._audit_logger.log(
|
|
AuditEventType.ACTION_REQUESTED,
|
|
agent_id=action.metadata.agent_id,
|
|
action_id=action.id,
|
|
project_id=action.metadata.project_id,
|
|
session_id=action.metadata.session_id,
|
|
details={
|
|
"action_type": action.action_type.value,
|
|
"tool_name": action.tool_name,
|
|
"resource": action.resource,
|
|
},
|
|
correlation_id=action.metadata.correlation_id,
|
|
)
|
|
audit_events.append(event)
|
|
|
|
# 1. Permission check
|
|
permission_result = await self._check_permissions(action, effective_policy)
|
|
if permission_result.decision == SafetyDecision.DENY:
|
|
return await self._create_denial_result(
|
|
action, permission_result.reasons, audit_events
|
|
)
|
|
|
|
# 2. Cost/budget check
|
|
budget_result = await self._check_budget(action, effective_policy)
|
|
if budget_result.decision == SafetyDecision.DENY:
|
|
return await self._create_denial_result(
|
|
action, budget_result.reasons, audit_events
|
|
)
|
|
|
|
# 3. Rate limit check
|
|
rate_result = await self._check_rate_limit(action, effective_policy)
|
|
if rate_result.decision == SafetyDecision.DENY:
|
|
return await self._create_denial_result(
|
|
action,
|
|
rate_result.reasons,
|
|
audit_events,
|
|
retry_after=rate_result.retry_after_seconds,
|
|
)
|
|
if rate_result.decision == SafetyDecision.DELAY:
|
|
# Return delay decision
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=False,
|
|
decision=SafetyDecision.DELAY,
|
|
reasons=rate_result.reasons,
|
|
retry_after_seconds=rate_result.retry_after_seconds,
|
|
audit_events=audit_events,
|
|
)
|
|
|
|
# 4. Loop detection
|
|
loop_result = await self._check_loops(action, effective_policy)
|
|
if loop_result.decision == SafetyDecision.DENY:
|
|
return await self._create_denial_result(
|
|
action, loop_result.reasons, audit_events
|
|
)
|
|
|
|
# 5. HITL check
|
|
hitl_result = await self._check_hitl(action, effective_policy)
|
|
if hitl_result.decision == SafetyDecision.REQUIRE_APPROVAL:
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=False,
|
|
decision=SafetyDecision.REQUIRE_APPROVAL,
|
|
reasons=hitl_result.reasons,
|
|
approval_id=hitl_result.approval_id,
|
|
audit_events=audit_events,
|
|
)
|
|
|
|
# 6. Create checkpoint if destructive
|
|
checkpoint_id = None
|
|
if action.is_destructive and self._config.auto_checkpoint_destructive:
|
|
checkpoint_id = await self._create_checkpoint(action)
|
|
|
|
# All checks passed
|
|
reasons.append("All safety checks passed")
|
|
|
|
if self._audit_logger:
|
|
event = await self._audit_logger.log_action_request(
|
|
action, SafetyDecision.ALLOW, reasons
|
|
)
|
|
audit_events.append(event)
|
|
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=True,
|
|
decision=SafetyDecision.ALLOW,
|
|
reasons=reasons,
|
|
checkpoint_id=checkpoint_id,
|
|
audit_events=audit_events,
|
|
)
|
|
|
|
except SafetyError as e:
|
|
# Known safety error
|
|
return await self._create_denial_result(action, [str(e)], audit_events)
|
|
except Exception as e:
|
|
# Unknown error - fail closed in strict mode
|
|
logger.error("Unexpected error in safety validation: %s", e)
|
|
if self._config.strict_mode:
|
|
return await self._create_denial_result(
|
|
action,
|
|
[f"Safety validation error: {e}"],
|
|
audit_events,
|
|
)
|
|
else:
|
|
# Non-strict mode - allow with warning
|
|
logger.warning("Non-strict mode: allowing action despite error")
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=True,
|
|
decision=SafetyDecision.ALLOW,
|
|
reasons=["Allowed despite validation error (non-strict mode)"],
|
|
audit_events=audit_events,
|
|
)
|
|
|
|
async def record_execution(
|
|
self,
|
|
action: ActionRequest,
|
|
result: ActionResult,
|
|
) -> None:
|
|
"""
|
|
Record action execution result for auditing and tracking.
|
|
|
|
Args:
|
|
action: The executed action
|
|
result: The execution result
|
|
"""
|
|
if self._audit_logger:
|
|
await self._audit_logger.log_action_executed(
|
|
action,
|
|
success=result.success,
|
|
execution_time_ms=result.execution_time_ms,
|
|
error=result.error,
|
|
)
|
|
|
|
# Update cost tracking
|
|
if self._cost_controller:
|
|
# Track actual cost
|
|
pass
|
|
|
|
# Update loop detection history
|
|
if self._loop_detector:
|
|
# Add to action history
|
|
pass
|
|
|
|
async def rollback(self, checkpoint_id: str) -> bool:
|
|
"""
|
|
Rollback to a checkpoint.
|
|
|
|
Args:
|
|
checkpoint_id: ID of the checkpoint to rollback to
|
|
|
|
Returns:
|
|
True if rollback succeeded
|
|
"""
|
|
if self._rollback_manager is None:
|
|
logger.warning("Rollback manager not available")
|
|
return False
|
|
|
|
# Delegate to rollback manager
|
|
return await self._rollback_manager.rollback(checkpoint_id)
|
|
|
|
async def emergency_stop(
|
|
self,
|
|
stop_type: str = "kill",
|
|
reason: str = "Manual emergency stop",
|
|
triggered_by: str = "system",
|
|
) -> None:
|
|
"""
|
|
Trigger emergency stop.
|
|
|
|
Args:
|
|
stop_type: Type of stop (kill, pause, lockdown)
|
|
reason: Reason for the stop
|
|
triggered_by: Who triggered the stop
|
|
"""
|
|
logger.critical(
|
|
"Emergency stop triggered: type=%s, reason=%s, by=%s",
|
|
stop_type,
|
|
reason,
|
|
triggered_by,
|
|
)
|
|
|
|
if self._audit_logger:
|
|
await self._audit_logger.log_emergency_stop(
|
|
stop_type=stop_type,
|
|
triggered_by=triggered_by,
|
|
reason=reason,
|
|
)
|
|
|
|
if self._emergency_controls:
|
|
await self._emergency_controls.execute_stop(stop_type)
|
|
|
|
def _get_policy(self, action: ActionRequest) -> SafetyPolicy:
|
|
"""Get the effective policy for an action."""
|
|
# Check cached policies
|
|
autonomy_level = action.metadata.autonomy_level
|
|
|
|
if autonomy_level.value not in self._policies:
|
|
self._policies[autonomy_level.value] = get_policy_for_autonomy_level(
|
|
autonomy_level
|
|
)
|
|
|
|
return self._policies[autonomy_level.value]
|
|
|
|
async def _check_permissions(
|
|
self,
|
|
action: ActionRequest,
|
|
policy: SafetyPolicy,
|
|
) -> GuardianResult:
|
|
"""Check if action is permitted."""
|
|
reasons: list[str] = []
|
|
|
|
# Check denied tools
|
|
if action.tool_name:
|
|
for pattern in policy.denied_tools:
|
|
if self._matches_pattern(action.tool_name, pattern):
|
|
reasons.append(
|
|
f"Tool '{action.tool_name}' denied by pattern '{pattern}'"
|
|
)
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=False,
|
|
decision=SafetyDecision.DENY,
|
|
reasons=reasons,
|
|
)
|
|
|
|
# Check allowed tools (if not "*")
|
|
if action.tool_name and "*" not in policy.allowed_tools:
|
|
allowed = False
|
|
for pattern in policy.allowed_tools:
|
|
if self._matches_pattern(action.tool_name, pattern):
|
|
allowed = True
|
|
break
|
|
if not allowed:
|
|
reasons.append(f"Tool '{action.tool_name}' not in allowed list")
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=False,
|
|
decision=SafetyDecision.DENY,
|
|
reasons=reasons,
|
|
)
|
|
|
|
# Check file patterns
|
|
if action.resource:
|
|
for pattern in policy.denied_file_patterns:
|
|
if self._matches_pattern(action.resource, pattern):
|
|
reasons.append(
|
|
f"Resource '{action.resource}' denied by pattern '{pattern}'"
|
|
)
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=False,
|
|
decision=SafetyDecision.DENY,
|
|
reasons=reasons,
|
|
)
|
|
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=True,
|
|
decision=SafetyDecision.ALLOW,
|
|
reasons=["Permission check passed"],
|
|
)
|
|
|
|
async def _check_budget(
|
|
self,
|
|
action: ActionRequest,
|
|
policy: SafetyPolicy,
|
|
) -> GuardianResult:
|
|
"""Check if action is within budget."""
|
|
# TODO: Implement with CostController
|
|
# For now, return allow
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=True,
|
|
decision=SafetyDecision.ALLOW,
|
|
reasons=["Budget check passed (not fully implemented)"],
|
|
)
|
|
|
|
async def _check_rate_limit(
|
|
self,
|
|
action: ActionRequest,
|
|
policy: SafetyPolicy,
|
|
) -> GuardianResult:
|
|
"""Check if action is within rate limits."""
|
|
# TODO: Implement with RateLimiter
|
|
# For now, return allow
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=True,
|
|
decision=SafetyDecision.ALLOW,
|
|
reasons=["Rate limit check passed (not fully implemented)"],
|
|
)
|
|
|
|
async def _check_loops(
|
|
self,
|
|
action: ActionRequest,
|
|
policy: SafetyPolicy,
|
|
) -> GuardianResult:
|
|
"""Check for action loops."""
|
|
# TODO: Implement with LoopDetector
|
|
# For now, return allow
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=True,
|
|
decision=SafetyDecision.ALLOW,
|
|
reasons=["Loop check passed (not fully implemented)"],
|
|
)
|
|
|
|
async def _check_hitl(
|
|
self,
|
|
action: ActionRequest,
|
|
policy: SafetyPolicy,
|
|
) -> GuardianResult:
|
|
"""Check if human approval is required."""
|
|
if not self._config.hitl_enabled:
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=True,
|
|
decision=SafetyDecision.ALLOW,
|
|
reasons=["HITL disabled"],
|
|
)
|
|
|
|
# Check if action requires approval
|
|
requires_approval = False
|
|
for pattern in policy.require_approval_for:
|
|
if pattern == "*":
|
|
requires_approval = True
|
|
break
|
|
if action.tool_name and self._matches_pattern(action.tool_name, pattern):
|
|
requires_approval = True
|
|
break
|
|
if action.action_type.value and self._matches_pattern(
|
|
action.action_type.value, pattern
|
|
):
|
|
requires_approval = True
|
|
break
|
|
|
|
if requires_approval:
|
|
# TODO: Create approval request with HITLManager
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=False,
|
|
decision=SafetyDecision.REQUIRE_APPROVAL,
|
|
reasons=["Action requires human approval"],
|
|
approval_id=None, # Will be set by HITLManager
|
|
)
|
|
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=True,
|
|
decision=SafetyDecision.ALLOW,
|
|
reasons=["No approval required"],
|
|
)
|
|
|
|
async def _create_checkpoint(self, action: ActionRequest) -> str | None:
|
|
"""Create a checkpoint before destructive action."""
|
|
if self._rollback_manager is None:
|
|
logger.warning("Rollback manager not available - skipping checkpoint")
|
|
return None
|
|
|
|
# TODO: Implement with RollbackManager
|
|
return None
|
|
|
|
async def _create_denial_result(
|
|
self,
|
|
action: ActionRequest,
|
|
reasons: list[str],
|
|
audit_events: list[Any],
|
|
retry_after: float | None = None,
|
|
) -> GuardianResult:
|
|
"""Create a denial result with audit logging."""
|
|
if self._audit_logger:
|
|
event = await self._audit_logger.log_action_request(
|
|
action, SafetyDecision.DENY, reasons
|
|
)
|
|
audit_events.append(event)
|
|
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=False,
|
|
decision=SafetyDecision.DENY,
|
|
reasons=reasons,
|
|
retry_after_seconds=retry_after,
|
|
audit_events=audit_events,
|
|
)
|
|
|
|
def _matches_pattern(self, value: str, pattern: str) -> bool:
|
|
"""Check if value matches a pattern (supports * wildcard)."""
|
|
if pattern == "*":
|
|
return True
|
|
|
|
if "*" not in pattern:
|
|
return value == pattern
|
|
|
|
# Simple wildcard matching
|
|
if pattern.startswith("*") and pattern.endswith("*"):
|
|
return pattern[1:-1] in value
|
|
elif pattern.startswith("*"):
|
|
return value.endswith(pattern[1:])
|
|
elif pattern.endswith("*"):
|
|
return value.startswith(pattern[:-1])
|
|
else:
|
|
# Pattern like "foo*bar"
|
|
parts = pattern.split("*")
|
|
if len(parts) == 2:
|
|
return value.startswith(parts[0]) and value.endswith(parts[1])
|
|
|
|
return False
|
|
|
|
|
|
# Singleton instance
|
|
_guardian_instance: SafetyGuardian | None = None
|
|
_guardian_lock = asyncio.Lock()
|
|
|
|
|
|
async def get_safety_guardian() -> SafetyGuardian:
|
|
"""Get the global SafetyGuardian instance."""
|
|
global _guardian_instance
|
|
|
|
async with _guardian_lock:
|
|
if _guardian_instance is None:
|
|
_guardian_instance = SafetyGuardian()
|
|
await _guardian_instance.initialize()
|
|
|
|
return _guardian_instance
|
|
|
|
|
|
async def shutdown_safety_guardian() -> None:
|
|
"""Shutdown the global SafetyGuardian."""
|
|
global _guardian_instance
|
|
|
|
async with _guardian_lock:
|
|
if _guardian_instance is not None:
|
|
await _guardian_instance.shutdown()
|
|
_guardian_instance = None
|
|
|
|
|
|
def reset_safety_guardian() -> None:
|
|
"""Reset the SafetyGuardian (for testing)."""
|
|
global _guardian_instance
|
|
_guardian_instance = None
|