forked from cardosofelipe/fast-next-template
feat(safety): add Phase C advanced controls
- Add rollback manager with file checkpointing and transaction context - Add HITL manager with approval queues and notification handlers - Add content filter with PII, secrets, and injection detection - Add emergency controls with stop/pause/resume capabilities - Update SafetyConfig with checkpoint_dir setting Issue #63 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1 +1,23 @@
|
||||
"""${dir} module."""
|
||||
"""Emergency controls for agent safety."""
|
||||
|
||||
from .controls import (
|
||||
EmergencyControls,
|
||||
EmergencyEvent,
|
||||
EmergencyReason,
|
||||
EmergencyState,
|
||||
EmergencyTrigger,
|
||||
check_emergency_allowed,
|
||||
emergency_stop_global,
|
||||
get_emergency_controls,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"EmergencyControls",
|
||||
"EmergencyEvent",
|
||||
"EmergencyReason",
|
||||
"EmergencyState",
|
||||
"EmergencyTrigger",
|
||||
"check_emergency_allowed",
|
||||
"emergency_stop_global",
|
||||
"get_emergency_controls",
|
||||
]
|
||||
|
||||
594
backend/app/services/safety/emergency/controls.py
Normal file
594
backend/app/services/safety/emergency/controls.py
Normal file
@@ -0,0 +1,594 @@
|
||||
"""
|
||||
Emergency Controls
|
||||
|
||||
Emergency stop and pause functionality for agent safety.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from ..exceptions import EmergencyStopError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmergencyState(str, Enum):
|
||||
"""Emergency control states."""
|
||||
|
||||
NORMAL = "normal"
|
||||
PAUSED = "paused"
|
||||
STOPPED = "stopped"
|
||||
|
||||
|
||||
class EmergencyReason(str, Enum):
|
||||
"""Reasons for emergency actions."""
|
||||
|
||||
MANUAL = "manual"
|
||||
SAFETY_VIOLATION = "safety_violation"
|
||||
BUDGET_EXCEEDED = "budget_exceeded"
|
||||
LOOP_DETECTED = "loop_detected"
|
||||
RATE_LIMIT = "rate_limit"
|
||||
CONTENT_VIOLATION = "content_violation"
|
||||
SYSTEM_ERROR = "system_error"
|
||||
EXTERNAL_TRIGGER = "external_trigger"
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmergencyEvent:
|
||||
"""Record of an emergency action."""
|
||||
|
||||
id: str
|
||||
state: EmergencyState
|
||||
reason: EmergencyReason
|
||||
triggered_by: str
|
||||
message: str
|
||||
scope: str # "global", "project:<id>", "agent:<id>"
|
||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
resolved_at: datetime | None = None
|
||||
resolved_by: str | None = None
|
||||
|
||||
|
||||
class EmergencyControls:
|
||||
"""
|
||||
Emergency stop and pause controls for agent safety.
|
||||
|
||||
Features:
|
||||
- Global emergency stop
|
||||
- Per-project/agent emergency controls
|
||||
- Graceful pause with state preservation
|
||||
- Automatic triggers from safety violations
|
||||
- Manual override capabilities
|
||||
- Event history and audit trail
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
notification_handlers: list[Callable[..., Any]] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize EmergencyControls.
|
||||
|
||||
Args:
|
||||
notification_handlers: Handlers to call on emergency events
|
||||
"""
|
||||
self._global_state = EmergencyState.NORMAL
|
||||
self._scoped_states: dict[str, EmergencyState] = {}
|
||||
self._events: list[EmergencyEvent] = []
|
||||
self._notification_handlers = notification_handlers or []
|
||||
self._lock = asyncio.Lock()
|
||||
self._event_id_counter = 0
|
||||
|
||||
# Callbacks for state changes
|
||||
self._on_stop_callbacks: list[Callable[..., Any]] = []
|
||||
self._on_pause_callbacks: list[Callable[..., Any]] = []
|
||||
self._on_resume_callbacks: list[Callable[..., Any]] = []
|
||||
|
||||
def _generate_event_id(self) -> str:
|
||||
"""Generate a unique event ID."""
|
||||
self._event_id_counter += 1
|
||||
return f"emerg-{self._event_id_counter:06d}"
|
||||
|
||||
async def emergency_stop(
|
||||
self,
|
||||
reason: EmergencyReason,
|
||||
triggered_by: str,
|
||||
message: str,
|
||||
scope: str = "global",
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> EmergencyEvent:
|
||||
"""
|
||||
Trigger emergency stop.
|
||||
|
||||
Args:
|
||||
reason: Reason for the stop
|
||||
triggered_by: Who/what triggered the stop
|
||||
message: Human-readable message
|
||||
scope: Scope of the stop (global, project:<id>, agent:<id>)
|
||||
metadata: Additional context
|
||||
|
||||
Returns:
|
||||
The emergency event record
|
||||
"""
|
||||
async with self._lock:
|
||||
event = EmergencyEvent(
|
||||
id=self._generate_event_id(),
|
||||
state=EmergencyState.STOPPED,
|
||||
reason=reason,
|
||||
triggered_by=triggered_by,
|
||||
message=message,
|
||||
scope=scope,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
if scope == "global":
|
||||
self._global_state = EmergencyState.STOPPED
|
||||
else:
|
||||
self._scoped_states[scope] = EmergencyState.STOPPED
|
||||
|
||||
self._events.append(event)
|
||||
|
||||
logger.critical(
|
||||
"EMERGENCY STOP: scope=%s, reason=%s, by=%s - %s",
|
||||
scope,
|
||||
reason.value,
|
||||
triggered_by,
|
||||
message,
|
||||
)
|
||||
|
||||
# Execute callbacks
|
||||
await self._execute_callbacks(self._on_stop_callbacks, event)
|
||||
await self._notify_handlers("emergency_stop", event)
|
||||
|
||||
return event
|
||||
|
||||
async def pause(
|
||||
self,
|
||||
reason: EmergencyReason,
|
||||
triggered_by: str,
|
||||
message: str,
|
||||
scope: str = "global",
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> EmergencyEvent:
|
||||
"""
|
||||
Pause operations (can be resumed).
|
||||
|
||||
Args:
|
||||
reason: Reason for the pause
|
||||
triggered_by: Who/what triggered the pause
|
||||
message: Human-readable message
|
||||
scope: Scope of the pause
|
||||
metadata: Additional context
|
||||
|
||||
Returns:
|
||||
The emergency event record
|
||||
"""
|
||||
async with self._lock:
|
||||
event = EmergencyEvent(
|
||||
id=self._generate_event_id(),
|
||||
state=EmergencyState.PAUSED,
|
||||
reason=reason,
|
||||
triggered_by=triggered_by,
|
||||
message=message,
|
||||
scope=scope,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
if scope == "global":
|
||||
self._global_state = EmergencyState.PAUSED
|
||||
else:
|
||||
self._scoped_states[scope] = EmergencyState.PAUSED
|
||||
|
||||
self._events.append(event)
|
||||
|
||||
logger.warning(
|
||||
"PAUSE: scope=%s, reason=%s, by=%s - %s",
|
||||
scope,
|
||||
reason.value,
|
||||
triggered_by,
|
||||
message,
|
||||
)
|
||||
|
||||
await self._execute_callbacks(self._on_pause_callbacks, event)
|
||||
await self._notify_handlers("pause", event)
|
||||
|
||||
return event
|
||||
|
||||
async def resume(
|
||||
self,
|
||||
scope: str = "global",
|
||||
resumed_by: str = "system",
|
||||
message: str | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Resume operations from paused state.
|
||||
|
||||
Args:
|
||||
scope: Scope to resume
|
||||
resumed_by: Who/what is resuming
|
||||
message: Optional message
|
||||
|
||||
Returns:
|
||||
True if resumed, False if not in paused state
|
||||
"""
|
||||
async with self._lock:
|
||||
current_state = self._get_state(scope)
|
||||
|
||||
if current_state == EmergencyState.STOPPED:
|
||||
logger.warning(
|
||||
"Cannot resume from STOPPED state: %s (requires reset)",
|
||||
scope,
|
||||
)
|
||||
return False
|
||||
|
||||
if current_state == EmergencyState.NORMAL:
|
||||
return True # Already normal
|
||||
|
||||
# Find the pause event and mark as resolved
|
||||
for event in reversed(self._events):
|
||||
if event.scope == scope and event.state == EmergencyState.PAUSED:
|
||||
if event.resolved_at is None:
|
||||
event.resolved_at = datetime.utcnow()
|
||||
event.resolved_by = resumed_by
|
||||
break
|
||||
|
||||
if scope == "global":
|
||||
self._global_state = EmergencyState.NORMAL
|
||||
else:
|
||||
self._scoped_states[scope] = EmergencyState.NORMAL
|
||||
|
||||
logger.info(
|
||||
"RESUMED: scope=%s, by=%s%s",
|
||||
scope,
|
||||
resumed_by,
|
||||
f" - {message}" if message else "",
|
||||
)
|
||||
|
||||
await self._execute_callbacks(
|
||||
self._on_resume_callbacks,
|
||||
{"scope": scope, "resumed_by": resumed_by},
|
||||
)
|
||||
await self._notify_handlers("resume", {"scope": scope, "resumed_by": resumed_by})
|
||||
|
||||
return True
|
||||
|
||||
async def reset(
|
||||
self,
|
||||
scope: str = "global",
|
||||
reset_by: str = "admin",
|
||||
message: str | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Reset from stopped state (requires explicit action).
|
||||
|
||||
Args:
|
||||
scope: Scope to reset
|
||||
reset_by: Who is resetting (should be admin)
|
||||
message: Optional message
|
||||
|
||||
Returns:
|
||||
True if reset successful
|
||||
"""
|
||||
async with self._lock:
|
||||
current_state = self._get_state(scope)
|
||||
|
||||
if current_state == EmergencyState.NORMAL:
|
||||
return True
|
||||
|
||||
# Find the stop event and mark as resolved
|
||||
for event in reversed(self._events):
|
||||
if event.scope == scope and event.state == EmergencyState.STOPPED:
|
||||
if event.resolved_at is None:
|
||||
event.resolved_at = datetime.utcnow()
|
||||
event.resolved_by = reset_by
|
||||
break
|
||||
|
||||
if scope == "global":
|
||||
self._global_state = EmergencyState.NORMAL
|
||||
else:
|
||||
self._scoped_states[scope] = EmergencyState.NORMAL
|
||||
|
||||
logger.warning(
|
||||
"EMERGENCY RESET: scope=%s, by=%s%s",
|
||||
scope,
|
||||
reset_by,
|
||||
f" - {message}" if message else "",
|
||||
)
|
||||
|
||||
await self._notify_handlers("reset", {"scope": scope, "reset_by": reset_by})
|
||||
|
||||
return True
|
||||
|
||||
async def check_allowed(
|
||||
self,
|
||||
scope: str | None = None,
|
||||
raise_if_blocked: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if operations are allowed.
|
||||
|
||||
Args:
|
||||
scope: Specific scope to check (also checks global)
|
||||
raise_if_blocked: Raise exception if blocked
|
||||
|
||||
Returns:
|
||||
True if operations are allowed
|
||||
|
||||
Raises:
|
||||
EmergencyStopError: If blocked and raise_if_blocked=True
|
||||
"""
|
||||
async with self._lock:
|
||||
# Always check global state
|
||||
if self._global_state != EmergencyState.NORMAL:
|
||||
if raise_if_blocked:
|
||||
raise EmergencyStopError(
|
||||
f"Global emergency state: {self._global_state.value}",
|
||||
stop_reason=self._get_last_reason("global"),
|
||||
triggered_by=self._get_last_triggered_by("global"),
|
||||
)
|
||||
return False
|
||||
|
||||
# Check specific scope
|
||||
if scope and scope in self._scoped_states:
|
||||
state = self._scoped_states[scope]
|
||||
if state != EmergencyState.NORMAL:
|
||||
if raise_if_blocked:
|
||||
raise EmergencyStopError(
|
||||
f"Emergency state for {scope}: {state.value}",
|
||||
stop_reason=self._get_last_reason(scope),
|
||||
triggered_by=self._get_last_triggered_by(scope),
|
||||
scope=scope,
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _get_state(self, scope: str) -> EmergencyState:
|
||||
"""Get state for a scope."""
|
||||
if scope == "global":
|
||||
return self._global_state
|
||||
return self._scoped_states.get(scope, EmergencyState.NORMAL)
|
||||
|
||||
def _get_last_reason(self, scope: str) -> str:
|
||||
"""Get reason from last event for scope."""
|
||||
for event in reversed(self._events):
|
||||
if event.scope == scope and event.resolved_at is None:
|
||||
return event.reason.value
|
||||
return "unknown"
|
||||
|
||||
def _get_last_triggered_by(self, scope: str) -> str:
|
||||
"""Get triggered_by from last event for scope."""
|
||||
for event in reversed(self._events):
|
||||
if event.scope == scope and event.resolved_at is None:
|
||||
return event.triggered_by
|
||||
return "unknown"
|
||||
|
||||
async def get_state(self, scope: str = "global") -> EmergencyState:
|
||||
"""Get current state for a scope."""
|
||||
async with self._lock:
|
||||
return self._get_state(scope)
|
||||
|
||||
async def get_all_states(self) -> dict[str, EmergencyState]:
|
||||
"""Get all current states."""
|
||||
async with self._lock:
|
||||
states = {"global": self._global_state}
|
||||
states.update(self._scoped_states)
|
||||
return states
|
||||
|
||||
async def get_active_events(self) -> list[EmergencyEvent]:
|
||||
"""Get all unresolved emergency events."""
|
||||
async with self._lock:
|
||||
return [e for e in self._events if e.resolved_at is None]
|
||||
|
||||
async def get_event_history(
|
||||
self,
|
||||
scope: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[EmergencyEvent]:
|
||||
"""Get emergency event history."""
|
||||
async with self._lock:
|
||||
events = list(self._events)
|
||||
|
||||
if scope:
|
||||
events = [e for e in events if e.scope == scope]
|
||||
|
||||
return events[-limit:]
|
||||
|
||||
def on_stop(self, callback: Callable[..., Any]) -> None:
|
||||
"""Register callback for stop events."""
|
||||
self._on_stop_callbacks.append(callback)
|
||||
|
||||
def on_pause(self, callback: Callable[..., Any]) -> None:
|
||||
"""Register callback for pause events."""
|
||||
self._on_pause_callbacks.append(callback)
|
||||
|
||||
def on_resume(self, callback: Callable[..., Any]) -> None:
|
||||
"""Register callback for resume events."""
|
||||
self._on_resume_callbacks.append(callback)
|
||||
|
||||
def add_notification_handler(self, handler: Callable[..., Any]) -> None:
|
||||
"""Add a notification handler."""
|
||||
self._notification_handlers.append(handler)
|
||||
|
||||
async def _execute_callbacks(
|
||||
self,
|
||||
callbacks: list[Callable[..., Any]],
|
||||
data: Any,
|
||||
) -> None:
|
||||
"""Execute callbacks safely."""
|
||||
for callback in callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(data)
|
||||
else:
|
||||
callback(data)
|
||||
except Exception as e:
|
||||
logger.error("Error in callback: %s", e)
|
||||
|
||||
async def _notify_handlers(self, event_type: str, data: Any) -> None:
|
||||
"""Notify all handlers of an event."""
|
||||
for handler in self._notification_handlers:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
await handler(event_type, data)
|
||||
else:
|
||||
handler(event_type, data)
|
||||
except Exception as e:
|
||||
logger.error("Error in notification handler: %s", e)
|
||||
|
||||
|
||||
class EmergencyTrigger:
|
||||
"""
|
||||
Automatic emergency triggers based on conditions.
|
||||
"""
|
||||
|
||||
def __init__(self, controls: EmergencyControls) -> None:
|
||||
"""
|
||||
Initialize EmergencyTrigger.
|
||||
|
||||
Args:
|
||||
controls: EmergencyControls instance to trigger
|
||||
"""
|
||||
self._controls = controls
|
||||
|
||||
async def trigger_on_safety_violation(
|
||||
self,
|
||||
violation_type: str,
|
||||
details: dict[str, Any],
|
||||
scope: str = "global",
|
||||
) -> EmergencyEvent:
|
||||
"""
|
||||
Trigger emergency from safety violation.
|
||||
|
||||
Args:
|
||||
violation_type: Type of violation
|
||||
details: Violation details
|
||||
scope: Scope for the emergency
|
||||
|
||||
Returns:
|
||||
Emergency event
|
||||
"""
|
||||
return await self._controls.emergency_stop(
|
||||
reason=EmergencyReason.SAFETY_VIOLATION,
|
||||
triggered_by="safety_system",
|
||||
message=f"Safety violation: {violation_type}",
|
||||
scope=scope,
|
||||
metadata={"violation_type": violation_type, **details},
|
||||
)
|
||||
|
||||
async def trigger_on_budget_exceeded(
|
||||
self,
|
||||
budget_type: str,
|
||||
current: float,
|
||||
limit: float,
|
||||
scope: str = "global",
|
||||
) -> EmergencyEvent:
|
||||
"""
|
||||
Trigger emergency from budget exceeded.
|
||||
|
||||
Args:
|
||||
budget_type: Type of budget
|
||||
current: Current usage
|
||||
limit: Budget limit
|
||||
scope: Scope for the emergency
|
||||
|
||||
Returns:
|
||||
Emergency event
|
||||
"""
|
||||
return await self._controls.pause(
|
||||
reason=EmergencyReason.BUDGET_EXCEEDED,
|
||||
triggered_by="budget_controller",
|
||||
message=f"Budget exceeded: {budget_type} ({current:.2f}/{limit:.2f})",
|
||||
scope=scope,
|
||||
metadata={"budget_type": budget_type, "current": current, "limit": limit},
|
||||
)
|
||||
|
||||
async def trigger_on_loop_detected(
|
||||
self,
|
||||
loop_type: str,
|
||||
agent_id: str,
|
||||
details: dict[str, Any],
|
||||
) -> EmergencyEvent:
|
||||
"""
|
||||
Trigger emergency from loop detection.
|
||||
|
||||
Args:
|
||||
loop_type: Type of loop
|
||||
agent_id: Agent that's looping
|
||||
details: Loop details
|
||||
|
||||
Returns:
|
||||
Emergency event
|
||||
"""
|
||||
return await self._controls.pause(
|
||||
reason=EmergencyReason.LOOP_DETECTED,
|
||||
triggered_by="loop_detector",
|
||||
message=f"Loop detected: {loop_type} in agent {agent_id}",
|
||||
scope=f"agent:{agent_id}",
|
||||
metadata={"loop_type": loop_type, "agent_id": agent_id, **details},
|
||||
)
|
||||
|
||||
async def trigger_on_content_violation(
|
||||
self,
|
||||
category: str,
|
||||
pattern: str,
|
||||
scope: str = "global",
|
||||
) -> EmergencyEvent:
|
||||
"""
|
||||
Trigger emergency from content violation.
|
||||
|
||||
Args:
|
||||
category: Content category
|
||||
pattern: Pattern that matched
|
||||
scope: Scope for the emergency
|
||||
|
||||
Returns:
|
||||
Emergency event
|
||||
"""
|
||||
return await self._controls.emergency_stop(
|
||||
reason=EmergencyReason.CONTENT_VIOLATION,
|
||||
triggered_by="content_filter",
|
||||
message=f"Content violation: {category} ({pattern})",
|
||||
scope=scope,
|
||||
metadata={"category": category, "pattern": pattern},
|
||||
)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_emergency_controls: EmergencyControls | None = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_emergency_controls() -> EmergencyControls:
|
||||
"""Get the singleton EmergencyControls instance."""
|
||||
global _emergency_controls
|
||||
|
||||
async with _lock:
|
||||
if _emergency_controls is None:
|
||||
_emergency_controls = EmergencyControls()
|
||||
return _emergency_controls
|
||||
|
||||
|
||||
async def emergency_stop_global(
|
||||
reason: str,
|
||||
triggered_by: str = "system",
|
||||
) -> EmergencyEvent:
|
||||
"""Quick global emergency stop."""
|
||||
controls = await get_emergency_controls()
|
||||
return await controls.emergency_stop(
|
||||
reason=EmergencyReason.MANUAL,
|
||||
triggered_by=triggered_by,
|
||||
message=reason,
|
||||
scope="global",
|
||||
)
|
||||
|
||||
|
||||
async def check_emergency_allowed(scope: str | None = None) -> bool:
|
||||
"""Quick check if operations are allowed."""
|
||||
controls = await get_emergency_controls()
|
||||
return await controls.check_allowed(scope=scope, raise_if_blocked=False)
|
||||
Reference in New Issue
Block a user