Files
Felipe Cardoso 95342cc94d fix(mcp-gateway): address critical issues from deep review
Frontend:
- Fix debounce race condition in UserListTable search handler
- Use useRef to properly track and cleanup timeout between keystrokes

Backend (LLM Gateway):
- Add thread-safe double-checked locking for global singletons
  (providers, circuit registry, cost tracker)
- Fix Redis URL parsing with proper urlparse validation
- Add explicit error handling for malformed Redis URLs
- Document circuit breaker state transition safety

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-04 01:36:55 +01:00

368 lines
11 KiB
Python

"""
LiteLLM provider configuration for LLM Gateway.
Configures the LiteLLM Router with model lists and failover chains.
"""
import logging
import os
import threading
from typing import Any
from urllib.parse import urlparse
import litellm
from litellm import Router
from config import Settings, get_settings
from models import (
MODEL_CONFIGS,
MODEL_GROUPS,
ModelConfig,
ModelGroup,
Provider,
)
logger = logging.getLogger(__name__)
def configure_litellm(settings: Settings) -> None:
"""
Configure LiteLLM global settings.
Args:
settings: Application settings
"""
# Set API keys in environment (LiteLLM reads from env)
if settings.anthropic_api_key:
os.environ["ANTHROPIC_API_KEY"] = settings.anthropic_api_key
if settings.openai_api_key:
os.environ["OPENAI_API_KEY"] = settings.openai_api_key
if settings.google_api_key:
os.environ["GEMINI_API_KEY"] = settings.google_api_key
if settings.alibaba_api_key:
os.environ["DASHSCOPE_API_KEY"] = settings.alibaba_api_key
if settings.deepseek_api_key:
os.environ["DEEPSEEK_API_KEY"] = settings.deepseek_api_key
# Configure LiteLLM settings
litellm.drop_params = True # Drop unsupported params instead of erroring
litellm.set_verbose = settings.debug
# Configure caching if enabled
if settings.litellm_cache_enabled:
litellm.cache = litellm.Cache(
type="redis", # type: ignore[arg-type]
host=_parse_redis_host(settings.redis_url),
port=_parse_redis_port(settings.redis_url), # type: ignore[arg-type]
ttl=settings.litellm_cache_ttl,
)
def _parse_redis_host(redis_url: str) -> str:
"""
Extract host from Redis URL.
Args:
redis_url: Redis connection URL (e.g., redis://localhost:6379/0)
Returns:
Hostname extracted from URL
Raises:
ValueError: If URL is malformed or missing host
"""
try:
parsed = urlparse(redis_url)
if not parsed.hostname:
raise ValueError(f"Invalid Redis URL: missing hostname in '{redis_url}'")
return parsed.hostname
except Exception as e:
logger.error(f"Failed to parse Redis URL: {e}")
raise ValueError(f"Invalid Redis URL: {redis_url}") from e
def _parse_redis_port(redis_url: str) -> int:
"""
Extract port from Redis URL.
Args:
redis_url: Redis connection URL (e.g., redis://localhost:6379/0)
Returns:
Port number (defaults to 6379 if not specified)
Raises:
ValueError: If URL is malformed
"""
try:
parsed = urlparse(redis_url)
return parsed.port or 6379
except Exception as e:
logger.error(f"Failed to parse Redis URL: {e}")
raise ValueError(f"Invalid Redis URL: {redis_url}") from e
def _is_provider_available(provider: Provider, settings: Settings) -> bool:
"""Check if a provider is available (API key configured)."""
provider_key_map = {
Provider.ANTHROPIC: settings.anthropic_api_key,
Provider.OPENAI: settings.openai_api_key,
Provider.GOOGLE: settings.google_api_key,
Provider.ALIBABA: settings.alibaba_api_key,
Provider.DEEPSEEK: settings.deepseek_api_key or settings.deepseek_base_url,
}
return bool(provider_key_map.get(provider))
def _build_model_entry(
model_config: ModelConfig,
settings: Settings,
) -> dict[str, Any] | None:
"""
Build a model entry for LiteLLM Router.
Args:
model_config: Model configuration
settings: Application settings
Returns:
Model entry dict or None if provider unavailable
"""
if not _is_provider_available(model_config.provider, settings):
logger.debug(
f"Skipping model {model_config.name}: "
f"{model_config.provider.value} provider not configured"
)
return None
entry: dict[str, Any] = {
"model_name": model_config.name,
"litellm_params": {
"model": model_config.litellm_name,
"timeout": settings.litellm_timeout,
"max_retries": settings.litellm_max_retries,
},
}
# Add custom base URL for DeepSeek self-hosted
if model_config.provider == Provider.DEEPSEEK and settings.deepseek_base_url:
entry["litellm_params"]["api_base"] = settings.deepseek_base_url
return entry
def build_model_list(settings: Settings | None = None) -> list[dict[str, Any]]:
"""
Build the complete model list for LiteLLM Router.
Args:
settings: Application settings (uses default if None)
Returns:
List of model entries for Router
"""
if settings is None:
settings = get_settings()
model_list: list[dict[str, Any]] = []
for model_config in MODEL_CONFIGS.values():
entry = _build_model_entry(model_config, settings)
if entry:
model_list.append(entry)
logger.info(f"Built model list with {len(model_list)} models")
return model_list
def build_fallback_config(settings: Settings | None = None) -> dict[str, list[str]]:
"""
Build fallback configuration based on model groups.
Args:
settings: Application settings (uses default if None)
Returns:
Dict mapping model names to their fallback chains
"""
if settings is None:
settings = get_settings()
fallbacks: dict[str, list[str]] = {}
for _group, config in MODEL_GROUPS.items():
# Get available models in this group's chain
available_models = []
for model_name in config.get_all_models():
model_config = MODEL_CONFIGS.get(model_name)
if model_config and _is_provider_available(model_config.provider, settings):
available_models.append(model_name)
if len(available_models) > 1:
# First model falls back to remaining models
fallbacks[available_models[0]] = available_models[1:]
return fallbacks
def get_available_models(settings: Settings | None = None) -> dict[str, ModelConfig]:
"""
Get all available models (with configured providers).
Args:
settings: Application settings (uses default if None)
Returns:
Dict of available model configs
"""
if settings is None:
settings = get_settings()
available: dict[str, ModelConfig] = {}
for name, config in MODEL_CONFIGS.items():
if _is_provider_available(config.provider, settings):
available[name] = config
return available
def get_available_model_groups(
settings: Settings | None = None,
) -> dict[ModelGroup, list[str]]:
"""
Get available models for each model group.
Args:
settings: Application settings (uses default if None)
Returns:
Dict mapping model groups to available models
"""
if settings is None:
settings = get_settings()
result: dict[ModelGroup, list[str]] = {}
for group, config in MODEL_GROUPS.items():
available_models = []
for model_name in config.get_all_models():
model_config = MODEL_CONFIGS.get(model_name)
if model_config and _is_provider_available(model_config.provider, settings):
available_models.append(model_name)
result[group] = available_models
return result
class LLMProvider:
"""
LLM Provider wrapper around LiteLLM Router.
Provides a high-level interface for LLM operations with
automatic failover and configuration management.
"""
def __init__(self, settings: Settings | None = None) -> None:
"""
Initialize LLM Provider.
Args:
settings: Application settings (uses default if None)
"""
self._settings = settings or get_settings()
self._router: Router | None = None
self._initialized = False
def initialize(self) -> None:
"""Initialize the provider and LiteLLM Router."""
if self._initialized:
return
# Configure LiteLLM global settings
configure_litellm(self._settings)
# Build model list
model_list = build_model_list(self._settings)
if not model_list:
logger.warning("No models available - no providers configured")
self._initialized = True
return
# Build fallback config
fallbacks = build_fallback_config(self._settings)
# Create Router
self._router = Router(
model_list=model_list,
fallbacks=list(fallbacks.items()) if fallbacks else None, # type: ignore[arg-type]
routing_strategy="latency-based-routing",
num_retries=self._settings.litellm_max_retries,
timeout=self._settings.litellm_timeout,
retry_after=5, # Retry after 5 seconds
allowed_fails=2, # Fail after 2 consecutive failures
)
self._initialized = True
logger.info(
f"LLM Provider initialized with {len(model_list)} models, "
f"{len(fallbacks)} fallback chains"
)
@property
def router(self) -> Router | None:
"""Get the LiteLLM Router."""
if not self._initialized:
self.initialize()
return self._router
@property
def is_available(self) -> bool:
"""Check if provider is available."""
if not self._initialized:
self.initialize()
return self._router is not None
def get_model_config(self, model_name: str) -> ModelConfig | None:
"""Get configuration for a specific model."""
return MODEL_CONFIGS.get(model_name)
def get_available_models(self) -> dict[str, ModelConfig]:
"""Get all available models."""
return get_available_models(self._settings)
def is_model_available(self, model_name: str) -> bool:
"""Check if a specific model is available."""
model_config = MODEL_CONFIGS.get(model_name)
if not model_config:
return False
return _is_provider_available(model_config.provider, self._settings)
# Global provider instance with thread-safe lazy initialization
_provider: LLMProvider | None = None
_provider_lock = threading.Lock()
def get_provider() -> LLMProvider:
"""
Get the global LLM Provider instance.
Thread-safe with double-checked locking pattern.
"""
global _provider
if _provider is None:
with _provider_lock:
# Double-check after acquiring lock
if _provider is None:
_provider = LLMProvider()
return _provider
def reset_provider() -> None:
"""Reset the global provider (for testing)."""
global _provider
with _provider_lock:
_provider = None