feat(llm-gateway): implement LLM Gateway MCP Server (#56)

Implements complete LLM Gateway MCP Server with:
- FastMCP server with 4 tools: chat_completion, list_models, get_usage, count_tokens
- LiteLLM Router with multi-provider failover chains
- Circuit breaker pattern for fault tolerance
- Redis-based cost tracking per project/agent
- Comprehensive test suite (209 tests, 92% coverage)

Model groups defined per ADR-004:
- reasoning: claude-opus-4 → gpt-4.1 → gemini-2.5-pro
- code: claude-sonnet-4 → gpt-4.1 → deepseek-coder
- fast: claude-haiku → gpt-4.1-mini → gemini-2.0-flash

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-03 20:31:19 +01:00
parent 746fb7b181
commit 6e8b0b022a
23 changed files with 9794 additions and 93 deletions

View File

@@ -0,0 +1 @@
"""Tests for LLM Gateway MCP Server."""

View File

@@ -0,0 +1,204 @@
"""
Pytest fixtures for LLM Gateway tests.
"""
import os
import sys
from collections.abc import AsyncIterator, Iterator
from typing import Any
from unittest.mock import MagicMock, patch
import fakeredis.aioredis
import pytest
# Add parent directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import Settings
from cost_tracking import CostTracker, reset_cost_tracker
from failover import CircuitBreakerRegistry, reset_circuit_registry
from providers import LLMProvider, reset_provider
from routing import ModelRouter, reset_model_router
@pytest.fixture
def test_settings() -> Settings:
"""Create test settings with mock API keys."""
return Settings(
host="127.0.0.1",
port=8001,
debug=True,
redis_url="redis://localhost:6379/0",
anthropic_api_key="test-anthropic-key",
openai_api_key="test-openai-key",
google_api_key="test-google-key",
litellm_timeout=30,
litellm_max_retries=2,
litellm_cache_enabled=False,
cost_tracking_enabled=True,
circuit_failure_threshold=3,
circuit_recovery_timeout=10,
)
@pytest.fixture
def settings_no_providers() -> Settings:
"""Create settings with no providers configured."""
return Settings(
host="127.0.0.1",
port=8001,
debug=False,
redis_url="redis://localhost:6379/0",
anthropic_api_key=None,
openai_api_key=None,
google_api_key=None,
alibaba_api_key=None,
deepseek_api_key=None,
)
@pytest.fixture
def fake_redis() -> fakeredis.aioredis.FakeRedis:
"""Create a fake Redis instance for testing."""
return fakeredis.aioredis.FakeRedis(decode_responses=True)
@pytest.fixture
async def cost_tracker(
fake_redis: fakeredis.aioredis.FakeRedis,
test_settings: Settings,
) -> AsyncIterator[CostTracker]:
"""Create a cost tracker with fake Redis."""
reset_cost_tracker()
tracker = CostTracker(redis_client=fake_redis, settings=test_settings)
yield tracker
await tracker.close()
reset_cost_tracker()
@pytest.fixture
def circuit_registry(test_settings: Settings) -> Iterator[CircuitBreakerRegistry]:
"""Create a circuit breaker registry for testing."""
reset_circuit_registry()
registry = CircuitBreakerRegistry(settings=test_settings)
yield registry
reset_circuit_registry()
@pytest.fixture
def model_router(
test_settings: Settings,
circuit_registry: CircuitBreakerRegistry,
) -> Iterator[ModelRouter]:
"""Create a model router for testing."""
reset_model_router()
router = ModelRouter(settings=test_settings, circuit_registry=circuit_registry)
yield router
reset_model_router()
@pytest.fixture
def llm_provider(test_settings: Settings) -> Iterator[LLMProvider]:
"""Create an LLM provider for testing."""
reset_provider()
provider = LLMProvider(settings=test_settings)
yield provider
reset_provider()
@pytest.fixture
def mock_litellm_response() -> dict[str, Any]:
"""Create a mock LiteLLM response."""
return {
"id": "test-response-id",
"model": "claude-opus-4",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "This is a test response.",
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30,
},
}
@pytest.fixture
def mock_completion_response() -> MagicMock:
"""Create a mock completion response object."""
response = MagicMock()
response.id = "test-response-id"
response.model = "claude-opus-4"
choice = MagicMock()
choice.index = 0
choice.message = MagicMock()
choice.message.content = "This is a test response."
choice.finish_reason = "stop"
response.choices = [choice]
response.usage = MagicMock()
response.usage.prompt_tokens = 10
response.usage.completion_tokens = 20
response.usage.total_tokens = 30
return response
@pytest.fixture
def sample_messages() -> list[dict[str, str]]:
"""Sample chat messages for testing."""
return [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"},
]
@pytest.fixture
def sample_project_id() -> str:
"""Sample project ID."""
return "proj-12345678-1234-1234-1234-123456789abc"
@pytest.fixture
def sample_agent_id() -> str:
"""Sample agent ID."""
return "agent-87654321-4321-4321-4321-cba987654321"
@pytest.fixture
def sample_session_id() -> str:
"""Sample session ID."""
return "session-11111111-2222-3333-4444-555555555555"
# Reset all global state after each test
@pytest.fixture(autouse=True)
def reset_globals() -> Iterator[None]:
"""Reset all global state after each test."""
yield
reset_cost_tracker()
reset_circuit_registry()
reset_model_router()
reset_provider()
# Mock environment variables for tests
# Note: Not autouse=True to avoid affecting default value tests
@pytest.fixture
def mock_env_vars() -> Iterator[None]:
"""Set test environment variables."""
env_vars = {
"LLM_GATEWAY_HOST": "127.0.0.1",
"LLM_GATEWAY_PORT": "8001",
"LLM_GATEWAY_DEBUG": "true",
}
with patch.dict(os.environ, env_vars, clear=False):
yield

View File

@@ -0,0 +1,200 @@
"""
Tests for config module.
"""
import os
from unittest.mock import patch
import pytest
from config import Settings, get_settings
class TestSettings:
"""Tests for Settings class."""
def test_default_values(self) -> None:
"""Test default configuration values."""
settings = Settings()
assert settings.host == "0.0.0.0"
assert settings.port == 8001
assert settings.debug is False
assert settings.redis_url == "redis://localhost:6379/0"
assert settings.litellm_timeout == 120
assert settings.circuit_failure_threshold == 5
def test_custom_values(self) -> None:
"""Test custom configuration values."""
settings = Settings(
host="127.0.0.1",
port=9000,
debug=True,
redis_url="redis://custom:6380/1",
litellm_timeout=60,
)
assert settings.host == "127.0.0.1"
assert settings.port == 9000
assert settings.debug is True
assert settings.redis_url == "redis://custom:6380/1"
assert settings.litellm_timeout == 60
def test_port_validation_valid(self) -> None:
"""Test valid port numbers."""
settings = Settings(port=1)
assert settings.port == 1
settings = Settings(port=65535)
assert settings.port == 65535
settings = Settings(port=8080)
assert settings.port == 8080
def test_port_validation_invalid(self) -> None:
"""Test invalid port numbers."""
with pytest.raises(ValueError, match="Port must be between"):
Settings(port=0)
with pytest.raises(ValueError, match="Port must be between"):
Settings(port=65536)
with pytest.raises(ValueError, match="Port must be between"):
Settings(port=-1)
def test_ttl_validation_valid(self) -> None:
"""Test valid TTL values."""
settings = Settings(redis_ttl_hours=1)
assert settings.redis_ttl_hours == 1
settings = Settings(redis_ttl_hours=168) # 1 week
assert settings.redis_ttl_hours == 168
def test_ttl_validation_invalid(self) -> None:
"""Test invalid TTL values."""
with pytest.raises(ValueError, match="Redis TTL must be positive"):
Settings(redis_ttl_hours=0)
with pytest.raises(ValueError, match="Redis TTL must be positive"):
Settings(redis_ttl_hours=-1)
def test_failure_threshold_validation(self) -> None:
"""Test circuit failure threshold validation."""
settings = Settings(circuit_failure_threshold=1)
assert settings.circuit_failure_threshold == 1
settings = Settings(circuit_failure_threshold=100)
assert settings.circuit_failure_threshold == 100
with pytest.raises(ValueError, match="Failure threshold must be between"):
Settings(circuit_failure_threshold=0)
with pytest.raises(ValueError, match="Failure threshold must be between"):
Settings(circuit_failure_threshold=101)
def test_timeout_validation(self) -> None:
"""Test timeout validation."""
settings = Settings(litellm_timeout=1)
assert settings.litellm_timeout == 1
settings = Settings(litellm_timeout=600)
assert settings.litellm_timeout == 600
with pytest.raises(ValueError, match="Timeout must be between"):
Settings(litellm_timeout=0)
with pytest.raises(ValueError, match="Timeout must be between"):
Settings(litellm_timeout=601)
def test_get_available_providers_none(self) -> None:
"""Test getting available providers with none configured."""
settings = Settings()
providers = settings.get_available_providers()
assert providers == []
def test_get_available_providers_some(self) -> None:
"""Test getting available providers with some configured."""
settings = Settings(
anthropic_api_key="test-key",
openai_api_key="test-key",
)
providers = settings.get_available_providers()
assert "anthropic" in providers
assert "openai" in providers
assert "google" not in providers
assert len(providers) == 2
def test_get_available_providers_all(self) -> None:
"""Test getting available providers with all configured."""
settings = Settings(
anthropic_api_key="test-key",
openai_api_key="test-key",
google_api_key="test-key",
alibaba_api_key="test-key",
deepseek_api_key="test-key",
)
providers = settings.get_available_providers()
assert len(providers) == 5
assert "anthropic" in providers
assert "openai" in providers
assert "google" in providers
assert "alibaba" in providers
assert "deepseek" in providers
def test_has_any_provider_false(self) -> None:
"""Test has_any_provider when none configured."""
settings = Settings()
assert settings.has_any_provider() is False
def test_has_any_provider_true(self) -> None:
"""Test has_any_provider when at least one configured."""
settings = Settings(anthropic_api_key="test-key")
assert settings.has_any_provider() is True
def test_deepseek_base_url_counts_as_provider(self) -> None:
"""Test that DeepSeek base URL alone counts as provider."""
settings = Settings(deepseek_base_url="http://localhost:8000")
providers = settings.get_available_providers()
assert "deepseek" in providers
class TestGetSettings:
"""Tests for get_settings function."""
def test_get_settings_returns_settings(self) -> None:
"""Test that get_settings returns a Settings instance."""
# Clear the cache first
get_settings.cache_clear()
settings = get_settings()
assert isinstance(settings, Settings)
def test_get_settings_is_cached(self) -> None:
"""Test that get_settings returns cached instance."""
get_settings.cache_clear()
settings1 = get_settings()
settings2 = get_settings()
assert settings1 is settings2
def test_env_var_override(self) -> None:
"""Test that environment variables override defaults."""
get_settings.cache_clear()
with patch.dict(
os.environ,
{
"LLM_GATEWAY_HOST": "192.168.1.1",
"LLM_GATEWAY_PORT": "9999",
"LLM_GATEWAY_DEBUG": "true",
},
):
get_settings.cache_clear()
settings = get_settings()
assert settings.host == "192.168.1.1"
assert settings.port == 9999
assert settings.debug is True

