Files
syndarix/mcp-servers/llm-gateway/tests/test_server.py
Felipe Cardoso f482559e15 fix(llm-gateway): improve type safety and datetime consistency
- Add type annotations for mypy compliance
- Use UTC-aware datetimes consistently (datetime.now(UTC))
- Add type: ignore comments for LiteLLM incomplete stubs
- Fix import ordering and formatting
- Update pyproject.toml mypy configuration

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-03 20:56:05 +01:00

416 lines
14 KiB
Python

"""
Tests for server module.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from config import Settings
from models import ModelGroup
@pytest.fixture
def test_settings() -> Settings:
"""Test settings with mock API keys."""
return Settings(
anthropic_api_key="test-anthropic-key",
openai_api_key="test-openai-key",
google_api_key="test-google-key",
cost_tracking_enabled=False, # Disable for most tests
litellm_cache_enabled=False,
)
@pytest.fixture
def test_client(test_settings: Settings) -> TestClient:
"""Create test client with mocked dependencies."""
with (
patch("server.get_settings", return_value=test_settings),
patch("server.get_provider") as mock_provider,
):
mock_provider.return_value = MagicMock()
mock_provider.return_value.is_available = True
mock_provider.return_value.router = MagicMock()
mock_provider.return_value.get_available_models.return_value = {}
from server import app
return TestClient(app)
class TestHealthEndpoint:
"""Tests for health check endpoint."""
def test_health_check(self, test_client: TestClient) -> None:
"""Test health check returns healthy status."""
with patch("server.get_settings") as mock_settings:
mock_settings.return_value = Settings(
anthropic_api_key="test-key",
)
with patch("server.get_provider") as mock_provider:
mock_provider.return_value = MagicMock()
mock_provider.return_value.is_available = True
response = test_client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert data["service"] == "llm-gateway"
class TestToolDiscoveryEndpoint:
"""Tests for tool discovery endpoint."""
def test_list_tools(self, test_client: TestClient) -> None:
"""Test listing available tools."""
response = test_client.get("/mcp/tools")
assert response.status_code == 200
data = response.json()
assert "tools" in data
assert len(data["tools"]) == 4 # 4 tools defined
tool_names = [t["name"] for t in data["tools"]]
assert "chat_completion" in tool_names
assert "list_models" in tool_names
assert "get_usage" in tool_names
assert "count_tokens" in tool_names
def test_tool_has_schema(self, test_client: TestClient) -> None:
"""Test that tools have input schemas."""
response = test_client.get("/mcp/tools")
data = response.json()
for tool in data["tools"]:
assert "inputSchema" in tool
assert "type" in tool["inputSchema"]
assert tool["inputSchema"]["type"] == "object"
class TestJSONRPCEndpoint:
"""Tests for JSON-RPC endpoint."""
def test_invalid_jsonrpc_version(self, test_client: TestClient) -> None:
"""Test invalid JSON-RPC version."""
response = test_client.post(
"/mcp",
json={
"jsonrpc": "1.0", # Invalid
"method": "tools/list",
"id": 1,
},
)
assert response.status_code == 200
data = response.json()
assert "error" in data
assert data["error"]["code"] == -32600
def test_tools_list(self, test_client: TestClient) -> None:
"""Test tools/list method."""
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/list",
"id": 1,
},
)
assert response.status_code == 200
data = response.json()
assert "result" in data
assert "tools" in data["result"]
def test_unknown_method(self, test_client: TestClient) -> None:
"""Test unknown method."""
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "unknown/method",
"id": 1,
},
)
assert response.status_code == 200
data = response.json()
assert "error" in data
assert data["error"]["code"] == -32601
def test_unknown_tool(self, test_client: TestClient) -> None:
"""Test unknown tool."""
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "unknown_tool",
"arguments": {},
},
"id": 1,
},
)
assert response.status_code == 200
data = response.json()
assert "error" in data
assert "Unknown tool" in data["error"]["message"]
class TestCountTokensTool:
"""Tests for count_tokens tool."""
def test_count_tokens(self, test_client: TestClient) -> None:
"""Test counting tokens."""
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "count_tokens",
"arguments": {
"project_id": "proj-123",
"agent_id": "agent-456",
"text": "Hello, world!",
},
},
"id": 1,
},
)
assert response.status_code == 200
data = response.json()
assert "result" in data
def test_count_tokens_with_model(self, test_client: TestClient) -> None:
"""Test counting tokens with specific model."""
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "count_tokens",
"arguments": {
"project_id": "proj-123",
"agent_id": "agent-456",
"text": "Hello, world!",
"model": "gpt-4",
},
},
"id": 1,
},
)
assert response.status_code == 200
class TestListModelsTool:
"""Tests for list_models tool."""
def test_list_all_models(self, test_client: TestClient) -> None:
"""Test listing all models."""
with patch("server.get_model_router") as mock_router:
mock_router.return_value = MagicMock()
mock_router.return_value.get_available_models_for_group = AsyncMock(
return_value=[]
)
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "list_models",
"arguments": {
"project_id": "proj-123",
"agent_id": "agent-456",
},
},
"id": 1,
},
)
assert response.status_code == 200
def test_list_models_by_group(self, test_client: TestClient) -> None:
"""Test listing models by group."""
with patch("server.get_model_router") as mock_router:
mock_router.return_value = MagicMock()
mock_router.return_value.parse_model_group.return_value = (
ModelGroup.REASONING
)
mock_router.return_value.get_available_models_for_group = AsyncMock(
return_value=[]
)
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "list_models",
"arguments": {
"project_id": "proj-123",
"agent_id": "agent-456",
"model_group": "reasoning",
},
},
"id": 1,
},
)
assert response.status_code == 200
class TestGetUsageTool:
"""Tests for get_usage tool."""
def test_get_usage(self, test_client: TestClient) -> None:
"""Test getting usage."""
with patch("server.get_cost_tracker") as mock_tracker:
mock_report = MagicMock()
mock_report.total_requests = 10
mock_report.total_tokens = 1000
mock_report.total_cost_usd = 0.50
mock_report.by_model = {}
mock_report.period_start.isoformat.return_value = "2024-01-01T00:00:00"
mock_report.period_end.isoformat.return_value = "2024-01-02T00:00:00"
mock_tracker.return_value = MagicMock()
mock_tracker.return_value.get_project_usage = AsyncMock(
return_value=mock_report
)
mock_tracker.return_value.get_agent_usage = AsyncMock(
return_value=mock_report
)
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "get_usage",
"arguments": {
"project_id": "proj-123",
"agent_id": "agent-456",
"period": "day",
},
},
"id": 1,
},
)
assert response.status_code == 200
class TestChatCompletionTool:
"""Tests for chat_completion tool."""
def test_chat_completion_streaming_not_supported(
self,
test_client: TestClient,
) -> None:
"""Test that streaming returns info message."""
with patch("server.get_model_router") as mock_router:
mock_router.return_value = MagicMock()
mock_router.return_value.select_model = AsyncMock(
return_value=("claude-opus-4", MagicMock())
)
with patch("server.get_cost_tracker") as mock_tracker:
mock_tracker.return_value = MagicMock()
mock_tracker.return_value.check_budget = AsyncMock(
return_value=(True, 0.0, 100.0)
)
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "chat_completion",
"arguments": {
"project_id": "proj-123",
"agent_id": "agent-456",
"messages": [{"role": "user", "content": "Hello"}],
"stream": True,
},
},
"id": 1,
},
)
assert response.status_code == 200
def test_chat_completion_success(self, test_client: TestClient) -> None:
"""Test successful chat completion."""
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Hello, world!"
mock_response.choices[0].finish_reason = "stop"
mock_response.usage = MagicMock()
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 5
with patch("server.get_model_router") as mock_router:
mock_model_config = MagicMock()
mock_model_config.provider.value = "anthropic"
mock_router.return_value = MagicMock()
mock_router.return_value.select_model = AsyncMock(
return_value=("claude-opus-4", mock_model_config)
)
with patch("server.get_cost_tracker") as mock_tracker:
mock_tracker.return_value = MagicMock()
mock_tracker.return_value.check_budget = AsyncMock(
return_value=(True, 0.0, 100.0)
)
mock_tracker.return_value.record_usage = AsyncMock()
with patch("server.get_provider") as mock_prov:
mock_prov.return_value = MagicMock()
mock_prov.return_value.router = MagicMock()
mock_prov.return_value.router.acompletion = AsyncMock(
return_value=mock_response
)
with patch("server.get_circuit_registry") as mock_reg:
mock_circuit = MagicMock()
mock_circuit.record_success = AsyncMock()
mock_reg.return_value = MagicMock()
mock_reg.return_value.get_circuit_sync.return_value = (
mock_circuit
)
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "chat_completion",
"arguments": {
"project_id": "proj-123",
"agent_id": "agent-456",
"messages": [
{"role": "user", "content": "Hello"}
],
},
},
"id": 1,
},
)
assert response.status_code == 200