Files
syndarix/mcp-servers/llm-gateway/models.py
Felipe Cardoso f482559e15 fix(llm-gateway): improve type safety and datetime consistency
- Add type annotations for mypy compliance
- Use UTC-aware datetimes consistently (datetime.now(UTC))
- Add type: ignore comments for LiteLLM incomplete stubs
- Fix import ordering and formatting
- Update pyproject.toml mypy configuration

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-03 20:56:05 +01:00

443 lines
15 KiB
Python

"""
Data models for LLM Gateway MCP Server.
Defines model groups, pricing, request/response structures.
Per ADR-004: LLM Provider Abstraction.
"""
from dataclasses import dataclass, field
from datetime import UTC, datetime
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
class ModelGroup(str, Enum):
"""Model groups for routing LLM requests."""
REASONING = "reasoning" # Complex analysis, architecture decisions
CODE = "code" # Code writing and refactoring
FAST = "fast" # Quick tasks, simple queries
VISION = "vision" # Multimodal image analysis
EMBEDDING = "embedding" # Vector embeddings
COST_OPTIMIZED = "cost_optimized" # High-volume, non-critical
SELF_HOSTED = "self_hosted" # Privacy-sensitive, air-gapped
# Aliases for backward compatibility with ADR-004
HIGH_REASONING = "reasoning"
CODE_GENERATION = "code"
FAST_RESPONSE = "fast"
class Provider(str, Enum):
"""Supported LLM providers."""
ANTHROPIC = "anthropic"
OPENAI = "openai"
GOOGLE = "google"
ALIBABA = "alibaba"
DEEPSEEK = "deepseek"
@dataclass
class ModelConfig:
"""Configuration for a specific model."""
name: str # Model identifier (e.g., "claude-3-opus-20240229")
litellm_name: str # LiteLLM model string (e.g., "anthropic/claude-3-opus-20240229")
provider: Provider
cost_per_1m_input: float # USD per 1M input tokens
cost_per_1m_output: float # USD per 1M output tokens
context_window: int # Max context tokens
max_output_tokens: int # Max output tokens
supports_vision: bool = False
supports_streaming: bool = True
supports_function_calling: bool = True
# Model configurations per ADR-004
MODEL_CONFIGS: dict[str, ModelConfig] = {
# Anthropic models
"claude-opus-4": ModelConfig(
name="claude-opus-4",
litellm_name="anthropic/claude-sonnet-4-20250514", # Using sonnet-4 as opus-4 placeholder
provider=Provider.ANTHROPIC,
cost_per_1m_input=15.0,
cost_per_1m_output=75.0,
context_window=200000,
max_output_tokens=8192,
supports_vision=True,
),
"claude-sonnet-4": ModelConfig(
name="claude-sonnet-4",
litellm_name="anthropic/claude-sonnet-4-20250514",
provider=Provider.ANTHROPIC,
cost_per_1m_input=3.0,
cost_per_1m_output=15.0,
context_window=200000,
max_output_tokens=8192,
supports_vision=True,
),
"claude-haiku": ModelConfig(
name="claude-haiku",
litellm_name="anthropic/claude-3-5-haiku-20241022",
provider=Provider.ANTHROPIC,
cost_per_1m_input=1.0,
cost_per_1m_output=5.0,
context_window=200000,
max_output_tokens=8192,
supports_vision=True,
),
# OpenAI models
"gpt-4.1": ModelConfig(
name="gpt-4.1",
litellm_name="openai/gpt-4.1",
provider=Provider.OPENAI,
cost_per_1m_input=2.0,
cost_per_1m_output=8.0,
context_window=1047576,
max_output_tokens=32768,
supports_vision=True,
),
"gpt-4.1-mini": ModelConfig(
name="gpt-4.1-mini",
litellm_name="openai/gpt-4.1-mini",
provider=Provider.OPENAI,
cost_per_1m_input=0.4,
cost_per_1m_output=1.6,
context_window=1047576,
max_output_tokens=32768,
supports_vision=True,
),
# Google models
"gemini-2.5-pro": ModelConfig(
name="gemini-2.5-pro",
litellm_name="gemini/gemini-2.5-pro",
provider=Provider.GOOGLE,
cost_per_1m_input=1.25,
cost_per_1m_output=10.0,
context_window=1048576,
max_output_tokens=65536,
supports_vision=True,
),
"gemini-2.0-flash": ModelConfig(
name="gemini-2.0-flash",
litellm_name="gemini/gemini-2.0-flash",
provider=Provider.GOOGLE,
cost_per_1m_input=0.1,
cost_per_1m_output=0.4,
context_window=1048576,
max_output_tokens=8192,
supports_vision=True,
),
# Alibaba models
"qwen-max": ModelConfig(
name="qwen-max",
litellm_name="alibaba/qwen-max",
provider=Provider.ALIBABA,
cost_per_1m_input=2.0,
cost_per_1m_output=6.0,
context_window=32768,
max_output_tokens=8192,
supports_vision=False,
),
# DeepSeek models
"deepseek-coder": ModelConfig(
name="deepseek-coder",
litellm_name="deepseek/deepseek-coder",
provider=Provider.DEEPSEEK,
cost_per_1m_input=0.14,
cost_per_1m_output=0.28,
context_window=128000,
max_output_tokens=8192,
supports_vision=False,
),
"deepseek-chat": ModelConfig(
name="deepseek-chat",
litellm_name="deepseek/deepseek-chat",
provider=Provider.DEEPSEEK,
cost_per_1m_input=0.14,
cost_per_1m_output=0.28,
context_window=128000,
max_output_tokens=8192,
supports_vision=False,
),
# Embedding models
"text-embedding-3-large": ModelConfig(
name="text-embedding-3-large",
litellm_name="openai/text-embedding-3-large",
provider=Provider.OPENAI,
cost_per_1m_input=0.13,
cost_per_1m_output=0.0,
context_window=8191,
max_output_tokens=0,
supports_vision=False,
supports_streaming=False,
supports_function_calling=False,
),
"voyage-3": ModelConfig(
name="voyage-3",
litellm_name="voyage/voyage-3",
provider=Provider.ANTHROPIC, # Voyage via Anthropic partnership
cost_per_1m_input=0.06,
cost_per_1m_output=0.0,
context_window=32000,
max_output_tokens=0,
supports_vision=False,
supports_streaming=False,
supports_function_calling=False,
),
}
@dataclass
class ModelGroupConfig:
"""Configuration for a model group with failover chain."""
primary: str # Primary model name
fallbacks: list[str] # Fallback model names in order
description: str
def get_all_models(self) -> list[str]:
"""Get all models in priority order."""
return [self.primary, *self.fallbacks]
# Model group configurations per ADR-004
MODEL_GROUPS: dict[ModelGroup, ModelGroupConfig] = {
ModelGroup.REASONING: ModelGroupConfig(
primary="claude-opus-4",
fallbacks=["gpt-4.1", "gemini-2.5-pro", "qwen-max"],
description="Complex analysis, architecture decisions",
),
ModelGroup.CODE: ModelGroupConfig(
primary="claude-sonnet-4",
fallbacks=["gpt-4.1", "deepseek-coder"],
description="Code writing and refactoring",
),
ModelGroup.FAST: ModelGroupConfig(
primary="claude-haiku",
fallbacks=["gpt-4.1-mini", "gemini-2.0-flash"],
description="Quick tasks, simple queries",
),
ModelGroup.VISION: ModelGroupConfig(
primary="claude-sonnet-4",
fallbacks=["gpt-4.1", "gemini-2.5-pro"],
description="Multimodal image analysis",
),
ModelGroup.EMBEDDING: ModelGroupConfig(
primary="text-embedding-3-large",
fallbacks=["voyage-3"],
description="Vector embeddings",
),
ModelGroup.COST_OPTIMIZED: ModelGroupConfig(
primary="qwen-max",
fallbacks=["deepseek-chat"],
description="High-volume, non-critical tasks",
),
ModelGroup.SELF_HOSTED: ModelGroupConfig(
primary="deepseek-chat",
fallbacks=["qwen-max"],
description="Privacy-sensitive, air-gapped",
),
}
# Agent type to model group mapping per ADR-004
AGENT_TYPE_MODEL_PREFERENCES: dict[str, ModelGroup] = {
"product_owner": ModelGroup.REASONING,
"software_architect": ModelGroup.REASONING,
"software_engineer": ModelGroup.CODE,
"qa_engineer": ModelGroup.CODE,
"devops_engineer": ModelGroup.FAST,
"project_manager": ModelGroup.FAST,
"business_analyst": ModelGroup.REASONING,
}
class ChatMessage(BaseModel):
"""A single chat message."""
role: str = Field(..., description="Message role: system, user, assistant, tool")
content: str | list[dict[str, Any]] = Field(..., description="Message content")
name: str | None = Field(default=None, description="Optional name for the message")
tool_calls: list[dict[str, Any]] | None = Field(
default=None, description="Tool calls if role is assistant"
)
tool_call_id: str | None = Field(
default=None, description="Tool call ID if role is tool"
)
class CompletionRequest(BaseModel):
"""Request for chat completion."""
project_id: str = Field(..., description="Project ID for cost attribution")
agent_id: str = Field(..., description="Agent ID making the request")
messages: list[ChatMessage] = Field(..., description="Chat messages")
model_group: ModelGroup = Field(
default=ModelGroup.REASONING, description="Model group for routing"
)
model_override: str | None = Field(
default=None, description="Specific model to use (bypasses routing)"
)
max_tokens: int = Field(
default=4096, ge=1, le=32768, description="Max output tokens"
)
temperature: float = Field(
default=0.7, ge=0.0, le=2.0, description="Sampling temperature"
)
stream: bool = Field(default=False, description="Enable streaming response")
session_id: str | None = Field(
default=None, description="Session ID for conversation tracking"
)
metadata: dict[str, Any] = Field(
default_factory=dict, description="Additional metadata"
)
class UsageStats(BaseModel):
"""Token usage statistics."""
prompt_tokens: int = Field(default=0, description="Input tokens used")
completion_tokens: int = Field(default=0, description="Output tokens generated")
total_tokens: int = Field(default=0, description="Total tokens")
cost_usd: float = Field(default=0.0, description="Estimated cost in USD")
@classmethod
def from_response(
cls, prompt_tokens: int, completion_tokens: int, model_config: ModelConfig
) -> "UsageStats":
"""Create usage stats from token counts and model config."""
input_cost = (prompt_tokens / 1_000_000) * model_config.cost_per_1m_input
output_cost = (completion_tokens / 1_000_000) * model_config.cost_per_1m_output
return cls(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
cost_usd=round(input_cost + output_cost, 6),
)
class CompletionResponse(BaseModel):
"""Response from chat completion."""
id: str = Field(..., description="Unique response ID")
model: str = Field(..., description="Model that generated the response")
provider: str = Field(..., description="Provider used")
content: str = Field(..., description="Generated content")
finish_reason: str = Field(
default="stop", description="Reason for completion finish"
)
usage: UsageStats = Field(default_factory=UsageStats, description="Token usage")
created_at: datetime = Field(
default_factory=lambda: datetime.now(UTC), description="Response timestamp"
)
metadata: dict[str, Any] = Field(
default_factory=dict, description="Additional metadata"
)
class StreamChunk(BaseModel):
"""A chunk from a streaming response."""
id: str = Field(..., description="Chunk ID")
delta: str = Field(default="", description="Content delta")
finish_reason: str | None = Field(default=None, description="Finish reason if done")
usage: UsageStats | None = Field(
default=None, description="Usage stats (only on final chunk)"
)
class EmbeddingRequest(BaseModel):
"""Request for text embeddings."""
project_id: str = Field(..., description="Project ID for cost attribution")
agent_id: str = Field(..., description="Agent ID making the request")
texts: list[str] = Field(..., min_length=1, description="Texts to embed")
model: str = Field(default="text-embedding-3-large", description="Embedding model")
class EmbeddingResponse(BaseModel):
"""Response from embedding generation."""
model: str = Field(..., description="Model used")
embeddings: list[list[float]] = Field(..., description="Embedding vectors")
usage: UsageStats = Field(default_factory=UsageStats, description="Token usage")
@dataclass
class CostRecord:
"""A single cost record for tracking."""
project_id: str
agent_id: str
model: str
prompt_tokens: int
completion_tokens: int
cost_usd: float
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
session_id: str | None = None
request_id: str | None = None
class UsageReport(BaseModel):
"""Usage report for a project or agent."""
entity_id: str = Field(..., description="Project or agent ID")
entity_type: str = Field(..., description="'project' or 'agent'")
period: str = Field(..., description="Report period (hour, day, month)")
period_start: datetime = Field(..., description="Period start time")
period_end: datetime = Field(..., description="Period end time")
total_requests: int = Field(default=0, description="Total requests")
total_tokens: int = Field(default=0, description="Total tokens used")
total_cost_usd: float = Field(default=0.0, description="Total cost in USD")
by_model: dict[str, dict[str, Any]] = Field(
default_factory=dict, description="Breakdown by model"
)
by_agent: dict[str, dict[str, Any]] = Field(
default_factory=dict, description="Breakdown by agent (for project reports)"
)
class ModelInfo(BaseModel):
"""Information about an available model."""
name: str = Field(..., description="Model name")
provider: str = Field(..., description="Provider name")
cost_per_1m_input: float = Field(..., description="Input cost per 1M tokens")
cost_per_1m_output: float = Field(..., description="Output cost per 1M tokens")
context_window: int = Field(..., description="Max context tokens")
max_output_tokens: int = Field(..., description="Max output tokens")
supports_vision: bool = Field(..., description="Vision capability")
supports_streaming: bool = Field(..., description="Streaming capability")
supports_function_calling: bool = Field(..., description="Function calling")
available: bool = Field(default=True, description="Provider configured")
@classmethod
def from_config(cls, config: ModelConfig, available: bool = True) -> "ModelInfo":
"""Create ModelInfo from ModelConfig."""
return cls(
name=config.name,
provider=config.provider.value,
cost_per_1m_input=config.cost_per_1m_input,
cost_per_1m_output=config.cost_per_1m_output,
context_window=config.context_window,
max_output_tokens=config.max_output_tokens,
supports_vision=config.supports_vision,
supports_streaming=config.supports_streaming,
supports_function_calling=config.supports_function_calling,
available=available,
)
class ModelGroupInfo(BaseModel):
"""Information about a model group."""
name: str = Field(..., description="Group name")
description: str = Field(..., description="Group description")
primary_model: str = Field(..., description="Primary model")
fallback_models: list[str] = Field(..., description="Fallback models")
available: bool = Field(default=True, description="At least one model available")