""" Syndarix LLM Gateway MCP Server. Provides unified LLM access with: - Multi-provider support (Claude, GPT, Gemini, Qwen, DeepSeek) - Automatic failover chains - Cost tracking via Redis - Model group routing (reasoning, code, fast, vision, embedding) - Circuit breaker protection Per ADR-004: LLM Provider Abstraction. """ import logging import uuid from contextlib import asynccontextmanager from typing import Any import tiktoken from fastapi import FastAPI from fastmcp import FastMCP from config import get_settings from cost_tracking import calculate_cost, get_cost_tracker from exceptions import ( AllProvidersFailedError, CircuitOpenError, CostLimitExceededError, InvalidModelError, InvalidModelGroupError, LLMGatewayError, ModelNotAvailableError, ) from failover import get_circuit_registry from models import ( MODEL_CONFIGS, MODEL_GROUPS, ModelGroup, ) from providers import get_provider from routing import get_model_router # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) # Create FastMCP server mcp = FastMCP("syndarix-llm-gateway") @asynccontextmanager async def lifespan(_app: FastAPI): """Application lifespan handler.""" settings = get_settings() logger.info(f"Starting LLM Gateway on {settings.host}:{settings.port}") # Initialize provider provider = get_provider() provider.initialize() yield # Cleanup from cost_tracking import close_cost_tracker await close_cost_tracker() logger.info("LLM Gateway shutdown complete") # Create FastAPI app that wraps FastMCP app = FastAPI( title="Syndarix LLM Gateway", description="MCP Server for unified LLM access", version="0.1.0", lifespan=lifespan, ) # Health endpoint @app.get("/health") async def health_check() -> dict[str, Any]: """Health check endpoint.""" settings = get_settings() provider = get_provider() return { "status": "healthy", "service": "llm-gateway", "providers_configured": settings.get_available_providers(), "provider_available": provider.is_available, } # Tool discovery endpoint (for MCP client compatibility) @app.get("/mcp/tools") async def list_tools() -> dict[str, Any]: """List available MCP tools.""" return { "tools": [ { "name": "chat_completion", "description": "Generate a chat completion using the specified model group", "inputSchema": { "type": "object", "properties": { "project_id": {"type": "string", "description": "Project ID"}, "agent_id": {"type": "string", "description": "Agent ID"}, "messages": { "type": "array", "items": { "type": "object", "properties": { "role": {"type": "string"}, "content": {"type": "string"}, }, "required": ["role", "content"], }, }, "model_group": { "type": "string", "enum": [g.value for g in ModelGroup], "default": "reasoning", }, "max_tokens": {"type": "integer", "default": 4096}, "temperature": {"type": "number", "default": 0.7}, "stream": {"type": "boolean", "default": False}, }, "required": ["project_id", "agent_id", "messages"], }, }, { "name": "list_models", "description": "List available models and model groups", "inputSchema": { "type": "object", "properties": { "project_id": {"type": "string"}, "agent_id": {"type": "string"}, "model_group": {"type": "string"}, }, "required": ["project_id", "agent_id"], }, }, { "name": "get_usage", "description": "Get usage statistics for a project or agent", "inputSchema": { "type": "object", "properties": { "project_id": {"type": "string"}, "agent_id": {"type": "string"}, "period": { "type": "string", "enum": ["hour", "day", "month"], "default": "day", }, }, "required": ["project_id", "agent_id"], }, }, { "name": "count_tokens", "description": "Count tokens in text", "inputSchema": { "type": "object", "properties": { "project_id": {"type": "string"}, "agent_id": {"type": "string"}, "text": {"type": "string"}, "model": {"type": "string"}, }, "required": ["project_id", "agent_id", "text"], }, }, ] } # JSON-RPC endpoint (for MCP client compatibility) @app.post("/mcp") async def jsonrpc_handler(request: dict[str, Any]) -> dict[str, Any]: """Handle JSON-RPC 2.0 requests for MCP tools.""" # Validate JSON-RPC structure if request.get("jsonrpc") != "2.0": return { "jsonrpc": "2.0", "error": {"code": -32600, "message": "Invalid JSON-RPC version"}, "id": request.get("id"), } method = request.get("method") params = request.get("params", {}) request_id = request.get("id") # Handle tool calls if method == "tools/call": tool_name = params.get("name") tool_args = params.get("arguments", {}) try: if tool_name == "chat_completion": result = await _impl_chat_completion(**tool_args) elif tool_name == "list_models": result = await _impl_list_models(**tool_args) elif tool_name == "get_usage": result = await _impl_get_usage(**tool_args) elif tool_name == "count_tokens": result = await _impl_count_tokens(**tool_args) else: return { "jsonrpc": "2.0", "error": {"code": -32601, "message": f"Unknown tool: {tool_name}"}, "id": request_id, } return { "jsonrpc": "2.0", "result": {"content": [{"type": "text", "text": str(result)}]}, "id": request_id, } except LLMGatewayError as e: return { "jsonrpc": "2.0", "error": {"code": -32000, "message": str(e), "data": e.to_dict()}, "id": request_id, } except Exception as e: logger.exception(f"Error executing tool {tool_name}") return { "jsonrpc": "2.0", "error": {"code": -32603, "message": str(e)}, "id": request_id, } # Handle tool listing elif method == "tools/list": tools_response = await list_tools() return { "jsonrpc": "2.0", "result": tools_response, "id": request_id, } else: return { "jsonrpc": "2.0", "error": {"code": -32601, "message": f"Unknown method: {method}"}, "id": request_id, } # ============================================================================ # Core Implementation Functions # ============================================================================ async def _impl_chat_completion( project_id: str, agent_id: str, messages: list[dict[str, Any]], model_group: str = "reasoning", model_override: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, stream: bool = False, session_id: str | None = None, ) -> dict[str, Any]: """Core implementation for chat completion.""" settings = get_settings() router = get_model_router() tracker = get_cost_tracker() circuit_registry = get_circuit_registry() # Check budget before making request if settings.cost_tracking_enabled: within_budget, current_cost, limit = await tracker.check_budget(project_id) if not within_budget: raise CostLimitExceededError( entity_type="project", entity_id=project_id, current_cost=current_cost, limit=limit, ) # Select model try: model_name, model_config = await router.select_model( model_group=model_group, model_override=model_override, ) except (InvalidModelGroupError, InvalidModelError, AllProvidersFailedError): raise except ModelNotAvailableError: raise # Get provider provider = get_provider() if not provider.router: raise AllProvidersFailedError( model_group=model_group, attempted_models=[model_name], errors=[{"error": "No providers configured"}], ) # Generate request ID request_id = str(uuid.uuid4()) try: # Get circuit breaker for this provider circuit = circuit_registry.get_circuit_sync(model_config.provider.value) # Make completion request if stream: # Return streaming response info # Actual streaming would be handled by a separate endpoint return { "status": "streaming_not_supported_via_tool", "message": "Use /stream endpoint for streaming responses", "request_id": request_id, } # Non-streaming completion response = await provider.router.acompletion( model=model_name, messages=messages, max_tokens=max_tokens, temperature=temperature, ) # Record success await circuit.record_success() # Extract response data content = response.choices[0].message.content or "" finish_reason = response.choices[0].finish_reason or "stop" # Get usage stats prompt_tokens = response.usage.prompt_tokens if response.usage else 0 completion_tokens = response.usage.completion_tokens if response.usage else 0 # Calculate cost cost_usd = calculate_cost(model_name, prompt_tokens, completion_tokens) # Record usage await tracker.record_usage( project_id=project_id, agent_id=agent_id, model=model_name, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, cost_usd=cost_usd, session_id=session_id, request_id=request_id, ) return { "id": request_id, "model": model_name, "provider": model_config.provider.value, "content": content, "finish_reason": finish_reason, "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens, "cost_usd": cost_usd, }, } except CircuitOpenError: raise except Exception as e: # Record failure circuit = circuit_registry.get_circuit_sync(model_config.provider.value) await circuit.record_failure(e) logger.error(f"Completion failed: {e}") raise AllProvidersFailedError( model_group=model_group, attempted_models=[model_name], errors=[{"model": model_name, "error": str(e)}], ) async def _impl_list_models( project_id: str, agent_id: str, model_group: str | None = None, ) -> dict[str, Any]: """Core implementation for list_models.""" settings = get_settings() provider = get_provider() router = get_model_router() # Get available providers available_providers = settings.get_available_providers() result: dict[str, Any] = { "project_id": project_id, "agent_id": agent_id, "available_providers": available_providers, } if model_group: # List models for specific group try: parsed_group = router.parse_model_group(model_group) models = await router.get_available_models_for_group(parsed_group) result["model_group"] = model_group result["models"] = [ { "name": name, "provider": config.provider.value, "available": available, "cost_per_1m_input": config.cost_per_1m_input, "cost_per_1m_output": config.cost_per_1m_output, } for name, config, available in models ] except InvalidModelGroupError as e: result["error"] = e.to_dict() else: # List all model groups groups: dict[str, Any] = {} for group in ModelGroup: group_config = MODEL_GROUPS.get(group) if group_config: models = await router.get_available_models_for_group(group) available_count = sum(1 for _, _, avail in models if avail) groups[group.value] = { "description": group_config.description, "primary": group_config.primary, "fallbacks": group_config.fallbacks, "available_models": available_count, "total_models": len(models), } result["model_groups"] = groups # List all models all_models: list[dict[str, Any]] = [] available_models = provider.get_available_models() for name, config in MODEL_CONFIGS.items(): all_models.append({ "name": name, "provider": config.provider.value, "available": name in available_models, "cost_per_1m_input": config.cost_per_1m_input, "cost_per_1m_output": config.cost_per_1m_output, "context_window": config.context_window, "max_output_tokens": config.max_output_tokens, "supports_vision": config.supports_vision, "supports_streaming": config.supports_streaming, }) result["models"] = all_models return result async def _impl_get_usage( project_id: str, agent_id: str, period: str = "day", ) -> dict[str, Any]: """Core implementation for get_usage.""" tracker = get_cost_tracker() # Get project usage project_report = await tracker.get_project_usage(project_id, period=period) # Get agent usage agent_report = await tracker.get_agent_usage(agent_id, period=period) return { "project_id": project_id, "agent_id": agent_id, "period": period, "project_usage": { "total_requests": project_report.total_requests, "total_tokens": project_report.total_tokens, "total_cost_usd": project_report.total_cost_usd, "by_model": project_report.by_model, "period_start": project_report.period_start.isoformat(), "period_end": project_report.period_end.isoformat(), }, "agent_usage": { "total_requests": agent_report.total_requests, "total_tokens": agent_report.total_tokens, "total_cost_usd": agent_report.total_cost_usd, "by_model": agent_report.by_model, "period_start": agent_report.period_start.isoformat(), "period_end": agent_report.period_end.isoformat(), }, } async def _impl_count_tokens( project_id: str, agent_id: str, text: str, model: str | None = None, ) -> dict[str, Any]: """Core implementation for count_tokens.""" # Use tiktoken for token counting # Default to cl100k_base (used by GPT-4, Claude, etc.) try: if model and model.startswith("gpt"): encoding = tiktoken.encoding_for_model(model) else: encoding = tiktoken.get_encoding("cl100k_base") token_count = len(encoding.encode(text)) except Exception as e: logger.warning(f"Token counting failed: {e}, using estimate") # Fallback: rough estimate of ~4 chars per token token_count = len(text) // 4 # Estimate costs for different models cost_estimates: dict[str, float] = {} for model_name, config in MODEL_CONFIGS.items(): if config.cost_per_1m_input > 0: cost = (token_count / 1_000_000) * config.cost_per_1m_input cost_estimates[model_name] = round(cost, 6) return { "project_id": project_id, "agent_id": agent_id, "token_count": token_count, "text_length": len(text), "encoding": "cl100k_base", "cost_estimates": cost_estimates, } # ============================================================================ # MCP Tools (wrappers around core implementations) # ============================================================================ @mcp.tool() async def chat_completion( project_id: str, agent_id: str, messages: list[dict[str, Any]], model_group: str = "reasoning", model_override: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, stream: bool = False, session_id: str | None = None, ) -> dict[str, Any]: """ Generate a chat completion using the specified model group. Args: project_id: UUID of the project (required for cost attribution) agent_id: UUID of the agent instance making the request messages: List of message dicts with 'role' and 'content' model_group: Model routing group (reasoning, code, fast, vision, embedding) model_override: Specific model to use (bypasses group routing) max_tokens: Maximum tokens to generate temperature: Sampling temperature (0.0-2.0) stream: Enable streaming response session_id: Optional session ID for tracking Returns: Completion response with content and usage statistics """ return await _impl_chat_completion( project_id=project_id, agent_id=agent_id, messages=messages, model_group=model_group, model_override=model_override, max_tokens=max_tokens, temperature=temperature, stream=stream, session_id=session_id, ) @mcp.tool() async def list_models( project_id: str, agent_id: str, model_group: str | None = None, ) -> dict[str, Any]: """ List available models and model groups. Args: project_id: UUID of the project agent_id: UUID of the agent instance model_group: Optional specific group to list Returns: Dictionary of available models and groups """ return await _impl_list_models( project_id=project_id, agent_id=agent_id, model_group=model_group, ) @mcp.tool() async def get_usage( project_id: str, agent_id: str, period: str = "day", ) -> dict[str, Any]: """ Get usage statistics for a project or agent. Args: project_id: UUID of the project agent_id: UUID of the agent period: Time period (hour, day, month) Returns: Usage statistics including tokens and costs """ return await _impl_get_usage( project_id=project_id, agent_id=agent_id, period=period, ) @mcp.tool() async def count_tokens( project_id: str, agent_id: str, text: str, model: str | None = None, ) -> dict[str, Any]: """ Count tokens in text. Args: project_id: UUID of the project agent_id: UUID of the agent text: Text to count tokens in model: Optional model for tokenizer selection Returns: Token count and estimation details """ return await _impl_count_tokens( project_id=project_id, agent_id=agent_id, text=text, model=model, ) def main() -> None: """Run the server.""" import uvicorn settings = get_settings() uvicorn.run( "server:app", host=settings.host, port=settings.port, reload=settings.debug, ) if __name__ == "__main__": main()