forked from cardosofelipe/fast-next-template
- Added `record_action` in `RateLimiter` for precise tracking of slot consumption post-validation. - Introduced deduplication mechanism for warning alerts in `CostController` to prevent spamming. - Refactored `CostController`'s session and daily budget alert handling for improved clarity. - Implemented test suites for `CostController` and `SafetyGuardian` to validate changes. - Expanded integration testing to cover deduplication, validation, and loop detection edge cases.
397 lines
12 KiB
Python
397 lines
12 KiB
Python
"""
|
|
Rate Limiter
|
|
|
|
Sliding window rate limiting for agent operations.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
from collections import deque
|
|
|
|
from ..config import get_safety_config
|
|
from ..exceptions import RateLimitExceededError
|
|
from ..models import (
|
|
ActionRequest,
|
|
RateLimitConfig,
|
|
RateLimitStatus,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SlidingWindowCounter:
|
|
"""Sliding window counter for rate limiting."""
|
|
|
|
def __init__(
|
|
self,
|
|
limit: int,
|
|
window_seconds: int,
|
|
burst_limit: int | None = None,
|
|
) -> None:
|
|
self.limit = limit
|
|
self.window_seconds = window_seconds
|
|
self.burst_limit = burst_limit or limit
|
|
self._timestamps: deque[float] = deque()
|
|
self._lock = asyncio.Lock()
|
|
|
|
async def try_acquire(self) -> tuple[bool, float]:
|
|
"""
|
|
Try to acquire a slot.
|
|
|
|
Returns:
|
|
Tuple of (allowed, retry_after_seconds)
|
|
"""
|
|
now = time.time()
|
|
window_start = now - self.window_seconds
|
|
|
|
async with self._lock:
|
|
# Remove expired entries
|
|
while self._timestamps and self._timestamps[0] < window_start:
|
|
self._timestamps.popleft()
|
|
|
|
current_count = len(self._timestamps)
|
|
|
|
# Check burst limit (instant check)
|
|
if current_count >= self.burst_limit:
|
|
# Calculate retry time
|
|
oldest = self._timestamps[0] if self._timestamps else now
|
|
retry_after = oldest + self.window_seconds - now
|
|
return False, max(0, retry_after)
|
|
|
|
# Check window limit
|
|
if current_count >= self.limit:
|
|
oldest = self._timestamps[0] if self._timestamps else now
|
|
retry_after = oldest + self.window_seconds - now
|
|
return False, max(0, retry_after)
|
|
|
|
# Allow and record
|
|
self._timestamps.append(now)
|
|
return True, 0.0
|
|
|
|
async def get_status(self) -> tuple[int, int, float]:
|
|
"""
|
|
Get current status.
|
|
|
|
Returns:
|
|
Tuple of (current_count, remaining, reset_in_seconds)
|
|
"""
|
|
now = time.time()
|
|
window_start = now - self.window_seconds
|
|
|
|
async with self._lock:
|
|
# Remove expired entries
|
|
while self._timestamps and self._timestamps[0] < window_start:
|
|
self._timestamps.popleft()
|
|
|
|
current_count = len(self._timestamps)
|
|
remaining = max(0, self.limit - current_count)
|
|
|
|
if self._timestamps:
|
|
reset_in = self._timestamps[0] + self.window_seconds - now
|
|
else:
|
|
reset_in = 0.0
|
|
|
|
return current_count, remaining, max(0, reset_in)
|
|
|
|
|
|
class RateLimiter:
|
|
"""
|
|
Rate limiter for agent operations.
|
|
|
|
Features:
|
|
- Per-tool rate limits
|
|
- Per-agent rate limits
|
|
- Per-resource rate limits
|
|
- Sliding window implementation
|
|
- Burst allowance with recovery
|
|
- Slowdown before hard block
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
"""Initialize the RateLimiter."""
|
|
config = get_safety_config()
|
|
|
|
self._configs: dict[str, RateLimitConfig] = {}
|
|
self._counters: dict[str, SlidingWindowCounter] = {}
|
|
self._lock = asyncio.Lock()
|
|
|
|
# Default rate limits
|
|
self._default_limits = {
|
|
"actions": RateLimitConfig(
|
|
name="actions",
|
|
limit=config.default_actions_per_minute,
|
|
window_seconds=60,
|
|
),
|
|
"llm_calls": RateLimitConfig(
|
|
name="llm_calls",
|
|
limit=config.default_llm_calls_per_minute,
|
|
window_seconds=60,
|
|
),
|
|
"file_ops": RateLimitConfig(
|
|
name="file_ops",
|
|
limit=config.default_file_ops_per_minute,
|
|
window_seconds=60,
|
|
),
|
|
}
|
|
|
|
def configure(self, config: RateLimitConfig) -> None:
|
|
"""
|
|
Configure a rate limit.
|
|
|
|
Args:
|
|
config: Rate limit configuration
|
|
"""
|
|
self._configs[config.name] = config
|
|
logger.debug(
|
|
"Configured rate limit: %s = %d/%ds",
|
|
config.name,
|
|
config.limit,
|
|
config.window_seconds,
|
|
)
|
|
|
|
async def check(
|
|
self,
|
|
limit_name: str,
|
|
key: str,
|
|
) -> RateLimitStatus:
|
|
"""
|
|
Check rate limit without consuming a slot.
|
|
|
|
Args:
|
|
limit_name: Name of the rate limit
|
|
key: Key for tracking (e.g., agent_id)
|
|
|
|
Returns:
|
|
Rate limit status
|
|
"""
|
|
counter = await self._get_counter(limit_name, key)
|
|
config = self._get_config(limit_name)
|
|
|
|
current, remaining, reset_in = await counter.get_status()
|
|
from datetime import datetime, timedelta
|
|
|
|
return RateLimitStatus(
|
|
name=limit_name,
|
|
current_count=current,
|
|
limit=config.limit,
|
|
window_seconds=config.window_seconds,
|
|
remaining=remaining,
|
|
reset_at=datetime.utcnow() + timedelta(seconds=reset_in),
|
|
is_limited=remaining <= 0,
|
|
retry_after_seconds=reset_in if remaining <= 0 else 0.0,
|
|
)
|
|
|
|
async def acquire(
|
|
self,
|
|
limit_name: str,
|
|
key: str,
|
|
) -> tuple[bool, RateLimitStatus]:
|
|
"""
|
|
Try to acquire a rate limit slot.
|
|
|
|
Args:
|
|
limit_name: Name of the rate limit
|
|
key: Key for tracking (e.g., agent_id)
|
|
|
|
Returns:
|
|
Tuple of (allowed, status)
|
|
"""
|
|
counter = await self._get_counter(limit_name, key)
|
|
config = self._get_config(limit_name)
|
|
|
|
allowed, retry_after = await counter.try_acquire()
|
|
current, remaining, reset_in = await counter.get_status()
|
|
|
|
from datetime import datetime, timedelta
|
|
|
|
status = RateLimitStatus(
|
|
name=limit_name,
|
|
current_count=current,
|
|
limit=config.limit,
|
|
window_seconds=config.window_seconds,
|
|
remaining=remaining,
|
|
reset_at=datetime.utcnow() + timedelta(seconds=reset_in),
|
|
is_limited=not allowed,
|
|
retry_after_seconds=retry_after,
|
|
)
|
|
|
|
return allowed, status
|
|
|
|
async def check_action(
|
|
self,
|
|
action: ActionRequest,
|
|
) -> tuple[bool, list[RateLimitStatus]]:
|
|
"""
|
|
Check all applicable rate limits for an action WITHOUT consuming slots.
|
|
|
|
Use this during validation to check if action would be allowed.
|
|
Call record_action() after successful execution to consume slots.
|
|
|
|
Args:
|
|
action: The action to check
|
|
|
|
Returns:
|
|
Tuple of (allowed, list of statuses)
|
|
"""
|
|
agent_id = action.metadata.agent_id
|
|
statuses: list[RateLimitStatus] = []
|
|
allowed = True
|
|
|
|
# Check general actions limit (read-only)
|
|
actions_status = await self.check("actions", agent_id)
|
|
statuses.append(actions_status)
|
|
if actions_status.is_limited:
|
|
allowed = False
|
|
|
|
# Check LLM-specific limit for LLM calls
|
|
if action.action_type.value == "llm_call":
|
|
llm_status = await self.check("llm_calls", agent_id)
|
|
statuses.append(llm_status)
|
|
if llm_status.is_limited:
|
|
allowed = False
|
|
|
|
# Check file ops limit for file operations
|
|
if action.action_type.value in {"file_read", "file_write", "file_delete"}:
|
|
file_status = await self.check("file_ops", agent_id)
|
|
statuses.append(file_status)
|
|
if file_status.is_limited:
|
|
allowed = False
|
|
|
|
return allowed, statuses
|
|
|
|
async def record_action(
|
|
self,
|
|
action: ActionRequest,
|
|
) -> None:
|
|
"""
|
|
Record an action by consuming rate limit slots.
|
|
|
|
Call this AFTER successful execution to properly count the action.
|
|
|
|
Args:
|
|
action: The executed action
|
|
"""
|
|
agent_id = action.metadata.agent_id
|
|
|
|
# Consume general actions slot
|
|
await self.acquire("actions", agent_id)
|
|
|
|
# Consume LLM-specific slot for LLM calls
|
|
if action.action_type.value == "llm_call":
|
|
await self.acquire("llm_calls", agent_id)
|
|
|
|
# Consume file ops slot for file operations
|
|
if action.action_type.value in {"file_read", "file_write", "file_delete"}:
|
|
await self.acquire("file_ops", agent_id)
|
|
|
|
async def require(
|
|
self,
|
|
limit_name: str,
|
|
key: str,
|
|
) -> None:
|
|
"""
|
|
Require rate limit slot or raise exception.
|
|
|
|
Args:
|
|
limit_name: Name of the rate limit
|
|
key: Key for tracking
|
|
|
|
Raises:
|
|
RateLimitExceededError: If rate limit exceeded
|
|
"""
|
|
allowed, status = await self.acquire(limit_name, key)
|
|
if not allowed:
|
|
raise RateLimitExceededError(
|
|
f"Rate limit exceeded: {limit_name}",
|
|
limit_type=limit_name,
|
|
limit_value=status.limit,
|
|
window_seconds=status.window_seconds,
|
|
retry_after_seconds=status.retry_after_seconds,
|
|
)
|
|
|
|
async def get_all_statuses(self, key: str) -> dict[str, RateLimitStatus]:
|
|
"""
|
|
Get status of all rate limits for a key.
|
|
|
|
Args:
|
|
key: Key for tracking
|
|
|
|
Returns:
|
|
Dict of limit name to status
|
|
"""
|
|
statuses = {}
|
|
for name in self._default_limits:
|
|
statuses[name] = await self.check(name, key)
|
|
for name in self._configs:
|
|
if name not in statuses:
|
|
statuses[name] = await self.check(name, key)
|
|
return statuses
|
|
|
|
async def reset(self, limit_name: str, key: str) -> bool:
|
|
"""
|
|
Reset a rate limit counter.
|
|
|
|
Args:
|
|
limit_name: Name of the rate limit
|
|
key: Key for tracking
|
|
|
|
Returns:
|
|
True if counter was found and reset
|
|
"""
|
|
counter_key = f"{limit_name}:{key}"
|
|
async with self._lock:
|
|
if counter_key in self._counters:
|
|
del self._counters[counter_key]
|
|
return True
|
|
return False
|
|
|
|
async def reset_all(self, key: str) -> int:
|
|
"""
|
|
Reset all rate limit counters for a key.
|
|
|
|
Args:
|
|
key: Key for tracking
|
|
|
|
Returns:
|
|
Number of counters reset
|
|
"""
|
|
count = 0
|
|
async with self._lock:
|
|
to_remove = [k for k in self._counters if k.endswith(f":{key}")]
|
|
for k in to_remove:
|
|
del self._counters[k]
|
|
count += 1
|
|
return count
|
|
|
|
def _get_config(self, limit_name: str) -> RateLimitConfig:
|
|
"""Get configuration for a rate limit."""
|
|
if limit_name in self._configs:
|
|
return self._configs[limit_name]
|
|
if limit_name in self._default_limits:
|
|
return self._default_limits[limit_name]
|
|
# Return default
|
|
return RateLimitConfig(
|
|
name=limit_name,
|
|
limit=60,
|
|
window_seconds=60,
|
|
)
|
|
|
|
async def _get_counter(
|
|
self,
|
|
limit_name: str,
|
|
key: str,
|
|
) -> SlidingWindowCounter:
|
|
"""Get or create a counter."""
|
|
counter_key = f"{limit_name}:{key}"
|
|
config = self._get_config(limit_name)
|
|
|
|
async with self._lock:
|
|
if counter_key not in self._counters:
|
|
self._counters[counter_key] = SlidingWindowCounter(
|
|
limit=config.limit,
|
|
window_seconds=config.window_seconds,
|
|
burst_limit=config.burst_limit,
|
|
)
|
|
return self._counters[counter_key]
|