""" Cost Controller Budget management and cost tracking for agent operations. """ import asyncio import logging from datetime import datetime, timedelta from typing import Any from ..config import get_safety_config from ..exceptions import BudgetExceededError from ..models import ( ActionRequest, BudgetScope, BudgetStatus, ) logger = logging.getLogger(__name__) class BudgetTracker: """Tracks usage against a budget limit.""" def __init__( self, scope: BudgetScope, scope_id: str, tokens_limit: int, cost_limit_usd: float, reset_interval: timedelta | None = None, warning_threshold: float = 0.8, ) -> None: self.scope = scope self.scope_id = scope_id self.tokens_limit = tokens_limit self.cost_limit_usd = cost_limit_usd self.warning_threshold = warning_threshold self._reset_interval = reset_interval self._tokens_used = 0 self._cost_used_usd = 0.0 self._created_at = datetime.utcnow() self._last_reset = datetime.utcnow() self._lock = asyncio.Lock() async def add_usage(self, tokens: int, cost_usd: float) -> None: """Add usage to the tracker.""" async with self._lock: self._check_reset() self._tokens_used += tokens self._cost_used_usd += cost_usd async def get_status(self) -> BudgetStatus: """Get current budget status.""" async with self._lock: self._check_reset() tokens_remaining = max(0, self.tokens_limit - self._tokens_used) cost_remaining = max(0, self.cost_limit_usd - self._cost_used_usd) token_usage_ratio = ( self._tokens_used / self.tokens_limit if self.tokens_limit > 0 else 0 ) cost_usage_ratio = ( self._cost_used_usd / self.cost_limit_usd if self.cost_limit_usd > 0 else 0 ) is_warning = ( max(token_usage_ratio, cost_usage_ratio) >= self.warning_threshold ) is_exceeded = ( self._tokens_used >= self.tokens_limit or self._cost_used_usd >= self.cost_limit_usd ) reset_at = None if self._reset_interval: reset_at = self._last_reset + self._reset_interval return BudgetStatus( scope=self.scope, scope_id=self.scope_id, tokens_used=self._tokens_used, tokens_limit=self.tokens_limit, cost_used_usd=self._cost_used_usd, cost_limit_usd=self.cost_limit_usd, tokens_remaining=tokens_remaining, cost_remaining_usd=cost_remaining, warning_threshold=self.warning_threshold, is_warning=is_warning, is_exceeded=is_exceeded, reset_at=reset_at, ) async def check_budget( self, estimated_tokens: int, estimated_cost_usd: float ) -> bool: """Check if there's enough budget for an operation.""" async with self._lock: self._check_reset() would_exceed_tokens = ( self._tokens_used + estimated_tokens ) > self.tokens_limit would_exceed_cost = ( self._cost_used_usd + estimated_cost_usd ) > self.cost_limit_usd return not (would_exceed_tokens or would_exceed_cost) def _check_reset(self) -> None: """Check if budget should reset.""" if self._reset_interval is None: return now = datetime.utcnow() if now >= self._last_reset + self._reset_interval: logger.info( "Resetting budget for %s:%s", self.scope.value, self.scope_id, ) self._tokens_used = 0 self._cost_used_usd = 0.0 self._last_reset = now async def reset(self) -> None: """Manually reset the budget.""" async with self._lock: self._tokens_used = 0 self._cost_used_usd = 0.0 self._last_reset = datetime.utcnow() class CostController: """ Controls costs and budgets for agent operations. Features: - Per-agent, per-project, per-session budgets - Real-time cost tracking - Budget alerts at configurable thresholds - Cost prediction for planned actions - Budget rollover policies """ def __init__( self, default_session_tokens: int | None = None, default_session_cost_usd: float | None = None, default_daily_tokens: int | None = None, default_daily_cost_usd: float | None = None, ) -> None: """ Initialize the CostController. Args: default_session_tokens: Default token budget per session default_session_cost_usd: Default USD budget per session default_daily_tokens: Default token budget per day default_daily_cost_usd: Default USD budget per day """ config = get_safety_config() self._default_session_tokens = ( default_session_tokens or config.default_session_token_budget ) self._default_session_cost = ( default_session_cost_usd or config.default_session_cost_limit ) self._default_daily_tokens = ( default_daily_tokens or config.default_daily_token_budget ) self._default_daily_cost = ( default_daily_cost_usd or config.default_daily_cost_limit ) self._trackers: dict[str, BudgetTracker] = {} self._lock = asyncio.Lock() # Alert handlers self._alert_handlers: list[Any] = [] async def get_or_create_tracker( self, scope: BudgetScope, scope_id: str, ) -> BudgetTracker: """Get or create a budget tracker.""" key = f"{scope.value}:{scope_id}" async with self._lock: if key not in self._trackers: if scope == BudgetScope.SESSION: tracker = BudgetTracker( scope=scope, scope_id=scope_id, tokens_limit=self._default_session_tokens, cost_limit_usd=self._default_session_cost, ) elif scope == BudgetScope.DAILY: tracker = BudgetTracker( scope=scope, scope_id=scope_id, tokens_limit=self._default_daily_tokens, cost_limit_usd=self._default_daily_cost, reset_interval=timedelta(days=1), ) else: # Default tracker = BudgetTracker( scope=scope, scope_id=scope_id, tokens_limit=self._default_session_tokens, cost_limit_usd=self._default_session_cost, ) self._trackers[key] = tracker return self._trackers[key] async def check_budget( self, agent_id: str, session_id: str | None, estimated_tokens: int, estimated_cost_usd: float, ) -> bool: """ Check if there's enough budget for an operation. Args: agent_id: ID of the agent session_id: Optional session ID estimated_tokens: Estimated token usage estimated_cost_usd: Estimated USD cost Returns: True if budget is available """ # Check session budget if session_id: session_tracker = await self.get_or_create_tracker( BudgetScope.SESSION, session_id ) if not await session_tracker.check_budget( estimated_tokens, estimated_cost_usd ): return False # Check agent daily budget agent_tracker = await self.get_or_create_tracker(BudgetScope.DAILY, agent_id) if not await agent_tracker.check_budget(estimated_tokens, estimated_cost_usd): return False return True async def check_action(self, action: ActionRequest) -> bool: """ Check if an action is within budget. Args: action: The action to check Returns: True if within budget """ return await self.check_budget( agent_id=action.metadata.agent_id, session_id=action.metadata.session_id, estimated_tokens=action.estimated_cost_tokens, estimated_cost_usd=action.estimated_cost_usd, ) async def require_budget( self, agent_id: str, session_id: str | None, estimated_tokens: int, estimated_cost_usd: float, ) -> None: """ Require budget or raise exception. Args: agent_id: ID of the agent session_id: Optional session ID estimated_tokens: Estimated token usage estimated_cost_usd: Estimated USD cost Raises: BudgetExceededError: If budget is exceeded """ if not await self.check_budget( agent_id, session_id, estimated_tokens, estimated_cost_usd ): # Determine which budget was exceeded if session_id: session_tracker = await self.get_or_create_tracker( BudgetScope.SESSION, session_id ) session_status = await session_tracker.get_status() if session_status.is_exceeded: raise BudgetExceededError( "Session budget exceeded", budget_type="session", current_usage=session_status.tokens_used, budget_limit=session_status.tokens_limit, agent_id=agent_id, ) agent_tracker = await self.get_or_create_tracker( BudgetScope.DAILY, agent_id ) agent_status = await agent_tracker.get_status() raise BudgetExceededError( "Daily budget exceeded", budget_type="daily", current_usage=agent_status.tokens_used, budget_limit=agent_status.tokens_limit, agent_id=agent_id, ) async def record_usage( self, agent_id: str, session_id: str | None, tokens: int, cost_usd: float, ) -> None: """ Record actual usage. Args: agent_id: ID of the agent session_id: Optional session ID tokens: Actual token usage cost_usd: Actual USD cost """ # Update session budget if session_id: session_tracker = await self.get_or_create_tracker( BudgetScope.SESSION, session_id ) await session_tracker.add_usage(tokens, cost_usd) # Check for warning status = await session_tracker.get_status() if status.is_warning and not status.is_exceeded: await self._send_alert( "warning", f"Session {session_id} at {status.tokens_used}/{status.tokens_limit} tokens", status, ) # Update agent daily budget agent_tracker = await self.get_or_create_tracker(BudgetScope.DAILY, agent_id) await agent_tracker.add_usage(tokens, cost_usd) # Check for warning status = await agent_tracker.get_status() if status.is_warning and not status.is_exceeded: await self._send_alert( "warning", f"Agent {agent_id} at {status.tokens_used}/{status.tokens_limit} daily tokens", status, ) async def get_status( self, scope: BudgetScope, scope_id: str, ) -> BudgetStatus | None: """ Get budget status. Args: scope: Budget scope scope_id: ID within scope Returns: Budget status or None if not tracked """ key = f"{scope.value}:{scope_id}" async with self._lock: tracker = self._trackers.get(key) if tracker: return await tracker.get_status() return None async def get_all_statuses(self) -> list[BudgetStatus]: """Get status of all tracked budgets.""" statuses = [] async with self._lock: trackers = list(self._trackers.values()) for tracker in trackers: statuses.append(await tracker.get_status()) return statuses async def set_budget( self, scope: BudgetScope, scope_id: str, tokens_limit: int, cost_limit_usd: float, ) -> None: """ Set a custom budget limit. Args: scope: Budget scope scope_id: ID within scope tokens_limit: Token limit cost_limit_usd: USD limit """ key = f"{scope.value}:{scope_id}" reset_interval = None if scope == BudgetScope.DAILY: reset_interval = timedelta(days=1) elif scope == BudgetScope.WEEKLY: reset_interval = timedelta(weeks=1) elif scope == BudgetScope.MONTHLY: reset_interval = timedelta(days=30) async with self._lock: self._trackers[key] = BudgetTracker( scope=scope, scope_id=scope_id, tokens_limit=tokens_limit, cost_limit_usd=cost_limit_usd, reset_interval=reset_interval, ) async def reset_budget(self, scope: BudgetScope, scope_id: str) -> bool: """ Reset a budget tracker. Args: scope: Budget scope scope_id: ID within scope Returns: True if tracker was found and reset """ key = f"{scope.value}:{scope_id}" async with self._lock: tracker = self._trackers.get(key) if tracker: await tracker.reset() return True return False def add_alert_handler(self, handler: Any) -> None: """Add an alert handler.""" self._alert_handlers.append(handler) def remove_alert_handler(self, handler: Any) -> None: """Remove an alert handler.""" if handler in self._alert_handlers: self._alert_handlers.remove(handler) async def _send_alert( self, alert_type: str, message: str, status: BudgetStatus, ) -> None: """Send alert to all handlers.""" for handler in self._alert_handlers: try: if asyncio.iscoroutinefunction(handler): await handler(alert_type, message, status) else: handler(alert_type, message, status) except Exception as e: logger.error("Error in alert handler: %s", e)