Files
syndarix/backend/app/services/safety/rollback/manager.py
Felipe Cardoso 520c06175e refactor(safety): apply consistent formatting across services and tests
Improved code readability and uniformity by standardizing line breaks, indentation, and inline conditions across safety-related services, models, and tests, including content filters, validation rules, and emergency controls.
2026-01-03 16:23:39 +01:00

418 lines
12 KiB
Python

"""
Rollback Manager
Manages checkpoints and rollback operations for agent actions.
"""
import asyncio
import logging
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any
from uuid import uuid4
from ..config import get_safety_config
from ..exceptions import RollbackError
from ..models import (
ActionRequest,
Checkpoint,
CheckpointType,
RollbackResult,
)
logger = logging.getLogger(__name__)
class FileCheckpoint:
"""Stores file state for rollback."""
def __init__(
self,
checkpoint_id: str,
file_path: str,
original_content: bytes | None,
existed: bool,
) -> None:
self.checkpoint_id = checkpoint_id
self.file_path = file_path
self.original_content = original_content
self.existed = existed
self.created_at = datetime.utcnow()
class RollbackManager:
"""
Manages checkpoints and rollback operations.
Features:
- File system checkpoints
- Transaction wrapping for actions
- Automatic checkpoint for destructive actions
- Rollback triggers on failure
- Checkpoint expiration and cleanup
"""
def __init__(
self,
checkpoint_dir: str | None = None,
retention_hours: int | None = None,
) -> None:
"""
Initialize the RollbackManager.
Args:
checkpoint_dir: Directory for storing checkpoint data
retention_hours: Hours to retain checkpoints
"""
config = get_safety_config()
self._checkpoint_dir = Path(checkpoint_dir or config.checkpoint_dir)
self._retention_hours = retention_hours or config.checkpoint_retention_hours
self._checkpoints: dict[str, Checkpoint] = {}
self._file_checkpoints: dict[str, list[FileCheckpoint]] = {}
self._lock = asyncio.Lock()
# Ensure checkpoint directory exists
self._checkpoint_dir.mkdir(parents=True, exist_ok=True)
async def create_checkpoint(
self,
action: ActionRequest,
checkpoint_type: CheckpointType = CheckpointType.COMPOSITE,
description: str | None = None,
) -> Checkpoint:
"""
Create a checkpoint before an action.
Args:
action: The action to checkpoint for
checkpoint_type: Type of checkpoint
description: Optional description
Returns:
The created checkpoint
"""
checkpoint_id = str(uuid4())
checkpoint = Checkpoint(
id=checkpoint_id,
checkpoint_type=checkpoint_type,
action_id=action.id,
created_at=datetime.utcnow(),
expires_at=datetime.utcnow() + timedelta(hours=self._retention_hours),
data={
"action_type": action.action_type.value,
"tool_name": action.tool_name,
"resource": action.resource,
},
description=description or f"Checkpoint for {action.tool_name}",
)
async with self._lock:
self._checkpoints[checkpoint_id] = checkpoint
self._file_checkpoints[checkpoint_id] = []
logger.info(
"Created checkpoint %s for action %s",
checkpoint_id,
action.id,
)
return checkpoint
async def checkpoint_file(
self,
checkpoint_id: str,
file_path: str,
) -> None:
"""
Store current state of a file for checkpoint.
Args:
checkpoint_id: ID of the checkpoint
file_path: Path to the file
"""
path = Path(file_path)
if path.exists():
content = path.read_bytes()
existed = True
else:
content = None
existed = False
file_checkpoint = FileCheckpoint(
checkpoint_id=checkpoint_id,
file_path=file_path,
original_content=content,
existed=existed,
)
async with self._lock:
if checkpoint_id not in self._file_checkpoints:
self._file_checkpoints[checkpoint_id] = []
self._file_checkpoints[checkpoint_id].append(file_checkpoint)
logger.debug(
"Stored file state for checkpoint %s: %s (existed=%s)",
checkpoint_id,
file_path,
existed,
)
async def checkpoint_files(
self,
checkpoint_id: str,
file_paths: list[str],
) -> None:
"""
Store current state of multiple files.
Args:
checkpoint_id: ID of the checkpoint
file_paths: Paths to the files
"""
for path in file_paths:
await self.checkpoint_file(checkpoint_id, path)
async def rollback(
self,
checkpoint_id: str,
) -> RollbackResult:
"""
Rollback to a checkpoint.
Args:
checkpoint_id: ID of the checkpoint
Returns:
Result of the rollback operation
"""
async with self._lock:
checkpoint = self._checkpoints.get(checkpoint_id)
if not checkpoint:
raise RollbackError(
f"Checkpoint not found: {checkpoint_id}",
checkpoint_id=checkpoint_id,
)
if not checkpoint.is_valid:
raise RollbackError(
f"Checkpoint is no longer valid: {checkpoint_id}",
checkpoint_id=checkpoint_id,
)
file_checkpoints = self._file_checkpoints.get(checkpoint_id, [])
actions_rolled_back: list[str] = []
failed_actions: list[str] = []
# Rollback file changes
for fc in file_checkpoints:
try:
await self._rollback_file(fc)
actions_rolled_back.append(f"file:{fc.file_path}")
except Exception as e:
logger.error("Failed to rollback file %s: %s", fc.file_path, e)
failed_actions.append(f"file:{fc.file_path}")
success = len(failed_actions) == 0
# Mark checkpoint as used
async with self._lock:
if checkpoint_id in self._checkpoints:
self._checkpoints[checkpoint_id].is_valid = False
result = RollbackResult(
checkpoint_id=checkpoint_id,
success=success,
actions_rolled_back=actions_rolled_back,
failed_actions=failed_actions,
error=None
if success
else f"Failed to rollback {len(failed_actions)} items",
)
if success:
logger.info("Rollback successful for checkpoint %s", checkpoint_id)
else:
logger.error(
"Rollback partially failed for checkpoint %s: %d failures",
checkpoint_id,
len(failed_actions),
)
return result
async def discard_checkpoint(self, checkpoint_id: str) -> bool:
"""
Discard a checkpoint without rolling back.
Args:
checkpoint_id: ID of the checkpoint
Returns:
True if checkpoint was found and discarded
"""
async with self._lock:
if checkpoint_id in self._checkpoints:
del self._checkpoints[checkpoint_id]
if checkpoint_id in self._file_checkpoints:
del self._file_checkpoints[checkpoint_id]
logger.debug("Discarded checkpoint %s", checkpoint_id)
return True
return False
async def get_checkpoint(self, checkpoint_id: str) -> Checkpoint | None:
"""Get a checkpoint by ID."""
async with self._lock:
return self._checkpoints.get(checkpoint_id)
async def list_checkpoints(
self,
action_id: str | None = None,
include_expired: bool = False,
) -> list[Checkpoint]:
"""
List checkpoints.
Args:
action_id: Optional filter by action ID
include_expired: Include expired checkpoints
Returns:
List of checkpoints
"""
now = datetime.utcnow()
async with self._lock:
checkpoints = list(self._checkpoints.values())
if action_id:
checkpoints = [c for c in checkpoints if c.action_id == action_id]
if not include_expired:
checkpoints = [
c for c in checkpoints if c.expires_at is None or c.expires_at > now
]
return checkpoints
async def cleanup_expired(self) -> int:
"""
Clean up expired checkpoints.
Returns:
Number of checkpoints cleaned up
"""
now = datetime.utcnow()
to_remove: list[str] = []
async with self._lock:
for checkpoint_id, checkpoint in self._checkpoints.items():
if checkpoint.expires_at and checkpoint.expires_at < now:
to_remove.append(checkpoint_id)
for checkpoint_id in to_remove:
del self._checkpoints[checkpoint_id]
if checkpoint_id in self._file_checkpoints:
del self._file_checkpoints[checkpoint_id]
if to_remove:
logger.info("Cleaned up %d expired checkpoints", len(to_remove))
return len(to_remove)
async def _rollback_file(self, fc: FileCheckpoint) -> None:
"""Rollback a single file to its checkpoint state."""
path = Path(fc.file_path)
if fc.existed:
# Restore original content
if fc.original_content is not None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(fc.original_content)
logger.debug("Restored file: %s", fc.file_path)
else:
# File didn't exist before - delete it
if path.exists():
path.unlink()
logger.debug("Deleted file (didn't exist before): %s", fc.file_path)
class TransactionContext:
"""
Context manager for transactional action execution.
Usage:
async with TransactionContext(rollback_manager, action) as tx:
tx.checkpoint_file("/path/to/file")
# Do work...
# If exception occurs, automatic rollback
"""
def __init__(
self,
manager: RollbackManager,
action: ActionRequest,
auto_rollback: bool = True,
) -> None:
self._manager = manager
self._action = action
self._auto_rollback = auto_rollback
self._checkpoint: Checkpoint | None = None
self._committed = False
async def __aenter__(self) -> "TransactionContext":
self._checkpoint = await self._manager.create_checkpoint(self._action)
return self
async def __aexit__(
self,
exc_type: type | None,
exc_val: Exception | None,
exc_tb: Any,
) -> bool:
if exc_val is not None and self._auto_rollback and not self._committed:
# Exception occurred - rollback
if self._checkpoint:
try:
await self._manager.rollback(self._checkpoint.id)
logger.info(
"Auto-rollback completed for action %s",
self._action.id,
)
except Exception as e:
logger.error("Auto-rollback failed: %s", e)
elif self._committed and self._checkpoint:
# Committed - discard checkpoint
await self._manager.discard_checkpoint(self._checkpoint.id)
return False # Don't suppress the exception
@property
def checkpoint_id(self) -> str | None:
"""Get the checkpoint ID."""
return self._checkpoint.id if self._checkpoint else None
async def checkpoint_file(self, file_path: str) -> None:
"""Checkpoint a file for this transaction."""
if self._checkpoint:
await self._manager.checkpoint_file(self._checkpoint.id, file_path)
async def checkpoint_files(self, file_paths: list[str]) -> None:
"""Checkpoint multiple files for this transaction."""
if self._checkpoint:
await self._manager.checkpoint_files(self._checkpoint.id, file_paths)
def commit(self) -> None:
"""Mark transaction as committed (no rollback on exit)."""
self._committed = True
async def rollback(self) -> RollbackResult | None:
"""Manually trigger rollback."""
if self._checkpoint:
return await self._manager.rollback(self._checkpoint.id)
return None