""" Data models for LLM Gateway MCP Server. Defines model groups, pricing, request/response structures. Per ADR-004: LLM Provider Abstraction. """ from dataclasses import dataclass, field from datetime import UTC, datetime from enum import Enum from typing import Any from pydantic import BaseModel, Field class ModelGroup(str, Enum): """Model groups for routing LLM requests.""" REASONING = "reasoning" # Complex analysis, architecture decisions CODE = "code" # Code writing and refactoring FAST = "fast" # Quick tasks, simple queries VISION = "vision" # Multimodal image analysis EMBEDDING = "embedding" # Vector embeddings COST_OPTIMIZED = "cost_optimized" # High-volume, non-critical SELF_HOSTED = "self_hosted" # Privacy-sensitive, air-gapped # Aliases for backward compatibility with ADR-004 HIGH_REASONING = "reasoning" CODE_GENERATION = "code" FAST_RESPONSE = "fast" class Provider(str, Enum): """Supported LLM providers.""" ANTHROPIC = "anthropic" OPENAI = "openai" GOOGLE = "google" ALIBABA = "alibaba" DEEPSEEK = "deepseek" @dataclass class ModelConfig: """Configuration for a specific model.""" name: str # Model identifier (e.g., "claude-3-opus-20240229") litellm_name: str # LiteLLM model string (e.g., "anthropic/claude-3-opus-20240229") provider: Provider cost_per_1m_input: float # USD per 1M input tokens cost_per_1m_output: float # USD per 1M output tokens context_window: int # Max context tokens max_output_tokens: int # Max output tokens supports_vision: bool = False supports_streaming: bool = True supports_function_calling: bool = True # Model configurations per ADR-004 MODEL_CONFIGS: dict[str, ModelConfig] = { # Anthropic models "claude-opus-4": ModelConfig( name="claude-opus-4", litellm_name="anthropic/claude-sonnet-4-20250514", # Using sonnet-4 as opus-4 placeholder provider=Provider.ANTHROPIC, cost_per_1m_input=15.0, cost_per_1m_output=75.0, context_window=200000, max_output_tokens=8192, supports_vision=True, ), "claude-sonnet-4": ModelConfig( name="claude-sonnet-4", litellm_name="anthropic/claude-sonnet-4-20250514", provider=Provider.ANTHROPIC, cost_per_1m_input=3.0, cost_per_1m_output=15.0, context_window=200000, max_output_tokens=8192, supports_vision=True, ), "claude-haiku": ModelConfig( name="claude-haiku", litellm_name="anthropic/claude-3-5-haiku-20241022", provider=Provider.ANTHROPIC, cost_per_1m_input=1.0, cost_per_1m_output=5.0, context_window=200000, max_output_tokens=8192, supports_vision=True, ), # OpenAI models "gpt-4.1": ModelConfig( name="gpt-4.1", litellm_name="openai/gpt-4.1", provider=Provider.OPENAI, cost_per_1m_input=2.0, cost_per_1m_output=8.0, context_window=1047576, max_output_tokens=32768, supports_vision=True, ), "gpt-4.1-mini": ModelConfig( name="gpt-4.1-mini", litellm_name="openai/gpt-4.1-mini", provider=Provider.OPENAI, cost_per_1m_input=0.4, cost_per_1m_output=1.6, context_window=1047576, max_output_tokens=32768, supports_vision=True, ), # Google models "gemini-2.5-pro": ModelConfig( name="gemini-2.5-pro", litellm_name="gemini/gemini-2.5-pro", provider=Provider.GOOGLE, cost_per_1m_input=1.25, cost_per_1m_output=10.0, context_window=1048576, max_output_tokens=65536, supports_vision=True, ), "gemini-2.0-flash": ModelConfig( name="gemini-2.0-flash", litellm_name="gemini/gemini-2.0-flash", provider=Provider.GOOGLE, cost_per_1m_input=0.1, cost_per_1m_output=0.4, context_window=1048576, max_output_tokens=8192, supports_vision=True, ), # Alibaba models "qwen-max": ModelConfig( name="qwen-max", litellm_name="alibaba/qwen-max", provider=Provider.ALIBABA, cost_per_1m_input=2.0, cost_per_1m_output=6.0, context_window=32768, max_output_tokens=8192, supports_vision=False, ), # DeepSeek models "deepseek-coder": ModelConfig( name="deepseek-coder", litellm_name="deepseek/deepseek-coder", provider=Provider.DEEPSEEK, cost_per_1m_input=0.14, cost_per_1m_output=0.28, context_window=128000, max_output_tokens=8192, supports_vision=False, ), "deepseek-chat": ModelConfig( name="deepseek-chat", litellm_name="deepseek/deepseek-chat", provider=Provider.DEEPSEEK, cost_per_1m_input=0.14, cost_per_1m_output=0.28, context_window=128000, max_output_tokens=8192, supports_vision=False, ), # Embedding models "text-embedding-3-large": ModelConfig( name="text-embedding-3-large", litellm_name="openai/text-embedding-3-large", provider=Provider.OPENAI, cost_per_1m_input=0.13, cost_per_1m_output=0.0, context_window=8191, max_output_tokens=0, supports_vision=False, supports_streaming=False, supports_function_calling=False, ), "voyage-3": ModelConfig( name="voyage-3", litellm_name="voyage/voyage-3", provider=Provider.ANTHROPIC, # Voyage via Anthropic partnership cost_per_1m_input=0.06, cost_per_1m_output=0.0, context_window=32000, max_output_tokens=0, supports_vision=False, supports_streaming=False, supports_function_calling=False, ), } @dataclass class ModelGroupConfig: """Configuration for a model group with failover chain.""" primary: str # Primary model name fallbacks: list[str] # Fallback model names in order description: str def get_all_models(self) -> list[str]: """Get all models in priority order.""" return [self.primary, *self.fallbacks] # Model group configurations per ADR-004 MODEL_GROUPS: dict[ModelGroup, ModelGroupConfig] = { ModelGroup.REASONING: ModelGroupConfig( primary="claude-opus-4", fallbacks=["gpt-4.1", "gemini-2.5-pro", "qwen-max"], description="Complex analysis, architecture decisions", ), ModelGroup.CODE: ModelGroupConfig( primary="claude-sonnet-4", fallbacks=["gpt-4.1", "deepseek-coder"], description="Code writing and refactoring", ), ModelGroup.FAST: ModelGroupConfig( primary="claude-haiku", fallbacks=["gpt-4.1-mini", "gemini-2.0-flash"], description="Quick tasks, simple queries", ), ModelGroup.VISION: ModelGroupConfig( primary="claude-sonnet-4", fallbacks=["gpt-4.1", "gemini-2.5-pro"], description="Multimodal image analysis", ), ModelGroup.EMBEDDING: ModelGroupConfig( primary="text-embedding-3-large", fallbacks=["voyage-3"], description="Vector embeddings", ), ModelGroup.COST_OPTIMIZED: ModelGroupConfig( primary="qwen-max", fallbacks=["deepseek-chat"], description="High-volume, non-critical tasks", ), ModelGroup.SELF_HOSTED: ModelGroupConfig( primary="deepseek-chat", fallbacks=["qwen-max"], description="Privacy-sensitive, air-gapped", ), } # Agent type to model group mapping per ADR-004 AGENT_TYPE_MODEL_PREFERENCES: dict[str, ModelGroup] = { "product_owner": ModelGroup.REASONING, "software_architect": ModelGroup.REASONING, "software_engineer": ModelGroup.CODE, "qa_engineer": ModelGroup.CODE, "devops_engineer": ModelGroup.FAST, "project_manager": ModelGroup.FAST, "business_analyst": ModelGroup.REASONING, } class ChatMessage(BaseModel): """A single chat message.""" role: str = Field(..., description="Message role: system, user, assistant, tool") content: str | list[dict[str, Any]] = Field(..., description="Message content") name: str | None = Field(default=None, description="Optional name for the message") tool_calls: list[dict[str, Any]] | None = Field( default=None, description="Tool calls if role is assistant" ) tool_call_id: str | None = Field( default=None, description="Tool call ID if role is tool" ) class CompletionRequest(BaseModel): """Request for chat completion.""" project_id: str = Field(..., description="Project ID for cost attribution") agent_id: str = Field(..., description="Agent ID making the request") messages: list[ChatMessage] = Field(..., description="Chat messages") model_group: ModelGroup = Field( default=ModelGroup.REASONING, description="Model group for routing" ) model_override: str | None = Field( default=None, description="Specific model to use (bypasses routing)" ) max_tokens: int = Field( default=4096, ge=1, le=32768, description="Max output tokens" ) temperature: float = Field( default=0.7, ge=0.0, le=2.0, description="Sampling temperature" ) stream: bool = Field(default=False, description="Enable streaming response") session_id: str | None = Field( default=None, description="Session ID for conversation tracking" ) metadata: dict[str, Any] = Field( default_factory=dict, description="Additional metadata" ) class UsageStats(BaseModel): """Token usage statistics.""" prompt_tokens: int = Field(default=0, description="Input tokens used") completion_tokens: int = Field(default=0, description="Output tokens generated") total_tokens: int = Field(default=0, description="Total tokens") cost_usd: float = Field(default=0.0, description="Estimated cost in USD") @classmethod def from_response( cls, prompt_tokens: int, completion_tokens: int, model_config: ModelConfig ) -> "UsageStats": """Create usage stats from token counts and model config.""" input_cost = (prompt_tokens / 1_000_000) * model_config.cost_per_1m_input output_cost = (completion_tokens / 1_000_000) * model_config.cost_per_1m_output return cls( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, cost_usd=round(input_cost + output_cost, 6), ) class CompletionResponse(BaseModel): """Response from chat completion.""" id: str = Field(..., description="Unique response ID") model: str = Field(..., description="Model that generated the response") provider: str = Field(..., description="Provider used") content: str = Field(..., description="Generated content") finish_reason: str = Field( default="stop", description="Reason for completion finish" ) usage: UsageStats = Field(default_factory=UsageStats, description="Token usage") created_at: datetime = Field( default_factory=lambda: datetime.now(UTC), description="Response timestamp" ) metadata: dict[str, Any] = Field( default_factory=dict, description="Additional metadata" ) class StreamChunk(BaseModel): """A chunk from a streaming response.""" id: str = Field(..., description="Chunk ID") delta: str = Field(default="", description="Content delta") finish_reason: str | None = Field(default=None, description="Finish reason if done") usage: UsageStats | None = Field( default=None, description="Usage stats (only on final chunk)" ) class EmbeddingRequest(BaseModel): """Request for text embeddings.""" project_id: str = Field(..., description="Project ID for cost attribution") agent_id: str = Field(..., description="Agent ID making the request") texts: list[str] = Field(..., min_length=1, description="Texts to embed") model: str = Field(default="text-embedding-3-large", description="Embedding model") class EmbeddingResponse(BaseModel): """Response from embedding generation.""" model: str = Field(..., description="Model used") embeddings: list[list[float]] = Field(..., description="Embedding vectors") usage: UsageStats = Field(default_factory=UsageStats, description="Token usage") @dataclass class CostRecord: """A single cost record for tracking.""" project_id: str agent_id: str model: str prompt_tokens: int completion_tokens: int cost_usd: float timestamp: datetime = field(default_factory=lambda: datetime.now(UTC)) session_id: str | None = None request_id: str | None = None class UsageReport(BaseModel): """Usage report for a project or agent.""" entity_id: str = Field(..., description="Project or agent ID") entity_type: str = Field(..., description="'project' or 'agent'") period: str = Field(..., description="Report period (hour, day, month)") period_start: datetime = Field(..., description="Period start time") period_end: datetime = Field(..., description="Period end time") total_requests: int = Field(default=0, description="Total requests") total_tokens: int = Field(default=0, description="Total tokens used") total_cost_usd: float = Field(default=0.0, description="Total cost in USD") by_model: dict[str, dict[str, Any]] = Field( default_factory=dict, description="Breakdown by model" ) by_agent: dict[str, dict[str, Any]] = Field( default_factory=dict, description="Breakdown by agent (for project reports)" ) class ModelInfo(BaseModel): """Information about an available model.""" name: str = Field(..., description="Model name") provider: str = Field(..., description="Provider name") cost_per_1m_input: float = Field(..., description="Input cost per 1M tokens") cost_per_1m_output: float = Field(..., description="Output cost per 1M tokens") context_window: int = Field(..., description="Max context tokens") max_output_tokens: int = Field(..., description="Max output tokens") supports_vision: bool = Field(..., description="Vision capability") supports_streaming: bool = Field(..., description="Streaming capability") supports_function_calling: bool = Field(..., description="Function calling") available: bool = Field(default=True, description="Provider configured") @classmethod def from_config(cls, config: ModelConfig, available: bool = True) -> "ModelInfo": """Create ModelInfo from ModelConfig.""" return cls( name=config.name, provider=config.provider.value, cost_per_1m_input=config.cost_per_1m_input, cost_per_1m_output=config.cost_per_1m_output, context_window=config.context_window, max_output_tokens=config.max_output_tokens, supports_vision=config.supports_vision, supports_streaming=config.supports_streaming, supports_function_calling=config.supports_function_calling, available=available, ) class ModelGroupInfo(BaseModel): """Information about a model group.""" name: str = Field(..., description="Group name") description: str = Field(..., description="Group description") primary_model: str = Field(..., description="Primary model") fallback_models: list[str] = Field(..., description="Fallback models") available: bool = Field(default=True, description="At least one model available")