View File

@@ -0,0 +1,417 @@
"""
Tests for cost_tracking module.
"""
import asyncio
import fakeredis.aioredis
import pytest
from config import Settings
from cost_tracking import (
CostTracker,
calculate_cost,
close_cost_tracker,
get_cost_tracker,
reset_cost_tracker,
)
@pytest.fixture
def tracker_settings() -> Settings:
"""Settings for cost tracker tests."""
return Settings(
redis_url="redis://localhost:6379/0",
redis_prefix="test_llm_gateway",
cost_tracking_enabled=True,
cost_alert_threshold=100.0,
default_budget_limit=1000.0,
)
@pytest.fixture
def fake_redis() -> fakeredis.aioredis.FakeRedis:
"""Create fake Redis for testing."""
return fakeredis.aioredis.FakeRedis(decode_responses=True)
@pytest.fixture
def tracker(
fake_redis: fakeredis.aioredis.FakeRedis,
tracker_settings: Settings,
) -> CostTracker:
"""Create cost tracker with fake Redis."""
return CostTracker(redis_client=fake_redis, settings=tracker_settings)
class TestCalculateCost:
"""Tests for calculate_cost function."""
def test_calculate_cost_known_model(self) -> None:
"""Test calculating cost for known model."""
# claude-opus-4: $15/1M input, $75/1M output
cost = calculate_cost(
model="claude-opus-4",
prompt_tokens=1000,
completion_tokens=500,
)
# 1000/1M * 15 + 500/1M * 75 = 0.015 + 0.0375 = 0.0525
assert cost == pytest.approx(0.0525, rel=0.001)
def test_calculate_cost_unknown_model(self) -> None:
"""Test calculating cost for unknown model."""
cost = calculate_cost(
model="unknown-model",
prompt_tokens=1000,
completion_tokens=500,
)
assert cost == 0.0
def test_calculate_cost_zero_tokens(self) -> None:
"""Test calculating cost with zero tokens."""
cost = calculate_cost(
model="claude-opus-4",
prompt_tokens=0,
completion_tokens=0,
)
assert cost == 0.0
def test_calculate_cost_large_token_counts(self) -> None:
"""Test calculating cost with large token counts."""
cost = calculate_cost(
model="claude-opus-4",
prompt_tokens=1_000_000,
completion_tokens=500_000,
)
# 1M * 15/1M + 500K * 75/1M = 15 + 37.5 = 52.5
assert cost == pytest.approx(52.5, rel=0.001)
class TestCostTracker:
"""Tests for CostTracker class."""
def test_record_usage(self, tracker: CostTracker) -> None:
"""Test recording usage."""
asyncio.run(
tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=100,
completion_tokens=50,
cost_usd=0.01,
)
)
# Verify by getting usage report
report = asyncio.run(
tracker.get_project_usage("proj-123", period="day")
)
assert report.total_requests == 1
assert report.total_cost_usd == pytest.approx(0.01, rel=0.01)
def test_record_usage_disabled(self, tracker_settings: Settings) -> None:
"""Test recording is skipped when disabled."""
settings = Settings(**{
**tracker_settings.model_dump(),
"cost_tracking_enabled": False,
})
fake_redis = fakeredis.aioredis.FakeRedis(decode_responses=True)
disabled_tracker = CostTracker(redis_client=fake_redis, settings=settings)
# This should not raise and should not record
asyncio.run(
disabled_tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=100,
completion_tokens=50,
cost_usd=0.01,
)
)
# Usage should be empty
report = asyncio.run(
disabled_tracker.get_project_usage("proj-123", period="day")
)
assert report.total_requests == 0
def test_record_usage_with_session(self, tracker: CostTracker) -> None:
"""Test recording usage with session ID."""
asyncio.run(
tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=100,
completion_tokens=50,
cost_usd=0.01,
session_id="session-789",
)
)
# Verify session usage
session_usage = asyncio.run(
tracker.get_session_usage("session-789")
)
assert session_usage["session_id"] == "session-789"
assert session_usage["total_cost_usd"] == pytest.approx(0.01, rel=0.01)
def test_get_project_usage_empty(self, tracker: CostTracker) -> None:
"""Test getting usage for project with no data."""
report = asyncio.run(
tracker.get_project_usage("nonexistent-project", period="day")
)
assert report.entity_id == "nonexistent-project"
assert report.entity_type == "project"
assert report.total_requests == 0
assert report.total_cost_usd == 0.0
def test_get_project_usage_multiple_models(self, tracker: CostTracker) -> None:
"""Test usage tracking across multiple models."""
# Record usage for different models
asyncio.run(
tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=100,
completion_tokens=50,
cost_usd=0.01,
)
)
asyncio.run(
tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="gpt-4.1",
prompt_tokens=200,
completion_tokens=100,
cost_usd=0.02,
)
)
report = asyncio.run(
tracker.get_project_usage("proj-123", period="day")
)
assert report.total_requests == 2
assert len(report.by_model) == 2
assert "claude-opus-4" in report.by_model
assert "gpt-4.1" in report.by_model
def test_get_agent_usage(self, tracker: CostTracker) -> None:
"""Test getting agent usage."""
asyncio.run(
tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=100,
completion_tokens=50,
cost_usd=0.01,
)
)
report = asyncio.run(
tracker.get_agent_usage("agent-456", period="day")
)
assert report.entity_id == "agent-456"
assert report.entity_type == "agent"
assert report.total_requests == 1
def test_usage_periods(self, tracker: CostTracker) -> None:
"""Test different usage periods."""
asyncio.run(
tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=100,
completion_tokens=50,
cost_usd=0.01,
)
)
# Check different periods
hour_report = asyncio.run(
tracker.get_project_usage("proj-123", period="hour")
)
day_report = asyncio.run(
tracker.get_project_usage("proj-123", period="day")
)
month_report = asyncio.run(
tracker.get_project_usage("proj-123", period="month")
)
assert hour_report.period == "hour"
assert day_report.period == "day"
assert month_report.period == "month"
# All should have the same data
assert hour_report.total_requests == 1
assert day_report.total_requests == 1
assert month_report.total_requests == 1
def test_check_budget_within(self, tracker: CostTracker) -> None:
"""Test budget check when within limit."""
# Record some usage
asyncio.run(
tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=1000,
completion_tokens=500,
cost_usd=50.0,
)
)
within, current, limit = asyncio.run(
tracker.check_budget("proj-123", budget_limit=100.0)
)
assert within is True
assert current == pytest.approx(50.0, rel=0.01)
assert limit == 100.0
def test_check_budget_exceeded(self, tracker: CostTracker) -> None:
"""Test budget check when exceeded."""
asyncio.run(
tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=1000,
completion_tokens=500,
cost_usd=150.0,
)
)
within, current, limit = asyncio.run(
tracker.check_budget("proj-123", budget_limit=100.0)
)
assert within is False
assert current >= limit
def test_check_budget_default_limit(self, tracker: CostTracker) -> None:
"""Test budget check with default limit."""
within, current, limit = asyncio.run(
tracker.check_budget("proj-123")
)
assert limit == 1000.0 # Default from settings
def test_estimate_request_cost_known_model(self, tracker: CostTracker) -> None:
"""Test estimating cost for known model."""
cost = asyncio.run(
tracker.estimate_request_cost(
model="claude-opus-4",
prompt_tokens=1000,
max_completion_tokens=500,
)
)
# 1000/1M * 15 + 500/1M * 75 = 0.015 + 0.0375 = 0.0525
assert cost == pytest.approx(0.0525, rel=0.01)
def test_estimate_request_cost_unknown_model(self, tracker: CostTracker) -> None:
"""Test estimating cost for unknown model."""
cost = asyncio.run(
tracker.estimate_request_cost(
model="unknown-model",
prompt_tokens=1000,
max_completion_tokens=500,
)
)
# Uses fallback estimate
assert cost > 0
def test_should_alert_below_threshold(self, tracker: CostTracker) -> None:
"""Test alert check when below threshold."""
asyncio.run(
tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=1000,
completion_tokens=500,
cost_usd=50.0,
)
)
should_alert, current = asyncio.run(
tracker.should_alert("proj-123", threshold=100.0)
)
assert should_alert is False
assert current == pytest.approx(50.0, rel=0.01)
def test_should_alert_above_threshold(self, tracker: CostTracker) -> None:
"""Test alert check when above threshold."""
asyncio.run(
tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=1000,
completion_tokens=500,
cost_usd=150.0,
)
)
should_alert, current = asyncio.run(
tracker.should_alert("proj-123", threshold=100.0)
)
assert should_alert is True
def test_close(self, tracker: CostTracker) -> None:
"""Test closing tracker."""
asyncio.run(tracker.close())
assert tracker._redis is None
class TestGlobalTracker:
"""Tests for global tracker functions."""
def test_get_cost_tracker(self) -> None:
"""Test getting global tracker."""
reset_cost_tracker()
tracker = get_cost_tracker()
assert isinstance(tracker, CostTracker)
def test_get_cost_tracker_singleton(self) -> None:
"""Test tracker is singleton."""
reset_cost_tracker()
tracker1 = get_cost_tracker()
tracker2 = get_cost_tracker()
assert tracker1 is tracker2
def test_reset_cost_tracker(self) -> None:
"""Test resetting global tracker."""
reset_cost_tracker()
tracker1 = get_cost_tracker()
reset_cost_tracker()
tracker2 = get_cost_tracker()
assert tracker1 is not tracker2
def test_close_cost_tracker(self) -> None:
"""Test closing global tracker."""
reset_cost_tracker()
_ = get_cost_tracker()
asyncio.run(close_cost_tracker())
# Getting again should create a new one
tracker2 = get_cost_tracker()
assert tracker2 is not None

View File

