""" Rate Limiter Sliding window rate limiting for agent operations. """ import asyncio import logging import time from collections import deque from ..config import get_safety_config from ..exceptions import RateLimitExceededError from ..models import ( ActionRequest, RateLimitConfig, RateLimitStatus, ) logger = logging.getLogger(__name__) class SlidingWindowCounter: """Sliding window counter for rate limiting.""" def __init__( self, limit: int, window_seconds: int, burst_limit: int | None = None, ) -> None: self.limit = limit self.window_seconds = window_seconds self.burst_limit = burst_limit or limit self._timestamps: deque[float] = deque() self._lock = asyncio.Lock() async def try_acquire(self) -> tuple[bool, float]: """ Try to acquire a slot. Returns: Tuple of (allowed, retry_after_seconds) """ now = time.time() window_start = now - self.window_seconds async with self._lock: # Remove expired entries while self._timestamps and self._timestamps[0] < window_start: self._timestamps.popleft() current_count = len(self._timestamps) # Check burst limit (instant check) if current_count >= self.burst_limit: # Calculate retry time oldest = self._timestamps[0] if self._timestamps else now retry_after = oldest + self.window_seconds - now return False, max(0, retry_after) # Check window limit if current_count >= self.limit: oldest = self._timestamps[0] if self._timestamps else now retry_after = oldest + self.window_seconds - now return False, max(0, retry_after) # Allow and record self._timestamps.append(now) return True, 0.0 async def get_status(self) -> tuple[int, int, float]: """ Get current status. Returns: Tuple of (current_count, remaining, reset_in_seconds) """ now = time.time() window_start = now - self.window_seconds async with self._lock: # Remove expired entries while self._timestamps and self._timestamps[0] < window_start: self._timestamps.popleft() current_count = len(self._timestamps) remaining = max(0, self.limit - current_count) if self._timestamps: reset_in = self._timestamps[0] + self.window_seconds - now else: reset_in = 0.0 return current_count, remaining, max(0, reset_in) class RateLimiter: """ Rate limiter for agent operations. Features: - Per-tool rate limits - Per-agent rate limits - Per-resource rate limits - Sliding window implementation - Burst allowance with recovery - Slowdown before hard block """ def __init__(self) -> None: """Initialize the RateLimiter.""" config = get_safety_config() self._configs: dict[str, RateLimitConfig] = {} self._counters: dict[str, SlidingWindowCounter] = {} self._lock = asyncio.Lock() # Default rate limits self._default_limits = { "actions": RateLimitConfig( name="actions", limit=config.default_actions_per_minute, window_seconds=60, ), "llm_calls": RateLimitConfig( name="llm_calls", limit=config.default_llm_calls_per_minute, window_seconds=60, ), "file_ops": RateLimitConfig( name="file_ops", limit=config.default_file_ops_per_minute, window_seconds=60, ), } def configure(self, config: RateLimitConfig) -> None: """ Configure a rate limit. Args: config: Rate limit configuration """ self._configs[config.name] = config logger.debug( "Configured rate limit: %s = %d/%ds", config.name, config.limit, config.window_seconds, ) async def check( self, limit_name: str, key: str, ) -> RateLimitStatus: """ Check rate limit without consuming a slot. Args: limit_name: Name of the rate limit key: Key for tracking (e.g., agent_id) Returns: Rate limit status """ counter = await self._get_counter(limit_name, key) config = self._get_config(limit_name) current, remaining, reset_in = await counter.get_status() from datetime import datetime, timedelta return RateLimitStatus( name=limit_name, current_count=current, limit=config.limit, window_seconds=config.window_seconds, remaining=remaining, reset_at=datetime.utcnow() + timedelta(seconds=reset_in), is_limited=remaining <= 0, retry_after_seconds=reset_in if remaining <= 0 else 0.0, ) async def acquire( self, limit_name: str, key: str, ) -> tuple[bool, RateLimitStatus]: """ Try to acquire a rate limit slot. Args: limit_name: Name of the rate limit key: Key for tracking (e.g., agent_id) Returns: Tuple of (allowed, status) """ counter = await self._get_counter(limit_name, key) config = self._get_config(limit_name) allowed, retry_after = await counter.try_acquire() current, remaining, reset_in = await counter.get_status() from datetime import datetime, timedelta status = RateLimitStatus( name=limit_name, current_count=current, limit=config.limit, window_seconds=config.window_seconds, remaining=remaining, reset_at=datetime.utcnow() + timedelta(seconds=reset_in), is_limited=not allowed, retry_after_seconds=retry_after, ) return allowed, status async def check_action( self, action: ActionRequest, ) -> tuple[bool, list[RateLimitStatus]]: """ Check all applicable rate limits for an action. Args: action: The action to check Returns: Tuple of (allowed, list of statuses) """ agent_id = action.metadata.agent_id statuses: list[RateLimitStatus] = [] allowed = True # Check general actions limit actions_allowed, actions_status = await self.acquire("actions", agent_id) statuses.append(actions_status) if not actions_allowed: allowed = False # Check LLM-specific limit for LLM calls if action.action_type.value == "llm_call": llm_allowed, llm_status = await self.acquire("llm_calls", agent_id) statuses.append(llm_status) if not llm_allowed: allowed = False # Check file ops limit for file operations if action.action_type.value in {"file_read", "file_write", "file_delete"}: file_allowed, file_status = await self.acquire("file_ops", agent_id) statuses.append(file_status) if not file_allowed: allowed = False return allowed, statuses async def require( self, limit_name: str, key: str, ) -> None: """ Require rate limit slot or raise exception. Args: limit_name: Name of the rate limit key: Key for tracking Raises: RateLimitExceededError: If rate limit exceeded """ allowed, status = await self.acquire(limit_name, key) if not allowed: raise RateLimitExceededError( f"Rate limit exceeded: {limit_name}", limit_type=limit_name, limit_value=status.limit, window_seconds=status.window_seconds, retry_after_seconds=status.retry_after_seconds, ) async def get_all_statuses(self, key: str) -> dict[str, RateLimitStatus]: """ Get status of all rate limits for a key. Args: key: Key for tracking Returns: Dict of limit name to status """ statuses = {} for name in self._default_limits: statuses[name] = await self.check(name, key) for name in self._configs: if name not in statuses: statuses[name] = await self.check(name, key) return statuses async def reset(self, limit_name: str, key: str) -> bool: """ Reset a rate limit counter. Args: limit_name: Name of the rate limit key: Key for tracking Returns: True if counter was found and reset """ counter_key = f"{limit_name}:{key}" async with self._lock: if counter_key in self._counters: del self._counters[counter_key] return True return False async def reset_all(self, key: str) -> int: """ Reset all rate limit counters for a key. Args: key: Key for tracking Returns: Number of counters reset """ count = 0 async with self._lock: to_remove = [k for k in self._counters if k.endswith(f":{key}")] for k in to_remove: del self._counters[k] count += 1 return count def _get_config(self, limit_name: str) -> RateLimitConfig: """Get configuration for a rate limit.""" if limit_name in self._configs: return self._configs[limit_name] if limit_name in self._default_limits: return self._default_limits[limit_name] # Return default return RateLimitConfig( name=limit_name, limit=60, window_seconds=60, ) async def _get_counter( self, limit_name: str, key: str, ) -> SlidingWindowCounter: """Get or create a counter.""" counter_key = f"{limit_name}:{key}" config = self._get_config(limit_name) async with self._lock: if counter_key not in self._counters: self._counters[counter_key] = SlidingWindowCounter( limit=config.limit, window_seconds=config.window_seconds, burst_limit=config.burst_limit, ) return self._counters[counter_key]