forked from cardosofelipe/fast-next-template
- Add type annotations for mypy compliance - Use UTC-aware datetimes consistently (datetime.now(UTC)) - Add type: ignore comments for LiteLLM incomplete stubs - Fix import ordering and formatting - Update pyproject.toml mypy configuration 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
416 lines
14 KiB
Python
416 lines
14 KiB
Python
"""
|
|
Tests for server module.
|
|
"""
|
|
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
|
|
from config import Settings
|
|
from models import ModelGroup
|
|
|
|
|
|
@pytest.fixture
|
|
def test_settings() -> Settings:
|
|
"""Test settings with mock API keys."""
|
|
return Settings(
|
|
anthropic_api_key="test-anthropic-key",
|
|
openai_api_key="test-openai-key",
|
|
google_api_key="test-google-key",
|
|
cost_tracking_enabled=False, # Disable for most tests
|
|
litellm_cache_enabled=False,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def test_client(test_settings: Settings) -> TestClient:
|
|
"""Create test client with mocked dependencies."""
|
|
with (
|
|
patch("server.get_settings", return_value=test_settings),
|
|
patch("server.get_provider") as mock_provider,
|
|
):
|
|
mock_provider.return_value = MagicMock()
|
|
mock_provider.return_value.is_available = True
|
|
mock_provider.return_value.router = MagicMock()
|
|
mock_provider.return_value.get_available_models.return_value = {}
|
|
|
|
from server import app
|
|
|
|
return TestClient(app)
|
|
|
|
|
|
class TestHealthEndpoint:
|
|
"""Tests for health check endpoint."""
|
|
|
|
def test_health_check(self, test_client: TestClient) -> None:
|
|
"""Test health check returns healthy status."""
|
|
with patch("server.get_settings") as mock_settings:
|
|
mock_settings.return_value = Settings(
|
|
anthropic_api_key="test-key",
|
|
)
|
|
with patch("server.get_provider") as mock_provider:
|
|
mock_provider.return_value = MagicMock()
|
|
mock_provider.return_value.is_available = True
|
|
|
|
response = test_client.get("/health")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "healthy"
|
|
assert data["service"] == "llm-gateway"
|
|
|
|
|
|
class TestToolDiscoveryEndpoint:
|
|
"""Tests for tool discovery endpoint."""
|
|
|
|
def test_list_tools(self, test_client: TestClient) -> None:
|
|
"""Test listing available tools."""
|
|
response = test_client.get("/mcp/tools")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "tools" in data
|
|
assert len(data["tools"]) == 4 # 4 tools defined
|
|
|
|
tool_names = [t["name"] for t in data["tools"]]
|
|
assert "chat_completion" in tool_names
|
|
assert "list_models" in tool_names
|
|
assert "get_usage" in tool_names
|
|
assert "count_tokens" in tool_names
|
|
|
|
def test_tool_has_schema(self, test_client: TestClient) -> None:
|
|
"""Test that tools have input schemas."""
|
|
response = test_client.get("/mcp/tools")
|
|
data = response.json()
|
|
|
|
for tool in data["tools"]:
|
|
assert "inputSchema" in tool
|
|
assert "type" in tool["inputSchema"]
|
|
assert tool["inputSchema"]["type"] == "object"
|
|
|
|
|
|
class TestJSONRPCEndpoint:
|
|
"""Tests for JSON-RPC endpoint."""
|
|
|
|
def test_invalid_jsonrpc_version(self, test_client: TestClient) -> None:
|
|
"""Test invalid JSON-RPC version."""
|
|
response = test_client.post(
|
|
"/mcp",
|
|
json={
|
|
"jsonrpc": "1.0", # Invalid
|
|
"method": "tools/list",
|
|
"id": 1,
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "error" in data
|
|
assert data["error"]["code"] == -32600
|
|
|
|
def test_tools_list(self, test_client: TestClient) -> None:
|
|
"""Test tools/list method."""
|
|
response = test_client.post(
|
|
"/mcp",
|
|
json={
|
|
"jsonrpc": "2.0",
|
|
"method": "tools/list",
|
|
"id": 1,
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "result" in data
|
|
assert "tools" in data["result"]
|
|
|
|
def test_unknown_method(self, test_client: TestClient) -> None:
|
|
"""Test unknown method."""
|
|
response = test_client.post(
|
|
"/mcp",
|
|
json={
|
|
"jsonrpc": "2.0",
|
|
"method": "unknown/method",
|
|
"id": 1,
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "error" in data
|
|
assert data["error"]["code"] == -32601
|
|
|
|
def test_unknown_tool(self, test_client: TestClient) -> None:
|
|
"""Test unknown tool."""
|
|
response = test_client.post(
|
|
"/mcp",
|
|
json={
|
|
"jsonrpc": "2.0",
|
|
"method": "tools/call",
|
|
"params": {
|
|
"name": "unknown_tool",
|
|
"arguments": {},
|
|
},
|
|
"id": 1,
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "error" in data
|
|
assert "Unknown tool" in data["error"]["message"]
|
|
|
|
|
|
class TestCountTokensTool:
|
|
"""Tests for count_tokens tool."""
|
|
|
|
def test_count_tokens(self, test_client: TestClient) -> None:
|
|
"""Test counting tokens."""
|
|
response = test_client.post(
|
|
"/mcp",
|
|
json={
|
|
"jsonrpc": "2.0",
|
|
"method": "tools/call",
|
|
"params": {
|
|
"name": "count_tokens",
|
|
"arguments": {
|
|
"project_id": "proj-123",
|
|
"agent_id": "agent-456",
|
|
"text": "Hello, world!",
|
|
},
|
|
},
|
|
"id": 1,
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "result" in data
|
|
|
|
def test_count_tokens_with_model(self, test_client: TestClient) -> None:
|
|
"""Test counting tokens with specific model."""
|
|
response = test_client.post(
|
|
"/mcp",
|
|
json={
|
|
"jsonrpc": "2.0",
|
|
"method": "tools/call",
|
|
"params": {
|
|
"name": "count_tokens",
|
|
"arguments": {
|
|
"project_id": "proj-123",
|
|
"agent_id": "agent-456",
|
|
"text": "Hello, world!",
|
|
"model": "gpt-4",
|
|
},
|
|
},
|
|
"id": 1,
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
|
|
|
|
class TestListModelsTool:
|
|
"""Tests for list_models tool."""
|
|
|
|
def test_list_all_models(self, test_client: TestClient) -> None:
|
|
"""Test listing all models."""
|
|
with patch("server.get_model_router") as mock_router:
|
|
mock_router.return_value = MagicMock()
|
|
mock_router.return_value.get_available_models_for_group = AsyncMock(
|
|
return_value=[]
|
|
)
|
|
|
|
response = test_client.post(
|
|
"/mcp",
|
|
json={
|
|
"jsonrpc": "2.0",
|
|
"method": "tools/call",
|
|
"params": {
|
|
"name": "list_models",
|
|
"arguments": {
|
|
"project_id": "proj-123",
|
|
"agent_id": "agent-456",
|
|
},
|
|
},
|
|
"id": 1,
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
|
|
def test_list_models_by_group(self, test_client: TestClient) -> None:
|
|
"""Test listing models by group."""
|
|
with patch("server.get_model_router") as mock_router:
|
|
mock_router.return_value = MagicMock()
|
|
mock_router.return_value.parse_model_group.return_value = (
|
|
ModelGroup.REASONING
|
|
)
|
|
mock_router.return_value.get_available_models_for_group = AsyncMock(
|
|
return_value=[]
|
|
)
|
|
|
|
response = test_client.post(
|
|
"/mcp",
|
|
json={
|
|
"jsonrpc": "2.0",
|
|
"method": "tools/call",
|
|
"params": {
|
|
"name": "list_models",
|
|
"arguments": {
|
|
"project_id": "proj-123",
|
|
"agent_id": "agent-456",
|
|
"model_group": "reasoning",
|
|
},
|
|
},
|
|
"id": 1,
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
|
|
|
|
class TestGetUsageTool:
|
|
"""Tests for get_usage tool."""
|
|
|
|
def test_get_usage(self, test_client: TestClient) -> None:
|
|
"""Test getting usage."""
|
|
with patch("server.get_cost_tracker") as mock_tracker:
|
|
mock_report = MagicMock()
|
|
mock_report.total_requests = 10
|
|
mock_report.total_tokens = 1000
|
|
mock_report.total_cost_usd = 0.50
|
|
mock_report.by_model = {}
|
|
mock_report.period_start.isoformat.return_value = "2024-01-01T00:00:00"
|
|
mock_report.period_end.isoformat.return_value = "2024-01-02T00:00:00"
|
|
|
|
mock_tracker.return_value = MagicMock()
|
|
mock_tracker.return_value.get_project_usage = AsyncMock(
|
|
return_value=mock_report
|
|
)
|
|
mock_tracker.return_value.get_agent_usage = AsyncMock(
|
|
return_value=mock_report
|
|
)
|
|
|
|
response = test_client.post(
|
|
"/mcp",
|
|
json={
|
|
"jsonrpc": "2.0",
|
|
"method": "tools/call",
|
|
"params": {
|
|
"name": "get_usage",
|
|
"arguments": {
|
|
"project_id": "proj-123",
|
|
"agent_id": "agent-456",
|
|
"period": "day",
|
|
},
|
|
},
|
|
"id": 1,
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
|
|
|
|
class TestChatCompletionTool:
|
|
"""Tests for chat_completion tool."""
|
|
|
|
def test_chat_completion_streaming_not_supported(
|
|
self,
|
|
test_client: TestClient,
|
|
) -> None:
|
|
"""Test that streaming returns info message."""
|
|
with patch("server.get_model_router") as mock_router:
|
|
mock_router.return_value = MagicMock()
|
|
mock_router.return_value.select_model = AsyncMock(
|
|
return_value=("claude-opus-4", MagicMock())
|
|
)
|
|
|
|
with patch("server.get_cost_tracker") as mock_tracker:
|
|
mock_tracker.return_value = MagicMock()
|
|
mock_tracker.return_value.check_budget = AsyncMock(
|
|
return_value=(True, 0.0, 100.0)
|
|
)
|
|
|
|
response = test_client.post(
|
|
"/mcp",
|
|
json={
|
|
"jsonrpc": "2.0",
|
|
"method": "tools/call",
|
|
"params": {
|
|
"name": "chat_completion",
|
|
"arguments": {
|
|
"project_id": "proj-123",
|
|
"agent_id": "agent-456",
|
|
"messages": [{"role": "user", "content": "Hello"}],
|
|
"stream": True,
|
|
},
|
|
},
|
|
"id": 1,
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
|
|
def test_chat_completion_success(self, test_client: TestClient) -> None:
|
|
"""Test successful chat completion."""
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [MagicMock()]
|
|
mock_response.choices[0].message.content = "Hello, world!"
|
|
mock_response.choices[0].finish_reason = "stop"
|
|
mock_response.usage = MagicMock()
|
|
mock_response.usage.prompt_tokens = 10
|
|
mock_response.usage.completion_tokens = 5
|
|
|
|
with patch("server.get_model_router") as mock_router:
|
|
mock_model_config = MagicMock()
|
|
mock_model_config.provider.value = "anthropic"
|
|
mock_router.return_value = MagicMock()
|
|
mock_router.return_value.select_model = AsyncMock(
|
|
return_value=("claude-opus-4", mock_model_config)
|
|
)
|
|
|
|
with patch("server.get_cost_tracker") as mock_tracker:
|
|
mock_tracker.return_value = MagicMock()
|
|
mock_tracker.return_value.check_budget = AsyncMock(
|
|
return_value=(True, 0.0, 100.0)
|
|
)
|
|
mock_tracker.return_value.record_usage = AsyncMock()
|
|
|
|
with patch("server.get_provider") as mock_prov:
|
|
mock_prov.return_value = MagicMock()
|
|
mock_prov.return_value.router = MagicMock()
|
|
mock_prov.return_value.router.acompletion = AsyncMock(
|
|
return_value=mock_response
|
|
)
|
|
|
|
with patch("server.get_circuit_registry") as mock_reg:
|
|
mock_circuit = MagicMock()
|
|
mock_circuit.record_success = AsyncMock()
|
|
mock_reg.return_value = MagicMock()
|
|
mock_reg.return_value.get_circuit_sync.return_value = (
|
|
mock_circuit
|
|
)
|
|
|
|
response = test_client.post(
|
|
"/mcp",
|
|
json={
|
|
"jsonrpc": "2.0",
|
|
"method": "tools/call",
|
|
"params": {
|
|
"name": "chat_completion",
|
|
"arguments": {
|
|
"project_id": "proj-123",
|
|
"agent_id": "agent-456",
|
|
"messages": [
|
|
{"role": "user", "content": "Hello"}
|
|
],
|
|
},
|
|
},
|
|
"id": 1,
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|