""" Audit Logger Comprehensive audit logging for all safety-related events. Provides tamper detection, structured logging, and compliance support. """ import asyncio import hashlib import json import logging from collections import deque from datetime import datetime, timedelta from typing import Any from uuid import uuid4 from ..config import get_safety_config from ..models import ( ActionRequest, AuditEvent, AuditEventType, SafetyDecision, ) logger = logging.getLogger(__name__) # Sentinel for distinguishing "no argument passed" from "explicitly passing None" _UNSET = object() class AuditLogger: """ Audit logger for safety events. Features: - Structured event logging - In-memory buffer with async flush - Tamper detection via hash chains - Query/search capability - Retention policy enforcement """ def __init__( self, max_buffer_size: int = 1000, flush_interval_seconds: float = 10.0, enable_hash_chain: bool = True, ) -> None: """ Initialize the audit logger. Args: max_buffer_size: Maximum events to buffer before auto-flush flush_interval_seconds: Interval for periodic flush enable_hash_chain: Enable tamper detection via hash chain """ self._buffer: deque[AuditEvent] = deque(maxlen=max_buffer_size) self._persisted: list[AuditEvent] = [] self._flush_interval = flush_interval_seconds self._enable_hash_chain = enable_hash_chain self._last_hash: str | None = None self._lock = asyncio.Lock() self._flush_task: asyncio.Task[None] | None = None self._running = False # Event handlers for real-time processing self._handlers: list[Any] = [] config = get_safety_config() self._retention_days = config.audit_retention_days self._include_sensitive = config.audit_include_sensitive async def start(self) -> None: """Start the audit logger background tasks.""" if self._running: return self._running = True self._flush_task = asyncio.create_task(self._periodic_flush()) logger.info("Audit logger started") async def stop(self) -> None: """Stop the audit logger and flush remaining events.""" self._running = False if self._flush_task: self._flush_task.cancel() try: await self._flush_task except asyncio.CancelledError: pass # Final flush await self.flush() logger.info("Audit logger stopped") async def log( self, event_type: AuditEventType, *, agent_id: str | None = None, action_id: str | None = None, project_id: str | None = None, session_id: str | None = None, user_id: str | None = None, decision: SafetyDecision | None = None, details: dict[str, Any] | None = None, correlation_id: str | None = None, ) -> AuditEvent: """ Log an audit event. Args: event_type: Type of audit event agent_id: Agent ID if applicable action_id: Action ID if applicable project_id: Project ID if applicable session_id: Session ID if applicable user_id: User ID if applicable decision: Safety decision if applicable details: Additional event details correlation_id: Correlation ID for tracing Returns: The created audit event """ # Sanitize sensitive data if needed sanitized_details = self._sanitize_details(details) if details else {} event = AuditEvent( id=str(uuid4()), event_type=event_type, timestamp=datetime.utcnow(), agent_id=agent_id, action_id=action_id, project_id=project_id, session_id=session_id, user_id=user_id, decision=decision, details=sanitized_details, correlation_id=correlation_id, ) async with self._lock: # Add hash chain for tamper detection if self._enable_hash_chain: event_hash = self._compute_hash(event) # Modify event.details directly (not sanitized_details) # to ensure the hash is stored on the actual event event.details["_hash"] = event_hash event.details["_prev_hash"] = self._last_hash self._last_hash = event_hash self._buffer.append(event) # Notify handlers await self._notify_handlers(event) # Log to standard logger as well self._log_to_logger(event) return event async def log_action_request( self, action: ActionRequest, decision: SafetyDecision, reasons: list[str] | None = None, ) -> AuditEvent: """Log an action request with its validation decision.""" event_type = ( AuditEventType.ACTION_DENIED if decision == SafetyDecision.DENY else AuditEventType.ACTION_VALIDATED ) return await self.log( event_type, agent_id=action.metadata.agent_id, action_id=action.id, project_id=action.metadata.project_id, session_id=action.metadata.session_id, user_id=action.metadata.user_id, decision=decision, details={ "action_type": action.action_type.value, "tool_name": action.tool_name, "resource": action.resource, "is_destructive": action.is_destructive, "reasons": reasons or [], }, correlation_id=action.metadata.correlation_id, ) async def log_action_executed( self, action: ActionRequest, success: bool, execution_time_ms: float, error: str | None = None, ) -> AuditEvent: """Log an action execution result.""" event_type = ( AuditEventType.ACTION_EXECUTED if success else AuditEventType.ACTION_FAILED ) return await self.log( event_type, agent_id=action.metadata.agent_id, action_id=action.id, project_id=action.metadata.project_id, session_id=action.metadata.session_id, decision=SafetyDecision.ALLOW if success else SafetyDecision.DENY, details={ "action_type": action.action_type.value, "tool_name": action.tool_name, "success": success, "execution_time_ms": execution_time_ms, "error": error, }, correlation_id=action.metadata.correlation_id, ) async def log_approval_event( self, event_type: AuditEventType, approval_id: str, action: ActionRequest, decided_by: str | None = None, reason: str | None = None, ) -> AuditEvent: """Log an approval-related event.""" return await self.log( event_type, agent_id=action.metadata.agent_id, action_id=action.id, project_id=action.metadata.project_id, session_id=action.metadata.session_id, user_id=decided_by, details={ "approval_id": approval_id, "action_type": action.action_type.value, "tool_name": action.tool_name, "decided_by": decided_by, "reason": reason, }, correlation_id=action.metadata.correlation_id, ) async def log_budget_event( self, event_type: AuditEventType, agent_id: str, scope: str, current_usage: float, limit: float, unit: str = "tokens", ) -> AuditEvent: """Log a budget-related event.""" return await self.log( event_type, agent_id=agent_id, details={ "scope": scope, "current_usage": current_usage, "limit": limit, "unit": unit, "usage_percent": (current_usage / limit * 100) if limit > 0 else 0, }, ) async def log_emergency_stop( self, stop_type: str, triggered_by: str, reason: str, affected_agents: list[str] | None = None, ) -> AuditEvent: """Log an emergency stop event.""" return await self.log( AuditEventType.EMERGENCY_STOP, user_id=triggered_by, details={ "stop_type": stop_type, "triggered_by": triggered_by, "reason": reason, "affected_agents": affected_agents or [], }, ) async def flush(self) -> int: """ Flush buffered events to persistent storage. Returns: Number of events flushed """ async with self._lock: if not self._buffer: return 0 events = list(self._buffer) self._buffer.clear() # Persist events (in production, this would go to database/storage) self._persisted.extend(events) # Enforce retention self._enforce_retention() logger.debug("Flushed %d audit events", len(events)) return len(events) async def query( self, *, event_types: list[AuditEventType] | None = None, agent_id: str | None = None, action_id: str | None = None, project_id: str | None = None, session_id: str | None = None, user_id: str | None = None, start_time: datetime | None = None, end_time: datetime | None = None, correlation_id: str | None = None, limit: int = 100, offset: int = 0, ) -> list[AuditEvent]: """ Query audit events with filters. Args: event_types: Filter by event types agent_id: Filter by agent ID action_id: Filter by action ID project_id: Filter by project ID session_id: Filter by session ID user_id: Filter by user ID start_time: Filter events after this time end_time: Filter events before this time correlation_id: Filter by correlation ID limit: Maximum results to return offset: Result offset for pagination Returns: List of matching audit events """ # Combine buffer and persisted for query all_events = list(self._persisted) + list(self._buffer) results = [] for event in all_events: if event_types and event.event_type not in event_types: continue if agent_id and event.agent_id != agent_id: continue if action_id and event.action_id != action_id: continue if project_id and event.project_id != project_id: continue if session_id and event.session_id != session_id: continue if user_id and event.user_id != user_id: continue if start_time and event.timestamp < start_time: continue if end_time and event.timestamp > end_time: continue if correlation_id and event.correlation_id != correlation_id: continue results.append(event) # Sort by timestamp descending results.sort(key=lambda e: e.timestamp, reverse=True) # Apply pagination return results[offset : offset + limit] async def get_action_history( self, agent_id: str, limit: int = 100, ) -> list[AuditEvent]: """Get action history for an agent.""" return await self.query( agent_id=agent_id, event_types=[ AuditEventType.ACTION_REQUESTED, AuditEventType.ACTION_VALIDATED, AuditEventType.ACTION_DENIED, AuditEventType.ACTION_EXECUTED, AuditEventType.ACTION_FAILED, ], limit=limit, ) async def verify_integrity(self) -> tuple[bool, list[str]]: """ Verify audit log integrity using hash chain. Returns: Tuple of (is_valid, list of issues found) """ if not self._enable_hash_chain: return True, [] issues: list[str] = [] all_events = list(self._persisted) + list(self._buffer) prev_hash: str | None = None for event in sorted(all_events, key=lambda e: e.timestamp): stored_prev = event.details.get("_prev_hash") stored_hash = event.details.get("_hash") if stored_prev != prev_hash: issues.append( f"Hash chain broken at event {event.id}: " f"expected prev_hash={prev_hash}, got {stored_prev}" ) if stored_hash: # Pass prev_hash to compute hash with correct chain position computed = self._compute_hash(event, prev_hash=prev_hash) if computed != stored_hash: issues.append( f"Hash mismatch at event {event.id}: " f"expected {computed}, got {stored_hash}" ) prev_hash = stored_hash return len(issues) == 0, issues def add_handler(self, handler: Any) -> None: """Add a real-time event handler.""" self._handlers.append(handler) def remove_handler(self, handler: Any) -> None: """Remove an event handler.""" if handler in self._handlers: self._handlers.remove(handler) def _sanitize_details(self, details: dict[str, Any]) -> dict[str, Any]: """Sanitize sensitive data from details.""" if self._include_sensitive: return details sanitized: dict[str, Any] = {} sensitive_keys = { "password", "secret", "token", "api_key", "apikey", "auth", "credential", } for key, value in details.items(): lower_key = key.lower() if any(s in lower_key for s in sensitive_keys): sanitized[key] = "[REDACTED]" elif isinstance(value, dict): sanitized[key] = self._sanitize_details(value) else: sanitized[key] = value return sanitized def _compute_hash( self, event: AuditEvent, prev_hash: str | None | object = _UNSET ) -> str: """Compute hash for an event (excluding hash fields). Args: event: The audit event to hash. prev_hash: Optional previous hash to use instead of self._last_hash. Pass this during verification to use the correct chain. Use None explicitly to indicate no previous hash. """ # Use passed prev_hash if explicitly provided, otherwise use instance state effective_prev: str | None = ( self._last_hash if prev_hash is _UNSET else prev_hash # type: ignore[assignment] ) data: dict[str, str | dict[str, str] | None] = { "id": event.id, "event_type": event.event_type.value, "timestamp": event.timestamp.isoformat(), "agent_id": event.agent_id, "action_id": event.action_id, "project_id": event.project_id, "session_id": event.session_id, "user_id": event.user_id, "decision": event.decision.value if event.decision else None, "details": { k: v for k, v in event.details.items() if not k.startswith("_") }, "correlation_id": event.correlation_id, } if effective_prev: data["_prev_hash"] = effective_prev serialized = json.dumps(data, sort_keys=True, default=str) return hashlib.sha256(serialized.encode()).hexdigest() def _log_to_logger(self, event: AuditEvent) -> None: """Log event to standard Python logger.""" log_data = { "audit_event": event.event_type.value, "event_id": event.id, "agent_id": event.agent_id, "action_id": event.action_id, "decision": event.decision.value if event.decision else None, } # Use appropriate log level based on event type if event.event_type in { AuditEventType.ACTION_DENIED, AuditEventType.POLICY_VIOLATION, AuditEventType.EMERGENCY_STOP, }: logger.warning("Audit: %s", log_data) elif event.event_type in { AuditEventType.ACTION_FAILED, AuditEventType.ROLLBACK_FAILED, }: logger.error("Audit: %s", log_data) else: logger.info("Audit: %s", log_data) def _enforce_retention(self) -> None: """Enforce retention policy on persisted events.""" if not self._retention_days: return cutoff = datetime.utcnow() - timedelta(days=self._retention_days) before_count = len(self._persisted) self._persisted = [e for e in self._persisted if e.timestamp >= cutoff] removed = before_count - len(self._persisted) if removed > 0: logger.info("Removed %d expired audit events", removed) async def _periodic_flush(self) -> None: """Background task for periodic flushing.""" while self._running: try: await asyncio.sleep(self._flush_interval) await self.flush() except asyncio.CancelledError: break except Exception as e: logger.error("Error in periodic audit flush: %s", e) async def _notify_handlers(self, event: AuditEvent) -> None: """Notify all registered handlers of a new event.""" for handler in self._handlers: try: if asyncio.iscoroutinefunction(handler): await handler(event) else: handler(event) except Exception as e: logger.error("Error in audit event handler: %s", e) # Singleton instance _audit_logger: AuditLogger | None = None _audit_lock = asyncio.Lock() async def get_audit_logger() -> AuditLogger: """Get the global audit logger instance.""" global _audit_logger async with _audit_lock: if _audit_logger is None: _audit_logger = AuditLogger() await _audit_logger.start() return _audit_logger async def shutdown_audit_logger() -> None: """Shutdown the global audit logger.""" global _audit_logger async with _audit_lock: if _audit_logger is not None: await _audit_logger.stop() _audit_logger = None def reset_audit_logger() -> None: """Reset the audit logger (for testing).""" global _audit_logger _audit_logger = None