feat(safety): add Phase C advanced controls
- Add rollback manager with file checkpointing and transaction context - Add HITL manager with approval queues and notification handlers - Add content filter with PII, secrets, and injection detection - Add emergency controls with stop/pause/resume capabilities - Update SafetyConfig with checkpoint_dir setting Issue #63 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -74,6 +74,10 @@ class SafetyConfig(BaseSettings):
|
|||||||
|
|
||||||
# Rollback settings
|
# Rollback settings
|
||||||
rollback_enabled: bool = Field(True, description="Enable rollback capability")
|
rollback_enabled: bool = Field(True, description="Enable rollback capability")
|
||||||
|
checkpoint_dir: str = Field(
|
||||||
|
"/tmp/syndarix_checkpoints", # noqa: S108
|
||||||
|
description="Directory for checkpoint storage",
|
||||||
|
)
|
||||||
checkpoint_retention_hours: int = Field(24, description="Checkpoint retention")
|
checkpoint_retention_hours: int = Field(24, description="Checkpoint retention")
|
||||||
auto_checkpoint_destructive: bool = Field(
|
auto_checkpoint_destructive: bool = Field(
|
||||||
True, description="Auto-checkpoint destructive actions"
|
True, description="Auto-checkpoint destructive actions"
|
||||||
|
|||||||
@@ -1 +1,23 @@
|
|||||||
"""${dir} module."""
|
"""Content filtering for safety."""
|
||||||
|
|
||||||
|
from .filter import (
|
||||||
|
ContentCategory,
|
||||||
|
ContentFilter,
|
||||||
|
FilterAction,
|
||||||
|
FilterMatch,
|
||||||
|
FilterPattern,
|
||||||
|
FilterResult,
|
||||||
|
filter_content,
|
||||||
|
scan_for_secrets,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ContentCategory",
|
||||||
|
"ContentFilter",
|
||||||
|
"FilterAction",
|
||||||
|
"FilterMatch",
|
||||||
|
"FilterPattern",
|
||||||
|
"FilterResult",
|
||||||
|
"filter_content",
|
||||||
|
"scan_for_secrets",
|
||||||
|
]
|
||||||
|
|||||||
532
backend/app/services/safety/content/filter.py
Normal file
532
backend/app/services/safety/content/filter.py
Normal file
@@ -0,0 +1,532 @@
|
|||||||
|
"""
|
||||||
|
Content Filter
|
||||||
|
|
||||||
|
Filters and sanitizes content for safety, including PII detection and secret scanning.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
|
from ..exceptions import ContentFilterError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ContentCategory(str, Enum):
|
||||||
|
"""Categories of sensitive content."""
|
||||||
|
|
||||||
|
PII = "pii"
|
||||||
|
SECRETS = "secrets"
|
||||||
|
CREDENTIALS = "credentials"
|
||||||
|
FINANCIAL = "financial"
|
||||||
|
HEALTH = "health"
|
||||||
|
PROFANITY = "profanity"
|
||||||
|
INJECTION = "injection"
|
||||||
|
CUSTOM = "custom"
|
||||||
|
|
||||||
|
|
||||||
|
class FilterAction(str, Enum):
|
||||||
|
"""Actions to take on detected content."""
|
||||||
|
|
||||||
|
ALLOW = "allow"
|
||||||
|
REDACT = "redact"
|
||||||
|
BLOCK = "block"
|
||||||
|
WARN = "warn"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FilterMatch:
|
||||||
|
"""A match found by a filter."""
|
||||||
|
|
||||||
|
category: ContentCategory
|
||||||
|
pattern_name: str
|
||||||
|
matched_text: str
|
||||||
|
start_pos: int
|
||||||
|
end_pos: int
|
||||||
|
confidence: float = 1.0
|
||||||
|
redacted_text: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FilterResult:
|
||||||
|
"""Result of content filtering."""
|
||||||
|
|
||||||
|
original_content: str
|
||||||
|
filtered_content: str
|
||||||
|
matches: list[FilterMatch] = field(default_factory=list)
|
||||||
|
blocked: bool = False
|
||||||
|
block_reason: str | None = None
|
||||||
|
warnings: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_sensitive_content(self) -> bool:
|
||||||
|
"""Check if any sensitive content was found."""
|
||||||
|
return len(self.matches) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FilterPattern:
|
||||||
|
"""A pattern for detecting sensitive content."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
category: ContentCategory
|
||||||
|
pattern: str # Regex pattern
|
||||||
|
action: FilterAction = FilterAction.REDACT
|
||||||
|
replacement: str = "[REDACTED]"
|
||||||
|
confidence: float = 1.0
|
||||||
|
enabled: bool = True
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""Compile the regex pattern."""
|
||||||
|
self._compiled = re.compile(self.pattern, re.IGNORECASE | re.MULTILINE)
|
||||||
|
|
||||||
|
def find_matches(self, content: str) -> list[FilterMatch]:
|
||||||
|
"""Find all matches in content."""
|
||||||
|
matches = []
|
||||||
|
for match in self._compiled.finditer(content):
|
||||||
|
matches.append(
|
||||||
|
FilterMatch(
|
||||||
|
category=self.category,
|
||||||
|
pattern_name=self.name,
|
||||||
|
matched_text=match.group(),
|
||||||
|
start_pos=match.start(),
|
||||||
|
end_pos=match.end(),
|
||||||
|
confidence=self.confidence,
|
||||||
|
redacted_text=self.replacement,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return matches
|
||||||
|
|
||||||
|
|
||||||
|
class ContentFilter:
|
||||||
|
"""
|
||||||
|
Filters content for sensitive information.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- PII detection (emails, phones, SSN, etc.)
|
||||||
|
- Secret scanning (API keys, tokens, passwords)
|
||||||
|
- Credential detection
|
||||||
|
- Injection attack prevention
|
||||||
|
- Custom pattern support
|
||||||
|
- Configurable actions (allow, redact, block, warn)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Default patterns for common sensitive data
|
||||||
|
DEFAULT_PATTERNS: ClassVar[list[FilterPattern]] = [
|
||||||
|
# PII Patterns
|
||||||
|
FilterPattern(
|
||||||
|
name="email",
|
||||||
|
category=ContentCategory.PII,
|
||||||
|
pattern=r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b",
|
||||||
|
action=FilterAction.REDACT,
|
||||||
|
replacement="[EMAIL]",
|
||||||
|
),
|
||||||
|
FilterPattern(
|
||||||
|
name="phone_us",
|
||||||
|
category=ContentCategory.PII,
|
||||||
|
pattern=r"\b(?:\+1[-.\s]?)?(?:\(?\d{3}\)?[-.\s]?)?\d{3}[-.\s]?\d{4}\b",
|
||||||
|
action=FilterAction.REDACT,
|
||||||
|
replacement="[PHONE]",
|
||||||
|
),
|
||||||
|
FilterPattern(
|
||||||
|
name="ssn",
|
||||||
|
category=ContentCategory.PII,
|
||||||
|
pattern=r"\b\d{3}[-\s]?\d{2}[-\s]?\d{4}\b",
|
||||||
|
action=FilterAction.REDACT,
|
||||||
|
replacement="[SSN]",
|
||||||
|
),
|
||||||
|
FilterPattern(
|
||||||
|
name="credit_card",
|
||||||
|
category=ContentCategory.FINANCIAL,
|
||||||
|
pattern=r"\b(?:\d{4}[-\s]?){3}\d{4}\b",
|
||||||
|
action=FilterAction.REDACT,
|
||||||
|
replacement="[CREDIT_CARD]",
|
||||||
|
),
|
||||||
|
FilterPattern(
|
||||||
|
name="ip_address",
|
||||||
|
category=ContentCategory.PII,
|
||||||
|
pattern=r"\b(?:\d{1,3}\.){3}\d{1,3}\b",
|
||||||
|
action=FilterAction.WARN,
|
||||||
|
replacement="[IP]",
|
||||||
|
confidence=0.8,
|
||||||
|
),
|
||||||
|
# Secret Patterns
|
||||||
|
FilterPattern(
|
||||||
|
name="api_key_generic",
|
||||||
|
category=ContentCategory.SECRETS,
|
||||||
|
pattern=r"\b(?:api[_-]?key|apikey)\s*[:=]\s*['\"]?([A-Za-z0-9_-]{20,})['\"]?",
|
||||||
|
action=FilterAction.BLOCK,
|
||||||
|
replacement="[API_KEY]",
|
||||||
|
),
|
||||||
|
FilterPattern(
|
||||||
|
name="aws_access_key",
|
||||||
|
category=ContentCategory.SECRETS,
|
||||||
|
pattern=r"\bAKIA[0-9A-Z]{16}\b",
|
||||||
|
action=FilterAction.BLOCK,
|
||||||
|
replacement="[AWS_KEY]",
|
||||||
|
),
|
||||||
|
FilterPattern(
|
||||||
|
name="aws_secret_key",
|
||||||
|
category=ContentCategory.SECRETS,
|
||||||
|
pattern=r"\b[A-Za-z0-9/+=]{40}\b",
|
||||||
|
action=FilterAction.WARN,
|
||||||
|
replacement="[AWS_SECRET]",
|
||||||
|
confidence=0.6, # Lower confidence - might be false positive
|
||||||
|
),
|
||||||
|
FilterPattern(
|
||||||
|
name="github_token",
|
||||||
|
category=ContentCategory.SECRETS,
|
||||||
|
pattern=r"\b(ghp|gho|ghu|ghs|ghr)_[A-Za-z0-9]{36,}\b",
|
||||||
|
action=FilterAction.BLOCK,
|
||||||
|
replacement="[GITHUB_TOKEN]",
|
||||||
|
),
|
||||||
|
FilterPattern(
|
||||||
|
name="jwt_token",
|
||||||
|
category=ContentCategory.SECRETS,
|
||||||
|
pattern=r"\beyJ[A-Za-z0-9_-]*\.eyJ[A-Za-z0-9_-]*\.[A-Za-z0-9_-]*\b",
|
||||||
|
action=FilterAction.BLOCK,
|
||||||
|
replacement="[JWT]",
|
||||||
|
),
|
||||||
|
# Credential Patterns
|
||||||
|
FilterPattern(
|
||||||
|
name="password_in_url",
|
||||||
|
category=ContentCategory.CREDENTIALS,
|
||||||
|
pattern=r"://[^:]+:([^@]+)@",
|
||||||
|
action=FilterAction.BLOCK,
|
||||||
|
replacement="://[REDACTED]@",
|
||||||
|
),
|
||||||
|
FilterPattern(
|
||||||
|
name="password_assignment",
|
||||||
|
category=ContentCategory.CREDENTIALS,
|
||||||
|
pattern=r"\b(?:password|passwd|pwd)\s*[:=]\s*['\"]?([^\s'\"]+)['\"]?",
|
||||||
|
action=FilterAction.REDACT,
|
||||||
|
replacement="[PASSWORD]",
|
||||||
|
),
|
||||||
|
FilterPattern(
|
||||||
|
name="private_key",
|
||||||
|
category=ContentCategory.SECRETS,
|
||||||
|
pattern=r"-----BEGIN (?:RSA |DSA |EC |OPENSSH )?PRIVATE KEY-----",
|
||||||
|
action=FilterAction.BLOCK,
|
||||||
|
replacement="[PRIVATE_KEY]",
|
||||||
|
),
|
||||||
|
# Injection Patterns
|
||||||
|
FilterPattern(
|
||||||
|
name="sql_injection",
|
||||||
|
category=ContentCategory.INJECTION,
|
||||||
|
pattern=r"(?:'\s*(?:OR|AND)\s*')|(?:--\s*$)|(?:;\s*(?:DROP|DELETE|UPDATE|INSERT))",
|
||||||
|
action=FilterAction.BLOCK,
|
||||||
|
replacement="[BLOCKED]",
|
||||||
|
),
|
||||||
|
FilterPattern(
|
||||||
|
name="command_injection",
|
||||||
|
category=ContentCategory.INJECTION,
|
||||||
|
pattern=r"[;&|`$]|\$\(|\$\{",
|
||||||
|
action=FilterAction.WARN,
|
||||||
|
replacement="[CMD]",
|
||||||
|
confidence=0.5, # Low confidence - common in code
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
enable_pii_filter: bool = True,
|
||||||
|
enable_secret_filter: bool = True,
|
||||||
|
enable_injection_filter: bool = True,
|
||||||
|
custom_patterns: list[FilterPattern] | None = None,
|
||||||
|
default_action: FilterAction = FilterAction.REDACT,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the ContentFilter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
enable_pii_filter: Enable PII detection
|
||||||
|
enable_secret_filter: Enable secret scanning
|
||||||
|
enable_injection_filter: Enable injection detection
|
||||||
|
custom_patterns: Additional custom patterns
|
||||||
|
default_action: Default action for matches
|
||||||
|
"""
|
||||||
|
self._patterns: list[FilterPattern] = []
|
||||||
|
self._default_action = default_action
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
# Load default patterns based on configuration
|
||||||
|
for pattern in self.DEFAULT_PATTERNS:
|
||||||
|
if pattern.category == ContentCategory.PII and not enable_pii_filter:
|
||||||
|
continue
|
||||||
|
if pattern.category == ContentCategory.SECRETS and not enable_secret_filter:
|
||||||
|
continue
|
||||||
|
if pattern.category == ContentCategory.CREDENTIALS and not enable_secret_filter:
|
||||||
|
continue
|
||||||
|
if pattern.category == ContentCategory.INJECTION and not enable_injection_filter:
|
||||||
|
continue
|
||||||
|
self._patterns.append(pattern)
|
||||||
|
|
||||||
|
# Add custom patterns
|
||||||
|
if custom_patterns:
|
||||||
|
self._patterns.extend(custom_patterns)
|
||||||
|
|
||||||
|
logger.info("ContentFilter initialized with %d patterns", len(self._patterns))
|
||||||
|
|
||||||
|
def add_pattern(self, pattern: FilterPattern) -> None:
|
||||||
|
"""Add a custom pattern."""
|
||||||
|
self._patterns.append(pattern)
|
||||||
|
logger.debug("Added pattern: %s", pattern.name)
|
||||||
|
|
||||||
|
def remove_pattern(self, pattern_name: str) -> bool:
|
||||||
|
"""Remove a pattern by name."""
|
||||||
|
for i, pattern in enumerate(self._patterns):
|
||||||
|
if pattern.name == pattern_name:
|
||||||
|
del self._patterns[i]
|
||||||
|
logger.debug("Removed pattern: %s", pattern_name)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def enable_pattern(self, pattern_name: str, enabled: bool = True) -> bool:
|
||||||
|
"""Enable or disable a pattern."""
|
||||||
|
for pattern in self._patterns:
|
||||||
|
if pattern.name == pattern_name:
|
||||||
|
pattern.enabled = enabled
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def filter(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
context: dict[str, Any] | None = None,
|
||||||
|
raise_on_block: bool = False,
|
||||||
|
) -> FilterResult:
|
||||||
|
"""
|
||||||
|
Filter content for sensitive information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content to filter
|
||||||
|
context: Optional context for filtering decisions
|
||||||
|
raise_on_block: Raise exception if content is blocked
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FilterResult with filtered content and match details
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ContentFilterError: If content is blocked and raise_on_block=True
|
||||||
|
"""
|
||||||
|
all_matches: list[FilterMatch] = []
|
||||||
|
blocked = False
|
||||||
|
block_reason: str | None = None
|
||||||
|
warnings: list[str] = []
|
||||||
|
|
||||||
|
# Find all matches
|
||||||
|
for pattern in self._patterns:
|
||||||
|
if not pattern.enabled:
|
||||||
|
continue
|
||||||
|
|
||||||
|
matches = pattern.find_matches(content)
|
||||||
|
for match in matches:
|
||||||
|
all_matches.append(match)
|
||||||
|
|
||||||
|
if pattern.action == FilterAction.BLOCK:
|
||||||
|
blocked = True
|
||||||
|
block_reason = f"Blocked by pattern: {pattern.name}"
|
||||||
|
elif pattern.action == FilterAction.WARN:
|
||||||
|
warnings.append(
|
||||||
|
f"Warning: {pattern.name} detected at position {match.start_pos}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sort matches by position (reverse for replacement)
|
||||||
|
all_matches.sort(key=lambda m: m.start_pos, reverse=True)
|
||||||
|
|
||||||
|
# Apply redactions
|
||||||
|
filtered_content = content
|
||||||
|
for match in all_matches:
|
||||||
|
matched_pattern = self._get_pattern(match.pattern_name)
|
||||||
|
if matched_pattern and matched_pattern.action in (FilterAction.REDACT, FilterAction.BLOCK):
|
||||||
|
filtered_content = (
|
||||||
|
filtered_content[: match.start_pos]
|
||||||
|
+ (match.redacted_text or "[REDACTED]")
|
||||||
|
+ filtered_content[match.end_pos :]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Re-sort for result
|
||||||
|
all_matches.sort(key=lambda m: m.start_pos)
|
||||||
|
|
||||||
|
result = FilterResult(
|
||||||
|
original_content=content,
|
||||||
|
filtered_content=filtered_content if not blocked else "",
|
||||||
|
matches=all_matches,
|
||||||
|
blocked=blocked,
|
||||||
|
block_reason=block_reason,
|
||||||
|
warnings=warnings,
|
||||||
|
)
|
||||||
|
|
||||||
|
if blocked:
|
||||||
|
logger.warning(
|
||||||
|
"Content blocked: %s (%d matches)",
|
||||||
|
block_reason,
|
||||||
|
len(all_matches),
|
||||||
|
)
|
||||||
|
if raise_on_block:
|
||||||
|
raise ContentFilterError(
|
||||||
|
block_reason or "Content blocked",
|
||||||
|
detected_category=all_matches[0].category.value if all_matches else "unknown",
|
||||||
|
pattern_name=all_matches[0].pattern_name if all_matches else None,
|
||||||
|
)
|
||||||
|
elif all_matches:
|
||||||
|
logger.debug(
|
||||||
|
"Content filtered: %d matches, %d warnings",
|
||||||
|
len(all_matches),
|
||||||
|
len(warnings),
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def filter_dict(
|
||||||
|
self,
|
||||||
|
data: dict[str, Any],
|
||||||
|
keys_to_filter: list[str] | None = None,
|
||||||
|
recursive: bool = True,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Filter string values in a dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dictionary to filter
|
||||||
|
keys_to_filter: Specific keys to filter (None = all)
|
||||||
|
recursive: Filter nested dictionaries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered dictionary
|
||||||
|
"""
|
||||||
|
result: dict[str, Any] = {}
|
||||||
|
|
||||||
|
for key, value in data.items():
|
||||||
|
if isinstance(value, str):
|
||||||
|
if keys_to_filter is None or key in keys_to_filter:
|
||||||
|
filter_result = await self.filter(value)
|
||||||
|
result[key] = filter_result.filtered_content
|
||||||
|
else:
|
||||||
|
result[key] = value
|
||||||
|
elif isinstance(value, dict) and recursive:
|
||||||
|
result[key] = await self.filter_dict(value, keys_to_filter, recursive)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
result[key] = [
|
||||||
|
(await self.filter(item)).filtered_content
|
||||||
|
if isinstance(item, str)
|
||||||
|
else item
|
||||||
|
for item in value
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
result[key] = value
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def scan(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
categories: list[ContentCategory] | None = None,
|
||||||
|
) -> list[FilterMatch]:
|
||||||
|
"""
|
||||||
|
Scan content without filtering (detection only).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content to scan
|
||||||
|
categories: Limit to specific categories
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matches found
|
||||||
|
"""
|
||||||
|
all_matches: list[FilterMatch] = []
|
||||||
|
|
||||||
|
for pattern in self._patterns:
|
||||||
|
if not pattern.enabled:
|
||||||
|
continue
|
||||||
|
if categories and pattern.category not in categories:
|
||||||
|
continue
|
||||||
|
|
||||||
|
matches = pattern.find_matches(content)
|
||||||
|
all_matches.extend(matches)
|
||||||
|
|
||||||
|
all_matches.sort(key=lambda m: m.start_pos)
|
||||||
|
return all_matches
|
||||||
|
|
||||||
|
async def validate_safe(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
categories: list[ContentCategory] | None = None,
|
||||||
|
allow_warnings: bool = True,
|
||||||
|
) -> tuple[bool, list[str]]:
|
||||||
|
"""
|
||||||
|
Validate that content is safe (no blocked patterns).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content to validate
|
||||||
|
categories: Limit to specific categories
|
||||||
|
allow_warnings: Allow content with warnings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_safe, list of issues)
|
||||||
|
"""
|
||||||
|
issues: list[str] = []
|
||||||
|
|
||||||
|
for pattern in self._patterns:
|
||||||
|
if not pattern.enabled:
|
||||||
|
continue
|
||||||
|
if categories and pattern.category not in categories:
|
||||||
|
continue
|
||||||
|
|
||||||
|
matches = pattern.find_matches(content)
|
||||||
|
for match in matches:
|
||||||
|
if pattern.action == FilterAction.BLOCK:
|
||||||
|
issues.append(f"Blocked: {pattern.name} at position {match.start_pos}")
|
||||||
|
elif pattern.action == FilterAction.WARN and not allow_warnings:
|
||||||
|
issues.append(f"Warning: {pattern.name} at position {match.start_pos}")
|
||||||
|
|
||||||
|
return len(issues) == 0, issues
|
||||||
|
|
||||||
|
def _get_pattern(self, name: str) -> FilterPattern | None:
|
||||||
|
"""Get a pattern by name."""
|
||||||
|
for pattern in self._patterns:
|
||||||
|
if pattern.name == name:
|
||||||
|
return pattern
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_pattern_stats(self) -> dict[str, Any]:
|
||||||
|
"""Get statistics about configured patterns."""
|
||||||
|
by_category: dict[str, int] = {}
|
||||||
|
by_action: dict[str, int] = {}
|
||||||
|
|
||||||
|
for pattern in self._patterns:
|
||||||
|
cat = pattern.category.value
|
||||||
|
by_category[cat] = by_category.get(cat, 0) + 1
|
||||||
|
|
||||||
|
act = pattern.action.value
|
||||||
|
by_action[act] = by_action.get(act, 0) + 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_patterns": len(self._patterns),
|
||||||
|
"enabled_patterns": sum(1 for p in self._patterns if p.enabled),
|
||||||
|
"by_category": by_category,
|
||||||
|
"by_action": by_action,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Convenience function for quick filtering
|
||||||
|
async def filter_content(content: str) -> str:
|
||||||
|
"""Quick filter content with default settings."""
|
||||||
|
filter_instance = ContentFilter()
|
||||||
|
result = await filter_instance.filter(content)
|
||||||
|
return result.filtered_content
|
||||||
|
|
||||||
|
|
||||||
|
async def scan_for_secrets(content: str) -> list[FilterMatch]:
|
||||||
|
"""Quick scan for secrets only."""
|
||||||
|
filter_instance = ContentFilter(
|
||||||
|
enable_pii_filter=False,
|
||||||
|
enable_injection_filter=False,
|
||||||
|
)
|
||||||
|
return await filter_instance.scan(
|
||||||
|
content,
|
||||||
|
categories=[ContentCategory.SECRETS, ContentCategory.CREDENTIALS],
|
||||||
|
)
|
||||||
@@ -1 +1,23 @@
|
|||||||
"""${dir} module."""
|
"""Emergency controls for agent safety."""
|
||||||
|
|
||||||
|
from .controls import (
|
||||||
|
EmergencyControls,
|
||||||
|
EmergencyEvent,
|
||||||
|
EmergencyReason,
|
||||||
|
EmergencyState,
|
||||||
|
EmergencyTrigger,
|
||||||
|
check_emergency_allowed,
|
||||||
|
emergency_stop_global,
|
||||||
|
get_emergency_controls,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"EmergencyControls",
|
||||||
|
"EmergencyEvent",
|
||||||
|
"EmergencyReason",
|
||||||
|
"EmergencyState",
|
||||||
|
"EmergencyTrigger",
|
||||||
|
"check_emergency_allowed",
|
||||||
|
"emergency_stop_global",
|
||||||
|
"get_emergency_controls",
|
||||||
|
]
|
||||||
|
|||||||
594
backend/app/services/safety/emergency/controls.py
Normal file
594
backend/app/services/safety/emergency/controls.py
Normal file
@@ -0,0 +1,594 @@
|
|||||||
|
"""
|
||||||
|
Emergency Controls
|
||||||
|
|
||||||
|
Emergency stop and pause functionality for agent safety.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from collections.abc import Callable
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from ..exceptions import EmergencyStopError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EmergencyState(str, Enum):
|
||||||
|
"""Emergency control states."""
|
||||||
|
|
||||||
|
NORMAL = "normal"
|
||||||
|
PAUSED = "paused"
|
||||||
|
STOPPED = "stopped"
|
||||||
|
|
||||||
|
|
||||||
|
class EmergencyReason(str, Enum):
|
||||||
|
"""Reasons for emergency actions."""
|
||||||
|
|
||||||
|
MANUAL = "manual"
|
||||||
|
SAFETY_VIOLATION = "safety_violation"
|
||||||
|
BUDGET_EXCEEDED = "budget_exceeded"
|
||||||
|
LOOP_DETECTED = "loop_detected"
|
||||||
|
RATE_LIMIT = "rate_limit"
|
||||||
|
CONTENT_VIOLATION = "content_violation"
|
||||||
|
SYSTEM_ERROR = "system_error"
|
||||||
|
EXTERNAL_TRIGGER = "external_trigger"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EmergencyEvent:
|
||||||
|
"""Record of an emergency action."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
state: EmergencyState
|
||||||
|
reason: EmergencyReason
|
||||||
|
triggered_by: str
|
||||||
|
message: str
|
||||||
|
scope: str # "global", "project:<id>", "agent:<id>"
|
||||||
|
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
resolved_at: datetime | None = None
|
||||||
|
resolved_by: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class EmergencyControls:
|
||||||
|
"""
|
||||||
|
Emergency stop and pause controls for agent safety.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Global emergency stop
|
||||||
|
- Per-project/agent emergency controls
|
||||||
|
- Graceful pause with state preservation
|
||||||
|
- Automatic triggers from safety violations
|
||||||
|
- Manual override capabilities
|
||||||
|
- Event history and audit trail
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
notification_handlers: list[Callable[..., Any]] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize EmergencyControls.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
notification_handlers: Handlers to call on emergency events
|
||||||
|
"""
|
||||||
|
self._global_state = EmergencyState.NORMAL
|
||||||
|
self._scoped_states: dict[str, EmergencyState] = {}
|
||||||
|
self._events: list[EmergencyEvent] = []
|
||||||
|
self._notification_handlers = notification_handlers or []
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
self._event_id_counter = 0
|
||||||
|
|
||||||
|
# Callbacks for state changes
|
||||||
|
self._on_stop_callbacks: list[Callable[..., Any]] = []
|
||||||
|
self._on_pause_callbacks: list[Callable[..., Any]] = []
|
||||||
|
self._on_resume_callbacks: list[Callable[..., Any]] = []
|
||||||
|
|
||||||
|
def _generate_event_id(self) -> str:
|
||||||
|
"""Generate a unique event ID."""
|
||||||
|
self._event_id_counter += 1
|
||||||
|
return f"emerg-{self._event_id_counter:06d}"
|
||||||
|
|
||||||
|
async def emergency_stop(
|
||||||
|
self,
|
||||||
|
reason: EmergencyReason,
|
||||||
|
triggered_by: str,
|
||||||
|
message: str,
|
||||||
|
scope: str = "global",
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> EmergencyEvent:
|
||||||
|
"""
|
||||||
|
Trigger emergency stop.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reason: Reason for the stop
|
||||||
|
triggered_by: Who/what triggered the stop
|
||||||
|
message: Human-readable message
|
||||||
|
scope: Scope of the stop (global, project:<id>, agent:<id>)
|
||||||
|
metadata: Additional context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The emergency event record
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
event = EmergencyEvent(
|
||||||
|
id=self._generate_event_id(),
|
||||||
|
state=EmergencyState.STOPPED,
|
||||||
|
reason=reason,
|
||||||
|
triggered_by=triggered_by,
|
||||||
|
message=message,
|
||||||
|
scope=scope,
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
if scope == "global":
|
||||||
|
self._global_state = EmergencyState.STOPPED
|
||||||
|
else:
|
||||||
|
self._scoped_states[scope] = EmergencyState.STOPPED
|
||||||
|
|
||||||
|
self._events.append(event)
|
||||||
|
|
||||||
|
logger.critical(
|
||||||
|
"EMERGENCY STOP: scope=%s, reason=%s, by=%s - %s",
|
||||||
|
scope,
|
||||||
|
reason.value,
|
||||||
|
triggered_by,
|
||||||
|
message,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute callbacks
|
||||||
|
await self._execute_callbacks(self._on_stop_callbacks, event)
|
||||||
|
await self._notify_handlers("emergency_stop", event)
|
||||||
|
|
||||||
|
return event
|
||||||
|
|
||||||
|
async def pause(
|
||||||
|
self,
|
||||||
|
reason: EmergencyReason,
|
||||||
|
triggered_by: str,
|
||||||
|
message: str,
|
||||||
|
scope: str = "global",
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> EmergencyEvent:
|
||||||
|
"""
|
||||||
|
Pause operations (can be resumed).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reason: Reason for the pause
|
||||||
|
triggered_by: Who/what triggered the pause
|
||||||
|
message: Human-readable message
|
||||||
|
scope: Scope of the pause
|
||||||
|
metadata: Additional context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The emergency event record
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
event = EmergencyEvent(
|
||||||
|
id=self._generate_event_id(),
|
||||||
|
state=EmergencyState.PAUSED,
|
||||||
|
reason=reason,
|
||||||
|
triggered_by=triggered_by,
|
||||||
|
message=message,
|
||||||
|
scope=scope,
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
if scope == "global":
|
||||||
|
self._global_state = EmergencyState.PAUSED
|
||||||
|
else:
|
||||||
|
self._scoped_states[scope] = EmergencyState.PAUSED
|
||||||
|
|
||||||
|
self._events.append(event)
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"PAUSE: scope=%s, reason=%s, by=%s - %s",
|
||||||
|
scope,
|
||||||
|
reason.value,
|
||||||
|
triggered_by,
|
||||||
|
message,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._execute_callbacks(self._on_pause_callbacks, event)
|
||||||
|
await self._notify_handlers("pause", event)
|
||||||
|
|
||||||
|
return event
|
||||||
|
|
||||||
|
async def resume(
|
||||||
|
self,
|
||||||
|
scope: str = "global",
|
||||||
|
resumed_by: str = "system",
|
||||||
|
message: str | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Resume operations from paused state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scope: Scope to resume
|
||||||
|
resumed_by: Who/what is resuming
|
||||||
|
message: Optional message
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if resumed, False if not in paused state
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
current_state = self._get_state(scope)
|
||||||
|
|
||||||
|
if current_state == EmergencyState.STOPPED:
|
||||||
|
logger.warning(
|
||||||
|
"Cannot resume from STOPPED state: %s (requires reset)",
|
||||||
|
scope,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
if current_state == EmergencyState.NORMAL:
|
||||||
|
return True # Already normal
|
||||||
|
|
||||||
|
# Find the pause event and mark as resolved
|
||||||
|
for event in reversed(self._events):
|
||||||
|
if event.scope == scope and event.state == EmergencyState.PAUSED:
|
||||||
|
if event.resolved_at is None:
|
||||||
|
event.resolved_at = datetime.utcnow()
|
||||||
|
event.resolved_by = resumed_by
|
||||||
|
break
|
||||||
|
|
||||||
|
if scope == "global":
|
||||||
|
self._global_state = EmergencyState.NORMAL
|
||||||
|
else:
|
||||||
|
self._scoped_states[scope] = EmergencyState.NORMAL
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"RESUMED: scope=%s, by=%s%s",
|
||||||
|
scope,
|
||||||
|
resumed_by,
|
||||||
|
f" - {message}" if message else "",
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._execute_callbacks(
|
||||||
|
self._on_resume_callbacks,
|
||||||
|
{"scope": scope, "resumed_by": resumed_by},
|
||||||
|
)
|
||||||
|
await self._notify_handlers("resume", {"scope": scope, "resumed_by": resumed_by})
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def reset(
|
||||||
|
self,
|
||||||
|
scope: str = "global",
|
||||||
|
reset_by: str = "admin",
|
||||||
|
message: str | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Reset from stopped state (requires explicit action).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scope: Scope to reset
|
||||||
|
reset_by: Who is resetting (should be admin)
|
||||||
|
message: Optional message
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if reset successful
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
current_state = self._get_state(scope)
|
||||||
|
|
||||||
|
if current_state == EmergencyState.NORMAL:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Find the stop event and mark as resolved
|
||||||
|
for event in reversed(self._events):
|
||||||
|
if event.scope == scope and event.state == EmergencyState.STOPPED:
|
||||||
|
if event.resolved_at is None:
|
||||||
|
event.resolved_at = datetime.utcnow()
|
||||||
|
event.resolved_by = reset_by
|
||||||
|
break
|
||||||
|
|
||||||
|
if scope == "global":
|
||||||
|
self._global_state = EmergencyState.NORMAL
|
||||||
|
else:
|
||||||
|
self._scoped_states[scope] = EmergencyState.NORMAL
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"EMERGENCY RESET: scope=%s, by=%s%s",
|
||||||
|
scope,
|
||||||
|
reset_by,
|
||||||
|
f" - {message}" if message else "",
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._notify_handlers("reset", {"scope": scope, "reset_by": reset_by})
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def check_allowed(
|
||||||
|
self,
|
||||||
|
scope: str | None = None,
|
||||||
|
raise_if_blocked: bool = True,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if operations are allowed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scope: Specific scope to check (also checks global)
|
||||||
|
raise_if_blocked: Raise exception if blocked
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if operations are allowed
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
EmergencyStopError: If blocked and raise_if_blocked=True
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
# Always check global state
|
||||||
|
if self._global_state != EmergencyState.NORMAL:
|
||||||
|
if raise_if_blocked:
|
||||||
|
raise EmergencyStopError(
|
||||||
|
f"Global emergency state: {self._global_state.value}",
|
||||||
|
stop_reason=self._get_last_reason("global"),
|
||||||
|
triggered_by=self._get_last_triggered_by("global"),
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check specific scope
|
||||||
|
if scope and scope in self._scoped_states:
|
||||||
|
state = self._scoped_states[scope]
|
||||||
|
if state != EmergencyState.NORMAL:
|
||||||
|
if raise_if_blocked:
|
||||||
|
raise EmergencyStopError(
|
||||||
|
f"Emergency state for {scope}: {state.value}",
|
||||||
|
stop_reason=self._get_last_reason(scope),
|
||||||
|
triggered_by=self._get_last_triggered_by(scope),
|
||||||
|
scope=scope,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _get_state(self, scope: str) -> EmergencyState:
|
||||||
|
"""Get state for a scope."""
|
||||||
|
if scope == "global":
|
||||||
|
return self._global_state
|
||||||
|
return self._scoped_states.get(scope, EmergencyState.NORMAL)
|
||||||
|
|
||||||
|
def _get_last_reason(self, scope: str) -> str:
|
||||||
|
"""Get reason from last event for scope."""
|
||||||
|
for event in reversed(self._events):
|
||||||
|
if event.scope == scope and event.resolved_at is None:
|
||||||
|
return event.reason.value
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
def _get_last_triggered_by(self, scope: str) -> str:
|
||||||
|
"""Get triggered_by from last event for scope."""
|
||||||
|
for event in reversed(self._events):
|
||||||
|
if event.scope == scope and event.resolved_at is None:
|
||||||
|
return event.triggered_by
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
async def get_state(self, scope: str = "global") -> EmergencyState:
|
||||||
|
"""Get current state for a scope."""
|
||||||
|
async with self._lock:
|
||||||
|
return self._get_state(scope)
|
||||||
|
|
||||||
|
async def get_all_states(self) -> dict[str, EmergencyState]:
|
||||||
|
"""Get all current states."""
|
||||||
|
async with self._lock:
|
||||||
|
states = {"global": self._global_state}
|
||||||
|
states.update(self._scoped_states)
|
||||||
|
return states
|
||||||
|
|
||||||
|
async def get_active_events(self) -> list[EmergencyEvent]:
|
||||||
|
"""Get all unresolved emergency events."""
|
||||||
|
async with self._lock:
|
||||||
|
return [e for e in self._events if e.resolved_at is None]
|
||||||
|
|
||||||
|
async def get_event_history(
|
||||||
|
self,
|
||||||
|
scope: str | None = None,
|
||||||
|
limit: int = 100,
|
||||||
|
) -> list[EmergencyEvent]:
|
||||||
|
"""Get emergency event history."""
|
||||||
|
async with self._lock:
|
||||||
|
events = list(self._events)
|
||||||
|
|
||||||
|
if scope:
|
||||||
|
events = [e for e in events if e.scope == scope]
|
||||||
|
|
||||||
|
return events[-limit:]
|
||||||
|
|
||||||
|
def on_stop(self, callback: Callable[..., Any]) -> None:
|
||||||
|
"""Register callback for stop events."""
|
||||||
|
self._on_stop_callbacks.append(callback)
|
||||||
|
|
||||||
|
def on_pause(self, callback: Callable[..., Any]) -> None:
|
||||||
|
"""Register callback for pause events."""
|
||||||
|
self._on_pause_callbacks.append(callback)
|
||||||
|
|
||||||
|
def on_resume(self, callback: Callable[..., Any]) -> None:
|
||||||
|
"""Register callback for resume events."""
|
||||||
|
self._on_resume_callbacks.append(callback)
|
||||||
|
|
||||||
|
def add_notification_handler(self, handler: Callable[..., Any]) -> None:
|
||||||
|
"""Add a notification handler."""
|
||||||
|
self._notification_handlers.append(handler)
|
||||||
|
|
||||||
|
async def _execute_callbacks(
|
||||||
|
self,
|
||||||
|
callbacks: list[Callable[..., Any]],
|
||||||
|
data: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Execute callbacks safely."""
|
||||||
|
for callback in callbacks:
|
||||||
|
try:
|
||||||
|
if asyncio.iscoroutinefunction(callback):
|
||||||
|
await callback(data)
|
||||||
|
else:
|
||||||
|
callback(data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error in callback: %s", e)
|
||||||
|
|
||||||
|
async def _notify_handlers(self, event_type: str, data: Any) -> None:
|
||||||
|
"""Notify all handlers of an event."""
|
||||||
|
for handler in self._notification_handlers:
|
||||||
|
try:
|
||||||
|
if asyncio.iscoroutinefunction(handler):
|
||||||
|
await handler(event_type, data)
|
||||||
|
else:
|
||||||
|
handler(event_type, data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error in notification handler: %s", e)
|
||||||
|
|
||||||
|
|
||||||
|
class EmergencyTrigger:
|
||||||
|
"""
|
||||||
|
Automatic emergency triggers based on conditions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, controls: EmergencyControls) -> None:
|
||||||
|
"""
|
||||||
|
Initialize EmergencyTrigger.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
controls: EmergencyControls instance to trigger
|
||||||
|
"""
|
||||||
|
self._controls = controls
|
||||||
|
|
||||||
|
async def trigger_on_safety_violation(
|
||||||
|
self,
|
||||||
|
violation_type: str,
|
||||||
|
details: dict[str, Any],
|
||||||
|
scope: str = "global",
|
||||||
|
) -> EmergencyEvent:
|
||||||
|
"""
|
||||||
|
Trigger emergency from safety violation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
violation_type: Type of violation
|
||||||
|
details: Violation details
|
||||||
|
scope: Scope for the emergency
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Emergency event
|
||||||
|
"""
|
||||||
|
return await self._controls.emergency_stop(
|
||||||
|
reason=EmergencyReason.SAFETY_VIOLATION,
|
||||||
|
triggered_by="safety_system",
|
||||||
|
message=f"Safety violation: {violation_type}",
|
||||||
|
scope=scope,
|
||||||
|
metadata={"violation_type": violation_type, **details},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def trigger_on_budget_exceeded(
|
||||||
|
self,
|
||||||
|
budget_type: str,
|
||||||
|
current: float,
|
||||||
|
limit: float,
|
||||||
|
scope: str = "global",
|
||||||
|
) -> EmergencyEvent:
|
||||||
|
"""
|
||||||
|
Trigger emergency from budget exceeded.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
budget_type: Type of budget
|
||||||
|
current: Current usage
|
||||||
|
limit: Budget limit
|
||||||
|
scope: Scope for the emergency
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Emergency event
|
||||||
|
"""
|
||||||
|
return await self._controls.pause(
|
||||||
|
reason=EmergencyReason.BUDGET_EXCEEDED,
|
||||||
|
triggered_by="budget_controller",
|
||||||
|
message=f"Budget exceeded: {budget_type} ({current:.2f}/{limit:.2f})",
|
||||||
|
scope=scope,
|
||||||
|
metadata={"budget_type": budget_type, "current": current, "limit": limit},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def trigger_on_loop_detected(
|
||||||
|
self,
|
||||||
|
loop_type: str,
|
||||||
|
agent_id: str,
|
||||||
|
details: dict[str, Any],
|
||||||
|
) -> EmergencyEvent:
|
||||||
|
"""
|
||||||
|
Trigger emergency from loop detection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loop_type: Type of loop
|
||||||
|
agent_id: Agent that's looping
|
||||||
|
details: Loop details
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Emergency event
|
||||||
|
"""
|
||||||
|
return await self._controls.pause(
|
||||||
|
reason=EmergencyReason.LOOP_DETECTED,
|
||||||
|
triggered_by="loop_detector",
|
||||||
|
message=f"Loop detected: {loop_type} in agent {agent_id}",
|
||||||
|
scope=f"agent:{agent_id}",
|
||||||
|
metadata={"loop_type": loop_type, "agent_id": agent_id, **details},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def trigger_on_content_violation(
|
||||||
|
self,
|
||||||
|
category: str,
|
||||||
|
pattern: str,
|
||||||
|
scope: str = "global",
|
||||||
|
) -> EmergencyEvent:
|
||||||
|
"""
|
||||||
|
Trigger emergency from content violation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
category: Content category
|
||||||
|
pattern: Pattern that matched
|
||||||
|
scope: Scope for the emergency
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Emergency event
|
||||||
|
"""
|
||||||
|
return await self._controls.emergency_stop(
|
||||||
|
reason=EmergencyReason.CONTENT_VIOLATION,
|
||||||
|
triggered_by="content_filter",
|
||||||
|
message=f"Content violation: {category} ({pattern})",
|
||||||
|
scope=scope,
|
||||||
|
metadata={"category": category, "pattern": pattern},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
_emergency_controls: EmergencyControls | None = None
|
||||||
|
_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_emergency_controls() -> EmergencyControls:
|
||||||
|
"""Get the singleton EmergencyControls instance."""
|
||||||
|
global _emergency_controls
|
||||||
|
|
||||||
|
async with _lock:
|
||||||
|
if _emergency_controls is None:
|
||||||
|
_emergency_controls = EmergencyControls()
|
||||||
|
return _emergency_controls
|
||||||
|
|
||||||
|
|
||||||
|
async def emergency_stop_global(
|
||||||
|
reason: str,
|
||||||
|
triggered_by: str = "system",
|
||||||
|
) -> EmergencyEvent:
|
||||||
|
"""Quick global emergency stop."""
|
||||||
|
controls = await get_emergency_controls()
|
||||||
|
return await controls.emergency_stop(
|
||||||
|
reason=EmergencyReason.MANUAL,
|
||||||
|
triggered_by=triggered_by,
|
||||||
|
message=reason,
|
||||||
|
scope="global",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def check_emergency_allowed(scope: str | None = None) -> bool:
|
||||||
|
"""Quick check if operations are allowed."""
|
||||||
|
controls = await get_emergency_controls()
|
||||||
|
return await controls.check_allowed(scope=scope, raise_if_blocked=False)
|
||||||
@@ -1 +1,5 @@
|
|||||||
"""${dir} module."""
|
"""Human-in-the-Loop approval workflows."""
|
||||||
|
|
||||||
|
from .manager import ApprovalQueue, HITLManager
|
||||||
|
|
||||||
|
__all__ = ["ApprovalQueue", "HITLManager"]
|
||||||
|
|||||||
449
backend/app/services/safety/hitl/manager.py
Normal file
449
backend/app/services/safety/hitl/manager.py
Normal file
@@ -0,0 +1,449 @@
|
|||||||
|
"""
|
||||||
|
Human-in-the-Loop (HITL) Manager
|
||||||
|
|
||||||
|
Manages approval workflows for actions requiring human oversight.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from collections.abc import Callable
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from ..config import get_safety_config
|
||||||
|
from ..exceptions import (
|
||||||
|
ApprovalDeniedError,
|
||||||
|
ApprovalRequiredError,
|
||||||
|
ApprovalTimeoutError,
|
||||||
|
)
|
||||||
|
from ..models import (
|
||||||
|
ActionRequest,
|
||||||
|
ApprovalRequest,
|
||||||
|
ApprovalResponse,
|
||||||
|
ApprovalStatus,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ApprovalQueue:
|
||||||
|
"""Queue for pending approval requests."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._pending: dict[str, ApprovalRequest] = {}
|
||||||
|
self._completed: dict[str, ApprovalResponse] = {}
|
||||||
|
self._waiters: dict[str, asyncio.Event] = {}
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def add(self, request: ApprovalRequest) -> None:
|
||||||
|
"""Add an approval request to the queue."""
|
||||||
|
async with self._lock:
|
||||||
|
self._pending[request.id] = request
|
||||||
|
self._waiters[request.id] = asyncio.Event()
|
||||||
|
|
||||||
|
async def get_pending(self, request_id: str) -> ApprovalRequest | None:
|
||||||
|
"""Get a pending request by ID."""
|
||||||
|
async with self._lock:
|
||||||
|
return self._pending.get(request_id)
|
||||||
|
|
||||||
|
async def complete(self, response: ApprovalResponse) -> bool:
|
||||||
|
"""Complete an approval request."""
|
||||||
|
async with self._lock:
|
||||||
|
if response.request_id not in self._pending:
|
||||||
|
return False
|
||||||
|
|
||||||
|
del self._pending[response.request_id]
|
||||||
|
self._completed[response.request_id] = response
|
||||||
|
|
||||||
|
# Notify waiters
|
||||||
|
if response.request_id in self._waiters:
|
||||||
|
self._waiters[response.request_id].set()
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def wait_for_response(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
timeout_seconds: float,
|
||||||
|
) -> ApprovalResponse | None:
|
||||||
|
"""Wait for a response to an approval request."""
|
||||||
|
async with self._lock:
|
||||||
|
waiter = self._waiters.get(request_id)
|
||||||
|
if not waiter:
|
||||||
|
return self._completed.get(request_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(waiter.wait(), timeout=timeout_seconds)
|
||||||
|
except TimeoutError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
return self._completed.get(request_id)
|
||||||
|
|
||||||
|
async def list_pending(self) -> list[ApprovalRequest]:
|
||||||
|
"""List all pending requests."""
|
||||||
|
async with self._lock:
|
||||||
|
return list(self._pending.values())
|
||||||
|
|
||||||
|
async def cancel(self, request_id: str) -> bool:
|
||||||
|
"""Cancel a pending request."""
|
||||||
|
async with self._lock:
|
||||||
|
if request_id not in self._pending:
|
||||||
|
return False
|
||||||
|
|
||||||
|
del self._pending[request_id]
|
||||||
|
|
||||||
|
# Create cancelled response
|
||||||
|
response = ApprovalResponse(
|
||||||
|
request_id=request_id,
|
||||||
|
status=ApprovalStatus.CANCELLED,
|
||||||
|
reason="Cancelled",
|
||||||
|
)
|
||||||
|
self._completed[request_id] = response
|
||||||
|
|
||||||
|
# Notify waiters
|
||||||
|
if request_id in self._waiters:
|
||||||
|
self._waiters[request_id].set()
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def cleanup_expired(self) -> int:
|
||||||
|
"""Clean up expired requests."""
|
||||||
|
now = datetime.utcnow()
|
||||||
|
to_timeout: list[str] = []
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
for request_id, request in self._pending.items():
|
||||||
|
if request.expires_at and request.expires_at < now:
|
||||||
|
to_timeout.append(request_id)
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for request_id in to_timeout:
|
||||||
|
async with self._lock:
|
||||||
|
if request_id in self._pending:
|
||||||
|
del self._pending[request_id]
|
||||||
|
self._completed[request_id] = ApprovalResponse(
|
||||||
|
request_id=request_id,
|
||||||
|
status=ApprovalStatus.TIMEOUT,
|
||||||
|
reason="Request timed out",
|
||||||
|
)
|
||||||
|
if request_id in self._waiters:
|
||||||
|
self._waiters[request_id].set()
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
class HITLManager:
|
||||||
|
"""
|
||||||
|
Manages Human-in-the-Loop approval workflows.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Approval request queue
|
||||||
|
- Configurable timeout handling (default deny)
|
||||||
|
- Approval delegation
|
||||||
|
- Batch approval for similar actions
|
||||||
|
- Approval with modifications
|
||||||
|
- Notification channels
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
default_timeout: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the HITLManager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
default_timeout: Default timeout for approval requests in seconds
|
||||||
|
"""
|
||||||
|
config = get_safety_config()
|
||||||
|
|
||||||
|
self._default_timeout = default_timeout or config.hitl_default_timeout
|
||||||
|
self._queue = ApprovalQueue()
|
||||||
|
self._notification_handlers: list[Callable[..., Any]] = []
|
||||||
|
self._running = False
|
||||||
|
self._cleanup_task: asyncio.Task[None] | None = None
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""Start the HITL manager background tasks."""
|
||||||
|
if self._running:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._running = True
|
||||||
|
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
|
||||||
|
logger.info("HITL Manager started")
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""Stop the HITL manager."""
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
if self._cleanup_task:
|
||||||
|
self._cleanup_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._cleanup_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger.info("HITL Manager stopped")
|
||||||
|
|
||||||
|
async def request_approval(
|
||||||
|
self,
|
||||||
|
action: ActionRequest,
|
||||||
|
reason: str,
|
||||||
|
timeout_seconds: int | None = None,
|
||||||
|
urgency: str = "normal",
|
||||||
|
context: dict[str, Any] | None = None,
|
||||||
|
) -> ApprovalRequest:
|
||||||
|
"""
|
||||||
|
Create an approval request for an action.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: The action requiring approval
|
||||||
|
reason: Why approval is required
|
||||||
|
timeout_seconds: Timeout for this request
|
||||||
|
urgency: Urgency level (low, normal, high, critical)
|
||||||
|
context: Additional context for the approver
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created approval request
|
||||||
|
"""
|
||||||
|
timeout = timeout_seconds or self._default_timeout
|
||||||
|
expires_at = datetime.utcnow() + timedelta(seconds=timeout)
|
||||||
|
|
||||||
|
request = ApprovalRequest(
|
||||||
|
id=str(uuid4()),
|
||||||
|
action=action,
|
||||||
|
reason=reason,
|
||||||
|
urgency=urgency,
|
||||||
|
timeout_seconds=timeout,
|
||||||
|
expires_at=expires_at,
|
||||||
|
context=context or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._queue.add(request)
|
||||||
|
|
||||||
|
# Notify handlers
|
||||||
|
await self._notify_handlers("approval_requested", request)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Approval requested: %s for action %s (timeout: %ds)",
|
||||||
|
request.id,
|
||||||
|
action.id,
|
||||||
|
timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
return request
|
||||||
|
|
||||||
|
async def wait_for_approval(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
timeout_seconds: int | None = None,
|
||||||
|
) -> ApprovalResponse:
|
||||||
|
"""
|
||||||
|
Wait for an approval decision.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_id: ID of the approval request
|
||||||
|
timeout_seconds: Override timeout
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The approval response
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ApprovalTimeoutError: If timeout expires
|
||||||
|
ApprovalDeniedError: If approval is denied
|
||||||
|
"""
|
||||||
|
request = await self._queue.get_pending(request_id)
|
||||||
|
if not request:
|
||||||
|
raise ApprovalRequiredError(
|
||||||
|
f"Approval request not found: {request_id}",
|
||||||
|
approval_id=request_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
timeout = timeout_seconds or request.timeout_seconds or self._default_timeout
|
||||||
|
response = await self._queue.wait_for_response(request_id, timeout)
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
# Timeout - default deny
|
||||||
|
response = ApprovalResponse(
|
||||||
|
request_id=request_id,
|
||||||
|
status=ApprovalStatus.TIMEOUT,
|
||||||
|
reason="Request timed out (default deny)",
|
||||||
|
)
|
||||||
|
await self._queue.complete(response)
|
||||||
|
|
||||||
|
raise ApprovalTimeoutError(
|
||||||
|
"Approval request timed out",
|
||||||
|
approval_id=request_id,
|
||||||
|
timeout_seconds=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status == ApprovalStatus.DENIED:
|
||||||
|
raise ApprovalDeniedError(
|
||||||
|
response.reason or "Approval denied",
|
||||||
|
approval_id=request_id,
|
||||||
|
denied_by=response.decided_by,
|
||||||
|
denial_reason=response.reason,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status == ApprovalStatus.TIMEOUT:
|
||||||
|
raise ApprovalTimeoutError(
|
||||||
|
"Approval request timed out",
|
||||||
|
approval_id=request_id,
|
||||||
|
timeout_seconds=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status == ApprovalStatus.CANCELLED:
|
||||||
|
raise ApprovalDeniedError(
|
||||||
|
"Approval request was cancelled",
|
||||||
|
approval_id=request_id,
|
||||||
|
denial_reason="Cancelled",
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def approve(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
decided_by: str,
|
||||||
|
reason: str | None = None,
|
||||||
|
modifications: dict[str, Any] | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Approve a pending request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_id: ID of the approval request
|
||||||
|
decided_by: Who approved
|
||||||
|
reason: Optional approval reason
|
||||||
|
modifications: Optional modifications to the action
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if approval was recorded
|
||||||
|
"""
|
||||||
|
response = ApprovalResponse(
|
||||||
|
request_id=request_id,
|
||||||
|
status=ApprovalStatus.APPROVED,
|
||||||
|
decided_by=decided_by,
|
||||||
|
reason=reason,
|
||||||
|
modifications=modifications,
|
||||||
|
)
|
||||||
|
|
||||||
|
success = await self._queue.complete(response)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
logger.info(
|
||||||
|
"Approval granted: %s by %s",
|
||||||
|
request_id,
|
||||||
|
decided_by,
|
||||||
|
)
|
||||||
|
await self._notify_handlers("approval_granted", response)
|
||||||
|
|
||||||
|
return success
|
||||||
|
|
||||||
|
async def deny(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
decided_by: str,
|
||||||
|
reason: str | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Deny a pending request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_id: ID of the approval request
|
||||||
|
decided_by: Who denied
|
||||||
|
reason: Denial reason
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if denial was recorded
|
||||||
|
"""
|
||||||
|
response = ApprovalResponse(
|
||||||
|
request_id=request_id,
|
||||||
|
status=ApprovalStatus.DENIED,
|
||||||
|
decided_by=decided_by,
|
||||||
|
reason=reason,
|
||||||
|
)
|
||||||
|
|
||||||
|
success = await self._queue.complete(response)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
logger.info(
|
||||||
|
"Approval denied: %s by %s - %s",
|
||||||
|
request_id,
|
||||||
|
decided_by,
|
||||||
|
reason,
|
||||||
|
)
|
||||||
|
await self._notify_handlers("approval_denied", response)
|
||||||
|
|
||||||
|
return success
|
||||||
|
|
||||||
|
async def cancel(self, request_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Cancel a pending request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_id: ID of the approval request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if request was cancelled
|
||||||
|
"""
|
||||||
|
success = await self._queue.cancel(request_id)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
logger.info("Approval request cancelled: %s", request_id)
|
||||||
|
|
||||||
|
return success
|
||||||
|
|
||||||
|
async def list_pending(self) -> list[ApprovalRequest]:
|
||||||
|
"""List all pending approval requests."""
|
||||||
|
return await self._queue.list_pending()
|
||||||
|
|
||||||
|
async def get_request(self, request_id: str) -> ApprovalRequest | None:
|
||||||
|
"""Get an approval request by ID."""
|
||||||
|
return await self._queue.get_pending(request_id)
|
||||||
|
|
||||||
|
def add_notification_handler(
|
||||||
|
self,
|
||||||
|
handler: Callable[..., Any],
|
||||||
|
) -> None:
|
||||||
|
"""Add a notification handler."""
|
||||||
|
self._notification_handlers.append(handler)
|
||||||
|
|
||||||
|
def remove_notification_handler(
|
||||||
|
self,
|
||||||
|
handler: Callable[..., Any],
|
||||||
|
) -> None:
|
||||||
|
"""Remove a notification handler."""
|
||||||
|
if handler in self._notification_handlers:
|
||||||
|
self._notification_handlers.remove(handler)
|
||||||
|
|
||||||
|
async def _notify_handlers(
|
||||||
|
self,
|
||||||
|
event_type: str,
|
||||||
|
data: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Notify all handlers of an event."""
|
||||||
|
for handler in self._notification_handlers:
|
||||||
|
try:
|
||||||
|
if asyncio.iscoroutinefunction(handler):
|
||||||
|
await handler(event_type, data)
|
||||||
|
else:
|
||||||
|
handler(event_type, data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error in notification handler: %s", e)
|
||||||
|
|
||||||
|
async def _periodic_cleanup(self) -> None:
|
||||||
|
"""Background task for cleaning up expired requests."""
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(30) # Check every 30 seconds
|
||||||
|
count = await self._queue.cleanup_expired()
|
||||||
|
if count:
|
||||||
|
logger.debug("Cleaned up %d expired approval requests", count)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error in approval cleanup: %s", e)
|
||||||
@@ -1 +1,5 @@
|
|||||||
"""${dir} module."""
|
"""Rollback management for agent actions."""
|
||||||
|
|
||||||
|
from .manager import RollbackManager, TransactionContext
|
||||||
|
|
||||||
|
__all__ = ["RollbackManager", "TransactionContext"]
|
||||||
|
|||||||
418
backend/app/services/safety/rollback/manager.py
Normal file
418
backend/app/services/safety/rollback/manager.py
Normal file
@@ -0,0 +1,418 @@
|
|||||||
|
"""
|
||||||
|
Rollback Manager
|
||||||
|
|
||||||
|
Manages checkpoints and rollback operations for agent actions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from ..config import get_safety_config
|
||||||
|
from ..exceptions import RollbackError
|
||||||
|
from ..models import (
|
||||||
|
ActionRequest,
|
||||||
|
Checkpoint,
|
||||||
|
CheckpointType,
|
||||||
|
RollbackResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FileCheckpoint:
|
||||||
|
"""Stores file state for rollback."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
checkpoint_id: str,
|
||||||
|
file_path: str,
|
||||||
|
original_content: bytes | None,
|
||||||
|
existed: bool,
|
||||||
|
) -> None:
|
||||||
|
self.checkpoint_id = checkpoint_id
|
||||||
|
self.file_path = file_path
|
||||||
|
self.original_content = original_content
|
||||||
|
self.existed = existed
|
||||||
|
self.created_at = datetime.utcnow()
|
||||||
|
|
||||||
|
|
||||||
|
class RollbackManager:
|
||||||
|
"""
|
||||||
|
Manages checkpoints and rollback operations.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- File system checkpoints
|
||||||
|
- Transaction wrapping for actions
|
||||||
|
- Automatic checkpoint for destructive actions
|
||||||
|
- Rollback triggers on failure
|
||||||
|
- Checkpoint expiration and cleanup
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
checkpoint_dir: str | None = None,
|
||||||
|
retention_hours: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the RollbackManager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_dir: Directory for storing checkpoint data
|
||||||
|
retention_hours: Hours to retain checkpoints
|
||||||
|
"""
|
||||||
|
config = get_safety_config()
|
||||||
|
|
||||||
|
self._checkpoint_dir = Path(
|
||||||
|
checkpoint_dir or config.checkpoint_dir
|
||||||
|
)
|
||||||
|
self._retention_hours = retention_hours or config.checkpoint_retention_hours
|
||||||
|
|
||||||
|
self._checkpoints: dict[str, Checkpoint] = {}
|
||||||
|
self._file_checkpoints: dict[str, list[FileCheckpoint]] = {}
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
# Ensure checkpoint directory exists
|
||||||
|
self._checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
async def create_checkpoint(
|
||||||
|
self,
|
||||||
|
action: ActionRequest,
|
||||||
|
checkpoint_type: CheckpointType = CheckpointType.COMPOSITE,
|
||||||
|
description: str | None = None,
|
||||||
|
) -> Checkpoint:
|
||||||
|
"""
|
||||||
|
Create a checkpoint before an action.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: The action to checkpoint for
|
||||||
|
checkpoint_type: Type of checkpoint
|
||||||
|
description: Optional description
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created checkpoint
|
||||||
|
"""
|
||||||
|
checkpoint_id = str(uuid4())
|
||||||
|
|
||||||
|
checkpoint = Checkpoint(
|
||||||
|
id=checkpoint_id,
|
||||||
|
checkpoint_type=checkpoint_type,
|
||||||
|
action_id=action.id,
|
||||||
|
created_at=datetime.utcnow(),
|
||||||
|
expires_at=datetime.utcnow() + timedelta(hours=self._retention_hours),
|
||||||
|
data={
|
||||||
|
"action_type": action.action_type.value,
|
||||||
|
"tool_name": action.tool_name,
|
||||||
|
"resource": action.resource,
|
||||||
|
},
|
||||||
|
description=description or f"Checkpoint for {action.tool_name}",
|
||||||
|
)
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
self._checkpoints[checkpoint_id] = checkpoint
|
||||||
|
self._file_checkpoints[checkpoint_id] = []
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Created checkpoint %s for action %s",
|
||||||
|
checkpoint_id,
|
||||||
|
action.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
|
async def checkpoint_file(
|
||||||
|
self,
|
||||||
|
checkpoint_id: str,
|
||||||
|
file_path: str,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Store current state of a file for checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_id: ID of the checkpoint
|
||||||
|
file_path: Path to the file
|
||||||
|
"""
|
||||||
|
path = Path(file_path)
|
||||||
|
|
||||||
|
if path.exists():
|
||||||
|
content = path.read_bytes()
|
||||||
|
existed = True
|
||||||
|
else:
|
||||||
|
content = None
|
||||||
|
existed = False
|
||||||
|
|
||||||
|
file_checkpoint = FileCheckpoint(
|
||||||
|
checkpoint_id=checkpoint_id,
|
||||||
|
file_path=file_path,
|
||||||
|
original_content=content,
|
||||||
|
existed=existed,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
if checkpoint_id not in self._file_checkpoints:
|
||||||
|
self._file_checkpoints[checkpoint_id] = []
|
||||||
|
self._file_checkpoints[checkpoint_id].append(file_checkpoint)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Stored file state for checkpoint %s: %s (existed=%s)",
|
||||||
|
checkpoint_id,
|
||||||
|
file_path,
|
||||||
|
existed,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def checkpoint_files(
|
||||||
|
self,
|
||||||
|
checkpoint_id: str,
|
||||||
|
file_paths: list[str],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Store current state of multiple files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_id: ID of the checkpoint
|
||||||
|
file_paths: Paths to the files
|
||||||
|
"""
|
||||||
|
for path in file_paths:
|
||||||
|
await self.checkpoint_file(checkpoint_id, path)
|
||||||
|
|
||||||
|
async def rollback(
|
||||||
|
self,
|
||||||
|
checkpoint_id: str,
|
||||||
|
) -> RollbackResult:
|
||||||
|
"""
|
||||||
|
Rollback to a checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_id: ID of the checkpoint
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Result of the rollback operation
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
checkpoint = self._checkpoints.get(checkpoint_id)
|
||||||
|
if not checkpoint:
|
||||||
|
raise RollbackError(
|
||||||
|
f"Checkpoint not found: {checkpoint_id}",
|
||||||
|
checkpoint_id=checkpoint_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not checkpoint.is_valid:
|
||||||
|
raise RollbackError(
|
||||||
|
f"Checkpoint is no longer valid: {checkpoint_id}",
|
||||||
|
checkpoint_id=checkpoint_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
file_checkpoints = self._file_checkpoints.get(checkpoint_id, [])
|
||||||
|
|
||||||
|
actions_rolled_back: list[str] = []
|
||||||
|
failed_actions: list[str] = []
|
||||||
|
|
||||||
|
# Rollback file changes
|
||||||
|
for fc in file_checkpoints:
|
||||||
|
try:
|
||||||
|
await self._rollback_file(fc)
|
||||||
|
actions_rolled_back.append(f"file:{fc.file_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to rollback file %s: %s", fc.file_path, e)
|
||||||
|
failed_actions.append(f"file:{fc.file_path}")
|
||||||
|
|
||||||
|
success = len(failed_actions) == 0
|
||||||
|
|
||||||
|
# Mark checkpoint as used
|
||||||
|
async with self._lock:
|
||||||
|
if checkpoint_id in self._checkpoints:
|
||||||
|
self._checkpoints[checkpoint_id].is_valid = False
|
||||||
|
|
||||||
|
result = RollbackResult(
|
||||||
|
checkpoint_id=checkpoint_id,
|
||||||
|
success=success,
|
||||||
|
actions_rolled_back=actions_rolled_back,
|
||||||
|
failed_actions=failed_actions,
|
||||||
|
error=None if success else f"Failed to rollback {len(failed_actions)} items",
|
||||||
|
)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
logger.info("Rollback successful for checkpoint %s", checkpoint_id)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
"Rollback partially failed for checkpoint %s: %d failures",
|
||||||
|
checkpoint_id,
|
||||||
|
len(failed_actions),
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def discard_checkpoint(self, checkpoint_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Discard a checkpoint without rolling back.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_id: ID of the checkpoint
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if checkpoint was found and discarded
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
if checkpoint_id in self._checkpoints:
|
||||||
|
del self._checkpoints[checkpoint_id]
|
||||||
|
if checkpoint_id in self._file_checkpoints:
|
||||||
|
del self._file_checkpoints[checkpoint_id]
|
||||||
|
logger.debug("Discarded checkpoint %s", checkpoint_id)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get_checkpoint(self, checkpoint_id: str) -> Checkpoint | None:
|
||||||
|
"""Get a checkpoint by ID."""
|
||||||
|
async with self._lock:
|
||||||
|
return self._checkpoints.get(checkpoint_id)
|
||||||
|
|
||||||
|
async def list_checkpoints(
|
||||||
|
self,
|
||||||
|
action_id: str | None = None,
|
||||||
|
include_expired: bool = False,
|
||||||
|
) -> list[Checkpoint]:
|
||||||
|
"""
|
||||||
|
List checkpoints.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action_id: Optional filter by action ID
|
||||||
|
include_expired: Include expired checkpoints
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of checkpoints
|
||||||
|
"""
|
||||||
|
now = datetime.utcnow()
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
checkpoints = list(self._checkpoints.values())
|
||||||
|
|
||||||
|
if action_id:
|
||||||
|
checkpoints = [c for c in checkpoints if c.action_id == action_id]
|
||||||
|
|
||||||
|
if not include_expired:
|
||||||
|
checkpoints = [
|
||||||
|
c for c in checkpoints
|
||||||
|
if c.expires_at is None or c.expires_at > now
|
||||||
|
]
|
||||||
|
|
||||||
|
return checkpoints
|
||||||
|
|
||||||
|
async def cleanup_expired(self) -> int:
|
||||||
|
"""
|
||||||
|
Clean up expired checkpoints.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of checkpoints cleaned up
|
||||||
|
"""
|
||||||
|
now = datetime.utcnow()
|
||||||
|
to_remove: list[str] = []
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
for checkpoint_id, checkpoint in self._checkpoints.items():
|
||||||
|
if checkpoint.expires_at and checkpoint.expires_at < now:
|
||||||
|
to_remove.append(checkpoint_id)
|
||||||
|
|
||||||
|
for checkpoint_id in to_remove:
|
||||||
|
del self._checkpoints[checkpoint_id]
|
||||||
|
if checkpoint_id in self._file_checkpoints:
|
||||||
|
del self._file_checkpoints[checkpoint_id]
|
||||||
|
|
||||||
|
if to_remove:
|
||||||
|
logger.info("Cleaned up %d expired checkpoints", len(to_remove))
|
||||||
|
|
||||||
|
return len(to_remove)
|
||||||
|
|
||||||
|
async def _rollback_file(self, fc: FileCheckpoint) -> None:
|
||||||
|
"""Rollback a single file to its checkpoint state."""
|
||||||
|
path = Path(fc.file_path)
|
||||||
|
|
||||||
|
if fc.existed:
|
||||||
|
# Restore original content
|
||||||
|
if fc.original_content is not None:
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
path.write_bytes(fc.original_content)
|
||||||
|
logger.debug("Restored file: %s", fc.file_path)
|
||||||
|
else:
|
||||||
|
# File didn't exist before - delete it
|
||||||
|
if path.exists():
|
||||||
|
path.unlink()
|
||||||
|
logger.debug("Deleted file (didn't exist before): %s", fc.file_path)
|
||||||
|
|
||||||
|
|
||||||
|
class TransactionContext:
|
||||||
|
"""
|
||||||
|
Context manager for transactional action execution.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
async with TransactionContext(rollback_manager, action) as tx:
|
||||||
|
tx.checkpoint_file("/path/to/file")
|
||||||
|
# Do work...
|
||||||
|
# If exception occurs, automatic rollback
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
manager: RollbackManager,
|
||||||
|
action: ActionRequest,
|
||||||
|
auto_rollback: bool = True,
|
||||||
|
) -> None:
|
||||||
|
self._manager = manager
|
||||||
|
self._action = action
|
||||||
|
self._auto_rollback = auto_rollback
|
||||||
|
self._checkpoint: Checkpoint | None = None
|
||||||
|
self._committed = False
|
||||||
|
|
||||||
|
async def __aenter__(self) -> "TransactionContext":
|
||||||
|
self._checkpoint = await self._manager.create_checkpoint(self._action)
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: type | None,
|
||||||
|
exc_val: Exception | None,
|
||||||
|
exc_tb: Any,
|
||||||
|
) -> bool:
|
||||||
|
if exc_val is not None and self._auto_rollback and not self._committed:
|
||||||
|
# Exception occurred - rollback
|
||||||
|
if self._checkpoint:
|
||||||
|
try:
|
||||||
|
await self._manager.rollback(self._checkpoint.id)
|
||||||
|
logger.info(
|
||||||
|
"Auto-rollback completed for action %s",
|
||||||
|
self._action.id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Auto-rollback failed: %s", e)
|
||||||
|
elif self._committed and self._checkpoint:
|
||||||
|
# Committed - discard checkpoint
|
||||||
|
await self._manager.discard_checkpoint(self._checkpoint.id)
|
||||||
|
|
||||||
|
return False # Don't suppress the exception
|
||||||
|
|
||||||
|
@property
|
||||||
|
def checkpoint_id(self) -> str | None:
|
||||||
|
"""Get the checkpoint ID."""
|
||||||
|
return self._checkpoint.id if self._checkpoint else None
|
||||||
|
|
||||||
|
async def checkpoint_file(self, file_path: str) -> None:
|
||||||
|
"""Checkpoint a file for this transaction."""
|
||||||
|
if self._checkpoint:
|
||||||
|
await self._manager.checkpoint_file(self._checkpoint.id, file_path)
|
||||||
|
|
||||||
|
async def checkpoint_files(self, file_paths: list[str]) -> None:
|
||||||
|
"""Checkpoint multiple files for this transaction."""
|
||||||
|
if self._checkpoint:
|
||||||
|
await self._manager.checkpoint_files(self._checkpoint.id, file_paths)
|
||||||
|
|
||||||
|
def commit(self) -> None:
|
||||||
|
"""Mark transaction as committed (no rollback on exit)."""
|
||||||
|
self._committed = True
|
||||||
|
|
||||||
|
async def rollback(self) -> RollbackResult | None:
|
||||||
|
"""Manually trigger rollback."""
|
||||||
|
if self._checkpoint:
|
||||||
|
return await self._manager.rollback(self._checkpoint.id)
|
||||||
|
return None
|
||||||
Reference in New Issue
Block a user