forked from cardosofelipe/pragma-stack
- Add tests for models: ActionMetadata, ActionRequest, ActionResult, ValidationRule, BudgetStatus, RateLimitConfig, ApprovalRequest/Response, Checkpoint, RollbackResult, AuditEvent, SafetyPolicy, GuardianResult - Add tests for validation: ActionValidator rules, priorities, patterns, bypass mode, batch validation, rule creation helpers - Add tests for loops: LoopDetector exact/semantic/oscillation detection, LoopBreaker throttle/backoff, history management - Add tests for content filter: PII filtering (email, phone, SSN, credit card), secret blocking (API keys, GitHub tokens, private keys), custom patterns, scan without filtering, dict filtering - Add tests for emergency controls: state management, pause/resume/reset, scoped emergency stops, callbacks, EmergencyTrigger events - Fix exception kwargs in content filter and emergency controls to match exception class signatures All 108 tests passing with lint and type checks clean.
595 lines
18 KiB
Python
595 lines
18 KiB
Python
"""
|
|
Emergency Controls
|
|
|
|
Emergency stop and pause functionality for agent safety.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
from enum import Enum
|
|
from typing import Any
|
|
|
|
from ..exceptions import EmergencyStopError
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class EmergencyState(str, Enum):
|
|
"""Emergency control states."""
|
|
|
|
NORMAL = "normal"
|
|
PAUSED = "paused"
|
|
STOPPED = "stopped"
|
|
|
|
|
|
class EmergencyReason(str, Enum):
|
|
"""Reasons for emergency actions."""
|
|
|
|
MANUAL = "manual"
|
|
SAFETY_VIOLATION = "safety_violation"
|
|
BUDGET_EXCEEDED = "budget_exceeded"
|
|
LOOP_DETECTED = "loop_detected"
|
|
RATE_LIMIT = "rate_limit"
|
|
CONTENT_VIOLATION = "content_violation"
|
|
SYSTEM_ERROR = "system_error"
|
|
EXTERNAL_TRIGGER = "external_trigger"
|
|
|
|
|
|
@dataclass
|
|
class EmergencyEvent:
|
|
"""Record of an emergency action."""
|
|
|
|
id: str
|
|
state: EmergencyState
|
|
reason: EmergencyReason
|
|
triggered_by: str
|
|
message: str
|
|
scope: str # "global", "project:<id>", "agent:<id>"
|
|
timestamp: datetime = field(default_factory=datetime.utcnow)
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|
resolved_at: datetime | None = None
|
|
resolved_by: str | None = None
|
|
|
|
|
|
class EmergencyControls:
|
|
"""
|
|
Emergency stop and pause controls for agent safety.
|
|
|
|
Features:
|
|
- Global emergency stop
|
|
- Per-project/agent emergency controls
|
|
- Graceful pause with state preservation
|
|
- Automatic triggers from safety violations
|
|
- Manual override capabilities
|
|
- Event history and audit trail
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
notification_handlers: list[Callable[..., Any]] | None = None,
|
|
) -> None:
|
|
"""
|
|
Initialize EmergencyControls.
|
|
|
|
Args:
|
|
notification_handlers: Handlers to call on emergency events
|
|
"""
|
|
self._global_state = EmergencyState.NORMAL
|
|
self._scoped_states: dict[str, EmergencyState] = {}
|
|
self._events: list[EmergencyEvent] = []
|
|
self._notification_handlers = notification_handlers or []
|
|
self._lock = asyncio.Lock()
|
|
self._event_id_counter = 0
|
|
|
|
# Callbacks for state changes
|
|
self._on_stop_callbacks: list[Callable[..., Any]] = []
|
|
self._on_pause_callbacks: list[Callable[..., Any]] = []
|
|
self._on_resume_callbacks: list[Callable[..., Any]] = []
|
|
|
|
def _generate_event_id(self) -> str:
|
|
"""Generate a unique event ID."""
|
|
self._event_id_counter += 1
|
|
return f"emerg-{self._event_id_counter:06d}"
|
|
|
|
async def emergency_stop(
|
|
self,
|
|
reason: EmergencyReason,
|
|
triggered_by: str,
|
|
message: str,
|
|
scope: str = "global",
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> EmergencyEvent:
|
|
"""
|
|
Trigger emergency stop.
|
|
|
|
Args:
|
|
reason: Reason for the stop
|
|
triggered_by: Who/what triggered the stop
|
|
message: Human-readable message
|
|
scope: Scope of the stop (global, project:<id>, agent:<id>)
|
|
metadata: Additional context
|
|
|
|
Returns:
|
|
The emergency event record
|
|
"""
|
|
async with self._lock:
|
|
event = EmergencyEvent(
|
|
id=self._generate_event_id(),
|
|
state=EmergencyState.STOPPED,
|
|
reason=reason,
|
|
triggered_by=triggered_by,
|
|
message=message,
|
|
scope=scope,
|
|
metadata=metadata or {},
|
|
)
|
|
|
|
if scope == "global":
|
|
self._global_state = EmergencyState.STOPPED
|
|
else:
|
|
self._scoped_states[scope] = EmergencyState.STOPPED
|
|
|
|
self._events.append(event)
|
|
|
|
logger.critical(
|
|
"EMERGENCY STOP: scope=%s, reason=%s, by=%s - %s",
|
|
scope,
|
|
reason.value,
|
|
triggered_by,
|
|
message,
|
|
)
|
|
|
|
# Execute callbacks
|
|
await self._execute_callbacks(self._on_stop_callbacks, event)
|
|
await self._notify_handlers("emergency_stop", event)
|
|
|
|
return event
|
|
|
|
async def pause(
|
|
self,
|
|
reason: EmergencyReason,
|
|
triggered_by: str,
|
|
message: str,
|
|
scope: str = "global",
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> EmergencyEvent:
|
|
"""
|
|
Pause operations (can be resumed).
|
|
|
|
Args:
|
|
reason: Reason for the pause
|
|
triggered_by: Who/what triggered the pause
|
|
message: Human-readable message
|
|
scope: Scope of the pause
|
|
metadata: Additional context
|
|
|
|
Returns:
|
|
The emergency event record
|
|
"""
|
|
async with self._lock:
|
|
event = EmergencyEvent(
|
|
id=self._generate_event_id(),
|
|
state=EmergencyState.PAUSED,
|
|
reason=reason,
|
|
triggered_by=triggered_by,
|
|
message=message,
|
|
scope=scope,
|
|
metadata=metadata or {},
|
|
)
|
|
|
|
if scope == "global":
|
|
self._global_state = EmergencyState.PAUSED
|
|
else:
|
|
self._scoped_states[scope] = EmergencyState.PAUSED
|
|
|
|
self._events.append(event)
|
|
|
|
logger.warning(
|
|
"PAUSE: scope=%s, reason=%s, by=%s - %s",
|
|
scope,
|
|
reason.value,
|
|
triggered_by,
|
|
message,
|
|
)
|
|
|
|
await self._execute_callbacks(self._on_pause_callbacks, event)
|
|
await self._notify_handlers("pause", event)
|
|
|
|
return event
|
|
|
|
async def resume(
|
|
self,
|
|
scope: str = "global",
|
|
resumed_by: str = "system",
|
|
message: str | None = None,
|
|
) -> bool:
|
|
"""
|
|
Resume operations from paused state.
|
|
|
|
Args:
|
|
scope: Scope to resume
|
|
resumed_by: Who/what is resuming
|
|
message: Optional message
|
|
|
|
Returns:
|
|
True if resumed, False if not in paused state
|
|
"""
|
|
async with self._lock:
|
|
current_state = self._get_state(scope)
|
|
|
|
if current_state == EmergencyState.STOPPED:
|
|
logger.warning(
|
|
"Cannot resume from STOPPED state: %s (requires reset)",
|
|
scope,
|
|
)
|
|
return False
|
|
|
|
if current_state == EmergencyState.NORMAL:
|
|
return True # Already normal
|
|
|
|
# Find the pause event and mark as resolved
|
|
for event in reversed(self._events):
|
|
if event.scope == scope and event.state == EmergencyState.PAUSED:
|
|
if event.resolved_at is None:
|
|
event.resolved_at = datetime.utcnow()
|
|
event.resolved_by = resumed_by
|
|
break
|
|
|
|
if scope == "global":
|
|
self._global_state = EmergencyState.NORMAL
|
|
else:
|
|
self._scoped_states[scope] = EmergencyState.NORMAL
|
|
|
|
logger.info(
|
|
"RESUMED: scope=%s, by=%s%s",
|
|
scope,
|
|
resumed_by,
|
|
f" - {message}" if message else "",
|
|
)
|
|
|
|
await self._execute_callbacks(
|
|
self._on_resume_callbacks,
|
|
{"scope": scope, "resumed_by": resumed_by},
|
|
)
|
|
await self._notify_handlers("resume", {"scope": scope, "resumed_by": resumed_by})
|
|
|
|
return True
|
|
|
|
async def reset(
|
|
self,
|
|
scope: str = "global",
|
|
reset_by: str = "admin",
|
|
message: str | None = None,
|
|
) -> bool:
|
|
"""
|
|
Reset from stopped state (requires explicit action).
|
|
|
|
Args:
|
|
scope: Scope to reset
|
|
reset_by: Who is resetting (should be admin)
|
|
message: Optional message
|
|
|
|
Returns:
|
|
True if reset successful
|
|
"""
|
|
async with self._lock:
|
|
current_state = self._get_state(scope)
|
|
|
|
if current_state == EmergencyState.NORMAL:
|
|
return True
|
|
|
|
# Find the stop event and mark as resolved
|
|
for event in reversed(self._events):
|
|
if event.scope == scope and event.state == EmergencyState.STOPPED:
|
|
if event.resolved_at is None:
|
|
event.resolved_at = datetime.utcnow()
|
|
event.resolved_by = reset_by
|
|
break
|
|
|
|
if scope == "global":
|
|
self._global_state = EmergencyState.NORMAL
|
|
else:
|
|
self._scoped_states[scope] = EmergencyState.NORMAL
|
|
|
|
logger.warning(
|
|
"EMERGENCY RESET: scope=%s, by=%s%s",
|
|
scope,
|
|
reset_by,
|
|
f" - {message}" if message else "",
|
|
)
|
|
|
|
await self._notify_handlers("reset", {"scope": scope, "reset_by": reset_by})
|
|
|
|
return True
|
|
|
|
async def check_allowed(
|
|
self,
|
|
scope: str | None = None,
|
|
raise_if_blocked: bool = True,
|
|
) -> bool:
|
|
"""
|
|
Check if operations are allowed.
|
|
|
|
Args:
|
|
scope: Specific scope to check (also checks global)
|
|
raise_if_blocked: Raise exception if blocked
|
|
|
|
Returns:
|
|
True if operations are allowed
|
|
|
|
Raises:
|
|
EmergencyStopError: If blocked and raise_if_blocked=True
|
|
"""
|
|
async with self._lock:
|
|
# Always check global state
|
|
if self._global_state != EmergencyState.NORMAL:
|
|
if raise_if_blocked:
|
|
raise EmergencyStopError(
|
|
f"Global emergency state: {self._global_state.value}",
|
|
stop_type=self._get_last_reason("global") or "emergency",
|
|
triggered_by=self._get_last_triggered_by("global"),
|
|
)
|
|
return False
|
|
|
|
# Check specific scope
|
|
if scope and scope in self._scoped_states:
|
|
state = self._scoped_states[scope]
|
|
if state != EmergencyState.NORMAL:
|
|
if raise_if_blocked:
|
|
raise EmergencyStopError(
|
|
f"Emergency state for {scope}: {state.value}",
|
|
stop_type=self._get_last_reason(scope) or "emergency",
|
|
triggered_by=self._get_last_triggered_by(scope),
|
|
details={"scope": scope},
|
|
)
|
|
return False
|
|
|
|
return True
|
|
|
|
def _get_state(self, scope: str) -> EmergencyState:
|
|
"""Get state for a scope."""
|
|
if scope == "global":
|
|
return self._global_state
|
|
return self._scoped_states.get(scope, EmergencyState.NORMAL)
|
|
|
|
def _get_last_reason(self, scope: str) -> str:
|
|
"""Get reason from last event for scope."""
|
|
for event in reversed(self._events):
|
|
if event.scope == scope and event.resolved_at is None:
|
|
return event.reason.value
|
|
return "unknown"
|
|
|
|
def _get_last_triggered_by(self, scope: str) -> str:
|
|
"""Get triggered_by from last event for scope."""
|
|
for event in reversed(self._events):
|
|
if event.scope == scope and event.resolved_at is None:
|
|
return event.triggered_by
|
|
return "unknown"
|
|
|
|
async def get_state(self, scope: str = "global") -> EmergencyState:
|
|
"""Get current state for a scope."""
|
|
async with self._lock:
|
|
return self._get_state(scope)
|
|
|
|
async def get_all_states(self) -> dict[str, EmergencyState]:
|
|
"""Get all current states."""
|
|
async with self._lock:
|
|
states = {"global": self._global_state}
|
|
states.update(self._scoped_states)
|
|
return states
|
|
|
|
async def get_active_events(self) -> list[EmergencyEvent]:
|
|
"""Get all unresolved emergency events."""
|
|
async with self._lock:
|
|
return [e for e in self._events if e.resolved_at is None]
|
|
|
|
async def get_event_history(
|
|
self,
|
|
scope: str | None = None,
|
|
limit: int = 100,
|
|
) -> list[EmergencyEvent]:
|
|
"""Get emergency event history."""
|
|
async with self._lock:
|
|
events = list(self._events)
|
|
|
|
if scope:
|
|
events = [e for e in events if e.scope == scope]
|
|
|
|
return events[-limit:]
|
|
|
|
def on_stop(self, callback: Callable[..., Any]) -> None:
|
|
"""Register callback for stop events."""
|
|
self._on_stop_callbacks.append(callback)
|
|
|
|
def on_pause(self, callback: Callable[..., Any]) -> None:
|
|
"""Register callback for pause events."""
|
|
self._on_pause_callbacks.append(callback)
|
|
|
|
def on_resume(self, callback: Callable[..., Any]) -> None:
|
|
"""Register callback for resume events."""
|
|
self._on_resume_callbacks.append(callback)
|
|
|
|
def add_notification_handler(self, handler: Callable[..., Any]) -> None:
|
|
"""Add a notification handler."""
|
|
self._notification_handlers.append(handler)
|
|
|
|
async def _execute_callbacks(
|
|
self,
|
|
callbacks: list[Callable[..., Any]],
|
|
data: Any,
|
|
) -> None:
|
|
"""Execute callbacks safely."""
|
|
for callback in callbacks:
|
|
try:
|
|
if asyncio.iscoroutinefunction(callback):
|
|
await callback(data)
|
|
else:
|
|
callback(data)
|
|
except Exception as e:
|
|
logger.error("Error in callback: %s", e)
|
|
|
|
async def _notify_handlers(self, event_type: str, data: Any) -> None:
|
|
"""Notify all handlers of an event."""
|
|
for handler in self._notification_handlers:
|
|
try:
|
|
if asyncio.iscoroutinefunction(handler):
|
|
await handler(event_type, data)
|
|
else:
|
|
handler(event_type, data)
|
|
except Exception as e:
|
|
logger.error("Error in notification handler: %s", e)
|
|
|
|
|
|
class EmergencyTrigger:
|
|
"""
|
|
Automatic emergency triggers based on conditions.
|
|
"""
|
|
|
|
def __init__(self, controls: EmergencyControls) -> None:
|
|
"""
|
|
Initialize EmergencyTrigger.
|
|
|
|
Args:
|
|
controls: EmergencyControls instance to trigger
|
|
"""
|
|
self._controls = controls
|
|
|
|
async def trigger_on_safety_violation(
|
|
self,
|
|
violation_type: str,
|
|
details: dict[str, Any],
|
|
scope: str = "global",
|
|
) -> EmergencyEvent:
|
|
"""
|
|
Trigger emergency from safety violation.
|
|
|
|
Args:
|
|
violation_type: Type of violation
|
|
details: Violation details
|
|
scope: Scope for the emergency
|
|
|
|
Returns:
|
|
Emergency event
|
|
"""
|
|
return await self._controls.emergency_stop(
|
|
reason=EmergencyReason.SAFETY_VIOLATION,
|
|
triggered_by="safety_system",
|
|
message=f"Safety violation: {violation_type}",
|
|
scope=scope,
|
|
metadata={"violation_type": violation_type, **details},
|
|
)
|
|
|
|
async def trigger_on_budget_exceeded(
|
|
self,
|
|
budget_type: str,
|
|
current: float,
|
|
limit: float,
|
|
scope: str = "global",
|
|
) -> EmergencyEvent:
|
|
"""
|
|
Trigger emergency from budget exceeded.
|
|
|
|
Args:
|
|
budget_type: Type of budget
|
|
current: Current usage
|
|
limit: Budget limit
|
|
scope: Scope for the emergency
|
|
|
|
Returns:
|
|
Emergency event
|
|
"""
|
|
return await self._controls.pause(
|
|
reason=EmergencyReason.BUDGET_EXCEEDED,
|
|
triggered_by="budget_controller",
|
|
message=f"Budget exceeded: {budget_type} ({current:.2f}/{limit:.2f})",
|
|
scope=scope,
|
|
metadata={"budget_type": budget_type, "current": current, "limit": limit},
|
|
)
|
|
|
|
async def trigger_on_loop_detected(
|
|
self,
|
|
loop_type: str,
|
|
agent_id: str,
|
|
details: dict[str, Any],
|
|
) -> EmergencyEvent:
|
|
"""
|
|
Trigger emergency from loop detection.
|
|
|
|
Args:
|
|
loop_type: Type of loop
|
|
agent_id: Agent that's looping
|
|
details: Loop details
|
|
|
|
Returns:
|
|
Emergency event
|
|
"""
|
|
return await self._controls.pause(
|
|
reason=EmergencyReason.LOOP_DETECTED,
|
|
triggered_by="loop_detector",
|
|
message=f"Loop detected: {loop_type} in agent {agent_id}",
|
|
scope=f"agent:{agent_id}",
|
|
metadata={"loop_type": loop_type, "agent_id": agent_id, **details},
|
|
)
|
|
|
|
async def trigger_on_content_violation(
|
|
self,
|
|
category: str,
|
|
pattern: str,
|
|
scope: str = "global",
|
|
) -> EmergencyEvent:
|
|
"""
|
|
Trigger emergency from content violation.
|
|
|
|
Args:
|
|
category: Content category
|
|
pattern: Pattern that matched
|
|
scope: Scope for the emergency
|
|
|
|
Returns:
|
|
Emergency event
|
|
"""
|
|
return await self._controls.emergency_stop(
|
|
reason=EmergencyReason.CONTENT_VIOLATION,
|
|
triggered_by="content_filter",
|
|
message=f"Content violation: {category} ({pattern})",
|
|
scope=scope,
|
|
metadata={"category": category, "pattern": pattern},
|
|
)
|
|
|
|
|
|
# Singleton instance
|
|
_emergency_controls: EmergencyControls | None = None
|
|
_lock = asyncio.Lock()
|
|
|
|
|
|
async def get_emergency_controls() -> EmergencyControls:
|
|
"""Get the singleton EmergencyControls instance."""
|
|
global _emergency_controls
|
|
|
|
async with _lock:
|
|
if _emergency_controls is None:
|
|
_emergency_controls = EmergencyControls()
|
|
return _emergency_controls
|
|
|
|
|
|
async def emergency_stop_global(
|
|
reason: str,
|
|
triggered_by: str = "system",
|
|
) -> EmergencyEvent:
|
|
"""Quick global emergency stop."""
|
|
controls = await get_emergency_controls()
|
|
return await controls.emergency_stop(
|
|
reason=EmergencyReason.MANUAL,
|
|
triggered_by=triggered_by,
|
|
message=reason,
|
|
scope="global",
|
|
)
|
|
|
|
|
|
async def check_emergency_allowed(scope: str | None = None) -> bool:
|
|
"""Quick check if operations are allowed."""
|
|
controls = await get_emergency_controls()
|
|
return await controls.check_allowed(scope=scope, raise_if_blocked=False)
|