""" Tests for Audit Logger. Tests cover: - AuditLogger initialization and lifecycle - Event logging and hash chain - Query and filtering - Retention policy enforcement - Handler management - Singleton pattern """ import asyncio from datetime import datetime, timedelta from unittest.mock import MagicMock, patch import pytest import pytest_asyncio from app.services.safety.audit.logger import ( AuditLogger, get_audit_logger, reset_audit_logger, shutdown_audit_logger, ) from app.services.safety.models import ( ActionMetadata, ActionRequest, ActionType, AuditEventType, AutonomyLevel, SafetyDecision, ) class TestAuditLoggerInit: """Tests for AuditLogger initialization.""" def test_init_default_values(self): """Test initialization with default values.""" logger = AuditLogger() assert logger._flush_interval == 10.0 assert logger._enable_hash_chain is True assert logger._last_hash is None assert logger._running is False def test_init_custom_values(self): """Test initialization with custom values.""" logger = AuditLogger( max_buffer_size=500, flush_interval_seconds=5.0, enable_hash_chain=False, ) assert logger._flush_interval == 5.0 assert logger._enable_hash_chain is False class TestAuditLoggerLifecycle: """Tests for AuditLogger start/stop.""" @pytest.mark.asyncio async def test_start_creates_flush_task(self): """Test that start creates the periodic flush task.""" logger = AuditLogger(flush_interval_seconds=1.0) await logger.start() assert logger._running is True assert logger._flush_task is not None await logger.stop() @pytest.mark.asyncio async def test_start_idempotent(self): """Test that multiple starts don't create multiple tasks.""" logger = AuditLogger() await logger.start() task1 = logger._flush_task await logger.start() # Second start task2 = logger._flush_task assert task1 is task2 await logger.stop() @pytest.mark.asyncio async def test_stop_cancels_task_and_flushes(self): """Test that stop cancels the task and flushes events.""" logger = AuditLogger() await logger.start() # Add an event await logger.log(AuditEventType.ACTION_REQUESTED, agent_id="agent-1") await logger.stop() assert logger._running is False # Event should be flushed assert len(logger._persisted) == 1 @pytest.mark.asyncio async def test_stop_without_start(self): """Test stopping without starting doesn't error.""" logger = AuditLogger() await logger.stop() # Should not raise class TestAuditLoggerLog: """Tests for the log method.""" @pytest_asyncio.fixture async def logger(self): """Create a logger instance.""" return AuditLogger(enable_hash_chain=True) @pytest.mark.asyncio async def test_log_creates_event(self, logger): """Test logging creates an event.""" event = await logger.log( AuditEventType.ACTION_REQUESTED, agent_id="agent-1", project_id="proj-1", ) assert event.event_type == AuditEventType.ACTION_REQUESTED assert event.agent_id == "agent-1" assert event.project_id == "proj-1" assert event.id is not None assert event.timestamp is not None @pytest.mark.asyncio async def test_log_adds_hash_chain(self, logger): """Test logging adds hash chain.""" event = await logger.log( AuditEventType.ACTION_REQUESTED, agent_id="agent-1", ) assert "_hash" in event.details assert "_prev_hash" in event.details assert event.details["_prev_hash"] is None # First event @pytest.mark.asyncio async def test_log_chain_links_events(self, logger): """Test hash chain links events.""" event1 = await logger.log(AuditEventType.ACTION_REQUESTED) event2 = await logger.log(AuditEventType.ACTION_EXECUTED) assert event2.details["_prev_hash"] == event1.details["_hash"] @pytest.mark.asyncio async def test_log_without_hash_chain(self): """Test logging without hash chain.""" logger = AuditLogger(enable_hash_chain=False) event = await logger.log(AuditEventType.ACTION_REQUESTED) assert "_hash" not in event.details assert "_prev_hash" not in event.details @pytest.mark.asyncio async def test_log_with_all_fields(self, logger): """Test logging with all optional fields.""" event = await logger.log( AuditEventType.ACTION_EXECUTED, agent_id="agent-1", action_id="action-1", project_id="proj-1", session_id="sess-1", user_id="user-1", decision=SafetyDecision.ALLOW, details={"custom": "data"}, correlation_id="corr-1", ) assert event.agent_id == "agent-1" assert event.action_id == "action-1" assert event.project_id == "proj-1" assert event.session_id == "sess-1" assert event.user_id == "user-1" assert event.decision == SafetyDecision.ALLOW assert event.details["custom"] == "data" assert event.correlation_id == "corr-1" @pytest.mark.asyncio async def test_log_buffers_event(self, logger): """Test logging adds event to buffer.""" await logger.log(AuditEventType.ACTION_REQUESTED) assert len(logger._buffer) == 1 class TestAuditLoggerConvenienceMethods: """Tests for convenience logging methods.""" @pytest_asyncio.fixture async def logger(self): """Create a logger instance.""" return AuditLogger(enable_hash_chain=False) @pytest_asyncio.fixture def action(self): """Create a test action request.""" metadata = ActionMetadata( agent_id="agent-1", session_id="sess-1", project_id="proj-1", autonomy_level=AutonomyLevel.MILESTONE, user_id="user-1", correlation_id="corr-1", ) return ActionRequest( action_type=ActionType.FILE_WRITE, tool_name="file_write", arguments={"path": "/test.txt"}, resource="/test.txt", metadata=metadata, ) @pytest.mark.asyncio async def test_log_action_request_allowed(self, logger, action): """Test logging allowed action request.""" event = await logger.log_action_request( action, SafetyDecision.ALLOW, reasons=["Within budget"], ) assert event.event_type == AuditEventType.ACTION_VALIDATED assert event.decision == SafetyDecision.ALLOW assert event.details["reasons"] == ["Within budget"] @pytest.mark.asyncio async def test_log_action_request_denied(self, logger, action): """Test logging denied action request.""" event = await logger.log_action_request( action, SafetyDecision.DENY, reasons=["Rate limit exceeded"], ) assert event.event_type == AuditEventType.ACTION_DENIED assert event.decision == SafetyDecision.DENY @pytest.mark.asyncio async def test_log_action_executed_success(self, logger, action): """Test logging successful action execution.""" event = await logger.log_action_executed( action, success=True, execution_time_ms=50.0, ) assert event.event_type == AuditEventType.ACTION_EXECUTED assert event.details["success"] is True assert event.details["execution_time_ms"] == 50.0 assert event.details["error"] is None @pytest.mark.asyncio async def test_log_action_executed_failure(self, logger, action): """Test logging failed action execution.""" event = await logger.log_action_executed( action, success=False, execution_time_ms=100.0, error="File not found", ) assert event.event_type == AuditEventType.ACTION_FAILED assert event.details["success"] is False assert event.details["error"] == "File not found" @pytest.mark.asyncio async def test_log_approval_event(self, logger, action): """Test logging approval event.""" event = await logger.log_approval_event( AuditEventType.APPROVAL_GRANTED, approval_id="approval-1", action=action, decided_by="admin", reason="Approved by admin", ) assert event.event_type == AuditEventType.APPROVAL_GRANTED assert event.details["approval_id"] == "approval-1" assert event.details["decided_by"] == "admin" @pytest.mark.asyncio async def test_log_budget_event(self, logger): """Test logging budget event.""" event = await logger.log_budget_event( AuditEventType.BUDGET_WARNING, agent_id="agent-1", scope="daily", current_usage=8000.0, limit=10000.0, unit="tokens", ) assert event.event_type == AuditEventType.BUDGET_WARNING assert event.details["scope"] == "daily" assert event.details["usage_percent"] == 80.0 @pytest.mark.asyncio async def test_log_budget_event_zero_limit(self, logger): """Test logging budget event with zero limit.""" event = await logger.log_budget_event( AuditEventType.BUDGET_WARNING, agent_id="agent-1", scope="daily", current_usage=100.0, limit=0.0, # Zero limit ) assert event.details["usage_percent"] == 0 @pytest.mark.asyncio async def test_log_emergency_stop(self, logger): """Test logging emergency stop.""" event = await logger.log_emergency_stop( stop_type="global", triggered_by="admin", reason="Security incident", affected_agents=["agent-1", "agent-2"], ) assert event.event_type == AuditEventType.EMERGENCY_STOP assert event.details["stop_type"] == "global" assert event.details["affected_agents"] == ["agent-1", "agent-2"] class TestAuditLoggerFlush: """Tests for flush functionality.""" @pytest.mark.asyncio async def test_flush_persists_events(self): """Test flush moves events to persisted storage.""" logger = AuditLogger(enable_hash_chain=False) await logger.log(AuditEventType.ACTION_REQUESTED) await logger.log(AuditEventType.ACTION_EXECUTED) assert len(logger._buffer) == 2 assert len(logger._persisted) == 0 count = await logger.flush() assert count == 2 assert len(logger._buffer) == 0 assert len(logger._persisted) == 2 @pytest.mark.asyncio async def test_flush_empty_buffer(self): """Test flush with empty buffer.""" logger = AuditLogger() count = await logger.flush() assert count == 0 class TestAuditLoggerQuery: """Tests for query functionality.""" @pytest_asyncio.fixture async def logger_with_events(self): """Create a logger with some test events.""" logger = AuditLogger(enable_hash_chain=False) # Add various events await logger.log( AuditEventType.ACTION_REQUESTED, agent_id="agent-1", project_id="proj-1", ) await logger.log( AuditEventType.ACTION_EXECUTED, agent_id="agent-1", project_id="proj-1", ) await logger.log( AuditEventType.ACTION_DENIED, agent_id="agent-2", project_id="proj-2", ) await logger.log( AuditEventType.BUDGET_WARNING, agent_id="agent-1", project_id="proj-1", user_id="user-1", ) return logger @pytest.mark.asyncio async def test_query_all(self, logger_with_events): """Test querying all events.""" events = await logger_with_events.query() assert len(events) == 4 @pytest.mark.asyncio async def test_query_by_event_type(self, logger_with_events): """Test filtering by event type.""" events = await logger_with_events.query( event_types=[AuditEventType.ACTION_REQUESTED] ) assert len(events) == 1 assert events[0].event_type == AuditEventType.ACTION_REQUESTED @pytest.mark.asyncio async def test_query_by_agent_id(self, logger_with_events): """Test filtering by agent ID.""" events = await logger_with_events.query(agent_id="agent-1") assert len(events) == 3 assert all(e.agent_id == "agent-1" for e in events) @pytest.mark.asyncio async def test_query_by_project_id(self, logger_with_events): """Test filtering by project ID.""" events = await logger_with_events.query(project_id="proj-2") assert len(events) == 1 @pytest.mark.asyncio async def test_query_by_user_id(self, logger_with_events): """Test filtering by user ID.""" events = await logger_with_events.query(user_id="user-1") assert len(events) == 1 assert events[0].event_type == AuditEventType.BUDGET_WARNING @pytest.mark.asyncio async def test_query_with_limit(self, logger_with_events): """Test query with limit.""" events = await logger_with_events.query(limit=2) assert len(events) == 2 @pytest.mark.asyncio async def test_query_with_offset(self, logger_with_events): """Test query with offset.""" all_events = await logger_with_events.query() offset_events = await logger_with_events.query(offset=2) assert len(offset_events) == 2 assert offset_events[0] == all_events[2] @pytest.mark.asyncio async def test_query_by_time_range(self): """Test filtering by time range.""" logger = AuditLogger(enable_hash_chain=False) now = datetime.utcnow() await logger.log(AuditEventType.ACTION_REQUESTED) # Query with start time events = await logger.query( start_time=now - timedelta(seconds=1), end_time=now + timedelta(seconds=1), ) assert len(events) == 1 @pytest.mark.asyncio async def test_query_by_correlation_id(self): """Test filtering by correlation ID.""" logger = AuditLogger(enable_hash_chain=False) await logger.log( AuditEventType.ACTION_REQUESTED, correlation_id="corr-123", ) await logger.log( AuditEventType.ACTION_EXECUTED, correlation_id="corr-456", ) events = await logger.query(correlation_id="corr-123") assert len(events) == 1 assert events[0].correlation_id == "corr-123" @pytest.mark.asyncio async def test_query_combined_filters(self, logger_with_events): """Test combined filters.""" events = await logger_with_events.query( agent_id="agent-1", project_id="proj-1", event_types=[ AuditEventType.ACTION_REQUESTED, AuditEventType.ACTION_EXECUTED, ], ) assert len(events) == 2 @pytest.mark.asyncio async def test_get_action_history(self, logger_with_events): """Test get_action_history method.""" events = await logger_with_events.get_action_history("agent-1") # Should only return action-related events assert len(events) == 2 assert all( e.event_type in {AuditEventType.ACTION_REQUESTED, AuditEventType.ACTION_EXECUTED} for e in events ) class TestAuditLoggerIntegrity: """Tests for hash chain integrity verification.""" @pytest.mark.asyncio async def test_verify_integrity_valid(self): """Test integrity verification with valid chain.""" logger = AuditLogger(enable_hash_chain=True) await logger.log(AuditEventType.ACTION_REQUESTED) await logger.log(AuditEventType.ACTION_EXECUTED) is_valid, issues = await logger.verify_integrity() assert is_valid is True assert len(issues) == 0 @pytest.mark.asyncio async def test_verify_integrity_disabled(self): """Test integrity verification when hash chain disabled.""" logger = AuditLogger(enable_hash_chain=False) await logger.log(AuditEventType.ACTION_REQUESTED) is_valid, issues = await logger.verify_integrity() assert is_valid is True assert len(issues) == 0 @pytest.mark.asyncio async def test_verify_integrity_broken_chain(self): """Test integrity verification with broken chain.""" logger = AuditLogger(enable_hash_chain=True) event1 = await logger.log(AuditEventType.ACTION_REQUESTED) await logger.log(AuditEventType.ACTION_EXECUTED) # Tamper with first event's hash event1.details["_hash"] = "tampered_hash" is_valid, issues = await logger.verify_integrity() assert is_valid is False assert len(issues) > 0 class TestAuditLoggerHandlers: """Tests for event handler management.""" @pytest.mark.asyncio async def test_add_sync_handler(self): """Test adding synchronous handler.""" logger = AuditLogger(enable_hash_chain=False) events_received = [] def handler(event): events_received.append(event) logger.add_handler(handler) await logger.log(AuditEventType.ACTION_REQUESTED) assert len(events_received) == 1 @pytest.mark.asyncio async def test_add_async_handler(self): """Test adding async handler.""" logger = AuditLogger(enable_hash_chain=False) events_received = [] async def handler(event): events_received.append(event) logger.add_handler(handler) await logger.log(AuditEventType.ACTION_REQUESTED) assert len(events_received) == 1 @pytest.mark.asyncio async def test_remove_handler(self): """Test removing handler.""" logger = AuditLogger(enable_hash_chain=False) events_received = [] def handler(event): events_received.append(event) logger.add_handler(handler) await logger.log(AuditEventType.ACTION_REQUESTED) logger.remove_handler(handler) await logger.log(AuditEventType.ACTION_EXECUTED) assert len(events_received) == 1 @pytest.mark.asyncio async def test_handler_error_caught(self): """Test that handler errors are caught.""" logger = AuditLogger(enable_hash_chain=False) def failing_handler(event): raise ValueError("Handler error") logger.add_handler(failing_handler) # Should not raise event = await logger.log(AuditEventType.ACTION_REQUESTED) assert event is not None class TestAuditLoggerSanitization: """Tests for sensitive data sanitization.""" @pytest.mark.asyncio async def test_sanitize_sensitive_keys(self): """Test sanitization of sensitive keys.""" with patch("app.services.safety.audit.logger.get_safety_config") as mock_config: mock_cfg = MagicMock() mock_cfg.audit_retention_days = 30 mock_cfg.audit_include_sensitive = False mock_config.return_value = mock_cfg logger = AuditLogger(enable_hash_chain=False) event = await logger.log( AuditEventType.ACTION_EXECUTED, details={ "password": "secret123", "api_key": "key123", "token": "token123", "normal_field": "visible", }, ) assert event.details["password"] == "[REDACTED]" assert event.details["api_key"] == "[REDACTED]" assert event.details["token"] == "[REDACTED]" assert event.details["normal_field"] == "visible" @pytest.mark.asyncio async def test_sanitize_nested_dict(self): """Test sanitization of nested dictionaries.""" with patch("app.services.safety.audit.logger.get_safety_config") as mock_config: mock_cfg = MagicMock() mock_cfg.audit_retention_days = 30 mock_cfg.audit_include_sensitive = False mock_config.return_value = mock_cfg logger = AuditLogger(enable_hash_chain=False) event = await logger.log( AuditEventType.ACTION_EXECUTED, details={ "config": { "api_secret": "secret", "name": "test", } }, ) assert event.details["config"]["api_secret"] == "[REDACTED]" assert event.details["config"]["name"] == "test" @pytest.mark.asyncio async def test_include_sensitive_when_enabled(self): """Test sensitive data included when enabled.""" with patch("app.services.safety.audit.logger.get_safety_config") as mock_config: mock_cfg = MagicMock() mock_cfg.audit_retention_days = 30 mock_cfg.audit_include_sensitive = True mock_config.return_value = mock_cfg logger = AuditLogger(enable_hash_chain=False) event = await logger.log( AuditEventType.ACTION_EXECUTED, details={"password": "secret123"}, ) assert event.details["password"] == "secret123" class TestAuditLoggerRetention: """Tests for retention policy enforcement.""" @pytest.mark.asyncio async def test_retention_removes_old_events(self): """Test that retention removes old events.""" with patch("app.services.safety.audit.logger.get_safety_config") as mock_config: mock_cfg = MagicMock() mock_cfg.audit_retention_days = 7 mock_cfg.audit_include_sensitive = False mock_config.return_value = mock_cfg logger = AuditLogger(enable_hash_chain=False) # Add an old event directly to persisted from app.services.safety.models import AuditEvent old_event = AuditEvent( id="old-event", event_type=AuditEventType.ACTION_REQUESTED, timestamp=datetime.utcnow() - timedelta(days=10), details={}, ) logger._persisted.append(old_event) # Add a recent event await logger.log(AuditEventType.ACTION_EXECUTED) # Flush will trigger retention enforcement await logger.flush() # Old event should be removed assert len(logger._persisted) == 1 assert logger._persisted[0].id != "old-event" @pytest.mark.asyncio async def test_retention_keeps_recent_events(self): """Test that retention keeps recent events.""" with patch("app.services.safety.audit.logger.get_safety_config") as mock_config: mock_cfg = MagicMock() mock_cfg.audit_retention_days = 7 mock_cfg.audit_include_sensitive = False mock_config.return_value = mock_cfg logger = AuditLogger(enable_hash_chain=False) await logger.log(AuditEventType.ACTION_REQUESTED) await logger.log(AuditEventType.ACTION_EXECUTED) await logger.flush() assert len(logger._persisted) == 2 class TestAuditLoggerSingleton: """Tests for singleton pattern.""" @pytest.mark.asyncio async def test_get_audit_logger_creates_instance(self): """Test get_audit_logger creates singleton.""" reset_audit_logger() logger1 = await get_audit_logger() logger2 = await get_audit_logger() assert logger1 is logger2 await shutdown_audit_logger() @pytest.mark.asyncio async def test_shutdown_audit_logger(self): """Test shutdown_audit_logger stops and clears singleton.""" import app.services.safety.audit.logger as audit_module reset_audit_logger() _logger = await get_audit_logger() await shutdown_audit_logger() assert audit_module._audit_logger is None def test_reset_audit_logger(self): """Test reset_audit_logger clears singleton.""" import app.services.safety.audit.logger as audit_module audit_module._audit_logger = AuditLogger() reset_audit_logger() assert audit_module._audit_logger is None class TestAuditLoggerPeriodicFlush: """Tests for periodic flush background task.""" @pytest.mark.asyncio async def test_periodic_flush_runs(self): """Test periodic flush runs and flushes events.""" logger = AuditLogger(flush_interval_seconds=0.1, enable_hash_chain=False) await logger.start() # Log an event await logger.log(AuditEventType.ACTION_REQUESTED) assert len(logger._buffer) == 1 # Wait for periodic flush await asyncio.sleep(0.15) # Event should be flushed assert len(logger._buffer) == 0 assert len(logger._persisted) == 1 await logger.stop() @pytest.mark.asyncio async def test_periodic_flush_handles_errors(self): """Test periodic flush handles errors gracefully.""" logger = AuditLogger(flush_interval_seconds=0.1) await logger.start() # Mock flush to raise an error original_flush = logger.flush async def failing_flush(): raise Exception("Flush error") logger.flush = failing_flush # Wait for flush attempt await asyncio.sleep(0.15) # Should still be running assert logger._running is True logger.flush = original_flush await logger.stop() class TestAuditLoggerLogging: """Tests for standard logger output.""" @pytest.mark.asyncio async def test_log_warning_for_denied(self): """Test warning level for denied events.""" with patch("app.services.safety.audit.logger.logger") as mock_logger: audit_logger = AuditLogger(enable_hash_chain=False) await audit_logger.log( AuditEventType.ACTION_DENIED, agent_id="agent-1", ) mock_logger.warning.assert_called() @pytest.mark.asyncio async def test_log_error_for_failed(self): """Test error level for failed events.""" with patch("app.services.safety.audit.logger.logger") as mock_logger: audit_logger = AuditLogger(enable_hash_chain=False) await audit_logger.log( AuditEventType.ACTION_FAILED, agent_id="agent-1", ) mock_logger.error.assert_called() @pytest.mark.asyncio async def test_log_info_for_normal(self): """Test info level for normal events.""" with patch("app.services.safety.audit.logger.logger") as mock_logger: audit_logger = AuditLogger(enable_hash_chain=False) await audit_logger.log( AuditEventType.ACTION_EXECUTED, agent_id="agent-1", ) mock_logger.info.assert_called() class TestAuditLoggerEdgeCases: """Tests for edge cases.""" @pytest.mark.asyncio async def test_log_with_none_details(self): """Test logging with None details.""" logger = AuditLogger(enable_hash_chain=False) event = await logger.log( AuditEventType.ACTION_REQUESTED, details=None, ) assert event.details == {} @pytest.mark.asyncio async def test_query_with_action_id(self): """Test querying by action ID.""" logger = AuditLogger(enable_hash_chain=False) await logger.log( AuditEventType.ACTION_REQUESTED, action_id="action-1", ) await logger.log( AuditEventType.ACTION_EXECUTED, action_id="action-2", ) events = await logger.query(action_id="action-1") assert len(events) == 1 assert events[0].action_id == "action-1" @pytest.mark.asyncio async def test_query_with_session_id(self): """Test querying by session ID.""" logger = AuditLogger(enable_hash_chain=False) await logger.log( AuditEventType.ACTION_REQUESTED, session_id="sess-1", ) await logger.log( AuditEventType.ACTION_EXECUTED, session_id="sess-2", ) events = await logger.query(session_id="sess-1") assert len(events) == 1 @pytest.mark.asyncio async def test_query_includes_buffer_and_persisted(self): """Test query includes both buffer and persisted events.""" logger = AuditLogger(enable_hash_chain=False) # Add event to buffer await logger.log(AuditEventType.ACTION_REQUESTED) # Flush to persisted await logger.flush() # Add another to buffer await logger.log(AuditEventType.ACTION_EXECUTED) # Query should return both events = await logger.query() assert len(events) == 2 @pytest.mark.asyncio async def test_remove_nonexistent_handler(self): """Test removing handler that doesn't exist.""" logger = AuditLogger() def handler(event): pass # Should not raise logger.remove_handler(handler) @pytest.mark.asyncio async def test_query_time_filter_excludes_events(self): """Test time filters exclude events correctly.""" logger = AuditLogger(enable_hash_chain=False) await logger.log(AuditEventType.ACTION_REQUESTED) # Query with future start time future = datetime.utcnow() + timedelta(hours=1) events = await logger.query(start_time=future) assert len(events) == 0 @pytest.mark.asyncio async def test_query_end_time_filter(self): """Test end time filter.""" logger = AuditLogger(enable_hash_chain=False) await logger.log(AuditEventType.ACTION_REQUESTED) # Query with past end time past = datetime.utcnow() - timedelta(hours=1) events = await logger.query(end_time=past) assert len(events) == 0