""" Model routing for LLM Gateway. Handles model selection based on: - Model group configuration - Circuit breaker availability - Agent type preferences """ import logging from typing import Any from config import Settings, get_settings from exceptions import ( AllProvidersFailedError, InvalidModelError, InvalidModelGroupError, ModelNotAvailableError, ) from failover import CircuitBreakerRegistry, get_circuit_registry from models import ( AGENT_TYPE_MODEL_PREFERENCES, MODEL_CONFIGS, MODEL_GROUPS, ModelConfig, ModelGroup, ) from providers import get_available_models logger = logging.getLogger(__name__) class ModelRouter: """ Routes requests to appropriate models based on configuration. Considers: - Model group preferences - Circuit breaker states - Agent type defaults - Provider availability """ def __init__( self, settings: Settings | None = None, circuit_registry: CircuitBreakerRegistry | None = None, ) -> None: """ Initialize model router. Args: settings: Application settings circuit_registry: Circuit breaker registry """ self._settings = settings or get_settings() self._circuit_registry = circuit_registry or get_circuit_registry() def parse_model_group(self, group_str: str) -> ModelGroup: """ Parse model group from string. Args: group_str: Group name string Returns: ModelGroup enum value Raises: InvalidModelGroupError: If group is unknown """ # Handle aliases aliases = { "high-reasoning": ModelGroup.REASONING, "high_reasoning": ModelGroup.REASONING, "code-generation": ModelGroup.CODE, "code_generation": ModelGroup.CODE, "fast-response": ModelGroup.FAST, "fast_response": ModelGroup.FAST, } # Try direct enum value try: return ModelGroup(group_str.lower()) except ValueError: pass # Try aliases if group_str.lower() in aliases: return aliases[group_str.lower()] # Unknown group available = [g.value for g in ModelGroup] raise InvalidModelGroupError( model_group=group_str, available_groups=available, ) def get_model_config(self, model_name: str) -> ModelConfig: """ Get configuration for a specific model. Args: model_name: Model name Returns: Model configuration Raises: InvalidModelError: If model is unknown """ config = MODEL_CONFIGS.get(model_name) if not config: raise InvalidModelError( model=model_name, reason="Unknown model", ) return config def get_preferred_group_for_agent(self, agent_type: str) -> ModelGroup: """ Get preferred model group for an agent type. Args: agent_type: Agent type identifier Returns: Preferred ModelGroup """ return AGENT_TYPE_MODEL_PREFERENCES.get( agent_type.lower(), ModelGroup.REASONING, # Default to reasoning ) async def select_model( self, model_group: ModelGroup | str, model_override: str | None = None, agent_type: str | None = None, ) -> tuple[str, ModelConfig]: """ Select the best available model. Args: model_group: Desired model group model_override: Specific model to use (bypasses group routing) agent_type: Agent type for preference lookup Returns: Tuple of (model_name, model_config) Raises: InvalidModelError: If override model is invalid InvalidModelGroupError: If group is invalid AllProvidersFailedError: If no models are available """ # Handle model override if model_override: config = MODEL_CONFIGS.get(model_override) if not config: raise InvalidModelError( model=model_override, reason="Unknown model", ) # Check if model's provider is available (using router's settings) available_models = get_available_models(self._settings) if model_override not in available_models: raise ModelNotAvailableError( model=model_override, provider=config.provider.value, ) # Check circuit breaker circuit = self._circuit_registry.get_circuit_sync(config.provider.value) if not circuit.is_available(): raise ModelNotAvailableError( model=model_override, provider=f"{config.provider.value} (circuit open)", ) return model_override, config # Parse model group if string if isinstance(model_group, str): model_group = self.parse_model_group(model_group) # Get agent type preference if no explicit group if agent_type: preferred = self.get_preferred_group_for_agent(agent_type) logger.debug( f"Agent type {agent_type} prefers {preferred.value}, " f"requested {model_group.value}" ) # Get group configuration group_config = MODEL_GROUPS.get(model_group) if not group_config: raise InvalidModelGroupError( model_group=model_group.value, available_groups=[g.value for g in ModelGroup], ) # Get available models available_models = get_available_models(self._settings) # Try models in priority order errors: list[dict[str, Any]] = [] attempted: list[str] = [] for model_name in group_config.get_all_models(): attempted.append(model_name) # Check if model provider is configured config = MODEL_CONFIGS.get(model_name) if not config: errors.append({"model": model_name, "error": "Unknown model"}) continue if model_name not in available_models: errors.append( { "model": model_name, "error": f"Provider {config.provider.value} not configured", } ) continue # Check circuit breaker circuit = self._circuit_registry.get_circuit_sync(config.provider.value) if not circuit.is_available(): errors.append( { "model": model_name, "error": f"Circuit open for {config.provider.value}", } ) continue # Model is available logger.debug(f"Selected model {model_name} for group {model_group.value}") return model_name, config # No models available raise AllProvidersFailedError( model_group=model_group.value, attempted_models=attempted, errors=errors, ) async def get_available_models_for_group( self, model_group: ModelGroup | str, ) -> list[tuple[str, ModelConfig, bool]]: """ Get all models for a group with availability status. Args: model_group: Model group Returns: List of (model_name, config, is_available) tuples """ # Parse model group if string if isinstance(model_group, str): model_group = self.parse_model_group(model_group) group_config = MODEL_GROUPS.get(model_group) if not group_config: return [] available_models = get_available_models(self._settings) result: list[tuple[str, ModelConfig, bool]] = [] for model_name in group_config.get_all_models(): config = MODEL_CONFIGS.get(model_name) if not config: continue is_available = model_name in available_models if is_available: # Also check circuit breaker circuit = self._circuit_registry.get_circuit_sync(config.provider.value) is_available = circuit.is_available() result.append((model_name, config, is_available)) return result def get_all_model_groups(self) -> dict[str, dict[str, Any]]: """ Get information about all model groups. Returns: Dict of group info """ result: dict[str, dict[str, Any]] = {} for group, config in MODEL_GROUPS.items(): result[group.value] = { "description": config.description, "primary": config.primary, "fallbacks": config.fallbacks, } return result # Global router instance (lazy initialization) _router: ModelRouter | None = None def get_model_router() -> ModelRouter: """Get the global model router instance.""" global _router if _router is None: _router = ModelRouter() return _router def reset_model_router() -> None: """Reset the global router (for testing).""" global _router _router = None