Files
Felipe Cardoso f482559e15 fix(llm-gateway): improve type safety and datetime consistency
- Add type annotations for mypy compliance
- Use UTC-aware datetimes consistently (datetime.now(UTC))
- Add type: ignore comments for LiteLLM incomplete stubs
- Fix import ordering and formatting
- Update pyproject.toml mypy configuration

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-03 20:56:05 +01:00

322 lines
9.4 KiB
Python

"""
Model routing for LLM Gateway.
Handles model selection based on:
- Model group configuration
- Circuit breaker availability
- Agent type preferences
"""
import logging
from typing import Any
from config import Settings, get_settings
from exceptions import (
AllProvidersFailedError,
InvalidModelError,
InvalidModelGroupError,
ModelNotAvailableError,
)
from failover import CircuitBreakerRegistry, get_circuit_registry
from models import (
AGENT_TYPE_MODEL_PREFERENCES,
MODEL_CONFIGS,
MODEL_GROUPS,
ModelConfig,
ModelGroup,
)
from providers import get_available_models
logger = logging.getLogger(__name__)
class ModelRouter:
"""
Routes requests to appropriate models based on configuration.
Considers:
- Model group preferences
- Circuit breaker states
- Agent type defaults
- Provider availability
"""
def __init__(
self,
settings: Settings | None = None,
circuit_registry: CircuitBreakerRegistry | None = None,
) -> None:
"""
Initialize model router.
Args:
settings: Application settings
circuit_registry: Circuit breaker registry
"""
self._settings = settings or get_settings()
self._circuit_registry = circuit_registry or get_circuit_registry()
def parse_model_group(self, group_str: str) -> ModelGroup:
"""
Parse model group from string.
Args:
group_str: Group name string
Returns:
ModelGroup enum value
Raises:
InvalidModelGroupError: If group is unknown
"""
# Handle aliases
aliases = {
"high-reasoning": ModelGroup.REASONING,
"high_reasoning": ModelGroup.REASONING,
"code-generation": ModelGroup.CODE,
"code_generation": ModelGroup.CODE,
"fast-response": ModelGroup.FAST,
"fast_response": ModelGroup.FAST,
}
# Try direct enum value
try:
return ModelGroup(group_str.lower())
except ValueError:
pass
# Try aliases
if group_str.lower() in aliases:
return aliases[group_str.lower()]
# Unknown group
available = [g.value for g in ModelGroup]
raise InvalidModelGroupError(
model_group=group_str,
available_groups=available,
)
def get_model_config(self, model_name: str) -> ModelConfig:
"""
Get configuration for a specific model.
Args:
model_name: Model name
Returns:
Model configuration
Raises:
InvalidModelError: If model is unknown
"""
config = MODEL_CONFIGS.get(model_name)
if not config:
raise InvalidModelError(
model=model_name,
reason="Unknown model",
)
return config
def get_preferred_group_for_agent(self, agent_type: str) -> ModelGroup:
"""
Get preferred model group for an agent type.
Args:
agent_type: Agent type identifier
Returns:
Preferred ModelGroup
"""
return AGENT_TYPE_MODEL_PREFERENCES.get(
agent_type.lower(),
ModelGroup.REASONING, # Default to reasoning
)
async def select_model(
self,
model_group: ModelGroup | str,
model_override: str | None = None,
agent_type: str | None = None,
) -> tuple[str, ModelConfig]:
"""
Select the best available model.
Args:
model_group: Desired model group
model_override: Specific model to use (bypasses group routing)
agent_type: Agent type for preference lookup
Returns:
Tuple of (model_name, model_config)
Raises:
InvalidModelError: If override model is invalid
InvalidModelGroupError: If group is invalid
AllProvidersFailedError: If no models are available
"""
# Handle model override
if model_override:
config = MODEL_CONFIGS.get(model_override)
if not config:
raise InvalidModelError(
model=model_override,
reason="Unknown model",
)
# Check if model's provider is available (using router's settings)
available_models = get_available_models(self._settings)
if model_override not in available_models:
raise ModelNotAvailableError(
model=model_override,
provider=config.provider.value,
)
# Check circuit breaker
circuit = self._circuit_registry.get_circuit_sync(config.provider.value)
if not circuit.is_available():
raise ModelNotAvailableError(
model=model_override,
provider=f"{config.provider.value} (circuit open)",
)
return model_override, config
# Parse model group if string
if isinstance(model_group, str):
model_group = self.parse_model_group(model_group)
# Get agent type preference if no explicit group
if agent_type:
preferred = self.get_preferred_group_for_agent(agent_type)
logger.debug(
f"Agent type {agent_type} prefers {preferred.value}, "
f"requested {model_group.value}"
)
# Get group configuration
group_config = MODEL_GROUPS.get(model_group)
if not group_config:
raise InvalidModelGroupError(
model_group=model_group.value,
available_groups=[g.value for g in ModelGroup],
)
# Get available models
available_models = get_available_models(self._settings)
# Try models in priority order
errors: list[dict[str, Any]] = []
attempted: list[str] = []
for model_name in group_config.get_all_models():
attempted.append(model_name)
# Check if model provider is configured
config = MODEL_CONFIGS.get(model_name)
if not config:
errors.append({"model": model_name, "error": "Unknown model"})
continue
if model_name not in available_models:
errors.append(
{
"model": model_name,
"error": f"Provider {config.provider.value} not configured",
}
)
continue
# Check circuit breaker
circuit = self._circuit_registry.get_circuit_sync(config.provider.value)
if not circuit.is_available():
errors.append(
{
"model": model_name,
"error": f"Circuit open for {config.provider.value}",
}
)
continue
# Model is available
logger.debug(f"Selected model {model_name} for group {model_group.value}")
return model_name, config
# No models available
raise AllProvidersFailedError(
model_group=model_group.value,
attempted_models=attempted,
errors=errors,
)
async def get_available_models_for_group(
self,
model_group: ModelGroup | str,
) -> list[tuple[str, ModelConfig, bool]]:
"""
Get all models for a group with availability status.
Args:
model_group: Model group
Returns:
List of (model_name, config, is_available) tuples
"""
# Parse model group if string
if isinstance(model_group, str):
model_group = self.parse_model_group(model_group)
group_config = MODEL_GROUPS.get(model_group)
if not group_config:
return []
available_models = get_available_models(self._settings)
result: list[tuple[str, ModelConfig, bool]] = []
for model_name in group_config.get_all_models():
config = MODEL_CONFIGS.get(model_name)
if not config:
continue
is_available = model_name in available_models
if is_available:
# Also check circuit breaker
circuit = self._circuit_registry.get_circuit_sync(config.provider.value)
is_available = circuit.is_available()
result.append((model_name, config, is_available))
return result
def get_all_model_groups(self) -> dict[str, dict[str, Any]]:
"""
Get information about all model groups.
Returns:
Dict of group info
"""
result: dict[str, dict[str, Any]] = {}
for group, config in MODEL_GROUPS.items():
result[group.value] = {
"description": config.description,
"primary": config.primary,
"fallbacks": config.fallbacks,
}
return result
# Global router instance (lazy initialization)
_router: ModelRouter | None = None
def get_model_router() -> ModelRouter:
"""Get the global model router instance."""
global _router
if _router is None:
_router = ModelRouter()
return _router
def reset_model_router() -> None:
"""Reset the global router (for testing)."""
global _router
_router = None