forked from cardosofelipe/fast-next-template
Implements complete LLM Gateway MCP Server with: - FastMCP server with 4 tools: chat_completion, list_models, get_usage, count_tokens - LiteLLM Router with multi-provider failover chains - Circuit breaker pattern for fault tolerance - Redis-based cost tracking per project/agent - Comprehensive test suite (209 tests, 92% coverage) Model groups defined per ADR-004: - reasoning: claude-opus-4 → gpt-4.1 → gemini-2.5-pro - code: claude-sonnet-4 → gpt-4.1 → deepseek-coder - fast: claude-haiku → gpt-4.1-mini → gemini-2.0-flash 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
443 lines
15 KiB
Python
443 lines
15 KiB
Python
"""
|
|
Data models for LLM Gateway MCP Server.
|
|
|
|
Defines model groups, pricing, request/response structures.
|
|
Per ADR-004: LLM Provider Abstraction.
|
|
"""
|
|
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
from enum import Enum
|
|
from typing import Any
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
class ModelGroup(str, Enum):
|
|
"""Model groups for routing LLM requests."""
|
|
|
|
REASONING = "reasoning" # Complex analysis, architecture decisions
|
|
CODE = "code" # Code writing and refactoring
|
|
FAST = "fast" # Quick tasks, simple queries
|
|
VISION = "vision" # Multimodal image analysis
|
|
EMBEDDING = "embedding" # Vector embeddings
|
|
COST_OPTIMIZED = "cost_optimized" # High-volume, non-critical
|
|
SELF_HOSTED = "self_hosted" # Privacy-sensitive, air-gapped
|
|
|
|
# Aliases for backward compatibility with ADR-004
|
|
HIGH_REASONING = "reasoning"
|
|
CODE_GENERATION = "code"
|
|
FAST_RESPONSE = "fast"
|
|
|
|
|
|
class Provider(str, Enum):
|
|
"""Supported LLM providers."""
|
|
|
|
ANTHROPIC = "anthropic"
|
|
OPENAI = "openai"
|
|
GOOGLE = "google"
|
|
ALIBABA = "alibaba"
|
|
DEEPSEEK = "deepseek"
|
|
|
|
|
|
@dataclass
|
|
class ModelConfig:
|
|
"""Configuration for a specific model."""
|
|
|
|
name: str # Model identifier (e.g., "claude-3-opus-20240229")
|
|
litellm_name: str # LiteLLM model string (e.g., "anthropic/claude-3-opus-20240229")
|
|
provider: Provider
|
|
cost_per_1m_input: float # USD per 1M input tokens
|
|
cost_per_1m_output: float # USD per 1M output tokens
|
|
context_window: int # Max context tokens
|
|
max_output_tokens: int # Max output tokens
|
|
supports_vision: bool = False
|
|
supports_streaming: bool = True
|
|
supports_function_calling: bool = True
|
|
|
|
|
|
# Model configurations per ADR-004
|
|
MODEL_CONFIGS: dict[str, ModelConfig] = {
|
|
# Anthropic models
|
|
"claude-opus-4": ModelConfig(
|
|
name="claude-opus-4",
|
|
litellm_name="anthropic/claude-sonnet-4-20250514", # Using sonnet-4 as opus-4 placeholder
|
|
provider=Provider.ANTHROPIC,
|
|
cost_per_1m_input=15.0,
|
|
cost_per_1m_output=75.0,
|
|
context_window=200000,
|
|
max_output_tokens=8192,
|
|
supports_vision=True,
|
|
),
|
|
"claude-sonnet-4": ModelConfig(
|
|
name="claude-sonnet-4",
|
|
litellm_name="anthropic/claude-sonnet-4-20250514",
|
|
provider=Provider.ANTHROPIC,
|
|
cost_per_1m_input=3.0,
|
|
cost_per_1m_output=15.0,
|
|
context_window=200000,
|
|
max_output_tokens=8192,
|
|
supports_vision=True,
|
|
),
|
|
"claude-haiku": ModelConfig(
|
|
name="claude-haiku",
|
|
litellm_name="anthropic/claude-3-5-haiku-20241022",
|
|
provider=Provider.ANTHROPIC,
|
|
cost_per_1m_input=1.0,
|
|
cost_per_1m_output=5.0,
|
|
context_window=200000,
|
|
max_output_tokens=8192,
|
|
supports_vision=True,
|
|
),
|
|
# OpenAI models
|
|
"gpt-4.1": ModelConfig(
|
|
name="gpt-4.1",
|
|
litellm_name="openai/gpt-4.1",
|
|
provider=Provider.OPENAI,
|
|
cost_per_1m_input=2.0,
|
|
cost_per_1m_output=8.0,
|
|
context_window=1047576,
|
|
max_output_tokens=32768,
|
|
supports_vision=True,
|
|
),
|
|
"gpt-4.1-mini": ModelConfig(
|
|
name="gpt-4.1-mini",
|
|
litellm_name="openai/gpt-4.1-mini",
|
|
provider=Provider.OPENAI,
|
|
cost_per_1m_input=0.4,
|
|
cost_per_1m_output=1.6,
|
|
context_window=1047576,
|
|
max_output_tokens=32768,
|
|
supports_vision=True,
|
|
),
|
|
# Google models
|
|
"gemini-2.5-pro": ModelConfig(
|
|
name="gemini-2.5-pro",
|
|
litellm_name="gemini/gemini-2.5-pro",
|
|
provider=Provider.GOOGLE,
|
|
cost_per_1m_input=1.25,
|
|
cost_per_1m_output=10.0,
|
|
context_window=1048576,
|
|
max_output_tokens=65536,
|
|
supports_vision=True,
|
|
),
|
|
"gemini-2.0-flash": ModelConfig(
|
|
name="gemini-2.0-flash",
|
|
litellm_name="gemini/gemini-2.0-flash",
|
|
provider=Provider.GOOGLE,
|
|
cost_per_1m_input=0.1,
|
|
cost_per_1m_output=0.4,
|
|
context_window=1048576,
|
|
max_output_tokens=8192,
|
|
supports_vision=True,
|
|
),
|
|
# Alibaba models
|
|
"qwen-max": ModelConfig(
|
|
name="qwen-max",
|
|
litellm_name="alibaba/qwen-max",
|
|
provider=Provider.ALIBABA,
|
|
cost_per_1m_input=2.0,
|
|
cost_per_1m_output=6.0,
|
|
context_window=32768,
|
|
max_output_tokens=8192,
|
|
supports_vision=False,
|
|
),
|
|
# DeepSeek models
|
|
"deepseek-coder": ModelConfig(
|
|
name="deepseek-coder",
|
|
litellm_name="deepseek/deepseek-coder",
|
|
provider=Provider.DEEPSEEK,
|
|
cost_per_1m_input=0.14,
|
|
cost_per_1m_output=0.28,
|
|
context_window=128000,
|
|
max_output_tokens=8192,
|
|
supports_vision=False,
|
|
),
|
|
"deepseek-chat": ModelConfig(
|
|
name="deepseek-chat",
|
|
litellm_name="deepseek/deepseek-chat",
|
|
provider=Provider.DEEPSEEK,
|
|
cost_per_1m_input=0.14,
|
|
cost_per_1m_output=0.28,
|
|
context_window=128000,
|
|
max_output_tokens=8192,
|
|
supports_vision=False,
|
|
),
|
|
# Embedding models
|
|
"text-embedding-3-large": ModelConfig(
|
|
name="text-embedding-3-large",
|
|
litellm_name="openai/text-embedding-3-large",
|
|
provider=Provider.OPENAI,
|
|
cost_per_1m_input=0.13,
|
|
cost_per_1m_output=0.0,
|
|
context_window=8191,
|
|
max_output_tokens=0,
|
|
supports_vision=False,
|
|
supports_streaming=False,
|
|
supports_function_calling=False,
|
|
),
|
|
"voyage-3": ModelConfig(
|
|
name="voyage-3",
|
|
litellm_name="voyage/voyage-3",
|
|
provider=Provider.ANTHROPIC, # Voyage via Anthropic partnership
|
|
cost_per_1m_input=0.06,
|
|
cost_per_1m_output=0.0,
|
|
context_window=32000,
|
|
max_output_tokens=0,
|
|
supports_vision=False,
|
|
supports_streaming=False,
|
|
supports_function_calling=False,
|
|
),
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class ModelGroupConfig:
|
|
"""Configuration for a model group with failover chain."""
|
|
|
|
primary: str # Primary model name
|
|
fallbacks: list[str] # Fallback model names in order
|
|
description: str
|
|
|
|
def get_all_models(self) -> list[str]:
|
|
"""Get all models in priority order."""
|
|
return [self.primary, *self.fallbacks]
|
|
|
|
|
|
# Model group configurations per ADR-004
|
|
MODEL_GROUPS: dict[ModelGroup, ModelGroupConfig] = {
|
|
ModelGroup.REASONING: ModelGroupConfig(
|
|
primary="claude-opus-4",
|
|
fallbacks=["gpt-4.1", "gemini-2.5-pro", "qwen-max"],
|
|
description="Complex analysis, architecture decisions",
|
|
),
|
|
ModelGroup.CODE: ModelGroupConfig(
|
|
primary="claude-sonnet-4",
|
|
fallbacks=["gpt-4.1", "deepseek-coder"],
|
|
description="Code writing and refactoring",
|
|
),
|
|
ModelGroup.FAST: ModelGroupConfig(
|
|
primary="claude-haiku",
|
|
fallbacks=["gpt-4.1-mini", "gemini-2.0-flash"],
|
|
description="Quick tasks, simple queries",
|
|
),
|
|
ModelGroup.VISION: ModelGroupConfig(
|
|
primary="claude-sonnet-4",
|
|
fallbacks=["gpt-4.1", "gemini-2.5-pro"],
|
|
description="Multimodal image analysis",
|
|
),
|
|
ModelGroup.EMBEDDING: ModelGroupConfig(
|
|
primary="text-embedding-3-large",
|
|
fallbacks=["voyage-3"],
|
|
description="Vector embeddings",
|
|
),
|
|
ModelGroup.COST_OPTIMIZED: ModelGroupConfig(
|
|
primary="qwen-max",
|
|
fallbacks=["deepseek-chat"],
|
|
description="High-volume, non-critical tasks",
|
|
),
|
|
ModelGroup.SELF_HOSTED: ModelGroupConfig(
|
|
primary="deepseek-chat",
|
|
fallbacks=["qwen-max"],
|
|
description="Privacy-sensitive, air-gapped",
|
|
),
|
|
}
|
|
|
|
|
|
# Agent type to model group mapping per ADR-004
|
|
AGENT_TYPE_MODEL_PREFERENCES: dict[str, ModelGroup] = {
|
|
"product_owner": ModelGroup.REASONING,
|
|
"software_architect": ModelGroup.REASONING,
|
|
"software_engineer": ModelGroup.CODE,
|
|
"qa_engineer": ModelGroup.CODE,
|
|
"devops_engineer": ModelGroup.FAST,
|
|
"project_manager": ModelGroup.FAST,
|
|
"business_analyst": ModelGroup.REASONING,
|
|
}
|
|
|
|
|
|
class ChatMessage(BaseModel):
|
|
"""A single chat message."""
|
|
|
|
role: str = Field(..., description="Message role: system, user, assistant, tool")
|
|
content: str | list[dict[str, Any]] = Field(..., description="Message content")
|
|
name: str | None = Field(default=None, description="Optional name for the message")
|
|
tool_calls: list[dict[str, Any]] | None = Field(
|
|
default=None, description="Tool calls if role is assistant"
|
|
)
|
|
tool_call_id: str | None = Field(
|
|
default=None, description="Tool call ID if role is tool"
|
|
)
|
|
|
|
|
|
class CompletionRequest(BaseModel):
|
|
"""Request for chat completion."""
|
|
|
|
project_id: str = Field(..., description="Project ID for cost attribution")
|
|
agent_id: str = Field(..., description="Agent ID making the request")
|
|
messages: list[ChatMessage] = Field(..., description="Chat messages")
|
|
model_group: ModelGroup = Field(
|
|
default=ModelGroup.REASONING, description="Model group for routing"
|
|
)
|
|
model_override: str | None = Field(
|
|
default=None, description="Specific model to use (bypasses routing)"
|
|
)
|
|
max_tokens: int = Field(default=4096, ge=1, le=32768, description="Max output tokens")
|
|
temperature: float = Field(
|
|
default=0.7, ge=0.0, le=2.0, description="Sampling temperature"
|
|
)
|
|
stream: bool = Field(default=False, description="Enable streaming response")
|
|
session_id: str | None = Field(
|
|
default=None, description="Session ID for conversation tracking"
|
|
)
|
|
metadata: dict[str, Any] = Field(
|
|
default_factory=dict, description="Additional metadata"
|
|
)
|
|
|
|
|
|
class UsageStats(BaseModel):
|
|
"""Token usage statistics."""
|
|
|
|
prompt_tokens: int = Field(default=0, description="Input tokens used")
|
|
completion_tokens: int = Field(default=0, description="Output tokens generated")
|
|
total_tokens: int = Field(default=0, description="Total tokens")
|
|
cost_usd: float = Field(default=0.0, description="Estimated cost in USD")
|
|
|
|
@classmethod
|
|
def from_response(
|
|
cls, prompt_tokens: int, completion_tokens: int, model_config: ModelConfig
|
|
) -> "UsageStats":
|
|
"""Create usage stats from token counts and model config."""
|
|
input_cost = (prompt_tokens / 1_000_000) * model_config.cost_per_1m_input
|
|
output_cost = (completion_tokens / 1_000_000) * model_config.cost_per_1m_output
|
|
return cls(
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=prompt_tokens + completion_tokens,
|
|
cost_usd=round(input_cost + output_cost, 6),
|
|
)
|
|
|
|
|
|
class CompletionResponse(BaseModel):
|
|
"""Response from chat completion."""
|
|
|
|
id: str = Field(..., description="Unique response ID")
|
|
model: str = Field(..., description="Model that generated the response")
|
|
provider: str = Field(..., description="Provider used")
|
|
content: str = Field(..., description="Generated content")
|
|
finish_reason: str = Field(
|
|
default="stop", description="Reason for completion finish"
|
|
)
|
|
usage: UsageStats = Field(default_factory=UsageStats, description="Token usage")
|
|
created_at: datetime = Field(
|
|
default_factory=datetime.utcnow, description="Response timestamp"
|
|
)
|
|
metadata: dict[str, Any] = Field(
|
|
default_factory=dict, description="Additional metadata"
|
|
)
|
|
|
|
|
|
class StreamChunk(BaseModel):
|
|
"""A chunk from a streaming response."""
|
|
|
|
id: str = Field(..., description="Chunk ID")
|
|
delta: str = Field(default="", description="Content delta")
|
|
finish_reason: str | None = Field(default=None, description="Finish reason if done")
|
|
usage: UsageStats | None = Field(
|
|
default=None, description="Usage stats (only on final chunk)"
|
|
)
|
|
|
|
|
|
class EmbeddingRequest(BaseModel):
|
|
"""Request for text embeddings."""
|
|
|
|
project_id: str = Field(..., description="Project ID for cost attribution")
|
|
agent_id: str = Field(..., description="Agent ID making the request")
|
|
texts: list[str] = Field(..., min_length=1, description="Texts to embed")
|
|
model: str = Field(
|
|
default="text-embedding-3-large", description="Embedding model"
|
|
)
|
|
|
|
|
|
class EmbeddingResponse(BaseModel):
|
|
"""Response from embedding generation."""
|
|
|
|
model: str = Field(..., description="Model used")
|
|
embeddings: list[list[float]] = Field(..., description="Embedding vectors")
|
|
usage: UsageStats = Field(default_factory=UsageStats, description="Token usage")
|
|
|
|
|
|
@dataclass
|
|
class CostRecord:
|
|
"""A single cost record for tracking."""
|
|
|
|
project_id: str
|
|
agent_id: str
|
|
model: str
|
|
prompt_tokens: int
|
|
completion_tokens: int
|
|
cost_usd: float
|
|
timestamp: datetime = field(default_factory=datetime.utcnow)
|
|
session_id: str | None = None
|
|
request_id: str | None = None
|
|
|
|
|
|
class UsageReport(BaseModel):
|
|
"""Usage report for a project or agent."""
|
|
|
|
entity_id: str = Field(..., description="Project or agent ID")
|
|
entity_type: str = Field(..., description="'project' or 'agent'")
|
|
period: str = Field(..., description="Report period (hour, day, month)")
|
|
period_start: datetime = Field(..., description="Period start time")
|
|
period_end: datetime = Field(..., description="Period end time")
|
|
total_requests: int = Field(default=0, description="Total requests")
|
|
total_tokens: int = Field(default=0, description="Total tokens used")
|
|
total_cost_usd: float = Field(default=0.0, description="Total cost in USD")
|
|
by_model: dict[str, dict[str, Any]] = Field(
|
|
default_factory=dict, description="Breakdown by model"
|
|
)
|
|
by_agent: dict[str, dict[str, Any]] = Field(
|
|
default_factory=dict, description="Breakdown by agent (for project reports)"
|
|
)
|
|
|
|
|
|
class ModelInfo(BaseModel):
|
|
"""Information about an available model."""
|
|
|
|
name: str = Field(..., description="Model name")
|
|
provider: str = Field(..., description="Provider name")
|
|
cost_per_1m_input: float = Field(..., description="Input cost per 1M tokens")
|
|
cost_per_1m_output: float = Field(..., description="Output cost per 1M tokens")
|
|
context_window: int = Field(..., description="Max context tokens")
|
|
max_output_tokens: int = Field(..., description="Max output tokens")
|
|
supports_vision: bool = Field(..., description="Vision capability")
|
|
supports_streaming: bool = Field(..., description="Streaming capability")
|
|
supports_function_calling: bool = Field(..., description="Function calling")
|
|
available: bool = Field(default=True, description="Provider configured")
|
|
|
|
@classmethod
|
|
def from_config(cls, config: ModelConfig, available: bool = True) -> "ModelInfo":
|
|
"""Create ModelInfo from ModelConfig."""
|
|
return cls(
|
|
name=config.name,
|
|
provider=config.provider.value,
|
|
cost_per_1m_input=config.cost_per_1m_input,
|
|
cost_per_1m_output=config.cost_per_1m_output,
|
|
context_window=config.context_window,
|
|
max_output_tokens=config.max_output_tokens,
|
|
supports_vision=config.supports_vision,
|
|
supports_streaming=config.supports_streaming,
|
|
supports_function_calling=config.supports_function_calling,
|
|
available=available,
|
|
)
|
|
|
|
|
|
class ModelGroupInfo(BaseModel):
|
|
"""Information about a model group."""
|
|
|
|
name: str = Field(..., description="Group name")
|
|
description: str = Field(..., description="Group description")
|
|
primary_model: str = Field(..., description="Primary model")
|
|
fallback_models: list[str] = Field(..., description="Fallback models")
|
|
available: bool = Field(default=True, description="At least one model available")
|