forked from cardosofelipe/fast-next-template
feat(backend): implement MCP client infrastructure (#55)
Core MCP client implementation with comprehensive tooling:
**Services:**
- MCPClientManager: Main facade for all MCP operations
- MCPServerRegistry: Thread-safe singleton for server configs
- ConnectionPool: Connection pooling with auto-reconnection
- ToolRouter: Automatic tool routing with circuit breaker
- AsyncCircuitBreaker: Custom async-compatible circuit breaker
**Configuration:**
- YAML-based config with Pydantic models
- Environment variable expansion support
- Transport types: HTTP, SSE, STDIO
**API Endpoints:**
- GET /mcp/servers - List all MCP servers
- GET /mcp/servers/{name}/tools - List server tools
- GET /mcp/tools - List all tools from all servers
- GET /mcp/health - Health check all servers
- POST /mcp/call - Execute tool (admin only)
- GET /mcp/circuit-breakers - Circuit breaker status
- POST /mcp/circuit-breakers/{name}/reset - Reset circuit breaker
- POST /mcp/servers/{name}/reconnect - Force reconnection
**Testing:**
- 156 unit tests with comprehensive coverage
- Tests for all services, routes, and error handling
- Proper mocking and async test support
**Documentation:**
- MCP_CLIENT.md with usage examples
- Phase 2+ workflow documentation
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
435
backend/app/services/mcp/connection.py
Normal file
435
backend/app/services/mcp/connection.py
Normal file
@@ -0,0 +1,435 @@
|
||||
"""
|
||||
MCP Connection Management
|
||||
|
||||
Handles connection lifecycle, pooling, and automatic reconnection
|
||||
for MCP servers.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from .config import MCPServerConfig, TransportType
|
||||
from .exceptions import MCPConnectionError, MCPTimeoutError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConnectionState(str, Enum):
|
||||
"""Connection state enumeration."""
|
||||
|
||||
DISCONNECTED = "disconnected"
|
||||
CONNECTING = "connecting"
|
||||
CONNECTED = "connected"
|
||||
RECONNECTING = "reconnecting"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class MCPConnection:
|
||||
"""
|
||||
Manages a single connection to an MCP server.
|
||||
|
||||
Handles connection lifecycle, health checking, and automatic reconnection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_name: str,
|
||||
config: MCPServerConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize connection.
|
||||
|
||||
Args:
|
||||
server_name: Name of the MCP server
|
||||
config: Server configuration
|
||||
"""
|
||||
self.server_name = server_name
|
||||
self.config = config
|
||||
self._state = ConnectionState.DISCONNECTED
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
self._lock = asyncio.Lock()
|
||||
self._last_activity: float | None = None
|
||||
self._connection_attempts = 0
|
||||
self._last_error: Exception | None = None
|
||||
|
||||
# Reconnection settings
|
||||
self._base_delay = config.retry_delay
|
||||
self._max_delay = config.retry_max_delay
|
||||
self._max_attempts = config.retry_attempts
|
||||
|
||||
@property
|
||||
def state(self) -> ConnectionState:
|
||||
"""Get current connection state."""
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if connection is established."""
|
||||
return self._state == ConnectionState.CONNECTED
|
||||
|
||||
@property
|
||||
def last_error(self) -> Exception | None:
|
||||
"""Get the last error that occurred."""
|
||||
return self._last_error
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""
|
||||
Establish connection to the MCP server.
|
||||
|
||||
Raises:
|
||||
MCPConnectionError: If connection fails after all retries
|
||||
"""
|
||||
async with self._lock:
|
||||
if self._state == ConnectionState.CONNECTED:
|
||||
return
|
||||
|
||||
self._state = ConnectionState.CONNECTING
|
||||
self._connection_attempts = 0
|
||||
self._last_error = None
|
||||
|
||||
while self._connection_attempts < self._max_attempts:
|
||||
try:
|
||||
await self._do_connect()
|
||||
self._state = ConnectionState.CONNECTED
|
||||
self._last_activity = time.time()
|
||||
logger.info(
|
||||
"Connected to MCP server: %s at %s",
|
||||
self.server_name,
|
||||
self.config.url,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
self._connection_attempts += 1
|
||||
self._last_error = e
|
||||
logger.warning(
|
||||
"Connection attempt %d/%d failed for %s: %s",
|
||||
self._connection_attempts,
|
||||
self._max_attempts,
|
||||
self.server_name,
|
||||
e,
|
||||
)
|
||||
|
||||
if self._connection_attempts < self._max_attempts:
|
||||
delay = self._calculate_backoff_delay()
|
||||
logger.debug(
|
||||
"Retrying connection to %s in %.1fs",
|
||||
self.server_name,
|
||||
delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# All attempts failed
|
||||
self._state = ConnectionState.ERROR
|
||||
raise MCPConnectionError(
|
||||
f"Failed to connect after {self._max_attempts} attempts",
|
||||
server_name=self.server_name,
|
||||
url=self.config.url,
|
||||
cause=self._last_error,
|
||||
)
|
||||
|
||||
async def _do_connect(self) -> None:
|
||||
"""Perform the actual connection (transport-specific)."""
|
||||
if self.config.transport == TransportType.HTTP:
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self.config.url,
|
||||
timeout=httpx.Timeout(self.config.timeout),
|
||||
headers={
|
||||
"User-Agent": "Syndarix-MCP-Client/1.0",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
# Verify connectivity with a simple request
|
||||
try:
|
||||
# Try to hit the MCP capabilities endpoint
|
||||
response = await self._client.get("/mcp/capabilities")
|
||||
if response.status_code not in (200, 404):
|
||||
# 404 is acceptable - server might not have capabilities endpoint
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code != 404:
|
||||
raise
|
||||
except httpx.ConnectError as e:
|
||||
raise MCPConnectionError(
|
||||
"Failed to connect to server",
|
||||
server_name=self.server_name,
|
||||
url=self.config.url,
|
||||
cause=e,
|
||||
)
|
||||
else:
|
||||
# For STDIO and SSE transports, we'll implement later
|
||||
raise NotImplementedError(
|
||||
f"Transport {self.config.transport} not yet implemented"
|
||||
)
|
||||
|
||||
def _calculate_backoff_delay(self) -> float:
|
||||
"""Calculate exponential backoff delay with jitter."""
|
||||
import random
|
||||
|
||||
delay = self._base_delay * (2 ** (self._connection_attempts - 1))
|
||||
delay = min(delay, self._max_delay)
|
||||
# Add jitter (±25%)
|
||||
jitter = delay * 0.25 * (random.random() * 2 - 1)
|
||||
return delay + jitter
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from the MCP server."""
|
||||
async with self._lock:
|
||||
if self._client is not None:
|
||||
try:
|
||||
await self._client.aclose()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Error closing connection to %s: %s",
|
||||
self.server_name,
|
||||
e,
|
||||
)
|
||||
finally:
|
||||
self._client = None
|
||||
|
||||
self._state = ConnectionState.DISCONNECTED
|
||||
logger.info("Disconnected from MCP server: %s", self.server_name)
|
||||
|
||||
async def reconnect(self) -> None:
|
||||
"""Reconnect to the MCP server."""
|
||||
async with self._lock:
|
||||
self._state = ConnectionState.RECONNECTING
|
||||
await self.disconnect()
|
||||
await self.connect()
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""
|
||||
Perform a health check on the connection.
|
||||
|
||||
Returns:
|
||||
True if connection is healthy
|
||||
"""
|
||||
if not self.is_connected or self._client is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
if self.config.transport == TransportType.HTTP:
|
||||
response = await self._client.get(
|
||||
"/health",
|
||||
timeout=5.0,
|
||||
)
|
||||
return response.status_code == 200
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Health check failed for %s: %s",
|
||||
self.server_name,
|
||||
e,
|
||||
)
|
||||
return False
|
||||
|
||||
async def execute_request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
data: dict[str, Any] | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Execute an HTTP request to the MCP server.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
path: Request path
|
||||
data: Optional request body
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
Response data
|
||||
|
||||
Raises:
|
||||
MCPConnectionError: If not connected
|
||||
MCPTimeoutError: If request times out
|
||||
"""
|
||||
if not self.is_connected or self._client is None:
|
||||
raise MCPConnectionError(
|
||||
"Not connected to server",
|
||||
server_name=self.server_name,
|
||||
)
|
||||
|
||||
effective_timeout = timeout or self.config.timeout
|
||||
|
||||
try:
|
||||
if method.upper() == "GET":
|
||||
response = await self._client.get(
|
||||
path,
|
||||
timeout=effective_timeout,
|
||||
)
|
||||
elif method.upper() == "POST":
|
||||
response = await self._client.post(
|
||||
path,
|
||||
json=data,
|
||||
timeout=effective_timeout,
|
||||
)
|
||||
else:
|
||||
response = await self._client.request(
|
||||
method.upper(),
|
||||
path,
|
||||
json=data,
|
||||
timeout=effective_timeout,
|
||||
)
|
||||
|
||||
self._last_activity = time.time()
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
raise MCPTimeoutError(
|
||||
"Request timed out",
|
||||
server_name=self.server_name,
|
||||
timeout_seconds=effective_timeout,
|
||||
operation=f"{method} {path}",
|
||||
) from e
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise MCPConnectionError(
|
||||
f"HTTP error: {e.response.status_code}",
|
||||
server_name=self.server_name,
|
||||
url=f"{self.config.url}{path}",
|
||||
cause=e,
|
||||
)
|
||||
except Exception as e:
|
||||
raise MCPConnectionError(
|
||||
f"Request failed: {e}",
|
||||
server_name=self.server_name,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
|
||||
class ConnectionPool:
|
||||
"""
|
||||
Pool of connections to MCP servers.
|
||||
|
||||
Manages connection lifecycle and provides connection reuse.
|
||||
"""
|
||||
|
||||
def __init__(self, max_connections_per_server: int = 10) -> None:
|
||||
"""
|
||||
Initialize connection pool.
|
||||
|
||||
Args:
|
||||
max_connections_per_server: Maximum connections per server
|
||||
"""
|
||||
self._connections: dict[str, MCPConnection] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._max_per_server = max_connections_per_server
|
||||
|
||||
async def get_connection(
|
||||
self,
|
||||
server_name: str,
|
||||
config: MCPServerConfig,
|
||||
) -> MCPConnection:
|
||||
"""
|
||||
Get or create a connection to a server.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server
|
||||
config: Server configuration
|
||||
|
||||
Returns:
|
||||
Active connection
|
||||
"""
|
||||
async with self._lock:
|
||||
if server_name not in self._connections:
|
||||
connection = MCPConnection(server_name, config)
|
||||
await connection.connect()
|
||||
self._connections[server_name] = connection
|
||||
|
||||
connection = self._connections[server_name]
|
||||
|
||||
# Reconnect if not connected
|
||||
if not connection.is_connected:
|
||||
await connection.connect()
|
||||
|
||||
return connection
|
||||
|
||||
async def release_connection(self, server_name: str) -> None:
|
||||
"""
|
||||
Release a connection (currently just tracks usage).
|
||||
|
||||
Args:
|
||||
server_name: Name of the server
|
||||
"""
|
||||
# For now, we keep connections alive
|
||||
# Future: implement connection reaping for idle connections
|
||||
|
||||
async def close_connection(self, server_name: str) -> None:
|
||||
"""
|
||||
Close and remove a connection.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server
|
||||
"""
|
||||
async with self._lock:
|
||||
if server_name in self._connections:
|
||||
await self._connections[server_name].disconnect()
|
||||
del self._connections[server_name]
|
||||
|
||||
async def close_all(self) -> None:
|
||||
"""Close all connections in the pool."""
|
||||
async with self._lock:
|
||||
for connection in self._connections.values():
|
||||
try:
|
||||
await connection.disconnect()
|
||||
except Exception as e:
|
||||
logger.warning("Error closing connection: %s", e)
|
||||
|
||||
self._connections.clear()
|
||||
logger.info("Closed all MCP connections")
|
||||
|
||||
async def health_check_all(self) -> dict[str, bool]:
|
||||
"""
|
||||
Perform health check on all connections.
|
||||
|
||||
Returns:
|
||||
Dict mapping server names to health status
|
||||
"""
|
||||
results = {}
|
||||
for name, connection in self._connections.items():
|
||||
results[name] = await connection.health_check()
|
||||
return results
|
||||
|
||||
def get_status(self) -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
Get status of all connections.
|
||||
|
||||
Returns:
|
||||
Dict mapping server names to status info
|
||||
"""
|
||||
return {
|
||||
name: {
|
||||
"state": conn.state.value,
|
||||
"is_connected": conn.is_connected,
|
||||
"url": conn.config.url,
|
||||
}
|
||||
for name, conn in self._connections.items()
|
||||
}
|
||||
|
||||
@asynccontextmanager
|
||||
async def connection(
|
||||
self,
|
||||
server_name: str,
|
||||
config: MCPServerConfig,
|
||||
) -> AsyncGenerator[MCPConnection, None]:
|
||||
"""
|
||||
Context manager for getting a connection.
|
||||
|
||||
Usage:
|
||||
async with pool.connection("server", config) as conn:
|
||||
result = await conn.execute_request("POST", "/tool", data)
|
||||
"""
|
||||
conn = await self.get_connection(server_name, config)
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
await self.release_connection(server_name)
|
||||
Reference in New Issue
Block a user