feat(llm-gateway): implement LLM Gateway MCP Server (#56)
Implements complete LLM Gateway MCP Server with: - FastMCP server with 4 tools: chat_completion, list_models, get_usage, count_tokens - LiteLLM Router with multi-provider failover chains - Circuit breaker pattern for fault tolerance - Redis-based cost tracking per project/agent - Comprehensive test suite (209 tests, 92% coverage) Model groups defined per ADR-004: - reasoning: claude-opus-4 → gpt-4.1 → gemini-2.5-pro - code: claude-sonnet-4 → gpt-4.1 → deepseek-coder - fast: claude-haiku → gpt-4.1-mini → gemini-2.0-flash 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
1
mcp-servers/llm-gateway/tests/__init__.py
Normal file
1
mcp-servers/llm-gateway/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for LLM Gateway MCP Server."""
|
||||
204
mcp-servers/llm-gateway/tests/conftest.py
Normal file
204
mcp-servers/llm-gateway/tests/conftest.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""
|
||||
Pytest fixtures for LLM Gateway tests.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import fakeredis.aioredis
|
||||
import pytest
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from config import Settings
|
||||
from cost_tracking import CostTracker, reset_cost_tracker
|
||||
from failover import CircuitBreakerRegistry, reset_circuit_registry
|
||||
from providers import LLMProvider, reset_provider
|
||||
from routing import ModelRouter, reset_model_router
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_settings() -> Settings:
|
||||
"""Create test settings with mock API keys."""
|
||||
return Settings(
|
||||
host="127.0.0.1",
|
||||
port=8001,
|
||||
debug=True,
|
||||
redis_url="redis://localhost:6379/0",
|
||||
anthropic_api_key="test-anthropic-key",
|
||||
openai_api_key="test-openai-key",
|
||||
google_api_key="test-google-key",
|
||||
litellm_timeout=30,
|
||||
litellm_max_retries=2,
|
||||
litellm_cache_enabled=False,
|
||||
cost_tracking_enabled=True,
|
||||
circuit_failure_threshold=3,
|
||||
circuit_recovery_timeout=10,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def settings_no_providers() -> Settings:
|
||||
"""Create settings with no providers configured."""
|
||||
return Settings(
|
||||
host="127.0.0.1",
|
||||
port=8001,
|
||||
debug=False,
|
||||
redis_url="redis://localhost:6379/0",
|
||||
anthropic_api_key=None,
|
||||
openai_api_key=None,
|
||||
google_api_key=None,
|
||||
alibaba_api_key=None,
|
||||
deepseek_api_key=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_redis() -> fakeredis.aioredis.FakeRedis:
|
||||
"""Create a fake Redis instance for testing."""
|
||||
return fakeredis.aioredis.FakeRedis(decode_responses=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def cost_tracker(
|
||||
fake_redis: fakeredis.aioredis.FakeRedis,
|
||||
test_settings: Settings,
|
||||
) -> AsyncIterator[CostTracker]:
|
||||
"""Create a cost tracker with fake Redis."""
|
||||
reset_cost_tracker()
|
||||
tracker = CostTracker(redis_client=fake_redis, settings=test_settings)
|
||||
yield tracker
|
||||
await tracker.close()
|
||||
reset_cost_tracker()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def circuit_registry(test_settings: Settings) -> Iterator[CircuitBreakerRegistry]:
|
||||
"""Create a circuit breaker registry for testing."""
|
||||
reset_circuit_registry()
|
||||
registry = CircuitBreakerRegistry(settings=test_settings)
|
||||
yield registry
|
||||
reset_circuit_registry()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_router(
|
||||
test_settings: Settings,
|
||||
circuit_registry: CircuitBreakerRegistry,
|
||||
) -> Iterator[ModelRouter]:
|
||||
"""Create a model router for testing."""
|
||||
reset_model_router()
|
||||
router = ModelRouter(settings=test_settings, circuit_registry=circuit_registry)
|
||||
yield router
|
||||
reset_model_router()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_provider(test_settings: Settings) -> Iterator[LLMProvider]:
|
||||
"""Create an LLM provider for testing."""
|
||||
reset_provider()
|
||||
provider = LLMProvider(settings=test_settings)
|
||||
yield provider
|
||||
reset_provider()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_litellm_response() -> dict[str, Any]:
|
||||
"""Create a mock LiteLLM response."""
|
||||
return {
|
||||
"id": "test-response-id",
|
||||
"model": "claude-opus-4",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "This is a test response.",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
"total_tokens": 30,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_completion_response() -> MagicMock:
|
||||
"""Create a mock completion response object."""
|
||||
response = MagicMock()
|
||||
response.id = "test-response-id"
|
||||
response.model = "claude-opus-4"
|
||||
|
||||
choice = MagicMock()
|
||||
choice.index = 0
|
||||
choice.message = MagicMock()
|
||||
choice.message.content = "This is a test response."
|
||||
choice.finish_reason = "stop"
|
||||
response.choices = [choice]
|
||||
|
||||
response.usage = MagicMock()
|
||||
response.usage.prompt_tokens = 10
|
||||
response.usage.completion_tokens = 20
|
||||
response.usage.total_tokens = 30
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_messages() -> list[dict[str, str]]:
|
||||
"""Sample chat messages for testing."""
|
||||
return [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_project_id() -> str:
|
||||
"""Sample project ID."""
|
||||
return "proj-12345678-1234-1234-1234-123456789abc"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_agent_id() -> str:
|
||||
"""Sample agent ID."""
|
||||
return "agent-87654321-4321-4321-4321-cba987654321"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_session_id() -> str:
|
||||
"""Sample session ID."""
|
||||
return "session-11111111-2222-3333-4444-555555555555"
|
||||
|
||||
|
||||
# Reset all global state after each test
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals() -> Iterator[None]:
|
||||
"""Reset all global state after each test."""
|
||||
yield
|
||||
reset_cost_tracker()
|
||||
reset_circuit_registry()
|
||||
reset_model_router()
|
||||
reset_provider()
|
||||
|
||||
|
||||
# Mock environment variables for tests
|
||||
# Note: Not autouse=True to avoid affecting default value tests
|
||||
@pytest.fixture
|
||||
def mock_env_vars() -> Iterator[None]:
|
||||
"""Set test environment variables."""
|
||||
env_vars = {
|
||||
"LLM_GATEWAY_HOST": "127.0.0.1",
|
||||
"LLM_GATEWAY_PORT": "8001",
|
||||
"LLM_GATEWAY_DEBUG": "true",
|
||||
}
|
||||
with patch.dict(os.environ, env_vars, clear=False):
|
||||
yield
|
||||
200
mcp-servers/llm-gateway/tests/test_config.py
Normal file
200
mcp-servers/llm-gateway/tests/test_config.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""
|
||||
Tests for config module.
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from config import Settings, get_settings
|
||||
|
||||
|
||||
class TestSettings:
|
||||
"""Tests for Settings class."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Test default configuration values."""
|
||||
settings = Settings()
|
||||
|
||||
assert settings.host == "0.0.0.0"
|
||||
assert settings.port == 8001
|
||||
assert settings.debug is False
|
||||
assert settings.redis_url == "redis://localhost:6379/0"
|
||||
assert settings.litellm_timeout == 120
|
||||
assert settings.circuit_failure_threshold == 5
|
||||
|
||||
def test_custom_values(self) -> None:
|
||||
"""Test custom configuration values."""
|
||||
settings = Settings(
|
||||
host="127.0.0.1",
|
||||
port=9000,
|
||||
debug=True,
|
||||
redis_url="redis://custom:6380/1",
|
||||
litellm_timeout=60,
|
||||
)
|
||||
|
||||
assert settings.host == "127.0.0.1"
|
||||
assert settings.port == 9000
|
||||
assert settings.debug is True
|
||||
assert settings.redis_url == "redis://custom:6380/1"
|
||||
assert settings.litellm_timeout == 60
|
||||
|
||||
def test_port_validation_valid(self) -> None:
|
||||
"""Test valid port numbers."""
|
||||
settings = Settings(port=1)
|
||||
assert settings.port == 1
|
||||
|
||||
settings = Settings(port=65535)
|
||||
assert settings.port == 65535
|
||||
|
||||
settings = Settings(port=8080)
|
||||
assert settings.port == 8080
|
||||
|
||||
def test_port_validation_invalid(self) -> None:
|
||||
"""Test invalid port numbers."""
|
||||
with pytest.raises(ValueError, match="Port must be between"):
|
||||
Settings(port=0)
|
||||
|
||||
with pytest.raises(ValueError, match="Port must be between"):
|
||||
Settings(port=65536)
|
||||
|
||||
with pytest.raises(ValueError, match="Port must be between"):
|
||||
Settings(port=-1)
|
||||
|
||||
def test_ttl_validation_valid(self) -> None:
|
||||
"""Test valid TTL values."""
|
||||
settings = Settings(redis_ttl_hours=1)
|
||||
assert settings.redis_ttl_hours == 1
|
||||
|
||||
settings = Settings(redis_ttl_hours=168) # 1 week
|
||||
assert settings.redis_ttl_hours == 168
|
||||
|
||||
def test_ttl_validation_invalid(self) -> None:
|
||||
"""Test invalid TTL values."""
|
||||
with pytest.raises(ValueError, match="Redis TTL must be positive"):
|
||||
Settings(redis_ttl_hours=0)
|
||||
|
||||
with pytest.raises(ValueError, match="Redis TTL must be positive"):
|
||||
Settings(redis_ttl_hours=-1)
|
||||
|
||||
def test_failure_threshold_validation(self) -> None:
|
||||
"""Test circuit failure threshold validation."""
|
||||
settings = Settings(circuit_failure_threshold=1)
|
||||
assert settings.circuit_failure_threshold == 1
|
||||
|
||||
settings = Settings(circuit_failure_threshold=100)
|
||||
assert settings.circuit_failure_threshold == 100
|
||||
|
||||
with pytest.raises(ValueError, match="Failure threshold must be between"):
|
||||
Settings(circuit_failure_threshold=0)
|
||||
|
||||
with pytest.raises(ValueError, match="Failure threshold must be between"):
|
||||
Settings(circuit_failure_threshold=101)
|
||||
|
||||
def test_timeout_validation(self) -> None:
|
||||
"""Test timeout validation."""
|
||||
settings = Settings(litellm_timeout=1)
|
||||
assert settings.litellm_timeout == 1
|
||||
|
||||
settings = Settings(litellm_timeout=600)
|
||||
assert settings.litellm_timeout == 600
|
||||
|
||||
with pytest.raises(ValueError, match="Timeout must be between"):
|
||||
Settings(litellm_timeout=0)
|
||||
|
||||
with pytest.raises(ValueError, match="Timeout must be between"):
|
||||
Settings(litellm_timeout=601)
|
||||
|
||||
def test_get_available_providers_none(self) -> None:
|
||||
"""Test getting available providers with none configured."""
|
||||
settings = Settings()
|
||||
providers = settings.get_available_providers()
|
||||
assert providers == []
|
||||
|
||||
def test_get_available_providers_some(self) -> None:
|
||||
"""Test getting available providers with some configured."""
|
||||
settings = Settings(
|
||||
anthropic_api_key="test-key",
|
||||
openai_api_key="test-key",
|
||||
)
|
||||
providers = settings.get_available_providers()
|
||||
|
||||
assert "anthropic" in providers
|
||||
assert "openai" in providers
|
||||
assert "google" not in providers
|
||||
assert len(providers) == 2
|
||||
|
||||
def test_get_available_providers_all(self) -> None:
|
||||
"""Test getting available providers with all configured."""
|
||||
settings = Settings(
|
||||
anthropic_api_key="test-key",
|
||||
openai_api_key="test-key",
|
||||
google_api_key="test-key",
|
||||
alibaba_api_key="test-key",
|
||||
deepseek_api_key="test-key",
|
||||
)
|
||||
providers = settings.get_available_providers()
|
||||
|
||||
assert len(providers) == 5
|
||||
assert "anthropic" in providers
|
||||
assert "openai" in providers
|
||||
assert "google" in providers
|
||||
assert "alibaba" in providers
|
||||
assert "deepseek" in providers
|
||||
|
||||
def test_has_any_provider_false(self) -> None:
|
||||
"""Test has_any_provider when none configured."""
|
||||
settings = Settings()
|
||||
assert settings.has_any_provider() is False
|
||||
|
||||
def test_has_any_provider_true(self) -> None:
|
||||
"""Test has_any_provider when at least one configured."""
|
||||
settings = Settings(anthropic_api_key="test-key")
|
||||
assert settings.has_any_provider() is True
|
||||
|
||||
def test_deepseek_base_url_counts_as_provider(self) -> None:
|
||||
"""Test that DeepSeek base URL alone counts as provider."""
|
||||
settings = Settings(deepseek_base_url="http://localhost:8000")
|
||||
providers = settings.get_available_providers()
|
||||
assert "deepseek" in providers
|
||||
|
||||
|
||||
class TestGetSettings:
|
||||
"""Tests for get_settings function."""
|
||||
|
||||
def test_get_settings_returns_settings(self) -> None:
|
||||
"""Test that get_settings returns a Settings instance."""
|
||||
# Clear the cache first
|
||||
get_settings.cache_clear()
|
||||
|
||||
settings = get_settings()
|
||||
assert isinstance(settings, Settings)
|
||||
|
||||
def test_get_settings_is_cached(self) -> None:
|
||||
"""Test that get_settings returns cached instance."""
|
||||
get_settings.cache_clear()
|
||||
|
||||
settings1 = get_settings()
|
||||
settings2 = get_settings()
|
||||
|
||||
assert settings1 is settings2
|
||||
|
||||
def test_env_var_override(self) -> None:
|
||||
"""Test that environment variables override defaults."""
|
||||
get_settings.cache_clear()
|
||||
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"LLM_GATEWAY_HOST": "192.168.1.1",
|
||||
"LLM_GATEWAY_PORT": "9999",
|
||||
"LLM_GATEWAY_DEBUG": "true",
|
||||
},
|
||||
):
|
||||
get_settings.cache_clear()
|
||||
settings = get_settings()
|
||||
|
||||
assert settings.host == "192.168.1.1"
|
||||
assert settings.port == 9999
|
||||
assert settings.debug is True
|
||||
417
mcp-servers/llm-gateway/tests/test_cost_tracking.py
Normal file
417
mcp-servers/llm-gateway/tests/test_cost_tracking.py
Normal file
@@ -0,0 +1,417 @@
|
||||
"""
|
||||
Tests for cost_tracking module.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import fakeredis.aioredis
|
||||
import pytest
|
||||
|
||||
from config import Settings
|
||||
from cost_tracking import (
|
||||
CostTracker,
|
||||
calculate_cost,
|
||||
close_cost_tracker,
|
||||
get_cost_tracker,
|
||||
reset_cost_tracker,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tracker_settings() -> Settings:
|
||||
"""Settings for cost tracker tests."""
|
||||
return Settings(
|
||||
redis_url="redis://localhost:6379/0",
|
||||
redis_prefix="test_llm_gateway",
|
||||
cost_tracking_enabled=True,
|
||||
cost_alert_threshold=100.0,
|
||||
default_budget_limit=1000.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_redis() -> fakeredis.aioredis.FakeRedis:
|
||||
"""Create fake Redis for testing."""
|
||||
return fakeredis.aioredis.FakeRedis(decode_responses=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tracker(
|
||||
fake_redis: fakeredis.aioredis.FakeRedis,
|
||||
tracker_settings: Settings,
|
||||
) -> CostTracker:
|
||||
"""Create cost tracker with fake Redis."""
|
||||
return CostTracker(redis_client=fake_redis, settings=tracker_settings)
|
||||
|
||||
|
||||
class TestCalculateCost:
|
||||
"""Tests for calculate_cost function."""
|
||||
|
||||
def test_calculate_cost_known_model(self) -> None:
|
||||
"""Test calculating cost for known model."""
|
||||
# claude-opus-4: $15/1M input, $75/1M output
|
||||
cost = calculate_cost(
|
||||
model="claude-opus-4",
|
||||
prompt_tokens=1000,
|
||||
completion_tokens=500,
|
||||
)
|
||||
|
||||
# 1000/1M * 15 + 500/1M * 75 = 0.015 + 0.0375 = 0.0525
|
||||
assert cost == pytest.approx(0.0525, rel=0.001)
|
||||
|
||||
def test_calculate_cost_unknown_model(self) -> None:
|
||||
"""Test calculating cost for unknown model."""
|
||||
cost = calculate_cost(
|
||||
model="unknown-model",
|
||||
prompt_tokens=1000,
|
||||
completion_tokens=500,
|
||||
)
|
||||
|
||||
assert cost == 0.0
|
||||
|
||||
def test_calculate_cost_zero_tokens(self) -> None:
|
||||
"""Test calculating cost with zero tokens."""
|
||||
cost = calculate_cost(
|
||||
model="claude-opus-4",
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
|
||||
assert cost == 0.0
|
||||
|
||||
def test_calculate_cost_large_token_counts(self) -> None:
|
||||
"""Test calculating cost with large token counts."""
|
||||
cost = calculate_cost(
|
||||
model="claude-opus-4",
|
||||
prompt_tokens=1_000_000,
|
||||
completion_tokens=500_000,
|
||||
)
|
||||
|
||||
# 1M * 15/1M + 500K * 75/1M = 15 + 37.5 = 52.5
|
||||
assert cost == pytest.approx(52.5, rel=0.001)
|
||||
|
||||
|
||||
class TestCostTracker:
|
||||
"""Tests for CostTracker class."""
|
||||
|
||||
def test_record_usage(self, tracker: CostTracker) -> None:
|
||||
"""Test recording usage."""
|
||||
asyncio.run(
|
||||
tracker.record_usage(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
model="claude-opus-4",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cost_usd=0.01,
|
||||
)
|
||||
)
|
||||
|
||||
# Verify by getting usage report
|
||||
report = asyncio.run(
|
||||
tracker.get_project_usage("proj-123", period="day")
|
||||
)
|
||||
|
||||
assert report.total_requests == 1
|
||||
assert report.total_cost_usd == pytest.approx(0.01, rel=0.01)
|
||||
|
||||
def test_record_usage_disabled(self, tracker_settings: Settings) -> None:
|
||||
"""Test recording is skipped when disabled."""
|
||||
settings = Settings(**{
|
||||
**tracker_settings.model_dump(),
|
||||
"cost_tracking_enabled": False,
|
||||
})
|
||||
fake_redis = fakeredis.aioredis.FakeRedis(decode_responses=True)
|
||||
disabled_tracker = CostTracker(redis_client=fake_redis, settings=settings)
|
||||
|
||||
# This should not raise and should not record
|
||||
asyncio.run(
|
||||
disabled_tracker.record_usage(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
model="claude-opus-4",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cost_usd=0.01,
|
||||
)
|
||||
)
|
||||
|
||||
# Usage should be empty
|
||||
report = asyncio.run(
|
||||
disabled_tracker.get_project_usage("proj-123", period="day")
|
||||
)
|
||||
assert report.total_requests == 0
|
||||
|
||||
def test_record_usage_with_session(self, tracker: CostTracker) -> None:
|
||||
"""Test recording usage with session ID."""
|
||||
asyncio.run(
|
||||
tracker.record_usage(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
model="claude-opus-4",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cost_usd=0.01,
|
||||
session_id="session-789",
|
||||
)
|
||||
)
|
||||
|
||||
# Verify session usage
|
||||
session_usage = asyncio.run(
|
||||
tracker.get_session_usage("session-789")
|
||||
)
|
||||
|
||||
assert session_usage["session_id"] == "session-789"
|
||||
assert session_usage["total_cost_usd"] == pytest.approx(0.01, rel=0.01)
|
||||
|
||||
def test_get_project_usage_empty(self, tracker: CostTracker) -> None:
|
||||
"""Test getting usage for project with no data."""
|
||||
report = asyncio.run(
|
||||
tracker.get_project_usage("nonexistent-project", period="day")
|
||||
)
|
||||
|
||||
assert report.entity_id == "nonexistent-project"
|
||||
assert report.entity_type == "project"
|
||||
assert report.total_requests == 0
|
||||
assert report.total_cost_usd == 0.0
|
||||
|
||||
def test_get_project_usage_multiple_models(self, tracker: CostTracker) -> None:
|
||||
"""Test usage tracking across multiple models."""
|
||||
# Record usage for different models
|
||||
asyncio.run(
|
||||
tracker.record_usage(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
model="claude-opus-4",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cost_usd=0.01,
|
||||
)
|
||||
)
|
||||
asyncio.run(
|
||||
tracker.record_usage(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
model="gpt-4.1",
|
||||
prompt_tokens=200,
|
||||
completion_tokens=100,
|
||||
cost_usd=0.02,
|
||||
)
|
||||
)
|
||||
|
||||
report = asyncio.run(
|
||||
tracker.get_project_usage("proj-123", period="day")
|
||||
)
|
||||
|
||||
assert report.total_requests == 2
|
||||
assert len(report.by_model) == 2
|
||||
assert "claude-opus-4" in report.by_model
|
||||
assert "gpt-4.1" in report.by_model
|
||||
|
||||
def test_get_agent_usage(self, tracker: CostTracker) -> None:
|
||||
"""Test getting agent usage."""
|
||||
asyncio.run(
|
||||
tracker.record_usage(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
model="claude-opus-4",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cost_usd=0.01,
|
||||
)
|
||||
)
|
||||
|
||||
report = asyncio.run(
|
||||
tracker.get_agent_usage("agent-456", period="day")
|
||||
)
|
||||
|
||||
assert report.entity_id == "agent-456"
|
||||
assert report.entity_type == "agent"
|
||||
assert report.total_requests == 1
|
||||
|
||||
def test_usage_periods(self, tracker: CostTracker) -> None:
|
||||
"""Test different usage periods."""
|
||||
asyncio.run(
|
||||
tracker.record_usage(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
model="claude-opus-4",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cost_usd=0.01,
|
||||
)
|
||||
)
|
||||
|
||||
# Check different periods
|
||||
hour_report = asyncio.run(
|
||||
tracker.get_project_usage("proj-123", period="hour")
|
||||
)
|
||||
day_report = asyncio.run(
|
||||
tracker.get_project_usage("proj-123", period="day")
|
||||
)
|
||||
month_report = asyncio.run(
|
||||
tracker.get_project_usage("proj-123", period="month")
|
||||
)
|
||||
|
||||
assert hour_report.period == "hour"
|
||||
assert day_report.period == "day"
|
||||
assert month_report.period == "month"
|
||||
|
||||
# All should have the same data
|
||||
assert hour_report.total_requests == 1
|
||||
assert day_report.total_requests == 1
|
||||
assert month_report.total_requests == 1
|
||||
|
||||
def test_check_budget_within(self, tracker: CostTracker) -> None:
|
||||
"""Test budget check when within limit."""
|
||||
# Record some usage
|
||||
asyncio.run(
|
||||
tracker.record_usage(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
model="claude-opus-4",
|
||||
prompt_tokens=1000,
|
||||
completion_tokens=500,
|
||||
cost_usd=50.0,
|
||||
)
|
||||
)
|
||||
|
||||
within, current, limit = asyncio.run(
|
||||
tracker.check_budget("proj-123", budget_limit=100.0)
|
||||
)
|
||||
|
||||
assert within is True
|
||||
assert current == pytest.approx(50.0, rel=0.01)
|
||||
assert limit == 100.0
|
||||
|
||||
def test_check_budget_exceeded(self, tracker: CostTracker) -> None:
|
||||
"""Test budget check when exceeded."""
|
||||
asyncio.run(
|
||||
tracker.record_usage(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
model="claude-opus-4",
|
||||
prompt_tokens=1000,
|
||||
completion_tokens=500,
|
||||
cost_usd=150.0,
|
||||
)
|
||||
)
|
||||
|
||||
within, current, limit = asyncio.run(
|
||||
tracker.check_budget("proj-123", budget_limit=100.0)
|
||||
)
|
||||
|
||||
assert within is False
|
||||
assert current >= limit
|
||||
|
||||
def test_check_budget_default_limit(self, tracker: CostTracker) -> None:
|
||||
"""Test budget check with default limit."""
|
||||
within, current, limit = asyncio.run(
|
||||
tracker.check_budget("proj-123")
|
||||
)
|
||||
|
||||
assert limit == 1000.0 # Default from settings
|
||||
|
||||
def test_estimate_request_cost_known_model(self, tracker: CostTracker) -> None:
|
||||
"""Test estimating cost for known model."""
|
||||
cost = asyncio.run(
|
||||
tracker.estimate_request_cost(
|
||||
model="claude-opus-4",
|
||||
prompt_tokens=1000,
|
||||
max_completion_tokens=500,
|
||||
)
|
||||
)
|
||||
|
||||
# 1000/1M * 15 + 500/1M * 75 = 0.015 + 0.0375 = 0.0525
|
||||
assert cost == pytest.approx(0.0525, rel=0.01)
|
||||
|
||||
def test_estimate_request_cost_unknown_model(self, tracker: CostTracker) -> None:
|
||||
"""Test estimating cost for unknown model."""
|
||||
cost = asyncio.run(
|
||||
tracker.estimate_request_cost(
|
||||
model="unknown-model",
|
||||
prompt_tokens=1000,
|
||||
max_completion_tokens=500,
|
||||
)
|
||||
)
|
||||
|
||||
# Uses fallback estimate
|
||||
assert cost > 0
|
||||
|
||||
def test_should_alert_below_threshold(self, tracker: CostTracker) -> None:
|
||||
"""Test alert check when below threshold."""
|
||||
asyncio.run(
|
||||
tracker.record_usage(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
model="claude-opus-4",
|
||||
prompt_tokens=1000,
|
||||
completion_tokens=500,
|
||||
cost_usd=50.0,
|
||||
)
|
||||
)
|
||||
|
||||
should_alert, current = asyncio.run(
|
||||
tracker.should_alert("proj-123", threshold=100.0)
|
||||
)
|
||||
|
||||
assert should_alert is False
|
||||
assert current == pytest.approx(50.0, rel=0.01)
|
||||
|
||||
def test_should_alert_above_threshold(self, tracker: CostTracker) -> None:
|
||||
"""Test alert check when above threshold."""
|
||||
asyncio.run(
|
||||
tracker.record_usage(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
model="claude-opus-4",
|
||||
prompt_tokens=1000,
|
||||
completion_tokens=500,
|
||||
cost_usd=150.0,
|
||||
)
|
||||
)
|
||||
|
||||
should_alert, current = asyncio.run(
|
||||
tracker.should_alert("proj-123", threshold=100.0)
|
||||
)
|
||||
|
||||
assert should_alert is True
|
||||
|
||||
def test_close(self, tracker: CostTracker) -> None:
|
||||
"""Test closing tracker."""
|
||||
asyncio.run(tracker.close())
|
||||
assert tracker._redis is None
|
||||
|
||||
|
||||
class TestGlobalTracker:
|
||||
"""Tests for global tracker functions."""
|
||||
|
||||
def test_get_cost_tracker(self) -> None:
|
||||
"""Test getting global tracker."""
|
||||
reset_cost_tracker()
|
||||
tracker = get_cost_tracker()
|
||||
assert isinstance(tracker, CostTracker)
|
||||
|
||||
def test_get_cost_tracker_singleton(self) -> None:
|
||||
"""Test tracker is singleton."""
|
||||
reset_cost_tracker()
|
||||
tracker1 = get_cost_tracker()
|
||||
tracker2 = get_cost_tracker()
|
||||
assert tracker1 is tracker2
|
||||
|
||||
def test_reset_cost_tracker(self) -> None:
|
||||
"""Test resetting global tracker."""
|
||||
reset_cost_tracker()
|
||||
tracker1 = get_cost_tracker()
|
||||
reset_cost_tracker()
|
||||
tracker2 = get_cost_tracker()
|
||||
assert tracker1 is not tracker2
|
||||
|
||||
def test_close_cost_tracker(self) -> None:
|
||||
"""Test closing global tracker."""
|
||||
reset_cost_tracker()
|
||||
_ = get_cost_tracker()
|
||||
asyncio.run(close_cost_tracker())
|
||||
# Getting again should create a new one
|
||||
tracker2 = get_cost_tracker()
|
||||
assert tracker2 is not None
|
||||
377
mcp-servers/llm-gateway/tests/test_exceptions.py
Normal file
377
mcp-servers/llm-gateway/tests/test_exceptions.py
Normal file
@@ -0,0 +1,377 @@
|
||||
"""
|
||||
Tests for exceptions module.
|
||||
"""
|
||||
|
||||
|
||||
from exceptions import (
|
||||
AllProvidersFailedError,
|
||||
CircuitOpenError,
|
||||
ConfigurationError,
|
||||
ContextTooLongError,
|
||||
CostLimitExceededError,
|
||||
ErrorCode,
|
||||
InvalidModelError,
|
||||
InvalidModelGroupError,
|
||||
LLMGatewayError,
|
||||
ModelNotAvailableError,
|
||||
ProviderError,
|
||||
RateLimitError,
|
||||
StreamError,
|
||||
TokenLimitExceededError,
|
||||
)
|
||||
|
||||
|
||||
class TestErrorCode:
|
||||
"""Tests for ErrorCode enum."""
|
||||
|
||||
def test_error_code_values(self) -> None:
|
||||
"""Test error code values."""
|
||||
assert ErrorCode.UNKNOWN_ERROR.value == "LLM_UNKNOWN_ERROR"
|
||||
assert ErrorCode.PROVIDER_ERROR.value == "LLM_PROVIDER_ERROR"
|
||||
assert ErrorCode.CIRCUIT_OPEN.value == "LLM_CIRCUIT_OPEN"
|
||||
assert ErrorCode.COST_LIMIT_EXCEEDED.value == "LLM_COST_LIMIT_EXCEEDED"
|
||||
|
||||
|
||||
class TestLLMGatewayError:
|
||||
"""Tests for LLMGatewayError base class."""
|
||||
|
||||
def test_basic_error(self) -> None:
|
||||
"""Test basic error creation."""
|
||||
error = LLMGatewayError("Something went wrong")
|
||||
assert str(error) == "[LLM_UNKNOWN_ERROR] Something went wrong"
|
||||
assert error.message == "Something went wrong"
|
||||
assert error.code == ErrorCode.UNKNOWN_ERROR
|
||||
assert error.details == {}
|
||||
assert error.cause is None
|
||||
|
||||
def test_error_with_code(self) -> None:
|
||||
"""Test error with custom code."""
|
||||
error = LLMGatewayError(
|
||||
"Provider failed",
|
||||
code=ErrorCode.PROVIDER_ERROR,
|
||||
)
|
||||
assert error.code == ErrorCode.PROVIDER_ERROR
|
||||
|
||||
def test_error_with_details(self) -> None:
|
||||
"""Test error with details."""
|
||||
error = LLMGatewayError(
|
||||
"Error",
|
||||
details={"key": "value"},
|
||||
)
|
||||
assert error.details == {"key": "value"}
|
||||
|
||||
def test_error_with_cause(self) -> None:
|
||||
"""Test error with cause exception."""
|
||||
cause = ValueError("Original error")
|
||||
error = LLMGatewayError("Wrapped error", cause=cause)
|
||||
assert error.cause is cause
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test converting error to dict."""
|
||||
error = LLMGatewayError(
|
||||
"Test error",
|
||||
code=ErrorCode.INVALID_REQUEST,
|
||||
details={"field": "value"},
|
||||
)
|
||||
result = error.to_dict()
|
||||
|
||||
assert result["error"] == "LLM_INVALID_REQUEST"
|
||||
assert result["message"] == "Test error"
|
||||
assert result["details"] == {"field": "value"}
|
||||
|
||||
def test_to_dict_no_details(self) -> None:
|
||||
"""Test to_dict without details."""
|
||||
error = LLMGatewayError("Test error")
|
||||
result = error.to_dict()
|
||||
|
||||
assert "details" not in result
|
||||
|
||||
def test_repr(self) -> None:
|
||||
"""Test error repr."""
|
||||
error = LLMGatewayError("Test", details={"key": "val"})
|
||||
repr_str = repr(error)
|
||||
|
||||
assert "LLMGatewayError" in repr_str
|
||||
assert "Test" in repr_str
|
||||
|
||||
|
||||
class TestProviderError:
|
||||
"""Tests for ProviderError."""
|
||||
|
||||
def test_basic_provider_error(self) -> None:
|
||||
"""Test basic provider error."""
|
||||
error = ProviderError(
|
||||
message="API call failed",
|
||||
provider="anthropic",
|
||||
)
|
||||
|
||||
assert error.provider == "anthropic"
|
||||
assert error.model is None
|
||||
assert error.status_code is None
|
||||
assert error.code == ErrorCode.PROVIDER_ERROR
|
||||
assert "provider" in error.details
|
||||
|
||||
def test_provider_error_with_model(self) -> None:
|
||||
"""Test provider error with model info."""
|
||||
error = ProviderError(
|
||||
message="Model not found",
|
||||
provider="openai",
|
||||
model="gpt-5",
|
||||
status_code=404,
|
||||
)
|
||||
|
||||
assert error.model == "gpt-5"
|
||||
assert error.status_code == 404
|
||||
assert error.details["model"] == "gpt-5"
|
||||
assert error.details["status_code"] == 404
|
||||
|
||||
def test_provider_error_with_cause(self) -> None:
|
||||
"""Test provider error with cause."""
|
||||
cause = ConnectionError("Network down")
|
||||
error = ProviderError(
|
||||
message="Connection failed",
|
||||
provider="google",
|
||||
cause=cause,
|
||||
)
|
||||
|
||||
assert error.cause is cause
|
||||
|
||||
|
||||
class TestRateLimitError:
|
||||
"""Tests for RateLimitError."""
|
||||
|
||||
def test_internal_rate_limit(self) -> None:
|
||||
"""Test internal rate limit error."""
|
||||
error = RateLimitError(
|
||||
message="Too many requests",
|
||||
retry_after=60,
|
||||
)
|
||||
|
||||
assert error.code == ErrorCode.RATE_LIMIT_EXCEEDED
|
||||
assert error.provider is None
|
||||
assert error.retry_after == 60
|
||||
assert error.details["retry_after_seconds"] == 60
|
||||
|
||||
def test_provider_rate_limit(self) -> None:
|
||||
"""Test provider rate limit error."""
|
||||
error = RateLimitError(
|
||||
message="OpenAI rate limit",
|
||||
provider="openai",
|
||||
retry_after=30,
|
||||
)
|
||||
|
||||
assert error.code == ErrorCode.PROVIDER_RATE_LIMIT
|
||||
assert error.provider == "openai"
|
||||
assert error.details["provider"] == "openai"
|
||||
|
||||
|
||||
class TestCircuitOpenError:
|
||||
"""Tests for CircuitOpenError."""
|
||||
|
||||
def test_circuit_open_error(self) -> None:
|
||||
"""Test circuit open error."""
|
||||
error = CircuitOpenError(
|
||||
provider="anthropic",
|
||||
recovery_time=45,
|
||||
)
|
||||
|
||||
assert error.provider == "anthropic"
|
||||
assert error.recovery_time == 45
|
||||
assert error.code == ErrorCode.CIRCUIT_OPEN
|
||||
assert "Circuit breaker open" in error.message
|
||||
assert error.details["recovery_time_seconds"] == 45
|
||||
|
||||
def test_circuit_open_no_recovery_time(self) -> None:
|
||||
"""Test circuit open without recovery time."""
|
||||
error = CircuitOpenError(provider="openai")
|
||||
|
||||
assert error.recovery_time is None
|
||||
assert "recovery_time_seconds" not in error.details
|
||||
|
||||
|
||||
class TestCostLimitExceededError:
|
||||
"""Tests for CostLimitExceededError."""
|
||||
|
||||
def test_project_cost_limit(self) -> None:
|
||||
"""Test project cost limit error."""
|
||||
error = CostLimitExceededError(
|
||||
entity_type="project",
|
||||
entity_id="proj-123",
|
||||
current_cost=150.0,
|
||||
limit=100.0,
|
||||
)
|
||||
|
||||
assert error.entity_type == "project"
|
||||
assert error.entity_id == "proj-123"
|
||||
assert error.current_cost == 150.0
|
||||
assert error.limit == 100.0
|
||||
assert error.code == ErrorCode.COST_LIMIT_EXCEEDED
|
||||
assert "$150.00" in error.message
|
||||
assert "$100.00" in error.message
|
||||
|
||||
def test_agent_cost_limit(self) -> None:
|
||||
"""Test agent cost limit error."""
|
||||
error = CostLimitExceededError(
|
||||
entity_type="agent",
|
||||
entity_id="agent-456",
|
||||
current_cost=50.0,
|
||||
limit=25.0,
|
||||
)
|
||||
|
||||
assert error.entity_type == "agent"
|
||||
assert error.details["entity_type"] == "agent"
|
||||
|
||||
|
||||
class TestInvalidModelGroupError:
|
||||
"""Tests for InvalidModelGroupError."""
|
||||
|
||||
def test_invalid_group_error(self) -> None:
|
||||
"""Test invalid model group error."""
|
||||
error = InvalidModelGroupError(
|
||||
model_group="invalid_group",
|
||||
available_groups=["reasoning", "code", "fast"],
|
||||
)
|
||||
|
||||
assert error.model_group == "invalid_group"
|
||||
assert error.available_groups == ["reasoning", "code", "fast"]
|
||||
assert error.code == ErrorCode.INVALID_MODEL_GROUP
|
||||
assert "invalid_group" in error.message
|
||||
|
||||
def test_invalid_group_no_available(self) -> None:
|
||||
"""Test invalid group without available list."""
|
||||
error = InvalidModelGroupError(model_group="unknown")
|
||||
|
||||
assert error.available_groups is None
|
||||
assert "available_groups" not in error.details
|
||||
|
||||
|
||||
class TestInvalidModelError:
|
||||
"""Tests for InvalidModelError."""
|
||||
|
||||
def test_invalid_model_error(self) -> None:
|
||||
"""Test invalid model error."""
|
||||
error = InvalidModelError(
|
||||
model="gpt-99",
|
||||
reason="Model does not exist",
|
||||
)
|
||||
|
||||
assert error.model == "gpt-99"
|
||||
assert error.code == ErrorCode.INVALID_MODEL
|
||||
assert "gpt-99" in error.message
|
||||
assert "Model does not exist" in error.message
|
||||
|
||||
def test_invalid_model_no_reason(self) -> None:
|
||||
"""Test invalid model without reason."""
|
||||
error = InvalidModelError(model="unknown-model")
|
||||
|
||||
assert "reason" not in error.details
|
||||
|
||||
|
||||
class TestModelNotAvailableError:
|
||||
"""Tests for ModelNotAvailableError."""
|
||||
|
||||
def test_model_not_available(self) -> None:
|
||||
"""Test model not available error."""
|
||||
error = ModelNotAvailableError(
|
||||
model="claude-opus-4",
|
||||
provider="anthropic",
|
||||
)
|
||||
|
||||
assert error.model == "claude-opus-4"
|
||||
assert error.provider == "anthropic"
|
||||
assert error.code == ErrorCode.MODEL_NOT_AVAILABLE
|
||||
assert "not configured" in error.message
|
||||
|
||||
|
||||
class TestAllProvidersFailedError:
|
||||
"""Tests for AllProvidersFailedError."""
|
||||
|
||||
def test_all_providers_failed(self) -> None:
|
||||
"""Test all providers failed error."""
|
||||
errors = [
|
||||
{"model": "claude-opus-4", "error": "Rate limited"},
|
||||
{"model": "gpt-4.1", "error": "Timeout"},
|
||||
]
|
||||
error = AllProvidersFailedError(
|
||||
model_group="reasoning",
|
||||
attempted_models=["claude-opus-4", "gpt-4.1"],
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
assert error.model_group == "reasoning"
|
||||
assert error.attempted_models == ["claude-opus-4", "gpt-4.1"]
|
||||
assert error.errors == errors
|
||||
assert error.code == ErrorCode.ALL_PROVIDERS_FAILED
|
||||
|
||||
|
||||
class TestStreamError:
|
||||
"""Tests for StreamError."""
|
||||
|
||||
def test_stream_error(self) -> None:
|
||||
"""Test stream error."""
|
||||
cause = OSError("Connection reset")
|
||||
error = StreamError(
|
||||
message="Stream interrupted",
|
||||
chunks_received=10,
|
||||
cause=cause,
|
||||
)
|
||||
|
||||
assert error.chunks_received == 10
|
||||
assert error.cause is cause
|
||||
assert error.code == ErrorCode.STREAM_ERROR
|
||||
|
||||
|
||||
class TestTokenLimitExceededError:
|
||||
"""Tests for TokenLimitExceededError."""
|
||||
|
||||
def test_token_limit_exceeded(self) -> None:
|
||||
"""Test token limit exceeded error."""
|
||||
error = TokenLimitExceededError(
|
||||
model="claude-haiku",
|
||||
token_count=10000,
|
||||
limit=8192,
|
||||
)
|
||||
|
||||
assert error.model == "claude-haiku"
|
||||
assert error.token_count == 10000
|
||||
assert error.limit == 8192
|
||||
assert error.code == ErrorCode.TOKEN_LIMIT_EXCEEDED
|
||||
|
||||
|
||||
class TestContextTooLongError:
|
||||
"""Tests for ContextTooLongError."""
|
||||
|
||||
def test_context_too_long(self) -> None:
|
||||
"""Test context too long error."""
|
||||
error = ContextTooLongError(
|
||||
model="gpt-4.1-mini",
|
||||
context_length=150000,
|
||||
max_context=100000,
|
||||
)
|
||||
|
||||
assert error.model == "gpt-4.1-mini"
|
||||
assert error.context_length == 150000
|
||||
assert error.max_context == 100000
|
||||
assert error.code == ErrorCode.CONTEXT_TOO_LONG
|
||||
|
||||
|
||||
class TestConfigurationError:
|
||||
"""Tests for ConfigurationError."""
|
||||
|
||||
def test_configuration_error(self) -> None:
|
||||
"""Test configuration error."""
|
||||
error = ConfigurationError(
|
||||
message="Missing API key",
|
||||
config_key="ANTHROPIC_API_KEY",
|
||||
)
|
||||
|
||||
assert error.config_key == "ANTHROPIC_API_KEY"
|
||||
assert error.code == ErrorCode.CONFIGURATION_ERROR
|
||||
assert error.details["config_key"] == "ANTHROPIC_API_KEY"
|
||||
|
||||
def test_configuration_error_no_key(self) -> None:
|
||||
"""Test configuration error without key."""
|
||||
error = ConfigurationError(message="Invalid configuration")
|
||||
|
||||
assert error.config_key is None
|
||||
assert "config_key" not in error.details
|
||||
407
mcp-servers/llm-gateway/tests/test_failover.py
Normal file
407
mcp-servers/llm-gateway/tests/test_failover.py
Normal file
@@ -0,0 +1,407 @@
|
||||
"""
|
||||
Tests for failover module (circuit breaker).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from config import Settings
|
||||
from exceptions import CircuitOpenError
|
||||
from failover import (
|
||||
CircuitBreaker,
|
||||
CircuitBreakerRegistry,
|
||||
CircuitState,
|
||||
CircuitStats,
|
||||
get_circuit_registry,
|
||||
reset_circuit_registry,
|
||||
)
|
||||
|
||||
|
||||
class TestCircuitState:
|
||||
"""Tests for CircuitState enum."""
|
||||
|
||||
def test_circuit_states(self) -> None:
|
||||
"""Test circuit state values."""
|
||||
assert CircuitState.CLOSED.value == "closed"
|
||||
assert CircuitState.OPEN.value == "open"
|
||||
assert CircuitState.HALF_OPEN.value == "half_open"
|
||||
|
||||
|
||||
class TestCircuitStats:
|
||||
"""Tests for CircuitStats dataclass."""
|
||||
|
||||
def test_default_stats(self) -> None:
|
||||
"""Test default stats values."""
|
||||
stats = CircuitStats()
|
||||
assert stats.failures == 0
|
||||
assert stats.successes == 0
|
||||
assert stats.last_failure_time is None
|
||||
assert stats.last_success_time is None
|
||||
assert stats.half_open_calls == 0
|
||||
|
||||
|
||||
class TestCircuitBreaker:
|
||||
"""Tests for CircuitBreaker class."""
|
||||
|
||||
def test_initial_state(self) -> None:
|
||||
"""Test circuit breaker initial state."""
|
||||
cb = CircuitBreaker(name="test", failure_threshold=5)
|
||||
|
||||
assert cb.name == "test"
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
assert cb.failure_threshold == 5
|
||||
assert cb.is_available() is True
|
||||
|
||||
def test_state_remains_closed_below_threshold(self) -> None:
|
||||
"""Test circuit stays closed below failure threshold."""
|
||||
cb = CircuitBreaker(name="test", failure_threshold=3)
|
||||
|
||||
# Record 2 failures (below threshold)
|
||||
asyncio.run(cb.record_failure())
|
||||
asyncio.run(cb.record_failure())
|
||||
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
assert cb.stats.failures == 2
|
||||
assert cb.is_available() is True
|
||||
|
||||
def test_state_opens_at_threshold(self) -> None:
|
||||
"""Test circuit opens at failure threshold."""
|
||||
cb = CircuitBreaker(name="test", failure_threshold=3)
|
||||
|
||||
# Record 3 failures (at threshold)
|
||||
asyncio.run(cb.record_failure())
|
||||
asyncio.run(cb.record_failure())
|
||||
asyncio.run(cb.record_failure())
|
||||
|
||||
assert cb.state == CircuitState.OPEN
|
||||
assert cb.is_available() is False
|
||||
|
||||
def test_success_resets_in_closed(self) -> None:
|
||||
"""Test success in closed state records properly."""
|
||||
cb = CircuitBreaker(name="test", failure_threshold=3)
|
||||
|
||||
asyncio.run(cb.record_failure())
|
||||
asyncio.run(cb.record_success())
|
||||
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
assert cb.stats.successes == 1
|
||||
assert cb.stats.last_success_time is not None
|
||||
|
||||
def test_half_open_transition(self) -> None:
|
||||
"""Test transition to half-open after recovery timeout."""
|
||||
cb = CircuitBreaker(
|
||||
name="test",
|
||||
failure_threshold=1,
|
||||
recovery_timeout=1, # 1 second
|
||||
)
|
||||
|
||||
# Open the circuit
|
||||
asyncio.run(cb.record_failure())
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
# Wait for recovery timeout
|
||||
time.sleep(1.1)
|
||||
|
||||
# State should transition to half-open
|
||||
assert cb.state == CircuitState.HALF_OPEN
|
||||
assert cb.is_available() is True
|
||||
|
||||
def test_half_open_success_closes(self) -> None:
|
||||
"""Test success in half-open closes circuit."""
|
||||
cb = CircuitBreaker(
|
||||
name="test",
|
||||
failure_threshold=1,
|
||||
recovery_timeout=0, # Immediate recovery for testing
|
||||
)
|
||||
|
||||
# Open and transition to half-open
|
||||
asyncio.run(cb.record_failure())
|
||||
time.sleep(0.1)
|
||||
_ = cb.state # Trigger state check
|
||||
|
||||
assert cb.state == CircuitState.HALF_OPEN
|
||||
|
||||
# Success should close
|
||||
asyncio.run(cb.record_success())
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
|
||||
def test_half_open_failure_reopens(self) -> None:
|
||||
"""Test failure in half-open reopens circuit."""
|
||||
cb = CircuitBreaker(
|
||||
name="test",
|
||||
failure_threshold=1,
|
||||
recovery_timeout=0.05, # Small but non-zero for reliable timing
|
||||
)
|
||||
|
||||
# Open and transition to half-open
|
||||
asyncio.run(cb.record_failure())
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
# Wait for recovery timeout
|
||||
time.sleep(0.1)
|
||||
assert cb.state == CircuitState.HALF_OPEN
|
||||
|
||||
# Failure should reopen
|
||||
asyncio.run(cb.record_failure())
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
def test_half_open_call_limit(self) -> None:
|
||||
"""Test half-open call limit."""
|
||||
cb = CircuitBreaker(
|
||||
name="test",
|
||||
failure_threshold=1,
|
||||
recovery_timeout=0,
|
||||
half_open_max_calls=2,
|
||||
)
|
||||
|
||||
# Open and transition to half-open
|
||||
asyncio.run(cb.record_failure())
|
||||
time.sleep(0.1)
|
||||
_ = cb.state
|
||||
|
||||
assert cb.is_available() is True
|
||||
|
||||
# Simulate calls in half-open
|
||||
cb._stats.half_open_calls = 1
|
||||
assert cb.is_available() is True
|
||||
|
||||
cb._stats.half_open_calls = 2
|
||||
assert cb.is_available() is False
|
||||
|
||||
def test_time_until_recovery(self) -> None:
|
||||
"""Test time until recovery calculation."""
|
||||
cb = CircuitBreaker(
|
||||
name="test",
|
||||
failure_threshold=1,
|
||||
recovery_timeout=60,
|
||||
)
|
||||
|
||||
# Closed circuit has no recovery time
|
||||
assert cb.time_until_recovery() is None
|
||||
|
||||
# Open circuit
|
||||
asyncio.run(cb.record_failure())
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
# Should have recovery time
|
||||
remaining = cb.time_until_recovery()
|
||||
assert remaining is not None
|
||||
assert 0 <= remaining <= 60
|
||||
|
||||
def test_execute_success(self) -> None:
|
||||
"""Test execute with successful function."""
|
||||
cb = CircuitBreaker(name="test", failure_threshold=3)
|
||||
|
||||
async def success_func() -> str:
|
||||
return "success"
|
||||
|
||||
result = asyncio.run(cb.execute(success_func))
|
||||
assert result == "success"
|
||||
assert cb.stats.successes == 1
|
||||
|
||||
def test_execute_failure(self) -> None:
|
||||
"""Test execute with failing function."""
|
||||
cb = CircuitBreaker(name="test", failure_threshold=3)
|
||||
|
||||
async def fail_func() -> None:
|
||||
raise ValueError("Error")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
asyncio.run(cb.execute(fail_func))
|
||||
|
||||
assert cb.stats.failures == 1
|
||||
|
||||
def test_execute_when_open(self) -> None:
|
||||
"""Test execute raises when circuit is open."""
|
||||
cb = CircuitBreaker(name="test", failure_threshold=1)
|
||||
|
||||
# Open the circuit
|
||||
asyncio.run(cb.record_failure())
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
async def success_func() -> str:
|
||||
return "success"
|
||||
|
||||
with pytest.raises(CircuitOpenError) as exc_info:
|
||||
asyncio.run(cb.execute(success_func))
|
||||
|
||||
assert exc_info.value.provider == "test"
|
||||
|
||||
def test_reset(self) -> None:
|
||||
"""Test circuit reset."""
|
||||
cb = CircuitBreaker(name="test", failure_threshold=1)
|
||||
|
||||
# Open the circuit
|
||||
asyncio.run(cb.record_failure())
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
# Reset
|
||||
cb.reset()
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
assert cb.stats.failures == 0
|
||||
assert cb.stats.successes == 0
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test converting circuit to dict."""
|
||||
cb = CircuitBreaker(name="test", failure_threshold=3)
|
||||
asyncio.run(cb.record_failure())
|
||||
asyncio.run(cb.record_success())
|
||||
|
||||
result = cb.to_dict()
|
||||
|
||||
assert result["name"] == "test"
|
||||
assert result["state"] == "closed"
|
||||
assert result["failures"] == 1
|
||||
assert result["successes"] == 1
|
||||
assert result["is_available"] is True
|
||||
|
||||
|
||||
class TestCircuitBreakerRegistry:
|
||||
"""Tests for CircuitBreakerRegistry class."""
|
||||
|
||||
def test_get_circuit_creates_new(self) -> None:
|
||||
"""Test getting a new circuit."""
|
||||
settings = Settings(circuit_failure_threshold=5)
|
||||
registry = CircuitBreakerRegistry(settings=settings)
|
||||
|
||||
circuit = asyncio.run(registry.get_circuit("anthropic"))
|
||||
|
||||
assert circuit.name == "anthropic"
|
||||
assert circuit.failure_threshold == 5
|
||||
|
||||
def test_get_circuit_returns_same(self) -> None:
|
||||
"""Test getting same circuit twice."""
|
||||
registry = CircuitBreakerRegistry()
|
||||
|
||||
circuit1 = asyncio.run(registry.get_circuit("openai"))
|
||||
circuit2 = asyncio.run(registry.get_circuit("openai"))
|
||||
|
||||
assert circuit1 is circuit2
|
||||
|
||||
def test_get_circuit_sync(self) -> None:
|
||||
"""Test sync circuit getter."""
|
||||
registry = CircuitBreakerRegistry()
|
||||
|
||||
circuit = registry.get_circuit_sync("google")
|
||||
assert circuit.name == "google"
|
||||
|
||||
def test_is_available(self) -> None:
|
||||
"""Test checking if circuit is available."""
|
||||
registry = CircuitBreakerRegistry()
|
||||
|
||||
assert asyncio.run(registry.is_available("test")) is True
|
||||
|
||||
# Open the circuit
|
||||
circuit = asyncio.run(registry.get_circuit("test"))
|
||||
for _ in range(5):
|
||||
asyncio.run(circuit.record_failure())
|
||||
|
||||
assert asyncio.run(registry.is_available("test")) is False
|
||||
|
||||
def test_record_success(self) -> None:
|
||||
"""Test recording success through registry."""
|
||||
registry = CircuitBreakerRegistry()
|
||||
|
||||
asyncio.run(registry.record_success("test"))
|
||||
|
||||
circuit = asyncio.run(registry.get_circuit("test"))
|
||||
assert circuit.stats.successes == 1
|
||||
|
||||
def test_record_failure(self) -> None:
|
||||
"""Test recording failure through registry."""
|
||||
registry = CircuitBreakerRegistry()
|
||||
|
||||
asyncio.run(registry.record_failure("test"))
|
||||
|
||||
circuit = asyncio.run(registry.get_circuit("test"))
|
||||
assert circuit.stats.failures == 1
|
||||
|
||||
def test_reset(self) -> None:
|
||||
"""Test resetting a specific circuit."""
|
||||
registry = CircuitBreakerRegistry()
|
||||
|
||||
# Create and fail a circuit
|
||||
asyncio.run(registry.record_failure("test"))
|
||||
asyncio.run(registry.reset("test"))
|
||||
|
||||
circuit = asyncio.run(registry.get_circuit("test"))
|
||||
assert circuit.stats.failures == 0
|
||||
|
||||
def test_reset_all(self) -> None:
|
||||
"""Test resetting all circuits."""
|
||||
registry = CircuitBreakerRegistry()
|
||||
|
||||
# Create multiple circuits with failures
|
||||
asyncio.run(registry.record_failure("circuit1"))
|
||||
asyncio.run(registry.record_failure("circuit2"))
|
||||
|
||||
asyncio.run(registry.reset_all())
|
||||
|
||||
circuit1 = asyncio.run(registry.get_circuit("circuit1"))
|
||||
circuit2 = asyncio.run(registry.get_circuit("circuit2"))
|
||||
assert circuit1.stats.failures == 0
|
||||
assert circuit2.stats.failures == 0
|
||||
|
||||
def test_get_all_states(self) -> None:
|
||||
"""Test getting all circuit states."""
|
||||
registry = CircuitBreakerRegistry()
|
||||
|
||||
asyncio.run(registry.get_circuit("circuit1"))
|
||||
asyncio.run(registry.get_circuit("circuit2"))
|
||||
|
||||
states = registry.get_all_states()
|
||||
|
||||
assert "circuit1" in states
|
||||
assert "circuit2" in states
|
||||
assert states["circuit1"]["state"] == "closed"
|
||||
|
||||
def test_get_open_circuits(self) -> None:
|
||||
"""Test getting open circuits."""
|
||||
settings = Settings(circuit_failure_threshold=1)
|
||||
registry = CircuitBreakerRegistry(settings=settings)
|
||||
|
||||
asyncio.run(registry.get_circuit("healthy"))
|
||||
asyncio.run(registry.record_failure("failing"))
|
||||
|
||||
open_circuits = registry.get_open_circuits()
|
||||
assert "failing" in open_circuits
|
||||
assert "healthy" not in open_circuits
|
||||
|
||||
def test_get_available_circuits(self) -> None:
|
||||
"""Test getting available circuits."""
|
||||
settings = Settings(circuit_failure_threshold=1)
|
||||
registry = CircuitBreakerRegistry(settings=settings)
|
||||
|
||||
asyncio.run(registry.get_circuit("healthy"))
|
||||
asyncio.run(registry.record_failure("failing"))
|
||||
|
||||
available = registry.get_available_circuits()
|
||||
assert "healthy" in available
|
||||
assert "failing" not in available
|
||||
|
||||
|
||||
class TestGlobalRegistry:
|
||||
"""Tests for global registry functions."""
|
||||
|
||||
def test_get_circuit_registry(self) -> None:
|
||||
"""Test getting global registry."""
|
||||
reset_circuit_registry()
|
||||
registry = get_circuit_registry()
|
||||
assert isinstance(registry, CircuitBreakerRegistry)
|
||||
|
||||
def test_get_circuit_registry_singleton(self) -> None:
|
||||
"""Test registry is singleton."""
|
||||
reset_circuit_registry()
|
||||
registry1 = get_circuit_registry()
|
||||
registry2 = get_circuit_registry()
|
||||
assert registry1 is registry2
|
||||
|
||||
def test_reset_circuit_registry(self) -> None:
|
||||
"""Test resetting global registry."""
|
||||
reset_circuit_registry()
|
||||
registry1 = get_circuit_registry()
|
||||
reset_circuit_registry()
|
||||
registry2 = get_circuit_registry()
|
||||
assert registry1 is not registry2
|
||||
408
mcp-servers/llm-gateway/tests/test_models.py
Normal file
408
mcp-servers/llm-gateway/tests/test_models.py
Normal file
@@ -0,0 +1,408 @@
|
||||
"""
|
||||
Tests for models module.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from models import (
|
||||
AGENT_TYPE_MODEL_PREFERENCES,
|
||||
MODEL_CONFIGS,
|
||||
MODEL_GROUPS,
|
||||
ChatMessage,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CostRecord,
|
||||
EmbeddingRequest,
|
||||
ModelConfig,
|
||||
ModelGroup,
|
||||
ModelGroupConfig,
|
||||
ModelGroupInfo,
|
||||
ModelInfo,
|
||||
Provider,
|
||||
StreamChunk,
|
||||
UsageReport,
|
||||
UsageStats,
|
||||
)
|
||||
|
||||
|
||||
class TestModelGroup:
|
||||
"""Tests for ModelGroup enum."""
|
||||
|
||||
def test_model_group_values(self) -> None:
|
||||
"""Test model group enum values."""
|
||||
assert ModelGroup.REASONING.value == "reasoning"
|
||||
assert ModelGroup.CODE.value == "code"
|
||||
assert ModelGroup.FAST.value == "fast"
|
||||
assert ModelGroup.VISION.value == "vision"
|
||||
assert ModelGroup.EMBEDDING.value == "embedding"
|
||||
assert ModelGroup.COST_OPTIMIZED.value == "cost_optimized"
|
||||
assert ModelGroup.SELF_HOSTED.value == "self_hosted"
|
||||
|
||||
def test_model_group_from_string(self) -> None:
|
||||
"""Test creating ModelGroup from string."""
|
||||
assert ModelGroup("reasoning") == ModelGroup.REASONING
|
||||
assert ModelGroup("code") == ModelGroup.CODE
|
||||
assert ModelGroup("fast") == ModelGroup.FAST
|
||||
|
||||
def test_model_group_invalid(self) -> None:
|
||||
"""Test invalid model group value."""
|
||||
with pytest.raises(ValueError):
|
||||
ModelGroup("invalid_group")
|
||||
|
||||
|
||||
class TestProvider:
|
||||
"""Tests for Provider enum."""
|
||||
|
||||
def test_provider_values(self) -> None:
|
||||
"""Test provider enum values."""
|
||||
assert Provider.ANTHROPIC.value == "anthropic"
|
||||
assert Provider.OPENAI.value == "openai"
|
||||
assert Provider.GOOGLE.value == "google"
|
||||
assert Provider.ALIBABA.value == "alibaba"
|
||||
assert Provider.DEEPSEEK.value == "deepseek"
|
||||
|
||||
|
||||
class TestModelConfig:
|
||||
"""Tests for ModelConfig dataclass."""
|
||||
|
||||
def test_model_config_creation(self) -> None:
|
||||
"""Test creating a ModelConfig."""
|
||||
config = ModelConfig(
|
||||
name="test-model",
|
||||
litellm_name="provider/test-model",
|
||||
provider=Provider.ANTHROPIC,
|
||||
cost_per_1m_input=10.0,
|
||||
cost_per_1m_output=30.0,
|
||||
context_window=100000,
|
||||
max_output_tokens=4096,
|
||||
supports_vision=True,
|
||||
)
|
||||
|
||||
assert config.name == "test-model"
|
||||
assert config.provider == Provider.ANTHROPIC
|
||||
assert config.cost_per_1m_input == 10.0
|
||||
assert config.supports_vision is True
|
||||
assert config.supports_streaming is True # default
|
||||
|
||||
def test_model_configs_exist(self) -> None:
|
||||
"""Test that model configs are defined."""
|
||||
assert len(MODEL_CONFIGS) > 0
|
||||
assert "claude-opus-4" in MODEL_CONFIGS
|
||||
assert "gpt-4.1" in MODEL_CONFIGS
|
||||
assert "gemini-2.5-pro" in MODEL_CONFIGS
|
||||
|
||||
|
||||
class TestModelGroupConfig:
|
||||
"""Tests for ModelGroupConfig dataclass."""
|
||||
|
||||
def test_model_group_config_creation(self) -> None:
|
||||
"""Test creating a ModelGroupConfig."""
|
||||
config = ModelGroupConfig(
|
||||
primary="model-a",
|
||||
fallbacks=["model-b", "model-c"],
|
||||
description="Test group",
|
||||
)
|
||||
|
||||
assert config.primary == "model-a"
|
||||
assert config.fallbacks == ["model-b", "model-c"]
|
||||
assert config.description == "Test group"
|
||||
|
||||
def test_get_all_models(self) -> None:
|
||||
"""Test getting all models in order."""
|
||||
config = ModelGroupConfig(
|
||||
primary="model-a",
|
||||
fallbacks=["model-b", "model-c"],
|
||||
description="Test group",
|
||||
)
|
||||
|
||||
models = config.get_all_models()
|
||||
assert models == ["model-a", "model-b", "model-c"]
|
||||
|
||||
def test_model_groups_exist(self) -> None:
|
||||
"""Test that model groups are defined."""
|
||||
assert len(MODEL_GROUPS) > 0
|
||||
assert ModelGroup.REASONING in MODEL_GROUPS
|
||||
assert ModelGroup.CODE in MODEL_GROUPS
|
||||
assert ModelGroup.FAST in MODEL_GROUPS
|
||||
|
||||
|
||||
class TestAgentTypePreferences:
|
||||
"""Tests for agent type model preferences."""
|
||||
|
||||
def test_agent_preferences_exist(self) -> None:
|
||||
"""Test that agent preferences are defined."""
|
||||
assert len(AGENT_TYPE_MODEL_PREFERENCES) > 0
|
||||
assert "product_owner" in AGENT_TYPE_MODEL_PREFERENCES
|
||||
assert "software_engineer" in AGENT_TYPE_MODEL_PREFERENCES
|
||||
|
||||
def test_agent_preference_values(self) -> None:
|
||||
"""Test agent preference values."""
|
||||
assert AGENT_TYPE_MODEL_PREFERENCES["product_owner"] == ModelGroup.REASONING
|
||||
assert AGENT_TYPE_MODEL_PREFERENCES["software_engineer"] == ModelGroup.CODE
|
||||
assert AGENT_TYPE_MODEL_PREFERENCES["devops_engineer"] == ModelGroup.FAST
|
||||
|
||||
|
||||
class TestChatMessage:
|
||||
"""Tests for ChatMessage model."""
|
||||
|
||||
def test_chat_message_creation(self) -> None:
|
||||
"""Test creating a ChatMessage."""
|
||||
msg = ChatMessage(role="user", content="Hello")
|
||||
assert msg.role == "user"
|
||||
assert msg.content == "Hello"
|
||||
assert msg.name is None
|
||||
assert msg.tool_calls is None
|
||||
|
||||
def test_chat_message_with_optional(self) -> None:
|
||||
"""Test ChatMessage with optional fields."""
|
||||
msg = ChatMessage(
|
||||
role="assistant",
|
||||
content="Response",
|
||||
name="assistant_1",
|
||||
tool_calls=[{"id": "call_1", "function": {"name": "test"}}],
|
||||
)
|
||||
assert msg.name == "assistant_1"
|
||||
assert msg.tool_calls is not None
|
||||
|
||||
def test_chat_message_list_content(self) -> None:
|
||||
"""Test ChatMessage with list content (for images)."""
|
||||
msg = ChatMessage(
|
||||
role="user",
|
||||
content=[
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{"type": "image_url", "image_url": {"url": "http://example.com/img.jpg"}},
|
||||
],
|
||||
)
|
||||
assert isinstance(msg.content, list)
|
||||
|
||||
|
||||
class TestCompletionRequest:
|
||||
"""Tests for CompletionRequest model."""
|
||||
|
||||
def test_completion_request_minimal(self) -> None:
|
||||
"""Test minimal CompletionRequest."""
|
||||
req = CompletionRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
messages=[ChatMessage(role="user", content="Hi")],
|
||||
)
|
||||
|
||||
assert req.project_id == "proj-123"
|
||||
assert req.agent_id == "agent-456"
|
||||
assert len(req.messages) == 1
|
||||
assert req.model_group == ModelGroup.REASONING # default
|
||||
assert req.max_tokens == 4096 # default
|
||||
assert req.temperature == 0.7 # default
|
||||
|
||||
def test_completion_request_full(self) -> None:
|
||||
"""Test full CompletionRequest."""
|
||||
req = CompletionRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
messages=[ChatMessage(role="user", content="Hi")],
|
||||
model_group=ModelGroup.CODE,
|
||||
model_override="claude-sonnet-4",
|
||||
max_tokens=8192,
|
||||
temperature=0.5,
|
||||
stream=True,
|
||||
session_id="session-789",
|
||||
metadata={"key": "value"},
|
||||
)
|
||||
|
||||
assert req.model_group == ModelGroup.CODE
|
||||
assert req.model_override == "claude-sonnet-4"
|
||||
assert req.max_tokens == 8192
|
||||
assert req.stream is True
|
||||
|
||||
def test_completion_request_validation(self) -> None:
|
||||
"""Test CompletionRequest validation."""
|
||||
with pytest.raises(ValueError):
|
||||
CompletionRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
messages=[ChatMessage(role="user", content="Hi")],
|
||||
max_tokens=0, # Invalid
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
CompletionRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
messages=[ChatMessage(role="user", content="Hi")],
|
||||
temperature=-0.1, # Invalid
|
||||
)
|
||||
|
||||
|
||||
class TestUsageStats:
|
||||
"""Tests for UsageStats model."""
|
||||
|
||||
def test_usage_stats_default(self) -> None:
|
||||
"""Test default UsageStats."""
|
||||
stats = UsageStats()
|
||||
assert stats.prompt_tokens == 0
|
||||
assert stats.completion_tokens == 0
|
||||
assert stats.total_tokens == 0
|
||||
assert stats.cost_usd == 0.0
|
||||
|
||||
def test_usage_stats_custom(self) -> None:
|
||||
"""Test custom UsageStats."""
|
||||
stats = UsageStats(
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
total_tokens=150,
|
||||
cost_usd=0.001,
|
||||
)
|
||||
assert stats.prompt_tokens == 100
|
||||
assert stats.total_tokens == 150
|
||||
|
||||
def test_usage_stats_from_response(self) -> None:
|
||||
"""Test creating UsageStats from response."""
|
||||
config = MODEL_CONFIGS["claude-opus-4"]
|
||||
stats = UsageStats.from_response(
|
||||
prompt_tokens=1000,
|
||||
completion_tokens=500,
|
||||
model_config=config,
|
||||
)
|
||||
|
||||
assert stats.prompt_tokens == 1000
|
||||
assert stats.completion_tokens == 500
|
||||
assert stats.total_tokens == 1500
|
||||
# 1000/1M * 15 + 500/1M * 75 = 0.015 + 0.0375 = 0.0525
|
||||
assert stats.cost_usd == pytest.approx(0.0525, rel=0.01)
|
||||
|
||||
|
||||
class TestCompletionResponse:
|
||||
"""Tests for CompletionResponse model."""
|
||||
|
||||
def test_completion_response_creation(self) -> None:
|
||||
"""Test creating a CompletionResponse."""
|
||||
response = CompletionResponse(
|
||||
id="resp-123",
|
||||
model="claude-opus-4",
|
||||
provider="anthropic",
|
||||
content="Hello, world!",
|
||||
)
|
||||
|
||||
assert response.id == "resp-123"
|
||||
assert response.model == "claude-opus-4"
|
||||
assert response.provider == "anthropic"
|
||||
assert response.content == "Hello, world!"
|
||||
assert response.finish_reason == "stop"
|
||||
|
||||
|
||||
class TestStreamChunk:
|
||||
"""Tests for StreamChunk model."""
|
||||
|
||||
def test_stream_chunk_creation(self) -> None:
|
||||
"""Test creating a StreamChunk."""
|
||||
chunk = StreamChunk(id="chunk-1", delta="Hello")
|
||||
assert chunk.id == "chunk-1"
|
||||
assert chunk.delta == "Hello"
|
||||
assert chunk.finish_reason is None
|
||||
|
||||
def test_stream_chunk_final(self) -> None:
|
||||
"""Test final StreamChunk."""
|
||||
chunk = StreamChunk(
|
||||
id="chunk-last",
|
||||
delta="",
|
||||
finish_reason="stop",
|
||||
usage=UsageStats(prompt_tokens=10, completion_tokens=5, total_tokens=15),
|
||||
)
|
||||
assert chunk.finish_reason == "stop"
|
||||
assert chunk.usage is not None
|
||||
|
||||
|
||||
class TestEmbeddingRequest:
|
||||
"""Tests for EmbeddingRequest model."""
|
||||
|
||||
def test_embedding_request_creation(self) -> None:
|
||||
"""Test creating an EmbeddingRequest."""
|
||||
req = EmbeddingRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
texts=["Hello", "World"],
|
||||
)
|
||||
|
||||
assert req.project_id == "proj-123"
|
||||
assert len(req.texts) == 2
|
||||
assert req.model == "text-embedding-3-large" # default
|
||||
|
||||
def test_embedding_request_validation(self) -> None:
|
||||
"""Test EmbeddingRequest validation."""
|
||||
with pytest.raises(ValueError):
|
||||
EmbeddingRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
texts=[], # Invalid - must have at least 1
|
||||
)
|
||||
|
||||
|
||||
class TestCostRecord:
|
||||
"""Tests for CostRecord dataclass."""
|
||||
|
||||
def test_cost_record_creation(self) -> None:
|
||||
"""Test creating a CostRecord."""
|
||||
record = CostRecord(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
model="claude-opus-4",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cost_usd=0.01,
|
||||
)
|
||||
|
||||
assert record.project_id == "proj-123"
|
||||
assert record.cost_usd == 0.01
|
||||
assert record.timestamp is not None
|
||||
|
||||
|
||||
class TestUsageReport:
|
||||
"""Tests for UsageReport model."""
|
||||
|
||||
def test_usage_report_creation(self) -> None:
|
||||
"""Test creating a UsageReport."""
|
||||
now = datetime.now(UTC)
|
||||
report = UsageReport(
|
||||
entity_id="proj-123",
|
||||
entity_type="project",
|
||||
period="day",
|
||||
period_start=now,
|
||||
period_end=now,
|
||||
)
|
||||
|
||||
assert report.entity_id == "proj-123"
|
||||
assert report.entity_type == "project"
|
||||
assert report.total_requests == 0
|
||||
assert report.total_cost_usd == 0.0
|
||||
|
||||
|
||||
class TestModelInfo:
|
||||
"""Tests for ModelInfo model."""
|
||||
|
||||
def test_model_info_from_config(self) -> None:
|
||||
"""Test creating ModelInfo from ModelConfig."""
|
||||
config = MODEL_CONFIGS["claude-opus-4"]
|
||||
info = ModelInfo.from_config(config, available=True)
|
||||
|
||||
assert info.name == "claude-opus-4"
|
||||
assert info.provider == "anthropic"
|
||||
assert info.available is True
|
||||
assert info.supports_vision is True
|
||||
|
||||
|
||||
class TestModelGroupInfo:
|
||||
"""Tests for ModelGroupInfo model."""
|
||||
|
||||
def test_model_group_info_creation(self) -> None:
|
||||
"""Test creating ModelGroupInfo."""
|
||||
info = ModelGroupInfo(
|
||||
name="reasoning",
|
||||
description="Complex analysis",
|
||||
primary_model="claude-opus-4",
|
||||
fallback_models=["gpt-4.1"],
|
||||
)
|
||||
|
||||
assert info.name == "reasoning"
|
||||
assert len(info.fallback_models) == 1
|
||||
308
mcp-servers/llm-gateway/tests/test_providers.py
Normal file
308
mcp-servers/llm-gateway/tests/test_providers.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""
|
||||
Tests for providers module.
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from config import Settings
|
||||
from models import MODEL_CONFIGS, ModelGroup, Provider
|
||||
from providers import (
|
||||
LLMProvider,
|
||||
build_fallback_config,
|
||||
build_model_list,
|
||||
configure_litellm,
|
||||
get_available_model_groups,
|
||||
get_available_models,
|
||||
get_provider,
|
||||
reset_provider,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def full_settings() -> Settings:
|
||||
"""Settings with all providers configured."""
|
||||
return Settings(
|
||||
anthropic_api_key="test-anthropic-key",
|
||||
openai_api_key="test-openai-key",
|
||||
google_api_key="test-google-key",
|
||||
alibaba_api_key="test-alibaba-key",
|
||||
deepseek_api_key="test-deepseek-key",
|
||||
litellm_timeout=60,
|
||||
litellm_cache_enabled=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def partial_settings() -> Settings:
|
||||
"""Settings with only some providers configured."""
|
||||
return Settings(
|
||||
anthropic_api_key="test-anthropic-key",
|
||||
openai_api_key=None,
|
||||
google_api_key=None,
|
||||
alibaba_api_key=None,
|
||||
deepseek_api_key=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def empty_settings() -> Settings:
|
||||
"""Settings with no providers configured."""
|
||||
return Settings(
|
||||
anthropic_api_key=None,
|
||||
openai_api_key=None,
|
||||
google_api_key=None,
|
||||
alibaba_api_key=None,
|
||||
deepseek_api_key=None,
|
||||
)
|
||||
|
||||
|
||||
class TestConfigureLiteLLM:
|
||||
"""Tests for configure_litellm function."""
|
||||
|
||||
def test_sets_api_keys(self, full_settings: Settings) -> None:
|
||||
"""Test that API keys are set in environment."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
configure_litellm(full_settings)
|
||||
|
||||
assert os.environ.get("ANTHROPIC_API_KEY") == "test-anthropic-key"
|
||||
assert os.environ.get("OPENAI_API_KEY") == "test-openai-key"
|
||||
assert os.environ.get("GEMINI_API_KEY") == "test-google-key"
|
||||
|
||||
def test_skips_none_keys(self, partial_settings: Settings) -> None:
|
||||
"""Test that None keys are not set."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
configure_litellm(partial_settings)
|
||||
|
||||
assert os.environ.get("ANTHROPIC_API_KEY") == "test-anthropic-key"
|
||||
assert "OPENAI_API_KEY" not in os.environ
|
||||
|
||||
|
||||
class TestBuildModelList:
|
||||
"""Tests for build_model_list function."""
|
||||
|
||||
def test_build_with_all_providers(self, full_settings: Settings) -> None:
|
||||
"""Test building model list with all providers."""
|
||||
model_list = build_model_list(full_settings)
|
||||
|
||||
assert len(model_list) > 0
|
||||
|
||||
# Check structure
|
||||
for entry in model_list:
|
||||
assert "model_name" in entry
|
||||
assert "litellm_params" in entry
|
||||
assert "model" in entry["litellm_params"]
|
||||
assert "timeout" in entry["litellm_params"]
|
||||
|
||||
def test_build_with_partial_providers(self, partial_settings: Settings) -> None:
|
||||
"""Test building model list with partial providers."""
|
||||
model_list = build_model_list(partial_settings)
|
||||
|
||||
# Should only include Anthropic models
|
||||
providers = set()
|
||||
for entry in model_list:
|
||||
model_name = entry["model_name"]
|
||||
config = MODEL_CONFIGS.get(model_name)
|
||||
if config:
|
||||
providers.add(config.provider)
|
||||
|
||||
assert Provider.ANTHROPIC in providers
|
||||
assert Provider.OPENAI not in providers
|
||||
|
||||
def test_build_with_no_providers(self, empty_settings: Settings) -> None:
|
||||
"""Test building model list with no providers."""
|
||||
model_list = build_model_list(empty_settings)
|
||||
|
||||
assert len(model_list) == 0
|
||||
|
||||
def test_build_includes_timeout(self, full_settings: Settings) -> None:
|
||||
"""Test that model entries include timeout."""
|
||||
model_list = build_model_list(full_settings)
|
||||
|
||||
for entry in model_list:
|
||||
assert entry["litellm_params"]["timeout"] == 60
|
||||
|
||||
|
||||
class TestBuildFallbackConfig:
|
||||
"""Tests for build_fallback_config function."""
|
||||
|
||||
def test_build_fallbacks_full(self, full_settings: Settings) -> None:
|
||||
"""Test building fallback config with all providers."""
|
||||
fallbacks = build_fallback_config(full_settings)
|
||||
|
||||
assert len(fallbacks) > 0
|
||||
|
||||
# Primary models should have fallbacks
|
||||
for _primary, chain in fallbacks.items():
|
||||
assert isinstance(chain, list)
|
||||
assert len(chain) > 0
|
||||
|
||||
def test_build_fallbacks_partial(self, partial_settings: Settings) -> None:
|
||||
"""Test building fallback config with partial providers."""
|
||||
fallbacks = build_fallback_config(partial_settings)
|
||||
|
||||
# With only Anthropic, there should be no fallbacks
|
||||
# (fallbacks require at least 2 available models)
|
||||
for primary, chain in fallbacks.items():
|
||||
# All models in chain should be from Anthropic
|
||||
for model in [primary] + chain:
|
||||
config = MODEL_CONFIGS.get(model)
|
||||
if config:
|
||||
assert config.provider == Provider.ANTHROPIC
|
||||
|
||||
|
||||
class TestGetAvailableModels:
|
||||
"""Tests for get_available_models function."""
|
||||
|
||||
def test_get_available_full(self, full_settings: Settings) -> None:
|
||||
"""Test getting available models with all providers."""
|
||||
models = get_available_models(full_settings)
|
||||
|
||||
assert len(models) > 0
|
||||
assert "claude-opus-4" in models
|
||||
assert "gpt-4.1" in models
|
||||
|
||||
def test_get_available_partial(self, partial_settings: Settings) -> None:
|
||||
"""Test getting available models with partial providers."""
|
||||
models = get_available_models(partial_settings)
|
||||
|
||||
assert "claude-opus-4" in models
|
||||
assert "gpt-4.1" not in models
|
||||
|
||||
def test_get_available_empty(self, empty_settings: Settings) -> None:
|
||||
"""Test getting available models with no providers."""
|
||||
models = get_available_models(empty_settings)
|
||||
|
||||
assert len(models) == 0
|
||||
|
||||
|
||||
class TestGetAvailableModelGroups:
|
||||
"""Tests for get_available_model_groups function."""
|
||||
|
||||
def test_get_groups_full(self, full_settings: Settings) -> None:
|
||||
"""Test getting groups with all providers."""
|
||||
groups = get_available_model_groups(full_settings)
|
||||
|
||||
assert len(groups) == len(ModelGroup)
|
||||
assert ModelGroup.REASONING in groups
|
||||
assert len(groups[ModelGroup.REASONING]) > 0
|
||||
|
||||
def test_get_groups_partial(self, partial_settings: Settings) -> None:
|
||||
"""Test getting groups with partial providers."""
|
||||
groups = get_available_model_groups(partial_settings)
|
||||
|
||||
# Only Anthropic models should be available
|
||||
for _group, models in groups.items():
|
||||
for model in models:
|
||||
config = MODEL_CONFIGS.get(model)
|
||||
if config:
|
||||
assert config.provider == Provider.ANTHROPIC
|
||||
|
||||
|
||||
class TestLLMProvider:
|
||||
"""Tests for LLMProvider class."""
|
||||
|
||||
def test_initialization(self, full_settings: Settings) -> None:
|
||||
"""Test provider initialization."""
|
||||
provider = LLMProvider(settings=full_settings)
|
||||
|
||||
assert provider._initialized is False
|
||||
assert provider._router is None
|
||||
|
||||
def test_initialize(self, full_settings: Settings) -> None:
|
||||
"""Test provider initialize."""
|
||||
with patch("providers.Router") as mock_router:
|
||||
provider = LLMProvider(settings=full_settings)
|
||||
provider.initialize()
|
||||
|
||||
assert provider._initialized is True
|
||||
mock_router.assert_called_once()
|
||||
|
||||
def test_initialize_idempotent(self, full_settings: Settings) -> None:
|
||||
"""Test that initialize is idempotent."""
|
||||
with patch("providers.Router") as mock_router:
|
||||
provider = LLMProvider(settings=full_settings)
|
||||
provider.initialize()
|
||||
provider.initialize()
|
||||
|
||||
# Should only be called once
|
||||
assert mock_router.call_count == 1
|
||||
|
||||
def test_initialize_no_providers(self, empty_settings: Settings) -> None:
|
||||
"""Test initialization with no providers."""
|
||||
provider = LLMProvider(settings=empty_settings)
|
||||
provider.initialize()
|
||||
|
||||
assert provider._initialized is True
|
||||
assert provider._router is None
|
||||
|
||||
def test_router_property(self, full_settings: Settings) -> None:
|
||||
"""Test router property triggers initialization."""
|
||||
with patch("providers.Router"):
|
||||
provider = LLMProvider(settings=full_settings)
|
||||
_ = provider.router
|
||||
|
||||
assert provider._initialized is True
|
||||
|
||||
def test_is_available(self, full_settings: Settings) -> None:
|
||||
"""Test is_available property."""
|
||||
with patch("providers.Router"):
|
||||
provider = LLMProvider(settings=full_settings)
|
||||
assert provider.is_available is True
|
||||
|
||||
def test_is_not_available(self, empty_settings: Settings) -> None:
|
||||
"""Test is_available when no providers."""
|
||||
provider = LLMProvider(settings=empty_settings)
|
||||
assert provider.is_available is False
|
||||
|
||||
def test_get_model_config(self, full_settings: Settings) -> None:
|
||||
"""Test getting model config."""
|
||||
provider = LLMProvider(settings=full_settings)
|
||||
|
||||
config = provider.get_model_config("claude-opus-4")
|
||||
assert config is not None
|
||||
assert config.name == "claude-opus-4"
|
||||
|
||||
assert provider.get_model_config("nonexistent") is None
|
||||
|
||||
def test_get_available_models(self, full_settings: Settings) -> None:
|
||||
"""Test getting available models."""
|
||||
provider = LLMProvider(settings=full_settings)
|
||||
models = provider.get_available_models()
|
||||
|
||||
assert "claude-opus-4" in models
|
||||
assert "gpt-4.1" in models
|
||||
|
||||
def test_is_model_available(self, full_settings: Settings) -> None:
|
||||
"""Test checking model availability."""
|
||||
provider = LLMProvider(settings=full_settings)
|
||||
|
||||
assert provider.is_model_available("claude-opus-4") is True
|
||||
assert provider.is_model_available("nonexistent") is False
|
||||
|
||||
|
||||
class TestGlobalProvider:
|
||||
"""Tests for global provider functions."""
|
||||
|
||||
def test_get_provider(self) -> None:
|
||||
"""Test getting global provider."""
|
||||
reset_provider()
|
||||
provider = get_provider()
|
||||
assert isinstance(provider, LLMProvider)
|
||||
|
||||
def test_get_provider_singleton(self) -> None:
|
||||
"""Test provider is singleton."""
|
||||
reset_provider()
|
||||
provider1 = get_provider()
|
||||
provider2 = get_provider()
|
||||
assert provider1 is provider2
|
||||
|
||||
def test_reset_provider(self) -> None:
|
||||
"""Test resetting global provider."""
|
||||
reset_provider()
|
||||
provider1 = get_provider()
|
||||
reset_provider()
|
||||
provider2 = get_provider()
|
||||
assert provider1 is not provider2
|
||||
243
mcp-servers/llm-gateway/tests/test_routing.py
Normal file
243
mcp-servers/llm-gateway/tests/test_routing.py
Normal file
@@ -0,0 +1,243 @@
|
||||
"""
|
||||
Tests for routing module.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from config import Settings
|
||||
from exceptions import (
|
||||
AllProvidersFailedError,
|
||||
InvalidModelError,
|
||||
InvalidModelGroupError,
|
||||
ModelNotAvailableError,
|
||||
)
|
||||
from failover import CircuitBreakerRegistry, reset_circuit_registry
|
||||
from models import ModelGroup
|
||||
from providers import reset_provider
|
||||
from routing import ModelRouter, get_model_router, reset_model_router
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def router_settings() -> Settings:
|
||||
"""Settings for routing tests."""
|
||||
return Settings(
|
||||
anthropic_api_key="test-key",
|
||||
openai_api_key="test-key",
|
||||
google_api_key="test-key",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def router(router_settings: Settings) -> ModelRouter:
|
||||
"""Create model router for testing."""
|
||||
reset_circuit_registry()
|
||||
reset_model_router()
|
||||
reset_provider()
|
||||
registry = CircuitBreakerRegistry(settings=router_settings)
|
||||
return ModelRouter(settings=router_settings, circuit_registry=registry)
|
||||
|
||||
|
||||
class TestModelRouter:
|
||||
"""Tests for ModelRouter class."""
|
||||
|
||||
def test_parse_model_group_valid(self, router: ModelRouter) -> None:
|
||||
"""Test parsing valid model groups."""
|
||||
assert router.parse_model_group("reasoning") == ModelGroup.REASONING
|
||||
assert router.parse_model_group("code") == ModelGroup.CODE
|
||||
assert router.parse_model_group("fast") == ModelGroup.FAST
|
||||
assert router.parse_model_group("REASONING") == ModelGroup.REASONING
|
||||
|
||||
def test_parse_model_group_aliases(self, router: ModelRouter) -> None:
|
||||
"""Test parsing model group aliases."""
|
||||
assert router.parse_model_group("high-reasoning") == ModelGroup.REASONING
|
||||
assert router.parse_model_group("high_reasoning") == ModelGroup.REASONING
|
||||
assert router.parse_model_group("code-generation") == ModelGroup.CODE
|
||||
assert router.parse_model_group("fast-response") == ModelGroup.FAST
|
||||
|
||||
def test_parse_model_group_invalid(self, router: ModelRouter) -> None:
|
||||
"""Test parsing invalid model group."""
|
||||
with pytest.raises(InvalidModelGroupError) as exc_info:
|
||||
router.parse_model_group("invalid_group")
|
||||
|
||||
assert exc_info.value.model_group == "invalid_group"
|
||||
assert exc_info.value.available_groups is not None
|
||||
|
||||
def test_get_model_config_valid(self, router: ModelRouter) -> None:
|
||||
"""Test getting valid model config."""
|
||||
config = router.get_model_config("claude-opus-4")
|
||||
assert config.name == "claude-opus-4"
|
||||
assert config.provider.value == "anthropic"
|
||||
|
||||
def test_get_model_config_invalid(self, router: ModelRouter) -> None:
|
||||
"""Test getting invalid model config."""
|
||||
with pytest.raises(InvalidModelError) as exc_info:
|
||||
router.get_model_config("nonexistent-model")
|
||||
|
||||
assert exc_info.value.model == "nonexistent-model"
|
||||
|
||||
def test_get_preferred_group_for_agent(self, router: ModelRouter) -> None:
|
||||
"""Test getting preferred group for agent types."""
|
||||
assert router.get_preferred_group_for_agent("product_owner") == ModelGroup.REASONING
|
||||
assert router.get_preferred_group_for_agent("software_engineer") == ModelGroup.CODE
|
||||
assert router.get_preferred_group_for_agent("devops_engineer") == ModelGroup.FAST
|
||||
|
||||
def test_get_preferred_group_unknown_agent(self, router: ModelRouter) -> None:
|
||||
"""Test getting preferred group for unknown agent."""
|
||||
# Should default to REASONING
|
||||
assert router.get_preferred_group_for_agent("unknown_type") == ModelGroup.REASONING
|
||||
|
||||
def test_select_model_by_group(self, router: ModelRouter) -> None:
|
||||
"""Test selecting model by group."""
|
||||
model_name, config = asyncio.run(
|
||||
router.select_model(model_group=ModelGroup.REASONING)
|
||||
)
|
||||
|
||||
assert model_name == "claude-opus-4"
|
||||
assert config.provider.value == "anthropic"
|
||||
|
||||
def test_select_model_by_group_string(self, router: ModelRouter) -> None:
|
||||
"""Test selecting model by group string."""
|
||||
model_name, config = asyncio.run(
|
||||
router.select_model(model_group="code")
|
||||
)
|
||||
|
||||
assert model_name == "claude-sonnet-4"
|
||||
|
||||
def test_select_model_with_override(self, router: ModelRouter) -> None:
|
||||
"""Test selecting specific model override."""
|
||||
model_name, config = asyncio.run(
|
||||
router.select_model(
|
||||
model_group="reasoning",
|
||||
model_override="gpt-4.1",
|
||||
)
|
||||
)
|
||||
|
||||
assert model_name == "gpt-4.1"
|
||||
assert config.provider.value == "openai"
|
||||
|
||||
def test_select_model_override_invalid(self, router: ModelRouter) -> None:
|
||||
"""Test selecting invalid model override."""
|
||||
with pytest.raises(InvalidModelError):
|
||||
asyncio.run(
|
||||
router.select_model(
|
||||
model_group="reasoning",
|
||||
model_override="nonexistent-model",
|
||||
)
|
||||
)
|
||||
|
||||
def test_select_model_override_unavailable(self, router: ModelRouter) -> None: # noqa: ARG002
|
||||
"""Test selecting unavailable model override."""
|
||||
# Create router without Alibaba key
|
||||
settings = Settings(
|
||||
anthropic_api_key="test-key",
|
||||
alibaba_api_key=None,
|
||||
)
|
||||
registry = CircuitBreakerRegistry(settings=settings)
|
||||
limited_router = ModelRouter(settings=settings, circuit_registry=registry)
|
||||
|
||||
with pytest.raises(ModelNotAvailableError):
|
||||
asyncio.run(
|
||||
limited_router.select_model(
|
||||
model_group="reasoning",
|
||||
model_override="qwen-max",
|
||||
)
|
||||
)
|
||||
|
||||
def test_select_model_fallback_on_circuit_open(
|
||||
self,
|
||||
router: ModelRouter,
|
||||
) -> None:
|
||||
"""Test fallback when primary circuit is open."""
|
||||
# Open circuit for anthropic
|
||||
circuit = router._circuit_registry.get_circuit_sync("anthropic")
|
||||
for _ in range(5):
|
||||
asyncio.run(circuit.record_failure())
|
||||
|
||||
# Should fall back to OpenAI
|
||||
model_name, config = asyncio.run(
|
||||
router.select_model(model_group=ModelGroup.REASONING)
|
||||
)
|
||||
|
||||
assert model_name == "gpt-4.1"
|
||||
assert config.provider.value == "openai"
|
||||
|
||||
def test_select_model_all_unavailable(self) -> None:
|
||||
"""Test when all providers are unavailable."""
|
||||
settings = Settings(
|
||||
anthropic_api_key=None,
|
||||
openai_api_key=None,
|
||||
google_api_key=None,
|
||||
)
|
||||
registry = CircuitBreakerRegistry(settings=settings)
|
||||
limited_router = ModelRouter(settings=settings, circuit_registry=registry)
|
||||
|
||||
with pytest.raises(AllProvidersFailedError) as exc_info:
|
||||
asyncio.run(
|
||||
limited_router.select_model(model_group=ModelGroup.REASONING)
|
||||
)
|
||||
|
||||
assert exc_info.value.model_group == "reasoning"
|
||||
assert len(exc_info.value.attempted_models) > 0
|
||||
|
||||
def test_get_available_models_for_group(self, router: ModelRouter) -> None:
|
||||
"""Test getting available models for a group."""
|
||||
models = asyncio.run(
|
||||
router.get_available_models_for_group(ModelGroup.REASONING)
|
||||
)
|
||||
|
||||
assert len(models) > 0
|
||||
# Should be (name, config, available) tuples
|
||||
for name, config, _available in models:
|
||||
assert isinstance(name, str)
|
||||
assert config is not None
|
||||
|
||||
def test_get_available_models_for_group_string(self, router: ModelRouter) -> None:
|
||||
"""Test getting available models with string group."""
|
||||
models = asyncio.run(
|
||||
router.get_available_models_for_group("code")
|
||||
)
|
||||
|
||||
assert len(models) > 0
|
||||
|
||||
def test_get_available_models_invalid_group(self, router: ModelRouter) -> None:
|
||||
"""Test getting models for invalid group."""
|
||||
with pytest.raises(InvalidModelGroupError):
|
||||
asyncio.run(
|
||||
router.get_available_models_for_group("invalid")
|
||||
)
|
||||
|
||||
def test_get_all_model_groups(self, router: ModelRouter) -> None:
|
||||
"""Test getting all model groups info."""
|
||||
groups = router.get_all_model_groups()
|
||||
|
||||
assert len(groups) == len(ModelGroup)
|
||||
assert "reasoning" in groups
|
||||
assert "code" in groups
|
||||
assert groups["reasoning"]["primary"] == "claude-opus-4"
|
||||
|
||||
|
||||
class TestGlobalRouter:
|
||||
"""Tests for global router functions."""
|
||||
|
||||
def test_get_model_router(self) -> None:
|
||||
"""Test getting global router."""
|
||||
reset_model_router()
|
||||
router = get_model_router()
|
||||
assert isinstance(router, ModelRouter)
|
||||
|
||||
def test_get_model_router_singleton(self) -> None:
|
||||
"""Test router is singleton."""
|
||||
reset_model_router()
|
||||
router1 = get_model_router()
|
||||
router2 = get_model_router()
|
||||
assert router1 is router2
|
||||
|
||||
def test_reset_model_router(self) -> None:
|
||||
"""Test resetting global router."""
|
||||
reset_model_router()
|
||||
router1 = get_model_router()
|
||||
reset_model_router()
|
||||
router2 = get_model_router()
|
||||
assert router1 is not router2
|
||||
412
mcp-servers/llm-gateway/tests/test_server.py
Normal file
412
mcp-servers/llm-gateway/tests/test_server.py
Normal file
@@ -0,0 +1,412 @@
|
||||
"""
|
||||
Tests for server module.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from config import Settings
|
||||
from models import ModelGroup
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_settings() -> Settings:
|
||||
"""Test settings with mock API keys."""
|
||||
return Settings(
|
||||
anthropic_api_key="test-anthropic-key",
|
||||
openai_api_key="test-openai-key",
|
||||
google_api_key="test-google-key",
|
||||
cost_tracking_enabled=False, # Disable for most tests
|
||||
litellm_cache_enabled=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_client(test_settings: Settings) -> TestClient:
|
||||
"""Create test client with mocked dependencies."""
|
||||
with (
|
||||
patch("server.get_settings", return_value=test_settings),
|
||||
patch("server.get_provider") as mock_provider,
|
||||
):
|
||||
mock_provider.return_value = MagicMock()
|
||||
mock_provider.return_value.is_available = True
|
||||
mock_provider.return_value.router = MagicMock()
|
||||
mock_provider.return_value.get_available_models.return_value = {}
|
||||
|
||||
from server import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
class TestHealthEndpoint:
|
||||
"""Tests for health check endpoint."""
|
||||
|
||||
def test_health_check(self, test_client: TestClient) -> None:
|
||||
"""Test health check returns healthy status."""
|
||||
with patch("server.get_settings") as mock_settings:
|
||||
mock_settings.return_value = Settings(
|
||||
anthropic_api_key="test-key",
|
||||
)
|
||||
with patch("server.get_provider") as mock_provider:
|
||||
mock_provider.return_value = MagicMock()
|
||||
mock_provider.return_value.is_available = True
|
||||
|
||||
response = test_client.get("/health")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
assert data["service"] == "llm-gateway"
|
||||
|
||||
|
||||
class TestToolDiscoveryEndpoint:
|
||||
"""Tests for tool discovery endpoint."""
|
||||
|
||||
def test_list_tools(self, test_client: TestClient) -> None:
|
||||
"""Test listing available tools."""
|
||||
response = test_client.get("/mcp/tools")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "tools" in data
|
||||
assert len(data["tools"]) == 4 # 4 tools defined
|
||||
|
||||
tool_names = [t["name"] for t in data["tools"]]
|
||||
assert "chat_completion" in tool_names
|
||||
assert "list_models" in tool_names
|
||||
assert "get_usage" in tool_names
|
||||
assert "count_tokens" in tool_names
|
||||
|
||||
def test_tool_has_schema(self, test_client: TestClient) -> None:
|
||||
"""Test that tools have input schemas."""
|
||||
response = test_client.get("/mcp/tools")
|
||||
data = response.json()
|
||||
|
||||
for tool in data["tools"]:
|
||||
assert "inputSchema" in tool
|
||||
assert "type" in tool["inputSchema"]
|
||||
assert tool["inputSchema"]["type"] == "object"
|
||||
|
||||
|
||||
class TestJSONRPCEndpoint:
|
||||
"""Tests for JSON-RPC endpoint."""
|
||||
|
||||
def test_invalid_jsonrpc_version(self, test_client: TestClient) -> None:
|
||||
"""Test invalid JSON-RPC version."""
|
||||
response = test_client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "1.0", # Invalid
|
||||
"method": "tools/list",
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "error" in data
|
||||
assert data["error"]["code"] == -32600
|
||||
|
||||
def test_tools_list(self, test_client: TestClient) -> None:
|
||||
"""Test tools/list method."""
|
||||
response = test_client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/list",
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "result" in data
|
||||
assert "tools" in data["result"]
|
||||
|
||||
def test_unknown_method(self, test_client: TestClient) -> None:
|
||||
"""Test unknown method."""
|
||||
response = test_client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "unknown/method",
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "error" in data
|
||||
assert data["error"]["code"] == -32601
|
||||
|
||||
def test_unknown_tool(self, test_client: TestClient) -> None:
|
||||
"""Test unknown tool."""
|
||||
response = test_client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "unknown_tool",
|
||||
"arguments": {},
|
||||
},
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "error" in data
|
||||
assert "Unknown tool" in data["error"]["message"]
|
||||
|
||||
|
||||
class TestCountTokensTool:
|
||||
"""Tests for count_tokens tool."""
|
||||
|
||||
def test_count_tokens(self, test_client: TestClient) -> None:
|
||||
"""Test counting tokens."""
|
||||
response = test_client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "count_tokens",
|
||||
"arguments": {
|
||||
"project_id": "proj-123",
|
||||
"agent_id": "agent-456",
|
||||
"text": "Hello, world!",
|
||||
},
|
||||
},
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "result" in data
|
||||
|
||||
def test_count_tokens_with_model(self, test_client: TestClient) -> None:
|
||||
"""Test counting tokens with specific model."""
|
||||
response = test_client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "count_tokens",
|
||||
"arguments": {
|
||||
"project_id": "proj-123",
|
||||
"agent_id": "agent-456",
|
||||
"text": "Hello, world!",
|
||||
"model": "gpt-4",
|
||||
},
|
||||
},
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class TestListModelsTool:
|
||||
"""Tests for list_models tool."""
|
||||
|
||||
def test_list_all_models(self, test_client: TestClient) -> None:
|
||||
"""Test listing all models."""
|
||||
with patch("server.get_model_router") as mock_router:
|
||||
mock_router.return_value = MagicMock()
|
||||
mock_router.return_value.get_available_models_for_group = AsyncMock(
|
||||
return_value=[]
|
||||
)
|
||||
|
||||
response = test_client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "list_models",
|
||||
"arguments": {
|
||||
"project_id": "proj-123",
|
||||
"agent_id": "agent-456",
|
||||
},
|
||||
},
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_list_models_by_group(self, test_client: TestClient) -> None:
|
||||
"""Test listing models by group."""
|
||||
with patch("server.get_model_router") as mock_router:
|
||||
mock_router.return_value = MagicMock()
|
||||
mock_router.return_value.parse_model_group.return_value = ModelGroup.REASONING
|
||||
mock_router.return_value.get_available_models_for_group = AsyncMock(
|
||||
return_value=[]
|
||||
)
|
||||
|
||||
response = test_client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "list_models",
|
||||
"arguments": {
|
||||
"project_id": "proj-123",
|
||||
"agent_id": "agent-456",
|
||||
"model_group": "reasoning",
|
||||
},
|
||||
},
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class TestGetUsageTool:
|
||||
"""Tests for get_usage tool."""
|
||||
|
||||
def test_get_usage(self, test_client: TestClient) -> None:
|
||||
"""Test getting usage."""
|
||||
with patch("server.get_cost_tracker") as mock_tracker:
|
||||
mock_report = MagicMock()
|
||||
mock_report.total_requests = 10
|
||||
mock_report.total_tokens = 1000
|
||||
mock_report.total_cost_usd = 0.50
|
||||
mock_report.by_model = {}
|
||||
mock_report.period_start.isoformat.return_value = "2024-01-01T00:00:00"
|
||||
mock_report.period_end.isoformat.return_value = "2024-01-02T00:00:00"
|
||||
|
||||
mock_tracker.return_value = MagicMock()
|
||||
mock_tracker.return_value.get_project_usage = AsyncMock(
|
||||
return_value=mock_report
|
||||
)
|
||||
mock_tracker.return_value.get_agent_usage = AsyncMock(
|
||||
return_value=mock_report
|
||||
)
|
||||
|
||||
response = test_client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "get_usage",
|
||||
"arguments": {
|
||||
"project_id": "proj-123",
|
||||
"agent_id": "agent-456",
|
||||
"period": "day",
|
||||
},
|
||||
},
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class TestChatCompletionTool:
|
||||
"""Tests for chat_completion tool."""
|
||||
|
||||
def test_chat_completion_streaming_not_supported(
|
||||
self,
|
||||
test_client: TestClient,
|
||||
) -> None:
|
||||
"""Test that streaming returns info message."""
|
||||
with patch("server.get_model_router") as mock_router:
|
||||
mock_router.return_value = MagicMock()
|
||||
mock_router.return_value.select_model = AsyncMock(
|
||||
return_value=("claude-opus-4", MagicMock())
|
||||
)
|
||||
|
||||
with patch("server.get_cost_tracker") as mock_tracker:
|
||||
mock_tracker.return_value = MagicMock()
|
||||
mock_tracker.return_value.check_budget = AsyncMock(
|
||||
return_value=(True, 0.0, 100.0)
|
||||
)
|
||||
|
||||
response = test_client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "chat_completion",
|
||||
"arguments": {
|
||||
"project_id": "proj-123",
|
||||
"agent_id": "agent-456",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
],
|
||||
"stream": True,
|
||||
},
|
||||
},
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_chat_completion_success(self, test_client: TestClient) -> None:
|
||||
"""Test successful chat completion."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Hello, world!"
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
mock_response.usage = MagicMock()
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 5
|
||||
|
||||
with patch("server.get_model_router") as mock_router:
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.provider.value = "anthropic"
|
||||
mock_router.return_value = MagicMock()
|
||||
mock_router.return_value.select_model = AsyncMock(
|
||||
return_value=("claude-opus-4", mock_model_config)
|
||||
)
|
||||
|
||||
with patch("server.get_cost_tracker") as mock_tracker:
|
||||
mock_tracker.return_value = MagicMock()
|
||||
mock_tracker.return_value.check_budget = AsyncMock(
|
||||
return_value=(True, 0.0, 100.0)
|
||||
)
|
||||
mock_tracker.return_value.record_usage = AsyncMock()
|
||||
|
||||
with patch("server.get_provider") as mock_prov:
|
||||
mock_prov.return_value = MagicMock()
|
||||
mock_prov.return_value.router = MagicMock()
|
||||
mock_prov.return_value.router.acompletion = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
with patch("server.get_circuit_registry") as mock_reg:
|
||||
mock_circuit = MagicMock()
|
||||
mock_circuit.record_success = AsyncMock()
|
||||
mock_reg.return_value = MagicMock()
|
||||
mock_reg.return_value.get_circuit_sync.return_value = mock_circuit
|
||||
|
||||
response = test_client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "chat_completion",
|
||||
"arguments": {
|
||||
"project_id": "proj-123",
|
||||
"agent_id": "agent-456",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
],
|
||||
},
|
||||
},
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
312
mcp-servers/llm-gateway/tests/test_streaming.py
Normal file
312
mcp-servers/llm-gateway/tests/test_streaming.py
Normal file
@@ -0,0 +1,312 @@
|
||||
"""
|
||||
Tests for streaming module.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from models import StreamChunk, UsageStats
|
||||
from streaming import (
|
||||
StreamAccumulator,
|
||||
StreamBuffer,
|
||||
format_sse_chunk,
|
||||
format_sse_done,
|
||||
format_sse_error,
|
||||
stream_to_string,
|
||||
wrap_litellm_stream,
|
||||
)
|
||||
|
||||
|
||||
class TestStreamAccumulator:
|
||||
"""Tests for StreamAccumulator class."""
|
||||
|
||||
def test_initial_state(self) -> None:
|
||||
"""Test initial accumulator state."""
|
||||
acc = StreamAccumulator()
|
||||
|
||||
assert acc.request_id is not None
|
||||
assert acc.content == ""
|
||||
assert acc.chunks_received == 0
|
||||
assert acc.prompt_tokens == 0
|
||||
assert acc.completion_tokens == 0
|
||||
assert acc.model is None
|
||||
assert acc.finish_reason is None
|
||||
|
||||
def test_custom_request_id(self) -> None:
|
||||
"""Test accumulator with custom request ID."""
|
||||
acc = StreamAccumulator(request_id="custom-id")
|
||||
assert acc.request_id == "custom-id"
|
||||
|
||||
def test_add_chunk_text(self) -> None:
|
||||
"""Test adding text chunks."""
|
||||
acc = StreamAccumulator()
|
||||
|
||||
acc.add_chunk("Hello")
|
||||
acc.add_chunk(", ")
|
||||
acc.add_chunk("world!")
|
||||
|
||||
assert acc.content == "Hello, world!"
|
||||
assert acc.chunks_received == 3
|
||||
|
||||
def test_add_chunk_with_finish_reason(self) -> None:
|
||||
"""Test adding chunk with finish reason."""
|
||||
acc = StreamAccumulator()
|
||||
|
||||
acc.add_chunk("Final", finish_reason="stop")
|
||||
|
||||
assert acc.finish_reason == "stop"
|
||||
|
||||
def test_add_chunk_with_model(self) -> None:
|
||||
"""Test adding chunk with model info."""
|
||||
acc = StreamAccumulator()
|
||||
|
||||
acc.add_chunk("Text", model="claude-opus-4")
|
||||
|
||||
assert acc.model == "claude-opus-4"
|
||||
|
||||
def test_add_chunk_with_usage(self) -> None:
|
||||
"""Test adding chunk with usage stats."""
|
||||
acc = StreamAccumulator()
|
||||
|
||||
acc.add_chunk(
|
||||
"Text",
|
||||
usage={"prompt_tokens": 10, "completion_tokens": 5},
|
||||
)
|
||||
|
||||
assert acc.prompt_tokens == 10
|
||||
assert acc.completion_tokens == 5
|
||||
assert acc.total_tokens == 15
|
||||
|
||||
def test_start_and_finish(self) -> None:
|
||||
"""Test start and finish timing."""
|
||||
acc = StreamAccumulator()
|
||||
|
||||
assert acc.duration_seconds is None
|
||||
|
||||
acc.start()
|
||||
acc.finish()
|
||||
|
||||
assert acc.duration_seconds is not None
|
||||
assert acc.duration_seconds >= 0
|
||||
|
||||
def test_get_usage_stats(self) -> None:
|
||||
"""Test getting usage stats."""
|
||||
acc = StreamAccumulator()
|
||||
acc.add_chunk("", usage={"prompt_tokens": 100, "completion_tokens": 50})
|
||||
|
||||
stats = acc.get_usage_stats(cost_usd=0.01)
|
||||
|
||||
assert stats.prompt_tokens == 100
|
||||
assert stats.completion_tokens == 50
|
||||
assert stats.total_tokens == 150
|
||||
assert stats.cost_usd == 0.01
|
||||
|
||||
|
||||
class TestWrapLiteLLMStream:
|
||||
"""Tests for wrap_litellm_stream function."""
|
||||
|
||||
async def test_wrap_stream_basic(self) -> None:
|
||||
"""Test wrapping a basic stream."""
|
||||
# Create mock stream chunks
|
||||
async def mock_stream():
|
||||
chunk1 = MagicMock()
|
||||
chunk1.choices = [MagicMock()]
|
||||
chunk1.choices[0].delta = MagicMock()
|
||||
chunk1.choices[0].delta.content = "Hello"
|
||||
chunk1.choices[0].finish_reason = None
|
||||
chunk1.model = "test-model"
|
||||
chunk1.usage = None
|
||||
yield chunk1
|
||||
|
||||
chunk2 = MagicMock()
|
||||
chunk2.choices = [MagicMock()]
|
||||
chunk2.choices[0].delta = MagicMock()
|
||||
chunk2.choices[0].delta.content = " World"
|
||||
chunk2.choices[0].finish_reason = "stop"
|
||||
chunk2.model = "test-model"
|
||||
chunk2.usage = MagicMock()
|
||||
chunk2.usage.prompt_tokens = 5
|
||||
chunk2.usage.completion_tokens = 2
|
||||
yield chunk2
|
||||
|
||||
accumulator = StreamAccumulator()
|
||||
chunks = []
|
||||
|
||||
async for chunk in wrap_litellm_stream(mock_stream(), accumulator):
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0].delta == "Hello"
|
||||
assert chunks[1].delta == " World"
|
||||
assert chunks[1].finish_reason == "stop"
|
||||
assert accumulator.content == "Hello World"
|
||||
|
||||
async def test_wrap_stream_without_accumulator(self) -> None:
|
||||
"""Test wrapping stream without accumulator."""
|
||||
async def mock_stream():
|
||||
chunk = MagicMock()
|
||||
chunk.choices = [MagicMock()]
|
||||
chunk.choices[0].delta = MagicMock()
|
||||
chunk.choices[0].delta.content = "Test"
|
||||
chunk.choices[0].finish_reason = None
|
||||
chunk.model = None
|
||||
chunk.usage = None
|
||||
yield chunk
|
||||
|
||||
chunks = []
|
||||
async for chunk in wrap_litellm_stream(mock_stream()):
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) == 1
|
||||
|
||||
|
||||
class TestSSEFormatting:
|
||||
"""Tests for SSE formatting functions."""
|
||||
|
||||
def test_format_sse_chunk_basic(self) -> None:
|
||||
"""Test formatting basic chunk."""
|
||||
chunk = StreamChunk(id="chunk-1", delta="Hello")
|
||||
result = format_sse_chunk(chunk)
|
||||
|
||||
assert result.startswith("data: ")
|
||||
assert result.endswith("\n\n")
|
||||
|
||||
# Parse the JSON
|
||||
data = json.loads(result[6:-2])
|
||||
assert data["id"] == "chunk-1"
|
||||
assert data["delta"] == "Hello"
|
||||
|
||||
def test_format_sse_chunk_with_finish(self) -> None:
|
||||
"""Test formatting chunk with finish reason."""
|
||||
chunk = StreamChunk(
|
||||
id="chunk-2",
|
||||
delta="",
|
||||
finish_reason="stop",
|
||||
)
|
||||
result = format_sse_chunk(chunk)
|
||||
data = json.loads(result[6:-2])
|
||||
|
||||
assert data["finish_reason"] == "stop"
|
||||
|
||||
def test_format_sse_chunk_with_usage(self) -> None:
|
||||
"""Test formatting chunk with usage."""
|
||||
chunk = StreamChunk(
|
||||
id="chunk-3",
|
||||
delta="",
|
||||
finish_reason="stop",
|
||||
usage=UsageStats(
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
total_tokens=15,
|
||||
cost_usd=0.001,
|
||||
),
|
||||
)
|
||||
result = format_sse_chunk(chunk)
|
||||
data = json.loads(result[6:-2])
|
||||
|
||||
assert "usage" in data
|
||||
assert data["usage"]["prompt_tokens"] == 10
|
||||
|
||||
def test_format_sse_done(self) -> None:
|
||||
"""Test formatting done message."""
|
||||
result = format_sse_done()
|
||||
assert result == "data: [DONE]\n\n"
|
||||
|
||||
def test_format_sse_error(self) -> None:
|
||||
"""Test formatting error message."""
|
||||
result = format_sse_error("Something went wrong", code="ERROR_CODE")
|
||||
data = json.loads(result[6:-2])
|
||||
|
||||
assert data["error"] == "Something went wrong"
|
||||
assert data["code"] == "ERROR_CODE"
|
||||
|
||||
def test_format_sse_error_without_code(self) -> None:
|
||||
"""Test formatting error without code."""
|
||||
result = format_sse_error("Error message")
|
||||
data = json.loads(result[6:-2])
|
||||
|
||||
assert data["error"] == "Error message"
|
||||
assert "code" not in data
|
||||
|
||||
|
||||
class TestStreamBuffer:
|
||||
"""Tests for StreamBuffer class."""
|
||||
|
||||
async def test_buffer_basic(self) -> None:
|
||||
"""Test basic buffer operations."""
|
||||
buffer = StreamBuffer(max_size=10)
|
||||
|
||||
# Producer
|
||||
async def produce():
|
||||
await buffer.put(StreamChunk(id="1", delta="Hello"))
|
||||
await buffer.put(StreamChunk(id="2", delta=" World"))
|
||||
await buffer.done()
|
||||
|
||||
# Consumer
|
||||
chunks = []
|
||||
asyncio.create_task(produce())
|
||||
|
||||
async for chunk in buffer:
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0].delta == "Hello"
|
||||
assert chunks[1].delta == " World"
|
||||
|
||||
async def test_buffer_error(self) -> None:
|
||||
"""Test buffer with error."""
|
||||
buffer = StreamBuffer()
|
||||
|
||||
async def produce():
|
||||
await buffer.put(StreamChunk(id="1", delta="Hello"))
|
||||
await buffer.error(ValueError("Test error"))
|
||||
|
||||
asyncio.create_task(produce())
|
||||
|
||||
with pytest.raises(ValueError, match="Test error"):
|
||||
async for _ in buffer:
|
||||
pass
|
||||
|
||||
async def test_buffer_put_after_done(self) -> None:
|
||||
"""Test putting after done raises."""
|
||||
buffer = StreamBuffer()
|
||||
await buffer.done()
|
||||
|
||||
with pytest.raises(RuntimeError, match="closed"):
|
||||
await buffer.put(StreamChunk(id="1", delta="Test"))
|
||||
|
||||
|
||||
class TestStreamToString:
|
||||
"""Tests for stream_to_string function."""
|
||||
|
||||
async def test_stream_to_string_basic(self) -> None:
|
||||
"""Test converting stream to string."""
|
||||
async def mock_stream():
|
||||
yield StreamChunk(id="1", delta="Hello")
|
||||
yield StreamChunk(id="2", delta=" ")
|
||||
yield StreamChunk(id="3", delta="World")
|
||||
yield StreamChunk(
|
||||
id="4",
|
||||
delta="",
|
||||
finish_reason="stop",
|
||||
usage=UsageStats(prompt_tokens=5, completion_tokens=3),
|
||||
)
|
||||
|
||||
content, usage = await stream_to_string(mock_stream())
|
||||
|
||||
assert content == "Hello World"
|
||||
assert usage is not None
|
||||
assert usage.prompt_tokens == 5
|
||||
|
||||
async def test_stream_to_string_no_usage(self) -> None:
|
||||
"""Test stream without usage stats."""
|
||||
async def mock_stream():
|
||||
yield StreamChunk(id="1", delta="Test")
|
||||
|
||||
content, usage = await stream_to_string(mock_stream())
|
||||
|
||||
assert content == "Test"
|
||||
assert usage is None
|
||||
Reference in New Issue
Block a user