forked from cardosofelipe/fast-next-template
Improved code readability and uniformity by standardizing line breaks, indentation, and inline conditions across safety-related services, models, and tests, including content filters, validation rules, and emergency controls.
597 lines
18 KiB
Python
597 lines
18 KiB
Python
"""
|
|
Emergency Controls
|
|
|
|
Emergency stop and pause functionality for agent safety.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
from enum import Enum
|
|
from typing import Any
|
|
|
|
from ..exceptions import EmergencyStopError
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class EmergencyState(str, Enum):
|
|
"""Emergency control states."""
|
|
|
|
NORMAL = "normal"
|
|
PAUSED = "paused"
|
|
STOPPED = "stopped"
|
|
|
|
|
|
class EmergencyReason(str, Enum):
|
|
"""Reasons for emergency actions."""
|
|
|
|
MANUAL = "manual"
|
|
SAFETY_VIOLATION = "safety_violation"
|
|
BUDGET_EXCEEDED = "budget_exceeded"
|
|
LOOP_DETECTED = "loop_detected"
|
|
RATE_LIMIT = "rate_limit"
|
|
CONTENT_VIOLATION = "content_violation"
|
|
SYSTEM_ERROR = "system_error"
|
|
EXTERNAL_TRIGGER = "external_trigger"
|
|
|
|
|
|
@dataclass
|
|
class EmergencyEvent:
|
|
"""Record of an emergency action."""
|
|
|
|
id: str
|
|
state: EmergencyState
|
|
reason: EmergencyReason
|
|
triggered_by: str
|
|
message: str
|
|
scope: str # "global", "project:<id>", "agent:<id>"
|
|
timestamp: datetime = field(default_factory=datetime.utcnow)
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|
resolved_at: datetime | None = None
|
|
resolved_by: str | None = None
|
|
|
|
|
|
class EmergencyControls:
|
|
"""
|
|
Emergency stop and pause controls for agent safety.
|
|
|
|
Features:
|
|
- Global emergency stop
|
|
- Per-project/agent emergency controls
|
|
- Graceful pause with state preservation
|
|
- Automatic triggers from safety violations
|
|
- Manual override capabilities
|
|
- Event history and audit trail
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
notification_handlers: list[Callable[..., Any]] | None = None,
|
|
) -> None:
|
|
"""
|
|
Initialize EmergencyControls.
|
|
|
|
Args:
|
|
notification_handlers: Handlers to call on emergency events
|
|
"""
|
|
self._global_state = EmergencyState.NORMAL
|
|
self._scoped_states: dict[str, EmergencyState] = {}
|
|
self._events: list[EmergencyEvent] = []
|
|
self._notification_handlers = notification_handlers or []
|
|
self._lock = asyncio.Lock()
|
|
self._event_id_counter = 0
|
|
|
|
# Callbacks for state changes
|
|
self._on_stop_callbacks: list[Callable[..., Any]] = []
|
|
self._on_pause_callbacks: list[Callable[..., Any]] = []
|
|
self._on_resume_callbacks: list[Callable[..., Any]] = []
|
|
|
|
def _generate_event_id(self) -> str:
|
|
"""Generate a unique event ID."""
|
|
self._event_id_counter += 1
|
|
return f"emerg-{self._event_id_counter:06d}"
|
|
|
|
async def emergency_stop(
|
|
self,
|
|
reason: EmergencyReason,
|
|
triggered_by: str,
|
|
message: str,
|
|
scope: str = "global",
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> EmergencyEvent:
|
|
"""
|
|
Trigger emergency stop.
|
|
|
|
Args:
|
|
reason: Reason for the stop
|
|
triggered_by: Who/what triggered the stop
|
|
message: Human-readable message
|
|
scope: Scope of the stop (global, project:<id>, agent:<id>)
|
|
metadata: Additional context
|
|
|
|
Returns:
|
|
The emergency event record
|
|
"""
|
|
async with self._lock:
|
|
event = EmergencyEvent(
|
|
id=self._generate_event_id(),
|
|
state=EmergencyState.STOPPED,
|
|
reason=reason,
|
|
triggered_by=triggered_by,
|
|
message=message,
|
|
scope=scope,
|
|
metadata=metadata or {},
|
|
)
|
|
|
|
if scope == "global":
|
|
self._global_state = EmergencyState.STOPPED
|
|
else:
|
|
self._scoped_states[scope] = EmergencyState.STOPPED
|
|
|
|
self._events.append(event)
|
|
|
|
logger.critical(
|
|
"EMERGENCY STOP: scope=%s, reason=%s, by=%s - %s",
|
|
scope,
|
|
reason.value,
|
|
triggered_by,
|
|
message,
|
|
)
|
|
|
|
# Execute callbacks
|
|
await self._execute_callbacks(self._on_stop_callbacks, event)
|
|
await self._notify_handlers("emergency_stop", event)
|
|
|
|
return event
|
|
|
|
async def pause(
|
|
self,
|
|
reason: EmergencyReason,
|
|
triggered_by: str,
|
|
message: str,
|
|
scope: str = "global",
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> EmergencyEvent:
|
|
"""
|
|
Pause operations (can be resumed).
|
|
|
|
Args:
|
|
reason: Reason for the pause
|
|
triggered_by: Who/what triggered the pause
|
|
message: Human-readable message
|
|
scope: Scope of the pause
|
|
metadata: Additional context
|
|
|
|
Returns:
|
|
The emergency event record
|
|
"""
|
|
async with self._lock:
|
|
event = EmergencyEvent(
|
|
id=self._generate_event_id(),
|
|
state=EmergencyState.PAUSED,
|
|
reason=reason,
|
|
triggered_by=triggered_by,
|
|
message=message,
|
|
scope=scope,
|
|
metadata=metadata or {},
|
|
)
|
|
|
|
if scope == "global":
|
|
self._global_state = EmergencyState.PAUSED
|
|
else:
|
|
self._scoped_states[scope] = EmergencyState.PAUSED
|
|
|
|
self._events.append(event)
|
|
|
|
logger.warning(
|
|
"PAUSE: scope=%s, reason=%s, by=%s - %s",
|
|
scope,
|
|
reason.value,
|
|
triggered_by,
|
|
message,
|
|
)
|
|
|
|
await self._execute_callbacks(self._on_pause_callbacks, event)
|
|
await self._notify_handlers("pause", event)
|
|
|
|
return event
|
|
|
|
async def resume(
|
|
self,
|
|
scope: str = "global",
|
|
resumed_by: str = "system",
|
|
message: str | None = None,
|
|
) -> bool:
|
|
"""
|
|
Resume operations from paused state.
|
|
|
|
Args:
|
|
scope: Scope to resume
|
|
resumed_by: Who/what is resuming
|
|
message: Optional message
|
|
|
|
Returns:
|
|
True if resumed, False if not in paused state
|
|
"""
|
|
async with self._lock:
|
|
current_state = self._get_state(scope)
|
|
|
|
if current_state == EmergencyState.STOPPED:
|
|
logger.warning(
|
|
"Cannot resume from STOPPED state: %s (requires reset)",
|
|
scope,
|
|
)
|
|
return False
|
|
|
|
if current_state == EmergencyState.NORMAL:
|
|
return True # Already normal
|
|
|
|
# Find the pause event and mark as resolved
|
|
for event in reversed(self._events):
|
|
if event.scope == scope and event.state == EmergencyState.PAUSED:
|
|
if event.resolved_at is None:
|
|
event.resolved_at = datetime.utcnow()
|
|
event.resolved_by = resumed_by
|
|
break
|
|
|
|
if scope == "global":
|
|
self._global_state = EmergencyState.NORMAL
|
|
else:
|
|
self._scoped_states[scope] = EmergencyState.NORMAL
|
|
|
|
logger.info(
|
|
"RESUMED: scope=%s, by=%s%s",
|
|
scope,
|
|
resumed_by,
|
|
f" - {message}" if message else "",
|
|
)
|
|
|
|
await self._execute_callbacks(
|
|
self._on_resume_callbacks,
|
|
{"scope": scope, "resumed_by": resumed_by},
|
|
)
|
|
await self._notify_handlers(
|
|
"resume", {"scope": scope, "resumed_by": resumed_by}
|
|
)
|
|
|
|
return True
|
|
|
|
async def reset(
|
|
self,
|
|
scope: str = "global",
|
|
reset_by: str = "admin",
|
|
message: str | None = None,
|
|
) -> bool:
|
|
"""
|
|
Reset from stopped state (requires explicit action).
|
|
|
|
Args:
|
|
scope: Scope to reset
|
|
reset_by: Who is resetting (should be admin)
|
|
message: Optional message
|
|
|
|
Returns:
|
|
True if reset successful
|
|
"""
|
|
async with self._lock:
|
|
current_state = self._get_state(scope)
|
|
|
|
if current_state == EmergencyState.NORMAL:
|
|
return True
|
|
|
|
# Find the stop event and mark as resolved
|
|
for event in reversed(self._events):
|
|
if event.scope == scope and event.state == EmergencyState.STOPPED:
|
|
if event.resolved_at is None:
|
|
event.resolved_at = datetime.utcnow()
|
|
event.resolved_by = reset_by
|
|
break
|
|
|
|
if scope == "global":
|
|
self._global_state = EmergencyState.NORMAL
|
|
else:
|
|
self._scoped_states[scope] = EmergencyState.NORMAL
|
|
|
|
logger.warning(
|
|
"EMERGENCY RESET: scope=%s, by=%s%s",
|
|
scope,
|
|
reset_by,
|
|
f" - {message}" if message else "",
|
|
)
|
|
|
|
await self._notify_handlers("reset", {"scope": scope, "reset_by": reset_by})
|
|
|
|
return True
|
|
|
|
async def check_allowed(
|
|
self,
|
|
scope: str | None = None,
|
|
raise_if_blocked: bool = True,
|
|
) -> bool:
|
|
"""
|
|
Check if operations are allowed.
|
|
|
|
Args:
|
|
scope: Specific scope to check (also checks global)
|
|
raise_if_blocked: Raise exception if blocked
|
|
|
|
Returns:
|
|
True if operations are allowed
|
|
|
|
Raises:
|
|
EmergencyStopError: If blocked and raise_if_blocked=True
|
|
"""
|
|
async with self._lock:
|
|
# Always check global state
|
|
if self._global_state != EmergencyState.NORMAL:
|
|
if raise_if_blocked:
|
|
raise EmergencyStopError(
|
|
f"Global emergency state: {self._global_state.value}",
|
|
stop_type=self._get_last_reason("global") or "emergency",
|
|
triggered_by=self._get_last_triggered_by("global"),
|
|
)
|
|
return False
|
|
|
|
# Check specific scope
|
|
if scope and scope in self._scoped_states:
|
|
state = self._scoped_states[scope]
|
|
if state != EmergencyState.NORMAL:
|
|
if raise_if_blocked:
|
|
raise EmergencyStopError(
|
|
f"Emergency state for {scope}: {state.value}",
|
|
stop_type=self._get_last_reason(scope) or "emergency",
|
|
triggered_by=self._get_last_triggered_by(scope),
|
|
details={"scope": scope},
|
|
)
|
|
return False
|
|
|
|
return True
|
|
|
|
def _get_state(self, scope: str) -> EmergencyState:
|
|
"""Get state for a scope."""
|
|
if scope == "global":
|
|
return self._global_state
|
|
return self._scoped_states.get(scope, EmergencyState.NORMAL)
|
|
|
|
def _get_last_reason(self, scope: str) -> str:
|
|
"""Get reason from last event for scope."""
|
|
for event in reversed(self._events):
|
|
if event.scope == scope and event.resolved_at is None:
|
|
return event.reason.value
|
|
return "unknown"
|
|
|
|
def _get_last_triggered_by(self, scope: str) -> str:
|
|
"""Get triggered_by from last event for scope."""
|
|
for event in reversed(self._events):
|
|
if event.scope == scope and event.resolved_at is None:
|
|
return event.triggered_by
|
|
return "unknown"
|
|
|
|
async def get_state(self, scope: str = "global") -> EmergencyState:
|
|
"""Get current state for a scope."""
|
|
async with self._lock:
|
|
return self._get_state(scope)
|
|
|
|
async def get_all_states(self) -> dict[str, EmergencyState]:
|
|
"""Get all current states."""
|
|
async with self._lock:
|
|
states = {"global": self._global_state}
|
|
states.update(self._scoped_states)
|
|
return states
|
|
|
|
async def get_active_events(self) -> list[EmergencyEvent]:
|
|
"""Get all unresolved emergency events."""
|
|
async with self._lock:
|
|
return [e for e in self._events if e.resolved_at is None]
|
|
|
|
async def get_event_history(
|
|
self,
|
|
scope: str | None = None,
|
|
limit: int = 100,
|
|
) -> list[EmergencyEvent]:
|
|
"""Get emergency event history."""
|
|
async with self._lock:
|
|
events = list(self._events)
|
|
|
|
if scope:
|
|
events = [e for e in events if e.scope == scope]
|
|
|
|
return events[-limit:]
|
|
|
|
def on_stop(self, callback: Callable[..., Any]) -> None:
|
|
"""Register callback for stop events."""
|
|
self._on_stop_callbacks.append(callback)
|
|
|
|
def on_pause(self, callback: Callable[..., Any]) -> None:
|
|
"""Register callback for pause events."""
|
|
self._on_pause_callbacks.append(callback)
|
|
|
|
def on_resume(self, callback: Callable[..., Any]) -> None:
|
|
"""Register callback for resume events."""
|
|
self._on_resume_callbacks.append(callback)
|
|
|
|
def add_notification_handler(self, handler: Callable[..., Any]) -> None:
|
|
"""Add a notification handler."""
|
|
self._notification_handlers.append(handler)
|
|
|
|
async def _execute_callbacks(
|
|
self,
|
|
callbacks: list[Callable[..., Any]],
|
|
data: Any,
|
|
) -> None:
|
|
"""Execute callbacks safely."""
|
|
for callback in callbacks:
|
|
try:
|
|
if asyncio.iscoroutinefunction(callback):
|
|
await callback(data)
|
|
else:
|
|
callback(data)
|
|
except Exception as e:
|
|
logger.error("Error in callback: %s", e)
|
|
|
|
async def _notify_handlers(self, event_type: str, data: Any) -> None:
|
|
"""Notify all handlers of an event."""
|
|
for handler in self._notification_handlers:
|
|
try:
|
|
if asyncio.iscoroutinefunction(handler):
|
|
await handler(event_type, data)
|
|
else:
|
|
handler(event_type, data)
|
|
except Exception as e:
|
|
logger.error("Error in notification handler: %s", e)
|
|
|
|
|
|
class EmergencyTrigger:
|
|
"""
|
|
Automatic emergency triggers based on conditions.
|
|
"""
|
|
|
|
def __init__(self, controls: EmergencyControls) -> None:
|
|
"""
|
|
Initialize EmergencyTrigger.
|
|
|
|
Args:
|
|
controls: EmergencyControls instance to trigger
|
|
"""
|
|
self._controls = controls
|
|
|
|
async def trigger_on_safety_violation(
|
|
self,
|
|
violation_type: str,
|
|
details: dict[str, Any],
|
|
scope: str = "global",
|
|
) -> EmergencyEvent:
|
|
"""
|
|
Trigger emergency from safety violation.
|
|
|
|
Args:
|
|
violation_type: Type of violation
|
|
details: Violation details
|
|
scope: Scope for the emergency
|
|
|
|
Returns:
|
|
Emergency event
|
|
"""
|
|
return await self._controls.emergency_stop(
|
|
reason=EmergencyReason.SAFETY_VIOLATION,
|
|
triggered_by="safety_system",
|
|
message=f"Safety violation: {violation_type}",
|
|
scope=scope,
|
|
metadata={"violation_type": violation_type, **details},
|
|
)
|
|
|
|
async def trigger_on_budget_exceeded(
|
|
self,
|
|
budget_type: str,
|
|
current: float,
|
|
limit: float,
|
|
scope: str = "global",
|
|
) -> EmergencyEvent:
|
|
"""
|
|
Trigger emergency from budget exceeded.
|
|
|
|
Args:
|
|
budget_type: Type of budget
|
|
current: Current usage
|
|
limit: Budget limit
|
|
scope: Scope for the emergency
|
|
|
|
Returns:
|
|
Emergency event
|
|
"""
|
|
return await self._controls.pause(
|
|
reason=EmergencyReason.BUDGET_EXCEEDED,
|
|
triggered_by="budget_controller",
|
|
message=f"Budget exceeded: {budget_type} ({current:.2f}/{limit:.2f})",
|
|
scope=scope,
|
|
metadata={"budget_type": budget_type, "current": current, "limit": limit},
|
|
)
|
|
|
|
async def trigger_on_loop_detected(
|
|
self,
|
|
loop_type: str,
|
|
agent_id: str,
|
|
details: dict[str, Any],
|
|
) -> EmergencyEvent:
|
|
"""
|
|
Trigger emergency from loop detection.
|
|
|
|
Args:
|
|
loop_type: Type of loop
|
|
agent_id: Agent that's looping
|
|
details: Loop details
|
|
|
|
Returns:
|
|
Emergency event
|
|
"""
|
|
return await self._controls.pause(
|
|
reason=EmergencyReason.LOOP_DETECTED,
|
|
triggered_by="loop_detector",
|
|
message=f"Loop detected: {loop_type} in agent {agent_id}",
|
|
scope=f"agent:{agent_id}",
|
|
metadata={"loop_type": loop_type, "agent_id": agent_id, **details},
|
|
)
|
|
|
|
async def trigger_on_content_violation(
|
|
self,
|
|
category: str,
|
|
pattern: str,
|
|
scope: str = "global",
|
|
) -> EmergencyEvent:
|
|
"""
|
|
Trigger emergency from content violation.
|
|
|
|
Args:
|
|
category: Content category
|
|
pattern: Pattern that matched
|
|
scope: Scope for the emergency
|
|
|
|
Returns:
|
|
Emergency event
|
|
"""
|
|
return await self._controls.emergency_stop(
|
|
reason=EmergencyReason.CONTENT_VIOLATION,
|
|
triggered_by="content_filter",
|
|
message=f"Content violation: {category} ({pattern})",
|
|
scope=scope,
|
|
metadata={"category": category, "pattern": pattern},
|
|
)
|
|
|
|
|
|
# Singleton instance
|
|
_emergency_controls: EmergencyControls | None = None
|
|
_lock = asyncio.Lock()
|
|
|
|
|
|
async def get_emergency_controls() -> EmergencyControls:
|
|
"""Get the singleton EmergencyControls instance."""
|
|
global _emergency_controls
|
|
|
|
async with _lock:
|
|
if _emergency_controls is None:
|
|
_emergency_controls = EmergencyControls()
|
|
return _emergency_controls
|
|
|
|
|
|
async def emergency_stop_global(
|
|
reason: str,
|
|
triggered_by: str = "system",
|
|
) -> EmergencyEvent:
|
|
"""Quick global emergency stop."""
|
|
controls = await get_emergency_controls()
|
|
return await controls.emergency_stop(
|
|
reason=EmergencyReason.MANUAL,
|
|
triggered_by=triggered_by,
|
|
message=reason,
|
|
scope="global",
|
|
)
|
|
|
|
|
|
async def check_emergency_allowed(scope: str | None = None) -> bool:
|
|
"""Quick check if operations are allowed."""
|
|
controls = await get_emergency_controls()
|
|
return await controls.check_allowed(scope=scope, raise_if_blocked=False)
|