feat(backend): implement MCP client infrastructure (#55)
Core MCP client implementation with comprehensive tooling:
**Services:**
- MCPClientManager: Main facade for all MCP operations
- MCPServerRegistry: Thread-safe singleton for server configs
- ConnectionPool: Connection pooling with auto-reconnection
- ToolRouter: Automatic tool routing with circuit breaker
- AsyncCircuitBreaker: Custom async-compatible circuit breaker
**Configuration:**
- YAML-based config with Pydantic models
- Environment variable expansion support
- Transport types: HTTP, SSE, STDIO
**API Endpoints:**
- GET /mcp/servers - List all MCP servers
- GET /mcp/servers/{name}/tools - List server tools
- GET /mcp/tools - List all tools from all servers
- GET /mcp/health - Health check all servers
- POST /mcp/call - Execute tool (admin only)
- GET /mcp/circuit-breakers - Circuit breaker status
- POST /mcp/circuit-breakers/{name}/reset - Reset circuit breaker
- POST /mcp/servers/{name}/reconnect - Force reconnection
**Testing:**
- 156 unit tests with comprehensive coverage
- Tests for all services, routes, and error handling
- Proper mocking and async test support
**Documentation:**
- MCP_CLIENT.md with usage examples
- Phase 2+ workflow documentation
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
491
backend/tests/api/routes/test_mcp.py
Normal file
491
backend/tests/api/routes/test_mcp.py
Normal file
@@ -0,0 +1,491 @@
|
||||
"""
|
||||
Tests for MCP API Routes
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import status
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.main import app
|
||||
from app.models.user import User
|
||||
from app.services.mcp import (
|
||||
MCPCircuitOpenError,
|
||||
MCPClientManager,
|
||||
MCPConnectionError,
|
||||
MCPServerNotFoundError,
|
||||
MCPTimeoutError,
|
||||
MCPToolNotFoundError,
|
||||
ServerHealth,
|
||||
)
|
||||
from app.services.mcp.config import MCPServerConfig, TransportType
|
||||
from app.services.mcp.routing import ToolInfo, ToolResult
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mcp_client():
|
||||
"""Create a mock MCP client manager."""
|
||||
client = MagicMock(spec=MCPClientManager)
|
||||
client.is_initialized = True
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_superuser():
|
||||
"""Create a mock superuser."""
|
||||
user = MagicMock(spec=User)
|
||||
user.id = "00000000-0000-0000-0000-000000000001"
|
||||
user.is_superuser = True
|
||||
user.email = "admin@example.com"
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_mcp_client, mock_superuser):
|
||||
"""Create a FastAPI test client with mocked dependencies."""
|
||||
from app.api.routes.mcp import get_mcp_client
|
||||
from app.api.dependencies.permissions import require_superuser
|
||||
|
||||
# Override dependencies
|
||||
async def override_get_mcp_client():
|
||||
return mock_mcp_client
|
||||
|
||||
async def override_require_superuser():
|
||||
return mock_superuser
|
||||
|
||||
app.dependency_overrides[get_mcp_client] = override_get_mcp_client
|
||||
app.dependency_overrides[require_superuser] = override_require_superuser
|
||||
|
||||
with patch("app.main.check_database_health", return_value=True):
|
||||
yield TestClient(app)
|
||||
|
||||
# Clean up
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
class TestListServers:
|
||||
"""Tests for GET /mcp/servers endpoint."""
|
||||
|
||||
def test_list_servers_success(self, client, mock_mcp_client):
|
||||
"""Test listing MCP servers returns correct data."""
|
||||
# Setup mock
|
||||
mock_mcp_client.list_servers.return_value = ["server-1", "server-2"]
|
||||
mock_mcp_client.get_server_config.side_effect = [
|
||||
MCPServerConfig(
|
||||
url="http://server1:8000",
|
||||
timeout=30,
|
||||
enabled=True,
|
||||
transport=TransportType.HTTP,
|
||||
description="Server 1",
|
||||
),
|
||||
MCPServerConfig(
|
||||
url="http://server2:8000",
|
||||
timeout=60,
|
||||
enabled=True,
|
||||
transport=TransportType.SSE,
|
||||
description="Server 2",
|
||||
),
|
||||
]
|
||||
|
||||
response = client.get("/api/v1/mcp/servers")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["total"] == 2
|
||||
assert len(data["servers"]) == 2
|
||||
assert data["servers"][0]["name"] == "server-1"
|
||||
assert data["servers"][0]["url"] == "http://server1:8000"
|
||||
assert data["servers"][1]["name"] == "server-2"
|
||||
assert data["servers"][1]["transport"] == "sse"
|
||||
|
||||
def test_list_servers_empty(self, client, mock_mcp_client):
|
||||
"""Test listing servers when none are registered."""
|
||||
mock_mcp_client.list_servers.return_value = []
|
||||
|
||||
response = client.get("/api/v1/mcp/servers")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["total"] == 0
|
||||
assert data["servers"] == []
|
||||
|
||||
def test_list_servers_handles_not_found(self, client, mock_mcp_client):
|
||||
"""Test that missing server configs are skipped gracefully."""
|
||||
mock_mcp_client.list_servers.return_value = ["server-1", "missing"]
|
||||
mock_mcp_client.get_server_config.side_effect = [
|
||||
MCPServerConfig(url="http://server1:8000"),
|
||||
MCPServerNotFoundError(server_name="missing"),
|
||||
]
|
||||
|
||||
response = client.get("/api/v1/mcp/servers")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
# Should only include the successfully retrieved server
|
||||
assert data["total"] == 1
|
||||
|
||||
|
||||
class TestListServerTools:
|
||||
"""Tests for GET /mcp/servers/{server_name}/tools endpoint."""
|
||||
|
||||
def test_list_server_tools_success(self, client, mock_mcp_client):
|
||||
"""Test listing tools for a specific server."""
|
||||
mock_mcp_client.list_tools = AsyncMock(
|
||||
return_value=[
|
||||
ToolInfo(name="tool1", description="Tool 1", server_name="server-1"),
|
||||
ToolInfo(name="tool2", description="Tool 2", server_name="server-1"),
|
||||
]
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/mcp/servers/server-1/tools")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["total"] == 2
|
||||
assert data["tools"][0]["name"] == "tool1"
|
||||
assert data["tools"][1]["name"] == "tool2"
|
||||
|
||||
def test_list_server_tools_not_found(self, client, mock_mcp_client):
|
||||
"""Test listing tools for non-existent server."""
|
||||
mock_mcp_client.list_tools = AsyncMock(
|
||||
side_effect=MCPServerNotFoundError(server_name="unknown")
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/mcp/servers/unknown/tools")
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
class TestListAllTools:
|
||||
"""Tests for GET /mcp/tools endpoint."""
|
||||
|
||||
def test_list_all_tools_success(self, client, mock_mcp_client):
|
||||
"""Test listing all tools from all servers."""
|
||||
mock_mcp_client.list_all_tools = AsyncMock(
|
||||
return_value=[
|
||||
ToolInfo(name="tool1", server_name="server-1"),
|
||||
ToolInfo(name="tool2", server_name="server-1"),
|
||||
ToolInfo(name="tool3", server_name="server-2"),
|
||||
]
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/mcp/tools")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["total"] == 3
|
||||
|
||||
def test_list_all_tools_empty(self, client, mock_mcp_client):
|
||||
"""Test listing tools when none are available."""
|
||||
mock_mcp_client.list_all_tools = AsyncMock(return_value=[])
|
||||
|
||||
response = client.get("/api/v1/mcp/tools")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["total"] == 0
|
||||
|
||||
|
||||
class TestHealthCheck:
|
||||
"""Tests for GET /mcp/health endpoint."""
|
||||
|
||||
def test_health_check_success(self, client, mock_mcp_client):
|
||||
"""Test health check returns correct data."""
|
||||
mock_mcp_client.health_check = AsyncMock(
|
||||
return_value={
|
||||
"server-1": ServerHealth(
|
||||
name="server-1",
|
||||
healthy=True,
|
||||
state="connected",
|
||||
url="http://server1:8000",
|
||||
tools_count=5,
|
||||
),
|
||||
"server-2": ServerHealth(
|
||||
name="server-2",
|
||||
healthy=False,
|
||||
state="error",
|
||||
url="http://server2:8000",
|
||||
error="Connection refused",
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/mcp/health")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["total"] == 2
|
||||
assert data["healthy_count"] == 1
|
||||
assert data["unhealthy_count"] == 1
|
||||
assert data["servers"]["server-1"]["healthy"] is True
|
||||
assert data["servers"]["server-2"]["healthy"] is False
|
||||
|
||||
def test_health_check_all_healthy(self, client, mock_mcp_client):
|
||||
"""Test health check when all servers are healthy."""
|
||||
mock_mcp_client.health_check = AsyncMock(
|
||||
return_value={
|
||||
"server-1": ServerHealth(
|
||||
name="server-1",
|
||||
healthy=True,
|
||||
state="connected",
|
||||
url="http://server1:8000",
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/mcp/health")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["healthy_count"] == 1
|
||||
assert data["unhealthy_count"] == 0
|
||||
|
||||
|
||||
class TestCallTool:
|
||||
"""Tests for POST /mcp/call endpoint."""
|
||||
|
||||
def test_call_tool_success(self, client, mock_mcp_client):
|
||||
"""Test successful tool execution."""
|
||||
mock_mcp_client.call_tool = AsyncMock(
|
||||
return_value=ToolResult(
|
||||
success=True,
|
||||
data={"result": "ok"},
|
||||
tool_name="test-tool",
|
||||
server_name="server-1",
|
||||
execution_time_ms=123.45,
|
||||
request_id="test-request-id",
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/mcp/call",
|
||||
json={
|
||||
"server": "server-1",
|
||||
"tool": "test-tool",
|
||||
"arguments": {"key": "value"},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"] == {"result": "ok"}
|
||||
assert data["tool_name"] == "test-tool"
|
||||
assert data["server_name"] == "server-1"
|
||||
|
||||
def test_call_tool_with_timeout(self, client, mock_mcp_client):
|
||||
"""Test tool execution with custom timeout."""
|
||||
mock_mcp_client.call_tool = AsyncMock(
|
||||
return_value=ToolResult(success=True, data={})
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/mcp/call",
|
||||
json={
|
||||
"server": "server-1",
|
||||
"tool": "test-tool",
|
||||
"timeout": 60.0,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
mock_mcp_client.call_tool.assert_called_once()
|
||||
call_args = mock_mcp_client.call_tool.call_args
|
||||
assert call_args.kwargs["timeout"] == 60.0
|
||||
|
||||
def test_call_tool_server_not_found(self, client, mock_mcp_client):
|
||||
"""Test tool execution with non-existent server."""
|
||||
mock_mcp_client.call_tool = AsyncMock(
|
||||
side_effect=MCPServerNotFoundError(server_name="unknown")
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/mcp/call",
|
||||
json={"server": "unknown", "tool": "test-tool"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
def test_call_tool_not_found(self, client, mock_mcp_client):
|
||||
"""Test tool execution with non-existent tool."""
|
||||
mock_mcp_client.call_tool = AsyncMock(
|
||||
side_effect=MCPToolNotFoundError(tool_name="unknown")
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/mcp/call",
|
||||
json={"server": "server-1", "tool": "unknown"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
def test_call_tool_timeout(self, client, mock_mcp_client):
|
||||
"""Test tool execution timeout."""
|
||||
mock_mcp_client.call_tool = AsyncMock(
|
||||
side_effect=MCPTimeoutError(
|
||||
"Request timed out",
|
||||
server_name="server-1",
|
||||
timeout_seconds=30.0,
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/mcp/call",
|
||||
json={"server": "server-1", "tool": "slow-tool"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_504_GATEWAY_TIMEOUT
|
||||
|
||||
def test_call_tool_connection_error(self, client, mock_mcp_client):
|
||||
"""Test tool execution with connection failure."""
|
||||
mock_mcp_client.call_tool = AsyncMock(
|
||||
side_effect=MCPConnectionError(
|
||||
"Connection refused",
|
||||
server_name="server-1",
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/mcp/call",
|
||||
json={"server": "server-1", "tool": "test-tool"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_502_BAD_GATEWAY
|
||||
|
||||
def test_call_tool_circuit_open(self, client, mock_mcp_client):
|
||||
"""Test tool execution with open circuit breaker."""
|
||||
mock_mcp_client.call_tool = AsyncMock(
|
||||
side_effect=MCPCircuitOpenError(
|
||||
server_name="server-1",
|
||||
failure_count=5,
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/mcp/call",
|
||||
json={"server": "server-1", "tool": "test-tool"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
|
||||
|
||||
class TestCircuitBreakers:
|
||||
"""Tests for circuit breaker endpoints."""
|
||||
|
||||
def test_list_circuit_breakers(self, client, mock_mcp_client):
|
||||
"""Test listing circuit breaker statuses."""
|
||||
mock_mcp_client.get_circuit_breaker_status.return_value = {
|
||||
"server-1": {"state": "closed", "failure_count": 0},
|
||||
"server-2": {"state": "open", "failure_count": 5},
|
||||
}
|
||||
|
||||
response = client.get("/api/v1/mcp/circuit-breakers")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert len(data["circuit_breakers"]) == 2
|
||||
|
||||
def test_list_circuit_breakers_empty(self, client, mock_mcp_client):
|
||||
"""Test listing when no circuit breakers exist."""
|
||||
mock_mcp_client.get_circuit_breaker_status.return_value = {}
|
||||
|
||||
response = client.get("/api/v1/mcp/circuit-breakers")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["circuit_breakers"] == []
|
||||
|
||||
def test_reset_circuit_breaker_success(self, client, mock_mcp_client):
|
||||
"""Test successfully resetting a circuit breaker."""
|
||||
mock_mcp_client.reset_circuit_breaker = AsyncMock(return_value=True)
|
||||
|
||||
response = client.post("/api/v1/mcp/circuit-breakers/server-1/reset")
|
||||
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
|
||||
def test_reset_circuit_breaker_not_found(self, client, mock_mcp_client):
|
||||
"""Test resetting non-existent circuit breaker."""
|
||||
mock_mcp_client.reset_circuit_breaker = AsyncMock(return_value=False)
|
||||
|
||||
response = client.post("/api/v1/mcp/circuit-breakers/unknown/reset")
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
class TestReconnectServer:
|
||||
"""Tests for POST /mcp/servers/{server_name}/reconnect endpoint."""
|
||||
|
||||
def test_reconnect_success(self, client, mock_mcp_client):
|
||||
"""Test successful server reconnection."""
|
||||
mock_mcp_client.disconnect = AsyncMock()
|
||||
mock_mcp_client.connect = AsyncMock()
|
||||
|
||||
response = client.post("/api/v1/mcp/servers/server-1/reconnect")
|
||||
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
mock_mcp_client.disconnect.assert_called_once_with("server-1")
|
||||
mock_mcp_client.connect.assert_called_once_with("server-1")
|
||||
|
||||
def test_reconnect_server_not_found(self, client, mock_mcp_client):
|
||||
"""Test reconnecting to non-existent server."""
|
||||
mock_mcp_client.disconnect = AsyncMock(
|
||||
side_effect=MCPServerNotFoundError(server_name="unknown")
|
||||
)
|
||||
|
||||
response = client.post("/api/v1/mcp/servers/unknown/reconnect")
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
def test_reconnect_connection_failure(self, client, mock_mcp_client):
|
||||
"""Test reconnection failure."""
|
||||
mock_mcp_client.disconnect = AsyncMock()
|
||||
mock_mcp_client.connect = AsyncMock(
|
||||
side_effect=MCPConnectionError(
|
||||
"Connection refused",
|
||||
server_name="server-1",
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post("/api/v1/mcp/servers/server-1/reconnect")
|
||||
|
||||
assert response.status_code == status.HTTP_502_BAD_GATEWAY
|
||||
|
||||
|
||||
class TestMCPEndpointsEdgeCases:
|
||||
"""Edge case tests for MCP endpoints."""
|
||||
|
||||
def test_servers_content_type(self, client, mock_mcp_client):
|
||||
"""Test that endpoints return JSON content type."""
|
||||
mock_mcp_client.list_servers.return_value = []
|
||||
|
||||
response = client.get("/api/v1/mcp/servers")
|
||||
|
||||
assert "application/json" in response.headers["content-type"]
|
||||
|
||||
def test_call_tool_validation_error(self, client):
|
||||
"""Test that invalid request body returns validation error."""
|
||||
response = client.post(
|
||||
"/api/v1/mcp/call",
|
||||
json={}, # Missing required fields
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
def test_call_tool_missing_server(self, client):
|
||||
"""Test that missing server field returns validation error."""
|
||||
response = client.post(
|
||||
"/api/v1/mcp/call",
|
||||
json={"tool": "test-tool"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
def test_call_tool_missing_tool(self, client):
|
||||
"""Test that missing tool field returns validation error."""
|
||||
response = client.post(
|
||||
"/api/v1/mcp/call",
|
||||
json={"server": "server-1"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
Reference in New Issue
Block a user