From ef659cd72d790ea32f653c6cd4f1376a4248b2a1 Mon Sep 17 00:00:00 2001 From: Felipe Cardoso Date: Sat, 3 Jan 2026 11:36:24 +0100 Subject: [PATCH] feat(safety): add Phase C advanced controls MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add rollback manager with file checkpointing and transaction context - Add HITL manager with approval queues and notification handlers - Add content filter with PII, secrets, and injection detection - Add emergency controls with stop/pause/resume capabilities - Update SafetyConfig with checkpoint_dir setting Issue #63 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- backend/app/services/safety/config.py | 4 + .../app/services/safety/content/__init__.py | 24 +- backend/app/services/safety/content/filter.py | 532 ++++++++++++++++ .../app/services/safety/emergency/__init__.py | 24 +- .../app/services/safety/emergency/controls.py | 594 ++++++++++++++++++ backend/app/services/safety/hitl/__init__.py | 6 +- backend/app/services/safety/hitl/manager.py | 449 +++++++++++++ .../app/services/safety/rollback/__init__.py | 6 +- .../app/services/safety/rollback/manager.py | 418 ++++++++++++ 9 files changed, 2053 insertions(+), 4 deletions(-) create mode 100644 backend/app/services/safety/content/filter.py create mode 100644 backend/app/services/safety/emergency/controls.py create mode 100644 backend/app/services/safety/hitl/manager.py create mode 100644 backend/app/services/safety/rollback/manager.py diff --git a/backend/app/services/safety/config.py b/backend/app/services/safety/config.py index 2488895..a4fa90a 100644 --- a/backend/app/services/safety/config.py +++ b/backend/app/services/safety/config.py @@ -74,6 +74,10 @@ class SafetyConfig(BaseSettings): # Rollback settings rollback_enabled: bool = Field(True, description="Enable rollback capability") + checkpoint_dir: str = Field( + "/tmp/syndarix_checkpoints", # noqa: S108 + description="Directory for checkpoint storage", + ) checkpoint_retention_hours: int = Field(24, description="Checkpoint retention") auto_checkpoint_destructive: bool = Field( True, description="Auto-checkpoint destructive actions" diff --git a/backend/app/services/safety/content/__init__.py b/backend/app/services/safety/content/__init__.py index 9f4729c..4cfa43a 100644 --- a/backend/app/services/safety/content/__init__.py +++ b/backend/app/services/safety/content/__init__.py @@ -1 +1,23 @@ -"""${dir} module.""" +"""Content filtering for safety.""" + +from .filter import ( + ContentCategory, + ContentFilter, + FilterAction, + FilterMatch, + FilterPattern, + FilterResult, + filter_content, + scan_for_secrets, +) + +__all__ = [ + "ContentCategory", + "ContentFilter", + "FilterAction", + "FilterMatch", + "FilterPattern", + "FilterResult", + "filter_content", + "scan_for_secrets", +] diff --git a/backend/app/services/safety/content/filter.py b/backend/app/services/safety/content/filter.py new file mode 100644 index 0000000..8f3b795 --- /dev/null +++ b/backend/app/services/safety/content/filter.py @@ -0,0 +1,532 @@ +""" +Content Filter + +Filters and sanitizes content for safety, including PII detection and secret scanning. +""" + +import asyncio +import logging +import re +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, ClassVar + +from ..exceptions import ContentFilterError + +logger = logging.getLogger(__name__) + + +class ContentCategory(str, Enum): + """Categories of sensitive content.""" + + PII = "pii" + SECRETS = "secrets" + CREDENTIALS = "credentials" + FINANCIAL = "financial" + HEALTH = "health" + PROFANITY = "profanity" + INJECTION = "injection" + CUSTOM = "custom" + + +class FilterAction(str, Enum): + """Actions to take on detected content.""" + + ALLOW = "allow" + REDACT = "redact" + BLOCK = "block" + WARN = "warn" + + +@dataclass +class FilterMatch: + """A match found by a filter.""" + + category: ContentCategory + pattern_name: str + matched_text: str + start_pos: int + end_pos: int + confidence: float = 1.0 + redacted_text: str | None = None + + +@dataclass +class FilterResult: + """Result of content filtering.""" + + original_content: str + filtered_content: str + matches: list[FilterMatch] = field(default_factory=list) + blocked: bool = False + block_reason: str | None = None + warnings: list[str] = field(default_factory=list) + + @property + def has_sensitive_content(self) -> bool: + """Check if any sensitive content was found.""" + return len(self.matches) > 0 + + +@dataclass +class FilterPattern: + """A pattern for detecting sensitive content.""" + + name: str + category: ContentCategory + pattern: str # Regex pattern + action: FilterAction = FilterAction.REDACT + replacement: str = "[REDACTED]" + confidence: float = 1.0 + enabled: bool = True + + def __post_init__(self) -> None: + """Compile the regex pattern.""" + self._compiled = re.compile(self.pattern, re.IGNORECASE | re.MULTILINE) + + def find_matches(self, content: str) -> list[FilterMatch]: + """Find all matches in content.""" + matches = [] + for match in self._compiled.finditer(content): + matches.append( + FilterMatch( + category=self.category, + pattern_name=self.name, + matched_text=match.group(), + start_pos=match.start(), + end_pos=match.end(), + confidence=self.confidence, + redacted_text=self.replacement, + ) + ) + return matches + + +class ContentFilter: + """ + Filters content for sensitive information. + + Features: + - PII detection (emails, phones, SSN, etc.) + - Secret scanning (API keys, tokens, passwords) + - Credential detection + - Injection attack prevention + - Custom pattern support + - Configurable actions (allow, redact, block, warn) + """ + + # Default patterns for common sensitive data + DEFAULT_PATTERNS: ClassVar[list[FilterPattern]] = [ + # PII Patterns + FilterPattern( + name="email", + category=ContentCategory.PII, + pattern=r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", + action=FilterAction.REDACT, + replacement="[EMAIL]", + ), + FilterPattern( + name="phone_us", + category=ContentCategory.PII, + pattern=r"\b(?:\+1[-.\s]?)?(?:\(?\d{3}\)?[-.\s]?)?\d{3}[-.\s]?\d{4}\b", + action=FilterAction.REDACT, + replacement="[PHONE]", + ), + FilterPattern( + name="ssn", + category=ContentCategory.PII, + pattern=r"\b\d{3}[-\s]?\d{2}[-\s]?\d{4}\b", + action=FilterAction.REDACT, + replacement="[SSN]", + ), + FilterPattern( + name="credit_card", + category=ContentCategory.FINANCIAL, + pattern=r"\b(?:\d{4}[-\s]?){3}\d{4}\b", + action=FilterAction.REDACT, + replacement="[CREDIT_CARD]", + ), + FilterPattern( + name="ip_address", + category=ContentCategory.PII, + pattern=r"\b(?:\d{1,3}\.){3}\d{1,3}\b", + action=FilterAction.WARN, + replacement="[IP]", + confidence=0.8, + ), + # Secret Patterns + FilterPattern( + name="api_key_generic", + category=ContentCategory.SECRETS, + pattern=r"\b(?:api[_-]?key|apikey)\s*[:=]\s*['\"]?([A-Za-z0-9_-]{20,})['\"]?", + action=FilterAction.BLOCK, + replacement="[API_KEY]", + ), + FilterPattern( + name="aws_access_key", + category=ContentCategory.SECRETS, + pattern=r"\bAKIA[0-9A-Z]{16}\b", + action=FilterAction.BLOCK, + replacement="[AWS_KEY]", + ), + FilterPattern( + name="aws_secret_key", + category=ContentCategory.SECRETS, + pattern=r"\b[A-Za-z0-9/+=]{40}\b", + action=FilterAction.WARN, + replacement="[AWS_SECRET]", + confidence=0.6, # Lower confidence - might be false positive + ), + FilterPattern( + name="github_token", + category=ContentCategory.SECRETS, + pattern=r"\b(ghp|gho|ghu|ghs|ghr)_[A-Za-z0-9]{36,}\b", + action=FilterAction.BLOCK, + replacement="[GITHUB_TOKEN]", + ), + FilterPattern( + name="jwt_token", + category=ContentCategory.SECRETS, + pattern=r"\beyJ[A-Za-z0-9_-]*\.eyJ[A-Za-z0-9_-]*\.[A-Za-z0-9_-]*\b", + action=FilterAction.BLOCK, + replacement="[JWT]", + ), + # Credential Patterns + FilterPattern( + name="password_in_url", + category=ContentCategory.CREDENTIALS, + pattern=r"://[^:]+:([^@]+)@", + action=FilterAction.BLOCK, + replacement="://[REDACTED]@", + ), + FilterPattern( + name="password_assignment", + category=ContentCategory.CREDENTIALS, + pattern=r"\b(?:password|passwd|pwd)\s*[:=]\s*['\"]?([^\s'\"]+)['\"]?", + action=FilterAction.REDACT, + replacement="[PASSWORD]", + ), + FilterPattern( + name="private_key", + category=ContentCategory.SECRETS, + pattern=r"-----BEGIN (?:RSA |DSA |EC |OPENSSH )?PRIVATE KEY-----", + action=FilterAction.BLOCK, + replacement="[PRIVATE_KEY]", + ), + # Injection Patterns + FilterPattern( + name="sql_injection", + category=ContentCategory.INJECTION, + pattern=r"(?:'\s*(?:OR|AND)\s*')|(?:--\s*$)|(?:;\s*(?:DROP|DELETE|UPDATE|INSERT))", + action=FilterAction.BLOCK, + replacement="[BLOCKED]", + ), + FilterPattern( + name="command_injection", + category=ContentCategory.INJECTION, + pattern=r"[;&|`$]|\$\(|\$\{", + action=FilterAction.WARN, + replacement="[CMD]", + confidence=0.5, # Low confidence - common in code + ), + ] + + def __init__( + self, + enable_pii_filter: bool = True, + enable_secret_filter: bool = True, + enable_injection_filter: bool = True, + custom_patterns: list[FilterPattern] | None = None, + default_action: FilterAction = FilterAction.REDACT, + ) -> None: + """ + Initialize the ContentFilter. + + Args: + enable_pii_filter: Enable PII detection + enable_secret_filter: Enable secret scanning + enable_injection_filter: Enable injection detection + custom_patterns: Additional custom patterns + default_action: Default action for matches + """ + self._patterns: list[FilterPattern] = [] + self._default_action = default_action + self._lock = asyncio.Lock() + + # Load default patterns based on configuration + for pattern in self.DEFAULT_PATTERNS: + if pattern.category == ContentCategory.PII and not enable_pii_filter: + continue + if pattern.category == ContentCategory.SECRETS and not enable_secret_filter: + continue + if pattern.category == ContentCategory.CREDENTIALS and not enable_secret_filter: + continue + if pattern.category == ContentCategory.INJECTION and not enable_injection_filter: + continue + self._patterns.append(pattern) + + # Add custom patterns + if custom_patterns: + self._patterns.extend(custom_patterns) + + logger.info("ContentFilter initialized with %d patterns", len(self._patterns)) + + def add_pattern(self, pattern: FilterPattern) -> None: + """Add a custom pattern.""" + self._patterns.append(pattern) + logger.debug("Added pattern: %s", pattern.name) + + def remove_pattern(self, pattern_name: str) -> bool: + """Remove a pattern by name.""" + for i, pattern in enumerate(self._patterns): + if pattern.name == pattern_name: + del self._patterns[i] + logger.debug("Removed pattern: %s", pattern_name) + return True + return False + + def enable_pattern(self, pattern_name: str, enabled: bool = True) -> bool: + """Enable or disable a pattern.""" + for pattern in self._patterns: + if pattern.name == pattern_name: + pattern.enabled = enabled + return True + return False + + async def filter( + self, + content: str, + context: dict[str, Any] | None = None, + raise_on_block: bool = False, + ) -> FilterResult: + """ + Filter content for sensitive information. + + Args: + content: Content to filter + context: Optional context for filtering decisions + raise_on_block: Raise exception if content is blocked + + Returns: + FilterResult with filtered content and match details + + Raises: + ContentFilterError: If content is blocked and raise_on_block=True + """ + all_matches: list[FilterMatch] = [] + blocked = False + block_reason: str | None = None + warnings: list[str] = [] + + # Find all matches + for pattern in self._patterns: + if not pattern.enabled: + continue + + matches = pattern.find_matches(content) + for match in matches: + all_matches.append(match) + + if pattern.action == FilterAction.BLOCK: + blocked = True + block_reason = f"Blocked by pattern: {pattern.name}" + elif pattern.action == FilterAction.WARN: + warnings.append( + f"Warning: {pattern.name} detected at position {match.start_pos}" + ) + + # Sort matches by position (reverse for replacement) + all_matches.sort(key=lambda m: m.start_pos, reverse=True) + + # Apply redactions + filtered_content = content + for match in all_matches: + matched_pattern = self._get_pattern(match.pattern_name) + if matched_pattern and matched_pattern.action in (FilterAction.REDACT, FilterAction.BLOCK): + filtered_content = ( + filtered_content[: match.start_pos] + + (match.redacted_text or "[REDACTED]") + + filtered_content[match.end_pos :] + ) + + # Re-sort for result + all_matches.sort(key=lambda m: m.start_pos) + + result = FilterResult( + original_content=content, + filtered_content=filtered_content if not blocked else "", + matches=all_matches, + blocked=blocked, + block_reason=block_reason, + warnings=warnings, + ) + + if blocked: + logger.warning( + "Content blocked: %s (%d matches)", + block_reason, + len(all_matches), + ) + if raise_on_block: + raise ContentFilterError( + block_reason or "Content blocked", + detected_category=all_matches[0].category.value if all_matches else "unknown", + pattern_name=all_matches[0].pattern_name if all_matches else None, + ) + elif all_matches: + logger.debug( + "Content filtered: %d matches, %d warnings", + len(all_matches), + len(warnings), + ) + + return result + + async def filter_dict( + self, + data: dict[str, Any], + keys_to_filter: list[str] | None = None, + recursive: bool = True, + ) -> dict[str, Any]: + """ + Filter string values in a dictionary. + + Args: + data: Dictionary to filter + keys_to_filter: Specific keys to filter (None = all) + recursive: Filter nested dictionaries + + Returns: + Filtered dictionary + """ + result: dict[str, Any] = {} + + for key, value in data.items(): + if isinstance(value, str): + if keys_to_filter is None or key in keys_to_filter: + filter_result = await self.filter(value) + result[key] = filter_result.filtered_content + else: + result[key] = value + elif isinstance(value, dict) and recursive: + result[key] = await self.filter_dict(value, keys_to_filter, recursive) + elif isinstance(value, list): + result[key] = [ + (await self.filter(item)).filtered_content + if isinstance(item, str) + else item + for item in value + ] + else: + result[key] = value + + return result + + async def scan( + self, + content: str, + categories: list[ContentCategory] | None = None, + ) -> list[FilterMatch]: + """ + Scan content without filtering (detection only). + + Args: + content: Content to scan + categories: Limit to specific categories + + Returns: + List of matches found + """ + all_matches: list[FilterMatch] = [] + + for pattern in self._patterns: + if not pattern.enabled: + continue + if categories and pattern.category not in categories: + continue + + matches = pattern.find_matches(content) + all_matches.extend(matches) + + all_matches.sort(key=lambda m: m.start_pos) + return all_matches + + async def validate_safe( + self, + content: str, + categories: list[ContentCategory] | None = None, + allow_warnings: bool = True, + ) -> tuple[bool, list[str]]: + """ + Validate that content is safe (no blocked patterns). + + Args: + content: Content to validate + categories: Limit to specific categories + allow_warnings: Allow content with warnings + + Returns: + Tuple of (is_safe, list of issues) + """ + issues: list[str] = [] + + for pattern in self._patterns: + if not pattern.enabled: + continue + if categories and pattern.category not in categories: + continue + + matches = pattern.find_matches(content) + for match in matches: + if pattern.action == FilterAction.BLOCK: + issues.append(f"Blocked: {pattern.name} at position {match.start_pos}") + elif pattern.action == FilterAction.WARN and not allow_warnings: + issues.append(f"Warning: {pattern.name} at position {match.start_pos}") + + return len(issues) == 0, issues + + def _get_pattern(self, name: str) -> FilterPattern | None: + """Get a pattern by name.""" + for pattern in self._patterns: + if pattern.name == name: + return pattern + return None + + def get_pattern_stats(self) -> dict[str, Any]: + """Get statistics about configured patterns.""" + by_category: dict[str, int] = {} + by_action: dict[str, int] = {} + + for pattern in self._patterns: + cat = pattern.category.value + by_category[cat] = by_category.get(cat, 0) + 1 + + act = pattern.action.value + by_action[act] = by_action.get(act, 0) + 1 + + return { + "total_patterns": len(self._patterns), + "enabled_patterns": sum(1 for p in self._patterns if p.enabled), + "by_category": by_category, + "by_action": by_action, + } + + +# Convenience function for quick filtering +async def filter_content(content: str) -> str: + """Quick filter content with default settings.""" + filter_instance = ContentFilter() + result = await filter_instance.filter(content) + return result.filtered_content + + +async def scan_for_secrets(content: str) -> list[FilterMatch]: + """Quick scan for secrets only.""" + filter_instance = ContentFilter( + enable_pii_filter=False, + enable_injection_filter=False, + ) + return await filter_instance.scan( + content, + categories=[ContentCategory.SECRETS, ContentCategory.CREDENTIALS], + ) diff --git a/backend/app/services/safety/emergency/__init__.py b/backend/app/services/safety/emergency/__init__.py index 9f4729c..8998ae2 100644 --- a/backend/app/services/safety/emergency/__init__.py +++ b/backend/app/services/safety/emergency/__init__.py @@ -1 +1,23 @@ -"""${dir} module.""" +"""Emergency controls for agent safety.""" + +from .controls import ( + EmergencyControls, + EmergencyEvent, + EmergencyReason, + EmergencyState, + EmergencyTrigger, + check_emergency_allowed, + emergency_stop_global, + get_emergency_controls, +) + +__all__ = [ + "EmergencyControls", + "EmergencyEvent", + "EmergencyReason", + "EmergencyState", + "EmergencyTrigger", + "check_emergency_allowed", + "emergency_stop_global", + "get_emergency_controls", +] diff --git a/backend/app/services/safety/emergency/controls.py b/backend/app/services/safety/emergency/controls.py new file mode 100644 index 0000000..b565515 --- /dev/null +++ b/backend/app/services/safety/emergency/controls.py @@ -0,0 +1,594 @@ +""" +Emergency Controls + +Emergency stop and pause functionality for agent safety. +""" + +import asyncio +import logging +from collections.abc import Callable +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Any + +from ..exceptions import EmergencyStopError + +logger = logging.getLogger(__name__) + + +class EmergencyState(str, Enum): + """Emergency control states.""" + + NORMAL = "normal" + PAUSED = "paused" + STOPPED = "stopped" + + +class EmergencyReason(str, Enum): + """Reasons for emergency actions.""" + + MANUAL = "manual" + SAFETY_VIOLATION = "safety_violation" + BUDGET_EXCEEDED = "budget_exceeded" + LOOP_DETECTED = "loop_detected" + RATE_LIMIT = "rate_limit" + CONTENT_VIOLATION = "content_violation" + SYSTEM_ERROR = "system_error" + EXTERNAL_TRIGGER = "external_trigger" + + +@dataclass +class EmergencyEvent: + """Record of an emergency action.""" + + id: str + state: EmergencyState + reason: EmergencyReason + triggered_by: str + message: str + scope: str # "global", "project:", "agent:" + timestamp: datetime = field(default_factory=datetime.utcnow) + metadata: dict[str, Any] = field(default_factory=dict) + resolved_at: datetime | None = None + resolved_by: str | None = None + + +class EmergencyControls: + """ + Emergency stop and pause controls for agent safety. + + Features: + - Global emergency stop + - Per-project/agent emergency controls + - Graceful pause with state preservation + - Automatic triggers from safety violations + - Manual override capabilities + - Event history and audit trail + """ + + def __init__( + self, + notification_handlers: list[Callable[..., Any]] | None = None, + ) -> None: + """ + Initialize EmergencyControls. + + Args: + notification_handlers: Handlers to call on emergency events + """ + self._global_state = EmergencyState.NORMAL + self._scoped_states: dict[str, EmergencyState] = {} + self._events: list[EmergencyEvent] = [] + self._notification_handlers = notification_handlers or [] + self._lock = asyncio.Lock() + self._event_id_counter = 0 + + # Callbacks for state changes + self._on_stop_callbacks: list[Callable[..., Any]] = [] + self._on_pause_callbacks: list[Callable[..., Any]] = [] + self._on_resume_callbacks: list[Callable[..., Any]] = [] + + def _generate_event_id(self) -> str: + """Generate a unique event ID.""" + self._event_id_counter += 1 + return f"emerg-{self._event_id_counter:06d}" + + async def emergency_stop( + self, + reason: EmergencyReason, + triggered_by: str, + message: str, + scope: str = "global", + metadata: dict[str, Any] | None = None, + ) -> EmergencyEvent: + """ + Trigger emergency stop. + + Args: + reason: Reason for the stop + triggered_by: Who/what triggered the stop + message: Human-readable message + scope: Scope of the stop (global, project:, agent:) + metadata: Additional context + + Returns: + The emergency event record + """ + async with self._lock: + event = EmergencyEvent( + id=self._generate_event_id(), + state=EmergencyState.STOPPED, + reason=reason, + triggered_by=triggered_by, + message=message, + scope=scope, + metadata=metadata or {}, + ) + + if scope == "global": + self._global_state = EmergencyState.STOPPED + else: + self._scoped_states[scope] = EmergencyState.STOPPED + + self._events.append(event) + + logger.critical( + "EMERGENCY STOP: scope=%s, reason=%s, by=%s - %s", + scope, + reason.value, + triggered_by, + message, + ) + + # Execute callbacks + await self._execute_callbacks(self._on_stop_callbacks, event) + await self._notify_handlers("emergency_stop", event) + + return event + + async def pause( + self, + reason: EmergencyReason, + triggered_by: str, + message: str, + scope: str = "global", + metadata: dict[str, Any] | None = None, + ) -> EmergencyEvent: + """ + Pause operations (can be resumed). + + Args: + reason: Reason for the pause + triggered_by: Who/what triggered the pause + message: Human-readable message + scope: Scope of the pause + metadata: Additional context + + Returns: + The emergency event record + """ + async with self._lock: + event = EmergencyEvent( + id=self._generate_event_id(), + state=EmergencyState.PAUSED, + reason=reason, + triggered_by=triggered_by, + message=message, + scope=scope, + metadata=metadata or {}, + ) + + if scope == "global": + self._global_state = EmergencyState.PAUSED + else: + self._scoped_states[scope] = EmergencyState.PAUSED + + self._events.append(event) + + logger.warning( + "PAUSE: scope=%s, reason=%s, by=%s - %s", + scope, + reason.value, + triggered_by, + message, + ) + + await self._execute_callbacks(self._on_pause_callbacks, event) + await self._notify_handlers("pause", event) + + return event + + async def resume( + self, + scope: str = "global", + resumed_by: str = "system", + message: str | None = None, + ) -> bool: + """ + Resume operations from paused state. + + Args: + scope: Scope to resume + resumed_by: Who/what is resuming + message: Optional message + + Returns: + True if resumed, False if not in paused state + """ + async with self._lock: + current_state = self._get_state(scope) + + if current_state == EmergencyState.STOPPED: + logger.warning( + "Cannot resume from STOPPED state: %s (requires reset)", + scope, + ) + return False + + if current_state == EmergencyState.NORMAL: + return True # Already normal + + # Find the pause event and mark as resolved + for event in reversed(self._events): + if event.scope == scope and event.state == EmergencyState.PAUSED: + if event.resolved_at is None: + event.resolved_at = datetime.utcnow() + event.resolved_by = resumed_by + break + + if scope == "global": + self._global_state = EmergencyState.NORMAL + else: + self._scoped_states[scope] = EmergencyState.NORMAL + + logger.info( + "RESUMED: scope=%s, by=%s%s", + scope, + resumed_by, + f" - {message}" if message else "", + ) + + await self._execute_callbacks( + self._on_resume_callbacks, + {"scope": scope, "resumed_by": resumed_by}, + ) + await self._notify_handlers("resume", {"scope": scope, "resumed_by": resumed_by}) + + return True + + async def reset( + self, + scope: str = "global", + reset_by: str = "admin", + message: str | None = None, + ) -> bool: + """ + Reset from stopped state (requires explicit action). + + Args: + scope: Scope to reset + reset_by: Who is resetting (should be admin) + message: Optional message + + Returns: + True if reset successful + """ + async with self._lock: + current_state = self._get_state(scope) + + if current_state == EmergencyState.NORMAL: + return True + + # Find the stop event and mark as resolved + for event in reversed(self._events): + if event.scope == scope and event.state == EmergencyState.STOPPED: + if event.resolved_at is None: + event.resolved_at = datetime.utcnow() + event.resolved_by = reset_by + break + + if scope == "global": + self._global_state = EmergencyState.NORMAL + else: + self._scoped_states[scope] = EmergencyState.NORMAL + + logger.warning( + "EMERGENCY RESET: scope=%s, by=%s%s", + scope, + reset_by, + f" - {message}" if message else "", + ) + + await self._notify_handlers("reset", {"scope": scope, "reset_by": reset_by}) + + return True + + async def check_allowed( + self, + scope: str | None = None, + raise_if_blocked: bool = True, + ) -> bool: + """ + Check if operations are allowed. + + Args: + scope: Specific scope to check (also checks global) + raise_if_blocked: Raise exception if blocked + + Returns: + True if operations are allowed + + Raises: + EmergencyStopError: If blocked and raise_if_blocked=True + """ + async with self._lock: + # Always check global state + if self._global_state != EmergencyState.NORMAL: + if raise_if_blocked: + raise EmergencyStopError( + f"Global emergency state: {self._global_state.value}", + stop_reason=self._get_last_reason("global"), + triggered_by=self._get_last_triggered_by("global"), + ) + return False + + # Check specific scope + if scope and scope in self._scoped_states: + state = self._scoped_states[scope] + if state != EmergencyState.NORMAL: + if raise_if_blocked: + raise EmergencyStopError( + f"Emergency state for {scope}: {state.value}", + stop_reason=self._get_last_reason(scope), + triggered_by=self._get_last_triggered_by(scope), + scope=scope, + ) + return False + + return True + + def _get_state(self, scope: str) -> EmergencyState: + """Get state for a scope.""" + if scope == "global": + return self._global_state + return self._scoped_states.get(scope, EmergencyState.NORMAL) + + def _get_last_reason(self, scope: str) -> str: + """Get reason from last event for scope.""" + for event in reversed(self._events): + if event.scope == scope and event.resolved_at is None: + return event.reason.value + return "unknown" + + def _get_last_triggered_by(self, scope: str) -> str: + """Get triggered_by from last event for scope.""" + for event in reversed(self._events): + if event.scope == scope and event.resolved_at is None: + return event.triggered_by + return "unknown" + + async def get_state(self, scope: str = "global") -> EmergencyState: + """Get current state for a scope.""" + async with self._lock: + return self._get_state(scope) + + async def get_all_states(self) -> dict[str, EmergencyState]: + """Get all current states.""" + async with self._lock: + states = {"global": self._global_state} + states.update(self._scoped_states) + return states + + async def get_active_events(self) -> list[EmergencyEvent]: + """Get all unresolved emergency events.""" + async with self._lock: + return [e for e in self._events if e.resolved_at is None] + + async def get_event_history( + self, + scope: str | None = None, + limit: int = 100, + ) -> list[EmergencyEvent]: + """Get emergency event history.""" + async with self._lock: + events = list(self._events) + + if scope: + events = [e for e in events if e.scope == scope] + + return events[-limit:] + + def on_stop(self, callback: Callable[..., Any]) -> None: + """Register callback for stop events.""" + self._on_stop_callbacks.append(callback) + + def on_pause(self, callback: Callable[..., Any]) -> None: + """Register callback for pause events.""" + self._on_pause_callbacks.append(callback) + + def on_resume(self, callback: Callable[..., Any]) -> None: + """Register callback for resume events.""" + self._on_resume_callbacks.append(callback) + + def add_notification_handler(self, handler: Callable[..., Any]) -> None: + """Add a notification handler.""" + self._notification_handlers.append(handler) + + async def _execute_callbacks( + self, + callbacks: list[Callable[..., Any]], + data: Any, + ) -> None: + """Execute callbacks safely.""" + for callback in callbacks: + try: + if asyncio.iscoroutinefunction(callback): + await callback(data) + else: + callback(data) + except Exception as e: + logger.error("Error in callback: %s", e) + + async def _notify_handlers(self, event_type: str, data: Any) -> None: + """Notify all handlers of an event.""" + for handler in self._notification_handlers: + try: + if asyncio.iscoroutinefunction(handler): + await handler(event_type, data) + else: + handler(event_type, data) + except Exception as e: + logger.error("Error in notification handler: %s", e) + + +class EmergencyTrigger: + """ + Automatic emergency triggers based on conditions. + """ + + def __init__(self, controls: EmergencyControls) -> None: + """ + Initialize EmergencyTrigger. + + Args: + controls: EmergencyControls instance to trigger + """ + self._controls = controls + + async def trigger_on_safety_violation( + self, + violation_type: str, + details: dict[str, Any], + scope: str = "global", + ) -> EmergencyEvent: + """ + Trigger emergency from safety violation. + + Args: + violation_type: Type of violation + details: Violation details + scope: Scope for the emergency + + Returns: + Emergency event + """ + return await self._controls.emergency_stop( + reason=EmergencyReason.SAFETY_VIOLATION, + triggered_by="safety_system", + message=f"Safety violation: {violation_type}", + scope=scope, + metadata={"violation_type": violation_type, **details}, + ) + + async def trigger_on_budget_exceeded( + self, + budget_type: str, + current: float, + limit: float, + scope: str = "global", + ) -> EmergencyEvent: + """ + Trigger emergency from budget exceeded. + + Args: + budget_type: Type of budget + current: Current usage + limit: Budget limit + scope: Scope for the emergency + + Returns: + Emergency event + """ + return await self._controls.pause( + reason=EmergencyReason.BUDGET_EXCEEDED, + triggered_by="budget_controller", + message=f"Budget exceeded: {budget_type} ({current:.2f}/{limit:.2f})", + scope=scope, + metadata={"budget_type": budget_type, "current": current, "limit": limit}, + ) + + async def trigger_on_loop_detected( + self, + loop_type: str, + agent_id: str, + details: dict[str, Any], + ) -> EmergencyEvent: + """ + Trigger emergency from loop detection. + + Args: + loop_type: Type of loop + agent_id: Agent that's looping + details: Loop details + + Returns: + Emergency event + """ + return await self._controls.pause( + reason=EmergencyReason.LOOP_DETECTED, + triggered_by="loop_detector", + message=f"Loop detected: {loop_type} in agent {agent_id}", + scope=f"agent:{agent_id}", + metadata={"loop_type": loop_type, "agent_id": agent_id, **details}, + ) + + async def trigger_on_content_violation( + self, + category: str, + pattern: str, + scope: str = "global", + ) -> EmergencyEvent: + """ + Trigger emergency from content violation. + + Args: + category: Content category + pattern: Pattern that matched + scope: Scope for the emergency + + Returns: + Emergency event + """ + return await self._controls.emergency_stop( + reason=EmergencyReason.CONTENT_VIOLATION, + triggered_by="content_filter", + message=f"Content violation: {category} ({pattern})", + scope=scope, + metadata={"category": category, "pattern": pattern}, + ) + + +# Singleton instance +_emergency_controls: EmergencyControls | None = None +_lock = asyncio.Lock() + + +async def get_emergency_controls() -> EmergencyControls: + """Get the singleton EmergencyControls instance.""" + global _emergency_controls + + async with _lock: + if _emergency_controls is None: + _emergency_controls = EmergencyControls() + return _emergency_controls + + +async def emergency_stop_global( + reason: str, + triggered_by: str = "system", +) -> EmergencyEvent: + """Quick global emergency stop.""" + controls = await get_emergency_controls() + return await controls.emergency_stop( + reason=EmergencyReason.MANUAL, + triggered_by=triggered_by, + message=reason, + scope="global", + ) + + +async def check_emergency_allowed(scope: str | None = None) -> bool: + """Quick check if operations are allowed.""" + controls = await get_emergency_controls() + return await controls.check_allowed(scope=scope, raise_if_blocked=False) diff --git a/backend/app/services/safety/hitl/__init__.py b/backend/app/services/safety/hitl/__init__.py index 9f4729c..8cacdc4 100644 --- a/backend/app/services/safety/hitl/__init__.py +++ b/backend/app/services/safety/hitl/__init__.py @@ -1 +1,5 @@ -"""${dir} module.""" +"""Human-in-the-Loop approval workflows.""" + +from .manager import ApprovalQueue, HITLManager + +__all__ = ["ApprovalQueue", "HITLManager"] diff --git a/backend/app/services/safety/hitl/manager.py b/backend/app/services/safety/hitl/manager.py new file mode 100644 index 0000000..f801bbb --- /dev/null +++ b/backend/app/services/safety/hitl/manager.py @@ -0,0 +1,449 @@ +""" +Human-in-the-Loop (HITL) Manager + +Manages approval workflows for actions requiring human oversight. +""" + +import asyncio +import logging +from collections.abc import Callable +from datetime import datetime, timedelta +from typing import Any +from uuid import uuid4 + +from ..config import get_safety_config +from ..exceptions import ( + ApprovalDeniedError, + ApprovalRequiredError, + ApprovalTimeoutError, +) +from ..models import ( + ActionRequest, + ApprovalRequest, + ApprovalResponse, + ApprovalStatus, +) + +logger = logging.getLogger(__name__) + + +class ApprovalQueue: + """Queue for pending approval requests.""" + + def __init__(self) -> None: + self._pending: dict[str, ApprovalRequest] = {} + self._completed: dict[str, ApprovalResponse] = {} + self._waiters: dict[str, asyncio.Event] = {} + self._lock = asyncio.Lock() + + async def add(self, request: ApprovalRequest) -> None: + """Add an approval request to the queue.""" + async with self._lock: + self._pending[request.id] = request + self._waiters[request.id] = asyncio.Event() + + async def get_pending(self, request_id: str) -> ApprovalRequest | None: + """Get a pending request by ID.""" + async with self._lock: + return self._pending.get(request_id) + + async def complete(self, response: ApprovalResponse) -> bool: + """Complete an approval request.""" + async with self._lock: + if response.request_id not in self._pending: + return False + + del self._pending[response.request_id] + self._completed[response.request_id] = response + + # Notify waiters + if response.request_id in self._waiters: + self._waiters[response.request_id].set() + + return True + + async def wait_for_response( + self, + request_id: str, + timeout_seconds: float, + ) -> ApprovalResponse | None: + """Wait for a response to an approval request.""" + async with self._lock: + waiter = self._waiters.get(request_id) + if not waiter: + return self._completed.get(request_id) + + try: + await asyncio.wait_for(waiter.wait(), timeout=timeout_seconds) + except TimeoutError: + return None + + async with self._lock: + return self._completed.get(request_id) + + async def list_pending(self) -> list[ApprovalRequest]: + """List all pending requests.""" + async with self._lock: + return list(self._pending.values()) + + async def cancel(self, request_id: str) -> bool: + """Cancel a pending request.""" + async with self._lock: + if request_id not in self._pending: + return False + + del self._pending[request_id] + + # Create cancelled response + response = ApprovalResponse( + request_id=request_id, + status=ApprovalStatus.CANCELLED, + reason="Cancelled", + ) + self._completed[request_id] = response + + # Notify waiters + if request_id in self._waiters: + self._waiters[request_id].set() + + return True + + async def cleanup_expired(self) -> int: + """Clean up expired requests.""" + now = datetime.utcnow() + to_timeout: list[str] = [] + + async with self._lock: + for request_id, request in self._pending.items(): + if request.expires_at and request.expires_at < now: + to_timeout.append(request_id) + + count = 0 + for request_id in to_timeout: + async with self._lock: + if request_id in self._pending: + del self._pending[request_id] + self._completed[request_id] = ApprovalResponse( + request_id=request_id, + status=ApprovalStatus.TIMEOUT, + reason="Request timed out", + ) + if request_id in self._waiters: + self._waiters[request_id].set() + count += 1 + + return count + + +class HITLManager: + """ + Manages Human-in-the-Loop approval workflows. + + Features: + - Approval request queue + - Configurable timeout handling (default deny) + - Approval delegation + - Batch approval for similar actions + - Approval with modifications + - Notification channels + """ + + def __init__( + self, + default_timeout: int | None = None, + ) -> None: + """ + Initialize the HITLManager. + + Args: + default_timeout: Default timeout for approval requests in seconds + """ + config = get_safety_config() + + self._default_timeout = default_timeout or config.hitl_default_timeout + self._queue = ApprovalQueue() + self._notification_handlers: list[Callable[..., Any]] = [] + self._running = False + self._cleanup_task: asyncio.Task[None] | None = None + + async def start(self) -> None: + """Start the HITL manager background tasks.""" + if self._running: + return + + self._running = True + self._cleanup_task = asyncio.create_task(self._periodic_cleanup()) + logger.info("HITL Manager started") + + async def stop(self) -> None: + """Stop the HITL manager.""" + self._running = False + + if self._cleanup_task: + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass + + logger.info("HITL Manager stopped") + + async def request_approval( + self, + action: ActionRequest, + reason: str, + timeout_seconds: int | None = None, + urgency: str = "normal", + context: dict[str, Any] | None = None, + ) -> ApprovalRequest: + """ + Create an approval request for an action. + + Args: + action: The action requiring approval + reason: Why approval is required + timeout_seconds: Timeout for this request + urgency: Urgency level (low, normal, high, critical) + context: Additional context for the approver + + Returns: + The created approval request + """ + timeout = timeout_seconds or self._default_timeout + expires_at = datetime.utcnow() + timedelta(seconds=timeout) + + request = ApprovalRequest( + id=str(uuid4()), + action=action, + reason=reason, + urgency=urgency, + timeout_seconds=timeout, + expires_at=expires_at, + context=context or {}, + ) + + await self._queue.add(request) + + # Notify handlers + await self._notify_handlers("approval_requested", request) + + logger.info( + "Approval requested: %s for action %s (timeout: %ds)", + request.id, + action.id, + timeout, + ) + + return request + + async def wait_for_approval( + self, + request_id: str, + timeout_seconds: int | None = None, + ) -> ApprovalResponse: + """ + Wait for an approval decision. + + Args: + request_id: ID of the approval request + timeout_seconds: Override timeout + + Returns: + The approval response + + Raises: + ApprovalTimeoutError: If timeout expires + ApprovalDeniedError: If approval is denied + """ + request = await self._queue.get_pending(request_id) + if not request: + raise ApprovalRequiredError( + f"Approval request not found: {request_id}", + approval_id=request_id, + ) + + timeout = timeout_seconds or request.timeout_seconds or self._default_timeout + response = await self._queue.wait_for_response(request_id, timeout) + + if response is None: + # Timeout - default deny + response = ApprovalResponse( + request_id=request_id, + status=ApprovalStatus.TIMEOUT, + reason="Request timed out (default deny)", + ) + await self._queue.complete(response) + + raise ApprovalTimeoutError( + "Approval request timed out", + approval_id=request_id, + timeout_seconds=timeout, + ) + + if response.status == ApprovalStatus.DENIED: + raise ApprovalDeniedError( + response.reason or "Approval denied", + approval_id=request_id, + denied_by=response.decided_by, + denial_reason=response.reason, + ) + + if response.status == ApprovalStatus.TIMEOUT: + raise ApprovalTimeoutError( + "Approval request timed out", + approval_id=request_id, + timeout_seconds=timeout, + ) + + if response.status == ApprovalStatus.CANCELLED: + raise ApprovalDeniedError( + "Approval request was cancelled", + approval_id=request_id, + denial_reason="Cancelled", + ) + + return response + + async def approve( + self, + request_id: str, + decided_by: str, + reason: str | None = None, + modifications: dict[str, Any] | None = None, + ) -> bool: + """ + Approve a pending request. + + Args: + request_id: ID of the approval request + decided_by: Who approved + reason: Optional approval reason + modifications: Optional modifications to the action + + Returns: + True if approval was recorded + """ + response = ApprovalResponse( + request_id=request_id, + status=ApprovalStatus.APPROVED, + decided_by=decided_by, + reason=reason, + modifications=modifications, + ) + + success = await self._queue.complete(response) + + if success: + logger.info( + "Approval granted: %s by %s", + request_id, + decided_by, + ) + await self._notify_handlers("approval_granted", response) + + return success + + async def deny( + self, + request_id: str, + decided_by: str, + reason: str | None = None, + ) -> bool: + """ + Deny a pending request. + + Args: + request_id: ID of the approval request + decided_by: Who denied + reason: Denial reason + + Returns: + True if denial was recorded + """ + response = ApprovalResponse( + request_id=request_id, + status=ApprovalStatus.DENIED, + decided_by=decided_by, + reason=reason, + ) + + success = await self._queue.complete(response) + + if success: + logger.info( + "Approval denied: %s by %s - %s", + request_id, + decided_by, + reason, + ) + await self._notify_handlers("approval_denied", response) + + return success + + async def cancel(self, request_id: str) -> bool: + """ + Cancel a pending request. + + Args: + request_id: ID of the approval request + + Returns: + True if request was cancelled + """ + success = await self._queue.cancel(request_id) + + if success: + logger.info("Approval request cancelled: %s", request_id) + + return success + + async def list_pending(self) -> list[ApprovalRequest]: + """List all pending approval requests.""" + return await self._queue.list_pending() + + async def get_request(self, request_id: str) -> ApprovalRequest | None: + """Get an approval request by ID.""" + return await self._queue.get_pending(request_id) + + def add_notification_handler( + self, + handler: Callable[..., Any], + ) -> None: + """Add a notification handler.""" + self._notification_handlers.append(handler) + + def remove_notification_handler( + self, + handler: Callable[..., Any], + ) -> None: + """Remove a notification handler.""" + if handler in self._notification_handlers: + self._notification_handlers.remove(handler) + + async def _notify_handlers( + self, + event_type: str, + data: Any, + ) -> None: + """Notify all handlers of an event.""" + for handler in self._notification_handlers: + try: + if asyncio.iscoroutinefunction(handler): + await handler(event_type, data) + else: + handler(event_type, data) + except Exception as e: + logger.error("Error in notification handler: %s", e) + + async def _periodic_cleanup(self) -> None: + """Background task for cleaning up expired requests.""" + while self._running: + try: + await asyncio.sleep(30) # Check every 30 seconds + count = await self._queue.cleanup_expired() + if count: + logger.debug("Cleaned up %d expired approval requests", count) + except asyncio.CancelledError: + break + except Exception as e: + logger.error("Error in approval cleanup: %s", e) diff --git a/backend/app/services/safety/rollback/__init__.py b/backend/app/services/safety/rollback/__init__.py index 9f4729c..a64a2ec 100644 --- a/backend/app/services/safety/rollback/__init__.py +++ b/backend/app/services/safety/rollback/__init__.py @@ -1 +1,5 @@ -"""${dir} module.""" +"""Rollback management for agent actions.""" + +from .manager import RollbackManager, TransactionContext + +__all__ = ["RollbackManager", "TransactionContext"] diff --git a/backend/app/services/safety/rollback/manager.py b/backend/app/services/safety/rollback/manager.py new file mode 100644 index 0000000..38d8799 --- /dev/null +++ b/backend/app/services/safety/rollback/manager.py @@ -0,0 +1,418 @@ +""" +Rollback Manager + +Manages checkpoints and rollback operations for agent actions. +""" + +import asyncio +import logging +from datetime import datetime, timedelta +from pathlib import Path +from typing import Any +from uuid import uuid4 + +from ..config import get_safety_config +from ..exceptions import RollbackError +from ..models import ( + ActionRequest, + Checkpoint, + CheckpointType, + RollbackResult, +) + +logger = logging.getLogger(__name__) + + +class FileCheckpoint: + """Stores file state for rollback.""" + + def __init__( + self, + checkpoint_id: str, + file_path: str, + original_content: bytes | None, + existed: bool, + ) -> None: + self.checkpoint_id = checkpoint_id + self.file_path = file_path + self.original_content = original_content + self.existed = existed + self.created_at = datetime.utcnow() + + +class RollbackManager: + """ + Manages checkpoints and rollback operations. + + Features: + - File system checkpoints + - Transaction wrapping for actions + - Automatic checkpoint for destructive actions + - Rollback triggers on failure + - Checkpoint expiration and cleanup + """ + + def __init__( + self, + checkpoint_dir: str | None = None, + retention_hours: int | None = None, + ) -> None: + """ + Initialize the RollbackManager. + + Args: + checkpoint_dir: Directory for storing checkpoint data + retention_hours: Hours to retain checkpoints + """ + config = get_safety_config() + + self._checkpoint_dir = Path( + checkpoint_dir or config.checkpoint_dir + ) + self._retention_hours = retention_hours or config.checkpoint_retention_hours + + self._checkpoints: dict[str, Checkpoint] = {} + self._file_checkpoints: dict[str, list[FileCheckpoint]] = {} + self._lock = asyncio.Lock() + + # Ensure checkpoint directory exists + self._checkpoint_dir.mkdir(parents=True, exist_ok=True) + + async def create_checkpoint( + self, + action: ActionRequest, + checkpoint_type: CheckpointType = CheckpointType.COMPOSITE, + description: str | None = None, + ) -> Checkpoint: + """ + Create a checkpoint before an action. + + Args: + action: The action to checkpoint for + checkpoint_type: Type of checkpoint + description: Optional description + + Returns: + The created checkpoint + """ + checkpoint_id = str(uuid4()) + + checkpoint = Checkpoint( + id=checkpoint_id, + checkpoint_type=checkpoint_type, + action_id=action.id, + created_at=datetime.utcnow(), + expires_at=datetime.utcnow() + timedelta(hours=self._retention_hours), + data={ + "action_type": action.action_type.value, + "tool_name": action.tool_name, + "resource": action.resource, + }, + description=description or f"Checkpoint for {action.tool_name}", + ) + + async with self._lock: + self._checkpoints[checkpoint_id] = checkpoint + self._file_checkpoints[checkpoint_id] = [] + + logger.info( + "Created checkpoint %s for action %s", + checkpoint_id, + action.id, + ) + + return checkpoint + + async def checkpoint_file( + self, + checkpoint_id: str, + file_path: str, + ) -> None: + """ + Store current state of a file for checkpoint. + + Args: + checkpoint_id: ID of the checkpoint + file_path: Path to the file + """ + path = Path(file_path) + + if path.exists(): + content = path.read_bytes() + existed = True + else: + content = None + existed = False + + file_checkpoint = FileCheckpoint( + checkpoint_id=checkpoint_id, + file_path=file_path, + original_content=content, + existed=existed, + ) + + async with self._lock: + if checkpoint_id not in self._file_checkpoints: + self._file_checkpoints[checkpoint_id] = [] + self._file_checkpoints[checkpoint_id].append(file_checkpoint) + + logger.debug( + "Stored file state for checkpoint %s: %s (existed=%s)", + checkpoint_id, + file_path, + existed, + ) + + async def checkpoint_files( + self, + checkpoint_id: str, + file_paths: list[str], + ) -> None: + """ + Store current state of multiple files. + + Args: + checkpoint_id: ID of the checkpoint + file_paths: Paths to the files + """ + for path in file_paths: + await self.checkpoint_file(checkpoint_id, path) + + async def rollback( + self, + checkpoint_id: str, + ) -> RollbackResult: + """ + Rollback to a checkpoint. + + Args: + checkpoint_id: ID of the checkpoint + + Returns: + Result of the rollback operation + """ + async with self._lock: + checkpoint = self._checkpoints.get(checkpoint_id) + if not checkpoint: + raise RollbackError( + f"Checkpoint not found: {checkpoint_id}", + checkpoint_id=checkpoint_id, + ) + + if not checkpoint.is_valid: + raise RollbackError( + f"Checkpoint is no longer valid: {checkpoint_id}", + checkpoint_id=checkpoint_id, + ) + + file_checkpoints = self._file_checkpoints.get(checkpoint_id, []) + + actions_rolled_back: list[str] = [] + failed_actions: list[str] = [] + + # Rollback file changes + for fc in file_checkpoints: + try: + await self._rollback_file(fc) + actions_rolled_back.append(f"file:{fc.file_path}") + except Exception as e: + logger.error("Failed to rollback file %s: %s", fc.file_path, e) + failed_actions.append(f"file:{fc.file_path}") + + success = len(failed_actions) == 0 + + # Mark checkpoint as used + async with self._lock: + if checkpoint_id in self._checkpoints: + self._checkpoints[checkpoint_id].is_valid = False + + result = RollbackResult( + checkpoint_id=checkpoint_id, + success=success, + actions_rolled_back=actions_rolled_back, + failed_actions=failed_actions, + error=None if success else f"Failed to rollback {len(failed_actions)} items", + ) + + if success: + logger.info("Rollback successful for checkpoint %s", checkpoint_id) + else: + logger.error( + "Rollback partially failed for checkpoint %s: %d failures", + checkpoint_id, + len(failed_actions), + ) + + return result + + async def discard_checkpoint(self, checkpoint_id: str) -> bool: + """ + Discard a checkpoint without rolling back. + + Args: + checkpoint_id: ID of the checkpoint + + Returns: + True if checkpoint was found and discarded + """ + async with self._lock: + if checkpoint_id in self._checkpoints: + del self._checkpoints[checkpoint_id] + if checkpoint_id in self._file_checkpoints: + del self._file_checkpoints[checkpoint_id] + logger.debug("Discarded checkpoint %s", checkpoint_id) + return True + return False + + async def get_checkpoint(self, checkpoint_id: str) -> Checkpoint | None: + """Get a checkpoint by ID.""" + async with self._lock: + return self._checkpoints.get(checkpoint_id) + + async def list_checkpoints( + self, + action_id: str | None = None, + include_expired: bool = False, + ) -> list[Checkpoint]: + """ + List checkpoints. + + Args: + action_id: Optional filter by action ID + include_expired: Include expired checkpoints + + Returns: + List of checkpoints + """ + now = datetime.utcnow() + + async with self._lock: + checkpoints = list(self._checkpoints.values()) + + if action_id: + checkpoints = [c for c in checkpoints if c.action_id == action_id] + + if not include_expired: + checkpoints = [ + c for c in checkpoints + if c.expires_at is None or c.expires_at > now + ] + + return checkpoints + + async def cleanup_expired(self) -> int: + """ + Clean up expired checkpoints. + + Returns: + Number of checkpoints cleaned up + """ + now = datetime.utcnow() + to_remove: list[str] = [] + + async with self._lock: + for checkpoint_id, checkpoint in self._checkpoints.items(): + if checkpoint.expires_at and checkpoint.expires_at < now: + to_remove.append(checkpoint_id) + + for checkpoint_id in to_remove: + del self._checkpoints[checkpoint_id] + if checkpoint_id in self._file_checkpoints: + del self._file_checkpoints[checkpoint_id] + + if to_remove: + logger.info("Cleaned up %d expired checkpoints", len(to_remove)) + + return len(to_remove) + + async def _rollback_file(self, fc: FileCheckpoint) -> None: + """Rollback a single file to its checkpoint state.""" + path = Path(fc.file_path) + + if fc.existed: + # Restore original content + if fc.original_content is not None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(fc.original_content) + logger.debug("Restored file: %s", fc.file_path) + else: + # File didn't exist before - delete it + if path.exists(): + path.unlink() + logger.debug("Deleted file (didn't exist before): %s", fc.file_path) + + +class TransactionContext: + """ + Context manager for transactional action execution. + + Usage: + async with TransactionContext(rollback_manager, action) as tx: + tx.checkpoint_file("/path/to/file") + # Do work... + # If exception occurs, automatic rollback + """ + + def __init__( + self, + manager: RollbackManager, + action: ActionRequest, + auto_rollback: bool = True, + ) -> None: + self._manager = manager + self._action = action + self._auto_rollback = auto_rollback + self._checkpoint: Checkpoint | None = None + self._committed = False + + async def __aenter__(self) -> "TransactionContext": + self._checkpoint = await self._manager.create_checkpoint(self._action) + return self + + async def __aexit__( + self, + exc_type: type | None, + exc_val: Exception | None, + exc_tb: Any, + ) -> bool: + if exc_val is not None and self._auto_rollback and not self._committed: + # Exception occurred - rollback + if self._checkpoint: + try: + await self._manager.rollback(self._checkpoint.id) + logger.info( + "Auto-rollback completed for action %s", + self._action.id, + ) + except Exception as e: + logger.error("Auto-rollback failed: %s", e) + elif self._committed and self._checkpoint: + # Committed - discard checkpoint + await self._manager.discard_checkpoint(self._checkpoint.id) + + return False # Don't suppress the exception + + @property + def checkpoint_id(self) -> str | None: + """Get the checkpoint ID.""" + return self._checkpoint.id if self._checkpoint else None + + async def checkpoint_file(self, file_path: str) -> None: + """Checkpoint a file for this transaction.""" + if self._checkpoint: + await self._manager.checkpoint_file(self._checkpoint.id, file_path) + + async def checkpoint_files(self, file_paths: list[str]) -> None: + """Checkpoint multiple files for this transaction.""" + if self._checkpoint: + await self._manager.checkpoint_files(self._checkpoint.id, file_paths) + + def commit(self) -> None: + """Mark transaction as committed (no rollback on exit).""" + self._committed = True + + async def rollback(self) -> RollbackResult | None: + """Manually trigger rollback.""" + if self._checkpoint: + return await self._manager.rollback(self._checkpoint.id) + return None