""" Syndarix LLM Gateway MCP Server. Provides unified LLM access with: - Multi-provider support (Claude, GPT, Gemini, Qwen, DeepSeek) - Automatic failover chains - Cost tracking via Redis - Model group routing (reasoning, code, fast, vision, embedding) - Circuit breaker protection Per ADR-004: LLM Provider Abstraction. """ import logging import uuid from collections.abc import AsyncIterator from contextlib import asynccontextmanager from typing import Any import tiktoken from fastapi import FastAPI from fastmcp import FastMCP from config import get_settings from cost_tracking import calculate_cost, get_cost_tracker from exceptions import ( AllProvidersFailedError, CircuitOpenError, CostLimitExceededError, InvalidModelError, InvalidModelGroupError, LLMGatewayError, ModelNotAvailableError, ) from failover import get_circuit_registry from models import ( MODEL_CONFIGS, MODEL_GROUPS, ModelGroup, ) from providers import get_provider from routing import get_model_router # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) # Create FastMCP server mcp = FastMCP("syndarix-llm-gateway") @asynccontextmanager async def lifespan(_app: FastAPI) -> AsyncIterator[None]: """Application lifespan handler.""" settings = get_settings() logger.info(f"Starting LLM Gateway on {settings.host}:{settings.port}") # Initialize provider provider = get_provider() provider.initialize() yield # Cleanup from cost_tracking import close_cost_tracker await close_cost_tracker() logger.info("LLM Gateway shutdown complete") # Create FastAPI app that wraps FastMCP app = FastAPI( title="Syndarix LLM Gateway", description="MCP Server for unified LLM access", version="0.1.0", lifespan=lifespan, ) # Health endpoint @app.get("/health") async def health_check() -> dict[str, Any]: """Health check endpoint.""" settings = get_settings() provider = get_provider() return { "status": "healthy", "service": "llm-gateway", "providers_configured": settings.get_available_providers(), "provider_available": provider.is_available, } # Tool discovery endpoint (for MCP client compatibility) @app.get("/mcp/tools") async def list_tools() -> dict[str, Any]: """List available MCP tools.""" return { "tools": [ { "name": "chat_completion", "description": "Generate a chat completion using the specified model group", "inputSchema": { "type": "object", "properties": { "project_id": {"type": "string", "description": "Project ID"}, "agent_id": {"type": "string", "description": "Agent ID"}, "messages": { "type": "array", "items": { "type": "object", "properties": { "role": {"type": "string"}, "content": {"type": "string"}, }, "required": ["role", "content"], }, }, "model_group": { "type": "string", "enum": [g.value for g in ModelGroup], "default": "reasoning", }, "max_tokens": {"type": "integer", "default": 4096}, "temperature": {"type": "number", "default": 0.7}, "stream": {"type": "boolean", "default": False}, }, "required": ["project_id", "agent_id", "messages"], }, }, { "name": "list_models", "description": "List available models and model groups", "inputSchema": { "type": "object", "properties": { "project_id": {"type": "string"}, "agent_id": {"type": "string"}, "model_group": {"type": "string"}, }, "required": ["project_id", "agent_id"], }, }, { "name": "get_usage", "description": "Get usage statistics for a project or agent", "inputSchema": { "type": "object", "properties": { "project_id": {"type": "string"}, "agent_id": {"type": "string"}, "period": { "type": "string", "enum": ["hour", "day", "month"], "default": "day", }, }, "required": ["project_id", "agent_id"], }, }, { "name": "count_tokens", "description": "Count tokens in text", "inputSchema": { "type": "object", "properties": { "project_id": {"type": "string"}, "agent_id": {"type": "string"}, "text": {"type": "string"}, "model": {"type": "string"}, }, "required": ["project_id", "agent_id", "text"], }, }, ] } # JSON-RPC endpoint (for MCP client compatibility) @app.post("/mcp") async def jsonrpc_handler(request: dict[str, Any]) -> dict[str, Any]: """Handle JSON-RPC 2.0 requests for MCP tools.""" # Validate JSON-RPC structure if request.get("jsonrpc") != "2.0": return { "jsonrpc": "2.0", "error": {"code": -32600, "message": "Invalid JSON-RPC version"}, "id": request.get("id"), } method = request.get("method") params = request.get("params", {}) request_id = request.get("id") # Handle tool calls if method == "tools/call": tool_name = params.get("name") tool_args = params.get("arguments", {}) try: if tool_name == "chat_completion": result = await _impl_chat_completion(**tool_args) elif tool_name == "list_models": result = await _impl_list_models(**tool_args) elif tool_name == "get_usage": result = await _impl_get_usage(**tool_args) elif tool_name == "count_tokens": result = await _impl_count_tokens(**tool_args) else: return { "jsonrpc": "2.0", "error": {"code": -32601, "message": f"Unknown tool: {tool_name}"}, "id": request_id, } return { "jsonrpc": "2.0", "result": {"content": [{"type": "text", "text": str(result)}]}, "id": request_id, } except LLMGatewayError as e: return { "jsonrpc": "2.0", "error": {"code": -32000, "message": str(e), "data": e.to_dict()}, "id": request_id, } except Exception as e: logger.exception(f"Error executing tool {tool_name}") return { "jsonrpc": "2.0", "error": {"code": -32603, "message": str(e)}, "id": request_id, } # Handle tool listing elif method == "tools/list": tools_response = await list_tools() return { "jsonrpc": "2.0", "result": tools_response, "id": request_id, } else: return { "jsonrpc": "2.0", "error": {"code": -32601, "message": f"Unknown method: {method}"}, "id": request_id, } # ============================================================================ # Core Implementation Functions # ============================================================================ async def _impl_chat_completion( project_id: str, agent_id: str, messages: list[dict[str, Any]], model_group: str = "reasoning", model_override: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, stream: bool = False, session_id: str | None = None, ) -> dict[str, Any]: """Core implementation for chat completion.""" settings = get_settings() router = get_model_router() tracker = get_cost_tracker() circuit_registry = get_circuit_registry() # Check budget before making request if settings.cost_tracking_enabled: within_budget, current_cost, limit = await tracker.check_budget(project_id) if not within_budget: raise CostLimitExceededError( entity_type="project", entity_id=project_id, current_cost=current_cost, limit=limit, ) # Select model try: model_name, model_config = await router.select_model( model_group=model_group, model_override=model_override, ) except (InvalidModelGroupError, InvalidModelError, AllProvidersFailedError): raise except ModelNotAvailableError: raise # Get provider provider = get_provider() if not provider.router: raise AllProvidersFailedError( model_group=model_group, attempted_models=[model_name], errors=[{"error": "No providers configured"}], ) # Generate request ID request_id = str(uuid.uuid4()) try: # Get circuit breaker for this provider circuit = circuit_registry.get_circuit_sync(model_config.provider.value) # Make completion request if stream: # Return streaming response info # Actual streaming would be handled by a separate endpoint return { "status": "streaming_not_supported_via_tool", "message": "Use /stream endpoint for streaming responses", "request_id": request_id, } # Non-streaming completion response = await provider.router.acompletion( model=model_name, messages=messages, # type: ignore[arg-type] max_tokens=max_tokens, temperature=temperature, ) # Record success await circuit.record_success() # Extract response data content = response.choices[0].message.content or "" # type: ignore[union-attr] finish_reason = response.choices[0].finish_reason or "stop" # Get usage stats prompt_tokens = response.usage.prompt_tokens if response.usage else 0 # type: ignore[attr-defined] completion_tokens = response.usage.completion_tokens if response.usage else 0 # type: ignore[attr-defined] # Calculate cost cost_usd = calculate_cost(model_name, prompt_tokens, completion_tokens) # Record usage await tracker.record_usage( project_id=project_id, agent_id=agent_id, model=model_name, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, cost_usd=cost_usd, session_id=session_id, request_id=request_id, ) return { "id": request_id, "model": model_name, "provider": model_config.provider.value, "content": content, "finish_reason": finish_reason, "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens, "cost_usd": cost_usd, }, } except CircuitOpenError: raise except Exception as e: # Record failure circuit = circuit_registry.get_circuit_sync(model_config.provider.value) await circuit.record_failure(e) logger.error(f"Completion failed: {e}") raise AllProvidersFailedError( model_group=model_group, attempted_models=[model_name], errors=[{"model": model_name, "error": str(e)}], ) async def _impl_list_models( project_id: str, agent_id: str, model_group: str | None = None, ) -> dict[str, Any]: """Core implementation for list_models.""" settings = get_settings() provider = get_provider() router = get_model_router() # Get available providers available_providers = settings.get_available_providers() result: dict[str, Any] = { "project_id": project_id, "agent_id": agent_id, "available_providers": available_providers, } if model_group: # List models for specific group try: parsed_group = router.parse_model_group(model_group) models = await router.get_available_models_for_group(parsed_group) result["model_group"] = model_group result["models"] = [ { "name": name, "provider": config.provider.value, "available": available, "cost_per_1m_input": config.cost_per_1m_input, "cost_per_1m_output": config.cost_per_1m_output, } for name, config, available in models ] except InvalidModelGroupError as e: result["error"] = e.to_dict() else: # List all model groups groups: dict[str, Any] = {} for group in ModelGroup: group_config = MODEL_GROUPS.get(group) if group_config: models = await router.get_available_models_for_group(group) available_count = sum(1 for _, _, avail in models if avail) groups[group.value] = { "description": group_config.description, "primary": group_config.primary, "fallbacks": group_config.fallbacks, "available_models": available_count, "total_models": len(models), } result["model_groups"] = groups # List all models all_models: list[dict[str, Any]] = [] available_models = provider.get_available_models() for name, config in MODEL_CONFIGS.items(): all_models.append( { "name": name, "provider": config.provider.value, "available": name in available_models, "cost_per_1m_input": config.cost_per_1m_input, "cost_per_1m_output": config.cost_per_1m_output, "context_window": config.context_window, "max_output_tokens": config.max_output_tokens, "supports_vision": config.supports_vision, "supports_streaming": config.supports_streaming, } ) result["models"] = all_models return result async def _impl_get_usage( project_id: str, agent_id: str, period: str = "day", ) -> dict[str, Any]: """Core implementation for get_usage.""" tracker = get_cost_tracker() # Get project usage project_report = await tracker.get_project_usage(project_id, period=period) # Get agent usage agent_report = await tracker.get_agent_usage(agent_id, period=period) return { "project_id": project_id, "agent_id": agent_id, "period": period, "project_usage": { "total_requests": project_report.total_requests, "total_tokens": project_report.total_tokens, "total_cost_usd": project_report.total_cost_usd, "by_model": project_report.by_model, "period_start": project_report.period_start.isoformat(), "period_end": project_report.period_end.isoformat(), }, "agent_usage": { "total_requests": agent_report.total_requests, "total_tokens": agent_report.total_tokens, "total_cost_usd": agent_report.total_cost_usd, "by_model": agent_report.by_model, "period_start": agent_report.period_start.isoformat(), "period_end": agent_report.period_end.isoformat(), }, } async def _impl_count_tokens( project_id: str, agent_id: str, text: str, model: str | None = None, ) -> dict[str, Any]: """Core implementation for count_tokens.""" # Use tiktoken for token counting # Default to cl100k_base (used by GPT-4, Claude, etc.) try: if model and model.startswith("gpt"): encoding = tiktoken.encoding_for_model(model) else: encoding = tiktoken.get_encoding("cl100k_base") token_count = len(encoding.encode(text)) except Exception as e: logger.warning(f"Token counting failed: {e}, using estimate") # Fallback: rough estimate of ~4 chars per token token_count = len(text) // 4 # Estimate costs for different models cost_estimates: dict[str, float] = {} for model_name, config in MODEL_CONFIGS.items(): if config.cost_per_1m_input > 0: cost = (token_count / 1_000_000) * config.cost_per_1m_input cost_estimates[model_name] = round(cost, 6) return { "project_id": project_id, "agent_id": agent_id, "token_count": token_count, "text_length": len(text), "encoding": "cl100k_base", "cost_estimates": cost_estimates, } # ============================================================================ # MCP Tools (wrappers around core implementations) # ============================================================================ @mcp.tool() async def chat_completion( project_id: str, agent_id: str, messages: list[dict[str, Any]], model_group: str = "reasoning", model_override: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, stream: bool = False, session_id: str | None = None, ) -> dict[str, Any]: """ Generate a chat completion using the specified model group. Args: project_id: UUID of the project (required for cost attribution) agent_id: UUID of the agent instance making the request messages: List of message dicts with 'role' and 'content' model_group: Model routing group (reasoning, code, fast, vision, embedding) model_override: Specific model to use (bypasses group routing) max_tokens: Maximum tokens to generate temperature: Sampling temperature (0.0-2.0) stream: Enable streaming response session_id: Optional session ID for tracking Returns: Completion response with content and usage statistics """ return await _impl_chat_completion( project_id=project_id, agent_id=agent_id, messages=messages, model_group=model_group, model_override=model_override, max_tokens=max_tokens, temperature=temperature, stream=stream, session_id=session_id, ) @mcp.tool() async def list_models( project_id: str, agent_id: str, model_group: str | None = None, ) -> dict[str, Any]: """ List available models and model groups. Args: project_id: UUID of the project agent_id: UUID of the agent instance model_group: Optional specific group to list Returns: Dictionary of available models and groups """ return await _impl_list_models( project_id=project_id, agent_id=agent_id, model_group=model_group, ) @mcp.tool() async def get_usage( project_id: str, agent_id: str, period: str = "day", ) -> dict[str, Any]: """ Get usage statistics for a project or agent. Args: project_id: UUID of the project agent_id: UUID of the agent period: Time period (hour, day, month) Returns: Usage statistics including tokens and costs """ return await _impl_get_usage( project_id=project_id, agent_id=agent_id, period=period, ) @mcp.tool() async def count_tokens( project_id: str, agent_id: str, text: str, model: str | None = None, ) -> dict[str, Any]: """ Count tokens in text. Args: project_id: UUID of the project agent_id: UUID of the agent text: Text to count tokens in model: Optional model for tokenizer selection Returns: Token count and estimation details """ return await _impl_count_tokens( project_id=project_id, agent_id=agent_id, text=text, model=model, ) def main() -> None: """Run the server.""" import uvicorn settings = get_settings() uvicorn.run( "server:app", host=settings.host, port=settings.port, reload=settings.debug, ) if __name__ == "__main__": main()