feat(backend): add safety framework foundation (Phase A) (#63)
Core safety framework architecture for autonomous agent guardrails: **Core Components:** - SafetyGuardian: Main orchestrator for all safety checks - AuditLogger: Comprehensive audit logging with hash chain tamper detection - SafetyConfig: Pydantic-based configuration - Models: Action requests, validation results, policies, checkpoints **Exception Hierarchy:** - SafetyError base with context preservation - Permission, Budget, RateLimit, Loop errors - Approval workflow errors (Required, Denied, Timeout) - Rollback, Sandbox, Emergency exceptions **Safety Policy System:** - Autonomy level based policies (FULL_CONTROL, MILESTONE, AUTONOMOUS) - Cost limits, rate limits, permission patterns - HITL approval requirements per action type - Configurable loop detection thresholds **Directory Structure:** - validation/, costs/, limits/, loops/ - Control subsystems - permissions/, rollback/, hitl/ - Access and recovery - content/, sandbox/, emergency/ - Protection systems - audit/, policies/ - Logging and configuration Phase A establishes the architecture. Subsystems to be implemented in Phase B-C. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
19
backend/app/services/safety/audit/__init__.py
Normal file
19
backend/app/services/safety/audit/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
Audit System
|
||||
|
||||
Comprehensive audit logging for all safety-related events.
|
||||
"""
|
||||
|
||||
from .logger import (
|
||||
AuditLogger,
|
||||
get_audit_logger,
|
||||
reset_audit_logger,
|
||||
shutdown_audit_logger,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AuditLogger",
|
||||
"get_audit_logger",
|
||||
"reset_audit_logger",
|
||||
"shutdown_audit_logger",
|
||||
]
|
||||
585
backend/app/services/safety/audit/logger.py
Normal file
585
backend/app/services/safety/audit/logger.py
Normal file
@@ -0,0 +1,585 @@
|
||||
"""
|
||||
Audit Logger
|
||||
|
||||
Comprehensive audit logging for all safety-related events.
|
||||
Provides tamper detection, structured logging, and compliance support.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from collections import deque
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from ..config import get_safety_config
|
||||
from ..models import (
|
||||
ActionRequest,
|
||||
AuditEvent,
|
||||
AuditEventType,
|
||||
SafetyDecision,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuditLogger:
|
||||
"""
|
||||
Audit logger for safety events.
|
||||
|
||||
Features:
|
||||
- Structured event logging
|
||||
- In-memory buffer with async flush
|
||||
- Tamper detection via hash chains
|
||||
- Query/search capability
|
||||
- Retention policy enforcement
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_buffer_size: int = 1000,
|
||||
flush_interval_seconds: float = 10.0,
|
||||
enable_hash_chain: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the audit logger.
|
||||
|
||||
Args:
|
||||
max_buffer_size: Maximum events to buffer before auto-flush
|
||||
flush_interval_seconds: Interval for periodic flush
|
||||
enable_hash_chain: Enable tamper detection via hash chain
|
||||
"""
|
||||
self._buffer: deque[AuditEvent] = deque(maxlen=max_buffer_size)
|
||||
self._persisted: list[AuditEvent] = []
|
||||
self._flush_interval = flush_interval_seconds
|
||||
self._enable_hash_chain = enable_hash_chain
|
||||
self._last_hash: str | None = None
|
||||
self._lock = asyncio.Lock()
|
||||
self._flush_task: asyncio.Task[None] | None = None
|
||||
self._running = False
|
||||
|
||||
# Event handlers for real-time processing
|
||||
self._handlers: list[Any] = []
|
||||
|
||||
config = get_safety_config()
|
||||
self._retention_days = config.audit_retention_days
|
||||
self._include_sensitive = config.audit_include_sensitive
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the audit logger background tasks."""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._flush_task = asyncio.create_task(self._periodic_flush())
|
||||
logger.info("Audit logger started")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the audit logger and flush remaining events."""
|
||||
self._running = False
|
||||
|
||||
if self._flush_task:
|
||||
self._flush_task.cancel()
|
||||
try:
|
||||
await self._flush_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Final flush
|
||||
await self.flush()
|
||||
logger.info("Audit logger stopped")
|
||||
|
||||
async def log(
|
||||
self,
|
||||
event_type: AuditEventType,
|
||||
*,
|
||||
agent_id: str | None = None,
|
||||
action_id: str | None = None,
|
||||
project_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
decision: SafetyDecision | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
correlation_id: str | None = None,
|
||||
) -> AuditEvent:
|
||||
"""
|
||||
Log an audit event.
|
||||
|
||||
Args:
|
||||
event_type: Type of audit event
|
||||
agent_id: Agent ID if applicable
|
||||
action_id: Action ID if applicable
|
||||
project_id: Project ID if applicable
|
||||
session_id: Session ID if applicable
|
||||
user_id: User ID if applicable
|
||||
decision: Safety decision if applicable
|
||||
details: Additional event details
|
||||
correlation_id: Correlation ID for tracing
|
||||
|
||||
Returns:
|
||||
The created audit event
|
||||
"""
|
||||
# Sanitize sensitive data if needed
|
||||
sanitized_details = self._sanitize_details(details) if details else {}
|
||||
|
||||
event = AuditEvent(
|
||||
id=str(uuid4()),
|
||||
event_type=event_type,
|
||||
timestamp=datetime.utcnow(),
|
||||
agent_id=agent_id,
|
||||
action_id=action_id,
|
||||
project_id=project_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
decision=decision,
|
||||
details=sanitized_details,
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
# Add hash chain for tamper detection
|
||||
if self._enable_hash_chain:
|
||||
event_hash = self._compute_hash(event)
|
||||
sanitized_details["_hash"] = event_hash
|
||||
sanitized_details["_prev_hash"] = self._last_hash
|
||||
self._last_hash = event_hash
|
||||
|
||||
self._buffer.append(event)
|
||||
|
||||
# Notify handlers
|
||||
await self._notify_handlers(event)
|
||||
|
||||
# Log to standard logger as well
|
||||
self._log_to_logger(event)
|
||||
|
||||
return event
|
||||
|
||||
async def log_action_request(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
decision: SafetyDecision,
|
||||
reasons: list[str] | None = None,
|
||||
) -> AuditEvent:
|
||||
"""Log an action request with its validation decision."""
|
||||
event_type = (
|
||||
AuditEventType.ACTION_DENIED
|
||||
if decision == SafetyDecision.DENY
|
||||
else AuditEventType.ACTION_VALIDATED
|
||||
)
|
||||
|
||||
return await self.log(
|
||||
event_type,
|
||||
agent_id=action.metadata.agent_id,
|
||||
action_id=action.id,
|
||||
project_id=action.metadata.project_id,
|
||||
session_id=action.metadata.session_id,
|
||||
user_id=action.metadata.user_id,
|
||||
decision=decision,
|
||||
details={
|
||||
"action_type": action.action_type.value,
|
||||
"tool_name": action.tool_name,
|
||||
"resource": action.resource,
|
||||
"is_destructive": action.is_destructive,
|
||||
"reasons": reasons or [],
|
||||
},
|
||||
correlation_id=action.metadata.correlation_id,
|
||||
)
|
||||
|
||||
async def log_action_executed(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
success: bool,
|
||||
execution_time_ms: float,
|
||||
error: str | None = None,
|
||||
) -> AuditEvent:
|
||||
"""Log an action execution result."""
|
||||
event_type = (
|
||||
AuditEventType.ACTION_EXECUTED
|
||||
if success
|
||||
else AuditEventType.ACTION_FAILED
|
||||
)
|
||||
|
||||
return await self.log(
|
||||
event_type,
|
||||
agent_id=action.metadata.agent_id,
|
||||
action_id=action.id,
|
||||
project_id=action.metadata.project_id,
|
||||
session_id=action.metadata.session_id,
|
||||
decision=SafetyDecision.ALLOW if success else SafetyDecision.DENY,
|
||||
details={
|
||||
"action_type": action.action_type.value,
|
||||
"tool_name": action.tool_name,
|
||||
"success": success,
|
||||
"execution_time_ms": execution_time_ms,
|
||||
"error": error,
|
||||
},
|
||||
correlation_id=action.metadata.correlation_id,
|
||||
)
|
||||
|
||||
async def log_approval_event(
|
||||
self,
|
||||
event_type: AuditEventType,
|
||||
approval_id: str,
|
||||
action: ActionRequest,
|
||||
decided_by: str | None = None,
|
||||
reason: str | None = None,
|
||||
) -> AuditEvent:
|
||||
"""Log an approval-related event."""
|
||||
return await self.log(
|
||||
event_type,
|
||||
agent_id=action.metadata.agent_id,
|
||||
action_id=action.id,
|
||||
project_id=action.metadata.project_id,
|
||||
session_id=action.metadata.session_id,
|
||||
user_id=decided_by,
|
||||
details={
|
||||
"approval_id": approval_id,
|
||||
"action_type": action.action_type.value,
|
||||
"tool_name": action.tool_name,
|
||||
"decided_by": decided_by,
|
||||
"reason": reason,
|
||||
},
|
||||
correlation_id=action.metadata.correlation_id,
|
||||
)
|
||||
|
||||
async def log_budget_event(
|
||||
self,
|
||||
event_type: AuditEventType,
|
||||
agent_id: str,
|
||||
scope: str,
|
||||
current_usage: float,
|
||||
limit: float,
|
||||
unit: str = "tokens",
|
||||
) -> AuditEvent:
|
||||
"""Log a budget-related event."""
|
||||
return await self.log(
|
||||
event_type,
|
||||
agent_id=agent_id,
|
||||
details={
|
||||
"scope": scope,
|
||||
"current_usage": current_usage,
|
||||
"limit": limit,
|
||||
"unit": unit,
|
||||
"usage_percent": (current_usage / limit * 100) if limit > 0 else 0,
|
||||
},
|
||||
)
|
||||
|
||||
async def log_emergency_stop(
|
||||
self,
|
||||
stop_type: str,
|
||||
triggered_by: str,
|
||||
reason: str,
|
||||
affected_agents: list[str] | None = None,
|
||||
) -> AuditEvent:
|
||||
"""Log an emergency stop event."""
|
||||
return await self.log(
|
||||
AuditEventType.EMERGENCY_STOP,
|
||||
user_id=triggered_by,
|
||||
details={
|
||||
"stop_type": stop_type,
|
||||
"triggered_by": triggered_by,
|
||||
"reason": reason,
|
||||
"affected_agents": affected_agents or [],
|
||||
},
|
||||
)
|
||||
|
||||
async def flush(self) -> int:
|
||||
"""
|
||||
Flush buffered events to persistent storage.
|
||||
|
||||
Returns:
|
||||
Number of events flushed
|
||||
"""
|
||||
async with self._lock:
|
||||
if not self._buffer:
|
||||
return 0
|
||||
|
||||
events = list(self._buffer)
|
||||
self._buffer.clear()
|
||||
|
||||
# Persist events (in production, this would go to database/storage)
|
||||
self._persisted.extend(events)
|
||||
|
||||
# Enforce retention
|
||||
self._enforce_retention()
|
||||
|
||||
logger.debug("Flushed %d audit events", len(events))
|
||||
return len(events)
|
||||
|
||||
async def query(
|
||||
self,
|
||||
*,
|
||||
event_types: list[AuditEventType] | None = None,
|
||||
agent_id: str | None = None,
|
||||
action_id: str | None = None,
|
||||
project_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
correlation_id: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[AuditEvent]:
|
||||
"""
|
||||
Query audit events with filters.
|
||||
|
||||
Args:
|
||||
event_types: Filter by event types
|
||||
agent_id: Filter by agent ID
|
||||
action_id: Filter by action ID
|
||||
project_id: Filter by project ID
|
||||
session_id: Filter by session ID
|
||||
user_id: Filter by user ID
|
||||
start_time: Filter events after this time
|
||||
end_time: Filter events before this time
|
||||
correlation_id: Filter by correlation ID
|
||||
limit: Maximum results to return
|
||||
offset: Result offset for pagination
|
||||
|
||||
Returns:
|
||||
List of matching audit events
|
||||
"""
|
||||
# Combine buffer and persisted for query
|
||||
all_events = list(self._persisted) + list(self._buffer)
|
||||
|
||||
results = []
|
||||
for event in all_events:
|
||||
if event_types and event.event_type not in event_types:
|
||||
continue
|
||||
if agent_id and event.agent_id != agent_id:
|
||||
continue
|
||||
if action_id and event.action_id != action_id:
|
||||
continue
|
||||
if project_id and event.project_id != project_id:
|
||||
continue
|
||||
if session_id and event.session_id != session_id:
|
||||
continue
|
||||
if user_id and event.user_id != user_id:
|
||||
continue
|
||||
if start_time and event.timestamp < start_time:
|
||||
continue
|
||||
if end_time and event.timestamp > end_time:
|
||||
continue
|
||||
if correlation_id and event.correlation_id != correlation_id:
|
||||
continue
|
||||
|
||||
results.append(event)
|
||||
|
||||
# Sort by timestamp descending
|
||||
results.sort(key=lambda e: e.timestamp, reverse=True)
|
||||
|
||||
# Apply pagination
|
||||
return results[offset : offset + limit]
|
||||
|
||||
async def get_action_history(
|
||||
self,
|
||||
agent_id: str,
|
||||
limit: int = 100,
|
||||
) -> list[AuditEvent]:
|
||||
"""Get action history for an agent."""
|
||||
return await self.query(
|
||||
agent_id=agent_id,
|
||||
event_types=[
|
||||
AuditEventType.ACTION_REQUESTED,
|
||||
AuditEventType.ACTION_VALIDATED,
|
||||
AuditEventType.ACTION_DENIED,
|
||||
AuditEventType.ACTION_EXECUTED,
|
||||
AuditEventType.ACTION_FAILED,
|
||||
],
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
async def verify_integrity(self) -> tuple[bool, list[str]]:
|
||||
"""
|
||||
Verify audit log integrity using hash chain.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, list of issues found)
|
||||
"""
|
||||
if not self._enable_hash_chain:
|
||||
return True, []
|
||||
|
||||
issues: list[str] = []
|
||||
all_events = list(self._persisted) + list(self._buffer)
|
||||
|
||||
prev_hash: str | None = None
|
||||
for event in sorted(all_events, key=lambda e: e.timestamp):
|
||||
stored_prev = event.details.get("_prev_hash")
|
||||
stored_hash = event.details.get("_hash")
|
||||
|
||||
if stored_prev != prev_hash:
|
||||
issues.append(
|
||||
f"Hash chain broken at event {event.id}: "
|
||||
f"expected prev_hash={prev_hash}, got {stored_prev}"
|
||||
)
|
||||
|
||||
if stored_hash:
|
||||
computed = self._compute_hash(event)
|
||||
if computed != stored_hash:
|
||||
issues.append(
|
||||
f"Hash mismatch at event {event.id}: "
|
||||
f"expected {computed}, got {stored_hash}"
|
||||
)
|
||||
|
||||
prev_hash = stored_hash
|
||||
|
||||
return len(issues) == 0, issues
|
||||
|
||||
def add_handler(self, handler: Any) -> None:
|
||||
"""Add a real-time event handler."""
|
||||
self._handlers.append(handler)
|
||||
|
||||
def remove_handler(self, handler: Any) -> None:
|
||||
"""Remove an event handler."""
|
||||
if handler in self._handlers:
|
||||
self._handlers.remove(handler)
|
||||
|
||||
def _sanitize_details(self, details: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Sanitize sensitive data from details."""
|
||||
if self._include_sensitive:
|
||||
return details
|
||||
|
||||
sanitized: dict[str, Any] = {}
|
||||
sensitive_keys = {
|
||||
"password",
|
||||
"secret",
|
||||
"token",
|
||||
"api_key",
|
||||
"apikey",
|
||||
"auth",
|
||||
"credential",
|
||||
}
|
||||
|
||||
for key, value in details.items():
|
||||
lower_key = key.lower()
|
||||
if any(s in lower_key for s in sensitive_keys):
|
||||
sanitized[key] = "[REDACTED]"
|
||||
elif isinstance(value, dict):
|
||||
sanitized[key] = self._sanitize_details(value)
|
||||
else:
|
||||
sanitized[key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
def _compute_hash(self, event: AuditEvent) -> str:
|
||||
"""Compute hash for an event (excluding hash fields)."""
|
||||
data = {
|
||||
"id": event.id,
|
||||
"event_type": event.event_type.value,
|
||||
"timestamp": event.timestamp.isoformat(),
|
||||
"agent_id": event.agent_id,
|
||||
"action_id": event.action_id,
|
||||
"project_id": event.project_id,
|
||||
"session_id": event.session_id,
|
||||
"user_id": event.user_id,
|
||||
"decision": event.decision.value if event.decision else None,
|
||||
"details": {
|
||||
k: v
|
||||
for k, v in event.details.items()
|
||||
if not k.startswith("_")
|
||||
},
|
||||
"correlation_id": event.correlation_id,
|
||||
}
|
||||
|
||||
if self._last_hash:
|
||||
data["_prev_hash"] = self._last_hash
|
||||
|
||||
serialized = json.dumps(data, sort_keys=True, default=str)
|
||||
return hashlib.sha256(serialized.encode()).hexdigest()
|
||||
|
||||
def _log_to_logger(self, event: AuditEvent) -> None:
|
||||
"""Log event to standard Python logger."""
|
||||
log_data = {
|
||||
"audit_event": event.event_type.value,
|
||||
"event_id": event.id,
|
||||
"agent_id": event.agent_id,
|
||||
"action_id": event.action_id,
|
||||
"decision": event.decision.value if event.decision else None,
|
||||
}
|
||||
|
||||
# Use appropriate log level based on event type
|
||||
if event.event_type in {
|
||||
AuditEventType.ACTION_DENIED,
|
||||
AuditEventType.POLICY_VIOLATION,
|
||||
AuditEventType.EMERGENCY_STOP,
|
||||
}:
|
||||
logger.warning("Audit: %s", log_data)
|
||||
elif event.event_type in {
|
||||
AuditEventType.ACTION_FAILED,
|
||||
AuditEventType.ROLLBACK_FAILED,
|
||||
}:
|
||||
logger.error("Audit: %s", log_data)
|
||||
else:
|
||||
logger.info("Audit: %s", log_data)
|
||||
|
||||
def _enforce_retention(self) -> None:
|
||||
"""Enforce retention policy on persisted events."""
|
||||
if not self._retention_days:
|
||||
return
|
||||
|
||||
cutoff = datetime.utcnow() - timedelta(days=self._retention_days)
|
||||
before_count = len(self._persisted)
|
||||
|
||||
self._persisted = [e for e in self._persisted if e.timestamp >= cutoff]
|
||||
|
||||
removed = before_count - len(self._persisted)
|
||||
if removed > 0:
|
||||
logger.info("Removed %d expired audit events", removed)
|
||||
|
||||
async def _periodic_flush(self) -> None:
|
||||
"""Background task for periodic flushing."""
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(self._flush_interval)
|
||||
await self.flush()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("Error in periodic audit flush: %s", e)
|
||||
|
||||
async def _notify_handlers(self, event: AuditEvent) -> None:
|
||||
"""Notify all registered handlers of a new event."""
|
||||
for handler in self._handlers:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
await handler(event)
|
||||
else:
|
||||
handler(event)
|
||||
except Exception as e:
|
||||
logger.error("Error in audit event handler: %s", e)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_audit_logger: AuditLogger | None = None
|
||||
_audit_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_audit_logger() -> AuditLogger:
|
||||
"""Get the global audit logger instance."""
|
||||
global _audit_logger
|
||||
|
||||
async with _audit_lock:
|
||||
if _audit_logger is None:
|
||||
_audit_logger = AuditLogger()
|
||||
await _audit_logger.start()
|
||||
|
||||
return _audit_logger
|
||||
|
||||
|
||||
async def shutdown_audit_logger() -> None:
|
||||
"""Shutdown the global audit logger."""
|
||||
global _audit_logger
|
||||
|
||||
async with _audit_lock:
|
||||
if _audit_logger is not None:
|
||||
await _audit_logger.stop()
|
||||
_audit_logger = None
|
||||
|
||||
|
||||
def reset_audit_logger() -> None:
|
||||
"""Reset the audit logger (for testing)."""
|
||||
global _audit_logger
|
||||
_audit_logger = None
|
||||
Reference in New Issue
Block a user