feat(safety): add Phase C advanced controls

- Add rollback manager with file checkpointing and transaction context
- Add HITL manager with approval queues and notification handlers
- Add content filter with PII, secrets, and injection detection
- Add emergency controls with stop/pause/resume capabilities
- Update SafetyConfig with checkpoint_dir setting

Issue #63

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-03 11:36:24 +01:00
parent 728edd1453
commit ef659cd72d
9 changed files with 2053 additions and 4 deletions

View File

@@ -74,6 +74,10 @@ class SafetyConfig(BaseSettings):
# Rollback settings
rollback_enabled: bool = Field(True, description="Enable rollback capability")
checkpoint_dir: str = Field(
"/tmp/syndarix_checkpoints", # noqa: S108
description="Directory for checkpoint storage",
)
checkpoint_retention_hours: int = Field(24, description="Checkpoint retention")
auto_checkpoint_destructive: bool = Field(
True, description="Auto-checkpoint destructive actions"

View File

@@ -1 +1,23 @@
"""${dir} module."""
"""Content filtering for safety."""
from .filter import (
ContentCategory,
ContentFilter,
FilterAction,
FilterMatch,
FilterPattern,
FilterResult,
filter_content,
scan_for_secrets,
)
__all__ = [
"ContentCategory",
"ContentFilter",
"FilterAction",
"FilterMatch",
"FilterPattern",
"FilterResult",
"filter_content",
"scan_for_secrets",
]

View File

@@ -0,0 +1,532 @@
"""
Content Filter
Filters and sanitizes content for safety, including PII detection and secret scanning.
"""
import asyncio
import logging
import re
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, ClassVar
from ..exceptions import ContentFilterError
logger = logging.getLogger(__name__)
class ContentCategory(str, Enum):
"""Categories of sensitive content."""
PII = "pii"
SECRETS = "secrets"
CREDENTIALS = "credentials"
FINANCIAL = "financial"
HEALTH = "health"
PROFANITY = "profanity"
INJECTION = "injection"
CUSTOM = "custom"
class FilterAction(str, Enum):
"""Actions to take on detected content."""
ALLOW = "allow"
REDACT = "redact"
BLOCK = "block"
WARN = "warn"
@dataclass
class FilterMatch:
"""A match found by a filter."""
category: ContentCategory
pattern_name: str
matched_text: str
start_pos: int
end_pos: int
confidence: float = 1.0
redacted_text: str | None = None
@dataclass
class FilterResult:
"""Result of content filtering."""
original_content: str
filtered_content: str
matches: list[FilterMatch] = field(default_factory=list)
blocked: bool = False
block_reason: str | None = None
warnings: list[str] = field(default_factory=list)
@property
def has_sensitive_content(self) -> bool:
"""Check if any sensitive content was found."""
return len(self.matches) > 0
@dataclass
class FilterPattern:
"""A pattern for detecting sensitive content."""
name: str
category: ContentCategory
pattern: str # Regex pattern
action: FilterAction = FilterAction.REDACT
replacement: str = "[REDACTED]"
confidence: float = 1.0
enabled: bool = True
def __post_init__(self) -> None:
"""Compile the regex pattern."""
self._compiled = re.compile(self.pattern, re.IGNORECASE | re.MULTILINE)
def find_matches(self, content: str) -> list[FilterMatch]:
"""Find all matches in content."""
matches = []
for match in self._compiled.finditer(content):
matches.append(
FilterMatch(
category=self.category,
pattern_name=self.name,
matched_text=match.group(),
start_pos=match.start(),
end_pos=match.end(),
confidence=self.confidence,
redacted_text=self.replacement,
)
)
return matches
class ContentFilter:
"""
Filters content for sensitive information.
Features:
- PII detection (emails, phones, SSN, etc.)
- Secret scanning (API keys, tokens, passwords)
- Credential detection
- Injection attack prevention
- Custom pattern support
- Configurable actions (allow, redact, block, warn)
"""
# Default patterns for common sensitive data
DEFAULT_PATTERNS: ClassVar[list[FilterPattern]] = [
# PII Patterns
FilterPattern(
name="email",
category=ContentCategory.PII,
pattern=r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b",
action=FilterAction.REDACT,
replacement="[EMAIL]",
),
FilterPattern(
name="phone_us",
category=ContentCategory.PII,
pattern=r"\b(?:\+1[-.\s]?)?(?:\(?\d{3}\)?[-.\s]?)?\d{3}[-.\s]?\d{4}\b",
action=FilterAction.REDACT,
replacement="[PHONE]",
),
FilterPattern(
name="ssn",
category=ContentCategory.PII,
pattern=r"\b\d{3}[-\s]?\d{2}[-\s]?\d{4}\b",
action=FilterAction.REDACT,
replacement="[SSN]",
),
FilterPattern(
name="credit_card",
category=ContentCategory.FINANCIAL,
pattern=r"\b(?:\d{4}[-\s]?){3}\d{4}\b",
action=FilterAction.REDACT,
replacement="[CREDIT_CARD]",
),
FilterPattern(
name="ip_address",
category=ContentCategory.PII,
pattern=r"\b(?:\d{1,3}\.){3}\d{1,3}\b",
action=FilterAction.WARN,
replacement="[IP]",
confidence=0.8,
),
# Secret Patterns
FilterPattern(
name="api_key_generic",
category=ContentCategory.SECRETS,
pattern=r"\b(?:api[_-]?key|apikey)\s*[:=]\s*['\"]?([A-Za-z0-9_-]{20,})['\"]?",
action=FilterAction.BLOCK,
replacement="[API_KEY]",
),
FilterPattern(
name="aws_access_key",
category=ContentCategory.SECRETS,
pattern=r"\bAKIA[0-9A-Z]{16}\b",
action=FilterAction.BLOCK,
replacement="[AWS_KEY]",
),
FilterPattern(
name="aws_secret_key",
category=ContentCategory.SECRETS,
pattern=r"\b[A-Za-z0-9/+=]{40}\b",
action=FilterAction.WARN,
replacement="[AWS_SECRET]",
confidence=0.6, # Lower confidence - might be false positive
),
FilterPattern(
name="github_token",
category=ContentCategory.SECRETS,
pattern=r"\b(ghp|gho|ghu|ghs|ghr)_[A-Za-z0-9]{36,}\b",
action=FilterAction.BLOCK,
replacement="[GITHUB_TOKEN]",
),
FilterPattern(
name="jwt_token",
category=ContentCategory.SECRETS,
pattern=r"\beyJ[A-Za-z0-9_-]*\.eyJ[A-Za-z0-9_-]*\.[A-Za-z0-9_-]*\b",
action=FilterAction.BLOCK,
replacement="[JWT]",
),
# Credential Patterns
FilterPattern(
name="password_in_url",
category=ContentCategory.CREDENTIALS,
pattern=r"://[^:]+:([^@]+)@",
action=FilterAction.BLOCK,
replacement="://[REDACTED]@",
),
FilterPattern(
name="password_assignment",
category=ContentCategory.CREDENTIALS,
pattern=r"\b(?:password|passwd|pwd)\s*[:=]\s*['\"]?([^\s'\"]+)['\"]?",
action=FilterAction.REDACT,
replacement="[PASSWORD]",
),
FilterPattern(
name="private_key",
category=ContentCategory.SECRETS,
pattern=r"-----BEGIN (?:RSA |DSA |EC |OPENSSH )?PRIVATE KEY-----",
action=FilterAction.BLOCK,
replacement="[PRIVATE_KEY]",
),
# Injection Patterns
FilterPattern(
name="sql_injection",
category=ContentCategory.INJECTION,
pattern=r"(?:'\s*(?:OR|AND)\s*')|(?:--\s*$)|(?:;\s*(?:DROP|DELETE|UPDATE|INSERT))",
action=FilterAction.BLOCK,
replacement="[BLOCKED]",
),
FilterPattern(
name="command_injection",
category=ContentCategory.INJECTION,
pattern=r"[;&|`$]|\$\(|\$\{",
action=FilterAction.WARN,
replacement="[CMD]",
confidence=0.5, # Low confidence - common in code
),
]
def __init__(
self,
enable_pii_filter: bool = True,
enable_secret_filter: bool = True,
enable_injection_filter: bool = True,
custom_patterns: list[FilterPattern] | None = None,
default_action: FilterAction = FilterAction.REDACT,
) -> None:
"""
Initialize the ContentFilter.
Args:
enable_pii_filter: Enable PII detection
enable_secret_filter: Enable secret scanning
enable_injection_filter: Enable injection detection
custom_patterns: Additional custom patterns
default_action: Default action for matches
"""
self._patterns: list[FilterPattern] = []
self._default_action = default_action
self._lock = asyncio.Lock()
# Load default patterns based on configuration
for pattern in self.DEFAULT_PATTERNS:
if pattern.category == ContentCategory.PII and not enable_pii_filter:
continue
if pattern.category == ContentCategory.SECRETS and not enable_secret_filter:
continue
if pattern.category == ContentCategory.CREDENTIALS and not enable_secret_filter:
continue
if pattern.category == ContentCategory.INJECTION and not enable_injection_filter:
continue
self._patterns.append(pattern)
# Add custom patterns
if custom_patterns:
self._patterns.extend(custom_patterns)
logger.info("ContentFilter initialized with %d patterns", len(self._patterns))
def add_pattern(self, pattern: FilterPattern) -> None:
"""Add a custom pattern."""
self._patterns.append(pattern)
logger.debug("Added pattern: %s", pattern.name)
def remove_pattern(self, pattern_name: str) -> bool:
"""Remove a pattern by name."""
for i, pattern in enumerate(self._patterns):
if pattern.name == pattern_name:
del self._patterns[i]
logger.debug("Removed pattern: %s", pattern_name)
return True
return False
def enable_pattern(self, pattern_name: str, enabled: bool = True) -> bool:
"""Enable or disable a pattern."""
for pattern in self._patterns:
if pattern.name == pattern_name:
pattern.enabled = enabled
return True
return False
async def filter(
self,
content: str,
context: dict[str, Any] | None = None,
raise_on_block: bool = False,
) -> FilterResult:
"""
Filter content for sensitive information.
Args:
content: Content to filter
context: Optional context for filtering decisions
raise_on_block: Raise exception if content is blocked
Returns:
FilterResult with filtered content and match details
Raises:
ContentFilterError: If content is blocked and raise_on_block=True
"""
all_matches: list[FilterMatch] = []
blocked = False
block_reason: str | None = None
warnings: list[str] = []
# Find all matches
for pattern in self._patterns:
if not pattern.enabled:
continue
matches = pattern.find_matches(content)
for match in matches:
all_matches.append(match)
if pattern.action == FilterAction.BLOCK:
blocked = True
block_reason = f"Blocked by pattern: {pattern.name}"
elif pattern.action == FilterAction.WARN:
warnings.append(
f"Warning: {pattern.name} detected at position {match.start_pos}"
)
# Sort matches by position (reverse for replacement)
all_matches.sort(key=lambda m: m.start_pos, reverse=True)
# Apply redactions
filtered_content = content
for match in all_matches:
matched_pattern = self._get_pattern(match.pattern_name)
if matched_pattern and matched_pattern.action in (FilterAction.REDACT, FilterAction.BLOCK):
filtered_content = (
filtered_content[: match.start_pos]
+ (match.redacted_text or "[REDACTED]")
+ filtered_content[match.end_pos :]
)
# Re-sort for result
all_matches.sort(key=lambda m: m.start_pos)
result = FilterResult(
original_content=content,
filtered_content=filtered_content if not blocked else "",
matches=all_matches,
blocked=blocked,
block_reason=block_reason,
warnings=warnings,
)
if blocked:
logger.warning(
"Content blocked: %s (%d matches)",
block_reason,
len(all_matches),
)
if raise_on_block:
raise ContentFilterError(
block_reason or "Content blocked",
detected_category=all_matches[0].category.value if all_matches else "unknown",
pattern_name=all_matches[0].pattern_name if all_matches else None,
)
elif all_matches:
logger.debug(
"Content filtered: %d matches, %d warnings",
len(all_matches),
len(warnings),
)
return result
async def filter_dict(
self,
data: dict[str, Any],
keys_to_filter: list[str] | None = None,
recursive: bool = True,
) -> dict[str, Any]:
"""
Filter string values in a dictionary.
Args:
data: Dictionary to filter
keys_to_filter: Specific keys to filter (None = all)
recursive: Filter nested dictionaries
Returns:
Filtered dictionary
"""
result: dict[str, Any] = {}
for key, value in data.items():
if isinstance(value, str):
if keys_to_filter is None or key in keys_to_filter:
filter_result = await self.filter(value)
result[key] = filter_result.filtered_content
else:
result[key] = value
elif isinstance(value, dict) and recursive:
result[key] = await self.filter_dict(value, keys_to_filter, recursive)
elif isinstance(value, list):
result[key] = [
(await self.filter(item)).filtered_content
if isinstance(item, str)
else item
for item in value
]
else:
result[key] = value
return result
async def scan(
self,
content: str,
categories: list[ContentCategory] | None = None,
) -> list[FilterMatch]:
"""
Scan content without filtering (detection only).
Args:
content: Content to scan
categories: Limit to specific categories
Returns:
List of matches found
"""
all_matches: list[FilterMatch] = []
for pattern in self._patterns:
if not pattern.enabled:
continue
if categories and pattern.category not in categories:
continue
matches = pattern.find_matches(content)
all_matches.extend(matches)
all_matches.sort(key=lambda m: m.start_pos)
return all_matches
async def validate_safe(
self,
content: str,
categories: list[ContentCategory] | None = None,
allow_warnings: bool = True,
) -> tuple[bool, list[str]]:
"""
Validate that content is safe (no blocked patterns).
Args:
content: Content to validate
categories: Limit to specific categories
allow_warnings: Allow content with warnings
Returns:
Tuple of (is_safe, list of issues)
"""
issues: list[str] = []
for pattern in self._patterns:
if not pattern.enabled:
continue
if categories and pattern.category not in categories:
continue
matches = pattern.find_matches(content)
for match in matches:
if pattern.action == FilterAction.BLOCK:
issues.append(f"Blocked: {pattern.name} at position {match.start_pos}")
elif pattern.action == FilterAction.WARN and not allow_warnings:
issues.append(f"Warning: {pattern.name} at position {match.start_pos}")
return len(issues) == 0, issues
def _get_pattern(self, name: str) -> FilterPattern | None:
"""Get a pattern by name."""
for pattern in self._patterns:
if pattern.name == name:
return pattern
return None
def get_pattern_stats(self) -> dict[str, Any]:
"""Get statistics about configured patterns."""
by_category: dict[str, int] = {}
by_action: dict[str, int] = {}
for pattern in self._patterns:
cat = pattern.category.value
by_category[cat] = by_category.get(cat, 0) + 1
act = pattern.action.value
by_action[act] = by_action.get(act, 0) + 1
return {
"total_patterns": len(self._patterns),
"enabled_patterns": sum(1 for p in self._patterns if p.enabled),
"by_category": by_category,
"by_action": by_action,
}
# Convenience function for quick filtering
async def filter_content(content: str) -> str:
"""Quick filter content with default settings."""
filter_instance = ContentFilter()
result = await filter_instance.filter(content)
return result.filtered_content
async def scan_for_secrets(content: str) -> list[FilterMatch]:
"""Quick scan for secrets only."""
filter_instance = ContentFilter(
enable_pii_filter=False,
enable_injection_filter=False,
)
return await filter_instance.scan(
content,
categories=[ContentCategory.SECRETS, ContentCategory.CREDENTIALS],
)

View File

@@ -1 +1,23 @@
"""${dir} module."""
"""Emergency controls for agent safety."""
from .controls import (
EmergencyControls,
EmergencyEvent,
EmergencyReason,
EmergencyState,
EmergencyTrigger,
check_emergency_allowed,
emergency_stop_global,
get_emergency_controls,
)
__all__ = [
"EmergencyControls",
"EmergencyEvent",
"EmergencyReason",
"EmergencyState",
"EmergencyTrigger",
"check_emergency_allowed",
"emergency_stop_global",
"get_emergency_controls",
]

View File

@@ -0,0 +1,594 @@
"""
Emergency Controls
Emergency stop and pause functionality for agent safety.
"""
import asyncio
import logging
from collections.abc import Callable
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any
from ..exceptions import EmergencyStopError
logger = logging.getLogger(__name__)
class EmergencyState(str, Enum):
"""Emergency control states."""
NORMAL = "normal"
PAUSED = "paused"
STOPPED = "stopped"
class EmergencyReason(str, Enum):
"""Reasons for emergency actions."""
MANUAL = "manual"
SAFETY_VIOLATION = "safety_violation"
BUDGET_EXCEEDED = "budget_exceeded"
LOOP_DETECTED = "loop_detected"
RATE_LIMIT = "rate_limit"
CONTENT_VIOLATION = "content_violation"
SYSTEM_ERROR = "system_error"
EXTERNAL_TRIGGER = "external_trigger"
@dataclass
class EmergencyEvent:
"""Record of an emergency action."""
id: str
state: EmergencyState
reason: EmergencyReason
triggered_by: str
message: str
scope: str # "global", "project:<id>", "agent:<id>"
timestamp: datetime = field(default_factory=datetime.utcnow)
metadata: dict[str, Any] = field(default_factory=dict)
resolved_at: datetime | None = None
resolved_by: str | None = None
class EmergencyControls:
"""
Emergency stop and pause controls for agent safety.
Features:
- Global emergency stop
- Per-project/agent emergency controls
- Graceful pause with state preservation
- Automatic triggers from safety violations
- Manual override capabilities
- Event history and audit trail
"""
def __init__(
self,
notification_handlers: list[Callable[..., Any]] | None = None,
) -> None:
"""
Initialize EmergencyControls.
Args:
notification_handlers: Handlers to call on emergency events
"""
self._global_state = EmergencyState.NORMAL
self._scoped_states: dict[str, EmergencyState] = {}
self._events: list[EmergencyEvent] = []
self._notification_handlers = notification_handlers or []
self._lock = asyncio.Lock()
self._event_id_counter = 0
# Callbacks for state changes
self._on_stop_callbacks: list[Callable[..., Any]] = []
self._on_pause_callbacks: list[Callable[..., Any]] = []
self._on_resume_callbacks: list[Callable[..., Any]] = []
def _generate_event_id(self) -> str:
"""Generate a unique event ID."""
self._event_id_counter += 1
return f"emerg-{self._event_id_counter:06d}"
async def emergency_stop(
self,
reason: EmergencyReason,
triggered_by: str,
message: str,
scope: str = "global",
metadata: dict[str, Any] | None = None,
) -> EmergencyEvent:
"""
Trigger emergency stop.
Args:
reason: Reason for the stop
triggered_by: Who/what triggered the stop
message: Human-readable message
scope: Scope of the stop (global, project:<id>, agent:<id>)
metadata: Additional context
Returns:
The emergency event record
"""
async with self._lock:
event = EmergencyEvent(
id=self._generate_event_id(),
state=EmergencyState.STOPPED,
reason=reason,
triggered_by=triggered_by,
message=message,
scope=scope,
metadata=metadata or {},
)
if scope == "global":
self._global_state = EmergencyState.STOPPED
else:
self._scoped_states[scope] = EmergencyState.STOPPED
self._events.append(event)
logger.critical(
"EMERGENCY STOP: scope=%s, reason=%s, by=%s - %s",
scope,
reason.value,
triggered_by,
message,
)
# Execute callbacks
await self._execute_callbacks(self._on_stop_callbacks, event)
await self._notify_handlers("emergency_stop", event)
return event
async def pause(
self,
reason: EmergencyReason,
triggered_by: str,
message: str,
scope: str = "global",
metadata: dict[str, Any] | None = None,
) -> EmergencyEvent:
"""
Pause operations (can be resumed).
Args:
reason: Reason for the pause
triggered_by: Who/what triggered the pause
message: Human-readable message
scope: Scope of the pause
metadata: Additional context
Returns:
The emergency event record
"""
async with self._lock:
event = EmergencyEvent(
id=self._generate_event_id(),
state=EmergencyState.PAUSED,
reason=reason,
triggered_by=triggered_by,
message=message,
scope=scope,
metadata=metadata or {},
)
if scope == "global":
self._global_state = EmergencyState.PAUSED
else:
self._scoped_states[scope] = EmergencyState.PAUSED
self._events.append(event)
logger.warning(
"PAUSE: scope=%s, reason=%s, by=%s - %s",
scope,
reason.value,
triggered_by,
message,
)
await self._execute_callbacks(self._on_pause_callbacks, event)
await self._notify_handlers("pause", event)
return event
async def resume(
self,
scope: str = "global",
resumed_by: str = "system",
message: str | None = None,
) -> bool:
"""
Resume operations from paused state.
Args:
scope: Scope to resume
resumed_by: Who/what is resuming
message: Optional message
Returns:
True if resumed, False if not in paused state
"""
async with self._lock:
current_state = self._get_state(scope)
if current_state == EmergencyState.STOPPED:
logger.warning(
"Cannot resume from STOPPED state: %s (requires reset)",
scope,
)
return False
if current_state == EmergencyState.NORMAL:
return True # Already normal
# Find the pause event and mark as resolved
for event in reversed(self._events):
if event.scope == scope and event.state == EmergencyState.PAUSED:
if event.resolved_at is None:
event.resolved_at = datetime.utcnow()
event.resolved_by = resumed_by
break
if scope == "global":
self._global_state = EmergencyState.NORMAL
else:
self._scoped_states[scope] = EmergencyState.NORMAL
logger.info(
"RESUMED: scope=%s, by=%s%s",
scope,
resumed_by,
f" - {message}" if message else "",
)
await self._execute_callbacks(
self._on_resume_callbacks,
{"scope": scope, "resumed_by": resumed_by},
)
await self._notify_handlers("resume", {"scope": scope, "resumed_by": resumed_by})
return True
async def reset(
self,
scope: str = "global",
reset_by: str = "admin",
message: str | None = None,
) -> bool:
"""
Reset from stopped state (requires explicit action).
Args:
scope: Scope to reset
reset_by: Who is resetting (should be admin)
message: Optional message
Returns:
True if reset successful
"""
async with self._lock:
current_state = self._get_state(scope)
if current_state == EmergencyState.NORMAL:
return True
# Find the stop event and mark as resolved
for event in reversed(self._events):
if event.scope == scope and event.state == EmergencyState.STOPPED:
if event.resolved_at is None:
event.resolved_at = datetime.utcnow()
event.resolved_by = reset_by
break
if scope == "global":
self._global_state = EmergencyState.NORMAL
else:
self._scoped_states[scope] = EmergencyState.NORMAL
logger.warning(
"EMERGENCY RESET: scope=%s, by=%s%s",
scope,
reset_by,
f" - {message}" if message else "",
)
await self._notify_handlers("reset", {"scope": scope, "reset_by": reset_by})
return True
async def check_allowed(
self,
scope: str | None = None,
raise_if_blocked: bool = True,
) -> bool:
"""
Check if operations are allowed.
Args:
scope: Specific scope to check (also checks global)
raise_if_blocked: Raise exception if blocked
Returns:
True if operations are allowed
Raises:
EmergencyStopError: If blocked and raise_if_blocked=True
"""
async with self._lock:
# Always check global state
if self._global_state != EmergencyState.NORMAL:
if raise_if_blocked:
raise EmergencyStopError(
f"Global emergency state: {self._global_state.value}",
stop_reason=self._get_last_reason("global"),
triggered_by=self._get_last_triggered_by("global"),
)
return False
# Check specific scope
if scope and scope in self._scoped_states:
state = self._scoped_states[scope]
if state != EmergencyState.NORMAL:
if raise_if_blocked:
raise EmergencyStopError(
f"Emergency state for {scope}: {state.value}",
stop_reason=self._get_last_reason(scope),
triggered_by=self._get_last_triggered_by(scope),
scope=scope,
)
return False
return True
def _get_state(self, scope: str) -> EmergencyState:
"""Get state for a scope."""
if scope == "global":
return self._global_state
return self._scoped_states.get(scope, EmergencyState.NORMAL)
def _get_last_reason(self, scope: str) -> str:
"""Get reason from last event for scope."""
for event in reversed(self._events):
if event.scope == scope and event.resolved_at is None:
return event.reason.value
return "unknown"
def _get_last_triggered_by(self, scope: str) -> str:
"""Get triggered_by from last event for scope."""
for event in reversed(self._events):
if event.scope == scope and event.resolved_at is None:
return event.triggered_by
return "unknown"
async def get_state(self, scope: str = "global") -> EmergencyState:
"""Get current state for a scope."""
async with self._lock:
return self._get_state(scope)
async def get_all_states(self) -> dict[str, EmergencyState]:
"""Get all current states."""
async with self._lock:
states = {"global": self._global_state}
states.update(self._scoped_states)
return states
async def get_active_events(self) -> list[EmergencyEvent]:
"""Get all unresolved emergency events."""
async with self._lock:
return [e for e in self._events if e.resolved_at is None]
async def get_event_history(
self,
scope: str | None = None,
limit: int = 100,
) -> list[EmergencyEvent]:
"""Get emergency event history."""
async with self._lock:
events = list(self._events)
if scope:
events = [e for e in events if e.scope == scope]
return events[-limit:]
def on_stop(self, callback: Callable[..., Any]) -> None:
"""Register callback for stop events."""
self._on_stop_callbacks.append(callback)
def on_pause(self, callback: Callable[..., Any]) -> None:
"""Register callback for pause events."""
self._on_pause_callbacks.append(callback)
def on_resume(self, callback: Callable[..., Any]) -> None:
"""Register callback for resume events."""
self._on_resume_callbacks.append(callback)
def add_notification_handler(self, handler: Callable[..., Any]) -> None:
"""Add a notification handler."""
self._notification_handlers.append(handler)
async def _execute_callbacks(
self,
callbacks: list[Callable[..., Any]],
data: Any,
) -> None:
"""Execute callbacks safely."""
for callback in callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback(data)
else:
callback(data)
except Exception as e:
logger.error("Error in callback: %s", e)
async def _notify_handlers(self, event_type: str, data: Any) -> None:
"""Notify all handlers of an event."""
for handler in self._notification_handlers:
try:
if asyncio.iscoroutinefunction(handler):
await handler(event_type, data)
else:
handler(event_type, data)
except Exception as e:
logger.error("Error in notification handler: %s", e)
class EmergencyTrigger:
"""
Automatic emergency triggers based on conditions.
"""
def __init__(self, controls: EmergencyControls) -> None:
"""
Initialize EmergencyTrigger.
Args:
controls: EmergencyControls instance to trigger
"""
self._controls = controls
async def trigger_on_safety_violation(
self,
violation_type: str,
details: dict[str, Any],
scope: str = "global",
) -> EmergencyEvent:
"""
Trigger emergency from safety violation.
Args:
violation_type: Type of violation
details: Violation details
scope: Scope for the emergency
Returns:
Emergency event
"""
return await self._controls.emergency_stop(
reason=EmergencyReason.SAFETY_VIOLATION,
triggered_by="safety_system",
message=f"Safety violation: {violation_type}",
scope=scope,
metadata={"violation_type": violation_type, **details},
)
async def trigger_on_budget_exceeded(
self,
budget_type: str,
current: float,
limit: float,
scope: str = "global",
) -> EmergencyEvent:
"""
Trigger emergency from budget exceeded.
Args:
budget_type: Type of budget
current: Current usage
limit: Budget limit
scope: Scope for the emergency
Returns:
Emergency event
"""
return await self._controls.pause(
reason=EmergencyReason.BUDGET_EXCEEDED,
triggered_by="budget_controller",
message=f"Budget exceeded: {budget_type} ({current:.2f}/{limit:.2f})",
scope=scope,
metadata={"budget_type": budget_type, "current": current, "limit": limit},
)
async def trigger_on_loop_detected(
self,
loop_type: str,
agent_id: str,
details: dict[str, Any],
) -> EmergencyEvent:
"""
Trigger emergency from loop detection.
Args:
loop_type: Type of loop
agent_id: Agent that's looping
details: Loop details
Returns:
Emergency event
"""
return await self._controls.pause(
reason=EmergencyReason.LOOP_DETECTED,
triggered_by="loop_detector",
message=f"Loop detected: {loop_type} in agent {agent_id}",
scope=f"agent:{agent_id}",
metadata={"loop_type": loop_type, "agent_id": agent_id, **details},
)
async def trigger_on_content_violation(
self,
category: str,
pattern: str,
scope: str = "global",
) -> EmergencyEvent:
"""
Trigger emergency from content violation.
Args:
category: Content category
pattern: Pattern that matched
scope: Scope for the emergency
Returns:
Emergency event
"""
return await self._controls.emergency_stop(
reason=EmergencyReason.CONTENT_VIOLATION,
triggered_by="content_filter",
message=f"Content violation: {category} ({pattern})",
scope=scope,
metadata={"category": category, "pattern": pattern},
)
# Singleton instance
_emergency_controls: EmergencyControls | None = None
_lock = asyncio.Lock()
async def get_emergency_controls() -> EmergencyControls:
"""Get the singleton EmergencyControls instance."""
global _emergency_controls
async with _lock:
if _emergency_controls is None:
_emergency_controls = EmergencyControls()
return _emergency_controls
async def emergency_stop_global(
reason: str,
triggered_by: str = "system",
) -> EmergencyEvent:
"""Quick global emergency stop."""
controls = await get_emergency_controls()
return await controls.emergency_stop(
reason=EmergencyReason.MANUAL,
triggered_by=triggered_by,
message=reason,
scope="global",
)
async def check_emergency_allowed(scope: str | None = None) -> bool:
"""Quick check if operations are allowed."""
controls = await get_emergency_controls()
return await controls.check_allowed(scope=scope, raise_if_blocked=False)

View File

@@ -1 +1,5 @@
"""${dir} module."""
"""Human-in-the-Loop approval workflows."""
from .manager import ApprovalQueue, HITLManager
__all__ = ["ApprovalQueue", "HITLManager"]

View File

@@ -0,0 +1,449 @@
"""
Human-in-the-Loop (HITL) Manager
Manages approval workflows for actions requiring human oversight.
"""
import asyncio
import logging
from collections.abc import Callable
from datetime import datetime, timedelta
from typing import Any
from uuid import uuid4
from ..config import get_safety_config
from ..exceptions import (
ApprovalDeniedError,
ApprovalRequiredError,
ApprovalTimeoutError,
)
from ..models import (
ActionRequest,
ApprovalRequest,
ApprovalResponse,
ApprovalStatus,
)
logger = logging.getLogger(__name__)
class ApprovalQueue:
"""Queue for pending approval requests."""
def __init__(self) -> None:
self._pending: dict[str, ApprovalRequest] = {}
self._completed: dict[str, ApprovalResponse] = {}
self._waiters: dict[str, asyncio.Event] = {}
self._lock = asyncio.Lock()
async def add(self, request: ApprovalRequest) -> None:
"""Add an approval request to the queue."""
async with self._lock:
self._pending[request.id] = request
self._waiters[request.id] = asyncio.Event()
async def get_pending(self, request_id: str) -> ApprovalRequest | None:
"""Get a pending request by ID."""
async with self._lock:
return self._pending.get(request_id)
async def complete(self, response: ApprovalResponse) -> bool:
"""Complete an approval request."""
async with self._lock:
if response.request_id not in self._pending:
return False
del self._pending[response.request_id]
self._completed[response.request_id] = response
# Notify waiters
if response.request_id in self._waiters:
self._waiters[response.request_id].set()
return True
async def wait_for_response(
self,
request_id: str,
timeout_seconds: float,
) -> ApprovalResponse | None:
"""Wait for a response to an approval request."""
async with self._lock:
waiter = self._waiters.get(request_id)
if not waiter:
return self._completed.get(request_id)
try:
await asyncio.wait_for(waiter.wait(), timeout=timeout_seconds)
except TimeoutError:
return None
async with self._lock:
return self._completed.get(request_id)
async def list_pending(self) -> list[ApprovalRequest]:
"""List all pending requests."""
async with self._lock:
return list(self._pending.values())
async def cancel(self, request_id: str) -> bool:
"""Cancel a pending request."""
async with self._lock:
if request_id not in self._pending:
return False
del self._pending[request_id]
# Create cancelled response
response = ApprovalResponse(
request_id=request_id,
status=ApprovalStatus.CANCELLED,
reason="Cancelled",
)
self._completed[request_id] = response
# Notify waiters
if request_id in self._waiters:
self._waiters[request_id].set()
return True
async def cleanup_expired(self) -> int:
"""Clean up expired requests."""
now = datetime.utcnow()
to_timeout: list[str] = []
async with self._lock:
for request_id, request in self._pending.items():
if request.expires_at and request.expires_at < now:
to_timeout.append(request_id)
count = 0
for request_id in to_timeout:
async with self._lock:
if request_id in self._pending:
del self._pending[request_id]
self._completed[request_id] = ApprovalResponse(
request_id=request_id,
status=ApprovalStatus.TIMEOUT,
reason="Request timed out",
)
if request_id in self._waiters:
self._waiters[request_id].set()
count += 1
return count
class HITLManager:
"""
Manages Human-in-the-Loop approval workflows.
Features:
- Approval request queue
- Configurable timeout handling (default deny)
- Approval delegation
- Batch approval for similar actions
- Approval with modifications
- Notification channels
"""
def __init__(
self,
default_timeout: int | None = None,
) -> None:
"""
Initialize the HITLManager.
Args:
default_timeout: Default timeout for approval requests in seconds
"""
config = get_safety_config()
self._default_timeout = default_timeout or config.hitl_default_timeout
self._queue = ApprovalQueue()
self._notification_handlers: list[Callable[..., Any]] = []
self._running = False
self._cleanup_task: asyncio.Task[None] | None = None
async def start(self) -> None:
"""Start the HITL manager background tasks."""
if self._running:
return
self._running = True
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
logger.info("HITL Manager started")
async def stop(self) -> None:
"""Stop the HITL manager."""
self._running = False
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
logger.info("HITL Manager stopped")
async def request_approval(
self,
action: ActionRequest,
reason: str,
timeout_seconds: int | None = None,
urgency: str = "normal",
context: dict[str, Any] | None = None,
) -> ApprovalRequest:
"""
Create an approval request for an action.
Args:
action: The action requiring approval
reason: Why approval is required
timeout_seconds: Timeout for this request
urgency: Urgency level (low, normal, high, critical)
context: Additional context for the approver
Returns:
The created approval request
"""
timeout = timeout_seconds or self._default_timeout
expires_at = datetime.utcnow() + timedelta(seconds=timeout)
request = ApprovalRequest(
id=str(uuid4()),
action=action,
reason=reason,
urgency=urgency,
timeout_seconds=timeout,
expires_at=expires_at,
context=context or {},
)
await self._queue.add(request)
# Notify handlers
await self._notify_handlers("approval_requested", request)
logger.info(
"Approval requested: %s for action %s (timeout: %ds)",
request.id,
action.id,
timeout,
)
return request
async def wait_for_approval(
self,
request_id: str,
timeout_seconds: int | None = None,
) -> ApprovalResponse:
"""
Wait for an approval decision.
Args:
request_id: ID of the approval request
timeout_seconds: Override timeout
Returns:
The approval response
Raises:
ApprovalTimeoutError: If timeout expires
ApprovalDeniedError: If approval is denied
"""
request = await self._queue.get_pending(request_id)
if not request:
raise ApprovalRequiredError(
f"Approval request not found: {request_id}",
approval_id=request_id,
)
timeout = timeout_seconds or request.timeout_seconds or self._default_timeout
response = await self._queue.wait_for_response(request_id, timeout)
if response is None:
# Timeout - default deny
response = ApprovalResponse(
request_id=request_id,
status=ApprovalStatus.TIMEOUT,
reason="Request timed out (default deny)",
)
await self._queue.complete(response)
raise ApprovalTimeoutError(
"Approval request timed out",
approval_id=request_id,
timeout_seconds=timeout,
)
if response.status == ApprovalStatus.DENIED:
raise ApprovalDeniedError(
response.reason or "Approval denied",
approval_id=request_id,
denied_by=response.decided_by,
denial_reason=response.reason,
)
if response.status == ApprovalStatus.TIMEOUT:
raise ApprovalTimeoutError(
"Approval request timed out",
approval_id=request_id,
timeout_seconds=timeout,
)
if response.status == ApprovalStatus.CANCELLED:
raise ApprovalDeniedError(
"Approval request was cancelled",
approval_id=request_id,
denial_reason="Cancelled",
)
return response
async def approve(
self,
request_id: str,
decided_by: str,
reason: str | None = None,
modifications: dict[str, Any] | None = None,
) -> bool:
"""
Approve a pending request.
Args:
request_id: ID of the approval request
decided_by: Who approved
reason: Optional approval reason
modifications: Optional modifications to the action
Returns:
True if approval was recorded
"""
response = ApprovalResponse(
request_id=request_id,
status=ApprovalStatus.APPROVED,
decided_by=decided_by,
reason=reason,
modifications=modifications,
)
success = await self._queue.complete(response)
if success:
logger.info(
"Approval granted: %s by %s",
request_id,
decided_by,
)
await self._notify_handlers("approval_granted", response)
return success
async def deny(
self,
request_id: str,
decided_by: str,
reason: str | None = None,
) -> bool:
"""
Deny a pending request.
Args:
request_id: ID of the approval request
decided_by: Who denied
reason: Denial reason
Returns:
True if denial was recorded
"""
response = ApprovalResponse(
request_id=request_id,
status=ApprovalStatus.DENIED,
decided_by=decided_by,
reason=reason,
)
success = await self._queue.complete(response)
if success:
logger.info(
"Approval denied: %s by %s - %s",
request_id,
decided_by,
reason,
)
await self._notify_handlers("approval_denied", response)
return success
async def cancel(self, request_id: str) -> bool:
"""
Cancel a pending request.
Args:
request_id: ID of the approval request
Returns:
True if request was cancelled
"""
success = await self._queue.cancel(request_id)
if success:
logger.info("Approval request cancelled: %s", request_id)
return success
async def list_pending(self) -> list[ApprovalRequest]:
"""List all pending approval requests."""
return await self._queue.list_pending()
async def get_request(self, request_id: str) -> ApprovalRequest | None:
"""Get an approval request by ID."""
return await self._queue.get_pending(request_id)
def add_notification_handler(
self,
handler: Callable[..., Any],
) -> None:
"""Add a notification handler."""
self._notification_handlers.append(handler)
def remove_notification_handler(
self,
handler: Callable[..., Any],
) -> None:
"""Remove a notification handler."""
if handler in self._notification_handlers:
self._notification_handlers.remove(handler)
async def _notify_handlers(
self,
event_type: str,
data: Any,
) -> None:
"""Notify all handlers of an event."""
for handler in self._notification_handlers:
try:
if asyncio.iscoroutinefunction(handler):
await handler(event_type, data)
else:
handler(event_type, data)
except Exception as e:
logger.error("Error in notification handler: %s", e)
async def _periodic_cleanup(self) -> None:
"""Background task for cleaning up expired requests."""
while self._running:
try:
await asyncio.sleep(30) # Check every 30 seconds
count = await self._queue.cleanup_expired()
if count:
logger.debug("Cleaned up %d expired approval requests", count)
except asyncio.CancelledError:
break
except Exception as e:
logger.error("Error in approval cleanup: %s", e)

View File

@@ -1 +1,5 @@
"""${dir} module."""
"""Rollback management for agent actions."""
from .manager import RollbackManager, TransactionContext
__all__ = ["RollbackManager", "TransactionContext"]

View File

@@ -0,0 +1,418 @@
"""
Rollback Manager
Manages checkpoints and rollback operations for agent actions.
"""
import asyncio
import logging
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any
from uuid import uuid4
from ..config import get_safety_config
from ..exceptions import RollbackError
from ..models import (
ActionRequest,
Checkpoint,
CheckpointType,
RollbackResult,
)
logger = logging.getLogger(__name__)
class FileCheckpoint:
"""Stores file state for rollback."""
def __init__(
self,
checkpoint_id: str,
file_path: str,
original_content: bytes | None,
existed: bool,
) -> None:
self.checkpoint_id = checkpoint_id
self.file_path = file_path
self.original_content = original_content
self.existed = existed
self.created_at = datetime.utcnow()
class RollbackManager:
"""
Manages checkpoints and rollback operations.
Features:
- File system checkpoints
- Transaction wrapping for actions
- Automatic checkpoint for destructive actions
- Rollback triggers on failure
- Checkpoint expiration and cleanup
"""
def __init__(
self,
checkpoint_dir: str | None = None,
retention_hours: int | None = None,
) -> None:
"""
Initialize the RollbackManager.
Args:
checkpoint_dir: Directory for storing checkpoint data
retention_hours: Hours to retain checkpoints
"""
config = get_safety_config()
self._checkpoint_dir = Path(
checkpoint_dir or config.checkpoint_dir
)
self._retention_hours = retention_hours or config.checkpoint_retention_hours
self._checkpoints: dict[str, Checkpoint] = {}
self._file_checkpoints: dict[str, list[FileCheckpoint]] = {}
self._lock = asyncio.Lock()
# Ensure checkpoint directory exists
self._checkpoint_dir.mkdir(parents=True, exist_ok=True)
async def create_checkpoint(
self,
action: ActionRequest,
checkpoint_type: CheckpointType = CheckpointType.COMPOSITE,
description: str | None = None,
) -> Checkpoint:
"""
Create a checkpoint before an action.
Args:
action: The action to checkpoint for
checkpoint_type: Type of checkpoint
description: Optional description
Returns:
The created checkpoint
"""
checkpoint_id = str(uuid4())
checkpoint = Checkpoint(
id=checkpoint_id,
checkpoint_type=checkpoint_type,
action_id=action.id,
created_at=datetime.utcnow(),
expires_at=datetime.utcnow() + timedelta(hours=self._retention_hours),
data={
"action_type": action.action_type.value,
"tool_name": action.tool_name,
"resource": action.resource,
},
description=description or f"Checkpoint for {action.tool_name}",
)
async with self._lock:
self._checkpoints[checkpoint_id] = checkpoint
self._file_checkpoints[checkpoint_id] = []
logger.info(
"Created checkpoint %s for action %s",
checkpoint_id,
action.id,
)
return checkpoint
async def checkpoint_file(
self,
checkpoint_id: str,
file_path: str,
) -> None:
"""
Store current state of a file for checkpoint.
Args:
checkpoint_id: ID of the checkpoint
file_path: Path to the file
"""
path = Path(file_path)
if path.exists():
content = path.read_bytes()
existed = True
else:
content = None
existed = False
file_checkpoint = FileCheckpoint(
checkpoint_id=checkpoint_id,
file_path=file_path,
original_content=content,
existed=existed,
)
async with self._lock:
if checkpoint_id not in self._file_checkpoints:
self._file_checkpoints[checkpoint_id] = []
self._file_checkpoints[checkpoint_id].append(file_checkpoint)
logger.debug(
"Stored file state for checkpoint %s: %s (existed=%s)",
checkpoint_id,
file_path,
existed,
)
async def checkpoint_files(
self,
checkpoint_id: str,
file_paths: list[str],
) -> None:
"""
Store current state of multiple files.
Args:
checkpoint_id: ID of the checkpoint
file_paths: Paths to the files
"""
for path in file_paths:
await self.checkpoint_file(checkpoint_id, path)
async def rollback(
self,
checkpoint_id: str,
) -> RollbackResult:
"""
Rollback to a checkpoint.
Args:
checkpoint_id: ID of the checkpoint
Returns:
Result of the rollback operation
"""
async with self._lock:
checkpoint = self._checkpoints.get(checkpoint_id)
if not checkpoint:
raise RollbackError(
f"Checkpoint not found: {checkpoint_id}",
checkpoint_id=checkpoint_id,
)
if not checkpoint.is_valid:
raise RollbackError(
f"Checkpoint is no longer valid: {checkpoint_id}",
checkpoint_id=checkpoint_id,
)
file_checkpoints = self._file_checkpoints.get(checkpoint_id, [])
actions_rolled_back: list[str] = []
failed_actions: list[str] = []
# Rollback file changes
for fc in file_checkpoints:
try:
await self._rollback_file(fc)
actions_rolled_back.append(f"file:{fc.file_path}")
except Exception as e:
logger.error("Failed to rollback file %s: %s", fc.file_path, e)
failed_actions.append(f"file:{fc.file_path}")
success = len(failed_actions) == 0
# Mark checkpoint as used
async with self._lock:
if checkpoint_id in self._checkpoints:
self._checkpoints[checkpoint_id].is_valid = False
result = RollbackResult(
checkpoint_id=checkpoint_id,
success=success,
actions_rolled_back=actions_rolled_back,
failed_actions=failed_actions,
error=None if success else f"Failed to rollback {len(failed_actions)} items",
)
if success:
logger.info("Rollback successful for checkpoint %s", checkpoint_id)
else:
logger.error(
"Rollback partially failed for checkpoint %s: %d failures",
checkpoint_id,
len(failed_actions),
)
return result
async def discard_checkpoint(self, checkpoint_id: str) -> bool:
"""
Discard a checkpoint without rolling back.
Args:
checkpoint_id: ID of the checkpoint
Returns:
True if checkpoint was found and discarded
"""
async with self._lock:
if checkpoint_id in self._checkpoints:
del self._checkpoints[checkpoint_id]
if checkpoint_id in self._file_checkpoints:
del self._file_checkpoints[checkpoint_id]
logger.debug("Discarded checkpoint %s", checkpoint_id)
return True
return False
async def get_checkpoint(self, checkpoint_id: str) -> Checkpoint | None:
"""Get a checkpoint by ID."""
async with self._lock:
return self._checkpoints.get(checkpoint_id)
async def list_checkpoints(
self,
action_id: str | None = None,
include_expired: bool = False,
) -> list[Checkpoint]:
"""
List checkpoints.
Args:
action_id: Optional filter by action ID
include_expired: Include expired checkpoints
Returns:
List of checkpoints
"""
now = datetime.utcnow()
async with self._lock:
checkpoints = list(self._checkpoints.values())
if action_id:
checkpoints = [c for c in checkpoints if c.action_id == action_id]
if not include_expired:
checkpoints = [
c for c in checkpoints
if c.expires_at is None or c.expires_at > now
]
return checkpoints
async def cleanup_expired(self) -> int:
"""
Clean up expired checkpoints.
Returns:
Number of checkpoints cleaned up
"""
now = datetime.utcnow()
to_remove: list[str] = []
async with self._lock:
for checkpoint_id, checkpoint in self._checkpoints.items():
if checkpoint.expires_at and checkpoint.expires_at < now:
to_remove.append(checkpoint_id)
for checkpoint_id in to_remove:
del self._checkpoints[checkpoint_id]
if checkpoint_id in self._file_checkpoints:
del self._file_checkpoints[checkpoint_id]
if to_remove:
logger.info("Cleaned up %d expired checkpoints", len(to_remove))
return len(to_remove)
async def _rollback_file(self, fc: FileCheckpoint) -> None:
"""Rollback a single file to its checkpoint state."""
path = Path(fc.file_path)
if fc.existed:
# Restore original content
if fc.original_content is not None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(fc.original_content)
logger.debug("Restored file: %s", fc.file_path)
else:
# File didn't exist before - delete it
if path.exists():
path.unlink()
logger.debug("Deleted file (didn't exist before): %s", fc.file_path)
class TransactionContext:
"""
Context manager for transactional action execution.
Usage:
async with TransactionContext(rollback_manager, action) as tx:
tx.checkpoint_file("/path/to/file")
# Do work...
# If exception occurs, automatic rollback
"""
def __init__(
self,
manager: RollbackManager,
action: ActionRequest,
auto_rollback: bool = True,
) -> None:
self._manager = manager
self._action = action
self._auto_rollback = auto_rollback
self._checkpoint: Checkpoint | None = None
self._committed = False
async def __aenter__(self) -> "TransactionContext":
self._checkpoint = await self._manager.create_checkpoint(self._action)
return self
async def __aexit__(
self,
exc_type: type | None,
exc_val: Exception | None,
exc_tb: Any,
) -> bool:
if exc_val is not None and self._auto_rollback and not self._committed:
# Exception occurred - rollback
if self._checkpoint:
try:
await self._manager.rollback(self._checkpoint.id)
logger.info(
"Auto-rollback completed for action %s",
self._action.id,
)
except Exception as e:
logger.error("Auto-rollback failed: %s", e)
elif self._committed and self._checkpoint:
# Committed - discard checkpoint
await self._manager.discard_checkpoint(self._checkpoint.id)
return False # Don't suppress the exception
@property
def checkpoint_id(self) -> str | None:
"""Get the checkpoint ID."""
return self._checkpoint.id if self._checkpoint else None
async def checkpoint_file(self, file_path: str) -> None:
"""Checkpoint a file for this transaction."""
if self._checkpoint:
await self._manager.checkpoint_file(self._checkpoint.id, file_path)
async def checkpoint_files(self, file_paths: list[str]) -> None:
"""Checkpoint multiple files for this transaction."""
if self._checkpoint:
await self._manager.checkpoint_files(self._checkpoint.id, file_paths)
def commit(self) -> None:
"""Mark transaction as committed (no rollback on exit)."""
self._committed = True
async def rollback(self) -> RollbackResult | None:
"""Manually trigger rollback."""
if self._checkpoint:
return await self._manager.rollback(self._checkpoint.id)
return None