forked from cardosofelipe/fast-next-template
feat(safety): add Phase C advanced controls
- Add rollback manager with file checkpointing and transaction context - Add HITL manager with approval queues and notification handlers - Add content filter with PII, secrets, and injection detection - Add emergency controls with stop/pause/resume capabilities - Update SafetyConfig with checkpoint_dir setting Issue #63 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -74,6 +74,10 @@ class SafetyConfig(BaseSettings):
|
||||
|
||||
# Rollback settings
|
||||
rollback_enabled: bool = Field(True, description="Enable rollback capability")
|
||||
checkpoint_dir: str = Field(
|
||||
"/tmp/syndarix_checkpoints", # noqa: S108
|
||||
description="Directory for checkpoint storage",
|
||||
)
|
||||
checkpoint_retention_hours: int = Field(24, description="Checkpoint retention")
|
||||
auto_checkpoint_destructive: bool = Field(
|
||||
True, description="Auto-checkpoint destructive actions"
|
||||
|
||||
@@ -1 +1,23 @@
|
||||
"""${dir} module."""
|
||||
"""Content filtering for safety."""
|
||||
|
||||
from .filter import (
|
||||
ContentCategory,
|
||||
ContentFilter,
|
||||
FilterAction,
|
||||
FilterMatch,
|
||||
FilterPattern,
|
||||
FilterResult,
|
||||
filter_content,
|
||||
scan_for_secrets,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ContentCategory",
|
||||
"ContentFilter",
|
||||
"FilterAction",
|
||||
"FilterMatch",
|
||||
"FilterPattern",
|
||||
"FilterResult",
|
||||
"filter_content",
|
||||
"scan_for_secrets",
|
||||
]
|
||||
|
||||
532
backend/app/services/safety/content/filter.py
Normal file
532
backend/app/services/safety/content/filter.py
Normal file
@@ -0,0 +1,532 @@
|
||||
"""
|
||||
Content Filter
|
||||
|
||||
Filters and sanitizes content for safety, including PII detection and secret scanning.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from ..exceptions import ContentFilterError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContentCategory(str, Enum):
|
||||
"""Categories of sensitive content."""
|
||||
|
||||
PII = "pii"
|
||||
SECRETS = "secrets"
|
||||
CREDENTIALS = "credentials"
|
||||
FINANCIAL = "financial"
|
||||
HEALTH = "health"
|
||||
PROFANITY = "profanity"
|
||||
INJECTION = "injection"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class FilterAction(str, Enum):
|
||||
"""Actions to take on detected content."""
|
||||
|
||||
ALLOW = "allow"
|
||||
REDACT = "redact"
|
||||
BLOCK = "block"
|
||||
WARN = "warn"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilterMatch:
|
||||
"""A match found by a filter."""
|
||||
|
||||
category: ContentCategory
|
||||
pattern_name: str
|
||||
matched_text: str
|
||||
start_pos: int
|
||||
end_pos: int
|
||||
confidence: float = 1.0
|
||||
redacted_text: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilterResult:
|
||||
"""Result of content filtering."""
|
||||
|
||||
original_content: str
|
||||
filtered_content: str
|
||||
matches: list[FilterMatch] = field(default_factory=list)
|
||||
blocked: bool = False
|
||||
block_reason: str | None = None
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def has_sensitive_content(self) -> bool:
|
||||
"""Check if any sensitive content was found."""
|
||||
return len(self.matches) > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilterPattern:
|
||||
"""A pattern for detecting sensitive content."""
|
||||
|
||||
name: str
|
||||
category: ContentCategory
|
||||
pattern: str # Regex pattern
|
||||
action: FilterAction = FilterAction.REDACT
|
||||
replacement: str = "[REDACTED]"
|
||||
confidence: float = 1.0
|
||||
enabled: bool = True
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Compile the regex pattern."""
|
||||
self._compiled = re.compile(self.pattern, re.IGNORECASE | re.MULTILINE)
|
||||
|
||||
def find_matches(self, content: str) -> list[FilterMatch]:
|
||||
"""Find all matches in content."""
|
||||
matches = []
|
||||
for match in self._compiled.finditer(content):
|
||||
matches.append(
|
||||
FilterMatch(
|
||||
category=self.category,
|
||||
pattern_name=self.name,
|
||||
matched_text=match.group(),
|
||||
start_pos=match.start(),
|
||||
end_pos=match.end(),
|
||||
confidence=self.confidence,
|
||||
redacted_text=self.replacement,
|
||||
)
|
||||
)
|
||||
return matches
|
||||
|
||||
|
||||
class ContentFilter:
|
||||
"""
|
||||
Filters content for sensitive information.
|
||||
|
||||
Features:
|
||||
- PII detection (emails, phones, SSN, etc.)
|
||||
- Secret scanning (API keys, tokens, passwords)
|
||||
- Credential detection
|
||||
- Injection attack prevention
|
||||
- Custom pattern support
|
||||
- Configurable actions (allow, redact, block, warn)
|
||||
"""
|
||||
|
||||
# Default patterns for common sensitive data
|
||||
DEFAULT_PATTERNS: ClassVar[list[FilterPattern]] = [
|
||||
# PII Patterns
|
||||
FilterPattern(
|
||||
name="email",
|
||||
category=ContentCategory.PII,
|
||||
pattern=r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b",
|
||||
action=FilterAction.REDACT,
|
||||
replacement="[EMAIL]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="phone_us",
|
||||
category=ContentCategory.PII,
|
||||
pattern=r"\b(?:\+1[-.\s]?)?(?:\(?\d{3}\)?[-.\s]?)?\d{3}[-.\s]?\d{4}\b",
|
||||
action=FilterAction.REDACT,
|
||||
replacement="[PHONE]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="ssn",
|
||||
category=ContentCategory.PII,
|
||||
pattern=r"\b\d{3}[-\s]?\d{2}[-\s]?\d{4}\b",
|
||||
action=FilterAction.REDACT,
|
||||
replacement="[SSN]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="credit_card",
|
||||
category=ContentCategory.FINANCIAL,
|
||||
pattern=r"\b(?:\d{4}[-\s]?){3}\d{4}\b",
|
||||
action=FilterAction.REDACT,
|
||||
replacement="[CREDIT_CARD]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="ip_address",
|
||||
category=ContentCategory.PII,
|
||||
pattern=r"\b(?:\d{1,3}\.){3}\d{1,3}\b",
|
||||
action=FilterAction.WARN,
|
||||
replacement="[IP]",
|
||||
confidence=0.8,
|
||||
),
|
||||
# Secret Patterns
|
||||
FilterPattern(
|
||||
name="api_key_generic",
|
||||
category=ContentCategory.SECRETS,
|
||||
pattern=r"\b(?:api[_-]?key|apikey)\s*[:=]\s*['\"]?([A-Za-z0-9_-]{20,})['\"]?",
|
||||
action=FilterAction.BLOCK,
|
||||
replacement="[API_KEY]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="aws_access_key",
|
||||
category=ContentCategory.SECRETS,
|
||||
pattern=r"\bAKIA[0-9A-Z]{16}\b",
|
||||
action=FilterAction.BLOCK,
|
||||
replacement="[AWS_KEY]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="aws_secret_key",
|
||||
category=ContentCategory.SECRETS,
|
||||
pattern=r"\b[A-Za-z0-9/+=]{40}\b",
|
||||
action=FilterAction.WARN,
|
||||
replacement="[AWS_SECRET]",
|
||||
confidence=0.6, # Lower confidence - might be false positive
|
||||
),
|
||||
FilterPattern(
|
||||
name="github_token",
|
||||
category=ContentCategory.SECRETS,
|
||||
pattern=r"\b(ghp|gho|ghu|ghs|ghr)_[A-Za-z0-9]{36,}\b",
|
||||
action=FilterAction.BLOCK,
|
||||
replacement="[GITHUB_TOKEN]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="jwt_token",
|
||||
category=ContentCategory.SECRETS,
|
||||
pattern=r"\beyJ[A-Za-z0-9_-]*\.eyJ[A-Za-z0-9_-]*\.[A-Za-z0-9_-]*\b",
|
||||
action=FilterAction.BLOCK,
|
||||
replacement="[JWT]",
|
||||
),
|
||||
# Credential Patterns
|
||||
FilterPattern(
|
||||
name="password_in_url",
|
||||
category=ContentCategory.CREDENTIALS,
|
||||
pattern=r"://[^:]+:([^@]+)@",
|
||||
action=FilterAction.BLOCK,
|
||||
replacement="://[REDACTED]@",
|
||||
),
|
||||
FilterPattern(
|
||||
name="password_assignment",
|
||||
category=ContentCategory.CREDENTIALS,
|
||||
pattern=r"\b(?:password|passwd|pwd)\s*[:=]\s*['\"]?([^\s'\"]+)['\"]?",
|
||||
action=FilterAction.REDACT,
|
||||
replacement="[PASSWORD]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="private_key",
|
||||
category=ContentCategory.SECRETS,
|
||||
pattern=r"-----BEGIN (?:RSA |DSA |EC |OPENSSH )?PRIVATE KEY-----",
|
||||
action=FilterAction.BLOCK,
|
||||
replacement="[PRIVATE_KEY]",
|
||||
),
|
||||
# Injection Patterns
|
||||
FilterPattern(
|
||||
name="sql_injection",
|
||||
category=ContentCategory.INJECTION,
|
||||
pattern=r"(?:'\s*(?:OR|AND)\s*')|(?:--\s*$)|(?:;\s*(?:DROP|DELETE|UPDATE|INSERT))",
|
||||
action=FilterAction.BLOCK,
|
||||
replacement="[BLOCKED]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="command_injection",
|
||||
category=ContentCategory.INJECTION,
|
||||
pattern=r"[;&|`$]|\$\(|\$\{",
|
||||
action=FilterAction.WARN,
|
||||
replacement="[CMD]",
|
||||
confidence=0.5, # Low confidence - common in code
|
||||
),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enable_pii_filter: bool = True,
|
||||
enable_secret_filter: bool = True,
|
||||
enable_injection_filter: bool = True,
|
||||
custom_patterns: list[FilterPattern] | None = None,
|
||||
default_action: FilterAction = FilterAction.REDACT,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the ContentFilter.
|
||||
|
||||
Args:
|
||||
enable_pii_filter: Enable PII detection
|
||||
enable_secret_filter: Enable secret scanning
|
||||
enable_injection_filter: Enable injection detection
|
||||
custom_patterns: Additional custom patterns
|
||||
default_action: Default action for matches
|
||||
"""
|
||||
self._patterns: list[FilterPattern] = []
|
||||
self._default_action = default_action
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Load default patterns based on configuration
|
||||
for pattern in self.DEFAULT_PATTERNS:
|
||||
if pattern.category == ContentCategory.PII and not enable_pii_filter:
|
||||
continue
|
||||
if pattern.category == ContentCategory.SECRETS and not enable_secret_filter:
|
||||
continue
|
||||
if pattern.category == ContentCategory.CREDENTIALS and not enable_secret_filter:
|
||||
continue
|
||||
if pattern.category == ContentCategory.INJECTION and not enable_injection_filter:
|
||||
continue
|
||||
self._patterns.append(pattern)
|
||||
|
||||
# Add custom patterns
|
||||
if custom_patterns:
|
||||
self._patterns.extend(custom_patterns)
|
||||
|
||||
logger.info("ContentFilter initialized with %d patterns", len(self._patterns))
|
||||
|
||||
def add_pattern(self, pattern: FilterPattern) -> None:
|
||||
"""Add a custom pattern."""
|
||||
self._patterns.append(pattern)
|
||||
logger.debug("Added pattern: %s", pattern.name)
|
||||
|
||||
def remove_pattern(self, pattern_name: str) -> bool:
|
||||
"""Remove a pattern by name."""
|
||||
for i, pattern in enumerate(self._patterns):
|
||||
if pattern.name == pattern_name:
|
||||
del self._patterns[i]
|
||||
logger.debug("Removed pattern: %s", pattern_name)
|
||||
return True
|
||||
return False
|
||||
|
||||
def enable_pattern(self, pattern_name: str, enabled: bool = True) -> bool:
|
||||
"""Enable or disable a pattern."""
|
||||
for pattern in self._patterns:
|
||||
if pattern.name == pattern_name:
|
||||
pattern.enabled = enabled
|
||||
return True
|
||||
return False
|
||||
|
||||
async def filter(
|
||||
self,
|
||||
content: str,
|
||||
context: dict[str, Any] | None = None,
|
||||
raise_on_block: bool = False,
|
||||
) -> FilterResult:
|
||||
"""
|
||||
Filter content for sensitive information.
|
||||
|
||||
Args:
|
||||
content: Content to filter
|
||||
context: Optional context for filtering decisions
|
||||
raise_on_block: Raise exception if content is blocked
|
||||
|
||||
Returns:
|
||||
FilterResult with filtered content and match details
|
||||
|
||||
Raises:
|
||||
ContentFilterError: If content is blocked and raise_on_block=True
|
||||
"""
|
||||
all_matches: list[FilterMatch] = []
|
||||
blocked = False
|
||||
block_reason: str | None = None
|
||||
warnings: list[str] = []
|
||||
|
||||
# Find all matches
|
||||
for pattern in self._patterns:
|
||||
if not pattern.enabled:
|
||||
continue
|
||||
|
||||
matches = pattern.find_matches(content)
|
||||
for match in matches:
|
||||
all_matches.append(match)
|
||||
|
||||
if pattern.action == FilterAction.BLOCK:
|
||||
blocked = True
|
||||
block_reason = f"Blocked by pattern: {pattern.name}"
|
||||
elif pattern.action == FilterAction.WARN:
|
||||
warnings.append(
|
||||
f"Warning: {pattern.name} detected at position {match.start_pos}"
|
||||
)
|
||||
|
||||
# Sort matches by position (reverse for replacement)
|
||||
all_matches.sort(key=lambda m: m.start_pos, reverse=True)
|
||||
|
||||
# Apply redactions
|
||||
filtered_content = content
|
||||
for match in all_matches:
|
||||
matched_pattern = self._get_pattern(match.pattern_name)
|
||||
if matched_pattern and matched_pattern.action in (FilterAction.REDACT, FilterAction.BLOCK):
|
||||
filtered_content = (
|
||||
filtered_content[: match.start_pos]
|
||||
+ (match.redacted_text or "[REDACTED]")
|
||||
+ filtered_content[match.end_pos :]
|
||||
)
|
||||
|
||||
# Re-sort for result
|
||||
all_matches.sort(key=lambda m: m.start_pos)
|
||||
|
||||
result = FilterResult(
|
||||
original_content=content,
|
||||
filtered_content=filtered_content if not blocked else "",
|
||||
matches=all_matches,
|
||||
blocked=blocked,
|
||||
block_reason=block_reason,
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
if blocked:
|
||||
logger.warning(
|
||||
"Content blocked: %s (%d matches)",
|
||||
block_reason,
|
||||
len(all_matches),
|
||||
)
|
||||
if raise_on_block:
|
||||
raise ContentFilterError(
|
||||
block_reason or "Content blocked",
|
||||
detected_category=all_matches[0].category.value if all_matches else "unknown",
|
||||
pattern_name=all_matches[0].pattern_name if all_matches else None,
|
||||
)
|
||||
elif all_matches:
|
||||
logger.debug(
|
||||
"Content filtered: %d matches, %d warnings",
|
||||
len(all_matches),
|
||||
len(warnings),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def filter_dict(
|
||||
self,
|
||||
data: dict[str, Any],
|
||||
keys_to_filter: list[str] | None = None,
|
||||
recursive: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Filter string values in a dictionary.
|
||||
|
||||
Args:
|
||||
data: Dictionary to filter
|
||||
keys_to_filter: Specific keys to filter (None = all)
|
||||
recursive: Filter nested dictionaries
|
||||
|
||||
Returns:
|
||||
Filtered dictionary
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
|
||||
for key, value in data.items():
|
||||
if isinstance(value, str):
|
||||
if keys_to_filter is None or key in keys_to_filter:
|
||||
filter_result = await self.filter(value)
|
||||
result[key] = filter_result.filtered_content
|
||||
else:
|
||||
result[key] = value
|
||||
elif isinstance(value, dict) and recursive:
|
||||
result[key] = await self.filter_dict(value, keys_to_filter, recursive)
|
||||
elif isinstance(value, list):
|
||||
result[key] = [
|
||||
(await self.filter(item)).filtered_content
|
||||
if isinstance(item, str)
|
||||
else item
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
result[key] = value
|
||||
|
||||
return result
|
||||
|
||||
async def scan(
|
||||
self,
|
||||
content: str,
|
||||
categories: list[ContentCategory] | None = None,
|
||||
) -> list[FilterMatch]:
|
||||
"""
|
||||
Scan content without filtering (detection only).
|
||||
|
||||
Args:
|
||||
content: Content to scan
|
||||
categories: Limit to specific categories
|
||||
|
||||
Returns:
|
||||
List of matches found
|
||||
"""
|
||||
all_matches: list[FilterMatch] = []
|
||||
|
||||
for pattern in self._patterns:
|
||||
if not pattern.enabled:
|
||||
continue
|
||||
if categories and pattern.category not in categories:
|
||||
continue
|
||||
|
||||
matches = pattern.find_matches(content)
|
||||
all_matches.extend(matches)
|
||||
|
||||
all_matches.sort(key=lambda m: m.start_pos)
|
||||
return all_matches
|
||||
|
||||
async def validate_safe(
|
||||
self,
|
||||
content: str,
|
||||
categories: list[ContentCategory] | None = None,
|
||||
allow_warnings: bool = True,
|
||||
) -> tuple[bool, list[str]]:
|
||||
"""
|
||||
Validate that content is safe (no blocked patterns).
|
||||
|
||||
Args:
|
||||
content: Content to validate
|
||||
categories: Limit to specific categories
|
||||
allow_warnings: Allow content with warnings
|
||||
|
||||
Returns:
|
||||
Tuple of (is_safe, list of issues)
|
||||
"""
|
||||
issues: list[str] = []
|
||||
|
||||
for pattern in self._patterns:
|
||||
if not pattern.enabled:
|
||||
continue
|
||||
if categories and pattern.category not in categories:
|
||||
continue
|
||||
|
||||
matches = pattern.find_matches(content)
|
||||
for match in matches:
|
||||
if pattern.action == FilterAction.BLOCK:
|
||||
issues.append(f"Blocked: {pattern.name} at position {match.start_pos}")
|
||||
elif pattern.action == FilterAction.WARN and not allow_warnings:
|
||||
issues.append(f"Warning: {pattern.name} at position {match.start_pos}")
|
||||
|
||||
return len(issues) == 0, issues
|
||||
|
||||
def _get_pattern(self, name: str) -> FilterPattern | None:
|
||||
"""Get a pattern by name."""
|
||||
for pattern in self._patterns:
|
||||
if pattern.name == name:
|
||||
return pattern
|
||||
return None
|
||||
|
||||
def get_pattern_stats(self) -> dict[str, Any]:
|
||||
"""Get statistics about configured patterns."""
|
||||
by_category: dict[str, int] = {}
|
||||
by_action: dict[str, int] = {}
|
||||
|
||||
for pattern in self._patterns:
|
||||
cat = pattern.category.value
|
||||
by_category[cat] = by_category.get(cat, 0) + 1
|
||||
|
||||
act = pattern.action.value
|
||||
by_action[act] = by_action.get(act, 0) + 1
|
||||
|
||||
return {
|
||||
"total_patterns": len(self._patterns),
|
||||
"enabled_patterns": sum(1 for p in self._patterns if p.enabled),
|
||||
"by_category": by_category,
|
||||
"by_action": by_action,
|
||||
}
|
||||
|
||||
|
||||
# Convenience function for quick filtering
|
||||
async def filter_content(content: str) -> str:
|
||||
"""Quick filter content with default settings."""
|
||||
filter_instance = ContentFilter()
|
||||
result = await filter_instance.filter(content)
|
||||
return result.filtered_content
|
||||
|
||||
|
||||
async def scan_for_secrets(content: str) -> list[FilterMatch]:
|
||||
"""Quick scan for secrets only."""
|
||||
filter_instance = ContentFilter(
|
||||
enable_pii_filter=False,
|
||||
enable_injection_filter=False,
|
||||
)
|
||||
return await filter_instance.scan(
|
||||
content,
|
||||
categories=[ContentCategory.SECRETS, ContentCategory.CREDENTIALS],
|
||||
)
|
||||
@@ -1 +1,23 @@
|
||||
"""${dir} module."""
|
||||
"""Emergency controls for agent safety."""
|
||||
|
||||
from .controls import (
|
||||
EmergencyControls,
|
||||
EmergencyEvent,
|
||||
EmergencyReason,
|
||||
EmergencyState,
|
||||
EmergencyTrigger,
|
||||
check_emergency_allowed,
|
||||
emergency_stop_global,
|
||||
get_emergency_controls,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"EmergencyControls",
|
||||
"EmergencyEvent",
|
||||
"EmergencyReason",
|
||||
"EmergencyState",
|
||||
"EmergencyTrigger",
|
||||
"check_emergency_allowed",
|
||||
"emergency_stop_global",
|
||||
"get_emergency_controls",
|
||||
]
|
||||
|
||||
594
backend/app/services/safety/emergency/controls.py
Normal file
594
backend/app/services/safety/emergency/controls.py
Normal file
@@ -0,0 +1,594 @@
|
||||
"""
|
||||
Emergency Controls
|
||||
|
||||
Emergency stop and pause functionality for agent safety.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from ..exceptions import EmergencyStopError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmergencyState(str, Enum):
|
||||
"""Emergency control states."""
|
||||
|
||||
NORMAL = "normal"
|
||||
PAUSED = "paused"
|
||||
STOPPED = "stopped"
|
||||
|
||||
|
||||
class EmergencyReason(str, Enum):
|
||||
"""Reasons for emergency actions."""
|
||||
|
||||
MANUAL = "manual"
|
||||
SAFETY_VIOLATION = "safety_violation"
|
||||
BUDGET_EXCEEDED = "budget_exceeded"
|
||||
LOOP_DETECTED = "loop_detected"
|
||||
RATE_LIMIT = "rate_limit"
|
||||
CONTENT_VIOLATION = "content_violation"
|
||||
SYSTEM_ERROR = "system_error"
|
||||
EXTERNAL_TRIGGER = "external_trigger"
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmergencyEvent:
|
||||
"""Record of an emergency action."""
|
||||
|
||||
id: str
|
||||
state: EmergencyState
|
||||
reason: EmergencyReason
|
||||
triggered_by: str
|
||||
message: str
|
||||
scope: str # "global", "project:<id>", "agent:<id>"
|
||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
resolved_at: datetime | None = None
|
||||
resolved_by: str | None = None
|
||||
|
||||
|
||||
class EmergencyControls:
|
||||
"""
|
||||
Emergency stop and pause controls for agent safety.
|
||||
|
||||
Features:
|
||||
- Global emergency stop
|
||||
- Per-project/agent emergency controls
|
||||
- Graceful pause with state preservation
|
||||
- Automatic triggers from safety violations
|
||||
- Manual override capabilities
|
||||
- Event history and audit trail
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
notification_handlers: list[Callable[..., Any]] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize EmergencyControls.
|
||||
|
||||
Args:
|
||||
notification_handlers: Handlers to call on emergency events
|
||||
"""
|
||||
self._global_state = EmergencyState.NORMAL
|
||||
self._scoped_states: dict[str, EmergencyState] = {}
|
||||
self._events: list[EmergencyEvent] = []
|
||||
self._notification_handlers = notification_handlers or []
|
||||
self._lock = asyncio.Lock()
|
||||
self._event_id_counter = 0
|
||||
|
||||
# Callbacks for state changes
|
||||
self._on_stop_callbacks: list[Callable[..., Any]] = []
|
||||
self._on_pause_callbacks: list[Callable[..., Any]] = []
|
||||
self._on_resume_callbacks: list[Callable[..., Any]] = []
|
||||
|
||||
def _generate_event_id(self) -> str:
|
||||
"""Generate a unique event ID."""
|
||||
self._event_id_counter += 1
|
||||
return f"emerg-{self._event_id_counter:06d}"
|
||||
|
||||
async def emergency_stop(
|
||||
self,
|
||||
reason: EmergencyReason,
|
||||
triggered_by: str,
|
||||
message: str,
|
||||
scope: str = "global",
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> EmergencyEvent:
|
||||
"""
|
||||
Trigger emergency stop.
|
||||
|
||||
Args:
|
||||
reason: Reason for the stop
|
||||
triggered_by: Who/what triggered the stop
|
||||
message: Human-readable message
|
||||
scope: Scope of the stop (global, project:<id>, agent:<id>)
|
||||
metadata: Additional context
|
||||
|
||||
Returns:
|
||||
The emergency event record
|
||||
"""
|
||||
async with self._lock:
|
||||
event = EmergencyEvent(
|
||||
id=self._generate_event_id(),
|
||||
state=EmergencyState.STOPPED,
|
||||
reason=reason,
|
||||
triggered_by=triggered_by,
|
||||
message=message,
|
||||
scope=scope,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
if scope == "global":
|
||||
self._global_state = EmergencyState.STOPPED
|
||||
else:
|
||||
self._scoped_states[scope] = EmergencyState.STOPPED
|
||||
|
||||
self._events.append(event)
|
||||
|
||||
logger.critical(
|
||||
"EMERGENCY STOP: scope=%s, reason=%s, by=%s - %s",
|
||||
scope,
|
||||
reason.value,
|
||||
triggered_by,
|
||||
message,
|
||||
)
|
||||
|
||||
# Execute callbacks
|
||||
await self._execute_callbacks(self._on_stop_callbacks, event)
|
||||
await self._notify_handlers("emergency_stop", event)
|
||||
|
||||
return event
|
||||
|
||||
async def pause(
|
||||
self,
|
||||
reason: EmergencyReason,
|
||||
triggered_by: str,
|
||||
message: str,
|
||||
scope: str = "global",
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> EmergencyEvent:
|
||||
"""
|
||||
Pause operations (can be resumed).
|
||||
|
||||
Args:
|
||||
reason: Reason for the pause
|
||||
triggered_by: Who/what triggered the pause
|
||||
message: Human-readable message
|
||||
scope: Scope of the pause
|
||||
metadata: Additional context
|
||||
|
||||
Returns:
|
||||
The emergency event record
|
||||
"""
|
||||
async with self._lock:
|
||||
event = EmergencyEvent(
|
||||
id=self._generate_event_id(),
|
||||
state=EmergencyState.PAUSED,
|
||||
reason=reason,
|
||||
triggered_by=triggered_by,
|
||||
message=message,
|
||||
scope=scope,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
if scope == "global":
|
||||
self._global_state = EmergencyState.PAUSED
|
||||
else:
|
||||
self._scoped_states[scope] = EmergencyState.PAUSED
|
||||
|
||||
self._events.append(event)
|
||||
|
||||
logger.warning(
|
||||
"PAUSE: scope=%s, reason=%s, by=%s - %s",
|
||||
scope,
|
||||
reason.value,
|
||||
triggered_by,
|
||||
message,
|
||||
)
|
||||
|
||||
await self._execute_callbacks(self._on_pause_callbacks, event)
|
||||
await self._notify_handlers("pause", event)
|
||||
|
||||
return event
|
||||
|
||||
async def resume(
|
||||
self,
|
||||
scope: str = "global",
|
||||
resumed_by: str = "system",
|
||||
message: str | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Resume operations from paused state.
|
||||
|
||||
Args:
|
||||
scope: Scope to resume
|
||||
resumed_by: Who/what is resuming
|
||||
message: Optional message
|
||||
|
||||
Returns:
|
||||
True if resumed, False if not in paused state
|
||||
"""
|
||||
async with self._lock:
|
||||
current_state = self._get_state(scope)
|
||||
|
||||
if current_state == EmergencyState.STOPPED:
|
||||
logger.warning(
|
||||
"Cannot resume from STOPPED state: %s (requires reset)",
|
||||
scope,
|
||||
)
|
||||
return False
|
||||
|
||||
if current_state == EmergencyState.NORMAL:
|
||||
return True # Already normal
|
||||
|
||||
# Find the pause event and mark as resolved
|
||||
for event in reversed(self._events):
|
||||
if event.scope == scope and event.state == EmergencyState.PAUSED:
|
||||
if event.resolved_at is None:
|
||||
event.resolved_at = datetime.utcnow()
|
||||
event.resolved_by = resumed_by
|
||||
break
|
||||
|
||||
if scope == "global":
|
||||
self._global_state = EmergencyState.NORMAL
|
||||
else:
|
||||
self._scoped_states[scope] = EmergencyState.NORMAL
|
||||
|
||||
logger.info(
|
||||
"RESUMED: scope=%s, by=%s%s",
|
||||
scope,
|
||||
resumed_by,
|
||||
f" - {message}" if message else "",
|
||||
)
|
||||
|
||||
await self._execute_callbacks(
|
||||
self._on_resume_callbacks,
|
||||
{"scope": scope, "resumed_by": resumed_by},
|
||||
)
|
||||
await self._notify_handlers("resume", {"scope": scope, "resumed_by": resumed_by})
|
||||
|
||||
return True
|
||||
|
||||
async def reset(
|
||||
self,
|
||||
scope: str = "global",
|
||||
reset_by: str = "admin",
|
||||
message: str | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Reset from stopped state (requires explicit action).
|
||||
|
||||
Args:
|
||||
scope: Scope to reset
|
||||
reset_by: Who is resetting (should be admin)
|
||||
message: Optional message
|
||||
|
||||
Returns:
|
||||
True if reset successful
|
||||
"""
|
||||
async with self._lock:
|
||||
current_state = self._get_state(scope)
|
||||
|
||||
if current_state == EmergencyState.NORMAL:
|
||||
return True
|
||||
|
||||
# Find the stop event and mark as resolved
|
||||
for event in reversed(self._events):
|
||||
if event.scope == scope and event.state == EmergencyState.STOPPED:
|
||||
if event.resolved_at is None:
|
||||
event.resolved_at = datetime.utcnow()
|
||||
event.resolved_by = reset_by
|
||||
break
|
||||
|
||||
if scope == "global":
|
||||
self._global_state = EmergencyState.NORMAL
|
||||
else:
|
||||
self._scoped_states[scope] = EmergencyState.NORMAL
|
||||
|
||||
logger.warning(
|
||||
"EMERGENCY RESET: scope=%s, by=%s%s",
|
||||
scope,
|
||||
reset_by,
|
||||
f" - {message}" if message else "",
|
||||
)
|
||||
|
||||
await self._notify_handlers("reset", {"scope": scope, "reset_by": reset_by})
|
||||
|
||||
return True
|
||||
|
||||
async def check_allowed(
|
||||
self,
|
||||
scope: str | None = None,
|
||||
raise_if_blocked: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if operations are allowed.
|
||||
|
||||
Args:
|
||||
scope: Specific scope to check (also checks global)
|
||||
raise_if_blocked: Raise exception if blocked
|
||||
|
||||
Returns:
|
||||
True if operations are allowed
|
||||
|
||||
Raises:
|
||||
EmergencyStopError: If blocked and raise_if_blocked=True
|
||||
"""
|
||||
async with self._lock:
|
||||
# Always check global state
|
||||
if self._global_state != EmergencyState.NORMAL:
|
||||
if raise_if_blocked:
|
||||
raise EmergencyStopError(
|
||||
f"Global emergency state: {self._global_state.value}",
|
||||
stop_reason=self._get_last_reason("global"),
|
||||
triggered_by=self._get_last_triggered_by("global"),
|
||||
)
|
||||
return False
|
||||
|
||||
# Check specific scope
|
||||
if scope and scope in self._scoped_states:
|
||||
state = self._scoped_states[scope]
|
||||
if state != EmergencyState.NORMAL:
|
||||
if raise_if_blocked:
|
||||
raise EmergencyStopError(
|
||||
f"Emergency state for {scope}: {state.value}",
|
||||
stop_reason=self._get_last_reason(scope),
|
||||
triggered_by=self._get_last_triggered_by(scope),
|
||||
scope=scope,
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _get_state(self, scope: str) -> EmergencyState:
|
||||
"""Get state for a scope."""
|
||||
if scope == "global":
|
||||
return self._global_state
|
||||
return self._scoped_states.get(scope, EmergencyState.NORMAL)
|
||||
|
||||
def _get_last_reason(self, scope: str) -> str:
|
||||
"""Get reason from last event for scope."""
|
||||
for event in reversed(self._events):
|
||||
if event.scope == scope and event.resolved_at is None:
|
||||
return event.reason.value
|
||||
return "unknown"
|
||||
|
||||
def _get_last_triggered_by(self, scope: str) -> str:
|
||||
"""Get triggered_by from last event for scope."""
|
||||
for event in reversed(self._events):
|
||||
if event.scope == scope and event.resolved_at is None:
|
||||
return event.triggered_by
|
||||
return "unknown"
|
||||
|
||||
async def get_state(self, scope: str = "global") -> EmergencyState:
|
||||
"""Get current state for a scope."""
|
||||
async with self._lock:
|
||||
return self._get_state(scope)
|
||||
|
||||
async def get_all_states(self) -> dict[str, EmergencyState]:
|
||||
"""Get all current states."""
|
||||
async with self._lock:
|
||||
states = {"global": self._global_state}
|
||||
states.update(self._scoped_states)
|
||||
return states
|
||||
|
||||
async def get_active_events(self) -> list[EmergencyEvent]:
|
||||
"""Get all unresolved emergency events."""
|
||||
async with self._lock:
|
||||
return [e for e in self._events if e.resolved_at is None]
|
||||
|
||||
async def get_event_history(
|
||||
self,
|
||||
scope: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[EmergencyEvent]:
|
||||
"""Get emergency event history."""
|
||||
async with self._lock:
|
||||
events = list(self._events)
|
||||
|
||||
if scope:
|
||||
events = [e for e in events if e.scope == scope]
|
||||
|
||||
return events[-limit:]
|
||||
|
||||
def on_stop(self, callback: Callable[..., Any]) -> None:
|
||||
"""Register callback for stop events."""
|
||||
self._on_stop_callbacks.append(callback)
|
||||
|
||||
def on_pause(self, callback: Callable[..., Any]) -> None:
|
||||
"""Register callback for pause events."""
|
||||
self._on_pause_callbacks.append(callback)
|
||||
|
||||
def on_resume(self, callback: Callable[..., Any]) -> None:
|
||||
"""Register callback for resume events."""
|
||||
self._on_resume_callbacks.append(callback)
|
||||
|
||||
def add_notification_handler(self, handler: Callable[..., Any]) -> None:
|
||||
"""Add a notification handler."""
|
||||
self._notification_handlers.append(handler)
|
||||
|
||||
async def _execute_callbacks(
|
||||
self,
|
||||
callbacks: list[Callable[..., Any]],
|
||||
data: Any,
|
||||
) -> None:
|
||||
"""Execute callbacks safely."""
|
||||
for callback in callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(data)
|
||||
else:
|
||||
callback(data)
|
||||
except Exception as e:
|
||||
logger.error("Error in callback: %s", e)
|
||||
|
||||
async def _notify_handlers(self, event_type: str, data: Any) -> None:
|
||||
"""Notify all handlers of an event."""
|
||||
for handler in self._notification_handlers:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
await handler(event_type, data)
|
||||
else:
|
||||
handler(event_type, data)
|
||||
except Exception as e:
|
||||
logger.error("Error in notification handler: %s", e)
|
||||
|
||||
|
||||
class EmergencyTrigger:
|
||||
"""
|
||||
Automatic emergency triggers based on conditions.
|
||||
"""
|
||||
|
||||
def __init__(self, controls: EmergencyControls) -> None:
|
||||
"""
|
||||
Initialize EmergencyTrigger.
|
||||
|
||||
Args:
|
||||
controls: EmergencyControls instance to trigger
|
||||
"""
|
||||
self._controls = controls
|
||||
|
||||
async def trigger_on_safety_violation(
|
||||
self,
|
||||
violation_type: str,
|
||||
details: dict[str, Any],
|
||||
scope: str = "global",
|
||||
) -> EmergencyEvent:
|
||||
"""
|
||||
Trigger emergency from safety violation.
|
||||
|
||||
Args:
|
||||
violation_type: Type of violation
|
||||
details: Violation details
|
||||
scope: Scope for the emergency
|
||||
|
||||
Returns:
|
||||
Emergency event
|
||||
"""
|
||||
return await self._controls.emergency_stop(
|
||||
reason=EmergencyReason.SAFETY_VIOLATION,
|
||||
triggered_by="safety_system",
|
||||
message=f"Safety violation: {violation_type}",
|
||||
scope=scope,
|
||||
metadata={"violation_type": violation_type, **details},
|
||||
)
|
||||
|
||||
async def trigger_on_budget_exceeded(
|
||||
self,
|
||||
budget_type: str,
|
||||
current: float,
|
||||
limit: float,
|
||||
scope: str = "global",
|
||||
) -> EmergencyEvent:
|
||||
"""
|
||||
Trigger emergency from budget exceeded.
|
||||
|
||||
Args:
|
||||
budget_type: Type of budget
|
||||
current: Current usage
|
||||
limit: Budget limit
|
||||
scope: Scope for the emergency
|
||||
|
||||
Returns:
|
||||
Emergency event
|
||||
"""
|
||||
return await self._controls.pause(
|
||||
reason=EmergencyReason.BUDGET_EXCEEDED,
|
||||
triggered_by="budget_controller",
|
||||
message=f"Budget exceeded: {budget_type} ({current:.2f}/{limit:.2f})",
|
||||
scope=scope,
|
||||
metadata={"budget_type": budget_type, "current": current, "limit": limit},
|
||||
)
|
||||
|
||||
async def trigger_on_loop_detected(
|
||||
self,
|
||||
loop_type: str,
|
||||
agent_id: str,
|
||||
details: dict[str, Any],
|
||||
) -> EmergencyEvent:
|
||||
"""
|
||||
Trigger emergency from loop detection.
|
||||
|
||||
Args:
|
||||
loop_type: Type of loop
|
||||
agent_id: Agent that's looping
|
||||
details: Loop details
|
||||
|
||||
Returns:
|
||||
Emergency event
|
||||
"""
|
||||
return await self._controls.pause(
|
||||
reason=EmergencyReason.LOOP_DETECTED,
|
||||
triggered_by="loop_detector",
|
||||
message=f"Loop detected: {loop_type} in agent {agent_id}",
|
||||
scope=f"agent:{agent_id}",
|
||||
metadata={"loop_type": loop_type, "agent_id": agent_id, **details},
|
||||
)
|
||||
|
||||
async def trigger_on_content_violation(
|
||||
self,
|
||||
category: str,
|
||||
pattern: str,
|
||||
scope: str = "global",
|
||||
) -> EmergencyEvent:
|
||||
"""
|
||||
Trigger emergency from content violation.
|
||||
|
||||
Args:
|
||||
category: Content category
|
||||
pattern: Pattern that matched
|
||||
scope: Scope for the emergency
|
||||
|
||||
Returns:
|
||||
Emergency event
|
||||
"""
|
||||
return await self._controls.emergency_stop(
|
||||
reason=EmergencyReason.CONTENT_VIOLATION,
|
||||
triggered_by="content_filter",
|
||||
message=f"Content violation: {category} ({pattern})",
|
||||
scope=scope,
|
||||
metadata={"category": category, "pattern": pattern},
|
||||
)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_emergency_controls: EmergencyControls | None = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_emergency_controls() -> EmergencyControls:
|
||||
"""Get the singleton EmergencyControls instance."""
|
||||
global _emergency_controls
|
||||
|
||||
async with _lock:
|
||||
if _emergency_controls is None:
|
||||
_emergency_controls = EmergencyControls()
|
||||
return _emergency_controls
|
||||
|
||||
|
||||
async def emergency_stop_global(
|
||||
reason: str,
|
||||
triggered_by: str = "system",
|
||||
) -> EmergencyEvent:
|
||||
"""Quick global emergency stop."""
|
||||
controls = await get_emergency_controls()
|
||||
return await controls.emergency_stop(
|
||||
reason=EmergencyReason.MANUAL,
|
||||
triggered_by=triggered_by,
|
||||
message=reason,
|
||||
scope="global",
|
||||
)
|
||||
|
||||
|
||||
async def check_emergency_allowed(scope: str | None = None) -> bool:
|
||||
"""Quick check if operations are allowed."""
|
||||
controls = await get_emergency_controls()
|
||||
return await controls.check_allowed(scope=scope, raise_if_blocked=False)
|
||||
@@ -1 +1,5 @@
|
||||
"""${dir} module."""
|
||||
"""Human-in-the-Loop approval workflows."""
|
||||
|
||||
from .manager import ApprovalQueue, HITLManager
|
||||
|
||||
__all__ = ["ApprovalQueue", "HITLManager"]
|
||||
|
||||
449
backend/app/services/safety/hitl/manager.py
Normal file
449
backend/app/services/safety/hitl/manager.py
Normal file
@@ -0,0 +1,449 @@
|
||||
"""
|
||||
Human-in-the-Loop (HITL) Manager
|
||||
|
||||
Manages approval workflows for actions requiring human oversight.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from ..config import get_safety_config
|
||||
from ..exceptions import (
|
||||
ApprovalDeniedError,
|
||||
ApprovalRequiredError,
|
||||
ApprovalTimeoutError,
|
||||
)
|
||||
from ..models import (
|
||||
ActionRequest,
|
||||
ApprovalRequest,
|
||||
ApprovalResponse,
|
||||
ApprovalStatus,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ApprovalQueue:
|
||||
"""Queue for pending approval requests."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._pending: dict[str, ApprovalRequest] = {}
|
||||
self._completed: dict[str, ApprovalResponse] = {}
|
||||
self._waiters: dict[str, asyncio.Event] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def add(self, request: ApprovalRequest) -> None:
|
||||
"""Add an approval request to the queue."""
|
||||
async with self._lock:
|
||||
self._pending[request.id] = request
|
||||
self._waiters[request.id] = asyncio.Event()
|
||||
|
||||
async def get_pending(self, request_id: str) -> ApprovalRequest | None:
|
||||
"""Get a pending request by ID."""
|
||||
async with self._lock:
|
||||
return self._pending.get(request_id)
|
||||
|
||||
async def complete(self, response: ApprovalResponse) -> bool:
|
||||
"""Complete an approval request."""
|
||||
async with self._lock:
|
||||
if response.request_id not in self._pending:
|
||||
return False
|
||||
|
||||
del self._pending[response.request_id]
|
||||
self._completed[response.request_id] = response
|
||||
|
||||
# Notify waiters
|
||||
if response.request_id in self._waiters:
|
||||
self._waiters[response.request_id].set()
|
||||
|
||||
return True
|
||||
|
||||
async def wait_for_response(
|
||||
self,
|
||||
request_id: str,
|
||||
timeout_seconds: float,
|
||||
) -> ApprovalResponse | None:
|
||||
"""Wait for a response to an approval request."""
|
||||
async with self._lock:
|
||||
waiter = self._waiters.get(request_id)
|
||||
if not waiter:
|
||||
return self._completed.get(request_id)
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(waiter.wait(), timeout=timeout_seconds)
|
||||
except TimeoutError:
|
||||
return None
|
||||
|
||||
async with self._lock:
|
||||
return self._completed.get(request_id)
|
||||
|
||||
async def list_pending(self) -> list[ApprovalRequest]:
|
||||
"""List all pending requests."""
|
||||
async with self._lock:
|
||||
return list(self._pending.values())
|
||||
|
||||
async def cancel(self, request_id: str) -> bool:
|
||||
"""Cancel a pending request."""
|
||||
async with self._lock:
|
||||
if request_id not in self._pending:
|
||||
return False
|
||||
|
||||
del self._pending[request_id]
|
||||
|
||||
# Create cancelled response
|
||||
response = ApprovalResponse(
|
||||
request_id=request_id,
|
||||
status=ApprovalStatus.CANCELLED,
|
||||
reason="Cancelled",
|
||||
)
|
||||
self._completed[request_id] = response
|
||||
|
||||
# Notify waiters
|
||||
if request_id in self._waiters:
|
||||
self._waiters[request_id].set()
|
||||
|
||||
return True
|
||||
|
||||
async def cleanup_expired(self) -> int:
|
||||
"""Clean up expired requests."""
|
||||
now = datetime.utcnow()
|
||||
to_timeout: list[str] = []
|
||||
|
||||
async with self._lock:
|
||||
for request_id, request in self._pending.items():
|
||||
if request.expires_at and request.expires_at < now:
|
||||
to_timeout.append(request_id)
|
||||
|
||||
count = 0
|
||||
for request_id in to_timeout:
|
||||
async with self._lock:
|
||||
if request_id in self._pending:
|
||||
del self._pending[request_id]
|
||||
self._completed[request_id] = ApprovalResponse(
|
||||
request_id=request_id,
|
||||
status=ApprovalStatus.TIMEOUT,
|
||||
reason="Request timed out",
|
||||
)
|
||||
if request_id in self._waiters:
|
||||
self._waiters[request_id].set()
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
|
||||
class HITLManager:
|
||||
"""
|
||||
Manages Human-in-the-Loop approval workflows.
|
||||
|
||||
Features:
|
||||
- Approval request queue
|
||||
- Configurable timeout handling (default deny)
|
||||
- Approval delegation
|
||||
- Batch approval for similar actions
|
||||
- Approval with modifications
|
||||
- Notification channels
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default_timeout: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the HITLManager.
|
||||
|
||||
Args:
|
||||
default_timeout: Default timeout for approval requests in seconds
|
||||
"""
|
||||
config = get_safety_config()
|
||||
|
||||
self._default_timeout = default_timeout or config.hitl_default_timeout
|
||||
self._queue = ApprovalQueue()
|
||||
self._notification_handlers: list[Callable[..., Any]] = []
|
||||
self._running = False
|
||||
self._cleanup_task: asyncio.Task[None] | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the HITL manager background tasks."""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
|
||||
logger.info("HITL Manager started")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the HITL manager."""
|
||||
self._running = False
|
||||
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info("HITL Manager stopped")
|
||||
|
||||
async def request_approval(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
reason: str,
|
||||
timeout_seconds: int | None = None,
|
||||
urgency: str = "normal",
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> ApprovalRequest:
|
||||
"""
|
||||
Create an approval request for an action.
|
||||
|
||||
Args:
|
||||
action: The action requiring approval
|
||||
reason: Why approval is required
|
||||
timeout_seconds: Timeout for this request
|
||||
urgency: Urgency level (low, normal, high, critical)
|
||||
context: Additional context for the approver
|
||||
|
||||
Returns:
|
||||
The created approval request
|
||||
"""
|
||||
timeout = timeout_seconds or self._default_timeout
|
||||
expires_at = datetime.utcnow() + timedelta(seconds=timeout)
|
||||
|
||||
request = ApprovalRequest(
|
||||
id=str(uuid4()),
|
||||
action=action,
|
||||
reason=reason,
|
||||
urgency=urgency,
|
||||
timeout_seconds=timeout,
|
||||
expires_at=expires_at,
|
||||
context=context or {},
|
||||
)
|
||||
|
||||
await self._queue.add(request)
|
||||
|
||||
# Notify handlers
|
||||
await self._notify_handlers("approval_requested", request)
|
||||
|
||||
logger.info(
|
||||
"Approval requested: %s for action %s (timeout: %ds)",
|
||||
request.id,
|
||||
action.id,
|
||||
timeout,
|
||||
)
|
||||
|
||||
return request
|
||||
|
||||
async def wait_for_approval(
|
||||
self,
|
||||
request_id: str,
|
||||
timeout_seconds: int | None = None,
|
||||
) -> ApprovalResponse:
|
||||
"""
|
||||
Wait for an approval decision.
|
||||
|
||||
Args:
|
||||
request_id: ID of the approval request
|
||||
timeout_seconds: Override timeout
|
||||
|
||||
Returns:
|
||||
The approval response
|
||||
|
||||
Raises:
|
||||
ApprovalTimeoutError: If timeout expires
|
||||
ApprovalDeniedError: If approval is denied
|
||||
"""
|
||||
request = await self._queue.get_pending(request_id)
|
||||
if not request:
|
||||
raise ApprovalRequiredError(
|
||||
f"Approval request not found: {request_id}",
|
||||
approval_id=request_id,
|
||||
)
|
||||
|
||||
timeout = timeout_seconds or request.timeout_seconds or self._default_timeout
|
||||
response = await self._queue.wait_for_response(request_id, timeout)
|
||||
|
||||
if response is None:
|
||||
# Timeout - default deny
|
||||
response = ApprovalResponse(
|
||||
request_id=request_id,
|
||||
status=ApprovalStatus.TIMEOUT,
|
||||
reason="Request timed out (default deny)",
|
||||
)
|
||||
await self._queue.complete(response)
|
||||
|
||||
raise ApprovalTimeoutError(
|
||||
"Approval request timed out",
|
||||
approval_id=request_id,
|
||||
timeout_seconds=timeout,
|
||||
)
|
||||
|
||||
if response.status == ApprovalStatus.DENIED:
|
||||
raise ApprovalDeniedError(
|
||||
response.reason or "Approval denied",
|
||||
approval_id=request_id,
|
||||
denied_by=response.decided_by,
|
||||
denial_reason=response.reason,
|
||||
)
|
||||
|
||||
if response.status == ApprovalStatus.TIMEOUT:
|
||||
raise ApprovalTimeoutError(
|
||||
"Approval request timed out",
|
||||
approval_id=request_id,
|
||||
timeout_seconds=timeout,
|
||||
)
|
||||
|
||||
if response.status == ApprovalStatus.CANCELLED:
|
||||
raise ApprovalDeniedError(
|
||||
"Approval request was cancelled",
|
||||
approval_id=request_id,
|
||||
denial_reason="Cancelled",
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def approve(
|
||||
self,
|
||||
request_id: str,
|
||||
decided_by: str,
|
||||
reason: str | None = None,
|
||||
modifications: dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Approve a pending request.
|
||||
|
||||
Args:
|
||||
request_id: ID of the approval request
|
||||
decided_by: Who approved
|
||||
reason: Optional approval reason
|
||||
modifications: Optional modifications to the action
|
||||
|
||||
Returns:
|
||||
True if approval was recorded
|
||||
"""
|
||||
response = ApprovalResponse(
|
||||
request_id=request_id,
|
||||
status=ApprovalStatus.APPROVED,
|
||||
decided_by=decided_by,
|
||||
reason=reason,
|
||||
modifications=modifications,
|
||||
)
|
||||
|
||||
success = await self._queue.complete(response)
|
||||
|
||||
if success:
|
||||
logger.info(
|
||||
"Approval granted: %s by %s",
|
||||
request_id,
|
||||
decided_by,
|
||||
)
|
||||
await self._notify_handlers("approval_granted", response)
|
||||
|
||||
return success
|
||||
|
||||
async def deny(
|
||||
self,
|
||||
request_id: str,
|
||||
decided_by: str,
|
||||
reason: str | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Deny a pending request.
|
||||
|
||||
Args:
|
||||
request_id: ID of the approval request
|
||||
decided_by: Who denied
|
||||
reason: Denial reason
|
||||
|
||||
Returns:
|
||||
True if denial was recorded
|
||||
"""
|
||||
response = ApprovalResponse(
|
||||
request_id=request_id,
|
||||
status=ApprovalStatus.DENIED,
|
||||
decided_by=decided_by,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
success = await self._queue.complete(response)
|
||||
|
||||
if success:
|
||||
logger.info(
|
||||
"Approval denied: %s by %s - %s",
|
||||
request_id,
|
||||
decided_by,
|
||||
reason,
|
||||
)
|
||||
await self._notify_handlers("approval_denied", response)
|
||||
|
||||
return success
|
||||
|
||||
async def cancel(self, request_id: str) -> bool:
|
||||
"""
|
||||
Cancel a pending request.
|
||||
|
||||
Args:
|
||||
request_id: ID of the approval request
|
||||
|
||||
Returns:
|
||||
True if request was cancelled
|
||||
"""
|
||||
success = await self._queue.cancel(request_id)
|
||||
|
||||
if success:
|
||||
logger.info("Approval request cancelled: %s", request_id)
|
||||
|
||||
return success
|
||||
|
||||
async def list_pending(self) -> list[ApprovalRequest]:
|
||||
"""List all pending approval requests."""
|
||||
return await self._queue.list_pending()
|
||||
|
||||
async def get_request(self, request_id: str) -> ApprovalRequest | None:
|
||||
"""Get an approval request by ID."""
|
||||
return await self._queue.get_pending(request_id)
|
||||
|
||||
def add_notification_handler(
|
||||
self,
|
||||
handler: Callable[..., Any],
|
||||
) -> None:
|
||||
"""Add a notification handler."""
|
||||
self._notification_handlers.append(handler)
|
||||
|
||||
def remove_notification_handler(
|
||||
self,
|
||||
handler: Callable[..., Any],
|
||||
) -> None:
|
||||
"""Remove a notification handler."""
|
||||
if handler in self._notification_handlers:
|
||||
self._notification_handlers.remove(handler)
|
||||
|
||||
async def _notify_handlers(
|
||||
self,
|
||||
event_type: str,
|
||||
data: Any,
|
||||
) -> None:
|
||||
"""Notify all handlers of an event."""
|
||||
for handler in self._notification_handlers:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
await handler(event_type, data)
|
||||
else:
|
||||
handler(event_type, data)
|
||||
except Exception as e:
|
||||
logger.error("Error in notification handler: %s", e)
|
||||
|
||||
async def _periodic_cleanup(self) -> None:
|
||||
"""Background task for cleaning up expired requests."""
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(30) # Check every 30 seconds
|
||||
count = await self._queue.cleanup_expired()
|
||||
if count:
|
||||
logger.debug("Cleaned up %d expired approval requests", count)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("Error in approval cleanup: %s", e)
|
||||
@@ -1 +1,5 @@
|
||||
"""${dir} module."""
|
||||
"""Rollback management for agent actions."""
|
||||
|
||||
from .manager import RollbackManager, TransactionContext
|
||||
|
||||
__all__ = ["RollbackManager", "TransactionContext"]
|
||||
|
||||
418
backend/app/services/safety/rollback/manager.py
Normal file
418
backend/app/services/safety/rollback/manager.py
Normal file
@@ -0,0 +1,418 @@
|
||||
"""
|
||||
Rollback Manager
|
||||
|
||||
Manages checkpoints and rollback operations for agent actions.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from ..config import get_safety_config
|
||||
from ..exceptions import RollbackError
|
||||
from ..models import (
|
||||
ActionRequest,
|
||||
Checkpoint,
|
||||
CheckpointType,
|
||||
RollbackResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileCheckpoint:
|
||||
"""Stores file state for rollback."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_id: str,
|
||||
file_path: str,
|
||||
original_content: bytes | None,
|
||||
existed: bool,
|
||||
) -> None:
|
||||
self.checkpoint_id = checkpoint_id
|
||||
self.file_path = file_path
|
||||
self.original_content = original_content
|
||||
self.existed = existed
|
||||
self.created_at = datetime.utcnow()
|
||||
|
||||
|
||||
class RollbackManager:
|
||||
"""
|
||||
Manages checkpoints and rollback operations.
|
||||
|
||||
Features:
|
||||
- File system checkpoints
|
||||
- Transaction wrapping for actions
|
||||
- Automatic checkpoint for destructive actions
|
||||
- Rollback triggers on failure
|
||||
- Checkpoint expiration and cleanup
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_dir: str | None = None,
|
||||
retention_hours: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the RollbackManager.
|
||||
|
||||
Args:
|
||||
checkpoint_dir: Directory for storing checkpoint data
|
||||
retention_hours: Hours to retain checkpoints
|
||||
"""
|
||||
config = get_safety_config()
|
||||
|
||||
self._checkpoint_dir = Path(
|
||||
checkpoint_dir or config.checkpoint_dir
|
||||
)
|
||||
self._retention_hours = retention_hours or config.checkpoint_retention_hours
|
||||
|
||||
self._checkpoints: dict[str, Checkpoint] = {}
|
||||
self._file_checkpoints: dict[str, list[FileCheckpoint]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Ensure checkpoint directory exists
|
||||
self._checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def create_checkpoint(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
checkpoint_type: CheckpointType = CheckpointType.COMPOSITE,
|
||||
description: str | None = None,
|
||||
) -> Checkpoint:
|
||||
"""
|
||||
Create a checkpoint before an action.
|
||||
|
||||
Args:
|
||||
action: The action to checkpoint for
|
||||
checkpoint_type: Type of checkpoint
|
||||
description: Optional description
|
||||
|
||||
Returns:
|
||||
The created checkpoint
|
||||
"""
|
||||
checkpoint_id = str(uuid4())
|
||||
|
||||
checkpoint = Checkpoint(
|
||||
id=checkpoint_id,
|
||||
checkpoint_type=checkpoint_type,
|
||||
action_id=action.id,
|
||||
created_at=datetime.utcnow(),
|
||||
expires_at=datetime.utcnow() + timedelta(hours=self._retention_hours),
|
||||
data={
|
||||
"action_type": action.action_type.value,
|
||||
"tool_name": action.tool_name,
|
||||
"resource": action.resource,
|
||||
},
|
||||
description=description or f"Checkpoint for {action.tool_name}",
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
self._checkpoints[checkpoint_id] = checkpoint
|
||||
self._file_checkpoints[checkpoint_id] = []
|
||||
|
||||
logger.info(
|
||||
"Created checkpoint %s for action %s",
|
||||
checkpoint_id,
|
||||
action.id,
|
||||
)
|
||||
|
||||
return checkpoint
|
||||
|
||||
async def checkpoint_file(
|
||||
self,
|
||||
checkpoint_id: str,
|
||||
file_path: str,
|
||||
) -> None:
|
||||
"""
|
||||
Store current state of a file for checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_id: ID of the checkpoint
|
||||
file_path: Path to the file
|
||||
"""
|
||||
path = Path(file_path)
|
||||
|
||||
if path.exists():
|
||||
content = path.read_bytes()
|
||||
existed = True
|
||||
else:
|
||||
content = None
|
||||
existed = False
|
||||
|
||||
file_checkpoint = FileCheckpoint(
|
||||
checkpoint_id=checkpoint_id,
|
||||
file_path=file_path,
|
||||
original_content=content,
|
||||
existed=existed,
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
if checkpoint_id not in self._file_checkpoints:
|
||||
self._file_checkpoints[checkpoint_id] = []
|
||||
self._file_checkpoints[checkpoint_id].append(file_checkpoint)
|
||||
|
||||
logger.debug(
|
||||
"Stored file state for checkpoint %s: %s (existed=%s)",
|
||||
checkpoint_id,
|
||||
file_path,
|
||||
existed,
|
||||
)
|
||||
|
||||
async def checkpoint_files(
|
||||
self,
|
||||
checkpoint_id: str,
|
||||
file_paths: list[str],
|
||||
) -> None:
|
||||
"""
|
||||
Store current state of multiple files.
|
||||
|
||||
Args:
|
||||
checkpoint_id: ID of the checkpoint
|
||||
file_paths: Paths to the files
|
||||
"""
|
||||
for path in file_paths:
|
||||
await self.checkpoint_file(checkpoint_id, path)
|
||||
|
||||
async def rollback(
|
||||
self,
|
||||
checkpoint_id: str,
|
||||
) -> RollbackResult:
|
||||
"""
|
||||
Rollback to a checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_id: ID of the checkpoint
|
||||
|
||||
Returns:
|
||||
Result of the rollback operation
|
||||
"""
|
||||
async with self._lock:
|
||||
checkpoint = self._checkpoints.get(checkpoint_id)
|
||||
if not checkpoint:
|
||||
raise RollbackError(
|
||||
f"Checkpoint not found: {checkpoint_id}",
|
||||
checkpoint_id=checkpoint_id,
|
||||
)
|
||||
|
||||
if not checkpoint.is_valid:
|
||||
raise RollbackError(
|
||||
f"Checkpoint is no longer valid: {checkpoint_id}",
|
||||
checkpoint_id=checkpoint_id,
|
||||
)
|
||||
|
||||
file_checkpoints = self._file_checkpoints.get(checkpoint_id, [])
|
||||
|
||||
actions_rolled_back: list[str] = []
|
||||
failed_actions: list[str] = []
|
||||
|
||||
# Rollback file changes
|
||||
for fc in file_checkpoints:
|
||||
try:
|
||||
await self._rollback_file(fc)
|
||||
actions_rolled_back.append(f"file:{fc.file_path}")
|
||||
except Exception as e:
|
||||
logger.error("Failed to rollback file %s: %s", fc.file_path, e)
|
||||
failed_actions.append(f"file:{fc.file_path}")
|
||||
|
||||
success = len(failed_actions) == 0
|
||||
|
||||
# Mark checkpoint as used
|
||||
async with self._lock:
|
||||
if checkpoint_id in self._checkpoints:
|
||||
self._checkpoints[checkpoint_id].is_valid = False
|
||||
|
||||
result = RollbackResult(
|
||||
checkpoint_id=checkpoint_id,
|
||||
success=success,
|
||||
actions_rolled_back=actions_rolled_back,
|
||||
failed_actions=failed_actions,
|
||||
error=None if success else f"Failed to rollback {len(failed_actions)} items",
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Rollback successful for checkpoint %s", checkpoint_id)
|
||||
else:
|
||||
logger.error(
|
||||
"Rollback partially failed for checkpoint %s: %d failures",
|
||||
checkpoint_id,
|
||||
len(failed_actions),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def discard_checkpoint(self, checkpoint_id: str) -> bool:
|
||||
"""
|
||||
Discard a checkpoint without rolling back.
|
||||
|
||||
Args:
|
||||
checkpoint_id: ID of the checkpoint
|
||||
|
||||
Returns:
|
||||
True if checkpoint was found and discarded
|
||||
"""
|
||||
async with self._lock:
|
||||
if checkpoint_id in self._checkpoints:
|
||||
del self._checkpoints[checkpoint_id]
|
||||
if checkpoint_id in self._file_checkpoints:
|
||||
del self._file_checkpoints[checkpoint_id]
|
||||
logger.debug("Discarded checkpoint %s", checkpoint_id)
|
||||
return True
|
||||
return False
|
||||
|
||||
async def get_checkpoint(self, checkpoint_id: str) -> Checkpoint | None:
|
||||
"""Get a checkpoint by ID."""
|
||||
async with self._lock:
|
||||
return self._checkpoints.get(checkpoint_id)
|
||||
|
||||
async def list_checkpoints(
|
||||
self,
|
||||
action_id: str | None = None,
|
||||
include_expired: bool = False,
|
||||
) -> list[Checkpoint]:
|
||||
"""
|
||||
List checkpoints.
|
||||
|
||||
Args:
|
||||
action_id: Optional filter by action ID
|
||||
include_expired: Include expired checkpoints
|
||||
|
||||
Returns:
|
||||
List of checkpoints
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
|
||||
async with self._lock:
|
||||
checkpoints = list(self._checkpoints.values())
|
||||
|
||||
if action_id:
|
||||
checkpoints = [c for c in checkpoints if c.action_id == action_id]
|
||||
|
||||
if not include_expired:
|
||||
checkpoints = [
|
||||
c for c in checkpoints
|
||||
if c.expires_at is None or c.expires_at > now
|
||||
]
|
||||
|
||||
return checkpoints
|
||||
|
||||
async def cleanup_expired(self) -> int:
|
||||
"""
|
||||
Clean up expired checkpoints.
|
||||
|
||||
Returns:
|
||||
Number of checkpoints cleaned up
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
to_remove: list[str] = []
|
||||
|
||||
async with self._lock:
|
||||
for checkpoint_id, checkpoint in self._checkpoints.items():
|
||||
if checkpoint.expires_at and checkpoint.expires_at < now:
|
||||
to_remove.append(checkpoint_id)
|
||||
|
||||
for checkpoint_id in to_remove:
|
||||
del self._checkpoints[checkpoint_id]
|
||||
if checkpoint_id in self._file_checkpoints:
|
||||
del self._file_checkpoints[checkpoint_id]
|
||||
|
||||
if to_remove:
|
||||
logger.info("Cleaned up %d expired checkpoints", len(to_remove))
|
||||
|
||||
return len(to_remove)
|
||||
|
||||
async def _rollback_file(self, fc: FileCheckpoint) -> None:
|
||||
"""Rollback a single file to its checkpoint state."""
|
||||
path = Path(fc.file_path)
|
||||
|
||||
if fc.existed:
|
||||
# Restore original content
|
||||
if fc.original_content is not None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_bytes(fc.original_content)
|
||||
logger.debug("Restored file: %s", fc.file_path)
|
||||
else:
|
||||
# File didn't exist before - delete it
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
logger.debug("Deleted file (didn't exist before): %s", fc.file_path)
|
||||
|
||||
|
||||
class TransactionContext:
|
||||
"""
|
||||
Context manager for transactional action execution.
|
||||
|
||||
Usage:
|
||||
async with TransactionContext(rollback_manager, action) as tx:
|
||||
tx.checkpoint_file("/path/to/file")
|
||||
# Do work...
|
||||
# If exception occurs, automatic rollback
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manager: RollbackManager,
|
||||
action: ActionRequest,
|
||||
auto_rollback: bool = True,
|
||||
) -> None:
|
||||
self._manager = manager
|
||||
self._action = action
|
||||
self._auto_rollback = auto_rollback
|
||||
self._checkpoint: Checkpoint | None = None
|
||||
self._committed = False
|
||||
|
||||
async def __aenter__(self) -> "TransactionContext":
|
||||
self._checkpoint = await self._manager.create_checkpoint(self._action)
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type | None,
|
||||
exc_val: Exception | None,
|
||||
exc_tb: Any,
|
||||
) -> bool:
|
||||
if exc_val is not None and self._auto_rollback and not self._committed:
|
||||
# Exception occurred - rollback
|
||||
if self._checkpoint:
|
||||
try:
|
||||
await self._manager.rollback(self._checkpoint.id)
|
||||
logger.info(
|
||||
"Auto-rollback completed for action %s",
|
||||
self._action.id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Auto-rollback failed: %s", e)
|
||||
elif self._committed and self._checkpoint:
|
||||
# Committed - discard checkpoint
|
||||
await self._manager.discard_checkpoint(self._checkpoint.id)
|
||||
|
||||
return False # Don't suppress the exception
|
||||
|
||||
@property
|
||||
def checkpoint_id(self) -> str | None:
|
||||
"""Get the checkpoint ID."""
|
||||
return self._checkpoint.id if self._checkpoint else None
|
||||
|
||||
async def checkpoint_file(self, file_path: str) -> None:
|
||||
"""Checkpoint a file for this transaction."""
|
||||
if self._checkpoint:
|
||||
await self._manager.checkpoint_file(self._checkpoint.id, file_path)
|
||||
|
||||
async def checkpoint_files(self, file_paths: list[str]) -> None:
|
||||
"""Checkpoint multiple files for this transaction."""
|
||||
if self._checkpoint:
|
||||
await self._manager.checkpoint_files(self._checkpoint.id, file_paths)
|
||||
|
||||
def commit(self) -> None:
|
||||
"""Mark transaction as committed (no rollback on exit)."""
|
||||
self._committed = True
|
||||
|
||||
async def rollback(self) -> RollbackResult | None:
|
||||
"""Manually trigger rollback."""
|
||||
if self._checkpoint:
|
||||
return await self._manager.rollback(self._checkpoint.id)
|
||||
return None
|
||||
Reference in New Issue
Block a user