Files
syndarix/mcp-servers/llm-gateway/tests/test_routing.py
Felipe Cardoso f482559e15 fix(llm-gateway): improve type safety and datetime consistency
- Add type annotations for mypy compliance
- Use UTC-aware datetimes consistently (datetime.now(UTC))
- Add type: ignore comments for LiteLLM incomplete stubs
- Fix import ordering and formatting
- Update pyproject.toml mypy configuration

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-03 20:56:05 +01:00

245 lines
8.9 KiB
Python

"""
Tests for routing module.
"""
import asyncio
import pytest
from config import Settings
from exceptions import (
AllProvidersFailedError,
InvalidModelError,
InvalidModelGroupError,
ModelNotAvailableError,
)
from failover import CircuitBreakerRegistry, reset_circuit_registry
from models import ModelGroup
from providers import reset_provider
from routing import ModelRouter, get_model_router, reset_model_router
@pytest.fixture
def router_settings() -> Settings:
"""Settings for routing tests."""
return Settings(
anthropic_api_key="test-key",
openai_api_key="test-key",
google_api_key="test-key",
)
@pytest.fixture
def router(router_settings: Settings) -> ModelRouter:
"""Create model router for testing."""
reset_circuit_registry()
reset_model_router()
reset_provider()
registry = CircuitBreakerRegistry(settings=router_settings)
return ModelRouter(settings=router_settings, circuit_registry=registry)
class TestModelRouter:
"""Tests for ModelRouter class."""
def test_parse_model_group_valid(self, router: ModelRouter) -> None:
"""Test parsing valid model groups."""
assert router.parse_model_group("reasoning") == ModelGroup.REASONING
assert router.parse_model_group("code") == ModelGroup.CODE
assert router.parse_model_group("fast") == ModelGroup.FAST
assert router.parse_model_group("REASONING") == ModelGroup.REASONING
def test_parse_model_group_aliases(self, router: ModelRouter) -> None:
"""Test parsing model group aliases."""
assert router.parse_model_group("high-reasoning") == ModelGroup.REASONING
assert router.parse_model_group("high_reasoning") == ModelGroup.REASONING
assert router.parse_model_group("code-generation") == ModelGroup.CODE
assert router.parse_model_group("fast-response") == ModelGroup.FAST
def test_parse_model_group_invalid(self, router: ModelRouter) -> None:
"""Test parsing invalid model group."""
with pytest.raises(InvalidModelGroupError) as exc_info:
router.parse_model_group("invalid_group")
assert exc_info.value.model_group == "invalid_group"
assert exc_info.value.available_groups is not None
def test_get_model_config_valid(self, router: ModelRouter) -> None:
"""Test getting valid model config."""
config = router.get_model_config("claude-opus-4")
assert config.name == "claude-opus-4"
assert config.provider.value == "anthropic"
def test_get_model_config_invalid(self, router: ModelRouter) -> None:
"""Test getting invalid model config."""
with pytest.raises(InvalidModelError) as exc_info:
router.get_model_config("nonexistent-model")
assert exc_info.value.model == "nonexistent-model"
def test_get_preferred_group_for_agent(self, router: ModelRouter) -> None:
"""Test getting preferred group for agent types."""
assert (
router.get_preferred_group_for_agent("product_owner")
== ModelGroup.REASONING
)
assert (
router.get_preferred_group_for_agent("software_engineer") == ModelGroup.CODE
)
assert (
router.get_preferred_group_for_agent("devops_engineer") == ModelGroup.FAST
)
def test_get_preferred_group_unknown_agent(self, router: ModelRouter) -> None:
"""Test getting preferred group for unknown agent."""
# Should default to REASONING
assert (
router.get_preferred_group_for_agent("unknown_type") == ModelGroup.REASONING
)
def test_select_model_by_group(self, router: ModelRouter) -> None:
"""Test selecting model by group."""
model_name, config = asyncio.run(
router.select_model(model_group=ModelGroup.REASONING)
)
assert model_name == "claude-opus-4"
assert config.provider.value == "anthropic"
def test_select_model_by_group_string(self, router: ModelRouter) -> None:
"""Test selecting model by group string."""
model_name, config = asyncio.run(router.select_model(model_group="code"))
assert model_name == "claude-sonnet-4"
def test_select_model_with_override(self, router: ModelRouter) -> None:
"""Test selecting specific model override."""
model_name, config = asyncio.run(
router.select_model(
model_group="reasoning",
model_override="gpt-4.1",
)
)
assert model_name == "gpt-4.1"
assert config.provider.value == "openai"
def test_select_model_override_invalid(self, router: ModelRouter) -> None:
"""Test selecting invalid model override."""
with pytest.raises(InvalidModelError):
asyncio.run(
router.select_model(
model_group="reasoning",
model_override="nonexistent-model",
)
)
def test_select_model_override_unavailable(self, router: ModelRouter) -> None: # noqa: ARG002
"""Test selecting unavailable model override."""
# Create router without Alibaba key
settings = Settings(
anthropic_api_key="test-key",
alibaba_api_key=None,
)
registry = CircuitBreakerRegistry(settings=settings)
limited_router = ModelRouter(settings=settings, circuit_registry=registry)
with pytest.raises(ModelNotAvailableError):
asyncio.run(
limited_router.select_model(
model_group="reasoning",
model_override="qwen-max",
)
)
def test_select_model_fallback_on_circuit_open(
self,
router: ModelRouter,
) -> None:
"""Test fallback when primary circuit is open."""
# Open circuit for anthropic
circuit = router._circuit_registry.get_circuit_sync("anthropic")
for _ in range(5):
asyncio.run(circuit.record_failure())
# Should fall back to OpenAI
model_name, config = asyncio.run(
router.select_model(model_group=ModelGroup.REASONING)
)
assert model_name == "gpt-4.1"
assert config.provider.value == "openai"
def test_select_model_all_unavailable(self) -> None:
"""Test when all providers are unavailable."""
settings = Settings(
anthropic_api_key=None,
openai_api_key=None,
google_api_key=None,
)
registry = CircuitBreakerRegistry(settings=settings)
limited_router = ModelRouter(settings=settings, circuit_registry=registry)
with pytest.raises(AllProvidersFailedError) as exc_info:
asyncio.run(limited_router.select_model(model_group=ModelGroup.REASONING))
assert exc_info.value.model_group == "reasoning"
assert len(exc_info.value.attempted_models) > 0
def test_get_available_models_for_group(self, router: ModelRouter) -> None:
"""Test getting available models for a group."""
models = asyncio.run(
router.get_available_models_for_group(ModelGroup.REASONING)
)
assert len(models) > 0
# Should be (name, config, available) tuples
for name, config, _available in models:
assert isinstance(name, str)
assert config is not None
def test_get_available_models_for_group_string(self, router: ModelRouter) -> None:
"""Test getting available models with string group."""
models = asyncio.run(router.get_available_models_for_group("code"))
assert len(models) > 0
def test_get_available_models_invalid_group(self, router: ModelRouter) -> None:
"""Test getting models for invalid group."""
with pytest.raises(InvalidModelGroupError):
asyncio.run(router.get_available_models_for_group("invalid"))
def test_get_all_model_groups(self, router: ModelRouter) -> None:
"""Test getting all model groups info."""
groups = router.get_all_model_groups()
assert len(groups) == len(ModelGroup)
assert "reasoning" in groups
assert "code" in groups
assert groups["reasoning"]["primary"] == "claude-opus-4"
class TestGlobalRouter:
"""Tests for global router functions."""
def test_get_model_router(self) -> None:
"""Test getting global router."""
reset_model_router()
router = get_model_router()
assert isinstance(router, ModelRouter)
def test_get_model_router_singleton(self) -> None:
"""Test router is singleton."""
reset_model_router()
router1 = get_model_router()
router2 = get_model_router()
assert router1 is router2
def test_reset_model_router(self) -> None:
"""Test resetting global router."""
reset_model_router()
router1 = get_model_router()
reset_model_router()
router2 = get_model_router()
assert router1 is not router2