forked from cardosofelipe/fast-next-template
feat(backend): add Phase B safety subsystems (#63)
Implements core control subsystems for the safety framework: **Action Validation (validation/validator.py):** - Rule-based validation engine with priority ordering - Allow/deny/require-approval rule types - Pattern matching for tools and resources - Validation result caching with LRU eviction - Emergency bypass capability with audit **Permission System (permissions/manager.py):** - Per-agent permission grants on resources - Resource pattern matching (wildcards) - Temporary permissions with expiration - Permission inheritance hierarchy - Default deny with configurable defaults **Cost Control (costs/controller.py):** - Per-session and per-day budget tracking - Token and USD cost limits - Warning alerts at configurable thresholds - Budget rollover and reset policies - Real-time usage tracking **Rate Limiting (limits/limiter.py):** - Sliding window rate limiter - Per-action, per-LLM-call, per-file-op limits - Burst allowance with recovery - Configurable limits per operation type **Loop Detection (loops/detector.py):** - Exact repetition detection (same action+args) - Semantic repetition (similar actions) - Oscillation pattern detection (A→B→A→B) - Per-agent action history tracking - Loop breaking suggestions 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1 +1,15 @@
|
||||
"""${dir} module."""
|
||||
"""
|
||||
Cost Control Module
|
||||
|
||||
Budget management and cost tracking.
|
||||
"""
|
||||
|
||||
from .controller import (
|
||||
BudgetTracker,
|
||||
CostController,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BudgetTracker",
|
||||
"CostController",
|
||||
]
|
||||
|
||||
479
backend/app/services/safety/costs/controller.py
Normal file
479
backend/app/services/safety/costs/controller.py
Normal file
@@ -0,0 +1,479 @@
|
||||
"""
|
||||
Cost Controller
|
||||
|
||||
Budget management and cost tracking for agent operations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from ..config import get_safety_config
|
||||
from ..exceptions import BudgetExceededError
|
||||
from ..models import (
|
||||
ActionRequest,
|
||||
BudgetScope,
|
||||
BudgetStatus,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BudgetTracker:
|
||||
"""Tracks usage against a budget limit."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scope: BudgetScope,
|
||||
scope_id: str,
|
||||
tokens_limit: int,
|
||||
cost_limit_usd: float,
|
||||
reset_interval: timedelta | None = None,
|
||||
warning_threshold: float = 0.8,
|
||||
) -> None:
|
||||
self.scope = scope
|
||||
self.scope_id = scope_id
|
||||
self.tokens_limit = tokens_limit
|
||||
self.cost_limit_usd = cost_limit_usd
|
||||
self.warning_threshold = warning_threshold
|
||||
self._reset_interval = reset_interval
|
||||
|
||||
self._tokens_used = 0
|
||||
self._cost_used_usd = 0.0
|
||||
self._created_at = datetime.utcnow()
|
||||
self._last_reset = datetime.utcnow()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def add_usage(self, tokens: int, cost_usd: float) -> None:
|
||||
"""Add usage to the tracker."""
|
||||
async with self._lock:
|
||||
self._check_reset()
|
||||
self._tokens_used += tokens
|
||||
self._cost_used_usd += cost_usd
|
||||
|
||||
async def get_status(self) -> BudgetStatus:
|
||||
"""Get current budget status."""
|
||||
async with self._lock:
|
||||
self._check_reset()
|
||||
|
||||
tokens_remaining = max(0, self.tokens_limit - self._tokens_used)
|
||||
cost_remaining = max(0, self.cost_limit_usd - self._cost_used_usd)
|
||||
|
||||
token_usage_ratio = (
|
||||
self._tokens_used / self.tokens_limit if self.tokens_limit > 0 else 0
|
||||
)
|
||||
cost_usage_ratio = (
|
||||
self._cost_used_usd / self.cost_limit_usd
|
||||
if self.cost_limit_usd > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
is_warning = max(token_usage_ratio, cost_usage_ratio) >= self.warning_threshold
|
||||
is_exceeded = (
|
||||
self._tokens_used >= self.tokens_limit
|
||||
or self._cost_used_usd >= self.cost_limit_usd
|
||||
)
|
||||
|
||||
reset_at = None
|
||||
if self._reset_interval:
|
||||
reset_at = self._last_reset + self._reset_interval
|
||||
|
||||
return BudgetStatus(
|
||||
scope=self.scope,
|
||||
scope_id=self.scope_id,
|
||||
tokens_used=self._tokens_used,
|
||||
tokens_limit=self.tokens_limit,
|
||||
cost_used_usd=self._cost_used_usd,
|
||||
cost_limit_usd=self.cost_limit_usd,
|
||||
tokens_remaining=tokens_remaining,
|
||||
cost_remaining_usd=cost_remaining,
|
||||
warning_threshold=self.warning_threshold,
|
||||
is_warning=is_warning,
|
||||
is_exceeded=is_exceeded,
|
||||
reset_at=reset_at,
|
||||
)
|
||||
|
||||
async def check_budget(self, estimated_tokens: int, estimated_cost_usd: float) -> bool:
|
||||
"""Check if there's enough budget for an operation."""
|
||||
async with self._lock:
|
||||
self._check_reset()
|
||||
|
||||
would_exceed_tokens = (self._tokens_used + estimated_tokens) > self.tokens_limit
|
||||
would_exceed_cost = (
|
||||
self._cost_used_usd + estimated_cost_usd
|
||||
) > self.cost_limit_usd
|
||||
|
||||
return not (would_exceed_tokens or would_exceed_cost)
|
||||
|
||||
def _check_reset(self) -> None:
|
||||
"""Check if budget should reset."""
|
||||
if self._reset_interval is None:
|
||||
return
|
||||
|
||||
now = datetime.utcnow()
|
||||
if now >= self._last_reset + self._reset_interval:
|
||||
logger.info(
|
||||
"Resetting budget for %s:%s",
|
||||
self.scope.value,
|
||||
self.scope_id,
|
||||
)
|
||||
self._tokens_used = 0
|
||||
self._cost_used_usd = 0.0
|
||||
self._last_reset = now
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Manually reset the budget."""
|
||||
async with self._lock:
|
||||
self._tokens_used = 0
|
||||
self._cost_used_usd = 0.0
|
||||
self._last_reset = datetime.utcnow()
|
||||
|
||||
|
||||
class CostController:
|
||||
"""
|
||||
Controls costs and budgets for agent operations.
|
||||
|
||||
Features:
|
||||
- Per-agent, per-project, per-session budgets
|
||||
- Real-time cost tracking
|
||||
- Budget alerts at configurable thresholds
|
||||
- Cost prediction for planned actions
|
||||
- Budget rollover policies
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default_session_tokens: int | None = None,
|
||||
default_session_cost_usd: float | None = None,
|
||||
default_daily_tokens: int | None = None,
|
||||
default_daily_cost_usd: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the CostController.
|
||||
|
||||
Args:
|
||||
default_session_tokens: Default token budget per session
|
||||
default_session_cost_usd: Default USD budget per session
|
||||
default_daily_tokens: Default token budget per day
|
||||
default_daily_cost_usd: Default USD budget per day
|
||||
"""
|
||||
config = get_safety_config()
|
||||
|
||||
self._default_session_tokens = (
|
||||
default_session_tokens or config.default_session_token_budget
|
||||
)
|
||||
self._default_session_cost = (
|
||||
default_session_cost_usd or config.default_session_cost_limit
|
||||
)
|
||||
self._default_daily_tokens = (
|
||||
default_daily_tokens or config.default_daily_token_budget
|
||||
)
|
||||
self._default_daily_cost = (
|
||||
default_daily_cost_usd or config.default_daily_cost_limit
|
||||
)
|
||||
|
||||
self._trackers: dict[str, BudgetTracker] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Alert handlers
|
||||
self._alert_handlers: list[Any] = []
|
||||
|
||||
async def get_or_create_tracker(
|
||||
self,
|
||||
scope: BudgetScope,
|
||||
scope_id: str,
|
||||
) -> BudgetTracker:
|
||||
"""Get or create a budget tracker."""
|
||||
key = f"{scope.value}:{scope_id}"
|
||||
|
||||
async with self._lock:
|
||||
if key not in self._trackers:
|
||||
if scope == BudgetScope.SESSION:
|
||||
tracker = BudgetTracker(
|
||||
scope=scope,
|
||||
scope_id=scope_id,
|
||||
tokens_limit=self._default_session_tokens,
|
||||
cost_limit_usd=self._default_session_cost,
|
||||
)
|
||||
elif scope == BudgetScope.DAILY:
|
||||
tracker = BudgetTracker(
|
||||
scope=scope,
|
||||
scope_id=scope_id,
|
||||
tokens_limit=self._default_daily_tokens,
|
||||
cost_limit_usd=self._default_daily_cost,
|
||||
reset_interval=timedelta(days=1),
|
||||
)
|
||||
else:
|
||||
# Default
|
||||
tracker = BudgetTracker(
|
||||
scope=scope,
|
||||
scope_id=scope_id,
|
||||
tokens_limit=self._default_session_tokens,
|
||||
cost_limit_usd=self._default_session_cost,
|
||||
)
|
||||
|
||||
self._trackers[key] = tracker
|
||||
|
||||
return self._trackers[key]
|
||||
|
||||
async def check_budget(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str | None,
|
||||
estimated_tokens: int,
|
||||
estimated_cost_usd: float,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if there's enough budget for an operation.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
session_id: Optional session ID
|
||||
estimated_tokens: Estimated token usage
|
||||
estimated_cost_usd: Estimated USD cost
|
||||
|
||||
Returns:
|
||||
True if budget is available
|
||||
"""
|
||||
# Check session budget
|
||||
if session_id:
|
||||
session_tracker = await self.get_or_create_tracker(
|
||||
BudgetScope.SESSION, session_id
|
||||
)
|
||||
if not await session_tracker.check_budget(estimated_tokens, estimated_cost_usd):
|
||||
return False
|
||||
|
||||
# Check agent daily budget
|
||||
agent_tracker = await self.get_or_create_tracker(
|
||||
BudgetScope.DAILY, agent_id
|
||||
)
|
||||
if not await agent_tracker.check_budget(estimated_tokens, estimated_cost_usd):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def check_action(self, action: ActionRequest) -> bool:
|
||||
"""
|
||||
Check if an action is within budget.
|
||||
|
||||
Args:
|
||||
action: The action to check
|
||||
|
||||
Returns:
|
||||
True if within budget
|
||||
"""
|
||||
return await self.check_budget(
|
||||
agent_id=action.metadata.agent_id,
|
||||
session_id=action.metadata.session_id,
|
||||
estimated_tokens=action.estimated_cost_tokens,
|
||||
estimated_cost_usd=action.estimated_cost_usd,
|
||||
)
|
||||
|
||||
async def require_budget(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str | None,
|
||||
estimated_tokens: int,
|
||||
estimated_cost_usd: float,
|
||||
) -> None:
|
||||
"""
|
||||
Require budget or raise exception.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
session_id: Optional session ID
|
||||
estimated_tokens: Estimated token usage
|
||||
estimated_cost_usd: Estimated USD cost
|
||||
|
||||
Raises:
|
||||
BudgetExceededError: If budget is exceeded
|
||||
"""
|
||||
if not await self.check_budget(
|
||||
agent_id, session_id, estimated_tokens, estimated_cost_usd
|
||||
):
|
||||
# Determine which budget was exceeded
|
||||
if session_id:
|
||||
session_tracker = await self.get_or_create_tracker(
|
||||
BudgetScope.SESSION, session_id
|
||||
)
|
||||
session_status = await session_tracker.get_status()
|
||||
if session_status.is_exceeded:
|
||||
raise BudgetExceededError(
|
||||
"Session budget exceeded",
|
||||
budget_type="session",
|
||||
current_usage=session_status.tokens_used,
|
||||
budget_limit=session_status.tokens_limit,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
agent_tracker = await self.get_or_create_tracker(
|
||||
BudgetScope.DAILY, agent_id
|
||||
)
|
||||
agent_status = await agent_tracker.get_status()
|
||||
raise BudgetExceededError(
|
||||
"Daily budget exceeded",
|
||||
budget_type="daily",
|
||||
current_usage=agent_status.tokens_used,
|
||||
budget_limit=agent_status.tokens_limit,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
async def record_usage(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str | None,
|
||||
tokens: int,
|
||||
cost_usd: float,
|
||||
) -> None:
|
||||
"""
|
||||
Record actual usage.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
session_id: Optional session ID
|
||||
tokens: Actual token usage
|
||||
cost_usd: Actual USD cost
|
||||
"""
|
||||
# Update session budget
|
||||
if session_id:
|
||||
session_tracker = await self.get_or_create_tracker(
|
||||
BudgetScope.SESSION, session_id
|
||||
)
|
||||
await session_tracker.add_usage(tokens, cost_usd)
|
||||
|
||||
# Check for warning
|
||||
status = await session_tracker.get_status()
|
||||
if status.is_warning and not status.is_exceeded:
|
||||
await self._send_alert(
|
||||
"warning",
|
||||
f"Session {session_id} at {status.tokens_used}/{status.tokens_limit} tokens",
|
||||
status,
|
||||
)
|
||||
|
||||
# Update agent daily budget
|
||||
agent_tracker = await self.get_or_create_tracker(BudgetScope.DAILY, agent_id)
|
||||
await agent_tracker.add_usage(tokens, cost_usd)
|
||||
|
||||
# Check for warning
|
||||
status = await agent_tracker.get_status()
|
||||
if status.is_warning and not status.is_exceeded:
|
||||
await self._send_alert(
|
||||
"warning",
|
||||
f"Agent {agent_id} at {status.tokens_used}/{status.tokens_limit} daily tokens",
|
||||
status,
|
||||
)
|
||||
|
||||
async def get_status(
|
||||
self,
|
||||
scope: BudgetScope,
|
||||
scope_id: str,
|
||||
) -> BudgetStatus | None:
|
||||
"""
|
||||
Get budget status.
|
||||
|
||||
Args:
|
||||
scope: Budget scope
|
||||
scope_id: ID within scope
|
||||
|
||||
Returns:
|
||||
Budget status or None if not tracked
|
||||
"""
|
||||
key = f"{scope.value}:{scope_id}"
|
||||
async with self._lock:
|
||||
tracker = self._trackers.get(key)
|
||||
|
||||
if tracker:
|
||||
return await tracker.get_status()
|
||||
return None
|
||||
|
||||
async def get_all_statuses(self) -> list[BudgetStatus]:
|
||||
"""Get status of all tracked budgets."""
|
||||
statuses = []
|
||||
async with self._lock:
|
||||
trackers = list(self._trackers.values())
|
||||
|
||||
for tracker in trackers:
|
||||
statuses.append(await tracker.get_status())
|
||||
|
||||
return statuses
|
||||
|
||||
async def set_budget(
|
||||
self,
|
||||
scope: BudgetScope,
|
||||
scope_id: str,
|
||||
tokens_limit: int,
|
||||
cost_limit_usd: float,
|
||||
) -> None:
|
||||
"""
|
||||
Set a custom budget limit.
|
||||
|
||||
Args:
|
||||
scope: Budget scope
|
||||
scope_id: ID within scope
|
||||
tokens_limit: Token limit
|
||||
cost_limit_usd: USD limit
|
||||
"""
|
||||
key = f"{scope.value}:{scope_id}"
|
||||
|
||||
reset_interval = None
|
||||
if scope == BudgetScope.DAILY:
|
||||
reset_interval = timedelta(days=1)
|
||||
elif scope == BudgetScope.WEEKLY:
|
||||
reset_interval = timedelta(weeks=1)
|
||||
elif scope == BudgetScope.MONTHLY:
|
||||
reset_interval = timedelta(days=30)
|
||||
|
||||
async with self._lock:
|
||||
self._trackers[key] = BudgetTracker(
|
||||
scope=scope,
|
||||
scope_id=scope_id,
|
||||
tokens_limit=tokens_limit,
|
||||
cost_limit_usd=cost_limit_usd,
|
||||
reset_interval=reset_interval,
|
||||
)
|
||||
|
||||
async def reset_budget(self, scope: BudgetScope, scope_id: str) -> bool:
|
||||
"""
|
||||
Reset a budget tracker.
|
||||
|
||||
Args:
|
||||
scope: Budget scope
|
||||
scope_id: ID within scope
|
||||
|
||||
Returns:
|
||||
True if tracker was found and reset
|
||||
"""
|
||||
key = f"{scope.value}:{scope_id}"
|
||||
async with self._lock:
|
||||
tracker = self._trackers.get(key)
|
||||
|
||||
if tracker:
|
||||
await tracker.reset()
|
||||
return True
|
||||
return False
|
||||
|
||||
def add_alert_handler(self, handler: Any) -> None:
|
||||
"""Add an alert handler."""
|
||||
self._alert_handlers.append(handler)
|
||||
|
||||
def remove_alert_handler(self, handler: Any) -> None:
|
||||
"""Remove an alert handler."""
|
||||
if handler in self._alert_handlers:
|
||||
self._alert_handlers.remove(handler)
|
||||
|
||||
async def _send_alert(
|
||||
self,
|
||||
alert_type: str,
|
||||
message: str,
|
||||
status: BudgetStatus,
|
||||
) -> None:
|
||||
"""Send alert to all handlers."""
|
||||
for handler in self._alert_handlers:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
await handler(alert_type, message, status)
|
||||
else:
|
||||
handler(alert_type, message, status)
|
||||
except Exception as e:
|
||||
logger.error("Error in alert handler: %s", e)
|
||||
@@ -1 +1,15 @@
|
||||
"""${dir} module."""
|
||||
"""
|
||||
Rate Limiting Module
|
||||
|
||||
Sliding window rate limiting for agent operations.
|
||||
"""
|
||||
|
||||
from .limiter import (
|
||||
RateLimiter,
|
||||
SlidingWindowCounter,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"RateLimiter",
|
||||
"SlidingWindowCounter",
|
||||
]
|
||||
|
||||
368
backend/app/services/safety/limits/limiter.py
Normal file
368
backend/app/services/safety/limits/limiter.py
Normal file
@@ -0,0 +1,368 @@
|
||||
"""
|
||||
Rate Limiter
|
||||
|
||||
Sliding window rate limiting for agent operations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections import deque
|
||||
|
||||
from ..config import get_safety_config
|
||||
from ..exceptions import RateLimitExceededError
|
||||
from ..models import (
|
||||
ActionRequest,
|
||||
RateLimitConfig,
|
||||
RateLimitStatus,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SlidingWindowCounter:
|
||||
"""Sliding window counter for rate limiting."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
limit: int,
|
||||
window_seconds: int,
|
||||
burst_limit: int | None = None,
|
||||
) -> None:
|
||||
self.limit = limit
|
||||
self.window_seconds = window_seconds
|
||||
self.burst_limit = burst_limit or limit
|
||||
self._timestamps: deque[float] = deque()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def try_acquire(self) -> tuple[bool, float]:
|
||||
"""
|
||||
Try to acquire a slot.
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed, retry_after_seconds)
|
||||
"""
|
||||
now = time.time()
|
||||
window_start = now - self.window_seconds
|
||||
|
||||
async with self._lock:
|
||||
# Remove expired entries
|
||||
while self._timestamps and self._timestamps[0] < window_start:
|
||||
self._timestamps.popleft()
|
||||
|
||||
current_count = len(self._timestamps)
|
||||
|
||||
# Check burst limit (instant check)
|
||||
if current_count >= self.burst_limit:
|
||||
# Calculate retry time
|
||||
oldest = self._timestamps[0] if self._timestamps else now
|
||||
retry_after = oldest + self.window_seconds - now
|
||||
return False, max(0, retry_after)
|
||||
|
||||
# Check window limit
|
||||
if current_count >= self.limit:
|
||||
oldest = self._timestamps[0] if self._timestamps else now
|
||||
retry_after = oldest + self.window_seconds - now
|
||||
return False, max(0, retry_after)
|
||||
|
||||
# Allow and record
|
||||
self._timestamps.append(now)
|
||||
return True, 0.0
|
||||
|
||||
async def get_status(self) -> tuple[int, int, float]:
|
||||
"""
|
||||
Get current status.
|
||||
|
||||
Returns:
|
||||
Tuple of (current_count, remaining, reset_in_seconds)
|
||||
"""
|
||||
now = time.time()
|
||||
window_start = now - self.window_seconds
|
||||
|
||||
async with self._lock:
|
||||
# Remove expired entries
|
||||
while self._timestamps and self._timestamps[0] < window_start:
|
||||
self._timestamps.popleft()
|
||||
|
||||
current_count = len(self._timestamps)
|
||||
remaining = max(0, self.limit - current_count)
|
||||
|
||||
if self._timestamps:
|
||||
reset_in = self._timestamps[0] + self.window_seconds - now
|
||||
else:
|
||||
reset_in = 0.0
|
||||
|
||||
return current_count, remaining, max(0, reset_in)
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""
|
||||
Rate limiter for agent operations.
|
||||
|
||||
Features:
|
||||
- Per-tool rate limits
|
||||
- Per-agent rate limits
|
||||
- Per-resource rate limits
|
||||
- Sliding window implementation
|
||||
- Burst allowance with recovery
|
||||
- Slowdown before hard block
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the RateLimiter."""
|
||||
config = get_safety_config()
|
||||
|
||||
self._configs: dict[str, RateLimitConfig] = {}
|
||||
self._counters: dict[str, SlidingWindowCounter] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Default rate limits
|
||||
self._default_limits = {
|
||||
"actions": RateLimitConfig(
|
||||
name="actions",
|
||||
limit=config.default_actions_per_minute,
|
||||
window_seconds=60,
|
||||
),
|
||||
"llm_calls": RateLimitConfig(
|
||||
name="llm_calls",
|
||||
limit=config.default_llm_calls_per_minute,
|
||||
window_seconds=60,
|
||||
),
|
||||
"file_ops": RateLimitConfig(
|
||||
name="file_ops",
|
||||
limit=config.default_file_ops_per_minute,
|
||||
window_seconds=60,
|
||||
),
|
||||
}
|
||||
|
||||
def configure(self, config: RateLimitConfig) -> None:
|
||||
"""
|
||||
Configure a rate limit.
|
||||
|
||||
Args:
|
||||
config: Rate limit configuration
|
||||
"""
|
||||
self._configs[config.name] = config
|
||||
logger.debug(
|
||||
"Configured rate limit: %s = %d/%ds",
|
||||
config.name,
|
||||
config.limit,
|
||||
config.window_seconds,
|
||||
)
|
||||
|
||||
async def check(
|
||||
self,
|
||||
limit_name: str,
|
||||
key: str,
|
||||
) -> RateLimitStatus:
|
||||
"""
|
||||
Check rate limit without consuming a slot.
|
||||
|
||||
Args:
|
||||
limit_name: Name of the rate limit
|
||||
key: Key for tracking (e.g., agent_id)
|
||||
|
||||
Returns:
|
||||
Rate limit status
|
||||
"""
|
||||
counter = await self._get_counter(limit_name, key)
|
||||
config = self._get_config(limit_name)
|
||||
|
||||
current, remaining, reset_in = await counter.get_status()
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
return RateLimitStatus(
|
||||
name=limit_name,
|
||||
current_count=current,
|
||||
limit=config.limit,
|
||||
window_seconds=config.window_seconds,
|
||||
remaining=remaining,
|
||||
reset_at=datetime.utcnow() + timedelta(seconds=reset_in),
|
||||
is_limited=remaining <= 0,
|
||||
retry_after_seconds=reset_in if remaining <= 0 else 0.0,
|
||||
)
|
||||
|
||||
async def acquire(
|
||||
self,
|
||||
limit_name: str,
|
||||
key: str,
|
||||
) -> tuple[bool, RateLimitStatus]:
|
||||
"""
|
||||
Try to acquire a rate limit slot.
|
||||
|
||||
Args:
|
||||
limit_name: Name of the rate limit
|
||||
key: Key for tracking (e.g., agent_id)
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed, status)
|
||||
"""
|
||||
counter = await self._get_counter(limit_name, key)
|
||||
config = self._get_config(limit_name)
|
||||
|
||||
allowed, retry_after = await counter.try_acquire()
|
||||
current, remaining, reset_in = await counter.get_status()
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
status = RateLimitStatus(
|
||||
name=limit_name,
|
||||
current_count=current,
|
||||
limit=config.limit,
|
||||
window_seconds=config.window_seconds,
|
||||
remaining=remaining,
|
||||
reset_at=datetime.utcnow() + timedelta(seconds=reset_in),
|
||||
is_limited=not allowed,
|
||||
retry_after_seconds=retry_after,
|
||||
)
|
||||
|
||||
return allowed, status
|
||||
|
||||
async def check_action(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
) -> tuple[bool, list[RateLimitStatus]]:
|
||||
"""
|
||||
Check all applicable rate limits for an action.
|
||||
|
||||
Args:
|
||||
action: The action to check
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed, list of statuses)
|
||||
"""
|
||||
agent_id = action.metadata.agent_id
|
||||
statuses: list[RateLimitStatus] = []
|
||||
allowed = True
|
||||
|
||||
# Check general actions limit
|
||||
actions_allowed, actions_status = await self.acquire("actions", agent_id)
|
||||
statuses.append(actions_status)
|
||||
if not actions_allowed:
|
||||
allowed = False
|
||||
|
||||
# Check LLM-specific limit for LLM calls
|
||||
if action.action_type.value == "llm_call":
|
||||
llm_allowed, llm_status = await self.acquire("llm_calls", agent_id)
|
||||
statuses.append(llm_status)
|
||||
if not llm_allowed:
|
||||
allowed = False
|
||||
|
||||
# Check file ops limit for file operations
|
||||
if action.action_type.value in {"file_read", "file_write", "file_delete"}:
|
||||
file_allowed, file_status = await self.acquire("file_ops", agent_id)
|
||||
statuses.append(file_status)
|
||||
if not file_allowed:
|
||||
allowed = False
|
||||
|
||||
return allowed, statuses
|
||||
|
||||
async def require(
|
||||
self,
|
||||
limit_name: str,
|
||||
key: str,
|
||||
) -> None:
|
||||
"""
|
||||
Require rate limit slot or raise exception.
|
||||
|
||||
Args:
|
||||
limit_name: Name of the rate limit
|
||||
key: Key for tracking
|
||||
|
||||
Raises:
|
||||
RateLimitExceededError: If rate limit exceeded
|
||||
"""
|
||||
allowed, status = await self.acquire(limit_name, key)
|
||||
if not allowed:
|
||||
raise RateLimitExceededError(
|
||||
f"Rate limit exceeded: {limit_name}",
|
||||
limit_type=limit_name,
|
||||
limit_value=status.limit,
|
||||
window_seconds=status.window_seconds,
|
||||
retry_after_seconds=status.retry_after_seconds,
|
||||
)
|
||||
|
||||
async def get_all_statuses(self, key: str) -> dict[str, RateLimitStatus]:
|
||||
"""
|
||||
Get status of all rate limits for a key.
|
||||
|
||||
Args:
|
||||
key: Key for tracking
|
||||
|
||||
Returns:
|
||||
Dict of limit name to status
|
||||
"""
|
||||
statuses = {}
|
||||
for name in self._default_limits:
|
||||
statuses[name] = await self.check(name, key)
|
||||
for name in self._configs:
|
||||
if name not in statuses:
|
||||
statuses[name] = await self.check(name, key)
|
||||
return statuses
|
||||
|
||||
async def reset(self, limit_name: str, key: str) -> bool:
|
||||
"""
|
||||
Reset a rate limit counter.
|
||||
|
||||
Args:
|
||||
limit_name: Name of the rate limit
|
||||
key: Key for tracking
|
||||
|
||||
Returns:
|
||||
True if counter was found and reset
|
||||
"""
|
||||
counter_key = f"{limit_name}:{key}"
|
||||
async with self._lock:
|
||||
if counter_key in self._counters:
|
||||
del self._counters[counter_key]
|
||||
return True
|
||||
return False
|
||||
|
||||
async def reset_all(self, key: str) -> int:
|
||||
"""
|
||||
Reset all rate limit counters for a key.
|
||||
|
||||
Args:
|
||||
key: Key for tracking
|
||||
|
||||
Returns:
|
||||
Number of counters reset
|
||||
"""
|
||||
count = 0
|
||||
async with self._lock:
|
||||
to_remove = [k for k in self._counters if k.endswith(f":{key}")]
|
||||
for k in to_remove:
|
||||
del self._counters[k]
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def _get_config(self, limit_name: str) -> RateLimitConfig:
|
||||
"""Get configuration for a rate limit."""
|
||||
if limit_name in self._configs:
|
||||
return self._configs[limit_name]
|
||||
if limit_name in self._default_limits:
|
||||
return self._default_limits[limit_name]
|
||||
# Return default
|
||||
return RateLimitConfig(
|
||||
name=limit_name,
|
||||
limit=60,
|
||||
window_seconds=60,
|
||||
)
|
||||
|
||||
async def _get_counter(
|
||||
self,
|
||||
limit_name: str,
|
||||
key: str,
|
||||
) -> SlidingWindowCounter:
|
||||
"""Get or create a counter."""
|
||||
counter_key = f"{limit_name}:{key}"
|
||||
config = self._get_config(limit_name)
|
||||
|
||||
async with self._lock:
|
||||
if counter_key not in self._counters:
|
||||
self._counters[counter_key] = SlidingWindowCounter(
|
||||
limit=config.limit,
|
||||
window_seconds=config.window_seconds,
|
||||
burst_limit=config.burst_limit,
|
||||
)
|
||||
return self._counters[counter_key]
|
||||
@@ -1 +1,17 @@
|
||||
"""${dir} module."""
|
||||
"""
|
||||
Loop Detection Module
|
||||
|
||||
Detects and prevents action loops in agent behavior.
|
||||
"""
|
||||
|
||||
from .detector import (
|
||||
ActionSignature,
|
||||
LoopBreaker,
|
||||
LoopDetector,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ActionSignature",
|
||||
"LoopBreaker",
|
||||
"LoopDetector",
|
||||
]
|
||||
|
||||
267
backend/app/services/safety/loops/detector.py
Normal file
267
backend/app/services/safety/loops/detector.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""
|
||||
Loop Detector
|
||||
|
||||
Detects and prevents action loops in agent behavior.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from collections import Counter, deque
|
||||
from typing import Any
|
||||
|
||||
from ..config import get_safety_config
|
||||
from ..exceptions import LoopDetectedError
|
||||
from ..models import ActionRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ActionSignature:
|
||||
"""Signature of an action for comparison."""
|
||||
|
||||
def __init__(self, action: ActionRequest) -> None:
|
||||
self.action_type = action.action_type.value
|
||||
self.tool_name = action.tool_name
|
||||
self.resource = action.resource
|
||||
self.args_hash = self._hash_args(action.arguments)
|
||||
|
||||
def _hash_args(self, args: dict[str, Any]) -> str:
|
||||
"""Create a hash of the arguments."""
|
||||
try:
|
||||
serialized = json.dumps(args, sort_keys=True, default=str)
|
||||
return hashlib.sha256(serialized.encode()).hexdigest()[:8]
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
def exact_key(self) -> str:
|
||||
"""Key for exact match detection."""
|
||||
return f"{self.action_type}:{self.tool_name}:{self.resource}:{self.args_hash}"
|
||||
|
||||
def semantic_key(self) -> str:
|
||||
"""Key for semantic (similar) match detection."""
|
||||
return f"{self.action_type}:{self.tool_name}:{self.resource}"
|
||||
|
||||
def type_key(self) -> str:
|
||||
"""Key for action type only."""
|
||||
return f"{self.action_type}"
|
||||
|
||||
|
||||
class LoopDetector:
|
||||
"""
|
||||
Detects action loops and repetitive behavior.
|
||||
|
||||
Loop Types:
|
||||
- Exact: Same action with same arguments
|
||||
- Semantic: Similar actions (same type/tool/resource, different args)
|
||||
- Oscillation: A→B→A→B patterns
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
history_size: int | None = None,
|
||||
max_exact_repetitions: int | None = None,
|
||||
max_semantic_repetitions: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the LoopDetector.
|
||||
|
||||
Args:
|
||||
history_size: Size of action history to track
|
||||
max_exact_repetitions: Max allowed exact repetitions
|
||||
max_semantic_repetitions: Max allowed semantic repetitions
|
||||
"""
|
||||
config = get_safety_config()
|
||||
|
||||
self._history_size = history_size or config.loop_history_size
|
||||
self._max_exact = max_exact_repetitions or config.max_repeated_actions
|
||||
self._max_semantic = max_semantic_repetitions or config.max_similar_actions
|
||||
|
||||
# Per-agent history
|
||||
self._histories: dict[str, deque[ActionSignature]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def check(self, action: ActionRequest) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Check if an action would create a loop.
|
||||
|
||||
Args:
|
||||
action: The action to check
|
||||
|
||||
Returns:
|
||||
Tuple of (is_loop, loop_type)
|
||||
"""
|
||||
agent_id = action.metadata.agent_id
|
||||
signature = ActionSignature(action)
|
||||
|
||||
async with self._lock:
|
||||
history = self._get_history(agent_id)
|
||||
|
||||
# Check exact repetition
|
||||
exact_key = signature.exact_key()
|
||||
exact_count = sum(1 for h in history if h.exact_key() == exact_key)
|
||||
if exact_count >= self._max_exact:
|
||||
return True, "exact"
|
||||
|
||||
# Check semantic repetition
|
||||
semantic_key = signature.semantic_key()
|
||||
semantic_count = sum(1 for h in history if h.semantic_key() == semantic_key)
|
||||
if semantic_count >= self._max_semantic:
|
||||
return True, "semantic"
|
||||
|
||||
# Check oscillation (A→B→A→B pattern)
|
||||
if len(history) >= 3:
|
||||
pattern = self._detect_oscillation(history, signature)
|
||||
if pattern:
|
||||
return True, "oscillation"
|
||||
|
||||
return False, None
|
||||
|
||||
async def check_and_raise(self, action: ActionRequest) -> None:
|
||||
"""
|
||||
Check for loops and raise if detected.
|
||||
|
||||
Args:
|
||||
action: The action to check
|
||||
|
||||
Raises:
|
||||
LoopDetectedError: If loop is detected
|
||||
"""
|
||||
is_loop, loop_type = await self.check(action)
|
||||
if is_loop:
|
||||
signature = ActionSignature(action)
|
||||
raise LoopDetectedError(
|
||||
f"Loop detected: {loop_type}",
|
||||
loop_type=loop_type or "unknown",
|
||||
repetition_count=self._max_exact if loop_type == "exact" else self._max_semantic,
|
||||
action_pattern=[signature.semantic_key()],
|
||||
agent_id=action.metadata.agent_id,
|
||||
action_id=action.id,
|
||||
)
|
||||
|
||||
async def record(self, action: ActionRequest) -> None:
|
||||
"""
|
||||
Record an action in history.
|
||||
|
||||
Args:
|
||||
action: The action to record
|
||||
"""
|
||||
agent_id = action.metadata.agent_id
|
||||
signature = ActionSignature(action)
|
||||
|
||||
async with self._lock:
|
||||
history = self._get_history(agent_id)
|
||||
history.append(signature)
|
||||
|
||||
async def clear_history(self, agent_id: str) -> None:
|
||||
"""
|
||||
Clear history for an agent.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
"""
|
||||
async with self._lock:
|
||||
if agent_id in self._histories:
|
||||
self._histories[agent_id].clear()
|
||||
|
||||
async def get_stats(self, agent_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
Get loop detection stats for an agent.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
|
||||
Returns:
|
||||
Stats dictionary
|
||||
"""
|
||||
async with self._lock:
|
||||
history = self._get_history(agent_id)
|
||||
|
||||
# Count action types
|
||||
type_counts = Counter(h.type_key() for h in history)
|
||||
semantic_counts = Counter(h.semantic_key() for h in history)
|
||||
|
||||
return {
|
||||
"history_size": len(history),
|
||||
"max_history": self._history_size,
|
||||
"action_type_counts": dict(type_counts),
|
||||
"top_semantic_patterns": semantic_counts.most_common(5),
|
||||
}
|
||||
|
||||
def _get_history(self, agent_id: str) -> deque[ActionSignature]:
|
||||
"""Get or create history for an agent."""
|
||||
if agent_id not in self._histories:
|
||||
self._histories[agent_id] = deque(maxlen=self._history_size)
|
||||
return self._histories[agent_id]
|
||||
|
||||
def _detect_oscillation(
|
||||
self,
|
||||
history: deque[ActionSignature],
|
||||
current: ActionSignature,
|
||||
) -> bool:
|
||||
"""
|
||||
Detect A→B→A→B oscillation pattern.
|
||||
|
||||
Looks at last 4+ actions including current.
|
||||
"""
|
||||
if len(history) < 3:
|
||||
return False
|
||||
|
||||
# Get last 3 actions + current
|
||||
recent = [*list(history)[-3:], current]
|
||||
|
||||
# Check for A→B→A→B pattern
|
||||
if len(recent) >= 4:
|
||||
# Get semantic keys
|
||||
keys = [a.semantic_key() for a in recent[-4:]]
|
||||
|
||||
# Pattern: k[0]==k[2] and k[1]==k[3] and k[0]!=k[1]
|
||||
if keys[0] == keys[2] and keys[1] == keys[3] and keys[0] != keys[1]:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class LoopBreaker:
|
||||
"""
|
||||
Strategies for breaking detected loops.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def suggest_alternatives(
|
||||
action: ActionRequest,
|
||||
loop_type: str,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Suggest alternative actions when loop is detected.
|
||||
|
||||
Args:
|
||||
action: The looping action
|
||||
loop_type: Type of loop detected
|
||||
|
||||
Returns:
|
||||
List of suggestions
|
||||
"""
|
||||
suggestions = []
|
||||
|
||||
if loop_type == "exact":
|
||||
suggestions.append(
|
||||
"The same action with identical arguments has been repeated too many times. "
|
||||
"Consider: (1) Verify the action succeeded, (2) Try a different approach, "
|
||||
"(3) Escalate for human review"
|
||||
)
|
||||
elif loop_type == "semantic":
|
||||
suggestions.append(
|
||||
"Similar actions have been repeated too many times. "
|
||||
"Consider: (1) Review if the approach is working, (2) Try an alternative method, "
|
||||
"(3) Request clarification on the goal"
|
||||
)
|
||||
elif loop_type == "oscillation":
|
||||
suggestions.append(
|
||||
"An oscillating pattern was detected (A→B→A→B). "
|
||||
"This usually indicates conflicting goals or a stuck state. "
|
||||
"Consider: (1) Step back and reassess, (2) Request human guidance"
|
||||
)
|
||||
|
||||
return suggestions
|
||||
@@ -1 +1,15 @@
|
||||
"""${dir} module."""
|
||||
"""
|
||||
Permission Management Module
|
||||
|
||||
Agent permissions for resource access.
|
||||
"""
|
||||
|
||||
from .manager import (
|
||||
PermissionGrant,
|
||||
PermissionManager,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PermissionGrant",
|
||||
"PermissionManager",
|
||||
]
|
||||
|
||||
384
backend/app/services/safety/permissions/manager.py
Normal file
384
backend/app/services/safety/permissions/manager.py
Normal file
@@ -0,0 +1,384 @@
|
||||
"""
|
||||
Permission Manager
|
||||
|
||||
Manages permissions for agent actions on resources.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import fnmatch
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
from ..exceptions import PermissionDeniedError
|
||||
from ..models import (
|
||||
ActionRequest,
|
||||
ActionType,
|
||||
PermissionLevel,
|
||||
ResourceType,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PermissionGrant:
|
||||
"""A permission grant for an agent on a resource."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_id: str,
|
||||
resource_pattern: str,
|
||||
resource_type: ResourceType,
|
||||
level: PermissionLevel,
|
||||
*,
|
||||
expires_at: datetime | None = None,
|
||||
granted_by: str | None = None,
|
||||
reason: str | None = None,
|
||||
) -> None:
|
||||
self.id = str(uuid4())
|
||||
self.agent_id = agent_id
|
||||
self.resource_pattern = resource_pattern
|
||||
self.resource_type = resource_type
|
||||
self.level = level
|
||||
self.expires_at = expires_at
|
||||
self.granted_by = granted_by
|
||||
self.reason = reason
|
||||
self.created_at = datetime.utcnow()
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the grant has expired."""
|
||||
if self.expires_at is None:
|
||||
return False
|
||||
return datetime.utcnow() > self.expires_at
|
||||
|
||||
def matches(self, resource: str, resource_type: ResourceType) -> bool:
|
||||
"""Check if this grant applies to a resource."""
|
||||
if self.resource_type != resource_type:
|
||||
return False
|
||||
return fnmatch.fnmatch(resource, self.resource_pattern)
|
||||
|
||||
def allows(self, required_level: PermissionLevel) -> bool:
|
||||
"""Check if this grant allows the required permission level."""
|
||||
# Permission level hierarchy
|
||||
hierarchy = {
|
||||
PermissionLevel.NONE: 0,
|
||||
PermissionLevel.READ: 1,
|
||||
PermissionLevel.WRITE: 2,
|
||||
PermissionLevel.EXECUTE: 3,
|
||||
PermissionLevel.DELETE: 4,
|
||||
PermissionLevel.ADMIN: 5,
|
||||
}
|
||||
|
||||
return hierarchy[self.level] >= hierarchy[required_level]
|
||||
|
||||
|
||||
class PermissionManager:
|
||||
"""
|
||||
Manages permissions for agent access to resources.
|
||||
|
||||
Features:
|
||||
- Permission grants by agent/resource pattern
|
||||
- Permission inheritance (project → agent → action)
|
||||
- Temporary permissions with expiration
|
||||
- Least-privilege defaults
|
||||
- Permission escalation logging
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default_deny: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the PermissionManager.
|
||||
|
||||
Args:
|
||||
default_deny: If True, deny access unless explicitly granted
|
||||
"""
|
||||
self._grants: list[PermissionGrant] = []
|
||||
self._default_deny = default_deny
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Default permissions for common resources
|
||||
self._default_permissions: dict[ResourceType, PermissionLevel] = {
|
||||
ResourceType.FILE: PermissionLevel.READ,
|
||||
ResourceType.DATABASE: PermissionLevel.READ,
|
||||
ResourceType.API: PermissionLevel.READ,
|
||||
ResourceType.GIT: PermissionLevel.READ,
|
||||
ResourceType.LLM: PermissionLevel.EXECUTE,
|
||||
ResourceType.SHELL: PermissionLevel.NONE,
|
||||
ResourceType.NETWORK: PermissionLevel.READ,
|
||||
}
|
||||
|
||||
async def grant(
|
||||
self,
|
||||
agent_id: str,
|
||||
resource_pattern: str,
|
||||
resource_type: ResourceType,
|
||||
level: PermissionLevel,
|
||||
*,
|
||||
duration_seconds: int | None = None,
|
||||
granted_by: str | None = None,
|
||||
reason: str | None = None,
|
||||
) -> PermissionGrant:
|
||||
"""
|
||||
Grant a permission to an agent.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
resource_pattern: Pattern for matching resources (supports wildcards)
|
||||
resource_type: Type of resource
|
||||
level: Permission level to grant
|
||||
duration_seconds: Optional duration for temporary permission
|
||||
granted_by: Who granted the permission
|
||||
reason: Reason for granting
|
||||
|
||||
Returns:
|
||||
The created permission grant
|
||||
"""
|
||||
expires_at = None
|
||||
if duration_seconds:
|
||||
expires_at = datetime.utcnow() + timedelta(seconds=duration_seconds)
|
||||
|
||||
grant = PermissionGrant(
|
||||
agent_id=agent_id,
|
||||
resource_pattern=resource_pattern,
|
||||
resource_type=resource_type,
|
||||
level=level,
|
||||
expires_at=expires_at,
|
||||
granted_by=granted_by,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
self._grants.append(grant)
|
||||
|
||||
logger.info(
|
||||
"Permission granted: agent=%s, resource=%s, type=%s, level=%s",
|
||||
agent_id,
|
||||
resource_pattern,
|
||||
resource_type.value,
|
||||
level.value,
|
||||
)
|
||||
|
||||
return grant
|
||||
|
||||
async def revoke(self, grant_id: str) -> bool:
|
||||
"""
|
||||
Revoke a permission grant.
|
||||
|
||||
Args:
|
||||
grant_id: ID of the grant to revoke
|
||||
|
||||
Returns:
|
||||
True if grant was found and revoked
|
||||
"""
|
||||
async with self._lock:
|
||||
for i, grant in enumerate(self._grants):
|
||||
if grant.id == grant_id:
|
||||
del self._grants[i]
|
||||
logger.info("Permission revoked: %s", grant_id)
|
||||
return True
|
||||
return False
|
||||
|
||||
async def revoke_all(self, agent_id: str) -> int:
|
||||
"""
|
||||
Revoke all permissions for an agent.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
|
||||
Returns:
|
||||
Number of grants revoked
|
||||
"""
|
||||
async with self._lock:
|
||||
original_count = len(self._grants)
|
||||
self._grants = [g for g in self._grants if g.agent_id != agent_id]
|
||||
revoked = original_count - len(self._grants)
|
||||
|
||||
if revoked:
|
||||
logger.info("Revoked %d permissions for agent %s", revoked, agent_id)
|
||||
|
||||
return revoked
|
||||
|
||||
async def check(
|
||||
self,
|
||||
agent_id: str,
|
||||
resource: str,
|
||||
resource_type: ResourceType,
|
||||
required_level: PermissionLevel,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if an agent has permission to access a resource.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
resource: Resource to access
|
||||
resource_type: Type of resource
|
||||
required_level: Required permission level
|
||||
|
||||
Returns:
|
||||
True if access is allowed
|
||||
"""
|
||||
# Clean up expired grants
|
||||
await self._cleanup_expired()
|
||||
|
||||
async with self._lock:
|
||||
for grant in self._grants:
|
||||
if grant.agent_id != agent_id:
|
||||
continue
|
||||
|
||||
if grant.is_expired():
|
||||
continue
|
||||
|
||||
if grant.matches(resource, resource_type):
|
||||
if grant.allows(required_level):
|
||||
return True
|
||||
|
||||
# Check default permissions
|
||||
if not self._default_deny:
|
||||
default_level = self._default_permissions.get(
|
||||
resource_type, PermissionLevel.NONE
|
||||
)
|
||||
hierarchy = {
|
||||
PermissionLevel.NONE: 0,
|
||||
PermissionLevel.READ: 1,
|
||||
PermissionLevel.WRITE: 2,
|
||||
PermissionLevel.EXECUTE: 3,
|
||||
PermissionLevel.DELETE: 4,
|
||||
PermissionLevel.ADMIN: 5,
|
||||
}
|
||||
if hierarchy[default_level] >= hierarchy[required_level]:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def check_action(self, action: ActionRequest) -> bool:
|
||||
"""
|
||||
Check if an action is permitted.
|
||||
|
||||
Args:
|
||||
action: The action to check
|
||||
|
||||
Returns:
|
||||
True if action is allowed
|
||||
"""
|
||||
# Determine required permission level from action type
|
||||
level_map = {
|
||||
ActionType.FILE_READ: PermissionLevel.READ,
|
||||
ActionType.FILE_WRITE: PermissionLevel.WRITE,
|
||||
ActionType.FILE_DELETE: PermissionLevel.DELETE,
|
||||
ActionType.DATABASE_QUERY: PermissionLevel.READ,
|
||||
ActionType.DATABASE_MUTATE: PermissionLevel.WRITE,
|
||||
ActionType.SHELL_COMMAND: PermissionLevel.EXECUTE,
|
||||
ActionType.API_CALL: PermissionLevel.EXECUTE,
|
||||
ActionType.GIT_OPERATION: PermissionLevel.WRITE,
|
||||
ActionType.LLM_CALL: PermissionLevel.EXECUTE,
|
||||
ActionType.NETWORK_REQUEST: PermissionLevel.READ,
|
||||
ActionType.TOOL_CALL: PermissionLevel.EXECUTE,
|
||||
}
|
||||
|
||||
required_level = level_map.get(action.action_type, PermissionLevel.EXECUTE)
|
||||
|
||||
# Determine resource type from action
|
||||
resource_type_map = {
|
||||
ActionType.FILE_READ: ResourceType.FILE,
|
||||
ActionType.FILE_WRITE: ResourceType.FILE,
|
||||
ActionType.FILE_DELETE: ResourceType.FILE,
|
||||
ActionType.DATABASE_QUERY: ResourceType.DATABASE,
|
||||
ActionType.DATABASE_MUTATE: ResourceType.DATABASE,
|
||||
ActionType.SHELL_COMMAND: ResourceType.SHELL,
|
||||
ActionType.API_CALL: ResourceType.API,
|
||||
ActionType.GIT_OPERATION: ResourceType.GIT,
|
||||
ActionType.LLM_CALL: ResourceType.LLM,
|
||||
ActionType.NETWORK_REQUEST: ResourceType.NETWORK,
|
||||
}
|
||||
|
||||
resource_type = resource_type_map.get(action.action_type, ResourceType.CUSTOM)
|
||||
resource = action.resource or action.tool_name or "*"
|
||||
|
||||
return await self.check(
|
||||
agent_id=action.metadata.agent_id,
|
||||
resource=resource,
|
||||
resource_type=resource_type,
|
||||
required_level=required_level,
|
||||
)
|
||||
|
||||
async def require_permission(
|
||||
self,
|
||||
agent_id: str,
|
||||
resource: str,
|
||||
resource_type: ResourceType,
|
||||
required_level: PermissionLevel,
|
||||
) -> None:
|
||||
"""
|
||||
Require permission or raise exception.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
resource: Resource to access
|
||||
resource_type: Type of resource
|
||||
required_level: Required permission level
|
||||
|
||||
Raises:
|
||||
PermissionDeniedError: If permission is denied
|
||||
"""
|
||||
if not await self.check(agent_id, resource, resource_type, required_level):
|
||||
raise PermissionDeniedError(
|
||||
f"Permission denied: {resource}",
|
||||
action_type=None,
|
||||
resource=resource,
|
||||
required_permission=required_level.value,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
async def list_grants(
|
||||
self,
|
||||
agent_id: str | None = None,
|
||||
resource_type: ResourceType | None = None,
|
||||
) -> list[PermissionGrant]:
|
||||
"""
|
||||
List permission grants.
|
||||
|
||||
Args:
|
||||
agent_id: Optional filter by agent
|
||||
resource_type: Optional filter by resource type
|
||||
|
||||
Returns:
|
||||
List of matching grants
|
||||
"""
|
||||
await self._cleanup_expired()
|
||||
|
||||
async with self._lock:
|
||||
grants = list(self._grants)
|
||||
|
||||
if agent_id:
|
||||
grants = [g for g in grants if g.agent_id == agent_id]
|
||||
|
||||
if resource_type:
|
||||
grants = [g for g in grants if g.resource_type == resource_type]
|
||||
|
||||
return grants
|
||||
|
||||
def set_default_permission(
|
||||
self,
|
||||
resource_type: ResourceType,
|
||||
level: PermissionLevel,
|
||||
) -> None:
|
||||
"""
|
||||
Set the default permission level for a resource type.
|
||||
|
||||
Args:
|
||||
resource_type: Type of resource
|
||||
level: Default permission level
|
||||
"""
|
||||
self._default_permissions[resource_type] = level
|
||||
|
||||
async def _cleanup_expired(self) -> None:
|
||||
"""Remove expired grants."""
|
||||
async with self._lock:
|
||||
original_count = len(self._grants)
|
||||
self._grants = [g for g in self._grants if not g.is_expired()]
|
||||
removed = original_count - len(self._grants)
|
||||
|
||||
if removed:
|
||||
logger.debug("Cleaned up %d expired permission grants", removed)
|
||||
@@ -1 +1,21 @@
|
||||
"""${dir} module."""
|
||||
"""
|
||||
Action Validation Module
|
||||
|
||||
Pre-execution validation with rule engine.
|
||||
"""
|
||||
|
||||
from .validator import (
|
||||
ActionValidator,
|
||||
ValidationCache,
|
||||
create_allow_rule,
|
||||
create_approval_rule,
|
||||
create_deny_rule,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ActionValidator",
|
||||
"ValidationCache",
|
||||
"create_allow_rule",
|
||||
"create_approval_rule",
|
||||
"create_deny_rule",
|
||||
]
|
||||
|
||||
439
backend/app/services/safety/validation/validator.py
Normal file
439
backend/app/services/safety/validation/validator.py
Normal file
@@ -0,0 +1,439 @@
|
||||
"""
|
||||
Action Validator
|
||||
|
||||
Pre-execution validation with rule engine for action requests.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import fnmatch
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
|
||||
from ..config import get_safety_config
|
||||
from ..models import (
|
||||
ActionRequest,
|
||||
ActionType,
|
||||
SafetyDecision,
|
||||
SafetyPolicy,
|
||||
ValidationResult,
|
||||
ValidationRule,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ValidationCache:
|
||||
"""LRU cache for validation results."""
|
||||
|
||||
def __init__(self, max_size: int = 1000, ttl_seconds: int = 60) -> None:
|
||||
self._cache: OrderedDict[str, tuple[ValidationResult, float]] = OrderedDict()
|
||||
self._max_size = max_size
|
||||
self._ttl = ttl_seconds
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def get(self, key: str) -> ValidationResult | None:
|
||||
"""Get cached validation result."""
|
||||
import time
|
||||
|
||||
async with self._lock:
|
||||
if key not in self._cache:
|
||||
return None
|
||||
|
||||
result, timestamp = self._cache[key]
|
||||
if time.time() - timestamp > self._ttl:
|
||||
del self._cache[key]
|
||||
return None
|
||||
|
||||
# Move to end (LRU)
|
||||
self._cache.move_to_end(key)
|
||||
return result
|
||||
|
||||
async def set(self, key: str, result: ValidationResult) -> None:
|
||||
"""Cache a validation result."""
|
||||
import time
|
||||
|
||||
async with self._lock:
|
||||
if key in self._cache:
|
||||
self._cache.move_to_end(key)
|
||||
else:
|
||||
if len(self._cache) >= self._max_size:
|
||||
self._cache.popitem(last=False)
|
||||
self._cache[key] = (result, time.time())
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear the cache."""
|
||||
async with self._lock:
|
||||
self._cache.clear()
|
||||
|
||||
|
||||
class ActionValidator:
|
||||
"""
|
||||
Validates actions against safety rules before execution.
|
||||
|
||||
Features:
|
||||
- Rule-based validation engine
|
||||
- Allow/deny/require-approval rules
|
||||
- Pattern matching for tools and resources
|
||||
- Validation result caching
|
||||
- Bypass capability for emergencies
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_enabled: bool = True,
|
||||
cache_size: int = 1000,
|
||||
cache_ttl: int = 60,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the ActionValidator.
|
||||
|
||||
Args:
|
||||
cache_enabled: Whether to cache validation results
|
||||
cache_size: Maximum cache entries
|
||||
cache_ttl: Cache TTL in seconds
|
||||
"""
|
||||
self._rules: list[ValidationRule] = []
|
||||
self._cache_enabled = cache_enabled
|
||||
self._cache = ValidationCache(max_size=cache_size, ttl_seconds=cache_ttl)
|
||||
self._bypass_enabled = False
|
||||
self._bypass_reason: str | None = None
|
||||
|
||||
config = get_safety_config()
|
||||
self._cache_enabled = cache_enabled
|
||||
self._cache_ttl = config.validation_cache_ttl
|
||||
self._cache_size = config.validation_cache_size
|
||||
|
||||
def add_rule(self, rule: ValidationRule) -> None:
|
||||
"""
|
||||
Add a validation rule.
|
||||
|
||||
Args:
|
||||
rule: The rule to add
|
||||
"""
|
||||
self._rules.append(rule)
|
||||
# Re-sort by priority (higher first)
|
||||
self._rules.sort(key=lambda r: r.priority, reverse=True)
|
||||
logger.debug("Added validation rule: %s (priority %d)", rule.name, rule.priority)
|
||||
|
||||
def remove_rule(self, rule_id: str) -> bool:
|
||||
"""
|
||||
Remove a validation rule by ID.
|
||||
|
||||
Args:
|
||||
rule_id: ID of the rule to remove
|
||||
|
||||
Returns:
|
||||
True if rule was found and removed
|
||||
"""
|
||||
for i, rule in enumerate(self._rules):
|
||||
if rule.id == rule_id:
|
||||
del self._rules[i]
|
||||
logger.debug("Removed validation rule: %s", rule_id)
|
||||
return True
|
||||
return False
|
||||
|
||||
def clear_rules(self) -> None:
|
||||
"""Remove all validation rules."""
|
||||
self._rules.clear()
|
||||
|
||||
def load_rules_from_policy(self, policy: SafetyPolicy) -> None:
|
||||
"""
|
||||
Load validation rules from a safety policy.
|
||||
|
||||
Args:
|
||||
policy: The policy to load rules from
|
||||
"""
|
||||
# Clear existing rules
|
||||
self.clear_rules()
|
||||
|
||||
# Add rules from policy
|
||||
for rule in policy.validation_rules:
|
||||
self.add_rule(rule)
|
||||
|
||||
# Create implicit rules from policy settings
|
||||
|
||||
# Denied tools
|
||||
for i, pattern in enumerate(policy.denied_tools):
|
||||
self.add_rule(
|
||||
ValidationRule(
|
||||
name=f"deny_tool_{i}",
|
||||
description=f"Deny tool pattern: {pattern}",
|
||||
priority=100, # High priority for denials
|
||||
tool_patterns=[pattern],
|
||||
decision=SafetyDecision.DENY,
|
||||
reason=f"Tool matches denied pattern: {pattern}",
|
||||
)
|
||||
)
|
||||
|
||||
# Require approval patterns
|
||||
for i, pattern in enumerate(policy.require_approval_for):
|
||||
if pattern == "*":
|
||||
# All actions require approval
|
||||
self.add_rule(
|
||||
ValidationRule(
|
||||
name="require_approval_all",
|
||||
description="All actions require approval",
|
||||
priority=50,
|
||||
action_types=list(ActionType),
|
||||
decision=SafetyDecision.REQUIRE_APPROVAL,
|
||||
reason="All actions require human approval",
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.add_rule(
|
||||
ValidationRule(
|
||||
name=f"require_approval_{i}",
|
||||
description=f"Require approval for: {pattern}",
|
||||
priority=50,
|
||||
tool_patterns=[pattern],
|
||||
decision=SafetyDecision.REQUIRE_APPROVAL,
|
||||
reason=f"Action matches approval-required pattern: {pattern}",
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("Loaded %d rules from policy: %s", len(self._rules), policy.name)
|
||||
|
||||
async def validate(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
policy: SafetyPolicy | None = None,
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
Validate an action against all rules.
|
||||
|
||||
Args:
|
||||
action: The action to validate
|
||||
policy: Optional policy override
|
||||
|
||||
Returns:
|
||||
ValidationResult with decision and details
|
||||
"""
|
||||
# Check bypass
|
||||
if self._bypass_enabled:
|
||||
logger.warning(
|
||||
"Validation bypass active: %s - allowing action %s",
|
||||
self._bypass_reason,
|
||||
action.id,
|
||||
)
|
||||
return ValidationResult(
|
||||
action_id=action.id,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
applied_rules=[],
|
||||
reasons=[f"Validation bypassed: {self._bypass_reason}"],
|
||||
)
|
||||
|
||||
# Check cache
|
||||
if self._cache_enabled:
|
||||
cache_key = self._get_cache_key(action)
|
||||
cached = await self._cache.get(cache_key)
|
||||
if cached:
|
||||
logger.debug("Using cached validation for action %s", action.id)
|
||||
return cached
|
||||
|
||||
# Load rules from policy if provided
|
||||
if policy and not self._rules:
|
||||
self.load_rules_from_policy(policy)
|
||||
|
||||
# Validate against rules
|
||||
applied_rules: list[str] = []
|
||||
reasons: list[str] = []
|
||||
final_decision = SafetyDecision.ALLOW
|
||||
approval_id: str | None = None
|
||||
|
||||
for rule in self._rules:
|
||||
if not rule.enabled:
|
||||
continue
|
||||
|
||||
if self._rule_matches(rule, action):
|
||||
applied_rules.append(rule.id)
|
||||
|
||||
if rule.reason:
|
||||
reasons.append(rule.reason)
|
||||
|
||||
# Handle decision priority
|
||||
if rule.decision == SafetyDecision.DENY:
|
||||
# Deny takes precedence
|
||||
final_decision = SafetyDecision.DENY
|
||||
break
|
||||
|
||||
elif rule.decision == SafetyDecision.REQUIRE_APPROVAL:
|
||||
# Upgrade to require approval
|
||||
if final_decision != SafetyDecision.DENY:
|
||||
final_decision = SafetyDecision.REQUIRE_APPROVAL
|
||||
|
||||
# If no rules matched and no explicit allow, default to allow
|
||||
if not applied_rules:
|
||||
reasons.append("No matching rules - default allow")
|
||||
|
||||
result = ValidationResult(
|
||||
action_id=action.id,
|
||||
decision=final_decision,
|
||||
applied_rules=applied_rules,
|
||||
reasons=reasons,
|
||||
approval_id=approval_id,
|
||||
)
|
||||
|
||||
# Cache result
|
||||
if self._cache_enabled:
|
||||
cache_key = self._get_cache_key(action)
|
||||
await self._cache.set(cache_key, result)
|
||||
|
||||
return result
|
||||
|
||||
async def validate_batch(
|
||||
self,
|
||||
actions: list[ActionRequest],
|
||||
policy: SafetyPolicy | None = None,
|
||||
) -> list[ValidationResult]:
|
||||
"""
|
||||
Validate multiple actions.
|
||||
|
||||
Args:
|
||||
actions: Actions to validate
|
||||
policy: Optional policy override
|
||||
|
||||
Returns:
|
||||
List of validation results
|
||||
"""
|
||||
tasks = [self.validate(action, policy) for action in actions]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
def enable_bypass(self, reason: str) -> None:
|
||||
"""
|
||||
Enable validation bypass (emergency use only).
|
||||
|
||||
Args:
|
||||
reason: Reason for enabling bypass
|
||||
"""
|
||||
logger.critical("Validation bypass enabled: %s", reason)
|
||||
self._bypass_enabled = True
|
||||
self._bypass_reason = reason
|
||||
|
||||
def disable_bypass(self) -> None:
|
||||
"""Disable validation bypass."""
|
||||
logger.info("Validation bypass disabled")
|
||||
self._bypass_enabled = False
|
||||
self._bypass_reason = None
|
||||
|
||||
async def clear_cache(self) -> None:
|
||||
"""Clear the validation cache."""
|
||||
await self._cache.clear()
|
||||
|
||||
def _rule_matches(self, rule: ValidationRule, action: ActionRequest) -> bool:
|
||||
"""Check if a rule matches an action."""
|
||||
# Check action types
|
||||
if rule.action_types:
|
||||
if action.action_type not in rule.action_types:
|
||||
return False
|
||||
|
||||
# Check tool patterns
|
||||
if rule.tool_patterns:
|
||||
if not action.tool_name:
|
||||
return False
|
||||
matched = False
|
||||
for pattern in rule.tool_patterns:
|
||||
if self._matches_pattern(action.tool_name, pattern):
|
||||
matched = True
|
||||
break
|
||||
if not matched:
|
||||
return False
|
||||
|
||||
# Check resource patterns
|
||||
if rule.resource_patterns:
|
||||
if not action.resource:
|
||||
return False
|
||||
matched = False
|
||||
for pattern in rule.resource_patterns:
|
||||
if self._matches_pattern(action.resource, pattern):
|
||||
matched = True
|
||||
break
|
||||
if not matched:
|
||||
return False
|
||||
|
||||
# Check agent IDs
|
||||
if rule.agent_ids:
|
||||
if action.metadata.agent_id not in rule.agent_ids:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _matches_pattern(self, value: str, pattern: str) -> bool:
|
||||
"""Check if value matches a pattern (supports wildcards)."""
|
||||
if pattern == "*":
|
||||
return True
|
||||
|
||||
# Use fnmatch for glob-style matching
|
||||
return fnmatch.fnmatch(value, pattern)
|
||||
|
||||
def _get_cache_key(self, action: ActionRequest) -> str:
|
||||
"""Generate a cache key for an action."""
|
||||
# Key based on action characteristics that affect validation
|
||||
key_parts = [
|
||||
action.action_type.value,
|
||||
action.tool_name or "",
|
||||
action.resource or "",
|
||||
action.metadata.agent_id,
|
||||
action.metadata.autonomy_level.value,
|
||||
]
|
||||
return ":".join(key_parts)
|
||||
|
||||
|
||||
# Module-level convenience functions
|
||||
|
||||
|
||||
def create_allow_rule(
|
||||
name: str,
|
||||
tool_patterns: list[str] | None = None,
|
||||
resource_patterns: list[str] | None = None,
|
||||
action_types: list[ActionType] | None = None,
|
||||
priority: int = 0,
|
||||
) -> ValidationRule:
|
||||
"""Create an allow rule."""
|
||||
return ValidationRule(
|
||||
name=name,
|
||||
tool_patterns=tool_patterns,
|
||||
resource_patterns=resource_patterns,
|
||||
action_types=action_types,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
|
||||
def create_deny_rule(
|
||||
name: str,
|
||||
tool_patterns: list[str] | None = None,
|
||||
resource_patterns: list[str] | None = None,
|
||||
action_types: list[ActionType] | None = None,
|
||||
reason: str | None = None,
|
||||
priority: int = 100,
|
||||
) -> ValidationRule:
|
||||
"""Create a deny rule."""
|
||||
return ValidationRule(
|
||||
name=name,
|
||||
tool_patterns=tool_patterns,
|
||||
resource_patterns=resource_patterns,
|
||||
action_types=action_types,
|
||||
decision=SafetyDecision.DENY,
|
||||
reason=reason,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
|
||||
def create_approval_rule(
|
||||
name: str,
|
||||
tool_patterns: list[str] | None = None,
|
||||
resource_patterns: list[str] | None = None,
|
||||
action_types: list[ActionType] | None = None,
|
||||
reason: str | None = None,
|
||||
priority: int = 50,
|
||||
) -> ValidationRule:
|
||||
"""Create a require-approval rule."""
|
||||
return ValidationRule(
|
||||
name=name,
|
||||
tool_patterns=tool_patterns,
|
||||
resource_patterns=resource_patterns,
|
||||
action_types=action_types,
|
||||
decision=SafetyDecision.REQUIRE_APPROVAL,
|
||||
reason=reason,
|
||||
priority=priority,
|
||||
)
|
||||
Reference in New Issue
Block a user