forked from cardosofelipe/fast-next-template
Improved code readability and uniformity by standardizing line breaks, indentation, and inline conditions across safety-related services, models, and tests, including content filters, validation rules, and emergency controls.
492 lines
16 KiB
Python
492 lines
16 KiB
Python
"""
|
|
Tests for MCP API Routes
|
|
"""
|
|
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from fastapi import status
|
|
from fastapi.testclient import TestClient
|
|
|
|
from app.main import app
|
|
from app.models.user import User
|
|
from app.services.mcp import (
|
|
MCPCircuitOpenError,
|
|
MCPClientManager,
|
|
MCPConnectionError,
|
|
MCPServerNotFoundError,
|
|
MCPTimeoutError,
|
|
MCPToolNotFoundError,
|
|
ServerHealth,
|
|
)
|
|
from app.services.mcp.config import MCPServerConfig, TransportType
|
|
from app.services.mcp.routing import ToolInfo, ToolResult
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_mcp_client():
|
|
"""Create a mock MCP client manager."""
|
|
client = MagicMock(spec=MCPClientManager)
|
|
client.is_initialized = True
|
|
return client
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_superuser():
|
|
"""Create a mock superuser."""
|
|
user = MagicMock(spec=User)
|
|
user.id = "00000000-0000-0000-0000-000000000001"
|
|
user.is_superuser = True
|
|
user.email = "admin@example.com"
|
|
return user
|
|
|
|
|
|
@pytest.fixture
|
|
def client(mock_mcp_client, mock_superuser):
|
|
"""Create a FastAPI test client with mocked dependencies."""
|
|
from app.api.dependencies.permissions import require_superuser
|
|
from app.api.routes.mcp import get_mcp_client
|
|
|
|
# Override dependencies
|
|
async def override_get_mcp_client():
|
|
return mock_mcp_client
|
|
|
|
async def override_require_superuser():
|
|
return mock_superuser
|
|
|
|
app.dependency_overrides[get_mcp_client] = override_get_mcp_client
|
|
app.dependency_overrides[require_superuser] = override_require_superuser
|
|
|
|
with patch("app.main.check_database_health", return_value=True):
|
|
yield TestClient(app)
|
|
|
|
# Clean up
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
class TestListServers:
|
|
"""Tests for GET /mcp/servers endpoint."""
|
|
|
|
def test_list_servers_success(self, client, mock_mcp_client):
|
|
"""Test listing MCP servers returns correct data."""
|
|
# Setup mock
|
|
mock_mcp_client.list_servers.return_value = ["server-1", "server-2"]
|
|
mock_mcp_client.get_server_config.side_effect = [
|
|
MCPServerConfig(
|
|
url="http://server1:8000",
|
|
timeout=30,
|
|
enabled=True,
|
|
transport=TransportType.HTTP,
|
|
description="Server 1",
|
|
),
|
|
MCPServerConfig(
|
|
url="http://server2:8000",
|
|
timeout=60,
|
|
enabled=True,
|
|
transport=TransportType.SSE,
|
|
description="Server 2",
|
|
),
|
|
]
|
|
|
|
response = client.get("/api/v1/mcp/servers")
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["total"] == 2
|
|
assert len(data["servers"]) == 2
|
|
assert data["servers"][0]["name"] == "server-1"
|
|
assert data["servers"][0]["url"] == "http://server1:8000"
|
|
assert data["servers"][1]["name"] == "server-2"
|
|
assert data["servers"][1]["transport"] == "sse"
|
|
|
|
def test_list_servers_empty(self, client, mock_mcp_client):
|
|
"""Test listing servers when none are registered."""
|
|
mock_mcp_client.list_servers.return_value = []
|
|
|
|
response = client.get("/api/v1/mcp/servers")
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["total"] == 0
|
|
assert data["servers"] == []
|
|
|
|
def test_list_servers_handles_not_found(self, client, mock_mcp_client):
|
|
"""Test that missing server configs are skipped gracefully."""
|
|
mock_mcp_client.list_servers.return_value = ["server-1", "missing"]
|
|
mock_mcp_client.get_server_config.side_effect = [
|
|
MCPServerConfig(url="http://server1:8000"),
|
|
MCPServerNotFoundError(server_name="missing"),
|
|
]
|
|
|
|
response = client.get("/api/v1/mcp/servers")
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
# Should only include the successfully retrieved server
|
|
assert data["total"] == 1
|
|
|
|
|
|
class TestListServerTools:
|
|
"""Tests for GET /mcp/servers/{server_name}/tools endpoint."""
|
|
|
|
def test_list_server_tools_success(self, client, mock_mcp_client):
|
|
"""Test listing tools for a specific server."""
|
|
mock_mcp_client.list_tools = AsyncMock(
|
|
return_value=[
|
|
ToolInfo(name="tool1", description="Tool 1", server_name="server-1"),
|
|
ToolInfo(name="tool2", description="Tool 2", server_name="server-1"),
|
|
]
|
|
)
|
|
|
|
response = client.get("/api/v1/mcp/servers/server-1/tools")
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["total"] == 2
|
|
assert data["tools"][0]["name"] == "tool1"
|
|
assert data["tools"][1]["name"] == "tool2"
|
|
|
|
def test_list_server_tools_not_found(self, client, mock_mcp_client):
|
|
"""Test listing tools for non-existent server."""
|
|
mock_mcp_client.list_tools = AsyncMock(
|
|
side_effect=MCPServerNotFoundError(server_name="unknown")
|
|
)
|
|
|
|
response = client.get("/api/v1/mcp/servers/unknown/tools")
|
|
|
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
|
|
|
|
|
class TestListAllTools:
|
|
"""Tests for GET /mcp/tools endpoint."""
|
|
|
|
def test_list_all_tools_success(self, client, mock_mcp_client):
|
|
"""Test listing all tools from all servers."""
|
|
mock_mcp_client.list_all_tools = AsyncMock(
|
|
return_value=[
|
|
ToolInfo(name="tool1", server_name="server-1"),
|
|
ToolInfo(name="tool2", server_name="server-1"),
|
|
ToolInfo(name="tool3", server_name="server-2"),
|
|
]
|
|
)
|
|
|
|
response = client.get("/api/v1/mcp/tools")
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["total"] == 3
|
|
|
|
def test_list_all_tools_empty(self, client, mock_mcp_client):
|
|
"""Test listing tools when none are available."""
|
|
mock_mcp_client.list_all_tools = AsyncMock(return_value=[])
|
|
|
|
response = client.get("/api/v1/mcp/tools")
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["total"] == 0
|
|
|
|
|
|
class TestHealthCheck:
|
|
"""Tests for GET /mcp/health endpoint."""
|
|
|
|
def test_health_check_success(self, client, mock_mcp_client):
|
|
"""Test health check returns correct data."""
|
|
mock_mcp_client.health_check = AsyncMock(
|
|
return_value={
|
|
"server-1": ServerHealth(
|
|
name="server-1",
|
|
healthy=True,
|
|
state="connected",
|
|
url="http://server1:8000",
|
|
tools_count=5,
|
|
),
|
|
"server-2": ServerHealth(
|
|
name="server-2",
|
|
healthy=False,
|
|
state="error",
|
|
url="http://server2:8000",
|
|
error="Connection refused",
|
|
),
|
|
}
|
|
)
|
|
|
|
response = client.get("/api/v1/mcp/health")
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["total"] == 2
|
|
assert data["healthy_count"] == 1
|
|
assert data["unhealthy_count"] == 1
|
|
assert data["servers"]["server-1"]["healthy"] is True
|
|
assert data["servers"]["server-2"]["healthy"] is False
|
|
|
|
def test_health_check_all_healthy(self, client, mock_mcp_client):
|
|
"""Test health check when all servers are healthy."""
|
|
mock_mcp_client.health_check = AsyncMock(
|
|
return_value={
|
|
"server-1": ServerHealth(
|
|
name="server-1",
|
|
healthy=True,
|
|
state="connected",
|
|
url="http://server1:8000",
|
|
),
|
|
}
|
|
)
|
|
|
|
response = client.get("/api/v1/mcp/health")
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["healthy_count"] == 1
|
|
assert data["unhealthy_count"] == 0
|
|
|
|
|
|
class TestCallTool:
|
|
"""Tests for POST /mcp/call endpoint."""
|
|
|
|
def test_call_tool_success(self, client, mock_mcp_client):
|
|
"""Test successful tool execution."""
|
|
mock_mcp_client.call_tool = AsyncMock(
|
|
return_value=ToolResult(
|
|
success=True,
|
|
data={"result": "ok"},
|
|
tool_name="test-tool",
|
|
server_name="server-1",
|
|
execution_time_ms=123.45,
|
|
request_id="test-request-id",
|
|
)
|
|
)
|
|
|
|
response = client.post(
|
|
"/api/v1/mcp/call",
|
|
json={
|
|
"server": "server-1",
|
|
"tool": "test-tool",
|
|
"arguments": {"key": "value"},
|
|
},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
assert data["data"] == {"result": "ok"}
|
|
assert data["tool_name"] == "test-tool"
|
|
assert data["server_name"] == "server-1"
|
|
|
|
def test_call_tool_with_timeout(self, client, mock_mcp_client):
|
|
"""Test tool execution with custom timeout."""
|
|
mock_mcp_client.call_tool = AsyncMock(
|
|
return_value=ToolResult(success=True, data={})
|
|
)
|
|
|
|
response = client.post(
|
|
"/api/v1/mcp/call",
|
|
json={
|
|
"server": "server-1",
|
|
"tool": "test-tool",
|
|
"timeout": 60.0,
|
|
},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
mock_mcp_client.call_tool.assert_called_once()
|
|
call_args = mock_mcp_client.call_tool.call_args
|
|
assert call_args.kwargs["timeout"] == 60.0
|
|
|
|
def test_call_tool_server_not_found(self, client, mock_mcp_client):
|
|
"""Test tool execution with non-existent server."""
|
|
mock_mcp_client.call_tool = AsyncMock(
|
|
side_effect=MCPServerNotFoundError(server_name="unknown")
|
|
)
|
|
|
|
response = client.post(
|
|
"/api/v1/mcp/call",
|
|
json={"server": "unknown", "tool": "test-tool"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
|
|
|
def test_call_tool_not_found(self, client, mock_mcp_client):
|
|
"""Test tool execution with non-existent tool."""
|
|
mock_mcp_client.call_tool = AsyncMock(
|
|
side_effect=MCPToolNotFoundError(tool_name="unknown")
|
|
)
|
|
|
|
response = client.post(
|
|
"/api/v1/mcp/call",
|
|
json={"server": "server-1", "tool": "unknown"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
|
|
|
def test_call_tool_timeout(self, client, mock_mcp_client):
|
|
"""Test tool execution timeout."""
|
|
mock_mcp_client.call_tool = AsyncMock(
|
|
side_effect=MCPTimeoutError(
|
|
"Request timed out",
|
|
server_name="server-1",
|
|
timeout_seconds=30.0,
|
|
)
|
|
)
|
|
|
|
response = client.post(
|
|
"/api/v1/mcp/call",
|
|
json={"server": "server-1", "tool": "slow-tool"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_504_GATEWAY_TIMEOUT
|
|
|
|
def test_call_tool_connection_error(self, client, mock_mcp_client):
|
|
"""Test tool execution with connection failure."""
|
|
mock_mcp_client.call_tool = AsyncMock(
|
|
side_effect=MCPConnectionError(
|
|
"Connection refused",
|
|
server_name="server-1",
|
|
)
|
|
)
|
|
|
|
response = client.post(
|
|
"/api/v1/mcp/call",
|
|
json={"server": "server-1", "tool": "test-tool"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_502_BAD_GATEWAY
|
|
|
|
def test_call_tool_circuit_open(self, client, mock_mcp_client):
|
|
"""Test tool execution with open circuit breaker."""
|
|
mock_mcp_client.call_tool = AsyncMock(
|
|
side_effect=MCPCircuitOpenError(
|
|
server_name="server-1",
|
|
failure_count=5,
|
|
)
|
|
)
|
|
|
|
response = client.post(
|
|
"/api/v1/mcp/call",
|
|
json={"server": "server-1", "tool": "test-tool"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
|
|
|
|
|
|
class TestCircuitBreakers:
|
|
"""Tests for circuit breaker endpoints."""
|
|
|
|
def test_list_circuit_breakers(self, client, mock_mcp_client):
|
|
"""Test listing circuit breaker statuses."""
|
|
mock_mcp_client.get_circuit_breaker_status.return_value = {
|
|
"server-1": {"state": "closed", "failure_count": 0},
|
|
"server-2": {"state": "open", "failure_count": 5},
|
|
}
|
|
|
|
response = client.get("/api/v1/mcp/circuit-breakers")
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert len(data["circuit_breakers"]) == 2
|
|
|
|
def test_list_circuit_breakers_empty(self, client, mock_mcp_client):
|
|
"""Test listing when no circuit breakers exist."""
|
|
mock_mcp_client.get_circuit_breaker_status.return_value = {}
|
|
|
|
response = client.get("/api/v1/mcp/circuit-breakers")
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["circuit_breakers"] == []
|
|
|
|
def test_reset_circuit_breaker_success(self, client, mock_mcp_client):
|
|
"""Test successfully resetting a circuit breaker."""
|
|
mock_mcp_client.reset_circuit_breaker = AsyncMock(return_value=True)
|
|
|
|
response = client.post("/api/v1/mcp/circuit-breakers/server-1/reset")
|
|
|
|
assert response.status_code == status.HTTP_204_NO_CONTENT
|
|
|
|
def test_reset_circuit_breaker_not_found(self, client, mock_mcp_client):
|
|
"""Test resetting non-existent circuit breaker."""
|
|
mock_mcp_client.reset_circuit_breaker = AsyncMock(return_value=False)
|
|
|
|
response = client.post("/api/v1/mcp/circuit-breakers/unknown/reset")
|
|
|
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
|
|
|
|
|
class TestReconnectServer:
|
|
"""Tests for POST /mcp/servers/{server_name}/reconnect endpoint."""
|
|
|
|
def test_reconnect_success(self, client, mock_mcp_client):
|
|
"""Test successful server reconnection."""
|
|
mock_mcp_client.disconnect = AsyncMock()
|
|
mock_mcp_client.connect = AsyncMock()
|
|
|
|
response = client.post("/api/v1/mcp/servers/server-1/reconnect")
|
|
|
|
assert response.status_code == status.HTTP_204_NO_CONTENT
|
|
mock_mcp_client.disconnect.assert_called_once_with("server-1")
|
|
mock_mcp_client.connect.assert_called_once_with("server-1")
|
|
|
|
def test_reconnect_server_not_found(self, client, mock_mcp_client):
|
|
"""Test reconnecting to non-existent server."""
|
|
mock_mcp_client.disconnect = AsyncMock(
|
|
side_effect=MCPServerNotFoundError(server_name="unknown")
|
|
)
|
|
|
|
response = client.post("/api/v1/mcp/servers/unknown/reconnect")
|
|
|
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
|
|
|
def test_reconnect_connection_failure(self, client, mock_mcp_client):
|
|
"""Test reconnection failure."""
|
|
mock_mcp_client.disconnect = AsyncMock()
|
|
mock_mcp_client.connect = AsyncMock(
|
|
side_effect=MCPConnectionError(
|
|
"Connection refused",
|
|
server_name="server-1",
|
|
)
|
|
)
|
|
|
|
response = client.post("/api/v1/mcp/servers/server-1/reconnect")
|
|
|
|
assert response.status_code == status.HTTP_502_BAD_GATEWAY
|
|
|
|
|
|
class TestMCPEndpointsEdgeCases:
|
|
"""Edge case tests for MCP endpoints."""
|
|
|
|
def test_servers_content_type(self, client, mock_mcp_client):
|
|
"""Test that endpoints return JSON content type."""
|
|
mock_mcp_client.list_servers.return_value = []
|
|
|
|
response = client.get("/api/v1/mcp/servers")
|
|
|
|
assert "application/json" in response.headers["content-type"]
|
|
|
|
def test_call_tool_validation_error(self, client):
|
|
"""Test that invalid request body returns validation error."""
|
|
response = client.post(
|
|
"/api/v1/mcp/call",
|
|
json={}, # Missing required fields
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
|
|
|
def test_call_tool_missing_server(self, client):
|
|
"""Test that missing server field returns validation error."""
|
|
response = client.post(
|
|
"/api/v1/mcp/call",
|
|
json={"tool": "test-tool"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
|
|
|
def test_call_tool_missing_tool(self, client):
|
|
"""Test that missing tool field returns validation error."""
|
|
response = client.post(
|
|
"/api/v1/mcp/call",
|
|
json={"server": "server-1"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|