forked from cardosofelipe/fast-next-template
Core MCP client implementation with comprehensive tooling:
**Services:**
- MCPClientManager: Main facade for all MCP operations
- MCPServerRegistry: Thread-safe singleton for server configs
- ConnectionPool: Connection pooling with auto-reconnection
- ToolRouter: Automatic tool routing with circuit breaker
- AsyncCircuitBreaker: Custom async-compatible circuit breaker
**Configuration:**
- YAML-based config with Pydantic models
- Environment variable expansion support
- Transport types: HTTP, SSE, STDIO
**API Endpoints:**
- GET /mcp/servers - List all MCP servers
- GET /mcp/servers/{name}/tools - List server tools
- GET /mcp/tools - List all tools from all servers
- GET /mcp/health - Health check all servers
- POST /mcp/call - Execute tool (admin only)
- GET /mcp/circuit-breakers - Circuit breaker status
- POST /mcp/circuit-breakers/{name}/reset - Reset circuit breaker
- POST /mcp/servers/{name}/reconnect - Force reconnection
**Testing:**
- 156 unit tests with comprehensive coverage
- Tests for all services, routes, and error handling
- Proper mocking and async test support
**Documentation:**
- MCP_CLIENT.md with usage examples
- Phase 2+ workflow documentation
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
620 lines
18 KiB
Python
620 lines
18 KiB
Python
"""
|
|
MCP Tool Call Routing
|
|
|
|
Routes tool calls to appropriate servers with retry logic,
|
|
circuit breakers, and request/response serialization.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
from typing import Any
|
|
|
|
from .config import MCPServerConfig
|
|
from .connection import ConnectionPool, MCPConnection
|
|
from .exceptions import (
|
|
MCPCircuitOpenError,
|
|
MCPError,
|
|
MCPTimeoutError,
|
|
MCPToolError,
|
|
MCPToolNotFoundError,
|
|
)
|
|
from .registry import MCPServerRegistry
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class CircuitState(Enum):
|
|
"""Circuit breaker states."""
|
|
|
|
CLOSED = "closed"
|
|
OPEN = "open"
|
|
HALF_OPEN = "half-open"
|
|
|
|
|
|
class AsyncCircuitBreaker:
|
|
"""
|
|
Async-compatible circuit breaker implementation.
|
|
|
|
Unlike pybreaker which wraps sync functions, this implementation
|
|
provides explicit success/failure tracking for async code.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
fail_max: int = 5,
|
|
reset_timeout: float = 30.0,
|
|
name: str = "",
|
|
) -> None:
|
|
"""
|
|
Initialize circuit breaker.
|
|
|
|
Args:
|
|
fail_max: Maximum failures before opening circuit
|
|
reset_timeout: Seconds to wait before trying again
|
|
name: Name for logging
|
|
"""
|
|
self.fail_max = fail_max
|
|
self.reset_timeout = reset_timeout
|
|
self.name = name
|
|
self._state = CircuitState.CLOSED
|
|
self._fail_counter = 0
|
|
self._last_failure_time: float | None = None
|
|
self._lock = asyncio.Lock()
|
|
|
|
@property
|
|
def current_state(self) -> str:
|
|
"""Get current state as string."""
|
|
# Check if we should transition from OPEN to HALF_OPEN
|
|
if self._state == CircuitState.OPEN:
|
|
if self._should_try_reset():
|
|
return CircuitState.HALF_OPEN.value
|
|
return self._state.value
|
|
|
|
@property
|
|
def fail_counter(self) -> int:
|
|
"""Get current failure count."""
|
|
return self._fail_counter
|
|
|
|
def _should_try_reset(self) -> bool:
|
|
"""Check if enough time has passed to try resetting."""
|
|
if self._last_failure_time is None:
|
|
return True
|
|
return (time.time() - self._last_failure_time) >= self.reset_timeout
|
|
|
|
async def success(self) -> None:
|
|
"""Record a successful call."""
|
|
async with self._lock:
|
|
self._fail_counter = 0
|
|
self._state = CircuitState.CLOSED
|
|
self._last_failure_time = None
|
|
|
|
async def failure(self) -> None:
|
|
"""Record a failed call."""
|
|
async with self._lock:
|
|
self._fail_counter += 1
|
|
self._last_failure_time = time.time()
|
|
|
|
if self._fail_counter >= self.fail_max:
|
|
self._state = CircuitState.OPEN
|
|
logger.warning(
|
|
"Circuit breaker %s opened after %d failures",
|
|
self.name,
|
|
self._fail_counter,
|
|
)
|
|
|
|
def is_open(self) -> bool:
|
|
"""Check if circuit is open (not allowing calls)."""
|
|
if self._state == CircuitState.OPEN:
|
|
return not self._should_try_reset()
|
|
return False
|
|
|
|
async def reset(self) -> None:
|
|
"""Manually reset the circuit breaker."""
|
|
async with self._lock:
|
|
self._state = CircuitState.CLOSED
|
|
self._fail_counter = 0
|
|
self._last_failure_time = None
|
|
|
|
|
|
@dataclass
|
|
class ToolInfo:
|
|
"""Information about an available tool."""
|
|
|
|
name: str
|
|
description: str | None = None
|
|
server_name: str | None = None
|
|
input_schema: dict[str, Any] | None = None
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
"""Convert to dictionary."""
|
|
return {
|
|
"name": self.name,
|
|
"description": self.description,
|
|
"server_name": self.server_name,
|
|
"input_schema": self.input_schema,
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class ToolResult:
|
|
"""Result of a tool execution."""
|
|
|
|
success: bool
|
|
data: Any = None
|
|
error: str | None = None
|
|
error_code: str | None = None
|
|
tool_name: str | None = None
|
|
server_name: str | None = None
|
|
execution_time_ms: float = 0.0
|
|
request_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
"""Convert to dictionary."""
|
|
return {
|
|
"success": self.success,
|
|
"data": self.data,
|
|
"error": self.error,
|
|
"error_code": self.error_code,
|
|
"tool_name": self.tool_name,
|
|
"server_name": self.server_name,
|
|
"execution_time_ms": self.execution_time_ms,
|
|
"request_id": self.request_id,
|
|
}
|
|
|
|
|
|
class ToolRouter:
|
|
"""
|
|
Routes tool calls to the appropriate MCP server.
|
|
|
|
Features:
|
|
- Tool name to server mapping
|
|
- Retry logic with exponential backoff
|
|
- Circuit breaker pattern for fault tolerance
|
|
- Request/response serialization
|
|
- Execution timing and metrics
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
registry: MCPServerRegistry,
|
|
connection_pool: ConnectionPool,
|
|
) -> None:
|
|
"""
|
|
Initialize the tool router.
|
|
|
|
Args:
|
|
registry: MCP server registry
|
|
connection_pool: Connection pool for servers
|
|
"""
|
|
self._registry = registry
|
|
self._pool = connection_pool
|
|
self._circuit_breakers: dict[str, AsyncCircuitBreaker] = {}
|
|
self._tool_to_server: dict[str, str] = {}
|
|
self._lock = asyncio.Lock()
|
|
|
|
def _get_circuit_breaker(
|
|
self,
|
|
server_name: str,
|
|
config: MCPServerConfig,
|
|
) -> AsyncCircuitBreaker:
|
|
"""Get or create a circuit breaker for a server."""
|
|
if server_name not in self._circuit_breakers:
|
|
self._circuit_breakers[server_name] = AsyncCircuitBreaker(
|
|
fail_max=config.circuit_breaker_threshold,
|
|
reset_timeout=config.circuit_breaker_timeout,
|
|
name=f"mcp-{server_name}",
|
|
)
|
|
return self._circuit_breakers[server_name]
|
|
|
|
async def register_tool_mapping(
|
|
self,
|
|
tool_name: str,
|
|
server_name: str,
|
|
) -> None:
|
|
"""
|
|
Register a mapping from tool name to server.
|
|
|
|
Args:
|
|
tool_name: Name of the tool
|
|
server_name: Name of the server providing the tool
|
|
"""
|
|
async with self._lock:
|
|
self._tool_to_server[tool_name] = server_name
|
|
logger.debug("Registered tool %s -> server %s", tool_name, server_name)
|
|
|
|
async def discover_tools(self) -> None:
|
|
"""
|
|
Discover all tools from registered servers and build mappings.
|
|
"""
|
|
for server_name in self._registry.list_enabled_servers():
|
|
try:
|
|
config = self._registry.get(server_name)
|
|
connection = await self._pool.get_connection(server_name, config)
|
|
|
|
# Fetch tools from server
|
|
tools = await self._fetch_tools_from_server(connection)
|
|
|
|
# Update registry with capabilities
|
|
self._registry.set_capabilities(
|
|
server_name,
|
|
tools=[t.to_dict() for t in tools],
|
|
)
|
|
|
|
# Update tool mappings
|
|
for tool in tools:
|
|
await self.register_tool_mapping(tool.name, server_name)
|
|
|
|
logger.info(
|
|
"Discovered %d tools from server %s",
|
|
len(tools),
|
|
server_name,
|
|
)
|
|
except Exception as e:
|
|
logger.warning(
|
|
"Failed to discover tools from %s: %s",
|
|
server_name,
|
|
e,
|
|
)
|
|
|
|
async def _fetch_tools_from_server(
|
|
self,
|
|
connection: MCPConnection,
|
|
) -> list[ToolInfo]:
|
|
"""Fetch available tools from an MCP server."""
|
|
try:
|
|
response = await connection.execute_request(
|
|
"GET",
|
|
"/mcp/tools",
|
|
)
|
|
|
|
tools = []
|
|
for tool_data in response.get("tools", []):
|
|
tools.append(
|
|
ToolInfo(
|
|
name=tool_data.get("name", ""),
|
|
description=tool_data.get("description"),
|
|
server_name=connection.server_name,
|
|
input_schema=tool_data.get("inputSchema"),
|
|
)
|
|
)
|
|
return tools
|
|
except Exception as e:
|
|
logger.warning(
|
|
"Error fetching tools from %s: %s",
|
|
connection.server_name,
|
|
e,
|
|
)
|
|
return []
|
|
|
|
def find_server_for_tool(self, tool_name: str) -> str | None:
|
|
"""
|
|
Find which server provides a specific tool.
|
|
|
|
Args:
|
|
tool_name: Name of the tool
|
|
|
|
Returns:
|
|
Server name or None if not found
|
|
"""
|
|
return self._tool_to_server.get(tool_name)
|
|
|
|
async def call_tool(
|
|
self,
|
|
server_name: str,
|
|
tool_name: str,
|
|
arguments: dict[str, Any] | None = None,
|
|
timeout: float | None = None,
|
|
) -> ToolResult:
|
|
"""
|
|
Call a tool on a specific server.
|
|
|
|
Args:
|
|
server_name: Name of the MCP server
|
|
tool_name: Name of the tool to call
|
|
arguments: Tool arguments
|
|
timeout: Optional timeout override
|
|
|
|
Returns:
|
|
Tool execution result
|
|
"""
|
|
start_time = time.time()
|
|
request_id = str(uuid.uuid4())
|
|
|
|
logger.debug(
|
|
"Tool call [%s]: %s.%s with args %s",
|
|
request_id,
|
|
server_name,
|
|
tool_name,
|
|
arguments,
|
|
)
|
|
|
|
try:
|
|
config = self._registry.get(server_name)
|
|
circuit_breaker = self._get_circuit_breaker(server_name, config)
|
|
|
|
# Check circuit breaker state
|
|
if circuit_breaker.is_open():
|
|
raise MCPCircuitOpenError(
|
|
server_name=server_name,
|
|
failure_count=circuit_breaker.fail_counter,
|
|
reset_timeout=config.circuit_breaker_timeout,
|
|
)
|
|
|
|
# Execute with retry logic
|
|
result = await self._execute_with_retry(
|
|
server_name=server_name,
|
|
config=config,
|
|
tool_name=tool_name,
|
|
arguments=arguments or {},
|
|
timeout=timeout,
|
|
circuit_breaker=circuit_breaker,
|
|
)
|
|
|
|
execution_time = (time.time() - start_time) * 1000
|
|
|
|
return ToolResult(
|
|
success=True,
|
|
data=result,
|
|
tool_name=tool_name,
|
|
server_name=server_name,
|
|
execution_time_ms=execution_time,
|
|
request_id=request_id,
|
|
)
|
|
|
|
except MCPCircuitOpenError:
|
|
raise
|
|
except MCPError as e:
|
|
execution_time = (time.time() - start_time) * 1000
|
|
logger.error(
|
|
"Tool call failed [%s]: %s.%s - %s",
|
|
request_id,
|
|
server_name,
|
|
tool_name,
|
|
e,
|
|
)
|
|
return ToolResult(
|
|
success=False,
|
|
error=str(e),
|
|
error_code=type(e).__name__,
|
|
tool_name=tool_name,
|
|
server_name=server_name,
|
|
execution_time_ms=execution_time,
|
|
request_id=request_id,
|
|
)
|
|
except Exception as e:
|
|
execution_time = (time.time() - start_time) * 1000
|
|
logger.exception(
|
|
"Unexpected error in tool call [%s]: %s.%s",
|
|
request_id,
|
|
server_name,
|
|
tool_name,
|
|
)
|
|
return ToolResult(
|
|
success=False,
|
|
error=str(e),
|
|
error_code="UnexpectedError",
|
|
tool_name=tool_name,
|
|
server_name=server_name,
|
|
execution_time_ms=execution_time,
|
|
request_id=request_id,
|
|
)
|
|
|
|
async def _execute_with_retry(
|
|
self,
|
|
server_name: str,
|
|
config: MCPServerConfig,
|
|
tool_name: str,
|
|
arguments: dict[str, Any],
|
|
timeout: float | None,
|
|
circuit_breaker: AsyncCircuitBreaker,
|
|
) -> Any:
|
|
"""Execute tool call with retry logic."""
|
|
last_error: Exception | None = None
|
|
attempts = 0
|
|
max_attempts = config.retry_attempts + 1 # +1 for initial attempt
|
|
|
|
while attempts < max_attempts:
|
|
attempts += 1
|
|
|
|
try:
|
|
# Use circuit breaker to track failures
|
|
result = await self._execute_tool_call(
|
|
server_name=server_name,
|
|
config=config,
|
|
tool_name=tool_name,
|
|
arguments=arguments,
|
|
timeout=timeout,
|
|
)
|
|
|
|
# Success - record it
|
|
await circuit_breaker.success()
|
|
return result
|
|
|
|
except MCPCircuitOpenError:
|
|
raise
|
|
except MCPTimeoutError:
|
|
# Timeout - don't retry
|
|
await circuit_breaker.failure()
|
|
raise
|
|
except MCPToolError:
|
|
# Tool-level error - don't retry (user error)
|
|
raise
|
|
except Exception as e:
|
|
last_error = e
|
|
await circuit_breaker.failure()
|
|
|
|
if attempts < max_attempts:
|
|
delay = self._calculate_retry_delay(attempts, config)
|
|
logger.warning(
|
|
"Tool call attempt %d/%d failed for %s.%s: %s. "
|
|
"Retrying in %.1fs",
|
|
attempts,
|
|
max_attempts,
|
|
server_name,
|
|
tool_name,
|
|
e,
|
|
delay,
|
|
)
|
|
await asyncio.sleep(delay)
|
|
|
|
# All attempts failed
|
|
raise MCPToolError(
|
|
f"Tool call failed after {max_attempts} attempts",
|
|
server_name=server_name,
|
|
tool_name=tool_name,
|
|
tool_args=arguments,
|
|
details={"last_error": str(last_error)},
|
|
)
|
|
|
|
def _calculate_retry_delay(
|
|
self,
|
|
attempt: int,
|
|
config: MCPServerConfig,
|
|
) -> float:
|
|
"""Calculate exponential backoff delay with jitter."""
|
|
import random
|
|
|
|
delay = config.retry_delay * (2 ** (attempt - 1))
|
|
delay = min(delay, config.retry_max_delay)
|
|
# Add jitter (±25%)
|
|
jitter = delay * 0.25 * (random.random() * 2 - 1)
|
|
return max(0.1, delay + jitter)
|
|
|
|
async def _execute_tool_call(
|
|
self,
|
|
server_name: str,
|
|
config: MCPServerConfig,
|
|
tool_name: str,
|
|
arguments: dict[str, Any],
|
|
timeout: float | None,
|
|
) -> Any:
|
|
"""Execute a single tool call."""
|
|
connection = await self._pool.get_connection(server_name, config)
|
|
|
|
# Build MCP tool call request
|
|
request_body = {
|
|
"jsonrpc": "2.0",
|
|
"method": "tools/call",
|
|
"params": {
|
|
"name": tool_name,
|
|
"arguments": arguments,
|
|
},
|
|
"id": str(uuid.uuid4()),
|
|
}
|
|
|
|
response = await connection.execute_request(
|
|
method="POST",
|
|
path="/mcp",
|
|
data=request_body,
|
|
timeout=timeout,
|
|
)
|
|
|
|
# Handle JSON-RPC response
|
|
if "error" in response:
|
|
error = response["error"]
|
|
raise MCPToolError(
|
|
error.get("message", "Tool execution failed"),
|
|
server_name=server_name,
|
|
tool_name=tool_name,
|
|
tool_args=arguments,
|
|
error_code=str(error.get("code", "UNKNOWN")),
|
|
)
|
|
|
|
return response.get("result")
|
|
|
|
async def route_tool(
|
|
self,
|
|
tool_name: str,
|
|
arguments: dict[str, Any] | None = None,
|
|
timeout: float | None = None,
|
|
) -> ToolResult:
|
|
"""
|
|
Route a tool call to the appropriate server.
|
|
|
|
Automatically discovers which server provides the tool.
|
|
|
|
Args:
|
|
tool_name: Name of the tool to call
|
|
arguments: Tool arguments
|
|
timeout: Optional timeout override
|
|
|
|
Returns:
|
|
Tool execution result
|
|
|
|
Raises:
|
|
MCPToolNotFoundError: If no server provides the tool
|
|
"""
|
|
server_name = self.find_server_for_tool(tool_name)
|
|
|
|
if server_name is None:
|
|
# Try to find from registry
|
|
server_name = self._registry.find_server_for_tool(tool_name)
|
|
|
|
if server_name is None:
|
|
raise MCPToolNotFoundError(
|
|
tool_name=tool_name,
|
|
available_tools=list(self._tool_to_server.keys()),
|
|
)
|
|
|
|
return await self.call_tool(
|
|
server_name=server_name,
|
|
tool_name=tool_name,
|
|
arguments=arguments,
|
|
timeout=timeout,
|
|
)
|
|
|
|
async def list_all_tools(self) -> list[ToolInfo]:
|
|
"""
|
|
Get all available tools from all servers.
|
|
|
|
Returns:
|
|
List of tool information
|
|
"""
|
|
tools = []
|
|
all_server_tools = self._registry.get_all_tools()
|
|
|
|
for server_name, server_tools in all_server_tools.items():
|
|
for tool_data in server_tools:
|
|
tools.append(
|
|
ToolInfo(
|
|
name=tool_data.get("name", ""),
|
|
description=tool_data.get("description"),
|
|
server_name=server_name,
|
|
input_schema=tool_data.get("input_schema"),
|
|
)
|
|
)
|
|
|
|
return tools
|
|
|
|
def get_circuit_breaker_status(self) -> dict[str, dict[str, Any]]:
|
|
"""Get status of all circuit breakers."""
|
|
return {
|
|
name: {
|
|
"state": cb.current_state,
|
|
"failure_count": cb.fail_counter,
|
|
}
|
|
for name, cb in self._circuit_breakers.items()
|
|
}
|
|
|
|
async def reset_circuit_breaker(self, server_name: str) -> bool:
|
|
"""
|
|
Manually reset a circuit breaker.
|
|
|
|
Args:
|
|
server_name: Name of the server
|
|
|
|
Returns:
|
|
True if circuit breaker was reset
|
|
"""
|
|
async with self._lock:
|
|
if server_name in self._circuit_breakers:
|
|
# Reset by removing (will be recreated on next call)
|
|
del self._circuit_breakers[server_name]
|
|
logger.info("Reset circuit breaker for %s", server_name)
|
|
return True
|
|
return False
|