@@ -0,0 +1,377 @@
"""
Tests for exceptions module.
"""
from exceptions import (
AllProvidersFailedError,
CircuitOpenError,
ConfigurationError,
ContextTooLongError,
CostLimitExceededError,
ErrorCode,
InvalidModelError,
InvalidModelGroupError,
LLMGatewayError,
ModelNotAvailableError,
ProviderError,
RateLimitError,
StreamError,
TokenLimitExceededError,
)
class TestErrorCode:
"""Tests for ErrorCode enum."""
def test_error_code_values(self) -> None:
"""Test error code values."""
assert ErrorCode.UNKNOWN_ERROR.value == "LLM_UNKNOWN_ERROR"
assert ErrorCode.PROVIDER_ERROR.value == "LLM_PROVIDER_ERROR"
assert ErrorCode.CIRCUIT_OPEN.value == "LLM_CIRCUIT_OPEN"
assert ErrorCode.COST_LIMIT_EXCEEDED.value == "LLM_COST_LIMIT_EXCEEDED"
class TestLLMGatewayError:
"""Tests for LLMGatewayError base class."""
def test_basic_error(self) -> None:
"""Test basic error creation."""
error = LLMGatewayError("Something went wrong")
assert str(error) == "[LLM_UNKNOWN_ERROR] Something went wrong"
assert error.message == "Something went wrong"
assert error.code == ErrorCode.UNKNOWN_ERROR
assert error.details == {}
assert error.cause is None
def test_error_with_code(self) -> None:
"""Test error with custom code."""
error = LLMGatewayError(
"Provider failed",
code=ErrorCode.PROVIDER_ERROR,
)
assert error.code == ErrorCode.PROVIDER_ERROR
def test_error_with_details(self) -> None:
"""Test error with details."""
error = LLMGatewayError(
"Error",
details={"key": "value"},
)
assert error.details == {"key": "value"}
def test_error_with_cause(self) -> None:
"""Test error with cause exception."""
cause = ValueError("Original error")
error = LLMGatewayError("Wrapped error", cause=cause)
assert error.cause is cause
def test_to_dict(self) -> None:
"""Test converting error to dict."""
error = LLMGatewayError(
"Test error",
code=ErrorCode.INVALID_REQUEST,
details={"field": "value"},
)
result = error.to_dict()
assert result["error"] == "LLM_INVALID_REQUEST"
assert result["message"] == "Test error"
assert result["details"] == {"field": "value"}
def test_to_dict_no_details(self) -> None:
"""Test to_dict without details."""
error = LLMGatewayError("Test error")
result = error.to_dict()
assert "details" not in result
def test_repr(self) -> None:
"""Test error repr."""
error = LLMGatewayError("Test", details={"key": "val"})
repr_str = repr(error)
assert "LLMGatewayError" in repr_str
assert "Test" in repr_str
class TestProviderError:
"""Tests for ProviderError."""
def test_basic_provider_error(self) -> None:
"""Test basic provider error."""
error = ProviderError(
message="API call failed",
provider="anthropic",
)
assert error.provider == "anthropic"
assert error.model is None
assert error.status_code is None
assert error.code == ErrorCode.PROVIDER_ERROR
assert "provider" in error.details
def test_provider_error_with_model(self) -> None:
"""Test provider error with model info."""
error = ProviderError(
message="Model not found",
provider="openai",
model="gpt-5",
status_code=404,
)
assert error.model == "gpt-5"
assert error.status_code == 404
assert error.details["model"] == "gpt-5"
assert error.details["status_code"] == 404
def test_provider_error_with_cause(self) -> None:
"""Test provider error with cause."""
cause = ConnectionError("Network down")
error = ProviderError(
message="Connection failed",
provider="google",
cause=cause,
)
assert error.cause is cause
class TestRateLimitError:
"""Tests for RateLimitError."""
def test_internal_rate_limit(self) -> None:
"""Test internal rate limit error."""
error = RateLimitError(
message="Too many requests",
retry_after=60,
)
assert error.code == ErrorCode.RATE_LIMIT_EXCEEDED
assert error.provider is None
assert error.retry_after == 60
assert error.details["retry_after_seconds"] == 60
def test_provider_rate_limit(self) -> None:
"""Test provider rate limit error."""
error = RateLimitError(
message="OpenAI rate limit",
provider="openai",
retry_after=30,
)
assert error.code == ErrorCode.PROVIDER_RATE_LIMIT
assert error.provider == "openai"
assert error.details["provider"] == "openai"
class TestCircuitOpenError:
"""Tests for CircuitOpenError."""
def test_circuit_open_error(self) -> None:
"""Test circuit open error."""
error = CircuitOpenError(
provider="anthropic",
recovery_time=45,
)
assert error.provider == "anthropic"
assert error.recovery_time == 45
assert error.code == ErrorCode.CIRCUIT_OPEN
assert "Circuit breaker open" in error.message
assert error.details["recovery_time_seconds"] == 45
def test_circuit_open_no_recovery_time(self) -> None:
"""Test circuit open without recovery time."""
error = CircuitOpenError(provider="openai")
assert error.recovery_time is None
assert "recovery_time_seconds" not in error.details
class TestCostLimitExceededError:
"""Tests for CostLimitExceededError."""
def test_project_cost_limit(self) -> None:
"""Test project cost limit error."""
error = CostLimitExceededError(
entity_type="project",
entity_id="proj-123",
current_cost=150.0,
limit=100.0,
)
assert error.entity_type == "project"
assert error.entity_id == "proj-123"
assert error.current_cost == 150.0
assert error.limit == 100.0
assert error.code == ErrorCode.COST_LIMIT_EXCEEDED
assert "$150.00" in error.message
assert "$100.00" in error.message
def test_agent_cost_limit(self) -> None:
"""Test agent cost limit error."""
error = CostLimitExceededError(
entity_type="agent",
entity_id="agent-456",
current_cost=50.0,
limit=25.0,
)
assert error.entity_type == "agent"
assert error.details["entity_type"] == "agent"
class TestInvalidModelGroupError:
"""Tests for InvalidModelGroupError."""
def test_invalid_group_error(self) -> None:
"""Test invalid model group error."""
error = InvalidModelGroupError(
model_group="invalid_group",
available_groups=["reasoning", "code", "fast"],
)
assert error.model_group == "invalid_group"
assert error.available_groups == ["reasoning", "code", "fast"]
assert error.code == ErrorCode.INVALID_MODEL_GROUP
assert "invalid_group" in error.message
def test_invalid_group_no_available(self) -> None:
"""Test invalid group without available list."""
error = InvalidModelGroupError(model_group="unknown")
assert error.available_groups is None
assert "available_groups" not in error.details
class TestInvalidModelError:
"""Tests for InvalidModelError."""
def test_invalid_model_error(self) -> None:
"""Test invalid model error."""
error = InvalidModelError(
model="gpt-99",
reason="Model does not exist",
)
assert error.model == "gpt-99"
assert error.code == ErrorCode.INVALID_MODEL
assert "gpt-99" in error.message
assert "Model does not exist" in error.message
def test_invalid_model_no_reason(self) -> None:
"""Test invalid model without reason."""
error = InvalidModelError(model="unknown-model")
assert "reason" not in error.details
class TestModelNotAvailableError:
"""Tests for ModelNotAvailableError."""
def test_model_not_available(self) -> None:
"""Test model not available error."""
error = ModelNotAvailableError(
model="claude-opus-4",
provider="anthropic",
)
assert error.model == "claude-opus-4"
assert error.provider == "anthropic"
assert error.code == ErrorCode.MODEL_NOT_AVAILABLE
assert "not configured" in error.message
class TestAllProvidersFailedError:
"""Tests for AllProvidersFailedError."""
def test_all_providers_failed(self) -> None:
"""Test all providers failed error."""
errors = [
{"model": "claude-opus-4", "error": "Rate limited"},
{"model": "gpt-4.1", "error": "Timeout"},
]
error = AllProvidersFailedError(
model_group="reasoning",
attempted_models=["claude-opus-4", "gpt-4.1"],
errors=errors,
)
assert error.model_group == "reasoning"
assert error.attempted_models == ["claude-opus-4", "gpt-4.1"]
assert error.errors == errors
assert error.code == ErrorCode.ALL_PROVIDERS_FAILED
class TestStreamError:
"""Tests for StreamError."""
def test_stream_error(self) -> None:
"""Test stream error."""
cause = OSError("Connection reset")
error = StreamError(
message="Stream interrupted",
chunks_received=10,
cause=cause,
)
assert error.chunks_received == 10
assert error.cause is cause
assert error.code == ErrorCode.STREAM_ERROR
class TestTokenLimitExceededError:
"""Tests for TokenLimitExceededError."""
def test_token_limit_exceeded(self) -> None:
"""Test token limit exceeded error."""
error = TokenLimitExceededError(
model="claude-haiku",
token_count=10000,
limit=8192,
)
assert error.model == "claude-haiku"
assert error.token_count == 10000
assert error.limit == 8192
assert error.code == ErrorCode.TOKEN_LIMIT_EXCEEDED
class TestContextTooLongError:
"""Tests for ContextTooLongError."""
def test_context_too_long(self) -> None:
"""Test context too long error."""
error = ContextTooLongError(
model="gpt-4.1-mini",
context_length=150000,
max_context=100000,
)
assert error.model == "gpt-4.1-mini"
assert error.context_length == 150000
assert error.max_context == 100000
assert error.code == ErrorCode.CONTEXT_TOO_LONG
class TestConfigurationError:
"""Tests for ConfigurationError."""
def test_configuration_error(self) -> None:
"""Test configuration error."""
error = ConfigurationError(
message="Missing API key",
config_key="ANTHROPIC_API_KEY",
)
assert error.config_key == "ANTHROPIC_API_KEY"
assert error.code == ErrorCode.CONFIGURATION_ERROR
assert error.details["config_key"] == "ANTHROPIC_API_KEY"
def test_configuration_error_no_key(self) -> None:
"""Test configuration error without key."""
error = ConfigurationError(message="Invalid configuration")
assert error.config_key is None
assert "config_key" not in error.details

View File

