6 Commits

Author SHA1 Message Date
Felipe Cardoso
c8b88dadc3 fix(safety): copy default patterns to avoid test pollution
The ContentFilter was appending references to DEFAULT_PATTERNS objects,
so when tests modified patterns (e.g., disabling them), those changes
persisted across test runs. Use dataclass replace() to create copies.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-03 12:08:43 +01:00
Felipe Cardoso
015f2de6c6 test(safety): add Phase E comprehensive safety tests
- Add tests for models: ActionMetadata, ActionRequest, ActionResult,
  ValidationRule, BudgetStatus, RateLimitConfig, ApprovalRequest/Response,
  Checkpoint, RollbackResult, AuditEvent, SafetyPolicy, GuardianResult
- Add tests for validation: ActionValidator rules, priorities, patterns,
  bypass mode, batch validation, rule creation helpers
- Add tests for loops: LoopDetector exact/semantic/oscillation detection,
  LoopBreaker throttle/backoff, history management
- Add tests for content filter: PII filtering (email, phone, SSN, credit card),
  secret blocking (API keys, GitHub tokens, private keys), custom patterns,
  scan without filtering, dict filtering
- Add tests for emergency controls: state management, pause/resume/reset,
  scoped emergency stops, callbacks, EmergencyTrigger events
- Fix exception kwargs in content filter and emergency controls to match
  exception class signatures

All 108 tests passing with lint and type checks clean.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-03 11:52:35 +01:00
Felipe Cardoso
f36bfb3781 feat(safety): add Phase D MCP integration and metrics
- Add MCPSafetyWrapper for safe MCP tool execution
- Add MCPToolCall/MCPToolResult models for MCP interactions
- Add SafeToolExecutor context manager
- Add SafetyMetrics collector with Prometheus export support
- Track validations, approvals, rate limits, budgets, and more
- Support for counters, gauges, and histograms

Issue #63

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-03 11:40:14 +01:00
Felipe Cardoso
ef659cd72d feat(safety): add Phase C advanced controls
- Add rollback manager with file checkpointing and transaction context
- Add HITL manager with approval queues and notification handlers
- Add content filter with PII, secrets, and injection detection
- Add emergency controls with stop/pause/resume capabilities
- Update SafetyConfig with checkpoint_dir setting

