forked from cardosofelipe/fast-next-template
- Add type annotations for mypy compliance - Use UTC-aware datetimes consistently (datetime.now(UTC)) - Add type: ignore comments for LiteLLM incomplete stubs - Fix import ordering and formatting - Update pyproject.toml mypy configuration 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
680 lines
21 KiB
Python
680 lines
21 KiB
Python
"""
|
|
Syndarix LLM Gateway MCP Server.
|
|
|
|
Provides unified LLM access with:
|
|
- Multi-provider support (Claude, GPT, Gemini, Qwen, DeepSeek)
|
|
- Automatic failover chains
|
|
- Cost tracking via Redis
|
|
- Model group routing (reasoning, code, fast, vision, embedding)
|
|
- Circuit breaker protection
|
|
|
|
Per ADR-004: LLM Provider Abstraction.
|
|
"""
|
|
|
|
import logging
|
|
import uuid
|
|
from collections.abc import AsyncIterator
|
|
from contextlib import asynccontextmanager
|
|
from typing import Any
|
|
|
|
import tiktoken
|
|
from fastapi import FastAPI
|
|
from fastmcp import FastMCP
|
|
|
|
from config import get_settings
|
|
from cost_tracking import calculate_cost, get_cost_tracker
|
|
from exceptions import (
|
|
AllProvidersFailedError,
|
|
CircuitOpenError,
|
|
CostLimitExceededError,
|
|
InvalidModelError,
|
|
InvalidModelGroupError,
|
|
LLMGatewayError,
|
|
ModelNotAvailableError,
|
|
)
|
|
from failover import get_circuit_registry
|
|
from models import (
|
|
MODEL_CONFIGS,
|
|
MODEL_GROUPS,
|
|
ModelGroup,
|
|
)
|
|
from providers import get_provider
|
|
from routing import get_model_router
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# Create FastMCP server
|
|
mcp = FastMCP("syndarix-llm-gateway")
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
|
|
"""Application lifespan handler."""
|
|
settings = get_settings()
|
|
logger.info(f"Starting LLM Gateway on {settings.host}:{settings.port}")
|
|
|
|
# Initialize provider
|
|
provider = get_provider()
|
|
provider.initialize()
|
|
|
|
yield
|
|
|
|
# Cleanup
|
|
from cost_tracking import close_cost_tracker
|
|
|
|
await close_cost_tracker()
|
|
logger.info("LLM Gateway shutdown complete")
|
|
|
|
|
|
# Create FastAPI app that wraps FastMCP
|
|
app = FastAPI(
|
|
title="Syndarix LLM Gateway",
|
|
description="MCP Server for unified LLM access",
|
|
version="0.1.0",
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
|
|
# Health endpoint
|
|
@app.get("/health")
|
|
async def health_check() -> dict[str, Any]:
|
|
"""Health check endpoint."""
|
|
settings = get_settings()
|
|
provider = get_provider()
|
|
|
|
return {
|
|
"status": "healthy",
|
|
"service": "llm-gateway",
|
|
"providers_configured": settings.get_available_providers(),
|
|
"provider_available": provider.is_available,
|
|
}
|
|
|
|
|
|
# Tool discovery endpoint (for MCP client compatibility)
|
|
@app.get("/mcp/tools")
|
|
async def list_tools() -> dict[str, Any]:
|
|
"""List available MCP tools."""
|
|
return {
|
|
"tools": [
|
|
{
|
|
"name": "chat_completion",
|
|
"description": "Generate a chat completion using the specified model group",
|
|
"inputSchema": {
|
|
"type": "object",
|
|
"properties": {
|
|
"project_id": {"type": "string", "description": "Project ID"},
|
|
"agent_id": {"type": "string", "description": "Agent ID"},
|
|
"messages": {
|
|
"type": "array",
|
|
"items": {
|
|
"type": "object",
|
|
"properties": {
|
|
"role": {"type": "string"},
|
|
"content": {"type": "string"},
|
|
},
|
|
"required": ["role", "content"],
|
|
},
|
|
},
|
|
"model_group": {
|
|
"type": "string",
|
|
"enum": [g.value for g in ModelGroup],
|
|
"default": "reasoning",
|
|
},
|
|
"max_tokens": {"type": "integer", "default": 4096},
|
|
"temperature": {"type": "number", "default": 0.7},
|
|
"stream": {"type": "boolean", "default": False},
|
|
},
|
|
"required": ["project_id", "agent_id", "messages"],
|
|
},
|
|
},
|
|
{
|
|
"name": "list_models",
|
|
"description": "List available models and model groups",
|
|
"inputSchema": {
|
|
"type": "object",
|
|
"properties": {
|
|
"project_id": {"type": "string"},
|
|
"agent_id": {"type": "string"},
|
|
"model_group": {"type": "string"},
|
|
},
|
|
"required": ["project_id", "agent_id"],
|
|
},
|
|
},
|
|
{
|
|
"name": "get_usage",
|
|
"description": "Get usage statistics for a project or agent",
|
|
"inputSchema": {
|
|
"type": "object",
|
|
"properties": {
|
|
"project_id": {"type": "string"},
|
|
"agent_id": {"type": "string"},
|
|
"period": {
|
|
"type": "string",
|
|
"enum": ["hour", "day", "month"],
|
|
"default": "day",
|
|
},
|
|
},
|
|
"required": ["project_id", "agent_id"],
|
|
},
|
|
},
|
|
{
|
|
"name": "count_tokens",
|
|
"description": "Count tokens in text",
|
|
"inputSchema": {
|
|
"type": "object",
|
|
"properties": {
|
|
"project_id": {"type": "string"},
|
|
"agent_id": {"type": "string"},
|
|
"text": {"type": "string"},
|
|
"model": {"type": "string"},
|
|
},
|
|
"required": ["project_id", "agent_id", "text"],
|
|
},
|
|
},
|
|
]
|
|
}
|
|
|
|
|
|
# JSON-RPC endpoint (for MCP client compatibility)
|
|
@app.post("/mcp")
|
|
async def jsonrpc_handler(request: dict[str, Any]) -> dict[str, Any]:
|
|
"""Handle JSON-RPC 2.0 requests for MCP tools."""
|
|
# Validate JSON-RPC structure
|
|
if request.get("jsonrpc") != "2.0":
|
|
return {
|
|
"jsonrpc": "2.0",
|
|
"error": {"code": -32600, "message": "Invalid JSON-RPC version"},
|
|
"id": request.get("id"),
|
|
}
|
|
|
|
method = request.get("method")
|
|
params = request.get("params", {})
|
|
request_id = request.get("id")
|
|
|
|
# Handle tool calls
|
|
if method == "tools/call":
|
|
tool_name = params.get("name")
|
|
tool_args = params.get("arguments", {})
|
|
|
|
try:
|
|
if tool_name == "chat_completion":
|
|
result = await _impl_chat_completion(**tool_args)
|
|
elif tool_name == "list_models":
|
|
result = await _impl_list_models(**tool_args)
|
|
elif tool_name == "get_usage":
|
|
result = await _impl_get_usage(**tool_args)
|
|
elif tool_name == "count_tokens":
|
|
result = await _impl_count_tokens(**tool_args)
|
|
else:
|
|
return {
|
|
"jsonrpc": "2.0",
|
|
"error": {"code": -32601, "message": f"Unknown tool: {tool_name}"},
|
|
"id": request_id,
|
|
}
|
|
|
|
return {
|
|
"jsonrpc": "2.0",
|
|
"result": {"content": [{"type": "text", "text": str(result)}]},
|
|
"id": request_id,
|
|
}
|
|
|
|
except LLMGatewayError as e:
|
|
return {
|
|
"jsonrpc": "2.0",
|
|
"error": {"code": -32000, "message": str(e), "data": e.to_dict()},
|
|
"id": request_id,
|
|
}
|
|
except Exception as e:
|
|
logger.exception(f"Error executing tool {tool_name}")
|
|
return {
|
|
"jsonrpc": "2.0",
|
|
"error": {"code": -32603, "message": str(e)},
|
|
"id": request_id,
|
|
}
|
|
|
|
# Handle tool listing
|
|
elif method == "tools/list":
|
|
tools_response = await list_tools()
|
|
return {
|
|
"jsonrpc": "2.0",
|
|
"result": tools_response,
|
|
"id": request_id,
|
|
}
|
|
|
|
else:
|
|
return {
|
|
"jsonrpc": "2.0",
|
|
"error": {"code": -32601, "message": f"Unknown method: {method}"},
|
|
"id": request_id,
|
|
}
|
|
|
|
|
|
# ============================================================================
|
|
# Core Implementation Functions
|
|
# ============================================================================
|
|
|
|
|
|
async def _impl_chat_completion(
|
|
project_id: str,
|
|
agent_id: str,
|
|
messages: list[dict[str, Any]],
|
|
model_group: str = "reasoning",
|
|
model_override: str | None = None,
|
|
max_tokens: int = 4096,
|
|
temperature: float = 0.7,
|
|
stream: bool = False,
|
|
session_id: str | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Core implementation for chat completion."""
|
|
settings = get_settings()
|
|
router = get_model_router()
|
|
tracker = get_cost_tracker()
|
|
circuit_registry = get_circuit_registry()
|
|
|
|
# Check budget before making request
|
|
if settings.cost_tracking_enabled:
|
|
within_budget, current_cost, limit = await tracker.check_budget(project_id)
|
|
if not within_budget:
|
|
raise CostLimitExceededError(
|
|
entity_type="project",
|
|
entity_id=project_id,
|
|
current_cost=current_cost,
|
|
limit=limit,
|
|
)
|
|
|
|
# Select model
|
|
try:
|
|
model_name, model_config = await router.select_model(
|
|
model_group=model_group,
|
|
model_override=model_override,
|
|
)
|
|
except (InvalidModelGroupError, InvalidModelError, AllProvidersFailedError):
|
|
raise
|
|
except ModelNotAvailableError:
|
|
raise
|
|
|
|
# Get provider
|
|
provider = get_provider()
|
|
if not provider.router:
|
|
raise AllProvidersFailedError(
|
|
model_group=model_group,
|
|
attempted_models=[model_name],
|
|
errors=[{"error": "No providers configured"}],
|
|
)
|
|
|
|
# Generate request ID
|
|
request_id = str(uuid.uuid4())
|
|
|
|
try:
|
|
# Get circuit breaker for this provider
|
|
circuit = circuit_registry.get_circuit_sync(model_config.provider.value)
|
|
|
|
# Make completion request
|
|
if stream:
|
|
# Return streaming response info
|
|
# Actual streaming would be handled by a separate endpoint
|
|
return {
|
|
"status": "streaming_not_supported_via_tool",
|
|
"message": "Use /stream endpoint for streaming responses",
|
|
"request_id": request_id,
|
|
}
|
|
|
|
# Non-streaming completion
|
|
response = await provider.router.acompletion(
|
|
model=model_name,
|
|
messages=messages, # type: ignore[arg-type]
|
|
max_tokens=max_tokens,
|
|
temperature=temperature,
|
|
)
|
|
|
|
# Record success
|
|
await circuit.record_success()
|
|
|
|
# Extract response data
|
|
content = response.choices[0].message.content or "" # type: ignore[union-attr]
|
|
finish_reason = response.choices[0].finish_reason or "stop"
|
|
|
|
# Get usage stats
|
|
prompt_tokens = response.usage.prompt_tokens if response.usage else 0 # type: ignore[attr-defined]
|
|
completion_tokens = response.usage.completion_tokens if response.usage else 0 # type: ignore[attr-defined]
|
|
|
|
# Calculate cost
|
|
cost_usd = calculate_cost(model_name, prompt_tokens, completion_tokens)
|
|
|
|
# Record usage
|
|
await tracker.record_usage(
|
|
project_id=project_id,
|
|
agent_id=agent_id,
|
|
model=model_name,
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
cost_usd=cost_usd,
|
|
session_id=session_id,
|
|
request_id=request_id,
|
|
)
|
|
|
|
return {
|
|
"id": request_id,
|
|
"model": model_name,
|
|
"provider": model_config.provider.value,
|
|
"content": content,
|
|
"finish_reason": finish_reason,
|
|
"usage": {
|
|
"prompt_tokens": prompt_tokens,
|
|
"completion_tokens": completion_tokens,
|
|
"total_tokens": prompt_tokens + completion_tokens,
|
|
"cost_usd": cost_usd,
|
|
},
|
|
}
|
|
|
|
except CircuitOpenError:
|
|
raise
|
|
except Exception as e:
|
|
# Record failure
|
|
circuit = circuit_registry.get_circuit_sync(model_config.provider.value)
|
|
await circuit.record_failure(e)
|
|
|
|
logger.error(f"Completion failed: {e}")
|
|
raise AllProvidersFailedError(
|
|
model_group=model_group,
|
|
attempted_models=[model_name],
|
|
errors=[{"model": model_name, "error": str(e)}],
|
|
)
|
|
|
|
|
|
async def _impl_list_models(
|
|
project_id: str,
|
|
agent_id: str,
|
|
model_group: str | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Core implementation for list_models."""
|
|
settings = get_settings()
|
|
provider = get_provider()
|
|
router = get_model_router()
|
|
|
|
# Get available providers
|
|
available_providers = settings.get_available_providers()
|
|
|
|
result: dict[str, Any] = {
|
|
"project_id": project_id,
|
|
"agent_id": agent_id,
|
|
"available_providers": available_providers,
|
|
}
|
|
|
|
if model_group:
|
|
# List models for specific group
|
|
try:
|
|
parsed_group = router.parse_model_group(model_group)
|
|
models = await router.get_available_models_for_group(parsed_group)
|
|
|
|
result["model_group"] = model_group
|
|
result["models"] = [
|
|
{
|
|
"name": name,
|
|
"provider": config.provider.value,
|
|
"available": available,
|
|
"cost_per_1m_input": config.cost_per_1m_input,
|
|
"cost_per_1m_output": config.cost_per_1m_output,
|
|
}
|
|
for name, config, available in models
|
|
]
|
|
except InvalidModelGroupError as e:
|
|
result["error"] = e.to_dict()
|
|
else:
|
|
# List all model groups
|
|
groups: dict[str, Any] = {}
|
|
for group in ModelGroup:
|
|
group_config = MODEL_GROUPS.get(group)
|
|
if group_config:
|
|
models = await router.get_available_models_for_group(group)
|
|
available_count = sum(1 for _, _, avail in models if avail)
|
|
groups[group.value] = {
|
|
"description": group_config.description,
|
|
"primary": group_config.primary,
|
|
"fallbacks": group_config.fallbacks,
|
|
"available_models": available_count,
|
|
"total_models": len(models),
|
|
}
|
|
result["model_groups"] = groups
|
|
|
|
# List all models
|
|
all_models: list[dict[str, Any]] = []
|
|
available_models = provider.get_available_models()
|
|
for name, config in MODEL_CONFIGS.items():
|
|
all_models.append(
|
|
{
|
|
"name": name,
|
|
"provider": config.provider.value,
|
|
"available": name in available_models,
|
|
"cost_per_1m_input": config.cost_per_1m_input,
|
|
"cost_per_1m_output": config.cost_per_1m_output,
|
|
"context_window": config.context_window,
|
|
"max_output_tokens": config.max_output_tokens,
|
|
"supports_vision": config.supports_vision,
|
|
"supports_streaming": config.supports_streaming,
|
|
}
|
|
)
|
|
result["models"] = all_models
|
|
|
|
return result
|
|
|
|
|
|
async def _impl_get_usage(
|
|
project_id: str,
|
|
agent_id: str,
|
|
period: str = "day",
|
|
) -> dict[str, Any]:
|
|
"""Core implementation for get_usage."""
|
|
tracker = get_cost_tracker()
|
|
|
|
# Get project usage
|
|
project_report = await tracker.get_project_usage(project_id, period=period)
|
|
|
|
# Get agent usage
|
|
agent_report = await tracker.get_agent_usage(agent_id, period=period)
|
|
|
|
return {
|
|
"project_id": project_id,
|
|
"agent_id": agent_id,
|
|
"period": period,
|
|
"project_usage": {
|
|
"total_requests": project_report.total_requests,
|
|
"total_tokens": project_report.total_tokens,
|
|
"total_cost_usd": project_report.total_cost_usd,
|
|
"by_model": project_report.by_model,
|
|
"period_start": project_report.period_start.isoformat(),
|
|
"period_end": project_report.period_end.isoformat(),
|
|
},
|
|
"agent_usage": {
|
|
"total_requests": agent_report.total_requests,
|
|
"total_tokens": agent_report.total_tokens,
|
|
"total_cost_usd": agent_report.total_cost_usd,
|
|
"by_model": agent_report.by_model,
|
|
"period_start": agent_report.period_start.isoformat(),
|
|
"period_end": agent_report.period_end.isoformat(),
|
|
},
|
|
}
|
|
|
|
|
|
async def _impl_count_tokens(
|
|
project_id: str,
|
|
agent_id: str,
|
|
text: str,
|
|
model: str | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Core implementation for count_tokens."""
|
|
# Use tiktoken for token counting
|
|
# Default to cl100k_base (used by GPT-4, Claude, etc.)
|
|
try:
|
|
if model and model.startswith("gpt"):
|
|
encoding = tiktoken.encoding_for_model(model)
|
|
else:
|
|
encoding = tiktoken.get_encoding("cl100k_base")
|
|
|
|
token_count = len(encoding.encode(text))
|
|
except Exception as e:
|
|
logger.warning(f"Token counting failed: {e}, using estimate")
|
|
# Fallback: rough estimate of ~4 chars per token
|
|
token_count = len(text) // 4
|
|
|
|
# Estimate costs for different models
|
|
cost_estimates: dict[str, float] = {}
|
|
for model_name, config in MODEL_CONFIGS.items():
|
|
if config.cost_per_1m_input > 0:
|
|
cost = (token_count / 1_000_000) * config.cost_per_1m_input
|
|
cost_estimates[model_name] = round(cost, 6)
|
|
|
|
return {
|
|
"project_id": project_id,
|
|
"agent_id": agent_id,
|
|
"token_count": token_count,
|
|
"text_length": len(text),
|
|
"encoding": "cl100k_base",
|
|
"cost_estimates": cost_estimates,
|
|
}
|
|
|
|
|
|
# ============================================================================
|
|
# MCP Tools (wrappers around core implementations)
|
|
# ============================================================================
|
|
|
|
|
|
@mcp.tool()
|
|
async def chat_completion(
|
|
project_id: str,
|
|
agent_id: str,
|
|
messages: list[dict[str, Any]],
|
|
model_group: str = "reasoning",
|
|
model_override: str | None = None,
|
|
max_tokens: int = 4096,
|
|
temperature: float = 0.7,
|
|
stream: bool = False,
|
|
session_id: str | None = None,
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Generate a chat completion using the specified model group.
|
|
|
|
Args:
|
|
project_id: UUID of the project (required for cost attribution)
|
|
agent_id: UUID of the agent instance making the request
|
|
messages: List of message dicts with 'role' and 'content'
|
|
model_group: Model routing group (reasoning, code, fast, vision, embedding)
|
|
model_override: Specific model to use (bypasses group routing)
|
|
max_tokens: Maximum tokens to generate
|
|
temperature: Sampling temperature (0.0-2.0)
|
|
stream: Enable streaming response
|
|
session_id: Optional session ID for tracking
|
|
|
|
Returns:
|
|
Completion response with content and usage statistics
|
|
"""
|
|
return await _impl_chat_completion(
|
|
project_id=project_id,
|
|
agent_id=agent_id,
|
|
messages=messages,
|
|
model_group=model_group,
|
|
model_override=model_override,
|
|
max_tokens=max_tokens,
|
|
temperature=temperature,
|
|
stream=stream,
|
|
session_id=session_id,
|
|
)
|
|
|
|
|
|
@mcp.tool()
|
|
async def list_models(
|
|
project_id: str,
|
|
agent_id: str,
|
|
model_group: str | None = None,
|
|
) -> dict[str, Any]:
|
|
"""
|
|
List available models and model groups.
|
|
|
|
Args:
|
|
project_id: UUID of the project
|
|
agent_id: UUID of the agent instance
|
|
model_group: Optional specific group to list
|
|
|
|
Returns:
|
|
Dictionary of available models and groups
|
|
"""
|
|
return await _impl_list_models(
|
|
project_id=project_id,
|
|
agent_id=agent_id,
|
|
model_group=model_group,
|
|
)
|
|
|
|
|
|
@mcp.tool()
|
|
async def get_usage(
|
|
project_id: str,
|
|
agent_id: str,
|
|
period: str = "day",
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Get usage statistics for a project or agent.
|
|
|
|
Args:
|
|
project_id: UUID of the project
|
|
agent_id: UUID of the agent
|
|
period: Time period (hour, day, month)
|
|
|
|
Returns:
|
|
Usage statistics including tokens and costs
|
|
"""
|
|
return await _impl_get_usage(
|
|
project_id=project_id,
|
|
agent_id=agent_id,
|
|
period=period,
|
|
)
|
|
|
|
|
|
@mcp.tool()
|
|
async def count_tokens(
|
|
project_id: str,
|
|
agent_id: str,
|
|
text: str,
|
|
model: str | None = None,
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Count tokens in text.
|
|
|
|
Args:
|
|
project_id: UUID of the project
|
|
agent_id: UUID of the agent
|
|
text: Text to count tokens in
|
|
model: Optional model for tokenizer selection
|
|
|
|
Returns:
|
|
Token count and estimation details
|
|
"""
|
|
return await _impl_count_tokens(
|
|
project_id=project_id,
|
|
agent_id=agent_id,
|
|
text=text,
|
|
model=model,
|
|
)
|
|
|
|
|
|
def main() -> None:
|
|
"""Run the server."""
|
|
import uvicorn
|
|
|
|
settings = get_settings()
|
|
uvicorn.run(
|
|
"server:app",
|
|
host=settings.host,
|
|
port=settings.port,
|
|
reload=settings.debug,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|