""" Emergency Controls Emergency stop and pause functionality for agent safety. """ import asyncio import logging from collections.abc import Callable from dataclasses import dataclass, field from datetime import datetime from enum import Enum from typing import Any from ..exceptions import EmergencyStopError logger = logging.getLogger(__name__) class EmergencyState(str, Enum): """Emergency control states.""" NORMAL = "normal" PAUSED = "paused" STOPPED = "stopped" class EmergencyReason(str, Enum): """Reasons for emergency actions.""" MANUAL = "manual" SAFETY_VIOLATION = "safety_violation" BUDGET_EXCEEDED = "budget_exceeded" LOOP_DETECTED = "loop_detected" RATE_LIMIT = "rate_limit" CONTENT_VIOLATION = "content_violation" SYSTEM_ERROR = "system_error" EXTERNAL_TRIGGER = "external_trigger" @dataclass class EmergencyEvent: """Record of an emergency action.""" id: str state: EmergencyState reason: EmergencyReason triggered_by: str message: str scope: str # "global", "project:", "agent:" timestamp: datetime = field(default_factory=datetime.utcnow) metadata: dict[str, Any] = field(default_factory=dict) resolved_at: datetime | None = None resolved_by: str | None = None class EmergencyControls: """ Emergency stop and pause controls for agent safety. Features: - Global emergency stop - Per-project/agent emergency controls - Graceful pause with state preservation - Automatic triggers from safety violations - Manual override capabilities - Event history and audit trail """ def __init__( self, notification_handlers: list[Callable[..., Any]] | None = None, ) -> None: """ Initialize EmergencyControls. Args: notification_handlers: Handlers to call on emergency events """ self._global_state = EmergencyState.NORMAL self._scoped_states: dict[str, EmergencyState] = {} self._events: list[EmergencyEvent] = [] self._notification_handlers = notification_handlers or [] self._lock = asyncio.Lock() self._event_id_counter = 0 # Callbacks for state changes self._on_stop_callbacks: list[Callable[..., Any]] = [] self._on_pause_callbacks: list[Callable[..., Any]] = [] self._on_resume_callbacks: list[Callable[..., Any]] = [] def _generate_event_id(self) -> str: """Generate a unique event ID.""" self._event_id_counter += 1 return f"emerg-{self._event_id_counter:06d}" async def emergency_stop( self, reason: EmergencyReason, triggered_by: str, message: str, scope: str = "global", metadata: dict[str, Any] | None = None, ) -> EmergencyEvent: """ Trigger emergency stop. Args: reason: Reason for the stop triggered_by: Who/what triggered the stop message: Human-readable message scope: Scope of the stop (global, project:, agent:) metadata: Additional context Returns: The emergency event record """ async with self._lock: event = EmergencyEvent( id=self._generate_event_id(), state=EmergencyState.STOPPED, reason=reason, triggered_by=triggered_by, message=message, scope=scope, metadata=metadata or {}, ) if scope == "global": self._global_state = EmergencyState.STOPPED else: self._scoped_states[scope] = EmergencyState.STOPPED self._events.append(event) logger.critical( "EMERGENCY STOP: scope=%s, reason=%s, by=%s - %s", scope, reason.value, triggered_by, message, ) # Execute callbacks await self._execute_callbacks(self._on_stop_callbacks, event) await self._notify_handlers("emergency_stop", event) return event async def pause( self, reason: EmergencyReason, triggered_by: str, message: str, scope: str = "global", metadata: dict[str, Any] | None = None, ) -> EmergencyEvent: """ Pause operations (can be resumed). Args: reason: Reason for the pause triggered_by: Who/what triggered the pause message: Human-readable message scope: Scope of the pause metadata: Additional context Returns: The emergency event record """ async with self._lock: event = EmergencyEvent( id=self._generate_event_id(), state=EmergencyState.PAUSED, reason=reason, triggered_by=triggered_by, message=message, scope=scope, metadata=metadata or {}, ) if scope == "global": self._global_state = EmergencyState.PAUSED else: self._scoped_states[scope] = EmergencyState.PAUSED self._events.append(event) logger.warning( "PAUSE: scope=%s, reason=%s, by=%s - %s", scope, reason.value, triggered_by, message, ) await self._execute_callbacks(self._on_pause_callbacks, event) await self._notify_handlers("pause", event) return event async def resume( self, scope: str = "global", resumed_by: str = "system", message: str | None = None, ) -> bool: """ Resume operations from paused state. Args: scope: Scope to resume resumed_by: Who/what is resuming message: Optional message Returns: True if resumed, False if not in paused state """ async with self._lock: current_state = self._get_state(scope) if current_state == EmergencyState.STOPPED: logger.warning( "Cannot resume from STOPPED state: %s (requires reset)", scope, ) return False if current_state == EmergencyState.NORMAL: return True # Already normal # Find the pause event and mark as resolved for event in reversed(self._events): if event.scope == scope and event.state == EmergencyState.PAUSED: if event.resolved_at is None: event.resolved_at = datetime.utcnow() event.resolved_by = resumed_by break if scope == "global": self._global_state = EmergencyState.NORMAL else: self._scoped_states[scope] = EmergencyState.NORMAL logger.info( "RESUMED: scope=%s, by=%s%s", scope, resumed_by, f" - {message}" if message else "", ) await self._execute_callbacks( self._on_resume_callbacks, {"scope": scope, "resumed_by": resumed_by}, ) await self._notify_handlers("resume", {"scope": scope, "resumed_by": resumed_by}) return True async def reset( self, scope: str = "global", reset_by: str = "admin", message: str | None = None, ) -> bool: """ Reset from stopped state (requires explicit action). Args: scope: Scope to reset reset_by: Who is resetting (should be admin) message: Optional message Returns: True if reset successful """ async with self._lock: current_state = self._get_state(scope) if current_state == EmergencyState.NORMAL: return True # Find the stop event and mark as resolved for event in reversed(self._events): if event.scope == scope and event.state == EmergencyState.STOPPED: if event.resolved_at is None: event.resolved_at = datetime.utcnow() event.resolved_by = reset_by break if scope == "global": self._global_state = EmergencyState.NORMAL else: self._scoped_states[scope] = EmergencyState.NORMAL logger.warning( "EMERGENCY RESET: scope=%s, by=%s%s", scope, reset_by, f" - {message}" if message else "", ) await self._notify_handlers("reset", {"scope": scope, "reset_by": reset_by}) return True async def check_allowed( self, scope: str | None = None, raise_if_blocked: bool = True, ) -> bool: """ Check if operations are allowed. Args: scope: Specific scope to check (also checks global) raise_if_blocked: Raise exception if blocked Returns: True if operations are allowed Raises: EmergencyStopError: If blocked and raise_if_blocked=True """ async with self._lock: # Always check global state if self._global_state != EmergencyState.NORMAL: if raise_if_blocked: raise EmergencyStopError( f"Global emergency state: {self._global_state.value}", stop_type=self._get_last_reason("global") or "emergency", triggered_by=self._get_last_triggered_by("global"), ) return False # Check specific scope if scope and scope in self._scoped_states: state = self._scoped_states[scope] if state != EmergencyState.NORMAL: if raise_if_blocked: raise EmergencyStopError( f"Emergency state for {scope}: {state.value}", stop_type=self._get_last_reason(scope) or "emergency", triggered_by=self._get_last_triggered_by(scope), details={"scope": scope}, ) return False return True def _get_state(self, scope: str) -> EmergencyState: """Get state for a scope.""" if scope == "global": return self._global_state return self._scoped_states.get(scope, EmergencyState.NORMAL) def _get_last_reason(self, scope: str) -> str: """Get reason from last event for scope.""" for event in reversed(self._events): if event.scope == scope and event.resolved_at is None: return event.reason.value return "unknown" def _get_last_triggered_by(self, scope: str) -> str: """Get triggered_by from last event for scope.""" for event in reversed(self._events): if event.scope == scope and event.resolved_at is None: return event.triggered_by return "unknown" async def get_state(self, scope: str = "global") -> EmergencyState: """Get current state for a scope.""" async with self._lock: return self._get_state(scope) async def get_all_states(self) -> dict[str, EmergencyState]: """Get all current states.""" async with self._lock: states = {"global": self._global_state} states.update(self._scoped_states) return states async def get_active_events(self) -> list[EmergencyEvent]: """Get all unresolved emergency events.""" async with self._lock: return [e for e in self._events if e.resolved_at is None] async def get_event_history( self, scope: str | None = None, limit: int = 100, ) -> list[EmergencyEvent]: """Get emergency event history.""" async with self._lock: events = list(self._events) if scope: events = [e for e in events if e.scope == scope] return events[-limit:] def on_stop(self, callback: Callable[..., Any]) -> None: """Register callback for stop events.""" self._on_stop_callbacks.append(callback) def on_pause(self, callback: Callable[..., Any]) -> None: """Register callback for pause events.""" self._on_pause_callbacks.append(callback) def on_resume(self, callback: Callable[..., Any]) -> None: """Register callback for resume events.""" self._on_resume_callbacks.append(callback) def add_notification_handler(self, handler: Callable[..., Any]) -> None: """Add a notification handler.""" self._notification_handlers.append(handler) async def _execute_callbacks( self, callbacks: list[Callable[..., Any]], data: Any, ) -> None: """Execute callbacks safely.""" for callback in callbacks: try: if asyncio.iscoroutinefunction(callback): await callback(data) else: callback(data) except Exception as e: logger.error("Error in callback: %s", e) async def _notify_handlers(self, event_type: str, data: Any) -> None: """Notify all handlers of an event.""" for handler in self._notification_handlers: try: if asyncio.iscoroutinefunction(handler): await handler(event_type, data) else: handler(event_type, data) except Exception as e: logger.error("Error in notification handler: %s", e) class EmergencyTrigger: """ Automatic emergency triggers based on conditions. """ def __init__(self, controls: EmergencyControls) -> None: """ Initialize EmergencyTrigger. Args: controls: EmergencyControls instance to trigger """ self._controls = controls async def trigger_on_safety_violation( self, violation_type: str, details: dict[str, Any], scope: str = "global", ) -> EmergencyEvent: """ Trigger emergency from safety violation. Args: violation_type: Type of violation details: Violation details scope: Scope for the emergency Returns: Emergency event """ return await self._controls.emergency_stop( reason=EmergencyReason.SAFETY_VIOLATION, triggered_by="safety_system", message=f"Safety violation: {violation_type}", scope=scope, metadata={"violation_type": violation_type, **details}, ) async def trigger_on_budget_exceeded( self, budget_type: str, current: float, limit: float, scope: str = "global", ) -> EmergencyEvent: """ Trigger emergency from budget exceeded. Args: budget_type: Type of budget current: Current usage limit: Budget limit scope: Scope for the emergency Returns: Emergency event """ return await self._controls.pause( reason=EmergencyReason.BUDGET_EXCEEDED, triggered_by="budget_controller", message=f"Budget exceeded: {budget_type} ({current:.2f}/{limit:.2f})", scope=scope, metadata={"budget_type": budget_type, "current": current, "limit": limit}, ) async def trigger_on_loop_detected( self, loop_type: str, agent_id: str, details: dict[str, Any], ) -> EmergencyEvent: """ Trigger emergency from loop detection. Args: loop_type: Type of loop agent_id: Agent that's looping details: Loop details Returns: Emergency event """ return await self._controls.pause( reason=EmergencyReason.LOOP_DETECTED, triggered_by="loop_detector", message=f"Loop detected: {loop_type} in agent {agent_id}", scope=f"agent:{agent_id}", metadata={"loop_type": loop_type, "agent_id": agent_id, **details}, ) async def trigger_on_content_violation( self, category: str, pattern: str, scope: str = "global", ) -> EmergencyEvent: """ Trigger emergency from content violation. Args: category: Content category pattern: Pattern that matched scope: Scope for the emergency Returns: Emergency event """ return await self._controls.emergency_stop( reason=EmergencyReason.CONTENT_VIOLATION, triggered_by="content_filter", message=f"Content violation: {category} ({pattern})", scope=scope, metadata={"category": category, "pattern": pattern}, ) # Singleton instance _emergency_controls: EmergencyControls | None = None _lock = asyncio.Lock() async def get_emergency_controls() -> EmergencyControls: """Get the singleton EmergencyControls instance.""" global _emergency_controls async with _lock: if _emergency_controls is None: _emergency_controls = EmergencyControls() return _emergency_controls async def emergency_stop_global( reason: str, triggered_by: str = "system", ) -> EmergencyEvent: """Quick global emergency stop.""" controls = await get_emergency_controls() return await controls.emergency_stop( reason=EmergencyReason.MANUAL, triggered_by=triggered_by, message=reason, scope="global", ) async def check_emergency_allowed(scope: str | None = None) -> bool: """Quick check if operations are allowed.""" controls = await get_emergency_controls() return await controls.check_allowed(scope=scope, raise_if_blocked=False)