@@ -0,0 +1,407 @@
"""
Tests for failover module (circuit breaker).
"""
import asyncio
import time
import pytest
from config import Settings
from exceptions import CircuitOpenError
from failover import (
CircuitBreaker,
CircuitBreakerRegistry,
CircuitState,
CircuitStats,
get_circuit_registry,
reset_circuit_registry,
)
class TestCircuitState:
"""Tests for CircuitState enum."""
def test_circuit_states(self) -> None:
"""Test circuit state values."""
assert CircuitState.CLOSED.value == "closed"
assert CircuitState.OPEN.value == "open"
assert CircuitState.HALF_OPEN.value == "half_open"
class TestCircuitStats:
"""Tests for CircuitStats dataclass."""
def test_default_stats(self) -> None:
"""Test default stats values."""
stats = CircuitStats()
assert stats.failures == 0
assert stats.successes == 0
assert stats.last_failure_time is None
assert stats.last_success_time is None
assert stats.half_open_calls == 0
class TestCircuitBreaker:
"""Tests for CircuitBreaker class."""
def test_initial_state(self) -> None:
"""Test circuit breaker initial state."""
cb = CircuitBreaker(name="test", failure_threshold=5)
assert cb.name == "test"
assert cb.state == CircuitState.CLOSED
assert cb.failure_threshold == 5
assert cb.is_available() is True
def test_state_remains_closed_below_threshold(self) -> None:
"""Test circuit stays closed below failure threshold."""
cb = CircuitBreaker(name="test", failure_threshold=3)
# Record 2 failures (below threshold)
asyncio.run(cb.record_failure())
asyncio.run(cb.record_failure())
assert cb.state == CircuitState.CLOSED
assert cb.stats.failures == 2
assert cb.is_available() is True
def test_state_opens_at_threshold(self) -> None:
"""Test circuit opens at failure threshold."""
cb = CircuitBreaker(name="test", failure_threshold=3)
# Record 3 failures (at threshold)
asyncio.run(cb.record_failure())
asyncio.run(cb.record_failure())
asyncio.run(cb.record_failure())
assert cb.state == CircuitState.OPEN
assert cb.is_available() is False
def test_success_resets_in_closed(self) -> None:
"""Test success in closed state records properly."""
cb = CircuitBreaker(name="test", failure_threshold=3)
asyncio.run(cb.record_failure())
asyncio.run(cb.record_success())
assert cb.state == CircuitState.CLOSED
assert cb.stats.successes == 1
assert cb.stats.last_success_time is not None
def test_half_open_transition(self) -> None:
"""Test transition to half-open after recovery timeout."""
cb = CircuitBreaker(
name="test",
failure_threshold=1,
recovery_timeout=1, # 1 second
)
# Open the circuit
asyncio.run(cb.record_failure())
assert cb.state == CircuitState.OPEN
# Wait for recovery timeout
time.sleep(1.1)
# State should transition to half-open
assert cb.state == CircuitState.HALF_OPEN
assert cb.is_available() is True
def test_half_open_success_closes(self) -> None:
"""Test success in half-open closes circuit."""
cb = CircuitBreaker(
name="test",
failure_threshold=1,
recovery_timeout=0, # Immediate recovery for testing
)
# Open and transition to half-open
asyncio.run(cb.record_failure())
time.sleep(0.1)
_ = cb.state # Trigger state check
assert cb.state == CircuitState.HALF_OPEN
# Success should close
asyncio.run(cb.record_success())
assert cb.state == CircuitState.CLOSED
def test_half_open_failure_reopens(self) -> None:
"""Test failure in half-open reopens circuit."""
cb = CircuitBreaker(
name="test",
failure_threshold=1,
recovery_timeout=0.05, # Small but non-zero for reliable timing
)
# Open and transition to half-open
asyncio.run(cb.record_failure())
assert cb.state == CircuitState.OPEN
# Wait for recovery timeout
time.sleep(0.1)
assert cb.state == CircuitState.HALF_OPEN
# Failure should reopen
asyncio.run(cb.record_failure())
assert cb.state == CircuitState.OPEN
def test_half_open_call_limit(self) -> None:
"""Test half-open call limit."""
cb = CircuitBreaker(
name="test",
failure_threshold=1,
recovery_timeout=0,
half_open_max_calls=2,
)
# Open and transition to half-open
asyncio.run(cb.record_failure())
time.sleep(0.1)
_ = cb.state
assert cb.is_available() is True
# Simulate calls in half-open
cb._stats.half_open_calls = 1
assert cb.is_available() is True
cb._stats.half_open_calls = 2
assert cb.is_available() is False
def test_time_until_recovery(self) -> None:
"""Test time until recovery calculation."""
cb = CircuitBreaker(
name="test",
failure_threshold=1,
recovery_timeout=60,
)
# Closed circuit has no recovery time
assert cb.time_until_recovery() is None
# Open circuit
asyncio.run(cb.record_failure())
assert cb.state == CircuitState.OPEN
# Should have recovery time
remaining = cb.time_until_recovery()
assert remaining is not None
assert 0 <= remaining <= 60
def test_execute_success(self) -> None:
"""Test execute with successful function."""
cb = CircuitBreaker(name="test", failure_threshold=3)
async def success_func() -> str:
return "success"
result = asyncio.run(cb.execute(success_func))
assert result == "success"
assert cb.stats.successes == 1
def test_execute_failure(self) -> None:
"""Test execute with failing function."""
cb = CircuitBreaker(name="test", failure_threshold=3)
async def fail_func() -> None:
raise ValueError("Error")
with pytest.raises(ValueError):
asyncio.run(cb.execute(fail_func))
assert cb.stats.failures == 1
def test_execute_when_open(self) -> None:
"""Test execute raises when circuit is open."""
cb = CircuitBreaker(name="test", failure_threshold=1)
# Open the circuit
asyncio.run(cb.record_failure())
assert cb.state == CircuitState.OPEN
async def success_func() -> str:
return "success"
with pytest.raises(CircuitOpenError) as exc_info:
asyncio.run(cb.execute(success_func))
assert exc_info.value.provider == "test"
def test_reset(self) -> None:
"""Test circuit reset."""
cb = CircuitBreaker(name="test", failure_threshold=1)
# Open the circuit
asyncio.run(cb.record_failure())
assert cb.state == CircuitState.OPEN
# Reset
cb.reset()
assert cb.state == CircuitState.CLOSED
assert cb.stats.failures == 0
assert cb.stats.successes == 0
def test_to_dict(self) -> None:
"""Test converting circuit to dict."""
cb = CircuitBreaker(name="test", failure_threshold=3)
asyncio.run(cb.record_failure())
asyncio.run(cb.record_success())
result = cb.to_dict()
assert result["name"] == "test"
assert result["state"] == "closed"
assert result["failures"] == 1
assert result["successes"] == 1
assert result["is_available"] is True
class TestCircuitBreakerRegistry:
"""Tests for CircuitBreakerRegistry class."""
def test_get_circuit_creates_new(self) -> None:
"""Test getting a new circuit."""
settings = Settings(circuit_failure_threshold=5)
registry = CircuitBreakerRegistry(settings=settings)
circuit = asyncio.run(registry.get_circuit("anthropic"))
assert circuit.name == "anthropic"
assert circuit.failure_threshold == 5
def test_get_circuit_returns_same(self) -> None:
"""Test getting same circuit twice."""
registry = CircuitBreakerRegistry()
circuit1 = asyncio.run(registry.get_circuit("openai"))
circuit2 = asyncio.run(registry.get_circuit("openai"))
assert circuit1 is circuit2
def test_get_circuit_sync(self) -> None:
"""Test sync circuit getter."""
registry = CircuitBreakerRegistry()
circuit = registry.get_circuit_sync("google")
assert circuit.name == "google"
def test_is_available(self) -> None:
"""Test checking if circuit is available."""
registry = CircuitBreakerRegistry()
assert asyncio.run(registry.is_available("test")) is True
# Open the circuit
circuit = asyncio.run(registry.get_circuit("test"))
for _ in range(5):
asyncio.run(circuit.record_failure())
assert asyncio.run(registry.is_available("test")) is False
def test_record_success(self) -> None:
"""Test recording success through registry."""
registry = CircuitBreakerRegistry()
asyncio.run(registry.record_success("test"))
circuit = asyncio.run(registry.get_circuit("test"))
assert circuit.stats.successes == 1
def test_record_failure(self) -> None:
"""Test recording failure through registry."""
registry = CircuitBreakerRegistry()
asyncio.run(registry.record_failure("test"))
circuit = asyncio.run(registry.get_circuit("test"))
assert circuit.stats.failures == 1
def test_reset(self) -> None:
"""Test resetting a specific circuit."""
registry = CircuitBreakerRegistry()
# Create and fail a circuit
asyncio.run(registry.record_failure("test"))
asyncio.run(registry.reset("test"))
circuit = asyncio.run(registry.get_circuit("test"))
assert circuit.stats.failures == 0
def test_reset_all(self) -> None:
"""Test resetting all circuits."""
registry = CircuitBreakerRegistry()
# Create multiple circuits with failures
asyncio.run(registry.record_failure("circuit1"))
asyncio.run(registry.record_failure("circuit2"))
asyncio.run(registry.reset_all())
circuit1 = asyncio.run(registry.get_circuit("circuit1"))
circuit2 = asyncio.run(registry.get_circuit("circuit2"))
assert circuit1.stats.failures == 0
assert circuit2.stats.failures == 0
def test_get_all_states(self) -> None:
"""Test getting all circuit states."""
registry = CircuitBreakerRegistry()
asyncio.run(registry.get_circuit("circuit1"))
asyncio.run(registry.get_circuit("circuit2"))
states = registry.get_all_states()
assert "circuit1" in states
assert "circuit2" in states
assert states["circuit1"]["state"] == "closed"
def test_get_open_circuits(self) -> None:
"""Test getting open circuits."""
settings = Settings(circuit_failure_threshold=1)
registry = CircuitBreakerRegistry(settings=settings)
asyncio.run(registry.get_circuit("healthy"))
asyncio.run(registry.record_failure("failing"))
open_circuits = registry.get_open_circuits()
assert "failing" in open_circuits
assert "healthy" not in open_circuits
def test_get_available_circuits(self) -> None:
"""Test getting available circuits."""
settings = Settings(circuit_failure_threshold=1)
registry = CircuitBreakerRegistry(settings=settings)
asyncio.run(registry.get_circuit("healthy"))
asyncio.run(registry.record_failure("failing"))
available = registry.get_available_circuits()
assert "healthy" in available
assert "failing" not in available
class TestGlobalRegistry:
"""Tests for global registry functions."""
def test_get_circuit_registry(self) -> None:
"""Test getting global registry."""
reset_circuit_registry()
registry = get_circuit_registry()
assert isinstance(registry, CircuitBreakerRegistry)
def test_get_circuit_registry_singleton(self) -> None:
"""Test registry is singleton."""
reset_circuit_registry()
registry1 = get_circuit_registry()
registry2 = get_circuit_registry()
assert registry1 is registry2
def test_reset_circuit_registry(self) -> None:
"""Test resetting global registry."""
reset_circuit_registry()
registry1 = get_circuit_registry()
reset_circuit_registry()
registry2 = get_circuit_registry()
assert registry1 is not registry2

View File

