Files
syndarix/backend/app/services/mcp/routing.py
Felipe Cardoso e5975fa5d0 feat(backend): implement MCP client infrastructure (#55)
Core MCP client implementation with comprehensive tooling:

**Services:**
- MCPClientManager: Main facade for all MCP operations
- MCPServerRegistry: Thread-safe singleton for server configs
- ConnectionPool: Connection pooling with auto-reconnection
- ToolRouter: Automatic tool routing with circuit breaker
- AsyncCircuitBreaker: Custom async-compatible circuit breaker

**Configuration:**
- YAML-based config with Pydantic models
- Environment variable expansion support
- Transport types: HTTP, SSE, STDIO

**API Endpoints:**
- GET /mcp/servers - List all MCP servers
- GET /mcp/servers/{name}/tools - List server tools
- GET /mcp/tools - List all tools from all servers
- GET /mcp/health - Health check all servers
- POST /mcp/call - Execute tool (admin only)
- GET /mcp/circuit-breakers - Circuit breaker status
- POST /mcp/circuit-breakers/{name}/reset - Reset circuit breaker
- POST /mcp/servers/{name}/reconnect - Force reconnection

**Testing:**
- 156 unit tests with comprehensive coverage
- Tests for all services, routes, and error handling
- Proper mocking and async test support

**Documentation:**
- MCP_CLIENT.md with usage examples
- Phase 2+ workflow documentation

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-03 11:12:41 +01:00

620 lines
18 KiB
Python

"""
MCP Tool Call Routing
Routes tool calls to appropriate servers with retry logic,
circuit breakers, and request/response serialization.
"""
import asyncio
import logging
import time
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
from .config import MCPServerConfig
from .connection import ConnectionPool, MCPConnection
from .exceptions import (
MCPCircuitOpenError,
MCPError,
MCPTimeoutError,
MCPToolError,
MCPToolNotFoundError,
)
from .registry import MCPServerRegistry
logger = logging.getLogger(__name__)
class CircuitState(Enum):
"""Circuit breaker states."""
CLOSED = "closed"
OPEN = "open"
HALF_OPEN = "half-open"
class AsyncCircuitBreaker:
"""
Async-compatible circuit breaker implementation.
Unlike pybreaker which wraps sync functions, this implementation
provides explicit success/failure tracking for async code.
"""
def __init__(
self,
fail_max: int = 5,
reset_timeout: float = 30.0,
name: str = "",
) -> None:
"""
Initialize circuit breaker.
Args:
fail_max: Maximum failures before opening circuit
reset_timeout: Seconds to wait before trying again
name: Name for logging
"""
self.fail_max = fail_max
self.reset_timeout = reset_timeout
self.name = name
self._state = CircuitState.CLOSED
self._fail_counter = 0
self._last_failure_time: float | None = None
self._lock = asyncio.Lock()
@property
def current_state(self) -> str:
"""Get current state as string."""
# Check if we should transition from OPEN to HALF_OPEN
if self._state == CircuitState.OPEN:
if self._should_try_reset():
return CircuitState.HALF_OPEN.value
return self._state.value
@property
def fail_counter(self) -> int:
"""Get current failure count."""
return self._fail_counter
def _should_try_reset(self) -> bool:
"""Check if enough time has passed to try resetting."""
if self._last_failure_time is None:
return True
return (time.time() - self._last_failure_time) >= self.reset_timeout
async def success(self) -> None:
"""Record a successful call."""
async with self._lock:
self._fail_counter = 0
self._state = CircuitState.CLOSED
self._last_failure_time = None
async def failure(self) -> None:
"""Record a failed call."""
async with self._lock:
self._fail_counter += 1
self._last_failure_time = time.time()
if self._fail_counter >= self.fail_max:
self._state = CircuitState.OPEN
logger.warning(
"Circuit breaker %s opened after %d failures",
self.name,
self._fail_counter,
)
def is_open(self) -> bool:
"""Check if circuit is open (not allowing calls)."""
if self._state == CircuitState.OPEN:
return not self._should_try_reset()
return False
async def reset(self) -> None:
"""Manually reset the circuit breaker."""
async with self._lock:
self._state = CircuitState.CLOSED
self._fail_counter = 0
self._last_failure_time = None
@dataclass
class ToolInfo:
"""Information about an available tool."""
name: str
description: str | None = None
server_name: str | None = None
input_schema: dict[str, Any] | None = None
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"name": self.name,
"description": self.description,
"server_name": self.server_name,
"input_schema": self.input_schema,
}
@dataclass
class ToolResult:
"""Result of a tool execution."""
success: bool
data: Any = None
error: str | None = None
error_code: str | None = None
tool_name: str | None = None
server_name: str | None = None
execution_time_ms: float = 0.0
request_id: str = field(default_factory=lambda: str(uuid.uuid4()))
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"success": self.success,
"data": self.data,
"error": self.error,
"error_code": self.error_code,
"tool_name": self.tool_name,
"server_name": self.server_name,
"execution_time_ms": self.execution_time_ms,
"request_id": self.request_id,
}
class ToolRouter:
"""
Routes tool calls to the appropriate MCP server.
Features:
- Tool name to server mapping
- Retry logic with exponential backoff
- Circuit breaker pattern for fault tolerance
- Request/response serialization
- Execution timing and metrics
"""
def __init__(
self,
registry: MCPServerRegistry,
connection_pool: ConnectionPool,
) -> None:
"""
Initialize the tool router.
Args:
registry: MCP server registry
connection_pool: Connection pool for servers
"""
self._registry = registry
self._pool = connection_pool
self._circuit_breakers: dict[str, AsyncCircuitBreaker] = {}
self._tool_to_server: dict[str, str] = {}
self._lock = asyncio.Lock()
def _get_circuit_breaker(
self,
server_name: str,
config: MCPServerConfig,
) -> AsyncCircuitBreaker:
"""Get or create a circuit breaker for a server."""
if server_name not in self._circuit_breakers:
self._circuit_breakers[server_name] = AsyncCircuitBreaker(
fail_max=config.circuit_breaker_threshold,
reset_timeout=config.circuit_breaker_timeout,
name=f"mcp-{server_name}",
)
return self._circuit_breakers[server_name]
async def register_tool_mapping(
self,
tool_name: str,
server_name: str,
) -> None:
"""
Register a mapping from tool name to server.
Args:
tool_name: Name of the tool
server_name: Name of the server providing the tool
"""
async with self._lock:
self._tool_to_server[tool_name] = server_name
logger.debug("Registered tool %s -> server %s", tool_name, server_name)
async def discover_tools(self) -> None:
"""
Discover all tools from registered servers and build mappings.
"""
for server_name in self._registry.list_enabled_servers():
try:
config = self._registry.get(server_name)
connection = await self._pool.get_connection(server_name, config)
# Fetch tools from server
tools = await self._fetch_tools_from_server(connection)
# Update registry with capabilities
self._registry.set_capabilities(
server_name,
tools=[t.to_dict() for t in tools],
)
# Update tool mappings
for tool in tools:
await self.register_tool_mapping(tool.name, server_name)
logger.info(
"Discovered %d tools from server %s",
len(tools),
server_name,
)
except Exception as e:
logger.warning(
"Failed to discover tools from %s: %s",
server_name,
e,
)
async def _fetch_tools_from_server(
self,
connection: MCPConnection,
) -> list[ToolInfo]:
"""Fetch available tools from an MCP server."""
try:
response = await connection.execute_request(
"GET",
"/mcp/tools",
)
tools = []
for tool_data in response.get("tools", []):
tools.append(
ToolInfo(
name=tool_data.get("name", ""),
description=tool_data.get("description"),
server_name=connection.server_name,
input_schema=tool_data.get("inputSchema"),
)
)
return tools
except Exception as e:
logger.warning(
"Error fetching tools from %s: %s",
connection.server_name,
e,
)
return []
def find_server_for_tool(self, tool_name: str) -> str | None:
"""
Find which server provides a specific tool.
Args:
tool_name: Name of the tool
Returns:
Server name or None if not found
"""
return self._tool_to_server.get(tool_name)
async def call_tool(
self,
server_name: str,
tool_name: str,
arguments: dict[str, Any] | None = None,
timeout: float | None = None,
) -> ToolResult:
"""
Call a tool on a specific server.
Args:
server_name: Name of the MCP server
tool_name: Name of the tool to call
arguments: Tool arguments
timeout: Optional timeout override
Returns:
Tool execution result
"""
start_time = time.time()
request_id = str(uuid.uuid4())
logger.debug(
"Tool call [%s]: %s.%s with args %s",
request_id,
server_name,
tool_name,
arguments,
)
try:
config = self._registry.get(server_name)
circuit_breaker = self._get_circuit_breaker(server_name, config)
# Check circuit breaker state
if circuit_breaker.is_open():
raise MCPCircuitOpenError(
server_name=server_name,
failure_count=circuit_breaker.fail_counter,
reset_timeout=config.circuit_breaker_timeout,
)
# Execute with retry logic
result = await self._execute_with_retry(
server_name=server_name,
config=config,
tool_name=tool_name,
arguments=arguments or {},
timeout=timeout,
circuit_breaker=circuit_breaker,
)
execution_time = (time.time() - start_time) * 1000
return ToolResult(
success=True,
data=result,
tool_name=tool_name,
server_name=server_name,
execution_time_ms=execution_time,
request_id=request_id,
)
except MCPCircuitOpenError:
raise
except MCPError as e:
execution_time = (time.time() - start_time) * 1000
logger.error(
"Tool call failed [%s]: %s.%s - %s",
request_id,
server_name,
tool_name,
e,
)
return ToolResult(
success=False,
error=str(e),
error_code=type(e).__name__,
tool_name=tool_name,
server_name=server_name,
execution_time_ms=execution_time,
request_id=request_id,
)
except Exception as e:
execution_time = (time.time() - start_time) * 1000
logger.exception(
"Unexpected error in tool call [%s]: %s.%s",
request_id,
server_name,
tool_name,
)
return ToolResult(
success=False,
error=str(e),
error_code="UnexpectedError",
tool_name=tool_name,
server_name=server_name,
execution_time_ms=execution_time,
request_id=request_id,
)
async def _execute_with_retry(
self,
server_name: str,
config: MCPServerConfig,
tool_name: str,
arguments: dict[str, Any],
timeout: float | None,
circuit_breaker: AsyncCircuitBreaker,
) -> Any:
"""Execute tool call with retry logic."""
last_error: Exception | None = None
attempts = 0
max_attempts = config.retry_attempts + 1 # +1 for initial attempt
while attempts < max_attempts:
attempts += 1
try:
# Use circuit breaker to track failures
result = await self._execute_tool_call(
server_name=server_name,
config=config,
tool_name=tool_name,
arguments=arguments,
timeout=timeout,
)
# Success - record it
await circuit_breaker.success()
return result
except MCPCircuitOpenError:
raise
except MCPTimeoutError:
# Timeout - don't retry
await circuit_breaker.failure()
raise
except MCPToolError:
# Tool-level error - don't retry (user error)
raise
except Exception as e:
last_error = e
await circuit_breaker.failure()
if attempts < max_attempts:
delay = self._calculate_retry_delay(attempts, config)
logger.warning(
"Tool call attempt %d/%d failed for %s.%s: %s. "
"Retrying in %.1fs",
attempts,
max_attempts,
server_name,
tool_name,
e,
delay,
)
await asyncio.sleep(delay)
# All attempts failed
raise MCPToolError(
f"Tool call failed after {max_attempts} attempts",
server_name=server_name,
tool_name=tool_name,
tool_args=arguments,
details={"last_error": str(last_error)},
)
def _calculate_retry_delay(
self,
attempt: int,
config: MCPServerConfig,
) -> float:
"""Calculate exponential backoff delay with jitter."""
import random
delay = config.retry_delay * (2 ** (attempt - 1))
delay = min(delay, config.retry_max_delay)
# Add jitter (±25%)
jitter = delay * 0.25 * (random.random() * 2 - 1)
return max(0.1, delay + jitter)
async def _execute_tool_call(
self,
server_name: str,
config: MCPServerConfig,
tool_name: str,
arguments: dict[str, Any],
timeout: float | None,
) -> Any:
"""Execute a single tool call."""
connection = await self._pool.get_connection(server_name, config)
# Build MCP tool call request
request_body = {
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": tool_name,
"arguments": arguments,
},
"id": str(uuid.uuid4()),
}
response = await connection.execute_request(
method="POST",
path="/mcp",
data=request_body,
timeout=timeout,
)
# Handle JSON-RPC response
if "error" in response:
error = response["error"]
raise MCPToolError(
error.get("message", "Tool execution failed"),
server_name=server_name,
tool_name=tool_name,
tool_args=arguments,
error_code=str(error.get("code", "UNKNOWN")),
)
return response.get("result")
async def route_tool(
self,
tool_name: str,
arguments: dict[str, Any] | None = None,
timeout: float | None = None,
) -> ToolResult:
"""
Route a tool call to the appropriate server.
Automatically discovers which server provides the tool.
Args:
tool_name: Name of the tool to call
arguments: Tool arguments
timeout: Optional timeout override
Returns:
Tool execution result
Raises:
MCPToolNotFoundError: If no server provides the tool
"""
server_name = self.find_server_for_tool(tool_name)
if server_name is None:
# Try to find from registry
server_name = self._registry.find_server_for_tool(tool_name)
if server_name is None:
raise MCPToolNotFoundError(
tool_name=tool_name,
available_tools=list(self._tool_to_server.keys()),
)
return await self.call_tool(
server_name=server_name,
tool_name=tool_name,
arguments=arguments,
timeout=timeout,
)
async def list_all_tools(self) -> list[ToolInfo]:
"""
Get all available tools from all servers.
Returns:
List of tool information
"""
tools = []
all_server_tools = self._registry.get_all_tools()
for server_name, server_tools in all_server_tools.items():
for tool_data in server_tools:
tools.append(
ToolInfo(
name=tool_data.get("name", ""),
description=tool_data.get("description"),
server_name=server_name,
input_schema=tool_data.get("input_schema"),
)
)
return tools
def get_circuit_breaker_status(self) -> dict[str, dict[str, Any]]:
"""Get status of all circuit breakers."""
return {
name: {
"state": cb.current_state,
"failure_count": cb.fail_counter,
}
for name, cb in self._circuit_breakers.items()
}
async def reset_circuit_breaker(self, server_name: str) -> bool:
"""
Manually reset a circuit breaker.
Args:
server_name: Name of the server
Returns:
True if circuit breaker was reset
"""
async with self._lock:
if server_name in self._circuit_breakers:
# Reset by removing (will be recreated on next call)
del self._circuit_breakers[server_name]
logger.info("Reset circuit breaker for %s", server_name)
return True
return False