forked from cardosofelipe/fast-next-template
Frontend: - Fix debounce race condition in UserListTable search handler - Use useRef to properly track and cleanup timeout between keystrokes Backend (LLM Gateway): - Add thread-safe double-checked locking for global singletons (providers, circuit registry, cost tracker) - Fix Redis URL parsing with proper urlparse validation - Add explicit error handling for malformed Redis URLs - Document circuit breaker state transition safety 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
479 lines
15 KiB
Python
479 lines
15 KiB
Python
"""
|
|
Cost tracking for LLM Gateway.
|
|
|
|
Tracks LLM usage costs per project and agent using Redis.
|
|
Provides aggregation by hour, day, and month with TTL-based expiry.
|
|
"""
|
|
|
|
import logging
|
|
import threading
|
|
from datetime import UTC, datetime, timedelta
|
|
from typing import Any
|
|
|
|
import redis.asyncio as redis
|
|
|
|
from config import Settings, get_settings
|
|
from models import (
|
|
MODEL_CONFIGS,
|
|
UsageReport,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class CostTracker:
|
|
"""
|
|
Redis-based cost tracker for LLM usage.
|
|
|
|
Key structure:
|
|
- {prefix}:cost:project:{project_id}:{date} -> Hash of usage by model
|
|
- {prefix}:cost:agent:{agent_id}:{date} -> Hash of usage by model
|
|
- {prefix}:cost:session:{session_id} -> Hash of session usage
|
|
- {prefix}:requests:{project_id}:{date} -> Request count
|
|
|
|
Date formats:
|
|
- hour: YYYYMMDDHH
|
|
- day: YYYYMMDD
|
|
- month: YYYYMM
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
redis_client: redis.Redis | None = None,
|
|
settings: Settings | None = None,
|
|
) -> None:
|
|
"""
|
|
Initialize cost tracker.
|
|
|
|
Args:
|
|
redis_client: Redis client (creates one if None)
|
|
settings: Application settings
|
|
"""
|
|
self._settings = settings or get_settings()
|
|
self._redis: redis.Redis | None = redis_client
|
|
self._prefix = self._settings.redis_prefix
|
|
|
|
async def _get_redis(self) -> redis.Redis:
|
|
"""Get or create Redis client."""
|
|
if self._redis is None:
|
|
self._redis = redis.from_url(
|
|
self._settings.redis_url,
|
|
decode_responses=True,
|
|
)
|
|
return self._redis
|
|
|
|
async def close(self) -> None:
|
|
"""Close Redis connection."""
|
|
if self._redis:
|
|
await self._redis.aclose()
|
|
self._redis = None
|
|
|
|
def _get_date_keys(self, timestamp: datetime | None = None) -> dict[str, str]:
|
|
"""Get date format keys for different periods."""
|
|
if timestamp is None:
|
|
timestamp = datetime.now(UTC)
|
|
return {
|
|
"hour": timestamp.strftime("%Y%m%d%H"),
|
|
"day": timestamp.strftime("%Y%m%d"),
|
|
"month": timestamp.strftime("%Y%m"),
|
|
}
|
|
|
|
def _get_ttl_seconds(self, period: str) -> int:
|
|
"""Get TTL in seconds for a period."""
|
|
ttls = {
|
|
"hour": 24 * 3600, # 24 hours
|
|
"day": 30 * 24 * 3600, # 30 days
|
|
"month": 365 * 24 * 3600, # 1 year
|
|
}
|
|
return ttls.get(period, 30 * 24 * 3600)
|
|
|
|
async def record_usage(
|
|
self,
|
|
project_id: str,
|
|
agent_id: str,
|
|
model: str,
|
|
prompt_tokens: int,
|
|
completion_tokens: int,
|
|
cost_usd: float,
|
|
session_id: str | None = None,
|
|
request_id: str | None = None, # noqa: ARG002 - reserved for future logging
|
|
) -> None:
|
|
"""
|
|
Record LLM usage.
|
|
|
|
Args:
|
|
project_id: Project ID
|
|
agent_id: Agent ID
|
|
model: Model name
|
|
prompt_tokens: Input tokens
|
|
completion_tokens: Output tokens
|
|
cost_usd: Cost in USD
|
|
session_id: Optional session ID
|
|
request_id: Optional request ID
|
|
"""
|
|
if not self._settings.cost_tracking_enabled:
|
|
return
|
|
|
|
r = await self._get_redis()
|
|
date_keys = self._get_date_keys()
|
|
pipe = r.pipeline()
|
|
|
|
# Record for each time period
|
|
for period, date_key in date_keys.items():
|
|
# Project-level tracking
|
|
project_key = f"{self._prefix}:cost:project:{project_id}:{date_key}"
|
|
await self._increment_usage(
|
|
pipe, project_key, model, prompt_tokens, completion_tokens, cost_usd
|
|
)
|
|
pipe.expire(project_key, self._get_ttl_seconds(period))
|
|
|
|
# Agent-level tracking
|
|
agent_key = f"{self._prefix}:cost:agent:{agent_id}:{date_key}"
|
|
await self._increment_usage(
|
|
pipe, agent_key, model, prompt_tokens, completion_tokens, cost_usd
|
|
)
|
|
pipe.expire(agent_key, self._get_ttl_seconds(period))
|
|
|
|
# Request counter
|
|
requests_key = f"{self._prefix}:requests:{project_id}:{date_key}"
|
|
pipe.incr(requests_key)
|
|
pipe.expire(requests_key, self._get_ttl_seconds(period))
|
|
|
|
# Session tracking (if session_id provided)
|
|
if session_id:
|
|
session_key = f"{self._prefix}:cost:session:{session_id}"
|
|
await self._increment_usage(
|
|
pipe, session_key, model, prompt_tokens, completion_tokens, cost_usd
|
|
)
|
|
pipe.expire(session_key, 24 * 3600) # 24 hour TTL for sessions
|
|
|
|
await pipe.execute()
|
|
|
|
logger.debug(
|
|
f"Recorded usage: project={project_id}, agent={agent_id}, "
|
|
f"model={model}, tokens={prompt_tokens + completion_tokens}, "
|
|
f"cost=${cost_usd:.6f}"
|
|
)
|
|
|
|
async def _increment_usage(
|
|
self,
|
|
pipe: redis.client.Pipeline,
|
|
key: str,
|
|
model: str,
|
|
prompt_tokens: int,
|
|
completion_tokens: int,
|
|
cost_usd: float,
|
|
) -> None:
|
|
"""Increment usage in a hash."""
|
|
# Store as JSON fields within the hash
|
|
pipe.hincrby(key, f"{model}:prompt_tokens", prompt_tokens)
|
|
pipe.hincrby(key, f"{model}:completion_tokens", completion_tokens)
|
|
pipe.hincrbyfloat(key, f"{model}:cost_usd", cost_usd)
|
|
pipe.hincrby(key, f"{model}:requests", 1)
|
|
# Totals
|
|
pipe.hincrby(key, "total:prompt_tokens", prompt_tokens)
|
|
pipe.hincrby(key, "total:completion_tokens", completion_tokens)
|
|
pipe.hincrbyfloat(key, "total:cost_usd", cost_usd)
|
|
pipe.hincrby(key, "total:requests", 1)
|
|
|
|
async def get_project_usage(
|
|
self,
|
|
project_id: str,
|
|
period: str = "day",
|
|
timestamp: datetime | None = None,
|
|
) -> UsageReport:
|
|
"""
|
|
Get usage report for a project.
|
|
|
|
Args:
|
|
project_id: Project ID
|
|
period: Time period (hour, day, month)
|
|
timestamp: Specific time to query (defaults to now)
|
|
|
|
Returns:
|
|
Usage report
|
|
"""
|
|
date_keys = self._get_date_keys(timestamp)
|
|
date_key = date_keys.get(period, date_keys["day"])
|
|
|
|
key = f"{self._prefix}:cost:project:{project_id}:{date_key}"
|
|
return await self._get_usage_report(
|
|
key, project_id, "project", period, timestamp
|
|
)
|
|
|
|
async def get_agent_usage(
|
|
self,
|
|
agent_id: str,
|
|
period: str = "day",
|
|
timestamp: datetime | None = None,
|
|
) -> UsageReport:
|
|
"""
|
|
Get usage report for an agent.
|
|
|
|
Args:
|
|
agent_id: Agent ID
|
|
period: Time period (hour, day, month)
|
|
timestamp: Specific time to query (defaults to now)
|
|
|
|
Returns:
|
|
Usage report
|
|
"""
|
|
date_keys = self._get_date_keys(timestamp)
|
|
date_key = date_keys.get(period, date_keys["day"])
|
|
|
|
key = f"{self._prefix}:cost:agent:{agent_id}:{date_key}"
|
|
return await self._get_usage_report(key, agent_id, "agent", period, timestamp)
|
|
|
|
async def _get_usage_report(
|
|
self,
|
|
key: str,
|
|
entity_id: str,
|
|
entity_type: str,
|
|
period: str,
|
|
timestamp: datetime | None,
|
|
) -> UsageReport:
|
|
"""Get usage report from a Redis hash."""
|
|
r = await self._get_redis()
|
|
data = await r.hgetall(key) # type: ignore[misc]
|
|
|
|
# Parse the hash data
|
|
by_model: dict[str, dict[str, Any]] = {}
|
|
total_requests = 0
|
|
total_tokens = 0
|
|
total_cost = 0.0
|
|
|
|
for field, value in data.items():
|
|
parts = field.split(":")
|
|
if len(parts) != 2:
|
|
continue
|
|
|
|
model, metric = parts
|
|
if model == "total":
|
|
if metric == "requests":
|
|
total_requests = int(value)
|
|
elif metric == "prompt_tokens" or metric == "completion_tokens":
|
|
total_tokens += int(value)
|
|
elif metric == "cost_usd":
|
|
total_cost = float(value)
|
|
else:
|
|
if model not in by_model:
|
|
by_model[model] = {
|
|
"prompt_tokens": 0,
|
|
"completion_tokens": 0,
|
|
"cost_usd": 0.0,
|
|
"requests": 0,
|
|
}
|
|
if metric == "prompt_tokens":
|
|
by_model[model]["prompt_tokens"] = int(value)
|
|
elif metric == "completion_tokens":
|
|
by_model[model]["completion_tokens"] = int(value)
|
|
elif metric == "cost_usd":
|
|
by_model[model]["cost_usd"] = float(value)
|
|
elif metric == "requests":
|
|
by_model[model]["requests"] = int(value)
|
|
|
|
# Calculate period boundaries
|
|
now = timestamp or datetime.now(UTC)
|
|
if period == "hour":
|
|
period_start = now.replace(minute=0, second=0, microsecond=0)
|
|
period_end = period_start + timedelta(hours=1)
|
|
elif period == "day":
|
|
period_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
period_end = period_start + timedelta(days=1)
|
|
else: # month
|
|
period_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
|
# Next month
|
|
if now.month == 12:
|
|
period_end = period_start.replace(year=now.year + 1, month=1)
|
|
else:
|
|
period_end = period_start.replace(month=now.month + 1)
|
|
|
|
return UsageReport(
|
|
entity_id=entity_id,
|
|
entity_type=entity_type,
|
|
period=period,
|
|
period_start=period_start,
|
|
period_end=period_end,
|
|
total_requests=total_requests,
|
|
total_tokens=total_tokens,
|
|
total_cost_usd=round(total_cost, 6),
|
|
by_model=by_model,
|
|
)
|
|
|
|
async def get_session_usage(self, session_id: str) -> dict[str, Any]:
|
|
"""
|
|
Get usage for a specific session.
|
|
|
|
Args:
|
|
session_id: Session ID
|
|
|
|
Returns:
|
|
Session usage data
|
|
"""
|
|
r = await self._get_redis()
|
|
key = f"{self._prefix}:cost:session:{session_id}"
|
|
data = await r.hgetall(key) # type: ignore[misc]
|
|
|
|
# Parse similar to _get_usage_report
|
|
result: dict[str, Any] = {
|
|
"session_id": session_id,
|
|
"total_tokens": 0,
|
|
"total_cost_usd": 0.0,
|
|
"by_model": {},
|
|
}
|
|
|
|
for field, value in data.items():
|
|
parts = field.split(":")
|
|
if len(parts) != 2:
|
|
continue
|
|
|
|
model, metric = parts
|
|
if model == "total":
|
|
if metric == "prompt_tokens" or metric == "completion_tokens":
|
|
result["total_tokens"] += int(value)
|
|
elif metric == "cost_usd":
|
|
result["total_cost_usd"] = float(value)
|
|
else:
|
|
if model not in result["by_model"]:
|
|
result["by_model"][model] = {}
|
|
if metric in ("prompt_tokens", "completion_tokens", "requests"):
|
|
result["by_model"][model][metric] = int(value)
|
|
elif metric == "cost_usd":
|
|
result["by_model"][model][metric] = float(value)
|
|
|
|
return result
|
|
|
|
async def check_budget(
|
|
self,
|
|
project_id: str,
|
|
budget_limit: float | None = None,
|
|
) -> tuple[bool, float, float]:
|
|
"""
|
|
Check if project is within budget.
|
|
|
|
Args:
|
|
project_id: Project ID
|
|
budget_limit: Budget limit (uses default if None)
|
|
|
|
Returns:
|
|
Tuple of (within_budget, current_cost, limit)
|
|
"""
|
|
limit = budget_limit or self._settings.default_budget_limit
|
|
|
|
# Get current month usage
|
|
report = await self.get_project_usage(project_id, period="month")
|
|
current_cost = report.total_cost_usd
|
|
|
|
within_budget = current_cost < limit
|
|
return within_budget, current_cost, limit
|
|
|
|
async def estimate_request_cost(
|
|
self,
|
|
model: str,
|
|
prompt_tokens: int,
|
|
max_completion_tokens: int,
|
|
) -> float:
|
|
"""
|
|
Estimate cost for a request.
|
|
|
|
Args:
|
|
model: Model name
|
|
prompt_tokens: Input token count
|
|
max_completion_tokens: Maximum output tokens
|
|
|
|
Returns:
|
|
Estimated cost in USD
|
|
"""
|
|
config = MODEL_CONFIGS.get(model)
|
|
if not config:
|
|
# Use a default estimate
|
|
return (prompt_tokens + max_completion_tokens) * 0.00001
|
|
|
|
input_cost = (prompt_tokens / 1_000_000) * config.cost_per_1m_input
|
|
output_cost = (max_completion_tokens / 1_000_000) * config.cost_per_1m_output
|
|
return round(input_cost + output_cost, 6)
|
|
|
|
async def should_alert(
|
|
self,
|
|
project_id: str,
|
|
threshold: float | None = None,
|
|
) -> tuple[bool, float]:
|
|
"""
|
|
Check if cost alert should be triggered.
|
|
|
|
Args:
|
|
project_id: Project ID
|
|
threshold: Alert threshold (uses default if None)
|
|
|
|
Returns:
|
|
Tuple of (should_alert, current_cost)
|
|
"""
|
|
thresh = threshold or self._settings.cost_alert_threshold
|
|
|
|
report = await self.get_project_usage(project_id, period="day")
|
|
current_cost = report.total_cost_usd
|
|
|
|
return current_cost >= thresh, current_cost
|
|
|
|
|
|
def calculate_cost(
|
|
model: str,
|
|
prompt_tokens: int,
|
|
completion_tokens: int,
|
|
) -> float:
|
|
"""
|
|
Calculate cost for a completion.
|
|
|
|
Args:
|
|
model: Model name
|
|
prompt_tokens: Input tokens
|
|
completion_tokens: Output tokens
|
|
|
|
Returns:
|
|
Cost in USD
|
|
"""
|
|
config = MODEL_CONFIGS.get(model)
|
|
if not config:
|
|
logger.warning(f"Unknown model {model} for cost calculation")
|
|
return 0.0
|
|
|
|
input_cost = (prompt_tokens / 1_000_000) * config.cost_per_1m_input
|
|
output_cost = (completion_tokens / 1_000_000) * config.cost_per_1m_output
|
|
return round(input_cost + output_cost, 6)
|
|
|
|
|
|
# Global tracker instance with thread-safe lazy initialization
|
|
_tracker: CostTracker | None = None
|
|
_tracker_lock = threading.Lock()
|
|
|
|
|
|
def get_cost_tracker() -> CostTracker:
|
|
"""
|
|
Get the global cost tracker instance.
|
|
|
|
Thread-safe with double-checked locking pattern.
|
|
"""
|
|
global _tracker
|
|
if _tracker is None:
|
|
with _tracker_lock:
|
|
# Double-check after acquiring lock
|
|
if _tracker is None:
|
|
_tracker = CostTracker()
|
|
return _tracker
|
|
|
|
|
|
async def close_cost_tracker() -> None:
|
|
"""Close the global cost tracker."""
|
|
global _tracker
|
|
with _tracker_lock:
|
|
if _tracker:
|
|
await _tracker.close()
|
|
_tracker = None
|
|
|
|
|
|
def reset_cost_tracker() -> None:
|
|
"""Reset the global tracker (for testing)."""
|
|
global _tracker
|
|
with _tracker_lock:
|
|
_tracker = None
|