forked from cardosofelipe/fast-next-template
Improved code readability and uniformity by standardizing line breaks, indentation, and inline conditions across safety-related services, models, and tests, including content filters, validation rules, and emergency controls.
418 lines
12 KiB
Python
418 lines
12 KiB
Python
"""
|
|
Rollback Manager
|
|
|
|
Manages checkpoints and rollback operations for agent actions.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
from datetime import datetime, timedelta
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from uuid import uuid4
|
|
|
|
from ..config import get_safety_config
|
|
from ..exceptions import RollbackError
|
|
from ..models import (
|
|
ActionRequest,
|
|
Checkpoint,
|
|
CheckpointType,
|
|
RollbackResult,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class FileCheckpoint:
|
|
"""Stores file state for rollback."""
|
|
|
|
def __init__(
|
|
self,
|
|
checkpoint_id: str,
|
|
file_path: str,
|
|
original_content: bytes | None,
|
|
existed: bool,
|
|
) -> None:
|
|
self.checkpoint_id = checkpoint_id
|
|
self.file_path = file_path
|
|
self.original_content = original_content
|
|
self.existed = existed
|
|
self.created_at = datetime.utcnow()
|
|
|
|
|
|
class RollbackManager:
|
|
"""
|
|
Manages checkpoints and rollback operations.
|
|
|
|
Features:
|
|
- File system checkpoints
|
|
- Transaction wrapping for actions
|
|
- Automatic checkpoint for destructive actions
|
|
- Rollback triggers on failure
|
|
- Checkpoint expiration and cleanup
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
checkpoint_dir: str | None = None,
|
|
retention_hours: int | None = None,
|
|
) -> None:
|
|
"""
|
|
Initialize the RollbackManager.
|
|
|
|
Args:
|
|
checkpoint_dir: Directory for storing checkpoint data
|
|
retention_hours: Hours to retain checkpoints
|
|
"""
|
|
config = get_safety_config()
|
|
|
|
self._checkpoint_dir = Path(checkpoint_dir or config.checkpoint_dir)
|
|
self._retention_hours = retention_hours or config.checkpoint_retention_hours
|
|
|
|
self._checkpoints: dict[str, Checkpoint] = {}
|
|
self._file_checkpoints: dict[str, list[FileCheckpoint]] = {}
|
|
self._lock = asyncio.Lock()
|
|
|
|
# Ensure checkpoint directory exists
|
|
self._checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
async def create_checkpoint(
|
|
self,
|
|
action: ActionRequest,
|
|
checkpoint_type: CheckpointType = CheckpointType.COMPOSITE,
|
|
description: str | None = None,
|
|
) -> Checkpoint:
|
|
"""
|
|
Create a checkpoint before an action.
|
|
|
|
Args:
|
|
action: The action to checkpoint for
|
|
checkpoint_type: Type of checkpoint
|
|
description: Optional description
|
|
|
|
Returns:
|
|
The created checkpoint
|
|
"""
|
|
checkpoint_id = str(uuid4())
|
|
|
|
checkpoint = Checkpoint(
|
|
id=checkpoint_id,
|
|
checkpoint_type=checkpoint_type,
|
|
action_id=action.id,
|
|
created_at=datetime.utcnow(),
|
|
expires_at=datetime.utcnow() + timedelta(hours=self._retention_hours),
|
|
data={
|
|
"action_type": action.action_type.value,
|
|
"tool_name": action.tool_name,
|
|
"resource": action.resource,
|
|
},
|
|
description=description or f"Checkpoint for {action.tool_name}",
|
|
)
|
|
|
|
async with self._lock:
|
|
self._checkpoints[checkpoint_id] = checkpoint
|
|
self._file_checkpoints[checkpoint_id] = []
|
|
|
|
logger.info(
|
|
"Created checkpoint %s for action %s",
|
|
checkpoint_id,
|
|
action.id,
|
|
)
|
|
|
|
return checkpoint
|
|
|
|
async def checkpoint_file(
|
|
self,
|
|
checkpoint_id: str,
|
|
file_path: str,
|
|
) -> None:
|
|
"""
|
|
Store current state of a file for checkpoint.
|
|
|
|
Args:
|
|
checkpoint_id: ID of the checkpoint
|
|
file_path: Path to the file
|
|
"""
|
|
path = Path(file_path)
|
|
|
|
if path.exists():
|
|
content = path.read_bytes()
|
|
existed = True
|
|
else:
|
|
content = None
|
|
existed = False
|
|
|
|
file_checkpoint = FileCheckpoint(
|
|
checkpoint_id=checkpoint_id,
|
|
file_path=file_path,
|
|
original_content=content,
|
|
existed=existed,
|
|
)
|
|
|
|
async with self._lock:
|
|
if checkpoint_id not in self._file_checkpoints:
|
|
self._file_checkpoints[checkpoint_id] = []
|
|
self._file_checkpoints[checkpoint_id].append(file_checkpoint)
|
|
|
|
logger.debug(
|
|
"Stored file state for checkpoint %s: %s (existed=%s)",
|
|
checkpoint_id,
|
|
file_path,
|
|
existed,
|
|
)
|
|
|
|
async def checkpoint_files(
|
|
self,
|
|
checkpoint_id: str,
|
|
file_paths: list[str],
|
|
) -> None:
|
|
"""
|
|
Store current state of multiple files.
|
|
|
|
Args:
|
|
checkpoint_id: ID of the checkpoint
|
|
file_paths: Paths to the files
|
|
"""
|
|
for path in file_paths:
|
|
await self.checkpoint_file(checkpoint_id, path)
|
|
|
|
async def rollback(
|
|
self,
|
|
checkpoint_id: str,
|
|
) -> RollbackResult:
|
|
"""
|
|
Rollback to a checkpoint.
|
|
|
|
Args:
|
|
checkpoint_id: ID of the checkpoint
|
|
|
|
Returns:
|
|
Result of the rollback operation
|
|
"""
|
|
async with self._lock:
|
|
checkpoint = self._checkpoints.get(checkpoint_id)
|
|
if not checkpoint:
|
|
raise RollbackError(
|
|
f"Checkpoint not found: {checkpoint_id}",
|
|
checkpoint_id=checkpoint_id,
|
|
)
|
|
|
|
if not checkpoint.is_valid:
|
|
raise RollbackError(
|
|
f"Checkpoint is no longer valid: {checkpoint_id}",
|
|
checkpoint_id=checkpoint_id,
|
|
)
|
|
|
|
file_checkpoints = self._file_checkpoints.get(checkpoint_id, [])
|
|
|
|
actions_rolled_back: list[str] = []
|
|
failed_actions: list[str] = []
|
|
|
|
# Rollback file changes
|
|
for fc in file_checkpoints:
|
|
try:
|
|
await self._rollback_file(fc)
|
|
actions_rolled_back.append(f"file:{fc.file_path}")
|
|
except Exception as e:
|
|
logger.error("Failed to rollback file %s: %s", fc.file_path, e)
|
|
failed_actions.append(f"file:{fc.file_path}")
|
|
|
|
success = len(failed_actions) == 0
|
|
|
|
# Mark checkpoint as used
|
|
async with self._lock:
|
|
if checkpoint_id in self._checkpoints:
|
|
self._checkpoints[checkpoint_id].is_valid = False
|
|
|
|
result = RollbackResult(
|
|
checkpoint_id=checkpoint_id,
|
|
success=success,
|
|
actions_rolled_back=actions_rolled_back,
|
|
failed_actions=failed_actions,
|
|
error=None
|
|
if success
|
|
else f"Failed to rollback {len(failed_actions)} items",
|
|
)
|
|
|
|
if success:
|
|
logger.info("Rollback successful for checkpoint %s", checkpoint_id)
|
|
else:
|
|
logger.error(
|
|
"Rollback partially failed for checkpoint %s: %d failures",
|
|
checkpoint_id,
|
|
len(failed_actions),
|
|
)
|
|
|
|
return result
|
|
|
|
async def discard_checkpoint(self, checkpoint_id: str) -> bool:
|
|
"""
|
|
Discard a checkpoint without rolling back.
|
|
|
|
Args:
|
|
checkpoint_id: ID of the checkpoint
|
|
|
|
Returns:
|
|
True if checkpoint was found and discarded
|
|
"""
|
|
async with self._lock:
|
|
if checkpoint_id in self._checkpoints:
|
|
del self._checkpoints[checkpoint_id]
|
|
if checkpoint_id in self._file_checkpoints:
|
|
del self._file_checkpoints[checkpoint_id]
|
|
logger.debug("Discarded checkpoint %s", checkpoint_id)
|
|
return True
|
|
return False
|
|
|
|
async def get_checkpoint(self, checkpoint_id: str) -> Checkpoint | None:
|
|
"""Get a checkpoint by ID."""
|
|
async with self._lock:
|
|
return self._checkpoints.get(checkpoint_id)
|
|
|
|
async def list_checkpoints(
|
|
self,
|
|
action_id: str | None = None,
|
|
include_expired: bool = False,
|
|
) -> list[Checkpoint]:
|
|
"""
|
|
List checkpoints.
|
|
|
|
Args:
|
|
action_id: Optional filter by action ID
|
|
include_expired: Include expired checkpoints
|
|
|
|
Returns:
|
|
List of checkpoints
|
|
"""
|
|
now = datetime.utcnow()
|
|
|
|
async with self._lock:
|
|
checkpoints = list(self._checkpoints.values())
|
|
|
|
if action_id:
|
|
checkpoints = [c for c in checkpoints if c.action_id == action_id]
|
|
|
|
if not include_expired:
|
|
checkpoints = [
|
|
c for c in checkpoints if c.expires_at is None or c.expires_at > now
|
|
]
|
|
|
|
return checkpoints
|
|
|
|
async def cleanup_expired(self) -> int:
|
|
"""
|
|
Clean up expired checkpoints.
|
|
|
|
Returns:
|
|
Number of checkpoints cleaned up
|
|
"""
|
|
now = datetime.utcnow()
|
|
to_remove: list[str] = []
|
|
|
|
async with self._lock:
|
|
for checkpoint_id, checkpoint in self._checkpoints.items():
|
|
if checkpoint.expires_at and checkpoint.expires_at < now:
|
|
to_remove.append(checkpoint_id)
|
|
|
|
for checkpoint_id in to_remove:
|
|
del self._checkpoints[checkpoint_id]
|
|
if checkpoint_id in self._file_checkpoints:
|
|
del self._file_checkpoints[checkpoint_id]
|
|
|
|
if to_remove:
|
|
logger.info("Cleaned up %d expired checkpoints", len(to_remove))
|
|
|
|
return len(to_remove)
|
|
|
|
async def _rollback_file(self, fc: FileCheckpoint) -> None:
|
|
"""Rollback a single file to its checkpoint state."""
|
|
path = Path(fc.file_path)
|
|
|
|
if fc.existed:
|
|
# Restore original content
|
|
if fc.original_content is not None:
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
path.write_bytes(fc.original_content)
|
|
logger.debug("Restored file: %s", fc.file_path)
|
|
else:
|
|
# File didn't exist before - delete it
|
|
if path.exists():
|
|
path.unlink()
|
|
logger.debug("Deleted file (didn't exist before): %s", fc.file_path)
|
|
|
|
|
|
class TransactionContext:
|
|
"""
|
|
Context manager for transactional action execution.
|
|
|
|
Usage:
|
|
async with TransactionContext(rollback_manager, action) as tx:
|
|
tx.checkpoint_file("/path/to/file")
|
|
# Do work...
|
|
# If exception occurs, automatic rollback
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
manager: RollbackManager,
|
|
action: ActionRequest,
|
|
auto_rollback: bool = True,
|
|
) -> None:
|
|
self._manager = manager
|
|
self._action = action
|
|
self._auto_rollback = auto_rollback
|
|
self._checkpoint: Checkpoint | None = None
|
|
self._committed = False
|
|
|
|
async def __aenter__(self) -> "TransactionContext":
|
|
self._checkpoint = await self._manager.create_checkpoint(self._action)
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: type | None,
|
|
exc_val: Exception | None,
|
|
exc_tb: Any,
|
|
) -> bool:
|
|
if exc_val is not None and self._auto_rollback and not self._committed:
|
|
# Exception occurred - rollback
|
|
if self._checkpoint:
|
|
try:
|
|
await self._manager.rollback(self._checkpoint.id)
|
|
logger.info(
|
|
"Auto-rollback completed for action %s",
|
|
self._action.id,
|
|
)
|
|
except Exception as e:
|
|
logger.error("Auto-rollback failed: %s", e)
|
|
elif self._committed and self._checkpoint:
|
|
# Committed - discard checkpoint
|
|
await self._manager.discard_checkpoint(self._checkpoint.id)
|
|
|
|
return False # Don't suppress the exception
|
|
|
|
@property
|
|
def checkpoint_id(self) -> str | None:
|
|
"""Get the checkpoint ID."""
|
|
return self._checkpoint.id if self._checkpoint else None
|
|
|
|
async def checkpoint_file(self, file_path: str) -> None:
|
|
"""Checkpoint a file for this transaction."""
|
|
if self._checkpoint:
|
|
await self._manager.checkpoint_file(self._checkpoint.id, file_path)
|
|
|
|
async def checkpoint_files(self, file_paths: list[str]) -> None:
|
|
"""Checkpoint multiple files for this transaction."""
|
|
if self._checkpoint:
|
|
await self._manager.checkpoint_files(self._checkpoint.id, file_paths)
|
|
|
|
def commit(self) -> None:
|
|
"""Mark transaction as committed (no rollback on exit)."""
|
|
self._committed = True
|
|
|
|
async def rollback(self) -> RollbackResult | None:
|
|
"""Manually trigger rollback."""
|
|
if self._checkpoint:
|
|
return await self._manager.rollback(self._checkpoint.id)
|
|
return None
|