""" Rollback Manager Manages checkpoints and rollback operations for agent actions. """ import asyncio import logging from datetime import datetime, timedelta from pathlib import Path from typing import Any from uuid import uuid4 from ..config import get_safety_config from ..exceptions import RollbackError from ..models import ( ActionRequest, Checkpoint, CheckpointType, RollbackResult, ) logger = logging.getLogger(__name__) class FileCheckpoint: """Stores file state for rollback.""" def __init__( self, checkpoint_id: str, file_path: str, original_content: bytes | None, existed: bool, ) -> None: self.checkpoint_id = checkpoint_id self.file_path = file_path self.original_content = original_content self.existed = existed self.created_at = datetime.utcnow() class RollbackManager: """ Manages checkpoints and rollback operations. Features: - File system checkpoints - Transaction wrapping for actions - Automatic checkpoint for destructive actions - Rollback triggers on failure - Checkpoint expiration and cleanup """ def __init__( self, checkpoint_dir: str | None = None, retention_hours: int | None = None, ) -> None: """ Initialize the RollbackManager. Args: checkpoint_dir: Directory for storing checkpoint data retention_hours: Hours to retain checkpoints """ config = get_safety_config() self._checkpoint_dir = Path( checkpoint_dir or config.checkpoint_dir ) self._retention_hours = retention_hours or config.checkpoint_retention_hours self._checkpoints: dict[str, Checkpoint] = {} self._file_checkpoints: dict[str, list[FileCheckpoint]] = {} self._lock = asyncio.Lock() # Ensure checkpoint directory exists self._checkpoint_dir.mkdir(parents=True, exist_ok=True) async def create_checkpoint( self, action: ActionRequest, checkpoint_type: CheckpointType = CheckpointType.COMPOSITE, description: str | None = None, ) -> Checkpoint: """ Create a checkpoint before an action. Args: action: The action to checkpoint for checkpoint_type: Type of checkpoint description: Optional description Returns: The created checkpoint """ checkpoint_id = str(uuid4()) checkpoint = Checkpoint( id=checkpoint_id, checkpoint_type=checkpoint_type, action_id=action.id, created_at=datetime.utcnow(), expires_at=datetime.utcnow() + timedelta(hours=self._retention_hours), data={ "action_type": action.action_type.value, "tool_name": action.tool_name, "resource": action.resource, }, description=description or f"Checkpoint for {action.tool_name}", ) async with self._lock: self._checkpoints[checkpoint_id] = checkpoint self._file_checkpoints[checkpoint_id] = [] logger.info( "Created checkpoint %s for action %s", checkpoint_id, action.id, ) return checkpoint async def checkpoint_file( self, checkpoint_id: str, file_path: str, ) -> None: """ Store current state of a file for checkpoint. Args: checkpoint_id: ID of the checkpoint file_path: Path to the file """ path = Path(file_path) if path.exists(): content = path.read_bytes() existed = True else: content = None existed = False file_checkpoint = FileCheckpoint( checkpoint_id=checkpoint_id, file_path=file_path, original_content=content, existed=existed, ) async with self._lock: if checkpoint_id not in self._file_checkpoints: self._file_checkpoints[checkpoint_id] = [] self._file_checkpoints[checkpoint_id].append(file_checkpoint) logger.debug( "Stored file state for checkpoint %s: %s (existed=%s)", checkpoint_id, file_path, existed, ) async def checkpoint_files( self, checkpoint_id: str, file_paths: list[str], ) -> None: """ Store current state of multiple files. Args: checkpoint_id: ID of the checkpoint file_paths: Paths to the files """ for path in file_paths: await self.checkpoint_file(checkpoint_id, path) async def rollback( self, checkpoint_id: str, ) -> RollbackResult: """ Rollback to a checkpoint. Args: checkpoint_id: ID of the checkpoint Returns: Result of the rollback operation """ async with self._lock: checkpoint = self._checkpoints.get(checkpoint_id) if not checkpoint: raise RollbackError( f"Checkpoint not found: {checkpoint_id}", checkpoint_id=checkpoint_id, ) if not checkpoint.is_valid: raise RollbackError( f"Checkpoint is no longer valid: {checkpoint_id}", checkpoint_id=checkpoint_id, ) file_checkpoints = self._file_checkpoints.get(checkpoint_id, []) actions_rolled_back: list[str] = [] failed_actions: list[str] = [] # Rollback file changes for fc in file_checkpoints: try: await self._rollback_file(fc) actions_rolled_back.append(f"file:{fc.file_path}") except Exception as e: logger.error("Failed to rollback file %s: %s", fc.file_path, e) failed_actions.append(f"file:{fc.file_path}") success = len(failed_actions) == 0 # Mark checkpoint as used async with self._lock: if checkpoint_id in self._checkpoints: self._checkpoints[checkpoint_id].is_valid = False result = RollbackResult( checkpoint_id=checkpoint_id, success=success, actions_rolled_back=actions_rolled_back, failed_actions=failed_actions, error=None if success else f"Failed to rollback {len(failed_actions)} items", ) if success: logger.info("Rollback successful for checkpoint %s", checkpoint_id) else: logger.error( "Rollback partially failed for checkpoint %s: %d failures", checkpoint_id, len(failed_actions), ) return result async def discard_checkpoint(self, checkpoint_id: str) -> bool: """ Discard a checkpoint without rolling back. Args: checkpoint_id: ID of the checkpoint Returns: True if checkpoint was found and discarded """ async with self._lock: if checkpoint_id in self._checkpoints: del self._checkpoints[checkpoint_id] if checkpoint_id in self._file_checkpoints: del self._file_checkpoints[checkpoint_id] logger.debug("Discarded checkpoint %s", checkpoint_id) return True return False async def get_checkpoint(self, checkpoint_id: str) -> Checkpoint | None: """Get a checkpoint by ID.""" async with self._lock: return self._checkpoints.get(checkpoint_id) async def list_checkpoints( self, action_id: str | None = None, include_expired: bool = False, ) -> list[Checkpoint]: """ List checkpoints. Args: action_id: Optional filter by action ID include_expired: Include expired checkpoints Returns: List of checkpoints """ now = datetime.utcnow() async with self._lock: checkpoints = list(self._checkpoints.values()) if action_id: checkpoints = [c for c in checkpoints if c.action_id == action_id] if not include_expired: checkpoints = [ c for c in checkpoints if c.expires_at is None or c.expires_at > now ] return checkpoints async def cleanup_expired(self) -> int: """ Clean up expired checkpoints. Returns: Number of checkpoints cleaned up """ now = datetime.utcnow() to_remove: list[str] = [] async with self._lock: for checkpoint_id, checkpoint in self._checkpoints.items(): if checkpoint.expires_at and checkpoint.expires_at < now: to_remove.append(checkpoint_id) for checkpoint_id in to_remove: del self._checkpoints[checkpoint_id] if checkpoint_id in self._file_checkpoints: del self._file_checkpoints[checkpoint_id] if to_remove: logger.info("Cleaned up %d expired checkpoints", len(to_remove)) return len(to_remove) async def _rollback_file(self, fc: FileCheckpoint) -> None: """Rollback a single file to its checkpoint state.""" path = Path(fc.file_path) if fc.existed: # Restore original content if fc.original_content is not None: path.parent.mkdir(parents=True, exist_ok=True) path.write_bytes(fc.original_content) logger.debug("Restored file: %s", fc.file_path) else: # File didn't exist before - delete it if path.exists(): path.unlink() logger.debug("Deleted file (didn't exist before): %s", fc.file_path) class TransactionContext: """ Context manager for transactional action execution. Usage: async with TransactionContext(rollback_manager, action) as tx: tx.checkpoint_file("/path/to/file") # Do work... # If exception occurs, automatic rollback """ def __init__( self, manager: RollbackManager, action: ActionRequest, auto_rollback: bool = True, ) -> None: self._manager = manager self._action = action self._auto_rollback = auto_rollback self._checkpoint: Checkpoint | None = None self._committed = False async def __aenter__(self) -> "TransactionContext": self._checkpoint = await self._manager.create_checkpoint(self._action) return self async def __aexit__( self, exc_type: type | None, exc_val: Exception | None, exc_tb: Any, ) -> bool: if exc_val is not None and self._auto_rollback and not self._committed: # Exception occurred - rollback if self._checkpoint: try: await self._manager.rollback(self._checkpoint.id) logger.info( "Auto-rollback completed for action %s", self._action.id, ) except Exception as e: logger.error("Auto-rollback failed: %s", e) elif self._committed and self._checkpoint: # Committed - discard checkpoint await self._manager.discard_checkpoint(self._checkpoint.id) return False # Don't suppress the exception @property def checkpoint_id(self) -> str | None: """Get the checkpoint ID.""" return self._checkpoint.id if self._checkpoint else None async def checkpoint_file(self, file_path: str) -> None: """Checkpoint a file for this transaction.""" if self._checkpoint: await self._manager.checkpoint_file(self._checkpoint.id, file_path) async def checkpoint_files(self, file_paths: list[str]) -> None: """Checkpoint multiple files for this transaction.""" if self._checkpoint: await self._manager.checkpoint_files(self._checkpoint.id, file_paths) def commit(self) -> None: """Mark transaction as committed (no rollback on exit).""" self._committed = True async def rollback(self) -> RollbackResult | None: """Manually trigger rollback.""" if self._checkpoint: return await self._manager.rollback(self._checkpoint.id) return None