feat(safety): add Phase C advanced controls
- Add rollback manager with file checkpointing and transaction context - Add HITL manager with approval queues and notification handlers - Add content filter with PII, secrets, and injection detection - Add emergency controls with stop/pause/resume capabilities - Update SafetyConfig with checkpoint_dir setting Issue #63 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
418
backend/app/services/safety/rollback/manager.py
Normal file
418
backend/app/services/safety/rollback/manager.py
Normal file
@@ -0,0 +1,418 @@
|
||||
"""
|
||||
Rollback Manager
|
||||
|
||||
Manages checkpoints and rollback operations for agent actions.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from ..config import get_safety_config
|
||||
from ..exceptions import RollbackError
|
||||
from ..models import (
|
||||
ActionRequest,
|
||||
Checkpoint,
|
||||
CheckpointType,
|
||||
RollbackResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileCheckpoint:
|
||||
"""Stores file state for rollback."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_id: str,
|
||||
file_path: str,
|
||||
original_content: bytes | None,
|
||||
existed: bool,
|
||||
) -> None:
|
||||
self.checkpoint_id = checkpoint_id
|
||||
self.file_path = file_path
|
||||
self.original_content = original_content
|
||||
self.existed = existed
|
||||
self.created_at = datetime.utcnow()
|
||||
|
||||
|
||||
class RollbackManager:
|
||||
"""
|
||||
Manages checkpoints and rollback operations.
|
||||
|
||||
Features:
|
||||
- File system checkpoints
|
||||
- Transaction wrapping for actions
|
||||
- Automatic checkpoint for destructive actions
|
||||
- Rollback triggers on failure
|
||||
- Checkpoint expiration and cleanup
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_dir: str | None = None,
|
||||
retention_hours: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the RollbackManager.
|
||||
|
||||
Args:
|
||||
checkpoint_dir: Directory for storing checkpoint data
|
||||
retention_hours: Hours to retain checkpoints
|
||||
"""
|
||||
config = get_safety_config()
|
||||
|
||||
self._checkpoint_dir = Path(
|
||||
checkpoint_dir or config.checkpoint_dir
|
||||
)
|
||||
self._retention_hours = retention_hours or config.checkpoint_retention_hours
|
||||
|
||||
self._checkpoints: dict[str, Checkpoint] = {}
|
||||
self._file_checkpoints: dict[str, list[FileCheckpoint]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Ensure checkpoint directory exists
|
||||
self._checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def create_checkpoint(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
checkpoint_type: CheckpointType = CheckpointType.COMPOSITE,
|
||||
description: str | None = None,
|
||||
) -> Checkpoint:
|
||||
"""
|
||||
Create a checkpoint before an action.
|
||||
|
||||
Args:
|
||||
action: The action to checkpoint for
|
||||
checkpoint_type: Type of checkpoint
|
||||
description: Optional description
|
||||
|
||||
Returns:
|
||||
The created checkpoint
|
||||
"""
|
||||
checkpoint_id = str(uuid4())
|
||||
|
||||
checkpoint = Checkpoint(
|
||||
id=checkpoint_id,
|
||||
checkpoint_type=checkpoint_type,
|
||||
action_id=action.id,
|
||||
created_at=datetime.utcnow(),
|
||||
expires_at=datetime.utcnow() + timedelta(hours=self._retention_hours),
|
||||
data={
|
||||
"action_type": action.action_type.value,
|
||||
"tool_name": action.tool_name,
|
||||
"resource": action.resource,
|
||||
},
|
||||
description=description or f"Checkpoint for {action.tool_name}",
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
self._checkpoints[checkpoint_id] = checkpoint
|
||||
self._file_checkpoints[checkpoint_id] = []
|
||||
|
||||
logger.info(
|
||||
"Created checkpoint %s for action %s",
|
||||
checkpoint_id,
|
||||
action.id,
|
||||
)
|
||||
|
||||
return checkpoint
|
||||
|
||||
async def checkpoint_file(
|
||||
self,
|
||||
checkpoint_id: str,
|
||||
file_path: str,
|
||||
) -> None:
|
||||
"""
|
||||
Store current state of a file for checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_id: ID of the checkpoint
|
||||
file_path: Path to the file
|
||||
"""
|
||||
path = Path(file_path)
|
||||
|
||||
if path.exists():
|
||||
content = path.read_bytes()
|
||||
existed = True
|
||||
else:
|
||||
content = None
|
||||
existed = False
|
||||
|
||||
file_checkpoint = FileCheckpoint(
|
||||
checkpoint_id=checkpoint_id,
|
||||
file_path=file_path,
|
||||
original_content=content,
|
||||
existed=existed,
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
if checkpoint_id not in self._file_checkpoints:
|
||||
self._file_checkpoints[checkpoint_id] = []
|
||||
self._file_checkpoints[checkpoint_id].append(file_checkpoint)
|
||||
|
||||
logger.debug(
|
||||
"Stored file state for checkpoint %s: %s (existed=%s)",
|
||||
checkpoint_id,
|
||||
file_path,
|
||||
existed,
|
||||
)
|
||||
|
||||
async def checkpoint_files(
|
||||
self,
|
||||
checkpoint_id: str,
|
||||
file_paths: list[str],
|
||||
) -> None:
|
||||
"""
|
||||
Store current state of multiple files.
|
||||
|
||||
Args:
|
||||
checkpoint_id: ID of the checkpoint
|
||||
file_paths: Paths to the files
|
||||
"""
|
||||
for path in file_paths:
|
||||
await self.checkpoint_file(checkpoint_id, path)
|
||||
|
||||
async def rollback(
|
||||
self,
|
||||
checkpoint_id: str,
|
||||
) -> RollbackResult:
|
||||
"""
|
||||
Rollback to a checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_id: ID of the checkpoint
|
||||
|
||||
Returns:
|
||||
Result of the rollback operation
|
||||
"""
|
||||
async with self._lock:
|
||||
checkpoint = self._checkpoints.get(checkpoint_id)
|
||||
if not checkpoint:
|
||||
raise RollbackError(
|
||||
f"Checkpoint not found: {checkpoint_id}",
|
||||
checkpoint_id=checkpoint_id,
|
||||
)
|
||||
|
||||
if not checkpoint.is_valid:
|
||||
raise RollbackError(
|
||||
f"Checkpoint is no longer valid: {checkpoint_id}",
|
||||
checkpoint_id=checkpoint_id,
|
||||
)
|
||||
|
||||
file_checkpoints = self._file_checkpoints.get(checkpoint_id, [])
|
||||
|
||||
actions_rolled_back: list[str] = []
|
||||
failed_actions: list[str] = []
|
||||
|
||||
# Rollback file changes
|
||||
for fc in file_checkpoints:
|
||||
try:
|
||||
await self._rollback_file(fc)
|
||||
actions_rolled_back.append(f"file:{fc.file_path}")
|
||||
except Exception as e:
|
||||
logger.error("Failed to rollback file %s: %s", fc.file_path, e)
|
||||
failed_actions.append(f"file:{fc.file_path}")
|
||||
|
||||
success = len(failed_actions) == 0
|
||||
|
||||
# Mark checkpoint as used
|
||||
async with self._lock:
|
||||
if checkpoint_id in self._checkpoints:
|
||||
self._checkpoints[checkpoint_id].is_valid = False
|
||||
|
||||
result = RollbackResult(
|
||||
checkpoint_id=checkpoint_id,
|
||||
success=success,
|
||||
actions_rolled_back=actions_rolled_back,
|
||||
failed_actions=failed_actions,
|
||||
error=None if success else f"Failed to rollback {len(failed_actions)} items",
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Rollback successful for checkpoint %s", checkpoint_id)
|
||||
else:
|
||||
logger.error(
|
||||
"Rollback partially failed for checkpoint %s: %d failures",
|
||||
checkpoint_id,
|
||||
len(failed_actions),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def discard_checkpoint(self, checkpoint_id: str) -> bool:
|
||||
"""
|
||||
Discard a checkpoint without rolling back.
|
||||
|
||||
Args:
|
||||
checkpoint_id: ID of the checkpoint
|
||||
|
||||
Returns:
|
||||
True if checkpoint was found and discarded
|
||||
"""
|
||||
async with self._lock:
|
||||
if checkpoint_id in self._checkpoints:
|
||||
del self._checkpoints[checkpoint_id]
|
||||
if checkpoint_id in self._file_checkpoints:
|
||||
del self._file_checkpoints[checkpoint_id]
|
||||
logger.debug("Discarded checkpoint %s", checkpoint_id)
|
||||
return True
|
||||
return False
|
||||
|
||||
async def get_checkpoint(self, checkpoint_id: str) -> Checkpoint | None:
|
||||
"""Get a checkpoint by ID."""
|
||||
async with self._lock:
|
||||
return self._checkpoints.get(checkpoint_id)
|
||||
|
||||
async def list_checkpoints(
|
||||
self,
|
||||
action_id: str | None = None,
|
||||
include_expired: bool = False,
|
||||
) -> list[Checkpoint]:
|
||||
"""
|
||||
List checkpoints.
|
||||
|
||||
Args:
|
||||
action_id: Optional filter by action ID
|
||||
include_expired: Include expired checkpoints
|
||||
|
||||
Returns:
|
||||
List of checkpoints
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
|
||||
async with self._lock:
|
||||
checkpoints = list(self._checkpoints.values())
|
||||
|
||||
if action_id:
|
||||
checkpoints = [c for c in checkpoints if c.action_id == action_id]
|
||||
|
||||
if not include_expired:
|
||||
checkpoints = [
|
||||
c for c in checkpoints
|
||||
if c.expires_at is None or c.expires_at > now
|
||||
]
|
||||
|
||||
return checkpoints
|
||||
|
||||
async def cleanup_expired(self) -> int:
|
||||
"""
|
||||
Clean up expired checkpoints.
|
||||
|
||||
Returns:
|
||||
Number of checkpoints cleaned up
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
to_remove: list[str] = []
|
||||
|
||||
async with self._lock:
|
||||
for checkpoint_id, checkpoint in self._checkpoints.items():
|
||||
if checkpoint.expires_at and checkpoint.expires_at < now:
|
||||
to_remove.append(checkpoint_id)
|
||||
|
||||
for checkpoint_id in to_remove:
|
||||
del self._checkpoints[checkpoint_id]
|
||||
if checkpoint_id in self._file_checkpoints:
|
||||
del self._file_checkpoints[checkpoint_id]
|
||||
|
||||
if to_remove:
|
||||
logger.info("Cleaned up %d expired checkpoints", len(to_remove))
|
||||
|
||||
return len(to_remove)
|
||||
|
||||
async def _rollback_file(self, fc: FileCheckpoint) -> None:
|
||||
"""Rollback a single file to its checkpoint state."""
|
||||
path = Path(fc.file_path)
|
||||
|
||||
if fc.existed:
|
||||
# Restore original content
|
||||
if fc.original_content is not None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_bytes(fc.original_content)
|
||||
logger.debug("Restored file: %s", fc.file_path)
|
||||
else:
|
||||
# File didn't exist before - delete it
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
logger.debug("Deleted file (didn't exist before): %s", fc.file_path)
|
||||
|
||||
|
||||
class TransactionContext:
|
||||
"""
|
||||
Context manager for transactional action execution.
|
||||
|
||||
Usage:
|
||||
async with TransactionContext(rollback_manager, action) as tx:
|
||||
tx.checkpoint_file("/path/to/file")
|
||||
# Do work...
|
||||
# If exception occurs, automatic rollback
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manager: RollbackManager,
|
||||
action: ActionRequest,
|
||||
auto_rollback: bool = True,
|
||||
) -> None:
|
||||
self._manager = manager
|
||||
self._action = action
|
||||
self._auto_rollback = auto_rollback
|
||||
self._checkpoint: Checkpoint | None = None
|
||||
self._committed = False
|
||||
|
||||
async def __aenter__(self) -> "TransactionContext":
|
||||
self._checkpoint = await self._manager.create_checkpoint(self._action)
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type | None,
|
||||
exc_val: Exception | None,
|
||||
exc_tb: Any,
|
||||
) -> bool:
|
||||
if exc_val is not None and self._auto_rollback and not self._committed:
|
||||
# Exception occurred - rollback
|
||||
if self._checkpoint:
|
||||
try:
|
||||
await self._manager.rollback(self._checkpoint.id)
|
||||
logger.info(
|
||||
"Auto-rollback completed for action %s",
|
||||
self._action.id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Auto-rollback failed: %s", e)
|
||||
elif self._committed and self._checkpoint:
|
||||
# Committed - discard checkpoint
|
||||
await self._manager.discard_checkpoint(self._checkpoint.id)
|
||||
|
||||
return False # Don't suppress the exception
|
||||
|
||||
@property
|
||||
def checkpoint_id(self) -> str | None:
|
||||
"""Get the checkpoint ID."""
|
||||
return self._checkpoint.id if self._checkpoint else None
|
||||
|
||||
async def checkpoint_file(self, file_path: str) -> None:
|
||||
"""Checkpoint a file for this transaction."""
|
||||
if self._checkpoint:
|
||||
await self._manager.checkpoint_file(self._checkpoint.id, file_path)
|
||||
|
||||
async def checkpoint_files(self, file_paths: list[str]) -> None:
|
||||
"""Checkpoint multiple files for this transaction."""
|
||||
if self._checkpoint:
|
||||
await self._manager.checkpoint_files(self._checkpoint.id, file_paths)
|
||||
|
||||
def commit(self) -> None:
|
||||
"""Mark transaction as committed (no rollback on exit)."""
|
||||
self._committed = True
|
||||
|
||||
async def rollback(self) -> RollbackResult | None:
|
||||
"""Manually trigger rollback."""
|
||||
if self._checkpoint:
|
||||
return await self._manager.rollback(self._checkpoint.id)
|
||||
return None
|
||||
Reference in New Issue
Block a user