forked from cardosofelipe/fast-next-template
- Add type annotations for mypy compliance - Use UTC-aware datetimes consistently (datetime.now(UTC)) - Add type: ignore comments for LiteLLM incomplete stubs - Fix import ordering and formatting - Update pyproject.toml mypy configuration 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
322 lines
9.4 KiB
Python
322 lines
9.4 KiB
Python
"""
|
|
Model routing for LLM Gateway.
|
|
|
|
Handles model selection based on:
|
|
- Model group configuration
|
|
- Circuit breaker availability
|
|
- Agent type preferences
|
|
"""
|
|
|
|
import logging
|
|
from typing import Any
|
|
|
|
from config import Settings, get_settings
|
|
from exceptions import (
|
|
AllProvidersFailedError,
|
|
InvalidModelError,
|
|
InvalidModelGroupError,
|
|
ModelNotAvailableError,
|
|
)
|
|
from failover import CircuitBreakerRegistry, get_circuit_registry
|
|
from models import (
|
|
AGENT_TYPE_MODEL_PREFERENCES,
|
|
MODEL_CONFIGS,
|
|
MODEL_GROUPS,
|
|
ModelConfig,
|
|
ModelGroup,
|
|
)
|
|
from providers import get_available_models
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ModelRouter:
|
|
"""
|
|
Routes requests to appropriate models based on configuration.
|
|
|
|
Considers:
|
|
- Model group preferences
|
|
- Circuit breaker states
|
|
- Agent type defaults
|
|
- Provider availability
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
settings: Settings | None = None,
|
|
circuit_registry: CircuitBreakerRegistry | None = None,
|
|
) -> None:
|
|
"""
|
|
Initialize model router.
|
|
|
|
Args:
|
|
settings: Application settings
|
|
circuit_registry: Circuit breaker registry
|
|
"""
|
|
self._settings = settings or get_settings()
|
|
self._circuit_registry = circuit_registry or get_circuit_registry()
|
|
|
|
def parse_model_group(self, group_str: str) -> ModelGroup:
|
|
"""
|
|
Parse model group from string.
|
|
|
|
Args:
|
|
group_str: Group name string
|
|
|
|
Returns:
|
|
ModelGroup enum value
|
|
|
|
Raises:
|
|
InvalidModelGroupError: If group is unknown
|
|
"""
|
|
# Handle aliases
|
|
aliases = {
|
|
"high-reasoning": ModelGroup.REASONING,
|
|
"high_reasoning": ModelGroup.REASONING,
|
|
"code-generation": ModelGroup.CODE,
|
|
"code_generation": ModelGroup.CODE,
|
|
"fast-response": ModelGroup.FAST,
|
|
"fast_response": ModelGroup.FAST,
|
|
}
|
|
|
|
# Try direct enum value
|
|
try:
|
|
return ModelGroup(group_str.lower())
|
|
except ValueError:
|
|
pass
|
|
|
|
# Try aliases
|
|
if group_str.lower() in aliases:
|
|
return aliases[group_str.lower()]
|
|
|
|
# Unknown group
|
|
available = [g.value for g in ModelGroup]
|
|
raise InvalidModelGroupError(
|
|
model_group=group_str,
|
|
available_groups=available,
|
|
)
|
|
|
|
def get_model_config(self, model_name: str) -> ModelConfig:
|
|
"""
|
|
Get configuration for a specific model.
|
|
|
|
Args:
|
|
model_name: Model name
|
|
|
|
Returns:
|
|
Model configuration
|
|
|
|
Raises:
|
|
InvalidModelError: If model is unknown
|
|
"""
|
|
config = MODEL_CONFIGS.get(model_name)
|
|
if not config:
|
|
raise InvalidModelError(
|
|
model=model_name,
|
|
reason="Unknown model",
|
|
)
|
|
return config
|
|
|
|
def get_preferred_group_for_agent(self, agent_type: str) -> ModelGroup:
|
|
"""
|
|
Get preferred model group for an agent type.
|
|
|
|
Args:
|
|
agent_type: Agent type identifier
|
|
|
|
Returns:
|
|
Preferred ModelGroup
|
|
"""
|
|
return AGENT_TYPE_MODEL_PREFERENCES.get(
|
|
agent_type.lower(),
|
|
ModelGroup.REASONING, # Default to reasoning
|
|
)
|
|
|
|
async def select_model(
|
|
self,
|
|
model_group: ModelGroup | str,
|
|
model_override: str | None = None,
|
|
agent_type: str | None = None,
|
|
) -> tuple[str, ModelConfig]:
|
|
"""
|
|
Select the best available model.
|
|
|
|
Args:
|
|
model_group: Desired model group
|
|
model_override: Specific model to use (bypasses group routing)
|
|
agent_type: Agent type for preference lookup
|
|
|
|
Returns:
|
|
Tuple of (model_name, model_config)
|
|
|
|
Raises:
|
|
InvalidModelError: If override model is invalid
|
|
InvalidModelGroupError: If group is invalid
|
|
AllProvidersFailedError: If no models are available
|
|
"""
|
|
# Handle model override
|
|
if model_override:
|
|
config = MODEL_CONFIGS.get(model_override)
|
|
if not config:
|
|
raise InvalidModelError(
|
|
model=model_override,
|
|
reason="Unknown model",
|
|
)
|
|
# Check if model's provider is available (using router's settings)
|
|
available_models = get_available_models(self._settings)
|
|
if model_override not in available_models:
|
|
raise ModelNotAvailableError(
|
|
model=model_override,
|
|
provider=config.provider.value,
|
|
)
|
|
# Check circuit breaker
|
|
circuit = self._circuit_registry.get_circuit_sync(config.provider.value)
|
|
if not circuit.is_available():
|
|
raise ModelNotAvailableError(
|
|
model=model_override,
|
|
provider=f"{config.provider.value} (circuit open)",
|
|
)
|
|
return model_override, config
|
|
|
|
# Parse model group if string
|
|
if isinstance(model_group, str):
|
|
model_group = self.parse_model_group(model_group)
|
|
|
|
# Get agent type preference if no explicit group
|
|
if agent_type:
|
|
preferred = self.get_preferred_group_for_agent(agent_type)
|
|
logger.debug(
|
|
f"Agent type {agent_type} prefers {preferred.value}, "
|
|
f"requested {model_group.value}"
|
|
)
|
|
|
|
# Get group configuration
|
|
group_config = MODEL_GROUPS.get(model_group)
|
|
if not group_config:
|
|
raise InvalidModelGroupError(
|
|
model_group=model_group.value,
|
|
available_groups=[g.value for g in ModelGroup],
|
|
)
|
|
|
|
# Get available models
|
|
available_models = get_available_models(self._settings)
|
|
|
|
# Try models in priority order
|
|
errors: list[dict[str, Any]] = []
|
|
attempted: list[str] = []
|
|
|
|
for model_name in group_config.get_all_models():
|
|
attempted.append(model_name)
|
|
|
|
# Check if model provider is configured
|
|
config = MODEL_CONFIGS.get(model_name)
|
|
if not config:
|
|
errors.append({"model": model_name, "error": "Unknown model"})
|
|
continue
|
|
|
|
if model_name not in available_models:
|
|
errors.append(
|
|
{
|
|
"model": model_name,
|
|
"error": f"Provider {config.provider.value} not configured",
|
|
}
|
|
)
|
|
continue
|
|
|
|
# Check circuit breaker
|
|
circuit = self._circuit_registry.get_circuit_sync(config.provider.value)
|
|
if not circuit.is_available():
|
|
errors.append(
|
|
{
|
|
"model": model_name,
|
|
"error": f"Circuit open for {config.provider.value}",
|
|
}
|
|
)
|
|
continue
|
|
|
|
# Model is available
|
|
logger.debug(f"Selected model {model_name} for group {model_group.value}")
|
|
return model_name, config
|
|
|
|
# No models available
|
|
raise AllProvidersFailedError(
|
|
model_group=model_group.value,
|
|
attempted_models=attempted,
|
|
errors=errors,
|
|
)
|
|
|
|
async def get_available_models_for_group(
|
|
self,
|
|
model_group: ModelGroup | str,
|
|
) -> list[tuple[str, ModelConfig, bool]]:
|
|
"""
|
|
Get all models for a group with availability status.
|
|
|
|
Args:
|
|
model_group: Model group
|
|
|
|
Returns:
|
|
List of (model_name, config, is_available) tuples
|
|
"""
|
|
# Parse model group if string
|
|
if isinstance(model_group, str):
|
|
model_group = self.parse_model_group(model_group)
|
|
|
|
group_config = MODEL_GROUPS.get(model_group)
|
|
if not group_config:
|
|
return []
|
|
|
|
available_models = get_available_models(self._settings)
|
|
result: list[tuple[str, ModelConfig, bool]] = []
|
|
|
|
for model_name in group_config.get_all_models():
|
|
config = MODEL_CONFIGS.get(model_name)
|
|
if not config:
|
|
continue
|
|
|
|
is_available = model_name in available_models
|
|
if is_available:
|
|
# Also check circuit breaker
|
|
circuit = self._circuit_registry.get_circuit_sync(config.provider.value)
|
|
is_available = circuit.is_available()
|
|
|
|
result.append((model_name, config, is_available))
|
|
|
|
return result
|
|
|
|
def get_all_model_groups(self) -> dict[str, dict[str, Any]]:
|
|
"""
|
|
Get information about all model groups.
|
|
|
|
Returns:
|
|
Dict of group info
|
|
"""
|
|
result: dict[str, dict[str, Any]] = {}
|
|
|
|
for group, config in MODEL_GROUPS.items():
|
|
result[group.value] = {
|
|
"description": config.description,
|
|
"primary": config.primary,
|
|
"fallbacks": config.fallbacks,
|
|
}
|
|
|
|
return result
|
|
|
|
|
|
# Global router instance (lazy initialization)
|
|
_router: ModelRouter | None = None
|
|
|
|
|
|
def get_model_router() -> ModelRouter:
|
|
"""Get the global model router instance."""
|
|
global _router
|
|
if _router is None:
|
|
_router = ModelRouter()
|
|
return _router
|
|
|
|
|
|
def reset_model_router() -> None:
|
|
"""Reset the global router (for testing)."""
|
|
global _router
|
|
_router = None
|