Issue #63

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-03 11:36:24 +01:00
Felipe Cardoso
728edd1453 feat(backend): add Phase B safety subsystems (#63)
Implements core control subsystems for the safety framework:

**Action Validation (validation/validator.py):**
- Rule-based validation engine with priority ordering
- Allow/deny/require-approval rule types
- Pattern matching for tools and resources
- Validation result caching with LRU eviction
- Emergency bypass capability with audit

**Permission System (permissions/manager.py):**
- Per-agent permission grants on resources
- Resource pattern matching (wildcards)
- Temporary permissions with expiration
- Permission inheritance hierarchy
- Default deny with configurable defaults

**Cost Control (costs/controller.py):**
- Per-session and per-day budget tracking
- Token and USD cost limits
- Warning alerts at configurable thresholds
- Budget rollover and reset policies
- Real-time usage tracking

**Rate Limiting (limits/limiter.py):**
- Sliding window rate limiter
- Per-action, per-LLM-call, per-file-op limits
- Burst allowance with recovery
- Configurable limits per operation type

**Loop Detection (loops/detector.py):**
- Exact repetition detection (same action+args)
- Semantic repetition (similar actions)
- Oscillation pattern detection (A→B→A→B)
- Per-agent action history tracking
- Loop breaking suggestions

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-03 11:28:00 +01:00
Felipe Cardoso
498c0a0e94 feat(backend): add safety framework foundation (Phase A) (#63)
Core safety framework architecture for autonomous agent guardrails:

**Core Components:**
- SafetyGuardian: Main orchestrator for all safety checks
- AuditLogger: Comprehensive audit logging with hash chain tamper detection
- SafetyConfig: Pydantic-based configuration
- Models: Action requests, validation results, policies, checkpoints

**Exception Hierarchy:**
- SafetyError base with context preservation
- Permission, Budget, RateLimit, Loop errors
- Approval workflow errors (Required, Denied, Timeout)
- Rollback, Sandbox, Emergency exceptions

**Safety Policy System:**
- Autonomy level based policies (FULL_CONTROL, MILESTONE, AUTONOMOUS)
- Cost limits, rate limits, permission patterns
- HITL approval requirements per action type
- Configurable loop detection thresholds

**Directory Structure:**
- validation/, costs/, limits/, loops/ - Control subsystems
- permissions/, rollback/, hitl/ - Access and recovery
- content/, sandbox/, emergency/ - Protection systems
- audit/, policies/ - Logging and configuration

Phase A establishes the architecture. Subsystems to be implemented in Phase B-C.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-03 11:22:25 +01:00
37 changed files with 9299 additions and 0 deletions

View File

@@ -0,0 +1,170 @@
"""
Safety and Guardrails Framework
Comprehensive safety framework for autonomous agent operation.
Provides multi-layered protection including:
- Pre-execution validation
- Cost and budget controls
- Rate limiting
- Loop detection and prevention
- Human-in-the-loop approval
- Rollback and checkpointing
- Content filtering
- Sandboxed execution
- Emergency controls
- Complete audit trail
Usage:
from app.services.safety import get_safety_guardian, SafetyGuardian
guardian = await get_safety_guardian()
result = await guardian.validate(action_request)
if result.allowed:
# Execute action
pass
else:
# Handle denial
print(f"Action denied: {result.reasons}")
"""
# Exceptions
# Audit
from .audit import (
AuditLogger,
get_audit_logger,
reset_audit_logger,
shutdown_audit_logger,
)
# Configuration
from .config import (
AutonomyConfig,
SafetyConfig,
get_autonomy_config,
get_default_policy,
get_policy_for_autonomy_level,
get_safety_config,
load_policies_from_directory,
load_policy_from_file,
reset_config_cache,
)
from .exceptions import (
ApprovalDeniedError,
ApprovalRequiredError,
ApprovalTimeoutError,
BudgetExceededError,
CheckpointError,
ContentFilterError,
EmergencyStopError,
LoopDetectedError,
PermissionDeniedError,
PolicyViolationError,
RateLimitExceededError,
RollbackError,
SafetyError,
SandboxError,
SandboxTimeoutError,
ValidationError,
)
# Guardian
from .guardian import (
SafetyGuardian,
get_safety_guardian,
reset_safety_guardian,
shutdown_safety_guardian,
)
# Models
from .models import (
ActionMetadata,
ActionRequest,
ActionResult,
ActionType,
ApprovalRequest,
ApprovalResponse,
ApprovalStatus,
AuditEvent,
AuditEventType,
AutonomyLevel,
BudgetScope,
BudgetStatus,
Checkpoint,
CheckpointType,
GuardianResult,
PermissionLevel,
RateLimitConfig,
RateLimitStatus,
ResourceType,
RollbackResult,
SafetyDecision,
SafetyPolicy,
ValidationResult,
ValidationRule,
)
__all__ = [
"ActionMetadata",
"ActionRequest",
"ActionResult",
# Models
"ActionType",
"ApprovalDeniedError",
"ApprovalRequest",
"ApprovalRequiredError",
"ApprovalResponse",
"ApprovalStatus",
"ApprovalTimeoutError",
"AuditEvent",
"AuditEventType",
# Audit
"AuditLogger",
"AutonomyConfig",
"AutonomyLevel",
"BudgetExceededError",
"BudgetScope",
"BudgetStatus",
"Checkpoint",
"CheckpointError",
"CheckpointType",
"ContentFilterError",
"EmergencyStopError",
"GuardianResult",
"LoopDetectedError",
"PermissionDeniedError",
"PermissionLevel",
"PolicyViolationError",
"RateLimitConfig",
"RateLimitExceededError",
"RateLimitStatus",
"ResourceType",
"RollbackError",
"RollbackResult",
# Configuration
"SafetyConfig",
"SafetyDecision",
# Exceptions
"SafetyError",
# Guardian
"SafetyGuardian",
"SafetyPolicy",
"SandboxError",
"SandboxTimeoutError",
"ValidationError",
"ValidationResult",
"ValidationRule",
"get_audit_logger",
"get_autonomy_config",
"get_default_policy",
"get_policy_for_autonomy_level",
"get_safety_config",
"get_safety_guardian",
"load_policies_from_directory",
"load_policy_from_file",
"reset_audit_logger",
"reset_config_cache",
"reset_safety_guardian",
"shutdown_audit_logger",
"shutdown_safety_guardian",
]

View File

@@ -0,0 +1,19 @@
"""
Audit System
Comprehensive audit logging for all safety-related events.
"""
from .logger import (
AuditLogger,
get_audit_logger,
reset_audit_logger,
shutdown_audit_logger,
)
__all__ = [
"AuditLogger",
"get_audit_logger",
"reset_audit_logger",
"shutdown_audit_logger",
]

View File

@@ -0,0 +1,585 @@
"""
Audit Logger
Comprehensive audit logging for all safety-related events.
Provides tamper detection, structured logging, and compliance support.
"""
import asyncio
import hashlib
import json
import logging
from collections import deque
from datetime import datetime, timedelta
from typing import Any
from uuid import uuid4
from ..config import get_safety_config
from ..models import (
ActionRequest,
AuditEvent,
AuditEventType,
SafetyDecision,
)
logger = logging.getLogger(__name__)
class AuditLogger:
"""
Audit logger for safety events.
Features:
- Structured event logging
- In-memory buffer with async flush
- Tamper detection via hash chains
- Query/search capability
- Retention policy enforcement
"""
def __init__(
self,
max_buffer_size: int = 1000,
flush_interval_seconds: float = 10.0,
enable_hash_chain: bool = True,
) -> None:
"""
Initialize the audit logger.
Args:
max_buffer_size: Maximum events to buffer before auto-flush
flush_interval_seconds: Interval for periodic flush
enable_hash_chain: Enable tamper detection via hash chain
"""
self._buffer: deque[AuditEvent] = deque(maxlen=max_buffer_size)
self._persisted: list[AuditEvent] = []
self._flush_interval = flush_interval_seconds
self._enable_hash_chain = enable_hash_chain
self._last_hash: str | None = None
self._lock = asyncio.Lock()
self._flush_task: asyncio.Task[None] | None = None
self._running = False
# Event handlers for real-time processing
self._handlers: list[Any] = []
config = get_safety_config()
self._retention_days = config.audit_retention_days
self._include_sensitive = config.audit_include_sensitive
async def start(self) -> None:
"""Start the audit logger background tasks."""
if self._running:
return
self._running = True
self._flush_task = asyncio.create_task(self._periodic_flush())
logger.info("Audit logger started")
async def stop(self) -> None:
"""Stop the audit logger and flush remaining events."""
self._running = False
if self._flush_task:
self._flush_task.cancel()
try:
await self._flush_task
except asyncio.CancelledError:
pass
# Final flush
await self.flush()
logger.info("Audit logger stopped")
async def log(
self,
event_type: AuditEventType,
*,
agent_id: str | None = None,
action_id: str | None = None,
project_id: str | None = None,
session_id: str | None = None,
user_id: str | None = None,
decision: SafetyDecision | None = None,
details: dict[str, Any] | None = None,
correlation_id: str | None = None,
) -> AuditEvent:
"""
Log an audit event.
Args:
event_type: Type of audit event
agent_id: Agent ID if applicable
action_id: Action ID if applicable
project_id: Project ID if applicable
session_id: Session ID if applicable
user_id: User ID if applicable
decision: Safety decision if applicable
details: Additional event details
correlation_id: Correlation ID for tracing
Returns:
The created audit event
"""
# Sanitize sensitive data if needed
sanitized_details = self._sanitize_details(details) if details else {}
event = AuditEvent(
id=str(uuid4()),
event_type=event_type,
timestamp=datetime.utcnow(),
agent_id=agent_id,
action_id=action_id,
project_id=project_id,
session_id=session_id,
user_id=user_id,
decision=decision,
details=sanitized_details,
correlation_id=correlation_id,
)
async with self._lock:
# Add hash chain for tamper detection
if self._enable_hash_chain:
event_hash = self._compute_hash(event)
sanitized_details["_hash"] = event_hash
sanitized_details["_prev_hash"] = self._last_hash
self._last_hash = event_hash
self._buffer.append(event)
# Notify handlers
await self._notify_handlers(event)
# Log to standard logger as well
self._log_to_logger(event)
return event
async def log_action_request(
self,
action: ActionRequest,
decision: SafetyDecision,
reasons: list[str] | None = None,
) -> AuditEvent:
"""Log an action request with its validation decision."""
event_type = (
AuditEventType.ACTION_DENIED
if decision == SafetyDecision.DENY
else AuditEventType.ACTION_VALIDATED
)
return await self.log(
event_type,
agent_id=action.metadata.agent_id,
action_id=action.id,
project_id=action.metadata.project_id,
session_id=action.metadata.session_id,
user_id=action.metadata.user_id,
decision=decision,
details={
"action_type": action.action_type.value,
"tool_name": action.tool_name,
"resource": action.resource,
"is_destructive": action.is_destructive,
"reasons": reasons or [],
},
correlation_id=action.metadata.correlation_id,
)
async def log_action_executed(
self,
action: ActionRequest,
success: bool,
execution_time_ms: float,
error: str | None = None,
) -> AuditEvent:
"""Log an action execution result."""
event_type = (
AuditEventType.ACTION_EXECUTED
if success
else AuditEventType.ACTION_FAILED
)
return await self.log(
event_type,
agent_id=action.metadata.agent_id,
action_id=action.id,
project_id=action.metadata.project_id,
session_id=action.metadata.session_id,
decision=SafetyDecision.ALLOW if success else SafetyDecision.DENY,
details={
"action_type": action.action_type.value,
"tool_name": action.tool_name,
"success": success,
"execution_time_ms": execution_time_ms,
"error": error,
},
correlation_id=action.metadata.correlation_id,
)
async def log_approval_event(
self,
event_type: AuditEventType,
approval_id: str,
action: ActionRequest,
decided_by: str | None = None,
reason: str | None = None,
) -> AuditEvent:
"""Log an approval-related event."""
return await self.log(
event_type,
agent_id=action.metadata.agent_id,
action_id=action.id,
project_id=action.metadata.project_id,
session_id=action.metadata.session_id,
user_id=decided_by,
details={
"approval_id": approval_id,
"action_type": action.action_type.value,
"tool_name": action.tool_name,
"decided_by": decided_by,
"reason": reason,
},
correlation_id=action.metadata.correlation_id,
)
async def log_budget_event(
self,
event_type: AuditEventType,
agent_id: str,
scope: str,
current_usage: float,
limit: float,
unit: str = "tokens",
) -> AuditEvent:
"""Log a budget-related event."""
return await self.log(
event_type,
agent_id=agent_id,
details={
"scope": scope,
"current_usage": current_usage,
"limit": limit,
"unit": unit,
"usage_percent": (current_usage / limit * 100) if limit > 0 else 0,
},
)
async def log_emergency_stop(
self,
stop_type: str,
triggered_by: str,
reason: str,
affected_agents: list[str] | None = None,
) -> AuditEvent:
"""Log an emergency stop event."""
return await self.log(
AuditEventType.EMERGENCY_STOP,
user_id=triggered_by,
details={
"stop_type": stop_type,
"triggered_by": triggered_by,
"reason": reason,
"affected_agents": affected_agents or [],
},
)
async def flush(self) -> int:
"""
Flush buffered events to persistent storage.
Returns:
Number of events flushed
"""
async with self._lock:
if not self._buffer:
return 0
events = list(self._buffer)
self._buffer.clear()
# Persist events (in production, this would go to database/storage)
self._persisted.extend(events)
# Enforce retention
self._enforce_retention()
logger.debug("Flushed %d audit events", len(events))
return len(events)
async def query(
self,
*,
event_types: list[AuditEventType] | None = None,
agent_id: str | None = None,
action_id: str | None = None,
project_id: str | None = None,
session_id: str | None = None,
user_id: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
correlation_id: str | None = None,
limit: int = 100,
offset: int = 0,
) -> list[AuditEvent]:
"""
Query audit events with filters.
Args:
event_types: Filter by event types
agent_id: Filter by agent ID
action_id: Filter by action ID
project_id: Filter by project ID
session_id: Filter by session ID
user_id: Filter by user ID
start_time: Filter events after this time
end_time: Filter events before this time
correlation_id: Filter by correlation ID
limit: Maximum results to return
offset: Result offset for pagination
Returns:
List of matching audit events
"""
# Combine buffer and persisted for query
all_events = list(self._persisted) + list(self._buffer)
results = []
for event in all_events:
if event_types and event.event_type not in event_types:
continue
if agent_id and event.agent_id != agent_id:
continue
if action_id and event.action_id != action_id:
continue
if project_id and event.project_id != project_id:
continue
if session_id and event.session_id != session_id:
continue
if user_id and event.user_id != user_id:
continue
if start_time and event.timestamp < start_time:
continue
if end_time and event.timestamp > end_time:
continue
if correlation_id and event.correlation_id != correlation_id:
continue
results.append(event)
# Sort by timestamp descending
results.sort(key=lambda e: e.timestamp, reverse=True)
# Apply pagination
return results[offset : offset + limit]
async def get_action_history(
self,
agent_id: str,
limit: int = 100,
) -> list[AuditEvent]:
"""Get action history for an agent."""
return await self.query(
agent_id=agent_id,
event_types=[
AuditEventType.ACTION_REQUESTED,
AuditEventType.ACTION_VALIDATED,
AuditEventType.ACTION_DENIED,
AuditEventType.ACTION_EXECUTED,
AuditEventType.ACTION_FAILED,
],
limit=limit,
)
async def verify_integrity(self) -> tuple[bool, list[str]]:
"""
Verify audit log integrity using hash chain.
Returns:
Tuple of (is_valid, list of issues found)
"""
if not self._enable_hash_chain:
return True, []
issues: list[str] = []
all_events = list(self._persisted) + list(self._buffer)
prev_hash: str | None = None
for event in sorted(all_events, key=lambda e: e.timestamp):
stored_prev = event.details.get("_prev_hash")
stored_hash = event.details.get("_hash")
if stored_prev != prev_hash:
issues.append(
f"Hash chain broken at event {event.id}: "
f"expected prev_hash={prev_hash}, got {stored_prev}"
)
if stored_hash:
computed = self._compute_hash(event)
if computed != stored_hash:
issues.append(
f"Hash mismatch at event {event.id}: "
f"expected {computed}, got {stored_hash}"
)
prev_hash = stored_hash
return len(issues) == 0, issues
def add_handler(self, handler: Any) -> None:
"""Add a real-time event handler."""
self._handlers.append(handler)
def remove_handler(self, handler: Any) -> None:
"""Remove an event handler."""
if handler in self._handlers:
self._handlers.remove(handler)
def _sanitize_details(self, details: dict[str, Any]) -> dict[str, Any]:
"""Sanitize sensitive data from details."""
if self._include_sensitive:
return details
sanitized: dict[str, Any] = {}
sensitive_keys = {
"password",
"secret",
"token",
"api_key",
"apikey",
"auth",
"credential",
}
for key, value in details.items():
lower_key = key.lower()
if any(s in lower_key for s in sensitive_keys):
sanitized[key] = "[REDACTED]"
elif isinstance(value, dict):
sanitized[key] = self._sanitize_details(value)
else:
sanitized[key] = value
return sanitized
def _compute_hash(self, event: AuditEvent) -> str:
"""Compute hash for an event (excluding hash fields)."""
data = {
"id": event.id,
"event_type": event.event_type.value,
"timestamp": event.timestamp.isoformat(),
"agent_id": event.agent_id,
"action_id": event.action_id,
"project_id": event.project_id,
"session_id": event.session_id,
"user_id": event.user_id,
"decision": event.decision.value if event.decision else None,
"details": {
k: v
for k, v in event.details.items()
if not k.startswith("_")
},
"correlation_id": event.correlation_id,
}
if self._last_hash:
data["_prev_hash"] = self._last_hash
serialized = json.dumps(data, sort_keys=True, default=str)
return hashlib.sha256(serialized.encode()).hexdigest()
def _log_to_logger(self, event: AuditEvent) -> None:
"""Log event to standard Python logger."""
log_data = {
"audit_event": event.event_type.value,
"event_id": event.id,
"agent_id": event.agent_id,
"action_id": event.action_id,
"decision": event.decision.value if event.decision else None,
}
# Use appropriate log level based on event type
if event.event_type in {
AuditEventType.ACTION_DENIED,
AuditEventType.POLICY_VIOLATION,
AuditEventType.EMERGENCY_STOP,
}:
logger.warning("Audit: %s", log_data)
elif event.event_type in {
AuditEventType.ACTION_FAILED,
AuditEventType.ROLLBACK_FAILED,
}:
logger.error("Audit: %s", log_data)
else:
logger.info("Audit: %s", log_data)
def _enforce_retention(self) -> None:
"""Enforce retention policy on persisted events."""
if not self._retention_days:
return
cutoff = datetime.utcnow() - timedelta(days=self._retention_days)
before_count = len(self._persisted)
self._persisted = [e for e in self._persisted if e.timestamp >= cutoff]
removed = before_count - len(self._persisted)
if removed > 0:
logger.info("Removed %d expired audit events", removed)
async def _periodic_flush(self) -> None:
"""Background task for periodic flushing."""
while self._running:
try:
await asyncio.sleep(self._flush_interval)
await self.flush()
except asyncio.CancelledError:
break
except Exception as e:
logger.error("Error in periodic audit flush: %s", e)
async def _notify_handlers(self, event: AuditEvent) -> None:
"""Notify all registered handlers of a new event."""
for handler in self._handlers:
try:
if asyncio.iscoroutinefunction(handler):
await handler(event)
else:
handler(event)
except Exception as e:
logger.error("Error in audit event handler: %s", e)
# Singleton instance
_audit_logger: AuditLogger | None = None
_audit_lock = asyncio.Lock()
async def get_audit_logger() -> AuditLogger:
"""Get the global audit logger instance."""
global _audit_logger
async with _audit_lock:
if _audit_logger is None:
_audit_logger = AuditLogger()
await _audit_logger.start()
return _audit_logger
async def shutdown_audit_logger() -> None:
"""Shutdown the global audit logger."""
global _audit_logger
async with _audit_lock:
if _audit_logger is not None:
await _audit_logger.stop()
_audit_logger = None
def reset_audit_logger() -> None:
"""Reset the audit logger (for testing)."""
global _audit_logger
_audit_logger = None

View File

@@ -0,0 +1,304 @@
"""
Safety Framework Configuration
Pydantic settings for the safety and guardrails framework.
"""
import logging
import os
from functools import lru_cache
from pathlib import Path
from typing import Any
import yaml
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
from .models import AutonomyLevel, SafetyPolicy
logger = logging.getLogger(__name__)
class SafetyConfig(BaseSettings):
"""Configuration for the safety framework."""
model_config = SettingsConfigDict(
env_prefix="SAFETY_",
env_file=".env",
env_file_encoding="utf-8",
extra="ignore",
)
# General settings
enabled: bool = Field(True, description="Enable safety framework")
strict_mode: bool = Field(
True, description="Strict mode (fail closed on errors)"
)
log_level: str = Field("INFO", description="Logging level")
# Default autonomy level
default_autonomy_level: AutonomyLevel = Field(
AutonomyLevel.MILESTONE,
description="Default autonomy level for new agents",
)
# Default budget limits
default_session_token_budget: int = Field(
100_000, description="Default tokens per session"
)
default_daily_token_budget: int = Field(
1_000_000, description="Default tokens per day"
)
default_session_cost_limit: float = Field(
10.0, description="Default USD per session"
)
default_daily_cost_limit: float = Field(100.0, description="Default USD per day")
# Default rate limits
default_actions_per_minute: int = Field(60, description="Default actions per min")
default_llm_calls_per_minute: int = Field(20, description="Default LLM calls/min")
default_file_ops_per_minute: int = Field(100, description="Default file ops/min")
# Loop detection
loop_detection_enabled: bool = Field(True, description="Enable loop detection")
max_repeated_actions: int = Field(5, description="Max exact repetitions")
max_similar_actions: int = Field(10, description="Max similar actions")
loop_history_size: int = Field(100, description="Action history size for loops")
# HITL settings
hitl_enabled: bool = Field(True, description="Enable human-in-the-loop")
hitl_default_timeout: int = Field(300, description="Default approval timeout (s)")
hitl_notification_channels: list[str] = Field(
default_factory=list, description="Notification channels"
)
# Rollback settings
rollback_enabled: bool = Field(True, description="Enable rollback capability")
checkpoint_dir: str = Field(
"/tmp/syndarix_checkpoints", # noqa: S108
description="Directory for checkpoint storage",
)
checkpoint_retention_hours: int = Field(24, description="Checkpoint retention")
auto_checkpoint_destructive: bool = Field(
True, description="Auto-checkpoint destructive actions"
)
# Sandbox settings
sandbox_enabled: bool = Field(False, description="Enable sandbox execution")
sandbox_timeout: int = Field(300, description="Sandbox timeout (s)")
sandbox_memory_mb: int = Field(1024, description="Sandbox memory limit (MB)")
sandbox_cpu_limit: float = Field(1.0, description="Sandbox CPU limit")
sandbox_network_enabled: bool = Field(False, description="Allow sandbox network")
# Audit settings
audit_enabled: bool = Field(True, description="Enable audit logging")
audit_retention_days: int = Field(90, description="Audit log retention (days)")
audit_include_sensitive: bool = Field(
False, description="Include sensitive data in audit"
)
# Content filtering
content_filter_enabled: bool = Field(True, description="Enable content filtering")
filter_pii: bool = Field(True, description="Filter PII")
filter_secrets: bool = Field(True, description="Filter secrets")
# Emergency controls
emergency_stop_enabled: bool = Field(True, description="Enable emergency stop")
emergency_webhook_url: str | None = Field(None, description="Emergency webhook")
# Policy file path
policy_file: str | None = Field(None, description="Path to policy YAML file")
# Validation cache
validation_cache_ttl: int = Field(60, description="Validation cache TTL (s)")
validation_cache_size: int = Field(1000, description="Validation cache size")
class AutonomyConfig(BaseSettings):
"""Configuration for autonomy levels."""
model_config = SettingsConfigDict(
env_prefix="AUTONOMY_",
env_file=".env",
env_file_encoding="utf-8",
extra="ignore",
)
# FULL_CONTROL settings
full_control_cost_limit: float = Field(1.0, description="USD limit per session")
full_control_require_all_approval: bool = Field(
True, description="Require approval for all"
)
full_control_block_destructive: bool = Field(
True, description="Block destructive actions"
)
# MILESTONE settings
milestone_cost_limit: float = Field(10.0, description="USD limit per session")
milestone_require_critical_approval: bool = Field(
True, description="Require approval for critical"
)
milestone_auto_checkpoint: bool = Field(
True, description="Auto-checkpoint destructive"
)
# AUTONOMOUS settings
autonomous_cost_limit: float = Field(100.0, description="USD limit per session")
autonomous_auto_approve_normal: bool = Field(
True, description="Auto-approve normal actions"
)
autonomous_auto_checkpoint: bool = Field(True, description="Auto-checkpoint all")
def _expand_env_vars(value: Any) -> Any:
"""Recursively expand environment variables in values."""
if isinstance(value, str):
return os.path.expandvars(value)
elif isinstance(value, dict):
return {k: _expand_env_vars(v) for k, v in value.items()}
elif isinstance(value, list):
return [_expand_env_vars(v) for v in value]
return value
def load_policy_from_file(file_path: str | Path) -> SafetyPolicy | None:
"""Load a safety policy from a YAML file."""
path = Path(file_path)
if not path.exists():
logger.warning("Policy file not found: %s", path)
return None
try:
with open(path) as f:
data = yaml.safe_load(f)
if data is None:
logger.warning("Empty policy file: %s", path)
return None
# Expand environment variables
data = _expand_env_vars(data)
return SafetyPolicy(**data)
except Exception as e:
logger.error("Failed to load policy file %s: %s", path, e)
return None
def load_policies_from_directory(directory: str | Path) -> dict[str, SafetyPolicy]:
"""Load all safety policies from a directory."""
policies: dict[str, SafetyPolicy] = {}
path = Path(directory)
if not path.exists() or not path.is_dir():
logger.warning("Policy directory not found: %s", path)
return policies
for file_path in path.glob("*.yaml"):
policy = load_policy_from_file(file_path)
if policy:
policies[policy.name] = policy
logger.info("Loaded policy: %s from %s", policy.name, file_path.name)
return policies
@lru_cache(maxsize=1)
def get_safety_config() -> SafetyConfig:
"""Get the safety configuration (cached singleton)."""
return SafetyConfig()
@lru_cache(maxsize=1)
def get_autonomy_config() -> AutonomyConfig:
"""Get the autonomy configuration (cached singleton)."""
return AutonomyConfig()
def get_default_policy() -> SafetyPolicy:
"""Get the default safety policy."""
config = get_safety_config()
return SafetyPolicy(
name="default",
description="Default safety policy",
max_tokens_per_session=config.default_session_token_budget,
max_tokens_per_day=config.default_daily_token_budget,
max_cost_per_session_usd=config.default_session_cost_limit,
max_cost_per_day_usd=config.default_daily_cost_limit,
max_actions_per_minute=config.default_actions_per_minute,
max_llm_calls_per_minute=config.default_llm_calls_per_minute,
max_file_operations_per_minute=config.default_file_ops_per_minute,
max_repeated_actions=config.max_repeated_actions,
max_similar_actions=config.max_similar_actions,
require_sandbox=config.sandbox_enabled,
sandbox_timeout_seconds=config.sandbox_timeout,
sandbox_memory_mb=config.sandbox_memory_mb,
)
def get_policy_for_autonomy_level(level: AutonomyLevel) -> SafetyPolicy:
"""Get the safety policy for a given autonomy level."""
autonomy = get_autonomy_config()
base_policy = get_default_policy()
if level == AutonomyLevel.FULL_CONTROL:
return SafetyPolicy(
name="full_control",
description="Full control mode - all actions require approval",
max_cost_per_session_usd=autonomy.full_control_cost_limit,
max_cost_per_day_usd=autonomy.full_control_cost_limit * 10,
require_approval_for=["*"], # All actions
max_tokens_per_session=base_policy.max_tokens_per_session // 10,
max_tokens_per_day=base_policy.max_tokens_per_day // 10,
max_actions_per_minute=base_policy.max_actions_per_minute // 2,
max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute // 2,
max_file_operations_per_minute=base_policy.max_file_operations_per_minute // 2,
denied_tools=["delete_*", "destroy_*", "drop_*"],
)
elif level == AutonomyLevel.MILESTONE:
return SafetyPolicy(
name="milestone",
description="Milestone mode - approval at milestones only",
max_cost_per_session_usd=autonomy.milestone_cost_limit,
max_cost_per_day_usd=autonomy.milestone_cost_limit * 10,
require_approval_for=[
"delete_file",
"push_to_remote",
"deploy_*",
"modify_critical_*",
"create_pull_request",
],
max_tokens_per_session=base_policy.max_tokens_per_session,
max_tokens_per_day=base_policy.max_tokens_per_day,
max_actions_per_minute=base_policy.max_actions_per_minute,
max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute,
max_file_operations_per_minute=base_policy.max_file_operations_per_minute,
)
else: # AUTONOMOUS
return SafetyPolicy(
name="autonomous",
description="Autonomous mode - minimal intervention",
max_cost_per_session_usd=autonomy.autonomous_cost_limit,
max_cost_per_day_usd=autonomy.autonomous_cost_limit * 10,
require_approval_for=[
"deploy_to_production",
"delete_repository",
"modify_production_config",
],
max_tokens_per_session=base_policy.max_tokens_per_session * 5,
max_tokens_per_day=base_policy.max_tokens_per_day * 5,
max_actions_per_minute=base_policy.max_actions_per_minute * 2,
max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute * 2,
max_file_operations_per_minute=base_policy.max_file_operations_per_minute * 2,
)
def reset_config_cache() -> None:
"""Reset configuration caches (for testing)."""
get_safety_config.cache_clear()
get_autonomy_config.cache_clear()

View File

@@ -0,0 +1,23 @@
"""Content filtering for safety."""
from .filter import (
ContentCategory,
ContentFilter,
FilterAction,
FilterMatch,
FilterPattern,
FilterResult,
filter_content,
scan_for_secrets,
)
__all__ = [
"ContentCategory",
"ContentFilter",
"FilterAction",
"FilterMatch",
"FilterPattern",
"FilterResult",
"filter_content",
"scan_for_secrets",
]

View File

@@ -0,0 +1,533 @@
"""
Content Filter
Filters and sanitizes content for safety, including PII detection and secret scanning.
"""
import asyncio
import logging
import re
from dataclasses import dataclass, field, replace
from enum import Enum
from typing import Any, ClassVar
from ..exceptions import ContentFilterError
logger = logging.getLogger(__name__)
class ContentCategory(str, Enum):
"""Categories of sensitive content."""
PII = "pii"
SECRETS = "secrets"
CREDENTIALS = "credentials"
FINANCIAL = "financial"
HEALTH = "health"
PROFANITY = "profanity"
INJECTION = "injection"
CUSTOM = "custom"
class FilterAction(str, Enum):
"""Actions to take on detected content."""
ALLOW = "allow"
REDACT = "redact"
BLOCK = "block"
WARN = "warn"
@dataclass
class FilterMatch:
"""A match found by a filter."""
category: ContentCategory
pattern_name: str
matched_text: str
start_pos: int
end_pos: int
confidence: float = 1.0
redacted_text: str | None = None
@dataclass
class FilterResult:
"""Result of content filtering."""
original_content: str
filtered_content: str
matches: list[FilterMatch] = field(default_factory=list)
blocked: bool = False
block_reason: str | None = None
warnings: list[str] = field(default_factory=list)
@property
def has_sensitive_content(self) -> bool:
"""Check if any sensitive content was found."""
return len(self.matches) > 0
@dataclass
class FilterPattern:
"""A pattern for detecting sensitive content."""
name: str
category: ContentCategory
pattern: str # Regex pattern
action: FilterAction = FilterAction.REDACT
replacement: str = "[REDACTED]"
confidence: float = 1.0
enabled: bool = True
def __post_init__(self) -> None:
"""Compile the regex pattern."""
self._compiled = re.compile(self.pattern, re.IGNORECASE | re.MULTILINE)
def find_matches(self, content: str) -> list[FilterMatch]:
"""Find all matches in content."""
matches = []
for match in self._compiled.finditer(content):
matches.append(
FilterMatch(
category=self.category,
pattern_name=self.name,
matched_text=match.group(),
start_pos=match.start(),
end_pos=match.end(),
confidence=self.confidence,
redacted_text=self.replacement,
)
)
return matches
class ContentFilter:
"""
Filters content for sensitive information.
Features:
- PII detection (emails, phones, SSN, etc.)
- Secret scanning (API keys, tokens, passwords)
- Credential detection
- Injection attack prevention
- Custom pattern support
- Configurable actions (allow, redact, block, warn)
"""
# Default patterns for common sensitive data
DEFAULT_PATTERNS: ClassVar[list[FilterPattern]] = [
# PII Patterns
FilterPattern(
name="email",
category=ContentCategory.PII,
pattern=r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b",
action=FilterAction.REDACT,
replacement="[EMAIL]",
),
FilterPattern(
name="phone_us",
category=ContentCategory.PII,
pattern=r"\b(?:\+1[-.\s]?)?(?:\(?\d{3}\)?[-.\s]?)?\d{3}[-.\s]?\d{4}\b",
action=FilterAction.REDACT,
replacement="[PHONE]",
),
FilterPattern(
name="ssn",
category=ContentCategory.PII,
pattern=r"\b\d{3}[-\s]?\d{2}[-\s]?\d{4}\b",
action=FilterAction.REDACT,
replacement="[SSN]",
),
FilterPattern(
name="credit_card",
category=ContentCategory.FINANCIAL,
pattern=r"\b(?:\d{4}[-\s]?){3}\d{4}\b",
action=FilterAction.REDACT,
replacement="[CREDIT_CARD]",
),
FilterPattern(
name="ip_address",
category=ContentCategory.PII,
pattern=r"\b(?:\d{1,3}\.){3}\d{1,3}\b",
action=FilterAction.WARN,
replacement="[IP]",
confidence=0.8,
),
# Secret Patterns
FilterPattern(
name="api_key_generic",
category=ContentCategory.SECRETS,
pattern=r"\b(?:api[_-]?key|apikey)\s*[:=]\s*['\"]?([A-Za-z0-9_-]{20,})['\"]?",
action=FilterAction.BLOCK,
replacement="[API_KEY]",
),
FilterPattern(
name="aws_access_key",
category=ContentCategory.SECRETS,
pattern=r"\bAKIA[0-9A-Z]{16}\b",
action=FilterAction.BLOCK,
replacement="[AWS_KEY]",
),
FilterPattern(
name="aws_secret_key",
category=ContentCategory.SECRETS,
pattern=r"\b[A-Za-z0-9/+=]{40}\b",
action=FilterAction.WARN,
replacement="[AWS_SECRET]",
confidence=0.6, # Lower confidence - might be false positive
),
FilterPattern(
name="github_token",
category=ContentCategory.SECRETS,
pattern=r"\b(ghp|gho|ghu|ghs|ghr)_[A-Za-z0-9]{36,}\b",
action=FilterAction.BLOCK,
replacement="[GITHUB_TOKEN]",
),
FilterPattern(
name="jwt_token",
category=ContentCategory.SECRETS,
pattern=r"\beyJ[A-Za-z0-9_-]*\.eyJ[A-Za-z0-9_-]*\.[A-Za-z0-9_-]*\b",
action=FilterAction.BLOCK,
replacement="[JWT]",
),
# Credential Patterns
FilterPattern(
name="password_in_url",
category=ContentCategory.CREDENTIALS,
pattern=r"://[^:]+:([^@]+)@",
action=FilterAction.BLOCK,
replacement="://[REDACTED]@",
),
FilterPattern(
name="password_assignment",
category=ContentCategory.CREDENTIALS,
pattern=r"\b(?:password|passwd|pwd)\s*[:=]\s*['\"]?([^\s'\"]+)['\"]?",
action=FilterAction.REDACT,
replacement="[PASSWORD]",
),
FilterPattern(
name="private_key",
category=ContentCategory.SECRETS,
pattern=r"-----BEGIN (?:RSA |DSA |EC |OPENSSH )?PRIVATE KEY-----",
action=FilterAction.BLOCK,
replacement="[PRIVATE_KEY]",
),
# Injection Patterns
FilterPattern(
name="sql_injection",
category=ContentCategory.INJECTION,
pattern=r"(?:'\s*(?:OR|AND)\s*')|(?:--\s*$)|(?:;\s*(?:DROP|DELETE|UPDATE|INSERT))",
action=FilterAction.BLOCK,
replacement="[BLOCKED]",
),
FilterPattern(
name="command_injection",
category=ContentCategory.INJECTION,
pattern=r"[;&|`$]|\$\(|\$\{",
action=FilterAction.WARN,
replacement="[CMD]",
confidence=0.5, # Low confidence - common in code
),
]
def __init__(
self,
enable_pii_filter: bool = True,
enable_secret_filter: bool = True,
enable_injection_filter: bool = True,
custom_patterns: list[FilterPattern] | None = None,
default_action: FilterAction = FilterAction.REDACT,
) -> None:
"""
Initialize the ContentFilter.
Args:
enable_pii_filter: Enable PII detection
enable_secret_filter: Enable secret scanning
enable_injection_filter: Enable injection detection
custom_patterns: Additional custom patterns
default_action: Default action for matches
"""
self._patterns: list[FilterPattern] = []
self._default_action = default_action
self._lock = asyncio.Lock()
# Load default patterns based on configuration
# Use replace() to create a copy of each pattern to avoid mutating shared defaults
for pattern in self.DEFAULT_PATTERNS:
if pattern.category == ContentCategory.PII and not enable_pii_filter:
continue
if pattern.category == ContentCategory.SECRETS and not enable_secret_filter:
continue
if pattern.category == ContentCategory.CREDENTIALS and not enable_secret_filter:
continue
if pattern.category == ContentCategory.INJECTION and not enable_injection_filter:
continue
self._patterns.append(replace(pattern))
# Add custom patterns
if custom_patterns:
self._patterns.extend(custom_patterns)
logger.info("ContentFilter initialized with %d patterns", len(self._patterns))
def add_pattern(self, pattern: FilterPattern) -> None:
"""Add a custom pattern."""
self._patterns.append(pattern)
logger.debug("Added pattern: %s", pattern.name)
def remove_pattern(self, pattern_name: str) -> bool:
"""Remove a pattern by name."""
for i, pattern in enumerate(self._patterns):
if pattern.name == pattern_name:
del self._patterns[i]
logger.debug("Removed pattern: %s", pattern_name)
return True
return False
def enable_pattern(self, pattern_name: str, enabled: bool = True) -> bool:
"""Enable or disable a pattern."""
for pattern in self._patterns:
if pattern.name == pattern_name:
pattern.enabled = enabled
return True
return False
async def filter(
self,
content: str,
context: dict[str, Any] | None = None,
raise_on_block: bool = False,
) -> FilterResult:
"""
Filter content for sensitive information.
Args:
content: Content to filter
context: Optional context for filtering decisions
raise_on_block: Raise exception if content is blocked
Returns:
FilterResult with filtered content and match details
Raises:
ContentFilterError: If content is blocked and raise_on_block=True
"""
all_matches: list[FilterMatch] = []
blocked = False
block_reason: str | None = None
warnings: list[str] = []
# Find all matches
for pattern in self._patterns:
if not pattern.enabled:
continue
matches = pattern.find_matches(content)
for match in matches:
all_matches.append(match)
if pattern.action == FilterAction.BLOCK:
blocked = True
block_reason = f"Blocked by pattern: {pattern.name}"
elif pattern.action == FilterAction.WARN:
warnings.append(
f"Warning: {pattern.name} detected at position {match.start_pos}"
)
# Sort matches by position (reverse for replacement)
all_matches.sort(key=lambda m: m.start_pos, reverse=True)
# Apply redactions
filtered_content = content
for match in all_matches:
matched_pattern = self._get_pattern(match.pattern_name)
if matched_pattern and matched_pattern.action in (FilterAction.REDACT, FilterAction.BLOCK):
filtered_content = (
filtered_content[: match.start_pos]
+ (match.redacted_text or "[REDACTED]")
+ filtered_content[match.end_pos :]
)
# Re-sort for result
all_matches.sort(key=lambda m: m.start_pos)
result = FilterResult(
original_content=content,
filtered_content=filtered_content if not blocked else "",
matches=all_matches,
blocked=blocked,
block_reason=block_reason,
warnings=warnings,
)
if blocked:
logger.warning(
"Content blocked: %s (%d matches)",
block_reason,
len(all_matches),
)
if raise_on_block:
raise ContentFilterError(
block_reason or "Content blocked",
filter_type=all_matches[0].category.value if all_matches else "unknown",
detected_patterns=[m.pattern_name for m in all_matches] if all_matches else [],
)
elif all_matches:
logger.debug(
"Content filtered: %d matches, %d warnings",
len(all_matches),
len(warnings),
)
return result
async def filter_dict(
self,
data: dict[str, Any],
keys_to_filter: list[str] | None = None,
recursive: bool = True,
) -> dict[str, Any]:
"""
Filter string values in a dictionary.
Args:
data: Dictionary to filter
keys_to_filter: Specific keys to filter (None = all)
recursive: Filter nested dictionaries
Returns:
Filtered dictionary
"""
result: dict[str, Any] = {}
for key, value in data.items():
if isinstance(value, str):
if keys_to_filter is None or key in keys_to_filter:
filter_result = await self.filter(value)
result[key] = filter_result.filtered_content
else:
result[key] = value
elif isinstance(value, dict) and recursive:
result[key] = await self.filter_dict(value, keys_to_filter, recursive)
elif isinstance(value, list):
result[key] = [
(await self.filter(item)).filtered_content
if isinstance(item, str)
else item
for item in value
]
else:
result[key] = value
return result
async def scan(
self,
content: str,
categories: list[ContentCategory] | None = None,
) -> list[FilterMatch]:
"""
Scan content without filtering (detection only).
Args:
content: Content to scan
categories: Limit to specific categories
Returns:
List of matches found
"""
all_matches: list[FilterMatch] = []
for pattern in self._patterns:
if not pattern.enabled:
continue
if categories and pattern.category not in categories:
continue
matches = pattern.find_matches(content)
all_matches.extend(matches)
all_matches.sort(key=lambda m: m.start_pos)
return all_matches
async def validate_safe(
self,
content: str,
categories: list[ContentCategory] | None = None,
allow_warnings: bool = True,
) -> tuple[bool, list[str]]:
"""
Validate that content is safe (no blocked patterns).
Args:
content: Content to validate
categories: Limit to specific categories
allow_warnings: Allow content with warnings
Returns:
Tuple of (is_safe, list of issues)
"""
issues: list[str] = []
for pattern in self._patterns:
if not pattern.enabled:
continue
if categories and pattern.category not in categories:
continue
matches = pattern.find_matches(content)
for match in matches:
if pattern.action == FilterAction.BLOCK:
issues.append(f"Blocked: {pattern.name} at position {match.start_pos}")
elif pattern.action == FilterAction.WARN and not allow_warnings:
issues.append(f"Warning: {pattern.name} at position {match.start_pos}")
return len(issues) == 0, issues
def _get_pattern(self, name: str) -> FilterPattern | None:
"""Get a pattern by name."""
for pattern in self._patterns:
if pattern.name == name:
return pattern
return None
def get_pattern_stats(self) -> dict[str, Any]:
"""Get statistics about configured patterns."""
by_category: dict[str, int] = {}
by_action: dict[str, int] = {}
for pattern in self._patterns:
cat = pattern.category.value
by_category[cat] = by_category.get(cat, 0) + 1
act = pattern.action.value
by_action[act] = by_action.get(act, 0) + 1
return {
"total_patterns": len(self._patterns),
"enabled_patterns": sum(1 for p in self._patterns if p.enabled),
"by_category": by_category,
"by_action": by_action,
}
# Convenience function for quick filtering
async def filter_content(content: str) -> str:
"""Quick filter content with default settings."""
filter_instance = ContentFilter()
result = await filter_instance.filter(content)
return result.filtered_content
async def scan_for_secrets(content: str) -> list[FilterMatch]:
"""Quick scan for secrets only."""
filter_instance = ContentFilter(
enable_pii_filter=False,
enable_injection_filter=False,
)
return await filter_instance.scan(
content,
categories=[ContentCategory.SECRETS, ContentCategory.CREDENTIALS],
)

View File

@@ -0,0 +1,15 @@
"""
Cost Control Module
Budget management and cost tracking.
"""
from .controller import (
BudgetTracker,
CostController,
)
__all__ = [
"BudgetTracker",
"CostController",
]

View File

@@ -0,0 +1,479 @@
"""
Cost Controller
Budget management and cost tracking for agent operations.
"""
import asyncio
import logging
from datetime import datetime, timedelta
from typing import Any
from ..config import get_safety_config
from ..exceptions import BudgetExceededError
from ..models import (
ActionRequest,
BudgetScope,
BudgetStatus,
)
logger = logging.getLogger(__name__)
class BudgetTracker:
"""Tracks usage against a budget limit."""
def __init__(
self,
scope: BudgetScope,
scope_id: str,
tokens_limit: int,
cost_limit_usd: float,
reset_interval: timedelta | None = None,
warning_threshold: float = 0.8,
) -> None:
self.scope = scope
self.scope_id = scope_id
self.tokens_limit = tokens_limit
self.cost_limit_usd = cost_limit_usd
self.warning_threshold = warning_threshold
self._reset_interval = reset_interval
self._tokens_used = 0
self._cost_used_usd = 0.0
self._created_at = datetime.utcnow()
self._last_reset = datetime.utcnow()
self._lock = asyncio.Lock()
async def add_usage(self, tokens: int, cost_usd: float) -> None:
"""Add usage to the tracker."""
async with self._lock:
self._check_reset()
self._tokens_used += tokens
self._cost_used_usd += cost_usd
async def get_status(self) -> BudgetStatus:
"""Get current budget status."""
async with self._lock:
self._check_reset()
tokens_remaining = max(0, self.tokens_limit - self._tokens_used)
cost_remaining = max(0, self.cost_limit_usd - self._cost_used_usd)
token_usage_ratio = (
self._tokens_used / self.tokens_limit if self.tokens_limit > 0 else 0
)
cost_usage_ratio = (
self._cost_used_usd / self.cost_limit_usd
if self.cost_limit_usd > 0
else 0
)
is_warning = max(token_usage_ratio, cost_usage_ratio) >= self.warning_threshold
is_exceeded = (
self._tokens_used >= self.tokens_limit
or self._cost_used_usd >= self.cost_limit_usd
)
reset_at = None
if self._reset_interval:
reset_at = self._last_reset + self._reset_interval
return BudgetStatus(
scope=self.scope,
scope_id=self.scope_id,
tokens_used=self._tokens_used,
tokens_limit=self.tokens_limit,
cost_used_usd=self._cost_used_usd,
cost_limit_usd=self.cost_limit_usd,
tokens_remaining=tokens_remaining,
cost_remaining_usd=cost_remaining,
warning_threshold=self.warning_threshold,
is_warning=is_warning,
is_exceeded=is_exceeded,
reset_at=reset_at,
)
async def check_budget(self, estimated_tokens: int, estimated_cost_usd: float) -> bool:
"""Check if there's enough budget for an operation."""
async with self._lock:
self._check_reset()
would_exceed_tokens = (self._tokens_used + estimated_tokens) > self.tokens_limit
would_exceed_cost = (
self._cost_used_usd + estimated_cost_usd
) > self.cost_limit_usd
return not (would_exceed_tokens or would_exceed_cost)
def _check_reset(self) -> None:
"""Check if budget should reset."""
if self._reset_interval is None:
return
now = datetime.utcnow()
if now >= self._last_reset + self._reset_interval:
logger.info(
"Resetting budget for %s:%s",
self.scope.value,
self.scope_id,
)
self._tokens_used = 0
self._cost_used_usd = 0.0
self._last_reset = now
async def reset(self) -> None:
"""Manually reset the budget."""
async with self._lock:
self._tokens_used = 0
self._cost_used_usd = 0.0
self._last_reset = datetime.utcnow()
class CostController:
"""
Controls costs and budgets for agent operations.
Features:
- Per-agent, per-project, per-session budgets
- Real-time cost tracking
- Budget alerts at configurable thresholds
- Cost prediction for planned actions
- Budget rollover policies
"""
def __init__(
self,
default_session_tokens: int | None = None,
default_session_cost_usd: float | None = None,
default_daily_tokens: int | None = None,
default_daily_cost_usd: float | None = None,
) -> None:
"""
Initialize the CostController.
Args:
default_session_tokens: Default token budget per session
default_session_cost_usd: Default USD budget per session
default_daily_tokens: Default token budget per day
default_daily_cost_usd: Default USD budget per day
"""
config = get_safety_config()
self._default_session_tokens = (
default_session_tokens or config.default_session_token_budget
)
self._default_session_cost = (
default_session_cost_usd or config.default_session_cost_limit
)
self._default_daily_tokens = (
default_daily_tokens or config.default_daily_token_budget
)
self._default_daily_cost = (
default_daily_cost_usd or config.default_daily_cost_limit
)
self._trackers: dict[str, BudgetTracker] = {}
self._lock = asyncio.Lock()
# Alert handlers
self._alert_handlers: list[Any] = []
async def get_or_create_tracker(
self,
scope: BudgetScope,
scope_id: str,
) -> BudgetTracker:
"""Get or create a budget tracker."""
key = f"{scope.value}:{scope_id}"
async with self._lock:
if key not in self._trackers:
if scope == BudgetScope.SESSION:
tracker = BudgetTracker(
scope=scope,
scope_id=scope_id,
tokens_limit=self._default_session_tokens,
cost_limit_usd=self._default_session_cost,
)
elif scope == BudgetScope.DAILY:
tracker = BudgetTracker(
scope=scope,
scope_id=scope_id,
tokens_limit=self._default_daily_tokens,
cost_limit_usd=self._default_daily_cost,
reset_interval=timedelta(days=1),
)
else:
# Default
tracker = BudgetTracker(
scope=scope,
scope_id=scope_id,
tokens_limit=self._default_session_tokens,
cost_limit_usd=self._default_session_cost,
)
self._trackers[key] = tracker
return self._trackers[key]
async def check_budget(
self,
agent_id: str,
session_id: str | None,
estimated_tokens: int,
estimated_cost_usd: float,
) -> bool:
"""
Check if there's enough budget for an operation.
Args:
agent_id: ID of the agent
session_id: Optional session ID
estimated_tokens: Estimated token usage
estimated_cost_usd: Estimated USD cost
Returns:
True if budget is available
"""
# Check session budget
if session_id:
session_tracker = await self.get_or_create_tracker(
BudgetScope.SESSION, session_id
)
if not await session_tracker.check_budget(estimated_tokens, estimated_cost_usd):
return False
# Check agent daily budget
agent_tracker = await self.get_or_create_tracker(
BudgetScope.DAILY, agent_id
)
if not await agent_tracker.check_budget(estimated_tokens, estimated_cost_usd):
return False
return True
async def check_action(self, action: ActionRequest) -> bool:
"""
Check if an action is within budget.
Args:
action: The action to check
Returns:
True if within budget
"""
return await self.check_budget(
agent_id=action.metadata.agent_id,
session_id=action.metadata.session_id,
estimated_tokens=action.estimated_cost_tokens,
estimated_cost_usd=action.estimated_cost_usd,
)
async def require_budget(
self,
agent_id: str,
session_id: str | None,
estimated_tokens: int,
estimated_cost_usd: float,
) -> None:
"""
Require budget or raise exception.
Args:
agent_id: ID of the agent
session_id: Optional session ID
estimated_tokens: Estimated token usage
estimated_cost_usd: Estimated USD cost
Raises:
BudgetExceededError: If budget is exceeded
"""
if not await self.check_budget(
agent_id, session_id, estimated_tokens, estimated_cost_usd
):
# Determine which budget was exceeded
if session_id:
session_tracker = await self.get_or_create_tracker(
BudgetScope.SESSION, session_id
)
session_status = await session_tracker.get_status()
if session_status.is_exceeded:
raise BudgetExceededError(
"Session budget exceeded",
budget_type="session",
current_usage=session_status.tokens_used,
budget_limit=session_status.tokens_limit,
agent_id=agent_id,
)
agent_tracker = await self.get_or_create_tracker(
BudgetScope.DAILY, agent_id
)
agent_status = await agent_tracker.get_status()
raise BudgetExceededError(
"Daily budget exceeded",
budget_type="daily",
current_usage=agent_status.tokens_used,
budget_limit=agent_status.tokens_limit,
agent_id=agent_id,
)
async def record_usage(
self,
agent_id: str,
session_id: str | None,
tokens: int,
cost_usd: float,
) -> None:
"""
Record actual usage.
Args:
agent_id: ID of the agent
session_id: Optional session ID
tokens: Actual token usage
cost_usd: Actual USD cost
"""
# Update session budget
if session_id:
session_tracker = await self.get_or_create_tracker(
BudgetScope.SESSION, session_id
)
await session_tracker.add_usage(tokens, cost_usd)
# Check for warning
status = await session_tracker.get_status()
if status.is_warning and not status.is_exceeded:
await self._send_alert(
"warning",
f"Session {session_id} at {status.tokens_used}/{status.tokens_limit} tokens",
status,
)
# Update agent daily budget
agent_tracker = await self.get_or_create_tracker(BudgetScope.DAILY, agent_id)
await agent_tracker.add_usage(tokens, cost_usd)
# Check for warning
status = await agent_tracker.get_status()
if status.is_warning and not status.is_exceeded:
await self._send_alert(
"warning",
f"Agent {agent_id} at {status.tokens_used}/{status.tokens_limit} daily tokens",
status,
)
async def get_status(
self,
scope: BudgetScope,
scope_id: str,
) -> BudgetStatus | None:
"""
Get budget status.
Args:
scope: Budget scope
scope_id: ID within scope
Returns:
Budget status or None if not tracked
"""
key = f"{scope.value}:{scope_id}"
async with self._lock:
tracker = self._trackers.get(key)
if tracker:
return await tracker.get_status()
return None
async def get_all_statuses(self) -> list[BudgetStatus]:
"""Get status of all tracked budgets."""
statuses = []
async with self._lock:
trackers = list(self._trackers.values())
for tracker in trackers:
statuses.append(await tracker.get_status())
return statuses
async def set_budget(
self,
scope: BudgetScope,
scope_id: str,
tokens_limit: int,
cost_limit_usd: float,
) -> None:
"""
Set a custom budget limit.
Args:
scope: Budget scope
scope_id: ID within scope
tokens_limit: Token limit
cost_limit_usd: USD limit
"""
key = f"{scope.value}:{scope_id}"
reset_interval = None
if scope == BudgetScope.DAILY:
reset_interval = timedelta(days=1)
elif scope == BudgetScope.WEEKLY:
reset_interval = timedelta(weeks=1)
elif scope == BudgetScope.MONTHLY:
reset_interval = timedelta(days=30)
async with self._lock:
self._trackers[key] = BudgetTracker(
scope=scope,
scope_id=scope_id,
tokens_limit=tokens_limit,
cost_limit_usd=cost_limit_usd,
reset_interval=reset_interval,
)
async def reset_budget(self, scope: BudgetScope, scope_id: str) -> bool:
"""
Reset a budget tracker.
Args:
scope: Budget scope
scope_id: ID within scope
Returns:
True if tracker was found and reset
"""
key = f"{scope.value}:{scope_id}"
async with self._lock:
tracker = self._trackers.get(key)
if tracker:
await tracker.reset()
return True
return False
def add_alert_handler(self, handler: Any) -> None:
"""Add an alert handler."""
self._alert_handlers.append(handler)
def remove_alert_handler(self, handler: Any) -> None:
"""Remove an alert handler."""
if handler in self._alert_handlers:
self._alert_handlers.remove(handler)
async def _send_alert(
self,
alert_type: str,
message: str,
status: BudgetStatus,
) -> None:
"""Send alert to all handlers."""
for handler in self._alert_handlers:
try:
if asyncio.iscoroutinefunction(handler):
await handler(alert_type, message, status)
else:
handler(alert_type, message, status)
except Exception as e:
logger.error("Error in alert handler: %s", e)

View File

@@ -0,0 +1,23 @@
"""Emergency controls for agent safety."""
from .controls import (
EmergencyControls,
EmergencyEvent,
EmergencyReason,
EmergencyState,
EmergencyTrigger,
check_emergency_allowed,
emergency_stop_global,
get_emergency_controls,
)
__all__ = [
"EmergencyControls",
"EmergencyEvent",
"EmergencyReason",
"EmergencyState",
"EmergencyTrigger",
"check_emergency_allowed",
"emergency_stop_global",
"get_emergency_controls",
]

View File

@@ -0,0 +1,594 @@
"""
Emergency Controls
Emergency stop and pause functionality for agent safety.
"""
import asyncio
import logging
from collections.abc import Callable
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any
from ..exceptions import EmergencyStopError
logger = logging.getLogger(__name__)
class EmergencyState(str, Enum):
"""Emergency control states."""
NORMAL = "normal"
PAUSED = "paused"
STOPPED = "stopped"
class EmergencyReason(str, Enum):
"""Reasons for emergency actions."""
MANUAL = "manual"
SAFETY_VIOLATION = "safety_violation"
BUDGET_EXCEEDED = "budget_exceeded"
LOOP_DETECTED = "loop_detected"
RATE_LIMIT = "rate_limit"
CONTENT_VIOLATION = "content_violation"
SYSTEM_ERROR = "system_error"
EXTERNAL_TRIGGER = "external_trigger"
@dataclass
class EmergencyEvent:
"""Record of an emergency action."""
id: str
state: EmergencyState
reason: EmergencyReason
triggered_by: str
message: str
scope: str # "global", "project:<id>", "agent:<id>"
timestamp: datetime = field(default_factory=datetime.utcnow)
metadata: dict[str, Any] = field(default_factory=dict)
resolved_at: datetime | None = None
resolved_by: str | None = None
class EmergencyControls:
"""
Emergency stop and pause controls for agent safety.
Features:
- Global emergency stop
- Per-project/agent emergency controls
- Graceful pause with state preservation
- Automatic triggers from safety violations
- Manual override capabilities
- Event history and audit trail
"""
def __init__(
self,
notification_handlers: list[Callable[..., Any]] | None = None,
) -> None:
"""
Initialize EmergencyControls.
Args:
notification_handlers: Handlers to call on emergency events
"""
self._global_state = EmergencyState.NORMAL
self._scoped_states: dict[str, EmergencyState] = {}
self._events: list[EmergencyEvent] = []
self._notification_handlers = notification_handlers or []
self._lock = asyncio.Lock()
self._event_id_counter = 0
# Callbacks for state changes
self._on_stop_callbacks: list[Callable[..., Any]] = []
self._on_pause_callbacks: list[Callable[..., Any]] = []
self._on_resume_callbacks: list[Callable[..., Any]] = []
def _generate_event_id(self) -> str:
"""Generate a unique event ID."""
self._event_id_counter += 1
return f"emerg-{self._event_id_counter:06d}"
async def emergency_stop(
self,
reason: EmergencyReason,
triggered_by: str,
message: str,
scope: str = "global",
metadata: dict[str, Any] | None = None,
) -> EmergencyEvent:
"""
Trigger emergency stop.
Args:
reason: Reason for the stop
triggered_by: Who/what triggered the stop
message: Human-readable message
scope: Scope of the stop (global, project:<id>, agent:<id>)
metadata: Additional context
Returns:
The emergency event record
"""
async with self._lock:
event = EmergencyEvent(
id=self._generate_event_id(),
state=EmergencyState.STOPPED,
reason=reason,
triggered_by=triggered_by,
message=message,
scope=scope,
metadata=metadata or {},
)
if scope == "global":
self._global_state = EmergencyState.STOPPED
else:
self._scoped_states[scope] = EmergencyState.STOPPED
self._events.append(event)
logger.critical(
"EMERGENCY STOP: scope=%s, reason=%s, by=%s - %s",
scope,
reason.value,
triggered_by,
message,
)
# Execute callbacks
await self._execute_callbacks(self._on_stop_callbacks, event)
await self._notify_handlers("emergency_stop", event)
return event
async def pause(
self,
reason: EmergencyReason,
triggered_by: str,
message: str,
scope: str = "global",
metadata: dict[str, Any] | None = None,
) -> EmergencyEvent:
"""
Pause operations (can be resumed).
Args:
reason: Reason for the pause
triggered_by: Who/what triggered the pause
message: Human-readable message
scope: Scope of the pause
metadata: Additional context
Returns:
The emergency event record
"""
async with self._lock:
event = EmergencyEvent(
id=self._generate_event_id(),
state=EmergencyState.PAUSED,
reason=reason,
triggered_by=triggered_by,
message=message,
scope=scope,
metadata=metadata or {},
)
if scope == "global":
self._global_state = EmergencyState.PAUSED
else:
self._scoped_states[scope] = EmergencyState.PAUSED
self._events.append(event)
logger.warning(
"PAUSE: scope=%s, reason=%s, by=%s - %s",
scope,
reason.value,
triggered_by,
message,
)
await self._execute_callbacks(self._on_pause_callbacks, event)
await self._notify_handlers("pause", event)
return event
async def resume(
self,
scope: str = "global",
resumed_by: str = "system",
message: str | None = None,
) -> bool:
"""
Resume operations from paused state.
Args:
scope: Scope to resume
resumed_by: Who/what is resuming
message: Optional message
Returns:
True if resumed, False if not in paused state
"""
async with self._lock:
current_state = self._get_state(scope)
if current_state == EmergencyState.STOPPED:
logger.warning(
"Cannot resume from STOPPED state: %s (requires reset)",
scope,
)
return False
if current_state == EmergencyState.NORMAL:
return True # Already normal
# Find the pause event and mark as resolved
for event in reversed(self._events):
if event.scope == scope and event.state == EmergencyState.PAUSED:
if event.resolved_at is None:
event.resolved_at = datetime.utcnow()
event.resolved_by = resumed_by
break
if scope == "global":
self._global_state = EmergencyState.NORMAL
else:
self._scoped_states[scope] = EmergencyState.NORMAL
logger.info(
"RESUMED: scope=%s, by=%s%s",
scope,
resumed_by,
f" - {message}" if message else "",
)
await self._execute_callbacks(
self._on_resume_callbacks,
{"scope": scope, "resumed_by": resumed_by},
)
await self._notify_handlers("resume", {"scope": scope, "resumed_by": resumed_by})
return True
async def reset(
self,
scope: str = "global",
reset_by: str = "admin",
message: str | None = None,
) -> bool:
"""
Reset from stopped state (requires explicit action).
Args:
scope: Scope to reset
reset_by: Who is resetting (should be admin)
message: Optional message
Returns:
True if reset successful
"""
async with self._lock:
current_state = self._get_state(scope)
if current_state == EmergencyState.NORMAL:
return True
# Find the stop event and mark as resolved
for event in reversed(self._events):
if event.scope == scope and event.state == EmergencyState.STOPPED:
if event.resolved_at is None:
event.resolved_at = datetime.utcnow()
event.resolved_by = reset_by
break
if scope == "global":
self._global_state = EmergencyState.NORMAL
else:
self._scoped_states[scope] = EmergencyState.NORMAL
logger.warning(
"EMERGENCY RESET: scope=%s, by=%s%s",
scope,
reset_by,
f" - {message}" if message else "",
)
await self._notify_handlers("reset", {"scope": scope, "reset_by": reset_by})
return True
async def check_allowed(
self,
scope: str | None = None,
raise_if_blocked: bool = True,
) -> bool:
"""
Check if operations are allowed.
Args:
scope: Specific scope to check (also checks global)
raise_if_blocked: Raise exception if blocked
Returns:
True if operations are allowed
Raises:
EmergencyStopError: If blocked and raise_if_blocked=True
"""
async with self._lock:
# Always check global state
if self._global_state != EmergencyState.NORMAL:
if raise_if_blocked:
raise EmergencyStopError(
f"Global emergency state: {self._global_state.value}",
stop_type=self._get_last_reason("global") or "emergency",
triggered_by=self._get_last_triggered_by("global"),
)
return False
# Check specific scope
if scope and scope in self._scoped_states:
state = self._scoped_states[scope]
if state != EmergencyState.NORMAL:
if raise_if_blocked:
raise EmergencyStopError(
f"Emergency state for {scope}: {state.value}",
stop_type=self._get_last_reason(scope) or "emergency",
triggered_by=self._get_last_triggered_by(scope),
details={"scope": scope},
)
return False
return True
def _get_state(self, scope: str) -> EmergencyState:
"""Get state for a scope."""
if scope == "global":
return self._global_state
return self._scoped_states.get(scope, EmergencyState.NORMAL)
def _get_last_reason(self, scope: str) -> str:
"""Get reason from last event for scope."""
for event in reversed(self._events):
if event.scope == scope and event.resolved_at is None:
return event.reason.value
return "unknown"
def _get_last_triggered_by(self, scope: str) -> str:
"""Get triggered_by from last event for scope."""
for event in reversed(self._events):
if event.scope == scope and event.resolved_at is None:
return event.triggered_by
return "unknown"
async def get_state(self, scope: str = "global") -> EmergencyState:
"""Get current state for a scope."""
async with self._lock:
return self._get_state(scope)
async def get_all_states(self) -> dict[str, EmergencyState]:
"""Get all current states."""
async with self._lock:
states = {"global": self._global_state}
states.update(self._scoped_states)
return states
async def get_active_events(self) -> list[EmergencyEvent]:
"""Get all unresolved emergency events."""
async with self._lock:
return [e for e in self._events if e.resolved_at is None]
async def get_event_history(
self,
scope: str | None = None,
limit: int = 100,
) -> list[EmergencyEvent]:
"""Get emergency event history."""
async with self._lock:
events = list(self._events)
if scope:
events = [e for e in events if e.scope == scope]
return events[-limit:]
def on_stop(self, callback: Callable[..., Any]) -> None:
"""Register callback for stop events."""
self._on_stop_callbacks.append(callback)
def on_pause(self, callback: Callable[..., Any]) -> None:
"""Register callback for pause events."""
self._on_pause_callbacks.append(callback)
def on_resume(self, callback: Callable[..., Any]) -> None:
"""Register callback for resume events."""
self._on_resume_callbacks.append(callback)
def add_notification_handler(self, handler: Callable[..., Any]) -> None:
"""Add a notification handler."""
self._notification_handlers.append(handler)
async def _execute_callbacks(
self,
callbacks: list[Callable[..., Any]],
data: Any,
) -> None:
"""Execute callbacks safely."""
for callback in callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback(data)
else:
callback(data)
except Exception as e:
logger.error("Error in callback: %s", e)
async def _notify_handlers(self, event_type: str, data: Any) -> None:
"""Notify all handlers of an event."""
for handler in self._notification_handlers:
try:
if asyncio.iscoroutinefunction(handler):
await handler(event_type, data)
else:
handler(event_type, data)
except Exception as e:
logger.error("Error in notification handler: %s", e)
class EmergencyTrigger:
"""
Automatic emergency triggers based on conditions.
"""
def __init__(self, controls: EmergencyControls) -> None:
"""
Initialize EmergencyTrigger.
Args:
controls: EmergencyControls instance to trigger
"""
self._controls = controls
async def trigger_on_safety_violation(
self,
violation_type: str,
details: dict[str, Any],
scope: str = "global",
) -> EmergencyEvent:
"""
Trigger emergency from safety violation.
Args:
violation_type: Type of violation
details: Violation details
scope: Scope for the emergency
Returns:
Emergency event
"""
return await self._controls.emergency_stop(
reason=EmergencyReason.SAFETY_VIOLATION,
triggered_by="safety_system",
message=f"Safety violation: {violation_type}",
scope=scope,
metadata={"violation_type": violation_type, **details},
)
async def trigger_on_budget_exceeded(
self,
budget_type: str,
current: float,
limit: float,
scope: str = "global",
) -> EmergencyEvent:
"""
Trigger emergency from budget exceeded.
Args:
budget_type: Type of budget
current: Current usage
limit: Budget limit
scope: Scope for the emergency
Returns:
Emergency event
"""
return await self._controls.pause(
reason=EmergencyReason.BUDGET_EXCEEDED,
triggered_by="budget_controller",
message=f"Budget exceeded: {budget_type} ({current:.2f}/{limit:.2f})",
scope=scope,
metadata={"budget_type": budget_type, "current": current, "limit": limit},
)
async def trigger_on_loop_detected(
self,
loop_type: str,
agent_id: str,
details: dict[str, Any],
) -> EmergencyEvent:
"""
Trigger emergency from loop detection.
Args:
loop_type: Type of loop
agent_id: Agent that's looping
details: Loop details
Returns:
Emergency event
"""
return await self._controls.pause(
reason=EmergencyReason.LOOP_DETECTED,
triggered_by="loop_detector",
message=f"Loop detected: {loop_type} in agent {agent_id}",
scope=f"agent:{agent_id}",
metadata={"loop_type": loop_type, "agent_id": agent_id, **details},
)
async def trigger_on_content_violation(
self,
category: str,
pattern: str,
scope: str = "global",
) -> EmergencyEvent:
"""
Trigger emergency from content violation.
Args:
category: Content category
pattern: Pattern that matched
scope: Scope for the emergency
Returns:
Emergency event
"""
return await self._controls.emergency_stop(
reason=EmergencyReason.CONTENT_VIOLATION,
triggered_by="content_filter",
message=f"Content violation: {category} ({pattern})",
scope=scope,
metadata={"category": category, "pattern": pattern},
)
# Singleton instance
_emergency_controls: EmergencyControls | None = None
_lock = asyncio.Lock()
async def get_emergency_controls() -> EmergencyControls:
"""Get the singleton EmergencyControls instance."""
global _emergency_controls
async with _lock:
if _emergency_controls is None:
_emergency_controls = EmergencyControls()
return _emergency_controls
async def emergency_stop_global(
reason: str,
triggered_by: str = "system",
) -> EmergencyEvent:
"""Quick global emergency stop."""
controls = await get_emergency_controls()
return await controls.emergency_stop(
reason=EmergencyReason.MANUAL,
triggered_by=triggered_by,
message=reason,
scope="global",
)
async def check_emergency_allowed(scope: str | None = None) -> bool:
"""Quick check if operations are allowed."""
controls = await get_emergency_controls()
return await controls.check_allowed(scope=scope, raise_if_blocked=False)

View File

@@ -0,0 +1,277 @@
"""
Safety Framework Exceptions
Custom exception classes for the safety and guardrails framework.
"""
from typing import Any
class SafetyError(Exception):
"""Base exception for all safety-related errors."""
def __init__(
self,
message: str,
*,
action_id: str | None = None,
agent_id: str | None = None,
details: dict[str, Any] | None = None,
) -> None:
super().__init__(message)
self.message = message
self.action_id = action_id
self.agent_id = agent_id
self.details = details or {}
class PermissionDeniedError(SafetyError):
"""Raised when an action is not permitted."""
def __init__(
self,
message: str = "Permission denied",
*,
action_type: str | None = None,
resource: str | None = None,
required_permission: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.action_type = action_type
self.resource = resource
self.required_permission = required_permission
class BudgetExceededError(SafetyError):
"""Raised when cost budget is exceeded."""
def __init__(
self,
message: str = "Budget exceeded",
*,
budget_type: str = "session",
current_usage: float = 0.0,
budget_limit: float = 0.0,
unit: str = "tokens",
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.budget_type = budget_type
self.current_usage = current_usage
self.budget_limit = budget_limit
self.unit = unit
class RateLimitExceededError(SafetyError):
"""Raised when rate limit is exceeded."""
def __init__(
self,
message: str = "Rate limit exceeded",
*,
limit_type: str = "actions",
limit_value: int = 0,
window_seconds: int = 60,
retry_after_seconds: float = 0.0,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.limit_type = limit_type
self.limit_value = limit_value
self.window_seconds = window_seconds
self.retry_after_seconds = retry_after_seconds
class LoopDetectedError(SafetyError):
"""Raised when an action loop is detected."""
def __init__(
self,
message: str = "Loop detected",
*,
loop_type: str = "exact",
repetition_count: int = 0,
action_pattern: list[str] | None = None,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.loop_type = loop_type
self.repetition_count = repetition_count
self.action_pattern = action_pattern or []
class ApprovalRequiredError(SafetyError):
"""Raised when human approval is required."""
def __init__(
self,
message: str = "Human approval required",
*,
approval_id: str | None = None,
reason: str | None = None,
timeout_seconds: int = 300,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.approval_id = approval_id
self.reason = reason
self.timeout_seconds = timeout_seconds
class ApprovalDeniedError(SafetyError):
"""Raised when human explicitly denies an action."""
def __init__(
self,
message: str = "Approval denied by human",
*,
approval_id: str | None = None,
denied_by: str | None = None,
denial_reason: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.approval_id = approval_id
self.denied_by = denied_by
self.denial_reason = denial_reason
class ApprovalTimeoutError(SafetyError):
"""Raised when approval request times out."""
def __init__(
self,
message: str = "Approval request timed out",
*,
approval_id: str | None = None,
timeout_seconds: int = 300,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.approval_id = approval_id
self.timeout_seconds = timeout_seconds
class RollbackError(SafetyError):
"""Raised when rollback fails."""
def __init__(
self,
message: str = "Rollback failed",
*,
checkpoint_id: str | None = None,
failed_actions: list[str] | None = None,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.checkpoint_id = checkpoint_id
self.failed_actions = failed_actions or []
class CheckpointError(SafetyError):
"""Raised when checkpoint creation fails."""
def __init__(
self,
message: str = "Checkpoint creation failed",
*,
checkpoint_type: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.checkpoint_type = checkpoint_type
class ValidationError(SafetyError):
"""Raised when action validation fails."""
def __init__(
self,
message: str = "Validation failed",
*,
validation_rules: list[str] | None = None,
failed_rules: list[str] | None = None,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.validation_rules = validation_rules or []
self.failed_rules = failed_rules or []
class ContentFilterError(SafetyError):
"""Raised when content filtering detects prohibited content."""
def __init__(
self,
message: str = "Prohibited content detected",
*,
filter_type: str | None = None,
detected_patterns: list[str] | None = None,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.filter_type = filter_type
self.detected_patterns = detected_patterns or []
class SandboxError(SafetyError):
"""Raised when sandbox execution fails."""
def __init__(
self,
message: str = "Sandbox execution failed",
*,
exit_code: int | None = None,
stderr: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.exit_code = exit_code
self.stderr = stderr
class SandboxTimeoutError(SandboxError):
"""Raised when sandbox execution times out."""
def __init__(
self,
message: str = "Sandbox execution timed out",
*,
timeout_seconds: int = 300,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.timeout_seconds = timeout_seconds
class EmergencyStopError(SafetyError):
"""Raised when emergency stop is triggered."""
def __init__(
self,
message: str = "Emergency stop triggered",
*,
stop_type: str = "kill",
triggered_by: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.stop_type = stop_type
self.triggered_by = triggered_by
class PolicyViolationError(SafetyError):
"""Raised when an action violates a safety policy."""
def __init__(
self,
message: str = "Policy violation",
*,
policy_name: str | None = None,
violated_rules: list[str] | None = None,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.policy_name = policy_name
self.violated_rules = violated_rules or []

View File

@@ -0,0 +1,614 @@
"""
Safety Guardian
Main facade for the safety framework. Orchestrates all safety checks
before, during, and after action execution.
"""
import asyncio
import logging
from typing import Any
from .audit import AuditLogger, get_audit_logger
from .config import (
SafetyConfig,
get_policy_for_autonomy_level,
get_safety_config,
)
from .exceptions import (
SafetyError,
)
from .models import (
ActionRequest,
ActionResult,
AuditEventType,
GuardianResult,
SafetyDecision,
SafetyPolicy,
)
logger = logging.getLogger(__name__)
class SafetyGuardian:
"""
Central orchestrator for all safety checks.
The SafetyGuardian is the main entry point for validating agent actions.
It coordinates multiple safety subsystems:
- Permission checking
- Cost/budget control
- Rate limiting
- Loop detection
- Human-in-the-loop approval
- Rollback/checkpoint management
- Content filtering
- Sandbox execution
Usage:
guardian = SafetyGuardian()
await guardian.initialize()
# Before executing an action
result = await guardian.validate(action_request)
if not result.allowed:
# Handle denial
# After action execution
await guardian.record_execution(action_request, action_result)
"""
def __init__(
self,
config: SafetyConfig | None = None,
audit_logger: AuditLogger | None = None,
) -> None:
"""
Initialize the SafetyGuardian.
Args:
config: Optional safety configuration. If None, loads from environment.
audit_logger: Optional audit logger. If None, uses global instance.
"""
self._config = config or get_safety_config()
self._audit_logger = audit_logger
self._initialized = False
self._lock = asyncio.Lock()
# Subsystem references (will be initialized lazily)
self._permission_manager: Any = None
self._cost_controller: Any = None
self._rate_limiter: Any = None
self._loop_detector: Any = None
self._hitl_manager: Any = None
self._rollback_manager: Any = None
self._content_filter: Any = None
self._sandbox_executor: Any = None
self._emergency_controls: Any = None
# Policy cache
self._policies: dict[str, SafetyPolicy] = {}
self._default_policy: SafetyPolicy | None = None
@property
def is_initialized(self) -> bool:
"""Check if the guardian is initialized."""
return self._initialized
async def initialize(self) -> None:
"""Initialize the SafetyGuardian and all subsystems."""
async with self._lock:
if self._initialized:
logger.warning("SafetyGuardian already initialized")
return
logger.info("Initializing SafetyGuardian")
# Get audit logger
if self._audit_logger is None:
self._audit_logger = await get_audit_logger()
# Initialize subsystems lazily as they're implemented
# For now, we'll import and initialize them when available
self._initialized = True
logger.info("SafetyGuardian initialized")
async def shutdown(self) -> None:
"""Shutdown the SafetyGuardian and all subsystems."""
async with self._lock:
if not self._initialized:
return
logger.info("Shutting down SafetyGuardian")
# Shutdown subsystems
# (Will be implemented as subsystems are added)
self._initialized = False
logger.info("SafetyGuardian shutdown complete")
async def validate(
self,
action: ActionRequest,
policy: SafetyPolicy | None = None,
) -> GuardianResult:
"""
Validate an action before execution.
Runs all safety checks in order:
1. Permission check
2. Cost/budget check
3. Rate limit check
4. Loop detection
5. HITL check (if required)
6. Checkpoint creation (if destructive)
Args:
action: The action to validate
policy: Optional policy override. If None, uses autonomy-level policy.
Returns:
GuardianResult with decision and details
"""
if not self._initialized:
await self.initialize()
if not self._config.enabled:
# Safety disabled - allow everything (NOT RECOMMENDED)
logger.warning("Safety framework disabled - allowing action %s", action.id)
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Safety framework disabled"],
)
# Get policy for this action
effective_policy = policy or self._get_policy(action)
reasons: list[str] = []
audit_events = []
try:
# Log action request
if self._audit_logger:
event = await self._audit_logger.log(
AuditEventType.ACTION_REQUESTED,
agent_id=action.metadata.agent_id,
action_id=action.id,
project_id=action.metadata.project_id,
session_id=action.metadata.session_id,
details={
"action_type": action.action_type.value,
"tool_name": action.tool_name,
"resource": action.resource,
},
correlation_id=action.metadata.correlation_id,
)
audit_events.append(event)
# 1. Permission check
permission_result = await self._check_permissions(action, effective_policy)
if permission_result.decision == SafetyDecision.DENY:
return await self._create_denial_result(
action, permission_result.reasons, audit_events
)
# 2. Cost/budget check
budget_result = await self._check_budget(action, effective_policy)
if budget_result.decision == SafetyDecision.DENY:
return await self._create_denial_result(
action, budget_result.reasons, audit_events
)
# 3. Rate limit check
rate_result = await self._check_rate_limit(action, effective_policy)
if rate_result.decision == SafetyDecision.DENY:
return await self._create_denial_result(
action,
rate_result.reasons,
audit_events,
retry_after=rate_result.retry_after_seconds,
)
if rate_result.decision == SafetyDecision.DELAY:
# Return delay decision
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DELAY,
reasons=rate_result.reasons,
retry_after_seconds=rate_result.retry_after_seconds,
audit_events=audit_events,
)
# 4. Loop detection
loop_result = await self._check_loops(action, effective_policy)
if loop_result.decision == SafetyDecision.DENY:
return await self._create_denial_result(
action, loop_result.reasons, audit_events
)
# 5. HITL check
hitl_result = await self._check_hitl(action, effective_policy)
if hitl_result.decision == SafetyDecision.REQUIRE_APPROVAL:
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.REQUIRE_APPROVAL,
reasons=hitl_result.reasons,
approval_id=hitl_result.approval_id,
audit_events=audit_events,
)
# 6. Create checkpoint if destructive
checkpoint_id = None
if action.is_destructive and self._config.auto_checkpoint_destructive:
checkpoint_id = await self._create_checkpoint(action)
# All checks passed
reasons.append("All safety checks passed")
if self._audit_logger:
event = await self._audit_logger.log_action_request(
action, SafetyDecision.ALLOW, reasons
)
audit_events.append(event)
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=reasons,
checkpoint_id=checkpoint_id,
audit_events=audit_events,
)
except SafetyError as e:
# Known safety error
return await self._create_denial_result(
action, [str(e)], audit_events
)
except Exception as e:
# Unknown error - fail closed in strict mode
logger.error("Unexpected error in safety validation: %s", e)
if self._config.strict_mode:
return await self._create_denial_result(
action,
[f"Safety validation error: {e}"],
audit_events,
)
else:
# Non-strict mode - allow with warning
logger.warning("Non-strict mode: allowing action despite error")
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Allowed despite validation error (non-strict mode)"],
audit_events=audit_events,
)
async def record_execution(
self,
action: ActionRequest,
result: ActionResult,
) -> None:
"""
Record action execution result for auditing and tracking.
Args:
action: The executed action
result: The execution result
"""
if self._audit_logger:
await self._audit_logger.log_action_executed(
action,
success=result.success,
execution_time_ms=result.execution_time_ms,
error=result.error,
)
# Update cost tracking
if self._cost_controller:
# Track actual cost
pass
# Update loop detection history
if self._loop_detector:
# Add to action history
pass
async def rollback(self, checkpoint_id: str) -> bool:
"""
Rollback to a checkpoint.
Args:
checkpoint_id: ID of the checkpoint to rollback to
Returns:
True if rollback succeeded
"""
if self._rollback_manager is None:
logger.warning("Rollback manager not available")
return False
# Delegate to rollback manager
return await self._rollback_manager.rollback(checkpoint_id)
async def emergency_stop(
self,
stop_type: str = "kill",
reason: str = "Manual emergency stop",
triggered_by: str = "system",
) -> None:
"""
Trigger emergency stop.
Args:
stop_type: Type of stop (kill, pause, lockdown)
reason: Reason for the stop
triggered_by: Who triggered the stop
"""
logger.critical(
"Emergency stop triggered: type=%s, reason=%s, by=%s",
stop_type,
reason,
triggered_by,
)
if self._audit_logger:
await self._audit_logger.log_emergency_stop(
stop_type=stop_type,
triggered_by=triggered_by,
reason=reason,
)
if self._emergency_controls:
await self._emergency_controls.execute_stop(stop_type)
def _get_policy(self, action: ActionRequest) -> SafetyPolicy:
"""Get the effective policy for an action."""
# Check cached policies
autonomy_level = action.metadata.autonomy_level
if autonomy_level.value not in self._policies:
self._policies[autonomy_level.value] = get_policy_for_autonomy_level(
autonomy_level
)
return self._policies[autonomy_level.value]
async def _check_permissions(
self,
action: ActionRequest,
policy: SafetyPolicy,
) -> GuardianResult:
"""Check if action is permitted."""
reasons: list[str] = []
# Check denied tools
if action.tool_name:
for pattern in policy.denied_tools:
if self._matches_pattern(action.tool_name, pattern):
reasons.append(f"Tool '{action.tool_name}' denied by pattern '{pattern}'")
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=reasons,
)
# Check allowed tools (if not "*")
if action.tool_name and "*" not in policy.allowed_tools:
allowed = False
for pattern in policy.allowed_tools:
if self._matches_pattern(action.tool_name, pattern):
allowed = True
break
if not allowed:
reasons.append(f"Tool '{action.tool_name}' not in allowed list")
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=reasons,
)
# Check file patterns
if action.resource:
for pattern in policy.denied_file_patterns:
if self._matches_pattern(action.resource, pattern):
reasons.append(f"Resource '{action.resource}' denied by pattern '{pattern}'")
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=reasons,
)
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Permission check passed"],
)
async def _check_budget(
self,
action: ActionRequest,
policy: SafetyPolicy,
) -> GuardianResult:
"""Check if action is within budget."""
# TODO: Implement with CostController
# For now, return allow
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Budget check passed (not fully implemented)"],
)
async def _check_rate_limit(
self,
action: ActionRequest,
policy: SafetyPolicy,
) -> GuardianResult:
"""Check if action is within rate limits."""
# TODO: Implement with RateLimiter
# For now, return allow
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Rate limit check passed (not fully implemented)"],
)
async def _check_loops(
self,
action: ActionRequest,
policy: SafetyPolicy,
) -> GuardianResult:
"""Check for action loops."""
# TODO: Implement with LoopDetector
# For now, return allow
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Loop check passed (not fully implemented)"],
)
async def _check_hitl(
self,
action: ActionRequest,
policy: SafetyPolicy,
) -> GuardianResult:
"""Check if human approval is required."""
if not self._config.hitl_enabled:
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["HITL disabled"],
)
# Check if action requires approval
requires_approval = False
for pattern in policy.require_approval_for:
if pattern == "*":
requires_approval = True
break
if action.tool_name and self._matches_pattern(action.tool_name, pattern):
requires_approval = True
break
if action.action_type.value and self._matches_pattern(
action.action_type.value, pattern
):
requires_approval = True
break
if requires_approval:
# TODO: Create approval request with HITLManager
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.REQUIRE_APPROVAL,
reasons=["Action requires human approval"],
approval_id=None, # Will be set by HITLManager
)
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["No approval required"],
)
async def _create_checkpoint(self, action: ActionRequest) -> str | None:
"""Create a checkpoint before destructive action."""
if self._rollback_manager is None:
logger.warning("Rollback manager not available - skipping checkpoint")
return None
# TODO: Implement with RollbackManager
return None
async def _create_denial_result(
self,
action: ActionRequest,
reasons: list[str],
audit_events: list[Any],
retry_after: float | None = None,
) -> GuardianResult:
"""Create a denial result with audit logging."""
if self._audit_logger:
event = await self._audit_logger.log_action_request(
action, SafetyDecision.DENY, reasons
)
audit_events.append(event)
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=reasons,
retry_after_seconds=retry_after,
audit_events=audit_events,
)
def _matches_pattern(self, value: str, pattern: str) -> bool:
"""Check if value matches a pattern (supports * wildcard)."""
if pattern == "*":
return True
if "*" not in pattern:
return value == pattern
# Simple wildcard matching
if pattern.startswith("*") and pattern.endswith("*"):
return pattern[1:-1] in value
elif pattern.startswith("*"):
return value.endswith(pattern[1:])
elif pattern.endswith("*"):
return value.startswith(pattern[:-1])
else:
# Pattern like "foo*bar"
parts = pattern.split("*")
if len(parts) == 2:
return value.startswith(parts[0]) and value.endswith(parts[1])
return False
# Singleton instance
_guardian_instance: SafetyGuardian | None = None
_guardian_lock = asyncio.Lock()
async def get_safety_guardian() -> SafetyGuardian:
"""Get the global SafetyGuardian instance."""
global _guardian_instance
async with _guardian_lock:
if _guardian_instance is None:
_guardian_instance = SafetyGuardian()
await _guardian_instance.initialize()
return _guardian_instance
async def shutdown_safety_guardian() -> None:
"""Shutdown the global SafetyGuardian."""
global _guardian_instance
async with _guardian_lock:
if _guardian_instance is not None:
await _guardian_instance.shutdown()
_guardian_instance = None
def reset_safety_guardian() -> None:
"""Reset the SafetyGuardian (for testing)."""
global _guardian_instance
_guardian_instance = None

View File

@@ -0,0 +1,5 @@
"""Human-in-the-Loop approval workflows."""
from .manager import ApprovalQueue, HITLManager
__all__ = ["ApprovalQueue", "HITLManager"]

View File

@@ -0,0 +1,449 @@
"""
Human-in-the-Loop (HITL) Manager
Manages approval workflows for actions requiring human oversight.
"""
import asyncio
import logging
from collections.abc import Callable
from datetime import datetime, timedelta
from typing import Any
from uuid import uuid4
from ..config import get_safety_config
from ..exceptions import (
ApprovalDeniedError,
ApprovalRequiredError,
ApprovalTimeoutError,
)
from ..models import (
ActionRequest,
ApprovalRequest,
ApprovalResponse,
ApprovalStatus,
)
logger = logging.getLogger(__name__)
class ApprovalQueue:
"""Queue for pending approval requests."""
def __init__(self) -> None:
self._pending: dict[str, ApprovalRequest] = {}
self._completed: dict[str, ApprovalResponse] = {}
self._waiters: dict[str, asyncio.Event] = {}
self._lock = asyncio.Lock()
async def add(self, request: ApprovalRequest) -> None:
"""Add an approval request to the queue."""
async with self._lock:
self._pending[request.id] = request
self._waiters[request.id] = asyncio.Event()
async def get_pending(self, request_id: str) -> ApprovalRequest | None:
"""Get a pending request by ID."""
async with self._lock:
return self._pending.get(request_id)
async def complete(self, response: ApprovalResponse) -> bool:
"""Complete an approval request."""
async with self._lock:
if response.request_id not in self._pending:
return False
del self._pending[response.request_id]
self._completed[response.request_id] = response
# Notify waiters
if response.request_id in self._waiters:
self._waiters[response.request_id].set()
return True
async def wait_for_response(
self,
request_id: str,
timeout_seconds: float,
) -> ApprovalResponse | None:
"""Wait for a response to an approval request."""
async with self._lock:
waiter = self._waiters.get(request_id)
if not waiter:
return self._completed.get(request_id)
try:
await asyncio.wait_for(waiter.wait(), timeout=timeout_seconds)
except TimeoutError:
return None
async with self._lock:
return self._completed.get(request_id)
async def list_pending(self) -> list[ApprovalRequest]:
"""List all pending requests."""
async with self._lock:
return list(self._pending.values())
async def cancel(self, request_id: str) -> bool:
"""Cancel a pending request."""
async with self._lock:
if request_id not in self._pending:
return False
del self._pending[request_id]
# Create cancelled response
response = ApprovalResponse(
request_id=request_id,
status=ApprovalStatus.CANCELLED,
reason="Cancelled",
)
self._completed[request_id] = response
# Notify waiters
if request_id in self._waiters:
self._waiters[request_id].set()
return True
async def cleanup_expired(self) -> int:
"""Clean up expired requests."""
now = datetime.utcnow()
to_timeout: list[str] = []
async with self._lock:
for request_id, request in self._pending.items():
if request.expires_at and request.expires_at < now:
to_timeout.append(request_id)
count = 0
for request_id in to_timeout:
async with self._lock:
if request_id in self._pending:
del self._pending[request_id]
self._completed[request_id] = ApprovalResponse(
request_id=request_id,
status=ApprovalStatus.TIMEOUT,
reason="Request timed out",
)
if request_id in self._waiters:
self._waiters[request_id].set()
count += 1
return count
class HITLManager:
"""
Manages Human-in-the-Loop approval workflows.
Features:
- Approval request queue
- Configurable timeout handling (default deny)
- Approval delegation
- Batch approval for similar actions
- Approval with modifications
- Notification channels
"""
def __init__(
self,
default_timeout: int | None = None,
) -> None:
"""
Initialize the HITLManager.
Args:
default_timeout: Default timeout for approval requests in seconds
"""
config = get_safety_config()
self._default_timeout = default_timeout or config.hitl_default_timeout
self._queue = ApprovalQueue()
self._notification_handlers: list[Callable[..., Any]] = []
self._running = False
self._cleanup_task: asyncio.Task[None] | None = None
async def start(self) -> None:
"""Start the HITL manager background tasks."""
if self._running:
return
self._running = True
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
logger.info("HITL Manager started")
async def stop(self) -> None:
"""Stop the HITL manager."""
self._running = False
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
logger.info("HITL Manager stopped")
async def request_approval(
self,
action: ActionRequest,
reason: str,
timeout_seconds: int | None = None,
urgency: str = "normal",
context: dict[str, Any] | None = None,
) -> ApprovalRequest:
"""
Create an approval request for an action.
Args:
action: The action requiring approval
reason: Why approval is required
timeout_seconds: Timeout for this request
urgency: Urgency level (low, normal, high, critical)
context: Additional context for the approver
Returns:
The created approval request
"""
timeout = timeout_seconds or self._default_timeout
expires_at = datetime.utcnow() + timedelta(seconds=timeout)
request = ApprovalRequest(
id=str(uuid4()),
action=action,
reason=reason,
urgency=urgency,
timeout_seconds=timeout,
expires_at=expires_at,
context=context or {},
)
await self._queue.add(request)
# Notify handlers
await self._notify_handlers("approval_requested", request)
logger.info(
"Approval requested: %s for action %s (timeout: %ds)",
request.id,
action.id,
timeout,
)
return request
async def wait_for_approval(
self,
request_id: str,
timeout_seconds: int | None = None,
) -> ApprovalResponse:
"""
Wait for an approval decision.
Args:
request_id: ID of the approval request
timeout_seconds: Override timeout
Returns:
The approval response
Raises:
ApprovalTimeoutError: If timeout expires
ApprovalDeniedError: If approval is denied
"""
request = await self._queue.get_pending(request_id)
if not request:
raise ApprovalRequiredError(
f"Approval request not found: {request_id}",
approval_id=request_id,
)
timeout = timeout_seconds or request.timeout_seconds or self._default_timeout
response = await self._queue.wait_for_response(request_id, timeout)
if response is None:
# Timeout - default deny
response = ApprovalResponse(
request_id=request_id,
status=ApprovalStatus.TIMEOUT,
reason="Request timed out (default deny)",
)
await self._queue.complete(response)
raise ApprovalTimeoutError(
"Approval request timed out",
approval_id=request_id,
timeout_seconds=timeout,
)
if response.status == ApprovalStatus.DENIED:
raise ApprovalDeniedError(
response.reason or "Approval denied",
approval_id=request_id,
denied_by=response.decided_by,
denial_reason=response.reason,
)
if response.status == ApprovalStatus.TIMEOUT:
raise ApprovalTimeoutError(
"Approval request timed out",
approval_id=request_id,
timeout_seconds=timeout,
)
if response.status == ApprovalStatus.CANCELLED:
raise ApprovalDeniedError(
"Approval request was cancelled",
approval_id=request_id,
denial_reason="Cancelled",
)
return response
async def approve(
self,
request_id: str,
decided_by: str,
reason: str | None = None,
modifications: dict[str, Any] | None = None,
) -> bool:
"""
Approve a pending request.
Args:
request_id: ID of the approval request
decided_by: Who approved
reason: Optional approval reason
modifications: Optional modifications to the action
Returns:
True if approval was recorded
"""
response = ApprovalResponse(
request_id=request_id,
status=ApprovalStatus.APPROVED,
decided_by=decided_by,
reason=reason,
modifications=modifications,
)
success = await self._queue.complete(response)
if success:
logger.info(
"Approval granted: %s by %s",
request_id,
decided_by,
)
await self._notify_handlers("approval_granted", response)
return success
async def deny(
self,
request_id: str,
decided_by: str,
reason: str | None = None,
) -> bool:
"""
Deny a pending request.
Args:
request_id: ID of the approval request
decided_by: Who denied
reason: Denial reason
Returns:
True if denial was recorded
"""
response = ApprovalResponse(
request_id=request_id,
status=ApprovalStatus.DENIED,
decided_by=decided_by,
reason=reason,
)
success = await self._queue.complete(response)
if success:
logger.info(
"Approval denied: %s by %s - %s",
request_id,
decided_by,
reason,
)
await self._notify_handlers("approval_denied", response)
return success
async def cancel(self, request_id: str) -> bool:
"""
Cancel a pending request.
Args:
request_id: ID of the approval request
Returns:
True if request was cancelled
"""
success = await self._queue.cancel(request_id)
if success:
logger.info("Approval request cancelled: %s", request_id)
return success
async def list_pending(self) -> list[ApprovalRequest]:
"""List all pending approval requests."""
return await self._queue.list_pending()
async def get_request(self, request_id: str) -> ApprovalRequest | None:
"""Get an approval request by ID."""
return await self._queue.get_pending(request_id)
def add_notification_handler(
self,
handler: Callable[..., Any],
) -> None:
"""Add a notification handler."""
self._notification_handlers.append(handler)
def remove_notification_handler(
self,
handler: Callable[..., Any],
) -> None:
"""Remove a notification handler."""
if handler in self._notification_handlers:
self._notification_handlers.remove(handler)
async def _notify_handlers(
self,
event_type: str,
data: Any,
) -> None:
"""Notify all handlers of an event."""
for handler in self._notification_handlers:
try:
if asyncio.iscoroutinefunction(handler):
await handler(event_type, data)
else:
handler(event_type, data)
except Exception as e:
logger.error("Error in notification handler: %s", e)
async def _periodic_cleanup(self) -> None:
"""Background task for cleaning up expired requests."""
while self._running:
try:
await asyncio.sleep(30) # Check every 30 seconds
count = await self._queue.cleanup_expired()
if count:
logger.debug("Cleaned up %d expired approval requests", count)
except asyncio.CancelledError:
break
except Exception as e:
logger.error("Error in approval cleanup: %s", e)

View File

@@ -0,0 +1,15 @@
"""
Rate Limiting Module
Sliding window rate limiting for agent operations.
"""
from .limiter import (
RateLimiter,
SlidingWindowCounter,
)
__all__ = [
"RateLimiter",
"SlidingWindowCounter",
]

View File

@@ -0,0 +1,368 @@
"""
Rate Limiter
Sliding window rate limiting for agent operations.
"""
import asyncio
import logging
import time
from collections import deque
from ..config import get_safety_config
from ..exceptions import RateLimitExceededError
from ..models import (
ActionRequest,
RateLimitConfig,
RateLimitStatus,
)
logger = logging.getLogger(__name__)
class SlidingWindowCounter:
"""Sliding window counter for rate limiting."""
def __init__(
self,
limit: int,
window_seconds: int,
burst_limit: int | None = None,
) -> None:
self.limit = limit
self.window_seconds = window_seconds
self.burst_limit = burst_limit or limit
self._timestamps: deque[float] = deque()
self._lock = asyncio.Lock()
async def try_acquire(self) -> tuple[bool, float]:
"""
Try to acquire a slot.
Returns:
Tuple of (allowed, retry_after_seconds)
"""
now = time.time()
window_start = now - self.window_seconds
async with self._lock:
# Remove expired entries
while self._timestamps and self._timestamps[0] < window_start:
self._timestamps.popleft()
current_count = len(self._timestamps)
# Check burst limit (instant check)
if current_count >= self.burst_limit:
# Calculate retry time
oldest = self._timestamps[0] if self._timestamps else now
retry_after = oldest + self.window_seconds - now
return False, max(0, retry_after)
# Check window limit
if current_count >= self.limit:
oldest = self._timestamps[0] if self._timestamps else now
retry_after = oldest + self.window_seconds - now
return False, max(0, retry_after)
# Allow and record
self._timestamps.append(now)
return True, 0.0
async def get_status(self) -> tuple[int, int, float]:
"""
Get current status.
Returns:
Tuple of (current_count, remaining, reset_in_seconds)
"""
now = time.time()
window_start = now - self.window_seconds
async with self._lock:
# Remove expired entries
while self._timestamps and self._timestamps[0] < window_start:
self._timestamps.popleft()
current_count = len(self._timestamps)
remaining = max(0, self.limit - current_count)
if self._timestamps:
reset_in = self._timestamps[0] + self.window_seconds - now
else:
reset_in = 0.0
return current_count, remaining, max(0, reset_in)
class RateLimiter:
"""
Rate limiter for agent operations.
Features:
- Per-tool rate limits
- Per-agent rate limits
- Per-resource rate limits
- Sliding window implementation
- Burst allowance with recovery
- Slowdown before hard block
"""
def __init__(self) -> None:
"""Initialize the RateLimiter."""
config = get_safety_config()
self._configs: dict[str, RateLimitConfig] = {}
self._counters: dict[str, SlidingWindowCounter] = {}
self._lock = asyncio.Lock()
# Default rate limits
self._default_limits = {
"actions": RateLimitConfig(
name="actions",
limit=config.default_actions_per_minute,
window_seconds=60,
),
"llm_calls": RateLimitConfig(
name="llm_calls",
limit=config.default_llm_calls_per_minute,
window_seconds=60,
),
"file_ops": RateLimitConfig(
name="file_ops",
limit=config.default_file_ops_per_minute,
window_seconds=60,
),
}
def configure(self, config: RateLimitConfig) -> None:
"""
Configure a rate limit.
Args:
config: Rate limit configuration
"""
self._configs[config.name] = config
logger.debug(
"Configured rate limit: %s = %d/%ds",
config.name,
config.limit,
config.window_seconds,
)
async def check(
self,
limit_name: str,
key: str,
) -> RateLimitStatus:
"""
Check rate limit without consuming a slot.
Args:
limit_name: Name of the rate limit
key: Key for tracking (e.g., agent_id)
Returns:
Rate limit status
"""
counter = await self._get_counter(limit_name, key)
config = self._get_config(limit_name)
current, remaining, reset_in = await counter.get_status()
from datetime import datetime, timedelta
return RateLimitStatus(
name=limit_name,
current_count=current,
limit=config.limit,
window_seconds=config.window_seconds,
remaining=remaining,
reset_at=datetime.utcnow() + timedelta(seconds=reset_in),
is_limited=remaining <= 0,
retry_after_seconds=reset_in if remaining <= 0 else 0.0,
)
async def acquire(
self,
limit_name: str,
key: str,
) -> tuple[bool, RateLimitStatus]:
"""
Try to acquire a rate limit slot.
Args:
limit_name: Name of the rate limit
key: Key for tracking (e.g., agent_id)
Returns:
Tuple of (allowed, status)
"""
counter = await self._get_counter(limit_name, key)
config = self._get_config(limit_name)
allowed, retry_after = await counter.try_acquire()
current, remaining, reset_in = await counter.get_status()
from datetime import datetime, timedelta
status = RateLimitStatus(
name=limit_name,
current_count=current,
limit=config.limit,
window_seconds=config.window_seconds,
remaining=remaining,
reset_at=datetime.utcnow() + timedelta(seconds=reset_in),
is_limited=not allowed,
retry_after_seconds=retry_after,
)
return allowed, status
async def check_action(
self,
action: ActionRequest,
) -> tuple[bool, list[RateLimitStatus]]:
"""
Check all applicable rate limits for an action.
Args:
action: The action to check
Returns:
Tuple of (allowed, list of statuses)
"""
agent_id = action.metadata.agent_id
statuses: list[RateLimitStatus] = []
allowed = True
# Check general actions limit
actions_allowed, actions_status = await self.acquire("actions", agent_id)
statuses.append(actions_status)
if not actions_allowed:
allowed = False
# Check LLM-specific limit for LLM calls
if action.action_type.value == "llm_call":
llm_allowed, llm_status = await self.acquire("llm_calls", agent_id)
statuses.append(llm_status)
if not llm_allowed:
allowed = False
# Check file ops limit for file operations
if action.action_type.value in {"file_read", "file_write", "file_delete"}:
file_allowed, file_status = await self.acquire("file_ops", agent_id)
statuses.append(file_status)
if not file_allowed:
allowed = False
return allowed, statuses
async def require(
self,
limit_name: str,
key: str,
) -> None:
"""
Require rate limit slot or raise exception.
Args:
limit_name: Name of the rate limit
key: Key for tracking
Raises:
RateLimitExceededError: If rate limit exceeded
"""
allowed, status = await self.acquire(limit_name, key)
if not allowed:
raise RateLimitExceededError(
f"Rate limit exceeded: {limit_name}",
limit_type=limit_name,
limit_value=status.limit,
window_seconds=status.window_seconds,
retry_after_seconds=status.retry_after_seconds,
)
async def get_all_statuses(self, key: str) -> dict[str, RateLimitStatus]:
"""
Get status of all rate limits for a key.
Args:
key: Key for tracking
Returns:
Dict of limit name to status
"""
statuses = {}
for name in self._default_limits:
statuses[name] = await self.check(name, key)
for name in self._configs:
if name not in statuses:
statuses[name] = await self.check(name, key)
return statuses
async def reset(self, limit_name: str, key: str) -> bool:
"""
Reset a rate limit counter.
Args:
limit_name: Name of the rate limit
key: Key for tracking
Returns:
True if counter was found and reset
"""
counter_key = f"{limit_name}:{key}"
async with self._lock:
if counter_key in self._counters:
del self._counters[counter_key]
return True
return False
async def reset_all(self, key: str) -> int:
"""
Reset all rate limit counters for a key.
Args:
key: Key for tracking
Returns:
Number of counters reset
"""
count = 0
async with self._lock:
to_remove = [k for k in self._counters if k.endswith(f":{key}")]
for k in to_remove:
del self._counters[k]
count += 1
return count
def _get_config(self, limit_name: str) -> RateLimitConfig:
"""Get configuration for a rate limit."""
if limit_name in self._configs:
return self._configs[limit_name]
if limit_name in self._default_limits:
return self._default_limits[limit_name]
# Return default
return RateLimitConfig(
name=limit_name,
limit=60,
window_seconds=60,
)
async def _get_counter(
self,
limit_name: str,
key: str,
) -> SlidingWindowCounter:
"""Get or create a counter."""
counter_key = f"{limit_name}:{key}"
config = self._get_config(limit_name)
async with self._lock:
if counter_key not in self._counters:
self._counters[counter_key] = SlidingWindowCounter(
limit=config.limit,
window_seconds=config.window_seconds,
burst_limit=config.burst_limit,
)
return self._counters[counter_key]

View File

@@ -0,0 +1,17 @@
"""
Loop Detection Module
Detects and prevents action loops in agent behavior.
"""
from .detector import (
ActionSignature,
LoopBreaker,
LoopDetector,
)
__all__ = [
"ActionSignature",
"LoopBreaker",
"LoopDetector",
]

View File

@@ -0,0 +1,267 @@
"""
Loop Detector
Detects and prevents action loops in agent behavior.
"""
import asyncio
import hashlib
import json
import logging
from collections import Counter, deque
from typing import Any
from ..config import get_safety_config
from ..exceptions import LoopDetectedError
from ..models import ActionRequest
logger = logging.getLogger(__name__)
class ActionSignature:
"""Signature of an action for comparison."""
def __init__(self, action: ActionRequest) -> None:
self.action_type = action.action_type.value
self.tool_name = action.tool_name
self.resource = action.resource
self.args_hash = self._hash_args(action.arguments)
def _hash_args(self, args: dict[str, Any]) -> str:
"""Create a hash of the arguments."""
try:
serialized = json.dumps(args, sort_keys=True, default=str)
return hashlib.sha256(serialized.encode()).hexdigest()[:8]
except Exception:
return ""
def exact_key(self) -> str:
"""Key for exact match detection."""
return f"{self.action_type}:{self.tool_name}:{self.resource}:{self.args_hash}"
def semantic_key(self) -> str:
"""Key for semantic (similar) match detection."""
return f"{self.action_type}:{self.tool_name}:{self.resource}"
def type_key(self) -> str:
"""Key for action type only."""
return f"{self.action_type}"
class LoopDetector:
"""
Detects action loops and repetitive behavior.
Loop Types:
- Exact: Same action with same arguments
- Semantic: Similar actions (same type/tool/resource, different args)
- Oscillation: A→B→A→B patterns
"""
def __init__(
self,
history_size: int | None = None,
max_exact_repetitions: int | None = None,
max_semantic_repetitions: int | None = None,
) -> None:
"""
Initialize the LoopDetector.
Args:
history_size: Size of action history to track
max_exact_repetitions: Max allowed exact repetitions
max_semantic_repetitions: Max allowed semantic repetitions
"""
config = get_safety_config()
self._history_size = history_size or config.loop_history_size
self._max_exact = max_exact_repetitions or config.max_repeated_actions
self._max_semantic = max_semantic_repetitions or config.max_similar_actions
# Per-agent history
self._histories: dict[str, deque[ActionSignature]] = {}
self._lock = asyncio.Lock()
async def check(self, action: ActionRequest) -> tuple[bool, str | None]:
"""
Check if an action would create a loop.
Args:
action: The action to check
Returns:
Tuple of (is_loop, loop_type)
"""
agent_id = action.metadata.agent_id
signature = ActionSignature(action)
async with self._lock:
history = self._get_history(agent_id)
# Check exact repetition
exact_key = signature.exact_key()
exact_count = sum(1 for h in history if h.exact_key() == exact_key)
if exact_count >= self._max_exact:
return True, "exact"
# Check semantic repetition
semantic_key = signature.semantic_key()
semantic_count = sum(1 for h in history if h.semantic_key() == semantic_key)
if semantic_count >= self._max_semantic:
return True, "semantic"
# Check oscillation (A→B→A→B pattern)
if len(history) >= 3:
pattern = self._detect_oscillation(history, signature)
if pattern:
return True, "oscillation"
return False, None
async def check_and_raise(self, action: ActionRequest) -> None:
"""
Check for loops and raise if detected.
Args:
action: The action to check
Raises:
LoopDetectedError: If loop is detected
"""
is_loop, loop_type = await self.check(action)
if is_loop:
signature = ActionSignature(action)
raise LoopDetectedError(
f"Loop detected: {loop_type}",
loop_type=loop_type or "unknown",
repetition_count=self._max_exact if loop_type == "exact" else self._max_semantic,
action_pattern=[signature.semantic_key()],
agent_id=action.metadata.agent_id,
action_id=action.id,
)
async def record(self, action: ActionRequest) -> None:
"""
Record an action in history.
Args:
action: The action to record
"""
agent_id = action.metadata.agent_id
signature = ActionSignature(action)
async with self._lock:
history = self._get_history(agent_id)
history.append(signature)
async def clear_history(self, agent_id: str) -> None:
"""
Clear history for an agent.
Args:
agent_id: ID of the agent
"""
async with self._lock:
if agent_id in self._histories:
self._histories[agent_id].clear()
async def get_stats(self, agent_id: str) -> dict[str, Any]:
"""
Get loop detection stats for an agent.
Args:
agent_id: ID of the agent
Returns:
Stats dictionary
"""
async with self._lock:
history = self._get_history(agent_id)
# Count action types
type_counts = Counter(h.type_key() for h in history)
semantic_counts = Counter(h.semantic_key() for h in history)
return {
"history_size": len(history),
"max_history": self._history_size,
"action_type_counts": dict(type_counts),
"top_semantic_patterns": semantic_counts.most_common(5),
}
def _get_history(self, agent_id: str) -> deque[ActionSignature]:
"""Get or create history for an agent."""
if agent_id not in self._histories:
self._histories[agent_id] = deque(maxlen=self._history_size)
return self._histories[agent_id]
def _detect_oscillation(
self,
history: deque[ActionSignature],
current: ActionSignature,
) -> bool:
"""
Detect A→B→A→B oscillation pattern.
Looks at last 4+ actions including current.
"""
if len(history) < 3:
return False
# Get last 3 actions + current
recent = [*list(history)[-3:], current]
# Check for A→B→A→B pattern
if len(recent) >= 4:
# Get semantic keys
keys = [a.semantic_key() for a in recent[-4:]]
# Pattern: k[0]==k[2] and k[1]==k[3] and k[0]!=k[1]
if keys[0] == keys[2] and keys[1] == keys[3] and keys[0] != keys[1]:
return True
return False
class LoopBreaker:
"""
Strategies for breaking detected loops.
"""
@staticmethod
async def suggest_alternatives(
action: ActionRequest,
loop_type: str,
) -> list[str]:
"""
Suggest alternative actions when loop is detected.
Args:
action: The looping action
loop_type: Type of loop detected
Returns:
List of suggestions
"""
suggestions = []
if loop_type == "exact":
suggestions.append(
"The same action with identical arguments has been repeated too many times. "
"Consider: (1) Verify the action succeeded, (2) Try a different approach, "
"(3) Escalate for human review"
)
elif loop_type == "semantic":
suggestions.append(
"Similar actions have been repeated too many times. "
"Consider: (1) Review if the approach is working, (2) Try an alternative method, "
"(3) Request clarification on the goal"
)
elif loop_type == "oscillation":
suggestions.append(
"An oscillating pattern was detected (A→B→A→B). "
"This usually indicates conflicting goals or a stuck state. "
"Consider: (1) Step back and reassess, (2) Request human guidance"
)
return suggestions

View File

@@ -0,0 +1,17 @@
"""MCP safety integration."""
from .integration import (
MCPSafetyWrapper,
MCPToolCall,
MCPToolResult,
SafeToolExecutor,
create_mcp_wrapper,
)
__all__ = [
"MCPSafetyWrapper",
"MCPToolCall",
"MCPToolResult",
"SafeToolExecutor",
"create_mcp_wrapper",
]

View File

@@ -0,0 +1,405 @@
"""
MCP Safety Integration
Provides safety-aware wrappers for MCP tool execution.
"""
import asyncio
import logging
from collections.abc import Callable
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, ClassVar, TypeVar
from ..audit import AuditLogger
from ..emergency import EmergencyControls, get_emergency_controls
from ..exceptions import (
EmergencyStopError,
SafetyError,
)
from ..guardian import SafetyGuardian, get_safety_guardian
from ..models import (
ActionMetadata,
ActionRequest,
ActionType,
AutonomyLevel,
SafetyDecision,
)
logger = logging.getLogger(__name__)
T = TypeVar("T")
@dataclass
class MCPToolCall:
"""Represents an MCP tool call."""
tool_name: str
arguments: dict[str, Any]
server_name: str | None = None
project_id: str | None = None
context: dict[str, Any] = field(default_factory=dict)
@dataclass
class MCPToolResult:
"""Result of an MCP tool execution."""
success: bool
result: Any = None
error: str | None = None
safety_decision: SafetyDecision = SafetyDecision.ALLOW
execution_time_ms: float = 0.0
approval_id: str | None = None
checkpoint_id: str | None = None
metadata: dict[str, Any] = field(default_factory=dict)
class MCPSafetyWrapper:
"""
Wraps MCP tool execution with safety checks.
Features:
- Pre-execution validation via SafetyGuardian
- Permission checking per tool/resource
- Budget and rate limit enforcement
- Audit logging of all MCP calls
- Emergency stop integration
- Checkpoint creation for destructive operations
"""
# Tool categories for automatic classification
DESTRUCTIVE_TOOLS: ClassVar[set[str]] = {
"file_write",
"file_delete",
"database_mutate",
"shell_execute",
"git_push",
"git_commit",
"deploy",
}
READ_ONLY_TOOLS: ClassVar[set[str]] = {
"file_read",
"database_query",
"git_status",
"git_log",
"list_files",
"search",
}
def __init__(
self,
guardian: SafetyGuardian | None = None,
audit_logger: AuditLogger | None = None,
emergency_controls: EmergencyControls | None = None,
) -> None:
"""
Initialize MCPSafetyWrapper.
Args:
guardian: SafetyGuardian instance (uses singleton if not provided)
audit_logger: AuditLogger instance
emergency_controls: EmergencyControls instance
"""
self._guardian = guardian
self._audit_logger = audit_logger
self._emergency_controls = emergency_controls
self._tool_handlers: dict[str, Callable[..., Any]] = {}
self._lock = asyncio.Lock()
async def _get_guardian(self) -> SafetyGuardian:
"""Get or create SafetyGuardian."""
if self._guardian is None:
self._guardian = await get_safety_guardian()
return self._guardian
async def _get_emergency_controls(self) -> EmergencyControls:
"""Get or create EmergencyControls."""
if self._emergency_controls is None:
self._emergency_controls = await get_emergency_controls()
return self._emergency_controls
def register_tool_handler(
self,
tool_name: str,
handler: Callable[..., Any],
) -> None:
"""
Register a handler for a tool.
Args:
tool_name: Name of the tool
handler: Async function to handle the tool call
"""
self._tool_handlers[tool_name] = handler
logger.debug("Registered handler for tool: %s", tool_name)
async def execute(
self,
tool_call: MCPToolCall,
agent_id: str,
autonomy_level: AutonomyLevel = AutonomyLevel.MILESTONE,
bypass_safety: bool = False,
) -> MCPToolResult:
"""
Execute an MCP tool call with safety checks.
Args:
tool_call: The tool call to execute
agent_id: ID of the calling agent
autonomy_level: Agent's autonomy level
bypass_safety: Bypass safety checks (emergency only)
Returns:
MCPToolResult with execution outcome
"""
start_time = datetime.utcnow()
# Check emergency controls first
emergency = await self._get_emergency_controls()
scope = f"agent:{agent_id}"
if tool_call.project_id:
scope = f"project:{tool_call.project_id}"
try:
await emergency.check_allowed(scope=scope, raise_if_blocked=True)
except EmergencyStopError as e:
return MCPToolResult(
success=False,
error=str(e),
safety_decision=SafetyDecision.DENY,
metadata={"emergency_stop": True},
)
# Build action request
action = self._build_action_request(
tool_call=tool_call,
agent_id=agent_id,
autonomy_level=autonomy_level,
)
# Skip safety checks if bypass is enabled
if bypass_safety:
logger.warning(
"Safety bypass enabled for tool: %s (agent: %s)",
tool_call.tool_name,
agent_id,
)
return await self._execute_tool(tool_call, action, start_time)
# Run safety validation
guardian = await self._get_guardian()
try:
guardian_result = await guardian.validate(action)
except SafetyError as e:
return MCPToolResult(
success=False,
error=str(e),
safety_decision=SafetyDecision.DENY,
execution_time_ms=self._elapsed_ms(start_time),
)
# Handle safety decision
if guardian_result.decision == SafetyDecision.DENY:
return MCPToolResult(
success=False,
error="; ".join(guardian_result.reasons),
safety_decision=SafetyDecision.DENY,
execution_time_ms=self._elapsed_ms(start_time),
)
if guardian_result.decision == SafetyDecision.REQUIRE_APPROVAL:
# For now, just return that approval is required
# The caller should handle the approval flow
return MCPToolResult(
success=False,
error="Action requires human approval",
safety_decision=SafetyDecision.REQUIRE_APPROVAL,
approval_id=guardian_result.approval_id,
execution_time_ms=self._elapsed_ms(start_time),
)
# Execute the tool
result = await self._execute_tool(
tool_call,
action,
start_time,
checkpoint_id=guardian_result.checkpoint_id,
)
return result
async def _execute_tool(
self,
tool_call: MCPToolCall,
action: ActionRequest,
start_time: datetime,
checkpoint_id: str | None = None,
) -> MCPToolResult:
"""Execute the actual tool call."""
handler = self._tool_handlers.get(tool_call.tool_name)
if handler is None:
return MCPToolResult(
success=False,
error=f"No handler registered for tool: {tool_call.tool_name}",
safety_decision=SafetyDecision.ALLOW,
execution_time_ms=self._elapsed_ms(start_time),
)
try:
if asyncio.iscoroutinefunction(handler):
result = await handler(**tool_call.arguments)
else:
result = handler(**tool_call.arguments)
return MCPToolResult(
success=True,
result=result,
safety_decision=SafetyDecision.ALLOW,
execution_time_ms=self._elapsed_ms(start_time),
checkpoint_id=checkpoint_id,
)
except Exception as e:
logger.error("Tool execution failed: %s - %s", tool_call.tool_name, e)
return MCPToolResult(
success=False,
error=str(e),
safety_decision=SafetyDecision.ALLOW,
execution_time_ms=self._elapsed_ms(start_time),
checkpoint_id=checkpoint_id,
)
def _build_action_request(
self,
tool_call: MCPToolCall,
agent_id: str,
autonomy_level: AutonomyLevel,
) -> ActionRequest:
"""Build an ActionRequest from an MCP tool call."""
action_type = self._classify_tool(tool_call.tool_name)
metadata = ActionMetadata(
agent_id=agent_id,
session_id=tool_call.context.get("session_id", ""),
project_id=tool_call.project_id or "",
autonomy_level=autonomy_level,
)
return ActionRequest(
action_type=action_type,
tool_name=tool_call.tool_name,
arguments=tool_call.arguments,
resource=tool_call.arguments.get("path", tool_call.arguments.get("resource")),
metadata=metadata,
)
def _classify_tool(self, tool_name: str) -> ActionType:
"""Classify a tool into an action type."""
tool_lower = tool_name.lower()
# Check destructive patterns
if any(d in tool_lower for d in ["write", "create", "delete", "remove", "update"]):
if "file" in tool_lower:
if "delete" in tool_lower or "remove" in tool_lower:
return ActionType.FILE_DELETE
return ActionType.FILE_WRITE
if "database" in tool_lower or "db" in tool_lower:
return ActionType.DATABASE_MUTATE
# Check read patterns
if any(r in tool_lower for r in ["read", "get", "list", "search", "query"]):
if "file" in tool_lower:
return ActionType.FILE_READ
if "database" in tool_lower or "db" in tool_lower:
return ActionType.DATABASE_QUERY
# Check specific types
if "shell" in tool_lower or "exec" in tool_lower or "bash" in tool_lower:
return ActionType.SHELL_COMMAND
if "git" in tool_lower:
return ActionType.GIT_OPERATION
if "http" in tool_lower or "fetch" in tool_lower or "request" in tool_lower:
return ActionType.NETWORK_REQUEST
if "llm" in tool_lower or "ai" in tool_lower or "claude" in tool_lower:
return ActionType.LLM_CALL
# Default to tool call
return ActionType.TOOL_CALL
def _elapsed_ms(self, start_time: datetime) -> float:
"""Calculate elapsed time in milliseconds."""
return (datetime.utcnow() - start_time).total_seconds() * 1000
class SafeToolExecutor:
"""
Context manager for safe tool execution with automatic cleanup.
Usage:
async with SafeToolExecutor(wrapper, tool_call, agent_id) as executor:
result = await executor.execute()
if result.success:
# Use result
else:
# Handle error or approval required
"""
def __init__(
self,
wrapper: MCPSafetyWrapper,
tool_call: MCPToolCall,
agent_id: str,
autonomy_level: AutonomyLevel = AutonomyLevel.MILESTONE,
) -> None:
self._wrapper = wrapper
self._tool_call = tool_call
self._agent_id = agent_id
self._autonomy_level = autonomy_level
self._result: MCPToolResult | None = None
async def __aenter__(self) -> "SafeToolExecutor":
return self
async def __aexit__(
self,
exc_type: type[Exception] | None,
exc_val: Exception | None,
exc_tb: Any,
) -> bool:
# Could trigger rollback here if needed
return False
async def execute(self) -> MCPToolResult:
"""Execute the tool call."""
self._result = await self._wrapper.execute(
self._tool_call,
self._agent_id,
self._autonomy_level,
)
return self._result
@property
def result(self) -> MCPToolResult | None:
"""Get the execution result."""
return self._result
# Factory function
async def create_mcp_wrapper(
guardian: SafetyGuardian | None = None,
) -> MCPSafetyWrapper:
"""Create an MCPSafetyWrapper with default configuration."""
if guardian is None:
guardian = await get_safety_guardian()
return MCPSafetyWrapper(
guardian=guardian,
emergency_controls=await get_emergency_controls(),
)

View File

@@ -0,0 +1,19 @@
"""Safety metrics collection and export."""
from .collector import (
MetricType,
MetricValue,
SafetyMetrics,
get_safety_metrics,
record_mcp_call,
record_validation,
)
__all__ = [
"MetricType",
"MetricValue",
"SafetyMetrics",
"get_safety_metrics",
"record_mcp_call",
"record_validation",
]

View File

@@ -0,0 +1,416 @@
"""
Safety Metrics Collector
Collects and exposes metrics for the safety framework.
"""
import asyncio
import logging
from collections import Counter, defaultdict
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any
logger = logging.getLogger(__name__)
class MetricType(str, Enum):
"""Types of metrics."""
COUNTER = "counter"
GAUGE = "gauge"
HISTOGRAM = "histogram"
@dataclass
class MetricValue:
"""A single metric value."""
name: str
metric_type: MetricType
value: float
labels: dict[str, str] = field(default_factory=dict)
timestamp: datetime = field(default_factory=datetime.utcnow)
@dataclass
class HistogramBucket:
"""Histogram bucket for distribution metrics."""
le: float # Less than or equal
count: int = 0
class SafetyMetrics:
"""
Collects safety framework metrics.
Metrics tracked:
- Action validation counts (by decision type)
- Approval request counts and latencies
- Budget usage and remaining
- Rate limit hits
- Loop detections
- Emergency events
- Content filter matches
"""
def __init__(self) -> None:
"""Initialize SafetyMetrics."""
self._counters: dict[str, Counter[str]] = defaultdict(Counter)
self._gauges: dict[str, dict[str, float]] = defaultdict(dict)
self._histograms: dict[str, list[float]] = defaultdict(list)
self._histogram_buckets: dict[str, list[HistogramBucket]] = {}
self._lock = asyncio.Lock()
# Initialize histogram buckets
self._init_histogram_buckets()
def _init_histogram_buckets(self) -> None:
"""Initialize histogram buckets for latency metrics."""
latency_buckets = [0.01, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, float("inf")]
for name in [
"validation_latency_seconds",
"approval_latency_seconds",
"mcp_execution_latency_seconds",
]:
self._histogram_buckets[name] = [
HistogramBucket(le=b) for b in latency_buckets
]
# Counter methods
async def inc_validations(
self,
decision: str,
agent_id: str | None = None,
) -> None:
"""Increment validation counter."""
async with self._lock:
labels = f"decision={decision}"
if agent_id:
labels += f",agent_id={agent_id}"
self._counters["safety_validations_total"][labels] += 1
async def inc_approvals_requested(self, urgency: str = "normal") -> None:
"""Increment approval requests counter."""
async with self._lock:
labels = f"urgency={urgency}"
self._counters["safety_approvals_requested_total"][labels] += 1
async def inc_approvals_granted(self) -> None:
"""Increment approvals granted counter."""
async with self._lock:
self._counters["safety_approvals_granted_total"][""] += 1
async def inc_approvals_denied(self, reason: str = "manual") -> None:
"""Increment approvals denied counter."""
async with self._lock:
labels = f"reason={reason}"
self._counters["safety_approvals_denied_total"][labels] += 1
async def inc_rate_limit_exceeded(self, limit_type: str) -> None:
"""Increment rate limit exceeded counter."""
async with self._lock:
labels = f"limit_type={limit_type}"
self._counters["safety_rate_limit_exceeded_total"][labels] += 1
async def inc_budget_exceeded(self, budget_type: str) -> None:
"""Increment budget exceeded counter."""
async with self._lock:
labels = f"budget_type={budget_type}"
self._counters["safety_budget_exceeded_total"][labels] += 1
async def inc_loops_detected(self, loop_type: str) -> None:
"""Increment loop detection counter."""
async with self._lock:
labels = f"loop_type={loop_type}"
self._counters["safety_loops_detected_total"][labels] += 1
async def inc_emergency_events(self, event_type: str, scope: str) -> None:
"""Increment emergency events counter."""
async with self._lock:
labels = f"event_type={event_type},scope={scope}"
self._counters["safety_emergency_events_total"][labels] += 1
async def inc_content_filtered(self, category: str, action: str) -> None:
"""Increment content filter counter."""
async with self._lock:
labels = f"category={category},action={action}"
self._counters["safety_content_filtered_total"][labels] += 1
async def inc_checkpoints_created(self) -> None:
"""Increment checkpoints created counter."""
async with self._lock:
self._counters["safety_checkpoints_created_total"][""] += 1
async def inc_rollbacks_executed(self, success: bool) -> None:
"""Increment rollbacks counter."""
async with self._lock:
labels = f"success={str(success).lower()}"
self._counters["safety_rollbacks_total"][labels] += 1
async def inc_mcp_calls(self, tool_name: str, success: bool) -> None:
"""Increment MCP tool calls counter."""
async with self._lock:
labels = f"tool_name={tool_name},success={str(success).lower()}"
self._counters["safety_mcp_calls_total"][labels] += 1
# Gauge methods
async def set_budget_remaining(
self,
scope: str,
budget_type: str,
remaining: float,
) -> None:
"""Set remaining budget gauge."""
async with self._lock:
labels = f"scope={scope},budget_type={budget_type}"
self._gauges["safety_budget_remaining"][labels] = remaining
async def set_rate_limit_remaining(
self,
scope: str,
limit_type: str,
remaining: int,
) -> None:
"""Set remaining rate limit gauge."""
async with self._lock:
labels = f"scope={scope},limit_type={limit_type}"
self._gauges["safety_rate_limit_remaining"][labels] = float(remaining)
async def set_pending_approvals(self, count: int) -> None:
"""Set pending approvals gauge."""
async with self._lock:
self._gauges["safety_pending_approvals"][""] = float(count)
async def set_active_checkpoints(self, count: int) -> None:
"""Set active checkpoints gauge."""
async with self._lock:
self._gauges["safety_active_checkpoints"][""] = float(count)
async def set_emergency_state(self, scope: str, state: str) -> None:
"""Set emergency state gauge (0=normal, 1=paused, 2=stopped)."""
async with self._lock:
state_value = {"normal": 0, "paused": 1, "stopped": 2}.get(state, -1)
labels = f"scope={scope}"
self._gauges["safety_emergency_state"][labels] = float(state_value)
# Histogram methods
async def observe_validation_latency(self, latency_seconds: float) -> None:
"""Observe validation latency."""
async with self._lock:
self._observe_histogram("validation_latency_seconds", latency_seconds)
async def observe_approval_latency(self, latency_seconds: float) -> None:
"""Observe approval latency."""
async with self._lock:
self._observe_histogram("approval_latency_seconds", latency_seconds)
async def observe_mcp_execution_latency(self, latency_seconds: float) -> None:
"""Observe MCP execution latency."""
async with self._lock:
self._observe_histogram("mcp_execution_latency_seconds", latency_seconds)
def _observe_histogram(self, name: str, value: float) -> None:
"""Record a value in a histogram."""
self._histograms[name].append(value)
# Update buckets
if name in self._histogram_buckets:
for bucket in self._histogram_buckets[name]:
if value <= bucket.le:
bucket.count += 1
# Export methods
async def get_all_metrics(self) -> list[MetricValue]:
"""Get all metrics as MetricValue objects."""
metrics: list[MetricValue] = []
async with self._lock:
# Export counters
for name, counter in self._counters.items():
for labels_str, value in counter.items():
labels = self._parse_labels(labels_str)
metrics.append(
MetricValue(
name=name,
metric_type=MetricType.COUNTER,
value=float(value),
labels=labels,
)
)
# Export gauges
for name, gauge_dict in self._gauges.items():
for labels_str, gauge_value in gauge_dict.items():
gauge_labels = self._parse_labels(labels_str)
metrics.append(
MetricValue(
name=name,
metric_type=MetricType.GAUGE,
value=gauge_value,
labels=gauge_labels,
)
)
# Export histogram summaries
for name, values in self._histograms.items():
if values:
metrics.append(
MetricValue(
name=f"{name}_count",
metric_type=MetricType.COUNTER,
value=float(len(values)),
)
)
metrics.append(
MetricValue(
name=f"{name}_sum",
metric_type=MetricType.COUNTER,
value=sum(values),
)
)
return metrics
async def get_prometheus_format(self) -> str:
"""Export metrics in Prometheus text format."""
lines: list[str] = []
async with self._lock:
# Export counters
for name, counter in self._counters.items():
lines.append(f"# TYPE {name} counter")
for labels_str, value in counter.items():
if labels_str:
lines.append(f"{name}{{{labels_str}}} {value}")
else:
lines.append(f"{name} {value}")
# Export gauges
for name, gauge_dict in self._gauges.items():
lines.append(f"# TYPE {name} gauge")
for labels_str, gauge_value in gauge_dict.items():
if labels_str:
lines.append(f"{name}{{{labels_str}}} {gauge_value}")
else:
lines.append(f"{name} {gauge_value}")
# Export histograms
for name, buckets in self._histogram_buckets.items():
lines.append(f"# TYPE {name} histogram")
for bucket in buckets:
le_str = "+Inf" if bucket.le == float("inf") else str(bucket.le)
lines.append(f'{name}_bucket{{le="{le_str}"}} {bucket.count}')
if name in self._histograms:
values = self._histograms[name]
lines.append(f"{name}_count {len(values)}")
lines.append(f"{name}_sum {sum(values)}")
return "\n".join(lines)
async def get_summary(self) -> dict[str, Any]:
"""Get a summary of key metrics."""
async with self._lock:
total_validations = sum(self._counters["safety_validations_total"].values())
denied_validations = sum(
v for k, v in self._counters["safety_validations_total"].items()
if "decision=deny" in k
)
return {
"total_validations": total_validations,
"denied_validations": denied_validations,
"approval_requests": sum(
self._counters["safety_approvals_requested_total"].values()
),
"approvals_granted": sum(
self._counters["safety_approvals_granted_total"].values()
),
"approvals_denied": sum(
self._counters["safety_approvals_denied_total"].values()
),
"rate_limit_hits": sum(
self._counters["safety_rate_limit_exceeded_total"].values()
),
"budget_exceeded": sum(
self._counters["safety_budget_exceeded_total"].values()
),
"loops_detected": sum(
self._counters["safety_loops_detected_total"].values()
),
"emergency_events": sum(
self._counters["safety_emergency_events_total"].values()
),
"content_filtered": sum(
self._counters["safety_content_filtered_total"].values()
),
"checkpoints_created": sum(
self._counters["safety_checkpoints_created_total"].values()
),
"rollbacks_executed": sum(
self._counters["safety_rollbacks_total"].values()
),
"mcp_calls": sum(
self._counters["safety_mcp_calls_total"].values()
),
"pending_approvals": self._gauges.get("safety_pending_approvals", {}).get("", 0),
"active_checkpoints": self._gauges.get("safety_active_checkpoints", {}).get("", 0),
}
async def reset(self) -> None:
"""Reset all metrics."""
async with self._lock:
self._counters.clear()
self._gauges.clear()
self._histograms.clear()
self._init_histogram_buckets()
def _parse_labels(self, labels_str: str) -> dict[str, str]:
"""Parse labels string into dictionary."""
if not labels_str:
return {}
labels = {}
for pair in labels_str.split(","):
if "=" in pair:
key, value = pair.split("=", 1)
labels[key.strip()] = value.strip()
return labels
# Singleton instance
_metrics: SafetyMetrics | None = None
_lock = asyncio.Lock()
async def get_safety_metrics() -> SafetyMetrics:
"""Get the singleton SafetyMetrics instance."""
global _metrics
async with _lock:
if _metrics is None:
_metrics = SafetyMetrics()
return _metrics
# Convenience functions
async def record_validation(decision: str, agent_id: str | None = None) -> None:
"""Record a validation event."""
metrics = await get_safety_metrics()
await metrics.inc_validations(decision, agent_id)
async def record_mcp_call(tool_name: str, success: bool, latency_ms: float) -> None:
"""Record an MCP tool call."""
metrics = await get_safety_metrics()
await metrics.inc_mcp_calls(tool_name, success)
await metrics.observe_mcp_execution_latency(latency_ms / 1000)

View File

@@ -0,0 +1,474 @@
"""
Safety Framework Models
Core Pydantic models for actions, events, policies, and safety decisions.
"""
from datetime import datetime
from enum import Enum
from typing import Any
from uuid import uuid4
from pydantic import BaseModel, Field
# ============================================================================
# Enums
# ============================================================================
class ActionType(str, Enum):
"""Types of actions that can be performed."""
TOOL_CALL = "tool_call"
FILE_READ = "file_read"
FILE_WRITE = "file_write"
FILE_DELETE = "file_delete"
API_CALL = "api_call"
DATABASE_QUERY = "database_query"
DATABASE_MUTATE = "database_mutate"
GIT_OPERATION = "git_operation"
SHELL_COMMAND = "shell_command"
LLM_CALL = "llm_call"
NETWORK_REQUEST = "network_request"
CUSTOM = "custom"
class ResourceType(str, Enum):
"""Types of resources that can be accessed."""
FILE = "file"
DATABASE = "database"
API = "api"
NETWORK = "network"
GIT = "git"
SHELL = "shell"
LLM = "llm"
MEMORY = "memory"
CUSTOM = "custom"
class PermissionLevel(str, Enum):
"""Permission levels for resource access."""
NONE = "none"
READ = "read"
WRITE = "write"
EXECUTE = "execute"
DELETE = "delete"
ADMIN = "admin"
class AutonomyLevel(str, Enum):
"""Autonomy levels for agent operation."""
FULL_CONTROL = "full_control" # Approve every action
MILESTONE = "milestone" # Approve at milestones
AUTONOMOUS = "autonomous" # Only major decisions
class SafetyDecision(str, Enum):
"""Result of safety validation."""
ALLOW = "allow"
DENY = "deny"
REQUIRE_APPROVAL = "require_approval"
DELAY = "delay"
SANDBOX = "sandbox"
class ApprovalStatus(str, Enum):
"""Status of approval request."""
PENDING = "pending"
APPROVED = "approved"
DENIED = "denied"
TIMEOUT = "timeout"
CANCELLED = "cancelled"
class AuditEventType(str, Enum):
"""Types of audit events."""
ACTION_REQUESTED = "action_requested"
ACTION_VALIDATED = "action_validated"
ACTION_DENIED = "action_denied"
ACTION_EXECUTED = "action_executed"
ACTION_FAILED = "action_failed"
APPROVAL_REQUESTED = "approval_requested"
APPROVAL_GRANTED = "approval_granted"
APPROVAL_DENIED = "approval_denied"
APPROVAL_TIMEOUT = "approval_timeout"
CHECKPOINT_CREATED = "checkpoint_created"
ROLLBACK_STARTED = "rollback_started"
ROLLBACK_COMPLETED = "rollback_completed"
ROLLBACK_FAILED = "rollback_failed"
BUDGET_WARNING = "budget_warning"
BUDGET_EXCEEDED = "budget_exceeded"
RATE_LIMITED = "rate_limited"
LOOP_DETECTED = "loop_detected"
EMERGENCY_STOP = "emergency_stop"
POLICY_VIOLATION = "policy_violation"
CONTENT_FILTERED = "content_filtered"
# ============================================================================
# Action Models
# ============================================================================
class ActionMetadata(BaseModel):
"""Metadata associated with an action."""
agent_id: str = Field(..., description="ID of the agent performing the action")
project_id: str | None = Field(None, description="ID of the project context")
session_id: str | None = Field(None, description="ID of the current session")
task_id: str | None = Field(None, description="ID of the current task")
parent_action_id: str | None = Field(None, description="ID of the parent action")
correlation_id: str | None = Field(None, description="Correlation ID for tracing")
user_id: str | None = Field(None, description="ID of the user who initiated")
autonomy_level: AutonomyLevel = Field(
default=AutonomyLevel.MILESTONE,
description="Current autonomy level",
)
context: dict[str, Any] = Field(
default_factory=dict,
description="Additional context",
)
class ActionRequest(BaseModel):
"""Request to perform an action."""
id: str = Field(default_factory=lambda: str(uuid4()))
action_type: ActionType = Field(..., description="Type of action to perform")
tool_name: str | None = Field(None, description="Name of the tool to call")
resource: str | None = Field(None, description="Resource being accessed")
resource_type: ResourceType | None = Field(None, description="Type of resource")
arguments: dict[str, Any] = Field(
default_factory=dict,
description="Action arguments",
)
metadata: ActionMetadata = Field(..., description="Action metadata")
estimated_cost_tokens: int = Field(0, description="Estimated token cost")
estimated_cost_usd: float = Field(0.0, description="Estimated USD cost")
is_destructive: bool = Field(False, description="Whether action is destructive")
is_reversible: bool = Field(True, description="Whether action can be rolled back")
timestamp: datetime = Field(default_factory=datetime.utcnow)
class ActionResult(BaseModel):
"""Result of an executed action."""
action_id: str = Field(..., description="ID of the action")
success: bool = Field(..., description="Whether action succeeded")
data: Any = Field(None, description="Action result data")
error: str | None = Field(None, description="Error message if failed")
error_code: str | None = Field(None, description="Error code if failed")
execution_time_ms: float = Field(0.0, description="Execution time in ms")
actual_cost_tokens: int = Field(0, description="Actual token cost")
actual_cost_usd: float = Field(0.0, description="Actual USD cost")
checkpoint_id: str | None = Field(None, description="Checkpoint ID if created")
timestamp: datetime = Field(default_factory=datetime.utcnow)
# ============================================================================
# Validation Models
# ============================================================================
class ValidationRule(BaseModel):
"""A single validation rule."""
id: str = Field(default_factory=lambda: str(uuid4()))
name: str = Field(..., description="Rule name")
description: str | None = Field(None, description="Rule description")
priority: int = Field(0, description="Rule priority (higher = evaluated first)")
enabled: bool = Field(True, description="Whether rule is enabled")
# Rule conditions
action_types: list[ActionType] | None = Field(
None, description="Action types this rule applies to"
)
tool_patterns: list[str] | None = Field(
None, description="Tool name patterns (supports wildcards)"
)
resource_patterns: list[str] | None = Field(
None, description="Resource patterns (supports wildcards)"
)
agent_ids: list[str] | None = Field(
None, description="Agent IDs this rule applies to"
)
# Rule decision
decision: SafetyDecision = Field(..., description="Decision when rule matches")
reason: str | None = Field(None, description="Reason for decision")
class ValidationResult(BaseModel):
"""Result of action validation."""
action_id: str = Field(..., description="ID of the validated action")
decision: SafetyDecision = Field(..., description="Validation decision")
applied_rules: list[str] = Field(
default_factory=list, description="IDs of applied rules"
)
reasons: list[str] = Field(
default_factory=list, description="Reasons for decision"
)
approval_id: str | None = Field(None, description="Approval request ID if needed")
retry_after_seconds: float | None = Field(
None, description="Retry delay if rate limited"
)
timestamp: datetime = Field(default_factory=datetime.utcnow)
# ============================================================================
# Budget Models
# ============================================================================
class BudgetScope(str, Enum):
"""Scope of a budget limit."""
SESSION = "session"
DAILY = "daily"
WEEKLY = "weekly"
MONTHLY = "monthly"
PROJECT = "project"
AGENT = "agent"
class BudgetStatus(BaseModel):
"""Current budget status."""
scope: BudgetScope = Field(..., description="Budget scope")
scope_id: str = Field(..., description="ID within scope (session/agent/project)")
tokens_used: int = Field(0, description="Tokens used in this scope")
tokens_limit: int = Field(100000, description="Token limit for this scope")
cost_used_usd: float = Field(0.0, description="USD spent in this scope")
cost_limit_usd: float = Field(10.0, description="USD limit for this scope")
tokens_remaining: int = Field(0, description="Remaining tokens")
cost_remaining_usd: float = Field(0.0, description="Remaining USD budget")
warning_threshold: float = Field(0.8, description="Warn at this usage fraction")
is_warning: bool = Field(False, description="Whether at warning level")
is_exceeded: bool = Field(False, description="Whether budget exceeded")
reset_at: datetime | None = Field(None, description="When budget resets")
# ============================================================================
# Rate Limit Models
# ============================================================================
class RateLimitConfig(BaseModel):
"""Configuration for a rate limit."""
name: str = Field(..., description="Rate limit name")
limit: int = Field(..., description="Maximum allowed in window")
window_seconds: int = Field(60, description="Time window in seconds")
burst_limit: int | None = Field(None, description="Burst allowance")
slowdown_threshold: float = Field(
0.8, description="Start slowing at this fraction"
)
class RateLimitStatus(BaseModel):
"""Current rate limit status."""
name: str = Field(..., description="Rate limit name")
current_count: int = Field(0, description="Current count in window")
limit: int = Field(..., description="Maximum allowed")
window_seconds: int = Field(..., description="Time window")
remaining: int = Field(..., description="Remaining in window")
reset_at: datetime = Field(..., description="When window resets")
is_limited: bool = Field(False, description="Whether currently limited")
retry_after_seconds: float = Field(0.0, description="Seconds until retry")
# ============================================================================
# Approval Models
# ============================================================================
class ApprovalRequest(BaseModel):
"""Request for human approval."""
id: str = Field(default_factory=lambda: str(uuid4()))
action: ActionRequest = Field(..., description="Action requiring approval")
reason: str = Field(..., description="Why approval is required")
urgency: str = Field("normal", description="Urgency level")
timeout_seconds: int = Field(300, description="Timeout for approval")
created_at: datetime = Field(default_factory=datetime.utcnow)
expires_at: datetime | None = Field(None, description="When request expires")
suggested_action: str | None = Field(None, description="Suggested response")
context: dict[str, Any] = Field(default_factory=dict, description="Extra context")
class ApprovalResponse(BaseModel):
"""Response to an approval request."""
request_id: str = Field(..., description="ID of the approval request")
status: ApprovalStatus = Field(..., description="Approval status")
decided_by: str | None = Field(None, description="Who made the decision")
reason: str | None = Field(None, description="Reason for decision")
modifications: dict[str, Any] | None = Field(
None, description="Modifications to action"
)
decided_at: datetime = Field(default_factory=datetime.utcnow)
# ============================================================================
# Checkpoint/Rollback Models
# ============================================================================
class CheckpointType(str, Enum):
"""Types of checkpoints."""
FILE = "file"
DATABASE = "database"
GIT = "git"
COMPOSITE = "composite"
class Checkpoint(BaseModel):
"""A rollback checkpoint."""
id: str = Field(default_factory=lambda: str(uuid4()))
checkpoint_type: CheckpointType = Field(..., description="Type of checkpoint")
action_id: str = Field(..., description="Action this checkpoint is for")
created_at: datetime = Field(default_factory=datetime.utcnow)
expires_at: datetime | None = Field(None, description="When checkpoint expires")
data: dict[str, Any] = Field(default_factory=dict, description="Checkpoint data")
description: str | None = Field(None, description="Description of checkpoint")
is_valid: bool = Field(True, description="Whether checkpoint is still valid")
class RollbackResult(BaseModel):
"""Result of a rollback operation."""
checkpoint_id: str = Field(..., description="ID of checkpoint rolled back to")
success: bool = Field(..., description="Whether rollback succeeded")
actions_rolled_back: list[str] = Field(
default_factory=list, description="IDs of rolled back actions"
)
failed_actions: list[str] = Field(
default_factory=list, description="IDs of actions that failed to rollback"
)
error: str | None = Field(None, description="Error message if failed")
timestamp: datetime = Field(default_factory=datetime.utcnow)
# ============================================================================
# Audit Models
# ============================================================================
class AuditEvent(BaseModel):
"""An audit log event."""
id: str = Field(default_factory=lambda: str(uuid4()))
event_type: AuditEventType = Field(..., description="Type of audit event")
timestamp: datetime = Field(default_factory=datetime.utcnow)
agent_id: str | None = Field(None, description="Agent ID if applicable")
action_id: str | None = Field(None, description="Action ID if applicable")
project_id: str | None = Field(None, description="Project ID if applicable")
session_id: str | None = Field(None, description="Session ID if applicable")
user_id: str | None = Field(None, description="User ID if applicable")
decision: SafetyDecision | None = Field(None, description="Safety decision")
details: dict[str, Any] = Field(default_factory=dict, description="Event details")
correlation_id: str | None = Field(None, description="Correlation ID for tracing")
# ============================================================================
# Policy Models
# ============================================================================
class SafetyPolicy(BaseModel):
"""A complete safety policy configuration."""
name: str = Field(..., description="Policy name")
description: str | None = Field(None, description="Policy description")
version: str = Field("1.0.0", description="Policy version")
enabled: bool = Field(True, description="Whether policy is enabled")
# Cost controls
max_tokens_per_session: int = Field(100_000, description="Max tokens per session")
max_tokens_per_day: int = Field(1_000_000, description="Max tokens per day")
max_cost_per_session_usd: float = Field(10.0, description="Max USD per session")
max_cost_per_day_usd: float = Field(100.0, description="Max USD per day")
# Rate limits
max_actions_per_minute: int = Field(60, description="Max actions per minute")
max_llm_calls_per_minute: int = Field(20, description="Max LLM calls per minute")
max_file_operations_per_minute: int = Field(
100, description="Max file ops per minute"
)
# Permissions
allowed_tools: list[str] = Field(
default_factory=lambda: ["*"],
description="Allowed tool patterns",
)
denied_tools: list[str] = Field(
default_factory=list,
description="Denied tool patterns",
)
allowed_file_patterns: list[str] = Field(
default_factory=lambda: ["**/*"],
description="Allowed file patterns",
)
denied_file_patterns: list[str] = Field(
default_factory=lambda: ["**/.env", "**/secrets/**"],
description="Denied file patterns",
)
# HITL
require_approval_for: list[str] = Field(
default_factory=lambda: [
"delete_file",
"push_to_remote",
"deploy_to_production",
"modify_critical_config",
],
description="Actions requiring approval",
)
# Loop detection
max_repeated_actions: int = Field(5, description="Max exact repetitions")
max_similar_actions: int = Field(10, description="Max similar actions")
# Sandbox
require_sandbox: bool = Field(False, description="Require sandbox execution")
sandbox_timeout_seconds: int = Field(300, description="Sandbox timeout")
sandbox_memory_mb: int = Field(1024, description="Sandbox memory limit")
# Validation rules
validation_rules: list[ValidationRule] = Field(
default_factory=list,
description="Custom validation rules",
)
# ============================================================================
# Guardian Result Models
# ============================================================================
class GuardianResult(BaseModel):
"""Result of SafetyGuardian evaluation."""
action_id: str = Field(..., description="ID of the action")
allowed: bool = Field(..., description="Whether action is allowed")
decision: SafetyDecision = Field(..., description="Safety decision")
reasons: list[str] = Field(default_factory=list, description="Decision reasons")
approval_id: str | None = Field(None, description="Approval ID if needed")
checkpoint_id: str | None = Field(None, description="Checkpoint ID if created")
retry_after_seconds: float | None = Field(None, description="Retry delay")
modified_action: ActionRequest | None = Field(
None, description="Modified action if changed"
)
audit_events: list[AuditEvent] = Field(
default_factory=list, description="Generated audit events"
)

View File

@@ -0,0 +1,15 @@
"""
Permission Management Module
Agent permissions for resource access.
"""
from .manager import (
PermissionGrant,
PermissionManager,
)
__all__ = [
"PermissionGrant",
"PermissionManager",
]

View File

@@ -0,0 +1,384 @@
"""
Permission Manager
Manages permissions for agent actions on resources.
"""
import asyncio
import fnmatch
import logging
from datetime import datetime, timedelta
from uuid import uuid4
from ..exceptions import PermissionDeniedError
from ..models import (
ActionRequest,
ActionType,
PermissionLevel,
ResourceType,
)
logger = logging.getLogger(__name__)
class PermissionGrant:
"""A permission grant for an agent on a resource."""
def __init__(
self,
agent_id: str,
resource_pattern: str,
resource_type: ResourceType,
level: PermissionLevel,
*,
expires_at: datetime | None = None,
granted_by: str | None = None,
reason: str | None = None,
) -> None:
self.id = str(uuid4())
self.agent_id = agent_id
self.resource_pattern = resource_pattern
self.resource_type = resource_type
self.level = level
self.expires_at = expires_at
self.granted_by = granted_by
self.reason = reason
self.created_at = datetime.utcnow()
def is_expired(self) -> bool:
"""Check if the grant has expired."""
if self.expires_at is None:
return False
return datetime.utcnow() > self.expires_at
def matches(self, resource: str, resource_type: ResourceType) -> bool:
"""Check if this grant applies to a resource."""
if self.resource_type != resource_type:
return False
return fnmatch.fnmatch(resource, self.resource_pattern)
def allows(self, required_level: PermissionLevel) -> bool:
"""Check if this grant allows the required permission level."""
# Permission level hierarchy
hierarchy = {
PermissionLevel.NONE: 0,
PermissionLevel.READ: 1,
PermissionLevel.WRITE: 2,
PermissionLevel.EXECUTE: 3,
PermissionLevel.DELETE: 4,
PermissionLevel.ADMIN: 5,
}
return hierarchy[self.level] >= hierarchy[required_level]
class PermissionManager:
"""
Manages permissions for agent access to resources.
Features:
- Permission grants by agent/resource pattern
- Permission inheritance (project → agent → action)
- Temporary permissions with expiration
- Least-privilege defaults
- Permission escalation logging
"""
def __init__(
self,
default_deny: bool = True,
) -> None:
"""
Initialize the PermissionManager.
Args:
default_deny: If True, deny access unless explicitly granted
"""
self._grants: list[PermissionGrant] = []
self._default_deny = default_deny
self._lock = asyncio.Lock()
# Default permissions for common resources
self._default_permissions: dict[ResourceType, PermissionLevel] = {
ResourceType.FILE: PermissionLevel.READ,
ResourceType.DATABASE: PermissionLevel.READ,
ResourceType.API: PermissionLevel.READ,
ResourceType.GIT: PermissionLevel.READ,
ResourceType.LLM: PermissionLevel.EXECUTE,
ResourceType.SHELL: PermissionLevel.NONE,
ResourceType.NETWORK: PermissionLevel.READ,
}
async def grant(
self,
agent_id: str,
resource_pattern: str,
resource_type: ResourceType,
level: PermissionLevel,
*,
duration_seconds: int | None = None,
granted_by: str | None = None,
reason: str | None = None,
) -> PermissionGrant:
"""
Grant a permission to an agent.
Args:
agent_id: ID of the agent
resource_pattern: Pattern for matching resources (supports wildcards)
resource_type: Type of resource
level: Permission level to grant
duration_seconds: Optional duration for temporary permission
granted_by: Who granted the permission
reason: Reason for granting
Returns:
The created permission grant
"""
expires_at = None
if duration_seconds:
expires_at = datetime.utcnow() + timedelta(seconds=duration_seconds)
grant = PermissionGrant(
agent_id=agent_id,
resource_pattern=resource_pattern,
resource_type=resource_type,
level=level,
expires_at=expires_at,
granted_by=granted_by,
reason=reason,
)
async with self._lock:
self._grants.append(grant)
logger.info(
"Permission granted: agent=%s, resource=%s, type=%s, level=%s",
agent_id,
resource_pattern,
resource_type.value,
level.value,
)
return grant
async def revoke(self, grant_id: str) -> bool:
"""
Revoke a permission grant.
Args:
grant_id: ID of the grant to revoke
Returns:
True if grant was found and revoked
"""
async with self._lock:
for i, grant in enumerate(self._grants):
if grant.id == grant_id:
del self._grants[i]
logger.info("Permission revoked: %s", grant_id)
return True
return False
async def revoke_all(self, agent_id: str) -> int:
"""
Revoke all permissions for an agent.
Args:
agent_id: ID of the agent
Returns:
Number of grants revoked
"""
async with self._lock:
original_count = len(self._grants)
self._grants = [g for g in self._grants if g.agent_id != agent_id]
revoked = original_count - len(self._grants)
if revoked:
logger.info("Revoked %d permissions for agent %s", revoked, agent_id)
return revoked
async def check(
self,
agent_id: str,
resource: str,
resource_type: ResourceType,
required_level: PermissionLevel,
) -> bool:
"""
Check if an agent has permission to access a resource.
Args:
agent_id: ID of the agent
resource: Resource to access
resource_type: Type of resource
required_level: Required permission level
Returns:
True if access is allowed
"""
# Clean up expired grants
await self._cleanup_expired()
async with self._lock:
for grant in self._grants:
if grant.agent_id != agent_id:
continue
if grant.is_expired():
continue
if grant.matches(resource, resource_type):
if grant.allows(required_level):
return True
# Check default permissions
if not self._default_deny:
default_level = self._default_permissions.get(
resource_type, PermissionLevel.NONE
)
hierarchy = {
PermissionLevel.NONE: 0,
PermissionLevel.READ: 1,
PermissionLevel.WRITE: 2,
PermissionLevel.EXECUTE: 3,
PermissionLevel.DELETE: 4,
PermissionLevel.ADMIN: 5,
}
if hierarchy[default_level] >= hierarchy[required_level]:
return True
return False
async def check_action(self, action: ActionRequest) -> bool:
"""
Check if an action is permitted.
Args:
action: The action to check
Returns:
True if action is allowed
"""
# Determine required permission level from action type
level_map = {
ActionType.FILE_READ: PermissionLevel.READ,
ActionType.FILE_WRITE: PermissionLevel.WRITE,
ActionType.FILE_DELETE: PermissionLevel.DELETE,
ActionType.DATABASE_QUERY: PermissionLevel.READ,
ActionType.DATABASE_MUTATE: PermissionLevel.WRITE,
ActionType.SHELL_COMMAND: PermissionLevel.EXECUTE,
ActionType.API_CALL: PermissionLevel.EXECUTE,
ActionType.GIT_OPERATION: PermissionLevel.WRITE,
ActionType.LLM_CALL: PermissionLevel.EXECUTE,
ActionType.NETWORK_REQUEST: PermissionLevel.READ,
ActionType.TOOL_CALL: PermissionLevel.EXECUTE,
}
required_level = level_map.get(action.action_type, PermissionLevel.EXECUTE)
# Determine resource type from action
resource_type_map = {
ActionType.FILE_READ: ResourceType.FILE,
ActionType.FILE_WRITE: ResourceType.FILE,
ActionType.FILE_DELETE: ResourceType.FILE,
ActionType.DATABASE_QUERY: ResourceType.DATABASE,
ActionType.DATABASE_MUTATE: ResourceType.DATABASE,
ActionType.SHELL_COMMAND: ResourceType.SHELL,
ActionType.API_CALL: ResourceType.API,
ActionType.GIT_OPERATION: ResourceType.GIT,
ActionType.LLM_CALL: ResourceType.LLM,
ActionType.NETWORK_REQUEST: ResourceType.NETWORK,
}
resource_type = resource_type_map.get(action.action_type, ResourceType.CUSTOM)
resource = action.resource or action.tool_name or "*"
return await self.check(
agent_id=action.metadata.agent_id,
resource=resource,
resource_type=resource_type,
required_level=required_level,
)
async def require_permission(
self,
agent_id: str,
resource: str,
resource_type: ResourceType,
required_level: PermissionLevel,
) -> None:
"""
Require permission or raise exception.
Args:
agent_id: ID of the agent
resource: Resource to access
resource_type: Type of resource
required_level: Required permission level
Raises:
PermissionDeniedError: If permission is denied
"""
if not await self.check(agent_id, resource, resource_type, required_level):
raise PermissionDeniedError(
f"Permission denied: {resource}",
action_type=None,
resource=resource,
required_permission=required_level.value,
agent_id=agent_id,
)
async def list_grants(
self,
agent_id: str | None = None,
resource_type: ResourceType | None = None,
) -> list[PermissionGrant]:
"""
List permission grants.
Args:
agent_id: Optional filter by agent
resource_type: Optional filter by resource type
Returns:
List of matching grants
"""
await self._cleanup_expired()
async with self._lock:
grants = list(self._grants)
if agent_id:
grants = [g for g in grants if g.agent_id == agent_id]
if resource_type:
grants = [g for g in grants if g.resource_type == resource_type]
return grants
def set_default_permission(
self,
resource_type: ResourceType,
level: PermissionLevel,
) -> None:
"""
Set the default permission level for a resource type.
Args:
resource_type: Type of resource
level: Default permission level
"""
self._default_permissions[resource_type] = level
async def _cleanup_expired(self) -> None:
"""Remove expired grants."""
async with self._lock:
original_count = len(self._grants)
self._grants = [g for g in self._grants if not g.is_expired()]
removed = original_count - len(self._grants)
if removed:
logger.debug("Cleaned up %d expired permission grants", removed)

View File

@@ -0,0 +1 @@
"""${dir} module."""

View File

@@ -0,0 +1,5 @@
"""Rollback management for agent actions."""
from .manager import RollbackManager, TransactionContext
__all__ = ["RollbackManager", "TransactionContext"]

View File

@@ -0,0 +1,418 @@
"""
Rollback Manager
Manages checkpoints and rollback operations for agent actions.
"""
import asyncio
import logging
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any
from uuid import uuid4
from ..config import get_safety_config
from ..exceptions import RollbackError
from ..models import (
ActionRequest,
Checkpoint,
CheckpointType,
RollbackResult,
)
logger = logging.getLogger(__name__)
class FileCheckpoint:
"""Stores file state for rollback."""
def __init__(
self,
checkpoint_id: str,
file_path: str,
original_content: bytes | None,
existed: bool,
) -> None:
self.checkpoint_id = checkpoint_id
self.file_path = file_path
self.original_content = original_content
self.existed = existed
self.created_at = datetime.utcnow()
class RollbackManager:
"""
Manages checkpoints and rollback operations.
Features:
- File system checkpoints
- Transaction wrapping for actions
- Automatic checkpoint for destructive actions
- Rollback triggers on failure
- Checkpoint expiration and cleanup
"""
def __init__(
self,
checkpoint_dir: str | None = None,
retention_hours: int | None = None,
) -> None:
"""
Initialize the RollbackManager.
Args:
checkpoint_dir: Directory for storing checkpoint data
retention_hours: Hours to retain checkpoints
"""
config = get_safety_config()
self._checkpoint_dir = Path(
checkpoint_dir or config.checkpoint_dir
)
self._retention_hours = retention_hours or config.checkpoint_retention_hours
self._checkpoints: dict[str, Checkpoint] = {}
self._file_checkpoints: dict[str, list[FileCheckpoint]] = {}
self._lock = asyncio.Lock()
# Ensure checkpoint directory exists
self._checkpoint_dir.mkdir(parents=True, exist_ok=True)
async def create_checkpoint(
self,
action: ActionRequest,
checkpoint_type: CheckpointType = CheckpointType.COMPOSITE,
description: str | None = None,
) -> Checkpoint:
"""
Create a checkpoint before an action.
Args:
action: The action to checkpoint for
checkpoint_type: Type of checkpoint
description: Optional description
Returns:
The created checkpoint
"""
checkpoint_id = str(uuid4())
checkpoint = Checkpoint(
id=checkpoint_id,
checkpoint_type=checkpoint_type,
action_id=action.id,
created_at=datetime.utcnow(),
expires_at=datetime.utcnow() + timedelta(hours=self._retention_hours),
data={
"action_type": action.action_type.value,
"tool_name": action.tool_name,
"resource": action.resource,
},
description=description or f"Checkpoint for {action.tool_name}",
)
async with self._lock:
self._checkpoints[checkpoint_id] = checkpoint
self._file_checkpoints[checkpoint_id] = []
logger.info(
"Created checkpoint %s for action %s",
checkpoint_id,
action.id,
)
return checkpoint
async def checkpoint_file(
self,
checkpoint_id: str,
file_path: str,
) -> None:
"""
Store current state of a file for checkpoint.
Args:
checkpoint_id: ID of the checkpoint
file_path: Path to the file
"""
path = Path(file_path)
if path.exists():
content = path.read_bytes()
existed = True
else:
content = None
existed = False
file_checkpoint = FileCheckpoint(
checkpoint_id=checkpoint_id,
file_path=file_path,
original_content=content,
existed=existed,
)
async with self._lock:
if checkpoint_id not in self._file_checkpoints:
self._file_checkpoints[checkpoint_id] = []
self._file_checkpoints[checkpoint_id].append(file_checkpoint)
logger.debug(
"Stored file state for checkpoint %s: %s (existed=%s)",
checkpoint_id,
file_path,
existed,
)
async def checkpoint_files(
self,
checkpoint_id: str,
file_paths: list[str],
) -> None:
"""
Store current state of multiple files.
Args:
checkpoint_id: ID of the checkpoint
file_paths: Paths to the files
"""
for path in file_paths:
await self.checkpoint_file(checkpoint_id, path)
async def rollback(
self,
checkpoint_id: str,
) -> RollbackResult:
"""
Rollback to a checkpoint.
Args:
checkpoint_id: ID of the checkpoint
Returns:
Result of the rollback operation
"""
async with self._lock:
checkpoint = self._checkpoints.get(checkpoint_id)
if not checkpoint:
raise RollbackError(
f"Checkpoint not found: {checkpoint_id}",
checkpoint_id=checkpoint_id,
)
if not checkpoint.is_valid:
raise RollbackError(
f"Checkpoint is no longer valid: {checkpoint_id}",
checkpoint_id=checkpoint_id,
)
file_checkpoints = self._file_checkpoints.get(checkpoint_id, [])
actions_rolled_back: list[str] = []
failed_actions: list[str] = []
# Rollback file changes
for fc in file_checkpoints:
try:
await self._rollback_file(fc)
actions_rolled_back.append(f"file:{fc.file_path}")
except Exception as e:
logger.error("Failed to rollback file %s: %s", fc.file_path, e)
failed_actions.append(f"file:{fc.file_path}")
success = len(failed_actions) == 0
# Mark checkpoint as used
async with self._lock:
if checkpoint_id in self._checkpoints:
self._checkpoints[checkpoint_id].is_valid = False
result = RollbackResult(
checkpoint_id=checkpoint_id,
success=success,
actions_rolled_back=actions_rolled_back,
failed_actions=failed_actions,
error=None if success else f"Failed to rollback {len(failed_actions)} items",
)
if success:
logger.info("Rollback successful for checkpoint %s", checkpoint_id)
else:
logger.error(
"Rollback partially failed for checkpoint %s: %d failures",
checkpoint_id,
len(failed_actions),
)
return result
async def discard_checkpoint(self, checkpoint_id: str) -> bool:
"""
Discard a checkpoint without rolling back.
Args:
checkpoint_id: ID of the checkpoint
Returns:
True if checkpoint was found and discarded
"""
async with self._lock:
if checkpoint_id in self._checkpoints:
del self._checkpoints[checkpoint_id]
if checkpoint_id in self._file_checkpoints:
del self._file_checkpoints[checkpoint_id]
logger.debug("Discarded checkpoint %s", checkpoint_id)
return True
return False
async def get_checkpoint(self, checkpoint_id: str) -> Checkpoint | None:
"""Get a checkpoint by ID."""
async with self._lock:
return self._checkpoints.get(checkpoint_id)
async def list_checkpoints(
self,
action_id: str | None = None,
include_expired: bool = False,
) -> list[Checkpoint]:
"""
List checkpoints.
Args:
action_id: Optional filter by action ID
include_expired: Include expired checkpoints
Returns:
List of checkpoints
"""
now = datetime.utcnow()
async with self._lock:
checkpoints = list(self._checkpoints.values())
if action_id:
checkpoints = [c for c in checkpoints if c.action_id == action_id]
if not include_expired:
checkpoints = [
c for c in checkpoints
if c.expires_at is None or c.expires_at > now
]
return checkpoints
async def cleanup_expired(self) -> int:
"""
Clean up expired checkpoints.
Returns:
Number of checkpoints cleaned up
"""
now = datetime.utcnow()
to_remove: list[str] = []
async with self._lock:
for checkpoint_id, checkpoint in self._checkpoints.items():
if checkpoint.expires_at and checkpoint.expires_at < now:
to_remove.append(checkpoint_id)
for checkpoint_id in to_remove:
del self._checkpoints[checkpoint_id]
if checkpoint_id in self._file_checkpoints:
del self._file_checkpoints[checkpoint_id]
if to_remove:
logger.info("Cleaned up %d expired checkpoints", len(to_remove))
return len(to_remove)
async def _rollback_file(self, fc: FileCheckpoint) -> None:
"""Rollback a single file to its checkpoint state."""
path = Path(fc.file_path)
if fc.existed:
# Restore original content
if fc.original_content is not None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(fc.original_content)
logger.debug("Restored file: %s", fc.file_path)
else:
# File didn't exist before - delete it
if path.exists():
path.unlink()
logger.debug("Deleted file (didn't exist before): %s", fc.file_path)
class TransactionContext:
"""
Context manager for transactional action execution.
Usage:
async with TransactionContext(rollback_manager, action) as tx:
tx.checkpoint_file("/path/to/file")
# Do work...
# If exception occurs, automatic rollback
"""
def __init__(
self,
manager: RollbackManager,
action: ActionRequest,
auto_rollback: bool = True,
) -> None:
self._manager = manager
self._action = action
self._auto_rollback = auto_rollback
self._checkpoint: Checkpoint | None = None
self._committed = False
async def __aenter__(self) -> "TransactionContext":
self._checkpoint = await self._manager.create_checkpoint(self._action)
return self
async def __aexit__(
self,
exc_type: type | None,
exc_val: Exception | None,
exc_tb: Any,
) -> bool:
if exc_val is not None and self._auto_rollback and not self._committed:
# Exception occurred - rollback
if self._checkpoint:
try:
await self._manager.rollback(self._checkpoint.id)
logger.info(
"Auto-rollback completed for action %s",
self._action.id,
)
except Exception as e:
logger.error("Auto-rollback failed: %s", e)
elif self._committed and self._checkpoint:
# Committed - discard checkpoint
await self._manager.discard_checkpoint(self._checkpoint.id)
return False # Don't suppress the exception
@property
def checkpoint_id(self) -> str | None:
"""Get the checkpoint ID."""
return self._checkpoint.id if self._checkpoint else None
async def checkpoint_file(self, file_path: str) -> None:
"""Checkpoint a file for this transaction."""
if self._checkpoint:
await self._manager.checkpoint_file(self._checkpoint.id, file_path)
async def checkpoint_files(self, file_paths: list[str]) -> None:
"""Checkpoint multiple files for this transaction."""
if self._checkpoint:
await self._manager.checkpoint_files(self._checkpoint.id, file_paths)
def commit(self) -> None:
"""Mark transaction as committed (no rollback on exit)."""
self._committed = True
async def rollback(self) -> RollbackResult | None:
"""Manually trigger rollback."""
if self._checkpoint:
return await self._manager.rollback(self._checkpoint.id)
return None

View File

@@ -0,0 +1 @@
"""${dir} module."""

View File

@@ -0,0 +1,21 @@
"""
Action Validation Module
Pre-execution validation with rule engine.
"""
from .validator import (
ActionValidator,
ValidationCache,
create_allow_rule,
create_approval_rule,
create_deny_rule,
)
__all__ = [
"ActionValidator",
"ValidationCache",
"create_allow_rule",
"create_approval_rule",
"create_deny_rule",
]

View File

@@ -0,0 +1,439 @@
"""
Action Validator
Pre-execution validation with rule engine for action requests.
"""
import asyncio
import fnmatch
import logging
from collections import OrderedDict
from ..config import get_safety_config
from ..models import (
ActionRequest,
ActionType,
SafetyDecision,
SafetyPolicy,
ValidationResult,
ValidationRule,
)
logger = logging.getLogger(__name__)
class ValidationCache:
"""LRU cache for validation results."""
def __init__(self, max_size: int = 1000, ttl_seconds: int = 60) -> None:
self._cache: OrderedDict[str, tuple[ValidationResult, float]] = OrderedDict()
self._max_size = max_size
self._ttl = ttl_seconds
self._lock = asyncio.Lock()
async def get(self, key: str) -> ValidationResult | None:
"""Get cached validation result."""
import time
async with self._lock:
if key not in self._cache:
return None
result, timestamp = self._cache[key]
if time.time() - timestamp > self._ttl:
del self._cache[key]
return None
# Move to end (LRU)
self._cache.move_to_end(key)
return result
async def set(self, key: str, result: ValidationResult) -> None:
"""Cache a validation result."""
import time
async with self._lock:
if key in self._cache:
self._cache.move_to_end(key)
else:
if len(self._cache) >= self._max_size:
self._cache.popitem(last=False)
self._cache[key] = (result, time.time())
async def clear(self) -> None:
"""Clear the cache."""
async with self._lock:
self._cache.clear()
class ActionValidator:
"""
Validates actions against safety rules before execution.
Features:
- Rule-based validation engine
- Allow/deny/require-approval rules
- Pattern matching for tools and resources
- Validation result caching
- Bypass capability for emergencies
"""
def __init__(
self,
cache_enabled: bool = True,
cache_size: int = 1000,
cache_ttl: int = 60,
) -> None:
"""
Initialize the ActionValidator.
Args:
cache_enabled: Whether to cache validation results
cache_size: Maximum cache entries
cache_ttl: Cache TTL in seconds
"""
self._rules: list[ValidationRule] = []
self._cache_enabled = cache_enabled
self._cache = ValidationCache(max_size=cache_size, ttl_seconds=cache_ttl)
self._bypass_enabled = False
self._bypass_reason: str | None = None
config = get_safety_config()
self._cache_enabled = cache_enabled
self._cache_ttl = config.validation_cache_ttl
self._cache_size = config.validation_cache_size
def add_rule(self, rule: ValidationRule) -> None:
"""
Add a validation rule.
Args:
rule: The rule to add
"""
self._rules.append(rule)
# Re-sort by priority (higher first)
self._rules.sort(key=lambda r: r.priority, reverse=True)
logger.debug("Added validation rule: %s (priority %d)", rule.name, rule.priority)
def remove_rule(self, rule_id: str) -> bool:
"""
Remove a validation rule by ID.
Args:
rule_id: ID of the rule to remove
Returns:
True if rule was found and removed
"""
for i, rule in enumerate(self._rules):
if rule.id == rule_id:
del self._rules[i]
logger.debug("Removed validation rule: %s", rule_id)
return True
return False
def clear_rules(self) -> None:
"""Remove all validation rules."""
self._rules.clear()
def load_rules_from_policy(self, policy: SafetyPolicy) -> None:
"""
Load validation rules from a safety policy.
Args:
policy: The policy to load rules from
"""
# Clear existing rules
self.clear_rules()
# Add rules from policy
for rule in policy.validation_rules:
self.add_rule(rule)
# Create implicit rules from policy settings
# Denied tools
for i, pattern in enumerate(policy.denied_tools):
self.add_rule(
ValidationRule(
name=f"deny_tool_{i}",
description=f"Deny tool pattern: {pattern}",
priority=100, # High priority for denials
tool_patterns=[pattern],
decision=SafetyDecision.DENY,
reason=f"Tool matches denied pattern: {pattern}",
)
)
# Require approval patterns
for i, pattern in enumerate(policy.require_approval_for):
if pattern == "*":
# All actions require approval
self.add_rule(
ValidationRule(
name="require_approval_all",
description="All actions require approval",
priority=50,
action_types=list(ActionType),
decision=SafetyDecision.REQUIRE_APPROVAL,
reason="All actions require human approval",
)
)
else:
self.add_rule(
ValidationRule(
name=f"require_approval_{i}",
description=f"Require approval for: {pattern}",
priority=50,
tool_patterns=[pattern],
decision=SafetyDecision.REQUIRE_APPROVAL,
reason=f"Action matches approval-required pattern: {pattern}",
)
)
logger.info("Loaded %d rules from policy: %s", len(self._rules), policy.name)
async def validate(
self,
action: ActionRequest,
policy: SafetyPolicy | None = None,
) -> ValidationResult:
"""
Validate an action against all rules.
Args:
action: The action to validate
policy: Optional policy override
Returns:
ValidationResult with decision and details
"""
# Check bypass
if self._bypass_enabled:
logger.warning(
"Validation bypass active: %s - allowing action %s",
self._bypass_reason,
action.id,
)
return ValidationResult(
action_id=action.id,
decision=SafetyDecision.ALLOW,
applied_rules=[],
reasons=[f"Validation bypassed: {self._bypass_reason}"],
)
# Check cache
if self._cache_enabled:
cache_key = self._get_cache_key(action)
cached = await self._cache.get(cache_key)
if cached:
logger.debug("Using cached validation for action %s", action.id)
return cached
# Load rules from policy if provided
if policy and not self._rules:
self.load_rules_from_policy(policy)
# Validate against rules
applied_rules: list[str] = []
reasons: list[str] = []
final_decision = SafetyDecision.ALLOW
approval_id: str | None = None
for rule in self._rules:
if not rule.enabled:
continue
if self._rule_matches(rule, action):
applied_rules.append(rule.id)
if rule.reason:
reasons.append(rule.reason)
# Handle decision priority
if rule.decision == SafetyDecision.DENY:
# Deny takes precedence
final_decision = SafetyDecision.DENY
break
elif rule.decision == SafetyDecision.REQUIRE_APPROVAL:
# Upgrade to require approval
if final_decision != SafetyDecision.DENY:
final_decision = SafetyDecision.REQUIRE_APPROVAL
# If no rules matched and no explicit allow, default to allow
if not applied_rules:
reasons.append("No matching rules - default allow")
result = ValidationResult(
action_id=action.id,
decision=final_decision,
applied_rules=applied_rules,
reasons=reasons,
approval_id=approval_id,
)
# Cache result
if self._cache_enabled:
cache_key = self._get_cache_key(action)
await self._cache.set(cache_key, result)
return result
async def validate_batch(
self,
actions: list[ActionRequest],
policy: SafetyPolicy | None = None,
) -> list[ValidationResult]:
"""
Validate multiple actions.
Args:
actions: Actions to validate
policy: Optional policy override
Returns:
List of validation results
"""
tasks = [self.validate(action, policy) for action in actions]
return await asyncio.gather(*tasks)
def enable_bypass(self, reason: str) -> None:
"""
Enable validation bypass (emergency use only).
Args:
reason: Reason for enabling bypass
"""
logger.critical("Validation bypass enabled: %s", reason)
self._bypass_enabled = True
self._bypass_reason = reason
def disable_bypass(self) -> None:
"""Disable validation bypass."""
logger.info("Validation bypass disabled")
self._bypass_enabled = False
self._bypass_reason = None
async def clear_cache(self) -> None:
"""Clear the validation cache."""
await self._cache.clear()
def _rule_matches(self, rule: ValidationRule, action: ActionRequest) -> bool:
"""Check if a rule matches an action."""
# Check action types
if rule.action_types:
if action.action_type not in rule.action_types:
return False
# Check tool patterns
if rule.tool_patterns:
if not action.tool_name:
return False
matched = False
for pattern in rule.tool_patterns:
if self._matches_pattern(action.tool_name, pattern):
matched = True
break
if not matched:
return False
# Check resource patterns
if rule.resource_patterns:
if not action.resource:
return False
matched = False
for pattern in rule.resource_patterns:
if self._matches_pattern(action.resource, pattern):
matched = True
break
if not matched:
return False
# Check agent IDs
if rule.agent_ids:
if action.metadata.agent_id not in rule.agent_ids:
return False
return True
def _matches_pattern(self, value: str, pattern: str) -> bool:
"""Check if value matches a pattern (supports wildcards)."""
if pattern == "*":
return True
# Use fnmatch for glob-style matching
return fnmatch.fnmatch(value, pattern)
def _get_cache_key(self, action: ActionRequest) -> str:
"""Generate a cache key for an action."""
# Key based on action characteristics that affect validation
key_parts = [
action.action_type.value,
action.tool_name or "",
action.resource or "",
action.metadata.agent_id,
action.metadata.autonomy_level.value,
]
return ":".join(key_parts)
# Module-level convenience functions
def create_allow_rule(
name: str,
tool_patterns: list[str] | None = None,
resource_patterns: list[str] | None = None,
action_types: list[ActionType] | None = None,
priority: int = 0,
) -> ValidationRule:
"""Create an allow rule."""
return ValidationRule(
name=name,
tool_patterns=tool_patterns,
resource_patterns=resource_patterns,
action_types=action_types,
decision=SafetyDecision.ALLOW,
priority=priority,
)
def create_deny_rule(
name: str,
tool_patterns: list[str] | None = None,
resource_patterns: list[str] | None = None,
action_types: list[ActionType] | None = None,
reason: str | None = None,
priority: int = 100,
) -> ValidationRule:
"""Create a deny rule."""
return ValidationRule(
name=name,
tool_patterns=tool_patterns,
resource_patterns=resource_patterns,
action_types=action_types,
decision=SafetyDecision.DENY,
reason=reason,
priority=priority,
)
def create_approval_rule(
name: str,
tool_patterns: list[str] | None = None,
resource_patterns: list[str] | None = None,
action_types: list[ActionType] | None = None,
reason: str | None = None,
priority: int = 50,
) -> ValidationRule:
"""Create a require-approval rule."""
return ValidationRule(
name=name,
tool_patterns=tool_patterns,
resource_patterns=resource_patterns,
action_types=action_types,
decision=SafetyDecision.REQUIRE_APPROVAL,
reason=reason,
priority=priority,
)

View File

@@ -0,0 +1,345 @@
"""Tests for content filtering module."""
import pytest
from app.services.safety.content.filter import (
ContentCategory,
ContentFilter,
FilterAction,
FilterPattern,
filter_content,
scan_for_secrets,
)
from app.services.safety.exceptions import ContentFilterError
@pytest.fixture
def filter_all() -> ContentFilter:
"""Create a ContentFilter with all filters enabled."""
return ContentFilter(
enable_pii_filter=True,
enable_secret_filter=True,
enable_injection_filter=True,
)
@pytest.fixture
def filter_pii_only() -> ContentFilter:
"""Create a ContentFilter with only PII filter."""
return ContentFilter(
enable_pii_filter=True,
enable_secret_filter=False,
enable_injection_filter=False,
)
class TestContentFilter:
"""Tests for ContentFilter class."""
@pytest.mark.asyncio
async def test_filter_email(self, filter_pii_only: ContentFilter) -> None:
"""Test filtering email addresses."""
content = "Contact me at john.doe@example.com for details."
result = await filter_pii_only.filter(content)
assert result.has_sensitive_content
assert "[EMAIL]" in result.filtered_content
assert "john.doe@example.com" not in result.filtered_content
assert len(result.matches) == 1
assert result.matches[0].category == ContentCategory.PII
@pytest.mark.asyncio
async def test_filter_phone_number(self, filter_pii_only: ContentFilter) -> None:
"""Test filtering phone numbers."""
content = "Call me at 555-123-4567 or (555) 987-6543."
result = await filter_pii_only.filter(content)
assert result.has_sensitive_content
assert "[PHONE]" in result.filtered_content
# Should redact both phone numbers
assert "555-123-4567" not in result.filtered_content
@pytest.mark.asyncio
async def test_filter_ssn(self, filter_pii_only: ContentFilter) -> None:
"""Test filtering Social Security Numbers."""
content = "SSN: 123-45-6789"
result = await filter_pii_only.filter(content)
assert result.has_sensitive_content
assert "[SSN]" in result.filtered_content
assert "123-45-6789" not in result.filtered_content
@pytest.mark.asyncio
async def test_filter_credit_card(self, filter_all: ContentFilter) -> None:
"""Test filtering credit card numbers."""
content = "Card: 4111-2222-3333-4444"
result = await filter_all.filter(content)
assert result.has_sensitive_content
assert "[CREDIT_CARD]" in result.filtered_content
@pytest.mark.asyncio
async def test_block_api_key(self, filter_all: ContentFilter) -> None:
"""Test blocking API keys."""
content = "api_key: sk-abcdef1234567890abcdef1234567890"
result = await filter_all.filter(content)
assert result.blocked
assert result.block_reason is not None
assert "api_key" in result.block_reason.lower()
@pytest.mark.asyncio
async def test_block_github_token(self, filter_all: ContentFilter) -> None:
"""Test blocking GitHub tokens."""
content = "token: ghp_abcdefghijklmnopqrstuvwxyz1234567890"
result = await filter_all.filter(content)
assert result.blocked
assert len(result.matches) > 0
@pytest.mark.asyncio
async def test_block_private_key(self, filter_all: ContentFilter) -> None:
"""Test blocking private keys."""
content = """
-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEA...
-----END RSA PRIVATE KEY-----
"""
result = await filter_all.filter(content)
assert result.blocked
@pytest.mark.asyncio
async def test_block_password_in_url(self, filter_all: ContentFilter) -> None:
"""Test blocking passwords in URLs."""
content = "Connect to: postgres://user:secretpassword@localhost/db"
result = await filter_all.filter(content)
assert result.blocked or result.has_sensitive_content
@pytest.mark.asyncio
async def test_warn_ip_address(self, filter_pii_only: ContentFilter) -> None:
"""Test warning on IP addresses."""
content = "Server IP: 192.168.1.100"
result = await filter_pii_only.filter(content)
# IP addresses generate warnings, not blocks
assert len(result.warnings) > 0 or result.has_sensitive_content
@pytest.mark.asyncio
async def test_no_false_positives_clean_content(
self,
filter_all: ContentFilter,
) -> None:
"""Test that clean content passes through."""
content = "This is a normal message with no sensitive data."
result = await filter_all.filter(content)
assert not result.blocked
assert result.filtered_content == content
@pytest.mark.asyncio
async def test_raise_on_block(self, filter_all: ContentFilter) -> None:
"""Test raising exception on blocked content."""
content = "-----BEGIN RSA PRIVATE KEY-----"
with pytest.raises(ContentFilterError):
await filter_all.filter(content, raise_on_block=True)
class TestFilterDict:
"""Tests for dictionary filtering."""
@pytest.mark.asyncio
async def test_filter_dict_values(self, filter_pii_only: ContentFilter) -> None:
"""Test filtering string values in a dictionary."""
data = {
"name": "John Doe",
"email": "john@example.com",
"age": 30,
}
result = await filter_pii_only.filter_dict(data)
assert "[EMAIL]" in result["email"]
assert result["age"] == 30 # Non-string unchanged
@pytest.mark.asyncio
async def test_filter_dict_recursive(
self,
filter_pii_only: ContentFilter,
) -> None:
"""Test recursive dictionary filtering."""
data = {
"user": {
"contact": {
"email": "test@example.com",
}
}
}
result = await filter_pii_only.filter_dict(data, recursive=True)
assert "[EMAIL]" in result["user"]["contact"]["email"]
@pytest.mark.asyncio
async def test_filter_dict_specific_keys(
self,
filter_pii_only: ContentFilter,
) -> None:
"""Test filtering specific keys only."""
data = {
"public_email": "public@example.com",
"private_email": "private@example.com",
}
result = await filter_pii_only.filter_dict(
data,
keys_to_filter=["private_email"],
)
# Only private_email should be filtered
assert "public@example.com" in result["public_email"]
assert "[EMAIL]" in result["private_email"]
class TestScan:
"""Tests for content scanning."""
@pytest.mark.asyncio
async def test_scan_without_filtering(
self,
filter_all: ContentFilter,
) -> None:
"""Test scanning without modifying content."""
content = "Email: test@example.com, SSN: 123-45-6789"
matches = await filter_all.scan(content)
assert len(matches) >= 2
@pytest.mark.asyncio
async def test_scan_specific_categories(
self,
filter_all: ContentFilter,
) -> None:
"""Test scanning for specific categories only."""
content = "Email: test@example.com, token: ghp_abc123456789012345678901234567890123"
# Scan only for secrets
matches = await filter_all.scan(
content,
categories=[ContentCategory.SECRETS],
)
# Should only find the token, not the email
assert all(m.category == ContentCategory.SECRETS for m in matches)
class TestValidateSafe:
"""Tests for safe validation."""
@pytest.mark.asyncio
async def test_validate_safe_clean(self, filter_all: ContentFilter) -> None:
"""Test validation of clean content."""
content = "This is safe content."
is_safe, issues = await filter_all.validate_safe(content)
assert is_safe is True
assert len(issues) == 0
@pytest.mark.asyncio
async def test_validate_safe_with_secrets(
self,
filter_all: ContentFilter,
) -> None:
"""Test validation of content with secrets."""
content = "-----BEGIN RSA PRIVATE KEY-----"
is_safe, issues = await filter_all.validate_safe(content)
assert is_safe is False
assert len(issues) > 0
class TestCustomPatterns:
"""Tests for custom pattern support."""
@pytest.mark.asyncio
async def test_add_custom_pattern(self) -> None:
"""Test adding a custom filter pattern."""
content_filter = ContentFilter(
enable_pii_filter=False,
enable_secret_filter=False,
enable_injection_filter=False,
)
# Add custom pattern for internal IDs
content_filter.add_pattern(
FilterPattern(
name="internal_id",
category=ContentCategory.CUSTOM,
pattern=r"INTERNAL-[A-Z0-9]{8}",
action=FilterAction.REDACT,
replacement="[INTERNAL_ID]",
)
)
content = "Reference: INTERNAL-ABC12345"
result = await content_filter.filter(content)
assert "[INTERNAL_ID]" in result.filtered_content
@pytest.mark.asyncio
async def test_disable_pattern(self, filter_pii_only: ContentFilter) -> None:
"""Test disabling a pattern."""
filter_pii_only.enable_pattern("email", enabled=False)
content = "Email: test@example.com"
result = await filter_pii_only.filter(content)
# Email should not be filtered
assert "test@example.com" in result.filtered_content
@pytest.mark.asyncio
async def test_remove_pattern(self, filter_pii_only: ContentFilter) -> None:
"""Test removing a pattern."""
removed = filter_pii_only.remove_pattern("email")
assert removed is True
content = "Email: test@example.com"
result = await filter_pii_only.filter(content)
# Email should not be filtered
assert "test@example.com" in result.filtered_content
class TestPatternStats:
"""Tests for pattern statistics."""
def test_get_pattern_stats(self, filter_all: ContentFilter) -> None:
"""Test getting pattern statistics."""
stats = filter_all.get_pattern_stats()
assert stats["total_patterns"] > 0
assert "by_category" in stats
assert "by_action" in stats
class TestConvenienceFunctions:
"""Tests for convenience functions."""
@pytest.mark.asyncio
async def test_filter_content_function(self) -> None:
"""Test the quick filter function."""
content = "Email: test@example.com"
filtered = await filter_content(content)
assert "test@example.com" not in filtered
@pytest.mark.asyncio
async def test_scan_for_secrets_function(self) -> None:
"""Test the quick secret scan function."""
content = "Token: ghp_abcdefghijklmnopqrstuvwxyz1234567890"
matches = await scan_for_secrets(content)
assert len(matches) > 0
assert matches[0].category in (
ContentCategory.SECRETS,
ContentCategory.CREDENTIALS,
)

View File

@@ -0,0 +1,425 @@
"""Tests for emergency controls module."""
import pytest
from app.services.safety.emergency.controls import (
EmergencyControls,
EmergencyReason,
EmergencyState,
EmergencyTrigger,
)
from app.services.safety.exceptions import EmergencyStopError
@pytest.fixture
def controls() -> EmergencyControls:
"""Create fresh EmergencyControls."""
return EmergencyControls()
class TestEmergencyControls:
"""Tests for EmergencyControls class."""
@pytest.mark.asyncio
async def test_initial_state_is_normal(
self,
controls: EmergencyControls,
) -> None:
"""Test that initial state is normal."""
state = await controls.get_state("global")
assert state == EmergencyState.NORMAL
@pytest.mark.asyncio
async def test_emergency_stop_changes_state(
self,
controls: EmergencyControls,
) -> None:
"""Test that emergency stop changes state to stopped."""
event = await controls.emergency_stop(
reason=EmergencyReason.MANUAL,
triggered_by="test",
message="Test emergency stop",
)
assert event.state == EmergencyState.STOPPED
state = await controls.get_state("global")
assert state == EmergencyState.STOPPED
@pytest.mark.asyncio
async def test_pause_changes_state(
self,
controls: EmergencyControls,
) -> None:
"""Test that pause changes state to paused."""
event = await controls.pause(
reason=EmergencyReason.BUDGET_EXCEEDED,
triggered_by="budget_controller",
message="Budget exceeded",
)
assert event.state == EmergencyState.PAUSED
state = await controls.get_state("global")
assert state == EmergencyState.PAUSED
@pytest.mark.asyncio
async def test_resume_from_paused(
self,
controls: EmergencyControls,
) -> None:
"""Test resuming from paused state."""
await controls.pause(
reason=EmergencyReason.RATE_LIMIT,
triggered_by="limiter",
message="Rate limited",
)
resumed = await controls.resume(resumed_by="admin")
assert resumed is True
state = await controls.get_state("global")
assert state == EmergencyState.NORMAL
@pytest.mark.asyncio
async def test_cannot_resume_from_stopped(
self,
controls: EmergencyControls,
) -> None:
"""Test that you cannot resume from stopped state."""
await controls.emergency_stop(
reason=EmergencyReason.SAFETY_VIOLATION,
triggered_by="safety",
message="Critical violation",
)
resumed = await controls.resume(resumed_by="admin")
assert resumed is False
state = await controls.get_state("global")
assert state == EmergencyState.STOPPED
@pytest.mark.asyncio
async def test_reset_from_stopped(
self,
controls: EmergencyControls,
) -> None:
"""Test resetting from stopped state."""
await controls.emergency_stop(
reason=EmergencyReason.SAFETY_VIOLATION,
triggered_by="safety",
message="Critical violation",
)
reset = await controls.reset(reset_by="admin")
assert reset is True
state = await controls.get_state("global")
assert state == EmergencyState.NORMAL
@pytest.mark.asyncio
async def test_scoped_emergency_stop(
self,
controls: EmergencyControls,
) -> None:
"""Test emergency stop with specific scope."""
await controls.emergency_stop(
reason=EmergencyReason.LOOP_DETECTED,
triggered_by="detector",
message="Loop in agent",
scope="agent:agent-123",
)
# Agent scope should be stopped
agent_state = await controls.get_state("agent:agent-123")
assert agent_state == EmergencyState.STOPPED
# Global should still be normal
global_state = await controls.get_state("global")
assert global_state == EmergencyState.NORMAL
@pytest.mark.asyncio
async def test_check_allowed_when_normal(
self,
controls: EmergencyControls,
) -> None:
"""Test check_allowed returns True when state is normal."""
allowed = await controls.check_allowed()
assert allowed is True
@pytest.mark.asyncio
async def test_check_allowed_when_stopped(
self,
controls: EmergencyControls,
) -> None:
"""Test check_allowed returns False when stopped."""
await controls.emergency_stop(
reason=EmergencyReason.MANUAL,
triggered_by="test",
message="Stop",
)
allowed = await controls.check_allowed(raise_if_blocked=False)
assert allowed is False
@pytest.mark.asyncio
async def test_check_allowed_raises_when_blocked(
self,
controls: EmergencyControls,
) -> None:
"""Test check_allowed raises exception when blocked."""
await controls.emergency_stop(
reason=EmergencyReason.MANUAL,
triggered_by="test",
message="Stop",
)
with pytest.raises(EmergencyStopError):
await controls.check_allowed(raise_if_blocked=True)
@pytest.mark.asyncio
async def test_check_allowed_with_scope(
self,
controls: EmergencyControls,
) -> None:
"""Test check_allowed with specific scope."""
await controls.pause(
reason=EmergencyReason.BUDGET_EXCEEDED,
triggered_by="budget",
message="Paused",
scope="project:proj-123",
)
# Project scope should be blocked
allowed_project = await controls.check_allowed(
scope="project:proj-123",
raise_if_blocked=False,
)
assert allowed_project is False
# Different scope should be allowed
allowed_other = await controls.check_allowed(
scope="project:proj-456",
raise_if_blocked=False,
)
assert allowed_other is True
@pytest.mark.asyncio
async def test_get_all_states(
self,
controls: EmergencyControls,
) -> None:
"""Test getting all states."""
await controls.pause(
reason=EmergencyReason.MANUAL,
triggered_by="test",
message="Pause",
scope="agent:a1",
)
await controls.emergency_stop(
reason=EmergencyReason.MANUAL,
triggered_by="test",
message="Stop",
scope="agent:a2",
)
states = await controls.get_all_states()
assert states["global"] == EmergencyState.NORMAL
assert states["agent:a1"] == EmergencyState.PAUSED
assert states["agent:a2"] == EmergencyState.STOPPED
@pytest.mark.asyncio
async def test_get_active_events(
self,
controls: EmergencyControls,
) -> None:
"""Test getting active (unresolved) events."""
await controls.pause(
reason=EmergencyReason.MANUAL,
triggered_by="test",
message="Pause 1",
)
events = await controls.get_active_events()
assert len(events) == 1
# Resume should resolve the event
await controls.resume()
events_after = await controls.get_active_events()
assert len(events_after) == 0
@pytest.mark.asyncio
async def test_event_history(
self,
controls: EmergencyControls,
) -> None:
"""Test getting event history."""
await controls.pause(
reason=EmergencyReason.RATE_LIMIT,
triggered_by="test",
message="Rate limited",
)
await controls.resume()
history = await controls.get_event_history()
assert len(history) == 1
assert history[0].reason == EmergencyReason.RATE_LIMIT
@pytest.mark.asyncio
async def test_event_metadata(
self,
controls: EmergencyControls,
) -> None:
"""Test event metadata storage."""
event = await controls.emergency_stop(
reason=EmergencyReason.BUDGET_EXCEEDED,
triggered_by="budget_controller",
message="Over budget",
metadata={"budget_type": "tokens", "usage": 150000},
)
assert event.metadata["budget_type"] == "tokens"
assert event.metadata["usage"] == 150000
class TestCallbacks:
"""Tests for callback functionality."""
@pytest.mark.asyncio
async def test_on_stop_callback(
self,
controls: EmergencyControls,
) -> None:
"""Test on_stop callback is called."""
callback_called = []
def callback(event: object) -> None:
callback_called.append(event)
controls.on_stop(callback)
await controls.emergency_stop(
reason=EmergencyReason.MANUAL,
triggered_by="test",
message="Stop",
)
assert len(callback_called) == 1
@pytest.mark.asyncio
async def test_on_pause_callback(
self,
controls: EmergencyControls,
) -> None:
"""Test on_pause callback is called."""
callback_called = []
def callback(event: object) -> None:
callback_called.append(event)
controls.on_pause(callback)
await controls.pause(
reason=EmergencyReason.MANUAL,
triggered_by="test",
message="Pause",
)
assert len(callback_called) == 1
@pytest.mark.asyncio
async def test_on_resume_callback(
self,
controls: EmergencyControls,
) -> None:
"""Test on_resume callback is called."""
callback_called = []
def callback(data: object) -> None:
callback_called.append(data)
controls.on_resume(callback)
await controls.pause(
reason=EmergencyReason.MANUAL,
triggered_by="test",
message="Pause",
)
await controls.resume()
assert len(callback_called) == 1
class TestEmergencyTrigger:
"""Tests for EmergencyTrigger class."""
@pytest.fixture
def trigger(self, controls: EmergencyControls) -> EmergencyTrigger:
"""Create an EmergencyTrigger."""
return EmergencyTrigger(controls)
@pytest.mark.asyncio
async def test_trigger_on_safety_violation(
self,
trigger: EmergencyTrigger,
controls: EmergencyControls,
) -> None:
"""Test triggering emergency on safety violation."""
event = await trigger.trigger_on_safety_violation(
violation_type="unauthorized_access",
details={"resource": "/secrets/key"},
)
assert event.reason == EmergencyReason.SAFETY_VIOLATION
assert event.state == EmergencyState.STOPPED
state = await controls.get_state("global")
assert state == EmergencyState.STOPPED
@pytest.mark.asyncio
async def test_trigger_on_budget_exceeded(
self,
trigger: EmergencyTrigger,
controls: EmergencyControls,
) -> None:
"""Test triggering pause on budget exceeded."""
event = await trigger.trigger_on_budget_exceeded(
budget_type="tokens",
current=150000,
limit=100000,
)
assert event.reason == EmergencyReason.BUDGET_EXCEEDED
assert event.state == EmergencyState.PAUSED # Pause, not stop
@pytest.mark.asyncio
async def test_trigger_on_loop_detected(
self,
trigger: EmergencyTrigger,
controls: EmergencyControls,
) -> None:
"""Test triggering pause on loop detection."""
event = await trigger.trigger_on_loop_detected(
loop_type="exact",
agent_id="agent-123",
details={"pattern": "file_read"},
)
assert event.reason == EmergencyReason.LOOP_DETECTED
assert event.scope == "agent:agent-123"
assert event.state == EmergencyState.PAUSED
@pytest.mark.asyncio
async def test_trigger_on_content_violation(
self,
trigger: EmergencyTrigger,
controls: EmergencyControls,
) -> None:
"""Test triggering stop on content violation."""
event = await trigger.trigger_on_content_violation(
category="secrets",
pattern="private_key",
)
assert event.reason == EmergencyReason.CONTENT_VIOLATION
assert event.state == EmergencyState.STOPPED

View File

@@ -0,0 +1,316 @@
"""Tests for loop detection module."""
import pytest
from app.services.safety.exceptions import LoopDetectedError
from app.services.safety.loops.detector import (
ActionSignature,
LoopBreaker,
LoopDetector,
)
from app.services.safety.models import (
ActionMetadata,
ActionRequest,
ActionType,
AutonomyLevel,
)
@pytest.fixture
def detector() -> LoopDetector:
"""Create a fresh LoopDetector with low thresholds for testing."""
return LoopDetector(
history_size=20,
max_exact_repetitions=3,
max_semantic_repetitions=5,
)
@pytest.fixture
def sample_metadata() -> ActionMetadata:
"""Create sample action metadata."""
return ActionMetadata(
agent_id="test-agent",
session_id="test-session",
autonomy_level=AutonomyLevel.MILESTONE,
)
def create_action(
metadata: ActionMetadata,
tool_name: str,
resource: str = "/tmp/test.txt", # noqa: S108
arguments: dict | None = None,
) -> ActionRequest:
"""Helper to create test actions."""
return ActionRequest(
action_type=ActionType.FILE_READ,
tool_name=tool_name,
resource=resource,
arguments=arguments or {},
metadata=metadata,
)
class TestActionSignature:
"""Tests for ActionSignature class."""
def test_exact_key_includes_args(self, sample_metadata: ActionMetadata) -> None:
"""Test that exact key includes argument hash."""
action1 = create_action(sample_metadata, "file_read", arguments={"path": "a"})
action2 = create_action(sample_metadata, "file_read", arguments={"path": "b"})
sig1 = ActionSignature(action1)
sig2 = ActionSignature(action2)
assert sig1.exact_key() != sig2.exact_key()
def test_semantic_key_ignores_args(self, sample_metadata: ActionMetadata) -> None:
"""Test that semantic key ignores arguments."""
action1 = create_action(sample_metadata, "file_read", arguments={"path": "a"})
action2 = create_action(sample_metadata, "file_read", arguments={"path": "b"})
sig1 = ActionSignature(action1)
sig2 = ActionSignature(action2)
assert sig1.semantic_key() == sig2.semantic_key()
def test_type_key(self, sample_metadata: ActionMetadata) -> None:
"""Test type key extraction."""
action = create_action(sample_metadata, "file_read")
sig = ActionSignature(action)
assert sig.type_key() == "file_read"
class TestLoopDetector:
"""Tests for LoopDetector class."""
@pytest.mark.asyncio
async def test_no_loop_on_first_action(
self,
detector: LoopDetector,
sample_metadata: ActionMetadata,
) -> None:
"""Test that first action is never a loop."""
action = create_action(sample_metadata, "file_read")
is_loop, loop_type = await detector.check(action)
assert is_loop is False
assert loop_type is None
@pytest.mark.asyncio
async def test_exact_loop_detection(
self,
detector: LoopDetector,
sample_metadata: ActionMetadata,
) -> None:
"""Test detection of exact repetitions."""
action = create_action(
sample_metadata,
"file_read",
resource="/tmp/same.txt", # noqa: S108
arguments={"path": "/tmp/same.txt"}, # noqa: S108
)
# Record the same action 3 times (threshold)
for _ in range(3):
await detector.record(action)
# Next should be detected as a loop
is_loop, loop_type = await detector.check(action)
assert is_loop is True
assert loop_type == "exact"
@pytest.mark.asyncio
async def test_semantic_loop_detection(
self,
detector: LoopDetector,
sample_metadata: ActionMetadata,
) -> None:
"""Test detection of semantic (similar) repetitions."""
# Record same tool/resource but different arguments
test_resource = "/tmp/test.txt" # noqa: S108
for i in range(5):
action = create_action(
sample_metadata,
"file_read",
resource=test_resource,
arguments={"offset": i},
)
await detector.record(action)
# Next similar action should be detected as semantic loop
action = create_action(
sample_metadata,
"file_read",
resource=test_resource,
arguments={"offset": 100},
)
is_loop, loop_type = await detector.check(action)
assert is_loop is True
assert loop_type == "semantic"
@pytest.mark.asyncio
async def test_oscillation_detection(
self,
detector: LoopDetector,
sample_metadata: ActionMetadata,
) -> None:
"""Test detection of A→B→A→B oscillation pattern."""
action_a = create_action(sample_metadata, "tool_a", resource="/a")
action_b = create_action(sample_metadata, "tool_b", resource="/b")
# Create A→B→A pattern
await detector.record(action_a)
await detector.record(action_b)
await detector.record(action_a)
# Fourth action completing A→B→A→B should be detected as oscillation
is_loop, loop_type = await detector.check(action_b)
assert is_loop is True
assert loop_type == "oscillation"
@pytest.mark.asyncio
async def test_different_actions_no_loop(
self,
detector: LoopDetector,
sample_metadata: ActionMetadata,
) -> None:
"""Test that different actions don't trigger loops."""
for i in range(10):
action = create_action(
sample_metadata,
f"tool_{i}",
resource=f"/resource_{i}",
)
is_loop, _ = await detector.check(action)
assert is_loop is False
await detector.record(action)
@pytest.mark.asyncio
async def test_check_and_raise(
self,
detector: LoopDetector,
sample_metadata: ActionMetadata,
) -> None:
"""Test check_and_raise raises on loop detection."""
action = create_action(sample_metadata, "file_read")
# Record threshold number of times
for _ in range(3):
await detector.record(action)
# Should raise
with pytest.raises(LoopDetectedError) as exc_info:
await detector.check_and_raise(action)
assert "exact" in exc_info.value.loop_type.lower()
@pytest.mark.asyncio
async def test_clear_history(
self,
detector: LoopDetector,
sample_metadata: ActionMetadata,
) -> None:
"""Test clearing agent history."""
action = create_action(sample_metadata, "file_read")
# Record multiple times
for _ in range(3):
await detector.record(action)
# Clear history
await detector.clear_history(sample_metadata.agent_id)
# Should no longer detect loop
is_loop, _ = await detector.check(action)
assert is_loop is False
@pytest.mark.asyncio
async def test_per_agent_history(
self,
detector: LoopDetector,
) -> None:
"""Test that history is tracked per agent."""
metadata1 = ActionMetadata(agent_id="agent-1", session_id="s1")
metadata2 = ActionMetadata(agent_id="agent-2", session_id="s2")
action1 = create_action(metadata1, "file_read")
action2 = create_action(metadata2, "file_read")
# Record for agent 1 (threshold times)
for _ in range(3):
await detector.record(action1)
# Agent 1 should detect loop
is_loop1, _ = await detector.check(action1)
assert is_loop1 is True
# Agent 2 should not detect loop
is_loop2, _ = await detector.check(action2)
assert is_loop2 is False
@pytest.mark.asyncio
async def test_get_stats(
self,
detector: LoopDetector,
sample_metadata: ActionMetadata,
) -> None:
"""Test getting loop detection stats."""
for i in range(5):
action = create_action(
sample_metadata,
f"tool_{i % 2}", # Alternate between 2 tools
resource=f"/resource_{i}",
)
await detector.record(action)
stats = await detector.get_stats(sample_metadata.agent_id)
assert stats["history_size"] == 5
assert len(stats["action_type_counts"]) > 0
class TestLoopBreaker:
"""Tests for LoopBreaker class."""
@pytest.mark.asyncio
async def test_suggest_alternatives_exact(
self,
sample_metadata: ActionMetadata,
) -> None:
"""Test suggestions for exact loops."""
action = create_action(sample_metadata, "file_read")
suggestions = await LoopBreaker.suggest_alternatives(action, "exact")
assert len(suggestions) > 0
assert "same action" in suggestions[0].lower()
@pytest.mark.asyncio
async def test_suggest_alternatives_semantic(
self,
sample_metadata: ActionMetadata,
) -> None:
"""Test suggestions for semantic loops."""
action = create_action(sample_metadata, "file_read")
suggestions = await LoopBreaker.suggest_alternatives(action, "semantic")
assert len(suggestions) > 0
assert "similar" in suggestions[0].lower()
@pytest.mark.asyncio
async def test_suggest_alternatives_oscillation(
self,
sample_metadata: ActionMetadata,
) -> None:
"""Test suggestions for oscillation loops."""
action = create_action(sample_metadata, "file_read")
suggestions = await LoopBreaker.suggest_alternatives(action, "oscillation")
assert len(suggestions) > 0
assert "oscillat" in suggestions[0].lower()

View File

@@ -0,0 +1,437 @@
"""Tests for safety framework models."""
from datetime import datetime, timedelta
from app.services.safety.models import (
ActionMetadata,
ActionRequest,
ActionResult,
ActionType,
ApprovalRequest,
ApprovalResponse,
ApprovalStatus,
AuditEvent,
AuditEventType,
AutonomyLevel,
BudgetScope,
BudgetStatus,
Checkpoint,
CheckpointType,
GuardianResult,
PermissionLevel,
RateLimitConfig,
RollbackResult,
SafetyDecision,
SafetyPolicy,
ValidationResult,
ValidationRule,
)
class TestActionMetadata:
"""Tests for ActionMetadata model."""
def test_create_with_defaults(self) -> None:
"""Test creating metadata with default values."""
metadata = ActionMetadata(
agent_id="agent-1",
)
assert metadata.agent_id == "agent-1"
assert metadata.autonomy_level == AutonomyLevel.MILESTONE
assert metadata.project_id is None
assert metadata.session_id is None
def test_create_with_all_fields(self) -> None:
"""Test creating metadata with all fields."""
metadata = ActionMetadata(
agent_id="agent-1",
session_id="session-1",
project_id="project-1",
user_id="user-1",
autonomy_level=AutonomyLevel.AUTONOMOUS,
)
assert metadata.project_id == "project-1"
assert metadata.user_id == "user-1"
assert metadata.autonomy_level == AutonomyLevel.AUTONOMOUS
class TestActionRequest:
"""Tests for ActionRequest model."""
def test_create_basic_action(self) -> None:
"""Test creating a basic action request."""
metadata = ActionMetadata(agent_id="agent-1", session_id="session-1")
action = ActionRequest(
action_type=ActionType.FILE_READ,
tool_name="file_read",
metadata=metadata,
)
assert action.action_type == ActionType.FILE_READ
assert action.tool_name == "file_read"
assert action.id is not None
assert action.metadata.agent_id == "agent-1"
def test_action_with_arguments(self) -> None:
"""Test action with arguments."""
test_path = "/tmp/test.txt" # noqa: S108
metadata = ActionMetadata(agent_id="agent-1", session_id="session-1")
action = ActionRequest(
action_type=ActionType.FILE_WRITE,
tool_name="file_write",
arguments={"path": test_path, "content": "hello"},
resource=test_path,
metadata=metadata,
)
assert action.arguments["path"] == test_path
assert action.resource == test_path
class TestActionResult:
"""Tests for ActionResult model."""
def test_successful_result(self) -> None:
"""Test creating a successful result."""
result = ActionResult(
action_id="action-1",
success=True,
data={"output": "done"},
)
assert result.success is True
assert result.data["output"] == "done"
assert result.error is None
def test_failed_result(self) -> None:
"""Test creating a failed result."""
result = ActionResult(
action_id="action-1",
success=False,
error="Permission denied",
)
assert result.success is False
assert result.error == "Permission denied"
class TestValidationRule:
"""Tests for ValidationRule model."""
def test_create_rule(self) -> None:
"""Test creating a validation rule."""
rule = ValidationRule(
name="deny_shell",
description="Deny shell commands",
priority=100,
tool_patterns=["shell_*"],
decision=SafetyDecision.DENY,
reason="Shell commands are not allowed",
)
assert rule.name == "deny_shell"
assert rule.priority == 100
assert rule.decision == SafetyDecision.DENY
assert rule.enabled is True
def test_rule_defaults(self) -> None:
"""Test rule default values."""
rule = ValidationRule(name="test_rule", decision=SafetyDecision.ALLOW)
assert rule.id is not None
assert rule.priority == 0
assert rule.enabled is True
assert rule.decision == SafetyDecision.ALLOW
class TestValidationResult:
"""Tests for ValidationResult model."""
def test_allow_result(self) -> None:
"""Test an allow result."""
result = ValidationResult(
action_id="action-1",
decision=SafetyDecision.ALLOW,
applied_rules=["rule-1"],
reasons=["Action is allowed"],
)
assert result.decision == SafetyDecision.ALLOW
assert len(result.applied_rules) == 1
def test_deny_result(self) -> None:
"""Test a deny result."""
result = ValidationResult(
action_id="action-1",
decision=SafetyDecision.DENY,
applied_rules=["deny_rule"],
reasons=["Action is not permitted"],
)
assert result.decision == SafetyDecision.DENY
class TestBudgetStatus:
"""Tests for BudgetStatus model."""
def test_under_budget(self) -> None:
"""Test status when under budget."""
status = BudgetStatus(
scope=BudgetScope.SESSION,
scope_id="session-1",
tokens_used=5000,
tokens_limit=10000,
tokens_remaining=5000,
)
assert status.tokens_remaining == 5000
assert status.tokens_used == 5000
def test_over_budget(self) -> None:
"""Test status when over budget."""
status = BudgetStatus(
scope=BudgetScope.SESSION,
scope_id="session-1",
cost_used_usd=15.0,
cost_limit_usd=10.0,
cost_remaining_usd=0.0,
is_exceeded=True,
)
assert status.is_exceeded is True
assert status.cost_remaining_usd == 0.0
class TestRateLimitConfig:
"""Tests for RateLimitConfig model."""
def test_create_config(self) -> None:
"""Test creating rate limit config."""
config = RateLimitConfig(
name="actions",
limit=60,
window_seconds=60,
)
assert config.name == "actions"
assert config.limit == 60
assert config.window_seconds == 60
class TestApprovalRequest:
"""Tests for ApprovalRequest model."""
def test_create_request(self) -> None:
"""Test creating an approval request."""
metadata = ActionMetadata(agent_id="agent-1", session_id="session-1")
action = ActionRequest(
action_type=ActionType.DATABASE_MUTATE,
tool_name="db_delete",
metadata=metadata,
)
request = ApprovalRequest(
id="approval-1",
action=action,
reason="Database mutation requires approval",
urgency="high",
timeout_seconds=300,
)
assert request.id == "approval-1"
assert request.urgency == "high"
assert request.timeout_seconds == 300
class TestApprovalResponse:
"""Tests for ApprovalResponse model."""
def test_approved_response(self) -> None:
"""Test an approved response."""
response = ApprovalResponse(
request_id="approval-1",
status=ApprovalStatus.APPROVED,
decided_by="admin",
reason="Looks safe",
)
assert response.status == ApprovalStatus.APPROVED
assert response.decided_by == "admin"
def test_denied_response(self) -> None:
"""Test a denied response."""
response = ApprovalResponse(
request_id="approval-1",
status=ApprovalStatus.DENIED,
decided_by="admin",
reason="Too risky",
)
assert response.status == ApprovalStatus.DENIED
class TestCheckpoint:
"""Tests for Checkpoint model."""
def test_create_checkpoint(self) -> None:
"""Test creating a checkpoint."""
test_path = "/tmp/test.txt" # noqa: S108
checkpoint = Checkpoint(
id="checkpoint-1",
checkpoint_type=CheckpointType.FILE,
action_id="action-1",
created_at=datetime.utcnow(),
data={"path": test_path},
description="File checkpoint",
)
assert checkpoint.id == "checkpoint-1"
assert checkpoint.is_valid is True
def test_expired_checkpoint(self) -> None:
"""Test an expired checkpoint."""
checkpoint = Checkpoint(
id="checkpoint-1",
checkpoint_type=CheckpointType.FILE,
action_id="action-1",
created_at=datetime.utcnow() - timedelta(hours=2),
expires_at=datetime.utcnow() - timedelta(hours=1),
data={},
)
# is_valid is a simple bool, not computed from expires_at
# The RollbackManager handles expiration logic
assert checkpoint.is_valid is True # Default value
class TestRollbackResult:
"""Tests for RollbackResult model."""
def test_successful_rollback(self) -> None:
"""Test a successful rollback."""
result = RollbackResult(
checkpoint_id="checkpoint-1",
success=True,
actions_rolled_back=["file:/tmp/test.txt"],
failed_actions=[],
)
assert result.success is True
assert len(result.actions_rolled_back) == 1
def test_partial_rollback(self) -> None:
"""Test a partial rollback."""
result = RollbackResult(
checkpoint_id="checkpoint-1",
success=False,
actions_rolled_back=["file:/tmp/a.txt"],
failed_actions=["file:/tmp/b.txt"],
error="Failed to rollback 1 item",
)
assert result.success is False
assert len(result.failed_actions) == 1
class TestAuditEvent:
"""Tests for AuditEvent model."""
def test_create_event(self) -> None:
"""Test creating an audit event."""
event = AuditEvent(
id="event-1",
event_type=AuditEventType.ACTION_EXECUTED,
timestamp=datetime.utcnow(),
agent_id="agent-1",
action_id="action-1",
data={"tool": "file_read"},
)
assert event.event_type == AuditEventType.ACTION_EXECUTED
assert event.agent_id == "agent-1"
class TestSafetyPolicy:
"""Tests for SafetyPolicy model."""
def test_default_policy(self) -> None:
"""Test creating a default policy."""
policy = SafetyPolicy(
name="default",
description="Default safety policy",
)
assert policy.name == "default"
assert policy.enabled is True
def test_restrictive_policy(self) -> None:
"""Test creating a restrictive policy."""
policy = SafetyPolicy(
name="restrictive",
description="Restrictive policy",
denied_tools=["shell_*", "exec_*"],
require_approval_for=["database_*", "git_push"],
)
assert len(policy.denied_tools) == 2
assert len(policy.require_approval_for) == 2
class TestGuardianResult:
"""Tests for GuardianResult model."""
def test_allowed_result(self) -> None:
"""Test an allowed result."""
result = GuardianResult(
action_id="action-1",
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["All checks passed"],
)
assert result.decision == SafetyDecision.ALLOW
assert result.allowed is True
assert result.approval_id is None
def test_approval_required_result(self) -> None:
"""Test a result requiring approval."""
result = GuardianResult(
action_id="action-1",
allowed=False,
decision=SafetyDecision.REQUIRE_APPROVAL,
reasons=["Action requires human approval"],
approval_id="approval-123",
)
assert result.decision == SafetyDecision.REQUIRE_APPROVAL
assert result.approval_id == "approval-123"
class TestEnums:
"""Tests for enum values."""
def test_action_types(self) -> None:
"""Test action type enum values."""
assert ActionType.FILE_READ.value == "file_read"
assert ActionType.SHELL_COMMAND.value == "shell_command"
def test_autonomy_levels(self) -> None:
"""Test autonomy level enum values."""
assert AutonomyLevel.FULL_CONTROL.value == "full_control"
assert AutonomyLevel.MILESTONE.value == "milestone"
assert AutonomyLevel.AUTONOMOUS.value == "autonomous"
def test_permission_levels(self) -> None:
"""Test permission level enum values."""
assert PermissionLevel.NONE.value == "none"
assert PermissionLevel.READ.value == "read"
assert PermissionLevel.WRITE.value == "write"
assert PermissionLevel.ADMIN.value == "admin"
def test_safety_decisions(self) -> None:
"""Test safety decision enum values."""
assert SafetyDecision.ALLOW.value == "allow"
assert SafetyDecision.DENY.value == "deny"
assert SafetyDecision.REQUIRE_APPROVAL.value == "require_approval"

View File

@@ -0,0 +1,404 @@
"""Tests for safety validation module."""
import pytest
from app.services.safety.models import (
ActionMetadata,
ActionRequest,
ActionType,
AutonomyLevel,
SafetyDecision,
SafetyPolicy,
ValidationRule,
)
from app.services.safety.validation.validator import (
ActionValidator,
create_allow_rule,
create_approval_rule,
create_deny_rule,
)
@pytest.fixture
def validator() -> ActionValidator:
"""Create a fresh ActionValidator."""
return ActionValidator(cache_enabled=False)
@pytest.fixture
def sample_action() -> ActionRequest:
"""Create a sample action request."""
metadata = ActionMetadata(
agent_id="test-agent",
session_id="test-session",
autonomy_level=AutonomyLevel.MILESTONE,
)
return ActionRequest(
action_type=ActionType.FILE_READ,
tool_name="file_read",
resource="/tmp/test.txt", # noqa: S108
metadata=metadata,
)
class TestActionValidator:
"""Tests for ActionValidator class."""
@pytest.mark.asyncio
async def test_no_rules_allows_by_default(
self,
validator: ActionValidator,
sample_action: ActionRequest,
) -> None:
"""Test that actions are allowed by default with no rules."""
result = await validator.validate(sample_action)
assert result.decision == SafetyDecision.ALLOW
assert "No matching rules" in result.reasons[0]
@pytest.mark.asyncio
async def test_deny_rule_blocks_action(
self,
validator: ActionValidator,
) -> None:
"""Test that a deny rule blocks matching actions."""
validator.add_rule(
create_deny_rule(
name="deny_shell",
tool_patterns=["shell_*"],
reason="Shell commands not allowed",
)
)
metadata = ActionMetadata(agent_id="test-agent", session_id="session-1")
action = ActionRequest(
action_type=ActionType.SHELL_COMMAND,
tool_name="shell_exec",
metadata=metadata,
)
result = await validator.validate(action)
assert result.decision == SafetyDecision.DENY
assert len(result.applied_rules) == 1 # One rule applied
@pytest.mark.asyncio
async def test_approval_rule_requires_approval(
self,
validator: ActionValidator,
) -> None:
"""Test that an approval rule requires approval."""
validator.add_rule(
create_approval_rule(
name="approve_db",
tool_patterns=["database_*"],
reason="Database operations require approval",
)
)
metadata = ActionMetadata(agent_id="test-agent", session_id="session-1")
action = ActionRequest(
action_type=ActionType.DATABASE_MUTATE,
tool_name="database_delete",
metadata=metadata,
)
result = await validator.validate(action)
assert result.decision == SafetyDecision.REQUIRE_APPROVAL
@pytest.mark.asyncio
async def test_deny_takes_precedence(
self,
validator: ActionValidator,
) -> None:
"""Test that deny rules take precedence over allow rules."""
validator.add_rule(
create_allow_rule(
name="allow_files",
tool_patterns=["file_*"],
priority=10,
)
)
validator.add_rule(
create_deny_rule(
name="deny_delete",
action_types=[ActionType.FILE_DELETE],
priority=100,
)
)
metadata = ActionMetadata(agent_id="test-agent", session_id="session-1")
action = ActionRequest(
action_type=ActionType.FILE_DELETE,
tool_name="file_delete",
metadata=metadata,
)
result = await validator.validate(action)
assert result.decision == SafetyDecision.DENY
@pytest.mark.asyncio
async def test_rule_priority_ordering(
self,
validator: ActionValidator,
) -> None:
"""Test that rules are evaluated in priority order."""
validator.add_rule(
ValidationRule(
name="low_priority",
priority=1,
decision=SafetyDecision.ALLOW,
)
)
validator.add_rule(
ValidationRule(
name="high_priority",
priority=100,
decision=SafetyDecision.DENY,
)
)
# High priority should be first in the list
assert validator._rules[0].name == "high_priority"
@pytest.mark.asyncio
async def test_disabled_rule_not_applied(
self,
validator: ActionValidator,
sample_action: ActionRequest,
) -> None:
"""Test that disabled rules are not applied."""
rule = create_deny_rule(
name="deny_all",
tool_patterns=["*"],
)
rule.enabled = False
validator.add_rule(rule)
result = await validator.validate(sample_action)
assert result.decision == SafetyDecision.ALLOW
@pytest.mark.asyncio
async def test_resource_pattern_matching(
self,
validator: ActionValidator,
) -> None:
"""Test resource pattern matching."""
validator.add_rule(
create_deny_rule(
name="deny_secrets",
resource_patterns=["*/secrets/*", "*.env"],
)
)
metadata = ActionMetadata(agent_id="test-agent", session_id="session-1")
action = ActionRequest(
action_type=ActionType.FILE_READ,
tool_name="file_read",
resource="/app/secrets/api_key.txt",
metadata=metadata,
)
result = await validator.validate(action)
assert result.decision == SafetyDecision.DENY
@pytest.mark.asyncio
async def test_agent_id_filter(
self,
validator: ActionValidator,
) -> None:
"""Test filtering by agent ID."""
rule = ValidationRule(
name="restrict_agent",
agent_ids=["restricted-agent"],
decision=SafetyDecision.DENY,
reason="Restricted agent",
)
validator.add_rule(rule)
# Restricted agent should be denied
metadata1 = ActionMetadata(agent_id="restricted-agent")
action1 = ActionRequest(
action_type=ActionType.FILE_READ,
tool_name="file_read",
metadata=metadata1,
)
result1 = await validator.validate(action1)
assert result1.decision == SafetyDecision.DENY
# Other agents should be allowed
metadata2 = ActionMetadata(agent_id="normal-agent")
action2 = ActionRequest(
action_type=ActionType.FILE_READ,
tool_name="file_read",
metadata=metadata2,
)
result2 = await validator.validate(action2)
assert result2.decision == SafetyDecision.ALLOW
@pytest.mark.asyncio
async def test_bypass_mode(
self,
validator: ActionValidator,
sample_action: ActionRequest,
) -> None:
"""Test validation bypass mode."""
validator.add_rule(create_deny_rule(name="deny_all", tool_patterns=["*"]))
# Should be denied normally
result1 = await validator.validate(sample_action)
assert result1.decision == SafetyDecision.DENY
# Enable bypass
validator.enable_bypass("Emergency situation")
result2 = await validator.validate(sample_action)
assert result2.decision == SafetyDecision.ALLOW
assert "bypassed" in result2.reasons[0].lower()
# Disable bypass
validator.disable_bypass()
result3 = await validator.validate(sample_action)
assert result3.decision == SafetyDecision.DENY
def test_remove_rule(self, validator: ActionValidator) -> None:
"""Test removing a rule."""
rule = create_deny_rule(name="test_rule", tool_patterns=["test"])
validator.add_rule(rule)
assert len(validator._rules) == 1
assert validator.remove_rule(rule.id) is True
assert len(validator._rules) == 0
def test_remove_nonexistent_rule(self, validator: ActionValidator) -> None:
"""Test removing a nonexistent rule returns False."""
assert validator.remove_rule("nonexistent") is False
def test_clear_rules(self, validator: ActionValidator) -> None:
"""Test clearing all rules."""
validator.add_rule(create_deny_rule(name="rule1", tool_patterns=["a"]))
validator.add_rule(create_deny_rule(name="rule2", tool_patterns=["b"]))
assert len(validator._rules) == 2
validator.clear_rules()
assert len(validator._rules) == 0
class TestLoadRulesFromPolicy:
"""Tests for loading rules from policies."""
@pytest.mark.asyncio
async def test_load_denied_tools(
self,
validator: ActionValidator,
) -> None:
"""Test loading denied tools from policy."""
policy = SafetyPolicy(
name="test",
denied_tools=["shell_*", "exec_*"],
)
validator.load_rules_from_policy(policy)
# Should have 2 deny rules
deny_rules = [r for r in validator._rules if r.decision == SafetyDecision.DENY]
assert len(deny_rules) == 2
@pytest.mark.asyncio
async def test_load_approval_patterns(
self,
validator: ActionValidator,
) -> None:
"""Test loading approval patterns from policy."""
policy = SafetyPolicy(
name="test",
require_approval_for=["database_*"],
)
validator.load_rules_from_policy(policy)
approval_rules = [
r for r in validator._rules
if r.decision == SafetyDecision.REQUIRE_APPROVAL
]
assert len(approval_rules) == 1
class TestValidationBatch:
"""Tests for batch validation."""
@pytest.mark.asyncio
async def test_validate_batch(
self,
validator: ActionValidator,
) -> None:
"""Test validating multiple actions."""
validator.add_rule(
create_deny_rule(
name="deny_shell",
tool_patterns=["shell_*"],
)
)
metadata = ActionMetadata(agent_id="test-agent", session_id="session-1")
actions = [
ActionRequest(
action_type=ActionType.FILE_READ,
tool_name="file_read",
metadata=metadata,
),
ActionRequest(
action_type=ActionType.SHELL_COMMAND,
tool_name="shell_exec",
metadata=metadata,
),
]
results = await validator.validate_batch(actions)
assert len(results) == 2
assert results[0].decision == SafetyDecision.ALLOW
assert results[1].decision == SafetyDecision.DENY
class TestHelperFunctions:
"""Tests for rule creation helper functions."""
def test_create_allow_rule(self) -> None:
"""Test creating an allow rule."""
rule = create_allow_rule(
name="allow_test",
tool_patterns=["test_*"],
priority=50,
)
assert rule.name == "allow_test"
assert rule.decision == SafetyDecision.ALLOW
assert rule.priority == 50
def test_create_deny_rule(self) -> None:
"""Test creating a deny rule."""
rule = create_deny_rule(
name="deny_test",
tool_patterns=["dangerous_*"],
reason="Too dangerous",
)
assert rule.name == "deny_test"
assert rule.decision == SafetyDecision.DENY
assert rule.reason == "Too dangerous"
assert rule.priority == 100 # Default priority for deny
def test_create_approval_rule(self) -> None:
"""Test creating an approval rule."""
rule = create_approval_rule(
name="approve_test",
action_types=[ActionType.DATABASE_MUTATE],
)
assert rule.name == "approve_test"
assert rule.decision == SafetyDecision.REQUIRE_APPROVAL
assert rule.priority == 50 # Default priority for approval