feat(llm-gateway): implement LLM Gateway MCP Server (#56)
Implements complete LLM Gateway MCP Server with: - FastMCP server with 4 tools: chat_completion, list_models, get_usage, count_tokens - LiteLLM Router with multi-provider failover chains - Circuit breaker pattern for fault tolerance - Redis-based cost tracking per project/agent - Comprehensive test suite (209 tests, 92% coverage) Model groups defined per ADR-004: - reasoning: claude-opus-4 → gpt-4.1 → gemini-2.5-pro - code: claude-sonnet-4 → gpt-4.1 → deepseek-coder - fast: claude-haiku → gpt-4.1-mini → gemini-2.0-flash 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
467
mcp-servers/llm-gateway/cost_tracking.py
Normal file
467
mcp-servers/llm-gateway/cost_tracking.py
Normal file
@@ -0,0 +1,467 @@
|
||||
"""
|
||||
Cost tracking for LLM Gateway.
|
||||
|
||||
Tracks LLM usage costs per project and agent using Redis.
|
||||
Provides aggregation by hour, day, and month with TTL-based expiry.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
import redis.asyncio as redis
|
||||
|
||||
from config import Settings, get_settings
|
||||
from models import (
|
||||
MODEL_CONFIGS,
|
||||
UsageReport,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CostTracker:
|
||||
"""
|
||||
Redis-based cost tracker for LLM usage.
|
||||
|
||||
Key structure:
|
||||
- {prefix}:cost:project:{project_id}:{date} -> Hash of usage by model
|
||||
- {prefix}:cost:agent:{agent_id}:{date} -> Hash of usage by model
|
||||
- {prefix}:cost:session:{session_id} -> Hash of session usage
|
||||
- {prefix}:requests:{project_id}:{date} -> Request count
|
||||
|
||||
Date formats:
|
||||
- hour: YYYYMMDDHH
|
||||
- day: YYYYMMDD
|
||||
- month: YYYYMM
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: redis.Redis | None = None,
|
||||
settings: Settings | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize cost tracker.
|
||||
|
||||
Args:
|
||||
redis_client: Redis client (creates one if None)
|
||||
settings: Application settings
|
||||
"""
|
||||
self._settings = settings or get_settings()
|
||||
self._redis: redis.Redis | None = redis_client
|
||||
self._prefix = self._settings.redis_prefix
|
||||
|
||||
async def _get_redis(self) -> redis.Redis:
|
||||
"""Get or create Redis client."""
|
||||
if self._redis is None:
|
||||
self._redis = redis.from_url(
|
||||
self._settings.redis_url,
|
||||
decode_responses=True,
|
||||
)
|
||||
return self._redis
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close Redis connection."""
|
||||
if self._redis:
|
||||
await self._redis.aclose()
|
||||
self._redis = None
|
||||
|
||||
def _get_date_keys(self, timestamp: datetime | None = None) -> dict[str, str]:
|
||||
"""Get date format keys for different periods."""
|
||||
if timestamp is None:
|
||||
timestamp = datetime.now(UTC)
|
||||
return {
|
||||
"hour": timestamp.strftime("%Y%m%d%H"),
|
||||
"day": timestamp.strftime("%Y%m%d"),
|
||||
"month": timestamp.strftime("%Y%m"),
|
||||
}
|
||||
|
||||
def _get_ttl_seconds(self, period: str) -> int:
|
||||
"""Get TTL in seconds for a period."""
|
||||
ttls = {
|
||||
"hour": 24 * 3600, # 24 hours
|
||||
"day": 30 * 24 * 3600, # 30 days
|
||||
"month": 365 * 24 * 3600, # 1 year
|
||||
}
|
||||
return ttls.get(period, 30 * 24 * 3600)
|
||||
|
||||
async def record_usage(
|
||||
self,
|
||||
project_id: str,
|
||||
agent_id: str,
|
||||
model: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
cost_usd: float,
|
||||
session_id: str | None = None,
|
||||
request_id: str | None = None, # noqa: ARG002 - reserved for future logging
|
||||
) -> None:
|
||||
"""
|
||||
Record LLM usage.
|
||||
|
||||
Args:
|
||||
project_id: Project ID
|
||||
agent_id: Agent ID
|
||||
model: Model name
|
||||
prompt_tokens: Input tokens
|
||||
completion_tokens: Output tokens
|
||||
cost_usd: Cost in USD
|
||||
session_id: Optional session ID
|
||||
request_id: Optional request ID
|
||||
"""
|
||||
if not self._settings.cost_tracking_enabled:
|
||||
return
|
||||
|
||||
r = await self._get_redis()
|
||||
date_keys = self._get_date_keys()
|
||||
pipe = r.pipeline()
|
||||
|
||||
# Record for each time period
|
||||
for period, date_key in date_keys.items():
|
||||
# Project-level tracking
|
||||
project_key = f"{self._prefix}:cost:project:{project_id}:{date_key}"
|
||||
await self._increment_usage(
|
||||
pipe, project_key, model, prompt_tokens, completion_tokens, cost_usd
|
||||
)
|
||||
pipe.expire(project_key, self._get_ttl_seconds(period))
|
||||
|
||||
# Agent-level tracking
|
||||
agent_key = f"{self._prefix}:cost:agent:{agent_id}:{date_key}"
|
||||
await self._increment_usage(
|
||||
pipe, agent_key, model, prompt_tokens, completion_tokens, cost_usd
|
||||
)
|
||||
pipe.expire(agent_key, self._get_ttl_seconds(period))
|
||||
|
||||
# Request counter
|
||||
requests_key = f"{self._prefix}:requests:{project_id}:{date_key}"
|
||||
pipe.incr(requests_key)
|
||||
pipe.expire(requests_key, self._get_ttl_seconds(period))
|
||||
|
||||
# Session tracking (if session_id provided)
|
||||
if session_id:
|
||||
session_key = f"{self._prefix}:cost:session:{session_id}"
|
||||
await self._increment_usage(
|
||||
pipe, session_key, model, prompt_tokens, completion_tokens, cost_usd
|
||||
)
|
||||
pipe.expire(session_key, 24 * 3600) # 24 hour TTL for sessions
|
||||
|
||||
await pipe.execute()
|
||||
|
||||
logger.debug(
|
||||
f"Recorded usage: project={project_id}, agent={agent_id}, "
|
||||
f"model={model}, tokens={prompt_tokens + completion_tokens}, "
|
||||
f"cost=${cost_usd:.6f}"
|
||||
)
|
||||
|
||||
async def _increment_usage(
|
||||
self,
|
||||
pipe: redis.client.Pipeline,
|
||||
key: str,
|
||||
model: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
cost_usd: float,
|
||||
) -> None:
|
||||
"""Increment usage in a hash."""
|
||||
# Store as JSON fields within the hash
|
||||
pipe.hincrby(key, f"{model}:prompt_tokens", prompt_tokens)
|
||||
pipe.hincrby(key, f"{model}:completion_tokens", completion_tokens)
|
||||
pipe.hincrbyfloat(key, f"{model}:cost_usd", cost_usd)
|
||||
pipe.hincrby(key, f"{model}:requests", 1)
|
||||
# Totals
|
||||
pipe.hincrby(key, "total:prompt_tokens", prompt_tokens)
|
||||
pipe.hincrby(key, "total:completion_tokens", completion_tokens)
|
||||
pipe.hincrbyfloat(key, "total:cost_usd", cost_usd)
|
||||
pipe.hincrby(key, "total:requests", 1)
|
||||
|
||||
async def get_project_usage(
|
||||
self,
|
||||
project_id: str,
|
||||
period: str = "day",
|
||||
timestamp: datetime | None = None,
|
||||
) -> UsageReport:
|
||||
"""
|
||||
Get usage report for a project.
|
||||
|
||||
Args:
|
||||
project_id: Project ID
|
||||
period: Time period (hour, day, month)
|
||||
timestamp: Specific time to query (defaults to now)
|
||||
|
||||
Returns:
|
||||
Usage report
|
||||
"""
|
||||
date_keys = self._get_date_keys(timestamp)
|
||||
date_key = date_keys.get(period, date_keys["day"])
|
||||
|
||||
key = f"{self._prefix}:cost:project:{project_id}:{date_key}"
|
||||
return await self._get_usage_report(
|
||||
key, project_id, "project", period, timestamp
|
||||
)
|
||||
|
||||
async def get_agent_usage(
|
||||
self,
|
||||
agent_id: str,
|
||||
period: str = "day",
|
||||
timestamp: datetime | None = None,
|
||||
) -> UsageReport:
|
||||
"""
|
||||
Get usage report for an agent.
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
period: Time period (hour, day, month)
|
||||
timestamp: Specific time to query (defaults to now)
|
||||
|
||||
Returns:
|
||||
Usage report
|
||||
"""
|
||||
date_keys = self._get_date_keys(timestamp)
|
||||
date_key = date_keys.get(period, date_keys["day"])
|
||||
|
||||
key = f"{self._prefix}:cost:agent:{agent_id}:{date_key}"
|
||||
return await self._get_usage_report(key, agent_id, "agent", period, timestamp)
|
||||
|
||||
async def _get_usage_report(
|
||||
self,
|
||||
key: str,
|
||||
entity_id: str,
|
||||
entity_type: str,
|
||||
period: str,
|
||||
timestamp: datetime | None,
|
||||
) -> UsageReport:
|
||||
"""Get usage report from a Redis hash."""
|
||||
r = await self._get_redis()
|
||||
data = await r.hgetall(key)
|
||||
|
||||
# Parse the hash data
|
||||
by_model: dict[str, dict[str, Any]] = {}
|
||||
total_requests = 0
|
||||
total_tokens = 0
|
||||
total_cost = 0.0
|
||||
|
||||
for field, value in data.items():
|
||||
parts = field.split(":")
|
||||
if len(parts) != 2:
|
||||
continue
|
||||
|
||||
model, metric = parts
|
||||
if model == "total":
|
||||
if metric == "requests":
|
||||
total_requests = int(value)
|
||||
elif metric == "prompt_tokens" or metric == "completion_tokens":
|
||||
total_tokens += int(value)
|
||||
elif metric == "cost_usd":
|
||||
total_cost = float(value)
|
||||
else:
|
||||
if model not in by_model:
|
||||
by_model[model] = {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"cost_usd": 0.0,
|
||||
"requests": 0,
|
||||
}
|
||||
if metric == "prompt_tokens":
|
||||
by_model[model]["prompt_tokens"] = int(value)
|
||||
elif metric == "completion_tokens":
|
||||
by_model[model]["completion_tokens"] = int(value)
|
||||
elif metric == "cost_usd":
|
||||
by_model[model]["cost_usd"] = float(value)
|
||||
elif metric == "requests":
|
||||
by_model[model]["requests"] = int(value)
|
||||
|
||||
# Calculate period boundaries
|
||||
now = timestamp or datetime.now(UTC)
|
||||
if period == "hour":
|
||||
period_start = now.replace(minute=0, second=0, microsecond=0)
|
||||
period_end = period_start + timedelta(hours=1)
|
||||
elif period == "day":
|
||||
period_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
period_end = period_start + timedelta(days=1)
|
||||
else: # month
|
||||
period_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
# Next month
|
||||
if now.month == 12:
|
||||
period_end = period_start.replace(year=now.year + 1, month=1)
|
||||
else:
|
||||
period_end = period_start.replace(month=now.month + 1)
|
||||
|
||||
return UsageReport(
|
||||
entity_id=entity_id,
|
||||
entity_type=entity_type,
|
||||
period=period,
|
||||
period_start=period_start,
|
||||
period_end=period_end,
|
||||
total_requests=total_requests,
|
||||
total_tokens=total_tokens,
|
||||
total_cost_usd=round(total_cost, 6),
|
||||
by_model=by_model,
|
||||
)
|
||||
|
||||
async def get_session_usage(self, session_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
Get usage for a specific session.
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
|
||||
Returns:
|
||||
Session usage data
|
||||
"""
|
||||
r = await self._get_redis()
|
||||
key = f"{self._prefix}:cost:session:{session_id}"
|
||||
data = await r.hgetall(key)
|
||||
|
||||
# Parse similar to _get_usage_report
|
||||
result: dict[str, Any] = {
|
||||
"session_id": session_id,
|
||||
"total_tokens": 0,
|
||||
"total_cost_usd": 0.0,
|
||||
"by_model": {},
|
||||
}
|
||||
|
||||
for field, value in data.items():
|
||||
parts = field.split(":")
|
||||
if len(parts) != 2:
|
||||
continue
|
||||
|
||||
model, metric = parts
|
||||
if model == "total":
|
||||
if metric == "prompt_tokens" or metric == "completion_tokens":
|
||||
result["total_tokens"] += int(value)
|
||||
elif metric == "cost_usd":
|
||||
result["total_cost_usd"] = float(value)
|
||||
else:
|
||||
if model not in result["by_model"]:
|
||||
result["by_model"][model] = {}
|
||||
if metric in ("prompt_tokens", "completion_tokens", "requests"):
|
||||
result["by_model"][model][metric] = int(value)
|
||||
elif metric == "cost_usd":
|
||||
result["by_model"][model][metric] = float(value)
|
||||
|
||||
return result
|
||||
|
||||
async def check_budget(
|
||||
self,
|
||||
project_id: str,
|
||||
budget_limit: float | None = None,
|
||||
) -> tuple[bool, float, float]:
|
||||
"""
|
||||
Check if project is within budget.
|
||||
|
||||
Args:
|
||||
project_id: Project ID
|
||||
budget_limit: Budget limit (uses default if None)
|
||||
|
||||
Returns:
|
||||
Tuple of (within_budget, current_cost, limit)
|
||||
"""
|
||||
limit = budget_limit or self._settings.default_budget_limit
|
||||
|
||||
# Get current month usage
|
||||
report = await self.get_project_usage(project_id, period="month")
|
||||
current_cost = report.total_cost_usd
|
||||
|
||||
within_budget = current_cost < limit
|
||||
return within_budget, current_cost, limit
|
||||
|
||||
async def estimate_request_cost(
|
||||
self,
|
||||
model: str,
|
||||
prompt_tokens: int,
|
||||
max_completion_tokens: int,
|
||||
) -> float:
|
||||
"""
|
||||
Estimate cost for a request.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
prompt_tokens: Input token count
|
||||
max_completion_tokens: Maximum output tokens
|
||||
|
||||
Returns:
|
||||
Estimated cost in USD
|
||||
"""
|
||||
config = MODEL_CONFIGS.get(model)
|
||||
if not config:
|
||||
# Use a default estimate
|
||||
return (prompt_tokens + max_completion_tokens) * 0.00001
|
||||
|
||||
input_cost = (prompt_tokens / 1_000_000) * config.cost_per_1m_input
|
||||
output_cost = (max_completion_tokens / 1_000_000) * config.cost_per_1m_output
|
||||
return round(input_cost + output_cost, 6)
|
||||
|
||||
async def should_alert(
|
||||
self,
|
||||
project_id: str,
|
||||
threshold: float | None = None,
|
||||
) -> tuple[bool, float]:
|
||||
"""
|
||||
Check if cost alert should be triggered.
|
||||
|
||||
Args:
|
||||
project_id: Project ID
|
||||
threshold: Alert threshold (uses default if None)
|
||||
|
||||
Returns:
|
||||
Tuple of (should_alert, current_cost)
|
||||
"""
|
||||
thresh = threshold or self._settings.cost_alert_threshold
|
||||
|
||||
report = await self.get_project_usage(project_id, period="day")
|
||||
current_cost = report.total_cost_usd
|
||||
|
||||
return current_cost >= thresh, current_cost
|
||||
|
||||
|
||||
def calculate_cost(
|
||||
model: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
) -> float:
|
||||
"""
|
||||
Calculate cost for a completion.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
prompt_tokens: Input tokens
|
||||
completion_tokens: Output tokens
|
||||
|
||||
Returns:
|
||||
Cost in USD
|
||||
"""
|
||||
config = MODEL_CONFIGS.get(model)
|
||||
if not config:
|
||||
logger.warning(f"Unknown model {model} for cost calculation")
|
||||
return 0.0
|
||||
|
||||
input_cost = (prompt_tokens / 1_000_000) * config.cost_per_1m_input
|
||||
output_cost = (completion_tokens / 1_000_000) * config.cost_per_1m_output
|
||||
return round(input_cost + output_cost, 6)
|
||||
|
||||
|
||||
# Global tracker instance (lazy initialization)
|
||||
_tracker: CostTracker | None = None
|
||||
|
||||
|
||||
def get_cost_tracker() -> CostTracker:
|
||||
"""Get the global cost tracker instance."""
|
||||
global _tracker
|
||||
if _tracker is None:
|
||||
_tracker = CostTracker()
|
||||
return _tracker
|
||||
|
||||
|
||||
async def close_cost_tracker() -> None:
|
||||
"""Close the global cost tracker."""
|
||||
global _tracker
|
||||
if _tracker:
|
||||
await _tracker.close()
|
||||
_tracker = None
|
||||
|
||||
|
||||
def reset_cost_tracker() -> None:
|
||||
"""Reset the global tracker (for testing)."""
|
||||
global _tracker
|
||||
_tracker = None
|
||||
Reference in New Issue
Block a user