Files
Felipe Cardoso f482559e15 fix(llm-gateway): improve type safety and datetime consistency
- Add type annotations for mypy compliance
- Use UTC-aware datetimes consistently (datetime.now(UTC))
- Add type: ignore comments for LiteLLM incomplete stubs
- Fix import ordering and formatting
- Update pyproject.toml mypy configuration

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-03 20:56:05 +01:00

680 lines
21 KiB
Python

"""
Syndarix LLM Gateway MCP Server.
Provides unified LLM access with:
- Multi-provider support (Claude, GPT, Gemini, Qwen, DeepSeek)
- Automatic failover chains
- Cost tracking via Redis
- Model group routing (reasoning, code, fast, vision, embedding)
- Circuit breaker protection
Per ADR-004: LLM Provider Abstraction.
"""
import logging
import uuid
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Any
import tiktoken
from fastapi import FastAPI
from fastmcp import FastMCP
from config import get_settings
from cost_tracking import calculate_cost, get_cost_tracker
from exceptions import (
AllProvidersFailedError,
CircuitOpenError,
CostLimitExceededError,
InvalidModelError,
InvalidModelGroupError,
LLMGatewayError,
ModelNotAvailableError,
)
from failover import get_circuit_registry
from models import (
MODEL_CONFIGS,
MODEL_GROUPS,
ModelGroup,
)
from providers import get_provider
from routing import get_model_router
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# Create FastMCP server
mcp = FastMCP("syndarix-llm-gateway")
@asynccontextmanager
async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
"""Application lifespan handler."""
settings = get_settings()
logger.info(f"Starting LLM Gateway on {settings.host}:{settings.port}")
# Initialize provider
provider = get_provider()
provider.initialize()
yield
# Cleanup
from cost_tracking import close_cost_tracker
await close_cost_tracker()
logger.info("LLM Gateway shutdown complete")
# Create FastAPI app that wraps FastMCP
app = FastAPI(
title="Syndarix LLM Gateway",
description="MCP Server for unified LLM access",
version="0.1.0",
lifespan=lifespan,
)
# Health endpoint
@app.get("/health")
async def health_check() -> dict[str, Any]:
"""Health check endpoint."""
settings = get_settings()
provider = get_provider()
return {
"status": "healthy",
"service": "llm-gateway",
"providers_configured": settings.get_available_providers(),
"provider_available": provider.is_available,
}
# Tool discovery endpoint (for MCP client compatibility)
@app.get("/mcp/tools")
async def list_tools() -> dict[str, Any]:
"""List available MCP tools."""
return {
"tools": [
{
"name": "chat_completion",
"description": "Generate a chat completion using the specified model group",
"inputSchema": {
"type": "object",
"properties": {
"project_id": {"type": "string", "description": "Project ID"},
"agent_id": {"type": "string", "description": "Agent ID"},
"messages": {
"type": "array",
"items": {
"type": "object",
"properties": {
"role": {"type": "string"},
"content": {"type": "string"},
},
"required": ["role", "content"],
},
},
"model_group": {
"type": "string",
"enum": [g.value for g in ModelGroup],
"default": "reasoning",
},
"max_tokens": {"type": "integer", "default": 4096},
"temperature": {"type": "number", "default": 0.7},
"stream": {"type": "boolean", "default": False},
},
"required": ["project_id", "agent_id", "messages"],
},
},
{
"name": "list_models",
"description": "List available models and model groups",
"inputSchema": {
"type": "object",
"properties": {
"project_id": {"type": "string"},
"agent_id": {"type": "string"},
"model_group": {"type": "string"},
},
"required": ["project_id", "agent_id"],
},
},
{
"name": "get_usage",
"description": "Get usage statistics for a project or agent",
"inputSchema": {
"type": "object",
"properties": {
"project_id": {"type": "string"},
"agent_id": {"type": "string"},
"period": {
"type": "string",
"enum": ["hour", "day", "month"],
"default": "day",
},
},
"required": ["project_id", "agent_id"],
},
},
{
"name": "count_tokens",
"description": "Count tokens in text",
"inputSchema": {
"type": "object",
"properties": {
"project_id": {"type": "string"},
"agent_id": {"type": "string"},
"text": {"type": "string"},
"model": {"type": "string"},
},
"required": ["project_id", "agent_id", "text"],
},
},
]
}
# JSON-RPC endpoint (for MCP client compatibility)
@app.post("/mcp")
async def jsonrpc_handler(request: dict[str, Any]) -> dict[str, Any]:
"""Handle JSON-RPC 2.0 requests for MCP tools."""
# Validate JSON-RPC structure
if request.get("jsonrpc") != "2.0":
return {
"jsonrpc": "2.0",
"error": {"code": -32600, "message": "Invalid JSON-RPC version"},
"id": request.get("id"),
}
method = request.get("method")
params = request.get("params", {})
request_id = request.get("id")
# Handle tool calls
if method == "tools/call":
tool_name = params.get("name")
tool_args = params.get("arguments", {})
try:
if tool_name == "chat_completion":
result = await _impl_chat_completion(**tool_args)
elif tool_name == "list_models":
result = await _impl_list_models(**tool_args)
elif tool_name == "get_usage":
result = await _impl_get_usage(**tool_args)
elif tool_name == "count_tokens":
result = await _impl_count_tokens(**tool_args)
else:
return {
"jsonrpc": "2.0",
"error": {"code": -32601, "message": f"Unknown tool: {tool_name}"},
"id": request_id,
}
return {
"jsonrpc": "2.0",
"result": {"content": [{"type": "text", "text": str(result)}]},
"id": request_id,
}
except LLMGatewayError as e:
return {
"jsonrpc": "2.0",
"error": {"code": -32000, "message": str(e), "data": e.to_dict()},
"id": request_id,
}
except Exception as e:
logger.exception(f"Error executing tool {tool_name}")
return {
"jsonrpc": "2.0",
"error": {"code": -32603, "message": str(e)},
"id": request_id,
}
# Handle tool listing
elif method == "tools/list":
tools_response = await list_tools()
return {
"jsonrpc": "2.0",
"result": tools_response,
"id": request_id,
}
else:
return {
"jsonrpc": "2.0",
"error": {"code": -32601, "message": f"Unknown method: {method}"},
"id": request_id,
}
# ============================================================================
# Core Implementation Functions
# ============================================================================
async def _impl_chat_completion(
project_id: str,
agent_id: str,
messages: list[dict[str, Any]],
model_group: str = "reasoning",
model_override: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
stream: bool = False,
session_id: str | None = None,
) -> dict[str, Any]:
"""Core implementation for chat completion."""
settings = get_settings()
router = get_model_router()
tracker = get_cost_tracker()
circuit_registry = get_circuit_registry()
# Check budget before making request
if settings.cost_tracking_enabled:
within_budget, current_cost, limit = await tracker.check_budget(project_id)
if not within_budget:
raise CostLimitExceededError(
entity_type="project",
entity_id=project_id,
current_cost=current_cost,
limit=limit,
)
# Select model
try:
model_name, model_config = await router.select_model(
model_group=model_group,
model_override=model_override,
)
except (InvalidModelGroupError, InvalidModelError, AllProvidersFailedError):
raise
except ModelNotAvailableError:
raise
# Get provider
provider = get_provider()
if not provider.router:
raise AllProvidersFailedError(
model_group=model_group,
attempted_models=[model_name],
errors=[{"error": "No providers configured"}],
)
# Generate request ID
request_id = str(uuid.uuid4())
try:
# Get circuit breaker for this provider
circuit = circuit_registry.get_circuit_sync(model_config.provider.value)
# Make completion request
if stream:
# Return streaming response info
# Actual streaming would be handled by a separate endpoint
return {
"status": "streaming_not_supported_via_tool",
"message": "Use /stream endpoint for streaming responses",
"request_id": request_id,
}
# Non-streaming completion
response = await provider.router.acompletion(
model=model_name,
messages=messages, # type: ignore[arg-type]
max_tokens=max_tokens,
temperature=temperature,
)
# Record success
await circuit.record_success()
# Extract response data
content = response.choices[0].message.content or "" # type: ignore[union-attr]
finish_reason = response.choices[0].finish_reason or "stop"
# Get usage stats
prompt_tokens = response.usage.prompt_tokens if response.usage else 0 # type: ignore[attr-defined]
completion_tokens = response.usage.completion_tokens if response.usage else 0 # type: ignore[attr-defined]
# Calculate cost
cost_usd = calculate_cost(model_name, prompt_tokens, completion_tokens)
# Record usage
await tracker.record_usage(
project_id=project_id,
agent_id=agent_id,
model=model_name,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
cost_usd=cost_usd,
session_id=session_id,
request_id=request_id,
)
return {
"id": request_id,
"model": model_name,
"provider": model_config.provider.value,
"content": content,
"finish_reason": finish_reason,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
"cost_usd": cost_usd,
},
}
except CircuitOpenError:
raise
except Exception as e:
# Record failure
circuit = circuit_registry.get_circuit_sync(model_config.provider.value)
await circuit.record_failure(e)
logger.error(f"Completion failed: {e}")
raise AllProvidersFailedError(
model_group=model_group,
attempted_models=[model_name],
errors=[{"model": model_name, "error": str(e)}],
)
async def _impl_list_models(
project_id: str,
agent_id: str,
model_group: str | None = None,
) -> dict[str, Any]:
"""Core implementation for list_models."""
settings = get_settings()
provider = get_provider()
router = get_model_router()
# Get available providers
available_providers = settings.get_available_providers()
result: dict[str, Any] = {
"project_id": project_id,
"agent_id": agent_id,
"available_providers": available_providers,
}
if model_group:
# List models for specific group
try:
parsed_group = router.parse_model_group(model_group)
models = await router.get_available_models_for_group(parsed_group)
result["model_group"] = model_group
result["models"] = [
{
"name": name,
"provider": config.provider.value,
"available": available,
"cost_per_1m_input": config.cost_per_1m_input,
"cost_per_1m_output": config.cost_per_1m_output,
}
for name, config, available in models
]
except InvalidModelGroupError as e:
result["error"] = e.to_dict()
else:
# List all model groups
groups: dict[str, Any] = {}
for group in ModelGroup:
group_config = MODEL_GROUPS.get(group)
if group_config:
models = await router.get_available_models_for_group(group)
available_count = sum(1 for _, _, avail in models if avail)
groups[group.value] = {
"description": group_config.description,
"primary": group_config.primary,
"fallbacks": group_config.fallbacks,
"available_models": available_count,
"total_models": len(models),
}
result["model_groups"] = groups
# List all models
all_models: list[dict[str, Any]] = []
available_models = provider.get_available_models()
for name, config in MODEL_CONFIGS.items():
all_models.append(
{
"name": name,
"provider": config.provider.value,
"available": name in available_models,
"cost_per_1m_input": config.cost_per_1m_input,
"cost_per_1m_output": config.cost_per_1m_output,
"context_window": config.context_window,
"max_output_tokens": config.max_output_tokens,
"supports_vision": config.supports_vision,
"supports_streaming": config.supports_streaming,
}
)
result["models"] = all_models
return result
async def _impl_get_usage(
project_id: str,
agent_id: str,
period: str = "day",
) -> dict[str, Any]:
"""Core implementation for get_usage."""
tracker = get_cost_tracker()
# Get project usage
project_report = await tracker.get_project_usage(project_id, period=period)
# Get agent usage
agent_report = await tracker.get_agent_usage(agent_id, period=period)
return {
"project_id": project_id,
"agent_id": agent_id,
"period": period,
"project_usage": {
"total_requests": project_report.total_requests,
"total_tokens": project_report.total_tokens,
"total_cost_usd": project_report.total_cost_usd,
"by_model": project_report.by_model,
"period_start": project_report.period_start.isoformat(),
"period_end": project_report.period_end.isoformat(),
},
"agent_usage": {
"total_requests": agent_report.total_requests,
"total_tokens": agent_report.total_tokens,
"total_cost_usd": agent_report.total_cost_usd,
"by_model": agent_report.by_model,
"period_start": agent_report.period_start.isoformat(),
"period_end": agent_report.period_end.isoformat(),
},
}
async def _impl_count_tokens(
project_id: str,
agent_id: str,
text: str,
model: str | None = None,
) -> dict[str, Any]:
"""Core implementation for count_tokens."""
# Use tiktoken for token counting
# Default to cl100k_base (used by GPT-4, Claude, etc.)
try:
if model and model.startswith("gpt"):
encoding = tiktoken.encoding_for_model(model)
else:
encoding = tiktoken.get_encoding("cl100k_base")
token_count = len(encoding.encode(text))
except Exception as e:
logger.warning(f"Token counting failed: {e}, using estimate")
# Fallback: rough estimate of ~4 chars per token
token_count = len(text) // 4
# Estimate costs for different models
cost_estimates: dict[str, float] = {}
for model_name, config in MODEL_CONFIGS.items():
if config.cost_per_1m_input > 0:
cost = (token_count / 1_000_000) * config.cost_per_1m_input
cost_estimates[model_name] = round(cost, 6)
return {
"project_id": project_id,
"agent_id": agent_id,
"token_count": token_count,
"text_length": len(text),
"encoding": "cl100k_base",
"cost_estimates": cost_estimates,
}
# ============================================================================
# MCP Tools (wrappers around core implementations)
# ============================================================================
@mcp.tool()
async def chat_completion(
project_id: str,
agent_id: str,
messages: list[dict[str, Any]],
model_group: str = "reasoning",
model_override: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
stream: bool = False,
session_id: str | None = None,
) -> dict[str, Any]:
"""
Generate a chat completion using the specified model group.
Args:
project_id: UUID of the project (required for cost attribution)
agent_id: UUID of the agent instance making the request
messages: List of message dicts with 'role' and 'content'
model_group: Model routing group (reasoning, code, fast, vision, embedding)
model_override: Specific model to use (bypasses group routing)
max_tokens: Maximum tokens to generate
temperature: Sampling temperature (0.0-2.0)
stream: Enable streaming response
session_id: Optional session ID for tracking
Returns:
Completion response with content and usage statistics
"""
return await _impl_chat_completion(
project_id=project_id,
agent_id=agent_id,
messages=messages,
model_group=model_group,
model_override=model_override,
max_tokens=max_tokens,
temperature=temperature,
stream=stream,
session_id=session_id,
)
@mcp.tool()
async def list_models(
project_id: str,
agent_id: str,
model_group: str | None = None,
) -> dict[str, Any]:
"""
List available models and model groups.
Args:
project_id: UUID of the project
agent_id: UUID of the agent instance
model_group: Optional specific group to list
Returns:
Dictionary of available models and groups
"""
return await _impl_list_models(
project_id=project_id,
agent_id=agent_id,
model_group=model_group,
)
@mcp.tool()
async def get_usage(
project_id: str,
agent_id: str,
period: str = "day",
) -> dict[str, Any]:
"""
Get usage statistics for a project or agent.
Args:
project_id: UUID of the project
agent_id: UUID of the agent
period: Time period (hour, day, month)
Returns:
Usage statistics including tokens and costs
"""
return await _impl_get_usage(
project_id=project_id,
agent_id=agent_id,
period=period,
)
@mcp.tool()
async def count_tokens(
project_id: str,
agent_id: str,
text: str,
model: str | None = None,
) -> dict[str, Any]:
"""
Count tokens in text.
Args:
project_id: UUID of the project
agent_id: UUID of the agent
text: Text to count tokens in
model: Optional model for tokenizer selection
Returns:
Token count and estimation details
"""
return await _impl_count_tokens(
project_id=project_id,
agent_id=agent_id,
text=text,
model=model,
)
def main() -> None:
"""Run the server."""
import uvicorn
settings = get_settings()
uvicorn.run(
"server:app",
host=settings.host,
port=settings.port,
reload=settings.debug,
)
if __name__ == "__main__":
main()