@@ -0,0 +1,408 @@
"""
Tests for models module.
"""
from datetime import UTC, datetime
import pytest
from models import (
AGENT_TYPE_MODEL_PREFERENCES,
MODEL_CONFIGS,
MODEL_GROUPS,
ChatMessage,
CompletionRequest,
CompletionResponse,
CostRecord,
EmbeddingRequest,
ModelConfig,
ModelGroup,
ModelGroupConfig,
ModelGroupInfo,
ModelInfo,
Provider,
StreamChunk,
UsageReport,
UsageStats,
)
class TestModelGroup:
"""Tests for ModelGroup enum."""
def test_model_group_values(self) -> None:
"""Test model group enum values."""
assert ModelGroup.REASONING.value == "reasoning"
assert ModelGroup.CODE.value == "code"
assert ModelGroup.FAST.value == "fast"
assert ModelGroup.VISION.value == "vision"
assert ModelGroup.EMBEDDING.value == "embedding"
assert ModelGroup.COST_OPTIMIZED.value == "cost_optimized"
assert ModelGroup.SELF_HOSTED.value == "self_hosted"
def test_model_group_from_string(self) -> None:
"""Test creating ModelGroup from string."""
assert ModelGroup("reasoning") == ModelGroup.REASONING
assert ModelGroup("code") == ModelGroup.CODE
assert ModelGroup("fast") == ModelGroup.FAST
def test_model_group_invalid(self) -> None:
"""Test invalid model group value."""
with pytest.raises(ValueError):
ModelGroup("invalid_group")
class TestProvider:
"""Tests for Provider enum."""
def test_provider_values(self) -> None:
"""Test provider enum values."""
assert Provider.ANTHROPIC.value == "anthropic"
assert Provider.OPENAI.value == "openai"
assert Provider.GOOGLE.value == "google"
assert Provider.ALIBABA.value == "alibaba"
assert Provider.DEEPSEEK.value == "deepseek"
class TestModelConfig:
"""Tests for ModelConfig dataclass."""
def test_model_config_creation(self) -> None:
"""Test creating a ModelConfig."""
config = ModelConfig(
name="test-model",
litellm_name="provider/test-model",
provider=Provider.ANTHROPIC,
cost_per_1m_input=10.0,
cost_per_1m_output=30.0,
context_window=100000,
max_output_tokens=4096,
supports_vision=True,
)
assert config.name == "test-model"
assert config.provider == Provider.ANTHROPIC
assert config.cost_per_1m_input == 10.0
assert config.supports_vision is True
assert config.supports_streaming is True # default
def test_model_configs_exist(self) -> None:
"""Test that model configs are defined."""
assert len(MODEL_CONFIGS) > 0
assert "claude-opus-4" in MODEL_CONFIGS
assert "gpt-4.1" in MODEL_CONFIGS
assert "gemini-2.5-pro" in MODEL_CONFIGS
class TestModelGroupConfig:
"""Tests for ModelGroupConfig dataclass."""
def test_model_group_config_creation(self) -> None:
"""Test creating a ModelGroupConfig."""
config = ModelGroupConfig(
primary="model-a",
fallbacks=["model-b", "model-c"],
description="Test group",
)
assert config.primary == "model-a"
assert config.fallbacks == ["model-b", "model-c"]
assert config.description == "Test group"
def test_get_all_models(self) -> None:
"""Test getting all models in order."""
config = ModelGroupConfig(
primary="model-a",
fallbacks=["model-b", "model-c"],
description="Test group",
)
models = config.get_all_models()
assert models == ["model-a", "model-b", "model-c"]
def test_model_groups_exist(self) -> None:
"""Test that model groups are defined."""
assert len(MODEL_GROUPS) > 0
assert ModelGroup.REASONING in MODEL_GROUPS
assert ModelGroup.CODE in MODEL_GROUPS
assert ModelGroup.FAST in MODEL_GROUPS
class TestAgentTypePreferences:
"""Tests for agent type model preferences."""
def test_agent_preferences_exist(self) -> None:
"""Test that agent preferences are defined."""
assert len(AGENT_TYPE_MODEL_PREFERENCES) > 0
assert "product_owner" in AGENT_TYPE_MODEL_PREFERENCES
assert "software_engineer" in AGENT_TYPE_MODEL_PREFERENCES
def test_agent_preference_values(self) -> None:
"""Test agent preference values."""
assert AGENT_TYPE_MODEL_PREFERENCES["product_owner"] == ModelGroup.REASONING
assert AGENT_TYPE_MODEL_PREFERENCES["software_engineer"] == ModelGroup.CODE
assert AGENT_TYPE_MODEL_PREFERENCES["devops_engineer"] == ModelGroup.FAST
class TestChatMessage:
"""Tests for ChatMessage model."""
def test_chat_message_creation(self) -> None:
"""Test creating a ChatMessage."""
msg = ChatMessage(role="user", content="Hello")
assert msg.role == "user"
assert msg.content == "Hello"
assert msg.name is None
assert msg.tool_calls is None
def test_chat_message_with_optional(self) -> None:
"""Test ChatMessage with optional fields."""
msg = ChatMessage(
role="assistant",
content="Response",
name="assistant_1",
tool_calls=[{"id": "call_1", "function": {"name": "test"}}],
)
assert msg.name == "assistant_1"
assert msg.tool_calls is not None
def test_chat_message_list_content(self) -> None:
"""Test ChatMessage with list content (for images)."""
msg = ChatMessage(
role="user",
content=[
{"type": "text", "text": "What's in this image?"},
{"type": "image_url", "image_url": {"url": "http://example.com/img.jpg"}},
],
)
assert isinstance(msg.content, list)
class TestCompletionRequest:
"""Tests for CompletionRequest model."""
def test_completion_request_minimal(self) -> None:
"""Test minimal CompletionRequest."""
req = CompletionRequest(
project_id="proj-123",
agent_id="agent-456",
messages=[ChatMessage(role="user", content="Hi")],
)
assert req.project_id == "proj-123"
assert req.agent_id == "agent-456"
assert len(req.messages) == 1
assert req.model_group == ModelGroup.REASONING # default
assert req.max_tokens == 4096 # default
assert req.temperature == 0.7 # default
def test_completion_request_full(self) -> None:
"""Test full CompletionRequest."""
req = CompletionRequest(
project_id="proj-123",
agent_id="agent-456",
messages=[ChatMessage(role="user", content="Hi")],
model_group=ModelGroup.CODE,
model_override="claude-sonnet-4",
max_tokens=8192,
temperature=0.5,
stream=True,
session_id="session-789",
metadata={"key": "value"},
)
assert req.model_group == ModelGroup.CODE
assert req.model_override == "claude-sonnet-4"
assert req.max_tokens == 8192
assert req.stream is True
def test_completion_request_validation(self) -> None:
"""Test CompletionRequest validation."""
with pytest.raises(ValueError):
CompletionRequest(
project_id="proj-123",
agent_id="agent-456",
messages=[ChatMessage(role="user", content="Hi")],
max_tokens=0, # Invalid
)
with pytest.raises(ValueError):
CompletionRequest(
project_id="proj-123",
agent_id="agent-456",
messages=[ChatMessage(role="user", content="Hi")],
temperature=-0.1, # Invalid
)
class TestUsageStats:
"""Tests for UsageStats model."""
def test_usage_stats_default(self) -> None:
"""Test default UsageStats."""
stats = UsageStats()
assert stats.prompt_tokens == 0
assert stats.completion_tokens == 0
assert stats.total_tokens == 0
assert stats.cost_usd == 0.0
def test_usage_stats_custom(self) -> None:
"""Test custom UsageStats."""
stats = UsageStats(
prompt_tokens=100,
completion_tokens=50,
total_tokens=150,
cost_usd=0.001,
)
assert stats.prompt_tokens == 100
assert stats.total_tokens == 150
def test_usage_stats_from_response(self) -> None:
"""Test creating UsageStats from response."""
config = MODEL_CONFIGS["claude-opus-4"]
stats = UsageStats.from_response(
prompt_tokens=1000,
completion_tokens=500,
model_config=config,
)
assert stats.prompt_tokens == 1000
assert stats.completion_tokens == 500
assert stats.total_tokens == 1500
# 1000/1M * 15 + 500/1M * 75 = 0.015 + 0.0375 = 0.0525
assert stats.cost_usd == pytest.approx(0.0525, rel=0.01)
class TestCompletionResponse:
"""Tests for CompletionResponse model."""
def test_completion_response_creation(self) -> None:
"""Test creating a CompletionResponse."""
response = CompletionResponse(
id="resp-123",
model="claude-opus-4",
provider="anthropic",
content="Hello, world!",
)
assert response.id == "resp-123"
assert response.model == "claude-opus-4"
assert response.provider == "anthropic"
assert response.content == "Hello, world!"
assert response.finish_reason == "stop"
class TestStreamChunk:
"""Tests for StreamChunk model."""
def test_stream_chunk_creation(self) -> None:
"""Test creating a StreamChunk."""
chunk = StreamChunk(id="chunk-1", delta="Hello")
assert chunk.id == "chunk-1"
assert chunk.delta == "Hello"
assert chunk.finish_reason is None
def test_stream_chunk_final(self) -> None:
"""Test final StreamChunk."""
chunk = StreamChunk(
id="chunk-last",
delta="",
finish_reason="stop",
usage=UsageStats(prompt_tokens=10, completion_tokens=5, total_tokens=15),
)
assert chunk.finish_reason == "stop"
assert chunk.usage is not None
class TestEmbeddingRequest:
"""Tests for EmbeddingRequest model."""
def test_embedding_request_creation(self) -> None:
"""Test creating an EmbeddingRequest."""
req = EmbeddingRequest(
project_id="proj-123",
agent_id="agent-456",
texts=["Hello", "World"],
)
assert req.project_id == "proj-123"
assert len(req.texts) == 2
assert req.model == "text-embedding-3-large" # default
def test_embedding_request_validation(self) -> None:
"""Test EmbeddingRequest validation."""
with pytest.raises(ValueError):
EmbeddingRequest(
project_id="proj-123",
agent_id="agent-456",
texts=[], # Invalid - must have at least 1
)
class TestCostRecord:
"""Tests for CostRecord dataclass."""
def test_cost_record_creation(self) -> None:
"""Test creating a CostRecord."""
record = CostRecord(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=100,
completion_tokens=50,
cost_usd=0.01,
)
assert record.project_id == "proj-123"
assert record.cost_usd == 0.01
assert record.timestamp is not None
class TestUsageReport:
"""Tests for UsageReport model."""
def test_usage_report_creation(self) -> None:
"""Test creating a UsageReport."""
now = datetime.now(UTC)
report = UsageReport(
entity_id="proj-123",
entity_type="project",
period="day",
period_start=now,
period_end=now,
)
assert report.entity_id == "proj-123"
assert report.entity_type == "project"
assert report.total_requests == 0
assert report.total_cost_usd == 0.0
class TestModelInfo:
"""Tests for ModelInfo model."""
def test_model_info_from_config(self) -> None:
"""Test creating ModelInfo from ModelConfig."""
config = MODEL_CONFIGS["claude-opus-4"]
info = ModelInfo.from_config(config, available=True)
assert info.name == "claude-opus-4"
assert info.provider == "anthropic"
assert info.available is True
assert info.supports_vision is True
class TestModelGroupInfo:
"""Tests for ModelGroupInfo model."""
def test_model_group_info_creation(self) -> None:
"""Test creating ModelGroupInfo."""
info = ModelGroupInfo(
name="reasoning",
description="Complex analysis",
primary_model="claude-opus-4",
fallback_models=["gpt-4.1"],
)
assert info.name == "reasoning"
assert len(info.fallback_models) == 1

View File

