feat(backend): add Phase B safety subsystems (#63)

Implements core control subsystems for the safety framework:

**Action Validation (validation/validator.py):**
- Rule-based validation engine with priority ordering
- Allow/deny/require-approval rule types
- Pattern matching for tools and resources
- Validation result caching with LRU eviction
- Emergency bypass capability with audit

**Permission System (permissions/manager.py):**
- Per-agent permission grants on resources
- Resource pattern matching (wildcards)
- Temporary permissions with expiration
- Permission inheritance hierarchy
- Default deny with configurable defaults

**Cost Control (costs/controller.py):**
- Per-session and per-day budget tracking
- Token and USD cost limits
- Warning alerts at configurable thresholds
- Budget rollover and reset policies
- Real-time usage tracking

**Rate Limiting (limits/limiter.py):**
- Sliding window rate limiter
- Per-action, per-LLM-call, per-file-op limits
- Burst allowance with recovery
- Configurable limits per operation type

**Loop Detection (loops/detector.py):**
- Exact repetition detection (same action+args)
- Semantic repetition (similar actions)
- Oscillation pattern detection (A→B→A→B)
- Per-agent action history tracking
- Loop breaking suggestions

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-03 11:28:00 +01:00
parent 498c0a0e94
commit 728edd1453
10 changed files with 2020 additions and 5 deletions

View File

@@ -1 +1,15 @@
"""${dir} module."""
"""
Cost Control Module
Budget management and cost tracking.
"""
from .controller import (
BudgetTracker,
CostController,
)
__all__ = [
"BudgetTracker",
"CostController",
]

View File

@@ -0,0 +1,479 @@
"""
Cost Controller
Budget management and cost tracking for agent operations.
"""
import asyncio
import logging
from datetime import datetime, timedelta
from typing import Any
from ..config import get_safety_config
from ..exceptions import BudgetExceededError
from ..models import (
ActionRequest,
BudgetScope,
BudgetStatus,
)
logger = logging.getLogger(__name__)
class BudgetTracker:
"""Tracks usage against a budget limit."""
def __init__(
self,
scope: BudgetScope,
scope_id: str,
tokens_limit: int,
cost_limit_usd: float,
reset_interval: timedelta | None = None,
warning_threshold: float = 0.8,
) -> None:
self.scope = scope
self.scope_id = scope_id
self.tokens_limit = tokens_limit
self.cost_limit_usd = cost_limit_usd
self.warning_threshold = warning_threshold
self._reset_interval = reset_interval
self._tokens_used = 0
self._cost_used_usd = 0.0
self._created_at = datetime.utcnow()
self._last_reset = datetime.utcnow()
self._lock = asyncio.Lock()
async def add_usage(self, tokens: int, cost_usd: float) -> None:
"""Add usage to the tracker."""
async with self._lock:
self._check_reset()
self._tokens_used += tokens
self._cost_used_usd += cost_usd
async def get_status(self) -> BudgetStatus:
"""Get current budget status."""
async with self._lock:
self._check_reset()
tokens_remaining = max(0, self.tokens_limit - self._tokens_used)
cost_remaining = max(0, self.cost_limit_usd - self._cost_used_usd)
token_usage_ratio = (
self._tokens_used / self.tokens_limit if self.tokens_limit > 0 else 0
)
cost_usage_ratio = (
self._cost_used_usd / self.cost_limit_usd
if self.cost_limit_usd > 0
else 0
)
is_warning = max(token_usage_ratio, cost_usage_ratio) >= self.warning_threshold
is_exceeded = (
self._tokens_used >= self.tokens_limit
or self._cost_used_usd >= self.cost_limit_usd
)
reset_at = None
if self._reset_interval:
reset_at = self._last_reset + self._reset_interval
return BudgetStatus(
scope=self.scope,
scope_id=self.scope_id,
tokens_used=self._tokens_used,
tokens_limit=self.tokens_limit,
cost_used_usd=self._cost_used_usd,
cost_limit_usd=self.cost_limit_usd,
tokens_remaining=tokens_remaining,
cost_remaining_usd=cost_remaining,
warning_threshold=self.warning_threshold,
is_warning=is_warning,
is_exceeded=is_exceeded,
reset_at=reset_at,
)
async def check_budget(self, estimated_tokens: int, estimated_cost_usd: float) -> bool:
"""Check if there's enough budget for an operation."""
async with self._lock:
self._check_reset()
would_exceed_tokens = (self._tokens_used + estimated_tokens) > self.tokens_limit
would_exceed_cost = (
self._cost_used_usd + estimated_cost_usd
) > self.cost_limit_usd
return not (would_exceed_tokens or would_exceed_cost)
def _check_reset(self) -> None:
"""Check if budget should reset."""
if self._reset_interval is None:
return
now = datetime.utcnow()
if now >= self._last_reset + self._reset_interval:
logger.info(
"Resetting budget for %s:%s",
self.scope.value,
self.scope_id,
)
self._tokens_used = 0
self._cost_used_usd = 0.0
self._last_reset = now
async def reset(self) -> None:
"""Manually reset the budget."""
async with self._lock:
self._tokens_used = 0
self._cost_used_usd = 0.0
self._last_reset = datetime.utcnow()
class CostController:
"""
Controls costs and budgets for agent operations.
Features:
- Per-agent, per-project, per-session budgets
- Real-time cost tracking
- Budget alerts at configurable thresholds
- Cost prediction for planned actions
- Budget rollover policies
"""
def __init__(
self,
default_session_tokens: int | None = None,
default_session_cost_usd: float | None = None,
default_daily_tokens: int | None = None,
default_daily_cost_usd: float | None = None,
) -> None:
"""
Initialize the CostController.
Args:
default_session_tokens: Default token budget per session
default_session_cost_usd: Default USD budget per session
default_daily_tokens: Default token budget per day
default_daily_cost_usd: Default USD budget per day
"""
config = get_safety_config()
self._default_session_tokens = (
default_session_tokens or config.default_session_token_budget
)
self._default_session_cost = (
default_session_cost_usd or config.default_session_cost_limit
)
self._default_daily_tokens = (
default_daily_tokens or config.default_daily_token_budget
)
self._default_daily_cost = (
default_daily_cost_usd or config.default_daily_cost_limit
)
self._trackers: dict[str, BudgetTracker] = {}
self._lock = asyncio.Lock()
# Alert handlers
self._alert_handlers: list[Any] = []
async def get_or_create_tracker(
self,
scope: BudgetScope,
scope_id: str,
) -> BudgetTracker:
"""Get or create a budget tracker."""
key = f"{scope.value}:{scope_id}"
async with self._lock:
if key not in self._trackers:
if scope == BudgetScope.SESSION:
tracker = BudgetTracker(
scope=scope,
scope_id=scope_id,
tokens_limit=self._default_session_tokens,
cost_limit_usd=self._default_session_cost,
)
elif scope == BudgetScope.DAILY:
tracker = BudgetTracker(
scope=scope,
scope_id=scope_id,
tokens_limit=self._default_daily_tokens,
cost_limit_usd=self._default_daily_cost,
reset_interval=timedelta(days=1),
)
else:
# Default
tracker = BudgetTracker(
scope=scope,
scope_id=scope_id,
tokens_limit=self._default_session_tokens,
cost_limit_usd=self._default_session_cost,
)
self._trackers[key] = tracker
return self._trackers[key]
async def check_budget(
self,
agent_id: str,
session_id: str | None,
estimated_tokens: int,
estimated_cost_usd: float,
) -> bool:
"""
Check if there's enough budget for an operation.
Args:
agent_id: ID of the agent
session_id: Optional session ID
estimated_tokens: Estimated token usage
estimated_cost_usd: Estimated USD cost
Returns:
True if budget is available
"""
# Check session budget
if session_id:
session_tracker = await self.get_or_create_tracker(
BudgetScope.SESSION, session_id
)
if not await session_tracker.check_budget(estimated_tokens, estimated_cost_usd):
return False
# Check agent daily budget
agent_tracker = await self.get_or_create_tracker(
BudgetScope.DAILY, agent_id
)
if not await agent_tracker.check_budget(estimated_tokens, estimated_cost_usd):
return False
return True
async def check_action(self, action: ActionRequest) -> bool:
"""
Check if an action is within budget.
Args:
action: The action to check
Returns:
True if within budget
"""
return await self.check_budget(
agent_id=action.metadata.agent_id,
session_id=action.metadata.session_id,
estimated_tokens=action.estimated_cost_tokens,
estimated_cost_usd=action.estimated_cost_usd,
)
async def require_budget(
self,
agent_id: str,
session_id: str | None,
estimated_tokens: int,
estimated_cost_usd: float,
) -> None:
"""
Require budget or raise exception.
Args:
agent_id: ID of the agent
session_id: Optional session ID
estimated_tokens: Estimated token usage
estimated_cost_usd: Estimated USD cost
Raises:
BudgetExceededError: If budget is exceeded
"""
if not await self.check_budget(
agent_id, session_id, estimated_tokens, estimated_cost_usd
):
# Determine which budget was exceeded
if session_id:
session_tracker = await self.get_or_create_tracker(
BudgetScope.SESSION, session_id
)
session_status = await session_tracker.get_status()
if session_status.is_exceeded:
raise BudgetExceededError(
"Session budget exceeded",
budget_type="session",
current_usage=session_status.tokens_used,
budget_limit=session_status.tokens_limit,
agent_id=agent_id,
)
agent_tracker = await self.get_or_create_tracker(
BudgetScope.DAILY, agent_id
)
agent_status = await agent_tracker.get_status()
raise BudgetExceededError(
"Daily budget exceeded",
budget_type="daily",
current_usage=agent_status.tokens_used,
budget_limit=agent_status.tokens_limit,
agent_id=agent_id,
)
async def record_usage(
self,
agent_id: str,
session_id: str | None,
tokens: int,
cost_usd: float,
) -> None:
"""
Record actual usage.
Args:
agent_id: ID of the agent
session_id: Optional session ID
tokens: Actual token usage
cost_usd: Actual USD cost
"""
# Update session budget
if session_id:
session_tracker = await self.get_or_create_tracker(
BudgetScope.SESSION, session_id
)
await session_tracker.add_usage(tokens, cost_usd)
# Check for warning
status = await session_tracker.get_status()
if status.is_warning and not status.is_exceeded:
await self._send_alert(
"warning",
f"Session {session_id} at {status.tokens_used}/{status.tokens_limit} tokens",
status,
)
# Update agent daily budget
agent_tracker = await self.get_or_create_tracker(BudgetScope.DAILY, agent_id)
await agent_tracker.add_usage(tokens, cost_usd)
# Check for warning
status = await agent_tracker.get_status()
if status.is_warning and not status.is_exceeded:
await self._send_alert(
"warning",
f"Agent {agent_id} at {status.tokens_used}/{status.tokens_limit} daily tokens",
status,
)
async def get_status(
self,
scope: BudgetScope,
scope_id: str,
) -> BudgetStatus | None:
"""
Get budget status.
Args:
scope: Budget scope
scope_id: ID within scope
Returns:
Budget status or None if not tracked
"""
key = f"{scope.value}:{scope_id}"
async with self._lock:
tracker = self._trackers.get(key)
if tracker:
return await tracker.get_status()
return None
async def get_all_statuses(self) -> list[BudgetStatus]:
"""Get status of all tracked budgets."""
statuses = []
async with self._lock:
trackers = list(self._trackers.values())
for tracker in trackers:
statuses.append(await tracker.get_status())
return statuses
async def set_budget(
self,
scope: BudgetScope,
scope_id: str,
tokens_limit: int,
cost_limit_usd: float,
) -> None:
"""
Set a custom budget limit.
Args:
scope: Budget scope
scope_id: ID within scope
tokens_limit: Token limit
cost_limit_usd: USD limit
"""
key = f"{scope.value}:{scope_id}"
reset_interval = None
if scope == BudgetScope.DAILY:
reset_interval = timedelta(days=1)
elif scope == BudgetScope.WEEKLY:
reset_interval = timedelta(weeks=1)
elif scope == BudgetScope.MONTHLY:
reset_interval = timedelta(days=30)
async with self._lock:
self._trackers[key] = BudgetTracker(
scope=scope,
scope_id=scope_id,
tokens_limit=tokens_limit,
cost_limit_usd=cost_limit_usd,
reset_interval=reset_interval,
)
async def reset_budget(self, scope: BudgetScope, scope_id: str) -> bool:
"""
Reset a budget tracker.
Args:
scope: Budget scope
scope_id: ID within scope
Returns:
True if tracker was found and reset
"""
key = f"{scope.value}:{scope_id}"
async with self._lock:
tracker = self._trackers.get(key)
if tracker:
await tracker.reset()
return True
return False
def add_alert_handler(self, handler: Any) -> None:
"""Add an alert handler."""
self._alert_handlers.append(handler)
def remove_alert_handler(self, handler: Any) -> None:
"""Remove an alert handler."""
if handler in self._alert_handlers:
self._alert_handlers.remove(handler)
async def _send_alert(
self,
alert_type: str,
message: str,
status: BudgetStatus,
) -> None:
"""Send alert to all handlers."""
for handler in self._alert_handlers:
try:
if asyncio.iscoroutinefunction(handler):
await handler(alert_type, message, status)
else:
handler(alert_type, message, status)
except Exception as e:
logger.error("Error in alert handler: %s", e)

View File

@@ -1 +1,15 @@
"""${dir} module."""
"""
Rate Limiting Module
Sliding window rate limiting for agent operations.
"""
from .limiter import (
RateLimiter,
SlidingWindowCounter,
)
__all__ = [
"RateLimiter",
"SlidingWindowCounter",
]

View File

@@ -0,0 +1,368 @@
"""
Rate Limiter
Sliding window rate limiting for agent operations.
"""
import asyncio
import logging
import time
from collections import deque
from ..config import get_safety_config
from ..exceptions import RateLimitExceededError
from ..models import (
ActionRequest,
RateLimitConfig,
RateLimitStatus,
)
logger = logging.getLogger(__name__)
class SlidingWindowCounter:
"""Sliding window counter for rate limiting."""
def __init__(
self,
limit: int,
window_seconds: int,
burst_limit: int | None = None,
) -> None:
self.limit = limit
self.window_seconds = window_seconds
self.burst_limit = burst_limit or limit
self._timestamps: deque[float] = deque()
self._lock = asyncio.Lock()
async def try_acquire(self) -> tuple[bool, float]:
"""
Try to acquire a slot.
Returns:
Tuple of (allowed, retry_after_seconds)
"""
now = time.time()
window_start = now - self.window_seconds
async with self._lock:
# Remove expired entries
while self._timestamps and self._timestamps[0] < window_start:
self._timestamps.popleft()
current_count = len(self._timestamps)
# Check burst limit (instant check)
if current_count >= self.burst_limit:
# Calculate retry time
oldest = self._timestamps[0] if self._timestamps else now
retry_after = oldest + self.window_seconds - now
return False, max(0, retry_after)
# Check window limit
if current_count >= self.limit:
oldest = self._timestamps[0] if self._timestamps else now
retry_after = oldest + self.window_seconds - now
return False, max(0, retry_after)
# Allow and record
self._timestamps.append(now)
return True, 0.0
async def get_status(self) -> tuple[int, int, float]:
"""
Get current status.
Returns:
Tuple of (current_count, remaining, reset_in_seconds)
"""
now = time.time()
window_start = now - self.window_seconds
async with self._lock:
# Remove expired entries
while self._timestamps and self._timestamps[0] < window_start:
self._timestamps.popleft()
current_count = len(self._timestamps)
remaining = max(0, self.limit - current_count)
if self._timestamps:
reset_in = self._timestamps[0] + self.window_seconds - now
else:
reset_in = 0.0
return current_count, remaining, max(0, reset_in)
class RateLimiter:
"""
Rate limiter for agent operations.
Features:
- Per-tool rate limits
- Per-agent rate limits
- Per-resource rate limits
- Sliding window implementation
- Burst allowance with recovery
- Slowdown before hard block
"""
def __init__(self) -> None:
"""Initialize the RateLimiter."""
config = get_safety_config()
self._configs: dict[str, RateLimitConfig] = {}
self._counters: dict[str, SlidingWindowCounter] = {}
self._lock = asyncio.Lock()
# Default rate limits
self._default_limits = {
"actions": RateLimitConfig(
name="actions",
limit=config.default_actions_per_minute,
window_seconds=60,
),
"llm_calls": RateLimitConfig(
name="llm_calls",
limit=config.default_llm_calls_per_minute,
window_seconds=60,
),
"file_ops": RateLimitConfig(
name="file_ops",
limit=config.default_file_ops_per_minute,
window_seconds=60,
),
}
def configure(self, config: RateLimitConfig) -> None:
"""
Configure a rate limit.
Args:
config: Rate limit configuration
"""
self._configs[config.name] = config
logger.debug(
"Configured rate limit: %s = %d/%ds",
config.name,
config.limit,
config.window_seconds,
)
async def check(
self,
limit_name: str,
key: str,
) -> RateLimitStatus:
"""
Check rate limit without consuming a slot.
Args:
limit_name: Name of the rate limit
key: Key for tracking (e.g., agent_id)
Returns:
Rate limit status
"""
counter = await self._get_counter(limit_name, key)
config = self._get_config(limit_name)
current, remaining, reset_in = await counter.get_status()
from datetime import datetime, timedelta
return RateLimitStatus(
name=limit_name,
current_count=current,
limit=config.limit,
window_seconds=config.window_seconds,
remaining=remaining,
reset_at=datetime.utcnow() + timedelta(seconds=reset_in),
is_limited=remaining <= 0,
retry_after_seconds=reset_in if remaining <= 0 else 0.0,
)
async def acquire(
self,
limit_name: str,
key: str,
) -> tuple[bool, RateLimitStatus]:
"""
Try to acquire a rate limit slot.
Args:
limit_name: Name of the rate limit
key: Key for tracking (e.g., agent_id)
Returns:
Tuple of (allowed, status)
"""
counter = await self._get_counter(limit_name, key)
config = self._get_config(limit_name)
allowed, retry_after = await counter.try_acquire()
current, remaining, reset_in = await counter.get_status()
from datetime import datetime, timedelta
status = RateLimitStatus(
name=limit_name,
current_count=current,
limit=config.limit,
window_seconds=config.window_seconds,
remaining=remaining,
reset_at=datetime.utcnow() + timedelta(seconds=reset_in),
is_limited=not allowed,
retry_after_seconds=retry_after,
)
return allowed, status
async def check_action(
self,
action: ActionRequest,
) -> tuple[bool, list[RateLimitStatus]]:
"""
Check all applicable rate limits for an action.
Args:
action: The action to check
Returns:
Tuple of (allowed, list of statuses)
"""
agent_id = action.metadata.agent_id
statuses: list[RateLimitStatus] = []
allowed = True
# Check general actions limit
actions_allowed, actions_status = await self.acquire("actions", agent_id)
statuses.append(actions_status)
if not actions_allowed:
allowed = False
# Check LLM-specific limit for LLM calls
if action.action_type.value == "llm_call":
llm_allowed, llm_status = await self.acquire("llm_calls", agent_id)
statuses.append(llm_status)
if not llm_allowed:
allowed = False
# Check file ops limit for file operations
if action.action_type.value in {"file_read", "file_write", "file_delete"}:
file_allowed, file_status = await self.acquire("file_ops", agent_id)
statuses.append(file_status)
if not file_allowed:
allowed = False
return allowed, statuses
async def require(
self,
limit_name: str,
key: str,
) -> None:
"""
Require rate limit slot or raise exception.
Args:
limit_name: Name of the rate limit
key: Key for tracking
Raises:
RateLimitExceededError: If rate limit exceeded
"""
allowed, status = await self.acquire(limit_name, key)
if not allowed:
raise RateLimitExceededError(
f"Rate limit exceeded: {limit_name}",
limit_type=limit_name,
limit_value=status.limit,
window_seconds=status.window_seconds,
retry_after_seconds=status.retry_after_seconds,
)
async def get_all_statuses(self, key: str) -> dict[str, RateLimitStatus]:
"""
Get status of all rate limits for a key.
Args:
key: Key for tracking
Returns:
Dict of limit name to status
"""
statuses = {}
for name in self._default_limits:
statuses[name] = await self.check(name, key)
for name in self._configs:
if name not in statuses:
statuses[name] = await self.check(name, key)
return statuses
async def reset(self, limit_name: str, key: str) -> bool:
"""
Reset a rate limit counter.
Args:
limit_name: Name of the rate limit
key: Key for tracking
Returns:
True if counter was found and reset
"""
counter_key = f"{limit_name}:{key}"
async with self._lock:
if counter_key in self._counters:
del self._counters[counter_key]
return True
return False
async def reset_all(self, key: str) -> int:
"""
Reset all rate limit counters for a key.
Args:
key: Key for tracking
Returns:
Number of counters reset
"""
count = 0
async with self._lock:
to_remove = [k for k in self._counters if k.endswith(f":{key}")]
for k in to_remove:
del self._counters[k]
count += 1
return count
def _get_config(self, limit_name: str) -> RateLimitConfig:
"""Get configuration for a rate limit."""
if limit_name in self._configs:
return self._configs[limit_name]
if limit_name in self._default_limits:
return self._default_limits[limit_name]
# Return default
return RateLimitConfig(
name=limit_name,
limit=60,
window_seconds=60,
)
async def _get_counter(
self,
limit_name: str,
key: str,
) -> SlidingWindowCounter:
"""Get or create a counter."""
counter_key = f"{limit_name}:{key}"
config = self._get_config(limit_name)
async with self._lock:
if counter_key not in self._counters:
self._counters[counter_key] = SlidingWindowCounter(
limit=config.limit,
window_seconds=config.window_seconds,
burst_limit=config.burst_limit,
)
return self._counters[counter_key]

View File

@@ -1 +1,17 @@
"""${dir} module."""
"""
Loop Detection Module
Detects and prevents action loops in agent behavior.
"""
from .detector import (
ActionSignature,
LoopBreaker,
LoopDetector,
)
__all__ = [
"ActionSignature",
"LoopBreaker",
"LoopDetector",
]

View File

@@ -0,0 +1,267 @@
"""
Loop Detector
Detects and prevents action loops in agent behavior.
"""
import asyncio
import hashlib
import json
import logging
from collections import Counter, deque
from typing import Any
from ..config import get_safety_config
from ..exceptions import LoopDetectedError
from ..models import ActionRequest
logger = logging.getLogger(__name__)
class ActionSignature:
"""Signature of an action for comparison."""
def __init__(self, action: ActionRequest) -> None:
self.action_type = action.action_type.value
self.tool_name = action.tool_name
self.resource = action.resource
self.args_hash = self._hash_args(action.arguments)
def _hash_args(self, args: dict[str, Any]) -> str:
"""Create a hash of the arguments."""
try:
serialized = json.dumps(args, sort_keys=True, default=str)
return hashlib.sha256(serialized.encode()).hexdigest()[:8]
except Exception:
return ""
def exact_key(self) -> str:
"""Key for exact match detection."""
return f"{self.action_type}:{self.tool_name}:{self.resource}:{self.args_hash}"
def semantic_key(self) -> str:
"""Key for semantic (similar) match detection."""
return f"{self.action_type}:{self.tool_name}:{self.resource}"
def type_key(self) -> str:
"""Key for action type only."""
return f"{self.action_type}"
class LoopDetector:
"""
Detects action loops and repetitive behavior.
Loop Types:
- Exact: Same action with same arguments
- Semantic: Similar actions (same type/tool/resource, different args)
- Oscillation: A→B→A→B patterns
"""
def __init__(
self,
history_size: int | None = None,
max_exact_repetitions: int | None = None,
max_semantic_repetitions: int | None = None,
) -> None:
"""
Initialize the LoopDetector.
Args:
history_size: Size of action history to track
max_exact_repetitions: Max allowed exact repetitions
max_semantic_repetitions: Max allowed semantic repetitions
"""
config = get_safety_config()
self._history_size = history_size or config.loop_history_size
self._max_exact = max_exact_repetitions or config.max_repeated_actions
self._max_semantic = max_semantic_repetitions or config.max_similar_actions
# Per-agent history
self._histories: dict[str, deque[ActionSignature]] = {}
self._lock = asyncio.Lock()
async def check(self, action: ActionRequest) -> tuple[bool, str | None]:
"""
Check if an action would create a loop.
Args:
action: The action to check
Returns:
Tuple of (is_loop, loop_type)
"""
agent_id = action.metadata.agent_id
signature = ActionSignature(action)
async with self._lock:
history = self._get_history(agent_id)
# Check exact repetition
exact_key = signature.exact_key()
exact_count = sum(1 for h in history if h.exact_key() == exact_key)
if exact_count >= self._max_exact:
return True, "exact"
# Check semantic repetition
semantic_key = signature.semantic_key()
semantic_count = sum(1 for h in history if h.semantic_key() == semantic_key)
if semantic_count >= self._max_semantic:
return True, "semantic"
# Check oscillation (A→B→A→B pattern)
if len(history) >= 3:
pattern = self._detect_oscillation(history, signature)
if pattern:
return True, "oscillation"
return False, None
async def check_and_raise(self, action: ActionRequest) -> None:
"""
Check for loops and raise if detected.
Args:
action: The action to check
Raises:
LoopDetectedError: If loop is detected
"""
is_loop, loop_type = await self.check(action)
if is_loop:
signature = ActionSignature(action)
raise LoopDetectedError(
f"Loop detected: {loop_type}",
loop_type=loop_type or "unknown",
repetition_count=self._max_exact if loop_type == "exact" else self._max_semantic,
action_pattern=[signature.semantic_key()],
agent_id=action.metadata.agent_id,
action_id=action.id,
)
async def record(self, action: ActionRequest) -> None:
"""
Record an action in history.
Args:
action: The action to record
"""
agent_id = action.metadata.agent_id
signature = ActionSignature(action)
async with self._lock:
history = self._get_history(agent_id)
history.append(signature)
async def clear_history(self, agent_id: str) -> None:
"""
Clear history for an agent.
Args:
agent_id: ID of the agent
"""
async with self._lock:
if agent_id in self._histories:
self._histories[agent_id].clear()
async def get_stats(self, agent_id: str) -> dict[str, Any]:
"""
Get loop detection stats for an agent.
Args:
agent_id: ID of the agent
Returns:
Stats dictionary
"""
async with self._lock:
history = self._get_history(agent_id)
# Count action types
type_counts = Counter(h.type_key() for h in history)
semantic_counts = Counter(h.semantic_key() for h in history)
return {
"history_size": len(history),
"max_history": self._history_size,
"action_type_counts": dict(type_counts),
"top_semantic_patterns": semantic_counts.most_common(5),
}
def _get_history(self, agent_id: str) -> deque[ActionSignature]:
"""Get or create history for an agent."""
if agent_id not in self._histories:
self._histories[agent_id] = deque(maxlen=self._history_size)
return self._histories[agent_id]
def _detect_oscillation(
self,
history: deque[ActionSignature],
current: ActionSignature,
) -> bool:
"""
Detect A→B→A→B oscillation pattern.
Looks at last 4+ actions including current.
"""
if len(history) < 3:
return False
# Get last 3 actions + current
recent = [*list(history)[-3:], current]
# Check for A→B→A→B pattern
if len(recent) >= 4:
# Get semantic keys
keys = [a.semantic_key() for a in recent[-4:]]
# Pattern: k[0]==k[2] and k[1]==k[3] and k[0]!=k[1]
if keys[0] == keys[2] and keys[1] == keys[3] and keys[0] != keys[1]:
return True
return False
class LoopBreaker:
"""
Strategies for breaking detected loops.
"""
@staticmethod
async def suggest_alternatives(
action: ActionRequest,
loop_type: str,
) -> list[str]:
"""
Suggest alternative actions when loop is detected.
Args:
action: The looping action
loop_type: Type of loop detected
Returns:
List of suggestions
"""
suggestions = []
if loop_type == "exact":
suggestions.append(
"The same action with identical arguments has been repeated too many times. "
"Consider: (1) Verify the action succeeded, (2) Try a different approach, "
"(3) Escalate for human review"
)
elif loop_type == "semantic":
suggestions.append(
"Similar actions have been repeated too many times. "
"Consider: (1) Review if the approach is working, (2) Try an alternative method, "
"(3) Request clarification on the goal"
)
elif loop_type == "oscillation":
suggestions.append(
"An oscillating pattern was detected (A→B→A→B). "
"This usually indicates conflicting goals or a stuck state. "
"Consider: (1) Step back and reassess, (2) Request human guidance"
)
return suggestions

View File

@@ -1 +1,15 @@
"""${dir} module."""
"""
Permission Management Module
Agent permissions for resource access.
"""
from .manager import (
PermissionGrant,
PermissionManager,
)
__all__ = [
"PermissionGrant",
"PermissionManager",
]

View File

@@ -0,0 +1,384 @@
"""
Permission Manager
Manages permissions for agent actions on resources.
"""
import asyncio
import fnmatch
import logging
from datetime import datetime, timedelta
from uuid import uuid4
from ..exceptions import PermissionDeniedError
from ..models import (
ActionRequest,
ActionType,
PermissionLevel,
ResourceType,
)
logger = logging.getLogger(__name__)
class PermissionGrant:
"""A permission grant for an agent on a resource."""
def __init__(
self,
agent_id: str,
resource_pattern: str,
resource_type: ResourceType,
level: PermissionLevel,
*,
expires_at: datetime | None = None,
granted_by: str | None = None,
reason: str | None = None,
) -> None:
self.id = str(uuid4())
self.agent_id = agent_id
self.resource_pattern = resource_pattern
self.resource_type = resource_type
self.level = level
self.expires_at = expires_at
self.granted_by = granted_by
self.reason = reason
self.created_at = datetime.utcnow()
def is_expired(self) -> bool:
"""Check if the grant has expired."""
if self.expires_at is None:
return False
return datetime.utcnow() > self.expires_at
def matches(self, resource: str, resource_type: ResourceType) -> bool:
"""Check if this grant applies to a resource."""
if self.resource_type != resource_type:
return False
return fnmatch.fnmatch(resource, self.resource_pattern)
def allows(self, required_level: PermissionLevel) -> bool:
"""Check if this grant allows the required permission level."""
# Permission level hierarchy
hierarchy = {
PermissionLevel.NONE: 0,
PermissionLevel.READ: 1,
PermissionLevel.WRITE: 2,
PermissionLevel.EXECUTE: 3,
PermissionLevel.DELETE: 4,
PermissionLevel.ADMIN: 5,
}
return hierarchy[self.level] >= hierarchy[required_level]
class PermissionManager:
"""
Manages permissions for agent access to resources.
Features:
- Permission grants by agent/resource pattern
- Permission inheritance (project → agent → action)
- Temporary permissions with expiration
- Least-privilege defaults
- Permission escalation logging
"""
def __init__(
self,
default_deny: bool = True,
) -> None:
"""
Initialize the PermissionManager.
Args:
default_deny: If True, deny access unless explicitly granted
"""
self._grants: list[PermissionGrant] = []
self._default_deny = default_deny
self._lock = asyncio.Lock()
# Default permissions for common resources
self._default_permissions: dict[ResourceType, PermissionLevel] = {
ResourceType.FILE: PermissionLevel.READ,
ResourceType.DATABASE: PermissionLevel.READ,
ResourceType.API: PermissionLevel.READ,
ResourceType.GIT: PermissionLevel.READ,
ResourceType.LLM: PermissionLevel.EXECUTE,
ResourceType.SHELL: PermissionLevel.NONE,
ResourceType.NETWORK: PermissionLevel.READ,
}
async def grant(
self,
agent_id: str,
resource_pattern: str,
resource_type: ResourceType,
level: PermissionLevel,
*,
duration_seconds: int | None = None,
granted_by: str | None = None,
reason: str | None = None,
) -> PermissionGrant:
"""
Grant a permission to an agent.
Args:
agent_id: ID of the agent
resource_pattern: Pattern for matching resources (supports wildcards)
resource_type: Type of resource
level: Permission level to grant
duration_seconds: Optional duration for temporary permission
granted_by: Who granted the permission
reason: Reason for granting
Returns:
The created permission grant
"""
expires_at = None
if duration_seconds:
expires_at = datetime.utcnow() + timedelta(seconds=duration_seconds)
grant = PermissionGrant(
agent_id=agent_id,
resource_pattern=resource_pattern,
resource_type=resource_type,
level=level,
expires_at=expires_at,
granted_by=granted_by,
reason=reason,
)
async with self._lock:
self._grants.append(grant)
logger.info(
"Permission granted: agent=%s, resource=%s, type=%s, level=%s",
agent_id,
resource_pattern,
resource_type.value,
level.value,
)
return grant
async def revoke(self, grant_id: str) -> bool:
"""
Revoke a permission grant.
Args:
grant_id: ID of the grant to revoke
Returns:
True if grant was found and revoked
"""
async with self._lock:
for i, grant in enumerate(self._grants):
if grant.id == grant_id:
del self._grants[i]
logger.info("Permission revoked: %s", grant_id)
return True
return False
async def revoke_all(self, agent_id: str) -> int:
"""
Revoke all permissions for an agent.
Args:
agent_id: ID of the agent
Returns:
Number of grants revoked
"""
async with self._lock:
original_count = len(self._grants)
self._grants = [g for g in self._grants if g.agent_id != agent_id]
revoked = original_count - len(self._grants)
if revoked:
logger.info("Revoked %d permissions for agent %s", revoked, agent_id)
return revoked
async def check(
self,
agent_id: str,
resource: str,
resource_type: ResourceType,
required_level: PermissionLevel,
) -> bool:
"""
Check if an agent has permission to access a resource.
Args:
agent_id: ID of the agent
resource: Resource to access
resource_type: Type of resource
required_level: Required permission level
Returns:
True if access is allowed
"""
# Clean up expired grants
await self._cleanup_expired()
async with self._lock:
for grant in self._grants:
if grant.agent_id != agent_id:
continue
if grant.is_expired():
continue
if grant.matches(resource, resource_type):
if grant.allows(required_level):
return True
# Check default permissions
if not self._default_deny:
default_level = self._default_permissions.get(
resource_type, PermissionLevel.NONE
)
hierarchy = {
PermissionLevel.NONE: 0,
PermissionLevel.READ: 1,
PermissionLevel.WRITE: 2,
PermissionLevel.EXECUTE: 3,
PermissionLevel.DELETE: 4,
PermissionLevel.ADMIN: 5,
}
if hierarchy[default_level] >= hierarchy[required_level]:
return True
return False
async def check_action(self, action: ActionRequest) -> bool:
"""
Check if an action is permitted.
Args:
action: The action to check
Returns:
True if action is allowed
"""
# Determine required permission level from action type
level_map = {
ActionType.FILE_READ: PermissionLevel.READ,
ActionType.FILE_WRITE: PermissionLevel.WRITE,
ActionType.FILE_DELETE: PermissionLevel.DELETE,
ActionType.DATABASE_QUERY: PermissionLevel.READ,
ActionType.DATABASE_MUTATE: PermissionLevel.WRITE,
ActionType.SHELL_COMMAND: PermissionLevel.EXECUTE,
ActionType.API_CALL: PermissionLevel.EXECUTE,
ActionType.GIT_OPERATION: PermissionLevel.WRITE,
ActionType.LLM_CALL: PermissionLevel.EXECUTE,
ActionType.NETWORK_REQUEST: PermissionLevel.READ,
ActionType.TOOL_CALL: PermissionLevel.EXECUTE,
}
required_level = level_map.get(action.action_type, PermissionLevel.EXECUTE)
# Determine resource type from action
resource_type_map = {
ActionType.FILE_READ: ResourceType.FILE,
ActionType.FILE_WRITE: ResourceType.FILE,
ActionType.FILE_DELETE: ResourceType.FILE,
ActionType.DATABASE_QUERY: ResourceType.DATABASE,
ActionType.DATABASE_MUTATE: ResourceType.DATABASE,
ActionType.SHELL_COMMAND: ResourceType.SHELL,
ActionType.API_CALL: ResourceType.API,
ActionType.GIT_OPERATION: ResourceType.GIT,
ActionType.LLM_CALL: ResourceType.LLM,
ActionType.NETWORK_REQUEST: ResourceType.NETWORK,
}
resource_type = resource_type_map.get(action.action_type, ResourceType.CUSTOM)
resource = action.resource or action.tool_name or "*"
return await self.check(
agent_id=action.metadata.agent_id,
resource=resource,
resource_type=resource_type,
required_level=required_level,
)
async def require_permission(
self,
agent_id: str,
resource: str,
resource_type: ResourceType,
required_level: PermissionLevel,
) -> None:
"""
Require permission or raise exception.
Args:
agent_id: ID of the agent
resource: Resource to access
resource_type: Type of resource
required_level: Required permission level
Raises:
PermissionDeniedError: If permission is denied
"""
if not await self.check(agent_id, resource, resource_type, required_level):
raise PermissionDeniedError(
f"Permission denied: {resource}",
action_type=None,
resource=resource,
required_permission=required_level.value,
agent_id=agent_id,
)
async def list_grants(
self,
agent_id: str | None = None,
resource_type: ResourceType | None = None,
) -> list[PermissionGrant]:
"""
List permission grants.
Args:
agent_id: Optional filter by agent
resource_type: Optional filter by resource type
Returns:
List of matching grants
"""
await self._cleanup_expired()
async with self._lock:
grants = list(self._grants)
if agent_id:
grants = [g for g in grants if g.agent_id == agent_id]
if resource_type:
grants = [g for g in grants if g.resource_type == resource_type]
return grants
def set_default_permission(
self,
resource_type: ResourceType,
level: PermissionLevel,
) -> None:
"""
Set the default permission level for a resource type.
Args:
resource_type: Type of resource
level: Default permission level
"""
self._default_permissions[resource_type] = level
async def _cleanup_expired(self) -> None:
"""Remove expired grants."""
async with self._lock:
original_count = len(self._grants)
self._grants = [g for g in self._grants if not g.is_expired()]
removed = original_count - len(self._grants)
if removed:
logger.debug("Cleaned up %d expired permission grants", removed)

View File

@@ -1 +1,21 @@
"""${dir} module."""
"""
Action Validation Module
Pre-execution validation with rule engine.
"""
from .validator import (
ActionValidator,
ValidationCache,
create_allow_rule,
create_approval_rule,
create_deny_rule,
)
__all__ = [
"ActionValidator",
"ValidationCache",
"create_allow_rule",
"create_approval_rule",
"create_deny_rule",
]

View File

@@ -0,0 +1,439 @@
"""
Action Validator
Pre-execution validation with rule engine for action requests.
"""
import asyncio
import fnmatch
import logging
from collections import OrderedDict
from ..config import get_safety_config
from ..models import (
ActionRequest,
ActionType,
SafetyDecision,
SafetyPolicy,
ValidationResult,
ValidationRule,
)
logger = logging.getLogger(__name__)
class ValidationCache:
"""LRU cache for validation results."""
def __init__(self, max_size: int = 1000, ttl_seconds: int = 60) -> None:
self._cache: OrderedDict[str, tuple[ValidationResult, float]] = OrderedDict()
self._max_size = max_size
self._ttl = ttl_seconds
self._lock = asyncio.Lock()
async def get(self, key: str) -> ValidationResult | None:
"""Get cached validation result."""
import time
async with self._lock:
if key not in self._cache:
return None
result, timestamp = self._cache[key]
if time.time() - timestamp > self._ttl:
del self._cache[key]
return None
# Move to end (LRU)
self._cache.move_to_end(key)
return result
async def set(self, key: str, result: ValidationResult) -> None:
"""Cache a validation result."""
import time
async with self._lock:
if key in self._cache:
self._cache.move_to_end(key)
else:
if len(self._cache) >= self._max_size:
self._cache.popitem(last=False)
self._cache[key] = (result, time.time())
async def clear(self) -> None:
"""Clear the cache."""
async with self._lock:
self._cache.clear()
class ActionValidator:
"""
Validates actions against safety rules before execution.
Features:
- Rule-based validation engine
- Allow/deny/require-approval rules
- Pattern matching for tools and resources
- Validation result caching
- Bypass capability for emergencies
"""
def __init__(
self,
cache_enabled: bool = True,
cache_size: int = 1000,
cache_ttl: int = 60,
) -> None:
"""
Initialize the ActionValidator.
Args:
cache_enabled: Whether to cache validation results
cache_size: Maximum cache entries
cache_ttl: Cache TTL in seconds
"""
self._rules: list[ValidationRule] = []
self._cache_enabled = cache_enabled
self._cache = ValidationCache(max_size=cache_size, ttl_seconds=cache_ttl)
self._bypass_enabled = False
self._bypass_reason: str | None = None
config = get_safety_config()
self._cache_enabled = cache_enabled
self._cache_ttl = config.validation_cache_ttl
self._cache_size = config.validation_cache_size
def add_rule(self, rule: ValidationRule) -> None:
"""
Add a validation rule.
Args:
rule: The rule to add
"""
self._rules.append(rule)
# Re-sort by priority (higher first)
self._rules.sort(key=lambda r: r.priority, reverse=True)
logger.debug("Added validation rule: %s (priority %d)", rule.name, rule.priority)
def remove_rule(self, rule_id: str) -> bool:
"""
Remove a validation rule by ID.
Args:
rule_id: ID of the rule to remove
Returns:
True if rule was found and removed
"""
for i, rule in enumerate(self._rules):
if rule.id == rule_id:
del self._rules[i]
logger.debug("Removed validation rule: %s", rule_id)
return True
return False
def clear_rules(self) -> None:
"""Remove all validation rules."""
self._rules.clear()
def load_rules_from_policy(self, policy: SafetyPolicy) -> None:
"""
Load validation rules from a safety policy.
Args:
policy: The policy to load rules from
"""
# Clear existing rules
self.clear_rules()
# Add rules from policy
for rule in policy.validation_rules:
self.add_rule(rule)
# Create implicit rules from policy settings
# Denied tools
for i, pattern in enumerate(policy.denied_tools):
self.add_rule(
ValidationRule(
name=f"deny_tool_{i}",
description=f"Deny tool pattern: {pattern}",
priority=100, # High priority for denials
tool_patterns=[pattern],
decision=SafetyDecision.DENY,
reason=f"Tool matches denied pattern: {pattern}",
)
)
# Require approval patterns
for i, pattern in enumerate(policy.require_approval_for):
if pattern == "*":
# All actions require approval
self.add_rule(
ValidationRule(
name="require_approval_all",
description="All actions require approval",
priority=50,
action_types=list(ActionType),
decision=SafetyDecision.REQUIRE_APPROVAL,
reason="All actions require human approval",
)
)
else:
self.add_rule(
ValidationRule(
name=f"require_approval_{i}",
description=f"Require approval for: {pattern}",
priority=50,
tool_patterns=[pattern],
decision=SafetyDecision.REQUIRE_APPROVAL,
reason=f"Action matches approval-required pattern: {pattern}",
)
)
logger.info("Loaded %d rules from policy: %s", len(self._rules), policy.name)
async def validate(
self,
action: ActionRequest,
policy: SafetyPolicy | None = None,
) -> ValidationResult:
"""
Validate an action against all rules.
Args:
action: The action to validate
policy: Optional policy override
Returns:
ValidationResult with decision and details
"""
# Check bypass
if self._bypass_enabled:
logger.warning(
"Validation bypass active: %s - allowing action %s",
self._bypass_reason,
action.id,
)
return ValidationResult(
action_id=action.id,
decision=SafetyDecision.ALLOW,
applied_rules=[],
reasons=[f"Validation bypassed: {self._bypass_reason}"],
)
# Check cache
if self._cache_enabled:
cache_key = self._get_cache_key(action)
cached = await self._cache.get(cache_key)
if cached:
logger.debug("Using cached validation for action %s", action.id)
return cached
# Load rules from policy if provided
if policy and not self._rules:
self.load_rules_from_policy(policy)
# Validate against rules
applied_rules: list[str] = []
reasons: list[str] = []
final_decision = SafetyDecision.ALLOW
approval_id: str | None = None
for rule in self._rules:
if not rule.enabled:
continue
if self._rule_matches(rule, action):
applied_rules.append(rule.id)
if rule.reason:
reasons.append(rule.reason)
# Handle decision priority
if rule.decision == SafetyDecision.DENY:
# Deny takes precedence
final_decision = SafetyDecision.DENY
break
elif rule.decision == SafetyDecision.REQUIRE_APPROVAL:
# Upgrade to require approval
if final_decision != SafetyDecision.DENY:
final_decision = SafetyDecision.REQUIRE_APPROVAL
# If no rules matched and no explicit allow, default to allow
if not applied_rules:
reasons.append("No matching rules - default allow")
result = ValidationResult(
action_id=action.id,
decision=final_decision,
applied_rules=applied_rules,
reasons=reasons,
approval_id=approval_id,
)
# Cache result
if self._cache_enabled:
cache_key = self._get_cache_key(action)
await self._cache.set(cache_key, result)
return result
async def validate_batch(
self,
actions: list[ActionRequest],
policy: SafetyPolicy | None = None,
) -> list[ValidationResult]:
"""
Validate multiple actions.
Args:
actions: Actions to validate
policy: Optional policy override
Returns:
List of validation results
"""
tasks = [self.validate(action, policy) for action in actions]
return await asyncio.gather(*tasks)
def enable_bypass(self, reason: str) -> None:
"""
Enable validation bypass (emergency use only).
Args:
reason: Reason for enabling bypass
"""
logger.critical("Validation bypass enabled: %s", reason)
self._bypass_enabled = True
self._bypass_reason = reason
def disable_bypass(self) -> None:
"""Disable validation bypass."""
logger.info("Validation bypass disabled")
self._bypass_enabled = False
self._bypass_reason = None
async def clear_cache(self) -> None:
"""Clear the validation cache."""
await self._cache.clear()
def _rule_matches(self, rule: ValidationRule, action: ActionRequest) -> bool:
"""Check if a rule matches an action."""
# Check action types
if rule.action_types:
if action.action_type not in rule.action_types:
return False
# Check tool patterns
if rule.tool_patterns:
if not action.tool_name:
return False
matched = False
for pattern in rule.tool_patterns:
if self._matches_pattern(action.tool_name, pattern):
matched = True
break
if not matched:
return False
# Check resource patterns
if rule.resource_patterns:
if not action.resource:
return False
matched = False
for pattern in rule.resource_patterns:
if self._matches_pattern(action.resource, pattern):
matched = True
break
if not matched:
return False
# Check agent IDs
if rule.agent_ids:
if action.metadata.agent_id not in rule.agent_ids:
return False
return True
def _matches_pattern(self, value: str, pattern: str) -> bool:
"""Check if value matches a pattern (supports wildcards)."""
if pattern == "*":
return True
# Use fnmatch for glob-style matching
return fnmatch.fnmatch(value, pattern)
def _get_cache_key(self, action: ActionRequest) -> str:
"""Generate a cache key for an action."""
# Key based on action characteristics that affect validation
key_parts = [
action.action_type.value,
action.tool_name or "",
action.resource or "",
action.metadata.agent_id,
action.metadata.autonomy_level.value,
]
return ":".join(key_parts)
# Module-level convenience functions
def create_allow_rule(
name: str,
tool_patterns: list[str] | None = None,
resource_patterns: list[str] | None = None,
action_types: list[ActionType] | None = None,
priority: int = 0,
) -> ValidationRule:
"""Create an allow rule."""
return ValidationRule(
name=name,
tool_patterns=tool_patterns,
resource_patterns=resource_patterns,
action_types=action_types,
decision=SafetyDecision.ALLOW,
priority=priority,
)
def create_deny_rule(
name: str,
tool_patterns: list[str] | None = None,
resource_patterns: list[str] | None = None,
action_types: list[ActionType] | None = None,
reason: str | None = None,
priority: int = 100,
) -> ValidationRule:
"""Create a deny rule."""
return ValidationRule(
name=name,
tool_patterns=tool_patterns,
resource_patterns=resource_patterns,
action_types=action_types,
decision=SafetyDecision.DENY,
reason=reason,
priority=priority,
)
def create_approval_rule(
name: str,
tool_patterns: list[str] | None = None,
resource_patterns: list[str] | None = None,
action_types: list[ActionType] | None = None,
reason: str | None = None,
priority: int = 50,
) -> ValidationRule:
"""Create a require-approval rule."""
return ValidationRule(
name=name,
tool_patterns=tool_patterns,
resource_patterns=resource_patterns,
action_types=action_types,
decision=SafetyDecision.REQUIRE_APPROVAL,
reason=reason,
priority=priority,
)