forked from cardosofelipe/fast-next-template
- Add type annotations for mypy compliance - Use UTC-aware datetimes consistently (datetime.now(UTC)) - Add type: ignore comments for LiteLLM incomplete stubs - Fix import ordering and formatting - Update pyproject.toml mypy configuration 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
245 lines
8.9 KiB
Python
245 lines
8.9 KiB
Python
"""
|
|
Tests for routing module.
|
|
"""
|
|
|
|
import asyncio
|
|
|
|
import pytest
|
|
|
|
from config import Settings
|
|
from exceptions import (
|
|
AllProvidersFailedError,
|
|
InvalidModelError,
|
|
InvalidModelGroupError,
|
|
ModelNotAvailableError,
|
|
)
|
|
from failover import CircuitBreakerRegistry, reset_circuit_registry
|
|
from models import ModelGroup
|
|
from providers import reset_provider
|
|
from routing import ModelRouter, get_model_router, reset_model_router
|
|
|
|
|
|
@pytest.fixture
|
|
def router_settings() -> Settings:
|
|
"""Settings for routing tests."""
|
|
return Settings(
|
|
anthropic_api_key="test-key",
|
|
openai_api_key="test-key",
|
|
google_api_key="test-key",
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def router(router_settings: Settings) -> ModelRouter:
|
|
"""Create model router for testing."""
|
|
reset_circuit_registry()
|
|
reset_model_router()
|
|
reset_provider()
|
|
registry = CircuitBreakerRegistry(settings=router_settings)
|
|
return ModelRouter(settings=router_settings, circuit_registry=registry)
|
|
|
|
|
|
class TestModelRouter:
|
|
"""Tests for ModelRouter class."""
|
|
|
|
def test_parse_model_group_valid(self, router: ModelRouter) -> None:
|
|
"""Test parsing valid model groups."""
|
|
assert router.parse_model_group("reasoning") == ModelGroup.REASONING
|
|
assert router.parse_model_group("code") == ModelGroup.CODE
|
|
assert router.parse_model_group("fast") == ModelGroup.FAST
|
|
assert router.parse_model_group("REASONING") == ModelGroup.REASONING
|
|
|
|
def test_parse_model_group_aliases(self, router: ModelRouter) -> None:
|
|
"""Test parsing model group aliases."""
|
|
assert router.parse_model_group("high-reasoning") == ModelGroup.REASONING
|
|
assert router.parse_model_group("high_reasoning") == ModelGroup.REASONING
|
|
assert router.parse_model_group("code-generation") == ModelGroup.CODE
|
|
assert router.parse_model_group("fast-response") == ModelGroup.FAST
|
|
|
|
def test_parse_model_group_invalid(self, router: ModelRouter) -> None:
|
|
"""Test parsing invalid model group."""
|
|
with pytest.raises(InvalidModelGroupError) as exc_info:
|
|
router.parse_model_group("invalid_group")
|
|
|
|
assert exc_info.value.model_group == "invalid_group"
|
|
assert exc_info.value.available_groups is not None
|
|
|
|
def test_get_model_config_valid(self, router: ModelRouter) -> None:
|
|
"""Test getting valid model config."""
|
|
config = router.get_model_config("claude-opus-4")
|
|
assert config.name == "claude-opus-4"
|
|
assert config.provider.value == "anthropic"
|
|
|
|
def test_get_model_config_invalid(self, router: ModelRouter) -> None:
|
|
"""Test getting invalid model config."""
|
|
with pytest.raises(InvalidModelError) as exc_info:
|
|
router.get_model_config("nonexistent-model")
|
|
|
|
assert exc_info.value.model == "nonexistent-model"
|
|
|
|
def test_get_preferred_group_for_agent(self, router: ModelRouter) -> None:
|
|
"""Test getting preferred group for agent types."""
|
|
assert (
|
|
router.get_preferred_group_for_agent("product_owner")
|
|
== ModelGroup.REASONING
|
|
)
|
|
assert (
|
|
router.get_preferred_group_for_agent("software_engineer") == ModelGroup.CODE
|
|
)
|
|
assert (
|
|
router.get_preferred_group_for_agent("devops_engineer") == ModelGroup.FAST
|
|
)
|
|
|
|
def test_get_preferred_group_unknown_agent(self, router: ModelRouter) -> None:
|
|
"""Test getting preferred group for unknown agent."""
|
|
# Should default to REASONING
|
|
assert (
|
|
router.get_preferred_group_for_agent("unknown_type") == ModelGroup.REASONING
|
|
)
|
|
|
|
def test_select_model_by_group(self, router: ModelRouter) -> None:
|
|
"""Test selecting model by group."""
|
|
model_name, config = asyncio.run(
|
|
router.select_model(model_group=ModelGroup.REASONING)
|
|
)
|
|
|
|
assert model_name == "claude-opus-4"
|
|
assert config.provider.value == "anthropic"
|
|
|
|
def test_select_model_by_group_string(self, router: ModelRouter) -> None:
|
|
"""Test selecting model by group string."""
|
|
model_name, config = asyncio.run(router.select_model(model_group="code"))
|
|
|
|
assert model_name == "claude-sonnet-4"
|
|
|
|
def test_select_model_with_override(self, router: ModelRouter) -> None:
|
|
"""Test selecting specific model override."""
|
|
model_name, config = asyncio.run(
|
|
router.select_model(
|
|
model_group="reasoning",
|
|
model_override="gpt-4.1",
|
|
)
|
|
)
|
|
|
|
assert model_name == "gpt-4.1"
|
|
assert config.provider.value == "openai"
|
|
|
|
def test_select_model_override_invalid(self, router: ModelRouter) -> None:
|
|
"""Test selecting invalid model override."""
|
|
with pytest.raises(InvalidModelError):
|
|
asyncio.run(
|
|
router.select_model(
|
|
model_group="reasoning",
|
|
model_override="nonexistent-model",
|
|
)
|
|
)
|
|
|
|
def test_select_model_override_unavailable(self, router: ModelRouter) -> None: # noqa: ARG002
|
|
"""Test selecting unavailable model override."""
|
|
# Create router without Alibaba key
|
|
settings = Settings(
|
|
anthropic_api_key="test-key",
|
|
alibaba_api_key=None,
|
|
)
|
|
registry = CircuitBreakerRegistry(settings=settings)
|
|
limited_router = ModelRouter(settings=settings, circuit_registry=registry)
|
|
|
|
with pytest.raises(ModelNotAvailableError):
|
|
asyncio.run(
|
|
limited_router.select_model(
|
|
model_group="reasoning",
|
|
model_override="qwen-max",
|
|
)
|
|
)
|
|
|
|
def test_select_model_fallback_on_circuit_open(
|
|
self,
|
|
router: ModelRouter,
|
|
) -> None:
|
|
"""Test fallback when primary circuit is open."""
|
|
# Open circuit for anthropic
|
|
circuit = router._circuit_registry.get_circuit_sync("anthropic")
|
|
for _ in range(5):
|
|
asyncio.run(circuit.record_failure())
|
|
|
|
# Should fall back to OpenAI
|
|
model_name, config = asyncio.run(
|
|
router.select_model(model_group=ModelGroup.REASONING)
|
|
)
|
|
|
|
assert model_name == "gpt-4.1"
|
|
assert config.provider.value == "openai"
|
|
|
|
def test_select_model_all_unavailable(self) -> None:
|
|
"""Test when all providers are unavailable."""
|
|
settings = Settings(
|
|
anthropic_api_key=None,
|
|
openai_api_key=None,
|
|
google_api_key=None,
|
|
)
|
|
registry = CircuitBreakerRegistry(settings=settings)
|
|
limited_router = ModelRouter(settings=settings, circuit_registry=registry)
|
|
|
|
with pytest.raises(AllProvidersFailedError) as exc_info:
|
|
asyncio.run(limited_router.select_model(model_group=ModelGroup.REASONING))
|
|
|
|
assert exc_info.value.model_group == "reasoning"
|
|
assert len(exc_info.value.attempted_models) > 0
|
|
|
|
def test_get_available_models_for_group(self, router: ModelRouter) -> None:
|
|
"""Test getting available models for a group."""
|
|
models = asyncio.run(
|
|
router.get_available_models_for_group(ModelGroup.REASONING)
|
|
)
|
|
|
|
assert len(models) > 0
|
|
# Should be (name, config, available) tuples
|
|
for name, config, _available in models:
|
|
assert isinstance(name, str)
|
|
assert config is not None
|
|
|
|
def test_get_available_models_for_group_string(self, router: ModelRouter) -> None:
|
|
"""Test getting available models with string group."""
|
|
models = asyncio.run(router.get_available_models_for_group("code"))
|
|
|
|
assert len(models) > 0
|
|
|
|
def test_get_available_models_invalid_group(self, router: ModelRouter) -> None:
|
|
"""Test getting models for invalid group."""
|
|
with pytest.raises(InvalidModelGroupError):
|
|
asyncio.run(router.get_available_models_for_group("invalid"))
|
|
|
|
def test_get_all_model_groups(self, router: ModelRouter) -> None:
|
|
"""Test getting all model groups info."""
|
|
groups = router.get_all_model_groups()
|
|
|
|
assert len(groups) == len(ModelGroup)
|
|
assert "reasoning" in groups
|
|
assert "code" in groups
|
|
assert groups["reasoning"]["primary"] == "claude-opus-4"
|
|
|
|
|
|
class TestGlobalRouter:
|
|
"""Tests for global router functions."""
|
|
|
|
def test_get_model_router(self) -> None:
|
|
"""Test getting global router."""
|
|
reset_model_router()
|
|
router = get_model_router()
|
|
assert isinstance(router, ModelRouter)
|
|
|
|
def test_get_model_router_singleton(self) -> None:
|
|
"""Test router is singleton."""
|
|
reset_model_router()
|
|
router1 = get_model_router()
|
|
router2 = get_model_router()
|
|
assert router1 is router2
|
|
|
|
def test_reset_model_router(self) -> None:
|
|
"""Test resetting global router."""
|
|
reset_model_router()
|
|
router1 = get_model_router()
|
|
reset_model_router()
|
|
router2 = get_model_router()
|
|
assert router1 is not router2
|