feat(llm-gateway): implement LLM Gateway MCP Server (#56)

Implements complete LLM Gateway MCP Server with:
- FastMCP server with 4 tools: chat_completion, list_models, get_usage, count_tokens
- LiteLLM Router with multi-provider failover chains
- Circuit breaker pattern for fault tolerance
- Redis-based cost tracking per project/agent
- Comprehensive test suite (209 tests, 92% coverage)

Model groups defined per ADR-004:
- reasoning: claude-opus-4 → gpt-4.1 → gemini-2.5-pro
- code: claude-sonnet-4 → gpt-4.1 → deepseek-coder
- fast: claude-haiku → gpt-4.1-mini → gemini-2.0-flash

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-03 20:31:19 +01:00
parent 746fb7b181
commit 6e8b0b022a
23 changed files with 9794 additions and 93 deletions

View File

@@ -0,0 +1,412 @@
"""
Tests for server module.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from config import Settings
from models import ModelGroup
@pytest.fixture
def test_settings() -> Settings:
"""Test settings with mock API keys."""
return Settings(
anthropic_api_key="test-anthropic-key",
openai_api_key="test-openai-key",
google_api_key="test-google-key",
cost_tracking_enabled=False, # Disable for most tests
litellm_cache_enabled=False,
)
@pytest.fixture
def test_client(test_settings: Settings) -> TestClient:
"""Create test client with mocked dependencies."""
with (
patch("server.get_settings", return_value=test_settings),
patch("server.get_provider") as mock_provider,
):
mock_provider.return_value = MagicMock()
mock_provider.return_value.is_available = True
mock_provider.return_value.router = MagicMock()
mock_provider.return_value.get_available_models.return_value = {}
from server import app
return TestClient(app)
class TestHealthEndpoint:
"""Tests for health check endpoint."""
def test_health_check(self, test_client: TestClient) -> None:
"""Test health check returns healthy status."""
with patch("server.get_settings") as mock_settings:
mock_settings.return_value = Settings(
anthropic_api_key="test-key",
)
with patch("server.get_provider") as mock_provider:
mock_provider.return_value = MagicMock()
mock_provider.return_value.is_available = True
response = test_client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert data["service"] == "llm-gateway"
class TestToolDiscoveryEndpoint:
"""Tests for tool discovery endpoint."""
def test_list_tools(self, test_client: TestClient) -> None:
"""Test listing available tools."""
response = test_client.get("/mcp/tools")
assert response.status_code == 200
data = response.json()
assert "tools" in data
assert len(data["tools"]) == 4 # 4 tools defined
tool_names = [t["name"] for t in data["tools"]]
assert "chat_completion" in tool_names
assert "list_models" in tool_names
assert "get_usage" in tool_names
assert "count_tokens" in tool_names
def test_tool_has_schema(self, test_client: TestClient) -> None:
"""Test that tools have input schemas."""
response = test_client.get("/mcp/tools")
data = response.json()
for tool in data["tools"]:
assert "inputSchema" in tool
assert "type" in tool["inputSchema"]
assert tool["inputSchema"]["type"] == "object"
class TestJSONRPCEndpoint:
"""Tests for JSON-RPC endpoint."""
def test_invalid_jsonrpc_version(self, test_client: TestClient) -> None:
"""Test invalid JSON-RPC version."""
response = test_client.post(
"/mcp",
json={
"jsonrpc": "1.0", # Invalid
"method": "tools/list",
"id": 1,
},
)
assert response.status_code == 200
data = response.json()
assert "error" in data
assert data["error"]["code"] == -32600
def test_tools_list(self, test_client: TestClient) -> None:
"""Test tools/list method."""
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/list",
"id": 1,
},
)
assert response.status_code == 200
data = response.json()
assert "result" in data
assert "tools" in data["result"]
def test_unknown_method(self, test_client: TestClient) -> None:
"""Test unknown method."""
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "unknown/method",
"id": 1,
},
)
assert response.status_code == 200
data = response.json()
assert "error" in data
assert data["error"]["code"] == -32601
def test_unknown_tool(self, test_client: TestClient) -> None:
"""Test unknown tool."""
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "unknown_tool",
"arguments": {},
},
"id": 1,
},
)
assert response.status_code == 200
data = response.json()
assert "error" in data
assert "Unknown tool" in data["error"]["message"]
class TestCountTokensTool:
"""Tests for count_tokens tool."""
def test_count_tokens(self, test_client: TestClient) -> None:
"""Test counting tokens."""
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "count_tokens",
"arguments": {
"project_id": "proj-123",
"agent_id": "agent-456",
"text": "Hello, world!",
},
},
"id": 1,
},
)
assert response.status_code == 200
data = response.json()
assert "result" in data
def test_count_tokens_with_model(self, test_client: TestClient) -> None:
"""Test counting tokens with specific model."""
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "count_tokens",
"arguments": {
"project_id": "proj-123",
"agent_id": "agent-456",
"text": "Hello, world!",
"model": "gpt-4",
},
},
"id": 1,
},
)
assert response.status_code == 200
class TestListModelsTool:
"""Tests for list_models tool."""
def test_list_all_models(self, test_client: TestClient) -> None:
"""Test listing all models."""
with patch("server.get_model_router") as mock_router:
mock_router.return_value = MagicMock()
mock_router.return_value.get_available_models_for_group = AsyncMock(
return_value=[]
)
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "list_models",
"arguments": {
"project_id": "proj-123",
"agent_id": "agent-456",
},
},
"id": 1,
},
)
assert response.status_code == 200
def test_list_models_by_group(self, test_client: TestClient) -> None:
"""Test listing models by group."""
with patch("server.get_model_router") as mock_router:
mock_router.return_value = MagicMock()
mock_router.return_value.parse_model_group.return_value = ModelGroup.REASONING
mock_router.return_value.get_available_models_for_group = AsyncMock(
return_value=[]
)
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "list_models",
"arguments": {
"project_id": "proj-123",
"agent_id": "agent-456",
"model_group": "reasoning",
},
},
"id": 1,
},
)
assert response.status_code == 200
class TestGetUsageTool:
"""Tests for get_usage tool."""
def test_get_usage(self, test_client: TestClient) -> None:
"""Test getting usage."""
with patch("server.get_cost_tracker") as mock_tracker:
mock_report = MagicMock()
mock_report.total_requests = 10
mock_report.total_tokens = 1000
mock_report.total_cost_usd = 0.50
mock_report.by_model = {}
mock_report.period_start.isoformat.return_value = "2024-01-01T00:00:00"
mock_report.period_end.isoformat.return_value = "2024-01-02T00:00:00"
mock_tracker.return_value = MagicMock()
mock_tracker.return_value.get_project_usage = AsyncMock(
return_value=mock_report
)
mock_tracker.return_value.get_agent_usage = AsyncMock(
return_value=mock_report
)
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "get_usage",
"arguments": {
"project_id": "proj-123",
"agent_id": "agent-456",
"period": "day",
},
},
"id": 1,
},
)
assert response.status_code == 200
class TestChatCompletionTool:
"""Tests for chat_completion tool."""
def test_chat_completion_streaming_not_supported(
self,
test_client: TestClient,
) -> None:
"""Test that streaming returns info message."""
with patch("server.get_model_router") as mock_router:
mock_router.return_value = MagicMock()
mock_router.return_value.select_model = AsyncMock(
return_value=("claude-opus-4", MagicMock())
)
with patch("server.get_cost_tracker") as mock_tracker:
mock_tracker.return_value = MagicMock()
mock_tracker.return_value.check_budget = AsyncMock(
return_value=(True, 0.0, 100.0)
)
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "chat_completion",
"arguments": {
"project_id": "proj-123",
"agent_id": "agent-456",
"messages": [
{"role": "user", "content": "Hello"}
],
"stream": True,
},
},
"id": 1,
},
)
assert response.status_code == 200
def test_chat_completion_success(self, test_client: TestClient) -> None:
"""Test successful chat completion."""
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Hello, world!"
mock_response.choices[0].finish_reason = "stop"
mock_response.usage = MagicMock()
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 5
with patch("server.get_model_router") as mock_router:
mock_model_config = MagicMock()
mock_model_config.provider.value = "anthropic"
mock_router.return_value = MagicMock()
mock_router.return_value.select_model = AsyncMock(
return_value=("claude-opus-4", mock_model_config)
)
with patch("server.get_cost_tracker") as mock_tracker:
mock_tracker.return_value = MagicMock()
mock_tracker.return_value.check_budget = AsyncMock(
return_value=(True, 0.0, 100.0)
)
mock_tracker.return_value.record_usage = AsyncMock()
with patch("server.get_provider") as mock_prov:
mock_prov.return_value = MagicMock()
mock_prov.return_value.router = MagicMock()
mock_prov.return_value.router.acompletion = AsyncMock(
return_value=mock_response
)
with patch("server.get_circuit_registry") as mock_reg:
mock_circuit = MagicMock()
mock_circuit.record_success = AsyncMock()
mock_reg.return_value = MagicMock()
mock_reg.return_value.get_circuit_sync.return_value = mock_circuit
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "chat_completion",
"arguments": {
"project_id": "proj-123",
"agent_id": "agent-456",
"messages": [
{"role": "user", "content": "Hello"}
],
},
},
"id": 1,
},
)
assert response.status_code == 200