""" Tests for routing module. """ import asyncio import pytest from config import Settings from exceptions import ( AllProvidersFailedError, InvalidModelError, InvalidModelGroupError, ModelNotAvailableError, ) from failover import CircuitBreakerRegistry, reset_circuit_registry from models import ModelGroup from providers import reset_provider from routing import ModelRouter, get_model_router, reset_model_router @pytest.fixture def router_settings() -> Settings: """Settings for routing tests.""" return Settings( anthropic_api_key="test-key", openai_api_key="test-key", google_api_key="test-key", ) @pytest.fixture def router(router_settings: Settings) -> ModelRouter: """Create model router for testing.""" reset_circuit_registry() reset_model_router() reset_provider() registry = CircuitBreakerRegistry(settings=router_settings) return ModelRouter(settings=router_settings, circuit_registry=registry) class TestModelRouter: """Tests for ModelRouter class.""" def test_parse_model_group_valid(self, router: ModelRouter) -> None: """Test parsing valid model groups.""" assert router.parse_model_group("reasoning") == ModelGroup.REASONING assert router.parse_model_group("code") == ModelGroup.CODE assert router.parse_model_group("fast") == ModelGroup.FAST assert router.parse_model_group("REASONING") == ModelGroup.REASONING def test_parse_model_group_aliases(self, router: ModelRouter) -> None: """Test parsing model group aliases.""" assert router.parse_model_group("high-reasoning") == ModelGroup.REASONING assert router.parse_model_group("high_reasoning") == ModelGroup.REASONING assert router.parse_model_group("code-generation") == ModelGroup.CODE assert router.parse_model_group("fast-response") == ModelGroup.FAST def test_parse_model_group_invalid(self, router: ModelRouter) -> None: """Test parsing invalid model group.""" with pytest.raises(InvalidModelGroupError) as exc_info: router.parse_model_group("invalid_group") assert exc_info.value.model_group == "invalid_group" assert exc_info.value.available_groups is not None def test_get_model_config_valid(self, router: ModelRouter) -> None: """Test getting valid model config.""" config = router.get_model_config("claude-opus-4") assert config.name == "claude-opus-4" assert config.provider.value == "anthropic" def test_get_model_config_invalid(self, router: ModelRouter) -> None: """Test getting invalid model config.""" with pytest.raises(InvalidModelError) as exc_info: router.get_model_config("nonexistent-model") assert exc_info.value.model == "nonexistent-model" def test_get_preferred_group_for_agent(self, router: ModelRouter) -> None: """Test getting preferred group for agent types.""" assert ( router.get_preferred_group_for_agent("product_owner") == ModelGroup.REASONING ) assert ( router.get_preferred_group_for_agent("software_engineer") == ModelGroup.CODE ) assert ( router.get_preferred_group_for_agent("devops_engineer") == ModelGroup.FAST ) def test_get_preferred_group_unknown_agent(self, router: ModelRouter) -> None: """Test getting preferred group for unknown agent.""" # Should default to REASONING assert ( router.get_preferred_group_for_agent("unknown_type") == ModelGroup.REASONING ) def test_select_model_by_group(self, router: ModelRouter) -> None: """Test selecting model by group.""" model_name, config = asyncio.run( router.select_model(model_group=ModelGroup.REASONING) ) assert model_name == "claude-opus-4" assert config.provider.value == "anthropic" def test_select_model_by_group_string(self, router: ModelRouter) -> None: """Test selecting model by group string.""" model_name, config = asyncio.run(router.select_model(model_group="code")) assert model_name == "claude-sonnet-4" def test_select_model_with_override(self, router: ModelRouter) -> None: """Test selecting specific model override.""" model_name, config = asyncio.run( router.select_model( model_group="reasoning", model_override="gpt-4.1", ) ) assert model_name == "gpt-4.1" assert config.provider.value == "openai" def test_select_model_override_invalid(self, router: ModelRouter) -> None: """Test selecting invalid model override.""" with pytest.raises(InvalidModelError): asyncio.run( router.select_model( model_group="reasoning", model_override="nonexistent-model", ) ) def test_select_model_override_unavailable(self, router: ModelRouter) -> None: # noqa: ARG002 """Test selecting unavailable model override.""" # Create router without Alibaba key settings = Settings( anthropic_api_key="test-key", alibaba_api_key=None, ) registry = CircuitBreakerRegistry(settings=settings) limited_router = ModelRouter(settings=settings, circuit_registry=registry) with pytest.raises(ModelNotAvailableError): asyncio.run( limited_router.select_model( model_group="reasoning", model_override="qwen-max", ) ) def test_select_model_fallback_on_circuit_open( self, router: ModelRouter, ) -> None: """Test fallback when primary circuit is open.""" # Open circuit for anthropic circuit = router._circuit_registry.get_circuit_sync("anthropic") for _ in range(5): asyncio.run(circuit.record_failure()) # Should fall back to OpenAI model_name, config = asyncio.run( router.select_model(model_group=ModelGroup.REASONING) ) assert model_name == "gpt-4.1" assert config.provider.value == "openai" def test_select_model_all_unavailable(self) -> None: """Test when all providers are unavailable.""" settings = Settings( anthropic_api_key=None, openai_api_key=None, google_api_key=None, ) registry = CircuitBreakerRegistry(settings=settings) limited_router = ModelRouter(settings=settings, circuit_registry=registry) with pytest.raises(AllProvidersFailedError) as exc_info: asyncio.run(limited_router.select_model(model_group=ModelGroup.REASONING)) assert exc_info.value.model_group == "reasoning" assert len(exc_info.value.attempted_models) > 0 def test_get_available_models_for_group(self, router: ModelRouter) -> None: """Test getting available models for a group.""" models = asyncio.run( router.get_available_models_for_group(ModelGroup.REASONING) ) assert len(models) > 0 # Should be (name, config, available) tuples for name, config, _available in models: assert isinstance(name, str) assert config is not None def test_get_available_models_for_group_string(self, router: ModelRouter) -> None: """Test getting available models with string group.""" models = asyncio.run(router.get_available_models_for_group("code")) assert len(models) > 0 def test_get_available_models_invalid_group(self, router: ModelRouter) -> None: """Test getting models for invalid group.""" with pytest.raises(InvalidModelGroupError): asyncio.run(router.get_available_models_for_group("invalid")) def test_get_all_model_groups(self, router: ModelRouter) -> None: """Test getting all model groups info.""" groups = router.get_all_model_groups() assert len(groups) == len(ModelGroup) assert "reasoning" in groups assert "code" in groups assert groups["reasoning"]["primary"] == "claude-opus-4" class TestGlobalRouter: """Tests for global router functions.""" def test_get_model_router(self) -> None: """Test getting global router.""" reset_model_router() router = get_model_router() assert isinstance(router, ModelRouter) def test_get_model_router_singleton(self) -> None: """Test router is singleton.""" reset_model_router() router1 = get_model_router() router2 = get_model_router() assert router1 is router2 def test_reset_model_router(self) -> None: """Test resetting global router.""" reset_model_router() router1 = get_model_router() reset_model_router() router2 = get_model_router() assert router1 is not router2