Files
syndarix/backend/app/services/safety/limits/limiter.py
Felipe Cardoso caf283bed2 feat(safety): enhance rate limiting and cost control with alert deduplication and usage tracking
- Added `record_action` in `RateLimiter` for precise tracking of slot consumption post-validation.
- Introduced deduplication mechanism for warning alerts in `CostController` to prevent spamming.
- Refactored `CostController`'s session and daily budget alert handling for improved clarity.
- Implemented test suites for `CostController` and `SafetyGuardian` to validate changes.
- Expanded integration testing to cover deduplication, validation, and loop detection edge cases.
2026-01-03 17:55:34 +01:00

397 lines
12 KiB
Python

"""
Rate Limiter
Sliding window rate limiting for agent operations.
"""
import asyncio
import logging
import time
from collections import deque
from ..config import get_safety_config
from ..exceptions import RateLimitExceededError
from ..models import (
ActionRequest,
RateLimitConfig,
RateLimitStatus,
)
logger = logging.getLogger(__name__)
class SlidingWindowCounter:
"""Sliding window counter for rate limiting."""
def __init__(
self,
limit: int,
window_seconds: int,
burst_limit: int | None = None,
) -> None:
self.limit = limit
self.window_seconds = window_seconds
self.burst_limit = burst_limit or limit
self._timestamps: deque[float] = deque()
self._lock = asyncio.Lock()
async def try_acquire(self) -> tuple[bool, float]:
"""
Try to acquire a slot.
Returns:
Tuple of (allowed, retry_after_seconds)
"""
now = time.time()
window_start = now - self.window_seconds
async with self._lock:
# Remove expired entries
while self._timestamps and self._timestamps[0] < window_start:
self._timestamps.popleft()
current_count = len(self._timestamps)
# Check burst limit (instant check)
if current_count >= self.burst_limit:
# Calculate retry time
oldest = self._timestamps[0] if self._timestamps else now
retry_after = oldest + self.window_seconds - now
return False, max(0, retry_after)
# Check window limit
if current_count >= self.limit:
oldest = self._timestamps[0] if self._timestamps else now
retry_after = oldest + self.window_seconds - now
return False, max(0, retry_after)
# Allow and record
self._timestamps.append(now)
return True, 0.0
async def get_status(self) -> tuple[int, int, float]:
"""
Get current status.
Returns:
Tuple of (current_count, remaining, reset_in_seconds)
"""
now = time.time()
window_start = now - self.window_seconds
async with self._lock:
# Remove expired entries
while self._timestamps and self._timestamps[0] < window_start:
self._timestamps.popleft()
current_count = len(self._timestamps)
remaining = max(0, self.limit - current_count)
if self._timestamps:
reset_in = self._timestamps[0] + self.window_seconds - now
else:
reset_in = 0.0
return current_count, remaining, max(0, reset_in)
class RateLimiter:
"""
Rate limiter for agent operations.
Features:
- Per-tool rate limits
- Per-agent rate limits
- Per-resource rate limits
- Sliding window implementation
- Burst allowance with recovery
- Slowdown before hard block
"""
def __init__(self) -> None:
"""Initialize the RateLimiter."""
config = get_safety_config()
self._configs: dict[str, RateLimitConfig] = {}
self._counters: dict[str, SlidingWindowCounter] = {}
self._lock = asyncio.Lock()
# Default rate limits
self._default_limits = {
"actions": RateLimitConfig(
name="actions",
limit=config.default_actions_per_minute,
window_seconds=60,
),
"llm_calls": RateLimitConfig(
name="llm_calls",
limit=config.default_llm_calls_per_minute,
window_seconds=60,
),
"file_ops": RateLimitConfig(
name="file_ops",
limit=config.default_file_ops_per_minute,
window_seconds=60,
),
}
def configure(self, config: RateLimitConfig) -> None:
"""
Configure a rate limit.
Args:
config: Rate limit configuration
"""
self._configs[config.name] = config
logger.debug(
"Configured rate limit: %s = %d/%ds",
config.name,
config.limit,
config.window_seconds,
)
async def check(
self,
limit_name: str,
key: str,
) -> RateLimitStatus:
"""
Check rate limit without consuming a slot.
Args:
limit_name: Name of the rate limit
key: Key for tracking (e.g., agent_id)
Returns:
Rate limit status
"""
counter = await self._get_counter(limit_name, key)
config = self._get_config(limit_name)
current, remaining, reset_in = await counter.get_status()
from datetime import datetime, timedelta
return RateLimitStatus(
name=limit_name,
current_count=current,
limit=config.limit,
window_seconds=config.window_seconds,
remaining=remaining,
reset_at=datetime.utcnow() + timedelta(seconds=reset_in),
is_limited=remaining <= 0,
retry_after_seconds=reset_in if remaining <= 0 else 0.0,
)
async def acquire(
self,
limit_name: str,
key: str,
) -> tuple[bool, RateLimitStatus]:
"""
Try to acquire a rate limit slot.
Args:
limit_name: Name of the rate limit
key: Key for tracking (e.g., agent_id)
Returns:
Tuple of (allowed, status)
"""
counter = await self._get_counter(limit_name, key)
config = self._get_config(limit_name)
allowed, retry_after = await counter.try_acquire()
current, remaining, reset_in = await counter.get_status()
from datetime import datetime, timedelta
status = RateLimitStatus(
name=limit_name,
current_count=current,
limit=config.limit,
window_seconds=config.window_seconds,
remaining=remaining,
reset_at=datetime.utcnow() + timedelta(seconds=reset_in),
is_limited=not allowed,
retry_after_seconds=retry_after,
)
return allowed, status
async def check_action(
self,
action: ActionRequest,
) -> tuple[bool, list[RateLimitStatus]]:
"""
Check all applicable rate limits for an action WITHOUT consuming slots.
Use this during validation to check if action would be allowed.
Call record_action() after successful execution to consume slots.
Args:
action: The action to check
Returns:
Tuple of (allowed, list of statuses)
"""
agent_id = action.metadata.agent_id
statuses: list[RateLimitStatus] = []
allowed = True
# Check general actions limit (read-only)
actions_status = await self.check("actions", agent_id)
statuses.append(actions_status)
if actions_status.is_limited:
allowed = False
# Check LLM-specific limit for LLM calls
if action.action_type.value == "llm_call":
llm_status = await self.check("llm_calls", agent_id)
statuses.append(llm_status)
if llm_status.is_limited:
allowed = False
# Check file ops limit for file operations
if action.action_type.value in {"file_read", "file_write", "file_delete"}:
file_status = await self.check("file_ops", agent_id)
statuses.append(file_status)
if file_status.is_limited:
allowed = False
return allowed, statuses
async def record_action(
self,
action: ActionRequest,
) -> None:
"""
Record an action by consuming rate limit slots.
Call this AFTER successful execution to properly count the action.
Args:
action: The executed action
"""
agent_id = action.metadata.agent_id
# Consume general actions slot
await self.acquire("actions", agent_id)
# Consume LLM-specific slot for LLM calls
if action.action_type.value == "llm_call":
await self.acquire("llm_calls", agent_id)
# Consume file ops slot for file operations
if action.action_type.value in {"file_read", "file_write", "file_delete"}:
await self.acquire("file_ops", agent_id)
async def require(
self,
limit_name: str,
key: str,
) -> None:
"""
Require rate limit slot or raise exception.
Args:
limit_name: Name of the rate limit
key: Key for tracking
Raises:
RateLimitExceededError: If rate limit exceeded
"""
allowed, status = await self.acquire(limit_name, key)
if not allowed:
raise RateLimitExceededError(
f"Rate limit exceeded: {limit_name}",
limit_type=limit_name,
limit_value=status.limit,
window_seconds=status.window_seconds,
retry_after_seconds=status.retry_after_seconds,
)
async def get_all_statuses(self, key: str) -> dict[str, RateLimitStatus]:
"""
Get status of all rate limits for a key.
Args:
key: Key for tracking
Returns:
Dict of limit name to status
"""
statuses = {}
for name in self._default_limits:
statuses[name] = await self.check(name, key)
for name in self._configs:
if name not in statuses:
statuses[name] = await self.check(name, key)
return statuses
async def reset(self, limit_name: str, key: str) -> bool:
"""
Reset a rate limit counter.
Args:
limit_name: Name of the rate limit
key: Key for tracking
Returns:
True if counter was found and reset
"""
counter_key = f"{limit_name}:{key}"
async with self._lock:
if counter_key in self._counters:
del self._counters[counter_key]
return True
return False
async def reset_all(self, key: str) -> int:
"""
Reset all rate limit counters for a key.
Args:
key: Key for tracking
Returns:
Number of counters reset
"""
count = 0
async with self._lock:
to_remove = [k for k in self._counters if k.endswith(f":{key}")]
for k in to_remove:
del self._counters[k]
count += 1
return count
def _get_config(self, limit_name: str) -> RateLimitConfig:
"""Get configuration for a rate limit."""
if limit_name in self._configs:
return self._configs[limit_name]
if limit_name in self._default_limits:
return self._default_limits[limit_name]
# Return default
return RateLimitConfig(
name=limit_name,
limit=60,
window_seconds=60,
)
async def _get_counter(
self,
limit_name: str,
key: str,
) -> SlidingWindowCounter:
"""Get or create a counter."""
counter_key = f"{limit_name}:{key}"
config = self._get_config(limit_name)
async with self._lock:
if counter_key not in self._counters:
self._counters[counter_key] = SlidingWindowCounter(
limit=config.limit,
window_seconds=config.window_seconds,
burst_limit=config.burst_limit,
)
return self._counters[counter_key]