Files
syndarix/mcp-servers/llm-gateway/server.py
Felipe Cardoso 6e8b0b022a feat(llm-gateway): implement LLM Gateway MCP Server (#56)
Implements complete LLM Gateway MCP Server with:
- FastMCP server with 4 tools: chat_completion, list_models, get_usage, count_tokens
- LiteLLM Router with multi-provider failover chains
- Circuit breaker pattern for fault tolerance
- Redis-based cost tracking per project/agent
- Comprehensive test suite (209 tests, 92% coverage)

Model groups defined per ADR-004:
- reasoning: claude-opus-4 → gpt-4.1 → gemini-2.5-pro
- code: claude-sonnet-4 → gpt-4.1 → deepseek-coder
- fast: claude-haiku → gpt-4.1-mini → gemini-2.0-flash

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-03 20:31:19 +01:00

676 lines
21 KiB
Python

"""
Syndarix LLM Gateway MCP Server.
Provides unified LLM access with:
- Multi-provider support (Claude, GPT, Gemini, Qwen, DeepSeek)
- Automatic failover chains
- Cost tracking via Redis
- Model group routing (reasoning, code, fast, vision, embedding)
- Circuit breaker protection
Per ADR-004: LLM Provider Abstraction.
"""
import logging
import uuid
from contextlib import asynccontextmanager
from typing import Any
import tiktoken
from fastapi import FastAPI
from fastmcp import FastMCP
from config import get_settings
from cost_tracking import calculate_cost, get_cost_tracker
from exceptions import (
AllProvidersFailedError,
CircuitOpenError,
CostLimitExceededError,
InvalidModelError,
InvalidModelGroupError,
LLMGatewayError,
ModelNotAvailableError,
)
from failover import get_circuit_registry
from models import (
MODEL_CONFIGS,
MODEL_GROUPS,
ModelGroup,
)
from providers import get_provider
from routing import get_model_router
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# Create FastMCP server
mcp = FastMCP("syndarix-llm-gateway")
@asynccontextmanager
async def lifespan(_app: FastAPI):
"""Application lifespan handler."""
settings = get_settings()
logger.info(f"Starting LLM Gateway on {settings.host}:{settings.port}")
# Initialize provider
provider = get_provider()
provider.initialize()
yield
# Cleanup
from cost_tracking import close_cost_tracker
await close_cost_tracker()
logger.info("LLM Gateway shutdown complete")
# Create FastAPI app that wraps FastMCP
app = FastAPI(
title="Syndarix LLM Gateway",
description="MCP Server for unified LLM access",
version="0.1.0",
lifespan=lifespan,
)
# Health endpoint
@app.get("/health")
async def health_check() -> dict[str, Any]:
"""Health check endpoint."""
settings = get_settings()
provider = get_provider()
return {
"status": "healthy",
"service": "llm-gateway",
"providers_configured": settings.get_available_providers(),
"provider_available": provider.is_available,
}
# Tool discovery endpoint (for MCP client compatibility)
@app.get("/mcp/tools")
async def list_tools() -> dict[str, Any]:
"""List available MCP tools."""
return {
"tools": [
{
"name": "chat_completion",
"description": "Generate a chat completion using the specified model group",
"inputSchema": {
"type": "object",
"properties": {
"project_id": {"type": "string", "description": "Project ID"},
"agent_id": {"type": "string", "description": "Agent ID"},
"messages": {
"type": "array",
"items": {
"type": "object",
"properties": {
"role": {"type": "string"},
"content": {"type": "string"},
},
"required": ["role", "content"],
},
},
"model_group": {
"type": "string",
"enum": [g.value for g in ModelGroup],
"default": "reasoning",
},
"max_tokens": {"type": "integer", "default": 4096},
"temperature": {"type": "number", "default": 0.7},
"stream": {"type": "boolean", "default": False},
},
"required": ["project_id", "agent_id", "messages"],
},
},
{
"name": "list_models",
"description": "List available models and model groups",
"inputSchema": {
"type": "object",
"properties": {
"project_id": {"type": "string"},
"agent_id": {"type": "string"},
"model_group": {"type": "string"},
},
"required": ["project_id", "agent_id"],
},
},
{
"name": "get_usage",
"description": "Get usage statistics for a project or agent",
"inputSchema": {
"type": "object",
"properties": {
"project_id": {"type": "string"},
"agent_id": {"type": "string"},
"period": {
"type": "string",
"enum": ["hour", "day", "month"],
"default": "day",
},
},
"required": ["project_id", "agent_id"],
},
},
{
"name": "count_tokens",
"description": "Count tokens in text",
"inputSchema": {
"type": "object",
"properties": {
"project_id": {"type": "string"},
"agent_id": {"type": "string"},
"text": {"type": "string"},
"model": {"type": "string"},
},
"required": ["project_id", "agent_id", "text"],
},
},
]
}
# JSON-RPC endpoint (for MCP client compatibility)
@app.post("/mcp")
async def jsonrpc_handler(request: dict[str, Any]) -> dict[str, Any]:
"""Handle JSON-RPC 2.0 requests for MCP tools."""
# Validate JSON-RPC structure
if request.get("jsonrpc") != "2.0":
return {
"jsonrpc": "2.0",
"error": {"code": -32600, "message": "Invalid JSON-RPC version"},
"id": request.get("id"),
}
method = request.get("method")
params = request.get("params", {})
request_id = request.get("id")
# Handle tool calls
if method == "tools/call":
tool_name = params.get("name")
tool_args = params.get("arguments", {})
try:
if tool_name == "chat_completion":
result = await _impl_chat_completion(**tool_args)
elif tool_name == "list_models":
result = await _impl_list_models(**tool_args)
elif tool_name == "get_usage":
result = await _impl_get_usage(**tool_args)
elif tool_name == "count_tokens":
result = await _impl_count_tokens(**tool_args)
else:
return {
"jsonrpc": "2.0",
"error": {"code": -32601, "message": f"Unknown tool: {tool_name}"},
"id": request_id,
}
return {
"jsonrpc": "2.0",
"result": {"content": [{"type": "text", "text": str(result)}]},
"id": request_id,
}
except LLMGatewayError as e:
return {
"jsonrpc": "2.0",
"error": {"code": -32000, "message": str(e), "data": e.to_dict()},
"id": request_id,
}
except Exception as e:
logger.exception(f"Error executing tool {tool_name}")
return {
"jsonrpc": "2.0",
"error": {"code": -32603, "message": str(e)},
"id": request_id,
}
# Handle tool listing
elif method == "tools/list":
tools_response = await list_tools()
return {
"jsonrpc": "2.0",
"result": tools_response,
"id": request_id,
}
else:
return {
"jsonrpc": "2.0",
"error": {"code": -32601, "message": f"Unknown method: {method}"},
"id": request_id,
}
# ============================================================================
# Core Implementation Functions
# ============================================================================
async def _impl_chat_completion(
project_id: str,
agent_id: str,
messages: list[dict[str, Any]],
model_group: str = "reasoning",
model_override: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
stream: bool = False,
session_id: str | None = None,
) -> dict[str, Any]:
"""Core implementation for chat completion."""
settings = get_settings()
router = get_model_router()
tracker = get_cost_tracker()
circuit_registry = get_circuit_registry()
# Check budget before making request
if settings.cost_tracking_enabled:
within_budget, current_cost, limit = await tracker.check_budget(project_id)
if not within_budget:
raise CostLimitExceededError(
entity_type="project",
entity_id=project_id,
current_cost=current_cost,
limit=limit,
)
# Select model
try:
model_name, model_config = await router.select_model(
model_group=model_group,
model_override=model_override,
)
except (InvalidModelGroupError, InvalidModelError, AllProvidersFailedError):
raise
except ModelNotAvailableError:
raise
# Get provider
provider = get_provider()
if not provider.router:
raise AllProvidersFailedError(
model_group=model_group,
attempted_models=[model_name],
errors=[{"error": "No providers configured"}],
)
# Generate request ID
request_id = str(uuid.uuid4())
try:
# Get circuit breaker for this provider
circuit = circuit_registry.get_circuit_sync(model_config.provider.value)
# Make completion request
if stream:
# Return streaming response info
# Actual streaming would be handled by a separate endpoint
return {
"status": "streaming_not_supported_via_tool",
"message": "Use /stream endpoint for streaming responses",
"request_id": request_id,
}
# Non-streaming completion
response = await provider.router.acompletion(
model=model_name,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
)
# Record success
await circuit.record_success()
# Extract response data
content = response.choices[0].message.content or ""
finish_reason = response.choices[0].finish_reason or "stop"
# Get usage stats
prompt_tokens = response.usage.prompt_tokens if response.usage else 0
completion_tokens = response.usage.completion_tokens if response.usage else 0
# Calculate cost
cost_usd = calculate_cost(model_name, prompt_tokens, completion_tokens)
# Record usage
await tracker.record_usage(
project_id=project_id,
agent_id=agent_id,
model=model_name,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
cost_usd=cost_usd,
session_id=session_id,
request_id=request_id,
)
return {
"id": request_id,
"model": model_name,
"provider": model_config.provider.value,
"content": content,
"finish_reason": finish_reason,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
"cost_usd": cost_usd,
},
}
except CircuitOpenError:
raise
except Exception as e:
# Record failure
circuit = circuit_registry.get_circuit_sync(model_config.provider.value)
await circuit.record_failure(e)
logger.error(f"Completion failed: {e}")
raise AllProvidersFailedError(
model_group=model_group,
attempted_models=[model_name],
errors=[{"model": model_name, "error": str(e)}],
)
async def _impl_list_models(
project_id: str,
agent_id: str,
model_group: str | None = None,
) -> dict[str, Any]:
"""Core implementation for list_models."""
settings = get_settings()
provider = get_provider()
router = get_model_router()
# Get available providers
available_providers = settings.get_available_providers()
result: dict[str, Any] = {
"project_id": project_id,
"agent_id": agent_id,
"available_providers": available_providers,
}
if model_group:
# List models for specific group
try:
parsed_group = router.parse_model_group(model_group)
models = await router.get_available_models_for_group(parsed_group)
result["model_group"] = model_group
result["models"] = [
{
"name": name,
"provider": config.provider.value,
"available": available,
"cost_per_1m_input": config.cost_per_1m_input,
"cost_per_1m_output": config.cost_per_1m_output,
}
for name, config, available in models
]
except InvalidModelGroupError as e:
result["error"] = e.to_dict()
else:
# List all model groups
groups: dict[str, Any] = {}
for group in ModelGroup:
group_config = MODEL_GROUPS.get(group)
if group_config:
models = await router.get_available_models_for_group(group)
available_count = sum(1 for _, _, avail in models if avail)
groups[group.value] = {
"description": group_config.description,
"primary": group_config.primary,
"fallbacks": group_config.fallbacks,
"available_models": available_count,
"total_models": len(models),
}
result["model_groups"] = groups
# List all models
all_models: list[dict[str, Any]] = []
available_models = provider.get_available_models()
for name, config in MODEL_CONFIGS.items():
all_models.append({
"name": name,
"provider": config.provider.value,
"available": name in available_models,
"cost_per_1m_input": config.cost_per_1m_input,
"cost_per_1m_output": config.cost_per_1m_output,
"context_window": config.context_window,
"max_output_tokens": config.max_output_tokens,
"supports_vision": config.supports_vision,
"supports_streaming": config.supports_streaming,
})
result["models"] = all_models
return result
async def _impl_get_usage(
project_id: str,
agent_id: str,
period: str = "day",
) -> dict[str, Any]:
"""Core implementation for get_usage."""
tracker = get_cost_tracker()
# Get project usage
project_report = await tracker.get_project_usage(project_id, period=period)
# Get agent usage
agent_report = await tracker.get_agent_usage(agent_id, period=period)
return {
"project_id": project_id,
"agent_id": agent_id,
"period": period,
"project_usage": {
"total_requests": project_report.total_requests,
"total_tokens": project_report.total_tokens,
"total_cost_usd": project_report.total_cost_usd,
"by_model": project_report.by_model,
"period_start": project_report.period_start.isoformat(),
"period_end": project_report.period_end.isoformat(),
},
"agent_usage": {
"total_requests": agent_report.total_requests,
"total_tokens": agent_report.total_tokens,
"total_cost_usd": agent_report.total_cost_usd,
"by_model": agent_report.by_model,
"period_start": agent_report.period_start.isoformat(),
"period_end": agent_report.period_end.isoformat(),
},
}
async def _impl_count_tokens(
project_id: str,
agent_id: str,
text: str,
model: str | None = None,
) -> dict[str, Any]:
"""Core implementation for count_tokens."""
# Use tiktoken for token counting
# Default to cl100k_base (used by GPT-4, Claude, etc.)
try:
if model and model.startswith("gpt"):
encoding = tiktoken.encoding_for_model(model)
else:
encoding = tiktoken.get_encoding("cl100k_base")
token_count = len(encoding.encode(text))
except Exception as e:
logger.warning(f"Token counting failed: {e}, using estimate")
# Fallback: rough estimate of ~4 chars per token
token_count = len(text) // 4
# Estimate costs for different models
cost_estimates: dict[str, float] = {}
for model_name, config in MODEL_CONFIGS.items():
if config.cost_per_1m_input > 0:
cost = (token_count / 1_000_000) * config.cost_per_1m_input
cost_estimates[model_name] = round(cost, 6)
return {
"project_id": project_id,
"agent_id": agent_id,
"token_count": token_count,
"text_length": len(text),
"encoding": "cl100k_base",
"cost_estimates": cost_estimates,
}
# ============================================================================
# MCP Tools (wrappers around core implementations)
# ============================================================================
@mcp.tool()
async def chat_completion(
project_id: str,
agent_id: str,
messages: list[dict[str, Any]],
model_group: str = "reasoning",
model_override: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
stream: bool = False,
session_id: str | None = None,
) -> dict[str, Any]:
"""
Generate a chat completion using the specified model group.
Args:
project_id: UUID of the project (required for cost attribution)
agent_id: UUID of the agent instance making the request
messages: List of message dicts with 'role' and 'content'
model_group: Model routing group (reasoning, code, fast, vision, embedding)
model_override: Specific model to use (bypasses group routing)
max_tokens: Maximum tokens to generate
temperature: Sampling temperature (0.0-2.0)
stream: Enable streaming response
session_id: Optional session ID for tracking
Returns:
Completion response with content and usage statistics
"""
return await _impl_chat_completion(
project_id=project_id,
agent_id=agent_id,
messages=messages,
model_group=model_group,
model_override=model_override,
max_tokens=max_tokens,
temperature=temperature,
stream=stream,
session_id=session_id,
)
@mcp.tool()
async def list_models(
project_id: str,
agent_id: str,
model_group: str | None = None,
) -> dict[str, Any]:
"""
List available models and model groups.
Args:
project_id: UUID of the project
agent_id: UUID of the agent instance
model_group: Optional specific group to list
Returns:
Dictionary of available models and groups
"""
return await _impl_list_models(
project_id=project_id,
agent_id=agent_id,
model_group=model_group,
)
@mcp.tool()
async def get_usage(
project_id: str,
agent_id: str,
period: str = "day",
) -> dict[str, Any]:
"""
Get usage statistics for a project or agent.
Args:
project_id: UUID of the project
agent_id: UUID of the agent
period: Time period (hour, day, month)
Returns:
Usage statistics including tokens and costs
"""
return await _impl_get_usage(
project_id=project_id,
agent_id=agent_id,
period=period,
)
@mcp.tool()
async def count_tokens(
project_id: str,
agent_id: str,
text: str,
model: str | None = None,
) -> dict[str, Any]:
"""
Count tokens in text.
Args:
project_id: UUID of the project
agent_id: UUID of the agent
text: Text to count tokens in
model: Optional model for tokenizer selection
Returns:
Token count and estimation details
"""
return await _impl_count_tokens(
project_id=project_id,
agent_id=agent_id,
text=text,
model=model,
)
def main() -> None:
"""Run the server."""
import uvicorn
settings = get_settings()
uvicorn.run(
"server:app",
host=settings.host,
port=settings.port,
reload=settings.debug,
)
if __name__ == "__main__":
main()