forked from cardosofelipe/fast-next-template
- Added `record_action` in `RateLimiter` for precise tracking of slot consumption post-validation. - Introduced deduplication mechanism for warning alerts in `CostController` to prevent spamming. - Refactored `CostController`'s session and daily budget alert handling for improved clarity. - Implemented test suites for `CostController` and `SafetyGuardian` to validate changes. - Expanded integration testing to cover deduplication, validation, and loop detection edge cases.
865 lines
30 KiB
Python
865 lines
30 KiB
Python
"""
|
|
Safety Guardian
|
|
|
|
Main facade for the safety framework. Orchestrates all safety checks
|
|
before, during, and after action execution.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
from typing import Any
|
|
|
|
from .audit import AuditLogger, get_audit_logger
|
|
from .config import (
|
|
SafetyConfig,
|
|
get_policy_for_autonomy_level,
|
|
get_safety_config,
|
|
)
|
|
from .costs.controller import CostController
|
|
from .exceptions import (
|
|
BudgetExceededError,
|
|
LoopDetectedError,
|
|
RateLimitExceededError,
|
|
SafetyError,
|
|
)
|
|
from .limits.limiter import RateLimiter
|
|
from .loops.detector import LoopDetector
|
|
from .models import (
|
|
ActionRequest,
|
|
ActionResult,
|
|
AuditEventType,
|
|
BudgetScope,
|
|
GuardianResult,
|
|
SafetyDecision,
|
|
SafetyPolicy,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SafetyGuardian:
|
|
"""
|
|
Central orchestrator for all safety checks.
|
|
|
|
The SafetyGuardian is the main entry point for validating agent actions.
|
|
It coordinates multiple safety subsystems:
|
|
- Permission checking
|
|
- Cost/budget control
|
|
- Rate limiting
|
|
- Loop detection
|
|
- Human-in-the-loop approval
|
|
- Rollback/checkpoint management
|
|
- Content filtering
|
|
- Sandbox execution
|
|
|
|
Usage:
|
|
guardian = SafetyGuardian()
|
|
await guardian.initialize()
|
|
|
|
# Before executing an action
|
|
result = await guardian.validate(action_request)
|
|
if not result.allowed:
|
|
# Handle denial
|
|
|
|
# After action execution
|
|
await guardian.record_execution(action_request, action_result)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: SafetyConfig | None = None,
|
|
audit_logger: AuditLogger | None = None,
|
|
cost_controller: CostController | None = None,
|
|
rate_limiter: RateLimiter | None = None,
|
|
loop_detector: LoopDetector | None = None,
|
|
) -> None:
|
|
"""
|
|
Initialize the SafetyGuardian.
|
|
|
|
Args:
|
|
config: Optional safety configuration. If None, loads from environment.
|
|
audit_logger: Optional audit logger. If None, uses global instance.
|
|
cost_controller: Optional cost controller. If None, creates default.
|
|
rate_limiter: Optional rate limiter. If None, creates default.
|
|
loop_detector: Optional loop detector. If None, creates default.
|
|
"""
|
|
self._config = config or get_safety_config()
|
|
self._audit_logger = audit_logger
|
|
self._initialized = False
|
|
self._lock = asyncio.Lock()
|
|
|
|
# Core safety subsystems (always initialized)
|
|
self._cost_controller: CostController | None = cost_controller
|
|
self._rate_limiter: RateLimiter | None = rate_limiter
|
|
self._loop_detector: LoopDetector | None = loop_detector
|
|
|
|
# Optional subsystems (will be initialized when available)
|
|
self._permission_manager: Any = None
|
|
self._hitl_manager: Any = None
|
|
self._rollback_manager: Any = None
|
|
self._content_filter: Any = None
|
|
self._sandbox_executor: Any = None
|
|
self._emergency_controls: Any = None
|
|
|
|
# Policy cache
|
|
self._policies: dict[str, SafetyPolicy] = {}
|
|
self._default_policy: SafetyPolicy | None = None
|
|
|
|
@property
|
|
def is_initialized(self) -> bool:
|
|
"""Check if the guardian is initialized."""
|
|
return self._initialized
|
|
|
|
@property
|
|
def cost_controller(self) -> CostController | None:
|
|
"""Get the cost controller instance."""
|
|
return self._cost_controller
|
|
|
|
@property
|
|
def rate_limiter(self) -> RateLimiter | None:
|
|
"""Get the rate limiter instance."""
|
|
return self._rate_limiter
|
|
|
|
@property
|
|
def loop_detector(self) -> LoopDetector | None:
|
|
"""Get the loop detector instance."""
|
|
return self._loop_detector
|
|
|
|
async def initialize(self) -> None:
|
|
"""Initialize the SafetyGuardian and all subsystems."""
|
|
async with self._lock:
|
|
if self._initialized:
|
|
logger.warning("SafetyGuardian already initialized")
|
|
return
|
|
|
|
logger.info("Initializing SafetyGuardian")
|
|
|
|
# Get audit logger
|
|
if self._audit_logger is None:
|
|
self._audit_logger = await get_audit_logger()
|
|
|
|
# Initialize core safety subsystems
|
|
if self._cost_controller is None:
|
|
self._cost_controller = CostController()
|
|
logger.debug("Initialized CostController")
|
|
|
|
if self._rate_limiter is None:
|
|
self._rate_limiter = RateLimiter()
|
|
logger.debug("Initialized RateLimiter")
|
|
|
|
if self._loop_detector is None:
|
|
self._loop_detector = LoopDetector()
|
|
logger.debug("Initialized LoopDetector")
|
|
|
|
self._initialized = True
|
|
logger.info(
|
|
"SafetyGuardian initialized with CostController, RateLimiter, LoopDetector"
|
|
)
|
|
|
|
async def shutdown(self) -> None:
|
|
"""Shutdown the SafetyGuardian and all subsystems."""
|
|
async with self._lock:
|
|
if not self._initialized:
|
|
return
|
|
|
|
logger.info("Shutting down SafetyGuardian")
|
|
|
|
# Shutdown subsystems
|
|
# (Will be implemented as subsystems are added)
|
|
|
|
self._initialized = False
|
|
logger.info("SafetyGuardian shutdown complete")
|
|
|
|
async def validate(
|
|
self,
|
|
action: ActionRequest,
|
|
policy: SafetyPolicy | None = None,
|
|
) -> GuardianResult:
|
|
"""
|
|
Validate an action before execution.
|
|
|
|
Runs all safety checks in order:
|
|
1. Permission check
|
|
2. Cost/budget check
|
|
3. Rate limit check
|
|
4. Loop detection
|
|
5. HITL check (if required)
|
|
6. Checkpoint creation (if destructive)
|
|
|
|
Args:
|
|
action: The action to validate
|
|
policy: Optional policy override. If None, uses autonomy-level policy.
|
|
|
|
Returns:
|
|
GuardianResult with decision and details
|
|
"""
|
|
if not self._initialized:
|
|
await self.initialize()
|
|
|
|
if not self._config.enabled:
|
|
# Safety disabled - allow everything (NOT RECOMMENDED)
|
|
logger.warning("Safety framework disabled - allowing action %s", action.id)
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=True,
|
|
decision=SafetyDecision.ALLOW,
|
|
reasons=["Safety framework disabled"],
|
|
)
|
|
|
|
# Get policy for this action
|
|
effective_policy = policy or self._get_policy(action)
|
|
|
|
reasons: list[str] = []
|
|
audit_events = []
|
|
|
|
try:
|
|
# Log action request
|
|
if self._audit_logger:
|
|
event = await self._audit_logger.log(
|
|
AuditEventType.ACTION_REQUESTED,
|
|
agent_id=action.metadata.agent_id,
|
|
action_id=action.id,
|
|
project_id=action.metadata.project_id,
|
|
session_id=action.metadata.session_id,
|
|
details={
|
|
"action_type": action.action_type.value,
|
|
"tool_name": action.tool_name,
|
|
"resource": action.resource,
|
|
},
|
|
correlation_id=action.metadata.correlation_id,
|
|
)
|
|
audit_events.append(event)
|
|
|
|
# 1. Permission check
|
|
permission_result = await self._check_permissions(action, effective_policy)
|
|
if permission_result.decision == SafetyDecision.DENY:
|
|
return await self._create_denial_result(
|
|
action, permission_result.reasons, audit_events
|
|
)
|
|
|
|
# 2. Cost/budget check
|
|
budget_result = await self._check_budget(action, effective_policy)
|
|
if budget_result.decision == SafetyDecision.DENY:
|
|
return await self._create_denial_result(
|
|
action, budget_result.reasons, audit_events
|
|
)
|
|
|
|
# 3. Rate limit check
|
|
rate_result = await self._check_rate_limit(action, effective_policy)
|
|
if rate_result.decision == SafetyDecision.DENY:
|
|
return await self._create_denial_result(
|
|
action,
|
|
rate_result.reasons,
|
|
audit_events,
|
|
retry_after=rate_result.retry_after_seconds,
|
|
)
|
|
if rate_result.decision == SafetyDecision.DELAY:
|
|
# Return delay decision
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=False,
|
|
decision=SafetyDecision.DELAY,
|
|
reasons=rate_result.reasons,
|
|
retry_after_seconds=rate_result.retry_after_seconds,
|
|
audit_events=audit_events,
|
|
)
|
|
|
|
# 4. Loop detection
|
|
loop_result = await self._check_loops(action, effective_policy)
|
|
if loop_result.decision == SafetyDecision.DENY:
|
|
return await self._create_denial_result(
|
|
action, loop_result.reasons, audit_events
|
|
)
|
|
|
|
# 5. HITL check
|
|
hitl_result = await self._check_hitl(action, effective_policy)
|
|
if hitl_result.decision == SafetyDecision.REQUIRE_APPROVAL:
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=False,
|
|
decision=SafetyDecision.REQUIRE_APPROVAL,
|
|
reasons=hitl_result.reasons,
|
|
approval_id=hitl_result.approval_id,
|
|
audit_events=audit_events,
|
|
)
|
|
|
|
# 6. Create checkpoint if destructive
|
|
checkpoint_id = None
|
|
if action.is_destructive and self._config.auto_checkpoint_destructive:
|
|
checkpoint_id = await self._create_checkpoint(action)
|
|
|
|
# All checks passed
|
|
reasons.append("All safety checks passed")
|
|
|
|
if self._audit_logger:
|
|
event = await self._audit_logger.log_action_request(
|
|
action, SafetyDecision.ALLOW, reasons
|
|
)
|
|
audit_events.append(event)
|
|
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=True,
|
|
decision=SafetyDecision.ALLOW,
|
|
reasons=reasons,
|
|
checkpoint_id=checkpoint_id,
|
|
audit_events=audit_events,
|
|
)
|
|
|
|
except SafetyError as e:
|
|
# Known safety error
|
|
return await self._create_denial_result(action, [str(e)], audit_events)
|
|
except Exception as e:
|
|
# Unknown error - fail closed in strict mode
|
|
logger.error("Unexpected error in safety validation: %s", e)
|
|
if self._config.strict_mode:
|
|
return await self._create_denial_result(
|
|
action,
|
|
[f"Safety validation error: {e}"],
|
|
audit_events,
|
|
)
|
|
else:
|
|
# Non-strict mode - allow with warning
|
|
logger.warning("Non-strict mode: allowing action despite error")
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=True,
|
|
decision=SafetyDecision.ALLOW,
|
|
reasons=["Allowed despite validation error (non-strict mode)"],
|
|
audit_events=audit_events,
|
|
)
|
|
|
|
async def record_execution(
|
|
self,
|
|
action: ActionRequest,
|
|
result: ActionResult,
|
|
) -> None:
|
|
"""
|
|
Record action execution result for auditing and tracking.
|
|
|
|
Args:
|
|
action: The executed action
|
|
result: The execution result
|
|
"""
|
|
if self._audit_logger:
|
|
await self._audit_logger.log_action_executed(
|
|
action,
|
|
success=result.success,
|
|
execution_time_ms=result.execution_time_ms,
|
|
error=result.error,
|
|
)
|
|
|
|
# Update cost tracking
|
|
if self._cost_controller:
|
|
try:
|
|
# Use explicit None check - 0 is a valid cost value
|
|
tokens = (
|
|
result.actual_cost_tokens
|
|
if result.actual_cost_tokens is not None
|
|
else action.estimated_cost_tokens
|
|
)
|
|
cost_usd = (
|
|
result.actual_cost_usd
|
|
if result.actual_cost_usd is not None
|
|
else action.estimated_cost_usd
|
|
)
|
|
await self._cost_controller.record_usage(
|
|
agent_id=action.metadata.agent_id,
|
|
session_id=action.metadata.session_id,
|
|
tokens=tokens,
|
|
cost_usd=cost_usd,
|
|
)
|
|
except Exception as e:
|
|
logger.warning("Failed to record cost: %s", e)
|
|
|
|
# Update rate limiter - consume slots for executed actions
|
|
if self._rate_limiter:
|
|
try:
|
|
await self._rate_limiter.record_action(action)
|
|
except Exception as e:
|
|
logger.warning("Failed to record action in rate limiter: %s", e)
|
|
|
|
# Update loop detection history
|
|
if self._loop_detector:
|
|
try:
|
|
await self._loop_detector.record(action)
|
|
except Exception as e:
|
|
logger.warning("Failed to record action in loop detector: %s", e)
|
|
|
|
async def rollback(self, checkpoint_id: str) -> bool:
|
|
"""
|
|
Rollback to a checkpoint.
|
|
|
|
Args:
|
|
checkpoint_id: ID of the checkpoint to rollback to
|
|
|
|
Returns:
|
|
True if rollback succeeded
|
|
"""
|
|
if self._rollback_manager is None:
|
|
logger.warning("Rollback manager not available")
|
|
return False
|
|
|
|
# Delegate to rollback manager
|
|
return await self._rollback_manager.rollback(checkpoint_id)
|
|
|
|
async def emergency_stop(
|
|
self,
|
|
stop_type: str = "kill",
|
|
reason: str = "Manual emergency stop",
|
|
triggered_by: str = "system",
|
|
) -> None:
|
|
"""
|
|
Trigger emergency stop.
|
|
|
|
Args:
|
|
stop_type: Type of stop (kill, pause, lockdown)
|
|
reason: Reason for the stop
|
|
triggered_by: Who triggered the stop
|
|
"""
|
|
logger.critical(
|
|
"Emergency stop triggered: type=%s, reason=%s, by=%s",
|
|
stop_type,
|
|
reason,
|
|
triggered_by,
|
|
)
|
|
|
|
if self._audit_logger:
|
|
await self._audit_logger.log_emergency_stop(
|
|
stop_type=stop_type,
|
|
triggered_by=triggered_by,
|
|
reason=reason,
|
|
)
|
|
|
|
if self._emergency_controls:
|
|
await self._emergency_controls.execute_stop(stop_type)
|
|
|
|
def _get_policy(self, action: ActionRequest) -> SafetyPolicy:
|
|
"""Get the effective policy for an action."""
|
|
# Check cached policies
|
|
autonomy_level = action.metadata.autonomy_level
|
|
|
|
if autonomy_level.value not in self._policies:
|
|
self._policies[autonomy_level.value] = get_policy_for_autonomy_level(
|
|
autonomy_level
|
|
)
|
|
|
|
return self._policies[autonomy_level.value]
|
|
|
|
async def _check_permissions(
|
|
self,
|
|
action: ActionRequest,
|
|
policy: SafetyPolicy,
|
|
) -> GuardianResult:
|
|
"""Check if action is permitted."""
|
|
reasons: list[str] = []
|
|
|
|
# Check denied tools
|
|
if action.tool_name:
|
|
for pattern in policy.denied_tools:
|
|
if self._matches_pattern(action.tool_name, pattern):
|
|
reasons.append(
|
|
f"Tool '{action.tool_name}' denied by pattern '{pattern}'"
|
|
)
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=False,
|
|
decision=SafetyDecision.DENY,
|
|
reasons=reasons,
|
|
)
|
|
|
|
# Check allowed tools (if not "*")
|
|
if action.tool_name and "*" not in policy.allowed_tools:
|
|
allowed = False
|
|
for pattern in policy.allowed_tools:
|
|
if self._matches_pattern(action.tool_name, pattern):
|
|
allowed = True
|
|
break
|
|
if not allowed:
|
|
reasons.append(f"Tool '{action.tool_name}' not in allowed list")
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=False,
|
|
decision=SafetyDecision.DENY,
|
|
reasons=reasons,
|
|
)
|
|
|
|
# Check file patterns
|
|
if action.resource:
|
|
for pattern in policy.denied_file_patterns:
|
|
if self._matches_pattern(action.resource, pattern):
|
|
reasons.append(
|
|
f"Resource '{action.resource}' denied by pattern '{pattern}'"
|
|
)
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=False,
|
|
decision=SafetyDecision.DENY,
|
|
reasons=reasons,
|
|
)
|
|
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=True,
|
|
decision=SafetyDecision.ALLOW,
|
|
reasons=["Permission check passed"],
|
|
)
|
|
|
|
async def _check_budget(
|
|
self,
|
|
action: ActionRequest,
|
|
policy: SafetyPolicy,
|
|
) -> GuardianResult:
|
|
"""Check if action is within budget."""
|
|
if self._cost_controller is None:
|
|
logger.warning("CostController not initialized - skipping budget check")
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=True,
|
|
decision=SafetyDecision.ALLOW,
|
|
reasons=["Budget check skipped (controller not initialized)"],
|
|
)
|
|
|
|
agent_id = action.metadata.agent_id
|
|
session_id = action.metadata.session_id
|
|
|
|
try:
|
|
# Check if we have budget for this action
|
|
has_budget = await self._cost_controller.check_budget(
|
|
agent_id=agent_id,
|
|
session_id=session_id,
|
|
estimated_tokens=action.estimated_cost_tokens,
|
|
estimated_cost_usd=action.estimated_cost_usd,
|
|
)
|
|
|
|
if not has_budget:
|
|
# Get current status for better error message
|
|
if session_id:
|
|
session_status = await self._cost_controller.get_status(
|
|
BudgetScope.SESSION, session_id
|
|
)
|
|
if session_status and session_status.is_exceeded:
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=False,
|
|
decision=SafetyDecision.DENY,
|
|
reasons=[
|
|
f"Session budget exceeded: {session_status.tokens_used}"
|
|
f"/{session_status.tokens_limit} tokens"
|
|
],
|
|
)
|
|
|
|
agent_status = await self._cost_controller.get_status(
|
|
BudgetScope.DAILY, agent_id
|
|
)
|
|
if agent_status and agent_status.is_exceeded:
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=False,
|
|
decision=SafetyDecision.DENY,
|
|
reasons=[
|
|
f"Daily budget exceeded: {agent_status.tokens_used}"
|
|
f"/{agent_status.tokens_limit} tokens"
|
|
],
|
|
)
|
|
|
|
# Generic budget exceeded
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=False,
|
|
decision=SafetyDecision.DENY,
|
|
reasons=["Budget exceeded"],
|
|
)
|
|
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=True,
|
|
decision=SafetyDecision.ALLOW,
|
|
reasons=["Budget check passed"],
|
|
)
|
|
|
|
except BudgetExceededError as e:
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=False,
|
|
decision=SafetyDecision.DENY,
|
|
reasons=[str(e)],
|
|
)
|
|
|
|
async def _check_rate_limit(
|
|
self,
|
|
action: ActionRequest,
|
|
policy: SafetyPolicy,
|
|
) -> GuardianResult:
|
|
"""Check if action is within rate limits."""
|
|
if self._rate_limiter is None:
|
|
logger.warning("RateLimiter not initialized - skipping rate limit check")
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=True,
|
|
decision=SafetyDecision.ALLOW,
|
|
reasons=["Rate limit check skipped (limiter not initialized)"],
|
|
)
|
|
|
|
try:
|
|
# Check all applicable rate limits for this action
|
|
allowed, statuses = await self._rate_limiter.check_action(action)
|
|
|
|
if not allowed:
|
|
# Find the first exceeded limit for the error message
|
|
exceeded_status = next(
|
|
(s for s in statuses if s.is_limited),
|
|
statuses[0] if statuses else None,
|
|
)
|
|
|
|
if exceeded_status:
|
|
retry_after = exceeded_status.retry_after_seconds
|
|
|
|
# Determine if this is a soft limit (delay) or hard limit (deny)
|
|
if retry_after > 0 and retry_after <= 5.0:
|
|
# Short wait - suggest delay
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=False,
|
|
decision=SafetyDecision.DELAY,
|
|
reasons=[
|
|
f"Rate limit '{exceeded_status.name}' exceeded. "
|
|
f"Current: {exceeded_status.current_count}/{exceeded_status.limit}"
|
|
],
|
|
retry_after_seconds=retry_after,
|
|
)
|
|
else:
|
|
# Hard deny
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=False,
|
|
decision=SafetyDecision.DENY,
|
|
reasons=[
|
|
f"Rate limit '{exceeded_status.name}' exceeded. "
|
|
f"Current: {exceeded_status.current_count}/{exceeded_status.limit}. "
|
|
f"Retry after {retry_after:.1f}s"
|
|
],
|
|
retry_after_seconds=retry_after,
|
|
)
|
|
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=False,
|
|
decision=SafetyDecision.DENY,
|
|
reasons=["Rate limit exceeded"],
|
|
)
|
|
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=True,
|
|
decision=SafetyDecision.ALLOW,
|
|
reasons=["Rate limit check passed"],
|
|
)
|
|
|
|
except RateLimitExceededError as e:
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=False,
|
|
decision=SafetyDecision.DENY,
|
|
reasons=[str(e)],
|
|
retry_after_seconds=e.retry_after_seconds,
|
|
)
|
|
|
|
async def _check_loops(
|
|
self,
|
|
action: ActionRequest,
|
|
policy: SafetyPolicy,
|
|
) -> GuardianResult:
|
|
"""Check for action loops."""
|
|
if self._loop_detector is None:
|
|
logger.warning("LoopDetector not initialized - skipping loop check")
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=True,
|
|
decision=SafetyDecision.ALLOW,
|
|
reasons=["Loop check skipped (detector not initialized)"],
|
|
)
|
|
|
|
try:
|
|
# Check if this action would create a loop
|
|
is_loop, loop_type = await self._loop_detector.check(action)
|
|
|
|
if is_loop:
|
|
# Get suggestions for breaking the loop
|
|
from .loops.detector import LoopBreaker
|
|
|
|
suggestions = await LoopBreaker.suggest_alternatives(
|
|
action, loop_type or "unknown"
|
|
)
|
|
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=False,
|
|
decision=SafetyDecision.DENY,
|
|
reasons=[
|
|
f"Loop detected: {loop_type}",
|
|
*suggestions,
|
|
],
|
|
)
|
|
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=True,
|
|
decision=SafetyDecision.ALLOW,
|
|
reasons=["Loop check passed"],
|
|
)
|
|
|
|
except LoopDetectedError as e:
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=False,
|
|
decision=SafetyDecision.DENY,
|
|
reasons=[str(e)],
|
|
)
|
|
|
|
async def _check_hitl(
|
|
self,
|
|
action: ActionRequest,
|
|
policy: SafetyPolicy,
|
|
) -> GuardianResult:
|
|
"""Check if human approval is required."""
|
|
if not self._config.hitl_enabled:
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=True,
|
|
decision=SafetyDecision.ALLOW,
|
|
reasons=["HITL disabled"],
|
|
)
|
|
|
|
# Check if action requires approval
|
|
requires_approval = False
|
|
for pattern in policy.require_approval_for:
|
|
if pattern == "*":
|
|
requires_approval = True
|
|
break
|
|
if action.tool_name and self._matches_pattern(action.tool_name, pattern):
|
|
requires_approval = True
|
|
break
|
|
if action.action_type.value and self._matches_pattern(
|
|
action.action_type.value, pattern
|
|
):
|
|
requires_approval = True
|
|
break
|
|
|
|
if requires_approval:
|
|
# TODO: Create approval request with HITLManager
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=False,
|
|
decision=SafetyDecision.REQUIRE_APPROVAL,
|
|
reasons=["Action requires human approval"],
|
|
approval_id=None, # Will be set by HITLManager
|
|
)
|
|
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=True,
|
|
decision=SafetyDecision.ALLOW,
|
|
reasons=["No approval required"],
|
|
)
|
|
|
|
async def _create_checkpoint(self, action: ActionRequest) -> str | None:
|
|
"""Create a checkpoint before destructive action."""
|
|
if self._rollback_manager is None:
|
|
logger.warning("Rollback manager not available - skipping checkpoint")
|
|
return None
|
|
|
|
# TODO: Implement with RollbackManager
|
|
return None
|
|
|
|
async def _create_denial_result(
|
|
self,
|
|
action: ActionRequest,
|
|
reasons: list[str],
|
|
audit_events: list[Any],
|
|
retry_after: float | None = None,
|
|
) -> GuardianResult:
|
|
"""Create a denial result with audit logging."""
|
|
if self._audit_logger:
|
|
event = await self._audit_logger.log_action_request(
|
|
action, SafetyDecision.DENY, reasons
|
|
)
|
|
audit_events.append(event)
|
|
|
|
return GuardianResult(
|
|
action_id=action.id,
|
|
allowed=False,
|
|
decision=SafetyDecision.DENY,
|
|
reasons=reasons,
|
|
retry_after_seconds=retry_after,
|
|
audit_events=audit_events,
|
|
)
|
|
|
|
def _matches_pattern(self, value: str, pattern: str) -> bool:
|
|
"""Check if value matches a pattern (supports * wildcard)."""
|
|
if pattern == "*":
|
|
return True
|
|
|
|
if "*" not in pattern:
|
|
return value == pattern
|
|
|
|
# Simple wildcard matching
|
|
if pattern.startswith("*") and pattern.endswith("*"):
|
|
return pattern[1:-1] in value
|
|
elif pattern.startswith("*"):
|
|
return value.endswith(pattern[1:])
|
|
elif pattern.endswith("*"):
|
|
return value.startswith(pattern[:-1])
|
|
else:
|
|
# Pattern like "foo*bar"
|
|
parts = pattern.split("*")
|
|
if len(parts) == 2:
|
|
return value.startswith(parts[0]) and value.endswith(parts[1])
|
|
|
|
return False
|
|
|
|
|
|
# Singleton instance
|
|
_guardian_instance: SafetyGuardian | None = None
|
|
_guardian_lock = asyncio.Lock()
|
|
|
|
|
|
async def get_safety_guardian() -> SafetyGuardian:
|
|
"""Get the global SafetyGuardian instance."""
|
|
global _guardian_instance
|
|
|
|
async with _guardian_lock:
|
|
if _guardian_instance is None:
|
|
_guardian_instance = SafetyGuardian()
|
|
await _guardian_instance.initialize()
|
|
|
|
return _guardian_instance
|
|
|
|
|
|
async def shutdown_safety_guardian() -> None:
|
|
"""Shutdown the global SafetyGuardian."""
|
|
global _guardian_instance
|
|
|
|
async with _guardian_lock:
|
|
if _guardian_instance is not None:
|
|
await _guardian_instance.shutdown()
|
|
_guardian_instance = None
|
|
|
|
|
|
async def reset_safety_guardian() -> None:
|
|
"""
|
|
Reset the SafetyGuardian (for testing).
|
|
|
|
This is an async function to properly acquire the guardian lock
|
|
and avoid race conditions with get_safety_guardian().
|
|
"""
|
|
global _guardian_instance
|
|
|
|
async with _guardian_lock:
|
|
if _guardian_instance is not None:
|
|
try:
|
|
await _guardian_instance.shutdown()
|
|
except Exception: # noqa: S110
|
|
pass # Ignore errors during test cleanup
|
|
_guardian_instance = None
|