forked from cardosofelipe/fast-next-template
Add tests to improve backend coverage from 85% to 93%: - test_audit.py: 60 tests for AuditLogger (20% -> 99%) - Hash chain integrity, sanitization, retention, handlers - Fixed bug: hash chain modification after event creation - Fixed bug: verification not using correct prev_hash - test_hitl.py: Tests for HITL manager (0% -> 100%) - test_permissions.py: Tests for permissions manager (0% -> 99%) - test_rollback.py: Tests for rollback manager (0% -> 100%) - test_metrics.py: Tests for metrics collector (0% -> 100%) - test_mcp_integration.py: Tests for MCP safety wrapper (0% -> 100%) - test_validation.py: Additional cache and edge case tests (76% -> 100%) - test_scoring.py: Lock cleanup and edge case tests (78% -> 91%) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
824 lines
28 KiB
Python
824 lines
28 KiB
Python
"""Tests for Rollback Manager.
|
|
|
|
Tests cover:
|
|
- FileCheckpoint: state storage
|
|
- RollbackManager: checkpoint, rollback, cleanup
|
|
- TransactionContext: auto-rollback, commit, manual rollback
|
|
- Edge cases: non-existent files, partial failures, expiration
|
|
"""
|
|
|
|
import tempfile
|
|
from datetime import datetime, timedelta
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
|
|
from app.services.safety.exceptions import RollbackError
|
|
from app.services.safety.models import (
|
|
ActionMetadata,
|
|
ActionRequest,
|
|
ActionType,
|
|
CheckpointType,
|
|
)
|
|
from app.services.safety.rollback.manager import (
|
|
FileCheckpoint,
|
|
RollbackManager,
|
|
TransactionContext,
|
|
)
|
|
|
|
# ============================================================================
|
|
# Fixtures
|
|
# ============================================================================
|
|
|
|
|
|
@pytest.fixture
|
|
def action_metadata() -> ActionMetadata:
|
|
"""Create standard action metadata for tests."""
|
|
return ActionMetadata(
|
|
agent_id="test-agent",
|
|
project_id="test-project",
|
|
session_id="test-session",
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def action_request(action_metadata: ActionMetadata) -> ActionRequest:
|
|
"""Create a standard action request for tests."""
|
|
return ActionRequest(
|
|
id="action-123",
|
|
action_type=ActionType.FILE_WRITE,
|
|
tool_name="file_write",
|
|
resource="/tmp/test_file.txt", # noqa: S108
|
|
metadata=action_metadata,
|
|
is_destructive=True,
|
|
)
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def rollback_manager() -> RollbackManager:
|
|
"""Create a RollbackManager for testing."""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
with patch("app.services.safety.rollback.manager.get_safety_config") as mock:
|
|
mock.return_value = MagicMock(
|
|
checkpoint_dir=tmpdir,
|
|
checkpoint_retention_hours=24,
|
|
)
|
|
manager = RollbackManager(checkpoint_dir=tmpdir, retention_hours=24)
|
|
yield manager
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_dir() -> Path:
|
|
"""Create a temporary directory for file operations."""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
yield Path(tmpdir)
|
|
|
|
|
|
# ============================================================================
|
|
# FileCheckpoint Tests
|
|
# ============================================================================
|
|
|
|
|
|
class TestFileCheckpoint:
|
|
"""Tests for the FileCheckpoint class."""
|
|
|
|
def test_file_checkpoint_creation(self) -> None:
|
|
"""Test creating a file checkpoint."""
|
|
fc = FileCheckpoint(
|
|
checkpoint_id="cp-123",
|
|
file_path="/path/to/file.txt",
|
|
original_content=b"original content",
|
|
existed=True,
|
|
)
|
|
|
|
assert fc.checkpoint_id == "cp-123"
|
|
assert fc.file_path == "/path/to/file.txt"
|
|
assert fc.original_content == b"original content"
|
|
assert fc.existed is True
|
|
assert fc.created_at is not None
|
|
|
|
def test_file_checkpoint_nonexistent_file(self) -> None:
|
|
"""Test checkpoint for non-existent file."""
|
|
fc = FileCheckpoint(
|
|
checkpoint_id="cp-123",
|
|
file_path="/path/to/new_file.txt",
|
|
original_content=None,
|
|
existed=False,
|
|
)
|
|
|
|
assert fc.original_content is None
|
|
assert fc.existed is False
|
|
|
|
|
|
# ============================================================================
|
|
# RollbackManager Tests
|
|
# ============================================================================
|
|
|
|
|
|
class TestRollbackManager:
|
|
"""Tests for the RollbackManager class."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_checkpoint(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
action_request: ActionRequest,
|
|
) -> None:
|
|
"""Test creating a checkpoint."""
|
|
checkpoint = await rollback_manager.create_checkpoint(
|
|
action=action_request,
|
|
checkpoint_type=CheckpointType.FILE,
|
|
description="Test checkpoint",
|
|
)
|
|
|
|
assert checkpoint.id is not None
|
|
assert checkpoint.action_id == action_request.id
|
|
assert checkpoint.checkpoint_type == CheckpointType.FILE
|
|
assert checkpoint.description == "Test checkpoint"
|
|
assert checkpoint.expires_at is not None
|
|
assert checkpoint.is_valid is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_checkpoint_default_description(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
action_request: ActionRequest,
|
|
) -> None:
|
|
"""Test checkpoint with default description."""
|
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
|
|
|
assert "file_write" in checkpoint.description
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_checkpoint_file_exists(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
action_request: ActionRequest,
|
|
temp_dir: Path,
|
|
) -> None:
|
|
"""Test checkpointing an existing file."""
|
|
# Create a file
|
|
test_file = temp_dir / "test.txt"
|
|
test_file.write_text("original content")
|
|
|
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
|
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
|
|
|
|
# Verify checkpoint was stored
|
|
async with rollback_manager._lock:
|
|
file_checkpoints = rollback_manager._file_checkpoints.get(checkpoint.id, [])
|
|
assert len(file_checkpoints) == 1
|
|
assert file_checkpoints[0].existed is True
|
|
assert file_checkpoints[0].original_content == b"original content"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_checkpoint_file_not_exists(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
action_request: ActionRequest,
|
|
temp_dir: Path,
|
|
) -> None:
|
|
"""Test checkpointing a non-existent file."""
|
|
test_file = temp_dir / "new_file.txt"
|
|
assert not test_file.exists()
|
|
|
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
|
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
|
|
|
|
# Verify checkpoint was stored
|
|
async with rollback_manager._lock:
|
|
file_checkpoints = rollback_manager._file_checkpoints.get(checkpoint.id, [])
|
|
assert len(file_checkpoints) == 1
|
|
assert file_checkpoints[0].existed is False
|
|
assert file_checkpoints[0].original_content is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_checkpoint_files_multiple(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
action_request: ActionRequest,
|
|
temp_dir: Path,
|
|
) -> None:
|
|
"""Test checkpointing multiple files."""
|
|
# Create files
|
|
file1 = temp_dir / "file1.txt"
|
|
file2 = temp_dir / "file2.txt"
|
|
file1.write_text("content 1")
|
|
file2.write_text("content 2")
|
|
|
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
|
await rollback_manager.checkpoint_files(
|
|
checkpoint.id,
|
|
[str(file1), str(file2)],
|
|
)
|
|
|
|
async with rollback_manager._lock:
|
|
file_checkpoints = rollback_manager._file_checkpoints.get(checkpoint.id, [])
|
|
assert len(file_checkpoints) == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rollback_restore_modified_file(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
action_request: ActionRequest,
|
|
temp_dir: Path,
|
|
) -> None:
|
|
"""Test rollback restores modified file content."""
|
|
test_file = temp_dir / "test.txt"
|
|
test_file.write_text("original content")
|
|
|
|
# Create checkpoint
|
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
|
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
|
|
|
|
# Modify file
|
|
test_file.write_text("modified content")
|
|
assert test_file.read_text() == "modified content"
|
|
|
|
# Rollback
|
|
result = await rollback_manager.rollback(checkpoint.id)
|
|
|
|
assert result.success is True
|
|
assert len(result.actions_rolled_back) == 1
|
|
assert test_file.read_text() == "original content"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rollback_delete_new_file(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
action_request: ActionRequest,
|
|
temp_dir: Path,
|
|
) -> None:
|
|
"""Test rollback deletes file that didn't exist before."""
|
|
test_file = temp_dir / "new_file.txt"
|
|
assert not test_file.exists()
|
|
|
|
# Create checkpoint before file exists
|
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
|
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
|
|
|
|
# Create the file
|
|
test_file.write_text("new content")
|
|
assert test_file.exists()
|
|
|
|
# Rollback
|
|
result = await rollback_manager.rollback(checkpoint.id)
|
|
|
|
assert result.success is True
|
|
assert not test_file.exists()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rollback_not_found(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
) -> None:
|
|
"""Test rollback with non-existent checkpoint."""
|
|
with pytest.raises(RollbackError) as exc_info:
|
|
await rollback_manager.rollback("nonexistent-id")
|
|
|
|
assert "not found" in str(exc_info.value)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rollback_invalid_checkpoint(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
action_request: ActionRequest,
|
|
temp_dir: Path,
|
|
) -> None:
|
|
"""Test rollback with invalidated checkpoint."""
|
|
test_file = temp_dir / "test.txt"
|
|
test_file.write_text("original")
|
|
|
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
|
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
|
|
|
|
# Rollback once (invalidates checkpoint)
|
|
await rollback_manager.rollback(checkpoint.id)
|
|
|
|
# Try to rollback again
|
|
with pytest.raises(RollbackError) as exc_info:
|
|
await rollback_manager.rollback(checkpoint.id)
|
|
|
|
assert "no longer valid" in str(exc_info.value)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_discard_checkpoint(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
action_request: ActionRequest,
|
|
) -> None:
|
|
"""Test discarding a checkpoint."""
|
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
|
|
|
result = await rollback_manager.discard_checkpoint(checkpoint.id)
|
|
assert result is True
|
|
|
|
# Verify it's gone
|
|
cp = await rollback_manager.get_checkpoint(checkpoint.id)
|
|
assert cp is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_discard_checkpoint_nonexistent(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
) -> None:
|
|
"""Test discarding a non-existent checkpoint."""
|
|
result = await rollback_manager.discard_checkpoint("nonexistent-id")
|
|
assert result is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_checkpoint(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
action_request: ActionRequest,
|
|
) -> None:
|
|
"""Test getting a checkpoint by ID."""
|
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
|
|
|
retrieved = await rollback_manager.get_checkpoint(checkpoint.id)
|
|
assert retrieved is not None
|
|
assert retrieved.id == checkpoint.id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_checkpoint_nonexistent(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
) -> None:
|
|
"""Test getting a non-existent checkpoint."""
|
|
retrieved = await rollback_manager.get_checkpoint("nonexistent-id")
|
|
assert retrieved is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_checkpoints(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
action_request: ActionRequest,
|
|
) -> None:
|
|
"""Test listing checkpoints."""
|
|
await rollback_manager.create_checkpoint(action=action_request)
|
|
await rollback_manager.create_checkpoint(action=action_request)
|
|
|
|
checkpoints = await rollback_manager.list_checkpoints()
|
|
assert len(checkpoints) == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_checkpoints_by_action(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
action_metadata: ActionMetadata,
|
|
) -> None:
|
|
"""Test listing checkpoints filtered by action."""
|
|
action1 = ActionRequest(
|
|
id="action-1",
|
|
action_type=ActionType.FILE_WRITE,
|
|
metadata=action_metadata,
|
|
)
|
|
action2 = ActionRequest(
|
|
id="action-2",
|
|
action_type=ActionType.FILE_WRITE,
|
|
metadata=action_metadata,
|
|
)
|
|
|
|
await rollback_manager.create_checkpoint(action=action1)
|
|
await rollback_manager.create_checkpoint(action=action2)
|
|
|
|
checkpoints = await rollback_manager.list_checkpoints(action_id="action-1")
|
|
assert len(checkpoints) == 1
|
|
assert checkpoints[0].action_id == "action-1"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_checkpoints_excludes_expired(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
action_request: ActionRequest,
|
|
) -> None:
|
|
"""Test list_checkpoints excludes expired by default."""
|
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
|
|
|
# Manually expire it
|
|
async with rollback_manager._lock:
|
|
rollback_manager._checkpoints[checkpoint.id].expires_at = (
|
|
datetime.utcnow() - timedelta(hours=1)
|
|
)
|
|
|
|
checkpoints = await rollback_manager.list_checkpoints()
|
|
assert len(checkpoints) == 0
|
|
|
|
# With include_expired=True
|
|
checkpoints = await rollback_manager.list_checkpoints(include_expired=True)
|
|
assert len(checkpoints) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cleanup_expired(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
action_request: ActionRequest,
|
|
temp_dir: Path,
|
|
) -> None:
|
|
"""Test cleaning up expired checkpoints."""
|
|
# Create checkpoints
|
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
|
test_file = temp_dir / "test.txt"
|
|
test_file.write_text("content")
|
|
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
|
|
|
|
# Expire it
|
|
async with rollback_manager._lock:
|
|
rollback_manager._checkpoints[checkpoint.id].expires_at = (
|
|
datetime.utcnow() - timedelta(hours=1)
|
|
)
|
|
|
|
# Cleanup
|
|
count = await rollback_manager.cleanup_expired()
|
|
assert count == 1
|
|
|
|
# Verify it's gone
|
|
async with rollback_manager._lock:
|
|
assert checkpoint.id not in rollback_manager._checkpoints
|
|
assert checkpoint.id not in rollback_manager._file_checkpoints
|
|
|
|
|
|
# ============================================================================
|
|
# TransactionContext Tests
|
|
# ============================================================================
|
|
|
|
|
|
class TestTransactionContext:
|
|
"""Tests for the TransactionContext class."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_context_creates_checkpoint(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
action_request: ActionRequest,
|
|
) -> None:
|
|
"""Test that entering context creates a checkpoint."""
|
|
async with TransactionContext(rollback_manager, action_request) as tx:
|
|
assert tx.checkpoint_id is not None
|
|
|
|
# Verify checkpoint exists
|
|
cp = await rollback_manager.get_checkpoint(tx.checkpoint_id)
|
|
assert cp is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_context_checkpoint_file(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
action_request: ActionRequest,
|
|
temp_dir: Path,
|
|
) -> None:
|
|
"""Test checkpointing files through context."""
|
|
test_file = temp_dir / "test.txt"
|
|
test_file.write_text("original")
|
|
|
|
async with TransactionContext(rollback_manager, action_request) as tx:
|
|
await tx.checkpoint_file(str(test_file))
|
|
|
|
# Modify file
|
|
test_file.write_text("modified")
|
|
|
|
# Manual rollback
|
|
result = await tx.rollback()
|
|
assert result is not None
|
|
assert result.success is True
|
|
|
|
assert test_file.read_text() == "original"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_context_checkpoint_files(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
action_request: ActionRequest,
|
|
temp_dir: Path,
|
|
) -> None:
|
|
"""Test checkpointing multiple files through context."""
|
|
file1 = temp_dir / "file1.txt"
|
|
file2 = temp_dir / "file2.txt"
|
|
file1.write_text("content 1")
|
|
file2.write_text("content 2")
|
|
|
|
async with TransactionContext(rollback_manager, action_request) as tx:
|
|
await tx.checkpoint_files([str(file1), str(file2)])
|
|
|
|
cp_id = tx.checkpoint_id
|
|
async with rollback_manager._lock:
|
|
file_cps = rollback_manager._file_checkpoints.get(cp_id, [])
|
|
assert len(file_cps) == 2
|
|
|
|
tx.commit()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_context_auto_rollback_on_exception(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
action_request: ActionRequest,
|
|
temp_dir: Path,
|
|
) -> None:
|
|
"""Test auto-rollback when exception occurs."""
|
|
test_file = temp_dir / "test.txt"
|
|
test_file.write_text("original")
|
|
|
|
with pytest.raises(ValueError):
|
|
async with TransactionContext(rollback_manager, action_request) as tx:
|
|
await tx.checkpoint_file(str(test_file))
|
|
test_file.write_text("modified")
|
|
raise ValueError("Simulated error")
|
|
|
|
# Should have been rolled back
|
|
assert test_file.read_text() == "original"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_context_commit_prevents_rollback(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
action_request: ActionRequest,
|
|
temp_dir: Path,
|
|
) -> None:
|
|
"""Test that commit prevents auto-rollback."""
|
|
test_file = temp_dir / "test.txt"
|
|
test_file.write_text("original")
|
|
|
|
with pytest.raises(ValueError):
|
|
async with TransactionContext(rollback_manager, action_request) as tx:
|
|
await tx.checkpoint_file(str(test_file))
|
|
test_file.write_text("modified")
|
|
tx.commit()
|
|
raise ValueError("Simulated error after commit")
|
|
|
|
# Should NOT have been rolled back
|
|
assert test_file.read_text() == "modified"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_context_discards_checkpoint_on_commit(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
action_request: ActionRequest,
|
|
) -> None:
|
|
"""Test that checkpoint is discarded after successful commit."""
|
|
checkpoint_id = None
|
|
|
|
async with TransactionContext(rollback_manager, action_request) as tx:
|
|
checkpoint_id = tx.checkpoint_id
|
|
tx.commit()
|
|
|
|
# Checkpoint should be discarded
|
|
cp = await rollback_manager.get_checkpoint(checkpoint_id)
|
|
assert cp is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_context_no_auto_rollback_when_disabled(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
action_request: ActionRequest,
|
|
temp_dir: Path,
|
|
) -> None:
|
|
"""Test that auto_rollback=False disables auto-rollback."""
|
|
test_file = temp_dir / "test.txt"
|
|
test_file.write_text("original")
|
|
|
|
with pytest.raises(ValueError):
|
|
async with TransactionContext(
|
|
rollback_manager,
|
|
action_request,
|
|
auto_rollback=False,
|
|
) as tx:
|
|
await tx.checkpoint_file(str(test_file))
|
|
test_file.write_text("modified")
|
|
raise ValueError("Simulated error")
|
|
|
|
# Should NOT have been rolled back
|
|
assert test_file.read_text() == "modified"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_context_manual_rollback(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
action_request: ActionRequest,
|
|
temp_dir: Path,
|
|
) -> None:
|
|
"""Test manual rollback within context."""
|
|
test_file = temp_dir / "test.txt"
|
|
test_file.write_text("original")
|
|
|
|
async with TransactionContext(rollback_manager, action_request) as tx:
|
|
await tx.checkpoint_file(str(test_file))
|
|
test_file.write_text("modified")
|
|
|
|
# Manual rollback
|
|
result = await tx.rollback()
|
|
assert result is not None
|
|
assert result.success is True
|
|
|
|
assert test_file.read_text() == "original"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_context_rollback_without_checkpoint(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
action_request: ActionRequest,
|
|
) -> None:
|
|
"""Test rollback when checkpoint is None."""
|
|
tx = TransactionContext(rollback_manager, action_request)
|
|
# Don't enter context, so _checkpoint is None
|
|
result = await tx.rollback()
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_context_checkpoint_file_without_checkpoint(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
action_request: ActionRequest,
|
|
temp_dir: Path,
|
|
) -> None:
|
|
"""Test checkpoint_file when checkpoint is None (no-op)."""
|
|
tx = TransactionContext(rollback_manager, action_request)
|
|
test_file = temp_dir / "test.txt"
|
|
test_file.write_text("content")
|
|
|
|
# Should not raise - just a no-op
|
|
await tx.checkpoint_file(str(test_file))
|
|
await tx.checkpoint_files([str(test_file)])
|
|
|
|
|
|
# ============================================================================
|
|
# Edge Cases
|
|
# ============================================================================
|
|
|
|
|
|
class TestRollbackEdgeCases:
|
|
"""Edge cases that could reveal hidden bugs."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_checkpoint_file_for_unknown_checkpoint(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
temp_dir: Path,
|
|
) -> None:
|
|
"""Test checkpointing file for non-existent checkpoint."""
|
|
test_file = temp_dir / "test.txt"
|
|
test_file.write_text("content")
|
|
|
|
# Should create the list if it doesn't exist
|
|
await rollback_manager.checkpoint_file("unknown-checkpoint", str(test_file))
|
|
|
|
async with rollback_manager._lock:
|
|
assert "unknown-checkpoint" in rollback_manager._file_checkpoints
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rollback_with_partial_failure(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
action_request: ActionRequest,
|
|
temp_dir: Path,
|
|
) -> None:
|
|
"""Test rollback when some files fail to restore."""
|
|
file1 = temp_dir / "file1.txt"
|
|
file1.write_text("original 1")
|
|
|
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
|
await rollback_manager.checkpoint_file(checkpoint.id, str(file1))
|
|
|
|
# Add a file checkpoint with a path that will fail
|
|
async with rollback_manager._lock:
|
|
# Create a checkpoint for a file in a non-writable location
|
|
bad_fc = FileCheckpoint(
|
|
checkpoint_id=checkpoint.id,
|
|
file_path="/nonexistent/path/file.txt",
|
|
original_content=b"content",
|
|
existed=True,
|
|
)
|
|
rollback_manager._file_checkpoints[checkpoint.id].append(bad_fc)
|
|
|
|
# Rollback - partial failure expected
|
|
result = await rollback_manager.rollback(checkpoint.id)
|
|
|
|
assert result.success is False
|
|
assert len(result.actions_rolled_back) == 1
|
|
assert len(result.failed_actions) == 1
|
|
assert "Failed to rollback" in result.error
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rollback_file_creates_parent_dirs(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
action_request: ActionRequest,
|
|
temp_dir: Path,
|
|
) -> None:
|
|
"""Test that rollback creates parent directories if needed."""
|
|
nested_file = temp_dir / "subdir" / "nested" / "file.txt"
|
|
nested_file.parent.mkdir(parents=True)
|
|
nested_file.write_text("original")
|
|
|
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
|
await rollback_manager.checkpoint_file(checkpoint.id, str(nested_file))
|
|
|
|
# Delete the entire directory structure
|
|
nested_file.unlink()
|
|
(temp_dir / "subdir" / "nested").rmdir()
|
|
(temp_dir / "subdir").rmdir()
|
|
|
|
# Rollback should recreate
|
|
result = await rollback_manager.rollback(checkpoint.id)
|
|
|
|
assert result.success is True
|
|
assert nested_file.exists()
|
|
assert nested_file.read_text() == "original"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rollback_file_already_correct(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
action_request: ActionRequest,
|
|
temp_dir: Path,
|
|
) -> None:
|
|
"""Test rollback when file already has correct content."""
|
|
test_file = temp_dir / "test.txt"
|
|
test_file.write_text("original")
|
|
|
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
|
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
|
|
|
|
# Don't modify file - rollback should still succeed
|
|
result = await rollback_manager.rollback(checkpoint.id)
|
|
|
|
assert result.success is True
|
|
assert test_file.read_text() == "original"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_checkpoint_with_none_expires_at(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
action_request: ActionRequest,
|
|
) -> None:
|
|
"""Test list_checkpoints handles None expires_at."""
|
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
|
|
|
# Set expires_at to None
|
|
async with rollback_manager._lock:
|
|
rollback_manager._checkpoints[checkpoint.id].expires_at = None
|
|
|
|
# Should still be listed
|
|
checkpoints = await rollback_manager.list_checkpoints()
|
|
assert len(checkpoints) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_auto_rollback_failure_logged(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
action_request: ActionRequest,
|
|
temp_dir: Path,
|
|
) -> None:
|
|
"""Test that auto-rollback failure is logged, not raised."""
|
|
test_file = temp_dir / "test.txt"
|
|
test_file.write_text("original")
|
|
|
|
with patch.object(
|
|
rollback_manager, "rollback", side_effect=Exception("Rollback failed!")
|
|
):
|
|
with patch("app.services.safety.rollback.manager.logger") as mock_logger:
|
|
with pytest.raises(ValueError):
|
|
async with TransactionContext(
|
|
rollback_manager, action_request
|
|
) as tx:
|
|
await tx.checkpoint_file(str(test_file))
|
|
test_file.write_text("modified")
|
|
raise ValueError("Original error")
|
|
|
|
# Rollback error should be logged
|
|
mock_logger.error.assert_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_multiple_checkpoints_same_action(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
action_request: ActionRequest,
|
|
) -> None:
|
|
"""Test creating multiple checkpoints for the same action."""
|
|
cp1 = await rollback_manager.create_checkpoint(action=action_request)
|
|
cp2 = await rollback_manager.create_checkpoint(action=action_request)
|
|
|
|
assert cp1.id != cp2.id
|
|
|
|
checkpoints = await rollback_manager.list_checkpoints(
|
|
action_id=action_request.id
|
|
)
|
|
assert len(checkpoints) == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cleanup_expired_with_no_expired(
|
|
self,
|
|
rollback_manager: RollbackManager,
|
|
action_request: ActionRequest,
|
|
) -> None:
|
|
"""Test cleanup when no checkpoints are expired."""
|
|
await rollback_manager.create_checkpoint(action=action_request)
|
|
|
|
count = await rollback_manager.cleanup_expired()
|
|
assert count == 0
|
|
|
|
# Checkpoint should still exist
|
|
checkpoints = await rollback_manager.list_checkpoints()
|
|
assert len(checkpoints) == 1
|