Files
fast-next-template/backend/tests/services/safety/test_rollback.py
Felipe Cardoso 60ebeaa582 test(safety): add comprehensive tests for safety framework modules
Add tests to improve backend coverage from 85% to 93%:

- test_audit.py: 60 tests for AuditLogger (20% -> 99%)
  - Hash chain integrity, sanitization, retention, handlers
  - Fixed bug: hash chain modification after event creation
  - Fixed bug: verification not using correct prev_hash

- test_hitl.py: Tests for HITL manager (0% -> 100%)
- test_permissions.py: Tests for permissions manager (0% -> 99%)
- test_rollback.py: Tests for rollback manager (0% -> 100%)
- test_metrics.py: Tests for metrics collector (0% -> 100%)
- test_mcp_integration.py: Tests for MCP safety wrapper (0% -> 100%)
- test_validation.py: Additional cache and edge case tests (76% -> 100%)
- test_scoring.py: Lock cleanup and edge case tests (78% -> 91%)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-04 19:41:54 +01:00

824 lines
28 KiB
Python

"""Tests for Rollback Manager.
Tests cover:
- FileCheckpoint: state storage
- RollbackManager: checkpoint, rollback, cleanup
- TransactionContext: auto-rollback, commit, manual rollback
- Edge cases: non-existent files, partial failures, expiration
"""
import tempfile
from datetime import datetime, timedelta
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
import pytest_asyncio
from app.services.safety.exceptions import RollbackError
from app.services.safety.models import (
ActionMetadata,
ActionRequest,
ActionType,
CheckpointType,
)
from app.services.safety.rollback.manager import (
FileCheckpoint,
RollbackManager,
TransactionContext,
)
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture
def action_metadata() -> ActionMetadata:
"""Create standard action metadata for tests."""
return ActionMetadata(
agent_id="test-agent",
project_id="test-project",
session_id="test-session",
)
@pytest.fixture
def action_request(action_metadata: ActionMetadata) -> ActionRequest:
"""Create a standard action request for tests."""
return ActionRequest(
id="action-123",
action_type=ActionType.FILE_WRITE,
tool_name="file_write",
resource="/tmp/test_file.txt", # noqa: S108
metadata=action_metadata,
is_destructive=True,
)
@pytest_asyncio.fixture
async def rollback_manager() -> RollbackManager:
"""Create a RollbackManager for testing."""
with tempfile.TemporaryDirectory() as tmpdir:
with patch("app.services.safety.rollback.manager.get_safety_config") as mock:
mock.return_value = MagicMock(
checkpoint_dir=tmpdir,
checkpoint_retention_hours=24,
)
manager = RollbackManager(checkpoint_dir=tmpdir, retention_hours=24)
yield manager
@pytest.fixture
def temp_dir() -> Path:
"""Create a temporary directory for file operations."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
# ============================================================================
# FileCheckpoint Tests
# ============================================================================
class TestFileCheckpoint:
"""Tests for the FileCheckpoint class."""
def test_file_checkpoint_creation(self) -> None:
"""Test creating a file checkpoint."""
fc = FileCheckpoint(
checkpoint_id="cp-123",
file_path="/path/to/file.txt",
original_content=b"original content",
existed=True,
)
assert fc.checkpoint_id == "cp-123"
assert fc.file_path == "/path/to/file.txt"
assert fc.original_content == b"original content"
assert fc.existed is True
assert fc.created_at is not None
def test_file_checkpoint_nonexistent_file(self) -> None:
"""Test checkpoint for non-existent file."""
fc = FileCheckpoint(
checkpoint_id="cp-123",
file_path="/path/to/new_file.txt",
original_content=None,
existed=False,
)
assert fc.original_content is None
assert fc.existed is False
# ============================================================================
# RollbackManager Tests
# ============================================================================
class TestRollbackManager:
"""Tests for the RollbackManager class."""
@pytest.mark.asyncio
async def test_create_checkpoint(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
) -> None:
"""Test creating a checkpoint."""
checkpoint = await rollback_manager.create_checkpoint(
action=action_request,
checkpoint_type=CheckpointType.FILE,
description="Test checkpoint",
)
assert checkpoint.id is not None
assert checkpoint.action_id == action_request.id
assert checkpoint.checkpoint_type == CheckpointType.FILE
assert checkpoint.description == "Test checkpoint"
assert checkpoint.expires_at is not None
assert checkpoint.is_valid is True
@pytest.mark.asyncio
async def test_create_checkpoint_default_description(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
) -> None:
"""Test checkpoint with default description."""
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
assert "file_write" in checkpoint.description
@pytest.mark.asyncio
async def test_checkpoint_file_exists(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test checkpointing an existing file."""
# Create a file
test_file = temp_dir / "test.txt"
test_file.write_text("original content")
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
# Verify checkpoint was stored
async with rollback_manager._lock:
file_checkpoints = rollback_manager._file_checkpoints.get(checkpoint.id, [])
assert len(file_checkpoints) == 1
assert file_checkpoints[0].existed is True
assert file_checkpoints[0].original_content == b"original content"
@pytest.mark.asyncio
async def test_checkpoint_file_not_exists(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test checkpointing a non-existent file."""
test_file = temp_dir / "new_file.txt"
assert not test_file.exists()
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
# Verify checkpoint was stored
async with rollback_manager._lock:
file_checkpoints = rollback_manager._file_checkpoints.get(checkpoint.id, [])
assert len(file_checkpoints) == 1
assert file_checkpoints[0].existed is False
assert file_checkpoints[0].original_content is None
@pytest.mark.asyncio
async def test_checkpoint_files_multiple(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test checkpointing multiple files."""
# Create files
file1 = temp_dir / "file1.txt"
file2 = temp_dir / "file2.txt"
file1.write_text("content 1")
file2.write_text("content 2")
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
await rollback_manager.checkpoint_files(
checkpoint.id,
[str(file1), str(file2)],
)
async with rollback_manager._lock:
file_checkpoints = rollback_manager._file_checkpoints.get(checkpoint.id, [])
assert len(file_checkpoints) == 2
@pytest.mark.asyncio
async def test_rollback_restore_modified_file(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test rollback restores modified file content."""
test_file = temp_dir / "test.txt"
test_file.write_text("original content")
# Create checkpoint
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
# Modify file
test_file.write_text("modified content")
assert test_file.read_text() == "modified content"
# Rollback
result = await rollback_manager.rollback(checkpoint.id)
assert result.success is True
assert len(result.actions_rolled_back) == 1
assert test_file.read_text() == "original content"
@pytest.mark.asyncio
async def test_rollback_delete_new_file(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test rollback deletes file that didn't exist before."""
test_file = temp_dir / "new_file.txt"
assert not test_file.exists()
# Create checkpoint before file exists
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
# Create the file
test_file.write_text("new content")
assert test_file.exists()
# Rollback
result = await rollback_manager.rollback(checkpoint.id)
assert result.success is True
assert not test_file.exists()
@pytest.mark.asyncio
async def test_rollback_not_found(
self,
rollback_manager: RollbackManager,
) -> None:
"""Test rollback with non-existent checkpoint."""
with pytest.raises(RollbackError) as exc_info:
await rollback_manager.rollback("nonexistent-id")
assert "not found" in str(exc_info.value)
@pytest.mark.asyncio
async def test_rollback_invalid_checkpoint(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test rollback with invalidated checkpoint."""
test_file = temp_dir / "test.txt"
test_file.write_text("original")
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
# Rollback once (invalidates checkpoint)
await rollback_manager.rollback(checkpoint.id)
# Try to rollback again
with pytest.raises(RollbackError) as exc_info:
await rollback_manager.rollback(checkpoint.id)
assert "no longer valid" in str(exc_info.value)
@pytest.mark.asyncio
async def test_discard_checkpoint(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
) -> None:
"""Test discarding a checkpoint."""
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
result = await rollback_manager.discard_checkpoint(checkpoint.id)
assert result is True
# Verify it's gone
cp = await rollback_manager.get_checkpoint(checkpoint.id)
assert cp is None
@pytest.mark.asyncio
async def test_discard_checkpoint_nonexistent(
self,
rollback_manager: RollbackManager,
) -> None:
"""Test discarding a non-existent checkpoint."""
result = await rollback_manager.discard_checkpoint("nonexistent-id")
assert result is False
@pytest.mark.asyncio
async def test_get_checkpoint(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
) -> None:
"""Test getting a checkpoint by ID."""
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
retrieved = await rollback_manager.get_checkpoint(checkpoint.id)
assert retrieved is not None
assert retrieved.id == checkpoint.id
@pytest.mark.asyncio
async def test_get_checkpoint_nonexistent(
self,
rollback_manager: RollbackManager,
) -> None:
"""Test getting a non-existent checkpoint."""
retrieved = await rollback_manager.get_checkpoint("nonexistent-id")
assert retrieved is None
@pytest.mark.asyncio
async def test_list_checkpoints(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
) -> None:
"""Test listing checkpoints."""
await rollback_manager.create_checkpoint(action=action_request)
await rollback_manager.create_checkpoint(action=action_request)
checkpoints = await rollback_manager.list_checkpoints()
assert len(checkpoints) == 2
@pytest.mark.asyncio
async def test_list_checkpoints_by_action(
self,
rollback_manager: RollbackManager,
action_metadata: ActionMetadata,
) -> None:
"""Test listing checkpoints filtered by action."""
action1 = ActionRequest(
id="action-1",
action_type=ActionType.FILE_WRITE,
metadata=action_metadata,
)
action2 = ActionRequest(
id="action-2",
action_type=ActionType.FILE_WRITE,
metadata=action_metadata,
)
await rollback_manager.create_checkpoint(action=action1)
await rollback_manager.create_checkpoint(action=action2)
checkpoints = await rollback_manager.list_checkpoints(action_id="action-1")
assert len(checkpoints) == 1
assert checkpoints[0].action_id == "action-1"
@pytest.mark.asyncio
async def test_list_checkpoints_excludes_expired(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
) -> None:
"""Test list_checkpoints excludes expired by default."""
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
# Manually expire it
async with rollback_manager._lock:
rollback_manager._checkpoints[checkpoint.id].expires_at = (
datetime.utcnow() - timedelta(hours=1)
)
checkpoints = await rollback_manager.list_checkpoints()
assert len(checkpoints) == 0
# With include_expired=True
checkpoints = await rollback_manager.list_checkpoints(include_expired=True)
assert len(checkpoints) == 1
@pytest.mark.asyncio
async def test_cleanup_expired(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test cleaning up expired checkpoints."""
# Create checkpoints
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
test_file = temp_dir / "test.txt"
test_file.write_text("content")
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
# Expire it
async with rollback_manager._lock:
rollback_manager._checkpoints[checkpoint.id].expires_at = (
datetime.utcnow() - timedelta(hours=1)
)
# Cleanup
count = await rollback_manager.cleanup_expired()
assert count == 1
# Verify it's gone
async with rollback_manager._lock:
assert checkpoint.id not in rollback_manager._checkpoints
assert checkpoint.id not in rollback_manager._file_checkpoints
# ============================================================================
# TransactionContext Tests
# ============================================================================
class TestTransactionContext:
"""Tests for the TransactionContext class."""
@pytest.mark.asyncio
async def test_context_creates_checkpoint(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
) -> None:
"""Test that entering context creates a checkpoint."""
async with TransactionContext(rollback_manager, action_request) as tx:
assert tx.checkpoint_id is not None
# Verify checkpoint exists
cp = await rollback_manager.get_checkpoint(tx.checkpoint_id)
assert cp is not None
@pytest.mark.asyncio
async def test_context_checkpoint_file(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test checkpointing files through context."""
test_file = temp_dir / "test.txt"
test_file.write_text("original")
async with TransactionContext(rollback_manager, action_request) as tx:
await tx.checkpoint_file(str(test_file))
# Modify file
test_file.write_text("modified")
# Manual rollback
result = await tx.rollback()
assert result is not None
assert result.success is True
assert test_file.read_text() == "original"
@pytest.mark.asyncio
async def test_context_checkpoint_files(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test checkpointing multiple files through context."""
file1 = temp_dir / "file1.txt"
file2 = temp_dir / "file2.txt"
file1.write_text("content 1")
file2.write_text("content 2")
async with TransactionContext(rollback_manager, action_request) as tx:
await tx.checkpoint_files([str(file1), str(file2)])
cp_id = tx.checkpoint_id
async with rollback_manager._lock:
file_cps = rollback_manager._file_checkpoints.get(cp_id, [])
assert len(file_cps) == 2
tx.commit()
@pytest.mark.asyncio
async def test_context_auto_rollback_on_exception(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test auto-rollback when exception occurs."""
test_file = temp_dir / "test.txt"
test_file.write_text("original")
with pytest.raises(ValueError):
async with TransactionContext(rollback_manager, action_request) as tx:
await tx.checkpoint_file(str(test_file))
test_file.write_text("modified")
raise ValueError("Simulated error")
# Should have been rolled back
assert test_file.read_text() == "original"
@pytest.mark.asyncio
async def test_context_commit_prevents_rollback(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test that commit prevents auto-rollback."""
test_file = temp_dir / "test.txt"
test_file.write_text("original")
with pytest.raises(ValueError):
async with TransactionContext(rollback_manager, action_request) as tx:
await tx.checkpoint_file(str(test_file))
test_file.write_text("modified")
tx.commit()
raise ValueError("Simulated error after commit")
# Should NOT have been rolled back
assert test_file.read_text() == "modified"
@pytest.mark.asyncio
async def test_context_discards_checkpoint_on_commit(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
) -> None:
"""Test that checkpoint is discarded after successful commit."""
checkpoint_id = None
async with TransactionContext(rollback_manager, action_request) as tx:
checkpoint_id = tx.checkpoint_id
tx.commit()
# Checkpoint should be discarded
cp = await rollback_manager.get_checkpoint(checkpoint_id)
assert cp is None
@pytest.mark.asyncio
async def test_context_no_auto_rollback_when_disabled(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test that auto_rollback=False disables auto-rollback."""
test_file = temp_dir / "test.txt"
test_file.write_text("original")
with pytest.raises(ValueError):
async with TransactionContext(
rollback_manager,
action_request,
auto_rollback=False,
) as tx:
await tx.checkpoint_file(str(test_file))
test_file.write_text("modified")
raise ValueError("Simulated error")
# Should NOT have been rolled back
assert test_file.read_text() == "modified"
@pytest.mark.asyncio
async def test_context_manual_rollback(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test manual rollback within context."""
test_file = temp_dir / "test.txt"
test_file.write_text("original")
async with TransactionContext(rollback_manager, action_request) as tx:
await tx.checkpoint_file(str(test_file))
test_file.write_text("modified")
# Manual rollback
result = await tx.rollback()
assert result is not None
assert result.success is True
assert test_file.read_text() == "original"
@pytest.mark.asyncio
async def test_context_rollback_without_checkpoint(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
) -> None:
"""Test rollback when checkpoint is None."""
tx = TransactionContext(rollback_manager, action_request)
# Don't enter context, so _checkpoint is None
result = await tx.rollback()
assert result is None
@pytest.mark.asyncio
async def test_context_checkpoint_file_without_checkpoint(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test checkpoint_file when checkpoint is None (no-op)."""
tx = TransactionContext(rollback_manager, action_request)
test_file = temp_dir / "test.txt"
test_file.write_text("content")
# Should not raise - just a no-op
await tx.checkpoint_file(str(test_file))
await tx.checkpoint_files([str(test_file)])
# ============================================================================
# Edge Cases
# ============================================================================
class TestRollbackEdgeCases:
"""Edge cases that could reveal hidden bugs."""
@pytest.mark.asyncio
async def test_checkpoint_file_for_unknown_checkpoint(
self,
rollback_manager: RollbackManager,
temp_dir: Path,
) -> None:
"""Test checkpointing file for non-existent checkpoint."""
test_file = temp_dir / "test.txt"
test_file.write_text("content")
# Should create the list if it doesn't exist
await rollback_manager.checkpoint_file("unknown-checkpoint", str(test_file))
async with rollback_manager._lock:
assert "unknown-checkpoint" in rollback_manager._file_checkpoints
@pytest.mark.asyncio
async def test_rollback_with_partial_failure(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test rollback when some files fail to restore."""
file1 = temp_dir / "file1.txt"
file1.write_text("original 1")
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
await rollback_manager.checkpoint_file(checkpoint.id, str(file1))
# Add a file checkpoint with a path that will fail
async with rollback_manager._lock:
# Create a checkpoint for a file in a non-writable location
bad_fc = FileCheckpoint(
checkpoint_id=checkpoint.id,
file_path="/nonexistent/path/file.txt",
original_content=b"content",
existed=True,
)
rollback_manager._file_checkpoints[checkpoint.id].append(bad_fc)
# Rollback - partial failure expected
result = await rollback_manager.rollback(checkpoint.id)
assert result.success is False
assert len(result.actions_rolled_back) == 1
assert len(result.failed_actions) == 1
assert "Failed to rollback" in result.error
@pytest.mark.asyncio
async def test_rollback_file_creates_parent_dirs(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test that rollback creates parent directories if needed."""
nested_file = temp_dir / "subdir" / "nested" / "file.txt"
nested_file.parent.mkdir(parents=True)
nested_file.write_text("original")
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
await rollback_manager.checkpoint_file(checkpoint.id, str(nested_file))
# Delete the entire directory structure
nested_file.unlink()
(temp_dir / "subdir" / "nested").rmdir()
(temp_dir / "subdir").rmdir()
# Rollback should recreate
result = await rollback_manager.rollback(checkpoint.id)
assert result.success is True
assert nested_file.exists()
assert nested_file.read_text() == "original"
@pytest.mark.asyncio
async def test_rollback_file_already_correct(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test rollback when file already has correct content."""
test_file = temp_dir / "test.txt"
test_file.write_text("original")
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
# Don't modify file - rollback should still succeed
result = await rollback_manager.rollback(checkpoint.id)
assert result.success is True
assert test_file.read_text() == "original"
@pytest.mark.asyncio
async def test_checkpoint_with_none_expires_at(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
) -> None:
"""Test list_checkpoints handles None expires_at."""
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
# Set expires_at to None
async with rollback_manager._lock:
rollback_manager._checkpoints[checkpoint.id].expires_at = None
# Should still be listed
checkpoints = await rollback_manager.list_checkpoints()
assert len(checkpoints) == 1
@pytest.mark.asyncio
async def test_auto_rollback_failure_logged(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test that auto-rollback failure is logged, not raised."""
test_file = temp_dir / "test.txt"
test_file.write_text("original")
with patch.object(
rollback_manager, "rollback", side_effect=Exception("Rollback failed!")
):
with patch("app.services.safety.rollback.manager.logger") as mock_logger:
with pytest.raises(ValueError):
async with TransactionContext(
rollback_manager, action_request
) as tx:
await tx.checkpoint_file(str(test_file))
test_file.write_text("modified")
raise ValueError("Original error")
# Rollback error should be logged
mock_logger.error.assert_called()
@pytest.mark.asyncio
async def test_multiple_checkpoints_same_action(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
) -> None:
"""Test creating multiple checkpoints for the same action."""
cp1 = await rollback_manager.create_checkpoint(action=action_request)
cp2 = await rollback_manager.create_checkpoint(action=action_request)
assert cp1.id != cp2.id
checkpoints = await rollback_manager.list_checkpoints(
action_id=action_request.id
)
assert len(checkpoints) == 2
@pytest.mark.asyncio
async def test_cleanup_expired_with_no_expired(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
) -> None:
"""Test cleanup when no checkpoints are expired."""
await rollback_manager.create_checkpoint(action=action_request)
count = await rollback_manager.cleanup_expired()
assert count == 0
# Checkpoint should still exist
checkpoints = await rollback_manager.list_checkpoints()
assert len(checkpoints) == 1