@@ -0,0 +1,308 @@
"""
Tests for providers module.
"""
import os
from unittest.mock import patch
import pytest
from config import Settings
from models import MODEL_CONFIGS, ModelGroup, Provider
from providers import (
LLMProvider,
build_fallback_config,
build_model_list,
configure_litellm,
get_available_model_groups,
get_available_models,
get_provider,
reset_provider,
)
@pytest.fixture
def full_settings() -> Settings:
"""Settings with all providers configured."""
return Settings(
anthropic_api_key="test-anthropic-key",
openai_api_key="test-openai-key",
google_api_key="test-google-key",
alibaba_api_key="test-alibaba-key",
deepseek_api_key="test-deepseek-key",
litellm_timeout=60,
litellm_cache_enabled=False,
)
@pytest.fixture
def partial_settings() -> Settings:
"""Settings with only some providers configured."""
return Settings(
anthropic_api_key="test-anthropic-key",
openai_api_key=None,
google_api_key=None,
alibaba_api_key=None,
deepseek_api_key=None,
)
@pytest.fixture
def empty_settings() -> Settings:
"""Settings with no providers configured."""
return Settings(
anthropic_api_key=None,
openai_api_key=None,
google_api_key=None,
alibaba_api_key=None,
deepseek_api_key=None,
)
class TestConfigureLiteLLM:
"""Tests for configure_litellm function."""
def test_sets_api_keys(self, full_settings: Settings) -> None:
"""Test that API keys are set in environment."""
with patch.dict(os.environ, {}, clear=True):
configure_litellm(full_settings)
assert os.environ.get("ANTHROPIC_API_KEY") == "test-anthropic-key"
assert os.environ.get("OPENAI_API_KEY") == "test-openai-key"
assert os.environ.get("GEMINI_API_KEY") == "test-google-key"
def test_skips_none_keys(self, partial_settings: Settings) -> None:
"""Test that None keys are not set."""
with patch.dict(os.environ, {}, clear=True):
configure_litellm(partial_settings)
assert os.environ.get("ANTHROPIC_API_KEY") == "test-anthropic-key"
assert "OPENAI_API_KEY" not in os.environ
class TestBuildModelList:
"""Tests for build_model_list function."""
def test_build_with_all_providers(self, full_settings: Settings) -> None:
"""Test building model list with all providers."""
model_list = build_model_list(full_settings)
assert len(model_list) > 0
# Check structure
for entry in model_list:
assert "model_name" in entry
assert "litellm_params" in entry
assert "model" in entry["litellm_params"]
assert "timeout" in entry["litellm_params"]
def test_build_with_partial_providers(self, partial_settings: Settings) -> None:
"""Test building model list with partial providers."""
model_list = build_model_list(partial_settings)
# Should only include Anthropic models
providers = set()
for entry in model_list:
model_name = entry["model_name"]
config = MODEL_CONFIGS.get(model_name)
if config:
providers.add(config.provider)
assert Provider.ANTHROPIC in providers
assert Provider.OPENAI not in providers
def test_build_with_no_providers(self, empty_settings: Settings) -> None:
"""Test building model list with no providers."""
model_list = build_model_list(empty_settings)
assert len(model_list) == 0
def test_build_includes_timeout(self, full_settings: Settings) -> None:
"""Test that model entries include timeout."""
model_list = build_model_list(full_settings)
for entry in model_list:
assert entry["litellm_params"]["timeout"] == 60
class TestBuildFallbackConfig:
"""Tests for build_fallback_config function."""
def test_build_fallbacks_full(self, full_settings: Settings) -> None:
"""Test building fallback config with all providers."""
fallbacks = build_fallback_config(full_settings)
assert len(fallbacks) > 0
# Primary models should have fallbacks
for _primary, chain in fallbacks.items():
assert isinstance(chain, list)
assert len(chain) > 0
def test_build_fallbacks_partial(self, partial_settings: Settings) -> None:
"""Test building fallback config with partial providers."""
fallbacks = build_fallback_config(partial_settings)
# With only Anthropic, there should be no fallbacks
# (fallbacks require at least 2 available models)
for primary, chain in fallbacks.items():
# All models in chain should be from Anthropic
for model in [primary] + chain:
config = MODEL_CONFIGS.get(model)
if config:
assert config.provider == Provider.ANTHROPIC
class TestGetAvailableModels:
"""Tests for get_available_models function."""
def test_get_available_full(self, full_settings: Settings) -> None:
"""Test getting available models with all providers."""
models = get_available_models(full_settings)
assert len(models) > 0
assert "claude-opus-4" in models
assert "gpt-4.1" in models
def test_get_available_partial(self, partial_settings: Settings) -> None:
"""Test getting available models with partial providers."""
models = get_available_models(partial_settings)
assert "claude-opus-4" in models
assert "gpt-4.1" not in models
def test_get_available_empty(self, empty_settings: Settings) -> None:
"""Test getting available models with no providers."""
models = get_available_models(empty_settings)
assert len(models) == 0
class TestGetAvailableModelGroups:
"""Tests for get_available_model_groups function."""
def test_get_groups_full(self, full_settings: Settings) -> None:
"""Test getting groups with all providers."""
groups = get_available_model_groups(full_settings)
assert len(groups) == len(ModelGroup)
assert ModelGroup.REASONING in groups
assert len(groups[ModelGroup.REASONING]) > 0
def test_get_groups_partial(self, partial_settings: Settings) -> None:
"""Test getting groups with partial providers."""
groups = get_available_model_groups(partial_settings)
# Only Anthropic models should be available
for _group, models in groups.items():
for model in models:
config = MODEL_CONFIGS.get(model)
if config:
assert config.provider == Provider.ANTHROPIC
class TestLLMProvider:
"""Tests for LLMProvider class."""
def test_initialization(self, full_settings: Settings) -> None:
"""Test provider initialization."""
provider = LLMProvider(settings=full_settings)
assert provider._initialized is False
assert provider._router is None
def test_initialize(self, full_settings: Settings) -> None:
"""Test provider initialize."""
with patch("providers.Router") as mock_router:
provider = LLMProvider(settings=full_settings)
provider.initialize()
assert provider._initialized is True
mock_router.assert_called_once()
def test_initialize_idempotent(self, full_settings: Settings) -> None:
"""Test that initialize is idempotent."""
with patch("providers.Router") as mock_router:
provider = LLMProvider(settings=full_settings)
provider.initialize()
provider.initialize()
# Should only be called once
assert mock_router.call_count == 1
def test_initialize_no_providers(self, empty_settings: Settings) -> None:
"""Test initialization with no providers."""
provider = LLMProvider(settings=empty_settings)
provider.initialize()
assert provider._initialized is True
assert provider._router is None
def test_router_property(self, full_settings: Settings) -> None:
"""Test router property triggers initialization."""
with patch("providers.Router"):
provider = LLMProvider(settings=full_settings)
_ = provider.router
assert provider._initialized is True
def test_is_available(self, full_settings: Settings) -> None:
"""Test is_available property."""
with patch("providers.Router"):
provider = LLMProvider(settings=full_settings)
assert provider.is_available is True
def test_is_not_available(self, empty_settings: Settings) -> None:
"""Test is_available when no providers."""
provider = LLMProvider(settings=empty_settings)
assert provider.is_available is False
def test_get_model_config(self, full_settings: Settings) -> None:
"""Test getting model config."""
provider = LLMProvider(settings=full_settings)
config = provider.get_model_config("claude-opus-4")
assert config is not None
assert config.name == "claude-opus-4"
assert provider.get_model_config("nonexistent") is None
def test_get_available_models(self, full_settings: Settings) -> None:
"""Test getting available models."""
provider = LLMProvider(settings=full_settings)
models = provider.get_available_models()
assert "claude-opus-4" in models
assert "gpt-4.1" in models
def test_is_model_available(self, full_settings: Settings) -> None:
"""Test checking model availability."""
provider = LLMProvider(settings=full_settings)
assert provider.is_model_available("claude-opus-4") is True
assert provider.is_model_available("nonexistent") is False
class TestGlobalProvider:
"""Tests for global provider functions."""
def test_get_provider(self) -> None:
"""Test getting global provider."""
reset_provider()
provider = get_provider()
assert isinstance(provider, LLMProvider)
def test_get_provider_singleton(self) -> None:
"""Test provider is singleton."""
reset_provider()
provider1 = get_provider()
provider2 = get_provider()
assert provider1 is provider2
def test_reset_provider(self) -> None:
"""Test resetting global provider."""
reset_provider()
provider1 = get_provider()
reset_provider()
provider2 = get_provider()
assert provider1 is not provider2

View File

