forked from cardosofelipe/fast-next-template
Improved code readability and uniformity by standardizing line breaks, indentation, and inline conditions across safety-related services, models, and tests, including content filters, validation rules, and emergency controls.
486 lines
15 KiB
Python
486 lines
15 KiB
Python
"""
|
|
Cost Controller
|
|
|
|
Budget management and cost tracking for agent operations.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
from datetime import datetime, timedelta
|
|
from typing import Any
|
|
|
|
from ..config import get_safety_config
|
|
from ..exceptions import BudgetExceededError
|
|
from ..models import (
|
|
ActionRequest,
|
|
BudgetScope,
|
|
BudgetStatus,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BudgetTracker:
|
|
"""Tracks usage against a budget limit."""
|
|
|
|
def __init__(
|
|
self,
|
|
scope: BudgetScope,
|
|
scope_id: str,
|
|
tokens_limit: int,
|
|
cost_limit_usd: float,
|
|
reset_interval: timedelta | None = None,
|
|
warning_threshold: float = 0.8,
|
|
) -> None:
|
|
self.scope = scope
|
|
self.scope_id = scope_id
|
|
self.tokens_limit = tokens_limit
|
|
self.cost_limit_usd = cost_limit_usd
|
|
self.warning_threshold = warning_threshold
|
|
self._reset_interval = reset_interval
|
|
|
|
self._tokens_used = 0
|
|
self._cost_used_usd = 0.0
|
|
self._created_at = datetime.utcnow()
|
|
self._last_reset = datetime.utcnow()
|
|
self._lock = asyncio.Lock()
|
|
|
|
async def add_usage(self, tokens: int, cost_usd: float) -> None:
|
|
"""Add usage to the tracker."""
|
|
async with self._lock:
|
|
self._check_reset()
|
|
self._tokens_used += tokens
|
|
self._cost_used_usd += cost_usd
|
|
|
|
async def get_status(self) -> BudgetStatus:
|
|
"""Get current budget status."""
|
|
async with self._lock:
|
|
self._check_reset()
|
|
|
|
tokens_remaining = max(0, self.tokens_limit - self._tokens_used)
|
|
cost_remaining = max(0, self.cost_limit_usd - self._cost_used_usd)
|
|
|
|
token_usage_ratio = (
|
|
self._tokens_used / self.tokens_limit if self.tokens_limit > 0 else 0
|
|
)
|
|
cost_usage_ratio = (
|
|
self._cost_used_usd / self.cost_limit_usd
|
|
if self.cost_limit_usd > 0
|
|
else 0
|
|
)
|
|
|
|
is_warning = (
|
|
max(token_usage_ratio, cost_usage_ratio) >= self.warning_threshold
|
|
)
|
|
is_exceeded = (
|
|
self._tokens_used >= self.tokens_limit
|
|
or self._cost_used_usd >= self.cost_limit_usd
|
|
)
|
|
|
|
reset_at = None
|
|
if self._reset_interval:
|
|
reset_at = self._last_reset + self._reset_interval
|
|
|
|
return BudgetStatus(
|
|
scope=self.scope,
|
|
scope_id=self.scope_id,
|
|
tokens_used=self._tokens_used,
|
|
tokens_limit=self.tokens_limit,
|
|
cost_used_usd=self._cost_used_usd,
|
|
cost_limit_usd=self.cost_limit_usd,
|
|
tokens_remaining=tokens_remaining,
|
|
cost_remaining_usd=cost_remaining,
|
|
warning_threshold=self.warning_threshold,
|
|
is_warning=is_warning,
|
|
is_exceeded=is_exceeded,
|
|
reset_at=reset_at,
|
|
)
|
|
|
|
async def check_budget(
|
|
self, estimated_tokens: int, estimated_cost_usd: float
|
|
) -> bool:
|
|
"""Check if there's enough budget for an operation."""
|
|
async with self._lock:
|
|
self._check_reset()
|
|
|
|
would_exceed_tokens = (
|
|
self._tokens_used + estimated_tokens
|
|
) > self.tokens_limit
|
|
would_exceed_cost = (
|
|
self._cost_used_usd + estimated_cost_usd
|
|
) > self.cost_limit_usd
|
|
|
|
return not (would_exceed_tokens or would_exceed_cost)
|
|
|
|
def _check_reset(self) -> None:
|
|
"""Check if budget should reset."""
|
|
if self._reset_interval is None:
|
|
return
|
|
|
|
now = datetime.utcnow()
|
|
if now >= self._last_reset + self._reset_interval:
|
|
logger.info(
|
|
"Resetting budget for %s:%s",
|
|
self.scope.value,
|
|
self.scope_id,
|
|
)
|
|
self._tokens_used = 0
|
|
self._cost_used_usd = 0.0
|
|
self._last_reset = now
|
|
|
|
async def reset(self) -> None:
|
|
"""Manually reset the budget."""
|
|
async with self._lock:
|
|
self._tokens_used = 0
|
|
self._cost_used_usd = 0.0
|
|
self._last_reset = datetime.utcnow()
|
|
|
|
|
|
class CostController:
|
|
"""
|
|
Controls costs and budgets for agent operations.
|
|
|
|
Features:
|
|
- Per-agent, per-project, per-session budgets
|
|
- Real-time cost tracking
|
|
- Budget alerts at configurable thresholds
|
|
- Cost prediction for planned actions
|
|
- Budget rollover policies
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
default_session_tokens: int | None = None,
|
|
default_session_cost_usd: float | None = None,
|
|
default_daily_tokens: int | None = None,
|
|
default_daily_cost_usd: float | None = None,
|
|
) -> None:
|
|
"""
|
|
Initialize the CostController.
|
|
|
|
Args:
|
|
default_session_tokens: Default token budget per session
|
|
default_session_cost_usd: Default USD budget per session
|
|
default_daily_tokens: Default token budget per day
|
|
default_daily_cost_usd: Default USD budget per day
|
|
"""
|
|
config = get_safety_config()
|
|
|
|
self._default_session_tokens = (
|
|
default_session_tokens or config.default_session_token_budget
|
|
)
|
|
self._default_session_cost = (
|
|
default_session_cost_usd or config.default_session_cost_limit
|
|
)
|
|
self._default_daily_tokens = (
|
|
default_daily_tokens or config.default_daily_token_budget
|
|
)
|
|
self._default_daily_cost = (
|
|
default_daily_cost_usd or config.default_daily_cost_limit
|
|
)
|
|
|
|
self._trackers: dict[str, BudgetTracker] = {}
|
|
self._lock = asyncio.Lock()
|
|
|
|
# Alert handlers
|
|
self._alert_handlers: list[Any] = []
|
|
|
|
async def get_or_create_tracker(
|
|
self,
|
|
scope: BudgetScope,
|
|
scope_id: str,
|
|
) -> BudgetTracker:
|
|
"""Get or create a budget tracker."""
|
|
key = f"{scope.value}:{scope_id}"
|
|
|
|
async with self._lock:
|
|
if key not in self._trackers:
|
|
if scope == BudgetScope.SESSION:
|
|
tracker = BudgetTracker(
|
|
scope=scope,
|
|
scope_id=scope_id,
|
|
tokens_limit=self._default_session_tokens,
|
|
cost_limit_usd=self._default_session_cost,
|
|
)
|
|
elif scope == BudgetScope.DAILY:
|
|
tracker = BudgetTracker(
|
|
scope=scope,
|
|
scope_id=scope_id,
|
|
tokens_limit=self._default_daily_tokens,
|
|
cost_limit_usd=self._default_daily_cost,
|
|
reset_interval=timedelta(days=1),
|
|
)
|
|
else:
|
|
# Default
|
|
tracker = BudgetTracker(
|
|
scope=scope,
|
|
scope_id=scope_id,
|
|
tokens_limit=self._default_session_tokens,
|
|
cost_limit_usd=self._default_session_cost,
|
|
)
|
|
|
|
self._trackers[key] = tracker
|
|
|
|
return self._trackers[key]
|
|
|
|
async def check_budget(
|
|
self,
|
|
agent_id: str,
|
|
session_id: str | None,
|
|
estimated_tokens: int,
|
|
estimated_cost_usd: float,
|
|
) -> bool:
|
|
"""
|
|
Check if there's enough budget for an operation.
|
|
|
|
Args:
|
|
agent_id: ID of the agent
|
|
session_id: Optional session ID
|
|
estimated_tokens: Estimated token usage
|
|
estimated_cost_usd: Estimated USD cost
|
|
|
|
Returns:
|
|
True if budget is available
|
|
"""
|
|
# Check session budget
|
|
if session_id:
|
|
session_tracker = await self.get_or_create_tracker(
|
|
BudgetScope.SESSION, session_id
|
|
)
|
|
if not await session_tracker.check_budget(
|
|
estimated_tokens, estimated_cost_usd
|
|
):
|
|
return False
|
|
|
|
# Check agent daily budget
|
|
agent_tracker = await self.get_or_create_tracker(BudgetScope.DAILY, agent_id)
|
|
if not await agent_tracker.check_budget(estimated_tokens, estimated_cost_usd):
|
|
return False
|
|
|
|
return True
|
|
|
|
async def check_action(self, action: ActionRequest) -> bool:
|
|
"""
|
|
Check if an action is within budget.
|
|
|
|
Args:
|
|
action: The action to check
|
|
|
|
Returns:
|
|
True if within budget
|
|
"""
|
|
return await self.check_budget(
|
|
agent_id=action.metadata.agent_id,
|
|
session_id=action.metadata.session_id,
|
|
estimated_tokens=action.estimated_cost_tokens,
|
|
estimated_cost_usd=action.estimated_cost_usd,
|
|
)
|
|
|
|
async def require_budget(
|
|
self,
|
|
agent_id: str,
|
|
session_id: str | None,
|
|
estimated_tokens: int,
|
|
estimated_cost_usd: float,
|
|
) -> None:
|
|
"""
|
|
Require budget or raise exception.
|
|
|
|
Args:
|
|
agent_id: ID of the agent
|
|
session_id: Optional session ID
|
|
estimated_tokens: Estimated token usage
|
|
estimated_cost_usd: Estimated USD cost
|
|
|
|
Raises:
|
|
BudgetExceededError: If budget is exceeded
|
|
"""
|
|
if not await self.check_budget(
|
|
agent_id, session_id, estimated_tokens, estimated_cost_usd
|
|
):
|
|
# Determine which budget was exceeded
|
|
if session_id:
|
|
session_tracker = await self.get_or_create_tracker(
|
|
BudgetScope.SESSION, session_id
|
|
)
|
|
session_status = await session_tracker.get_status()
|
|
if session_status.is_exceeded:
|
|
raise BudgetExceededError(
|
|
"Session budget exceeded",
|
|
budget_type="session",
|
|
current_usage=session_status.tokens_used,
|
|
budget_limit=session_status.tokens_limit,
|
|
agent_id=agent_id,
|
|
)
|
|
|
|
agent_tracker = await self.get_or_create_tracker(
|
|
BudgetScope.DAILY, agent_id
|
|
)
|
|
agent_status = await agent_tracker.get_status()
|
|
raise BudgetExceededError(
|
|
"Daily budget exceeded",
|
|
budget_type="daily",
|
|
current_usage=agent_status.tokens_used,
|
|
budget_limit=agent_status.tokens_limit,
|
|
agent_id=agent_id,
|
|
)
|
|
|
|
async def record_usage(
|
|
self,
|
|
agent_id: str,
|
|
session_id: str | None,
|
|
tokens: int,
|
|
cost_usd: float,
|
|
) -> None:
|
|
"""
|
|
Record actual usage.
|
|
|
|
Args:
|
|
agent_id: ID of the agent
|
|
session_id: Optional session ID
|
|
tokens: Actual token usage
|
|
cost_usd: Actual USD cost
|
|
"""
|
|
# Update session budget
|
|
if session_id:
|
|
session_tracker = await self.get_or_create_tracker(
|
|
BudgetScope.SESSION, session_id
|
|
)
|
|
await session_tracker.add_usage(tokens, cost_usd)
|
|
|
|
# Check for warning
|
|
status = await session_tracker.get_status()
|
|
if status.is_warning and not status.is_exceeded:
|
|
await self._send_alert(
|
|
"warning",
|
|
f"Session {session_id} at {status.tokens_used}/{status.tokens_limit} tokens",
|
|
status,
|
|
)
|
|
|
|
# Update agent daily budget
|
|
agent_tracker = await self.get_or_create_tracker(BudgetScope.DAILY, agent_id)
|
|
await agent_tracker.add_usage(tokens, cost_usd)
|
|
|
|
# Check for warning
|
|
status = await agent_tracker.get_status()
|
|
if status.is_warning and not status.is_exceeded:
|
|
await self._send_alert(
|
|
"warning",
|
|
f"Agent {agent_id} at {status.tokens_used}/{status.tokens_limit} daily tokens",
|
|
status,
|
|
)
|
|
|
|
async def get_status(
|
|
self,
|
|
scope: BudgetScope,
|
|
scope_id: str,
|
|
) -> BudgetStatus | None:
|
|
"""
|
|
Get budget status.
|
|
|
|
Args:
|
|
scope: Budget scope
|
|
scope_id: ID within scope
|
|
|
|
Returns:
|
|
Budget status or None if not tracked
|
|
"""
|
|
key = f"{scope.value}:{scope_id}"
|
|
async with self._lock:
|
|
tracker = self._trackers.get(key)
|
|
|
|
if tracker:
|
|
return await tracker.get_status()
|
|
return None
|
|
|
|
async def get_all_statuses(self) -> list[BudgetStatus]:
|
|
"""Get status of all tracked budgets."""
|
|
statuses = []
|
|
async with self._lock:
|
|
trackers = list(self._trackers.values())
|
|
|
|
for tracker in trackers:
|
|
statuses.append(await tracker.get_status())
|
|
|
|
return statuses
|
|
|
|
async def set_budget(
|
|
self,
|
|
scope: BudgetScope,
|
|
scope_id: str,
|
|
tokens_limit: int,
|
|
cost_limit_usd: float,
|
|
) -> None:
|
|
"""
|
|
Set a custom budget limit.
|
|
|
|
Args:
|
|
scope: Budget scope
|
|
scope_id: ID within scope
|
|
tokens_limit: Token limit
|
|
cost_limit_usd: USD limit
|
|
"""
|
|
key = f"{scope.value}:{scope_id}"
|
|
|
|
reset_interval = None
|
|
if scope == BudgetScope.DAILY:
|
|
reset_interval = timedelta(days=1)
|
|
elif scope == BudgetScope.WEEKLY:
|
|
reset_interval = timedelta(weeks=1)
|
|
elif scope == BudgetScope.MONTHLY:
|
|
reset_interval = timedelta(days=30)
|
|
|
|
async with self._lock:
|
|
self._trackers[key] = BudgetTracker(
|
|
scope=scope,
|
|
scope_id=scope_id,
|
|
tokens_limit=tokens_limit,
|
|
cost_limit_usd=cost_limit_usd,
|
|
reset_interval=reset_interval,
|
|
)
|
|
|
|
async def reset_budget(self, scope: BudgetScope, scope_id: str) -> bool:
|
|
"""
|
|
Reset a budget tracker.
|
|
|
|
Args:
|
|
scope: Budget scope
|
|
scope_id: ID within scope
|
|
|
|
Returns:
|
|
True if tracker was found and reset
|
|
"""
|
|
key = f"{scope.value}:{scope_id}"
|
|
async with self._lock:
|
|
tracker = self._trackers.get(key)
|
|
|
|
if tracker:
|
|
await tracker.reset()
|
|
return True
|
|
return False
|
|
|
|
def add_alert_handler(self, handler: Any) -> None:
|
|
"""Add an alert handler."""
|
|
self._alert_handlers.append(handler)
|
|
|
|
def remove_alert_handler(self, handler: Any) -> None:
|
|
"""Remove an alert handler."""
|
|
if handler in self._alert_handlers:
|
|
self._alert_handlers.remove(handler)
|
|
|
|
async def _send_alert(
|
|
self,
|
|
alert_type: str,
|
|
message: str,
|
|
status: BudgetStatus,
|
|
) -> None:
|
|
"""Send alert to all handlers."""
|
|
for handler in self._alert_handlers:
|
|
try:
|
|
if asyncio.iscoroutinefunction(handler):
|
|
await handler(alert_type, message, status)
|
|
else:
|
|
handler(alert_type, message, status)
|
|
except Exception as e:
|
|
logger.error("Error in alert handler: %s", e)
|