forked from cardosofelipe/fast-next-template
Add tests to improve backend coverage from 85% to 93%: - test_audit.py: 60 tests for AuditLogger (20% -> 99%) - Hash chain integrity, sanitization, retention, handlers - Fixed bug: hash chain modification after event creation - Fixed bug: verification not using correct prev_hash - test_hitl.py: Tests for HITL manager (0% -> 100%) - test_permissions.py: Tests for permissions manager (0% -> 99%) - test_rollback.py: Tests for rollback manager (0% -> 100%) - test_metrics.py: Tests for metrics collector (0% -> 100%) - test_mcp_integration.py: Tests for MCP safety wrapper (0% -> 100%) - test_validation.py: Additional cache and edge case tests (76% -> 100%) - test_scoring.py: Lock cleanup and edge case tests (78% -> 91%) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
990 lines
30 KiB
Python
990 lines
30 KiB
Python
"""
|
|
Tests for Audit Logger.
|
|
|
|
Tests cover:
|
|
- AuditLogger initialization and lifecycle
|
|
- Event logging and hash chain
|
|
- Query and filtering
|
|
- Retention policy enforcement
|
|
- Handler management
|
|
- Singleton pattern
|
|
"""
|
|
|
|
import asyncio
|
|
from datetime import datetime, timedelta
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
|
|
from app.services.safety.audit.logger import (
|
|
AuditLogger,
|
|
get_audit_logger,
|
|
reset_audit_logger,
|
|
shutdown_audit_logger,
|
|
)
|
|
from app.services.safety.models import (
|
|
ActionMetadata,
|
|
ActionRequest,
|
|
ActionType,
|
|
AuditEventType,
|
|
AutonomyLevel,
|
|
SafetyDecision,
|
|
)
|
|
|
|
|
|
class TestAuditLoggerInit:
|
|
"""Tests for AuditLogger initialization."""
|
|
|
|
def test_init_default_values(self):
|
|
"""Test initialization with default values."""
|
|
logger = AuditLogger()
|
|
|
|
assert logger._flush_interval == 10.0
|
|
assert logger._enable_hash_chain is True
|
|
assert logger._last_hash is None
|
|
assert logger._running is False
|
|
|
|
def test_init_custom_values(self):
|
|
"""Test initialization with custom values."""
|
|
logger = AuditLogger(
|
|
max_buffer_size=500,
|
|
flush_interval_seconds=5.0,
|
|
enable_hash_chain=False,
|
|
)
|
|
|
|
assert logger._flush_interval == 5.0
|
|
assert logger._enable_hash_chain is False
|
|
|
|
|
|
class TestAuditLoggerLifecycle:
|
|
"""Tests for AuditLogger start/stop."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_creates_flush_task(self):
|
|
"""Test that start creates the periodic flush task."""
|
|
logger = AuditLogger(flush_interval_seconds=1.0)
|
|
|
|
await logger.start()
|
|
|
|
assert logger._running is True
|
|
assert logger._flush_task is not None
|
|
|
|
await logger.stop()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_idempotent(self):
|
|
"""Test that multiple starts don't create multiple tasks."""
|
|
logger = AuditLogger()
|
|
|
|
await logger.start()
|
|
task1 = logger._flush_task
|
|
|
|
await logger.start() # Second start
|
|
task2 = logger._flush_task
|
|
|
|
assert task1 is task2
|
|
|
|
await logger.stop()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stop_cancels_task_and_flushes(self):
|
|
"""Test that stop cancels the task and flushes events."""
|
|
logger = AuditLogger()
|
|
|
|
await logger.start()
|
|
|
|
# Add an event
|
|
await logger.log(AuditEventType.ACTION_REQUESTED, agent_id="agent-1")
|
|
|
|
await logger.stop()
|
|
|
|
assert logger._running is False
|
|
# Event should be flushed
|
|
assert len(logger._persisted) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stop_without_start(self):
|
|
"""Test stopping without starting doesn't error."""
|
|
logger = AuditLogger()
|
|
await logger.stop() # Should not raise
|
|
|
|
|
|
class TestAuditLoggerLog:
|
|
"""Tests for the log method."""
|
|
|
|
@pytest_asyncio.fixture
|
|
async def logger(self):
|
|
"""Create a logger instance."""
|
|
return AuditLogger(enable_hash_chain=True)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_log_creates_event(self, logger):
|
|
"""Test logging creates an event."""
|
|
event = await logger.log(
|
|
AuditEventType.ACTION_REQUESTED,
|
|
agent_id="agent-1",
|
|
project_id="proj-1",
|
|
)
|
|
|
|
assert event.event_type == AuditEventType.ACTION_REQUESTED
|
|
assert event.agent_id == "agent-1"
|
|
assert event.project_id == "proj-1"
|
|
assert event.id is not None
|
|
assert event.timestamp is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_log_adds_hash_chain(self, logger):
|
|
"""Test logging adds hash chain."""
|
|
event = await logger.log(
|
|
AuditEventType.ACTION_REQUESTED,
|
|
agent_id="agent-1",
|
|
)
|
|
|
|
assert "_hash" in event.details
|
|
assert "_prev_hash" in event.details
|
|
assert event.details["_prev_hash"] is None # First event
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_log_chain_links_events(self, logger):
|
|
"""Test hash chain links events."""
|
|
event1 = await logger.log(AuditEventType.ACTION_REQUESTED)
|
|
event2 = await logger.log(AuditEventType.ACTION_EXECUTED)
|
|
|
|
assert event2.details["_prev_hash"] == event1.details["_hash"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_log_without_hash_chain(self):
|
|
"""Test logging without hash chain."""
|
|
logger = AuditLogger(enable_hash_chain=False)
|
|
|
|
event = await logger.log(AuditEventType.ACTION_REQUESTED)
|
|
|
|
assert "_hash" not in event.details
|
|
assert "_prev_hash" not in event.details
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_log_with_all_fields(self, logger):
|
|
"""Test logging with all optional fields."""
|
|
event = await logger.log(
|
|
AuditEventType.ACTION_EXECUTED,
|
|
agent_id="agent-1",
|
|
action_id="action-1",
|
|
project_id="proj-1",
|
|
session_id="sess-1",
|
|
user_id="user-1",
|
|
decision=SafetyDecision.ALLOW,
|
|
details={"custom": "data"},
|
|
correlation_id="corr-1",
|
|
)
|
|
|
|
assert event.agent_id == "agent-1"
|
|
assert event.action_id == "action-1"
|
|
assert event.project_id == "proj-1"
|
|
assert event.session_id == "sess-1"
|
|
assert event.user_id == "user-1"
|
|
assert event.decision == SafetyDecision.ALLOW
|
|
assert event.details["custom"] == "data"
|
|
assert event.correlation_id == "corr-1"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_log_buffers_event(self, logger):
|
|
"""Test logging adds event to buffer."""
|
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
|
|
|
assert len(logger._buffer) == 1
|
|
|
|
|
|
class TestAuditLoggerConvenienceMethods:
|
|
"""Tests for convenience logging methods."""
|
|
|
|
@pytest_asyncio.fixture
|
|
async def logger(self):
|
|
"""Create a logger instance."""
|
|
return AuditLogger(enable_hash_chain=False)
|
|
|
|
@pytest_asyncio.fixture
|
|
def action(self):
|
|
"""Create a test action request."""
|
|
metadata = ActionMetadata(
|
|
agent_id="agent-1",
|
|
session_id="sess-1",
|
|
project_id="proj-1",
|
|
autonomy_level=AutonomyLevel.MILESTONE,
|
|
user_id="user-1",
|
|
correlation_id="corr-1",
|
|
)
|
|
|
|
return ActionRequest(
|
|
action_type=ActionType.FILE_WRITE,
|
|
tool_name="file_write",
|
|
arguments={"path": "/test.txt"},
|
|
resource="/test.txt",
|
|
metadata=metadata,
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_log_action_request_allowed(self, logger, action):
|
|
"""Test logging allowed action request."""
|
|
event = await logger.log_action_request(
|
|
action,
|
|
SafetyDecision.ALLOW,
|
|
reasons=["Within budget"],
|
|
)
|
|
|
|
assert event.event_type == AuditEventType.ACTION_VALIDATED
|
|
assert event.decision == SafetyDecision.ALLOW
|
|
assert event.details["reasons"] == ["Within budget"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_log_action_request_denied(self, logger, action):
|
|
"""Test logging denied action request."""
|
|
event = await logger.log_action_request(
|
|
action,
|
|
SafetyDecision.DENY,
|
|
reasons=["Rate limit exceeded"],
|
|
)
|
|
|
|
assert event.event_type == AuditEventType.ACTION_DENIED
|
|
assert event.decision == SafetyDecision.DENY
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_log_action_executed_success(self, logger, action):
|
|
"""Test logging successful action execution."""
|
|
event = await logger.log_action_executed(
|
|
action,
|
|
success=True,
|
|
execution_time_ms=50.0,
|
|
)
|
|
|
|
assert event.event_type == AuditEventType.ACTION_EXECUTED
|
|
assert event.details["success"] is True
|
|
assert event.details["execution_time_ms"] == 50.0
|
|
assert event.details["error"] is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_log_action_executed_failure(self, logger, action):
|
|
"""Test logging failed action execution."""
|
|
event = await logger.log_action_executed(
|
|
action,
|
|
success=False,
|
|
execution_time_ms=100.0,
|
|
error="File not found",
|
|
)
|
|
|
|
assert event.event_type == AuditEventType.ACTION_FAILED
|
|
assert event.details["success"] is False
|
|
assert event.details["error"] == "File not found"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_log_approval_event(self, logger, action):
|
|
"""Test logging approval event."""
|
|
event = await logger.log_approval_event(
|
|
AuditEventType.APPROVAL_GRANTED,
|
|
approval_id="approval-1",
|
|
action=action,
|
|
decided_by="admin",
|
|
reason="Approved by admin",
|
|
)
|
|
|
|
assert event.event_type == AuditEventType.APPROVAL_GRANTED
|
|
assert event.details["approval_id"] == "approval-1"
|
|
assert event.details["decided_by"] == "admin"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_log_budget_event(self, logger):
|
|
"""Test logging budget event."""
|
|
event = await logger.log_budget_event(
|
|
AuditEventType.BUDGET_WARNING,
|
|
agent_id="agent-1",
|
|
scope="daily",
|
|
current_usage=8000.0,
|
|
limit=10000.0,
|
|
unit="tokens",
|
|
)
|
|
|
|
assert event.event_type == AuditEventType.BUDGET_WARNING
|
|
assert event.details["scope"] == "daily"
|
|
assert event.details["usage_percent"] == 80.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_log_budget_event_zero_limit(self, logger):
|
|
"""Test logging budget event with zero limit."""
|
|
event = await logger.log_budget_event(
|
|
AuditEventType.BUDGET_WARNING,
|
|
agent_id="agent-1",
|
|
scope="daily",
|
|
current_usage=100.0,
|
|
limit=0.0, # Zero limit
|
|
)
|
|
|
|
assert event.details["usage_percent"] == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_log_emergency_stop(self, logger):
|
|
"""Test logging emergency stop."""
|
|
event = await logger.log_emergency_stop(
|
|
stop_type="global",
|
|
triggered_by="admin",
|
|
reason="Security incident",
|
|
affected_agents=["agent-1", "agent-2"],
|
|
)
|
|
|
|
assert event.event_type == AuditEventType.EMERGENCY_STOP
|
|
assert event.details["stop_type"] == "global"
|
|
assert event.details["affected_agents"] == ["agent-1", "agent-2"]
|
|
|
|
|
|
class TestAuditLoggerFlush:
|
|
"""Tests for flush functionality."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_flush_persists_events(self):
|
|
"""Test flush moves events to persisted storage."""
|
|
logger = AuditLogger(enable_hash_chain=False)
|
|
|
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
|
await logger.log(AuditEventType.ACTION_EXECUTED)
|
|
|
|
assert len(logger._buffer) == 2
|
|
assert len(logger._persisted) == 0
|
|
|
|
count = await logger.flush()
|
|
|
|
assert count == 2
|
|
assert len(logger._buffer) == 0
|
|
assert len(logger._persisted) == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_flush_empty_buffer(self):
|
|
"""Test flush with empty buffer."""
|
|
logger = AuditLogger()
|
|
|
|
count = await logger.flush()
|
|
|
|
assert count == 0
|
|
|
|
|
|
class TestAuditLoggerQuery:
|
|
"""Tests for query functionality."""
|
|
|
|
@pytest_asyncio.fixture
|
|
async def logger_with_events(self):
|
|
"""Create a logger with some test events."""
|
|
logger = AuditLogger(enable_hash_chain=False)
|
|
|
|
# Add various events
|
|
await logger.log(
|
|
AuditEventType.ACTION_REQUESTED,
|
|
agent_id="agent-1",
|
|
project_id="proj-1",
|
|
)
|
|
await logger.log(
|
|
AuditEventType.ACTION_EXECUTED,
|
|
agent_id="agent-1",
|
|
project_id="proj-1",
|
|
)
|
|
await logger.log(
|
|
AuditEventType.ACTION_DENIED,
|
|
agent_id="agent-2",
|
|
project_id="proj-2",
|
|
)
|
|
await logger.log(
|
|
AuditEventType.BUDGET_WARNING,
|
|
agent_id="agent-1",
|
|
project_id="proj-1",
|
|
user_id="user-1",
|
|
)
|
|
|
|
return logger
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_all(self, logger_with_events):
|
|
"""Test querying all events."""
|
|
events = await logger_with_events.query()
|
|
|
|
assert len(events) == 4
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_by_event_type(self, logger_with_events):
|
|
"""Test filtering by event type."""
|
|
events = await logger_with_events.query(
|
|
event_types=[AuditEventType.ACTION_REQUESTED]
|
|
)
|
|
|
|
assert len(events) == 1
|
|
assert events[0].event_type == AuditEventType.ACTION_REQUESTED
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_by_agent_id(self, logger_with_events):
|
|
"""Test filtering by agent ID."""
|
|
events = await logger_with_events.query(agent_id="agent-1")
|
|
|
|
assert len(events) == 3
|
|
assert all(e.agent_id == "agent-1" for e in events)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_by_project_id(self, logger_with_events):
|
|
"""Test filtering by project ID."""
|
|
events = await logger_with_events.query(project_id="proj-2")
|
|
|
|
assert len(events) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_by_user_id(self, logger_with_events):
|
|
"""Test filtering by user ID."""
|
|
events = await logger_with_events.query(user_id="user-1")
|
|
|
|
assert len(events) == 1
|
|
assert events[0].event_type == AuditEventType.BUDGET_WARNING
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_with_limit(self, logger_with_events):
|
|
"""Test query with limit."""
|
|
events = await logger_with_events.query(limit=2)
|
|
|
|
assert len(events) == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_with_offset(self, logger_with_events):
|
|
"""Test query with offset."""
|
|
all_events = await logger_with_events.query()
|
|
offset_events = await logger_with_events.query(offset=2)
|
|
|
|
assert len(offset_events) == 2
|
|
assert offset_events[0] == all_events[2]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_by_time_range(self):
|
|
"""Test filtering by time range."""
|
|
logger = AuditLogger(enable_hash_chain=False)
|
|
|
|
now = datetime.utcnow()
|
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
|
|
|
# Query with start time
|
|
events = await logger.query(
|
|
start_time=now - timedelta(seconds=1),
|
|
end_time=now + timedelta(seconds=1),
|
|
)
|
|
|
|
assert len(events) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_by_correlation_id(self):
|
|
"""Test filtering by correlation ID."""
|
|
logger = AuditLogger(enable_hash_chain=False)
|
|
|
|
await logger.log(
|
|
AuditEventType.ACTION_REQUESTED,
|
|
correlation_id="corr-123",
|
|
)
|
|
await logger.log(
|
|
AuditEventType.ACTION_EXECUTED,
|
|
correlation_id="corr-456",
|
|
)
|
|
|
|
events = await logger.query(correlation_id="corr-123")
|
|
|
|
assert len(events) == 1
|
|
assert events[0].correlation_id == "corr-123"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_combined_filters(self, logger_with_events):
|
|
"""Test combined filters."""
|
|
events = await logger_with_events.query(
|
|
agent_id="agent-1",
|
|
project_id="proj-1",
|
|
event_types=[
|
|
AuditEventType.ACTION_REQUESTED,
|
|
AuditEventType.ACTION_EXECUTED,
|
|
],
|
|
)
|
|
|
|
assert len(events) == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_action_history(self, logger_with_events):
|
|
"""Test get_action_history method."""
|
|
events = await logger_with_events.get_action_history("agent-1")
|
|
|
|
# Should only return action-related events
|
|
assert len(events) == 2
|
|
assert all(
|
|
e.event_type
|
|
in {AuditEventType.ACTION_REQUESTED, AuditEventType.ACTION_EXECUTED}
|
|
for e in events
|
|
)
|
|
|
|
|
|
class TestAuditLoggerIntegrity:
|
|
"""Tests for hash chain integrity verification."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_verify_integrity_valid(self):
|
|
"""Test integrity verification with valid chain."""
|
|
logger = AuditLogger(enable_hash_chain=True)
|
|
|
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
|
await logger.log(AuditEventType.ACTION_EXECUTED)
|
|
|
|
is_valid, issues = await logger.verify_integrity()
|
|
|
|
assert is_valid is True
|
|
assert len(issues) == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_verify_integrity_disabled(self):
|
|
"""Test integrity verification when hash chain disabled."""
|
|
logger = AuditLogger(enable_hash_chain=False)
|
|
|
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
|
|
|
is_valid, issues = await logger.verify_integrity()
|
|
|
|
assert is_valid is True
|
|
assert len(issues) == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_verify_integrity_broken_chain(self):
|
|
"""Test integrity verification with broken chain."""
|
|
logger = AuditLogger(enable_hash_chain=True)
|
|
|
|
event1 = await logger.log(AuditEventType.ACTION_REQUESTED)
|
|
await logger.log(AuditEventType.ACTION_EXECUTED)
|
|
|
|
# Tamper with first event's hash
|
|
event1.details["_hash"] = "tampered_hash"
|
|
|
|
is_valid, issues = await logger.verify_integrity()
|
|
|
|
assert is_valid is False
|
|
assert len(issues) > 0
|
|
|
|
|
|
class TestAuditLoggerHandlers:
|
|
"""Tests for event handler management."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_add_sync_handler(self):
|
|
"""Test adding synchronous handler."""
|
|
logger = AuditLogger(enable_hash_chain=False)
|
|
events_received = []
|
|
|
|
def handler(event):
|
|
events_received.append(event)
|
|
|
|
logger.add_handler(handler)
|
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
|
|
|
assert len(events_received) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_add_async_handler(self):
|
|
"""Test adding async handler."""
|
|
logger = AuditLogger(enable_hash_chain=False)
|
|
events_received = []
|
|
|
|
async def handler(event):
|
|
events_received.append(event)
|
|
|
|
logger.add_handler(handler)
|
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
|
|
|
assert len(events_received) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_remove_handler(self):
|
|
"""Test removing handler."""
|
|
logger = AuditLogger(enable_hash_chain=False)
|
|
events_received = []
|
|
|
|
def handler(event):
|
|
events_received.append(event)
|
|
|
|
logger.add_handler(handler)
|
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
|
|
|
logger.remove_handler(handler)
|
|
await logger.log(AuditEventType.ACTION_EXECUTED)
|
|
|
|
assert len(events_received) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handler_error_caught(self):
|
|
"""Test that handler errors are caught."""
|
|
logger = AuditLogger(enable_hash_chain=False)
|
|
|
|
def failing_handler(event):
|
|
raise ValueError("Handler error")
|
|
|
|
logger.add_handler(failing_handler)
|
|
|
|
# Should not raise
|
|
event = await logger.log(AuditEventType.ACTION_REQUESTED)
|
|
assert event is not None
|
|
|
|
|
|
class TestAuditLoggerSanitization:
|
|
"""Tests for sensitive data sanitization."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sanitize_sensitive_keys(self):
|
|
"""Test sanitization of sensitive keys."""
|
|
with patch("app.services.safety.audit.logger.get_safety_config") as mock_config:
|
|
mock_cfg = MagicMock()
|
|
mock_cfg.audit_retention_days = 30
|
|
mock_cfg.audit_include_sensitive = False
|
|
mock_config.return_value = mock_cfg
|
|
|
|
logger = AuditLogger(enable_hash_chain=False)
|
|
|
|
event = await logger.log(
|
|
AuditEventType.ACTION_EXECUTED,
|
|
details={
|
|
"password": "secret123",
|
|
"api_key": "key123",
|
|
"token": "token123",
|
|
"normal_field": "visible",
|
|
},
|
|
)
|
|
|
|
assert event.details["password"] == "[REDACTED]"
|
|
assert event.details["api_key"] == "[REDACTED]"
|
|
assert event.details["token"] == "[REDACTED]"
|
|
assert event.details["normal_field"] == "visible"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sanitize_nested_dict(self):
|
|
"""Test sanitization of nested dictionaries."""
|
|
with patch("app.services.safety.audit.logger.get_safety_config") as mock_config:
|
|
mock_cfg = MagicMock()
|
|
mock_cfg.audit_retention_days = 30
|
|
mock_cfg.audit_include_sensitive = False
|
|
mock_config.return_value = mock_cfg
|
|
|
|
logger = AuditLogger(enable_hash_chain=False)
|
|
|
|
event = await logger.log(
|
|
AuditEventType.ACTION_EXECUTED,
|
|
details={
|
|
"config": {
|
|
"api_secret": "secret",
|
|
"name": "test",
|
|
}
|
|
},
|
|
)
|
|
|
|
assert event.details["config"]["api_secret"] == "[REDACTED]"
|
|
assert event.details["config"]["name"] == "test"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_include_sensitive_when_enabled(self):
|
|
"""Test sensitive data included when enabled."""
|
|
with patch("app.services.safety.audit.logger.get_safety_config") as mock_config:
|
|
mock_cfg = MagicMock()
|
|
mock_cfg.audit_retention_days = 30
|
|
mock_cfg.audit_include_sensitive = True
|
|
mock_config.return_value = mock_cfg
|
|
|
|
logger = AuditLogger(enable_hash_chain=False)
|
|
|
|
event = await logger.log(
|
|
AuditEventType.ACTION_EXECUTED,
|
|
details={"password": "secret123"},
|
|
)
|
|
|
|
assert event.details["password"] == "secret123"
|
|
|
|
|
|
class TestAuditLoggerRetention:
|
|
"""Tests for retention policy enforcement."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retention_removes_old_events(self):
|
|
"""Test that retention removes old events."""
|
|
with patch("app.services.safety.audit.logger.get_safety_config") as mock_config:
|
|
mock_cfg = MagicMock()
|
|
mock_cfg.audit_retention_days = 7
|
|
mock_cfg.audit_include_sensitive = False
|
|
mock_config.return_value = mock_cfg
|
|
|
|
logger = AuditLogger(enable_hash_chain=False)
|
|
|
|
# Add an old event directly to persisted
|
|
from app.services.safety.models import AuditEvent
|
|
|
|
old_event = AuditEvent(
|
|
id="old-event",
|
|
event_type=AuditEventType.ACTION_REQUESTED,
|
|
timestamp=datetime.utcnow() - timedelta(days=10),
|
|
details={},
|
|
)
|
|
logger._persisted.append(old_event)
|
|
|
|
# Add a recent event
|
|
await logger.log(AuditEventType.ACTION_EXECUTED)
|
|
|
|
# Flush will trigger retention enforcement
|
|
await logger.flush()
|
|
|
|
# Old event should be removed
|
|
assert len(logger._persisted) == 1
|
|
assert logger._persisted[0].id != "old-event"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retention_keeps_recent_events(self):
|
|
"""Test that retention keeps recent events."""
|
|
with patch("app.services.safety.audit.logger.get_safety_config") as mock_config:
|
|
mock_cfg = MagicMock()
|
|
mock_cfg.audit_retention_days = 7
|
|
mock_cfg.audit_include_sensitive = False
|
|
mock_config.return_value = mock_cfg
|
|
|
|
logger = AuditLogger(enable_hash_chain=False)
|
|
|
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
|
await logger.log(AuditEventType.ACTION_EXECUTED)
|
|
|
|
await logger.flush()
|
|
|
|
assert len(logger._persisted) == 2
|
|
|
|
|
|
class TestAuditLoggerSingleton:
|
|
"""Tests for singleton pattern."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_audit_logger_creates_instance(self):
|
|
"""Test get_audit_logger creates singleton."""
|
|
|
|
reset_audit_logger()
|
|
|
|
logger1 = await get_audit_logger()
|
|
logger2 = await get_audit_logger()
|
|
|
|
assert logger1 is logger2
|
|
|
|
await shutdown_audit_logger()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_shutdown_audit_logger(self):
|
|
"""Test shutdown_audit_logger stops and clears singleton."""
|
|
import app.services.safety.audit.logger as audit_module
|
|
|
|
reset_audit_logger()
|
|
|
|
_logger = await get_audit_logger()
|
|
await shutdown_audit_logger()
|
|
|
|
assert audit_module._audit_logger is None
|
|
|
|
def test_reset_audit_logger(self):
|
|
"""Test reset_audit_logger clears singleton."""
|
|
import app.services.safety.audit.logger as audit_module
|
|
|
|
audit_module._audit_logger = AuditLogger()
|
|
reset_audit_logger()
|
|
|
|
assert audit_module._audit_logger is None
|
|
|
|
|
|
class TestAuditLoggerPeriodicFlush:
|
|
"""Tests for periodic flush background task."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_periodic_flush_runs(self):
|
|
"""Test periodic flush runs and flushes events."""
|
|
logger = AuditLogger(flush_interval_seconds=0.1, enable_hash_chain=False)
|
|
|
|
await logger.start()
|
|
|
|
# Log an event
|
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
|
assert len(logger._buffer) == 1
|
|
|
|
# Wait for periodic flush
|
|
await asyncio.sleep(0.15)
|
|
|
|
# Event should be flushed
|
|
assert len(logger._buffer) == 0
|
|
assert len(logger._persisted) == 1
|
|
|
|
await logger.stop()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_periodic_flush_handles_errors(self):
|
|
"""Test periodic flush handles errors gracefully."""
|
|
logger = AuditLogger(flush_interval_seconds=0.1)
|
|
|
|
await logger.start()
|
|
|
|
# Mock flush to raise an error
|
|
original_flush = logger.flush
|
|
|
|
async def failing_flush():
|
|
raise Exception("Flush error")
|
|
|
|
logger.flush = failing_flush
|
|
|
|
# Wait for flush attempt
|
|
await asyncio.sleep(0.15)
|
|
|
|
# Should still be running
|
|
assert logger._running is True
|
|
|
|
logger.flush = original_flush
|
|
await logger.stop()
|
|
|
|
|
|
class TestAuditLoggerLogging:
|
|
"""Tests for standard logger output."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_log_warning_for_denied(self):
|
|
"""Test warning level for denied events."""
|
|
with patch("app.services.safety.audit.logger.logger") as mock_logger:
|
|
audit_logger = AuditLogger(enable_hash_chain=False)
|
|
|
|
await audit_logger.log(
|
|
AuditEventType.ACTION_DENIED,
|
|
agent_id="agent-1",
|
|
)
|
|
|
|
mock_logger.warning.assert_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_log_error_for_failed(self):
|
|
"""Test error level for failed events."""
|
|
with patch("app.services.safety.audit.logger.logger") as mock_logger:
|
|
audit_logger = AuditLogger(enable_hash_chain=False)
|
|
|
|
await audit_logger.log(
|
|
AuditEventType.ACTION_FAILED,
|
|
agent_id="agent-1",
|
|
)
|
|
|
|
mock_logger.error.assert_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_log_info_for_normal(self):
|
|
"""Test info level for normal events."""
|
|
with patch("app.services.safety.audit.logger.logger") as mock_logger:
|
|
audit_logger = AuditLogger(enable_hash_chain=False)
|
|
|
|
await audit_logger.log(
|
|
AuditEventType.ACTION_EXECUTED,
|
|
agent_id="agent-1",
|
|
)
|
|
|
|
mock_logger.info.assert_called()
|
|
|
|
|
|
class TestAuditLoggerEdgeCases:
|
|
"""Tests for edge cases."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_log_with_none_details(self):
|
|
"""Test logging with None details."""
|
|
logger = AuditLogger(enable_hash_chain=False)
|
|
|
|
event = await logger.log(
|
|
AuditEventType.ACTION_REQUESTED,
|
|
details=None,
|
|
)
|
|
|
|
assert event.details == {}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_with_action_id(self):
|
|
"""Test querying by action ID."""
|
|
logger = AuditLogger(enable_hash_chain=False)
|
|
|
|
await logger.log(
|
|
AuditEventType.ACTION_REQUESTED,
|
|
action_id="action-1",
|
|
)
|
|
await logger.log(
|
|
AuditEventType.ACTION_EXECUTED,
|
|
action_id="action-2",
|
|
)
|
|
|
|
events = await logger.query(action_id="action-1")
|
|
|
|
assert len(events) == 1
|
|
assert events[0].action_id == "action-1"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_with_session_id(self):
|
|
"""Test querying by session ID."""
|
|
logger = AuditLogger(enable_hash_chain=False)
|
|
|
|
await logger.log(
|
|
AuditEventType.ACTION_REQUESTED,
|
|
session_id="sess-1",
|
|
)
|
|
await logger.log(
|
|
AuditEventType.ACTION_EXECUTED,
|
|
session_id="sess-2",
|
|
)
|
|
|
|
events = await logger.query(session_id="sess-1")
|
|
|
|
assert len(events) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_includes_buffer_and_persisted(self):
|
|
"""Test query includes both buffer and persisted events."""
|
|
logger = AuditLogger(enable_hash_chain=False)
|
|
|
|
# Add event to buffer
|
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
|
|
|
# Flush to persisted
|
|
await logger.flush()
|
|
|
|
# Add another to buffer
|
|
await logger.log(AuditEventType.ACTION_EXECUTED)
|
|
|
|
# Query should return both
|
|
events = await logger.query()
|
|
|
|
assert len(events) == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_remove_nonexistent_handler(self):
|
|
"""Test removing handler that doesn't exist."""
|
|
logger = AuditLogger()
|
|
|
|
def handler(event):
|
|
pass
|
|
|
|
# Should not raise
|
|
logger.remove_handler(handler)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_time_filter_excludes_events(self):
|
|
"""Test time filters exclude events correctly."""
|
|
logger = AuditLogger(enable_hash_chain=False)
|
|
|
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
|
|
|
# Query with future start time
|
|
future = datetime.utcnow() + timedelta(hours=1)
|
|
events = await logger.query(start_time=future)
|
|
|
|
assert len(events) == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_end_time_filter(self):
|
|
"""Test end time filter."""
|
|
logger = AuditLogger(enable_hash_chain=False)
|
|
|
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
|
|
|
# Query with past end time
|
|
past = datetime.utcnow() - timedelta(hours=1)
|
|
events = await logger.query(end_time=past)
|
|
|
|
assert len(events) == 0
|