""" Cost tracking for LLM Gateway. Tracks LLM usage costs per project and agent using Redis. Provides aggregation by hour, day, and month with TTL-based expiry. """ import logging from datetime import UTC, datetime, timedelta from typing import Any import redis.asyncio as redis from config import Settings, get_settings from models import ( MODEL_CONFIGS, UsageReport, ) logger = logging.getLogger(__name__) class CostTracker: """ Redis-based cost tracker for LLM usage. Key structure: - {prefix}:cost:project:{project_id}:{date} -> Hash of usage by model - {prefix}:cost:agent:{agent_id}:{date} -> Hash of usage by model - {prefix}:cost:session:{session_id} -> Hash of session usage - {prefix}:requests:{project_id}:{date} -> Request count Date formats: - hour: YYYYMMDDHH - day: YYYYMMDD - month: YYYYMM """ def __init__( self, redis_client: redis.Redis | None = None, settings: Settings | None = None, ) -> None: """ Initialize cost tracker. Args: redis_client: Redis client (creates one if None) settings: Application settings """ self._settings = settings or get_settings() self._redis: redis.Redis | None = redis_client self._prefix = self._settings.redis_prefix async def _get_redis(self) -> redis.Redis: """Get or create Redis client.""" if self._redis is None: self._redis = redis.from_url( self._settings.redis_url, decode_responses=True, ) return self._redis async def close(self) -> None: """Close Redis connection.""" if self._redis: await self._redis.aclose() self._redis = None def _get_date_keys(self, timestamp: datetime | None = None) -> dict[str, str]: """Get date format keys for different periods.""" if timestamp is None: timestamp = datetime.now(UTC) return { "hour": timestamp.strftime("%Y%m%d%H"), "day": timestamp.strftime("%Y%m%d"), "month": timestamp.strftime("%Y%m"), } def _get_ttl_seconds(self, period: str) -> int: """Get TTL in seconds for a period.""" ttls = { "hour": 24 * 3600, # 24 hours "day": 30 * 24 * 3600, # 30 days "month": 365 * 24 * 3600, # 1 year } return ttls.get(period, 30 * 24 * 3600) async def record_usage( self, project_id: str, agent_id: str, model: str, prompt_tokens: int, completion_tokens: int, cost_usd: float, session_id: str | None = None, request_id: str | None = None, # noqa: ARG002 - reserved for future logging ) -> None: """ Record LLM usage. Args: project_id: Project ID agent_id: Agent ID model: Model name prompt_tokens: Input tokens completion_tokens: Output tokens cost_usd: Cost in USD session_id: Optional session ID request_id: Optional request ID """ if not self._settings.cost_tracking_enabled: return r = await self._get_redis() date_keys = self._get_date_keys() pipe = r.pipeline() # Record for each time period for period, date_key in date_keys.items(): # Project-level tracking project_key = f"{self._prefix}:cost:project:{project_id}:{date_key}" await self._increment_usage( pipe, project_key, model, prompt_tokens, completion_tokens, cost_usd ) pipe.expire(project_key, self._get_ttl_seconds(period)) # Agent-level tracking agent_key = f"{self._prefix}:cost:agent:{agent_id}:{date_key}" await self._increment_usage( pipe, agent_key, model, prompt_tokens, completion_tokens, cost_usd ) pipe.expire(agent_key, self._get_ttl_seconds(period)) # Request counter requests_key = f"{self._prefix}:requests:{project_id}:{date_key}" pipe.incr(requests_key) pipe.expire(requests_key, self._get_ttl_seconds(period)) # Session tracking (if session_id provided) if session_id: session_key = f"{self._prefix}:cost:session:{session_id}" await self._increment_usage( pipe, session_key, model, prompt_tokens, completion_tokens, cost_usd ) pipe.expire(session_key, 24 * 3600) # 24 hour TTL for sessions await pipe.execute() logger.debug( f"Recorded usage: project={project_id}, agent={agent_id}, " f"model={model}, tokens={prompt_tokens + completion_tokens}, " f"cost=${cost_usd:.6f}" ) async def _increment_usage( self, pipe: redis.client.Pipeline, key: str, model: str, prompt_tokens: int, completion_tokens: int, cost_usd: float, ) -> None: """Increment usage in a hash.""" # Store as JSON fields within the hash pipe.hincrby(key, f"{model}:prompt_tokens", prompt_tokens) pipe.hincrby(key, f"{model}:completion_tokens", completion_tokens) pipe.hincrbyfloat(key, f"{model}:cost_usd", cost_usd) pipe.hincrby(key, f"{model}:requests", 1) # Totals pipe.hincrby(key, "total:prompt_tokens", prompt_tokens) pipe.hincrby(key, "total:completion_tokens", completion_tokens) pipe.hincrbyfloat(key, "total:cost_usd", cost_usd) pipe.hincrby(key, "total:requests", 1) async def get_project_usage( self, project_id: str, period: str = "day", timestamp: datetime | None = None, ) -> UsageReport: """ Get usage report for a project. Args: project_id: Project ID period: Time period (hour, day, month) timestamp: Specific time to query (defaults to now) Returns: Usage report """ date_keys = self._get_date_keys(timestamp) date_key = date_keys.get(period, date_keys["day"]) key = f"{self._prefix}:cost:project:{project_id}:{date_key}" return await self._get_usage_report( key, project_id, "project", period, timestamp ) async def get_agent_usage( self, agent_id: str, period: str = "day", timestamp: datetime | None = None, ) -> UsageReport: """ Get usage report for an agent. Args: agent_id: Agent ID period: Time period (hour, day, month) timestamp: Specific time to query (defaults to now) Returns: Usage report """ date_keys = self._get_date_keys(timestamp) date_key = date_keys.get(period, date_keys["day"]) key = f"{self._prefix}:cost:agent:{agent_id}:{date_key}" return await self._get_usage_report(key, agent_id, "agent", period, timestamp) async def _get_usage_report( self, key: str, entity_id: str, entity_type: str, period: str, timestamp: datetime | None, ) -> UsageReport: """Get usage report from a Redis hash.""" r = await self._get_redis() data = await r.hgetall(key) # type: ignore[misc] # Parse the hash data by_model: dict[str, dict[str, Any]] = {} total_requests = 0 total_tokens = 0 total_cost = 0.0 for field, value in data.items(): parts = field.split(":") if len(parts) != 2: continue model, metric = parts if model == "total": if metric == "requests": total_requests = int(value) elif metric == "prompt_tokens" or metric == "completion_tokens": total_tokens += int(value) elif metric == "cost_usd": total_cost = float(value) else: if model not in by_model: by_model[model] = { "prompt_tokens": 0, "completion_tokens": 0, "cost_usd": 0.0, "requests": 0, } if metric == "prompt_tokens": by_model[model]["prompt_tokens"] = int(value) elif metric == "completion_tokens": by_model[model]["completion_tokens"] = int(value) elif metric == "cost_usd": by_model[model]["cost_usd"] = float(value) elif metric == "requests": by_model[model]["requests"] = int(value) # Calculate period boundaries now = timestamp or datetime.now(UTC) if period == "hour": period_start = now.replace(minute=0, second=0, microsecond=0) period_end = period_start + timedelta(hours=1) elif period == "day": period_start = now.replace(hour=0, minute=0, second=0, microsecond=0) period_end = period_start + timedelta(days=1) else: # month period_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) # Next month if now.month == 12: period_end = period_start.replace(year=now.year + 1, month=1) else: period_end = period_start.replace(month=now.month + 1) return UsageReport( entity_id=entity_id, entity_type=entity_type, period=period, period_start=period_start, period_end=period_end, total_requests=total_requests, total_tokens=total_tokens, total_cost_usd=round(total_cost, 6), by_model=by_model, ) async def get_session_usage(self, session_id: str) -> dict[str, Any]: """ Get usage for a specific session. Args: session_id: Session ID Returns: Session usage data """ r = await self._get_redis() key = f"{self._prefix}:cost:session:{session_id}" data = await r.hgetall(key) # type: ignore[misc] # Parse similar to _get_usage_report result: dict[str, Any] = { "session_id": session_id, "total_tokens": 0, "total_cost_usd": 0.0, "by_model": {}, } for field, value in data.items(): parts = field.split(":") if len(parts) != 2: continue model, metric = parts if model == "total": if metric == "prompt_tokens" or metric == "completion_tokens": result["total_tokens"] += int(value) elif metric == "cost_usd": result["total_cost_usd"] = float(value) else: if model not in result["by_model"]: result["by_model"][model] = {} if metric in ("prompt_tokens", "completion_tokens", "requests"): result["by_model"][model][metric] = int(value) elif metric == "cost_usd": result["by_model"][model][metric] = float(value) return result async def check_budget( self, project_id: str, budget_limit: float | None = None, ) -> tuple[bool, float, float]: """ Check if project is within budget. Args: project_id: Project ID budget_limit: Budget limit (uses default if None) Returns: Tuple of (within_budget, current_cost, limit) """ limit = budget_limit or self._settings.default_budget_limit # Get current month usage report = await self.get_project_usage(project_id, period="month") current_cost = report.total_cost_usd within_budget = current_cost < limit return within_budget, current_cost, limit async def estimate_request_cost( self, model: str, prompt_tokens: int, max_completion_tokens: int, ) -> float: """ Estimate cost for a request. Args: model: Model name prompt_tokens: Input token count max_completion_tokens: Maximum output tokens Returns: Estimated cost in USD """ config = MODEL_CONFIGS.get(model) if not config: # Use a default estimate return (prompt_tokens + max_completion_tokens) * 0.00001 input_cost = (prompt_tokens / 1_000_000) * config.cost_per_1m_input output_cost = (max_completion_tokens / 1_000_000) * config.cost_per_1m_output return round(input_cost + output_cost, 6) async def should_alert( self, project_id: str, threshold: float | None = None, ) -> tuple[bool, float]: """ Check if cost alert should be triggered. Args: project_id: Project ID threshold: Alert threshold (uses default if None) Returns: Tuple of (should_alert, current_cost) """ thresh = threshold or self._settings.cost_alert_threshold report = await self.get_project_usage(project_id, period="day") current_cost = report.total_cost_usd return current_cost >= thresh, current_cost def calculate_cost( model: str, prompt_tokens: int, completion_tokens: int, ) -> float: """ Calculate cost for a completion. Args: model: Model name prompt_tokens: Input tokens completion_tokens: Output tokens Returns: Cost in USD """ config = MODEL_CONFIGS.get(model) if not config: logger.warning(f"Unknown model {model} for cost calculation") return 0.0 input_cost = (prompt_tokens / 1_000_000) * config.cost_per_1m_input output_cost = (completion_tokens / 1_000_000) * config.cost_per_1m_output return round(input_cost + output_cost, 6) # Global tracker instance (lazy initialization) _tracker: CostTracker | None = None def get_cost_tracker() -> CostTracker: """Get the global cost tracker instance.""" global _tracker if _tracker is None: _tracker = CostTracker() return _tracker async def close_cost_tracker() -> None: """Close the global cost tracker.""" global _tracker if _tracker: await _tracker.close() _tracker = None def reset_cost_tracker() -> None: """Reset the global tracker (for testing).""" global _tracker _tracker = None