@@ -0,0 +1,243 @@
"""
Tests for routing module.
"""
import asyncio
import pytest
from config import Settings
from exceptions import (
AllProvidersFailedError,
InvalidModelError,
InvalidModelGroupError,
ModelNotAvailableError,
)
from failover import CircuitBreakerRegistry, reset_circuit_registry
from models import ModelGroup
from providers import reset_provider
from routing import ModelRouter, get_model_router, reset_model_router
@pytest.fixture
def router_settings() -> Settings:
"""Settings for routing tests."""
return Settings(
anthropic_api_key="test-key",
openai_api_key="test-key",
google_api_key="test-key",
)
@pytest.fixture
def router(router_settings: Settings) -> ModelRouter:
"""Create model router for testing."""
reset_circuit_registry()
reset_model_router()
reset_provider()
registry = CircuitBreakerRegistry(settings=router_settings)
return ModelRouter(settings=router_settings, circuit_registry=registry)
class TestModelRouter:
"""Tests for ModelRouter class."""
def test_parse_model_group_valid(self, router: ModelRouter) -> None:
"""Test parsing valid model groups."""
assert router.parse_model_group("reasoning") == ModelGroup.REASONING
assert router.parse_model_group("code") == ModelGroup.CODE
assert router.parse_model_group("fast") == ModelGroup.FAST
assert router.parse_model_group("REASONING") == ModelGroup.REASONING
def test_parse_model_group_aliases(self, router: ModelRouter) -> None:
"""Test parsing model group aliases."""
assert router.parse_model_group("high-reasoning") == ModelGroup.REASONING
assert router.parse_model_group("high_reasoning") == ModelGroup.REASONING
assert router.parse_model_group("code-generation") == ModelGroup.CODE
assert router.parse_model_group("fast-response") == ModelGroup.FAST
def test_parse_model_group_invalid(self, router: ModelRouter) -> None:
"""Test parsing invalid model group."""
with pytest.raises(InvalidModelGroupError) as exc_info:
router.parse_model_group("invalid_group")
assert exc_info.value.model_group == "invalid_group"
assert exc_info.value.available_groups is not None
def test_get_model_config_valid(self, router: ModelRouter) -> None:
"""Test getting valid model config."""
config = router.get_model_config("claude-opus-4")
assert config.name == "claude-opus-4"
assert config.provider.value == "anthropic"
def test_get_model_config_invalid(self, router: ModelRouter) -> None:
"""Test getting invalid model config."""
with pytest.raises(InvalidModelError) as exc_info:
router.get_model_config("nonexistent-model")
assert exc_info.value.model == "nonexistent-model"
def test_get_preferred_group_for_agent(self, router: ModelRouter) -> None:
"""Test getting preferred group for agent types."""
assert router.get_preferred_group_for_agent("product_owner") == ModelGroup.REASONING
assert router.get_preferred_group_for_agent("software_engineer") == ModelGroup.CODE
assert router.get_preferred_group_for_agent("devops_engineer") == ModelGroup.FAST
def test_get_preferred_group_unknown_agent(self, router: ModelRouter) -> None:
"""Test getting preferred group for unknown agent."""
# Should default to REASONING
assert router.get_preferred_group_for_agent("unknown_type") == ModelGroup.REASONING
def test_select_model_by_group(self, router: ModelRouter) -> None:
"""Test selecting model by group."""
model_name, config = asyncio.run(
router.select_model(model_group=ModelGroup.REASONING)
)
assert model_name == "claude-opus-4"
assert config.provider.value == "anthropic"
def test_select_model_by_group_string(self, router: ModelRouter) -> None:
"""Test selecting model by group string."""
model_name, config = asyncio.run(
router.select_model(model_group="code")
)
assert model_name == "claude-sonnet-4"
def test_select_model_with_override(self, router: ModelRouter) -> None:
"""Test selecting specific model override."""
model_name, config = asyncio.run(
router.select_model(
model_group="reasoning",
model_override="gpt-4.1",
)
)
assert model_name == "gpt-4.1"
assert config.provider.value == "openai"
def test_select_model_override_invalid(self, router: ModelRouter) -> None:
"""Test selecting invalid model override."""
with pytest.raises(InvalidModelError):
asyncio.run(
router.select_model(
model_group="reasoning",
model_override="nonexistent-model",
)
)
def test_select_model_override_unavailable(self, router: ModelRouter) -> None: # noqa: ARG002
"""Test selecting unavailable model override."""
# Create router without Alibaba key
settings = Settings(
anthropic_api_key="test-key",
alibaba_api_key=None,
)
registry = CircuitBreakerRegistry(settings=settings)
limited_router = ModelRouter(settings=settings, circuit_registry=registry)
with pytest.raises(ModelNotAvailableError):
asyncio.run(
limited_router.select_model(
model_group="reasoning",
model_override="qwen-max",
)
)
def test_select_model_fallback_on_circuit_open(
self,
router: ModelRouter,
) -> None:
"""Test fallback when primary circuit is open."""
# Open circuit for anthropic
circuit = router._circuit_registry.get_circuit_sync("anthropic")
for _ in range(5):
asyncio.run(circuit.record_failure())
# Should fall back to OpenAI
model_name, config = asyncio.run(
router.select_model(model_group=ModelGroup.REASONING)
)
assert model_name == "gpt-4.1"
assert config.provider.value == "openai"
def test_select_model_all_unavailable(self) -> None:
"""Test when all providers are unavailable."""
settings = Settings(
anthropic_api_key=None,
openai_api_key=None,
google_api_key=None,
)
registry = CircuitBreakerRegistry(settings=settings)
limited_router = ModelRouter(settings=settings, circuit_registry=registry)
with pytest.raises(AllProvidersFailedError) as exc_info:
asyncio.run(
limited_router.select_model(model_group=ModelGroup.REASONING)
)
assert exc_info.value.model_group == "reasoning"
assert len(exc_info.value.attempted_models) > 0
def test_get_available_models_for_group(self, router: ModelRouter) -> None:
"""Test getting available models for a group."""
models = asyncio.run(
router.get_available_models_for_group(ModelGroup.REASONING)
)
assert len(models) > 0
# Should be (name, config, available) tuples
for name, config, _available in models:
assert isinstance(name, str)
assert config is not None
def test_get_available_models_for_group_string(self, router: ModelRouter) -> None:
"""Test getting available models with string group."""
models = asyncio.run(
router.get_available_models_for_group("code")
)
assert len(models) > 0
def test_get_available_models_invalid_group(self, router: ModelRouter) -> None:
"""Test getting models for invalid group."""
with pytest.raises(InvalidModelGroupError):
asyncio.run(
router.get_available_models_for_group("invalid")
)
def test_get_all_model_groups(self, router: ModelRouter) -> None:
"""Test getting all model groups info."""
groups = router.get_all_model_groups()
assert len(groups) == len(ModelGroup)
assert "reasoning" in groups
assert "code" in groups
assert groups["reasoning"]["primary"] == "claude-opus-4"
class TestGlobalRouter:
"""Tests for global router functions."""
def test_get_model_router(self) -> None:
"""Test getting global router."""
reset_model_router()
router = get_model_router()
assert isinstance(router, ModelRouter)
def test_get_model_router_singleton(self) -> None:
"""Test router is singleton."""
reset_model_router()
router1 = get_model_router()
router2 = get_model_router()
assert router1 is router2
def test_reset_model_router(self) -> None:
"""Test resetting global router."""
reset_model_router()
router1 = get_model_router()
reset_model_router()
router2 = get_model_router()
assert router1 is not router2

View File

