forked from cardosofelipe/fast-next-template
Core MCP client implementation with comprehensive tooling:
**Services:**
- MCPClientManager: Main facade for all MCP operations
- MCPServerRegistry: Thread-safe singleton for server configs
- ConnectionPool: Connection pooling with auto-reconnection
- ToolRouter: Automatic tool routing with circuit breaker
- AsyncCircuitBreaker: Custom async-compatible circuit breaker
**Configuration:**
- YAML-based config with Pydantic models
- Environment variable expansion support
- Transport types: HTTP, SSE, STDIO
**API Endpoints:**
- GET /mcp/servers - List all MCP servers
- GET /mcp/servers/{name}/tools - List server tools
- GET /mcp/tools - List all tools from all servers
- GET /mcp/health - Health check all servers
- POST /mcp/call - Execute tool (admin only)
- GET /mcp/circuit-breakers - Circuit breaker status
- POST /mcp/circuit-breakers/{name}/reset - Reset circuit breaker
- POST /mcp/servers/{name}/reconnect - Force reconnection
**Testing:**
- 156 unit tests with comprehensive coverage
- Tests for all services, routes, and error handling
- Proper mocking and async test support
**Documentation:**
- MCP_CLIENT.md with usage examples
- Phase 2+ workflow documentation
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
406 lines
14 KiB
Python
406 lines
14 KiB
Python
"""
|
|
Tests for MCP Connection Management
|
|
"""
|
|
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import httpx
|
|
import pytest
|
|
|
|
from app.services.mcp.config import MCPServerConfig, TransportType
|
|
from app.services.mcp.connection import (
|
|
ConnectionPool,
|
|
ConnectionState,
|
|
MCPConnection,
|
|
)
|
|
from app.services.mcp.exceptions import MCPConnectionError, MCPTimeoutError
|
|
|
|
|
|
@pytest.fixture
|
|
def server_config():
|
|
"""Create a sample server configuration."""
|
|
return MCPServerConfig(
|
|
url="http://localhost:8000",
|
|
transport=TransportType.HTTP,
|
|
timeout=30,
|
|
retry_attempts=3,
|
|
retry_delay=0.1, # Short delay for tests
|
|
retry_max_delay=1.0,
|
|
)
|
|
|
|
|
|
class TestConnectionState:
|
|
"""Tests for ConnectionState enum."""
|
|
|
|
def test_connection_states(self):
|
|
"""Test all connection states are defined."""
|
|
assert ConnectionState.DISCONNECTED == "disconnected"
|
|
assert ConnectionState.CONNECTING == "connecting"
|
|
assert ConnectionState.CONNECTED == "connected"
|
|
assert ConnectionState.RECONNECTING == "reconnecting"
|
|
assert ConnectionState.ERROR == "error"
|
|
|
|
|
|
class TestMCPConnection:
|
|
"""Tests for MCPConnection class."""
|
|
|
|
def test_initial_state(self, server_config):
|
|
"""Test initial connection state."""
|
|
conn = MCPConnection("test-server", server_config)
|
|
assert conn.server_name == "test-server"
|
|
assert conn.config == server_config
|
|
assert conn.state == ConnectionState.DISCONNECTED
|
|
assert conn.is_connected is False
|
|
assert conn.last_error is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect_success(self, server_config):
|
|
"""Test successful connection."""
|
|
conn = MCPConnection("test-server", server_config)
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
|
|
with patch("httpx.AsyncClient") as MockClient:
|
|
mock_client = AsyncMock()
|
|
mock_client.get = AsyncMock(return_value=mock_response)
|
|
MockClient.return_value = mock_client
|
|
|
|
await conn.connect()
|
|
|
|
assert conn.state == ConnectionState.CONNECTED
|
|
assert conn.is_connected is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect_404_capabilities_ok(self, server_config):
|
|
"""Test connection succeeds even if /mcp/capabilities returns 404."""
|
|
conn = MCPConnection("test-server", server_config)
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 404
|
|
|
|
with patch("httpx.AsyncClient") as MockClient:
|
|
mock_client = AsyncMock()
|
|
mock_client.get = AsyncMock(return_value=mock_response)
|
|
MockClient.return_value = mock_client
|
|
|
|
await conn.connect()
|
|
|
|
assert conn.state == ConnectionState.CONNECTED
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect_failure_with_retry(self, server_config):
|
|
"""Test connection failure with retries."""
|
|
# Reduce retry attempts for faster test
|
|
server_config.retry_attempts = 2
|
|
server_config.retry_delay = 0.01
|
|
conn = MCPConnection("test-server", server_config)
|
|
|
|
with patch("httpx.AsyncClient") as MockClient:
|
|
mock_client = AsyncMock()
|
|
mock_client.get = AsyncMock(
|
|
side_effect=httpx.ConnectError("Connection refused")
|
|
)
|
|
MockClient.return_value = mock_client
|
|
|
|
with pytest.raises(MCPConnectionError) as exc_info:
|
|
await conn.connect()
|
|
|
|
assert conn.state == ConnectionState.ERROR
|
|
assert "Failed to connect after" in str(exc_info.value)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_disconnect(self, server_config):
|
|
"""Test disconnection."""
|
|
conn = MCPConnection("test-server", server_config)
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
|
|
with patch("httpx.AsyncClient") as MockClient:
|
|
mock_client = AsyncMock()
|
|
mock_client.get = AsyncMock(return_value=mock_response)
|
|
mock_client.aclose = AsyncMock()
|
|
MockClient.return_value = mock_client
|
|
|
|
await conn.connect()
|
|
assert conn.is_connected is True
|
|
|
|
await conn.disconnect()
|
|
assert conn.state == ConnectionState.DISCONNECTED
|
|
assert conn.is_connected is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reconnect(self, server_config):
|
|
"""Test reconnection."""
|
|
conn = MCPConnection("test-server", server_config)
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
|
|
with patch("httpx.AsyncClient") as MockClient:
|
|
mock_client = AsyncMock()
|
|
mock_client.get = AsyncMock(return_value=mock_response)
|
|
mock_client.aclose = AsyncMock()
|
|
MockClient.return_value = mock_client
|
|
|
|
await conn.connect()
|
|
assert conn.is_connected is True
|
|
|
|
await conn.reconnect()
|
|
assert conn.is_connected is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check_success(self, server_config):
|
|
"""Test successful health check."""
|
|
conn = MCPConnection("test-server", server_config)
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
|
|
with patch("httpx.AsyncClient") as MockClient:
|
|
mock_client = AsyncMock()
|
|
mock_client.get = AsyncMock(return_value=mock_response)
|
|
MockClient.return_value = mock_client
|
|
|
|
await conn.connect()
|
|
healthy = await conn.health_check()
|
|
|
|
assert healthy is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check_disconnected(self, server_config):
|
|
"""Test health check when disconnected."""
|
|
conn = MCPConnection("test-server", server_config)
|
|
healthy = await conn.health_check()
|
|
assert healthy is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_request_get(self, server_config):
|
|
"""Test executing GET request."""
|
|
conn = MCPConnection("test-server", server_config)
|
|
|
|
mock_connect_response = MagicMock()
|
|
mock_connect_response.status_code = 200
|
|
|
|
mock_request_response = MagicMock()
|
|
mock_request_response.status_code = 200
|
|
mock_request_response.json.return_value = {"tools": []}
|
|
mock_request_response.raise_for_status = MagicMock()
|
|
|
|
with patch("httpx.AsyncClient") as MockClient:
|
|
mock_client = AsyncMock()
|
|
mock_client.get = AsyncMock(
|
|
side_effect=[mock_connect_response, mock_request_response]
|
|
)
|
|
MockClient.return_value = mock_client
|
|
|
|
await conn.connect()
|
|
result = await conn.execute_request("GET", "/mcp/tools")
|
|
|
|
assert result == {"tools": []}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_request_post(self, server_config):
|
|
"""Test executing POST request."""
|
|
conn = MCPConnection("test-server", server_config)
|
|
|
|
mock_connect_response = MagicMock()
|
|
mock_connect_response.status_code = 200
|
|
|
|
mock_request_response = MagicMock()
|
|
mock_request_response.status_code = 200
|
|
mock_request_response.json.return_value = {"result": "success"}
|
|
mock_request_response.raise_for_status = MagicMock()
|
|
|
|
with patch("httpx.AsyncClient") as MockClient:
|
|
mock_client = AsyncMock()
|
|
mock_client.get = AsyncMock(return_value=mock_connect_response)
|
|
mock_client.post = AsyncMock(return_value=mock_request_response)
|
|
MockClient.return_value = mock_client
|
|
|
|
await conn.connect()
|
|
result = await conn.execute_request(
|
|
"POST", "/mcp", data={"method": "test"}
|
|
)
|
|
|
|
assert result == {"result": "success"}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_request_timeout(self, server_config):
|
|
"""Test request timeout."""
|
|
conn = MCPConnection("test-server", server_config)
|
|
|
|
mock_connect_response = MagicMock()
|
|
mock_connect_response.status_code = 200
|
|
|
|
with patch("httpx.AsyncClient") as MockClient:
|
|
mock_client = AsyncMock()
|
|
mock_client.get = AsyncMock(
|
|
side_effect=[
|
|
mock_connect_response,
|
|
httpx.TimeoutException("Request timeout"),
|
|
]
|
|
)
|
|
MockClient.return_value = mock_client
|
|
|
|
await conn.connect()
|
|
|
|
with pytest.raises(MCPTimeoutError) as exc_info:
|
|
await conn.execute_request("GET", "/slow-endpoint")
|
|
|
|
assert "timeout" in str(exc_info.value).lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_request_not_connected(self, server_config):
|
|
"""Test request when not connected."""
|
|
conn = MCPConnection("test-server", server_config)
|
|
|
|
with pytest.raises(MCPConnectionError) as exc_info:
|
|
await conn.execute_request("GET", "/test")
|
|
|
|
assert "Not connected" in str(exc_info.value)
|
|
|
|
def test_backoff_delay_calculation(self, server_config):
|
|
"""Test exponential backoff delay calculation."""
|
|
conn = MCPConnection("test-server", server_config)
|
|
|
|
# First attempt
|
|
conn._connection_attempts = 1
|
|
delay1 = conn._calculate_backoff_delay()
|
|
|
|
# Second attempt
|
|
conn._connection_attempts = 2
|
|
delay2 = conn._calculate_backoff_delay()
|
|
|
|
# Third attempt
|
|
conn._connection_attempts = 3
|
|
delay3 = conn._calculate_backoff_delay()
|
|
|
|
# Delays should generally increase (with some jitter)
|
|
# Base is 0.1, so rough expectations:
|
|
# Attempt 1: ~0.1s
|
|
# Attempt 2: ~0.2s
|
|
# Attempt 3: ~0.4s
|
|
assert delay1 > 0
|
|
assert delay2 > delay1 * 0.5 # Allow for jitter
|
|
assert delay3 <= server_config.retry_max_delay * 1.25 # Within max + jitter
|
|
|
|
|
|
class TestConnectionPool:
|
|
"""Tests for ConnectionPool class."""
|
|
|
|
@pytest.fixture
|
|
def pool(self):
|
|
"""Create a connection pool."""
|
|
return ConnectionPool(max_connections_per_server=5)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_connection_creates_new(self, pool, server_config):
|
|
"""Test getting connection creates new one if not exists."""
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
|
|
with patch("httpx.AsyncClient") as MockClient:
|
|
mock_client = AsyncMock()
|
|
mock_client.get = AsyncMock(return_value=mock_response)
|
|
MockClient.return_value = mock_client
|
|
|
|
conn = await pool.get_connection("test-server", server_config)
|
|
|
|
assert conn is not None
|
|
assert conn.is_connected is True
|
|
assert conn.server_name == "test-server"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_connection_reuses_existing(self, pool, server_config):
|
|
"""Test getting connection reuses existing one."""
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
|
|
with patch("httpx.AsyncClient") as MockClient:
|
|
mock_client = AsyncMock()
|
|
mock_client.get = AsyncMock(return_value=mock_response)
|
|
MockClient.return_value = mock_client
|
|
|
|
conn1 = await pool.get_connection("test-server", server_config)
|
|
conn2 = await pool.get_connection("test-server", server_config)
|
|
|
|
assert conn1 is conn2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_close_connection(self, pool, server_config):
|
|
"""Test closing a connection."""
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
|
|
with patch("httpx.AsyncClient") as MockClient:
|
|
mock_client = AsyncMock()
|
|
mock_client.get = AsyncMock(return_value=mock_response)
|
|
mock_client.aclose = AsyncMock()
|
|
MockClient.return_value = mock_client
|
|
|
|
await pool.get_connection("test-server", server_config)
|
|
await pool.close_connection("test-server")
|
|
|
|
assert "test-server" not in pool._connections
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_close_all(self, pool, server_config):
|
|
"""Test closing all connections."""
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
|
|
with patch("httpx.AsyncClient") as MockClient:
|
|
mock_client = AsyncMock()
|
|
mock_client.get = AsyncMock(return_value=mock_response)
|
|
mock_client.aclose = AsyncMock()
|
|
MockClient.return_value = mock_client
|
|
|
|
config2 = MCPServerConfig(url="http://server2:8000")
|
|
await pool.get_connection("server-1", server_config)
|
|
await pool.get_connection("server-2", config2)
|
|
|
|
assert len(pool._connections) == 2
|
|
|
|
await pool.close_all()
|
|
|
|
assert len(pool._connections) == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check_all(self, pool, server_config):
|
|
"""Test health check on all connections."""
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
|
|
with patch("httpx.AsyncClient") as MockClient:
|
|
mock_client = AsyncMock()
|
|
mock_client.get = AsyncMock(return_value=mock_response)
|
|
MockClient.return_value = mock_client
|
|
|
|
await pool.get_connection("test-server", server_config)
|
|
results = await pool.health_check_all()
|
|
|
|
assert "test-server" in results
|
|
assert results["test-server"] is True
|
|
|
|
def test_get_status(self, pool, server_config):
|
|
"""Test getting pool status."""
|
|
status = pool.get_status()
|
|
assert status == {}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connection_context_manager(self, pool, server_config):
|
|
"""Test connection context manager."""
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
|
|
with patch("httpx.AsyncClient") as MockClient:
|
|
mock_client = AsyncMock()
|
|
mock_client.get = AsyncMock(return_value=mock_response)
|
|
MockClient.return_value = mock_client
|
|
|
|
async with pool.connection("test-server", server_config) as conn:
|
|
assert conn.is_connected is True
|
|
assert conn.server_name == "test-server"
|