From 728edd145326c7e1041eaed969ee5065db9ffdac Mon Sep 17 00:00:00 2001 From: Felipe Cardoso Date: Sat, 3 Jan 2026 11:28:00 +0100 Subject: [PATCH] feat(backend): add Phase B safety subsystems (#63) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements core control subsystems for the safety framework: **Action Validation (validation/validator.py):** - Rule-based validation engine with priority ordering - Allow/deny/require-approval rule types - Pattern matching for tools and resources - Validation result caching with LRU eviction - Emergency bypass capability with audit **Permission System (permissions/manager.py):** - Per-agent permission grants on resources - Resource pattern matching (wildcards) - Temporary permissions with expiration - Permission inheritance hierarchy - Default deny with configurable defaults **Cost Control (costs/controller.py):** - Per-session and per-day budget tracking - Token and USD cost limits - Warning alerts at configurable thresholds - Budget rollover and reset policies - Real-time usage tracking **Rate Limiting (limits/limiter.py):** - Sliding window rate limiter - Per-action, per-LLM-call, per-file-op limits - Burst allowance with recovery - Configurable limits per operation type **Loop Detection (loops/detector.py):** - Exact repetition detection (same action+args) - Semantic repetition (similar actions) - Oscillation pattern detection (A→B→A→B) - Per-agent action history tracking - Loop breaking suggestions 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- backend/app/services/safety/costs/__init__.py | 16 +- .../app/services/safety/costs/controller.py | 479 ++++++++++++++++++ .../app/services/safety/limits/__init__.py | 16 +- backend/app/services/safety/limits/limiter.py | 368 ++++++++++++++ backend/app/services/safety/loops/__init__.py | 18 +- backend/app/services/safety/loops/detector.py | 267 ++++++++++ .../services/safety/permissions/__init__.py | 16 +- .../services/safety/permissions/manager.py | 384 ++++++++++++++ .../services/safety/validation/__init__.py | 22 +- .../services/safety/validation/validator.py | 439 ++++++++++++++++ 10 files changed, 2020 insertions(+), 5 deletions(-) create mode 100644 backend/app/services/safety/costs/controller.py create mode 100644 backend/app/services/safety/limits/limiter.py create mode 100644 backend/app/services/safety/loops/detector.py create mode 100644 backend/app/services/safety/permissions/manager.py create mode 100644 backend/app/services/safety/validation/validator.py diff --git a/backend/app/services/safety/costs/__init__.py b/backend/app/services/safety/costs/__init__.py index 9f4729c..a9a585c 100644 --- a/backend/app/services/safety/costs/__init__.py +++ b/backend/app/services/safety/costs/__init__.py @@ -1 +1,15 @@ -"""${dir} module.""" +""" +Cost Control Module + +Budget management and cost tracking. +""" + +from .controller import ( + BudgetTracker, + CostController, +) + +__all__ = [ + "BudgetTracker", + "CostController", +] diff --git a/backend/app/services/safety/costs/controller.py b/backend/app/services/safety/costs/controller.py new file mode 100644 index 0000000..da61ba3 --- /dev/null +++ b/backend/app/services/safety/costs/controller.py @@ -0,0 +1,479 @@ +""" +Cost Controller + +Budget management and cost tracking for agent operations. +""" + +import asyncio +import logging +from datetime import datetime, timedelta +from typing import Any + +from ..config import get_safety_config +from ..exceptions import BudgetExceededError +from ..models import ( + ActionRequest, + BudgetScope, + BudgetStatus, +) + +logger = logging.getLogger(__name__) + + +class BudgetTracker: + """Tracks usage against a budget limit.""" + + def __init__( + self, + scope: BudgetScope, + scope_id: str, + tokens_limit: int, + cost_limit_usd: float, + reset_interval: timedelta | None = None, + warning_threshold: float = 0.8, + ) -> None: + self.scope = scope + self.scope_id = scope_id + self.tokens_limit = tokens_limit + self.cost_limit_usd = cost_limit_usd + self.warning_threshold = warning_threshold + self._reset_interval = reset_interval + + self._tokens_used = 0 + self._cost_used_usd = 0.0 + self._created_at = datetime.utcnow() + self._last_reset = datetime.utcnow() + self._lock = asyncio.Lock() + + async def add_usage(self, tokens: int, cost_usd: float) -> None: + """Add usage to the tracker.""" + async with self._lock: + self._check_reset() + self._tokens_used += tokens + self._cost_used_usd += cost_usd + + async def get_status(self) -> BudgetStatus: + """Get current budget status.""" + async with self._lock: + self._check_reset() + + tokens_remaining = max(0, self.tokens_limit - self._tokens_used) + cost_remaining = max(0, self.cost_limit_usd - self._cost_used_usd) + + token_usage_ratio = ( + self._tokens_used / self.tokens_limit if self.tokens_limit > 0 else 0 + ) + cost_usage_ratio = ( + self._cost_used_usd / self.cost_limit_usd + if self.cost_limit_usd > 0 + else 0 + ) + + is_warning = max(token_usage_ratio, cost_usage_ratio) >= self.warning_threshold + is_exceeded = ( + self._tokens_used >= self.tokens_limit + or self._cost_used_usd >= self.cost_limit_usd + ) + + reset_at = None + if self._reset_interval: + reset_at = self._last_reset + self._reset_interval + + return BudgetStatus( + scope=self.scope, + scope_id=self.scope_id, + tokens_used=self._tokens_used, + tokens_limit=self.tokens_limit, + cost_used_usd=self._cost_used_usd, + cost_limit_usd=self.cost_limit_usd, + tokens_remaining=tokens_remaining, + cost_remaining_usd=cost_remaining, + warning_threshold=self.warning_threshold, + is_warning=is_warning, + is_exceeded=is_exceeded, + reset_at=reset_at, + ) + + async def check_budget(self, estimated_tokens: int, estimated_cost_usd: float) -> bool: + """Check if there's enough budget for an operation.""" + async with self._lock: + self._check_reset() + + would_exceed_tokens = (self._tokens_used + estimated_tokens) > self.tokens_limit + would_exceed_cost = ( + self._cost_used_usd + estimated_cost_usd + ) > self.cost_limit_usd + + return not (would_exceed_tokens or would_exceed_cost) + + def _check_reset(self) -> None: + """Check if budget should reset.""" + if self._reset_interval is None: + return + + now = datetime.utcnow() + if now >= self._last_reset + self._reset_interval: + logger.info( + "Resetting budget for %s:%s", + self.scope.value, + self.scope_id, + ) + self._tokens_used = 0 + self._cost_used_usd = 0.0 + self._last_reset = now + + async def reset(self) -> None: + """Manually reset the budget.""" + async with self._lock: + self._tokens_used = 0 + self._cost_used_usd = 0.0 + self._last_reset = datetime.utcnow() + + +class CostController: + """ + Controls costs and budgets for agent operations. + + Features: + - Per-agent, per-project, per-session budgets + - Real-time cost tracking + - Budget alerts at configurable thresholds + - Cost prediction for planned actions + - Budget rollover policies + """ + + def __init__( + self, + default_session_tokens: int | None = None, + default_session_cost_usd: float | None = None, + default_daily_tokens: int | None = None, + default_daily_cost_usd: float | None = None, + ) -> None: + """ + Initialize the CostController. + + Args: + default_session_tokens: Default token budget per session + default_session_cost_usd: Default USD budget per session + default_daily_tokens: Default token budget per day + default_daily_cost_usd: Default USD budget per day + """ + config = get_safety_config() + + self._default_session_tokens = ( + default_session_tokens or config.default_session_token_budget + ) + self._default_session_cost = ( + default_session_cost_usd or config.default_session_cost_limit + ) + self._default_daily_tokens = ( + default_daily_tokens or config.default_daily_token_budget + ) + self._default_daily_cost = ( + default_daily_cost_usd or config.default_daily_cost_limit + ) + + self._trackers: dict[str, BudgetTracker] = {} + self._lock = asyncio.Lock() + + # Alert handlers + self._alert_handlers: list[Any] = [] + + async def get_or_create_tracker( + self, + scope: BudgetScope, + scope_id: str, + ) -> BudgetTracker: + """Get or create a budget tracker.""" + key = f"{scope.value}:{scope_id}" + + async with self._lock: + if key not in self._trackers: + if scope == BudgetScope.SESSION: + tracker = BudgetTracker( + scope=scope, + scope_id=scope_id, + tokens_limit=self._default_session_tokens, + cost_limit_usd=self._default_session_cost, + ) + elif scope == BudgetScope.DAILY: + tracker = BudgetTracker( + scope=scope, + scope_id=scope_id, + tokens_limit=self._default_daily_tokens, + cost_limit_usd=self._default_daily_cost, + reset_interval=timedelta(days=1), + ) + else: + # Default + tracker = BudgetTracker( + scope=scope, + scope_id=scope_id, + tokens_limit=self._default_session_tokens, + cost_limit_usd=self._default_session_cost, + ) + + self._trackers[key] = tracker + + return self._trackers[key] + + async def check_budget( + self, + agent_id: str, + session_id: str | None, + estimated_tokens: int, + estimated_cost_usd: float, + ) -> bool: + """ + Check if there's enough budget for an operation. + + Args: + agent_id: ID of the agent + session_id: Optional session ID + estimated_tokens: Estimated token usage + estimated_cost_usd: Estimated USD cost + + Returns: + True if budget is available + """ + # Check session budget + if session_id: + session_tracker = await self.get_or_create_tracker( + BudgetScope.SESSION, session_id + ) + if not await session_tracker.check_budget(estimated_tokens, estimated_cost_usd): + return False + + # Check agent daily budget + agent_tracker = await self.get_or_create_tracker( + BudgetScope.DAILY, agent_id + ) + if not await agent_tracker.check_budget(estimated_tokens, estimated_cost_usd): + return False + + return True + + async def check_action(self, action: ActionRequest) -> bool: + """ + Check if an action is within budget. + + Args: + action: The action to check + + Returns: + True if within budget + """ + return await self.check_budget( + agent_id=action.metadata.agent_id, + session_id=action.metadata.session_id, + estimated_tokens=action.estimated_cost_tokens, + estimated_cost_usd=action.estimated_cost_usd, + ) + + async def require_budget( + self, + agent_id: str, + session_id: str | None, + estimated_tokens: int, + estimated_cost_usd: float, + ) -> None: + """ + Require budget or raise exception. + + Args: + agent_id: ID of the agent + session_id: Optional session ID + estimated_tokens: Estimated token usage + estimated_cost_usd: Estimated USD cost + + Raises: + BudgetExceededError: If budget is exceeded + """ + if not await self.check_budget( + agent_id, session_id, estimated_tokens, estimated_cost_usd + ): + # Determine which budget was exceeded + if session_id: + session_tracker = await self.get_or_create_tracker( + BudgetScope.SESSION, session_id + ) + session_status = await session_tracker.get_status() + if session_status.is_exceeded: + raise BudgetExceededError( + "Session budget exceeded", + budget_type="session", + current_usage=session_status.tokens_used, + budget_limit=session_status.tokens_limit, + agent_id=agent_id, + ) + + agent_tracker = await self.get_or_create_tracker( + BudgetScope.DAILY, agent_id + ) + agent_status = await agent_tracker.get_status() + raise BudgetExceededError( + "Daily budget exceeded", + budget_type="daily", + current_usage=agent_status.tokens_used, + budget_limit=agent_status.tokens_limit, + agent_id=agent_id, + ) + + async def record_usage( + self, + agent_id: str, + session_id: str | None, + tokens: int, + cost_usd: float, + ) -> None: + """ + Record actual usage. + + Args: + agent_id: ID of the agent + session_id: Optional session ID + tokens: Actual token usage + cost_usd: Actual USD cost + """ + # Update session budget + if session_id: + session_tracker = await self.get_or_create_tracker( + BudgetScope.SESSION, session_id + ) + await session_tracker.add_usage(tokens, cost_usd) + + # Check for warning + status = await session_tracker.get_status() + if status.is_warning and not status.is_exceeded: + await self._send_alert( + "warning", + f"Session {session_id} at {status.tokens_used}/{status.tokens_limit} tokens", + status, + ) + + # Update agent daily budget + agent_tracker = await self.get_or_create_tracker(BudgetScope.DAILY, agent_id) + await agent_tracker.add_usage(tokens, cost_usd) + + # Check for warning + status = await agent_tracker.get_status() + if status.is_warning and not status.is_exceeded: + await self._send_alert( + "warning", + f"Agent {agent_id} at {status.tokens_used}/{status.tokens_limit} daily tokens", + status, + ) + + async def get_status( + self, + scope: BudgetScope, + scope_id: str, + ) -> BudgetStatus | None: + """ + Get budget status. + + Args: + scope: Budget scope + scope_id: ID within scope + + Returns: + Budget status or None if not tracked + """ + key = f"{scope.value}:{scope_id}" + async with self._lock: + tracker = self._trackers.get(key) + + if tracker: + return await tracker.get_status() + return None + + async def get_all_statuses(self) -> list[BudgetStatus]: + """Get status of all tracked budgets.""" + statuses = [] + async with self._lock: + trackers = list(self._trackers.values()) + + for tracker in trackers: + statuses.append(await tracker.get_status()) + + return statuses + + async def set_budget( + self, + scope: BudgetScope, + scope_id: str, + tokens_limit: int, + cost_limit_usd: float, + ) -> None: + """ + Set a custom budget limit. + + Args: + scope: Budget scope + scope_id: ID within scope + tokens_limit: Token limit + cost_limit_usd: USD limit + """ + key = f"{scope.value}:{scope_id}" + + reset_interval = None + if scope == BudgetScope.DAILY: + reset_interval = timedelta(days=1) + elif scope == BudgetScope.WEEKLY: + reset_interval = timedelta(weeks=1) + elif scope == BudgetScope.MONTHLY: + reset_interval = timedelta(days=30) + + async with self._lock: + self._trackers[key] = BudgetTracker( + scope=scope, + scope_id=scope_id, + tokens_limit=tokens_limit, + cost_limit_usd=cost_limit_usd, + reset_interval=reset_interval, + ) + + async def reset_budget(self, scope: BudgetScope, scope_id: str) -> bool: + """ + Reset a budget tracker. + + Args: + scope: Budget scope + scope_id: ID within scope + + Returns: + True if tracker was found and reset + """ + key = f"{scope.value}:{scope_id}" + async with self._lock: + tracker = self._trackers.get(key) + + if tracker: + await tracker.reset() + return True + return False + + def add_alert_handler(self, handler: Any) -> None: + """Add an alert handler.""" + self._alert_handlers.append(handler) + + def remove_alert_handler(self, handler: Any) -> None: + """Remove an alert handler.""" + if handler in self._alert_handlers: + self._alert_handlers.remove(handler) + + async def _send_alert( + self, + alert_type: str, + message: str, + status: BudgetStatus, + ) -> None: + """Send alert to all handlers.""" + for handler in self._alert_handlers: + try: + if asyncio.iscoroutinefunction(handler): + await handler(alert_type, message, status) + else: + handler(alert_type, message, status) + except Exception as e: + logger.error("Error in alert handler: %s", e) diff --git a/backend/app/services/safety/limits/__init__.py b/backend/app/services/safety/limits/__init__.py index 9f4729c..739d149 100644 --- a/backend/app/services/safety/limits/__init__.py +++ b/backend/app/services/safety/limits/__init__.py @@ -1 +1,15 @@ -"""${dir} module.""" +""" +Rate Limiting Module + +Sliding window rate limiting for agent operations. +""" + +from .limiter import ( + RateLimiter, + SlidingWindowCounter, +) + +__all__ = [ + "RateLimiter", + "SlidingWindowCounter", +] diff --git a/backend/app/services/safety/limits/limiter.py b/backend/app/services/safety/limits/limiter.py new file mode 100644 index 0000000..bc94ab0 --- /dev/null +++ b/backend/app/services/safety/limits/limiter.py @@ -0,0 +1,368 @@ +""" +Rate Limiter + +Sliding window rate limiting for agent operations. +""" + +import asyncio +import logging +import time +from collections import deque + +from ..config import get_safety_config +from ..exceptions import RateLimitExceededError +from ..models import ( + ActionRequest, + RateLimitConfig, + RateLimitStatus, +) + +logger = logging.getLogger(__name__) + + +class SlidingWindowCounter: + """Sliding window counter for rate limiting.""" + + def __init__( + self, + limit: int, + window_seconds: int, + burst_limit: int | None = None, + ) -> None: + self.limit = limit + self.window_seconds = window_seconds + self.burst_limit = burst_limit or limit + self._timestamps: deque[float] = deque() + self._lock = asyncio.Lock() + + async def try_acquire(self) -> tuple[bool, float]: + """ + Try to acquire a slot. + + Returns: + Tuple of (allowed, retry_after_seconds) + """ + now = time.time() + window_start = now - self.window_seconds + + async with self._lock: + # Remove expired entries + while self._timestamps and self._timestamps[0] < window_start: + self._timestamps.popleft() + + current_count = len(self._timestamps) + + # Check burst limit (instant check) + if current_count >= self.burst_limit: + # Calculate retry time + oldest = self._timestamps[0] if self._timestamps else now + retry_after = oldest + self.window_seconds - now + return False, max(0, retry_after) + + # Check window limit + if current_count >= self.limit: + oldest = self._timestamps[0] if self._timestamps else now + retry_after = oldest + self.window_seconds - now + return False, max(0, retry_after) + + # Allow and record + self._timestamps.append(now) + return True, 0.0 + + async def get_status(self) -> tuple[int, int, float]: + """ + Get current status. + + Returns: + Tuple of (current_count, remaining, reset_in_seconds) + """ + now = time.time() + window_start = now - self.window_seconds + + async with self._lock: + # Remove expired entries + while self._timestamps and self._timestamps[0] < window_start: + self._timestamps.popleft() + + current_count = len(self._timestamps) + remaining = max(0, self.limit - current_count) + + if self._timestamps: + reset_in = self._timestamps[0] + self.window_seconds - now + else: + reset_in = 0.0 + + return current_count, remaining, max(0, reset_in) + + +class RateLimiter: + """ + Rate limiter for agent operations. + + Features: + - Per-tool rate limits + - Per-agent rate limits + - Per-resource rate limits + - Sliding window implementation + - Burst allowance with recovery + - Slowdown before hard block + """ + + def __init__(self) -> None: + """Initialize the RateLimiter.""" + config = get_safety_config() + + self._configs: dict[str, RateLimitConfig] = {} + self._counters: dict[str, SlidingWindowCounter] = {} + self._lock = asyncio.Lock() + + # Default rate limits + self._default_limits = { + "actions": RateLimitConfig( + name="actions", + limit=config.default_actions_per_minute, + window_seconds=60, + ), + "llm_calls": RateLimitConfig( + name="llm_calls", + limit=config.default_llm_calls_per_minute, + window_seconds=60, + ), + "file_ops": RateLimitConfig( + name="file_ops", + limit=config.default_file_ops_per_minute, + window_seconds=60, + ), + } + + def configure(self, config: RateLimitConfig) -> None: + """ + Configure a rate limit. + + Args: + config: Rate limit configuration + """ + self._configs[config.name] = config + logger.debug( + "Configured rate limit: %s = %d/%ds", + config.name, + config.limit, + config.window_seconds, + ) + + async def check( + self, + limit_name: str, + key: str, + ) -> RateLimitStatus: + """ + Check rate limit without consuming a slot. + + Args: + limit_name: Name of the rate limit + key: Key for tracking (e.g., agent_id) + + Returns: + Rate limit status + """ + counter = await self._get_counter(limit_name, key) + config = self._get_config(limit_name) + + current, remaining, reset_in = await counter.get_status() + from datetime import datetime, timedelta + + return RateLimitStatus( + name=limit_name, + current_count=current, + limit=config.limit, + window_seconds=config.window_seconds, + remaining=remaining, + reset_at=datetime.utcnow() + timedelta(seconds=reset_in), + is_limited=remaining <= 0, + retry_after_seconds=reset_in if remaining <= 0 else 0.0, + ) + + async def acquire( + self, + limit_name: str, + key: str, + ) -> tuple[bool, RateLimitStatus]: + """ + Try to acquire a rate limit slot. + + Args: + limit_name: Name of the rate limit + key: Key for tracking (e.g., agent_id) + + Returns: + Tuple of (allowed, status) + """ + counter = await self._get_counter(limit_name, key) + config = self._get_config(limit_name) + + allowed, retry_after = await counter.try_acquire() + current, remaining, reset_in = await counter.get_status() + + from datetime import datetime, timedelta + + status = RateLimitStatus( + name=limit_name, + current_count=current, + limit=config.limit, + window_seconds=config.window_seconds, + remaining=remaining, + reset_at=datetime.utcnow() + timedelta(seconds=reset_in), + is_limited=not allowed, + retry_after_seconds=retry_after, + ) + + return allowed, status + + async def check_action( + self, + action: ActionRequest, + ) -> tuple[bool, list[RateLimitStatus]]: + """ + Check all applicable rate limits for an action. + + Args: + action: The action to check + + Returns: + Tuple of (allowed, list of statuses) + """ + agent_id = action.metadata.agent_id + statuses: list[RateLimitStatus] = [] + allowed = True + + # Check general actions limit + actions_allowed, actions_status = await self.acquire("actions", agent_id) + statuses.append(actions_status) + if not actions_allowed: + allowed = False + + # Check LLM-specific limit for LLM calls + if action.action_type.value == "llm_call": + llm_allowed, llm_status = await self.acquire("llm_calls", agent_id) + statuses.append(llm_status) + if not llm_allowed: + allowed = False + + # Check file ops limit for file operations + if action.action_type.value in {"file_read", "file_write", "file_delete"}: + file_allowed, file_status = await self.acquire("file_ops", agent_id) + statuses.append(file_status) + if not file_allowed: + allowed = False + + return allowed, statuses + + async def require( + self, + limit_name: str, + key: str, + ) -> None: + """ + Require rate limit slot or raise exception. + + Args: + limit_name: Name of the rate limit + key: Key for tracking + + Raises: + RateLimitExceededError: If rate limit exceeded + """ + allowed, status = await self.acquire(limit_name, key) + if not allowed: + raise RateLimitExceededError( + f"Rate limit exceeded: {limit_name}", + limit_type=limit_name, + limit_value=status.limit, + window_seconds=status.window_seconds, + retry_after_seconds=status.retry_after_seconds, + ) + + async def get_all_statuses(self, key: str) -> dict[str, RateLimitStatus]: + """ + Get status of all rate limits for a key. + + Args: + key: Key for tracking + + Returns: + Dict of limit name to status + """ + statuses = {} + for name in self._default_limits: + statuses[name] = await self.check(name, key) + for name in self._configs: + if name not in statuses: + statuses[name] = await self.check(name, key) + return statuses + + async def reset(self, limit_name: str, key: str) -> bool: + """ + Reset a rate limit counter. + + Args: + limit_name: Name of the rate limit + key: Key for tracking + + Returns: + True if counter was found and reset + """ + counter_key = f"{limit_name}:{key}" + async with self._lock: + if counter_key in self._counters: + del self._counters[counter_key] + return True + return False + + async def reset_all(self, key: str) -> int: + """ + Reset all rate limit counters for a key. + + Args: + key: Key for tracking + + Returns: + Number of counters reset + """ + count = 0 + async with self._lock: + to_remove = [k for k in self._counters if k.endswith(f":{key}")] + for k in to_remove: + del self._counters[k] + count += 1 + return count + + def _get_config(self, limit_name: str) -> RateLimitConfig: + """Get configuration for a rate limit.""" + if limit_name in self._configs: + return self._configs[limit_name] + if limit_name in self._default_limits: + return self._default_limits[limit_name] + # Return default + return RateLimitConfig( + name=limit_name, + limit=60, + window_seconds=60, + ) + + async def _get_counter( + self, + limit_name: str, + key: str, + ) -> SlidingWindowCounter: + """Get or create a counter.""" + counter_key = f"{limit_name}:{key}" + config = self._get_config(limit_name) + + async with self._lock: + if counter_key not in self._counters: + self._counters[counter_key] = SlidingWindowCounter( + limit=config.limit, + window_seconds=config.window_seconds, + burst_limit=config.burst_limit, + ) + return self._counters[counter_key] diff --git a/backend/app/services/safety/loops/__init__.py b/backend/app/services/safety/loops/__init__.py index 9f4729c..5ff9a8f 100644 --- a/backend/app/services/safety/loops/__init__.py +++ b/backend/app/services/safety/loops/__init__.py @@ -1 +1,17 @@ -"""${dir} module.""" +""" +Loop Detection Module + +Detects and prevents action loops in agent behavior. +""" + +from .detector import ( + ActionSignature, + LoopBreaker, + LoopDetector, +) + +__all__ = [ + "ActionSignature", + "LoopBreaker", + "LoopDetector", +] diff --git a/backend/app/services/safety/loops/detector.py b/backend/app/services/safety/loops/detector.py new file mode 100644 index 0000000..91a9216 --- /dev/null +++ b/backend/app/services/safety/loops/detector.py @@ -0,0 +1,267 @@ +""" +Loop Detector + +Detects and prevents action loops in agent behavior. +""" + +import asyncio +import hashlib +import json +import logging +from collections import Counter, deque +from typing import Any + +from ..config import get_safety_config +from ..exceptions import LoopDetectedError +from ..models import ActionRequest + +logger = logging.getLogger(__name__) + + +class ActionSignature: + """Signature of an action for comparison.""" + + def __init__(self, action: ActionRequest) -> None: + self.action_type = action.action_type.value + self.tool_name = action.tool_name + self.resource = action.resource + self.args_hash = self._hash_args(action.arguments) + + def _hash_args(self, args: dict[str, Any]) -> str: + """Create a hash of the arguments.""" + try: + serialized = json.dumps(args, sort_keys=True, default=str) + return hashlib.sha256(serialized.encode()).hexdigest()[:8] + except Exception: + return "" + + def exact_key(self) -> str: + """Key for exact match detection.""" + return f"{self.action_type}:{self.tool_name}:{self.resource}:{self.args_hash}" + + def semantic_key(self) -> str: + """Key for semantic (similar) match detection.""" + return f"{self.action_type}:{self.tool_name}:{self.resource}" + + def type_key(self) -> str: + """Key for action type only.""" + return f"{self.action_type}" + + +class LoopDetector: + """ + Detects action loops and repetitive behavior. + + Loop Types: + - Exact: Same action with same arguments + - Semantic: Similar actions (same type/tool/resource, different args) + - Oscillation: A→B→A→B patterns + """ + + def __init__( + self, + history_size: int | None = None, + max_exact_repetitions: int | None = None, + max_semantic_repetitions: int | None = None, + ) -> None: + """ + Initialize the LoopDetector. + + Args: + history_size: Size of action history to track + max_exact_repetitions: Max allowed exact repetitions + max_semantic_repetitions: Max allowed semantic repetitions + """ + config = get_safety_config() + + self._history_size = history_size or config.loop_history_size + self._max_exact = max_exact_repetitions or config.max_repeated_actions + self._max_semantic = max_semantic_repetitions or config.max_similar_actions + + # Per-agent history + self._histories: dict[str, deque[ActionSignature]] = {} + self._lock = asyncio.Lock() + + async def check(self, action: ActionRequest) -> tuple[bool, str | None]: + """ + Check if an action would create a loop. + + Args: + action: The action to check + + Returns: + Tuple of (is_loop, loop_type) + """ + agent_id = action.metadata.agent_id + signature = ActionSignature(action) + + async with self._lock: + history = self._get_history(agent_id) + + # Check exact repetition + exact_key = signature.exact_key() + exact_count = sum(1 for h in history if h.exact_key() == exact_key) + if exact_count >= self._max_exact: + return True, "exact" + + # Check semantic repetition + semantic_key = signature.semantic_key() + semantic_count = sum(1 for h in history if h.semantic_key() == semantic_key) + if semantic_count >= self._max_semantic: + return True, "semantic" + + # Check oscillation (A→B→A→B pattern) + if len(history) >= 3: + pattern = self._detect_oscillation(history, signature) + if pattern: + return True, "oscillation" + + return False, None + + async def check_and_raise(self, action: ActionRequest) -> None: + """ + Check for loops and raise if detected. + + Args: + action: The action to check + + Raises: + LoopDetectedError: If loop is detected + """ + is_loop, loop_type = await self.check(action) + if is_loop: + signature = ActionSignature(action) + raise LoopDetectedError( + f"Loop detected: {loop_type}", + loop_type=loop_type or "unknown", + repetition_count=self._max_exact if loop_type == "exact" else self._max_semantic, + action_pattern=[signature.semantic_key()], + agent_id=action.metadata.agent_id, + action_id=action.id, + ) + + async def record(self, action: ActionRequest) -> None: + """ + Record an action in history. + + Args: + action: The action to record + """ + agent_id = action.metadata.agent_id + signature = ActionSignature(action) + + async with self._lock: + history = self._get_history(agent_id) + history.append(signature) + + async def clear_history(self, agent_id: str) -> None: + """ + Clear history for an agent. + + Args: + agent_id: ID of the agent + """ + async with self._lock: + if agent_id in self._histories: + self._histories[agent_id].clear() + + async def get_stats(self, agent_id: str) -> dict[str, Any]: + """ + Get loop detection stats for an agent. + + Args: + agent_id: ID of the agent + + Returns: + Stats dictionary + """ + async with self._lock: + history = self._get_history(agent_id) + + # Count action types + type_counts = Counter(h.type_key() for h in history) + semantic_counts = Counter(h.semantic_key() for h in history) + + return { + "history_size": len(history), + "max_history": self._history_size, + "action_type_counts": dict(type_counts), + "top_semantic_patterns": semantic_counts.most_common(5), + } + + def _get_history(self, agent_id: str) -> deque[ActionSignature]: + """Get or create history for an agent.""" + if agent_id not in self._histories: + self._histories[agent_id] = deque(maxlen=self._history_size) + return self._histories[agent_id] + + def _detect_oscillation( + self, + history: deque[ActionSignature], + current: ActionSignature, + ) -> bool: + """ + Detect A→B→A→B oscillation pattern. + + Looks at last 4+ actions including current. + """ + if len(history) < 3: + return False + + # Get last 3 actions + current + recent = [*list(history)[-3:], current] + + # Check for A→B→A→B pattern + if len(recent) >= 4: + # Get semantic keys + keys = [a.semantic_key() for a in recent[-4:]] + + # Pattern: k[0]==k[2] and k[1]==k[3] and k[0]!=k[1] + if keys[0] == keys[2] and keys[1] == keys[3] and keys[0] != keys[1]: + return True + + return False + + +class LoopBreaker: + """ + Strategies for breaking detected loops. + """ + + @staticmethod + async def suggest_alternatives( + action: ActionRequest, + loop_type: str, + ) -> list[str]: + """ + Suggest alternative actions when loop is detected. + + Args: + action: The looping action + loop_type: Type of loop detected + + Returns: + List of suggestions + """ + suggestions = [] + + if loop_type == "exact": + suggestions.append( + "The same action with identical arguments has been repeated too many times. " + "Consider: (1) Verify the action succeeded, (2) Try a different approach, " + "(3) Escalate for human review" + ) + elif loop_type == "semantic": + suggestions.append( + "Similar actions have been repeated too many times. " + "Consider: (1) Review if the approach is working, (2) Try an alternative method, " + "(3) Request clarification on the goal" + ) + elif loop_type == "oscillation": + suggestions.append( + "An oscillating pattern was detected (A→B→A→B). " + "This usually indicates conflicting goals or a stuck state. " + "Consider: (1) Step back and reassess, (2) Request human guidance" + ) + + return suggestions diff --git a/backend/app/services/safety/permissions/__init__.py b/backend/app/services/safety/permissions/__init__.py index 9f4729c..83c022a 100644 --- a/backend/app/services/safety/permissions/__init__.py +++ b/backend/app/services/safety/permissions/__init__.py @@ -1 +1,15 @@ -"""${dir} module.""" +""" +Permission Management Module + +Agent permissions for resource access. +""" + +from .manager import ( + PermissionGrant, + PermissionManager, +) + +__all__ = [ + "PermissionGrant", + "PermissionManager", +] diff --git a/backend/app/services/safety/permissions/manager.py b/backend/app/services/safety/permissions/manager.py new file mode 100644 index 0000000..c417c9d --- /dev/null +++ b/backend/app/services/safety/permissions/manager.py @@ -0,0 +1,384 @@ +""" +Permission Manager + +Manages permissions for agent actions on resources. +""" + +import asyncio +import fnmatch +import logging +from datetime import datetime, timedelta +from uuid import uuid4 + +from ..exceptions import PermissionDeniedError +from ..models import ( + ActionRequest, + ActionType, + PermissionLevel, + ResourceType, +) + +logger = logging.getLogger(__name__) + + +class PermissionGrant: + """A permission grant for an agent on a resource.""" + + def __init__( + self, + agent_id: str, + resource_pattern: str, + resource_type: ResourceType, + level: PermissionLevel, + *, + expires_at: datetime | None = None, + granted_by: str | None = None, + reason: str | None = None, + ) -> None: + self.id = str(uuid4()) + self.agent_id = agent_id + self.resource_pattern = resource_pattern + self.resource_type = resource_type + self.level = level + self.expires_at = expires_at + self.granted_by = granted_by + self.reason = reason + self.created_at = datetime.utcnow() + + def is_expired(self) -> bool: + """Check if the grant has expired.""" + if self.expires_at is None: + return False + return datetime.utcnow() > self.expires_at + + def matches(self, resource: str, resource_type: ResourceType) -> bool: + """Check if this grant applies to a resource.""" + if self.resource_type != resource_type: + return False + return fnmatch.fnmatch(resource, self.resource_pattern) + + def allows(self, required_level: PermissionLevel) -> bool: + """Check if this grant allows the required permission level.""" + # Permission level hierarchy + hierarchy = { + PermissionLevel.NONE: 0, + PermissionLevel.READ: 1, + PermissionLevel.WRITE: 2, + PermissionLevel.EXECUTE: 3, + PermissionLevel.DELETE: 4, + PermissionLevel.ADMIN: 5, + } + + return hierarchy[self.level] >= hierarchy[required_level] + + +class PermissionManager: + """ + Manages permissions for agent access to resources. + + Features: + - Permission grants by agent/resource pattern + - Permission inheritance (project → agent → action) + - Temporary permissions with expiration + - Least-privilege defaults + - Permission escalation logging + """ + + def __init__( + self, + default_deny: bool = True, + ) -> None: + """ + Initialize the PermissionManager. + + Args: + default_deny: If True, deny access unless explicitly granted + """ + self._grants: list[PermissionGrant] = [] + self._default_deny = default_deny + self._lock = asyncio.Lock() + + # Default permissions for common resources + self._default_permissions: dict[ResourceType, PermissionLevel] = { + ResourceType.FILE: PermissionLevel.READ, + ResourceType.DATABASE: PermissionLevel.READ, + ResourceType.API: PermissionLevel.READ, + ResourceType.GIT: PermissionLevel.READ, + ResourceType.LLM: PermissionLevel.EXECUTE, + ResourceType.SHELL: PermissionLevel.NONE, + ResourceType.NETWORK: PermissionLevel.READ, + } + + async def grant( + self, + agent_id: str, + resource_pattern: str, + resource_type: ResourceType, + level: PermissionLevel, + *, + duration_seconds: int | None = None, + granted_by: str | None = None, + reason: str | None = None, + ) -> PermissionGrant: + """ + Grant a permission to an agent. + + Args: + agent_id: ID of the agent + resource_pattern: Pattern for matching resources (supports wildcards) + resource_type: Type of resource + level: Permission level to grant + duration_seconds: Optional duration for temporary permission + granted_by: Who granted the permission + reason: Reason for granting + + Returns: + The created permission grant + """ + expires_at = None + if duration_seconds: + expires_at = datetime.utcnow() + timedelta(seconds=duration_seconds) + + grant = PermissionGrant( + agent_id=agent_id, + resource_pattern=resource_pattern, + resource_type=resource_type, + level=level, + expires_at=expires_at, + granted_by=granted_by, + reason=reason, + ) + + async with self._lock: + self._grants.append(grant) + + logger.info( + "Permission granted: agent=%s, resource=%s, type=%s, level=%s", + agent_id, + resource_pattern, + resource_type.value, + level.value, + ) + + return grant + + async def revoke(self, grant_id: str) -> bool: + """ + Revoke a permission grant. + + Args: + grant_id: ID of the grant to revoke + + Returns: + True if grant was found and revoked + """ + async with self._lock: + for i, grant in enumerate(self._grants): + if grant.id == grant_id: + del self._grants[i] + logger.info("Permission revoked: %s", grant_id) + return True + return False + + async def revoke_all(self, agent_id: str) -> int: + """ + Revoke all permissions for an agent. + + Args: + agent_id: ID of the agent + + Returns: + Number of grants revoked + """ + async with self._lock: + original_count = len(self._grants) + self._grants = [g for g in self._grants if g.agent_id != agent_id] + revoked = original_count - len(self._grants) + + if revoked: + logger.info("Revoked %d permissions for agent %s", revoked, agent_id) + + return revoked + + async def check( + self, + agent_id: str, + resource: str, + resource_type: ResourceType, + required_level: PermissionLevel, + ) -> bool: + """ + Check if an agent has permission to access a resource. + + Args: + agent_id: ID of the agent + resource: Resource to access + resource_type: Type of resource + required_level: Required permission level + + Returns: + True if access is allowed + """ + # Clean up expired grants + await self._cleanup_expired() + + async with self._lock: + for grant in self._grants: + if grant.agent_id != agent_id: + continue + + if grant.is_expired(): + continue + + if grant.matches(resource, resource_type): + if grant.allows(required_level): + return True + + # Check default permissions + if not self._default_deny: + default_level = self._default_permissions.get( + resource_type, PermissionLevel.NONE + ) + hierarchy = { + PermissionLevel.NONE: 0, + PermissionLevel.READ: 1, + PermissionLevel.WRITE: 2, + PermissionLevel.EXECUTE: 3, + PermissionLevel.DELETE: 4, + PermissionLevel.ADMIN: 5, + } + if hierarchy[default_level] >= hierarchy[required_level]: + return True + + return False + + async def check_action(self, action: ActionRequest) -> bool: + """ + Check if an action is permitted. + + Args: + action: The action to check + + Returns: + True if action is allowed + """ + # Determine required permission level from action type + level_map = { + ActionType.FILE_READ: PermissionLevel.READ, + ActionType.FILE_WRITE: PermissionLevel.WRITE, + ActionType.FILE_DELETE: PermissionLevel.DELETE, + ActionType.DATABASE_QUERY: PermissionLevel.READ, + ActionType.DATABASE_MUTATE: PermissionLevel.WRITE, + ActionType.SHELL_COMMAND: PermissionLevel.EXECUTE, + ActionType.API_CALL: PermissionLevel.EXECUTE, + ActionType.GIT_OPERATION: PermissionLevel.WRITE, + ActionType.LLM_CALL: PermissionLevel.EXECUTE, + ActionType.NETWORK_REQUEST: PermissionLevel.READ, + ActionType.TOOL_CALL: PermissionLevel.EXECUTE, + } + + required_level = level_map.get(action.action_type, PermissionLevel.EXECUTE) + + # Determine resource type from action + resource_type_map = { + ActionType.FILE_READ: ResourceType.FILE, + ActionType.FILE_WRITE: ResourceType.FILE, + ActionType.FILE_DELETE: ResourceType.FILE, + ActionType.DATABASE_QUERY: ResourceType.DATABASE, + ActionType.DATABASE_MUTATE: ResourceType.DATABASE, + ActionType.SHELL_COMMAND: ResourceType.SHELL, + ActionType.API_CALL: ResourceType.API, + ActionType.GIT_OPERATION: ResourceType.GIT, + ActionType.LLM_CALL: ResourceType.LLM, + ActionType.NETWORK_REQUEST: ResourceType.NETWORK, + } + + resource_type = resource_type_map.get(action.action_type, ResourceType.CUSTOM) + resource = action.resource or action.tool_name or "*" + + return await self.check( + agent_id=action.metadata.agent_id, + resource=resource, + resource_type=resource_type, + required_level=required_level, + ) + + async def require_permission( + self, + agent_id: str, + resource: str, + resource_type: ResourceType, + required_level: PermissionLevel, + ) -> None: + """ + Require permission or raise exception. + + Args: + agent_id: ID of the agent + resource: Resource to access + resource_type: Type of resource + required_level: Required permission level + + Raises: + PermissionDeniedError: If permission is denied + """ + if not await self.check(agent_id, resource, resource_type, required_level): + raise PermissionDeniedError( + f"Permission denied: {resource}", + action_type=None, + resource=resource, + required_permission=required_level.value, + agent_id=agent_id, + ) + + async def list_grants( + self, + agent_id: str | None = None, + resource_type: ResourceType | None = None, + ) -> list[PermissionGrant]: + """ + List permission grants. + + Args: + agent_id: Optional filter by agent + resource_type: Optional filter by resource type + + Returns: + List of matching grants + """ + await self._cleanup_expired() + + async with self._lock: + grants = list(self._grants) + + if agent_id: + grants = [g for g in grants if g.agent_id == agent_id] + + if resource_type: + grants = [g for g in grants if g.resource_type == resource_type] + + return grants + + def set_default_permission( + self, + resource_type: ResourceType, + level: PermissionLevel, + ) -> None: + """ + Set the default permission level for a resource type. + + Args: + resource_type: Type of resource + level: Default permission level + """ + self._default_permissions[resource_type] = level + + async def _cleanup_expired(self) -> None: + """Remove expired grants.""" + async with self._lock: + original_count = len(self._grants) + self._grants = [g for g in self._grants if not g.is_expired()] + removed = original_count - len(self._grants) + + if removed: + logger.debug("Cleaned up %d expired permission grants", removed) diff --git a/backend/app/services/safety/validation/__init__.py b/backend/app/services/safety/validation/__init__.py index 9f4729c..20df8a5 100644 --- a/backend/app/services/safety/validation/__init__.py +++ b/backend/app/services/safety/validation/__init__.py @@ -1 +1,21 @@ -"""${dir} module.""" +""" +Action Validation Module + +Pre-execution validation with rule engine. +""" + +from .validator import ( + ActionValidator, + ValidationCache, + create_allow_rule, + create_approval_rule, + create_deny_rule, +) + +__all__ = [ + "ActionValidator", + "ValidationCache", + "create_allow_rule", + "create_approval_rule", + "create_deny_rule", +] diff --git a/backend/app/services/safety/validation/validator.py b/backend/app/services/safety/validation/validator.py new file mode 100644 index 0000000..0187414 --- /dev/null +++ b/backend/app/services/safety/validation/validator.py @@ -0,0 +1,439 @@ +""" +Action Validator + +Pre-execution validation with rule engine for action requests. +""" + +import asyncio +import fnmatch +import logging +from collections import OrderedDict + +from ..config import get_safety_config +from ..models import ( + ActionRequest, + ActionType, + SafetyDecision, + SafetyPolicy, + ValidationResult, + ValidationRule, +) + +logger = logging.getLogger(__name__) + + +class ValidationCache: + """LRU cache for validation results.""" + + def __init__(self, max_size: int = 1000, ttl_seconds: int = 60) -> None: + self._cache: OrderedDict[str, tuple[ValidationResult, float]] = OrderedDict() + self._max_size = max_size + self._ttl = ttl_seconds + self._lock = asyncio.Lock() + + async def get(self, key: str) -> ValidationResult | None: + """Get cached validation result.""" + import time + + async with self._lock: + if key not in self._cache: + return None + + result, timestamp = self._cache[key] + if time.time() - timestamp > self._ttl: + del self._cache[key] + return None + + # Move to end (LRU) + self._cache.move_to_end(key) + return result + + async def set(self, key: str, result: ValidationResult) -> None: + """Cache a validation result.""" + import time + + async with self._lock: + if key in self._cache: + self._cache.move_to_end(key) + else: + if len(self._cache) >= self._max_size: + self._cache.popitem(last=False) + self._cache[key] = (result, time.time()) + + async def clear(self) -> None: + """Clear the cache.""" + async with self._lock: + self._cache.clear() + + +class ActionValidator: + """ + Validates actions against safety rules before execution. + + Features: + - Rule-based validation engine + - Allow/deny/require-approval rules + - Pattern matching for tools and resources + - Validation result caching + - Bypass capability for emergencies + """ + + def __init__( + self, + cache_enabled: bool = True, + cache_size: int = 1000, + cache_ttl: int = 60, + ) -> None: + """ + Initialize the ActionValidator. + + Args: + cache_enabled: Whether to cache validation results + cache_size: Maximum cache entries + cache_ttl: Cache TTL in seconds + """ + self._rules: list[ValidationRule] = [] + self._cache_enabled = cache_enabled + self._cache = ValidationCache(max_size=cache_size, ttl_seconds=cache_ttl) + self._bypass_enabled = False + self._bypass_reason: str | None = None + + config = get_safety_config() + self._cache_enabled = cache_enabled + self._cache_ttl = config.validation_cache_ttl + self._cache_size = config.validation_cache_size + + def add_rule(self, rule: ValidationRule) -> None: + """ + Add a validation rule. + + Args: + rule: The rule to add + """ + self._rules.append(rule) + # Re-sort by priority (higher first) + self._rules.sort(key=lambda r: r.priority, reverse=True) + logger.debug("Added validation rule: %s (priority %d)", rule.name, rule.priority) + + def remove_rule(self, rule_id: str) -> bool: + """ + Remove a validation rule by ID. + + Args: + rule_id: ID of the rule to remove + + Returns: + True if rule was found and removed + """ + for i, rule in enumerate(self._rules): + if rule.id == rule_id: + del self._rules[i] + logger.debug("Removed validation rule: %s", rule_id) + return True + return False + + def clear_rules(self) -> None: + """Remove all validation rules.""" + self._rules.clear() + + def load_rules_from_policy(self, policy: SafetyPolicy) -> None: + """ + Load validation rules from a safety policy. + + Args: + policy: The policy to load rules from + """ + # Clear existing rules + self.clear_rules() + + # Add rules from policy + for rule in policy.validation_rules: + self.add_rule(rule) + + # Create implicit rules from policy settings + + # Denied tools + for i, pattern in enumerate(policy.denied_tools): + self.add_rule( + ValidationRule( + name=f"deny_tool_{i}", + description=f"Deny tool pattern: {pattern}", + priority=100, # High priority for denials + tool_patterns=[pattern], + decision=SafetyDecision.DENY, + reason=f"Tool matches denied pattern: {pattern}", + ) + ) + + # Require approval patterns + for i, pattern in enumerate(policy.require_approval_for): + if pattern == "*": + # All actions require approval + self.add_rule( + ValidationRule( + name="require_approval_all", + description="All actions require approval", + priority=50, + action_types=list(ActionType), + decision=SafetyDecision.REQUIRE_APPROVAL, + reason="All actions require human approval", + ) + ) + else: + self.add_rule( + ValidationRule( + name=f"require_approval_{i}", + description=f"Require approval for: {pattern}", + priority=50, + tool_patterns=[pattern], + decision=SafetyDecision.REQUIRE_APPROVAL, + reason=f"Action matches approval-required pattern: {pattern}", + ) + ) + + logger.info("Loaded %d rules from policy: %s", len(self._rules), policy.name) + + async def validate( + self, + action: ActionRequest, + policy: SafetyPolicy | None = None, + ) -> ValidationResult: + """ + Validate an action against all rules. + + Args: + action: The action to validate + policy: Optional policy override + + Returns: + ValidationResult with decision and details + """ + # Check bypass + if self._bypass_enabled: + logger.warning( + "Validation bypass active: %s - allowing action %s", + self._bypass_reason, + action.id, + ) + return ValidationResult( + action_id=action.id, + decision=SafetyDecision.ALLOW, + applied_rules=[], + reasons=[f"Validation bypassed: {self._bypass_reason}"], + ) + + # Check cache + if self._cache_enabled: + cache_key = self._get_cache_key(action) + cached = await self._cache.get(cache_key) + if cached: + logger.debug("Using cached validation for action %s", action.id) + return cached + + # Load rules from policy if provided + if policy and not self._rules: + self.load_rules_from_policy(policy) + + # Validate against rules + applied_rules: list[str] = [] + reasons: list[str] = [] + final_decision = SafetyDecision.ALLOW + approval_id: str | None = None + + for rule in self._rules: + if not rule.enabled: + continue + + if self._rule_matches(rule, action): + applied_rules.append(rule.id) + + if rule.reason: + reasons.append(rule.reason) + + # Handle decision priority + if rule.decision == SafetyDecision.DENY: + # Deny takes precedence + final_decision = SafetyDecision.DENY + break + + elif rule.decision == SafetyDecision.REQUIRE_APPROVAL: + # Upgrade to require approval + if final_decision != SafetyDecision.DENY: + final_decision = SafetyDecision.REQUIRE_APPROVAL + + # If no rules matched and no explicit allow, default to allow + if not applied_rules: + reasons.append("No matching rules - default allow") + + result = ValidationResult( + action_id=action.id, + decision=final_decision, + applied_rules=applied_rules, + reasons=reasons, + approval_id=approval_id, + ) + + # Cache result + if self._cache_enabled: + cache_key = self._get_cache_key(action) + await self._cache.set(cache_key, result) + + return result + + async def validate_batch( + self, + actions: list[ActionRequest], + policy: SafetyPolicy | None = None, + ) -> list[ValidationResult]: + """ + Validate multiple actions. + + Args: + actions: Actions to validate + policy: Optional policy override + + Returns: + List of validation results + """ + tasks = [self.validate(action, policy) for action in actions] + return await asyncio.gather(*tasks) + + def enable_bypass(self, reason: str) -> None: + """ + Enable validation bypass (emergency use only). + + Args: + reason: Reason for enabling bypass + """ + logger.critical("Validation bypass enabled: %s", reason) + self._bypass_enabled = True + self._bypass_reason = reason + + def disable_bypass(self) -> None: + """Disable validation bypass.""" + logger.info("Validation bypass disabled") + self._bypass_enabled = False + self._bypass_reason = None + + async def clear_cache(self) -> None: + """Clear the validation cache.""" + await self._cache.clear() + + def _rule_matches(self, rule: ValidationRule, action: ActionRequest) -> bool: + """Check if a rule matches an action.""" + # Check action types + if rule.action_types: + if action.action_type not in rule.action_types: + return False + + # Check tool patterns + if rule.tool_patterns: + if not action.tool_name: + return False + matched = False + for pattern in rule.tool_patterns: + if self._matches_pattern(action.tool_name, pattern): + matched = True + break + if not matched: + return False + + # Check resource patterns + if rule.resource_patterns: + if not action.resource: + return False + matched = False + for pattern in rule.resource_patterns: + if self._matches_pattern(action.resource, pattern): + matched = True + break + if not matched: + return False + + # Check agent IDs + if rule.agent_ids: + if action.metadata.agent_id not in rule.agent_ids: + return False + + return True + + def _matches_pattern(self, value: str, pattern: str) -> bool: + """Check if value matches a pattern (supports wildcards).""" + if pattern == "*": + return True + + # Use fnmatch for glob-style matching + return fnmatch.fnmatch(value, pattern) + + def _get_cache_key(self, action: ActionRequest) -> str: + """Generate a cache key for an action.""" + # Key based on action characteristics that affect validation + key_parts = [ + action.action_type.value, + action.tool_name or "", + action.resource or "", + action.metadata.agent_id, + action.metadata.autonomy_level.value, + ] + return ":".join(key_parts) + + +# Module-level convenience functions + + +def create_allow_rule( + name: str, + tool_patterns: list[str] | None = None, + resource_patterns: list[str] | None = None, + action_types: list[ActionType] | None = None, + priority: int = 0, +) -> ValidationRule: + """Create an allow rule.""" + return ValidationRule( + name=name, + tool_patterns=tool_patterns, + resource_patterns=resource_patterns, + action_types=action_types, + decision=SafetyDecision.ALLOW, + priority=priority, + ) + + +def create_deny_rule( + name: str, + tool_patterns: list[str] | None = None, + resource_patterns: list[str] | None = None, + action_types: list[ActionType] | None = None, + reason: str | None = None, + priority: int = 100, +) -> ValidationRule: + """Create a deny rule.""" + return ValidationRule( + name=name, + tool_patterns=tool_patterns, + resource_patterns=resource_patterns, + action_types=action_types, + decision=SafetyDecision.DENY, + reason=reason, + priority=priority, + ) + + +def create_approval_rule( + name: str, + tool_patterns: list[str] | None = None, + resource_patterns: list[str] | None = None, + action_types: list[ActionType] | None = None, + reason: str | None = None, + priority: int = 50, +) -> ValidationRule: + """Create a require-approval rule.""" + return ValidationRule( + name=name, + tool_patterns=tool_patterns, + resource_patterns=resource_patterns, + action_types=action_types, + decision=SafetyDecision.REQUIRE_APPROVAL, + reason=reason, + priority=priority, + )