@@ -0,0 +1,412 @@
"""
Tests for server module.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from config import Settings
from models import ModelGroup
@pytest.fixture
def test_settings() -> Settings:
"""Test settings with mock API keys."""
return Settings(
anthropic_api_key="test-anthropic-key",
openai_api_key="test-openai-key",
google_api_key="test-google-key",
cost_tracking_enabled=False, # Disable for most tests
litellm_cache_enabled=False,
)
@pytest.fixture
def test_client(test_settings: Settings) -> TestClient:
"""Create test client with mocked dependencies."""
with (
patch("server.get_settings", return_value=test_settings),
patch("server.get_provider") as mock_provider,
):
mock_provider.return_value = MagicMock()
mock_provider.return_value.is_available = True
mock_provider.return_value.router = MagicMock()
mock_provider.return_value.get_available_models.return_value = {}
from server import app
return TestClient(app)
class TestHealthEndpoint:
"""Tests for health check endpoint."""
def test_health_check(self, test_client: TestClient) -> None:
"""Test health check returns healthy status."""
with patch("server.get_settings") as mock_settings:
mock_settings.return_value = Settings(
anthropic_api_key="test-key",
)
with patch("server.get_provider") as mock_provider:
mock_provider.return_value = MagicMock()
mock_provider.return_value.is_available = True
response = test_client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert data["service"] == "llm-gateway"
class TestToolDiscoveryEndpoint:
"""Tests for tool discovery endpoint."""
def test_list_tools(self, test_client: TestClient) -> None:
"""Test listing available tools."""
response = test_client.get("/mcp/tools")
assert response.status_code == 200
data = response.json()
assert "tools" in data
assert len(data["tools"]) == 4 # 4 tools defined
tool_names = [t["name"] for t in data["tools"]]
assert "chat_completion" in tool_names
assert "list_models" in tool_names
assert "get_usage" in tool_names
assert "count_tokens" in tool_names
def test_tool_has_schema(self, test_client: TestClient) -> None:
"""Test that tools have input schemas."""
response = test_client.get("/mcp/tools")
data = response.json()
for tool in data["tools"]:
assert "inputSchema" in tool
assert "type" in tool["inputSchema"]
assert tool["inputSchema"]["type"] == "object"
class TestJSONRPCEndpoint:
"""Tests for JSON-RPC endpoint."""
def test_invalid_jsonrpc_version(self, test_client: TestClient) -> None:
"""Test invalid JSON-RPC version."""
response = test_client.post(
"/mcp",
json={
"jsonrpc": "1.0", # Invalid
"method": "tools/list",
"id": 1,
},
)
assert response.status_code == 200
data = response.json()
assert "error" in data
assert data["error"]["code"] == -32600
def test_tools_list(self, test_client: TestClient) -> None:
"""Test tools/list method."""
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/list",
"id": 1,
},
)
assert response.status_code == 200
data = response.json()
assert "result" in data
assert "tools" in data["result"]
def test_unknown_method(self, test_client: TestClient) -> None:
"""Test unknown method."""
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "unknown/method",
"id": 1,
},
)
assert response.status_code == 200
data = response.json()
assert "error" in data
assert data["error"]["code"] == -32601
def test_unknown_tool(self, test_client: TestClient) -> None:
"""Test unknown tool."""
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "unknown_tool",
"arguments": {},
},
"id": 1,
},
)
assert response.status_code == 200
data = response.json()
assert "error" in data
assert "Unknown tool" in data["error"]["message"]
class TestCountTokensTool:
"""Tests for count_tokens tool."""
def test_count_tokens(self, test_client: TestClient) -> None:
"""Test counting tokens."""
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "count_tokens",
"arguments": {
"project_id": "proj-123",
"agent_id": "agent-456",
"text": "Hello, world!",
},
},
"id": 1,
},
)
assert response.status_code == 200
data = response.json()
assert "result" in data
def test_count_tokens_with_model(self, test_client: TestClient) -> None:
"""Test counting tokens with specific model."""
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "count_tokens",
"arguments": {
"project_id": "proj-123",
"agent_id": "agent-456",
"text": "Hello, world!",
"model": "gpt-4",
},
},
"id": 1,
},
)
assert response.status_code == 200
class TestListModelsTool:
"""Tests for list_models tool."""
def test_list_all_models(self, test_client: TestClient) -> None:
"""Test listing all models."""
with patch("server.get_model_router") as mock_router:
mock_router.return_value = MagicMock()
mock_router.return_value.get_available_models_for_group = AsyncMock(
return_value=[]
)
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "list_models",
"arguments": {
"project_id": "proj-123",
"agent_id": "agent-456",
},
},
"id": 1,
},
)
assert response.status_code == 200
def test_list_models_by_group(self, test_client: TestClient) -> None:
"""Test listing models by group."""
with patch("server.get_model_router") as mock_router:
mock_router.return_value = MagicMock()
mock_router.return_value.parse_model_group.return_value = ModelGroup.REASONING
mock_router.return_value.get_available_models_for_group = AsyncMock(
return_value=[]
)
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "list_models",
"arguments": {
"project_id": "proj-123",
"agent_id": "agent-456",
"model_group": "reasoning",
},
},
"id": 1,
},
)
assert response.status_code == 200
class TestGetUsageTool:
"""Tests for get_usage tool."""
def test_get_usage(self, test_client: TestClient) -> None:
"""Test getting usage."""
with patch("server.get_cost_tracker") as mock_tracker:
mock_report = MagicMock()
mock_report.total_requests = 10
mock_report.total_tokens = 1000
mock_report.total_cost_usd = 0.50
mock_report.by_model = {}
mock_report.period_start.isoformat.return_value = "2024-01-01T00:00:00"
mock_report.period_end.isoformat.return_value = "2024-01-02T00:00:00"
mock_tracker.return_value = MagicMock()
mock_tracker.return_value.get_project_usage = AsyncMock(
return_value=mock_report
)
mock_tracker.return_value.get_agent_usage = AsyncMock(
return_value=mock_report
)
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "get_usage",
"arguments": {
"project_id": "proj-123",
"agent_id": "agent-456",
"period": "day",
},
},
"id": 1,
},
)
assert response.status_code == 200
class TestChatCompletionTool:
"""Tests for chat_completion tool."""
def test_chat_completion_streaming_not_supported(
self,
test_client: TestClient,
) -> None:
"""Test that streaming returns info message."""
with patch("server.get_model_router") as mock_router:
mock_router.return_value = MagicMock()
mock_router.return_value.select_model = AsyncMock(
return_value=("claude-opus-4", MagicMock())
)
with patch("server.get_cost_tracker") as mock_tracker:
mock_tracker.return_value = MagicMock()
mock_tracker.return_value.check_budget = AsyncMock(
return_value=(True, 0.0, 100.0)
)
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "chat_completion",
"arguments": {
"project_id": "proj-123",
"agent_id": "agent-456",
"messages": [
{"role": "user", "content": "Hello"}
],
"stream": True,
},
},
"id": 1,
},
)
assert response.status_code == 200
def test_chat_completion_success(self, test_client: TestClient) -> None:
"""Test successful chat completion."""
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Hello, world!"
mock_response.choices[0].finish_reason = "stop"
mock_response.usage = MagicMock()
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 5
with patch("server.get_model_router") as mock_router:
mock_model_config = MagicMock()
mock_model_config.provider.value = "anthropic"
mock_router.return_value = MagicMock()
mock_router.return_value.select_model = AsyncMock(
return_value=("claude-opus-4", mock_model_config)
)
with patch("server.get_cost_tracker") as mock_tracker:
mock_tracker.return_value = MagicMock()
mock_tracker.return_value.check_budget = AsyncMock(
return_value=(True, 0.0, 100.0)
)
mock_tracker.return_value.record_usage = AsyncMock()
with patch("server.get_provider") as mock_prov:
mock_prov.return_value = MagicMock()
mock_prov.return_value.router = MagicMock()
mock_prov.return_value.router.acompletion = AsyncMock(
return_value=mock_response
)
with patch("server.get_circuit_registry") as mock_reg:
mock_circuit = MagicMock()
mock_circuit.record_success = AsyncMock()
mock_reg.return_value = MagicMock()
mock_reg.return_value.get_circuit_sync.return_value = mock_circuit
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "chat_completion",
"arguments": {
"project_id": "proj-123",
"agent_id": "agent-456",
"messages": [
{"role": "user", "content": "Hello"}
],
},
},
"id": 1,
},
)
assert response.status_code == 200

View File

@@ -0,0 +1,312 @@
"""
Tests for streaming module.
"""
import asyncio
import json
from unittest.mock import MagicMock
import pytest
from models import StreamChunk, UsageStats
from streaming import (
StreamAccumulator,
StreamBuffer,
format_sse_chunk,
format_sse_done,
format_sse_error,
stream_to_string,
wrap_litellm_stream,
)
class TestStreamAccumulator:
"""Tests for StreamAccumulator class."""
def test_initial_state(self) -> None:
"""Test initial accumulator state."""
acc = StreamAccumulator()
assert acc.request_id is not None
assert acc.content == ""
assert acc.chunks_received == 0
assert acc.prompt_tokens == 0
assert acc.completion_tokens == 0
assert acc.model is None
assert acc.finish_reason is None
def test_custom_request_id(self) -> None:
"""Test accumulator with custom request ID."""
acc = StreamAccumulator(request_id="custom-id")
assert acc.request_id == "custom-id"
def test_add_chunk_text(self) -> None:
"""Test adding text chunks."""
acc = StreamAccumulator()
acc.add_chunk("Hello")
acc.add_chunk(", ")
acc.add_chunk("world!")
assert acc.content == "Hello, world!"
assert acc.chunks_received == 3
def test_add_chunk_with_finish_reason(self) -> None:
"""Test adding chunk with finish reason."""
acc = StreamAccumulator()
acc.add_chunk("Final", finish_reason="stop")
assert acc.finish_reason == "stop"
def test_add_chunk_with_model(self) -> None:
"""Test adding chunk with model info."""
acc = StreamAccumulator()
acc.add_chunk("Text", model="claude-opus-4")
assert acc.model == "claude-opus-4"
def test_add_chunk_with_usage(self) -> None:
"""Test adding chunk with usage stats."""
acc = StreamAccumulator()
acc.add_chunk(
"Text",
usage={"prompt_tokens": 10, "completion_tokens": 5},
)
assert acc.prompt_tokens == 10
assert acc.completion_tokens == 5
assert acc.total_tokens == 15
def test_start_and_finish(self) -> None:
"""Test start and finish timing."""
acc = StreamAccumulator()
assert acc.duration_seconds is None
acc.start()
acc.finish()
assert acc.duration_seconds is not None
assert acc.duration_seconds >= 0
def test_get_usage_stats(self) -> None:
"""Test getting usage stats."""
acc = StreamAccumulator()
acc.add_chunk("", usage={"prompt_tokens": 100, "completion_tokens": 50})
stats = acc.get_usage_stats(cost_usd=0.01)
assert stats.prompt_tokens == 100
assert stats.completion_tokens == 50
assert stats.total_tokens == 150
assert stats.cost_usd == 0.01
class TestWrapLiteLLMStream:
"""Tests for wrap_litellm_stream function."""
async def test_wrap_stream_basic(self) -> None:
"""Test wrapping a basic stream."""
# Create mock stream chunks
async def mock_stream():
chunk1 = MagicMock()
chunk1.choices = [MagicMock()]
chunk1.choices[0].delta = MagicMock()
chunk1.choices[0].delta.content = "Hello"
chunk1.choices[0].finish_reason = None
chunk1.model = "test-model"
chunk1.usage = None
yield chunk1
chunk2 = MagicMock()
chunk2.choices = [MagicMock()]
chunk2.choices[0].delta = MagicMock()
chunk2.choices[0].delta.content = " World"
chunk2.choices[0].finish_reason = "stop"
chunk2.model = "test-model"
chunk2.usage = MagicMock()
chunk2.usage.prompt_tokens = 5
chunk2.usage.completion_tokens = 2
yield chunk2
accumulator = StreamAccumulator()
chunks = []
async for chunk in wrap_litellm_stream(mock_stream(), accumulator):
chunks.append(chunk)
assert len(chunks) == 2
assert chunks[0].delta == "Hello"
assert chunks[1].delta == " World"
assert chunks[1].finish_reason == "stop"
assert accumulator.content == "Hello World"
async def test_wrap_stream_without_accumulator(self) -> None:
"""Test wrapping stream without accumulator."""
async def mock_stream():
chunk = MagicMock()
chunk.choices = [MagicMock()]
chunk.choices[0].delta = MagicMock()
chunk.choices[0].delta.content = "Test"
chunk.choices[0].finish_reason = None
chunk.model = None
chunk.usage = None
yield chunk
chunks = []
async for chunk in wrap_litellm_stream(mock_stream()):
chunks.append(chunk)
assert len(chunks) == 1
class TestSSEFormatting:
"""Tests for SSE formatting functions."""
def test_format_sse_chunk_basic(self) -> None:
"""Test formatting basic chunk."""
chunk = StreamChunk(id="chunk-1", delta="Hello")
result = format_sse_chunk(chunk)
assert result.startswith("data: ")
assert result.endswith("\n\n")
# Parse the JSON
data = json.loads(result[6:-2])
assert data["id"] == "chunk-1"
assert data["delta"] == "Hello"
def test_format_sse_chunk_with_finish(self) -> None:
"""Test formatting chunk with finish reason."""
chunk = StreamChunk(
id="chunk-2",
delta="",
finish_reason="stop",
)
result = format_sse_chunk(chunk)
data = json.loads(result[6:-2])
assert data["finish_reason"] == "stop"
def test_format_sse_chunk_with_usage(self) -> None:
"""Test formatting chunk with usage."""
chunk = StreamChunk(
id="chunk-3",
delta="",
finish_reason="stop",
usage=UsageStats(
prompt_tokens=10,
completion_tokens=5,
total_tokens=15,
cost_usd=0.001,
),
)
result = format_sse_chunk(chunk)
data = json.loads(result[6:-2])
assert "usage" in data
assert data["usage"]["prompt_tokens"] == 10
def test_format_sse_done(self) -> None:
"""Test formatting done message."""
result = format_sse_done()
assert result == "data: [DONE]\n\n"
def test_format_sse_error(self) -> None:
"""Test formatting error message."""
result = format_sse_error("Something went wrong", code="ERROR_CODE")
data = json.loads(result[6:-2])
assert data["error"] == "Something went wrong"
assert data["code"] == "ERROR_CODE"
def test_format_sse_error_without_code(self) -> None:
"""Test formatting error without code."""
result = format_sse_error("Error message")
data = json.loads(result[6:-2])
assert data["error"] == "Error message"
assert "code" not in data
class TestStreamBuffer:
"""Tests for StreamBuffer class."""
async def test_buffer_basic(self) -> None:
"""Test basic buffer operations."""
buffer = StreamBuffer(max_size=10)
# Producer
async def produce():
await buffer.put(StreamChunk(id="1", delta="Hello"))
await buffer.put(StreamChunk(id="2", delta=" World"))
await buffer.done()
# Consumer
chunks = []
asyncio.create_task(produce())
async for chunk in buffer:
chunks.append(chunk)
assert len(chunks) == 2
assert chunks[0].delta == "Hello"
assert chunks[1].delta == " World"
async def test_buffer_error(self) -> None:
"""Test buffer with error."""
buffer = StreamBuffer()
async def produce():
await buffer.put(StreamChunk(id="1", delta="Hello"))
await buffer.error(ValueError("Test error"))
asyncio.create_task(produce())
with pytest.raises(ValueError, match="Test error"):
async for _ in buffer:
pass
async def test_buffer_put_after_done(self) -> None:
"""Test putting after done raises."""
buffer = StreamBuffer()
await buffer.done()
with pytest.raises(RuntimeError, match="closed"):
await buffer.put(StreamChunk(id="1", delta="Test"))
class TestStreamToString:
"""Tests for stream_to_string function."""
async def test_stream_to_string_basic(self) -> None:
"""Test converting stream to string."""
async def mock_stream():
yield StreamChunk(id="1", delta="Hello")
yield StreamChunk(id="2", delta=" ")
yield StreamChunk(id="3", delta="World")
yield StreamChunk(
id="4",
delta="",
finish_reason="stop",
usage=UsageStats(prompt_tokens=5, completion_tokens=3),
)
content, usage = await stream_to_string(mock_stream())
assert content == "Hello World"
assert usage is not None
assert usage.prompt_tokens == 5
async def test_stream_to_string_no_usage(self) -> None:
"""Test stream without usage stats."""
async def mock_stream():
yield StreamChunk(id="1", delta="Test")
content, usage = await stream_to_string(mock_stream())
assert content == "Test"
assert usage is None