Files
syndarix/mcp-servers/llm-gateway/cost_tracking.py
Felipe Cardoso 95342cc94d fix(mcp-gateway): address critical issues from deep review
Frontend:
- Fix debounce race condition in UserListTable search handler
- Use useRef to properly track and cleanup timeout between keystrokes

Backend (LLM Gateway):
- Add thread-safe double-checked locking for global singletons
  (providers, circuit registry, cost tracker)
- Fix Redis URL parsing with proper urlparse validation
- Add explicit error handling for malformed Redis URLs
- Document circuit breaker state transition safety

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-04 01:36:55 +01:00

479 lines
15 KiB
Python

"""
Cost tracking for LLM Gateway.
Tracks LLM usage costs per project and agent using Redis.
Provides aggregation by hour, day, and month with TTL-based expiry.
"""
import logging
import threading
from datetime import UTC, datetime, timedelta
from typing import Any
import redis.asyncio as redis
from config import Settings, get_settings
from models import (
MODEL_CONFIGS,
UsageReport,
)
logger = logging.getLogger(__name__)
class CostTracker:
"""
Redis-based cost tracker for LLM usage.
Key structure:
- {prefix}:cost:project:{project_id}:{date} -> Hash of usage by model
- {prefix}:cost:agent:{agent_id}:{date} -> Hash of usage by model
- {prefix}:cost:session:{session_id} -> Hash of session usage
- {prefix}:requests:{project_id}:{date} -> Request count
Date formats:
- hour: YYYYMMDDHH
- day: YYYYMMDD
- month: YYYYMM
"""
def __init__(
self,
redis_client: redis.Redis | None = None,
settings: Settings | None = None,
) -> None:
"""
Initialize cost tracker.
Args:
redis_client: Redis client (creates one if None)
settings: Application settings
"""
self._settings = settings or get_settings()
self._redis: redis.Redis | None = redis_client
self._prefix = self._settings.redis_prefix
async def _get_redis(self) -> redis.Redis:
"""Get or create Redis client."""
if self._redis is None:
self._redis = redis.from_url(
self._settings.redis_url,
decode_responses=True,
)
return self._redis
async def close(self) -> None:
"""Close Redis connection."""
if self._redis:
await self._redis.aclose()
self._redis = None
def _get_date_keys(self, timestamp: datetime | None = None) -> dict[str, str]:
"""Get date format keys for different periods."""
if timestamp is None:
timestamp = datetime.now(UTC)
return {
"hour": timestamp.strftime("%Y%m%d%H"),
"day": timestamp.strftime("%Y%m%d"),
"month": timestamp.strftime("%Y%m"),
}
def _get_ttl_seconds(self, period: str) -> int:
"""Get TTL in seconds for a period."""
ttls = {
"hour": 24 * 3600, # 24 hours
"day": 30 * 24 * 3600, # 30 days
"month": 365 * 24 * 3600, # 1 year
}
return ttls.get(period, 30 * 24 * 3600)
async def record_usage(
self,
project_id: str,
agent_id: str,
model: str,
prompt_tokens: int,
completion_tokens: int,
cost_usd: float,
session_id: str | None = None,
request_id: str | None = None, # noqa: ARG002 - reserved for future logging
) -> None:
"""
Record LLM usage.
Args:
project_id: Project ID
agent_id: Agent ID
model: Model name
prompt_tokens: Input tokens
completion_tokens: Output tokens
cost_usd: Cost in USD
session_id: Optional session ID
request_id: Optional request ID
"""
if not self._settings.cost_tracking_enabled:
return
r = await self._get_redis()
date_keys = self._get_date_keys()
pipe = r.pipeline()
# Record for each time period
for period, date_key in date_keys.items():
# Project-level tracking
project_key = f"{self._prefix}:cost:project:{project_id}:{date_key}"
await self._increment_usage(
pipe, project_key, model, prompt_tokens, completion_tokens, cost_usd
)
pipe.expire(project_key, self._get_ttl_seconds(period))
# Agent-level tracking
agent_key = f"{self._prefix}:cost:agent:{agent_id}:{date_key}"
await self._increment_usage(
pipe, agent_key, model, prompt_tokens, completion_tokens, cost_usd
)
pipe.expire(agent_key, self._get_ttl_seconds(period))
# Request counter
requests_key = f"{self._prefix}:requests:{project_id}:{date_key}"
pipe.incr(requests_key)
pipe.expire(requests_key, self._get_ttl_seconds(period))
# Session tracking (if session_id provided)
if session_id:
session_key = f"{self._prefix}:cost:session:{session_id}"
await self._increment_usage(
pipe, session_key, model, prompt_tokens, completion_tokens, cost_usd
)
pipe.expire(session_key, 24 * 3600) # 24 hour TTL for sessions
await pipe.execute()
logger.debug(
f"Recorded usage: project={project_id}, agent={agent_id}, "
f"model={model}, tokens={prompt_tokens + completion_tokens}, "
f"cost=${cost_usd:.6f}"
)
async def _increment_usage(
self,
pipe: redis.client.Pipeline,
key: str,
model: str,
prompt_tokens: int,
completion_tokens: int,
cost_usd: float,
) -> None:
"""Increment usage in a hash."""
# Store as JSON fields within the hash
pipe.hincrby(key, f"{model}:prompt_tokens", prompt_tokens)
pipe.hincrby(key, f"{model}:completion_tokens", completion_tokens)
pipe.hincrbyfloat(key, f"{model}:cost_usd", cost_usd)
pipe.hincrby(key, f"{model}:requests", 1)
# Totals
pipe.hincrby(key, "total:prompt_tokens", prompt_tokens)
pipe.hincrby(key, "total:completion_tokens", completion_tokens)
pipe.hincrbyfloat(key, "total:cost_usd", cost_usd)
pipe.hincrby(key, "total:requests", 1)
async def get_project_usage(
self,
project_id: str,
period: str = "day",
timestamp: datetime | None = None,
) -> UsageReport:
"""
Get usage report for a project.
Args:
project_id: Project ID
period: Time period (hour, day, month)
timestamp: Specific time to query (defaults to now)
Returns:
Usage report
"""
date_keys = self._get_date_keys(timestamp)
date_key = date_keys.get(period, date_keys["day"])
key = f"{self._prefix}:cost:project:{project_id}:{date_key}"
return await self._get_usage_report(
key, project_id, "project", period, timestamp
)
async def get_agent_usage(
self,
agent_id: str,
period: str = "day",
timestamp: datetime | None = None,
) -> UsageReport:
"""
Get usage report for an agent.
Args:
agent_id: Agent ID
period: Time period (hour, day, month)
timestamp: Specific time to query (defaults to now)
Returns:
Usage report
"""
date_keys = self._get_date_keys(timestamp)
date_key = date_keys.get(period, date_keys["day"])
key = f"{self._prefix}:cost:agent:{agent_id}:{date_key}"
return await self._get_usage_report(key, agent_id, "agent", period, timestamp)
async def _get_usage_report(
self,
key: str,
entity_id: str,
entity_type: str,
period: str,
timestamp: datetime | None,
) -> UsageReport:
"""Get usage report from a Redis hash."""
r = await self._get_redis()
data = await r.hgetall(key) # type: ignore[misc]
# Parse the hash data
by_model: dict[str, dict[str, Any]] = {}
total_requests = 0
total_tokens = 0
total_cost = 0.0
for field, value in data.items():
parts = field.split(":")
if len(parts) != 2:
continue
model, metric = parts
if model == "total":
if metric == "requests":
total_requests = int(value)
elif metric == "prompt_tokens" or metric == "completion_tokens":
total_tokens += int(value)
elif metric == "cost_usd":
total_cost = float(value)
else:
if model not in by_model:
by_model[model] = {
"prompt_tokens": 0,
"completion_tokens": 0,
"cost_usd": 0.0,
"requests": 0,
}
if metric == "prompt_tokens":
by_model[model]["prompt_tokens"] = int(value)
elif metric == "completion_tokens":
by_model[model]["completion_tokens"] = int(value)
elif metric == "cost_usd":
by_model[model]["cost_usd"] = float(value)
elif metric == "requests":
by_model[model]["requests"] = int(value)
# Calculate period boundaries
now = timestamp or datetime.now(UTC)
if period == "hour":
period_start = now.replace(minute=0, second=0, microsecond=0)
period_end = period_start + timedelta(hours=1)
elif period == "day":
period_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
period_end = period_start + timedelta(days=1)
else: # month
period_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
# Next month
if now.month == 12:
period_end = period_start.replace(year=now.year + 1, month=1)
else:
period_end = period_start.replace(month=now.month + 1)
return UsageReport(
entity_id=entity_id,
entity_type=entity_type,
period=period,
period_start=period_start,
period_end=period_end,
total_requests=total_requests,
total_tokens=total_tokens,
total_cost_usd=round(total_cost, 6),
by_model=by_model,
)
async def get_session_usage(self, session_id: str) -> dict[str, Any]:
"""
Get usage for a specific session.
Args:
session_id: Session ID
Returns:
Session usage data
"""
r = await self._get_redis()
key = f"{self._prefix}:cost:session:{session_id}"
data = await r.hgetall(key) # type: ignore[misc]
# Parse similar to _get_usage_report
result: dict[str, Any] = {
"session_id": session_id,
"total_tokens": 0,
"total_cost_usd": 0.0,
"by_model": {},
}
for field, value in data.items():
parts = field.split(":")
if len(parts) != 2:
continue
model, metric = parts
if model == "total":
if metric == "prompt_tokens" or metric == "completion_tokens":
result["total_tokens"] += int(value)
elif metric == "cost_usd":
result["total_cost_usd"] = float(value)
else:
if model not in result["by_model"]:
result["by_model"][model] = {}
if metric in ("prompt_tokens", "completion_tokens", "requests"):
result["by_model"][model][metric] = int(value)
elif metric == "cost_usd":
result["by_model"][model][metric] = float(value)
return result
async def check_budget(
self,
project_id: str,
budget_limit: float | None = None,
) -> tuple[bool, float, float]:
"""
Check if project is within budget.
Args:
project_id: Project ID
budget_limit: Budget limit (uses default if None)
Returns:
Tuple of (within_budget, current_cost, limit)
"""
limit = budget_limit or self._settings.default_budget_limit
# Get current month usage
report = await self.get_project_usage(project_id, period="month")
current_cost = report.total_cost_usd
within_budget = current_cost < limit
return within_budget, current_cost, limit
async def estimate_request_cost(
self,
model: str,
prompt_tokens: int,
max_completion_tokens: int,
) -> float:
"""
Estimate cost for a request.
Args:
model: Model name
prompt_tokens: Input token count
max_completion_tokens: Maximum output tokens
Returns:
Estimated cost in USD
"""
config = MODEL_CONFIGS.get(model)
if not config:
# Use a default estimate
return (prompt_tokens + max_completion_tokens) * 0.00001
input_cost = (prompt_tokens / 1_000_000) * config.cost_per_1m_input
output_cost = (max_completion_tokens / 1_000_000) * config.cost_per_1m_output
return round(input_cost + output_cost, 6)
async def should_alert(
self,
project_id: str,
threshold: float | None = None,
) -> tuple[bool, float]:
"""
Check if cost alert should be triggered.
Args:
project_id: Project ID
threshold: Alert threshold (uses default if None)
Returns:
Tuple of (should_alert, current_cost)
"""
thresh = threshold or self._settings.cost_alert_threshold
report = await self.get_project_usage(project_id, period="day")
current_cost = report.total_cost_usd
return current_cost >= thresh, current_cost
def calculate_cost(
model: str,
prompt_tokens: int,
completion_tokens: int,
) -> float:
"""
Calculate cost for a completion.
Args:
model: Model name
prompt_tokens: Input tokens
completion_tokens: Output tokens
Returns:
Cost in USD
"""
config = MODEL_CONFIGS.get(model)
if not config:
logger.warning(f"Unknown model {model} for cost calculation")
return 0.0
input_cost = (prompt_tokens / 1_000_000) * config.cost_per_1m_input
output_cost = (completion_tokens / 1_000_000) * config.cost_per_1m_output
return round(input_cost + output_cost, 6)
# Global tracker instance with thread-safe lazy initialization
_tracker: CostTracker | None = None
_tracker_lock = threading.Lock()
def get_cost_tracker() -> CostTracker:
"""
Get the global cost tracker instance.
Thread-safe with double-checked locking pattern.
"""
global _tracker
if _tracker is None:
with _tracker_lock:
# Double-check after acquiring lock
if _tracker is None:
_tracker = CostTracker()
return _tracker
async def close_cost_tracker() -> None:
"""Close the global cost tracker."""
global _tracker
with _tracker_lock:
if _tracker:
await _tracker.close()
_tracker = None
def reset_cost_tracker() -> None:
"""Reset the global tracker (for testing)."""
global _tracker
with _tracker_lock:
_tracker = None