Files
syndarix/backend/app/services/safety/audit/logger.py
Felipe Cardoso 60ebeaa582 test(safety): add comprehensive tests for safety framework modules
Add tests to improve backend coverage from 85% to 93%:

- test_audit.py: 60 tests for AuditLogger (20% -> 99%)
  - Hash chain integrity, sanitization, retention, handlers
  - Fixed bug: hash chain modification after event creation
  - Fixed bug: verification not using correct prev_hash

- test_hitl.py: Tests for HITL manager (0% -> 100%)
- test_permissions.py: Tests for permissions manager (0% -> 99%)
- test_rollback.py: Tests for rollback manager (0% -> 100%)
- test_metrics.py: Tests for metrics collector (0% -> 100%)
- test_mcp_integration.py: Tests for MCP safety wrapper (0% -> 100%)
- test_validation.py: Additional cache and edge case tests (76% -> 100%)
- test_scoring.py: Lock cleanup and edge case tests (78% -> 91%)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-04 19:41:54 +01:00

602 lines
19 KiB
Python

"""
Audit Logger
Comprehensive audit logging for all safety-related events.
Provides tamper detection, structured logging, and compliance support.
"""
import asyncio
import hashlib
import json
import logging
from collections import deque
from datetime import datetime, timedelta
from typing import Any
from uuid import uuid4
from ..config import get_safety_config
from ..models import (
ActionRequest,
AuditEvent,
AuditEventType,
SafetyDecision,
)
logger = logging.getLogger(__name__)
# Sentinel for distinguishing "no argument passed" from "explicitly passing None"
_UNSET = object()
class AuditLogger:
"""
Audit logger for safety events.
Features:
- Structured event logging
- In-memory buffer with async flush
- Tamper detection via hash chains
- Query/search capability
- Retention policy enforcement
"""
def __init__(
self,
max_buffer_size: int = 1000,
flush_interval_seconds: float = 10.0,
enable_hash_chain: bool = True,
) -> None:
"""
Initialize the audit logger.
Args:
max_buffer_size: Maximum events to buffer before auto-flush
flush_interval_seconds: Interval for periodic flush
enable_hash_chain: Enable tamper detection via hash chain
"""
self._buffer: deque[AuditEvent] = deque(maxlen=max_buffer_size)
self._persisted: list[AuditEvent] = []
self._flush_interval = flush_interval_seconds
self._enable_hash_chain = enable_hash_chain
self._last_hash: str | None = None
self._lock = asyncio.Lock()
self._flush_task: asyncio.Task[None] | None = None
self._running = False
# Event handlers for real-time processing
self._handlers: list[Any] = []
config = get_safety_config()
self._retention_days = config.audit_retention_days
self._include_sensitive = config.audit_include_sensitive
async def start(self) -> None:
"""Start the audit logger background tasks."""
if self._running:
return
self._running = True
self._flush_task = asyncio.create_task(self._periodic_flush())
logger.info("Audit logger started")
async def stop(self) -> None:
"""Stop the audit logger and flush remaining events."""
self._running = False
if self._flush_task:
self._flush_task.cancel()
try:
await self._flush_task
except asyncio.CancelledError:
pass
# Final flush
await self.flush()
logger.info("Audit logger stopped")
async def log(
self,
event_type: AuditEventType,
*,
agent_id: str | None = None,
action_id: str | None = None,
project_id: str | None = None,
session_id: str | None = None,
user_id: str | None = None,
decision: SafetyDecision | None = None,
details: dict[str, Any] | None = None,
correlation_id: str | None = None,
) -> AuditEvent:
"""
Log an audit event.
Args:
event_type: Type of audit event
agent_id: Agent ID if applicable
action_id: Action ID if applicable
project_id: Project ID if applicable
session_id: Session ID if applicable
user_id: User ID if applicable
decision: Safety decision if applicable
details: Additional event details
correlation_id: Correlation ID for tracing
Returns:
The created audit event
"""
# Sanitize sensitive data if needed
sanitized_details = self._sanitize_details(details) if details else {}
event = AuditEvent(
id=str(uuid4()),
event_type=event_type,
timestamp=datetime.utcnow(),
agent_id=agent_id,
action_id=action_id,
project_id=project_id,
session_id=session_id,
user_id=user_id,
decision=decision,
details=sanitized_details,
correlation_id=correlation_id,
)
async with self._lock:
# Add hash chain for tamper detection
if self._enable_hash_chain:
event_hash = self._compute_hash(event)
# Modify event.details directly (not sanitized_details)
# to ensure the hash is stored on the actual event
event.details["_hash"] = event_hash
event.details["_prev_hash"] = self._last_hash
self._last_hash = event_hash
self._buffer.append(event)
# Notify handlers
await self._notify_handlers(event)
# Log to standard logger as well
self._log_to_logger(event)
return event
async def log_action_request(
self,
action: ActionRequest,
decision: SafetyDecision,
reasons: list[str] | None = None,
) -> AuditEvent:
"""Log an action request with its validation decision."""
event_type = (
AuditEventType.ACTION_DENIED
if decision == SafetyDecision.DENY
else AuditEventType.ACTION_VALIDATED
)
return await self.log(
event_type,
agent_id=action.metadata.agent_id,
action_id=action.id,
project_id=action.metadata.project_id,
session_id=action.metadata.session_id,
user_id=action.metadata.user_id,
decision=decision,
details={
"action_type": action.action_type.value,
"tool_name": action.tool_name,
"resource": action.resource,
"is_destructive": action.is_destructive,
"reasons": reasons or [],
},
correlation_id=action.metadata.correlation_id,
)
async def log_action_executed(
self,
action: ActionRequest,
success: bool,
execution_time_ms: float,
error: str | None = None,
) -> AuditEvent:
"""Log an action execution result."""
event_type = (
AuditEventType.ACTION_EXECUTED if success else AuditEventType.ACTION_FAILED
)
return await self.log(
event_type,
agent_id=action.metadata.agent_id,
action_id=action.id,
project_id=action.metadata.project_id,
session_id=action.metadata.session_id,
decision=SafetyDecision.ALLOW if success else SafetyDecision.DENY,
details={
"action_type": action.action_type.value,
"tool_name": action.tool_name,
"success": success,
"execution_time_ms": execution_time_ms,
"error": error,
},
correlation_id=action.metadata.correlation_id,
)
async def log_approval_event(
self,
event_type: AuditEventType,
approval_id: str,
action: ActionRequest,
decided_by: str | None = None,
reason: str | None = None,
) -> AuditEvent:
"""Log an approval-related event."""
return await self.log(
event_type,
agent_id=action.metadata.agent_id,
action_id=action.id,
project_id=action.metadata.project_id,
session_id=action.metadata.session_id,
user_id=decided_by,
details={
"approval_id": approval_id,
"action_type": action.action_type.value,
"tool_name": action.tool_name,
"decided_by": decided_by,
"reason": reason,
},
correlation_id=action.metadata.correlation_id,
)
async def log_budget_event(
self,
event_type: AuditEventType,
agent_id: str,
scope: str,
current_usage: float,
limit: float,
unit: str = "tokens",
) -> AuditEvent:
"""Log a budget-related event."""
return await self.log(
event_type,
agent_id=agent_id,
details={
"scope": scope,
"current_usage": current_usage,
"limit": limit,
"unit": unit,
"usage_percent": (current_usage / limit * 100) if limit > 0 else 0,
},
)
async def log_emergency_stop(
self,
stop_type: str,
triggered_by: str,
reason: str,
affected_agents: list[str] | None = None,
) -> AuditEvent:
"""Log an emergency stop event."""
return await self.log(
AuditEventType.EMERGENCY_STOP,
user_id=triggered_by,
details={
"stop_type": stop_type,
"triggered_by": triggered_by,
"reason": reason,
"affected_agents": affected_agents or [],
},
)
async def flush(self) -> int:
"""
Flush buffered events to persistent storage.
Returns:
Number of events flushed
"""
async with self._lock:
if not self._buffer:
return 0
events = list(self._buffer)
self._buffer.clear()
# Persist events (in production, this would go to database/storage)
self._persisted.extend(events)
# Enforce retention
self._enforce_retention()
logger.debug("Flushed %d audit events", len(events))
return len(events)
async def query(
self,
*,
event_types: list[AuditEventType] | None = None,
agent_id: str | None = None,
action_id: str | None = None,
project_id: str | None = None,
session_id: str | None = None,
user_id: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
correlation_id: str | None = None,
limit: int = 100,
offset: int = 0,
) -> list[AuditEvent]:
"""
Query audit events with filters.
Args:
event_types: Filter by event types
agent_id: Filter by agent ID
action_id: Filter by action ID
project_id: Filter by project ID
session_id: Filter by session ID
user_id: Filter by user ID
start_time: Filter events after this time
end_time: Filter events before this time
correlation_id: Filter by correlation ID
limit: Maximum results to return
offset: Result offset for pagination
Returns:
List of matching audit events
"""
# Combine buffer and persisted for query
all_events = list(self._persisted) + list(self._buffer)
results = []
for event in all_events:
if event_types and event.event_type not in event_types:
continue
if agent_id and event.agent_id != agent_id:
continue
if action_id and event.action_id != action_id:
continue
if project_id and event.project_id != project_id:
continue
if session_id and event.session_id != session_id:
continue
if user_id and event.user_id != user_id:
continue
if start_time and event.timestamp < start_time:
continue
if end_time and event.timestamp > end_time:
continue
if correlation_id and event.correlation_id != correlation_id:
continue
results.append(event)
# Sort by timestamp descending
results.sort(key=lambda e: e.timestamp, reverse=True)
# Apply pagination
return results[offset : offset + limit]
async def get_action_history(
self,
agent_id: str,
limit: int = 100,
) -> list[AuditEvent]:
"""Get action history for an agent."""
return await self.query(
agent_id=agent_id,
event_types=[
AuditEventType.ACTION_REQUESTED,
AuditEventType.ACTION_VALIDATED,
AuditEventType.ACTION_DENIED,
AuditEventType.ACTION_EXECUTED,
AuditEventType.ACTION_FAILED,
],
limit=limit,
)
async def verify_integrity(self) -> tuple[bool, list[str]]:
"""
Verify audit log integrity using hash chain.
Returns:
Tuple of (is_valid, list of issues found)
"""
if not self._enable_hash_chain:
return True, []
issues: list[str] = []
all_events = list(self._persisted) + list(self._buffer)
prev_hash: str | None = None
for event in sorted(all_events, key=lambda e: e.timestamp):
stored_prev = event.details.get("_prev_hash")
stored_hash = event.details.get("_hash")
if stored_prev != prev_hash:
issues.append(
f"Hash chain broken at event {event.id}: "
f"expected prev_hash={prev_hash}, got {stored_prev}"
)
if stored_hash:
# Pass prev_hash to compute hash with correct chain position
computed = self._compute_hash(event, prev_hash=prev_hash)
if computed != stored_hash:
issues.append(
f"Hash mismatch at event {event.id}: "
f"expected {computed}, got {stored_hash}"
)
prev_hash = stored_hash
return len(issues) == 0, issues
def add_handler(self, handler: Any) -> None:
"""Add a real-time event handler."""
self._handlers.append(handler)
def remove_handler(self, handler: Any) -> None:
"""Remove an event handler."""
if handler in self._handlers:
self._handlers.remove(handler)
def _sanitize_details(self, details: dict[str, Any]) -> dict[str, Any]:
"""Sanitize sensitive data from details."""
if self._include_sensitive:
return details
sanitized: dict[str, Any] = {}
sensitive_keys = {
"password",
"secret",
"token",
"api_key",
"apikey",
"auth",
"credential",
}
for key, value in details.items():
lower_key = key.lower()
if any(s in lower_key for s in sensitive_keys):
sanitized[key] = "[REDACTED]"
elif isinstance(value, dict):
sanitized[key] = self._sanitize_details(value)
else:
sanitized[key] = value
return sanitized
def _compute_hash(
self, event: AuditEvent, prev_hash: str | None | object = _UNSET
) -> str:
"""Compute hash for an event (excluding hash fields).
Args:
event: The audit event to hash.
prev_hash: Optional previous hash to use instead of self._last_hash.
Pass this during verification to use the correct chain.
Use None explicitly to indicate no previous hash.
"""
# Use passed prev_hash if explicitly provided, otherwise use instance state
effective_prev: str | None = (
self._last_hash if prev_hash is _UNSET else prev_hash # type: ignore[assignment]
)
data: dict[str, str | dict[str, str] | None] = {
"id": event.id,
"event_type": event.event_type.value,
"timestamp": event.timestamp.isoformat(),
"agent_id": event.agent_id,
"action_id": event.action_id,
"project_id": event.project_id,
"session_id": event.session_id,
"user_id": event.user_id,
"decision": event.decision.value if event.decision else None,
"details": {
k: v for k, v in event.details.items() if not k.startswith("_")
},
"correlation_id": event.correlation_id,
}
if effective_prev:
data["_prev_hash"] = effective_prev
serialized = json.dumps(data, sort_keys=True, default=str)
return hashlib.sha256(serialized.encode()).hexdigest()
def _log_to_logger(self, event: AuditEvent) -> None:
"""Log event to standard Python logger."""
log_data = {
"audit_event": event.event_type.value,
"event_id": event.id,
"agent_id": event.agent_id,
"action_id": event.action_id,
"decision": event.decision.value if event.decision else None,
}
# Use appropriate log level based on event type
if event.event_type in {
AuditEventType.ACTION_DENIED,
AuditEventType.POLICY_VIOLATION,
AuditEventType.EMERGENCY_STOP,
}:
logger.warning("Audit: %s", log_data)
elif event.event_type in {
AuditEventType.ACTION_FAILED,
AuditEventType.ROLLBACK_FAILED,
}:
logger.error("Audit: %s", log_data)
else:
logger.info("Audit: %s", log_data)
def _enforce_retention(self) -> None:
"""Enforce retention policy on persisted events."""
if not self._retention_days:
return
cutoff = datetime.utcnow() - timedelta(days=self._retention_days)
before_count = len(self._persisted)
self._persisted = [e for e in self._persisted if e.timestamp >= cutoff]
removed = before_count - len(self._persisted)
if removed > 0:
logger.info("Removed %d expired audit events", removed)
async def _periodic_flush(self) -> None:
"""Background task for periodic flushing."""
while self._running:
try:
await asyncio.sleep(self._flush_interval)
await self.flush()
except asyncio.CancelledError:
break
except Exception as e:
logger.error("Error in periodic audit flush: %s", e)
async def _notify_handlers(self, event: AuditEvent) -> None:
"""Notify all registered handlers of a new event."""
for handler in self._handlers:
try:
if asyncio.iscoroutinefunction(handler):
await handler(event)
else:
handler(event)
except Exception as e:
logger.error("Error in audit event handler: %s", e)
# Singleton instance
_audit_logger: AuditLogger | None = None
_audit_lock = asyncio.Lock()
async def get_audit_logger() -> AuditLogger:
"""Get the global audit logger instance."""
global _audit_logger
async with _audit_lock:
if _audit_logger is None:
_audit_logger = AuditLogger()
await _audit_logger.start()
return _audit_logger
async def shutdown_audit_logger() -> None:
"""Shutdown the global audit logger."""
global _audit_logger
async with _audit_lock:
if _audit_logger is not None:
await _audit_logger.stop()
_audit_logger = None
def reset_audit_logger() -> None:
"""Reset the audit logger (for testing)."""
global _audit_logger
_audit_logger = None