forked from cardosofelipe/pragma-stack
Compare commits
6 Commits
e5975fa5d0
...
c8b88dadc3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c8b88dadc3 | ||
|
|
015f2de6c6 | ||
|
|
f36bfb3781 | ||
|
|
ef659cd72d | ||
|
|
728edd1453 | ||
|
|
498c0a0e94 |
170
backend/app/services/safety/__init__.py
Normal file
170
backend/app/services/safety/__init__.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
Safety and Guardrails Framework
|
||||
|
||||
Comprehensive safety framework for autonomous agent operation.
|
||||
Provides multi-layered protection including:
|
||||
- Pre-execution validation
|
||||
- Cost and budget controls
|
||||
- Rate limiting
|
||||
- Loop detection and prevention
|
||||
- Human-in-the-loop approval
|
||||
- Rollback and checkpointing
|
||||
- Content filtering
|
||||
- Sandboxed execution
|
||||
- Emergency controls
|
||||
- Complete audit trail
|
||||
|
||||
Usage:
|
||||
from app.services.safety import get_safety_guardian, SafetyGuardian
|
||||
|
||||
guardian = await get_safety_guardian()
|
||||
result = await guardian.validate(action_request)
|
||||
|
||||
if result.allowed:
|
||||
# Execute action
|
||||
pass
|
||||
else:
|
||||
# Handle denial
|
||||
print(f"Action denied: {result.reasons}")
|
||||
"""
|
||||
|
||||
# Exceptions
|
||||
# Audit
|
||||
from .audit import (
|
||||
AuditLogger,
|
||||
get_audit_logger,
|
||||
reset_audit_logger,
|
||||
shutdown_audit_logger,
|
||||
)
|
||||
|
||||
# Configuration
|
||||
from .config import (
|
||||
AutonomyConfig,
|
||||
SafetyConfig,
|
||||
get_autonomy_config,
|
||||
get_default_policy,
|
||||
get_policy_for_autonomy_level,
|
||||
get_safety_config,
|
||||
load_policies_from_directory,
|
||||
load_policy_from_file,
|
||||
reset_config_cache,
|
||||
)
|
||||
from .exceptions import (
|
||||
ApprovalDeniedError,
|
||||
ApprovalRequiredError,
|
||||
ApprovalTimeoutError,
|
||||
BudgetExceededError,
|
||||
CheckpointError,
|
||||
ContentFilterError,
|
||||
EmergencyStopError,
|
||||
LoopDetectedError,
|
||||
PermissionDeniedError,
|
||||
PolicyViolationError,
|
||||
RateLimitExceededError,
|
||||
RollbackError,
|
||||
SafetyError,
|
||||
SandboxError,
|
||||
SandboxTimeoutError,
|
||||
ValidationError,
|
||||
)
|
||||
|
||||
# Guardian
|
||||
from .guardian import (
|
||||
SafetyGuardian,
|
||||
get_safety_guardian,
|
||||
reset_safety_guardian,
|
||||
shutdown_safety_guardian,
|
||||
)
|
||||
|
||||
# Models
|
||||
from .models import (
|
||||
ActionMetadata,
|
||||
ActionRequest,
|
||||
ActionResult,
|
||||
ActionType,
|
||||
ApprovalRequest,
|
||||
ApprovalResponse,
|
||||
ApprovalStatus,
|
||||
AuditEvent,
|
||||
AuditEventType,
|
||||
AutonomyLevel,
|
||||
BudgetScope,
|
||||
BudgetStatus,
|
||||
Checkpoint,
|
||||
CheckpointType,
|
||||
GuardianResult,
|
||||
PermissionLevel,
|
||||
RateLimitConfig,
|
||||
RateLimitStatus,
|
||||
ResourceType,
|
||||
RollbackResult,
|
||||
SafetyDecision,
|
||||
SafetyPolicy,
|
||||
ValidationResult,
|
||||
ValidationRule,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ActionMetadata",
|
||||
"ActionRequest",
|
||||
"ActionResult",
|
||||
# Models
|
||||
"ActionType",
|
||||
"ApprovalDeniedError",
|
||||
"ApprovalRequest",
|
||||
"ApprovalRequiredError",
|
||||
"ApprovalResponse",
|
||||
"ApprovalStatus",
|
||||
"ApprovalTimeoutError",
|
||||
"AuditEvent",
|
||||
"AuditEventType",
|
||||
# Audit
|
||||
"AuditLogger",
|
||||
"AutonomyConfig",
|
||||
"AutonomyLevel",
|
||||
"BudgetExceededError",
|
||||
"BudgetScope",
|
||||
"BudgetStatus",
|
||||
"Checkpoint",
|
||||
"CheckpointError",
|
||||
"CheckpointType",
|
||||
"ContentFilterError",
|
||||
"EmergencyStopError",
|
||||
"GuardianResult",
|
||||
"LoopDetectedError",
|
||||
"PermissionDeniedError",
|
||||
"PermissionLevel",
|
||||
"PolicyViolationError",
|
||||
"RateLimitConfig",
|
||||
"RateLimitExceededError",
|
||||
"RateLimitStatus",
|
||||
"ResourceType",
|
||||
"RollbackError",
|
||||
"RollbackResult",
|
||||
# Configuration
|
||||
"SafetyConfig",
|
||||
"SafetyDecision",
|
||||
# Exceptions
|
||||
"SafetyError",
|
||||
# Guardian
|
||||
"SafetyGuardian",
|
||||
"SafetyPolicy",
|
||||
"SandboxError",
|
||||
"SandboxTimeoutError",
|
||||
"ValidationError",
|
||||
"ValidationResult",
|
||||
"ValidationRule",
|
||||
"get_audit_logger",
|
||||
"get_autonomy_config",
|
||||
"get_default_policy",
|
||||
"get_policy_for_autonomy_level",
|
||||
"get_safety_config",
|
||||
"get_safety_guardian",
|
||||
"load_policies_from_directory",
|
||||
"load_policy_from_file",
|
||||
"reset_audit_logger",
|
||||
"reset_config_cache",
|
||||
"reset_safety_guardian",
|
||||
"shutdown_audit_logger",
|
||||
"shutdown_safety_guardian",
|
||||
]
|
||||
19
backend/app/services/safety/audit/__init__.py
Normal file
19
backend/app/services/safety/audit/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
Audit System
|
||||
|
||||
Comprehensive audit logging for all safety-related events.
|
||||
"""
|
||||
|
||||
from .logger import (
|
||||
AuditLogger,
|
||||
get_audit_logger,
|
||||
reset_audit_logger,
|
||||
shutdown_audit_logger,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AuditLogger",
|
||||
"get_audit_logger",
|
||||
"reset_audit_logger",
|
||||
"shutdown_audit_logger",
|
||||
]
|
||||
585
backend/app/services/safety/audit/logger.py
Normal file
585
backend/app/services/safety/audit/logger.py
Normal file
@@ -0,0 +1,585 @@
|
||||
"""
|
||||
Audit Logger
|
||||
|
||||
Comprehensive audit logging for all safety-related events.
|
||||
Provides tamper detection, structured logging, and compliance support.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from collections import deque
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from ..config import get_safety_config
|
||||
from ..models import (
|
||||
ActionRequest,
|
||||
AuditEvent,
|
||||
AuditEventType,
|
||||
SafetyDecision,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuditLogger:
|
||||
"""
|
||||
Audit logger for safety events.
|
||||
|
||||
Features:
|
||||
- Structured event logging
|
||||
- In-memory buffer with async flush
|
||||
- Tamper detection via hash chains
|
||||
- Query/search capability
|
||||
- Retention policy enforcement
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_buffer_size: int = 1000,
|
||||
flush_interval_seconds: float = 10.0,
|
||||
enable_hash_chain: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the audit logger.
|
||||
|
||||
Args:
|
||||
max_buffer_size: Maximum events to buffer before auto-flush
|
||||
flush_interval_seconds: Interval for periodic flush
|
||||
enable_hash_chain: Enable tamper detection via hash chain
|
||||
"""
|
||||
self._buffer: deque[AuditEvent] = deque(maxlen=max_buffer_size)
|
||||
self._persisted: list[AuditEvent] = []
|
||||
self._flush_interval = flush_interval_seconds
|
||||
self._enable_hash_chain = enable_hash_chain
|
||||
self._last_hash: str | None = None
|
||||
self._lock = asyncio.Lock()
|
||||
self._flush_task: asyncio.Task[None] | None = None
|
||||
self._running = False
|
||||
|
||||
# Event handlers for real-time processing
|
||||
self._handlers: list[Any] = []
|
||||
|
||||
config = get_safety_config()
|
||||
self._retention_days = config.audit_retention_days
|
||||
self._include_sensitive = config.audit_include_sensitive
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the audit logger background tasks."""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._flush_task = asyncio.create_task(self._periodic_flush())
|
||||
logger.info("Audit logger started")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the audit logger and flush remaining events."""
|
||||
self._running = False
|
||||
|
||||
if self._flush_task:
|
||||
self._flush_task.cancel()
|
||||
try:
|
||||
await self._flush_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Final flush
|
||||
await self.flush()
|
||||
logger.info("Audit logger stopped")
|
||||
|
||||
async def log(
|
||||
self,
|
||||
event_type: AuditEventType,
|
||||
*,
|
||||
agent_id: str | None = None,
|
||||
action_id: str | None = None,
|
||||
project_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
decision: SafetyDecision | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
correlation_id: str | None = None,
|
||||
) -> AuditEvent:
|
||||
"""
|
||||
Log an audit event.
|
||||
|
||||
Args:
|
||||
event_type: Type of audit event
|
||||
agent_id: Agent ID if applicable
|
||||
action_id: Action ID if applicable
|
||||
project_id: Project ID if applicable
|
||||
session_id: Session ID if applicable
|
||||
user_id: User ID if applicable
|
||||
decision: Safety decision if applicable
|
||||
details: Additional event details
|
||||
correlation_id: Correlation ID for tracing
|
||||
|
||||
Returns:
|
||||
The created audit event
|
||||
"""
|
||||
# Sanitize sensitive data if needed
|
||||
sanitized_details = self._sanitize_details(details) if details else {}
|
||||
|
||||
event = AuditEvent(
|
||||
id=str(uuid4()),
|
||||
event_type=event_type,
|
||||
timestamp=datetime.utcnow(),
|
||||
agent_id=agent_id,
|
||||
action_id=action_id,
|
||||
project_id=project_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
decision=decision,
|
||||
details=sanitized_details,
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
# Add hash chain for tamper detection
|
||||
if self._enable_hash_chain:
|
||||
event_hash = self._compute_hash(event)
|
||||
sanitized_details["_hash"] = event_hash
|
||||
sanitized_details["_prev_hash"] = self._last_hash
|
||||
self._last_hash = event_hash
|
||||
|
||||
self._buffer.append(event)
|
||||
|
||||
# Notify handlers
|
||||
await self._notify_handlers(event)
|
||||
|
||||
# Log to standard logger as well
|
||||
self._log_to_logger(event)
|
||||
|
||||
return event
|
||||
|
||||
async def log_action_request(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
decision: SafetyDecision,
|
||||
reasons: list[str] | None = None,
|
||||
) -> AuditEvent:
|
||||
"""Log an action request with its validation decision."""
|
||||
event_type = (
|
||||
AuditEventType.ACTION_DENIED
|
||||
if decision == SafetyDecision.DENY
|
||||
else AuditEventType.ACTION_VALIDATED
|
||||
)
|
||||
|
||||
return await self.log(
|
||||
event_type,
|
||||
agent_id=action.metadata.agent_id,
|
||||
action_id=action.id,
|
||||
project_id=action.metadata.project_id,
|
||||
session_id=action.metadata.session_id,
|
||||
user_id=action.metadata.user_id,
|
||||
decision=decision,
|
||||
details={
|
||||
"action_type": action.action_type.value,
|
||||
"tool_name": action.tool_name,
|
||||
"resource": action.resource,
|
||||
"is_destructive": action.is_destructive,
|
||||
"reasons": reasons or [],
|
||||
},
|
||||
correlation_id=action.metadata.correlation_id,
|
||||
)
|
||||
|
||||
async def log_action_executed(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
success: bool,
|
||||
execution_time_ms: float,
|
||||
error: str | None = None,
|
||||
) -> AuditEvent:
|
||||
"""Log an action execution result."""
|
||||
event_type = (
|
||||
AuditEventType.ACTION_EXECUTED
|
||||
if success
|
||||
else AuditEventType.ACTION_FAILED
|
||||
)
|
||||
|
||||
return await self.log(
|
||||
event_type,
|
||||
agent_id=action.metadata.agent_id,
|
||||
action_id=action.id,
|
||||
project_id=action.metadata.project_id,
|
||||
session_id=action.metadata.session_id,
|
||||
decision=SafetyDecision.ALLOW if success else SafetyDecision.DENY,
|
||||
details={
|
||||
"action_type": action.action_type.value,
|
||||
"tool_name": action.tool_name,
|
||||
"success": success,
|
||||
"execution_time_ms": execution_time_ms,
|
||||
"error": error,
|
||||
},
|
||||
correlation_id=action.metadata.correlation_id,
|
||||
)
|
||||
|
||||
async def log_approval_event(
|
||||
self,
|
||||
event_type: AuditEventType,
|
||||
approval_id: str,
|
||||
action: ActionRequest,
|
||||
decided_by: str | None = None,
|
||||
reason: str | None = None,
|
||||
) -> AuditEvent:
|
||||
"""Log an approval-related event."""
|
||||
return await self.log(
|
||||
event_type,
|
||||
agent_id=action.metadata.agent_id,
|
||||
action_id=action.id,
|
||||
project_id=action.metadata.project_id,
|
||||
session_id=action.metadata.session_id,
|
||||
user_id=decided_by,
|
||||
details={
|
||||
"approval_id": approval_id,
|
||||
"action_type": action.action_type.value,
|
||||
"tool_name": action.tool_name,
|
||||
"decided_by": decided_by,
|
||||
"reason": reason,
|
||||
},
|
||||
correlation_id=action.metadata.correlation_id,
|
||||
)
|
||||
|
||||
async def log_budget_event(
|
||||
self,
|
||||
event_type: AuditEventType,
|
||||
agent_id: str,
|
||||
scope: str,
|
||||
current_usage: float,
|
||||
limit: float,
|
||||
unit: str = "tokens",
|
||||
) -> AuditEvent:
|
||||
"""Log a budget-related event."""
|
||||
return await self.log(
|
||||
event_type,
|
||||
agent_id=agent_id,
|
||||
details={
|
||||
"scope": scope,
|
||||
"current_usage": current_usage,
|
||||
"limit": limit,
|
||||
"unit": unit,
|
||||
"usage_percent": (current_usage / limit * 100) if limit > 0 else 0,
|
||||
},
|
||||
)
|
||||
|
||||
async def log_emergency_stop(
|
||||
self,
|
||||
stop_type: str,
|
||||
triggered_by: str,
|
||||
reason: str,
|
||||
affected_agents: list[str] | None = None,
|
||||
) -> AuditEvent:
|
||||
"""Log an emergency stop event."""
|
||||
return await self.log(
|
||||
AuditEventType.EMERGENCY_STOP,
|
||||
user_id=triggered_by,
|
||||
details={
|
||||
"stop_type": stop_type,
|
||||
"triggered_by": triggered_by,
|
||||
"reason": reason,
|
||||
"affected_agents": affected_agents or [],
|
||||
},
|
||||
)
|
||||
|
||||
async def flush(self) -> int:
|
||||
"""
|
||||
Flush buffered events to persistent storage.
|
||||
|
||||
Returns:
|
||||
Number of events flushed
|
||||
"""
|
||||
async with self._lock:
|
||||
if not self._buffer:
|
||||
return 0
|
||||
|
||||
events = list(self._buffer)
|
||||
self._buffer.clear()
|
||||
|
||||
# Persist events (in production, this would go to database/storage)
|
||||
self._persisted.extend(events)
|
||||
|
||||
# Enforce retention
|
||||
self._enforce_retention()
|
||||
|
||||
logger.debug("Flushed %d audit events", len(events))
|
||||
return len(events)
|
||||
|
||||
async def query(
|
||||
self,
|
||||
*,
|
||||
event_types: list[AuditEventType] | None = None,
|
||||
agent_id: str | None = None,
|
||||
action_id: str | None = None,
|
||||
project_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
correlation_id: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[AuditEvent]:
|
||||
"""
|
||||
Query audit events with filters.
|
||||
|
||||
Args:
|
||||
event_types: Filter by event types
|
||||
agent_id: Filter by agent ID
|
||||
action_id: Filter by action ID
|
||||
project_id: Filter by project ID
|
||||
session_id: Filter by session ID
|
||||
user_id: Filter by user ID
|
||||
start_time: Filter events after this time
|
||||
end_time: Filter events before this time
|
||||
correlation_id: Filter by correlation ID
|
||||
limit: Maximum results to return
|
||||
offset: Result offset for pagination
|
||||
|
||||
Returns:
|
||||
List of matching audit events
|
||||
"""
|
||||
# Combine buffer and persisted for query
|
||||
all_events = list(self._persisted) + list(self._buffer)
|
||||
|
||||
results = []
|
||||
for event in all_events:
|
||||
if event_types and event.event_type not in event_types:
|
||||
continue
|
||||
if agent_id and event.agent_id != agent_id:
|
||||
continue
|
||||
if action_id and event.action_id != action_id:
|
||||
continue
|
||||
if project_id and event.project_id != project_id:
|
||||
continue
|
||||
if session_id and event.session_id != session_id:
|
||||
continue
|
||||
if user_id and event.user_id != user_id:
|
||||
continue
|
||||
if start_time and event.timestamp < start_time:
|
||||
continue
|
||||
if end_time and event.timestamp > end_time:
|
||||
continue
|
||||
if correlation_id and event.correlation_id != correlation_id:
|
||||
continue
|
||||
|
||||
results.append(event)
|
||||
|
||||
# Sort by timestamp descending
|
||||
results.sort(key=lambda e: e.timestamp, reverse=True)
|
||||
|
||||
# Apply pagination
|
||||
return results[offset : offset + limit]
|
||||
|
||||
async def get_action_history(
|
||||
self,
|
||||
agent_id: str,
|
||||
limit: int = 100,
|
||||
) -> list[AuditEvent]:
|
||||
"""Get action history for an agent."""
|
||||
return await self.query(
|
||||
agent_id=agent_id,
|
||||
event_types=[
|
||||
AuditEventType.ACTION_REQUESTED,
|
||||
AuditEventType.ACTION_VALIDATED,
|
||||
AuditEventType.ACTION_DENIED,
|
||||
AuditEventType.ACTION_EXECUTED,
|
||||
AuditEventType.ACTION_FAILED,
|
||||
],
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
async def verify_integrity(self) -> tuple[bool, list[str]]:
|
||||
"""
|
||||
Verify audit log integrity using hash chain.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, list of issues found)
|
||||
"""
|
||||
if not self._enable_hash_chain:
|
||||
return True, []
|
||||
|
||||
issues: list[str] = []
|
||||
all_events = list(self._persisted) + list(self._buffer)
|
||||
|
||||
prev_hash: str | None = None
|
||||
for event in sorted(all_events, key=lambda e: e.timestamp):
|
||||
stored_prev = event.details.get("_prev_hash")
|
||||
stored_hash = event.details.get("_hash")
|
||||
|
||||
if stored_prev != prev_hash:
|
||||
issues.append(
|
||||
f"Hash chain broken at event {event.id}: "
|
||||
f"expected prev_hash={prev_hash}, got {stored_prev}"
|
||||
)
|
||||
|
||||
if stored_hash:
|
||||
computed = self._compute_hash(event)
|
||||
if computed != stored_hash:
|
||||
issues.append(
|
||||
f"Hash mismatch at event {event.id}: "
|
||||
f"expected {computed}, got {stored_hash}"
|
||||
)
|
||||
|
||||
prev_hash = stored_hash
|
||||
|
||||
return len(issues) == 0, issues
|
||||
|
||||
def add_handler(self, handler: Any) -> None:
|
||||
"""Add a real-time event handler."""
|
||||
self._handlers.append(handler)
|
||||
|
||||
def remove_handler(self, handler: Any) -> None:
|
||||
"""Remove an event handler."""
|
||||
if handler in self._handlers:
|
||||
self._handlers.remove(handler)
|
||||
|
||||
def _sanitize_details(self, details: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Sanitize sensitive data from details."""
|
||||
if self._include_sensitive:
|
||||
return details
|
||||
|
||||
sanitized: dict[str, Any] = {}
|
||||
sensitive_keys = {
|
||||
"password",
|
||||
"secret",
|
||||
"token",
|
||||
"api_key",
|
||||
"apikey",
|
||||
"auth",
|
||||
"credential",
|
||||
}
|
||||
|
||||
for key, value in details.items():
|
||||
lower_key = key.lower()
|
||||
if any(s in lower_key for s in sensitive_keys):
|
||||
sanitized[key] = "[REDACTED]"
|
||||
elif isinstance(value, dict):
|
||||
sanitized[key] = self._sanitize_details(value)
|
||||
else:
|
||||
sanitized[key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
def _compute_hash(self, event: AuditEvent) -> str:
|
||||
"""Compute hash for an event (excluding hash fields)."""
|
||||
data = {
|
||||
"id": event.id,
|
||||
"event_type": event.event_type.value,
|
||||
"timestamp": event.timestamp.isoformat(),
|
||||
"agent_id": event.agent_id,
|
||||
"action_id": event.action_id,
|
||||
"project_id": event.project_id,
|
||||
"session_id": event.session_id,
|
||||
"user_id": event.user_id,
|
||||
"decision": event.decision.value if event.decision else None,
|
||||
"details": {
|
||||
k: v
|
||||
for k, v in event.details.items()
|
||||
if not k.startswith("_")
|
||||
},
|
||||
"correlation_id": event.correlation_id,
|
||||
}
|
||||
|
||||
if self._last_hash:
|
||||
data["_prev_hash"] = self._last_hash
|
||||
|
||||
serialized = json.dumps(data, sort_keys=True, default=str)
|
||||
return hashlib.sha256(serialized.encode()).hexdigest()
|
||||
|
||||
def _log_to_logger(self, event: AuditEvent) -> None:
|
||||
"""Log event to standard Python logger."""
|
||||
log_data = {
|
||||
"audit_event": event.event_type.value,
|
||||
"event_id": event.id,
|
||||
"agent_id": event.agent_id,
|
||||
"action_id": event.action_id,
|
||||
"decision": event.decision.value if event.decision else None,
|
||||
}
|
||||
|
||||
# Use appropriate log level based on event type
|
||||
if event.event_type in {
|
||||
AuditEventType.ACTION_DENIED,
|
||||
AuditEventType.POLICY_VIOLATION,
|
||||
AuditEventType.EMERGENCY_STOP,
|
||||
}:
|
||||
logger.warning("Audit: %s", log_data)
|
||||
elif event.event_type in {
|
||||
AuditEventType.ACTION_FAILED,
|
||||
AuditEventType.ROLLBACK_FAILED,
|
||||
}:
|
||||
logger.error("Audit: %s", log_data)
|
||||
else:
|
||||
logger.info("Audit: %s", log_data)
|
||||
|
||||
def _enforce_retention(self) -> None:
|
||||
"""Enforce retention policy on persisted events."""
|
||||
if not self._retention_days:
|
||||
return
|
||||
|
||||
cutoff = datetime.utcnow() - timedelta(days=self._retention_days)
|
||||
before_count = len(self._persisted)
|
||||
|
||||
self._persisted = [e for e in self._persisted if e.timestamp >= cutoff]
|
||||
|
||||
removed = before_count - len(self._persisted)
|
||||
if removed > 0:
|
||||
logger.info("Removed %d expired audit events", removed)
|
||||
|
||||
async def _periodic_flush(self) -> None:
|
||||
"""Background task for periodic flushing."""
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(self._flush_interval)
|
||||
await self.flush()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("Error in periodic audit flush: %s", e)
|
||||
|
||||
async def _notify_handlers(self, event: AuditEvent) -> None:
|
||||
"""Notify all registered handlers of a new event."""
|
||||
for handler in self._handlers:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
await handler(event)
|
||||
else:
|
||||
handler(event)
|
||||
except Exception as e:
|
||||
logger.error("Error in audit event handler: %s", e)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_audit_logger: AuditLogger | None = None
|
||||
_audit_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_audit_logger() -> AuditLogger:
|
||||
"""Get the global audit logger instance."""
|
||||
global _audit_logger
|
||||
|
||||
async with _audit_lock:
|
||||
if _audit_logger is None:
|
||||
_audit_logger = AuditLogger()
|
||||
await _audit_logger.start()
|
||||
|
||||
return _audit_logger
|
||||
|
||||
|
||||
async def shutdown_audit_logger() -> None:
|
||||
"""Shutdown the global audit logger."""
|
||||
global _audit_logger
|
||||
|
||||
async with _audit_lock:
|
||||
if _audit_logger is not None:
|
||||
await _audit_logger.stop()
|
||||
_audit_logger = None
|
||||
|
||||
|
||||
def reset_audit_logger() -> None:
|
||||
"""Reset the audit logger (for testing)."""
|
||||
global _audit_logger
|
||||
_audit_logger = None
|
||||
304
backend/app/services/safety/config.py
Normal file
304
backend/app/services/safety/config.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
Safety Framework Configuration
|
||||
|
||||
Pydantic settings for the safety and guardrails framework.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from .models import AutonomyLevel, SafetyPolicy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SafetyConfig(BaseSettings):
|
||||
"""Configuration for the safety framework."""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="SAFETY_",
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
# General settings
|
||||
enabled: bool = Field(True, description="Enable safety framework")
|
||||
strict_mode: bool = Field(
|
||||
True, description="Strict mode (fail closed on errors)"
|
||||
)
|
||||
log_level: str = Field("INFO", description="Logging level")
|
||||
|
||||
# Default autonomy level
|
||||
default_autonomy_level: AutonomyLevel = Field(
|
||||
AutonomyLevel.MILESTONE,
|
||||
description="Default autonomy level for new agents",
|
||||
)
|
||||
|
||||
# Default budget limits
|
||||
default_session_token_budget: int = Field(
|
||||
100_000, description="Default tokens per session"
|
||||
)
|
||||
default_daily_token_budget: int = Field(
|
||||
1_000_000, description="Default tokens per day"
|
||||
)
|
||||
default_session_cost_limit: float = Field(
|
||||
10.0, description="Default USD per session"
|
||||
)
|
||||
default_daily_cost_limit: float = Field(100.0, description="Default USD per day")
|
||||
|
||||
# Default rate limits
|
||||
default_actions_per_minute: int = Field(60, description="Default actions per min")
|
||||
default_llm_calls_per_minute: int = Field(20, description="Default LLM calls/min")
|
||||
default_file_ops_per_minute: int = Field(100, description="Default file ops/min")
|
||||
|
||||
# Loop detection
|
||||
loop_detection_enabled: bool = Field(True, description="Enable loop detection")
|
||||
max_repeated_actions: int = Field(5, description="Max exact repetitions")
|
||||
max_similar_actions: int = Field(10, description="Max similar actions")
|
||||
loop_history_size: int = Field(100, description="Action history size for loops")
|
||||
|
||||
# HITL settings
|
||||
hitl_enabled: bool = Field(True, description="Enable human-in-the-loop")
|
||||
hitl_default_timeout: int = Field(300, description="Default approval timeout (s)")
|
||||
hitl_notification_channels: list[str] = Field(
|
||||
default_factory=list, description="Notification channels"
|
||||
)
|
||||
|
||||
# Rollback settings
|
||||
rollback_enabled: bool = Field(True, description="Enable rollback capability")
|
||||
checkpoint_dir: str = Field(
|
||||
"/tmp/syndarix_checkpoints", # noqa: S108
|
||||
description="Directory for checkpoint storage",
|
||||
)
|
||||
checkpoint_retention_hours: int = Field(24, description="Checkpoint retention")
|
||||
auto_checkpoint_destructive: bool = Field(
|
||||
True, description="Auto-checkpoint destructive actions"
|
||||
)
|
||||
|
||||
# Sandbox settings
|
||||
sandbox_enabled: bool = Field(False, description="Enable sandbox execution")
|
||||
sandbox_timeout: int = Field(300, description="Sandbox timeout (s)")
|
||||
sandbox_memory_mb: int = Field(1024, description="Sandbox memory limit (MB)")
|
||||
sandbox_cpu_limit: float = Field(1.0, description="Sandbox CPU limit")
|
||||
sandbox_network_enabled: bool = Field(False, description="Allow sandbox network")
|
||||
|
||||
# Audit settings
|
||||
audit_enabled: bool = Field(True, description="Enable audit logging")
|
||||
audit_retention_days: int = Field(90, description="Audit log retention (days)")
|
||||
audit_include_sensitive: bool = Field(
|
||||
False, description="Include sensitive data in audit"
|
||||
)
|
||||
|
||||
# Content filtering
|
||||
content_filter_enabled: bool = Field(True, description="Enable content filtering")
|
||||
filter_pii: bool = Field(True, description="Filter PII")
|
||||
filter_secrets: bool = Field(True, description="Filter secrets")
|
||||
|
||||
# Emergency controls
|
||||
emergency_stop_enabled: bool = Field(True, description="Enable emergency stop")
|
||||
emergency_webhook_url: str | None = Field(None, description="Emergency webhook")
|
||||
|
||||
# Policy file path
|
||||
policy_file: str | None = Field(None, description="Path to policy YAML file")
|
||||
|
||||
# Validation cache
|
||||
validation_cache_ttl: int = Field(60, description="Validation cache TTL (s)")
|
||||
validation_cache_size: int = Field(1000, description="Validation cache size")
|
||||
|
||||
|
||||
class AutonomyConfig(BaseSettings):
|
||||
"""Configuration for autonomy levels."""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="AUTONOMY_",
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
# FULL_CONTROL settings
|
||||
full_control_cost_limit: float = Field(1.0, description="USD limit per session")
|
||||
full_control_require_all_approval: bool = Field(
|
||||
True, description="Require approval for all"
|
||||
)
|
||||
full_control_block_destructive: bool = Field(
|
||||
True, description="Block destructive actions"
|
||||
)
|
||||
|
||||
# MILESTONE settings
|
||||
milestone_cost_limit: float = Field(10.0, description="USD limit per session")
|
||||
milestone_require_critical_approval: bool = Field(
|
||||
True, description="Require approval for critical"
|
||||
)
|
||||
milestone_auto_checkpoint: bool = Field(
|
||||
True, description="Auto-checkpoint destructive"
|
||||
)
|
||||
|
||||
# AUTONOMOUS settings
|
||||
autonomous_cost_limit: float = Field(100.0, description="USD limit per session")
|
||||
autonomous_auto_approve_normal: bool = Field(
|
||||
True, description="Auto-approve normal actions"
|
||||
)
|
||||
autonomous_auto_checkpoint: bool = Field(True, description="Auto-checkpoint all")
|
||||
|
||||
|
||||
def _expand_env_vars(value: Any) -> Any:
|
||||
"""Recursively expand environment variables in values."""
|
||||
if isinstance(value, str):
|
||||
return os.path.expandvars(value)
|
||||
elif isinstance(value, dict):
|
||||
return {k: _expand_env_vars(v) for k, v in value.items()}
|
||||
elif isinstance(value, list):
|
||||
return [_expand_env_vars(v) for v in value]
|
||||
return value
|
||||
|
||||
|
||||
def load_policy_from_file(file_path: str | Path) -> SafetyPolicy | None:
|
||||
"""Load a safety policy from a YAML file."""
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
logger.warning("Policy file not found: %s", path)
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(path) as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
if data is None:
|
||||
logger.warning("Empty policy file: %s", path)
|
||||
return None
|
||||
|
||||
# Expand environment variables
|
||||
data = _expand_env_vars(data)
|
||||
|
||||
return SafetyPolicy(**data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to load policy file %s: %s", path, e)
|
||||
return None
|
||||
|
||||
|
||||
def load_policies_from_directory(directory: str | Path) -> dict[str, SafetyPolicy]:
|
||||
"""Load all safety policies from a directory."""
|
||||
policies: dict[str, SafetyPolicy] = {}
|
||||
path = Path(directory)
|
||||
|
||||
if not path.exists() or not path.is_dir():
|
||||
logger.warning("Policy directory not found: %s", path)
|
||||
return policies
|
||||
|
||||
for file_path in path.glob("*.yaml"):
|
||||
policy = load_policy_from_file(file_path)
|
||||
if policy:
|
||||
policies[policy.name] = policy
|
||||
logger.info("Loaded policy: %s from %s", policy.name, file_path.name)
|
||||
|
||||
return policies
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_safety_config() -> SafetyConfig:
|
||||
"""Get the safety configuration (cached singleton)."""
|
||||
return SafetyConfig()
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_autonomy_config() -> AutonomyConfig:
|
||||
"""Get the autonomy configuration (cached singleton)."""
|
||||
return AutonomyConfig()
|
||||
|
||||
|
||||
def get_default_policy() -> SafetyPolicy:
|
||||
"""Get the default safety policy."""
|
||||
config = get_safety_config()
|
||||
|
||||
return SafetyPolicy(
|
||||
name="default",
|
||||
description="Default safety policy",
|
||||
max_tokens_per_session=config.default_session_token_budget,
|
||||
max_tokens_per_day=config.default_daily_token_budget,
|
||||
max_cost_per_session_usd=config.default_session_cost_limit,
|
||||
max_cost_per_day_usd=config.default_daily_cost_limit,
|
||||
max_actions_per_minute=config.default_actions_per_minute,
|
||||
max_llm_calls_per_minute=config.default_llm_calls_per_minute,
|
||||
max_file_operations_per_minute=config.default_file_ops_per_minute,
|
||||
max_repeated_actions=config.max_repeated_actions,
|
||||
max_similar_actions=config.max_similar_actions,
|
||||
require_sandbox=config.sandbox_enabled,
|
||||
sandbox_timeout_seconds=config.sandbox_timeout,
|
||||
sandbox_memory_mb=config.sandbox_memory_mb,
|
||||
)
|
||||
|
||||
|
||||
def get_policy_for_autonomy_level(level: AutonomyLevel) -> SafetyPolicy:
|
||||
"""Get the safety policy for a given autonomy level."""
|
||||
autonomy = get_autonomy_config()
|
||||
|
||||
base_policy = get_default_policy()
|
||||
|
||||
if level == AutonomyLevel.FULL_CONTROL:
|
||||
return SafetyPolicy(
|
||||
name="full_control",
|
||||
description="Full control mode - all actions require approval",
|
||||
max_cost_per_session_usd=autonomy.full_control_cost_limit,
|
||||
max_cost_per_day_usd=autonomy.full_control_cost_limit * 10,
|
||||
require_approval_for=["*"], # All actions
|
||||
max_tokens_per_session=base_policy.max_tokens_per_session // 10,
|
||||
max_tokens_per_day=base_policy.max_tokens_per_day // 10,
|
||||
max_actions_per_minute=base_policy.max_actions_per_minute // 2,
|
||||
max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute // 2,
|
||||
max_file_operations_per_minute=base_policy.max_file_operations_per_minute // 2,
|
||||
denied_tools=["delete_*", "destroy_*", "drop_*"],
|
||||
)
|
||||
|
||||
elif level == AutonomyLevel.MILESTONE:
|
||||
return SafetyPolicy(
|
||||
name="milestone",
|
||||
description="Milestone mode - approval at milestones only",
|
||||
max_cost_per_session_usd=autonomy.milestone_cost_limit,
|
||||
max_cost_per_day_usd=autonomy.milestone_cost_limit * 10,
|
||||
require_approval_for=[
|
||||
"delete_file",
|
||||
"push_to_remote",
|
||||
"deploy_*",
|
||||
"modify_critical_*",
|
||||
"create_pull_request",
|
||||
],
|
||||
max_tokens_per_session=base_policy.max_tokens_per_session,
|
||||
max_tokens_per_day=base_policy.max_tokens_per_day,
|
||||
max_actions_per_minute=base_policy.max_actions_per_minute,
|
||||
max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute,
|
||||
max_file_operations_per_minute=base_policy.max_file_operations_per_minute,
|
||||
)
|
||||
|
||||
else: # AUTONOMOUS
|
||||
return SafetyPolicy(
|
||||
name="autonomous",
|
||||
description="Autonomous mode - minimal intervention",
|
||||
max_cost_per_session_usd=autonomy.autonomous_cost_limit,
|
||||
max_cost_per_day_usd=autonomy.autonomous_cost_limit * 10,
|
||||
require_approval_for=[
|
||||
"deploy_to_production",
|
||||
"delete_repository",
|
||||
"modify_production_config",
|
||||
],
|
||||
max_tokens_per_session=base_policy.max_tokens_per_session * 5,
|
||||
max_tokens_per_day=base_policy.max_tokens_per_day * 5,
|
||||
max_actions_per_minute=base_policy.max_actions_per_minute * 2,
|
||||
max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute * 2,
|
||||
max_file_operations_per_minute=base_policy.max_file_operations_per_minute * 2,
|
||||
)
|
||||
|
||||
|
||||
def reset_config_cache() -> None:
|
||||
"""Reset configuration caches (for testing)."""
|
||||
get_safety_config.cache_clear()
|
||||
get_autonomy_config.cache_clear()
|
||||
23
backend/app/services/safety/content/__init__.py
Normal file
23
backend/app/services/safety/content/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Content filtering for safety."""
|
||||
|
||||
from .filter import (
|
||||
ContentCategory,
|
||||
ContentFilter,
|
||||
FilterAction,
|
||||
FilterMatch,
|
||||
FilterPattern,
|
||||
FilterResult,
|
||||
filter_content,
|
||||
scan_for_secrets,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ContentCategory",
|
||||
"ContentFilter",
|
||||
"FilterAction",
|
||||
"FilterMatch",
|
||||
"FilterPattern",
|
||||
"FilterResult",
|
||||
"filter_content",
|
||||
"scan_for_secrets",
|
||||
]
|
||||
533
backend/app/services/safety/content/filter.py
Normal file
533
backend/app/services/safety/content/filter.py
Normal file
@@ -0,0 +1,533 @@
|
||||
"""
|
||||
Content Filter
|
||||
|
||||
Filters and sanitizes content for safety, including PII detection and secret scanning.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field, replace
|
||||
from enum import Enum
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from ..exceptions import ContentFilterError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContentCategory(str, Enum):
|
||||
"""Categories of sensitive content."""
|
||||
|
||||
PII = "pii"
|
||||
SECRETS = "secrets"
|
||||
CREDENTIALS = "credentials"
|
||||
FINANCIAL = "financial"
|
||||
HEALTH = "health"
|
||||
PROFANITY = "profanity"
|
||||
INJECTION = "injection"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class FilterAction(str, Enum):
|
||||
"""Actions to take on detected content."""
|
||||
|
||||
ALLOW = "allow"
|
||||
REDACT = "redact"
|
||||
BLOCK = "block"
|
||||
WARN = "warn"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilterMatch:
|
||||
"""A match found by a filter."""
|
||||
|
||||
category: ContentCategory
|
||||
pattern_name: str
|
||||
matched_text: str
|
||||
start_pos: int
|
||||
end_pos: int
|
||||
confidence: float = 1.0
|
||||
redacted_text: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilterResult:
|
||||
"""Result of content filtering."""
|
||||
|
||||
original_content: str
|
||||
filtered_content: str
|
||||
matches: list[FilterMatch] = field(default_factory=list)
|
||||
blocked: bool = False
|
||||
block_reason: str | None = None
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def has_sensitive_content(self) -> bool:
|
||||
"""Check if any sensitive content was found."""
|
||||
return len(self.matches) > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilterPattern:
|
||||
"""A pattern for detecting sensitive content."""
|
||||
|
||||
name: str
|
||||
category: ContentCategory
|
||||
pattern: str # Regex pattern
|
||||
action: FilterAction = FilterAction.REDACT
|
||||
replacement: str = "[REDACTED]"
|
||||
confidence: float = 1.0
|
||||
enabled: bool = True
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Compile the regex pattern."""
|
||||
self._compiled = re.compile(self.pattern, re.IGNORECASE | re.MULTILINE)
|
||||
|
||||
def find_matches(self, content: str) -> list[FilterMatch]:
|
||||
"""Find all matches in content."""
|
||||
matches = []
|
||||
for match in self._compiled.finditer(content):
|
||||
matches.append(
|
||||
FilterMatch(
|
||||
category=self.category,
|
||||
pattern_name=self.name,
|
||||
matched_text=match.group(),
|
||||
start_pos=match.start(),
|
||||
end_pos=match.end(),
|
||||
confidence=self.confidence,
|
||||
redacted_text=self.replacement,
|
||||
)
|
||||
)
|
||||
return matches
|
||||
|
||||
|
||||
class ContentFilter:
|
||||
"""
|
||||
Filters content for sensitive information.
|
||||
|
||||
Features:
|
||||
- PII detection (emails, phones, SSN, etc.)
|
||||
- Secret scanning (API keys, tokens, passwords)
|
||||
- Credential detection
|
||||
- Injection attack prevention
|
||||
- Custom pattern support
|
||||
- Configurable actions (allow, redact, block, warn)
|
||||
"""
|
||||
|
||||
# Default patterns for common sensitive data
|
||||
DEFAULT_PATTERNS: ClassVar[list[FilterPattern]] = [
|
||||
# PII Patterns
|
||||
FilterPattern(
|
||||
name="email",
|
||||
category=ContentCategory.PII,
|
||||
pattern=r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b",
|
||||
action=FilterAction.REDACT,
|
||||
replacement="[EMAIL]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="phone_us",
|
||||
category=ContentCategory.PII,
|
||||
pattern=r"\b(?:\+1[-.\s]?)?(?:\(?\d{3}\)?[-.\s]?)?\d{3}[-.\s]?\d{4}\b",
|
||||
action=FilterAction.REDACT,
|
||||
replacement="[PHONE]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="ssn",
|
||||
category=ContentCategory.PII,
|
||||
pattern=r"\b\d{3}[-\s]?\d{2}[-\s]?\d{4}\b",
|
||||
action=FilterAction.REDACT,
|
||||
replacement="[SSN]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="credit_card",
|
||||
category=ContentCategory.FINANCIAL,
|
||||
pattern=r"\b(?:\d{4}[-\s]?){3}\d{4}\b",
|
||||
action=FilterAction.REDACT,
|
||||
replacement="[CREDIT_CARD]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="ip_address",
|
||||
category=ContentCategory.PII,
|
||||
pattern=r"\b(?:\d{1,3}\.){3}\d{1,3}\b",
|
||||
action=FilterAction.WARN,
|
||||
replacement="[IP]",
|
||||
confidence=0.8,
|
||||
),
|
||||
# Secret Patterns
|
||||
FilterPattern(
|
||||
name="api_key_generic",
|
||||
category=ContentCategory.SECRETS,
|
||||
pattern=r"\b(?:api[_-]?key|apikey)\s*[:=]\s*['\"]?([A-Za-z0-9_-]{20,})['\"]?",
|
||||
action=FilterAction.BLOCK,
|
||||
replacement="[API_KEY]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="aws_access_key",
|
||||
category=ContentCategory.SECRETS,
|
||||
pattern=r"\bAKIA[0-9A-Z]{16}\b",
|
||||
action=FilterAction.BLOCK,
|
||||
replacement="[AWS_KEY]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="aws_secret_key",
|
||||
category=ContentCategory.SECRETS,
|
||||
pattern=r"\b[A-Za-z0-9/+=]{40}\b",
|
||||
action=FilterAction.WARN,
|
||||
replacement="[AWS_SECRET]",
|
||||
confidence=0.6, # Lower confidence - might be false positive
|
||||
),
|
||||
FilterPattern(
|
||||
name="github_token",
|
||||
category=ContentCategory.SECRETS,
|
||||
pattern=r"\b(ghp|gho|ghu|ghs|ghr)_[A-Za-z0-9]{36,}\b",
|
||||
action=FilterAction.BLOCK,
|
||||
replacement="[GITHUB_TOKEN]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="jwt_token",
|
||||
category=ContentCategory.SECRETS,
|
||||
pattern=r"\beyJ[A-Za-z0-9_-]*\.eyJ[A-Za-z0-9_-]*\.[A-Za-z0-9_-]*\b",
|
||||
action=FilterAction.BLOCK,
|
||||
replacement="[JWT]",
|
||||
),
|
||||
# Credential Patterns
|
||||
FilterPattern(
|
||||
name="password_in_url",
|
||||
category=ContentCategory.CREDENTIALS,
|
||||
pattern=r"://[^:]+:([^@]+)@",
|
||||
action=FilterAction.BLOCK,
|
||||
replacement="://[REDACTED]@",
|
||||
),
|
||||
FilterPattern(
|
||||
name="password_assignment",
|
||||
category=ContentCategory.CREDENTIALS,
|
||||
pattern=r"\b(?:password|passwd|pwd)\s*[:=]\s*['\"]?([^\s'\"]+)['\"]?",
|
||||
action=FilterAction.REDACT,
|
||||
replacement="[PASSWORD]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="private_key",
|
||||
category=ContentCategory.SECRETS,
|
||||
pattern=r"-----BEGIN (?:RSA |DSA |EC |OPENSSH )?PRIVATE KEY-----",
|
||||
action=FilterAction.BLOCK,
|
||||
replacement="[PRIVATE_KEY]",
|
||||
),
|
||||
# Injection Patterns
|
||||
FilterPattern(
|
||||
name="sql_injection",
|
||||
category=ContentCategory.INJECTION,
|
||||
pattern=r"(?:'\s*(?:OR|AND)\s*')|(?:--\s*$)|(?:;\s*(?:DROP|DELETE|UPDATE|INSERT))",
|
||||
action=FilterAction.BLOCK,
|
||||
replacement="[BLOCKED]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="command_injection",
|
||||
category=ContentCategory.INJECTION,
|
||||
pattern=r"[;&|`$]|\$\(|\$\{",
|
||||
action=FilterAction.WARN,
|
||||
replacement="[CMD]",
|
||||
confidence=0.5, # Low confidence - common in code
|
||||
),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enable_pii_filter: bool = True,
|
||||
enable_secret_filter: bool = True,
|
||||
enable_injection_filter: bool = True,
|
||||
custom_patterns: list[FilterPattern] | None = None,
|
||||
default_action: FilterAction = FilterAction.REDACT,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the ContentFilter.
|
||||
|
||||
Args:
|
||||
enable_pii_filter: Enable PII detection
|
||||
enable_secret_filter: Enable secret scanning
|
||||
enable_injection_filter: Enable injection detection
|
||||
custom_patterns: Additional custom patterns
|
||||
default_action: Default action for matches
|
||||
"""
|
||||
self._patterns: list[FilterPattern] = []
|
||||
self._default_action = default_action
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Load default patterns based on configuration
|
||||
# Use replace() to create a copy of each pattern to avoid mutating shared defaults
|
||||
for pattern in self.DEFAULT_PATTERNS:
|
||||
if pattern.category == ContentCategory.PII and not enable_pii_filter:
|
||||
continue
|
||||
if pattern.category == ContentCategory.SECRETS and not enable_secret_filter:
|
||||
continue
|
||||
if pattern.category == ContentCategory.CREDENTIALS and not enable_secret_filter:
|
||||
continue
|
||||
if pattern.category == ContentCategory.INJECTION and not enable_injection_filter:
|
||||
continue
|
||||
self._patterns.append(replace(pattern))
|
||||
|
||||
# Add custom patterns
|
||||
if custom_patterns:
|
||||
self._patterns.extend(custom_patterns)
|
||||
|
||||
logger.info("ContentFilter initialized with %d patterns", len(self._patterns))
|
||||
|
||||
def add_pattern(self, pattern: FilterPattern) -> None:
|
||||
"""Add a custom pattern."""
|
||||
self._patterns.append(pattern)
|
||||
logger.debug("Added pattern: %s", pattern.name)
|
||||
|
||||
def remove_pattern(self, pattern_name: str) -> bool:
|
||||
"""Remove a pattern by name."""
|
||||
for i, pattern in enumerate(self._patterns):
|
||||
if pattern.name == pattern_name:
|
||||
del self._patterns[i]
|
||||
logger.debug("Removed pattern: %s", pattern_name)
|
||||
return True
|
||||
return False
|
||||
|
||||
def enable_pattern(self, pattern_name: str, enabled: bool = True) -> bool:
|
||||
"""Enable or disable a pattern."""
|
||||
for pattern in self._patterns:
|
||||
if pattern.name == pattern_name:
|
||||
pattern.enabled = enabled
|
||||
return True
|
||||
return False
|
||||
|
||||
async def filter(
|
||||
self,
|
||||
content: str,
|
||||
context: dict[str, Any] | None = None,
|
||||
raise_on_block: bool = False,
|
||||
) -> FilterResult:
|
||||
"""
|
||||
Filter content for sensitive information.
|
||||
|
||||
Args:
|
||||
content: Content to filter
|
||||
context: Optional context for filtering decisions
|
||||
raise_on_block: Raise exception if content is blocked
|
||||
|
||||
Returns:
|
||||
FilterResult with filtered content and match details
|
||||
|
||||
Raises:
|
||||
ContentFilterError: If content is blocked and raise_on_block=True
|
||||
"""
|
||||
all_matches: list[FilterMatch] = []
|
||||
blocked = False
|
||||
block_reason: str | None = None
|
||||
warnings: list[str] = []
|
||||
|
||||
# Find all matches
|
||||
for pattern in self._patterns:
|
||||
if not pattern.enabled:
|
||||
continue
|
||||
|
||||
matches = pattern.find_matches(content)
|
||||
for match in matches:
|
||||
all_matches.append(match)
|
||||
|
||||
if pattern.action == FilterAction.BLOCK:
|
||||
blocked = True
|
||||
block_reason = f"Blocked by pattern: {pattern.name}"
|
||||
elif pattern.action == FilterAction.WARN:
|
||||
warnings.append(
|
||||
f"Warning: {pattern.name} detected at position {match.start_pos}"
|
||||
)
|
||||
|
||||
# Sort matches by position (reverse for replacement)
|
||||
all_matches.sort(key=lambda m: m.start_pos, reverse=True)
|
||||
|
||||
# Apply redactions
|
||||
filtered_content = content
|
||||
for match in all_matches:
|
||||
matched_pattern = self._get_pattern(match.pattern_name)
|
||||
if matched_pattern and matched_pattern.action in (FilterAction.REDACT, FilterAction.BLOCK):
|
||||
filtered_content = (
|
||||
filtered_content[: match.start_pos]
|
||||
+ (match.redacted_text or "[REDACTED]")
|
||||
+ filtered_content[match.end_pos :]
|
||||
)
|
||||
|
||||
# Re-sort for result
|
||||
all_matches.sort(key=lambda m: m.start_pos)
|
||||
|
||||
result = FilterResult(
|
||||
original_content=content,
|
||||
filtered_content=filtered_content if not blocked else "",
|
||||
matches=all_matches,
|
||||
blocked=blocked,
|
||||
block_reason=block_reason,
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
if blocked:
|
||||
logger.warning(
|
||||
"Content blocked: %s (%d matches)",
|
||||
block_reason,
|
||||
len(all_matches),
|
||||
)
|
||||
if raise_on_block:
|
||||
raise ContentFilterError(
|
||||
block_reason or "Content blocked",
|
||||
filter_type=all_matches[0].category.value if all_matches else "unknown",
|
||||
detected_patterns=[m.pattern_name for m in all_matches] if all_matches else [],
|
||||
)
|
||||
elif all_matches:
|
||||
logger.debug(
|
||||
"Content filtered: %d matches, %d warnings",
|
||||
len(all_matches),
|
||||
len(warnings),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def filter_dict(
|
||||
self,
|
||||
data: dict[str, Any],
|
||||
keys_to_filter: list[str] | None = None,
|
||||
recursive: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Filter string values in a dictionary.
|
||||
|
||||
Args:
|
||||
data: Dictionary to filter
|
||||
keys_to_filter: Specific keys to filter (None = all)
|
||||
recursive: Filter nested dictionaries
|
||||
|
||||
Returns:
|
||||
Filtered dictionary
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
|
||||
for key, value in data.items():
|
||||
if isinstance(value, str):
|
||||
if keys_to_filter is None or key in keys_to_filter:
|
||||
filter_result = await self.filter(value)
|
||||
result[key] = filter_result.filtered_content
|
||||
else:
|
||||
result[key] = value
|
||||
elif isinstance(value, dict) and recursive:
|
||||
result[key] = await self.filter_dict(value, keys_to_filter, recursive)
|
||||
elif isinstance(value, list):
|
||||
result[key] = [
|
||||
(await self.filter(item)).filtered_content
|
||||
if isinstance(item, str)
|
||||
else item
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
result[key] = value
|
||||
|
||||
return result
|
||||
|
||||
async def scan(
|
||||
self,
|
||||
content: str,
|
||||
categories: list[ContentCategory] | None = None,
|
||||
) -> list[FilterMatch]:
|
||||
"""
|
||||
Scan content without filtering (detection only).
|
||||
|
||||
Args:
|
||||
content: Content to scan
|
||||
categories: Limit to specific categories
|
||||
|
||||
Returns:
|
||||
List of matches found
|
||||
"""
|
||||
all_matches: list[FilterMatch] = []
|
||||
|
||||
for pattern in self._patterns:
|
||||
if not pattern.enabled:
|
||||
continue
|
||||
if categories and pattern.category not in categories:
|
||||
continue
|
||||
|
||||
matches = pattern.find_matches(content)
|
||||
all_matches.extend(matches)
|
||||
|
||||
all_matches.sort(key=lambda m: m.start_pos)
|
||||
return all_matches
|
||||
|
||||
async def validate_safe(
|
||||
self,
|
||||
content: str,
|
||||
categories: list[ContentCategory] | None = None,
|
||||
allow_warnings: bool = True,
|
||||
) -> tuple[bool, list[str]]:
|
||||
"""
|
||||
Validate that content is safe (no blocked patterns).
|
||||
|
||||
Args:
|
||||
content: Content to validate
|
||||
categories: Limit to specific categories
|
||||
allow_warnings: Allow content with warnings
|
||||
|
||||
Returns:
|
||||
Tuple of (is_safe, list of issues)
|
||||
"""
|
||||
issues: list[str] = []
|
||||
|
||||
for pattern in self._patterns:
|
||||
if not pattern.enabled:
|
||||
continue
|
||||
if categories and pattern.category not in categories:
|
||||
continue
|
||||
|
||||
matches = pattern.find_matches(content)
|
||||
for match in matches:
|
||||
if pattern.action == FilterAction.BLOCK:
|
||||
issues.append(f"Blocked: {pattern.name} at position {match.start_pos}")
|
||||
elif pattern.action == FilterAction.WARN and not allow_warnings:
|
||||
issues.append(f"Warning: {pattern.name} at position {match.start_pos}")
|
||||
|
||||
return len(issues) == 0, issues
|
||||
|
||||
def _get_pattern(self, name: str) -> FilterPattern | None:
|
||||
"""Get a pattern by name."""
|
||||
for pattern in self._patterns:
|
||||
if pattern.name == name:
|
||||
return pattern
|
||||
return None
|
||||
|
||||
def get_pattern_stats(self) -> dict[str, Any]:
|
||||
"""Get statistics about configured patterns."""
|
||||
by_category: dict[str, int] = {}
|
||||
by_action: dict[str, int] = {}
|
||||
|
||||
for pattern in self._patterns:
|
||||
cat = pattern.category.value
|
||||
by_category[cat] = by_category.get(cat, 0) + 1
|
||||
|
||||
act = pattern.action.value
|
||||
by_action[act] = by_action.get(act, 0) + 1
|
||||
|
||||
return {
|
||||
"total_patterns": len(self._patterns),
|
||||
"enabled_patterns": sum(1 for p in self._patterns if p.enabled),
|
||||
"by_category": by_category,
|
||||
"by_action": by_action,
|
||||
}
|
||||
|
||||
|
||||
# Convenience function for quick filtering
|
||||
async def filter_content(content: str) -> str:
|
||||
"""Quick filter content with default settings."""
|
||||
filter_instance = ContentFilter()
|
||||
result = await filter_instance.filter(content)
|
||||
return result.filtered_content
|
||||
|
||||
|
||||
async def scan_for_secrets(content: str) -> list[FilterMatch]:
|
||||
"""Quick scan for secrets only."""
|
||||
filter_instance = ContentFilter(
|
||||
enable_pii_filter=False,
|
||||
enable_injection_filter=False,
|
||||
)
|
||||
return await filter_instance.scan(
|
||||
content,
|
||||
categories=[ContentCategory.SECRETS, ContentCategory.CREDENTIALS],
|
||||
)
|
||||
15
backend/app/services/safety/costs/__init__.py
Normal file
15
backend/app/services/safety/costs/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
Cost Control Module
|
||||
|
||||
Budget management and cost tracking.
|
||||
"""
|
||||
|
||||
from .controller import (
|
||||
BudgetTracker,
|
||||
CostController,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BudgetTracker",
|
||||
"CostController",
|
||||
]
|
||||
479
backend/app/services/safety/costs/controller.py
Normal file
479
backend/app/services/safety/costs/controller.py
Normal file
@@ -0,0 +1,479 @@
|
||||
"""
|
||||
Cost Controller
|
||||
|
||||
Budget management and cost tracking for agent operations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from ..config import get_safety_config
|
||||
from ..exceptions import BudgetExceededError
|
||||
from ..models import (
|
||||
ActionRequest,
|
||||
BudgetScope,
|
||||
BudgetStatus,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BudgetTracker:
|
||||
"""Tracks usage against a budget limit."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scope: BudgetScope,
|
||||
scope_id: str,
|
||||
tokens_limit: int,
|
||||
cost_limit_usd: float,
|
||||
reset_interval: timedelta | None = None,
|
||||
warning_threshold: float = 0.8,
|
||||
) -> None:
|
||||
self.scope = scope
|
||||
self.scope_id = scope_id
|
||||
self.tokens_limit = tokens_limit
|
||||
self.cost_limit_usd = cost_limit_usd
|
||||
self.warning_threshold = warning_threshold
|
||||
self._reset_interval = reset_interval
|
||||
|
||||
self._tokens_used = 0
|
||||
self._cost_used_usd = 0.0
|
||||
self._created_at = datetime.utcnow()
|
||||
self._last_reset = datetime.utcnow()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def add_usage(self, tokens: int, cost_usd: float) -> None:
|
||||
"""Add usage to the tracker."""
|
||||
async with self._lock:
|
||||
self._check_reset()
|
||||
self._tokens_used += tokens
|
||||
self._cost_used_usd += cost_usd
|
||||
|
||||
async def get_status(self) -> BudgetStatus:
|
||||
"""Get current budget status."""
|
||||
async with self._lock:
|
||||
self._check_reset()
|
||||
|
||||
tokens_remaining = max(0, self.tokens_limit - self._tokens_used)
|
||||
cost_remaining = max(0, self.cost_limit_usd - self._cost_used_usd)
|
||||
|
||||
token_usage_ratio = (
|
||||
self._tokens_used / self.tokens_limit if self.tokens_limit > 0 else 0
|
||||
)
|
||||
cost_usage_ratio = (
|
||||
self._cost_used_usd / self.cost_limit_usd
|
||||
if self.cost_limit_usd > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
is_warning = max(token_usage_ratio, cost_usage_ratio) >= self.warning_threshold
|
||||
is_exceeded = (
|
||||
self._tokens_used >= self.tokens_limit
|
||||
or self._cost_used_usd >= self.cost_limit_usd
|
||||
)
|
||||
|
||||
reset_at = None
|
||||
if self._reset_interval:
|
||||
reset_at = self._last_reset + self._reset_interval
|
||||
|
||||
return BudgetStatus(
|
||||
scope=self.scope,
|
||||
scope_id=self.scope_id,
|
||||
tokens_used=self._tokens_used,
|
||||
tokens_limit=self.tokens_limit,
|
||||
cost_used_usd=self._cost_used_usd,
|
||||
cost_limit_usd=self.cost_limit_usd,
|
||||
tokens_remaining=tokens_remaining,
|
||||
cost_remaining_usd=cost_remaining,
|
||||
warning_threshold=self.warning_threshold,
|
||||
is_warning=is_warning,
|
||||
is_exceeded=is_exceeded,
|
||||
reset_at=reset_at,
|
||||
)
|
||||
|
||||
async def check_budget(self, estimated_tokens: int, estimated_cost_usd: float) -> bool:
|
||||
"""Check if there's enough budget for an operation."""
|
||||
async with self._lock:
|
||||
self._check_reset()
|
||||
|
||||
would_exceed_tokens = (self._tokens_used + estimated_tokens) > self.tokens_limit
|
||||
would_exceed_cost = (
|
||||
self._cost_used_usd + estimated_cost_usd
|
||||
) > self.cost_limit_usd
|
||||
|
||||
return not (would_exceed_tokens or would_exceed_cost)
|
||||
|
||||
def _check_reset(self) -> None:
|
||||
"""Check if budget should reset."""
|
||||
if self._reset_interval is None:
|
||||
return
|
||||
|
||||
now = datetime.utcnow()
|
||||
if now >= self._last_reset + self._reset_interval:
|
||||
logger.info(
|
||||
"Resetting budget for %s:%s",
|
||||
self.scope.value,
|
||||
self.scope_id,
|
||||
)
|
||||
self._tokens_used = 0
|
||||
self._cost_used_usd = 0.0
|
||||
self._last_reset = now
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Manually reset the budget."""
|
||||
async with self._lock:
|
||||
self._tokens_used = 0
|
||||
self._cost_used_usd = 0.0
|
||||
self._last_reset = datetime.utcnow()
|
||||
|
||||
|
||||
class CostController:
|
||||
"""
|
||||
Controls costs and budgets for agent operations.
|
||||
|
||||
Features:
|
||||
- Per-agent, per-project, per-session budgets
|
||||
- Real-time cost tracking
|
||||
- Budget alerts at configurable thresholds
|
||||
- Cost prediction for planned actions
|
||||
- Budget rollover policies
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default_session_tokens: int | None = None,
|
||||
default_session_cost_usd: float | None = None,
|
||||
default_daily_tokens: int | None = None,
|
||||
default_daily_cost_usd: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the CostController.
|
||||
|
||||
Args:
|
||||
default_session_tokens: Default token budget per session
|
||||
default_session_cost_usd: Default USD budget per session
|
||||
default_daily_tokens: Default token budget per day
|
||||
default_daily_cost_usd: Default USD budget per day
|
||||
"""
|
||||
config = get_safety_config()
|
||||
|
||||
self._default_session_tokens = (
|
||||
default_session_tokens or config.default_session_token_budget
|
||||
)
|
||||
self._default_session_cost = (
|
||||
default_session_cost_usd or config.default_session_cost_limit
|
||||
)
|
||||
self._default_daily_tokens = (
|
||||
default_daily_tokens or config.default_daily_token_budget
|
||||
)
|
||||
self._default_daily_cost = (
|
||||
default_daily_cost_usd or config.default_daily_cost_limit
|
||||
)
|
||||
|
||||
self._trackers: dict[str, BudgetTracker] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Alert handlers
|
||||
self._alert_handlers: list[Any] = []
|
||||
|
||||
async def get_or_create_tracker(
|
||||
self,
|
||||
scope: BudgetScope,
|
||||
scope_id: str,
|
||||
) -> BudgetTracker:
|
||||
"""Get or create a budget tracker."""
|
||||
key = f"{scope.value}:{scope_id}"
|
||||
|
||||
async with self._lock:
|
||||
if key not in self._trackers:
|
||||
if scope == BudgetScope.SESSION:
|
||||
tracker = BudgetTracker(
|
||||
scope=scope,
|
||||
scope_id=scope_id,
|
||||
tokens_limit=self._default_session_tokens,
|
||||
cost_limit_usd=self._default_session_cost,
|
||||
)
|
||||
elif scope == BudgetScope.DAILY:
|
||||
tracker = BudgetTracker(
|
||||
scope=scope,
|
||||
scope_id=scope_id,
|
||||
tokens_limit=self._default_daily_tokens,
|
||||
cost_limit_usd=self._default_daily_cost,
|
||||
reset_interval=timedelta(days=1),
|
||||
)
|
||||
else:
|
||||
# Default
|
||||
tracker = BudgetTracker(
|
||||
scope=scope,
|
||||
scope_id=scope_id,
|
||||
tokens_limit=self._default_session_tokens,
|
||||
cost_limit_usd=self._default_session_cost,
|
||||
)
|
||||
|
||||
self._trackers[key] = tracker
|
||||
|
||||
return self._trackers[key]
|
||||
|
||||
async def check_budget(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str | None,
|
||||
estimated_tokens: int,
|
||||
estimated_cost_usd: float,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if there's enough budget for an operation.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
session_id: Optional session ID
|
||||
estimated_tokens: Estimated token usage
|
||||
estimated_cost_usd: Estimated USD cost
|
||||
|
||||
Returns:
|
||||
True if budget is available
|
||||
"""
|
||||
# Check session budget
|
||||
if session_id:
|
||||
session_tracker = await self.get_or_create_tracker(
|
||||
BudgetScope.SESSION, session_id
|
||||
)
|
||||
if not await session_tracker.check_budget(estimated_tokens, estimated_cost_usd):
|
||||
return False
|
||||
|
||||
# Check agent daily budget
|
||||
agent_tracker = await self.get_or_create_tracker(
|
||||
BudgetScope.DAILY, agent_id
|
||||
)
|
||||
if not await agent_tracker.check_budget(estimated_tokens, estimated_cost_usd):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def check_action(self, action: ActionRequest) -> bool:
|
||||
"""
|
||||
Check if an action is within budget.
|
||||
|
||||
Args:
|
||||
action: The action to check
|
||||
|
||||
Returns:
|
||||
True if within budget
|
||||
"""
|
||||
return await self.check_budget(
|
||||
agent_id=action.metadata.agent_id,
|
||||
session_id=action.metadata.session_id,
|
||||
estimated_tokens=action.estimated_cost_tokens,
|
||||
estimated_cost_usd=action.estimated_cost_usd,
|
||||
)
|
||||
|
||||
async def require_budget(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str | None,
|
||||
estimated_tokens: int,
|
||||
estimated_cost_usd: float,
|
||||
) -> None:
|
||||
"""
|
||||
Require budget or raise exception.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
session_id: Optional session ID
|
||||
estimated_tokens: Estimated token usage
|
||||
estimated_cost_usd: Estimated USD cost
|
||||
|
||||
Raises:
|
||||
BudgetExceededError: If budget is exceeded
|
||||
"""
|
||||
if not await self.check_budget(
|
||||
agent_id, session_id, estimated_tokens, estimated_cost_usd
|
||||
):
|
||||
# Determine which budget was exceeded
|
||||
if session_id:
|
||||
session_tracker = await self.get_or_create_tracker(
|
||||
BudgetScope.SESSION, session_id
|
||||
)
|
||||
session_status = await session_tracker.get_status()
|
||||
if session_status.is_exceeded:
|
||||
raise BudgetExceededError(
|
||||
"Session budget exceeded",
|
||||
budget_type="session",
|
||||
current_usage=session_status.tokens_used,
|
||||
budget_limit=session_status.tokens_limit,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
agent_tracker = await self.get_or_create_tracker(
|
||||
BudgetScope.DAILY, agent_id
|
||||
)
|
||||
agent_status = await agent_tracker.get_status()
|
||||
raise BudgetExceededError(
|
||||
"Daily budget exceeded",
|
||||
budget_type="daily",
|
||||
current_usage=agent_status.tokens_used,
|
||||
budget_limit=agent_status.tokens_limit,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
async def record_usage(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str | None,
|
||||
tokens: int,
|
||||
cost_usd: float,
|
||||
) -> None:
|
||||
"""
|
||||
Record actual usage.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
session_id: Optional session ID
|
||||
tokens: Actual token usage
|
||||
cost_usd: Actual USD cost
|
||||
"""
|
||||
# Update session budget
|
||||
if session_id:
|
||||
session_tracker = await self.get_or_create_tracker(
|
||||
BudgetScope.SESSION, session_id
|
||||
)
|
||||
await session_tracker.add_usage(tokens, cost_usd)
|
||||
|
||||
# Check for warning
|
||||
status = await session_tracker.get_status()
|
||||
if status.is_warning and not status.is_exceeded:
|
||||
await self._send_alert(
|
||||
"warning",
|
||||
f"Session {session_id} at {status.tokens_used}/{status.tokens_limit} tokens",
|
||||
status,
|
||||
)
|
||||
|
||||
# Update agent daily budget
|
||||
agent_tracker = await self.get_or_create_tracker(BudgetScope.DAILY, agent_id)
|
||||
await agent_tracker.add_usage(tokens, cost_usd)
|
||||
|
||||
# Check for warning
|
||||
status = await agent_tracker.get_status()
|
||||
if status.is_warning and not status.is_exceeded:
|
||||
await self._send_alert(
|
||||
"warning",
|
||||
f"Agent {agent_id} at {status.tokens_used}/{status.tokens_limit} daily tokens",
|
||||
status,
|
||||
)
|
||||
|
||||
async def get_status(
|
||||
self,
|
||||
scope: BudgetScope,
|
||||
scope_id: str,
|
||||
) -> BudgetStatus | None:
|
||||
"""
|
||||
Get budget status.
|
||||
|
||||
Args:
|
||||
scope: Budget scope
|
||||
scope_id: ID within scope
|
||||
|
||||
Returns:
|
||||
Budget status or None if not tracked
|
||||
"""
|
||||
key = f"{scope.value}:{scope_id}"
|
||||
async with self._lock:
|
||||
tracker = self._trackers.get(key)
|
||||
|
||||
if tracker:
|
||||
return await tracker.get_status()
|
||||
return None
|
||||
|
||||
async def get_all_statuses(self) -> list[BudgetStatus]:
|
||||
"""Get status of all tracked budgets."""
|
||||
statuses = []
|
||||
async with self._lock:
|
||||
trackers = list(self._trackers.values())
|
||||
|
||||
for tracker in trackers:
|
||||
statuses.append(await tracker.get_status())
|
||||
|
||||
return statuses
|
||||
|
||||
async def set_budget(
|
||||
self,
|
||||
scope: BudgetScope,
|
||||
scope_id: str,
|
||||
tokens_limit: int,
|
||||
cost_limit_usd: float,
|
||||
) -> None:
|
||||
"""
|
||||
Set a custom budget limit.
|
||||
|
||||
Args:
|
||||
scope: Budget scope
|
||||
scope_id: ID within scope
|
||||
tokens_limit: Token limit
|
||||
cost_limit_usd: USD limit
|
||||
"""
|
||||
key = f"{scope.value}:{scope_id}"
|
||||
|
||||
reset_interval = None
|
||||
if scope == BudgetScope.DAILY:
|
||||
reset_interval = timedelta(days=1)
|
||||
elif scope == BudgetScope.WEEKLY:
|
||||
reset_interval = timedelta(weeks=1)
|
||||
elif scope == BudgetScope.MONTHLY:
|
||||
reset_interval = timedelta(days=30)
|
||||
|
||||
async with self._lock:
|
||||
self._trackers[key] = BudgetTracker(
|
||||
scope=scope,
|
||||
scope_id=scope_id,
|
||||
tokens_limit=tokens_limit,
|
||||
cost_limit_usd=cost_limit_usd,
|
||||
reset_interval=reset_interval,
|
||||
)
|
||||
|
||||
async def reset_budget(self, scope: BudgetScope, scope_id: str) -> bool:
|
||||
"""
|
||||
Reset a budget tracker.
|
||||
|
||||
Args:
|
||||
scope: Budget scope
|
||||
scope_id: ID within scope
|
||||
|
||||
Returns:
|
||||
True if tracker was found and reset
|
||||
"""
|
||||
key = f"{scope.value}:{scope_id}"
|
||||
async with self._lock:
|
||||
tracker = self._trackers.get(key)
|
||||
|
||||
if tracker:
|
||||
await tracker.reset()
|
||||
return True
|
||||
return False
|
||||
|
||||
def add_alert_handler(self, handler: Any) -> None:
|
||||
"""Add an alert handler."""
|
||||
self._alert_handlers.append(handler)
|
||||
|
||||
def remove_alert_handler(self, handler: Any) -> None:
|
||||
"""Remove an alert handler."""
|
||||
if handler in self._alert_handlers:
|
||||
self._alert_handlers.remove(handler)
|
||||
|
||||
async def _send_alert(
|
||||
self,
|
||||
alert_type: str,
|
||||
message: str,
|
||||
status: BudgetStatus,
|
||||
) -> None:
|
||||
"""Send alert to all handlers."""
|
||||
for handler in self._alert_handlers:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
await handler(alert_type, message, status)
|
||||
else:
|
||||
handler(alert_type, message, status)
|
||||
except Exception as e:
|
||||
logger.error("Error in alert handler: %s", e)
|
||||
23
backend/app/services/safety/emergency/__init__.py
Normal file
23
backend/app/services/safety/emergency/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Emergency controls for agent safety."""
|
||||
|
||||
from .controls import (
|
||||
EmergencyControls,
|
||||
EmergencyEvent,
|
||||
EmergencyReason,
|
||||
EmergencyState,
|
||||
EmergencyTrigger,
|
||||
check_emergency_allowed,
|
||||
emergency_stop_global,
|
||||
get_emergency_controls,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"EmergencyControls",
|
||||
"EmergencyEvent",
|
||||
"EmergencyReason",
|
||||
"EmergencyState",
|
||||
"EmergencyTrigger",
|
||||
"check_emergency_allowed",
|
||||
"emergency_stop_global",
|
||||
"get_emergency_controls",
|
||||
]
|
||||
594
backend/app/services/safety/emergency/controls.py
Normal file
594
backend/app/services/safety/emergency/controls.py
Normal file
@@ -0,0 +1,594 @@
|
||||
"""
|
||||
Emergency Controls
|
||||
|
||||
Emergency stop and pause functionality for agent safety.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from ..exceptions import EmergencyStopError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmergencyState(str, Enum):
|
||||
"""Emergency control states."""
|
||||
|
||||
NORMAL = "normal"
|
||||
PAUSED = "paused"
|
||||
STOPPED = "stopped"
|
||||
|
||||
|
||||
class EmergencyReason(str, Enum):
|
||||
"""Reasons for emergency actions."""
|
||||
|
||||
MANUAL = "manual"
|
||||
SAFETY_VIOLATION = "safety_violation"
|
||||
BUDGET_EXCEEDED = "budget_exceeded"
|
||||
LOOP_DETECTED = "loop_detected"
|
||||
RATE_LIMIT = "rate_limit"
|
||||
CONTENT_VIOLATION = "content_violation"
|
||||
SYSTEM_ERROR = "system_error"
|
||||
EXTERNAL_TRIGGER = "external_trigger"
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmergencyEvent:
|
||||
"""Record of an emergency action."""
|
||||
|
||||
id: str
|
||||
state: EmergencyState
|
||||
reason: EmergencyReason
|
||||
triggered_by: str
|
||||
message: str
|
||||
scope: str # "global", "project:<id>", "agent:<id>"
|
||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
resolved_at: datetime | None = None
|
||||
resolved_by: str | None = None
|
||||
|
||||
|
||||
class EmergencyControls:
|
||||
"""
|
||||
Emergency stop and pause controls for agent safety.
|
||||
|
||||
Features:
|
||||
- Global emergency stop
|
||||
- Per-project/agent emergency controls
|
||||
- Graceful pause with state preservation
|
||||
- Automatic triggers from safety violations
|
||||
- Manual override capabilities
|
||||
- Event history and audit trail
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
notification_handlers: list[Callable[..., Any]] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize EmergencyControls.
|
||||
|
||||
Args:
|
||||
notification_handlers: Handlers to call on emergency events
|
||||
"""
|
||||
self._global_state = EmergencyState.NORMAL
|
||||
self._scoped_states: dict[str, EmergencyState] = {}
|
||||
self._events: list[EmergencyEvent] = []
|
||||
self._notification_handlers = notification_handlers or []
|
||||
self._lock = asyncio.Lock()
|
||||
self._event_id_counter = 0
|
||||
|
||||
# Callbacks for state changes
|
||||
self._on_stop_callbacks: list[Callable[..., Any]] = []
|
||||
self._on_pause_callbacks: list[Callable[..., Any]] = []
|
||||
self._on_resume_callbacks: list[Callable[..., Any]] = []
|
||||
|
||||
def _generate_event_id(self) -> str:
|
||||
"""Generate a unique event ID."""
|
||||
self._event_id_counter += 1
|
||||
return f"emerg-{self._event_id_counter:06d}"
|
||||
|
||||
async def emergency_stop(
|
||||
self,
|
||||
reason: EmergencyReason,
|
||||
triggered_by: str,
|
||||
message: str,
|
||||
scope: str = "global",
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> EmergencyEvent:
|
||||
"""
|
||||
Trigger emergency stop.
|
||||
|
||||
Args:
|
||||
reason: Reason for the stop
|
||||
triggered_by: Who/what triggered the stop
|
||||
message: Human-readable message
|
||||
scope: Scope of the stop (global, project:<id>, agent:<id>)
|
||||
metadata: Additional context
|
||||
|
||||
Returns:
|
||||
The emergency event record
|
||||
"""
|
||||
async with self._lock:
|
||||
event = EmergencyEvent(
|
||||
id=self._generate_event_id(),
|
||||
state=EmergencyState.STOPPED,
|
||||
reason=reason,
|
||||
triggered_by=triggered_by,
|
||||
message=message,
|
||||
scope=scope,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
if scope == "global":
|
||||
self._global_state = EmergencyState.STOPPED
|
||||
else:
|
||||
self._scoped_states[scope] = EmergencyState.STOPPED
|
||||
|
||||
self._events.append(event)
|
||||
|
||||
logger.critical(
|
||||
"EMERGENCY STOP: scope=%s, reason=%s, by=%s - %s",
|
||||
scope,
|
||||
reason.value,
|
||||
triggered_by,
|
||||
message,
|
||||
)
|
||||
|
||||
# Execute callbacks
|
||||
await self._execute_callbacks(self._on_stop_callbacks, event)
|
||||
await self._notify_handlers("emergency_stop", event)
|
||||
|
||||
return event
|
||||
|
||||
async def pause(
|
||||
self,
|
||||
reason: EmergencyReason,
|
||||
triggered_by: str,
|
||||
message: str,
|
||||
scope: str = "global",
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> EmergencyEvent:
|
||||
"""
|
||||
Pause operations (can be resumed).
|
||||
|
||||
Args:
|
||||
reason: Reason for the pause
|
||||
triggered_by: Who/what triggered the pause
|
||||
message: Human-readable message
|
||||
scope: Scope of the pause
|
||||
metadata: Additional context
|
||||
|
||||
Returns:
|
||||
The emergency event record
|
||||
"""
|
||||
async with self._lock:
|
||||
event = EmergencyEvent(
|
||||
id=self._generate_event_id(),
|
||||
state=EmergencyState.PAUSED,
|
||||
reason=reason,
|
||||
triggered_by=triggered_by,
|
||||
message=message,
|
||||
scope=scope,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
if scope == "global":
|
||||
self._global_state = EmergencyState.PAUSED
|
||||
else:
|
||||
self._scoped_states[scope] = EmergencyState.PAUSED
|
||||
|
||||
self._events.append(event)
|
||||
|
||||
logger.warning(
|
||||
"PAUSE: scope=%s, reason=%s, by=%s - %s",
|
||||
scope,
|
||||
reason.value,
|
||||
triggered_by,
|
||||
message,
|
||||
)
|
||||
|
||||
await self._execute_callbacks(self._on_pause_callbacks, event)
|
||||
await self._notify_handlers("pause", event)
|
||||
|
||||
return event
|
||||
|
||||
async def resume(
|
||||
self,
|
||||
scope: str = "global",
|
||||
resumed_by: str = "system",
|
||||
message: str | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Resume operations from paused state.
|
||||
|
||||
Args:
|
||||
scope: Scope to resume
|
||||
resumed_by: Who/what is resuming
|
||||
message: Optional message
|
||||
|
||||
Returns:
|
||||
True if resumed, False if not in paused state
|
||||
"""
|
||||
async with self._lock:
|
||||
current_state = self._get_state(scope)
|
||||
|
||||
if current_state == EmergencyState.STOPPED:
|
||||
logger.warning(
|
||||
"Cannot resume from STOPPED state: %s (requires reset)",
|
||||
scope,
|
||||
)
|
||||
return False
|
||||
|
||||
if current_state == EmergencyState.NORMAL:
|
||||
return True # Already normal
|
||||
|
||||
# Find the pause event and mark as resolved
|
||||
for event in reversed(self._events):
|
||||
if event.scope == scope and event.state == EmergencyState.PAUSED:
|
||||
if event.resolved_at is None:
|
||||
event.resolved_at = datetime.utcnow()
|
||||
event.resolved_by = resumed_by
|
||||
break
|
||||
|
||||
if scope == "global":
|
||||
self._global_state = EmergencyState.NORMAL
|
||||
else:
|
||||
self._scoped_states[scope] = EmergencyState.NORMAL
|
||||
|
||||
logger.info(
|
||||
"RESUMED: scope=%s, by=%s%s",
|
||||
scope,
|
||||
resumed_by,
|
||||
f" - {message}" if message else "",
|
||||
)
|
||||
|
||||
await self._execute_callbacks(
|
||||
self._on_resume_callbacks,
|
||||
{"scope": scope, "resumed_by": resumed_by},
|
||||
)
|
||||
await self._notify_handlers("resume", {"scope": scope, "resumed_by": resumed_by})
|
||||
|
||||
return True
|
||||
|
||||
async def reset(
|
||||
self,
|
||||
scope: str = "global",
|
||||
reset_by: str = "admin",
|
||||
message: str | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Reset from stopped state (requires explicit action).
|
||||
|
||||
Args:
|
||||
scope: Scope to reset
|
||||
reset_by: Who is resetting (should be admin)
|
||||
message: Optional message
|
||||
|
||||
Returns:
|
||||
True if reset successful
|
||||
"""
|
||||
async with self._lock:
|
||||
current_state = self._get_state(scope)
|
||||
|
||||
if current_state == EmergencyState.NORMAL:
|
||||
return True
|
||||
|
||||
# Find the stop event and mark as resolved
|
||||
for event in reversed(self._events):
|
||||
if event.scope == scope and event.state == EmergencyState.STOPPED:
|
||||
if event.resolved_at is None:
|
||||
event.resolved_at = datetime.utcnow()
|
||||
event.resolved_by = reset_by
|
||||
break
|
||||
|
||||
if scope == "global":
|
||||
self._global_state = EmergencyState.NORMAL
|
||||
else:
|
||||
self._scoped_states[scope] = EmergencyState.NORMAL
|
||||
|
||||
logger.warning(
|
||||
"EMERGENCY RESET: scope=%s, by=%s%s",
|
||||
scope,
|
||||
reset_by,
|
||||
f" - {message}" if message else "",
|
||||
)
|
||||
|
||||
await self._notify_handlers("reset", {"scope": scope, "reset_by": reset_by})
|
||||
|
||||
return True
|
||||
|
||||
async def check_allowed(
|
||||
self,
|
||||
scope: str | None = None,
|
||||
raise_if_blocked: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if operations are allowed.
|
||||
|
||||
Args:
|
||||
scope: Specific scope to check (also checks global)
|
||||
raise_if_blocked: Raise exception if blocked
|
||||
|
||||
Returns:
|
||||
True if operations are allowed
|
||||
|
||||
Raises:
|
||||
EmergencyStopError: If blocked and raise_if_blocked=True
|
||||
"""
|
||||
async with self._lock:
|
||||
# Always check global state
|
||||
if self._global_state != EmergencyState.NORMAL:
|
||||
if raise_if_blocked:
|
||||
raise EmergencyStopError(
|
||||
f"Global emergency state: {self._global_state.value}",
|
||||
stop_type=self._get_last_reason("global") or "emergency",
|
||||
triggered_by=self._get_last_triggered_by("global"),
|
||||
)
|
||||
return False
|
||||
|
||||
# Check specific scope
|
||||
if scope and scope in self._scoped_states:
|
||||
state = self._scoped_states[scope]
|
||||
if state != EmergencyState.NORMAL:
|
||||
if raise_if_blocked:
|
||||
raise EmergencyStopError(
|
||||
f"Emergency state for {scope}: {state.value}",
|
||||
stop_type=self._get_last_reason(scope) or "emergency",
|
||||
triggered_by=self._get_last_triggered_by(scope),
|
||||
details={"scope": scope},
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _get_state(self, scope: str) -> EmergencyState:
|
||||
"""Get state for a scope."""
|
||||
if scope == "global":
|
||||
return self._global_state
|
||||
return self._scoped_states.get(scope, EmergencyState.NORMAL)
|
||||
|
||||
def _get_last_reason(self, scope: str) -> str:
|
||||
"""Get reason from last event for scope."""
|
||||
for event in reversed(self._events):
|
||||
if event.scope == scope and event.resolved_at is None:
|
||||
return event.reason.value
|
||||
return "unknown"
|
||||
|
||||
def _get_last_triggered_by(self, scope: str) -> str:
|
||||
"""Get triggered_by from last event for scope."""
|
||||
for event in reversed(self._events):
|
||||
if event.scope == scope and event.resolved_at is None:
|
||||
return event.triggered_by
|
||||
return "unknown"
|
||||
|
||||
async def get_state(self, scope: str = "global") -> EmergencyState:
|
||||
"""Get current state for a scope."""
|
||||
async with self._lock:
|
||||
return self._get_state(scope)
|
||||
|
||||
async def get_all_states(self) -> dict[str, EmergencyState]:
|
||||
"""Get all current states."""
|
||||
async with self._lock:
|
||||
states = {"global": self._global_state}
|
||||
states.update(self._scoped_states)
|
||||
return states
|
||||
|
||||
async def get_active_events(self) -> list[EmergencyEvent]:
|
||||
"""Get all unresolved emergency events."""
|
||||
async with self._lock:
|
||||
return [e for e in self._events if e.resolved_at is None]
|
||||
|
||||
async def get_event_history(
|
||||
self,
|
||||
scope: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[EmergencyEvent]:
|
||||
"""Get emergency event history."""
|
||||
async with self._lock:
|
||||
events = list(self._events)
|
||||
|
||||
if scope:
|
||||
events = [e for e in events if e.scope == scope]
|
||||
|
||||
return events[-limit:]
|
||||
|
||||
def on_stop(self, callback: Callable[..., Any]) -> None:
|
||||
"""Register callback for stop events."""
|
||||
self._on_stop_callbacks.append(callback)
|
||||
|
||||
def on_pause(self, callback: Callable[..., Any]) -> None:
|
||||
"""Register callback for pause events."""
|
||||
self._on_pause_callbacks.append(callback)
|
||||
|
||||
def on_resume(self, callback: Callable[..., Any]) -> None:
|
||||
"""Register callback for resume events."""
|
||||
self._on_resume_callbacks.append(callback)
|
||||
|
||||
def add_notification_handler(self, handler: Callable[..., Any]) -> None:
|
||||
"""Add a notification handler."""
|
||||
self._notification_handlers.append(handler)
|
||||
|
||||
async def _execute_callbacks(
|
||||
self,
|
||||
callbacks: list[Callable[..., Any]],
|
||||
data: Any,
|
||||
) -> None:
|
||||
"""Execute callbacks safely."""
|
||||
for callback in callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(data)
|
||||
else:
|
||||
callback(data)
|
||||
except Exception as e:
|
||||
logger.error("Error in callback: %s", e)
|
||||
|
||||
async def _notify_handlers(self, event_type: str, data: Any) -> None:
|
||||
"""Notify all handlers of an event."""
|
||||
for handler in self._notification_handlers:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
await handler(event_type, data)
|
||||
else:
|
||||
handler(event_type, data)
|
||||
except Exception as e:
|
||||
logger.error("Error in notification handler: %s", e)
|
||||
|
||||
|
||||
class EmergencyTrigger:
|
||||
"""
|
||||
Automatic emergency triggers based on conditions.
|
||||
"""
|
||||
|
||||
def __init__(self, controls: EmergencyControls) -> None:
|
||||
"""
|
||||
Initialize EmergencyTrigger.
|
||||
|
||||
Args:
|
||||
controls: EmergencyControls instance to trigger
|
||||
"""
|
||||
self._controls = controls
|
||||
|
||||
async def trigger_on_safety_violation(
|
||||
self,
|
||||
violation_type: str,
|
||||
details: dict[str, Any],
|
||||
scope: str = "global",
|
||||
) -> EmergencyEvent:
|
||||
"""
|
||||
Trigger emergency from safety violation.
|
||||
|
||||
Args:
|
||||
violation_type: Type of violation
|
||||
details: Violation details
|
||||
scope: Scope for the emergency
|
||||
|
||||
Returns:
|
||||
Emergency event
|
||||
"""
|
||||
return await self._controls.emergency_stop(
|
||||
reason=EmergencyReason.SAFETY_VIOLATION,
|
||||
triggered_by="safety_system",
|
||||
message=f"Safety violation: {violation_type}",
|
||||
scope=scope,
|
||||
metadata={"violation_type": violation_type, **details},
|
||||
)
|
||||
|
||||
async def trigger_on_budget_exceeded(
|
||||
self,
|
||||
budget_type: str,
|
||||
current: float,
|
||||
limit: float,
|
||||
scope: str = "global",
|
||||
) -> EmergencyEvent:
|
||||
"""
|
||||
Trigger emergency from budget exceeded.
|
||||
|
||||
Args:
|
||||
budget_type: Type of budget
|
||||
current: Current usage
|
||||
limit: Budget limit
|
||||
scope: Scope for the emergency
|
||||
|
||||
Returns:
|
||||
Emergency event
|
||||
"""
|
||||
return await self._controls.pause(
|
||||
reason=EmergencyReason.BUDGET_EXCEEDED,
|
||||
triggered_by="budget_controller",
|
||||
message=f"Budget exceeded: {budget_type} ({current:.2f}/{limit:.2f})",
|
||||
scope=scope,
|
||||
metadata={"budget_type": budget_type, "current": current, "limit": limit},
|
||||
)
|
||||
|
||||
async def trigger_on_loop_detected(
|
||||
self,
|
||||
loop_type: str,
|
||||
agent_id: str,
|
||||
details: dict[str, Any],
|
||||
) -> EmergencyEvent:
|
||||
"""
|
||||
Trigger emergency from loop detection.
|
||||
|
||||
Args:
|
||||
loop_type: Type of loop
|
||||
agent_id: Agent that's looping
|
||||
details: Loop details
|
||||
|
||||
Returns:
|
||||
Emergency event
|
||||
"""
|
||||
return await self._controls.pause(
|
||||
reason=EmergencyReason.LOOP_DETECTED,
|
||||
triggered_by="loop_detector",
|
||||
message=f"Loop detected: {loop_type} in agent {agent_id}",
|
||||
scope=f"agent:{agent_id}",
|
||||
metadata={"loop_type": loop_type, "agent_id": agent_id, **details},
|
||||
)
|
||||
|
||||
async def trigger_on_content_violation(
|
||||
self,
|
||||
category: str,
|
||||
pattern: str,
|
||||
scope: str = "global",
|
||||
) -> EmergencyEvent:
|
||||
"""
|
||||
Trigger emergency from content violation.
|
||||
|
||||
Args:
|
||||
category: Content category
|
||||
pattern: Pattern that matched
|
||||
scope: Scope for the emergency
|
||||
|
||||
Returns:
|
||||
Emergency event
|
||||
"""
|
||||
return await self._controls.emergency_stop(
|
||||
reason=EmergencyReason.CONTENT_VIOLATION,
|
||||
triggered_by="content_filter",
|
||||
message=f"Content violation: {category} ({pattern})",
|
||||
scope=scope,
|
||||
metadata={"category": category, "pattern": pattern},
|
||||
)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_emergency_controls: EmergencyControls | None = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_emergency_controls() -> EmergencyControls:
|
||||
"""Get the singleton EmergencyControls instance."""
|
||||
global _emergency_controls
|
||||
|
||||
async with _lock:
|
||||
if _emergency_controls is None:
|
||||
_emergency_controls = EmergencyControls()
|
||||
return _emergency_controls
|
||||
|
||||
|
||||
async def emergency_stop_global(
|
||||
reason: str,
|
||||
triggered_by: str = "system",
|
||||
) -> EmergencyEvent:
|
||||
"""Quick global emergency stop."""
|
||||
controls = await get_emergency_controls()
|
||||
return await controls.emergency_stop(
|
||||
reason=EmergencyReason.MANUAL,
|
||||
triggered_by=triggered_by,
|
||||
message=reason,
|
||||
scope="global",
|
||||
)
|
||||
|
||||
|
||||
async def check_emergency_allowed(scope: str | None = None) -> bool:
|
||||
"""Quick check if operations are allowed."""
|
||||
controls = await get_emergency_controls()
|
||||
return await controls.check_allowed(scope=scope, raise_if_blocked=False)
|
||||
277
backend/app/services/safety/exceptions.py
Normal file
277
backend/app/services/safety/exceptions.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""
|
||||
Safety Framework Exceptions
|
||||
|
||||
Custom exception classes for the safety and guardrails framework.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class SafetyError(Exception):
|
||||
"""Base exception for all safety-related errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
action_id: str | None = None,
|
||||
agent_id: str | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.action_id = action_id
|
||||
self.agent_id = agent_id
|
||||
self.details = details or {}
|
||||
|
||||
|
||||
class PermissionDeniedError(SafetyError):
|
||||
"""Raised when an action is not permitted."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Permission denied",
|
||||
*,
|
||||
action_type: str | None = None,
|
||||
resource: str | None = None,
|
||||
required_permission: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.action_type = action_type
|
||||
self.resource = resource
|
||||
self.required_permission = required_permission
|
||||
|
||||
|
||||
class BudgetExceededError(SafetyError):
|
||||
"""Raised when cost budget is exceeded."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Budget exceeded",
|
||||
*,
|
||||
budget_type: str = "session",
|
||||
current_usage: float = 0.0,
|
||||
budget_limit: float = 0.0,
|
||||
unit: str = "tokens",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.budget_type = budget_type
|
||||
self.current_usage = current_usage
|
||||
self.budget_limit = budget_limit
|
||||
self.unit = unit
|
||||
|
||||
|
||||
class RateLimitExceededError(SafetyError):
|
||||
"""Raised when rate limit is exceeded."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Rate limit exceeded",
|
||||
*,
|
||||
limit_type: str = "actions",
|
||||
limit_value: int = 0,
|
||||
window_seconds: int = 60,
|
||||
retry_after_seconds: float = 0.0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.limit_type = limit_type
|
||||
self.limit_value = limit_value
|
||||
self.window_seconds = window_seconds
|
||||
self.retry_after_seconds = retry_after_seconds
|
||||
|
||||
|
||||
class LoopDetectedError(SafetyError):
|
||||
"""Raised when an action loop is detected."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Loop detected",
|
||||
*,
|
||||
loop_type: str = "exact",
|
||||
repetition_count: int = 0,
|
||||
action_pattern: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.loop_type = loop_type
|
||||
self.repetition_count = repetition_count
|
||||
self.action_pattern = action_pattern or []
|
||||
|
||||
|
||||
class ApprovalRequiredError(SafetyError):
|
||||
"""Raised when human approval is required."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Human approval required",
|
||||
*,
|
||||
approval_id: str | None = None,
|
||||
reason: str | None = None,
|
||||
timeout_seconds: int = 300,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.approval_id = approval_id
|
||||
self.reason = reason
|
||||
self.timeout_seconds = timeout_seconds
|
||||
|
||||
|
||||
class ApprovalDeniedError(SafetyError):
|
||||
"""Raised when human explicitly denies an action."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Approval denied by human",
|
||||
*,
|
||||
approval_id: str | None = None,
|
||||
denied_by: str | None = None,
|
||||
denial_reason: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.approval_id = approval_id
|
||||
self.denied_by = denied_by
|
||||
self.denial_reason = denial_reason
|
||||
|
||||
|
||||
class ApprovalTimeoutError(SafetyError):
|
||||
"""Raised when approval request times out."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Approval request timed out",
|
||||
*,
|
||||
approval_id: str | None = None,
|
||||
timeout_seconds: int = 300,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.approval_id = approval_id
|
||||
self.timeout_seconds = timeout_seconds
|
||||
|
||||
|
||||
class RollbackError(SafetyError):
|
||||
"""Raised when rollback fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Rollback failed",
|
||||
*,
|
||||
checkpoint_id: str | None = None,
|
||||
failed_actions: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.checkpoint_id = checkpoint_id
|
||||
self.failed_actions = failed_actions or []
|
||||
|
||||
|
||||
class CheckpointError(SafetyError):
|
||||
"""Raised when checkpoint creation fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Checkpoint creation failed",
|
||||
*,
|
||||
checkpoint_type: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.checkpoint_type = checkpoint_type
|
||||
|
||||
|
||||
class ValidationError(SafetyError):
|
||||
"""Raised when action validation fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Validation failed",
|
||||
*,
|
||||
validation_rules: list[str] | None = None,
|
||||
failed_rules: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.validation_rules = validation_rules or []
|
||||
self.failed_rules = failed_rules or []
|
||||
|
||||
|
||||
class ContentFilterError(SafetyError):
|
||||
"""Raised when content filtering detects prohibited content."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Prohibited content detected",
|
||||
*,
|
||||
filter_type: str | None = None,
|
||||
detected_patterns: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.filter_type = filter_type
|
||||
self.detected_patterns = detected_patterns or []
|
||||
|
||||
|
||||
class SandboxError(SafetyError):
|
||||
"""Raised when sandbox execution fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Sandbox execution failed",
|
||||
*,
|
||||
exit_code: int | None = None,
|
||||
stderr: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.exit_code = exit_code
|
||||
self.stderr = stderr
|
||||
|
||||
|
||||
class SandboxTimeoutError(SandboxError):
|
||||
"""Raised when sandbox execution times out."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Sandbox execution timed out",
|
||||
*,
|
||||
timeout_seconds: int = 300,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.timeout_seconds = timeout_seconds
|
||||
|
||||
|
||||
class EmergencyStopError(SafetyError):
|
||||
"""Raised when emergency stop is triggered."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Emergency stop triggered",
|
||||
*,
|
||||
stop_type: str = "kill",
|
||||
triggered_by: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.stop_type = stop_type
|
||||
self.triggered_by = triggered_by
|
||||
|
||||
|
||||
class PolicyViolationError(SafetyError):
|
||||
"""Raised when an action violates a safety policy."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Policy violation",
|
||||
*,
|
||||
policy_name: str | None = None,
|
||||
violated_rules: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.policy_name = policy_name
|
||||
self.violated_rules = violated_rules or []
|
||||
614
backend/app/services/safety/guardian.py
Normal file
614
backend/app/services/safety/guardian.py
Normal file
@@ -0,0 +1,614 @@
|
||||
"""
|
||||
Safety Guardian
|
||||
|
||||
Main facade for the safety framework. Orchestrates all safety checks
|
||||
before, during, and after action execution.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from .audit import AuditLogger, get_audit_logger
|
||||
from .config import (
|
||||
SafetyConfig,
|
||||
get_policy_for_autonomy_level,
|
||||
get_safety_config,
|
||||
)
|
||||
from .exceptions import (
|
||||
SafetyError,
|
||||
)
|
||||
from .models import (
|
||||
ActionRequest,
|
||||
ActionResult,
|
||||
AuditEventType,
|
||||
GuardianResult,
|
||||
SafetyDecision,
|
||||
SafetyPolicy,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SafetyGuardian:
|
||||
"""
|
||||
Central orchestrator for all safety checks.
|
||||
|
||||
The SafetyGuardian is the main entry point for validating agent actions.
|
||||
It coordinates multiple safety subsystems:
|
||||
- Permission checking
|
||||
- Cost/budget control
|
||||
- Rate limiting
|
||||
- Loop detection
|
||||
- Human-in-the-loop approval
|
||||
- Rollback/checkpoint management
|
||||
- Content filtering
|
||||
- Sandbox execution
|
||||
|
||||
Usage:
|
||||
guardian = SafetyGuardian()
|
||||
await guardian.initialize()
|
||||
|
||||
# Before executing an action
|
||||
result = await guardian.validate(action_request)
|
||||
if not result.allowed:
|
||||
# Handle denial
|
||||
|
||||
# After action execution
|
||||
await guardian.record_execution(action_request, action_result)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SafetyConfig | None = None,
|
||||
audit_logger: AuditLogger | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the SafetyGuardian.
|
||||
|
||||
Args:
|
||||
config: Optional safety configuration. If None, loads from environment.
|
||||
audit_logger: Optional audit logger. If None, uses global instance.
|
||||
"""
|
||||
self._config = config or get_safety_config()
|
||||
self._audit_logger = audit_logger
|
||||
self._initialized = False
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Subsystem references (will be initialized lazily)
|
||||
self._permission_manager: Any = None
|
||||
self._cost_controller: Any = None
|
||||
self._rate_limiter: Any = None
|
||||
self._loop_detector: Any = None
|
||||
self._hitl_manager: Any = None
|
||||
self._rollback_manager: Any = None
|
||||
self._content_filter: Any = None
|
||||
self._sandbox_executor: Any = None
|
||||
self._emergency_controls: Any = None
|
||||
|
||||
# Policy cache
|
||||
self._policies: dict[str, SafetyPolicy] = {}
|
||||
self._default_policy: SafetyPolicy | None = None
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if the guardian is initialized."""
|
||||
return self._initialized
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the SafetyGuardian and all subsystems."""
|
||||
async with self._lock:
|
||||
if self._initialized:
|
||||
logger.warning("SafetyGuardian already initialized")
|
||||
return
|
||||
|
||||
logger.info("Initializing SafetyGuardian")
|
||||
|
||||
# Get audit logger
|
||||
if self._audit_logger is None:
|
||||
self._audit_logger = await get_audit_logger()
|
||||
|
||||
# Initialize subsystems lazily as they're implemented
|
||||
# For now, we'll import and initialize them when available
|
||||
|
||||
self._initialized = True
|
||||
logger.info("SafetyGuardian initialized")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Shutdown the SafetyGuardian and all subsystems."""
|
||||
async with self._lock:
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
logger.info("Shutting down SafetyGuardian")
|
||||
|
||||
# Shutdown subsystems
|
||||
# (Will be implemented as subsystems are added)
|
||||
|
||||
self._initialized = False
|
||||
logger.info("SafetyGuardian shutdown complete")
|
||||
|
||||
async def validate(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
policy: SafetyPolicy | None = None,
|
||||
) -> GuardianResult:
|
||||
"""
|
||||
Validate an action before execution.
|
||||
|
||||
Runs all safety checks in order:
|
||||
1. Permission check
|
||||
2. Cost/budget check
|
||||
3. Rate limit check
|
||||
4. Loop detection
|
||||
5. HITL check (if required)
|
||||
6. Checkpoint creation (if destructive)
|
||||
|
||||
Args:
|
||||
action: The action to validate
|
||||
policy: Optional policy override. If None, uses autonomy-level policy.
|
||||
|
||||
Returns:
|
||||
GuardianResult with decision and details
|
||||
"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
if not self._config.enabled:
|
||||
# Safety disabled - allow everything (NOT RECOMMENDED)
|
||||
logger.warning("Safety framework disabled - allowing action %s", action.id)
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["Safety framework disabled"],
|
||||
)
|
||||
|
||||
# Get policy for this action
|
||||
effective_policy = policy or self._get_policy(action)
|
||||
|
||||
reasons: list[str] = []
|
||||
audit_events = []
|
||||
|
||||
try:
|
||||
# Log action request
|
||||
if self._audit_logger:
|
||||
event = await self._audit_logger.log(
|
||||
AuditEventType.ACTION_REQUESTED,
|
||||
agent_id=action.metadata.agent_id,
|
||||
action_id=action.id,
|
||||
project_id=action.metadata.project_id,
|
||||
session_id=action.metadata.session_id,
|
||||
details={
|
||||
"action_type": action.action_type.value,
|
||||
"tool_name": action.tool_name,
|
||||
"resource": action.resource,
|
||||
},
|
||||
correlation_id=action.metadata.correlation_id,
|
||||
)
|
||||
audit_events.append(event)
|
||||
|
||||
# 1. Permission check
|
||||
permission_result = await self._check_permissions(action, effective_policy)
|
||||
if permission_result.decision == SafetyDecision.DENY:
|
||||
return await self._create_denial_result(
|
||||
action, permission_result.reasons, audit_events
|
||||
)
|
||||
|
||||
# 2. Cost/budget check
|
||||
budget_result = await self._check_budget(action, effective_policy)
|
||||
if budget_result.decision == SafetyDecision.DENY:
|
||||
return await self._create_denial_result(
|
||||
action, budget_result.reasons, audit_events
|
||||
)
|
||||
|
||||
# 3. Rate limit check
|
||||
rate_result = await self._check_rate_limit(action, effective_policy)
|
||||
if rate_result.decision == SafetyDecision.DENY:
|
||||
return await self._create_denial_result(
|
||||
action,
|
||||
rate_result.reasons,
|
||||
audit_events,
|
||||
retry_after=rate_result.retry_after_seconds,
|
||||
)
|
||||
if rate_result.decision == SafetyDecision.DELAY:
|
||||
# Return delay decision
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DELAY,
|
||||
reasons=rate_result.reasons,
|
||||
retry_after_seconds=rate_result.retry_after_seconds,
|
||||
audit_events=audit_events,
|
||||
)
|
||||
|
||||
# 4. Loop detection
|
||||
loop_result = await self._check_loops(action, effective_policy)
|
||||
if loop_result.decision == SafetyDecision.DENY:
|
||||
return await self._create_denial_result(
|
||||
action, loop_result.reasons, audit_events
|
||||
)
|
||||
|
||||
# 5. HITL check
|
||||
hitl_result = await self._check_hitl(action, effective_policy)
|
||||
if hitl_result.decision == SafetyDecision.REQUIRE_APPROVAL:
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.REQUIRE_APPROVAL,
|
||||
reasons=hitl_result.reasons,
|
||||
approval_id=hitl_result.approval_id,
|
||||
audit_events=audit_events,
|
||||
)
|
||||
|
||||
# 6. Create checkpoint if destructive
|
||||
checkpoint_id = None
|
||||
if action.is_destructive and self._config.auto_checkpoint_destructive:
|
||||
checkpoint_id = await self._create_checkpoint(action)
|
||||
|
||||
# All checks passed
|
||||
reasons.append("All safety checks passed")
|
||||
|
||||
if self._audit_logger:
|
||||
event = await self._audit_logger.log_action_request(
|
||||
action, SafetyDecision.ALLOW, reasons
|
||||
)
|
||||
audit_events.append(event)
|
||||
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=reasons,
|
||||
checkpoint_id=checkpoint_id,
|
||||
audit_events=audit_events,
|
||||
)
|
||||
|
||||
except SafetyError as e:
|
||||
# Known safety error
|
||||
return await self._create_denial_result(
|
||||
action, [str(e)], audit_events
|
||||
)
|
||||
except Exception as e:
|
||||
# Unknown error - fail closed in strict mode
|
||||
logger.error("Unexpected error in safety validation: %s", e)
|
||||
if self._config.strict_mode:
|
||||
return await self._create_denial_result(
|
||||
action,
|
||||
[f"Safety validation error: {e}"],
|
||||
audit_events,
|
||||
)
|
||||
else:
|
||||
# Non-strict mode - allow with warning
|
||||
logger.warning("Non-strict mode: allowing action despite error")
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["Allowed despite validation error (non-strict mode)"],
|
||||
audit_events=audit_events,
|
||||
)
|
||||
|
||||
async def record_execution(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
result: ActionResult,
|
||||
) -> None:
|
||||
"""
|
||||
Record action execution result for auditing and tracking.
|
||||
|
||||
Args:
|
||||
action: The executed action
|
||||
result: The execution result
|
||||
"""
|
||||
if self._audit_logger:
|
||||
await self._audit_logger.log_action_executed(
|
||||
action,
|
||||
success=result.success,
|
||||
execution_time_ms=result.execution_time_ms,
|
||||
error=result.error,
|
||||
)
|
||||
|
||||
# Update cost tracking
|
||||
if self._cost_controller:
|
||||
# Track actual cost
|
||||
pass
|
||||
|
||||
# Update loop detection history
|
||||
if self._loop_detector:
|
||||
# Add to action history
|
||||
pass
|
||||
|
||||
async def rollback(self, checkpoint_id: str) -> bool:
|
||||
"""
|
||||
Rollback to a checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_id: ID of the checkpoint to rollback to
|
||||
|
||||
Returns:
|
||||
True if rollback succeeded
|
||||
"""
|
||||
if self._rollback_manager is None:
|
||||
logger.warning("Rollback manager not available")
|
||||
return False
|
||||
|
||||
# Delegate to rollback manager
|
||||
return await self._rollback_manager.rollback(checkpoint_id)
|
||||
|
||||
async def emergency_stop(
|
||||
self,
|
||||
stop_type: str = "kill",
|
||||
reason: str = "Manual emergency stop",
|
||||
triggered_by: str = "system",
|
||||
) -> None:
|
||||
"""
|
||||
Trigger emergency stop.
|
||||
|
||||
Args:
|
||||
stop_type: Type of stop (kill, pause, lockdown)
|
||||
reason: Reason for the stop
|
||||
triggered_by: Who triggered the stop
|
||||
"""
|
||||
logger.critical(
|
||||
"Emergency stop triggered: type=%s, reason=%s, by=%s",
|
||||
stop_type,
|
||||
reason,
|
||||
triggered_by,
|
||||
)
|
||||
|
||||
if self._audit_logger:
|
||||
await self._audit_logger.log_emergency_stop(
|
||||
stop_type=stop_type,
|
||||
triggered_by=triggered_by,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
if self._emergency_controls:
|
||||
await self._emergency_controls.execute_stop(stop_type)
|
||||
|
||||
def _get_policy(self, action: ActionRequest) -> SafetyPolicy:
|
||||
"""Get the effective policy for an action."""
|
||||
# Check cached policies
|
||||
autonomy_level = action.metadata.autonomy_level
|
||||
|
||||
if autonomy_level.value not in self._policies:
|
||||
self._policies[autonomy_level.value] = get_policy_for_autonomy_level(
|
||||
autonomy_level
|
||||
)
|
||||
|
||||
return self._policies[autonomy_level.value]
|
||||
|
||||
async def _check_permissions(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
policy: SafetyPolicy,
|
||||
) -> GuardianResult:
|
||||
"""Check if action is permitted."""
|
||||
reasons: list[str] = []
|
||||
|
||||
# Check denied tools
|
||||
if action.tool_name:
|
||||
for pattern in policy.denied_tools:
|
||||
if self._matches_pattern(action.tool_name, pattern):
|
||||
reasons.append(f"Tool '{action.tool_name}' denied by pattern '{pattern}'")
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DENY,
|
||||
reasons=reasons,
|
||||
)
|
||||
|
||||
# Check allowed tools (if not "*")
|
||||
if action.tool_name and "*" not in policy.allowed_tools:
|
||||
allowed = False
|
||||
for pattern in policy.allowed_tools:
|
||||
if self._matches_pattern(action.tool_name, pattern):
|
||||
allowed = True
|
||||
break
|
||||
if not allowed:
|
||||
reasons.append(f"Tool '{action.tool_name}' not in allowed list")
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DENY,
|
||||
reasons=reasons,
|
||||
)
|
||||
|
||||
# Check file patterns
|
||||
if action.resource:
|
||||
for pattern in policy.denied_file_patterns:
|
||||
if self._matches_pattern(action.resource, pattern):
|
||||
reasons.append(f"Resource '{action.resource}' denied by pattern '{pattern}'")
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DENY,
|
||||
reasons=reasons,
|
||||
)
|
||||
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["Permission check passed"],
|
||||
)
|
||||
|
||||
async def _check_budget(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
policy: SafetyPolicy,
|
||||
) -> GuardianResult:
|
||||
"""Check if action is within budget."""
|
||||
# TODO: Implement with CostController
|
||||
# For now, return allow
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["Budget check passed (not fully implemented)"],
|
||||
)
|
||||
|
||||
async def _check_rate_limit(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
policy: SafetyPolicy,
|
||||
) -> GuardianResult:
|
||||
"""Check if action is within rate limits."""
|
||||
# TODO: Implement with RateLimiter
|
||||
# For now, return allow
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["Rate limit check passed (not fully implemented)"],
|
||||
)
|
||||
|
||||
async def _check_loops(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
policy: SafetyPolicy,
|
||||
) -> GuardianResult:
|
||||
"""Check for action loops."""
|
||||
# TODO: Implement with LoopDetector
|
||||
# For now, return allow
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["Loop check passed (not fully implemented)"],
|
||||
)
|
||||
|
||||
async def _check_hitl(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
policy: SafetyPolicy,
|
||||
) -> GuardianResult:
|
||||
"""Check if human approval is required."""
|
||||
if not self._config.hitl_enabled:
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["HITL disabled"],
|
||||
)
|
||||
|
||||
# Check if action requires approval
|
||||
requires_approval = False
|
||||
for pattern in policy.require_approval_for:
|
||||
if pattern == "*":
|
||||
requires_approval = True
|
||||
break
|
||||
if action.tool_name and self._matches_pattern(action.tool_name, pattern):
|
||||
requires_approval = True
|
||||
break
|
||||
if action.action_type.value and self._matches_pattern(
|
||||
action.action_type.value, pattern
|
||||
):
|
||||
requires_approval = True
|
||||
break
|
||||
|
||||
if requires_approval:
|
||||
# TODO: Create approval request with HITLManager
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.REQUIRE_APPROVAL,
|
||||
reasons=["Action requires human approval"],
|
||||
approval_id=None, # Will be set by HITLManager
|
||||
)
|
||||
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["No approval required"],
|
||||
)
|
||||
|
||||
async def _create_checkpoint(self, action: ActionRequest) -> str | None:
|
||||
"""Create a checkpoint before destructive action."""
|
||||
if self._rollback_manager is None:
|
||||
logger.warning("Rollback manager not available - skipping checkpoint")
|
||||
return None
|
||||
|
||||
# TODO: Implement with RollbackManager
|
||||
return None
|
||||
|
||||
async def _create_denial_result(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
reasons: list[str],
|
||||
audit_events: list[Any],
|
||||
retry_after: float | None = None,
|
||||
) -> GuardianResult:
|
||||
"""Create a denial result with audit logging."""
|
||||
if self._audit_logger:
|
||||
event = await self._audit_logger.log_action_request(
|
||||
action, SafetyDecision.DENY, reasons
|
||||
)
|
||||
audit_events.append(event)
|
||||
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DENY,
|
||||
reasons=reasons,
|
||||
retry_after_seconds=retry_after,
|
||||
audit_events=audit_events,
|
||||
)
|
||||
|
||||
def _matches_pattern(self, value: str, pattern: str) -> bool:
|
||||
"""Check if value matches a pattern (supports * wildcard)."""
|
||||
if pattern == "*":
|
||||
return True
|
||||
|
||||
if "*" not in pattern:
|
||||
return value == pattern
|
||||
|
||||
# Simple wildcard matching
|
||||
if pattern.startswith("*") and pattern.endswith("*"):
|
||||
return pattern[1:-1] in value
|
||||
elif pattern.startswith("*"):
|
||||
return value.endswith(pattern[1:])
|
||||
elif pattern.endswith("*"):
|
||||
return value.startswith(pattern[:-1])
|
||||
else:
|
||||
# Pattern like "foo*bar"
|
||||
parts = pattern.split("*")
|
||||
if len(parts) == 2:
|
||||
return value.startswith(parts[0]) and value.endswith(parts[1])
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_guardian_instance: SafetyGuardian | None = None
|
||||
_guardian_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_safety_guardian() -> SafetyGuardian:
|
||||
"""Get the global SafetyGuardian instance."""
|
||||
global _guardian_instance
|
||||
|
||||
async with _guardian_lock:
|
||||
if _guardian_instance is None:
|
||||
_guardian_instance = SafetyGuardian()
|
||||
await _guardian_instance.initialize()
|
||||
|
||||
return _guardian_instance
|
||||
|
||||
|
||||
async def shutdown_safety_guardian() -> None:
|
||||
"""Shutdown the global SafetyGuardian."""
|
||||
global _guardian_instance
|
||||
|
||||
async with _guardian_lock:
|
||||
if _guardian_instance is not None:
|
||||
await _guardian_instance.shutdown()
|
||||
_guardian_instance = None
|
||||
|
||||
|
||||
def reset_safety_guardian() -> None:
|
||||
"""Reset the SafetyGuardian (for testing)."""
|
||||
global _guardian_instance
|
||||
_guardian_instance = None
|
||||
5
backend/app/services/safety/hitl/__init__.py
Normal file
5
backend/app/services/safety/hitl/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Human-in-the-Loop approval workflows."""
|
||||
|
||||
from .manager import ApprovalQueue, HITLManager
|
||||
|
||||
__all__ = ["ApprovalQueue", "HITLManager"]
|
||||
449
backend/app/services/safety/hitl/manager.py
Normal file
449
backend/app/services/safety/hitl/manager.py
Normal file
@@ -0,0 +1,449 @@
|
||||
"""
|
||||
Human-in-the-Loop (HITL) Manager
|
||||
|
||||
Manages approval workflows for actions requiring human oversight.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from ..config import get_safety_config
|
||||
from ..exceptions import (
|
||||
ApprovalDeniedError,
|
||||
ApprovalRequiredError,
|
||||
ApprovalTimeoutError,
|
||||
)
|
||||
from ..models import (
|
||||
ActionRequest,
|
||||
ApprovalRequest,
|
||||
ApprovalResponse,
|
||||
ApprovalStatus,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ApprovalQueue:
|
||||
"""Queue for pending approval requests."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._pending: dict[str, ApprovalRequest] = {}
|
||||
self._completed: dict[str, ApprovalResponse] = {}
|
||||
self._waiters: dict[str, asyncio.Event] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def add(self, request: ApprovalRequest) -> None:
|
||||
"""Add an approval request to the queue."""
|
||||
async with self._lock:
|
||||
self._pending[request.id] = request
|
||||
self._waiters[request.id] = asyncio.Event()
|
||||
|
||||
async def get_pending(self, request_id: str) -> ApprovalRequest | None:
|
||||
"""Get a pending request by ID."""
|
||||
async with self._lock:
|
||||
return self._pending.get(request_id)
|
||||
|
||||
async def complete(self, response: ApprovalResponse) -> bool:
|
||||
"""Complete an approval request."""
|
||||
async with self._lock:
|
||||
if response.request_id not in self._pending:
|
||||
return False
|
||||
|
||||
del self._pending[response.request_id]
|
||||
self._completed[response.request_id] = response
|
||||
|
||||
# Notify waiters
|
||||
if response.request_id in self._waiters:
|
||||
self._waiters[response.request_id].set()
|
||||
|
||||
return True
|
||||
|
||||
async def wait_for_response(
|
||||
self,
|
||||
request_id: str,
|
||||
timeout_seconds: float,
|
||||
) -> ApprovalResponse | None:
|
||||
"""Wait for a response to an approval request."""
|
||||
async with self._lock:
|
||||
waiter = self._waiters.get(request_id)
|
||||
if not waiter:
|
||||
return self._completed.get(request_id)
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(waiter.wait(), timeout=timeout_seconds)
|
||||
except TimeoutError:
|
||||
return None
|
||||
|
||||
async with self._lock:
|
||||
return self._completed.get(request_id)
|
||||
|
||||
async def list_pending(self) -> list[ApprovalRequest]:
|
||||
"""List all pending requests."""
|
||||
async with self._lock:
|
||||
return list(self._pending.values())
|
||||
|
||||
async def cancel(self, request_id: str) -> bool:
|
||||
"""Cancel a pending request."""
|
||||
async with self._lock:
|
||||
if request_id not in self._pending:
|
||||
return False
|
||||
|
||||
del self._pending[request_id]
|
||||
|
||||
# Create cancelled response
|
||||
response = ApprovalResponse(
|
||||
request_id=request_id,
|
||||
status=ApprovalStatus.CANCELLED,
|
||||
reason="Cancelled",
|
||||
)
|
||||
self._completed[request_id] = response
|
||||
|
||||
# Notify waiters
|
||||
if request_id in self._waiters:
|
||||
self._waiters[request_id].set()
|
||||
|
||||
return True
|
||||
|
||||
async def cleanup_expired(self) -> int:
|
||||
"""Clean up expired requests."""
|
||||
now = datetime.utcnow()
|
||||
to_timeout: list[str] = []
|
||||
|
||||
async with self._lock:
|
||||
for request_id, request in self._pending.items():
|
||||
if request.expires_at and request.expires_at < now:
|
||||
to_timeout.append(request_id)
|
||||
|
||||
count = 0
|
||||
for request_id in to_timeout:
|
||||
async with self._lock:
|
||||
if request_id in self._pending:
|
||||
del self._pending[request_id]
|
||||
self._completed[request_id] = ApprovalResponse(
|
||||
request_id=request_id,
|
||||
status=ApprovalStatus.TIMEOUT,
|
||||
reason="Request timed out",
|
||||
)
|
||||
if request_id in self._waiters:
|
||||
self._waiters[request_id].set()
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
|
||||
class HITLManager:
|
||||
"""
|
||||
Manages Human-in-the-Loop approval workflows.
|
||||
|
||||
Features:
|
||||
- Approval request queue
|
||||
- Configurable timeout handling (default deny)
|
||||
- Approval delegation
|
||||
- Batch approval for similar actions
|
||||
- Approval with modifications
|
||||
- Notification channels
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default_timeout: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the HITLManager.
|
||||
|
||||
Args:
|
||||
default_timeout: Default timeout for approval requests in seconds
|
||||
"""
|
||||
config = get_safety_config()
|
||||
|
||||
self._default_timeout = default_timeout or config.hitl_default_timeout
|
||||
self._queue = ApprovalQueue()
|
||||
self._notification_handlers: list[Callable[..., Any]] = []
|
||||
self._running = False
|
||||
self._cleanup_task: asyncio.Task[None] | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the HITL manager background tasks."""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
|
||||
logger.info("HITL Manager started")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the HITL manager."""
|
||||
self._running = False
|
||||
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info("HITL Manager stopped")
|
||||
|
||||
async def request_approval(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
reason: str,
|
||||
timeout_seconds: int | None = None,
|
||||
urgency: str = "normal",
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> ApprovalRequest:
|
||||
"""
|
||||
Create an approval request for an action.
|
||||
|
||||
Args:
|
||||
action: The action requiring approval
|
||||
reason: Why approval is required
|
||||
timeout_seconds: Timeout for this request
|
||||
urgency: Urgency level (low, normal, high, critical)
|
||||
context: Additional context for the approver
|
||||
|
||||
Returns:
|
||||
The created approval request
|
||||
"""
|
||||
timeout = timeout_seconds or self._default_timeout
|
||||
expires_at = datetime.utcnow() + timedelta(seconds=timeout)
|
||||
|
||||
request = ApprovalRequest(
|
||||
id=str(uuid4()),
|
||||
action=action,
|
||||
reason=reason,
|
||||
urgency=urgency,
|
||||
timeout_seconds=timeout,
|
||||
expires_at=expires_at,
|
||||
context=context or {},
|
||||
)
|
||||
|
||||
await self._queue.add(request)
|
||||
|
||||
# Notify handlers
|
||||
await self._notify_handlers("approval_requested", request)
|
||||
|
||||
logger.info(
|
||||
"Approval requested: %s for action %s (timeout: %ds)",
|
||||
request.id,
|
||||
action.id,
|
||||
timeout,
|
||||
)
|
||||
|
||||
return request
|
||||
|
||||
async def wait_for_approval(
|
||||
self,
|
||||
request_id: str,
|
||||
timeout_seconds: int | None = None,
|
||||
) -> ApprovalResponse:
|
||||
"""
|
||||
Wait for an approval decision.
|
||||
|
||||
Args:
|
||||
request_id: ID of the approval request
|
||||
timeout_seconds: Override timeout
|
||||
|
||||
Returns:
|
||||
The approval response
|
||||
|
||||
Raises:
|
||||
ApprovalTimeoutError: If timeout expires
|
||||
ApprovalDeniedError: If approval is denied
|
||||
"""
|
||||
request = await self._queue.get_pending(request_id)
|
||||
if not request:
|
||||
raise ApprovalRequiredError(
|
||||
f"Approval request not found: {request_id}",
|
||||
approval_id=request_id,
|
||||
)
|
||||
|
||||
timeout = timeout_seconds or request.timeout_seconds or self._default_timeout
|
||||
response = await self._queue.wait_for_response(request_id, timeout)
|
||||
|
||||
if response is None:
|
||||
# Timeout - default deny
|
||||
response = ApprovalResponse(
|
||||
request_id=request_id,
|
||||
status=ApprovalStatus.TIMEOUT,
|
||||
reason="Request timed out (default deny)",
|
||||
)
|
||||
await self._queue.complete(response)
|
||||
|
||||
raise ApprovalTimeoutError(
|
||||
"Approval request timed out",
|
||||
approval_id=request_id,
|
||||
timeout_seconds=timeout,
|
||||
)
|
||||
|
||||
if response.status == ApprovalStatus.DENIED:
|
||||
raise ApprovalDeniedError(
|
||||
response.reason or "Approval denied",
|
||||
approval_id=request_id,
|
||||
denied_by=response.decided_by,
|
||||
denial_reason=response.reason,
|
||||
)
|
||||
|
||||
if response.status == ApprovalStatus.TIMEOUT:
|
||||
raise ApprovalTimeoutError(
|
||||
"Approval request timed out",
|
||||
approval_id=request_id,
|
||||
timeout_seconds=timeout,
|
||||
)
|
||||
|
||||
if response.status == ApprovalStatus.CANCELLED:
|
||||
raise ApprovalDeniedError(
|
||||
"Approval request was cancelled",
|
||||
approval_id=request_id,
|
||||
denial_reason="Cancelled",
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def approve(
|
||||
self,
|
||||
request_id: str,
|
||||
decided_by: str,
|
||||
reason: str | None = None,
|
||||
modifications: dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Approve a pending request.
|
||||
|
||||
Args:
|
||||
request_id: ID of the approval request
|
||||
decided_by: Who approved
|
||||
reason: Optional approval reason
|
||||
modifications: Optional modifications to the action
|
||||
|
||||
Returns:
|
||||
True if approval was recorded
|
||||
"""
|
||||
response = ApprovalResponse(
|
||||
request_id=request_id,
|
||||
status=ApprovalStatus.APPROVED,
|
||||
decided_by=decided_by,
|
||||
reason=reason,
|
||||
modifications=modifications,
|
||||
)
|
||||
|
||||
success = await self._queue.complete(response)
|
||||
|
||||
if success:
|
||||
logger.info(
|
||||
"Approval granted: %s by %s",
|
||||
request_id,
|
||||
decided_by,
|
||||
)
|
||||
await self._notify_handlers("approval_granted", response)
|
||||
|
||||
return success
|
||||
|
||||
async def deny(
|
||||
self,
|
||||
request_id: str,
|
||||
decided_by: str,
|
||||
reason: str | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Deny a pending request.
|
||||
|
||||
Args:
|
||||
request_id: ID of the approval request
|
||||
decided_by: Who denied
|
||||
reason: Denial reason
|
||||
|
||||
Returns:
|
||||
True if denial was recorded
|
||||
"""
|
||||
response = ApprovalResponse(
|
||||
request_id=request_id,
|
||||
status=ApprovalStatus.DENIED,
|
||||
decided_by=decided_by,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
success = await self._queue.complete(response)
|
||||
|
||||
if success:
|
||||
logger.info(
|
||||
"Approval denied: %s by %s - %s",
|
||||
request_id,
|
||||
decided_by,
|
||||
reason,
|
||||
)
|
||||
await self._notify_handlers("approval_denied", response)
|
||||
|
||||
return success
|
||||
|
||||
async def cancel(self, request_id: str) -> bool:
|
||||
"""
|
||||
Cancel a pending request.
|
||||
|
||||
Args:
|
||||
request_id: ID of the approval request
|
||||
|
||||
Returns:
|
||||
True if request was cancelled
|
||||
"""
|
||||
success = await self._queue.cancel(request_id)
|
||||
|
||||
if success:
|
||||
logger.info("Approval request cancelled: %s", request_id)
|
||||
|
||||
return success
|
||||
|
||||
async def list_pending(self) -> list[ApprovalRequest]:
|
||||
"""List all pending approval requests."""
|
||||
return await self._queue.list_pending()
|
||||
|
||||
async def get_request(self, request_id: str) -> ApprovalRequest | None:
|
||||
"""Get an approval request by ID."""
|
||||
return await self._queue.get_pending(request_id)
|
||||
|
||||
def add_notification_handler(
|
||||
self,
|
||||
handler: Callable[..., Any],
|
||||
) -> None:
|
||||
"""Add a notification handler."""
|
||||
self._notification_handlers.append(handler)
|
||||
|
||||
def remove_notification_handler(
|
||||
self,
|
||||
handler: Callable[..., Any],
|
||||
) -> None:
|
||||
"""Remove a notification handler."""
|
||||
if handler in self._notification_handlers:
|
||||
self._notification_handlers.remove(handler)
|
||||
|
||||
async def _notify_handlers(
|
||||
self,
|
||||
event_type: str,
|
||||
data: Any,
|
||||
) -> None:
|
||||
"""Notify all handlers of an event."""
|
||||
for handler in self._notification_handlers:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
await handler(event_type, data)
|
||||
else:
|
||||
handler(event_type, data)
|
||||
except Exception as e:
|
||||
logger.error("Error in notification handler: %s", e)
|
||||
|
||||
async def _periodic_cleanup(self) -> None:
|
||||
"""Background task for cleaning up expired requests."""
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(30) # Check every 30 seconds
|
||||
count = await self._queue.cleanup_expired()
|
||||
if count:
|
||||
logger.debug("Cleaned up %d expired approval requests", count)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("Error in approval cleanup: %s", e)
|
||||
15
backend/app/services/safety/limits/__init__.py
Normal file
15
backend/app/services/safety/limits/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
Rate Limiting Module
|
||||
|
||||
Sliding window rate limiting for agent operations.
|
||||
"""
|
||||
|
||||
from .limiter import (
|
||||
RateLimiter,
|
||||
SlidingWindowCounter,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"RateLimiter",
|
||||
"SlidingWindowCounter",
|
||||
]
|
||||
368
backend/app/services/safety/limits/limiter.py
Normal file
368
backend/app/services/safety/limits/limiter.py
Normal file
@@ -0,0 +1,368 @@
|
||||
"""
|
||||
Rate Limiter
|
||||
|
||||
Sliding window rate limiting for agent operations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections import deque
|
||||
|
||||
from ..config import get_safety_config
|
||||
from ..exceptions import RateLimitExceededError
|
||||
from ..models import (
|
||||
ActionRequest,
|
||||
RateLimitConfig,
|
||||
RateLimitStatus,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SlidingWindowCounter:
|
||||
"""Sliding window counter for rate limiting."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
limit: int,
|
||||
window_seconds: int,
|
||||
burst_limit: int | None = None,
|
||||
) -> None:
|
||||
self.limit = limit
|
||||
self.window_seconds = window_seconds
|
||||
self.burst_limit = burst_limit or limit
|
||||
self._timestamps: deque[float] = deque()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def try_acquire(self) -> tuple[bool, float]:
|
||||
"""
|
||||
Try to acquire a slot.
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed, retry_after_seconds)
|
||||
"""
|
||||
now = time.time()
|
||||
window_start = now - self.window_seconds
|
||||
|
||||
async with self._lock:
|
||||
# Remove expired entries
|
||||
while self._timestamps and self._timestamps[0] < window_start:
|
||||
self._timestamps.popleft()
|
||||
|
||||
current_count = len(self._timestamps)
|
||||
|
||||
# Check burst limit (instant check)
|
||||
if current_count >= self.burst_limit:
|
||||
# Calculate retry time
|
||||
oldest = self._timestamps[0] if self._timestamps else now
|
||||
retry_after = oldest + self.window_seconds - now
|
||||
return False, max(0, retry_after)
|
||||
|
||||
# Check window limit
|
||||
if current_count >= self.limit:
|
||||
oldest = self._timestamps[0] if self._timestamps else now
|
||||
retry_after = oldest + self.window_seconds - now
|
||||
return False, max(0, retry_after)
|
||||
|
||||
# Allow and record
|
||||
self._timestamps.append(now)
|
||||
return True, 0.0
|
||||
|
||||
async def get_status(self) -> tuple[int, int, float]:
|
||||
"""
|
||||
Get current status.
|
||||
|
||||
Returns:
|
||||
Tuple of (current_count, remaining, reset_in_seconds)
|
||||
"""
|
||||
now = time.time()
|
||||
window_start = now - self.window_seconds
|
||||
|
||||
async with self._lock:
|
||||
# Remove expired entries
|
||||
while self._timestamps and self._timestamps[0] < window_start:
|
||||
self._timestamps.popleft()
|
||||
|
||||
current_count = len(self._timestamps)
|
||||
remaining = max(0, self.limit - current_count)
|
||||
|
||||
if self._timestamps:
|
||||
reset_in = self._timestamps[0] + self.window_seconds - now
|
||||
else:
|
||||
reset_in = 0.0
|
||||
|
||||
return current_count, remaining, max(0, reset_in)
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""
|
||||
Rate limiter for agent operations.
|
||||
|
||||
Features:
|
||||
- Per-tool rate limits
|
||||
- Per-agent rate limits
|
||||
- Per-resource rate limits
|
||||
- Sliding window implementation
|
||||
- Burst allowance with recovery
|
||||
- Slowdown before hard block
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the RateLimiter."""
|
||||
config = get_safety_config()
|
||||
|
||||
self._configs: dict[str, RateLimitConfig] = {}
|
||||
self._counters: dict[str, SlidingWindowCounter] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Default rate limits
|
||||
self._default_limits = {
|
||||
"actions": RateLimitConfig(
|
||||
name="actions",
|
||||
limit=config.default_actions_per_minute,
|
||||
window_seconds=60,
|
||||
),
|
||||
"llm_calls": RateLimitConfig(
|
||||
name="llm_calls",
|
||||
limit=config.default_llm_calls_per_minute,
|
||||
window_seconds=60,
|
||||
),
|
||||
"file_ops": RateLimitConfig(
|
||||
name="file_ops",
|
||||
limit=config.default_file_ops_per_minute,
|
||||
window_seconds=60,
|
||||
),
|
||||
}
|
||||
|
||||
def configure(self, config: RateLimitConfig) -> None:
|
||||
"""
|
||||
Configure a rate limit.
|
||||
|
||||
Args:
|
||||
config: Rate limit configuration
|
||||
"""
|
||||
self._configs[config.name] = config
|
||||
logger.debug(
|
||||
"Configured rate limit: %s = %d/%ds",
|
||||
config.name,
|
||||
config.limit,
|
||||
config.window_seconds,
|
||||
)
|
||||
|
||||
async def check(
|
||||
self,
|
||||
limit_name: str,
|
||||
key: str,
|
||||
) -> RateLimitStatus:
|
||||
"""
|
||||
Check rate limit without consuming a slot.
|
||||
|
||||
Args:
|
||||
limit_name: Name of the rate limit
|
||||
key: Key for tracking (e.g., agent_id)
|
||||
|
||||
Returns:
|
||||
Rate limit status
|
||||
"""
|
||||
counter = await self._get_counter(limit_name, key)
|
||||
config = self._get_config(limit_name)
|
||||
|
||||
current, remaining, reset_in = await counter.get_status()
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
return RateLimitStatus(
|
||||
name=limit_name,
|
||||
current_count=current,
|
||||
limit=config.limit,
|
||||
window_seconds=config.window_seconds,
|
||||
remaining=remaining,
|
||||
reset_at=datetime.utcnow() + timedelta(seconds=reset_in),
|
||||
is_limited=remaining <= 0,
|
||||
retry_after_seconds=reset_in if remaining <= 0 else 0.0,
|
||||
)
|
||||
|
||||
async def acquire(
|
||||
self,
|
||||
limit_name: str,
|
||||
key: str,
|
||||
) -> tuple[bool, RateLimitStatus]:
|
||||
"""
|
||||
Try to acquire a rate limit slot.
|
||||
|
||||
Args:
|
||||
limit_name: Name of the rate limit
|
||||
key: Key for tracking (e.g., agent_id)
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed, status)
|
||||
"""
|
||||
counter = await self._get_counter(limit_name, key)
|
||||
config = self._get_config(limit_name)
|
||||
|
||||
allowed, retry_after = await counter.try_acquire()
|
||||
current, remaining, reset_in = await counter.get_status()
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
status = RateLimitStatus(
|
||||
name=limit_name,
|
||||
current_count=current,
|
||||
limit=config.limit,
|
||||
window_seconds=config.window_seconds,
|
||||
remaining=remaining,
|
||||
reset_at=datetime.utcnow() + timedelta(seconds=reset_in),
|
||||
is_limited=not allowed,
|
||||
retry_after_seconds=retry_after,
|
||||
)
|
||||
|
||||
return allowed, status
|
||||
|
||||
async def check_action(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
) -> tuple[bool, list[RateLimitStatus]]:
|
||||
"""
|
||||
Check all applicable rate limits for an action.
|
||||
|
||||
Args:
|
||||
action: The action to check
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed, list of statuses)
|
||||
"""
|
||||
agent_id = action.metadata.agent_id
|
||||
statuses: list[RateLimitStatus] = []
|
||||
allowed = True
|
||||
|
||||
# Check general actions limit
|
||||
actions_allowed, actions_status = await self.acquire("actions", agent_id)
|
||||
statuses.append(actions_status)
|
||||
if not actions_allowed:
|
||||
allowed = False
|
||||
|
||||
# Check LLM-specific limit for LLM calls
|
||||
if action.action_type.value == "llm_call":
|
||||
llm_allowed, llm_status = await self.acquire("llm_calls", agent_id)
|
||||
statuses.append(llm_status)
|
||||
if not llm_allowed:
|
||||
allowed = False
|
||||
|
||||
# Check file ops limit for file operations
|
||||
if action.action_type.value in {"file_read", "file_write", "file_delete"}:
|
||||
file_allowed, file_status = await self.acquire("file_ops", agent_id)
|
||||
statuses.append(file_status)
|
||||
if not file_allowed:
|
||||
allowed = False
|
||||
|
||||
return allowed, statuses
|
||||
|
||||
async def require(
|
||||
self,
|
||||
limit_name: str,
|
||||
key: str,
|
||||
) -> None:
|
||||
"""
|
||||
Require rate limit slot or raise exception.
|
||||
|
||||
Args:
|
||||
limit_name: Name of the rate limit
|
||||
key: Key for tracking
|
||||
|
||||
Raises:
|
||||
RateLimitExceededError: If rate limit exceeded
|
||||
"""
|
||||
allowed, status = await self.acquire(limit_name, key)
|
||||
if not allowed:
|
||||
raise RateLimitExceededError(
|
||||
f"Rate limit exceeded: {limit_name}",
|
||||
limit_type=limit_name,
|
||||
limit_value=status.limit,
|
||||
window_seconds=status.window_seconds,
|
||||
retry_after_seconds=status.retry_after_seconds,
|
||||
)
|
||||
|
||||
async def get_all_statuses(self, key: str) -> dict[str, RateLimitStatus]:
|
||||
"""
|
||||
Get status of all rate limits for a key.
|
||||
|
||||
Args:
|
||||
key: Key for tracking
|
||||
|
||||
Returns:
|
||||
Dict of limit name to status
|
||||
"""
|
||||
statuses = {}
|
||||
for name in self._default_limits:
|
||||
statuses[name] = await self.check(name, key)
|
||||
for name in self._configs:
|
||||
if name not in statuses:
|
||||
statuses[name] = await self.check(name, key)
|
||||
return statuses
|
||||
|
||||
async def reset(self, limit_name: str, key: str) -> bool:
|
||||
"""
|
||||
Reset a rate limit counter.
|
||||
|
||||
Args:
|
||||
limit_name: Name of the rate limit
|
||||
key: Key for tracking
|
||||
|
||||
Returns:
|
||||
True if counter was found and reset
|
||||
"""
|
||||
counter_key = f"{limit_name}:{key}"
|
||||
async with self._lock:
|
||||
if counter_key in self._counters:
|
||||
del self._counters[counter_key]
|
||||
return True
|
||||
return False
|
||||
|
||||
async def reset_all(self, key: str) -> int:
|
||||
"""
|
||||
Reset all rate limit counters for a key.
|
||||
|
||||
Args:
|
||||
key: Key for tracking
|
||||
|
||||
Returns:
|
||||
Number of counters reset
|
||||
"""
|
||||
count = 0
|
||||
async with self._lock:
|
||||
to_remove = [k for k in self._counters if k.endswith(f":{key}")]
|
||||
for k in to_remove:
|
||||
del self._counters[k]
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def _get_config(self, limit_name: str) -> RateLimitConfig:
|
||||
"""Get configuration for a rate limit."""
|
||||
if limit_name in self._configs:
|
||||
return self._configs[limit_name]
|
||||
if limit_name in self._default_limits:
|
||||
return self._default_limits[limit_name]
|
||||
# Return default
|
||||
return RateLimitConfig(
|
||||
name=limit_name,
|
||||
limit=60,
|
||||
window_seconds=60,
|
||||
)
|
||||
|
||||
async def _get_counter(
|
||||
self,
|
||||
limit_name: str,
|
||||
key: str,
|
||||
) -> SlidingWindowCounter:
|
||||
"""Get or create a counter."""
|
||||
counter_key = f"{limit_name}:{key}"
|
||||
config = self._get_config(limit_name)
|
||||
|
||||
async with self._lock:
|
||||
if counter_key not in self._counters:
|
||||
self._counters[counter_key] = SlidingWindowCounter(
|
||||
limit=config.limit,
|
||||
window_seconds=config.window_seconds,
|
||||
burst_limit=config.burst_limit,
|
||||
)
|
||||
return self._counters[counter_key]
|
||||
17
backend/app/services/safety/loops/__init__.py
Normal file
17
backend/app/services/safety/loops/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
Loop Detection Module
|
||||
|
||||
Detects and prevents action loops in agent behavior.
|
||||
"""
|
||||
|
||||
from .detector import (
|
||||
ActionSignature,
|
||||
LoopBreaker,
|
||||
LoopDetector,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ActionSignature",
|
||||
"LoopBreaker",
|
||||
"LoopDetector",
|
||||
]
|
||||
267
backend/app/services/safety/loops/detector.py
Normal file
267
backend/app/services/safety/loops/detector.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""
|
||||
Loop Detector
|
||||
|
||||
Detects and prevents action loops in agent behavior.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from collections import Counter, deque
|
||||
from typing import Any
|
||||
|
||||
from ..config import get_safety_config
|
||||
from ..exceptions import LoopDetectedError
|
||||
from ..models import ActionRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ActionSignature:
|
||||
"""Signature of an action for comparison."""
|
||||
|
||||
def __init__(self, action: ActionRequest) -> None:
|
||||
self.action_type = action.action_type.value
|
||||
self.tool_name = action.tool_name
|
||||
self.resource = action.resource
|
||||
self.args_hash = self._hash_args(action.arguments)
|
||||
|
||||
def _hash_args(self, args: dict[str, Any]) -> str:
|
||||
"""Create a hash of the arguments."""
|
||||
try:
|
||||
serialized = json.dumps(args, sort_keys=True, default=str)
|
||||
return hashlib.sha256(serialized.encode()).hexdigest()[:8]
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
def exact_key(self) -> str:
|
||||
"""Key for exact match detection."""
|
||||
return f"{self.action_type}:{self.tool_name}:{self.resource}:{self.args_hash}"
|
||||
|
||||
def semantic_key(self) -> str:
|
||||
"""Key for semantic (similar) match detection."""
|
||||
return f"{self.action_type}:{self.tool_name}:{self.resource}"
|
||||
|
||||
def type_key(self) -> str:
|
||||
"""Key for action type only."""
|
||||
return f"{self.action_type}"
|
||||
|
||||
|
||||
class LoopDetector:
|
||||
"""
|
||||
Detects action loops and repetitive behavior.
|
||||
|
||||
Loop Types:
|
||||
- Exact: Same action with same arguments
|
||||
- Semantic: Similar actions (same type/tool/resource, different args)
|
||||
- Oscillation: A→B→A→B patterns
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
history_size: int | None = None,
|
||||
max_exact_repetitions: int | None = None,
|
||||
max_semantic_repetitions: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the LoopDetector.
|
||||
|
||||
Args:
|
||||
history_size: Size of action history to track
|
||||
max_exact_repetitions: Max allowed exact repetitions
|
||||
max_semantic_repetitions: Max allowed semantic repetitions
|
||||
"""
|
||||
config = get_safety_config()
|
||||
|
||||
self._history_size = history_size or config.loop_history_size
|
||||
self._max_exact = max_exact_repetitions or config.max_repeated_actions
|
||||
self._max_semantic = max_semantic_repetitions or config.max_similar_actions
|
||||
|
||||
# Per-agent history
|
||||
self._histories: dict[str, deque[ActionSignature]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def check(self, action: ActionRequest) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Check if an action would create a loop.
|
||||
|
||||
Args:
|
||||
action: The action to check
|
||||
|
||||
Returns:
|
||||
Tuple of (is_loop, loop_type)
|
||||
"""
|
||||
agent_id = action.metadata.agent_id
|
||||
signature = ActionSignature(action)
|
||||
|
||||
async with self._lock:
|
||||
history = self._get_history(agent_id)
|
||||
|
||||
# Check exact repetition
|
||||
exact_key = signature.exact_key()
|
||||
exact_count = sum(1 for h in history if h.exact_key() == exact_key)
|
||||
if exact_count >= self._max_exact:
|
||||
return True, "exact"
|
||||
|
||||
# Check semantic repetition
|
||||
semantic_key = signature.semantic_key()
|
||||
semantic_count = sum(1 for h in history if h.semantic_key() == semantic_key)
|
||||
if semantic_count >= self._max_semantic:
|
||||
return True, "semantic"
|
||||
|
||||
# Check oscillation (A→B→A→B pattern)
|
||||
if len(history) >= 3:
|
||||
pattern = self._detect_oscillation(history, signature)
|
||||
if pattern:
|
||||
return True, "oscillation"
|
||||
|
||||
return False, None
|
||||
|
||||
async def check_and_raise(self, action: ActionRequest) -> None:
|
||||
"""
|
||||
Check for loops and raise if detected.
|
||||
|
||||
Args:
|
||||
action: The action to check
|
||||
|
||||
Raises:
|
||||
LoopDetectedError: If loop is detected
|
||||
"""
|
||||
is_loop, loop_type = await self.check(action)
|
||||
if is_loop:
|
||||
signature = ActionSignature(action)
|
||||
raise LoopDetectedError(
|
||||
f"Loop detected: {loop_type}",
|
||||
loop_type=loop_type or "unknown",
|
||||
repetition_count=self._max_exact if loop_type == "exact" else self._max_semantic,
|
||||
action_pattern=[signature.semantic_key()],
|
||||
agent_id=action.metadata.agent_id,
|
||||
action_id=action.id,
|
||||
)
|
||||
|
||||
async def record(self, action: ActionRequest) -> None:
|
||||
"""
|
||||
Record an action in history.
|
||||
|
||||
Args:
|
||||
action: The action to record
|
||||
"""
|
||||
agent_id = action.metadata.agent_id
|
||||
signature = ActionSignature(action)
|
||||
|
||||
async with self._lock:
|
||||
history = self._get_history(agent_id)
|
||||
history.append(signature)
|
||||
|
||||
async def clear_history(self, agent_id: str) -> None:
|
||||
"""
|
||||
Clear history for an agent.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
"""
|
||||
async with self._lock:
|
||||
if agent_id in self._histories:
|
||||
self._histories[agent_id].clear()
|
||||
|
||||
async def get_stats(self, agent_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
Get loop detection stats for an agent.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
|
||||
Returns:
|
||||
Stats dictionary
|
||||
"""
|
||||
async with self._lock:
|
||||
history = self._get_history(agent_id)
|
||||
|
||||
# Count action types
|
||||
type_counts = Counter(h.type_key() for h in history)
|
||||
semantic_counts = Counter(h.semantic_key() for h in history)
|
||||
|
||||
return {
|
||||
"history_size": len(history),
|
||||
"max_history": self._history_size,
|
||||
"action_type_counts": dict(type_counts),
|
||||
"top_semantic_patterns": semantic_counts.most_common(5),
|
||||
}
|
||||
|
||||
def _get_history(self, agent_id: str) -> deque[ActionSignature]:
|
||||
"""Get or create history for an agent."""
|
||||
if agent_id not in self._histories:
|
||||
self._histories[agent_id] = deque(maxlen=self._history_size)
|
||||
return self._histories[agent_id]
|
||||
|
||||
def _detect_oscillation(
|
||||
self,
|
||||
history: deque[ActionSignature],
|
||||
current: ActionSignature,
|
||||
) -> bool:
|
||||
"""
|
||||
Detect A→B→A→B oscillation pattern.
|
||||
|
||||
Looks at last 4+ actions including current.
|
||||
"""
|
||||
if len(history) < 3:
|
||||
return False
|
||||
|
||||
# Get last 3 actions + current
|
||||
recent = [*list(history)[-3:], current]
|
||||
|
||||
# Check for A→B→A→B pattern
|
||||
if len(recent) >= 4:
|
||||
# Get semantic keys
|
||||
keys = [a.semantic_key() for a in recent[-4:]]
|
||||
|
||||
# Pattern: k[0]==k[2] and k[1]==k[3] and k[0]!=k[1]
|
||||
if keys[0] == keys[2] and keys[1] == keys[3] and keys[0] != keys[1]:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class LoopBreaker:
|
||||
"""
|
||||
Strategies for breaking detected loops.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def suggest_alternatives(
|
||||
action: ActionRequest,
|
||||
loop_type: str,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Suggest alternative actions when loop is detected.
|
||||
|
||||
Args:
|
||||
action: The looping action
|
||||
loop_type: Type of loop detected
|
||||
|
||||
Returns:
|
||||
List of suggestions
|
||||
"""
|
||||
suggestions = []
|
||||
|
||||
if loop_type == "exact":
|
||||
suggestions.append(
|
||||
"The same action with identical arguments has been repeated too many times. "
|
||||
"Consider: (1) Verify the action succeeded, (2) Try a different approach, "
|
||||
"(3) Escalate for human review"
|
||||
)
|
||||
elif loop_type == "semantic":
|
||||
suggestions.append(
|
||||
"Similar actions have been repeated too many times. "
|
||||
"Consider: (1) Review if the approach is working, (2) Try an alternative method, "
|
||||
"(3) Request clarification on the goal"
|
||||
)
|
||||
elif loop_type == "oscillation":
|
||||
suggestions.append(
|
||||
"An oscillating pattern was detected (A→B→A→B). "
|
||||
"This usually indicates conflicting goals or a stuck state. "
|
||||
"Consider: (1) Step back and reassess, (2) Request human guidance"
|
||||
)
|
||||
|
||||
return suggestions
|
||||
17
backend/app/services/safety/mcp/__init__.py
Normal file
17
backend/app/services/safety/mcp/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""MCP safety integration."""
|
||||
|
||||
from .integration import (
|
||||
MCPSafetyWrapper,
|
||||
MCPToolCall,
|
||||
MCPToolResult,
|
||||
SafeToolExecutor,
|
||||
create_mcp_wrapper,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MCPSafetyWrapper",
|
||||
"MCPToolCall",
|
||||
"MCPToolResult",
|
||||
"SafeToolExecutor",
|
||||
"create_mcp_wrapper",
|
||||
]
|
||||
405
backend/app/services/safety/mcp/integration.py
Normal file
405
backend/app/services/safety/mcp/integration.py
Normal file
@@ -0,0 +1,405 @@
|
||||
"""
|
||||
MCP Safety Integration
|
||||
|
||||
Provides safety-aware wrappers for MCP tool execution.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, ClassVar, TypeVar
|
||||
|
||||
from ..audit import AuditLogger
|
||||
from ..emergency import EmergencyControls, get_emergency_controls
|
||||
from ..exceptions import (
|
||||
EmergencyStopError,
|
||||
SafetyError,
|
||||
)
|
||||
from ..guardian import SafetyGuardian, get_safety_guardian
|
||||
from ..models import (
|
||||
ActionMetadata,
|
||||
ActionRequest,
|
||||
ActionType,
|
||||
AutonomyLevel,
|
||||
SafetyDecision,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPToolCall:
|
||||
"""Represents an MCP tool call."""
|
||||
|
||||
tool_name: str
|
||||
arguments: dict[str, Any]
|
||||
server_name: str | None = None
|
||||
project_id: str | None = None
|
||||
context: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPToolResult:
|
||||
"""Result of an MCP tool execution."""
|
||||
|
||||
success: bool
|
||||
result: Any = None
|
||||
error: str | None = None
|
||||
safety_decision: SafetyDecision = SafetyDecision.ALLOW
|
||||
execution_time_ms: float = 0.0
|
||||
approval_id: str | None = None
|
||||
checkpoint_id: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class MCPSafetyWrapper:
|
||||
"""
|
||||
Wraps MCP tool execution with safety checks.
|
||||
|
||||
Features:
|
||||
- Pre-execution validation via SafetyGuardian
|
||||
- Permission checking per tool/resource
|
||||
- Budget and rate limit enforcement
|
||||
- Audit logging of all MCP calls
|
||||
- Emergency stop integration
|
||||
- Checkpoint creation for destructive operations
|
||||
"""
|
||||
|
||||
# Tool categories for automatic classification
|
||||
DESTRUCTIVE_TOOLS: ClassVar[set[str]] = {
|
||||
"file_write",
|
||||
"file_delete",
|
||||
"database_mutate",
|
||||
"shell_execute",
|
||||
"git_push",
|
||||
"git_commit",
|
||||
"deploy",
|
||||
}
|
||||
|
||||
READ_ONLY_TOOLS: ClassVar[set[str]] = {
|
||||
"file_read",
|
||||
"database_query",
|
||||
"git_status",
|
||||
"git_log",
|
||||
"list_files",
|
||||
"search",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guardian: SafetyGuardian | None = None,
|
||||
audit_logger: AuditLogger | None = None,
|
||||
emergency_controls: EmergencyControls | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize MCPSafetyWrapper.
|
||||
|
||||
Args:
|
||||
guardian: SafetyGuardian instance (uses singleton if not provided)
|
||||
audit_logger: AuditLogger instance
|
||||
emergency_controls: EmergencyControls instance
|
||||
"""
|
||||
self._guardian = guardian
|
||||
self._audit_logger = audit_logger
|
||||
self._emergency_controls = emergency_controls
|
||||
self._tool_handlers: dict[str, Callable[..., Any]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def _get_guardian(self) -> SafetyGuardian:
|
||||
"""Get or create SafetyGuardian."""
|
||||
if self._guardian is None:
|
||||
self._guardian = await get_safety_guardian()
|
||||
return self._guardian
|
||||
|
||||
async def _get_emergency_controls(self) -> EmergencyControls:
|
||||
"""Get or create EmergencyControls."""
|
||||
if self._emergency_controls is None:
|
||||
self._emergency_controls = await get_emergency_controls()
|
||||
return self._emergency_controls
|
||||
|
||||
def register_tool_handler(
|
||||
self,
|
||||
tool_name: str,
|
||||
handler: Callable[..., Any],
|
||||
) -> None:
|
||||
"""
|
||||
Register a handler for a tool.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool
|
||||
handler: Async function to handle the tool call
|
||||
"""
|
||||
self._tool_handlers[tool_name] = handler
|
||||
logger.debug("Registered handler for tool: %s", tool_name)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
tool_call: MCPToolCall,
|
||||
agent_id: str,
|
||||
autonomy_level: AutonomyLevel = AutonomyLevel.MILESTONE,
|
||||
bypass_safety: bool = False,
|
||||
) -> MCPToolResult:
|
||||
"""
|
||||
Execute an MCP tool call with safety checks.
|
||||
|
||||
Args:
|
||||
tool_call: The tool call to execute
|
||||
agent_id: ID of the calling agent
|
||||
autonomy_level: Agent's autonomy level
|
||||
bypass_safety: Bypass safety checks (emergency only)
|
||||
|
||||
Returns:
|
||||
MCPToolResult with execution outcome
|
||||
"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
# Check emergency controls first
|
||||
emergency = await self._get_emergency_controls()
|
||||
scope = f"agent:{agent_id}"
|
||||
if tool_call.project_id:
|
||||
scope = f"project:{tool_call.project_id}"
|
||||
|
||||
try:
|
||||
await emergency.check_allowed(scope=scope, raise_if_blocked=True)
|
||||
except EmergencyStopError as e:
|
||||
return MCPToolResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
safety_decision=SafetyDecision.DENY,
|
||||
metadata={"emergency_stop": True},
|
||||
)
|
||||
|
||||
# Build action request
|
||||
action = self._build_action_request(
|
||||
tool_call=tool_call,
|
||||
agent_id=agent_id,
|
||||
autonomy_level=autonomy_level,
|
||||
)
|
||||
|
||||
# Skip safety checks if bypass is enabled
|
||||
if bypass_safety:
|
||||
logger.warning(
|
||||
"Safety bypass enabled for tool: %s (agent: %s)",
|
||||
tool_call.tool_name,
|
||||
agent_id,
|
||||
)
|
||||
return await self._execute_tool(tool_call, action, start_time)
|
||||
|
||||
# Run safety validation
|
||||
guardian = await self._get_guardian()
|
||||
try:
|
||||
guardian_result = await guardian.validate(action)
|
||||
except SafetyError as e:
|
||||
return MCPToolResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
safety_decision=SafetyDecision.DENY,
|
||||
execution_time_ms=self._elapsed_ms(start_time),
|
||||
)
|
||||
|
||||
# Handle safety decision
|
||||
if guardian_result.decision == SafetyDecision.DENY:
|
||||
return MCPToolResult(
|
||||
success=False,
|
||||
error="; ".join(guardian_result.reasons),
|
||||
safety_decision=SafetyDecision.DENY,
|
||||
execution_time_ms=self._elapsed_ms(start_time),
|
||||
)
|
||||
|
||||
if guardian_result.decision == SafetyDecision.REQUIRE_APPROVAL:
|
||||
# For now, just return that approval is required
|
||||
# The caller should handle the approval flow
|
||||
return MCPToolResult(
|
||||
success=False,
|
||||
error="Action requires human approval",
|
||||
safety_decision=SafetyDecision.REQUIRE_APPROVAL,
|
||||
approval_id=guardian_result.approval_id,
|
||||
execution_time_ms=self._elapsed_ms(start_time),
|
||||
)
|
||||
|
||||
# Execute the tool
|
||||
result = await self._execute_tool(
|
||||
tool_call,
|
||||
action,
|
||||
start_time,
|
||||
checkpoint_id=guardian_result.checkpoint_id,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def _execute_tool(
|
||||
self,
|
||||
tool_call: MCPToolCall,
|
||||
action: ActionRequest,
|
||||
start_time: datetime,
|
||||
checkpoint_id: str | None = None,
|
||||
) -> MCPToolResult:
|
||||
"""Execute the actual tool call."""
|
||||
handler = self._tool_handlers.get(tool_call.tool_name)
|
||||
|
||||
if handler is None:
|
||||
return MCPToolResult(
|
||||
success=False,
|
||||
error=f"No handler registered for tool: {tool_call.tool_name}",
|
||||
safety_decision=SafetyDecision.ALLOW,
|
||||
execution_time_ms=self._elapsed_ms(start_time),
|
||||
)
|
||||
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
result = await handler(**tool_call.arguments)
|
||||
else:
|
||||
result = handler(**tool_call.arguments)
|
||||
|
||||
return MCPToolResult(
|
||||
success=True,
|
||||
result=result,
|
||||
safety_decision=SafetyDecision.ALLOW,
|
||||
execution_time_ms=self._elapsed_ms(start_time),
|
||||
checkpoint_id=checkpoint_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Tool execution failed: %s - %s", tool_call.tool_name, e)
|
||||
return MCPToolResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
safety_decision=SafetyDecision.ALLOW,
|
||||
execution_time_ms=self._elapsed_ms(start_time),
|
||||
checkpoint_id=checkpoint_id,
|
||||
)
|
||||
|
||||
def _build_action_request(
|
||||
self,
|
||||
tool_call: MCPToolCall,
|
||||
agent_id: str,
|
||||
autonomy_level: AutonomyLevel,
|
||||
) -> ActionRequest:
|
||||
"""Build an ActionRequest from an MCP tool call."""
|
||||
action_type = self._classify_tool(tool_call.tool_name)
|
||||
|
||||
metadata = ActionMetadata(
|
||||
agent_id=agent_id,
|
||||
session_id=tool_call.context.get("session_id", ""),
|
||||
project_id=tool_call.project_id or "",
|
||||
autonomy_level=autonomy_level,
|
||||
)
|
||||
|
||||
return ActionRequest(
|
||||
action_type=action_type,
|
||||
tool_name=tool_call.tool_name,
|
||||
arguments=tool_call.arguments,
|
||||
resource=tool_call.arguments.get("path", tool_call.arguments.get("resource")),
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _classify_tool(self, tool_name: str) -> ActionType:
|
||||
"""Classify a tool into an action type."""
|
||||
tool_lower = tool_name.lower()
|
||||
|
||||
# Check destructive patterns
|
||||
if any(d in tool_lower for d in ["write", "create", "delete", "remove", "update"]):
|
||||
if "file" in tool_lower:
|
||||
if "delete" in tool_lower or "remove" in tool_lower:
|
||||
return ActionType.FILE_DELETE
|
||||
return ActionType.FILE_WRITE
|
||||
if "database" in tool_lower or "db" in tool_lower:
|
||||
return ActionType.DATABASE_MUTATE
|
||||
|
||||
# Check read patterns
|
||||
if any(r in tool_lower for r in ["read", "get", "list", "search", "query"]):
|
||||
if "file" in tool_lower:
|
||||
return ActionType.FILE_READ
|
||||
if "database" in tool_lower or "db" in tool_lower:
|
||||
return ActionType.DATABASE_QUERY
|
||||
|
||||
# Check specific types
|
||||
if "shell" in tool_lower or "exec" in tool_lower or "bash" in tool_lower:
|
||||
return ActionType.SHELL_COMMAND
|
||||
|
||||
if "git" in tool_lower:
|
||||
return ActionType.GIT_OPERATION
|
||||
|
||||
if "http" in tool_lower or "fetch" in tool_lower or "request" in tool_lower:
|
||||
return ActionType.NETWORK_REQUEST
|
||||
|
||||
if "llm" in tool_lower or "ai" in tool_lower or "claude" in tool_lower:
|
||||
return ActionType.LLM_CALL
|
||||
|
||||
# Default to tool call
|
||||
return ActionType.TOOL_CALL
|
||||
|
||||
def _elapsed_ms(self, start_time: datetime) -> float:
|
||||
"""Calculate elapsed time in milliseconds."""
|
||||
return (datetime.utcnow() - start_time).total_seconds() * 1000
|
||||
|
||||
|
||||
class SafeToolExecutor:
|
||||
"""
|
||||
Context manager for safe tool execution with automatic cleanup.
|
||||
|
||||
Usage:
|
||||
async with SafeToolExecutor(wrapper, tool_call, agent_id) as executor:
|
||||
result = await executor.execute()
|
||||
if result.success:
|
||||
# Use result
|
||||
else:
|
||||
# Handle error or approval required
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
wrapper: MCPSafetyWrapper,
|
||||
tool_call: MCPToolCall,
|
||||
agent_id: str,
|
||||
autonomy_level: AutonomyLevel = AutonomyLevel.MILESTONE,
|
||||
) -> None:
|
||||
self._wrapper = wrapper
|
||||
self._tool_call = tool_call
|
||||
self._agent_id = agent_id
|
||||
self._autonomy_level = autonomy_level
|
||||
self._result: MCPToolResult | None = None
|
||||
|
||||
async def __aenter__(self) -> "SafeToolExecutor":
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[Exception] | None,
|
||||
exc_val: Exception | None,
|
||||
exc_tb: Any,
|
||||
) -> bool:
|
||||
# Could trigger rollback here if needed
|
||||
return False
|
||||
|
||||
async def execute(self) -> MCPToolResult:
|
||||
"""Execute the tool call."""
|
||||
self._result = await self._wrapper.execute(
|
||||
self._tool_call,
|
||||
self._agent_id,
|
||||
self._autonomy_level,
|
||||
)
|
||||
return self._result
|
||||
|
||||
@property
|
||||
def result(self) -> MCPToolResult | None:
|
||||
"""Get the execution result."""
|
||||
return self._result
|
||||
|
||||
|
||||
# Factory function
|
||||
async def create_mcp_wrapper(
|
||||
guardian: SafetyGuardian | None = None,
|
||||
) -> MCPSafetyWrapper:
|
||||
"""Create an MCPSafetyWrapper with default configuration."""
|
||||
if guardian is None:
|
||||
guardian = await get_safety_guardian()
|
||||
|
||||
return MCPSafetyWrapper(
|
||||
guardian=guardian,
|
||||
emergency_controls=await get_emergency_controls(),
|
||||
)
|
||||
19
backend/app/services/safety/metrics/__init__.py
Normal file
19
backend/app/services/safety/metrics/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Safety metrics collection and export."""
|
||||
|
||||
from .collector import (
|
||||
MetricType,
|
||||
MetricValue,
|
||||
SafetyMetrics,
|
||||
get_safety_metrics,
|
||||
record_mcp_call,
|
||||
record_validation,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MetricType",
|
||||
"MetricValue",
|
||||
"SafetyMetrics",
|
||||
"get_safety_metrics",
|
||||
"record_mcp_call",
|
||||
"record_validation",
|
||||
]
|
||||
416
backend/app/services/safety/metrics/collector.py
Normal file
416
backend/app/services/safety/metrics/collector.py
Normal file
@@ -0,0 +1,416 @@
|
||||
"""
|
||||
Safety Metrics Collector
|
||||
|
||||
Collects and exposes metrics for the safety framework.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import Counter, defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MetricType(str, Enum):
|
||||
"""Types of metrics."""
|
||||
|
||||
COUNTER = "counter"
|
||||
GAUGE = "gauge"
|
||||
HISTOGRAM = "histogram"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetricValue:
|
||||
"""A single metric value."""
|
||||
|
||||
name: str
|
||||
metric_type: MetricType
|
||||
value: float
|
||||
labels: dict[str, str] = field(default_factory=dict)
|
||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HistogramBucket:
|
||||
"""Histogram bucket for distribution metrics."""
|
||||
|
||||
le: float # Less than or equal
|
||||
count: int = 0
|
||||
|
||||
|
||||
class SafetyMetrics:
|
||||
"""
|
||||
Collects safety framework metrics.
|
||||
|
||||
Metrics tracked:
|
||||
- Action validation counts (by decision type)
|
||||
- Approval request counts and latencies
|
||||
- Budget usage and remaining
|
||||
- Rate limit hits
|
||||
- Loop detections
|
||||
- Emergency events
|
||||
- Content filter matches
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize SafetyMetrics."""
|
||||
self._counters: dict[str, Counter[str]] = defaultdict(Counter)
|
||||
self._gauges: dict[str, dict[str, float]] = defaultdict(dict)
|
||||
self._histograms: dict[str, list[float]] = defaultdict(list)
|
||||
self._histogram_buckets: dict[str, list[HistogramBucket]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Initialize histogram buckets
|
||||
self._init_histogram_buckets()
|
||||
|
||||
def _init_histogram_buckets(self) -> None:
|
||||
"""Initialize histogram buckets for latency metrics."""
|
||||
latency_buckets = [0.01, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, float("inf")]
|
||||
|
||||
for name in [
|
||||
"validation_latency_seconds",
|
||||
"approval_latency_seconds",
|
||||
"mcp_execution_latency_seconds",
|
||||
]:
|
||||
self._histogram_buckets[name] = [
|
||||
HistogramBucket(le=b) for b in latency_buckets
|
||||
]
|
||||
|
||||
# Counter methods
|
||||
|
||||
async def inc_validations(
|
||||
self,
|
||||
decision: str,
|
||||
agent_id: str | None = None,
|
||||
) -> None:
|
||||
"""Increment validation counter."""
|
||||
async with self._lock:
|
||||
labels = f"decision={decision}"
|
||||
if agent_id:
|
||||
labels += f",agent_id={agent_id}"
|
||||
self._counters["safety_validations_total"][labels] += 1
|
||||
|
||||
async def inc_approvals_requested(self, urgency: str = "normal") -> None:
|
||||
"""Increment approval requests counter."""
|
||||
async with self._lock:
|
||||
labels = f"urgency={urgency}"
|
||||
self._counters["safety_approvals_requested_total"][labels] += 1
|
||||
|
||||
async def inc_approvals_granted(self) -> None:
|
||||
"""Increment approvals granted counter."""
|
||||
async with self._lock:
|
||||
self._counters["safety_approvals_granted_total"][""] += 1
|
||||
|
||||
async def inc_approvals_denied(self, reason: str = "manual") -> None:
|
||||
"""Increment approvals denied counter."""
|
||||
async with self._lock:
|
||||
labels = f"reason={reason}"
|
||||
self._counters["safety_approvals_denied_total"][labels] += 1
|
||||
|
||||
async def inc_rate_limit_exceeded(self, limit_type: str) -> None:
|
||||
"""Increment rate limit exceeded counter."""
|
||||
async with self._lock:
|
||||
labels = f"limit_type={limit_type}"
|
||||
self._counters["safety_rate_limit_exceeded_total"][labels] += 1
|
||||
|
||||
async def inc_budget_exceeded(self, budget_type: str) -> None:
|
||||
"""Increment budget exceeded counter."""
|
||||
async with self._lock:
|
||||
labels = f"budget_type={budget_type}"
|
||||
self._counters["safety_budget_exceeded_total"][labels] += 1
|
||||
|
||||
async def inc_loops_detected(self, loop_type: str) -> None:
|
||||
"""Increment loop detection counter."""
|
||||
async with self._lock:
|
||||
labels = f"loop_type={loop_type}"
|
||||
self._counters["safety_loops_detected_total"][labels] += 1
|
||||
|
||||
async def inc_emergency_events(self, event_type: str, scope: str) -> None:
|
||||
"""Increment emergency events counter."""
|
||||
async with self._lock:
|
||||
labels = f"event_type={event_type},scope={scope}"
|
||||
self._counters["safety_emergency_events_total"][labels] += 1
|
||||
|
||||
async def inc_content_filtered(self, category: str, action: str) -> None:
|
||||
"""Increment content filter counter."""
|
||||
async with self._lock:
|
||||
labels = f"category={category},action={action}"
|
||||
self._counters["safety_content_filtered_total"][labels] += 1
|
||||
|
||||
async def inc_checkpoints_created(self) -> None:
|
||||
"""Increment checkpoints created counter."""
|
||||
async with self._lock:
|
||||
self._counters["safety_checkpoints_created_total"][""] += 1
|
||||
|
||||
async def inc_rollbacks_executed(self, success: bool) -> None:
|
||||
"""Increment rollbacks counter."""
|
||||
async with self._lock:
|
||||
labels = f"success={str(success).lower()}"
|
||||
self._counters["safety_rollbacks_total"][labels] += 1
|
||||
|
||||
async def inc_mcp_calls(self, tool_name: str, success: bool) -> None:
|
||||
"""Increment MCP tool calls counter."""
|
||||
async with self._lock:
|
||||
labels = f"tool_name={tool_name},success={str(success).lower()}"
|
||||
self._counters["safety_mcp_calls_total"][labels] += 1
|
||||
|
||||
# Gauge methods
|
||||
|
||||
async def set_budget_remaining(
|
||||
self,
|
||||
scope: str,
|
||||
budget_type: str,
|
||||
remaining: float,
|
||||
) -> None:
|
||||
"""Set remaining budget gauge."""
|
||||
async with self._lock:
|
||||
labels = f"scope={scope},budget_type={budget_type}"
|
||||
self._gauges["safety_budget_remaining"][labels] = remaining
|
||||
|
||||
async def set_rate_limit_remaining(
|
||||
self,
|
||||
scope: str,
|
||||
limit_type: str,
|
||||
remaining: int,
|
||||
) -> None:
|
||||
"""Set remaining rate limit gauge."""
|
||||
async with self._lock:
|
||||
labels = f"scope={scope},limit_type={limit_type}"
|
||||
self._gauges["safety_rate_limit_remaining"][labels] = float(remaining)
|
||||
|
||||
async def set_pending_approvals(self, count: int) -> None:
|
||||
"""Set pending approvals gauge."""
|
||||
async with self._lock:
|
||||
self._gauges["safety_pending_approvals"][""] = float(count)
|
||||
|
||||
async def set_active_checkpoints(self, count: int) -> None:
|
||||
"""Set active checkpoints gauge."""
|
||||
async with self._lock:
|
||||
self._gauges["safety_active_checkpoints"][""] = float(count)
|
||||
|
||||
async def set_emergency_state(self, scope: str, state: str) -> None:
|
||||
"""Set emergency state gauge (0=normal, 1=paused, 2=stopped)."""
|
||||
async with self._lock:
|
||||
state_value = {"normal": 0, "paused": 1, "stopped": 2}.get(state, -1)
|
||||
labels = f"scope={scope}"
|
||||
self._gauges["safety_emergency_state"][labels] = float(state_value)
|
||||
|
||||
# Histogram methods
|
||||
|
||||
async def observe_validation_latency(self, latency_seconds: float) -> None:
|
||||
"""Observe validation latency."""
|
||||
async with self._lock:
|
||||
self._observe_histogram("validation_latency_seconds", latency_seconds)
|
||||
|
||||
async def observe_approval_latency(self, latency_seconds: float) -> None:
|
||||
"""Observe approval latency."""
|
||||
async with self._lock:
|
||||
self._observe_histogram("approval_latency_seconds", latency_seconds)
|
||||
|
||||
async def observe_mcp_execution_latency(self, latency_seconds: float) -> None:
|
||||
"""Observe MCP execution latency."""
|
||||
async with self._lock:
|
||||
self._observe_histogram("mcp_execution_latency_seconds", latency_seconds)
|
||||
|
||||
def _observe_histogram(self, name: str, value: float) -> None:
|
||||
"""Record a value in a histogram."""
|
||||
self._histograms[name].append(value)
|
||||
|
||||
# Update buckets
|
||||
if name in self._histogram_buckets:
|
||||
for bucket in self._histogram_buckets[name]:
|
||||
if value <= bucket.le:
|
||||
bucket.count += 1
|
||||
|
||||
# Export methods
|
||||
|
||||
async def get_all_metrics(self) -> list[MetricValue]:
|
||||
"""Get all metrics as MetricValue objects."""
|
||||
metrics: list[MetricValue] = []
|
||||
|
||||
async with self._lock:
|
||||
# Export counters
|
||||
for name, counter in self._counters.items():
|
||||
for labels_str, value in counter.items():
|
||||
labels = self._parse_labels(labels_str)
|
||||
metrics.append(
|
||||
MetricValue(
|
||||
name=name,
|
||||
metric_type=MetricType.COUNTER,
|
||||
value=float(value),
|
||||
labels=labels,
|
||||
)
|
||||
)
|
||||
|
||||
# Export gauges
|
||||
for name, gauge_dict in self._gauges.items():
|
||||
for labels_str, gauge_value in gauge_dict.items():
|
||||
gauge_labels = self._parse_labels(labels_str)
|
||||
metrics.append(
|
||||
MetricValue(
|
||||
name=name,
|
||||
metric_type=MetricType.GAUGE,
|
||||
value=gauge_value,
|
||||
labels=gauge_labels,
|
||||
)
|
||||
)
|
||||
|
||||
# Export histogram summaries
|
||||
for name, values in self._histograms.items():
|
||||
if values:
|
||||
metrics.append(
|
||||
MetricValue(
|
||||
name=f"{name}_count",
|
||||
metric_type=MetricType.COUNTER,
|
||||
value=float(len(values)),
|
||||
)
|
||||
)
|
||||
metrics.append(
|
||||
MetricValue(
|
||||
name=f"{name}_sum",
|
||||
metric_type=MetricType.COUNTER,
|
||||
value=sum(values),
|
||||
)
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
async def get_prometheus_format(self) -> str:
|
||||
"""Export metrics in Prometheus text format."""
|
||||
lines: list[str] = []
|
||||
|
||||
async with self._lock:
|
||||
# Export counters
|
||||
for name, counter in self._counters.items():
|
||||
lines.append(f"# TYPE {name} counter")
|
||||
for labels_str, value in counter.items():
|
||||
if labels_str:
|
||||
lines.append(f"{name}{{{labels_str}}} {value}")
|
||||
else:
|
||||
lines.append(f"{name} {value}")
|
||||
|
||||
# Export gauges
|
||||
for name, gauge_dict in self._gauges.items():
|
||||
lines.append(f"# TYPE {name} gauge")
|
||||
for labels_str, gauge_value in gauge_dict.items():
|
||||
if labels_str:
|
||||
lines.append(f"{name}{{{labels_str}}} {gauge_value}")
|
||||
else:
|
||||
lines.append(f"{name} {gauge_value}")
|
||||
|
||||
# Export histograms
|
||||
for name, buckets in self._histogram_buckets.items():
|
||||
lines.append(f"# TYPE {name} histogram")
|
||||
for bucket in buckets:
|
||||
le_str = "+Inf" if bucket.le == float("inf") else str(bucket.le)
|
||||
lines.append(f'{name}_bucket{{le="{le_str}"}} {bucket.count}')
|
||||
|
||||
if name in self._histograms:
|
||||
values = self._histograms[name]
|
||||
lines.append(f"{name}_count {len(values)}")
|
||||
lines.append(f"{name}_sum {sum(values)}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def get_summary(self) -> dict[str, Any]:
|
||||
"""Get a summary of key metrics."""
|
||||
async with self._lock:
|
||||
total_validations = sum(self._counters["safety_validations_total"].values())
|
||||
denied_validations = sum(
|
||||
v for k, v in self._counters["safety_validations_total"].items()
|
||||
if "decision=deny" in k
|
||||
)
|
||||
|
||||
return {
|
||||
"total_validations": total_validations,
|
||||
"denied_validations": denied_validations,
|
||||
"approval_requests": sum(
|
||||
self._counters["safety_approvals_requested_total"].values()
|
||||
),
|
||||
"approvals_granted": sum(
|
||||
self._counters["safety_approvals_granted_total"].values()
|
||||
),
|
||||
"approvals_denied": sum(
|
||||
self._counters["safety_approvals_denied_total"].values()
|
||||
),
|
||||
"rate_limit_hits": sum(
|
||||
self._counters["safety_rate_limit_exceeded_total"].values()
|
||||
),
|
||||
"budget_exceeded": sum(
|
||||
self._counters["safety_budget_exceeded_total"].values()
|
||||
),
|
||||
"loops_detected": sum(
|
||||
self._counters["safety_loops_detected_total"].values()
|
||||
),
|
||||
"emergency_events": sum(
|
||||
self._counters["safety_emergency_events_total"].values()
|
||||
),
|
||||
"content_filtered": sum(
|
||||
self._counters["safety_content_filtered_total"].values()
|
||||
),
|
||||
"checkpoints_created": sum(
|
||||
self._counters["safety_checkpoints_created_total"].values()
|
||||
),
|
||||
"rollbacks_executed": sum(
|
||||
self._counters["safety_rollbacks_total"].values()
|
||||
),
|
||||
"mcp_calls": sum(
|
||||
self._counters["safety_mcp_calls_total"].values()
|
||||
),
|
||||
"pending_approvals": self._gauges.get("safety_pending_approvals", {}).get("", 0),
|
||||
"active_checkpoints": self._gauges.get("safety_active_checkpoints", {}).get("", 0),
|
||||
}
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Reset all metrics."""
|
||||
async with self._lock:
|
||||
self._counters.clear()
|
||||
self._gauges.clear()
|
||||
self._histograms.clear()
|
||||
self._init_histogram_buckets()
|
||||
|
||||
def _parse_labels(self, labels_str: str) -> dict[str, str]:
|
||||
"""Parse labels string into dictionary."""
|
||||
if not labels_str:
|
||||
return {}
|
||||
|
||||
labels = {}
|
||||
for pair in labels_str.split(","):
|
||||
if "=" in pair:
|
||||
key, value = pair.split("=", 1)
|
||||
labels[key.strip()] = value.strip()
|
||||
|
||||
return labels
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_metrics: SafetyMetrics | None = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_safety_metrics() -> SafetyMetrics:
|
||||
"""Get the singleton SafetyMetrics instance."""
|
||||
global _metrics
|
||||
|
||||
async with _lock:
|
||||
if _metrics is None:
|
||||
_metrics = SafetyMetrics()
|
||||
return _metrics
|
||||
|
||||
|
||||
# Convenience functions
|
||||
async def record_validation(decision: str, agent_id: str | None = None) -> None:
|
||||
"""Record a validation event."""
|
||||
metrics = await get_safety_metrics()
|
||||
await metrics.inc_validations(decision, agent_id)
|
||||
|
||||
|
||||
async def record_mcp_call(tool_name: str, success: bool, latency_ms: float) -> None:
|
||||
"""Record an MCP tool call."""
|
||||
metrics = await get_safety_metrics()
|
||||
await metrics.inc_mcp_calls(tool_name, success)
|
||||
await metrics.observe_mcp_execution_latency(latency_ms / 1000)
|
||||
474
backend/app/services/safety/models.py
Normal file
474
backend/app/services/safety/models.py
Normal file
@@ -0,0 +1,474 @@
|
||||
"""
|
||||
Safety Framework Models
|
||||
|
||||
Core Pydantic models for actions, events, policies, and safety decisions.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# ============================================================================
|
||||
# Enums
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class ActionType(str, Enum):
|
||||
"""Types of actions that can be performed."""
|
||||
|
||||
TOOL_CALL = "tool_call"
|
||||
FILE_READ = "file_read"
|
||||
FILE_WRITE = "file_write"
|
||||
FILE_DELETE = "file_delete"
|
||||
API_CALL = "api_call"
|
||||
DATABASE_QUERY = "database_query"
|
||||
DATABASE_MUTATE = "database_mutate"
|
||||
GIT_OPERATION = "git_operation"
|
||||
SHELL_COMMAND = "shell_command"
|
||||
LLM_CALL = "llm_call"
|
||||
NETWORK_REQUEST = "network_request"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class ResourceType(str, Enum):
|
||||
"""Types of resources that can be accessed."""
|
||||
|
||||
FILE = "file"
|
||||
DATABASE = "database"
|
||||
API = "api"
|
||||
NETWORK = "network"
|
||||
GIT = "git"
|
||||
SHELL = "shell"
|
||||
LLM = "llm"
|
||||
MEMORY = "memory"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class PermissionLevel(str, Enum):
|
||||
"""Permission levels for resource access."""
|
||||
|
||||
NONE = "none"
|
||||
READ = "read"
|
||||
WRITE = "write"
|
||||
EXECUTE = "execute"
|
||||
DELETE = "delete"
|
||||
ADMIN = "admin"
|
||||
|
||||
|
||||
class AutonomyLevel(str, Enum):
|
||||
"""Autonomy levels for agent operation."""
|
||||
|
||||
FULL_CONTROL = "full_control" # Approve every action
|
||||
MILESTONE = "milestone" # Approve at milestones
|
||||
AUTONOMOUS = "autonomous" # Only major decisions
|
||||
|
||||
|
||||
class SafetyDecision(str, Enum):
|
||||
"""Result of safety validation."""
|
||||
|
||||
ALLOW = "allow"
|
||||
DENY = "deny"
|
||||
REQUIRE_APPROVAL = "require_approval"
|
||||
DELAY = "delay"
|
||||
SANDBOX = "sandbox"
|
||||
|
||||
|
||||
class ApprovalStatus(str, Enum):
|
||||
"""Status of approval request."""
|
||||
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
DENIED = "denied"
|
||||
TIMEOUT = "timeout"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class AuditEventType(str, Enum):
|
||||
"""Types of audit events."""
|
||||
|
||||
ACTION_REQUESTED = "action_requested"
|
||||
ACTION_VALIDATED = "action_validated"
|
||||
ACTION_DENIED = "action_denied"
|
||||
ACTION_EXECUTED = "action_executed"
|
||||
ACTION_FAILED = "action_failed"
|
||||
APPROVAL_REQUESTED = "approval_requested"
|
||||
APPROVAL_GRANTED = "approval_granted"
|
||||
APPROVAL_DENIED = "approval_denied"
|
||||
APPROVAL_TIMEOUT = "approval_timeout"
|
||||
CHECKPOINT_CREATED = "checkpoint_created"
|
||||
ROLLBACK_STARTED = "rollback_started"
|
||||
ROLLBACK_COMPLETED = "rollback_completed"
|
||||
ROLLBACK_FAILED = "rollback_failed"
|
||||
BUDGET_WARNING = "budget_warning"
|
||||
BUDGET_EXCEEDED = "budget_exceeded"
|
||||
RATE_LIMITED = "rate_limited"
|
||||
LOOP_DETECTED = "loop_detected"
|
||||
EMERGENCY_STOP = "emergency_stop"
|
||||
POLICY_VIOLATION = "policy_violation"
|
||||
CONTENT_FILTERED = "content_filtered"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Action Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class ActionMetadata(BaseModel):
|
||||
"""Metadata associated with an action."""
|
||||
|
||||
agent_id: str = Field(..., description="ID of the agent performing the action")
|
||||
project_id: str | None = Field(None, description="ID of the project context")
|
||||
session_id: str | None = Field(None, description="ID of the current session")
|
||||
task_id: str | None = Field(None, description="ID of the current task")
|
||||
parent_action_id: str | None = Field(None, description="ID of the parent action")
|
||||
correlation_id: str | None = Field(None, description="Correlation ID for tracing")
|
||||
user_id: str | None = Field(None, description="ID of the user who initiated")
|
||||
autonomy_level: AutonomyLevel = Field(
|
||||
default=AutonomyLevel.MILESTONE,
|
||||
description="Current autonomy level",
|
||||
)
|
||||
context: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Additional context",
|
||||
)
|
||||
|
||||
|
||||
class ActionRequest(BaseModel):
|
||||
"""Request to perform an action."""
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
action_type: ActionType = Field(..., description="Type of action to perform")
|
||||
tool_name: str | None = Field(None, description="Name of the tool to call")
|
||||
resource: str | None = Field(None, description="Resource being accessed")
|
||||
resource_type: ResourceType | None = Field(None, description="Type of resource")
|
||||
arguments: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Action arguments",
|
||||
)
|
||||
metadata: ActionMetadata = Field(..., description="Action metadata")
|
||||
estimated_cost_tokens: int = Field(0, description="Estimated token cost")
|
||||
estimated_cost_usd: float = Field(0.0, description="Estimated USD cost")
|
||||
is_destructive: bool = Field(False, description="Whether action is destructive")
|
||||
is_reversible: bool = Field(True, description="Whether action can be rolled back")
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class ActionResult(BaseModel):
|
||||
"""Result of an executed action."""
|
||||
|
||||
action_id: str = Field(..., description="ID of the action")
|
||||
success: bool = Field(..., description="Whether action succeeded")
|
||||
data: Any = Field(None, description="Action result data")
|
||||
error: str | None = Field(None, description="Error message if failed")
|
||||
error_code: str | None = Field(None, description="Error code if failed")
|
||||
execution_time_ms: float = Field(0.0, description="Execution time in ms")
|
||||
actual_cost_tokens: int = Field(0, description="Actual token cost")
|
||||
actual_cost_usd: float = Field(0.0, description="Actual USD cost")
|
||||
checkpoint_id: str | None = Field(None, description="Checkpoint ID if created")
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Validation Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class ValidationRule(BaseModel):
|
||||
"""A single validation rule."""
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
name: str = Field(..., description="Rule name")
|
||||
description: str | None = Field(None, description="Rule description")
|
||||
priority: int = Field(0, description="Rule priority (higher = evaluated first)")
|
||||
enabled: bool = Field(True, description="Whether rule is enabled")
|
||||
|
||||
# Rule conditions
|
||||
action_types: list[ActionType] | None = Field(
|
||||
None, description="Action types this rule applies to"
|
||||
)
|
||||
tool_patterns: list[str] | None = Field(
|
||||
None, description="Tool name patterns (supports wildcards)"
|
||||
)
|
||||
resource_patterns: list[str] | None = Field(
|
||||
None, description="Resource patterns (supports wildcards)"
|
||||
)
|
||||
agent_ids: list[str] | None = Field(
|
||||
None, description="Agent IDs this rule applies to"
|
||||
)
|
||||
|
||||
# Rule decision
|
||||
decision: SafetyDecision = Field(..., description="Decision when rule matches")
|
||||
reason: str | None = Field(None, description="Reason for decision")
|
||||
|
||||
|
||||
class ValidationResult(BaseModel):
|
||||
"""Result of action validation."""
|
||||
|
||||
action_id: str = Field(..., description="ID of the validated action")
|
||||
decision: SafetyDecision = Field(..., description="Validation decision")
|
||||
applied_rules: list[str] = Field(
|
||||
default_factory=list, description="IDs of applied rules"
|
||||
)
|
||||
reasons: list[str] = Field(
|
||||
default_factory=list, description="Reasons for decision"
|
||||
)
|
||||
approval_id: str | None = Field(None, description="Approval request ID if needed")
|
||||
retry_after_seconds: float | None = Field(
|
||||
None, description="Retry delay if rate limited"
|
||||
)
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Budget Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class BudgetScope(str, Enum):
|
||||
"""Scope of a budget limit."""
|
||||
|
||||
SESSION = "session"
|
||||
DAILY = "daily"
|
||||
WEEKLY = "weekly"
|
||||
MONTHLY = "monthly"
|
||||
PROJECT = "project"
|
||||
AGENT = "agent"
|
||||
|
||||
|
||||
class BudgetStatus(BaseModel):
|
||||
"""Current budget status."""
|
||||
|
||||
scope: BudgetScope = Field(..., description="Budget scope")
|
||||
scope_id: str = Field(..., description="ID within scope (session/agent/project)")
|
||||
tokens_used: int = Field(0, description="Tokens used in this scope")
|
||||
tokens_limit: int = Field(100000, description="Token limit for this scope")
|
||||
cost_used_usd: float = Field(0.0, description="USD spent in this scope")
|
||||
cost_limit_usd: float = Field(10.0, description="USD limit for this scope")
|
||||
tokens_remaining: int = Field(0, description="Remaining tokens")
|
||||
cost_remaining_usd: float = Field(0.0, description="Remaining USD budget")
|
||||
warning_threshold: float = Field(0.8, description="Warn at this usage fraction")
|
||||
is_warning: bool = Field(False, description="Whether at warning level")
|
||||
is_exceeded: bool = Field(False, description="Whether budget exceeded")
|
||||
reset_at: datetime | None = Field(None, description="When budget resets")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Rate Limit Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class RateLimitConfig(BaseModel):
|
||||
"""Configuration for a rate limit."""
|
||||
|
||||
name: str = Field(..., description="Rate limit name")
|
||||
limit: int = Field(..., description="Maximum allowed in window")
|
||||
window_seconds: int = Field(60, description="Time window in seconds")
|
||||
burst_limit: int | None = Field(None, description="Burst allowance")
|
||||
slowdown_threshold: float = Field(
|
||||
0.8, description="Start slowing at this fraction"
|
||||
)
|
||||
|
||||
|
||||
class RateLimitStatus(BaseModel):
|
||||
"""Current rate limit status."""
|
||||
|
||||
name: str = Field(..., description="Rate limit name")
|
||||
current_count: int = Field(0, description="Current count in window")
|
||||
limit: int = Field(..., description="Maximum allowed")
|
||||
window_seconds: int = Field(..., description="Time window")
|
||||
remaining: int = Field(..., description="Remaining in window")
|
||||
reset_at: datetime = Field(..., description="When window resets")
|
||||
is_limited: bool = Field(False, description="Whether currently limited")
|
||||
retry_after_seconds: float = Field(0.0, description="Seconds until retry")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Approval Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class ApprovalRequest(BaseModel):
|
||||
"""Request for human approval."""
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
action: ActionRequest = Field(..., description="Action requiring approval")
|
||||
reason: str = Field(..., description="Why approval is required")
|
||||
urgency: str = Field("normal", description="Urgency level")
|
||||
timeout_seconds: int = Field(300, description="Timeout for approval")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
expires_at: datetime | None = Field(None, description="When request expires")
|
||||
suggested_action: str | None = Field(None, description="Suggested response")
|
||||
context: dict[str, Any] = Field(default_factory=dict, description="Extra context")
|
||||
|
||||
|
||||
class ApprovalResponse(BaseModel):
|
||||
"""Response to an approval request."""
|
||||
|
||||
request_id: str = Field(..., description="ID of the approval request")
|
||||
status: ApprovalStatus = Field(..., description="Approval status")
|
||||
decided_by: str | None = Field(None, description="Who made the decision")
|
||||
reason: str | None = Field(None, description="Reason for decision")
|
||||
modifications: dict[str, Any] | None = Field(
|
||||
None, description="Modifications to action"
|
||||
)
|
||||
decided_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Checkpoint/Rollback Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class CheckpointType(str, Enum):
|
||||
"""Types of checkpoints."""
|
||||
|
||||
FILE = "file"
|
||||
DATABASE = "database"
|
||||
GIT = "git"
|
||||
COMPOSITE = "composite"
|
||||
|
||||
|
||||
class Checkpoint(BaseModel):
|
||||
"""A rollback checkpoint."""
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
checkpoint_type: CheckpointType = Field(..., description="Type of checkpoint")
|
||||
action_id: str = Field(..., description="Action this checkpoint is for")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
expires_at: datetime | None = Field(None, description="When checkpoint expires")
|
||||
data: dict[str, Any] = Field(default_factory=dict, description="Checkpoint data")
|
||||
description: str | None = Field(None, description="Description of checkpoint")
|
||||
is_valid: bool = Field(True, description="Whether checkpoint is still valid")
|
||||
|
||||
|
||||
class RollbackResult(BaseModel):
|
||||
"""Result of a rollback operation."""
|
||||
|
||||
checkpoint_id: str = Field(..., description="ID of checkpoint rolled back to")
|
||||
success: bool = Field(..., description="Whether rollback succeeded")
|
||||
actions_rolled_back: list[str] = Field(
|
||||
default_factory=list, description="IDs of rolled back actions"
|
||||
)
|
||||
failed_actions: list[str] = Field(
|
||||
default_factory=list, description="IDs of actions that failed to rollback"
|
||||
)
|
||||
error: str | None = Field(None, description="Error message if failed")
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Audit Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class AuditEvent(BaseModel):
|
||||
"""An audit log event."""
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
event_type: AuditEventType = Field(..., description="Type of audit event")
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
agent_id: str | None = Field(None, description="Agent ID if applicable")
|
||||
action_id: str | None = Field(None, description="Action ID if applicable")
|
||||
project_id: str | None = Field(None, description="Project ID if applicable")
|
||||
session_id: str | None = Field(None, description="Session ID if applicable")
|
||||
user_id: str | None = Field(None, description="User ID if applicable")
|
||||
decision: SafetyDecision | None = Field(None, description="Safety decision")
|
||||
details: dict[str, Any] = Field(default_factory=dict, description="Event details")
|
||||
correlation_id: str | None = Field(None, description="Correlation ID for tracing")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Policy Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class SafetyPolicy(BaseModel):
|
||||
"""A complete safety policy configuration."""
|
||||
|
||||
name: str = Field(..., description="Policy name")
|
||||
description: str | None = Field(None, description="Policy description")
|
||||
version: str = Field("1.0.0", description="Policy version")
|
||||
enabled: bool = Field(True, description="Whether policy is enabled")
|
||||
|
||||
# Cost controls
|
||||
max_tokens_per_session: int = Field(100_000, description="Max tokens per session")
|
||||
max_tokens_per_day: int = Field(1_000_000, description="Max tokens per day")
|
||||
max_cost_per_session_usd: float = Field(10.0, description="Max USD per session")
|
||||
max_cost_per_day_usd: float = Field(100.0, description="Max USD per day")
|
||||
|
||||
# Rate limits
|
||||
max_actions_per_minute: int = Field(60, description="Max actions per minute")
|
||||
max_llm_calls_per_minute: int = Field(20, description="Max LLM calls per minute")
|
||||
max_file_operations_per_minute: int = Field(
|
||||
100, description="Max file ops per minute"
|
||||
)
|
||||
|
||||
# Permissions
|
||||
allowed_tools: list[str] = Field(
|
||||
default_factory=lambda: ["*"],
|
||||
description="Allowed tool patterns",
|
||||
)
|
||||
denied_tools: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Denied tool patterns",
|
||||
)
|
||||
allowed_file_patterns: list[str] = Field(
|
||||
default_factory=lambda: ["**/*"],
|
||||
description="Allowed file patterns",
|
||||
)
|
||||
denied_file_patterns: list[str] = Field(
|
||||
default_factory=lambda: ["**/.env", "**/secrets/**"],
|
||||
description="Denied file patterns",
|
||||
)
|
||||
|
||||
# HITL
|
||||
require_approval_for: list[str] = Field(
|
||||
default_factory=lambda: [
|
||||
"delete_file",
|
||||
"push_to_remote",
|
||||
"deploy_to_production",
|
||||
"modify_critical_config",
|
||||
],
|
||||
description="Actions requiring approval",
|
||||
)
|
||||
|
||||
# Loop detection
|
||||
max_repeated_actions: int = Field(5, description="Max exact repetitions")
|
||||
max_similar_actions: int = Field(10, description="Max similar actions")
|
||||
|
||||
# Sandbox
|
||||
require_sandbox: bool = Field(False, description="Require sandbox execution")
|
||||
sandbox_timeout_seconds: int = Field(300, description="Sandbox timeout")
|
||||
sandbox_memory_mb: int = Field(1024, description="Sandbox memory limit")
|
||||
|
||||
# Validation rules
|
||||
validation_rules: list[ValidationRule] = Field(
|
||||
default_factory=list,
|
||||
description="Custom validation rules",
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Guardian Result Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class GuardianResult(BaseModel):
|
||||
"""Result of SafetyGuardian evaluation."""
|
||||
|
||||
action_id: str = Field(..., description="ID of the action")
|
||||
allowed: bool = Field(..., description="Whether action is allowed")
|
||||
decision: SafetyDecision = Field(..., description="Safety decision")
|
||||
reasons: list[str] = Field(default_factory=list, description="Decision reasons")
|
||||
approval_id: str | None = Field(None, description="Approval ID if needed")
|
||||
checkpoint_id: str | None = Field(None, description="Checkpoint ID if created")
|
||||
retry_after_seconds: float | None = Field(None, description="Retry delay")
|
||||
modified_action: ActionRequest | None = Field(
|
||||
None, description="Modified action if changed"
|
||||
)
|
||||
audit_events: list[AuditEvent] = Field(
|
||||
default_factory=list, description="Generated audit events"
|
||||
)
|
||||
15
backend/app/services/safety/permissions/__init__.py
Normal file
15
backend/app/services/safety/permissions/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
Permission Management Module
|
||||
|
||||
Agent permissions for resource access.
|
||||
"""
|
||||
|
||||
from .manager import (
|
||||
PermissionGrant,
|
||||
PermissionManager,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PermissionGrant",
|
||||
"PermissionManager",
|
||||
]
|
||||
384
backend/app/services/safety/permissions/manager.py
Normal file
384
backend/app/services/safety/permissions/manager.py
Normal file
@@ -0,0 +1,384 @@
|
||||
"""
|
||||
Permission Manager
|
||||
|
||||
Manages permissions for agent actions on resources.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import fnmatch
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
from ..exceptions import PermissionDeniedError
|
||||
from ..models import (
|
||||
ActionRequest,
|
||||
ActionType,
|
||||
PermissionLevel,
|
||||
ResourceType,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PermissionGrant:
|
||||
"""A permission grant for an agent on a resource."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_id: str,
|
||||
resource_pattern: str,
|
||||
resource_type: ResourceType,
|
||||
level: PermissionLevel,
|
||||
*,
|
||||
expires_at: datetime | None = None,
|
||||
granted_by: str | None = None,
|
||||
reason: str | None = None,
|
||||
) -> None:
|
||||
self.id = str(uuid4())
|
||||
self.agent_id = agent_id
|
||||
self.resource_pattern = resource_pattern
|
||||
self.resource_type = resource_type
|
||||
self.level = level
|
||||
self.expires_at = expires_at
|
||||
self.granted_by = granted_by
|
||||
self.reason = reason
|
||||
self.created_at = datetime.utcnow()
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the grant has expired."""
|
||||
if self.expires_at is None:
|
||||
return False
|
||||
return datetime.utcnow() > self.expires_at
|
||||
|
||||
def matches(self, resource: str, resource_type: ResourceType) -> bool:
|
||||
"""Check if this grant applies to a resource."""
|
||||
if self.resource_type != resource_type:
|
||||
return False
|
||||
return fnmatch.fnmatch(resource, self.resource_pattern)
|
||||
|
||||
def allows(self, required_level: PermissionLevel) -> bool:
|
||||
"""Check if this grant allows the required permission level."""
|
||||
# Permission level hierarchy
|
||||
hierarchy = {
|
||||
PermissionLevel.NONE: 0,
|
||||
PermissionLevel.READ: 1,
|
||||
PermissionLevel.WRITE: 2,
|
||||
PermissionLevel.EXECUTE: 3,
|
||||
PermissionLevel.DELETE: 4,
|
||||
PermissionLevel.ADMIN: 5,
|
||||
}
|
||||
|
||||
return hierarchy[self.level] >= hierarchy[required_level]
|
||||
|
||||
|
||||
class PermissionManager:
|
||||
"""
|
||||
Manages permissions for agent access to resources.
|
||||
|
||||
Features:
|
||||
- Permission grants by agent/resource pattern
|
||||
- Permission inheritance (project → agent → action)
|
||||
- Temporary permissions with expiration
|
||||
- Least-privilege defaults
|
||||
- Permission escalation logging
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default_deny: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the PermissionManager.
|
||||
|
||||
Args:
|
||||
default_deny: If True, deny access unless explicitly granted
|
||||
"""
|
||||
self._grants: list[PermissionGrant] = []
|
||||
self._default_deny = default_deny
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Default permissions for common resources
|
||||
self._default_permissions: dict[ResourceType, PermissionLevel] = {
|
||||
ResourceType.FILE: PermissionLevel.READ,
|
||||
ResourceType.DATABASE: PermissionLevel.READ,
|
||||
ResourceType.API: PermissionLevel.READ,
|
||||
ResourceType.GIT: PermissionLevel.READ,
|
||||
ResourceType.LLM: PermissionLevel.EXECUTE,
|
||||
ResourceType.SHELL: PermissionLevel.NONE,
|
||||
ResourceType.NETWORK: PermissionLevel.READ,
|
||||
}
|
||||
|
||||
async def grant(
|
||||
self,
|
||||
agent_id: str,
|
||||
resource_pattern: str,
|
||||
resource_type: ResourceType,
|
||||
level: PermissionLevel,
|
||||
*,
|
||||
duration_seconds: int | None = None,
|
||||
granted_by: str | None = None,
|
||||
reason: str | None = None,
|
||||
) -> PermissionGrant:
|
||||
"""
|
||||
Grant a permission to an agent.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
resource_pattern: Pattern for matching resources (supports wildcards)
|
||||
resource_type: Type of resource
|
||||
level: Permission level to grant
|
||||
duration_seconds: Optional duration for temporary permission
|
||||
granted_by: Who granted the permission
|
||||
reason: Reason for granting
|
||||
|
||||
Returns:
|
||||
The created permission grant
|
||||
"""
|
||||
expires_at = None
|
||||
if duration_seconds:
|
||||
expires_at = datetime.utcnow() + timedelta(seconds=duration_seconds)
|
||||
|
||||
grant = PermissionGrant(
|
||||
agent_id=agent_id,
|
||||
resource_pattern=resource_pattern,
|
||||
resource_type=resource_type,
|
||||
level=level,
|
||||
expires_at=expires_at,
|
||||
granted_by=granted_by,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
self._grants.append(grant)
|
||||
|
||||
logger.info(
|
||||
"Permission granted: agent=%s, resource=%s, type=%s, level=%s",
|
||||
agent_id,
|
||||
resource_pattern,
|
||||
resource_type.value,
|
||||
level.value,
|
||||
)
|
||||
|
||||
return grant
|
||||
|
||||
async def revoke(self, grant_id: str) -> bool:
|
||||
"""
|
||||
Revoke a permission grant.
|
||||
|
||||
Args:
|
||||
grant_id: ID of the grant to revoke
|
||||
|
||||
Returns:
|
||||
True if grant was found and revoked
|
||||
"""
|
||||
async with self._lock:
|
||||
for i, grant in enumerate(self._grants):
|
||||
if grant.id == grant_id:
|
||||
del self._grants[i]
|
||||
logger.info("Permission revoked: %s", grant_id)
|
||||
return True
|
||||
return False
|
||||
|
||||
async def revoke_all(self, agent_id: str) -> int:
|
||||
"""
|
||||
Revoke all permissions for an agent.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
|
||||
Returns:
|
||||
Number of grants revoked
|
||||
"""
|
||||
async with self._lock:
|
||||
original_count = len(self._grants)
|
||||
self._grants = [g for g in self._grants if g.agent_id != agent_id]
|
||||
revoked = original_count - len(self._grants)
|
||||
|
||||
if revoked:
|
||||
logger.info("Revoked %d permissions for agent %s", revoked, agent_id)
|
||||
|
||||
return revoked
|
||||
|
||||
async def check(
|
||||
self,
|
||||
agent_id: str,
|
||||
resource: str,
|
||||
resource_type: ResourceType,
|
||||
required_level: PermissionLevel,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if an agent has permission to access a resource.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
resource: Resource to access
|
||||
resource_type: Type of resource
|
||||
required_level: Required permission level
|
||||
|
||||
Returns:
|
||||
True if access is allowed
|
||||
"""
|
||||
# Clean up expired grants
|
||||
await self._cleanup_expired()
|
||||
|
||||
async with self._lock:
|
||||
for grant in self._grants:
|
||||
if grant.agent_id != agent_id:
|
||||
continue
|
||||
|
||||
if grant.is_expired():
|
||||
continue
|
||||
|
||||
if grant.matches(resource, resource_type):
|
||||
if grant.allows(required_level):
|
||||
return True
|
||||
|
||||
# Check default permissions
|
||||
if not self._default_deny:
|
||||
default_level = self._default_permissions.get(
|
||||
resource_type, PermissionLevel.NONE
|
||||
)
|
||||
hierarchy = {
|
||||
PermissionLevel.NONE: 0,
|
||||
PermissionLevel.READ: 1,
|
||||
PermissionLevel.WRITE: 2,
|
||||
PermissionLevel.EXECUTE: 3,
|
||||
PermissionLevel.DELETE: 4,
|
||||
PermissionLevel.ADMIN: 5,
|
||||
}
|
||||
if hierarchy[default_level] >= hierarchy[required_level]:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def check_action(self, action: ActionRequest) -> bool:
|
||||
"""
|
||||
Check if an action is permitted.
|
||||
|
||||
Args:
|
||||
action: The action to check
|
||||
|
||||
Returns:
|
||||
True if action is allowed
|
||||
"""
|
||||
# Determine required permission level from action type
|
||||
level_map = {
|
||||
ActionType.FILE_READ: PermissionLevel.READ,
|
||||
ActionType.FILE_WRITE: PermissionLevel.WRITE,
|
||||
ActionType.FILE_DELETE: PermissionLevel.DELETE,
|
||||
ActionType.DATABASE_QUERY: PermissionLevel.READ,
|
||||
ActionType.DATABASE_MUTATE: PermissionLevel.WRITE,
|
||||
ActionType.SHELL_COMMAND: PermissionLevel.EXECUTE,
|
||||
ActionType.API_CALL: PermissionLevel.EXECUTE,
|
||||
ActionType.GIT_OPERATION: PermissionLevel.WRITE,
|
||||
ActionType.LLM_CALL: PermissionLevel.EXECUTE,
|
||||
ActionType.NETWORK_REQUEST: PermissionLevel.READ,
|
||||
ActionType.TOOL_CALL: PermissionLevel.EXECUTE,
|
||||
}
|
||||
|
||||
required_level = level_map.get(action.action_type, PermissionLevel.EXECUTE)
|
||||
|
||||
# Determine resource type from action
|
||||
resource_type_map = {
|
||||
ActionType.FILE_READ: ResourceType.FILE,
|
||||
ActionType.FILE_WRITE: ResourceType.FILE,
|
||||
ActionType.FILE_DELETE: ResourceType.FILE,
|
||||
ActionType.DATABASE_QUERY: ResourceType.DATABASE,
|
||||
ActionType.DATABASE_MUTATE: ResourceType.DATABASE,
|
||||
ActionType.SHELL_COMMAND: ResourceType.SHELL,
|
||||
ActionType.API_CALL: ResourceType.API,
|
||||
ActionType.GIT_OPERATION: ResourceType.GIT,
|
||||
ActionType.LLM_CALL: ResourceType.LLM,
|
||||
ActionType.NETWORK_REQUEST: ResourceType.NETWORK,
|
||||
}
|
||||
|
||||
resource_type = resource_type_map.get(action.action_type, ResourceType.CUSTOM)
|
||||
resource = action.resource or action.tool_name or "*"
|
||||
|
||||
return await self.check(
|
||||
agent_id=action.metadata.agent_id,
|
||||
resource=resource,
|
||||
resource_type=resource_type,
|
||||
required_level=required_level,
|
||||
)
|
||||
|
||||
async def require_permission(
|
||||
self,
|
||||
agent_id: str,
|
||||
resource: str,
|
||||
resource_type: ResourceType,
|
||||
required_level: PermissionLevel,
|
||||
) -> None:
|
||||
"""
|
||||
Require permission or raise exception.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
resource: Resource to access
|
||||
resource_type: Type of resource
|
||||
required_level: Required permission level
|
||||
|
||||
Raises:
|
||||
PermissionDeniedError: If permission is denied
|
||||
"""
|
||||
if not await self.check(agent_id, resource, resource_type, required_level):
|
||||
raise PermissionDeniedError(
|
||||
f"Permission denied: {resource}",
|
||||
action_type=None,
|
||||
resource=resource,
|
||||
required_permission=required_level.value,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
async def list_grants(
|
||||
self,
|
||||
agent_id: str | None = None,
|
||||
resource_type: ResourceType | None = None,
|
||||
) -> list[PermissionGrant]:
|
||||
"""
|
||||
List permission grants.
|
||||
|
||||
Args:
|
||||
agent_id: Optional filter by agent
|
||||
resource_type: Optional filter by resource type
|
||||
|
||||
Returns:
|
||||
List of matching grants
|
||||
"""
|
||||
await self._cleanup_expired()
|
||||
|
||||
async with self._lock:
|
||||
grants = list(self._grants)
|
||||
|
||||
if agent_id:
|
||||
grants = [g for g in grants if g.agent_id == agent_id]
|
||||
|
||||
if resource_type:
|
||||
grants = [g for g in grants if g.resource_type == resource_type]
|
||||
|
||||
return grants
|
||||
|
||||
def set_default_permission(
|
||||
self,
|
||||
resource_type: ResourceType,
|
||||
level: PermissionLevel,
|
||||
) -> None:
|
||||
"""
|
||||
Set the default permission level for a resource type.
|
||||
|
||||
Args:
|
||||
resource_type: Type of resource
|
||||
level: Default permission level
|
||||
"""
|
||||
self._default_permissions[resource_type] = level
|
||||
|
||||
async def _cleanup_expired(self) -> None:
|
||||
"""Remove expired grants."""
|
||||
async with self._lock:
|
||||
original_count = len(self._grants)
|
||||
self._grants = [g for g in self._grants if not g.is_expired()]
|
||||
removed = original_count - len(self._grants)
|
||||
|
||||
if removed:
|
||||
logger.debug("Cleaned up %d expired permission grants", removed)
|
||||
1
backend/app/services/safety/policies/__init__.py
Normal file
1
backend/app/services/safety/policies/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""${dir} module."""
|
||||
5
backend/app/services/safety/rollback/__init__.py
Normal file
5
backend/app/services/safety/rollback/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Rollback management for agent actions."""
|
||||
|
||||
from .manager import RollbackManager, TransactionContext
|
||||
|
||||
__all__ = ["RollbackManager", "TransactionContext"]
|
||||
418
backend/app/services/safety/rollback/manager.py
Normal file
418
backend/app/services/safety/rollback/manager.py
Normal file
@@ -0,0 +1,418 @@
|
||||
"""
|
||||
Rollback Manager
|
||||
|
||||
Manages checkpoints and rollback operations for agent actions.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from ..config import get_safety_config
|
||||
from ..exceptions import RollbackError
|
||||
from ..models import (
|
||||
ActionRequest,
|
||||
Checkpoint,
|
||||
CheckpointType,
|
||||
RollbackResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileCheckpoint:
|
||||
"""Stores file state for rollback."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_id: str,
|
||||
file_path: str,
|
||||
original_content: bytes | None,
|
||||
existed: bool,
|
||||
) -> None:
|
||||
self.checkpoint_id = checkpoint_id
|
||||
self.file_path = file_path
|
||||
self.original_content = original_content
|
||||
self.existed = existed
|
||||
self.created_at = datetime.utcnow()
|
||||
|
||||
|
||||
class RollbackManager:
|
||||
"""
|
||||
Manages checkpoints and rollback operations.
|
||||
|
||||
Features:
|
||||
- File system checkpoints
|
||||
- Transaction wrapping for actions
|
||||
- Automatic checkpoint for destructive actions
|
||||
- Rollback triggers on failure
|
||||
- Checkpoint expiration and cleanup
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_dir: str | None = None,
|
||||
retention_hours: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the RollbackManager.
|
||||
|
||||
Args:
|
||||
checkpoint_dir: Directory for storing checkpoint data
|
||||
retention_hours: Hours to retain checkpoints
|
||||
"""
|
||||
config = get_safety_config()
|
||||
|
||||
self._checkpoint_dir = Path(
|
||||
checkpoint_dir or config.checkpoint_dir
|
||||
)
|
||||
self._retention_hours = retention_hours or config.checkpoint_retention_hours
|
||||
|
||||
self._checkpoints: dict[str, Checkpoint] = {}
|
||||
self._file_checkpoints: dict[str, list[FileCheckpoint]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Ensure checkpoint directory exists
|
||||
self._checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def create_checkpoint(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
checkpoint_type: CheckpointType = CheckpointType.COMPOSITE,
|
||||
description: str | None = None,
|
||||
) -> Checkpoint:
|
||||
"""
|
||||
Create a checkpoint before an action.
|
||||
|
||||
Args:
|
||||
action: The action to checkpoint for
|
||||
checkpoint_type: Type of checkpoint
|
||||
description: Optional description
|
||||
|
||||
Returns:
|
||||
The created checkpoint
|
||||
"""
|
||||
checkpoint_id = str(uuid4())
|
||||
|
||||
checkpoint = Checkpoint(
|
||||
id=checkpoint_id,
|
||||
checkpoint_type=checkpoint_type,
|
||||
action_id=action.id,
|
||||
created_at=datetime.utcnow(),
|
||||
expires_at=datetime.utcnow() + timedelta(hours=self._retention_hours),
|
||||
data={
|
||||
"action_type": action.action_type.value,
|
||||
"tool_name": action.tool_name,
|
||||
"resource": action.resource,
|
||||
},
|
||||
description=description or f"Checkpoint for {action.tool_name}",
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
self._checkpoints[checkpoint_id] = checkpoint
|
||||
self._file_checkpoints[checkpoint_id] = []
|
||||
|
||||
logger.info(
|
||||
"Created checkpoint %s for action %s",
|
||||
checkpoint_id,
|
||||
action.id,
|
||||
)
|
||||
|
||||
return checkpoint
|
||||
|
||||
async def checkpoint_file(
|
||||
self,
|
||||
checkpoint_id: str,
|
||||
file_path: str,
|
||||
) -> None:
|
||||
"""
|
||||
Store current state of a file for checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_id: ID of the checkpoint
|
||||
file_path: Path to the file
|
||||
"""
|
||||
path = Path(file_path)
|
||||
|
||||
if path.exists():
|
||||
content = path.read_bytes()
|
||||
existed = True
|
||||
else:
|
||||
content = None
|
||||
existed = False
|
||||
|
||||
file_checkpoint = FileCheckpoint(
|
||||
checkpoint_id=checkpoint_id,
|
||||
file_path=file_path,
|
||||
original_content=content,
|
||||
existed=existed,
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
if checkpoint_id not in self._file_checkpoints:
|
||||
self._file_checkpoints[checkpoint_id] = []
|
||||
self._file_checkpoints[checkpoint_id].append(file_checkpoint)
|
||||
|
||||
logger.debug(
|
||||
"Stored file state for checkpoint %s: %s (existed=%s)",
|
||||
checkpoint_id,
|
||||
file_path,
|
||||
existed,
|
||||
)
|
||||
|
||||
async def checkpoint_files(
|
||||
self,
|
||||
checkpoint_id: str,
|
||||
file_paths: list[str],
|
||||
) -> None:
|
||||
"""
|
||||
Store current state of multiple files.
|
||||
|
||||
Args:
|
||||
checkpoint_id: ID of the checkpoint
|
||||
file_paths: Paths to the files
|
||||
"""
|
||||
for path in file_paths:
|
||||
await self.checkpoint_file(checkpoint_id, path)
|
||||
|
||||
async def rollback(
|
||||
self,
|
||||
checkpoint_id: str,
|
||||
) -> RollbackResult:
|
||||
"""
|
||||
Rollback to a checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_id: ID of the checkpoint
|
||||
|
||||
Returns:
|
||||
Result of the rollback operation
|
||||
"""
|
||||
async with self._lock:
|
||||
checkpoint = self._checkpoints.get(checkpoint_id)
|
||||
if not checkpoint:
|
||||
raise RollbackError(
|
||||
f"Checkpoint not found: {checkpoint_id}",
|
||||
checkpoint_id=checkpoint_id,
|
||||
)
|
||||
|
||||
if not checkpoint.is_valid:
|
||||
raise RollbackError(
|
||||
f"Checkpoint is no longer valid: {checkpoint_id}",
|
||||
checkpoint_id=checkpoint_id,
|
||||
)
|
||||
|
||||
file_checkpoints = self._file_checkpoints.get(checkpoint_id, [])
|
||||
|
||||
actions_rolled_back: list[str] = []
|
||||
failed_actions: list[str] = []
|
||||
|
||||
# Rollback file changes
|
||||
for fc in file_checkpoints:
|
||||
try:
|
||||
await self._rollback_file(fc)
|
||||
actions_rolled_back.append(f"file:{fc.file_path}")
|
||||
except Exception as e:
|
||||
logger.error("Failed to rollback file %s: %s", fc.file_path, e)
|
||||
failed_actions.append(f"file:{fc.file_path}")
|
||||
|
||||
success = len(failed_actions) == 0
|
||||
|
||||
# Mark checkpoint as used
|
||||
async with self._lock:
|
||||
if checkpoint_id in self._checkpoints:
|
||||
self._checkpoints[checkpoint_id].is_valid = False
|
||||
|
||||
result = RollbackResult(
|
||||
checkpoint_id=checkpoint_id,
|
||||
success=success,
|
||||
actions_rolled_back=actions_rolled_back,
|
||||
failed_actions=failed_actions,
|
||||
error=None if success else f"Failed to rollback {len(failed_actions)} items",
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Rollback successful for checkpoint %s", checkpoint_id)
|
||||
else:
|
||||
logger.error(
|
||||
"Rollback partially failed for checkpoint %s: %d failures",
|
||||
checkpoint_id,
|
||||
len(failed_actions),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def discard_checkpoint(self, checkpoint_id: str) -> bool:
|
||||
"""
|
||||
Discard a checkpoint without rolling back.
|
||||
|
||||
Args:
|
||||
checkpoint_id: ID of the checkpoint
|
||||
|
||||
Returns:
|
||||
True if checkpoint was found and discarded
|
||||
"""
|
||||
async with self._lock:
|
||||
if checkpoint_id in self._checkpoints:
|
||||
del self._checkpoints[checkpoint_id]
|
||||
if checkpoint_id in self._file_checkpoints:
|
||||
del self._file_checkpoints[checkpoint_id]
|
||||
logger.debug("Discarded checkpoint %s", checkpoint_id)
|
||||
return True
|
||||
return False
|
||||
|
||||
async def get_checkpoint(self, checkpoint_id: str) -> Checkpoint | None:
|
||||
"""Get a checkpoint by ID."""
|
||||
async with self._lock:
|
||||
return self._checkpoints.get(checkpoint_id)
|
||||
|
||||
async def list_checkpoints(
|
||||
self,
|
||||
action_id: str | None = None,
|
||||
include_expired: bool = False,
|
||||
) -> list[Checkpoint]:
|
||||
"""
|
||||
List checkpoints.
|
||||
|
||||
Args:
|
||||
action_id: Optional filter by action ID
|
||||
include_expired: Include expired checkpoints
|
||||
|
||||
Returns:
|
||||
List of checkpoints
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
|
||||
async with self._lock:
|
||||
checkpoints = list(self._checkpoints.values())
|
||||
|
||||
if action_id:
|
||||
checkpoints = [c for c in checkpoints if c.action_id == action_id]
|
||||
|
||||
if not include_expired:
|
||||
checkpoints = [
|
||||
c for c in checkpoints
|
||||
if c.expires_at is None or c.expires_at > now
|
||||
]
|
||||
|
||||
return checkpoints
|
||||
|
||||
async def cleanup_expired(self) -> int:
|
||||
"""
|
||||
Clean up expired checkpoints.
|
||||
|
||||
Returns:
|
||||
Number of checkpoints cleaned up
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
to_remove: list[str] = []
|
||||
|
||||
async with self._lock:
|
||||
for checkpoint_id, checkpoint in self._checkpoints.items():
|
||||
if checkpoint.expires_at and checkpoint.expires_at < now:
|
||||
to_remove.append(checkpoint_id)
|
||||
|
||||
for checkpoint_id in to_remove:
|
||||
del self._checkpoints[checkpoint_id]
|
||||
if checkpoint_id in self._file_checkpoints:
|
||||
del self._file_checkpoints[checkpoint_id]
|
||||
|
||||
if to_remove:
|
||||
logger.info("Cleaned up %d expired checkpoints", len(to_remove))
|
||||
|
||||
return len(to_remove)
|
||||
|
||||
async def _rollback_file(self, fc: FileCheckpoint) -> None:
|
||||
"""Rollback a single file to its checkpoint state."""
|
||||
path = Path(fc.file_path)
|
||||
|
||||
if fc.existed:
|
||||
# Restore original content
|
||||
if fc.original_content is not None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_bytes(fc.original_content)
|
||||
logger.debug("Restored file: %s", fc.file_path)
|
||||
else:
|
||||
# File didn't exist before - delete it
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
logger.debug("Deleted file (didn't exist before): %s", fc.file_path)
|
||||
|
||||
|
||||
class TransactionContext:
|
||||
"""
|
||||
Context manager for transactional action execution.
|
||||
|
||||
Usage:
|
||||
async with TransactionContext(rollback_manager, action) as tx:
|
||||
tx.checkpoint_file("/path/to/file")
|
||||
# Do work...
|
||||
# If exception occurs, automatic rollback
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manager: RollbackManager,
|
||||
action: ActionRequest,
|
||||
auto_rollback: bool = True,
|
||||
) -> None:
|
||||
self._manager = manager
|
||||
self._action = action
|
||||
self._auto_rollback = auto_rollback
|
||||
self._checkpoint: Checkpoint | None = None
|
||||
self._committed = False
|
||||
|
||||
async def __aenter__(self) -> "TransactionContext":
|
||||
self._checkpoint = await self._manager.create_checkpoint(self._action)
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type | None,
|
||||
exc_val: Exception | None,
|
||||
exc_tb: Any,
|
||||
) -> bool:
|
||||
if exc_val is not None and self._auto_rollback and not self._committed:
|
||||
# Exception occurred - rollback
|
||||
if self._checkpoint:
|
||||
try:
|
||||
await self._manager.rollback(self._checkpoint.id)
|
||||
logger.info(
|
||||
"Auto-rollback completed for action %s",
|
||||
self._action.id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Auto-rollback failed: %s", e)
|
||||
elif self._committed and self._checkpoint:
|
||||
# Committed - discard checkpoint
|
||||
await self._manager.discard_checkpoint(self._checkpoint.id)
|
||||
|
||||
return False # Don't suppress the exception
|
||||
|
||||
@property
|
||||
def checkpoint_id(self) -> str | None:
|
||||
"""Get the checkpoint ID."""
|
||||
return self._checkpoint.id if self._checkpoint else None
|
||||
|
||||
async def checkpoint_file(self, file_path: str) -> None:
|
||||
"""Checkpoint a file for this transaction."""
|
||||
if self._checkpoint:
|
||||
await self._manager.checkpoint_file(self._checkpoint.id, file_path)
|
||||
|
||||
async def checkpoint_files(self, file_paths: list[str]) -> None:
|
||||
"""Checkpoint multiple files for this transaction."""
|
||||
if self._checkpoint:
|
||||
await self._manager.checkpoint_files(self._checkpoint.id, file_paths)
|
||||
|
||||
def commit(self) -> None:
|
||||
"""Mark transaction as committed (no rollback on exit)."""
|
||||
self._committed = True
|
||||
|
||||
async def rollback(self) -> RollbackResult | None:
|
||||
"""Manually trigger rollback."""
|
||||
if self._checkpoint:
|
||||
return await self._manager.rollback(self._checkpoint.id)
|
||||
return None
|
||||
1
backend/app/services/safety/sandbox/__init__.py
Normal file
1
backend/app/services/safety/sandbox/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""${dir} module."""
|
||||
21
backend/app/services/safety/validation/__init__.py
Normal file
21
backend/app/services/safety/validation/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
Action Validation Module
|
||||
|
||||
Pre-execution validation with rule engine.
|
||||
"""
|
||||
|
||||
from .validator import (
|
||||
ActionValidator,
|
||||
ValidationCache,
|
||||
create_allow_rule,
|
||||
create_approval_rule,
|
||||
create_deny_rule,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ActionValidator",
|
||||
"ValidationCache",
|
||||
"create_allow_rule",
|
||||
"create_approval_rule",
|
||||
"create_deny_rule",
|
||||
]
|
||||
439
backend/app/services/safety/validation/validator.py
Normal file
439
backend/app/services/safety/validation/validator.py
Normal file
@@ -0,0 +1,439 @@
|
||||
"""
|
||||
Action Validator
|
||||
|
||||
Pre-execution validation with rule engine for action requests.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import fnmatch
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
|
||||
from ..config import get_safety_config
|
||||
from ..models import (
|
||||
ActionRequest,
|
||||
ActionType,
|
||||
SafetyDecision,
|
||||
SafetyPolicy,
|
||||
ValidationResult,
|
||||
ValidationRule,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ValidationCache:
|
||||
"""LRU cache for validation results."""
|
||||
|
||||
def __init__(self, max_size: int = 1000, ttl_seconds: int = 60) -> None:
|
||||
self._cache: OrderedDict[str, tuple[ValidationResult, float]] = OrderedDict()
|
||||
self._max_size = max_size
|
||||
self._ttl = ttl_seconds
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def get(self, key: str) -> ValidationResult | None:
|
||||
"""Get cached validation result."""
|
||||
import time
|
||||
|
||||
async with self._lock:
|
||||
if key not in self._cache:
|
||||
return None
|
||||
|
||||
result, timestamp = self._cache[key]
|
||||
if time.time() - timestamp > self._ttl:
|
||||
del self._cache[key]
|
||||
return None
|
||||
|
||||
# Move to end (LRU)
|
||||
self._cache.move_to_end(key)
|
||||
return result
|
||||
|
||||
async def set(self, key: str, result: ValidationResult) -> None:
|
||||
"""Cache a validation result."""
|
||||
import time
|
||||
|
||||
async with self._lock:
|
||||
if key in self._cache:
|
||||
self._cache.move_to_end(key)
|
||||
else:
|
||||
if len(self._cache) >= self._max_size:
|
||||
self._cache.popitem(last=False)
|
||||
self._cache[key] = (result, time.time())
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear the cache."""
|
||||
async with self._lock:
|
||||
self._cache.clear()
|
||||
|
||||
|
||||
class ActionValidator:
|
||||
"""
|
||||
Validates actions against safety rules before execution.
|
||||
|
||||
Features:
|
||||
- Rule-based validation engine
|
||||
- Allow/deny/require-approval rules
|
||||
- Pattern matching for tools and resources
|
||||
- Validation result caching
|
||||
- Bypass capability for emergencies
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_enabled: bool = True,
|
||||
cache_size: int = 1000,
|
||||
cache_ttl: int = 60,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the ActionValidator.
|
||||
|
||||
Args:
|
||||
cache_enabled: Whether to cache validation results
|
||||
cache_size: Maximum cache entries
|
||||
cache_ttl: Cache TTL in seconds
|
||||
"""
|
||||
self._rules: list[ValidationRule] = []
|
||||
self._cache_enabled = cache_enabled
|
||||
self._cache = ValidationCache(max_size=cache_size, ttl_seconds=cache_ttl)
|
||||
self._bypass_enabled = False
|
||||
self._bypass_reason: str | None = None
|
||||
|
||||
config = get_safety_config()
|
||||
self._cache_enabled = cache_enabled
|
||||
self._cache_ttl = config.validation_cache_ttl
|
||||
self._cache_size = config.validation_cache_size
|
||||
|
||||
def add_rule(self, rule: ValidationRule) -> None:
|
||||
"""
|
||||
Add a validation rule.
|
||||
|
||||
Args:
|
||||
rule: The rule to add
|
||||
"""
|
||||
self._rules.append(rule)
|
||||
# Re-sort by priority (higher first)
|
||||
self._rules.sort(key=lambda r: r.priority, reverse=True)
|
||||
logger.debug("Added validation rule: %s (priority %d)", rule.name, rule.priority)
|
||||
|
||||
def remove_rule(self, rule_id: str) -> bool:
|
||||
"""
|
||||
Remove a validation rule by ID.
|
||||
|
||||
Args:
|
||||
rule_id: ID of the rule to remove
|
||||
|
||||
Returns:
|
||||
True if rule was found and removed
|
||||
"""
|
||||
for i, rule in enumerate(self._rules):
|
||||
if rule.id == rule_id:
|
||||
del self._rules[i]
|
||||
logger.debug("Removed validation rule: %s", rule_id)
|
||||
return True
|
||||
return False
|
||||
|
||||
def clear_rules(self) -> None:
|
||||
"""Remove all validation rules."""
|
||||
self._rules.clear()
|
||||
|
||||
def load_rules_from_policy(self, policy: SafetyPolicy) -> None:
|
||||
"""
|
||||
Load validation rules from a safety policy.
|
||||
|
||||
Args:
|
||||
policy: The policy to load rules from
|
||||
"""
|
||||
# Clear existing rules
|
||||
self.clear_rules()
|
||||
|
||||
# Add rules from policy
|
||||
for rule in policy.validation_rules:
|
||||
self.add_rule(rule)
|
||||
|
||||
# Create implicit rules from policy settings
|
||||
|
||||
# Denied tools
|
||||
for i, pattern in enumerate(policy.denied_tools):
|
||||
self.add_rule(
|
||||
ValidationRule(
|
||||
name=f"deny_tool_{i}",
|
||||
description=f"Deny tool pattern: {pattern}",
|
||||
priority=100, # High priority for denials
|
||||
tool_patterns=[pattern],
|
||||
decision=SafetyDecision.DENY,
|
||||
reason=f"Tool matches denied pattern: {pattern}",
|
||||
)
|
||||
)
|
||||
|
||||
# Require approval patterns
|
||||
for i, pattern in enumerate(policy.require_approval_for):
|
||||
if pattern == "*":
|
||||
# All actions require approval
|
||||
self.add_rule(
|
||||
ValidationRule(
|
||||
name="require_approval_all",
|
||||
description="All actions require approval",
|
||||
priority=50,
|
||||
action_types=list(ActionType),
|
||||
decision=SafetyDecision.REQUIRE_APPROVAL,
|
||||
reason="All actions require human approval",
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.add_rule(
|
||||
ValidationRule(
|
||||
name=f"require_approval_{i}",
|
||||
description=f"Require approval for: {pattern}",
|
||||
priority=50,
|
||||
tool_patterns=[pattern],
|
||||
decision=SafetyDecision.REQUIRE_APPROVAL,
|
||||
reason=f"Action matches approval-required pattern: {pattern}",
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("Loaded %d rules from policy: %s", len(self._rules), policy.name)
|
||||
|
||||
async def validate(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
policy: SafetyPolicy | None = None,
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
Validate an action against all rules.
|
||||
|
||||
Args:
|
||||
action: The action to validate
|
||||
policy: Optional policy override
|
||||
|
||||
Returns:
|
||||
ValidationResult with decision and details
|
||||
"""
|
||||
# Check bypass
|
||||
if self._bypass_enabled:
|
||||
logger.warning(
|
||||
"Validation bypass active: %s - allowing action %s",
|
||||
self._bypass_reason,
|
||||
action.id,
|
||||
)
|
||||
return ValidationResult(
|
||||
action_id=action.id,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
applied_rules=[],
|
||||
reasons=[f"Validation bypassed: {self._bypass_reason}"],
|
||||
)
|
||||
|
||||
# Check cache
|
||||
if self._cache_enabled:
|
||||
cache_key = self._get_cache_key(action)
|
||||
cached = await self._cache.get(cache_key)
|
||||
if cached:
|
||||
logger.debug("Using cached validation for action %s", action.id)
|
||||
return cached
|
||||
|
||||
# Load rules from policy if provided
|
||||
if policy and not self._rules:
|
||||
self.load_rules_from_policy(policy)
|
||||
|
||||
# Validate against rules
|
||||
applied_rules: list[str] = []
|
||||
reasons: list[str] = []
|
||||
final_decision = SafetyDecision.ALLOW
|
||||
approval_id: str | None = None
|
||||
|
||||
for rule in self._rules:
|
||||
if not rule.enabled:
|
||||
continue
|
||||
|
||||
if self._rule_matches(rule, action):
|
||||
applied_rules.append(rule.id)
|
||||
|
||||
if rule.reason:
|
||||
reasons.append(rule.reason)
|
||||
|
||||
# Handle decision priority
|
||||
if rule.decision == SafetyDecision.DENY:
|
||||
# Deny takes precedence
|
||||
final_decision = SafetyDecision.DENY
|
||||
break
|
||||
|
||||
elif rule.decision == SafetyDecision.REQUIRE_APPROVAL:
|
||||
# Upgrade to require approval
|
||||
if final_decision != SafetyDecision.DENY:
|
||||
final_decision = SafetyDecision.REQUIRE_APPROVAL
|
||||
|
||||
# If no rules matched and no explicit allow, default to allow
|
||||
if not applied_rules:
|
||||
reasons.append("No matching rules - default allow")
|
||||
|
||||
result = ValidationResult(
|
||||
action_id=action.id,
|
||||
decision=final_decision,
|
||||
applied_rules=applied_rules,
|
||||
reasons=reasons,
|
||||
approval_id=approval_id,
|
||||
)
|
||||
|
||||
# Cache result
|
||||
if self._cache_enabled:
|
||||
cache_key = self._get_cache_key(action)
|
||||
await self._cache.set(cache_key, result)
|
||||
|
||||
return result
|
||||
|
||||
async def validate_batch(
|
||||
self,
|
||||
actions: list[ActionRequest],
|
||||
policy: SafetyPolicy | None = None,
|
||||
) -> list[ValidationResult]:
|
||||
"""
|
||||
Validate multiple actions.
|
||||
|
||||
Args:
|
||||
actions: Actions to validate
|
||||
policy: Optional policy override
|
||||
|
||||
Returns:
|
||||
List of validation results
|
||||
"""
|
||||
tasks = [self.validate(action, policy) for action in actions]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
def enable_bypass(self, reason: str) -> None:
|
||||
"""
|
||||
Enable validation bypass (emergency use only).
|
||||
|
||||
Args:
|
||||
reason: Reason for enabling bypass
|
||||
"""
|
||||
logger.critical("Validation bypass enabled: %s", reason)
|
||||
self._bypass_enabled = True
|
||||
self._bypass_reason = reason
|
||||
|
||||
def disable_bypass(self) -> None:
|
||||
"""Disable validation bypass."""
|
||||
logger.info("Validation bypass disabled")
|
||||
self._bypass_enabled = False
|
||||
self._bypass_reason = None
|
||||
|
||||
async def clear_cache(self) -> None:
|
||||
"""Clear the validation cache."""
|
||||
await self._cache.clear()
|
||||
|
||||
def _rule_matches(self, rule: ValidationRule, action: ActionRequest) -> bool:
|
||||
"""Check if a rule matches an action."""
|
||||
# Check action types
|
||||
if rule.action_types:
|
||||
if action.action_type not in rule.action_types:
|
||||
return False
|
||||
|
||||
# Check tool patterns
|
||||
if rule.tool_patterns:
|
||||
if not action.tool_name:
|
||||
return False
|
||||
matched = False
|
||||
for pattern in rule.tool_patterns:
|
||||
if self._matches_pattern(action.tool_name, pattern):
|
||||
matched = True
|
||||
break
|
||||
if not matched:
|
||||
return False
|
||||
|
||||
# Check resource patterns
|
||||
if rule.resource_patterns:
|
||||
if not action.resource:
|
||||
return False
|
||||
matched = False
|
||||
for pattern in rule.resource_patterns:
|
||||
if self._matches_pattern(action.resource, pattern):
|
||||
matched = True
|
||||
break
|
||||
if not matched:
|
||||
return False
|
||||
|
||||
# Check agent IDs
|
||||
if rule.agent_ids:
|
||||
if action.metadata.agent_id not in rule.agent_ids:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _matches_pattern(self, value: str, pattern: str) -> bool:
|
||||
"""Check if value matches a pattern (supports wildcards)."""
|
||||
if pattern == "*":
|
||||
return True
|
||||
|
||||
# Use fnmatch for glob-style matching
|
||||
return fnmatch.fnmatch(value, pattern)
|
||||
|
||||
def _get_cache_key(self, action: ActionRequest) -> str:
|
||||
"""Generate a cache key for an action."""
|
||||
# Key based on action characteristics that affect validation
|
||||
key_parts = [
|
||||
action.action_type.value,
|
||||
action.tool_name or "",
|
||||
action.resource or "",
|
||||
action.metadata.agent_id,
|
||||
action.metadata.autonomy_level.value,
|
||||
]
|
||||
return ":".join(key_parts)
|
||||
|
||||
|
||||
# Module-level convenience functions
|
||||
|
||||
|
||||
def create_allow_rule(
|
||||
name: str,
|
||||
tool_patterns: list[str] | None = None,
|
||||
resource_patterns: list[str] | None = None,
|
||||
action_types: list[ActionType] | None = None,
|
||||
priority: int = 0,
|
||||
) -> ValidationRule:
|
||||
"""Create an allow rule."""
|
||||
return ValidationRule(
|
||||
name=name,
|
||||
tool_patterns=tool_patterns,
|
||||
resource_patterns=resource_patterns,
|
||||
action_types=action_types,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
|
||||
def create_deny_rule(
|
||||
name: str,
|
||||
tool_patterns: list[str] | None = None,
|
||||
resource_patterns: list[str] | None = None,
|
||||
action_types: list[ActionType] | None = None,
|
||||
reason: str | None = None,
|
||||
priority: int = 100,
|
||||
) -> ValidationRule:
|
||||
"""Create a deny rule."""
|
||||
return ValidationRule(
|
||||
name=name,
|
||||
tool_patterns=tool_patterns,
|
||||
resource_patterns=resource_patterns,
|
||||
action_types=action_types,
|
||||
decision=SafetyDecision.DENY,
|
||||
reason=reason,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
|
||||
def create_approval_rule(
|
||||
name: str,
|
||||
tool_patterns: list[str] | None = None,
|
||||
resource_patterns: list[str] | None = None,
|
||||
action_types: list[ActionType] | None = None,
|
||||
reason: str | None = None,
|
||||
priority: int = 50,
|
||||
) -> ValidationRule:
|
||||
"""Create a require-approval rule."""
|
||||
return ValidationRule(
|
||||
name=name,
|
||||
tool_patterns=tool_patterns,
|
||||
resource_patterns=resource_patterns,
|
||||
action_types=action_types,
|
||||
decision=SafetyDecision.REQUIRE_APPROVAL,
|
||||
reason=reason,
|
||||
priority=priority,
|
||||
)
|
||||
0
backend/tests/services/safety/__init__.py
Normal file
0
backend/tests/services/safety/__init__.py
Normal file
345
backend/tests/services/safety/test_content_filter.py
Normal file
345
backend/tests/services/safety/test_content_filter.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""Tests for content filtering module."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.safety.content.filter import (
|
||||
ContentCategory,
|
||||
ContentFilter,
|
||||
FilterAction,
|
||||
FilterPattern,
|
||||
filter_content,
|
||||
scan_for_secrets,
|
||||
)
|
||||
from app.services.safety.exceptions import ContentFilterError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def filter_all() -> ContentFilter:
|
||||
"""Create a ContentFilter with all filters enabled."""
|
||||
return ContentFilter(
|
||||
enable_pii_filter=True,
|
||||
enable_secret_filter=True,
|
||||
enable_injection_filter=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def filter_pii_only() -> ContentFilter:
|
||||
"""Create a ContentFilter with only PII filter."""
|
||||
return ContentFilter(
|
||||
enable_pii_filter=True,
|
||||
enable_secret_filter=False,
|
||||
enable_injection_filter=False,
|
||||
)
|
||||
|
||||
|
||||
class TestContentFilter:
|
||||
"""Tests for ContentFilter class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_email(self, filter_pii_only: ContentFilter) -> None:
|
||||
"""Test filtering email addresses."""
|
||||
content = "Contact me at john.doe@example.com for details."
|
||||
result = await filter_pii_only.filter(content)
|
||||
|
||||
assert result.has_sensitive_content
|
||||
assert "[EMAIL]" in result.filtered_content
|
||||
assert "john.doe@example.com" not in result.filtered_content
|
||||
assert len(result.matches) == 1
|
||||
assert result.matches[0].category == ContentCategory.PII
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_phone_number(self, filter_pii_only: ContentFilter) -> None:
|
||||
"""Test filtering phone numbers."""
|
||||
content = "Call me at 555-123-4567 or (555) 987-6543."
|
||||
result = await filter_pii_only.filter(content)
|
||||
|
||||
assert result.has_sensitive_content
|
||||
assert "[PHONE]" in result.filtered_content
|
||||
# Should redact both phone numbers
|
||||
assert "555-123-4567" not in result.filtered_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_ssn(self, filter_pii_only: ContentFilter) -> None:
|
||||
"""Test filtering Social Security Numbers."""
|
||||
content = "SSN: 123-45-6789"
|
||||
result = await filter_pii_only.filter(content)
|
||||
|
||||
assert result.has_sensitive_content
|
||||
assert "[SSN]" in result.filtered_content
|
||||
assert "123-45-6789" not in result.filtered_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_credit_card(self, filter_all: ContentFilter) -> None:
|
||||
"""Test filtering credit card numbers."""
|
||||
content = "Card: 4111-2222-3333-4444"
|
||||
result = await filter_all.filter(content)
|
||||
|
||||
assert result.has_sensitive_content
|
||||
assert "[CREDIT_CARD]" in result.filtered_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_block_api_key(self, filter_all: ContentFilter) -> None:
|
||||
"""Test blocking API keys."""
|
||||
content = "api_key: sk-abcdef1234567890abcdef1234567890"
|
||||
result = await filter_all.filter(content)
|
||||
|
||||
assert result.blocked
|
||||
assert result.block_reason is not None
|
||||
assert "api_key" in result.block_reason.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_block_github_token(self, filter_all: ContentFilter) -> None:
|
||||
"""Test blocking GitHub tokens."""
|
||||
content = "token: ghp_abcdefghijklmnopqrstuvwxyz1234567890"
|
||||
result = await filter_all.filter(content)
|
||||
|
||||
assert result.blocked
|
||||
assert len(result.matches) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_block_private_key(self, filter_all: ContentFilter) -> None:
|
||||
"""Test blocking private keys."""
|
||||
content = """
|
||||
-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIEpAIBAAKCAQEA...
|
||||
-----END RSA PRIVATE KEY-----
|
||||
"""
|
||||
result = await filter_all.filter(content)
|
||||
|
||||
assert result.blocked
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_block_password_in_url(self, filter_all: ContentFilter) -> None:
|
||||
"""Test blocking passwords in URLs."""
|
||||
content = "Connect to: postgres://user:secretpassword@localhost/db"
|
||||
result = await filter_all.filter(content)
|
||||
|
||||
assert result.blocked or result.has_sensitive_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_warn_ip_address(self, filter_pii_only: ContentFilter) -> None:
|
||||
"""Test warning on IP addresses."""
|
||||
content = "Server IP: 192.168.1.100"
|
||||
result = await filter_pii_only.filter(content)
|
||||
|
||||
# IP addresses generate warnings, not blocks
|
||||
assert len(result.warnings) > 0 or result.has_sensitive_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_false_positives_clean_content(
|
||||
self,
|
||||
filter_all: ContentFilter,
|
||||
) -> None:
|
||||
"""Test that clean content passes through."""
|
||||
content = "This is a normal message with no sensitive data."
|
||||
result = await filter_all.filter(content)
|
||||
|
||||
assert not result.blocked
|
||||
assert result.filtered_content == content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raise_on_block(self, filter_all: ContentFilter) -> None:
|
||||
"""Test raising exception on blocked content."""
|
||||
content = "-----BEGIN RSA PRIVATE KEY-----"
|
||||
|
||||
with pytest.raises(ContentFilterError):
|
||||
await filter_all.filter(content, raise_on_block=True)
|
||||
|
||||
|
||||
class TestFilterDict:
|
||||
"""Tests for dictionary filtering."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_dict_values(self, filter_pii_only: ContentFilter) -> None:
|
||||
"""Test filtering string values in a dictionary."""
|
||||
data = {
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com",
|
||||
"age": 30,
|
||||
}
|
||||
result = await filter_pii_only.filter_dict(data)
|
||||
|
||||
assert "[EMAIL]" in result["email"]
|
||||
assert result["age"] == 30 # Non-string unchanged
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_dict_recursive(
|
||||
self,
|
||||
filter_pii_only: ContentFilter,
|
||||
) -> None:
|
||||
"""Test recursive dictionary filtering."""
|
||||
data = {
|
||||
"user": {
|
||||
"contact": {
|
||||
"email": "test@example.com",
|
||||
}
|
||||
}
|
||||
}
|
||||
result = await filter_pii_only.filter_dict(data, recursive=True)
|
||||
|
||||
assert "[EMAIL]" in result["user"]["contact"]["email"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_dict_specific_keys(
|
||||
self,
|
||||
filter_pii_only: ContentFilter,
|
||||
) -> None:
|
||||
"""Test filtering specific keys only."""
|
||||
data = {
|
||||
"public_email": "public@example.com",
|
||||
"private_email": "private@example.com",
|
||||
}
|
||||
result = await filter_pii_only.filter_dict(
|
||||
data,
|
||||
keys_to_filter=["private_email"],
|
||||
)
|
||||
|
||||
# Only private_email should be filtered
|
||||
assert "public@example.com" in result["public_email"]
|
||||
assert "[EMAIL]" in result["private_email"]
|
||||
|
||||
|
||||
class TestScan:
|
||||
"""Tests for content scanning."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_without_filtering(
|
||||
self,
|
||||
filter_all: ContentFilter,
|
||||
) -> None:
|
||||
"""Test scanning without modifying content."""
|
||||
content = "Email: test@example.com, SSN: 123-45-6789"
|
||||
matches = await filter_all.scan(content)
|
||||
|
||||
assert len(matches) >= 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_specific_categories(
|
||||
self,
|
||||
filter_all: ContentFilter,
|
||||
) -> None:
|
||||
"""Test scanning for specific categories only."""
|
||||
content = "Email: test@example.com, token: ghp_abc123456789012345678901234567890123"
|
||||
|
||||
# Scan only for secrets
|
||||
matches = await filter_all.scan(
|
||||
content,
|
||||
categories=[ContentCategory.SECRETS],
|
||||
)
|
||||
|
||||
# Should only find the token, not the email
|
||||
assert all(m.category == ContentCategory.SECRETS for m in matches)
|
||||
|
||||
|
||||
class TestValidateSafe:
|
||||
"""Tests for safe validation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_safe_clean(self, filter_all: ContentFilter) -> None:
|
||||
"""Test validation of clean content."""
|
||||
content = "This is safe content."
|
||||
is_safe, issues = await filter_all.validate_safe(content)
|
||||
|
||||
assert is_safe is True
|
||||
assert len(issues) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_safe_with_secrets(
|
||||
self,
|
||||
filter_all: ContentFilter,
|
||||
) -> None:
|
||||
"""Test validation of content with secrets."""
|
||||
content = "-----BEGIN RSA PRIVATE KEY-----"
|
||||
is_safe, issues = await filter_all.validate_safe(content)
|
||||
|
||||
assert is_safe is False
|
||||
assert len(issues) > 0
|
||||
|
||||
|
||||
class TestCustomPatterns:
|
||||
"""Tests for custom pattern support."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_custom_pattern(self) -> None:
|
||||
"""Test adding a custom filter pattern."""
|
||||
content_filter = ContentFilter(
|
||||
enable_pii_filter=False,
|
||||
enable_secret_filter=False,
|
||||
enable_injection_filter=False,
|
||||
)
|
||||
|
||||
# Add custom pattern for internal IDs
|
||||
content_filter.add_pattern(
|
||||
FilterPattern(
|
||||
name="internal_id",
|
||||
category=ContentCategory.CUSTOM,
|
||||
pattern=r"INTERNAL-[A-Z0-9]{8}",
|
||||
action=FilterAction.REDACT,
|
||||
replacement="[INTERNAL_ID]",
|
||||
)
|
||||
)
|
||||
|
||||
content = "Reference: INTERNAL-ABC12345"
|
||||
result = await content_filter.filter(content)
|
||||
|
||||
assert "[INTERNAL_ID]" in result.filtered_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disable_pattern(self, filter_pii_only: ContentFilter) -> None:
|
||||
"""Test disabling a pattern."""
|
||||
filter_pii_only.enable_pattern("email", enabled=False)
|
||||
|
||||
content = "Email: test@example.com"
|
||||
result = await filter_pii_only.filter(content)
|
||||
|
||||
# Email should not be filtered
|
||||
assert "test@example.com" in result.filtered_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_pattern(self, filter_pii_only: ContentFilter) -> None:
|
||||
"""Test removing a pattern."""
|
||||
removed = filter_pii_only.remove_pattern("email")
|
||||
assert removed is True
|
||||
|
||||
content = "Email: test@example.com"
|
||||
result = await filter_pii_only.filter(content)
|
||||
|
||||
# Email should not be filtered
|
||||
assert "test@example.com" in result.filtered_content
|
||||
|
||||
|
||||
class TestPatternStats:
|
||||
"""Tests for pattern statistics."""
|
||||
|
||||
def test_get_pattern_stats(self, filter_all: ContentFilter) -> None:
|
||||
"""Test getting pattern statistics."""
|
||||
stats = filter_all.get_pattern_stats()
|
||||
|
||||
assert stats["total_patterns"] > 0
|
||||
assert "by_category" in stats
|
||||
assert "by_action" in stats
|
||||
|
||||
|
||||
class TestConvenienceFunctions:
|
||||
"""Tests for convenience functions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_content_function(self) -> None:
|
||||
"""Test the quick filter function."""
|
||||
content = "Email: test@example.com"
|
||||
filtered = await filter_content(content)
|
||||
|
||||
assert "test@example.com" not in filtered
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_for_secrets_function(self) -> None:
|
||||
"""Test the quick secret scan function."""
|
||||
content = "Token: ghp_abcdefghijklmnopqrstuvwxyz1234567890"
|
||||
matches = await scan_for_secrets(content)
|
||||
|
||||
assert len(matches) > 0
|
||||
assert matches[0].category in (
|
||||
ContentCategory.SECRETS,
|
||||
ContentCategory.CREDENTIALS,
|
||||
)
|
||||
425
backend/tests/services/safety/test_emergency.py
Normal file
425
backend/tests/services/safety/test_emergency.py
Normal file
@@ -0,0 +1,425 @@
|
||||
"""Tests for emergency controls module."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.safety.emergency.controls import (
|
||||
EmergencyControls,
|
||||
EmergencyReason,
|
||||
EmergencyState,
|
||||
EmergencyTrigger,
|
||||
)
|
||||
from app.services.safety.exceptions import EmergencyStopError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def controls() -> EmergencyControls:
|
||||
"""Create fresh EmergencyControls."""
|
||||
return EmergencyControls()
|
||||
|
||||
|
||||
class TestEmergencyControls:
|
||||
"""Tests for EmergencyControls class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initial_state_is_normal(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test that initial state is normal."""
|
||||
state = await controls.get_state("global")
|
||||
assert state == EmergencyState.NORMAL
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emergency_stop_changes_state(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test that emergency stop changes state to stopped."""
|
||||
event = await controls.emergency_stop(
|
||||
reason=EmergencyReason.MANUAL,
|
||||
triggered_by="test",
|
||||
message="Test emergency stop",
|
||||
)
|
||||
|
||||
assert event.state == EmergencyState.STOPPED
|
||||
state = await controls.get_state("global")
|
||||
assert state == EmergencyState.STOPPED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_changes_state(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test that pause changes state to paused."""
|
||||
event = await controls.pause(
|
||||
reason=EmergencyReason.BUDGET_EXCEEDED,
|
||||
triggered_by="budget_controller",
|
||||
message="Budget exceeded",
|
||||
)
|
||||
|
||||
assert event.state == EmergencyState.PAUSED
|
||||
state = await controls.get_state("global")
|
||||
assert state == EmergencyState.PAUSED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_from_paused(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test resuming from paused state."""
|
||||
await controls.pause(
|
||||
reason=EmergencyReason.RATE_LIMIT,
|
||||
triggered_by="limiter",
|
||||
message="Rate limited",
|
||||
)
|
||||
|
||||
resumed = await controls.resume(resumed_by="admin")
|
||||
|
||||
assert resumed is True
|
||||
state = await controls.get_state("global")
|
||||
assert state == EmergencyState.NORMAL
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_resume_from_stopped(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test that you cannot resume from stopped state."""
|
||||
await controls.emergency_stop(
|
||||
reason=EmergencyReason.SAFETY_VIOLATION,
|
||||
triggered_by="safety",
|
||||
message="Critical violation",
|
||||
)
|
||||
|
||||
resumed = await controls.resume(resumed_by="admin")
|
||||
|
||||
assert resumed is False
|
||||
state = await controls.get_state("global")
|
||||
assert state == EmergencyState.STOPPED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_from_stopped(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test resetting from stopped state."""
|
||||
await controls.emergency_stop(
|
||||
reason=EmergencyReason.SAFETY_VIOLATION,
|
||||
triggered_by="safety",
|
||||
message="Critical violation",
|
||||
)
|
||||
|
||||
reset = await controls.reset(reset_by="admin")
|
||||
|
||||
assert reset is True
|
||||
state = await controls.get_state("global")
|
||||
assert state == EmergencyState.NORMAL
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoped_emergency_stop(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test emergency stop with specific scope."""
|
||||
await controls.emergency_stop(
|
||||
reason=EmergencyReason.LOOP_DETECTED,
|
||||
triggered_by="detector",
|
||||
message="Loop in agent",
|
||||
scope="agent:agent-123",
|
||||
)
|
||||
|
||||
# Agent scope should be stopped
|
||||
agent_state = await controls.get_state("agent:agent-123")
|
||||
assert agent_state == EmergencyState.STOPPED
|
||||
|
||||
# Global should still be normal
|
||||
global_state = await controls.get_state("global")
|
||||
assert global_state == EmergencyState.NORMAL
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_allowed_when_normal(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test check_allowed returns True when state is normal."""
|
||||
allowed = await controls.check_allowed()
|
||||
assert allowed is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_allowed_when_stopped(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test check_allowed returns False when stopped."""
|
||||
await controls.emergency_stop(
|
||||
reason=EmergencyReason.MANUAL,
|
||||
triggered_by="test",
|
||||
message="Stop",
|
||||
)
|
||||
|
||||
allowed = await controls.check_allowed(raise_if_blocked=False)
|
||||
assert allowed is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_allowed_raises_when_blocked(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test check_allowed raises exception when blocked."""
|
||||
await controls.emergency_stop(
|
||||
reason=EmergencyReason.MANUAL,
|
||||
triggered_by="test",
|
||||
message="Stop",
|
||||
)
|
||||
|
||||
with pytest.raises(EmergencyStopError):
|
||||
await controls.check_allowed(raise_if_blocked=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_allowed_with_scope(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test check_allowed with specific scope."""
|
||||
await controls.pause(
|
||||
reason=EmergencyReason.BUDGET_EXCEEDED,
|
||||
triggered_by="budget",
|
||||
message="Paused",
|
||||
scope="project:proj-123",
|
||||
)
|
||||
|
||||
# Project scope should be blocked
|
||||
allowed_project = await controls.check_allowed(
|
||||
scope="project:proj-123",
|
||||
raise_if_blocked=False,
|
||||
)
|
||||
assert allowed_project is False
|
||||
|
||||
# Different scope should be allowed
|
||||
allowed_other = await controls.check_allowed(
|
||||
scope="project:proj-456",
|
||||
raise_if_blocked=False,
|
||||
)
|
||||
assert allowed_other is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_states(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test getting all states."""
|
||||
await controls.pause(
|
||||
reason=EmergencyReason.MANUAL,
|
||||
triggered_by="test",
|
||||
message="Pause",
|
||||
scope="agent:a1",
|
||||
)
|
||||
await controls.emergency_stop(
|
||||
reason=EmergencyReason.MANUAL,
|
||||
triggered_by="test",
|
||||
message="Stop",
|
||||
scope="agent:a2",
|
||||
)
|
||||
|
||||
states = await controls.get_all_states()
|
||||
|
||||
assert states["global"] == EmergencyState.NORMAL
|
||||
assert states["agent:a1"] == EmergencyState.PAUSED
|
||||
assert states["agent:a2"] == EmergencyState.STOPPED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_events(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test getting active (unresolved) events."""
|
||||
await controls.pause(
|
||||
reason=EmergencyReason.MANUAL,
|
||||
triggered_by="test",
|
||||
message="Pause 1",
|
||||
)
|
||||
|
||||
events = await controls.get_active_events()
|
||||
assert len(events) == 1
|
||||
|
||||
# Resume should resolve the event
|
||||
await controls.resume()
|
||||
events_after = await controls.get_active_events()
|
||||
assert len(events_after) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_history(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test getting event history."""
|
||||
await controls.pause(
|
||||
reason=EmergencyReason.RATE_LIMIT,
|
||||
triggered_by="test",
|
||||
message="Rate limited",
|
||||
)
|
||||
await controls.resume()
|
||||
|
||||
history = await controls.get_event_history()
|
||||
|
||||
assert len(history) == 1
|
||||
assert history[0].reason == EmergencyReason.RATE_LIMIT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_metadata(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test event metadata storage."""
|
||||
event = await controls.emergency_stop(
|
||||
reason=EmergencyReason.BUDGET_EXCEEDED,
|
||||
triggered_by="budget_controller",
|
||||
message="Over budget",
|
||||
metadata={"budget_type": "tokens", "usage": 150000},
|
||||
)
|
||||
|
||||
assert event.metadata["budget_type"] == "tokens"
|
||||
assert event.metadata["usage"] == 150000
|
||||
|
||||
|
||||
class TestCallbacks:
|
||||
"""Tests for callback functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_stop_callback(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test on_stop callback is called."""
|
||||
callback_called = []
|
||||
|
||||
def callback(event: object) -> None:
|
||||
callback_called.append(event)
|
||||
|
||||
controls.on_stop(callback)
|
||||
|
||||
await controls.emergency_stop(
|
||||
reason=EmergencyReason.MANUAL,
|
||||
triggered_by="test",
|
||||
message="Stop",
|
||||
)
|
||||
|
||||
assert len(callback_called) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_pause_callback(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test on_pause callback is called."""
|
||||
callback_called = []
|
||||
|
||||
def callback(event: object) -> None:
|
||||
callback_called.append(event)
|
||||
|
||||
controls.on_pause(callback)
|
||||
|
||||
await controls.pause(
|
||||
reason=EmergencyReason.MANUAL,
|
||||
triggered_by="test",
|
||||
message="Pause",
|
||||
)
|
||||
|
||||
assert len(callback_called) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_resume_callback(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test on_resume callback is called."""
|
||||
callback_called = []
|
||||
|
||||
def callback(data: object) -> None:
|
||||
callback_called.append(data)
|
||||
|
||||
controls.on_resume(callback)
|
||||
|
||||
await controls.pause(
|
||||
reason=EmergencyReason.MANUAL,
|
||||
triggered_by="test",
|
||||
message="Pause",
|
||||
)
|
||||
await controls.resume()
|
||||
|
||||
assert len(callback_called) == 1
|
||||
|
||||
|
||||
class TestEmergencyTrigger:
|
||||
"""Tests for EmergencyTrigger class."""
|
||||
|
||||
@pytest.fixture
|
||||
def trigger(self, controls: EmergencyControls) -> EmergencyTrigger:
|
||||
"""Create an EmergencyTrigger."""
|
||||
return EmergencyTrigger(controls)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_on_safety_violation(
|
||||
self,
|
||||
trigger: EmergencyTrigger,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test triggering emergency on safety violation."""
|
||||
event = await trigger.trigger_on_safety_violation(
|
||||
violation_type="unauthorized_access",
|
||||
details={"resource": "/secrets/key"},
|
||||
)
|
||||
|
||||
assert event.reason == EmergencyReason.SAFETY_VIOLATION
|
||||
assert event.state == EmergencyState.STOPPED
|
||||
|
||||
state = await controls.get_state("global")
|
||||
assert state == EmergencyState.STOPPED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_on_budget_exceeded(
|
||||
self,
|
||||
trigger: EmergencyTrigger,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test triggering pause on budget exceeded."""
|
||||
event = await trigger.trigger_on_budget_exceeded(
|
||||
budget_type="tokens",
|
||||
current=150000,
|
||||
limit=100000,
|
||||
)
|
||||
|
||||
assert event.reason == EmergencyReason.BUDGET_EXCEEDED
|
||||
assert event.state == EmergencyState.PAUSED # Pause, not stop
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_on_loop_detected(
|
||||
self,
|
||||
trigger: EmergencyTrigger,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test triggering pause on loop detection."""
|
||||
event = await trigger.trigger_on_loop_detected(
|
||||
loop_type="exact",
|
||||
agent_id="agent-123",
|
||||
details={"pattern": "file_read"},
|
||||
)
|
||||
|
||||
assert event.reason == EmergencyReason.LOOP_DETECTED
|
||||
assert event.scope == "agent:agent-123"
|
||||
assert event.state == EmergencyState.PAUSED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_on_content_violation(
|
||||
self,
|
||||
trigger: EmergencyTrigger,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test triggering stop on content violation."""
|
||||
event = await trigger.trigger_on_content_violation(
|
||||
category="secrets",
|
||||
pattern="private_key",
|
||||
)
|
||||
|
||||
assert event.reason == EmergencyReason.CONTENT_VIOLATION
|
||||
assert event.state == EmergencyState.STOPPED
|
||||
316
backend/tests/services/safety/test_loops.py
Normal file
316
backend/tests/services/safety/test_loops.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""Tests for loop detection module."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.safety.exceptions import LoopDetectedError
|
||||
from app.services.safety.loops.detector import (
|
||||
ActionSignature,
|
||||
LoopBreaker,
|
||||
LoopDetector,
|
||||
)
|
||||
from app.services.safety.models import (
|
||||
ActionMetadata,
|
||||
ActionRequest,
|
||||
ActionType,
|
||||
AutonomyLevel,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def detector() -> LoopDetector:
|
||||
"""Create a fresh LoopDetector with low thresholds for testing."""
|
||||
return LoopDetector(
|
||||
history_size=20,
|
||||
max_exact_repetitions=3,
|
||||
max_semantic_repetitions=5,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_metadata() -> ActionMetadata:
|
||||
"""Create sample action metadata."""
|
||||
return ActionMetadata(
|
||||
agent_id="test-agent",
|
||||
session_id="test-session",
|
||||
autonomy_level=AutonomyLevel.MILESTONE,
|
||||
)
|
||||
|
||||
|
||||
def create_action(
|
||||
metadata: ActionMetadata,
|
||||
tool_name: str,
|
||||
resource: str = "/tmp/test.txt", # noqa: S108
|
||||
arguments: dict | None = None,
|
||||
) -> ActionRequest:
|
||||
"""Helper to create test actions."""
|
||||
return ActionRequest(
|
||||
action_type=ActionType.FILE_READ,
|
||||
tool_name=tool_name,
|
||||
resource=resource,
|
||||
arguments=arguments or {},
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
class TestActionSignature:
|
||||
"""Tests for ActionSignature class."""
|
||||
|
||||
def test_exact_key_includes_args(self, sample_metadata: ActionMetadata) -> None:
|
||||
"""Test that exact key includes argument hash."""
|
||||
action1 = create_action(sample_metadata, "file_read", arguments={"path": "a"})
|
||||
action2 = create_action(sample_metadata, "file_read", arguments={"path": "b"})
|
||||
|
||||
sig1 = ActionSignature(action1)
|
||||
sig2 = ActionSignature(action2)
|
||||
|
||||
assert sig1.exact_key() != sig2.exact_key()
|
||||
|
||||
def test_semantic_key_ignores_args(self, sample_metadata: ActionMetadata) -> None:
|
||||
"""Test that semantic key ignores arguments."""
|
||||
action1 = create_action(sample_metadata, "file_read", arguments={"path": "a"})
|
||||
action2 = create_action(sample_metadata, "file_read", arguments={"path": "b"})
|
||||
|
||||
sig1 = ActionSignature(action1)
|
||||
sig2 = ActionSignature(action2)
|
||||
|
||||
assert sig1.semantic_key() == sig2.semantic_key()
|
||||
|
||||
def test_type_key(self, sample_metadata: ActionMetadata) -> None:
|
||||
"""Test type key extraction."""
|
||||
action = create_action(sample_metadata, "file_read")
|
||||
sig = ActionSignature(action)
|
||||
|
||||
assert sig.type_key() == "file_read"
|
||||
|
||||
|
||||
class TestLoopDetector:
|
||||
"""Tests for LoopDetector class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_loop_on_first_action(
|
||||
self,
|
||||
detector: LoopDetector,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test that first action is never a loop."""
|
||||
action = create_action(sample_metadata, "file_read")
|
||||
|
||||
is_loop, loop_type = await detector.check(action)
|
||||
|
||||
assert is_loop is False
|
||||
assert loop_type is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exact_loop_detection(
|
||||
self,
|
||||
detector: LoopDetector,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test detection of exact repetitions."""
|
||||
action = create_action(
|
||||
sample_metadata,
|
||||
"file_read",
|
||||
resource="/tmp/same.txt", # noqa: S108
|
||||
arguments={"path": "/tmp/same.txt"}, # noqa: S108
|
||||
)
|
||||
|
||||
# Record the same action 3 times (threshold)
|
||||
for _ in range(3):
|
||||
await detector.record(action)
|
||||
|
||||
# Next should be detected as a loop
|
||||
is_loop, loop_type = await detector.check(action)
|
||||
|
||||
assert is_loop is True
|
||||
assert loop_type == "exact"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_semantic_loop_detection(
|
||||
self,
|
||||
detector: LoopDetector,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test detection of semantic (similar) repetitions."""
|
||||
# Record same tool/resource but different arguments
|
||||
test_resource = "/tmp/test.txt" # noqa: S108
|
||||
for i in range(5):
|
||||
action = create_action(
|
||||
sample_metadata,
|
||||
"file_read",
|
||||
resource=test_resource,
|
||||
arguments={"offset": i},
|
||||
)
|
||||
await detector.record(action)
|
||||
|
||||
# Next similar action should be detected as semantic loop
|
||||
action = create_action(
|
||||
sample_metadata,
|
||||
"file_read",
|
||||
resource=test_resource,
|
||||
arguments={"offset": 100},
|
||||
)
|
||||
is_loop, loop_type = await detector.check(action)
|
||||
|
||||
assert is_loop is True
|
||||
assert loop_type == "semantic"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oscillation_detection(
|
||||
self,
|
||||
detector: LoopDetector,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test detection of A→B→A→B oscillation pattern."""
|
||||
action_a = create_action(sample_metadata, "tool_a", resource="/a")
|
||||
action_b = create_action(sample_metadata, "tool_b", resource="/b")
|
||||
|
||||
# Create A→B→A pattern
|
||||
await detector.record(action_a)
|
||||
await detector.record(action_b)
|
||||
await detector.record(action_a)
|
||||
|
||||
# Fourth action completing A→B→A→B should be detected as oscillation
|
||||
is_loop, loop_type = await detector.check(action_b)
|
||||
|
||||
assert is_loop is True
|
||||
assert loop_type == "oscillation"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_actions_no_loop(
|
||||
self,
|
||||
detector: LoopDetector,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test that different actions don't trigger loops."""
|
||||
for i in range(10):
|
||||
action = create_action(
|
||||
sample_metadata,
|
||||
f"tool_{i}",
|
||||
resource=f"/resource_{i}",
|
||||
)
|
||||
is_loop, _ = await detector.check(action)
|
||||
assert is_loop is False
|
||||
await detector.record(action)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_and_raise(
|
||||
self,
|
||||
detector: LoopDetector,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test check_and_raise raises on loop detection."""
|
||||
action = create_action(sample_metadata, "file_read")
|
||||
|
||||
# Record threshold number of times
|
||||
for _ in range(3):
|
||||
await detector.record(action)
|
||||
|
||||
# Should raise
|
||||
with pytest.raises(LoopDetectedError) as exc_info:
|
||||
await detector.check_and_raise(action)
|
||||
|
||||
assert "exact" in exc_info.value.loop_type.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_history(
|
||||
self,
|
||||
detector: LoopDetector,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test clearing agent history."""
|
||||
action = create_action(sample_metadata, "file_read")
|
||||
|
||||
# Record multiple times
|
||||
for _ in range(3):
|
||||
await detector.record(action)
|
||||
|
||||
# Clear history
|
||||
await detector.clear_history(sample_metadata.agent_id)
|
||||
|
||||
# Should no longer detect loop
|
||||
is_loop, _ = await detector.check(action)
|
||||
assert is_loop is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_per_agent_history(
|
||||
self,
|
||||
detector: LoopDetector,
|
||||
) -> None:
|
||||
"""Test that history is tracked per agent."""
|
||||
metadata1 = ActionMetadata(agent_id="agent-1", session_id="s1")
|
||||
metadata2 = ActionMetadata(agent_id="agent-2", session_id="s2")
|
||||
|
||||
action1 = create_action(metadata1, "file_read")
|
||||
action2 = create_action(metadata2, "file_read")
|
||||
|
||||
# Record for agent 1 (threshold times)
|
||||
for _ in range(3):
|
||||
await detector.record(action1)
|
||||
|
||||
# Agent 1 should detect loop
|
||||
is_loop1, _ = await detector.check(action1)
|
||||
assert is_loop1 is True
|
||||
|
||||
# Agent 2 should not detect loop
|
||||
is_loop2, _ = await detector.check(action2)
|
||||
assert is_loop2 is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stats(
|
||||
self,
|
||||
detector: LoopDetector,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test getting loop detection stats."""
|
||||
for i in range(5):
|
||||
action = create_action(
|
||||
sample_metadata,
|
||||
f"tool_{i % 2}", # Alternate between 2 tools
|
||||
resource=f"/resource_{i}",
|
||||
)
|
||||
await detector.record(action)
|
||||
|
||||
stats = await detector.get_stats(sample_metadata.agent_id)
|
||||
|
||||
assert stats["history_size"] == 5
|
||||
assert len(stats["action_type_counts"]) > 0
|
||||
|
||||
|
||||
class TestLoopBreaker:
|
||||
"""Tests for LoopBreaker class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_suggest_alternatives_exact(
|
||||
self,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test suggestions for exact loops."""
|
||||
action = create_action(sample_metadata, "file_read")
|
||||
suggestions = await LoopBreaker.suggest_alternatives(action, "exact")
|
||||
|
||||
assert len(suggestions) > 0
|
||||
assert "same action" in suggestions[0].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_suggest_alternatives_semantic(
|
||||
self,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test suggestions for semantic loops."""
|
||||
action = create_action(sample_metadata, "file_read")
|
||||
suggestions = await LoopBreaker.suggest_alternatives(action, "semantic")
|
||||
|
||||
assert len(suggestions) > 0
|
||||
assert "similar" in suggestions[0].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_suggest_alternatives_oscillation(
|
||||
self,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test suggestions for oscillation loops."""
|
||||
action = create_action(sample_metadata, "file_read")
|
||||
suggestions = await LoopBreaker.suggest_alternatives(action, "oscillation")
|
||||
|
||||
assert len(suggestions) > 0
|
||||
assert "oscillat" in suggestions[0].lower()
|
||||
437
backend/tests/services/safety/test_models.py
Normal file
437
backend/tests/services/safety/test_models.py
Normal file
@@ -0,0 +1,437 @@
|
||||
"""Tests for safety framework models."""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from app.services.safety.models import (
|
||||
ActionMetadata,
|
||||
ActionRequest,
|
||||
ActionResult,
|
||||
ActionType,
|
||||
ApprovalRequest,
|
||||
ApprovalResponse,
|
||||
ApprovalStatus,
|
||||
AuditEvent,
|
||||
AuditEventType,
|
||||
AutonomyLevel,
|
||||
BudgetScope,
|
||||
BudgetStatus,
|
||||
Checkpoint,
|
||||
CheckpointType,
|
||||
GuardianResult,
|
||||
PermissionLevel,
|
||||
RateLimitConfig,
|
||||
RollbackResult,
|
||||
SafetyDecision,
|
||||
SafetyPolicy,
|
||||
ValidationResult,
|
||||
ValidationRule,
|
||||
)
|
||||
|
||||
|
||||
class TestActionMetadata:
|
||||
"""Tests for ActionMetadata model."""
|
||||
|
||||
def test_create_with_defaults(self) -> None:
|
||||
"""Test creating metadata with default values."""
|
||||
metadata = ActionMetadata(
|
||||
agent_id="agent-1",
|
||||
)
|
||||
|
||||
assert metadata.agent_id == "agent-1"
|
||||
assert metadata.autonomy_level == AutonomyLevel.MILESTONE
|
||||
assert metadata.project_id is None
|
||||
assert metadata.session_id is None
|
||||
|
||||
def test_create_with_all_fields(self) -> None:
|
||||
"""Test creating metadata with all fields."""
|
||||
metadata = ActionMetadata(
|
||||
agent_id="agent-1",
|
||||
session_id="session-1",
|
||||
project_id="project-1",
|
||||
user_id="user-1",
|
||||
autonomy_level=AutonomyLevel.AUTONOMOUS,
|
||||
)
|
||||
|
||||
assert metadata.project_id == "project-1"
|
||||
assert metadata.user_id == "user-1"
|
||||
assert metadata.autonomy_level == AutonomyLevel.AUTONOMOUS
|
||||
|
||||
|
||||
class TestActionRequest:
|
||||
"""Tests for ActionRequest model."""
|
||||
|
||||
def test_create_basic_action(self) -> None:
|
||||
"""Test creating a basic action request."""
|
||||
metadata = ActionMetadata(agent_id="agent-1", session_id="session-1")
|
||||
action = ActionRequest(
|
||||
action_type=ActionType.FILE_READ,
|
||||
tool_name="file_read",
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
assert action.action_type == ActionType.FILE_READ
|
||||
assert action.tool_name == "file_read"
|
||||
assert action.id is not None
|
||||
assert action.metadata.agent_id == "agent-1"
|
||||
|
||||
def test_action_with_arguments(self) -> None:
|
||||
"""Test action with arguments."""
|
||||
test_path = "/tmp/test.txt" # noqa: S108
|
||||
metadata = ActionMetadata(agent_id="agent-1", session_id="session-1")
|
||||
action = ActionRequest(
|
||||
action_type=ActionType.FILE_WRITE,
|
||||
tool_name="file_write",
|
||||
arguments={"path": test_path, "content": "hello"},
|
||||
resource=test_path,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
assert action.arguments["path"] == test_path
|
||||
assert action.resource == test_path
|
||||
|
||||
|
||||
class TestActionResult:
|
||||
"""Tests for ActionResult model."""
|
||||
|
||||
def test_successful_result(self) -> None:
|
||||
"""Test creating a successful result."""
|
||||
result = ActionResult(
|
||||
action_id="action-1",
|
||||
success=True,
|
||||
data={"output": "done"},
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["output"] == "done"
|
||||
assert result.error is None
|
||||
|
||||
def test_failed_result(self) -> None:
|
||||
"""Test creating a failed result."""
|
||||
result = ActionResult(
|
||||
action_id="action-1",
|
||||
success=False,
|
||||
error="Permission denied",
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert result.error == "Permission denied"
|
||||
|
||||
|
||||
class TestValidationRule:
|
||||
"""Tests for ValidationRule model."""
|
||||
|
||||
def test_create_rule(self) -> None:
|
||||
"""Test creating a validation rule."""
|
||||
rule = ValidationRule(
|
||||
name="deny_shell",
|
||||
description="Deny shell commands",
|
||||
priority=100,
|
||||
tool_patterns=["shell_*"],
|
||||
decision=SafetyDecision.DENY,
|
||||
reason="Shell commands are not allowed",
|
||||
)
|
||||
|
||||
assert rule.name == "deny_shell"
|
||||
assert rule.priority == 100
|
||||
assert rule.decision == SafetyDecision.DENY
|
||||
assert rule.enabled is True
|
||||
|
||||
def test_rule_defaults(self) -> None:
|
||||
"""Test rule default values."""
|
||||
rule = ValidationRule(name="test_rule", decision=SafetyDecision.ALLOW)
|
||||
|
||||
assert rule.id is not None
|
||||
assert rule.priority == 0
|
||||
assert rule.enabled is True
|
||||
assert rule.decision == SafetyDecision.ALLOW
|
||||
|
||||
|
||||
class TestValidationResult:
|
||||
"""Tests for ValidationResult model."""
|
||||
|
||||
def test_allow_result(self) -> None:
|
||||
"""Test an allow result."""
|
||||
result = ValidationResult(
|
||||
action_id="action-1",
|
||||
decision=SafetyDecision.ALLOW,
|
||||
applied_rules=["rule-1"],
|
||||
reasons=["Action is allowed"],
|
||||
)
|
||||
|
||||
assert result.decision == SafetyDecision.ALLOW
|
||||
assert len(result.applied_rules) == 1
|
||||
|
||||
def test_deny_result(self) -> None:
|
||||
"""Test a deny result."""
|
||||
result = ValidationResult(
|
||||
action_id="action-1",
|
||||
decision=SafetyDecision.DENY,
|
||||
applied_rules=["deny_rule"],
|
||||
reasons=["Action is not permitted"],
|
||||
)
|
||||
|
||||
assert result.decision == SafetyDecision.DENY
|
||||
|
||||
|
||||
class TestBudgetStatus:
|
||||
"""Tests for BudgetStatus model."""
|
||||
|
||||
def test_under_budget(self) -> None:
|
||||
"""Test status when under budget."""
|
||||
status = BudgetStatus(
|
||||
scope=BudgetScope.SESSION,
|
||||
scope_id="session-1",
|
||||
tokens_used=5000,
|
||||
tokens_limit=10000,
|
||||
tokens_remaining=5000,
|
||||
)
|
||||
|
||||
assert status.tokens_remaining == 5000
|
||||
assert status.tokens_used == 5000
|
||||
|
||||
def test_over_budget(self) -> None:
|
||||
"""Test status when over budget."""
|
||||
status = BudgetStatus(
|
||||
scope=BudgetScope.SESSION,
|
||||
scope_id="session-1",
|
||||
cost_used_usd=15.0,
|
||||
cost_limit_usd=10.0,
|
||||
cost_remaining_usd=0.0,
|
||||
is_exceeded=True,
|
||||
)
|
||||
|
||||
assert status.is_exceeded is True
|
||||
assert status.cost_remaining_usd == 0.0
|
||||
|
||||
|
||||
class TestRateLimitConfig:
|
||||
"""Tests for RateLimitConfig model."""
|
||||
|
||||
def test_create_config(self) -> None:
|
||||
"""Test creating rate limit config."""
|
||||
config = RateLimitConfig(
|
||||
name="actions",
|
||||
limit=60,
|
||||
window_seconds=60,
|
||||
)
|
||||
|
||||
assert config.name == "actions"
|
||||
assert config.limit == 60
|
||||
assert config.window_seconds == 60
|
||||
|
||||
|
||||
class TestApprovalRequest:
|
||||
"""Tests for ApprovalRequest model."""
|
||||
|
||||
def test_create_request(self) -> None:
|
||||
"""Test creating an approval request."""
|
||||
metadata = ActionMetadata(agent_id="agent-1", session_id="session-1")
|
||||
action = ActionRequest(
|
||||
action_type=ActionType.DATABASE_MUTATE,
|
||||
tool_name="db_delete",
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
request = ApprovalRequest(
|
||||
id="approval-1",
|
||||
action=action,
|
||||
reason="Database mutation requires approval",
|
||||
urgency="high",
|
||||
timeout_seconds=300,
|
||||
)
|
||||
|
||||
assert request.id == "approval-1"
|
||||
assert request.urgency == "high"
|
||||
assert request.timeout_seconds == 300
|
||||
|
||||
|
||||
class TestApprovalResponse:
|
||||
"""Tests for ApprovalResponse model."""
|
||||
|
||||
def test_approved_response(self) -> None:
|
||||
"""Test an approved response."""
|
||||
response = ApprovalResponse(
|
||||
request_id="approval-1",
|
||||
status=ApprovalStatus.APPROVED,
|
||||
decided_by="admin",
|
||||
reason="Looks safe",
|
||||
)
|
||||
|
||||
assert response.status == ApprovalStatus.APPROVED
|
||||
assert response.decided_by == "admin"
|
||||
|
||||
def test_denied_response(self) -> None:
|
||||
"""Test a denied response."""
|
||||
response = ApprovalResponse(
|
||||
request_id="approval-1",
|
||||
status=ApprovalStatus.DENIED,
|
||||
decided_by="admin",
|
||||
reason="Too risky",
|
||||
)
|
||||
|
||||
assert response.status == ApprovalStatus.DENIED
|
||||
|
||||
|
||||
class TestCheckpoint:
|
||||
"""Tests for Checkpoint model."""
|
||||
|
||||
def test_create_checkpoint(self) -> None:
|
||||
"""Test creating a checkpoint."""
|
||||
test_path = "/tmp/test.txt" # noqa: S108
|
||||
checkpoint = Checkpoint(
|
||||
id="checkpoint-1",
|
||||
checkpoint_type=CheckpointType.FILE,
|
||||
action_id="action-1",
|
||||
created_at=datetime.utcnow(),
|
||||
data={"path": test_path},
|
||||
description="File checkpoint",
|
||||
)
|
||||
|
||||
assert checkpoint.id == "checkpoint-1"
|
||||
assert checkpoint.is_valid is True
|
||||
|
||||
def test_expired_checkpoint(self) -> None:
|
||||
"""Test an expired checkpoint."""
|
||||
checkpoint = Checkpoint(
|
||||
id="checkpoint-1",
|
||||
checkpoint_type=CheckpointType.FILE,
|
||||
action_id="action-1",
|
||||
created_at=datetime.utcnow() - timedelta(hours=2),
|
||||
expires_at=datetime.utcnow() - timedelta(hours=1),
|
||||
data={},
|
||||
)
|
||||
|
||||
# is_valid is a simple bool, not computed from expires_at
|
||||
# The RollbackManager handles expiration logic
|
||||
assert checkpoint.is_valid is True # Default value
|
||||
|
||||
|
||||
class TestRollbackResult:
|
||||
"""Tests for RollbackResult model."""
|
||||
|
||||
def test_successful_rollback(self) -> None:
|
||||
"""Test a successful rollback."""
|
||||
result = RollbackResult(
|
||||
checkpoint_id="checkpoint-1",
|
||||
success=True,
|
||||
actions_rolled_back=["file:/tmp/test.txt"],
|
||||
failed_actions=[],
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert len(result.actions_rolled_back) == 1
|
||||
|
||||
def test_partial_rollback(self) -> None:
|
||||
"""Test a partial rollback."""
|
||||
result = RollbackResult(
|
||||
checkpoint_id="checkpoint-1",
|
||||
success=False,
|
||||
actions_rolled_back=["file:/tmp/a.txt"],
|
||||
failed_actions=["file:/tmp/b.txt"],
|
||||
error="Failed to rollback 1 item",
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert len(result.failed_actions) == 1
|
||||
|
||||
|
||||
class TestAuditEvent:
|
||||
"""Tests for AuditEvent model."""
|
||||
|
||||
def test_create_event(self) -> None:
|
||||
"""Test creating an audit event."""
|
||||
event = AuditEvent(
|
||||
id="event-1",
|
||||
event_type=AuditEventType.ACTION_EXECUTED,
|
||||
timestamp=datetime.utcnow(),
|
||||
agent_id="agent-1",
|
||||
action_id="action-1",
|
||||
data={"tool": "file_read"},
|
||||
)
|
||||
|
||||
assert event.event_type == AuditEventType.ACTION_EXECUTED
|
||||
assert event.agent_id == "agent-1"
|
||||
|
||||
|
||||
class TestSafetyPolicy:
|
||||
"""Tests for SafetyPolicy model."""
|
||||
|
||||
def test_default_policy(self) -> None:
|
||||
"""Test creating a default policy."""
|
||||
policy = SafetyPolicy(
|
||||
name="default",
|
||||
description="Default safety policy",
|
||||
)
|
||||
|
||||
assert policy.name == "default"
|
||||
assert policy.enabled is True
|
||||
|
||||
def test_restrictive_policy(self) -> None:
|
||||
"""Test creating a restrictive policy."""
|
||||
policy = SafetyPolicy(
|
||||
name="restrictive",
|
||||
description="Restrictive policy",
|
||||
denied_tools=["shell_*", "exec_*"],
|
||||
require_approval_for=["database_*", "git_push"],
|
||||
)
|
||||
|
||||
assert len(policy.denied_tools) == 2
|
||||
assert len(policy.require_approval_for) == 2
|
||||
|
||||
|
||||
class TestGuardianResult:
|
||||
"""Tests for GuardianResult model."""
|
||||
|
||||
def test_allowed_result(self) -> None:
|
||||
"""Test an allowed result."""
|
||||
result = GuardianResult(
|
||||
action_id="action-1",
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["All checks passed"],
|
||||
)
|
||||
|
||||
assert result.decision == SafetyDecision.ALLOW
|
||||
assert result.allowed is True
|
||||
assert result.approval_id is None
|
||||
|
||||
def test_approval_required_result(self) -> None:
|
||||
"""Test a result requiring approval."""
|
||||
result = GuardianResult(
|
||||
action_id="action-1",
|
||||
allowed=False,
|
||||
decision=SafetyDecision.REQUIRE_APPROVAL,
|
||||
reasons=["Action requires human approval"],
|
||||
approval_id="approval-123",
|
||||
)
|
||||
|
||||
assert result.decision == SafetyDecision.REQUIRE_APPROVAL
|
||||
assert result.approval_id == "approval-123"
|
||||
|
||||
|
||||
class TestEnums:
|
||||
"""Tests for enum values."""
|
||||
|
||||
def test_action_types(self) -> None:
|
||||
"""Test action type enum values."""
|
||||
assert ActionType.FILE_READ.value == "file_read"
|
||||
assert ActionType.SHELL_COMMAND.value == "shell_command"
|
||||
|
||||
def test_autonomy_levels(self) -> None:
|
||||
"""Test autonomy level enum values."""
|
||||
assert AutonomyLevel.FULL_CONTROL.value == "full_control"
|
||||
assert AutonomyLevel.MILESTONE.value == "milestone"
|
||||
assert AutonomyLevel.AUTONOMOUS.value == "autonomous"
|
||||
|
||||
def test_permission_levels(self) -> None:
|
||||
"""Test permission level enum values."""
|
||||
assert PermissionLevel.NONE.value == "none"
|
||||
assert PermissionLevel.READ.value == "read"
|
||||
assert PermissionLevel.WRITE.value == "write"
|
||||
assert PermissionLevel.ADMIN.value == "admin"
|
||||
|
||||
def test_safety_decisions(self) -> None:
|
||||
"""Test safety decision enum values."""
|
||||
assert SafetyDecision.ALLOW.value == "allow"
|
||||
assert SafetyDecision.DENY.value == "deny"
|
||||
assert SafetyDecision.REQUIRE_APPROVAL.value == "require_approval"
|
||||
404
backend/tests/services/safety/test_validation.py
Normal file
404
backend/tests/services/safety/test_validation.py
Normal file
@@ -0,0 +1,404 @@
|
||||
"""Tests for safety validation module."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.safety.models import (
|
||||
ActionMetadata,
|
||||
ActionRequest,
|
||||
ActionType,
|
||||
AutonomyLevel,
|
||||
SafetyDecision,
|
||||
SafetyPolicy,
|
||||
ValidationRule,
|
||||
)
|
||||
from app.services.safety.validation.validator import (
|
||||
ActionValidator,
|
||||
create_allow_rule,
|
||||
create_approval_rule,
|
||||
create_deny_rule,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def validator() -> ActionValidator:
|
||||
"""Create a fresh ActionValidator."""
|
||||
return ActionValidator(cache_enabled=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_action() -> ActionRequest:
|
||||
"""Create a sample action request."""
|
||||
metadata = ActionMetadata(
|
||||
agent_id="test-agent",
|
||||
session_id="test-session",
|
||||
autonomy_level=AutonomyLevel.MILESTONE,
|
||||
)
|
||||
return ActionRequest(
|
||||
action_type=ActionType.FILE_READ,
|
||||
tool_name="file_read",
|
||||
resource="/tmp/test.txt", # noqa: S108
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
class TestActionValidator:
|
||||
"""Tests for ActionValidator class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_rules_allows_by_default(
|
||||
self,
|
||||
validator: ActionValidator,
|
||||
sample_action: ActionRequest,
|
||||
) -> None:
|
||||
"""Test that actions are allowed by default with no rules."""
|
||||
result = await validator.validate(sample_action)
|
||||
|
||||
assert result.decision == SafetyDecision.ALLOW
|
||||
assert "No matching rules" in result.reasons[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deny_rule_blocks_action(
|
||||
self,
|
||||
validator: ActionValidator,
|
||||
) -> None:
|
||||
"""Test that a deny rule blocks matching actions."""
|
||||
validator.add_rule(
|
||||
create_deny_rule(
|
||||
name="deny_shell",
|
||||
tool_patterns=["shell_*"],
|
||||
reason="Shell commands not allowed",
|
||||
)
|
||||
)
|
||||
|
||||
metadata = ActionMetadata(agent_id="test-agent", session_id="session-1")
|
||||
action = ActionRequest(
|
||||
action_type=ActionType.SHELL_COMMAND,
|
||||
tool_name="shell_exec",
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
result = await validator.validate(action)
|
||||
|
||||
assert result.decision == SafetyDecision.DENY
|
||||
assert len(result.applied_rules) == 1 # One rule applied
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approval_rule_requires_approval(
|
||||
self,
|
||||
validator: ActionValidator,
|
||||
) -> None:
|
||||
"""Test that an approval rule requires approval."""
|
||||
validator.add_rule(
|
||||
create_approval_rule(
|
||||
name="approve_db",
|
||||
tool_patterns=["database_*"],
|
||||
reason="Database operations require approval",
|
||||
)
|
||||
)
|
||||
|
||||
metadata = ActionMetadata(agent_id="test-agent", session_id="session-1")
|
||||
action = ActionRequest(
|
||||
action_type=ActionType.DATABASE_MUTATE,
|
||||
tool_name="database_delete",
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
result = await validator.validate(action)
|
||||
|
||||
assert result.decision == SafetyDecision.REQUIRE_APPROVAL
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deny_takes_precedence(
|
||||
self,
|
||||
validator: ActionValidator,
|
||||
) -> None:
|
||||
"""Test that deny rules take precedence over allow rules."""
|
||||
validator.add_rule(
|
||||
create_allow_rule(
|
||||
name="allow_files",
|
||||
tool_patterns=["file_*"],
|
||||
priority=10,
|
||||
)
|
||||
)
|
||||
validator.add_rule(
|
||||
create_deny_rule(
|
||||
name="deny_delete",
|
||||
action_types=[ActionType.FILE_DELETE],
|
||||
priority=100,
|
||||
)
|
||||
)
|
||||
|
||||
metadata = ActionMetadata(agent_id="test-agent", session_id="session-1")
|
||||
action = ActionRequest(
|
||||
action_type=ActionType.FILE_DELETE,
|
||||
tool_name="file_delete",
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
result = await validator.validate(action)
|
||||
|
||||
assert result.decision == SafetyDecision.DENY
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rule_priority_ordering(
|
||||
self,
|
||||
validator: ActionValidator,
|
||||
) -> None:
|
||||
"""Test that rules are evaluated in priority order."""
|
||||
validator.add_rule(
|
||||
ValidationRule(
|
||||
name="low_priority",
|
||||
priority=1,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
)
|
||||
)
|
||||
validator.add_rule(
|
||||
ValidationRule(
|
||||
name="high_priority",
|
||||
priority=100,
|
||||
decision=SafetyDecision.DENY,
|
||||
)
|
||||
)
|
||||
|
||||
# High priority should be first in the list
|
||||
assert validator._rules[0].name == "high_priority"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disabled_rule_not_applied(
|
||||
self,
|
||||
validator: ActionValidator,
|
||||
sample_action: ActionRequest,
|
||||
) -> None:
|
||||
"""Test that disabled rules are not applied."""
|
||||
rule = create_deny_rule(
|
||||
name="deny_all",
|
||||
tool_patterns=["*"],
|
||||
)
|
||||
rule.enabled = False
|
||||
validator.add_rule(rule)
|
||||
|
||||
result = await validator.validate(sample_action)
|
||||
|
||||
assert result.decision == SafetyDecision.ALLOW
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resource_pattern_matching(
|
||||
self,
|
||||
validator: ActionValidator,
|
||||
) -> None:
|
||||
"""Test resource pattern matching."""
|
||||
validator.add_rule(
|
||||
create_deny_rule(
|
||||
name="deny_secrets",
|
||||
resource_patterns=["*/secrets/*", "*.env"],
|
||||
)
|
||||
)
|
||||
|
||||
metadata = ActionMetadata(agent_id="test-agent", session_id="session-1")
|
||||
action = ActionRequest(
|
||||
action_type=ActionType.FILE_READ,
|
||||
tool_name="file_read",
|
||||
resource="/app/secrets/api_key.txt",
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
result = await validator.validate(action)
|
||||
|
||||
assert result.decision == SafetyDecision.DENY
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_id_filter(
|
||||
self,
|
||||
validator: ActionValidator,
|
||||
) -> None:
|
||||
"""Test filtering by agent ID."""
|
||||
rule = ValidationRule(
|
||||
name="restrict_agent",
|
||||
agent_ids=["restricted-agent"],
|
||||
decision=SafetyDecision.DENY,
|
||||
reason="Restricted agent",
|
||||
)
|
||||
validator.add_rule(rule)
|
||||
|
||||
# Restricted agent should be denied
|
||||
metadata1 = ActionMetadata(agent_id="restricted-agent")
|
||||
action1 = ActionRequest(
|
||||
action_type=ActionType.FILE_READ,
|
||||
tool_name="file_read",
|
||||
metadata=metadata1,
|
||||
)
|
||||
result1 = await validator.validate(action1)
|
||||
assert result1.decision == SafetyDecision.DENY
|
||||
|
||||
# Other agents should be allowed
|
||||
metadata2 = ActionMetadata(agent_id="normal-agent")
|
||||
action2 = ActionRequest(
|
||||
action_type=ActionType.FILE_READ,
|
||||
tool_name="file_read",
|
||||
metadata=metadata2,
|
||||
)
|
||||
result2 = await validator.validate(action2)
|
||||
assert result2.decision == SafetyDecision.ALLOW
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bypass_mode(
|
||||
self,
|
||||
validator: ActionValidator,
|
||||
sample_action: ActionRequest,
|
||||
) -> None:
|
||||
"""Test validation bypass mode."""
|
||||
validator.add_rule(create_deny_rule(name="deny_all", tool_patterns=["*"]))
|
||||
|
||||
# Should be denied normally
|
||||
result1 = await validator.validate(sample_action)
|
||||
assert result1.decision == SafetyDecision.DENY
|
||||
|
||||
# Enable bypass
|
||||
validator.enable_bypass("Emergency situation")
|
||||
result2 = await validator.validate(sample_action)
|
||||
assert result2.decision == SafetyDecision.ALLOW
|
||||
assert "bypassed" in result2.reasons[0].lower()
|
||||
|
||||
# Disable bypass
|
||||
validator.disable_bypass()
|
||||
result3 = await validator.validate(sample_action)
|
||||
assert result3.decision == SafetyDecision.DENY
|
||||
|
||||
def test_remove_rule(self, validator: ActionValidator) -> None:
|
||||
"""Test removing a rule."""
|
||||
rule = create_deny_rule(name="test_rule", tool_patterns=["test"])
|
||||
validator.add_rule(rule)
|
||||
|
||||
assert len(validator._rules) == 1
|
||||
assert validator.remove_rule(rule.id) is True
|
||||
assert len(validator._rules) == 0
|
||||
|
||||
def test_remove_nonexistent_rule(self, validator: ActionValidator) -> None:
|
||||
"""Test removing a nonexistent rule returns False."""
|
||||
assert validator.remove_rule("nonexistent") is False
|
||||
|
||||
def test_clear_rules(self, validator: ActionValidator) -> None:
|
||||
"""Test clearing all rules."""
|
||||
validator.add_rule(create_deny_rule(name="rule1", tool_patterns=["a"]))
|
||||
validator.add_rule(create_deny_rule(name="rule2", tool_patterns=["b"]))
|
||||
|
||||
assert len(validator._rules) == 2
|
||||
validator.clear_rules()
|
||||
assert len(validator._rules) == 0
|
||||
|
||||
|
||||
class TestLoadRulesFromPolicy:
|
||||
"""Tests for loading rules from policies."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_denied_tools(
|
||||
self,
|
||||
validator: ActionValidator,
|
||||
) -> None:
|
||||
"""Test loading denied tools from policy."""
|
||||
policy = SafetyPolicy(
|
||||
name="test",
|
||||
denied_tools=["shell_*", "exec_*"],
|
||||
)
|
||||
|
||||
validator.load_rules_from_policy(policy)
|
||||
|
||||
# Should have 2 deny rules
|
||||
deny_rules = [r for r in validator._rules if r.decision == SafetyDecision.DENY]
|
||||
assert len(deny_rules) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_approval_patterns(
|
||||
self,
|
||||
validator: ActionValidator,
|
||||
) -> None:
|
||||
"""Test loading approval patterns from policy."""
|
||||
policy = SafetyPolicy(
|
||||
name="test",
|
||||
require_approval_for=["database_*"],
|
||||
)
|
||||
|
||||
validator.load_rules_from_policy(policy)
|
||||
|
||||
approval_rules = [
|
||||
r for r in validator._rules
|
||||
if r.decision == SafetyDecision.REQUIRE_APPROVAL
|
||||
]
|
||||
assert len(approval_rules) == 1
|
||||
|
||||
|
||||
class TestValidationBatch:
|
||||
"""Tests for batch validation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_batch(
|
||||
self,
|
||||
validator: ActionValidator,
|
||||
) -> None:
|
||||
"""Test validating multiple actions."""
|
||||
validator.add_rule(
|
||||
create_deny_rule(
|
||||
name="deny_shell",
|
||||
tool_patterns=["shell_*"],
|
||||
)
|
||||
)
|
||||
|
||||
metadata = ActionMetadata(agent_id="test-agent", session_id="session-1")
|
||||
actions = [
|
||||
ActionRequest(
|
||||
action_type=ActionType.FILE_READ,
|
||||
tool_name="file_read",
|
||||
metadata=metadata,
|
||||
),
|
||||
ActionRequest(
|
||||
action_type=ActionType.SHELL_COMMAND,
|
||||
tool_name="shell_exec",
|
||||
metadata=metadata,
|
||||
),
|
||||
]
|
||||
|
||||
results = await validator.validate_batch(actions)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0].decision == SafetyDecision.ALLOW
|
||||
assert results[1].decision == SafetyDecision.DENY
|
||||
|
||||
|
||||
class TestHelperFunctions:
|
||||
"""Tests for rule creation helper functions."""
|
||||
|
||||
def test_create_allow_rule(self) -> None:
|
||||
"""Test creating an allow rule."""
|
||||
rule = create_allow_rule(
|
||||
name="allow_test",
|
||||
tool_patterns=["test_*"],
|
||||
priority=50,
|
||||
)
|
||||
|
||||
assert rule.name == "allow_test"
|
||||
assert rule.decision == SafetyDecision.ALLOW
|
||||
assert rule.priority == 50
|
||||
|
||||
def test_create_deny_rule(self) -> None:
|
||||
"""Test creating a deny rule."""
|
||||
rule = create_deny_rule(
|
||||
name="deny_test",
|
||||
tool_patterns=["dangerous_*"],
|
||||
reason="Too dangerous",
|
||||
)
|
||||
|
||||
assert rule.name == "deny_test"
|
||||
assert rule.decision == SafetyDecision.DENY
|
||||
assert rule.reason == "Too dangerous"
|
||||
assert rule.priority == 100 # Default priority for deny
|
||||
|
||||
def test_create_approval_rule(self) -> None:
|
||||
"""Test creating an approval rule."""
|
||||
rule = create_approval_rule(
|
||||
name="approve_test",
|
||||
action_types=[ActionType.DATABASE_MUTATE],
|
||||
)
|
||||
|
||||
assert rule.name == "approve_test"
|
||||
assert rule.decision == SafetyDecision.REQUIRE_APPROVAL
|
||||
assert rule.priority == 50 # Default priority for approval
|
||||
Reference in New Issue
Block a user