Files
syndarix/backend/tests/services/safety/test_audit.py
Felipe Cardoso 60ebeaa582 test(safety): add comprehensive tests for safety framework modules
Add tests to improve backend coverage from 85% to 93%:

- test_audit.py: 60 tests for AuditLogger (20% -> 99%)
  - Hash chain integrity, sanitization, retention, handlers
  - Fixed bug: hash chain modification after event creation
  - Fixed bug: verification not using correct prev_hash

- test_hitl.py: Tests for HITL manager (0% -> 100%)
- test_permissions.py: Tests for permissions manager (0% -> 99%)
- test_rollback.py: Tests for rollback manager (0% -> 100%)
- test_metrics.py: Tests for metrics collector (0% -> 100%)
- test_mcp_integration.py: Tests for MCP safety wrapper (0% -> 100%)
- test_validation.py: Additional cache and edge case tests (76% -> 100%)
- test_scoring.py: Lock cleanup and edge case tests (78% -> 91%)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-04 19:41:54 +01:00

990 lines
30 KiB
Python

"""
Tests for Audit Logger.
Tests cover:
- AuditLogger initialization and lifecycle
- Event logging and hash chain
- Query and filtering
- Retention policy enforcement
- Handler management
- Singleton pattern
"""
import asyncio
from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch
import pytest
import pytest_asyncio
from app.services.safety.audit.logger import (
AuditLogger,
get_audit_logger,
reset_audit_logger,
shutdown_audit_logger,
)
from app.services.safety.models import (
ActionMetadata,
ActionRequest,
ActionType,
AuditEventType,
AutonomyLevel,
SafetyDecision,
)
class TestAuditLoggerInit:
"""Tests for AuditLogger initialization."""
def test_init_default_values(self):
"""Test initialization with default values."""
logger = AuditLogger()
assert logger._flush_interval == 10.0
assert logger._enable_hash_chain is True
assert logger._last_hash is None
assert logger._running is False
def test_init_custom_values(self):
"""Test initialization with custom values."""
logger = AuditLogger(
max_buffer_size=500,
flush_interval_seconds=5.0,
enable_hash_chain=False,
)
assert logger._flush_interval == 5.0
assert logger._enable_hash_chain is False
class TestAuditLoggerLifecycle:
"""Tests for AuditLogger start/stop."""
@pytest.mark.asyncio
async def test_start_creates_flush_task(self):
"""Test that start creates the periodic flush task."""
logger = AuditLogger(flush_interval_seconds=1.0)
await logger.start()
assert logger._running is True
assert logger._flush_task is not None
await logger.stop()
@pytest.mark.asyncio
async def test_start_idempotent(self):
"""Test that multiple starts don't create multiple tasks."""
logger = AuditLogger()
await logger.start()
task1 = logger._flush_task
await logger.start() # Second start
task2 = logger._flush_task
assert task1 is task2
await logger.stop()
@pytest.mark.asyncio
async def test_stop_cancels_task_and_flushes(self):
"""Test that stop cancels the task and flushes events."""
logger = AuditLogger()
await logger.start()
# Add an event
await logger.log(AuditEventType.ACTION_REQUESTED, agent_id="agent-1")
await logger.stop()
assert logger._running is False
# Event should be flushed
assert len(logger._persisted) == 1
@pytest.mark.asyncio
async def test_stop_without_start(self):
"""Test stopping without starting doesn't error."""
logger = AuditLogger()
await logger.stop() # Should not raise
class TestAuditLoggerLog:
"""Tests for the log method."""
@pytest_asyncio.fixture
async def logger(self):
"""Create a logger instance."""
return AuditLogger(enable_hash_chain=True)
@pytest.mark.asyncio
async def test_log_creates_event(self, logger):
"""Test logging creates an event."""
event = await logger.log(
AuditEventType.ACTION_REQUESTED,
agent_id="agent-1",
project_id="proj-1",
)
assert event.event_type == AuditEventType.ACTION_REQUESTED
assert event.agent_id == "agent-1"
assert event.project_id == "proj-1"
assert event.id is not None
assert event.timestamp is not None
@pytest.mark.asyncio
async def test_log_adds_hash_chain(self, logger):
"""Test logging adds hash chain."""
event = await logger.log(
AuditEventType.ACTION_REQUESTED,
agent_id="agent-1",
)
assert "_hash" in event.details
assert "_prev_hash" in event.details
assert event.details["_prev_hash"] is None # First event
@pytest.mark.asyncio
async def test_log_chain_links_events(self, logger):
"""Test hash chain links events."""
event1 = await logger.log(AuditEventType.ACTION_REQUESTED)
event2 = await logger.log(AuditEventType.ACTION_EXECUTED)
assert event2.details["_prev_hash"] == event1.details["_hash"]
@pytest.mark.asyncio
async def test_log_without_hash_chain(self):
"""Test logging without hash chain."""
logger = AuditLogger(enable_hash_chain=False)
event = await logger.log(AuditEventType.ACTION_REQUESTED)
assert "_hash" not in event.details
assert "_prev_hash" not in event.details
@pytest.mark.asyncio
async def test_log_with_all_fields(self, logger):
"""Test logging with all optional fields."""
event = await logger.log(
AuditEventType.ACTION_EXECUTED,
agent_id="agent-1",
action_id="action-1",
project_id="proj-1",
session_id="sess-1",
user_id="user-1",
decision=SafetyDecision.ALLOW,
details={"custom": "data"},
correlation_id="corr-1",
)
assert event.agent_id == "agent-1"
assert event.action_id == "action-1"
assert event.project_id == "proj-1"
assert event.session_id == "sess-1"
assert event.user_id == "user-1"
assert event.decision == SafetyDecision.ALLOW
assert event.details["custom"] == "data"
assert event.correlation_id == "corr-1"
@pytest.mark.asyncio
async def test_log_buffers_event(self, logger):
"""Test logging adds event to buffer."""
await logger.log(AuditEventType.ACTION_REQUESTED)
assert len(logger._buffer) == 1
class TestAuditLoggerConvenienceMethods:
"""Tests for convenience logging methods."""
@pytest_asyncio.fixture
async def logger(self):
"""Create a logger instance."""
return AuditLogger(enable_hash_chain=False)
@pytest_asyncio.fixture
def action(self):
"""Create a test action request."""
metadata = ActionMetadata(
agent_id="agent-1",
session_id="sess-1",
project_id="proj-1",
autonomy_level=AutonomyLevel.MILESTONE,
user_id="user-1",
correlation_id="corr-1",
)
return ActionRequest(
action_type=ActionType.FILE_WRITE,
tool_name="file_write",
arguments={"path": "/test.txt"},
resource="/test.txt",
metadata=metadata,
)
@pytest.mark.asyncio
async def test_log_action_request_allowed(self, logger, action):
"""Test logging allowed action request."""
event = await logger.log_action_request(
action,
SafetyDecision.ALLOW,
reasons=["Within budget"],
)
assert event.event_type == AuditEventType.ACTION_VALIDATED
assert event.decision == SafetyDecision.ALLOW
assert event.details["reasons"] == ["Within budget"]
@pytest.mark.asyncio
async def test_log_action_request_denied(self, logger, action):
"""Test logging denied action request."""
event = await logger.log_action_request(
action,
SafetyDecision.DENY,
reasons=["Rate limit exceeded"],
)
assert event.event_type == AuditEventType.ACTION_DENIED
assert event.decision == SafetyDecision.DENY
@pytest.mark.asyncio
async def test_log_action_executed_success(self, logger, action):
"""Test logging successful action execution."""
event = await logger.log_action_executed(
action,
success=True,
execution_time_ms=50.0,
)
assert event.event_type == AuditEventType.ACTION_EXECUTED
assert event.details["success"] is True
assert event.details["execution_time_ms"] == 50.0
assert event.details["error"] is None
@pytest.mark.asyncio
async def test_log_action_executed_failure(self, logger, action):
"""Test logging failed action execution."""
event = await logger.log_action_executed(
action,
success=False,
execution_time_ms=100.0,
error="File not found",
)
assert event.event_type == AuditEventType.ACTION_FAILED
assert event.details["success"] is False
assert event.details["error"] == "File not found"
@pytest.mark.asyncio
async def test_log_approval_event(self, logger, action):
"""Test logging approval event."""
event = await logger.log_approval_event(
AuditEventType.APPROVAL_GRANTED,
approval_id="approval-1",
action=action,
decided_by="admin",
reason="Approved by admin",
)
assert event.event_type == AuditEventType.APPROVAL_GRANTED
assert event.details["approval_id"] == "approval-1"
assert event.details["decided_by"] == "admin"
@pytest.mark.asyncio
async def test_log_budget_event(self, logger):
"""Test logging budget event."""
event = await logger.log_budget_event(
AuditEventType.BUDGET_WARNING,
agent_id="agent-1",
scope="daily",
current_usage=8000.0,
limit=10000.0,
unit="tokens",
)
assert event.event_type == AuditEventType.BUDGET_WARNING
assert event.details["scope"] == "daily"
assert event.details["usage_percent"] == 80.0
@pytest.mark.asyncio
async def test_log_budget_event_zero_limit(self, logger):
"""Test logging budget event with zero limit."""
event = await logger.log_budget_event(
AuditEventType.BUDGET_WARNING,
agent_id="agent-1",
scope="daily",
current_usage=100.0,
limit=0.0, # Zero limit
)
assert event.details["usage_percent"] == 0
@pytest.mark.asyncio
async def test_log_emergency_stop(self, logger):
"""Test logging emergency stop."""
event = await logger.log_emergency_stop(
stop_type="global",
triggered_by="admin",
reason="Security incident",
affected_agents=["agent-1", "agent-2"],
)
assert event.event_type == AuditEventType.EMERGENCY_STOP
assert event.details["stop_type"] == "global"
assert event.details["affected_agents"] == ["agent-1", "agent-2"]
class TestAuditLoggerFlush:
"""Tests for flush functionality."""
@pytest.mark.asyncio
async def test_flush_persists_events(self):
"""Test flush moves events to persisted storage."""
logger = AuditLogger(enable_hash_chain=False)
await logger.log(AuditEventType.ACTION_REQUESTED)
await logger.log(AuditEventType.ACTION_EXECUTED)
assert len(logger._buffer) == 2
assert len(logger._persisted) == 0
count = await logger.flush()
assert count == 2
assert len(logger._buffer) == 0
assert len(logger._persisted) == 2
@pytest.mark.asyncio
async def test_flush_empty_buffer(self):
"""Test flush with empty buffer."""
logger = AuditLogger()
count = await logger.flush()
assert count == 0
class TestAuditLoggerQuery:
"""Tests for query functionality."""
@pytest_asyncio.fixture
async def logger_with_events(self):
"""Create a logger with some test events."""
logger = AuditLogger(enable_hash_chain=False)
# Add various events
await logger.log(
AuditEventType.ACTION_REQUESTED,
agent_id="agent-1",
project_id="proj-1",
)
await logger.log(
AuditEventType.ACTION_EXECUTED,
agent_id="agent-1",
project_id="proj-1",
)
await logger.log(
AuditEventType.ACTION_DENIED,
agent_id="agent-2",
project_id="proj-2",
)
await logger.log(
AuditEventType.BUDGET_WARNING,
agent_id="agent-1",
project_id="proj-1",
user_id="user-1",
)
return logger
@pytest.mark.asyncio
async def test_query_all(self, logger_with_events):
"""Test querying all events."""
events = await logger_with_events.query()
assert len(events) == 4
@pytest.mark.asyncio
async def test_query_by_event_type(self, logger_with_events):
"""Test filtering by event type."""
events = await logger_with_events.query(
event_types=[AuditEventType.ACTION_REQUESTED]
)
assert len(events) == 1
assert events[0].event_type == AuditEventType.ACTION_REQUESTED
@pytest.mark.asyncio
async def test_query_by_agent_id(self, logger_with_events):
"""Test filtering by agent ID."""
events = await logger_with_events.query(agent_id="agent-1")
assert len(events) == 3
assert all(e.agent_id == "agent-1" for e in events)
@pytest.mark.asyncio
async def test_query_by_project_id(self, logger_with_events):
"""Test filtering by project ID."""
events = await logger_with_events.query(project_id="proj-2")
assert len(events) == 1
@pytest.mark.asyncio
async def test_query_by_user_id(self, logger_with_events):
"""Test filtering by user ID."""
events = await logger_with_events.query(user_id="user-1")
assert len(events) == 1
assert events[0].event_type == AuditEventType.BUDGET_WARNING
@pytest.mark.asyncio
async def test_query_with_limit(self, logger_with_events):
"""Test query with limit."""
events = await logger_with_events.query(limit=2)
assert len(events) == 2
@pytest.mark.asyncio
async def test_query_with_offset(self, logger_with_events):
"""Test query with offset."""
all_events = await logger_with_events.query()
offset_events = await logger_with_events.query(offset=2)
assert len(offset_events) == 2
assert offset_events[0] == all_events[2]
@pytest.mark.asyncio
async def test_query_by_time_range(self):
"""Test filtering by time range."""
logger = AuditLogger(enable_hash_chain=False)
now = datetime.utcnow()
await logger.log(AuditEventType.ACTION_REQUESTED)
# Query with start time
events = await logger.query(
start_time=now - timedelta(seconds=1),
end_time=now + timedelta(seconds=1),
)
assert len(events) == 1
@pytest.mark.asyncio
async def test_query_by_correlation_id(self):
"""Test filtering by correlation ID."""
logger = AuditLogger(enable_hash_chain=False)
await logger.log(
AuditEventType.ACTION_REQUESTED,
correlation_id="corr-123",
)
await logger.log(
AuditEventType.ACTION_EXECUTED,
correlation_id="corr-456",
)
events = await logger.query(correlation_id="corr-123")
assert len(events) == 1
assert events[0].correlation_id == "corr-123"
@pytest.mark.asyncio
async def test_query_combined_filters(self, logger_with_events):
"""Test combined filters."""
events = await logger_with_events.query(
agent_id="agent-1",
project_id="proj-1",
event_types=[
AuditEventType.ACTION_REQUESTED,
AuditEventType.ACTION_EXECUTED,
],
)
assert len(events) == 2
@pytest.mark.asyncio
async def test_get_action_history(self, logger_with_events):
"""Test get_action_history method."""
events = await logger_with_events.get_action_history("agent-1")
# Should only return action-related events
assert len(events) == 2
assert all(
e.event_type
in {AuditEventType.ACTION_REQUESTED, AuditEventType.ACTION_EXECUTED}
for e in events
)
class TestAuditLoggerIntegrity:
"""Tests for hash chain integrity verification."""
@pytest.mark.asyncio
async def test_verify_integrity_valid(self):
"""Test integrity verification with valid chain."""
logger = AuditLogger(enable_hash_chain=True)
await logger.log(AuditEventType.ACTION_REQUESTED)
await logger.log(AuditEventType.ACTION_EXECUTED)
is_valid, issues = await logger.verify_integrity()
assert is_valid is True
assert len(issues) == 0
@pytest.mark.asyncio
async def test_verify_integrity_disabled(self):
"""Test integrity verification when hash chain disabled."""
logger = AuditLogger(enable_hash_chain=False)
await logger.log(AuditEventType.ACTION_REQUESTED)
is_valid, issues = await logger.verify_integrity()
assert is_valid is True
assert len(issues) == 0
@pytest.mark.asyncio
async def test_verify_integrity_broken_chain(self):
"""Test integrity verification with broken chain."""
logger = AuditLogger(enable_hash_chain=True)
event1 = await logger.log(AuditEventType.ACTION_REQUESTED)
await logger.log(AuditEventType.ACTION_EXECUTED)
# Tamper with first event's hash
event1.details["_hash"] = "tampered_hash"
is_valid, issues = await logger.verify_integrity()
assert is_valid is False
assert len(issues) > 0
class TestAuditLoggerHandlers:
"""Tests for event handler management."""
@pytest.mark.asyncio
async def test_add_sync_handler(self):
"""Test adding synchronous handler."""
logger = AuditLogger(enable_hash_chain=False)
events_received = []
def handler(event):
events_received.append(event)
logger.add_handler(handler)
await logger.log(AuditEventType.ACTION_REQUESTED)
assert len(events_received) == 1
@pytest.mark.asyncio
async def test_add_async_handler(self):
"""Test adding async handler."""
logger = AuditLogger(enable_hash_chain=False)
events_received = []
async def handler(event):
events_received.append(event)
logger.add_handler(handler)
await logger.log(AuditEventType.ACTION_REQUESTED)
assert len(events_received) == 1
@pytest.mark.asyncio
async def test_remove_handler(self):
"""Test removing handler."""
logger = AuditLogger(enable_hash_chain=False)
events_received = []
def handler(event):
events_received.append(event)
logger.add_handler(handler)
await logger.log(AuditEventType.ACTION_REQUESTED)
logger.remove_handler(handler)
await logger.log(AuditEventType.ACTION_EXECUTED)
assert len(events_received) == 1
@pytest.mark.asyncio
async def test_handler_error_caught(self):
"""Test that handler errors are caught."""
logger = AuditLogger(enable_hash_chain=False)
def failing_handler(event):
raise ValueError("Handler error")
logger.add_handler(failing_handler)
# Should not raise
event = await logger.log(AuditEventType.ACTION_REQUESTED)
assert event is not None
class TestAuditLoggerSanitization:
"""Tests for sensitive data sanitization."""
@pytest.mark.asyncio
async def test_sanitize_sensitive_keys(self):
"""Test sanitization of sensitive keys."""
with patch("app.services.safety.audit.logger.get_safety_config") as mock_config:
mock_cfg = MagicMock()
mock_cfg.audit_retention_days = 30
mock_cfg.audit_include_sensitive = False
mock_config.return_value = mock_cfg
logger = AuditLogger(enable_hash_chain=False)
event = await logger.log(
AuditEventType.ACTION_EXECUTED,
details={
"password": "secret123",
"api_key": "key123",
"token": "token123",
"normal_field": "visible",
},
)
assert event.details["password"] == "[REDACTED]"
assert event.details["api_key"] == "[REDACTED]"
assert event.details["token"] == "[REDACTED]"
assert event.details["normal_field"] == "visible"
@pytest.mark.asyncio
async def test_sanitize_nested_dict(self):
"""Test sanitization of nested dictionaries."""
with patch("app.services.safety.audit.logger.get_safety_config") as mock_config:
mock_cfg = MagicMock()
mock_cfg.audit_retention_days = 30
mock_cfg.audit_include_sensitive = False
mock_config.return_value = mock_cfg
logger = AuditLogger(enable_hash_chain=False)
event = await logger.log(
AuditEventType.ACTION_EXECUTED,
details={
"config": {
"api_secret": "secret",
"name": "test",
}
},
)
assert event.details["config"]["api_secret"] == "[REDACTED]"
assert event.details["config"]["name"] == "test"
@pytest.mark.asyncio
async def test_include_sensitive_when_enabled(self):
"""Test sensitive data included when enabled."""
with patch("app.services.safety.audit.logger.get_safety_config") as mock_config:
mock_cfg = MagicMock()
mock_cfg.audit_retention_days = 30
mock_cfg.audit_include_sensitive = True
mock_config.return_value = mock_cfg
logger = AuditLogger(enable_hash_chain=False)
event = await logger.log(
AuditEventType.ACTION_EXECUTED,
details={"password": "secret123"},
)
assert event.details["password"] == "secret123"
class TestAuditLoggerRetention:
"""Tests for retention policy enforcement."""
@pytest.mark.asyncio
async def test_retention_removes_old_events(self):
"""Test that retention removes old events."""
with patch("app.services.safety.audit.logger.get_safety_config") as mock_config:
mock_cfg = MagicMock()
mock_cfg.audit_retention_days = 7
mock_cfg.audit_include_sensitive = False
mock_config.return_value = mock_cfg
logger = AuditLogger(enable_hash_chain=False)
# Add an old event directly to persisted
from app.services.safety.models import AuditEvent
old_event = AuditEvent(
id="old-event",
event_type=AuditEventType.ACTION_REQUESTED,
timestamp=datetime.utcnow() - timedelta(days=10),
details={},
)
logger._persisted.append(old_event)
# Add a recent event
await logger.log(AuditEventType.ACTION_EXECUTED)
# Flush will trigger retention enforcement
await logger.flush()
# Old event should be removed
assert len(logger._persisted) == 1
assert logger._persisted[0].id != "old-event"
@pytest.mark.asyncio
async def test_retention_keeps_recent_events(self):
"""Test that retention keeps recent events."""
with patch("app.services.safety.audit.logger.get_safety_config") as mock_config:
mock_cfg = MagicMock()
mock_cfg.audit_retention_days = 7
mock_cfg.audit_include_sensitive = False
mock_config.return_value = mock_cfg
logger = AuditLogger(enable_hash_chain=False)
await logger.log(AuditEventType.ACTION_REQUESTED)
await logger.log(AuditEventType.ACTION_EXECUTED)
await logger.flush()
assert len(logger._persisted) == 2
class TestAuditLoggerSingleton:
"""Tests for singleton pattern."""
@pytest.mark.asyncio
async def test_get_audit_logger_creates_instance(self):
"""Test get_audit_logger creates singleton."""
reset_audit_logger()
logger1 = await get_audit_logger()
logger2 = await get_audit_logger()
assert logger1 is logger2
await shutdown_audit_logger()
@pytest.mark.asyncio
async def test_shutdown_audit_logger(self):
"""Test shutdown_audit_logger stops and clears singleton."""
import app.services.safety.audit.logger as audit_module
reset_audit_logger()
_logger = await get_audit_logger()
await shutdown_audit_logger()
assert audit_module._audit_logger is None
def test_reset_audit_logger(self):
"""Test reset_audit_logger clears singleton."""
import app.services.safety.audit.logger as audit_module
audit_module._audit_logger = AuditLogger()
reset_audit_logger()
assert audit_module._audit_logger is None
class TestAuditLoggerPeriodicFlush:
"""Tests for periodic flush background task."""
@pytest.mark.asyncio
async def test_periodic_flush_runs(self):
"""Test periodic flush runs and flushes events."""
logger = AuditLogger(flush_interval_seconds=0.1, enable_hash_chain=False)
await logger.start()
# Log an event
await logger.log(AuditEventType.ACTION_REQUESTED)
assert len(logger._buffer) == 1
# Wait for periodic flush
await asyncio.sleep(0.15)
# Event should be flushed
assert len(logger._buffer) == 0
assert len(logger._persisted) == 1
await logger.stop()
@pytest.mark.asyncio
async def test_periodic_flush_handles_errors(self):
"""Test periodic flush handles errors gracefully."""
logger = AuditLogger(flush_interval_seconds=0.1)
await logger.start()
# Mock flush to raise an error
original_flush = logger.flush
async def failing_flush():
raise Exception("Flush error")
logger.flush = failing_flush
# Wait for flush attempt
await asyncio.sleep(0.15)
# Should still be running
assert logger._running is True
logger.flush = original_flush
await logger.stop()
class TestAuditLoggerLogging:
"""Tests for standard logger output."""
@pytest.mark.asyncio
async def test_log_warning_for_denied(self):
"""Test warning level for denied events."""
with patch("app.services.safety.audit.logger.logger") as mock_logger:
audit_logger = AuditLogger(enable_hash_chain=False)
await audit_logger.log(
AuditEventType.ACTION_DENIED,
agent_id="agent-1",
)
mock_logger.warning.assert_called()
@pytest.mark.asyncio
async def test_log_error_for_failed(self):
"""Test error level for failed events."""
with patch("app.services.safety.audit.logger.logger") as mock_logger:
audit_logger = AuditLogger(enable_hash_chain=False)
await audit_logger.log(
AuditEventType.ACTION_FAILED,
agent_id="agent-1",
)
mock_logger.error.assert_called()
@pytest.mark.asyncio
async def test_log_info_for_normal(self):
"""Test info level for normal events."""
with patch("app.services.safety.audit.logger.logger") as mock_logger:
audit_logger = AuditLogger(enable_hash_chain=False)
await audit_logger.log(
AuditEventType.ACTION_EXECUTED,
agent_id="agent-1",
)
mock_logger.info.assert_called()
class TestAuditLoggerEdgeCases:
"""Tests for edge cases."""
@pytest.mark.asyncio
async def test_log_with_none_details(self):
"""Test logging with None details."""
logger = AuditLogger(enable_hash_chain=False)
event = await logger.log(
AuditEventType.ACTION_REQUESTED,
details=None,
)
assert event.details == {}
@pytest.mark.asyncio
async def test_query_with_action_id(self):
"""Test querying by action ID."""
logger = AuditLogger(enable_hash_chain=False)
await logger.log(
AuditEventType.ACTION_REQUESTED,
action_id="action-1",
)
await logger.log(
AuditEventType.ACTION_EXECUTED,
action_id="action-2",
)
events = await logger.query(action_id="action-1")
assert len(events) == 1
assert events[0].action_id == "action-1"
@pytest.mark.asyncio
async def test_query_with_session_id(self):
"""Test querying by session ID."""
logger = AuditLogger(enable_hash_chain=False)
await logger.log(
AuditEventType.ACTION_REQUESTED,
session_id="sess-1",
)
await logger.log(
AuditEventType.ACTION_EXECUTED,
session_id="sess-2",
)
events = await logger.query(session_id="sess-1")
assert len(events) == 1
@pytest.mark.asyncio
async def test_query_includes_buffer_and_persisted(self):
"""Test query includes both buffer and persisted events."""
logger = AuditLogger(enable_hash_chain=False)
# Add event to buffer
await logger.log(AuditEventType.ACTION_REQUESTED)
# Flush to persisted
await logger.flush()
# Add another to buffer
await logger.log(AuditEventType.ACTION_EXECUTED)
# Query should return both
events = await logger.query()
assert len(events) == 2
@pytest.mark.asyncio
async def test_remove_nonexistent_handler(self):
"""Test removing handler that doesn't exist."""
logger = AuditLogger()
def handler(event):
pass
# Should not raise
logger.remove_handler(handler)
@pytest.mark.asyncio
async def test_query_time_filter_excludes_events(self):
"""Test time filters exclude events correctly."""
logger = AuditLogger(enable_hash_chain=False)
await logger.log(AuditEventType.ACTION_REQUESTED)
# Query with future start time
future = datetime.utcnow() + timedelta(hours=1)
events = await logger.query(start_time=future)
assert len(events) == 0
@pytest.mark.asyncio
async def test_query_end_time_filter(self):
"""Test end time filter."""
logger = AuditLogger(enable_hash_chain=False)
await logger.log(AuditEventType.ACTION_REQUESTED)
# Query with past end time
past = datetime.utcnow() - timedelta(hours=1)
events = await logger.query(end_time=past)
assert len(events) == 0