forked from cardosofelipe/fast-next-template
Add tests to improve backend coverage from 85% to 93%: - test_audit.py: 60 tests for AuditLogger (20% -> 99%) - Hash chain integrity, sanitization, retention, handlers - Fixed bug: hash chain modification after event creation - Fixed bug: verification not using correct prev_hash - test_hitl.py: Tests for HITL manager (0% -> 100%) - test_permissions.py: Tests for permissions manager (0% -> 99%) - test_rollback.py: Tests for rollback manager (0% -> 100%) - test_metrics.py: Tests for metrics collector (0% -> 100%) - test_mcp_integration.py: Tests for MCP safety wrapper (0% -> 100%) - test_validation.py: Additional cache and edge case tests (76% -> 100%) - test_scoring.py: Lock cleanup and edge case tests (78% -> 91%) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
602 lines
19 KiB
Python
602 lines
19 KiB
Python
"""
|
|
Audit Logger
|
|
|
|
Comprehensive audit logging for all safety-related events.
|
|
Provides tamper detection, structured logging, and compliance support.
|
|
"""
|
|
|
|
import asyncio
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
from collections import deque
|
|
from datetime import datetime, timedelta
|
|
from typing import Any
|
|
from uuid import uuid4
|
|
|
|
from ..config import get_safety_config
|
|
from ..models import (
|
|
ActionRequest,
|
|
AuditEvent,
|
|
AuditEventType,
|
|
SafetyDecision,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Sentinel for distinguishing "no argument passed" from "explicitly passing None"
|
|
_UNSET = object()
|
|
|
|
|
|
class AuditLogger:
|
|
"""
|
|
Audit logger for safety events.
|
|
|
|
Features:
|
|
- Structured event logging
|
|
- In-memory buffer with async flush
|
|
- Tamper detection via hash chains
|
|
- Query/search capability
|
|
- Retention policy enforcement
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
max_buffer_size: int = 1000,
|
|
flush_interval_seconds: float = 10.0,
|
|
enable_hash_chain: bool = True,
|
|
) -> None:
|
|
"""
|
|
Initialize the audit logger.
|
|
|
|
Args:
|
|
max_buffer_size: Maximum events to buffer before auto-flush
|
|
flush_interval_seconds: Interval for periodic flush
|
|
enable_hash_chain: Enable tamper detection via hash chain
|
|
"""
|
|
self._buffer: deque[AuditEvent] = deque(maxlen=max_buffer_size)
|
|
self._persisted: list[AuditEvent] = []
|
|
self._flush_interval = flush_interval_seconds
|
|
self._enable_hash_chain = enable_hash_chain
|
|
self._last_hash: str | None = None
|
|
self._lock = asyncio.Lock()
|
|
self._flush_task: asyncio.Task[None] | None = None
|
|
self._running = False
|
|
|
|
# Event handlers for real-time processing
|
|
self._handlers: list[Any] = []
|
|
|
|
config = get_safety_config()
|
|
self._retention_days = config.audit_retention_days
|
|
self._include_sensitive = config.audit_include_sensitive
|
|
|
|
async def start(self) -> None:
|
|
"""Start the audit logger background tasks."""
|
|
if self._running:
|
|
return
|
|
|
|
self._running = True
|
|
self._flush_task = asyncio.create_task(self._periodic_flush())
|
|
logger.info("Audit logger started")
|
|
|
|
async def stop(self) -> None:
|
|
"""Stop the audit logger and flush remaining events."""
|
|
self._running = False
|
|
|
|
if self._flush_task:
|
|
self._flush_task.cancel()
|
|
try:
|
|
await self._flush_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
# Final flush
|
|
await self.flush()
|
|
logger.info("Audit logger stopped")
|
|
|
|
async def log(
|
|
self,
|
|
event_type: AuditEventType,
|
|
*,
|
|
agent_id: str | None = None,
|
|
action_id: str | None = None,
|
|
project_id: str | None = None,
|
|
session_id: str | None = None,
|
|
user_id: str | None = None,
|
|
decision: SafetyDecision | None = None,
|
|
details: dict[str, Any] | None = None,
|
|
correlation_id: str | None = None,
|
|
) -> AuditEvent:
|
|
"""
|
|
Log an audit event.
|
|
|
|
Args:
|
|
event_type: Type of audit event
|
|
agent_id: Agent ID if applicable
|
|
action_id: Action ID if applicable
|
|
project_id: Project ID if applicable
|
|
session_id: Session ID if applicable
|
|
user_id: User ID if applicable
|
|
decision: Safety decision if applicable
|
|
details: Additional event details
|
|
correlation_id: Correlation ID for tracing
|
|
|
|
Returns:
|
|
The created audit event
|
|
"""
|
|
# Sanitize sensitive data if needed
|
|
sanitized_details = self._sanitize_details(details) if details else {}
|
|
|
|
event = AuditEvent(
|
|
id=str(uuid4()),
|
|
event_type=event_type,
|
|
timestamp=datetime.utcnow(),
|
|
agent_id=agent_id,
|
|
action_id=action_id,
|
|
project_id=project_id,
|
|
session_id=session_id,
|
|
user_id=user_id,
|
|
decision=decision,
|
|
details=sanitized_details,
|
|
correlation_id=correlation_id,
|
|
)
|
|
|
|
async with self._lock:
|
|
# Add hash chain for tamper detection
|
|
if self._enable_hash_chain:
|
|
event_hash = self._compute_hash(event)
|
|
# Modify event.details directly (not sanitized_details)
|
|
# to ensure the hash is stored on the actual event
|
|
event.details["_hash"] = event_hash
|
|
event.details["_prev_hash"] = self._last_hash
|
|
self._last_hash = event_hash
|
|
|
|
self._buffer.append(event)
|
|
|
|
# Notify handlers
|
|
await self._notify_handlers(event)
|
|
|
|
# Log to standard logger as well
|
|
self._log_to_logger(event)
|
|
|
|
return event
|
|
|
|
async def log_action_request(
|
|
self,
|
|
action: ActionRequest,
|
|
decision: SafetyDecision,
|
|
reasons: list[str] | None = None,
|
|
) -> AuditEvent:
|
|
"""Log an action request with its validation decision."""
|
|
event_type = (
|
|
AuditEventType.ACTION_DENIED
|
|
if decision == SafetyDecision.DENY
|
|
else AuditEventType.ACTION_VALIDATED
|
|
)
|
|
|
|
return await self.log(
|
|
event_type,
|
|
agent_id=action.metadata.agent_id,
|
|
action_id=action.id,
|
|
project_id=action.metadata.project_id,
|
|
session_id=action.metadata.session_id,
|
|
user_id=action.metadata.user_id,
|
|
decision=decision,
|
|
details={
|
|
"action_type": action.action_type.value,
|
|
"tool_name": action.tool_name,
|
|
"resource": action.resource,
|
|
"is_destructive": action.is_destructive,
|
|
"reasons": reasons or [],
|
|
},
|
|
correlation_id=action.metadata.correlation_id,
|
|
)
|
|
|
|
async def log_action_executed(
|
|
self,
|
|
action: ActionRequest,
|
|
success: bool,
|
|
execution_time_ms: float,
|
|
error: str | None = None,
|
|
) -> AuditEvent:
|
|
"""Log an action execution result."""
|
|
event_type = (
|
|
AuditEventType.ACTION_EXECUTED if success else AuditEventType.ACTION_FAILED
|
|
)
|
|
|
|
return await self.log(
|
|
event_type,
|
|
agent_id=action.metadata.agent_id,
|
|
action_id=action.id,
|
|
project_id=action.metadata.project_id,
|
|
session_id=action.metadata.session_id,
|
|
decision=SafetyDecision.ALLOW if success else SafetyDecision.DENY,
|
|
details={
|
|
"action_type": action.action_type.value,
|
|
"tool_name": action.tool_name,
|
|
"success": success,
|
|
"execution_time_ms": execution_time_ms,
|
|
"error": error,
|
|
},
|
|
correlation_id=action.metadata.correlation_id,
|
|
)
|
|
|
|
async def log_approval_event(
|
|
self,
|
|
event_type: AuditEventType,
|
|
approval_id: str,
|
|
action: ActionRequest,
|
|
decided_by: str | None = None,
|
|
reason: str | None = None,
|
|
) -> AuditEvent:
|
|
"""Log an approval-related event."""
|
|
return await self.log(
|
|
event_type,
|
|
agent_id=action.metadata.agent_id,
|
|
action_id=action.id,
|
|
project_id=action.metadata.project_id,
|
|
session_id=action.metadata.session_id,
|
|
user_id=decided_by,
|
|
details={
|
|
"approval_id": approval_id,
|
|
"action_type": action.action_type.value,
|
|
"tool_name": action.tool_name,
|
|
"decided_by": decided_by,
|
|
"reason": reason,
|
|
},
|
|
correlation_id=action.metadata.correlation_id,
|
|
)
|
|
|
|
async def log_budget_event(
|
|
self,
|
|
event_type: AuditEventType,
|
|
agent_id: str,
|
|
scope: str,
|
|
current_usage: float,
|
|
limit: float,
|
|
unit: str = "tokens",
|
|
) -> AuditEvent:
|
|
"""Log a budget-related event."""
|
|
return await self.log(
|
|
event_type,
|
|
agent_id=agent_id,
|
|
details={
|
|
"scope": scope,
|
|
"current_usage": current_usage,
|
|
"limit": limit,
|
|
"unit": unit,
|
|
"usage_percent": (current_usage / limit * 100) if limit > 0 else 0,
|
|
},
|
|
)
|
|
|
|
async def log_emergency_stop(
|
|
self,
|
|
stop_type: str,
|
|
triggered_by: str,
|
|
reason: str,
|
|
affected_agents: list[str] | None = None,
|
|
) -> AuditEvent:
|
|
"""Log an emergency stop event."""
|
|
return await self.log(
|
|
AuditEventType.EMERGENCY_STOP,
|
|
user_id=triggered_by,
|
|
details={
|
|
"stop_type": stop_type,
|
|
"triggered_by": triggered_by,
|
|
"reason": reason,
|
|
"affected_agents": affected_agents or [],
|
|
},
|
|
)
|
|
|
|
async def flush(self) -> int:
|
|
"""
|
|
Flush buffered events to persistent storage.
|
|
|
|
Returns:
|
|
Number of events flushed
|
|
"""
|
|
async with self._lock:
|
|
if not self._buffer:
|
|
return 0
|
|
|
|
events = list(self._buffer)
|
|
self._buffer.clear()
|
|
|
|
# Persist events (in production, this would go to database/storage)
|
|
self._persisted.extend(events)
|
|
|
|
# Enforce retention
|
|
self._enforce_retention()
|
|
|
|
logger.debug("Flushed %d audit events", len(events))
|
|
return len(events)
|
|
|
|
async def query(
|
|
self,
|
|
*,
|
|
event_types: list[AuditEventType] | None = None,
|
|
agent_id: str | None = None,
|
|
action_id: str | None = None,
|
|
project_id: str | None = None,
|
|
session_id: str | None = None,
|
|
user_id: str | None = None,
|
|
start_time: datetime | None = None,
|
|
end_time: datetime | None = None,
|
|
correlation_id: str | None = None,
|
|
limit: int = 100,
|
|
offset: int = 0,
|
|
) -> list[AuditEvent]:
|
|
"""
|
|
Query audit events with filters.
|
|
|
|
Args:
|
|
event_types: Filter by event types
|
|
agent_id: Filter by agent ID
|
|
action_id: Filter by action ID
|
|
project_id: Filter by project ID
|
|
session_id: Filter by session ID
|
|
user_id: Filter by user ID
|
|
start_time: Filter events after this time
|
|
end_time: Filter events before this time
|
|
correlation_id: Filter by correlation ID
|
|
limit: Maximum results to return
|
|
offset: Result offset for pagination
|
|
|
|
Returns:
|
|
List of matching audit events
|
|
"""
|
|
# Combine buffer and persisted for query
|
|
all_events = list(self._persisted) + list(self._buffer)
|
|
|
|
results = []
|
|
for event in all_events:
|
|
if event_types and event.event_type not in event_types:
|
|
continue
|
|
if agent_id and event.agent_id != agent_id:
|
|
continue
|
|
if action_id and event.action_id != action_id:
|
|
continue
|
|
if project_id and event.project_id != project_id:
|
|
continue
|
|
if session_id and event.session_id != session_id:
|
|
continue
|
|
if user_id and event.user_id != user_id:
|
|
continue
|
|
if start_time and event.timestamp < start_time:
|
|
continue
|
|
if end_time and event.timestamp > end_time:
|
|
continue
|
|
if correlation_id and event.correlation_id != correlation_id:
|
|
continue
|
|
|
|
results.append(event)
|
|
|
|
# Sort by timestamp descending
|
|
results.sort(key=lambda e: e.timestamp, reverse=True)
|
|
|
|
# Apply pagination
|
|
return results[offset : offset + limit]
|
|
|
|
async def get_action_history(
|
|
self,
|
|
agent_id: str,
|
|
limit: int = 100,
|
|
) -> list[AuditEvent]:
|
|
"""Get action history for an agent."""
|
|
return await self.query(
|
|
agent_id=agent_id,
|
|
event_types=[
|
|
AuditEventType.ACTION_REQUESTED,
|
|
AuditEventType.ACTION_VALIDATED,
|
|
AuditEventType.ACTION_DENIED,
|
|
AuditEventType.ACTION_EXECUTED,
|
|
AuditEventType.ACTION_FAILED,
|
|
],
|
|
limit=limit,
|
|
)
|
|
|
|
async def verify_integrity(self) -> tuple[bool, list[str]]:
|
|
"""
|
|
Verify audit log integrity using hash chain.
|
|
|
|
Returns:
|
|
Tuple of (is_valid, list of issues found)
|
|
"""
|
|
if not self._enable_hash_chain:
|
|
return True, []
|
|
|
|
issues: list[str] = []
|
|
all_events = list(self._persisted) + list(self._buffer)
|
|
|
|
prev_hash: str | None = None
|
|
for event in sorted(all_events, key=lambda e: e.timestamp):
|
|
stored_prev = event.details.get("_prev_hash")
|
|
stored_hash = event.details.get("_hash")
|
|
|
|
if stored_prev != prev_hash:
|
|
issues.append(
|
|
f"Hash chain broken at event {event.id}: "
|
|
f"expected prev_hash={prev_hash}, got {stored_prev}"
|
|
)
|
|
|
|
if stored_hash:
|
|
# Pass prev_hash to compute hash with correct chain position
|
|
computed = self._compute_hash(event, prev_hash=prev_hash)
|
|
if computed != stored_hash:
|
|
issues.append(
|
|
f"Hash mismatch at event {event.id}: "
|
|
f"expected {computed}, got {stored_hash}"
|
|
)
|
|
|
|
prev_hash = stored_hash
|
|
|
|
return len(issues) == 0, issues
|
|
|
|
def add_handler(self, handler: Any) -> None:
|
|
"""Add a real-time event handler."""
|
|
self._handlers.append(handler)
|
|
|
|
def remove_handler(self, handler: Any) -> None:
|
|
"""Remove an event handler."""
|
|
if handler in self._handlers:
|
|
self._handlers.remove(handler)
|
|
|
|
def _sanitize_details(self, details: dict[str, Any]) -> dict[str, Any]:
|
|
"""Sanitize sensitive data from details."""
|
|
if self._include_sensitive:
|
|
return details
|
|
|
|
sanitized: dict[str, Any] = {}
|
|
sensitive_keys = {
|
|
"password",
|
|
"secret",
|
|
"token",
|
|
"api_key",
|
|
"apikey",
|
|
"auth",
|
|
"credential",
|
|
}
|
|
|
|
for key, value in details.items():
|
|
lower_key = key.lower()
|
|
if any(s in lower_key for s in sensitive_keys):
|
|
sanitized[key] = "[REDACTED]"
|
|
elif isinstance(value, dict):
|
|
sanitized[key] = self._sanitize_details(value)
|
|
else:
|
|
sanitized[key] = value
|
|
|
|
return sanitized
|
|
|
|
def _compute_hash(
|
|
self, event: AuditEvent, prev_hash: str | None | object = _UNSET
|
|
) -> str:
|
|
"""Compute hash for an event (excluding hash fields).
|
|
|
|
Args:
|
|
event: The audit event to hash.
|
|
prev_hash: Optional previous hash to use instead of self._last_hash.
|
|
Pass this during verification to use the correct chain.
|
|
Use None explicitly to indicate no previous hash.
|
|
"""
|
|
# Use passed prev_hash if explicitly provided, otherwise use instance state
|
|
effective_prev: str | None = (
|
|
self._last_hash if prev_hash is _UNSET else prev_hash # type: ignore[assignment]
|
|
)
|
|
|
|
data: dict[str, str | dict[str, str] | None] = {
|
|
"id": event.id,
|
|
"event_type": event.event_type.value,
|
|
"timestamp": event.timestamp.isoformat(),
|
|
"agent_id": event.agent_id,
|
|
"action_id": event.action_id,
|
|
"project_id": event.project_id,
|
|
"session_id": event.session_id,
|
|
"user_id": event.user_id,
|
|
"decision": event.decision.value if event.decision else None,
|
|
"details": {
|
|
k: v for k, v in event.details.items() if not k.startswith("_")
|
|
},
|
|
"correlation_id": event.correlation_id,
|
|
}
|
|
|
|
if effective_prev:
|
|
data["_prev_hash"] = effective_prev
|
|
|
|
serialized = json.dumps(data, sort_keys=True, default=str)
|
|
return hashlib.sha256(serialized.encode()).hexdigest()
|
|
|
|
def _log_to_logger(self, event: AuditEvent) -> None:
|
|
"""Log event to standard Python logger."""
|
|
log_data = {
|
|
"audit_event": event.event_type.value,
|
|
"event_id": event.id,
|
|
"agent_id": event.agent_id,
|
|
"action_id": event.action_id,
|
|
"decision": event.decision.value if event.decision else None,
|
|
}
|
|
|
|
# Use appropriate log level based on event type
|
|
if event.event_type in {
|
|
AuditEventType.ACTION_DENIED,
|
|
AuditEventType.POLICY_VIOLATION,
|
|
AuditEventType.EMERGENCY_STOP,
|
|
}:
|
|
logger.warning("Audit: %s", log_data)
|
|
elif event.event_type in {
|
|
AuditEventType.ACTION_FAILED,
|
|
AuditEventType.ROLLBACK_FAILED,
|
|
}:
|
|
logger.error("Audit: %s", log_data)
|
|
else:
|
|
logger.info("Audit: %s", log_data)
|
|
|
|
def _enforce_retention(self) -> None:
|
|
"""Enforce retention policy on persisted events."""
|
|
if not self._retention_days:
|
|
return
|
|
|
|
cutoff = datetime.utcnow() - timedelta(days=self._retention_days)
|
|
before_count = len(self._persisted)
|
|
|
|
self._persisted = [e for e in self._persisted if e.timestamp >= cutoff]
|
|
|
|
removed = before_count - len(self._persisted)
|
|
if removed > 0:
|
|
logger.info("Removed %d expired audit events", removed)
|
|
|
|
async def _periodic_flush(self) -> None:
|
|
"""Background task for periodic flushing."""
|
|
while self._running:
|
|
try:
|
|
await asyncio.sleep(self._flush_interval)
|
|
await self.flush()
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
logger.error("Error in periodic audit flush: %s", e)
|
|
|
|
async def _notify_handlers(self, event: AuditEvent) -> None:
|
|
"""Notify all registered handlers of a new event."""
|
|
for handler in self._handlers:
|
|
try:
|
|
if asyncio.iscoroutinefunction(handler):
|
|
await handler(event)
|
|
else:
|
|
handler(event)
|
|
except Exception as e:
|
|
logger.error("Error in audit event handler: %s", e)
|
|
|
|
|
|
# Singleton instance
|
|
_audit_logger: AuditLogger | None = None
|
|
_audit_lock = asyncio.Lock()
|
|
|
|
|
|
async def get_audit_logger() -> AuditLogger:
|
|
"""Get the global audit logger instance."""
|
|
global _audit_logger
|
|
|
|
async with _audit_lock:
|
|
if _audit_logger is None:
|
|
_audit_logger = AuditLogger()
|
|
await _audit_logger.start()
|
|
|
|
return _audit_logger
|
|
|
|
|
|
async def shutdown_audit_logger() -> None:
|
|
"""Shutdown the global audit logger."""
|
|
global _audit_logger
|
|
|
|
async with _audit_lock:
|
|
if _audit_logger is not None:
|
|
await _audit_logger.stop()
|
|
_audit_logger = None
|
|
|
|
|
|
def reset_audit_logger() -> None:
|
|
"""Reset the audit logger (for testing)."""
|
|
global _audit_logger
|
|
_audit_logger = None
|