Files
syndarix/backend/tests/api/routes/test_mcp.py
Felipe Cardoso 520c06175e refactor(safety): apply consistent formatting across services and tests
Improved code readability and uniformity by standardizing line breaks, indentation, and inline conditions across safety-related services, models, and tests, including content filters, validation rules, and emergency controls.
2026-01-03 16:23:39 +01:00

492 lines
16 KiB
Python

"""
Tests for MCP API Routes
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import status
from fastapi.testclient import TestClient
from app.main import app
from app.models.user import User
from app.services.mcp import (
MCPCircuitOpenError,
MCPClientManager,
MCPConnectionError,
MCPServerNotFoundError,
MCPTimeoutError,
MCPToolNotFoundError,
ServerHealth,
)
from app.services.mcp.config import MCPServerConfig, TransportType
from app.services.mcp.routing import ToolInfo, ToolResult
@pytest.fixture
def mock_mcp_client():
"""Create a mock MCP client manager."""
client = MagicMock(spec=MCPClientManager)
client.is_initialized = True
return client
@pytest.fixture
def mock_superuser():
"""Create a mock superuser."""
user = MagicMock(spec=User)
user.id = "00000000-0000-0000-0000-000000000001"
user.is_superuser = True
user.email = "admin@example.com"
return user
@pytest.fixture
def client(mock_mcp_client, mock_superuser):
"""Create a FastAPI test client with mocked dependencies."""
from app.api.dependencies.permissions import require_superuser
from app.api.routes.mcp import get_mcp_client
# Override dependencies
async def override_get_mcp_client():
return mock_mcp_client
async def override_require_superuser():
return mock_superuser
app.dependency_overrides[get_mcp_client] = override_get_mcp_client
app.dependency_overrides[require_superuser] = override_require_superuser
with patch("app.main.check_database_health", return_value=True):
yield TestClient(app)
# Clean up
app.dependency_overrides.clear()
class TestListServers:
"""Tests for GET /mcp/servers endpoint."""
def test_list_servers_success(self, client, mock_mcp_client):
"""Test listing MCP servers returns correct data."""
# Setup mock
mock_mcp_client.list_servers.return_value = ["server-1", "server-2"]
mock_mcp_client.get_server_config.side_effect = [
MCPServerConfig(
url="http://server1:8000",
timeout=30,
enabled=True,
transport=TransportType.HTTP,
description="Server 1",
),
MCPServerConfig(
url="http://server2:8000",
timeout=60,
enabled=True,
transport=TransportType.SSE,
description="Server 2",
),
]
response = client.get("/api/v1/mcp/servers")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["total"] == 2
assert len(data["servers"]) == 2
assert data["servers"][0]["name"] == "server-1"
assert data["servers"][0]["url"] == "http://server1:8000"
assert data["servers"][1]["name"] == "server-2"
assert data["servers"][1]["transport"] == "sse"
def test_list_servers_empty(self, client, mock_mcp_client):
"""Test listing servers when none are registered."""
mock_mcp_client.list_servers.return_value = []
response = client.get("/api/v1/mcp/servers")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["total"] == 0
assert data["servers"] == []
def test_list_servers_handles_not_found(self, client, mock_mcp_client):
"""Test that missing server configs are skipped gracefully."""
mock_mcp_client.list_servers.return_value = ["server-1", "missing"]
mock_mcp_client.get_server_config.side_effect = [
MCPServerConfig(url="http://server1:8000"),
MCPServerNotFoundError(server_name="missing"),
]
response = client.get("/api/v1/mcp/servers")
assert response.status_code == status.HTTP_200_OK
data = response.json()
# Should only include the successfully retrieved server
assert data["total"] == 1
class TestListServerTools:
"""Tests for GET /mcp/servers/{server_name}/tools endpoint."""
def test_list_server_tools_success(self, client, mock_mcp_client):
"""Test listing tools for a specific server."""
mock_mcp_client.list_tools = AsyncMock(
return_value=[
ToolInfo(name="tool1", description="Tool 1", server_name="server-1"),
ToolInfo(name="tool2", description="Tool 2", server_name="server-1"),
]
)
response = client.get("/api/v1/mcp/servers/server-1/tools")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["total"] == 2
assert data["tools"][0]["name"] == "tool1"
assert data["tools"][1]["name"] == "tool2"
def test_list_server_tools_not_found(self, client, mock_mcp_client):
"""Test listing tools for non-existent server."""
mock_mcp_client.list_tools = AsyncMock(
side_effect=MCPServerNotFoundError(server_name="unknown")
)
response = client.get("/api/v1/mcp/servers/unknown/tools")
assert response.status_code == status.HTTP_404_NOT_FOUND
class TestListAllTools:
"""Tests for GET /mcp/tools endpoint."""
def test_list_all_tools_success(self, client, mock_mcp_client):
"""Test listing all tools from all servers."""
mock_mcp_client.list_all_tools = AsyncMock(
return_value=[
ToolInfo(name="tool1", server_name="server-1"),
ToolInfo(name="tool2", server_name="server-1"),
ToolInfo(name="tool3", server_name="server-2"),
]
)
response = client.get("/api/v1/mcp/tools")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["total"] == 3
def test_list_all_tools_empty(self, client, mock_mcp_client):
"""Test listing tools when none are available."""
mock_mcp_client.list_all_tools = AsyncMock(return_value=[])
response = client.get("/api/v1/mcp/tools")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["total"] == 0
class TestHealthCheck:
"""Tests for GET /mcp/health endpoint."""
def test_health_check_success(self, client, mock_mcp_client):
"""Test health check returns correct data."""
mock_mcp_client.health_check = AsyncMock(
return_value={
"server-1": ServerHealth(
name="server-1",
healthy=True,
state="connected",
url="http://server1:8000",
tools_count=5,
),
"server-2": ServerHealth(
name="server-2",
healthy=False,
state="error",
url="http://server2:8000",
error="Connection refused",
),
}
)
response = client.get("/api/v1/mcp/health")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["total"] == 2
assert data["healthy_count"] == 1
assert data["unhealthy_count"] == 1
assert data["servers"]["server-1"]["healthy"] is True
assert data["servers"]["server-2"]["healthy"] is False
def test_health_check_all_healthy(self, client, mock_mcp_client):
"""Test health check when all servers are healthy."""
mock_mcp_client.health_check = AsyncMock(
return_value={
"server-1": ServerHealth(
name="server-1",
healthy=True,
state="connected",
url="http://server1:8000",
),
}
)
response = client.get("/api/v1/mcp/health")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["healthy_count"] == 1
assert data["unhealthy_count"] == 0
class TestCallTool:
"""Tests for POST /mcp/call endpoint."""
def test_call_tool_success(self, client, mock_mcp_client):
"""Test successful tool execution."""
mock_mcp_client.call_tool = AsyncMock(
return_value=ToolResult(
success=True,
data={"result": "ok"},
tool_name="test-tool",
server_name="server-1",
execution_time_ms=123.45,
request_id="test-request-id",
)
)
response = client.post(
"/api/v1/mcp/call",
json={
"server": "server-1",
"tool": "test-tool",
"arguments": {"key": "value"},
},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["success"] is True
assert data["data"] == {"result": "ok"}
assert data["tool_name"] == "test-tool"
assert data["server_name"] == "server-1"
def test_call_tool_with_timeout(self, client, mock_mcp_client):
"""Test tool execution with custom timeout."""
mock_mcp_client.call_tool = AsyncMock(
return_value=ToolResult(success=True, data={})
)
response = client.post(
"/api/v1/mcp/call",
json={
"server": "server-1",
"tool": "test-tool",
"timeout": 60.0,
},
)
assert response.status_code == status.HTTP_200_OK
mock_mcp_client.call_tool.assert_called_once()
call_args = mock_mcp_client.call_tool.call_args
assert call_args.kwargs["timeout"] == 60.0
def test_call_tool_server_not_found(self, client, mock_mcp_client):
"""Test tool execution with non-existent server."""
mock_mcp_client.call_tool = AsyncMock(
side_effect=MCPServerNotFoundError(server_name="unknown")
)
response = client.post(
"/api/v1/mcp/call",
json={"server": "unknown", "tool": "test-tool"},
)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_call_tool_not_found(self, client, mock_mcp_client):
"""Test tool execution with non-existent tool."""
mock_mcp_client.call_tool = AsyncMock(
side_effect=MCPToolNotFoundError(tool_name="unknown")
)
response = client.post(
"/api/v1/mcp/call",
json={"server": "server-1", "tool": "unknown"},
)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_call_tool_timeout(self, client, mock_mcp_client):
"""Test tool execution timeout."""
mock_mcp_client.call_tool = AsyncMock(
side_effect=MCPTimeoutError(
"Request timed out",
server_name="server-1",
timeout_seconds=30.0,
)
)
response = client.post(
"/api/v1/mcp/call",
json={"server": "server-1", "tool": "slow-tool"},
)
assert response.status_code == status.HTTP_504_GATEWAY_TIMEOUT
def test_call_tool_connection_error(self, client, mock_mcp_client):
"""Test tool execution with connection failure."""
mock_mcp_client.call_tool = AsyncMock(
side_effect=MCPConnectionError(
"Connection refused",
server_name="server-1",
)
)
response = client.post(
"/api/v1/mcp/call",
json={"server": "server-1", "tool": "test-tool"},
)
assert response.status_code == status.HTTP_502_BAD_GATEWAY
def test_call_tool_circuit_open(self, client, mock_mcp_client):
"""Test tool execution with open circuit breaker."""
mock_mcp_client.call_tool = AsyncMock(
side_effect=MCPCircuitOpenError(
server_name="server-1",
failure_count=5,
)
)
response = client.post(
"/api/v1/mcp/call",
json={"server": "server-1", "tool": "test-tool"},
)
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
class TestCircuitBreakers:
"""Tests for circuit breaker endpoints."""
def test_list_circuit_breakers(self, client, mock_mcp_client):
"""Test listing circuit breaker statuses."""
mock_mcp_client.get_circuit_breaker_status.return_value = {
"server-1": {"state": "closed", "failure_count": 0},
"server-2": {"state": "open", "failure_count": 5},
}
response = client.get("/api/v1/mcp/circuit-breakers")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert len(data["circuit_breakers"]) == 2
def test_list_circuit_breakers_empty(self, client, mock_mcp_client):
"""Test listing when no circuit breakers exist."""
mock_mcp_client.get_circuit_breaker_status.return_value = {}
response = client.get("/api/v1/mcp/circuit-breakers")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["circuit_breakers"] == []
def test_reset_circuit_breaker_success(self, client, mock_mcp_client):
"""Test successfully resetting a circuit breaker."""
mock_mcp_client.reset_circuit_breaker = AsyncMock(return_value=True)
response = client.post("/api/v1/mcp/circuit-breakers/server-1/reset")
assert response.status_code == status.HTTP_204_NO_CONTENT
def test_reset_circuit_breaker_not_found(self, client, mock_mcp_client):
"""Test resetting non-existent circuit breaker."""
mock_mcp_client.reset_circuit_breaker = AsyncMock(return_value=False)
response = client.post("/api/v1/mcp/circuit-breakers/unknown/reset")
assert response.status_code == status.HTTP_404_NOT_FOUND
class TestReconnectServer:
"""Tests for POST /mcp/servers/{server_name}/reconnect endpoint."""
def test_reconnect_success(self, client, mock_mcp_client):
"""Test successful server reconnection."""
mock_mcp_client.disconnect = AsyncMock()
mock_mcp_client.connect = AsyncMock()
response = client.post("/api/v1/mcp/servers/server-1/reconnect")
assert response.status_code == status.HTTP_204_NO_CONTENT
mock_mcp_client.disconnect.assert_called_once_with("server-1")
mock_mcp_client.connect.assert_called_once_with("server-1")
def test_reconnect_server_not_found(self, client, mock_mcp_client):
"""Test reconnecting to non-existent server."""
mock_mcp_client.disconnect = AsyncMock(
side_effect=MCPServerNotFoundError(server_name="unknown")
)
response = client.post("/api/v1/mcp/servers/unknown/reconnect")
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_reconnect_connection_failure(self, client, mock_mcp_client):
"""Test reconnection failure."""
mock_mcp_client.disconnect = AsyncMock()
mock_mcp_client.connect = AsyncMock(
side_effect=MCPConnectionError(
"Connection refused",
server_name="server-1",
)
)
response = client.post("/api/v1/mcp/servers/server-1/reconnect")
assert response.status_code == status.HTTP_502_BAD_GATEWAY
class TestMCPEndpointsEdgeCases:
"""Edge case tests for MCP endpoints."""
def test_servers_content_type(self, client, mock_mcp_client):
"""Test that endpoints return JSON content type."""
mock_mcp_client.list_servers.return_value = []
response = client.get("/api/v1/mcp/servers")
assert "application/json" in response.headers["content-type"]
def test_call_tool_validation_error(self, client):
"""Test that invalid request body returns validation error."""
response = client.post(
"/api/v1/mcp/call",
json={}, # Missing required fields
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
def test_call_tool_missing_server(self, client):
"""Test that missing server field returns validation error."""
response = client.post(
"/api/v1/mcp/call",
json={"tool": "test-tool"},
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
def test_call_tool_missing_tool(self, client):
"""Test that missing tool field returns validation error."""
response = client.post(
"/api/v1/mcp/call",
json={"server": "server-1"},
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY