"""Tests for Rollback Manager. Tests cover: - FileCheckpoint: state storage - RollbackManager: checkpoint, rollback, cleanup - TransactionContext: auto-rollback, commit, manual rollback - Edge cases: non-existent files, partial failures, expiration """ import tempfile from datetime import datetime, timedelta from pathlib import Path from unittest.mock import MagicMock, patch import pytest import pytest_asyncio from app.services.safety.exceptions import RollbackError from app.services.safety.models import ( ActionMetadata, ActionRequest, ActionType, CheckpointType, ) from app.services.safety.rollback.manager import ( FileCheckpoint, RollbackManager, TransactionContext, ) # ============================================================================ # Fixtures # ============================================================================ @pytest.fixture def action_metadata() -> ActionMetadata: """Create standard action metadata for tests.""" return ActionMetadata( agent_id="test-agent", project_id="test-project", session_id="test-session", ) @pytest.fixture def action_request(action_metadata: ActionMetadata) -> ActionRequest: """Create a standard action request for tests.""" return ActionRequest( id="action-123", action_type=ActionType.FILE_WRITE, tool_name="file_write", resource="/tmp/test_file.txt", # noqa: S108 metadata=action_metadata, is_destructive=True, ) @pytest_asyncio.fixture async def rollback_manager() -> RollbackManager: """Create a RollbackManager for testing.""" with tempfile.TemporaryDirectory() as tmpdir: with patch("app.services.safety.rollback.manager.get_safety_config") as mock: mock.return_value = MagicMock( checkpoint_dir=tmpdir, checkpoint_retention_hours=24, ) manager = RollbackManager(checkpoint_dir=tmpdir, retention_hours=24) yield manager @pytest.fixture def temp_dir() -> Path: """Create a temporary directory for file operations.""" with tempfile.TemporaryDirectory() as tmpdir: yield Path(tmpdir) # ============================================================================ # FileCheckpoint Tests # ============================================================================ class TestFileCheckpoint: """Tests for the FileCheckpoint class.""" def test_file_checkpoint_creation(self) -> None: """Test creating a file checkpoint.""" fc = FileCheckpoint( checkpoint_id="cp-123", file_path="/path/to/file.txt", original_content=b"original content", existed=True, ) assert fc.checkpoint_id == "cp-123" assert fc.file_path == "/path/to/file.txt" assert fc.original_content == b"original content" assert fc.existed is True assert fc.created_at is not None def test_file_checkpoint_nonexistent_file(self) -> None: """Test checkpoint for non-existent file.""" fc = FileCheckpoint( checkpoint_id="cp-123", file_path="/path/to/new_file.txt", original_content=None, existed=False, ) assert fc.original_content is None assert fc.existed is False # ============================================================================ # RollbackManager Tests # ============================================================================ class TestRollbackManager: """Tests for the RollbackManager class.""" @pytest.mark.asyncio async def test_create_checkpoint( self, rollback_manager: RollbackManager, action_request: ActionRequest, ) -> None: """Test creating a checkpoint.""" checkpoint = await rollback_manager.create_checkpoint( action=action_request, checkpoint_type=CheckpointType.FILE, description="Test checkpoint", ) assert checkpoint.id is not None assert checkpoint.action_id == action_request.id assert checkpoint.checkpoint_type == CheckpointType.FILE assert checkpoint.description == "Test checkpoint" assert checkpoint.expires_at is not None assert checkpoint.is_valid is True @pytest.mark.asyncio async def test_create_checkpoint_default_description( self, rollback_manager: RollbackManager, action_request: ActionRequest, ) -> None: """Test checkpoint with default description.""" checkpoint = await rollback_manager.create_checkpoint(action=action_request) assert "file_write" in checkpoint.description @pytest.mark.asyncio async def test_checkpoint_file_exists( self, rollback_manager: RollbackManager, action_request: ActionRequest, temp_dir: Path, ) -> None: """Test checkpointing an existing file.""" # Create a file test_file = temp_dir / "test.txt" test_file.write_text("original content") checkpoint = await rollback_manager.create_checkpoint(action=action_request) await rollback_manager.checkpoint_file(checkpoint.id, str(test_file)) # Verify checkpoint was stored async with rollback_manager._lock: file_checkpoints = rollback_manager._file_checkpoints.get(checkpoint.id, []) assert len(file_checkpoints) == 1 assert file_checkpoints[0].existed is True assert file_checkpoints[0].original_content == b"original content" @pytest.mark.asyncio async def test_checkpoint_file_not_exists( self, rollback_manager: RollbackManager, action_request: ActionRequest, temp_dir: Path, ) -> None: """Test checkpointing a non-existent file.""" test_file = temp_dir / "new_file.txt" assert not test_file.exists() checkpoint = await rollback_manager.create_checkpoint(action=action_request) await rollback_manager.checkpoint_file(checkpoint.id, str(test_file)) # Verify checkpoint was stored async with rollback_manager._lock: file_checkpoints = rollback_manager._file_checkpoints.get(checkpoint.id, []) assert len(file_checkpoints) == 1 assert file_checkpoints[0].existed is False assert file_checkpoints[0].original_content is None @pytest.mark.asyncio async def test_checkpoint_files_multiple( self, rollback_manager: RollbackManager, action_request: ActionRequest, temp_dir: Path, ) -> None: """Test checkpointing multiple files.""" # Create files file1 = temp_dir / "file1.txt" file2 = temp_dir / "file2.txt" file1.write_text("content 1") file2.write_text("content 2") checkpoint = await rollback_manager.create_checkpoint(action=action_request) await rollback_manager.checkpoint_files( checkpoint.id, [str(file1), str(file2)], ) async with rollback_manager._lock: file_checkpoints = rollback_manager._file_checkpoints.get(checkpoint.id, []) assert len(file_checkpoints) == 2 @pytest.mark.asyncio async def test_rollback_restore_modified_file( self, rollback_manager: RollbackManager, action_request: ActionRequest, temp_dir: Path, ) -> None: """Test rollback restores modified file content.""" test_file = temp_dir / "test.txt" test_file.write_text("original content") # Create checkpoint checkpoint = await rollback_manager.create_checkpoint(action=action_request) await rollback_manager.checkpoint_file(checkpoint.id, str(test_file)) # Modify file test_file.write_text("modified content") assert test_file.read_text() == "modified content" # Rollback result = await rollback_manager.rollback(checkpoint.id) assert result.success is True assert len(result.actions_rolled_back) == 1 assert test_file.read_text() == "original content" @pytest.mark.asyncio async def test_rollback_delete_new_file( self, rollback_manager: RollbackManager, action_request: ActionRequest, temp_dir: Path, ) -> None: """Test rollback deletes file that didn't exist before.""" test_file = temp_dir / "new_file.txt" assert not test_file.exists() # Create checkpoint before file exists checkpoint = await rollback_manager.create_checkpoint(action=action_request) await rollback_manager.checkpoint_file(checkpoint.id, str(test_file)) # Create the file test_file.write_text("new content") assert test_file.exists() # Rollback result = await rollback_manager.rollback(checkpoint.id) assert result.success is True assert not test_file.exists() @pytest.mark.asyncio async def test_rollback_not_found( self, rollback_manager: RollbackManager, ) -> None: """Test rollback with non-existent checkpoint.""" with pytest.raises(RollbackError) as exc_info: await rollback_manager.rollback("nonexistent-id") assert "not found" in str(exc_info.value) @pytest.mark.asyncio async def test_rollback_invalid_checkpoint( self, rollback_manager: RollbackManager, action_request: ActionRequest, temp_dir: Path, ) -> None: """Test rollback with invalidated checkpoint.""" test_file = temp_dir / "test.txt" test_file.write_text("original") checkpoint = await rollback_manager.create_checkpoint(action=action_request) await rollback_manager.checkpoint_file(checkpoint.id, str(test_file)) # Rollback once (invalidates checkpoint) await rollback_manager.rollback(checkpoint.id) # Try to rollback again with pytest.raises(RollbackError) as exc_info: await rollback_manager.rollback(checkpoint.id) assert "no longer valid" in str(exc_info.value) @pytest.mark.asyncio async def test_discard_checkpoint( self, rollback_manager: RollbackManager, action_request: ActionRequest, ) -> None: """Test discarding a checkpoint.""" checkpoint = await rollback_manager.create_checkpoint(action=action_request) result = await rollback_manager.discard_checkpoint(checkpoint.id) assert result is True # Verify it's gone cp = await rollback_manager.get_checkpoint(checkpoint.id) assert cp is None @pytest.mark.asyncio async def test_discard_checkpoint_nonexistent( self, rollback_manager: RollbackManager, ) -> None: """Test discarding a non-existent checkpoint.""" result = await rollback_manager.discard_checkpoint("nonexistent-id") assert result is False @pytest.mark.asyncio async def test_get_checkpoint( self, rollback_manager: RollbackManager, action_request: ActionRequest, ) -> None: """Test getting a checkpoint by ID.""" checkpoint = await rollback_manager.create_checkpoint(action=action_request) retrieved = await rollback_manager.get_checkpoint(checkpoint.id) assert retrieved is not None assert retrieved.id == checkpoint.id @pytest.mark.asyncio async def test_get_checkpoint_nonexistent( self, rollback_manager: RollbackManager, ) -> None: """Test getting a non-existent checkpoint.""" retrieved = await rollback_manager.get_checkpoint("nonexistent-id") assert retrieved is None @pytest.mark.asyncio async def test_list_checkpoints( self, rollback_manager: RollbackManager, action_request: ActionRequest, ) -> None: """Test listing checkpoints.""" await rollback_manager.create_checkpoint(action=action_request) await rollback_manager.create_checkpoint(action=action_request) checkpoints = await rollback_manager.list_checkpoints() assert len(checkpoints) == 2 @pytest.mark.asyncio async def test_list_checkpoints_by_action( self, rollback_manager: RollbackManager, action_metadata: ActionMetadata, ) -> None: """Test listing checkpoints filtered by action.""" action1 = ActionRequest( id="action-1", action_type=ActionType.FILE_WRITE, metadata=action_metadata, ) action2 = ActionRequest( id="action-2", action_type=ActionType.FILE_WRITE, metadata=action_metadata, ) await rollback_manager.create_checkpoint(action=action1) await rollback_manager.create_checkpoint(action=action2) checkpoints = await rollback_manager.list_checkpoints(action_id="action-1") assert len(checkpoints) == 1 assert checkpoints[0].action_id == "action-1" @pytest.mark.asyncio async def test_list_checkpoints_excludes_expired( self, rollback_manager: RollbackManager, action_request: ActionRequest, ) -> None: """Test list_checkpoints excludes expired by default.""" checkpoint = await rollback_manager.create_checkpoint(action=action_request) # Manually expire it async with rollback_manager._lock: rollback_manager._checkpoints[checkpoint.id].expires_at = ( datetime.utcnow() - timedelta(hours=1) ) checkpoints = await rollback_manager.list_checkpoints() assert len(checkpoints) == 0 # With include_expired=True checkpoints = await rollback_manager.list_checkpoints(include_expired=True) assert len(checkpoints) == 1 @pytest.mark.asyncio async def test_cleanup_expired( self, rollback_manager: RollbackManager, action_request: ActionRequest, temp_dir: Path, ) -> None: """Test cleaning up expired checkpoints.""" # Create checkpoints checkpoint = await rollback_manager.create_checkpoint(action=action_request) test_file = temp_dir / "test.txt" test_file.write_text("content") await rollback_manager.checkpoint_file(checkpoint.id, str(test_file)) # Expire it async with rollback_manager._lock: rollback_manager._checkpoints[checkpoint.id].expires_at = ( datetime.utcnow() - timedelta(hours=1) ) # Cleanup count = await rollback_manager.cleanup_expired() assert count == 1 # Verify it's gone async with rollback_manager._lock: assert checkpoint.id not in rollback_manager._checkpoints assert checkpoint.id not in rollback_manager._file_checkpoints # ============================================================================ # TransactionContext Tests # ============================================================================ class TestTransactionContext: """Tests for the TransactionContext class.""" @pytest.mark.asyncio async def test_context_creates_checkpoint( self, rollback_manager: RollbackManager, action_request: ActionRequest, ) -> None: """Test that entering context creates a checkpoint.""" async with TransactionContext(rollback_manager, action_request) as tx: assert tx.checkpoint_id is not None # Verify checkpoint exists cp = await rollback_manager.get_checkpoint(tx.checkpoint_id) assert cp is not None @pytest.mark.asyncio async def test_context_checkpoint_file( self, rollback_manager: RollbackManager, action_request: ActionRequest, temp_dir: Path, ) -> None: """Test checkpointing files through context.""" test_file = temp_dir / "test.txt" test_file.write_text("original") async with TransactionContext(rollback_manager, action_request) as tx: await tx.checkpoint_file(str(test_file)) # Modify file test_file.write_text("modified") # Manual rollback result = await tx.rollback() assert result is not None assert result.success is True assert test_file.read_text() == "original" @pytest.mark.asyncio async def test_context_checkpoint_files( self, rollback_manager: RollbackManager, action_request: ActionRequest, temp_dir: Path, ) -> None: """Test checkpointing multiple files through context.""" file1 = temp_dir / "file1.txt" file2 = temp_dir / "file2.txt" file1.write_text("content 1") file2.write_text("content 2") async with TransactionContext(rollback_manager, action_request) as tx: await tx.checkpoint_files([str(file1), str(file2)]) cp_id = tx.checkpoint_id async with rollback_manager._lock: file_cps = rollback_manager._file_checkpoints.get(cp_id, []) assert len(file_cps) == 2 tx.commit() @pytest.mark.asyncio async def test_context_auto_rollback_on_exception( self, rollback_manager: RollbackManager, action_request: ActionRequest, temp_dir: Path, ) -> None: """Test auto-rollback when exception occurs.""" test_file = temp_dir / "test.txt" test_file.write_text("original") with pytest.raises(ValueError): async with TransactionContext(rollback_manager, action_request) as tx: await tx.checkpoint_file(str(test_file)) test_file.write_text("modified") raise ValueError("Simulated error") # Should have been rolled back assert test_file.read_text() == "original" @pytest.mark.asyncio async def test_context_commit_prevents_rollback( self, rollback_manager: RollbackManager, action_request: ActionRequest, temp_dir: Path, ) -> None: """Test that commit prevents auto-rollback.""" test_file = temp_dir / "test.txt" test_file.write_text("original") with pytest.raises(ValueError): async with TransactionContext(rollback_manager, action_request) as tx: await tx.checkpoint_file(str(test_file)) test_file.write_text("modified") tx.commit() raise ValueError("Simulated error after commit") # Should NOT have been rolled back assert test_file.read_text() == "modified" @pytest.mark.asyncio async def test_context_discards_checkpoint_on_commit( self, rollback_manager: RollbackManager, action_request: ActionRequest, ) -> None: """Test that checkpoint is discarded after successful commit.""" checkpoint_id = None async with TransactionContext(rollback_manager, action_request) as tx: checkpoint_id = tx.checkpoint_id tx.commit() # Checkpoint should be discarded cp = await rollback_manager.get_checkpoint(checkpoint_id) assert cp is None @pytest.mark.asyncio async def test_context_no_auto_rollback_when_disabled( self, rollback_manager: RollbackManager, action_request: ActionRequest, temp_dir: Path, ) -> None: """Test that auto_rollback=False disables auto-rollback.""" test_file = temp_dir / "test.txt" test_file.write_text("original") with pytest.raises(ValueError): async with TransactionContext( rollback_manager, action_request, auto_rollback=False, ) as tx: await tx.checkpoint_file(str(test_file)) test_file.write_text("modified") raise ValueError("Simulated error") # Should NOT have been rolled back assert test_file.read_text() == "modified" @pytest.mark.asyncio async def test_context_manual_rollback( self, rollback_manager: RollbackManager, action_request: ActionRequest, temp_dir: Path, ) -> None: """Test manual rollback within context.""" test_file = temp_dir / "test.txt" test_file.write_text("original") async with TransactionContext(rollback_manager, action_request) as tx: await tx.checkpoint_file(str(test_file)) test_file.write_text("modified") # Manual rollback result = await tx.rollback() assert result is not None assert result.success is True assert test_file.read_text() == "original" @pytest.mark.asyncio async def test_context_rollback_without_checkpoint( self, rollback_manager: RollbackManager, action_request: ActionRequest, ) -> None: """Test rollback when checkpoint is None.""" tx = TransactionContext(rollback_manager, action_request) # Don't enter context, so _checkpoint is None result = await tx.rollback() assert result is None @pytest.mark.asyncio async def test_context_checkpoint_file_without_checkpoint( self, rollback_manager: RollbackManager, action_request: ActionRequest, temp_dir: Path, ) -> None: """Test checkpoint_file when checkpoint is None (no-op).""" tx = TransactionContext(rollback_manager, action_request) test_file = temp_dir / "test.txt" test_file.write_text("content") # Should not raise - just a no-op await tx.checkpoint_file(str(test_file)) await tx.checkpoint_files([str(test_file)]) # ============================================================================ # Edge Cases # ============================================================================ class TestRollbackEdgeCases: """Edge cases that could reveal hidden bugs.""" @pytest.mark.asyncio async def test_checkpoint_file_for_unknown_checkpoint( self, rollback_manager: RollbackManager, temp_dir: Path, ) -> None: """Test checkpointing file for non-existent checkpoint.""" test_file = temp_dir / "test.txt" test_file.write_text("content") # Should create the list if it doesn't exist await rollback_manager.checkpoint_file("unknown-checkpoint", str(test_file)) async with rollback_manager._lock: assert "unknown-checkpoint" in rollback_manager._file_checkpoints @pytest.mark.asyncio async def test_rollback_with_partial_failure( self, rollback_manager: RollbackManager, action_request: ActionRequest, temp_dir: Path, ) -> None: """Test rollback when some files fail to restore.""" file1 = temp_dir / "file1.txt" file1.write_text("original 1") checkpoint = await rollback_manager.create_checkpoint(action=action_request) await rollback_manager.checkpoint_file(checkpoint.id, str(file1)) # Add a file checkpoint with a path that will fail async with rollback_manager._lock: # Create a checkpoint for a file in a non-writable location bad_fc = FileCheckpoint( checkpoint_id=checkpoint.id, file_path="/nonexistent/path/file.txt", original_content=b"content", existed=True, ) rollback_manager._file_checkpoints[checkpoint.id].append(bad_fc) # Rollback - partial failure expected result = await rollback_manager.rollback(checkpoint.id) assert result.success is False assert len(result.actions_rolled_back) == 1 assert len(result.failed_actions) == 1 assert "Failed to rollback" in result.error @pytest.mark.asyncio async def test_rollback_file_creates_parent_dirs( self, rollback_manager: RollbackManager, action_request: ActionRequest, temp_dir: Path, ) -> None: """Test that rollback creates parent directories if needed.""" nested_file = temp_dir / "subdir" / "nested" / "file.txt" nested_file.parent.mkdir(parents=True) nested_file.write_text("original") checkpoint = await rollback_manager.create_checkpoint(action=action_request) await rollback_manager.checkpoint_file(checkpoint.id, str(nested_file)) # Delete the entire directory structure nested_file.unlink() (temp_dir / "subdir" / "nested").rmdir() (temp_dir / "subdir").rmdir() # Rollback should recreate result = await rollback_manager.rollback(checkpoint.id) assert result.success is True assert nested_file.exists() assert nested_file.read_text() == "original" @pytest.mark.asyncio async def test_rollback_file_already_correct( self, rollback_manager: RollbackManager, action_request: ActionRequest, temp_dir: Path, ) -> None: """Test rollback when file already has correct content.""" test_file = temp_dir / "test.txt" test_file.write_text("original") checkpoint = await rollback_manager.create_checkpoint(action=action_request) await rollback_manager.checkpoint_file(checkpoint.id, str(test_file)) # Don't modify file - rollback should still succeed result = await rollback_manager.rollback(checkpoint.id) assert result.success is True assert test_file.read_text() == "original" @pytest.mark.asyncio async def test_checkpoint_with_none_expires_at( self, rollback_manager: RollbackManager, action_request: ActionRequest, ) -> None: """Test list_checkpoints handles None expires_at.""" checkpoint = await rollback_manager.create_checkpoint(action=action_request) # Set expires_at to None async with rollback_manager._lock: rollback_manager._checkpoints[checkpoint.id].expires_at = None # Should still be listed checkpoints = await rollback_manager.list_checkpoints() assert len(checkpoints) == 1 @pytest.mark.asyncio async def test_auto_rollback_failure_logged( self, rollback_manager: RollbackManager, action_request: ActionRequest, temp_dir: Path, ) -> None: """Test that auto-rollback failure is logged, not raised.""" test_file = temp_dir / "test.txt" test_file.write_text("original") with patch.object( rollback_manager, "rollback", side_effect=Exception("Rollback failed!") ): with patch("app.services.safety.rollback.manager.logger") as mock_logger: with pytest.raises(ValueError): async with TransactionContext( rollback_manager, action_request ) as tx: await tx.checkpoint_file(str(test_file)) test_file.write_text("modified") raise ValueError("Original error") # Rollback error should be logged mock_logger.error.assert_called() @pytest.mark.asyncio async def test_multiple_checkpoints_same_action( self, rollback_manager: RollbackManager, action_request: ActionRequest, ) -> None: """Test creating multiple checkpoints for the same action.""" cp1 = await rollback_manager.create_checkpoint(action=action_request) cp2 = await rollback_manager.create_checkpoint(action=action_request) assert cp1.id != cp2.id checkpoints = await rollback_manager.list_checkpoints( action_id=action_request.id ) assert len(checkpoints) == 2 @pytest.mark.asyncio async def test_cleanup_expired_with_no_expired( self, rollback_manager: RollbackManager, action_request: ActionRequest, ) -> None: """Test cleanup when no checkpoints are expired.""" await rollback_manager.create_checkpoint(action=action_request) count = await rollback_manager.cleanup_expired() assert count == 0 # Checkpoint should still exist checkpoints = await rollback_manager.list_checkpoints() assert len(checkpoints) == 1