""" MCP Tool Call Routing Routes tool calls to appropriate servers with retry logic, circuit breakers, and request/response serialization. """ import asyncio import logging import time import uuid from dataclasses import dataclass, field from enum import Enum from typing import Any from .config import MCPServerConfig from .connection import ConnectionPool, MCPConnection from .exceptions import ( MCPCircuitOpenError, MCPError, MCPTimeoutError, MCPToolError, MCPToolNotFoundError, ) from .registry import MCPServerRegistry logger = logging.getLogger(__name__) class CircuitState(Enum): """Circuit breaker states.""" CLOSED = "closed" OPEN = "open" HALF_OPEN = "half-open" class AsyncCircuitBreaker: """ Async-compatible circuit breaker implementation. Unlike pybreaker which wraps sync functions, this implementation provides explicit success/failure tracking for async code. """ def __init__( self, fail_max: int = 5, reset_timeout: float = 30.0, name: str = "", ) -> None: """ Initialize circuit breaker. Args: fail_max: Maximum failures before opening circuit reset_timeout: Seconds to wait before trying again name: Name for logging """ self.fail_max = fail_max self.reset_timeout = reset_timeout self.name = name self._state = CircuitState.CLOSED self._fail_counter = 0 self._last_failure_time: float | None = None self._lock = asyncio.Lock() @property def current_state(self) -> str: """Get current state as string.""" # Check if we should transition from OPEN to HALF_OPEN if self._state == CircuitState.OPEN: if self._should_try_reset(): return CircuitState.HALF_OPEN.value return self._state.value @property def fail_counter(self) -> int: """Get current failure count.""" return self._fail_counter def _should_try_reset(self) -> bool: """Check if enough time has passed to try resetting.""" if self._last_failure_time is None: return True return (time.time() - self._last_failure_time) >= self.reset_timeout async def success(self) -> None: """Record a successful call.""" async with self._lock: self._fail_counter = 0 self._state = CircuitState.CLOSED self._last_failure_time = None async def failure(self) -> None: """Record a failed call.""" async with self._lock: self._fail_counter += 1 self._last_failure_time = time.time() if self._fail_counter >= self.fail_max: self._state = CircuitState.OPEN logger.warning( "Circuit breaker %s opened after %d failures", self.name, self._fail_counter, ) def is_open(self) -> bool: """Check if circuit is open (not allowing calls).""" if self._state == CircuitState.OPEN: return not self._should_try_reset() return False async def reset(self) -> None: """Manually reset the circuit breaker.""" async with self._lock: self._state = CircuitState.CLOSED self._fail_counter = 0 self._last_failure_time = None @dataclass class ToolInfo: """Information about an available tool.""" name: str description: str | None = None server_name: str | None = None input_schema: dict[str, Any] | None = None def to_dict(self) -> dict[str, Any]: """Convert to dictionary.""" return { "name": self.name, "description": self.description, "server_name": self.server_name, "input_schema": self.input_schema, } @dataclass class ToolResult: """Result of a tool execution.""" success: bool data: Any = None error: str | None = None error_code: str | None = None tool_name: str | None = None server_name: str | None = None execution_time_ms: float = 0.0 request_id: str = field(default_factory=lambda: str(uuid.uuid4())) def to_dict(self) -> dict[str, Any]: """Convert to dictionary.""" return { "success": self.success, "data": self.data, "error": self.error, "error_code": self.error_code, "tool_name": self.tool_name, "server_name": self.server_name, "execution_time_ms": self.execution_time_ms, "request_id": self.request_id, } class ToolRouter: """ Routes tool calls to the appropriate MCP server. Features: - Tool name to server mapping - Retry logic with exponential backoff - Circuit breaker pattern for fault tolerance - Request/response serialization - Execution timing and metrics """ def __init__( self, registry: MCPServerRegistry, connection_pool: ConnectionPool, ) -> None: """ Initialize the tool router. Args: registry: MCP server registry connection_pool: Connection pool for servers """ self._registry = registry self._pool = connection_pool self._circuit_breakers: dict[str, AsyncCircuitBreaker] = {} self._tool_to_server: dict[str, str] = {} self._lock = asyncio.Lock() def _get_circuit_breaker( self, server_name: str, config: MCPServerConfig, ) -> AsyncCircuitBreaker: """Get or create a circuit breaker for a server.""" if server_name not in self._circuit_breakers: self._circuit_breakers[server_name] = AsyncCircuitBreaker( fail_max=config.circuit_breaker_threshold, reset_timeout=config.circuit_breaker_timeout, name=f"mcp-{server_name}", ) return self._circuit_breakers[server_name] async def register_tool_mapping( self, tool_name: str, server_name: str, ) -> None: """ Register a mapping from tool name to server. Args: tool_name: Name of the tool server_name: Name of the server providing the tool """ async with self._lock: self._tool_to_server[tool_name] = server_name logger.debug("Registered tool %s -> server %s", tool_name, server_name) async def discover_tools(self) -> None: """ Discover all tools from registered servers and build mappings. """ for server_name in self._registry.list_enabled_servers(): try: config = self._registry.get(server_name) connection = await self._pool.get_connection(server_name, config) # Fetch tools from server tools = await self._fetch_tools_from_server(connection) # Update registry with capabilities self._registry.set_capabilities( server_name, tools=[t.to_dict() for t in tools], ) # Update tool mappings for tool in tools: await self.register_tool_mapping(tool.name, server_name) logger.info( "Discovered %d tools from server %s", len(tools), server_name, ) except Exception as e: logger.warning( "Failed to discover tools from %s: %s", server_name, e, ) async def _fetch_tools_from_server( self, connection: MCPConnection, ) -> list[ToolInfo]: """Fetch available tools from an MCP server.""" try: response = await connection.execute_request( "GET", "/mcp/tools", ) tools = [] for tool_data in response.get("tools", []): tools.append( ToolInfo( name=tool_data.get("name", ""), description=tool_data.get("description"), server_name=connection.server_name, input_schema=tool_data.get("inputSchema"), ) ) return tools except Exception as e: logger.warning( "Error fetching tools from %s: %s", connection.server_name, e, ) return [] def find_server_for_tool(self, tool_name: str) -> str | None: """ Find which server provides a specific tool. Args: tool_name: Name of the tool Returns: Server name or None if not found """ return self._tool_to_server.get(tool_name) async def call_tool( self, server_name: str, tool_name: str, arguments: dict[str, Any] | None = None, timeout: float | None = None, ) -> ToolResult: """ Call a tool on a specific server. Args: server_name: Name of the MCP server tool_name: Name of the tool to call arguments: Tool arguments timeout: Optional timeout override Returns: Tool execution result """ start_time = time.time() request_id = str(uuid.uuid4()) logger.debug( "Tool call [%s]: %s.%s with args %s", request_id, server_name, tool_name, arguments, ) try: config = self._registry.get(server_name) circuit_breaker = self._get_circuit_breaker(server_name, config) # Check circuit breaker state if circuit_breaker.is_open(): raise MCPCircuitOpenError( server_name=server_name, failure_count=circuit_breaker.fail_counter, reset_timeout=config.circuit_breaker_timeout, ) # Execute with retry logic result = await self._execute_with_retry( server_name=server_name, config=config, tool_name=tool_name, arguments=arguments or {}, timeout=timeout, circuit_breaker=circuit_breaker, ) execution_time = (time.time() - start_time) * 1000 return ToolResult( success=True, data=result, tool_name=tool_name, server_name=server_name, execution_time_ms=execution_time, request_id=request_id, ) except MCPCircuitOpenError: raise except MCPError as e: execution_time = (time.time() - start_time) * 1000 logger.error( "Tool call failed [%s]: %s.%s - %s", request_id, server_name, tool_name, e, ) return ToolResult( success=False, error=str(e), error_code=type(e).__name__, tool_name=tool_name, server_name=server_name, execution_time_ms=execution_time, request_id=request_id, ) except Exception as e: execution_time = (time.time() - start_time) * 1000 logger.exception( "Unexpected error in tool call [%s]: %s.%s", request_id, server_name, tool_name, ) return ToolResult( success=False, error=str(e), error_code="UnexpectedError", tool_name=tool_name, server_name=server_name, execution_time_ms=execution_time, request_id=request_id, ) async def _execute_with_retry( self, server_name: str, config: MCPServerConfig, tool_name: str, arguments: dict[str, Any], timeout: float | None, circuit_breaker: AsyncCircuitBreaker, ) -> Any: """Execute tool call with retry logic.""" last_error: Exception | None = None attempts = 0 max_attempts = config.retry_attempts + 1 # +1 for initial attempt while attempts < max_attempts: attempts += 1 try: # Use circuit breaker to track failures result = await self._execute_tool_call( server_name=server_name, config=config, tool_name=tool_name, arguments=arguments, timeout=timeout, ) # Success - record it await circuit_breaker.success() return result except MCPCircuitOpenError: raise except MCPTimeoutError: # Timeout - don't retry await circuit_breaker.failure() raise except MCPToolError: # Tool-level error - don't retry (user error) raise except Exception as e: last_error = e await circuit_breaker.failure() if attempts < max_attempts: delay = self._calculate_retry_delay(attempts, config) logger.warning( "Tool call attempt %d/%d failed for %s.%s: %s. " "Retrying in %.1fs", attempts, max_attempts, server_name, tool_name, e, delay, ) await asyncio.sleep(delay) # All attempts failed raise MCPToolError( f"Tool call failed after {max_attempts} attempts", server_name=server_name, tool_name=tool_name, tool_args=arguments, details={"last_error": str(last_error)}, ) def _calculate_retry_delay( self, attempt: int, config: MCPServerConfig, ) -> float: """Calculate exponential backoff delay with jitter.""" import random delay = config.retry_delay * (2 ** (attempt - 1)) delay = min(delay, config.retry_max_delay) # Add jitter (±25%) jitter = delay * 0.25 * (random.random() * 2 - 1) return max(0.1, delay + jitter) async def _execute_tool_call( self, server_name: str, config: MCPServerConfig, tool_name: str, arguments: dict[str, Any], timeout: float | None, ) -> Any: """Execute a single tool call.""" connection = await self._pool.get_connection(server_name, config) # Build MCP tool call request request_body = { "jsonrpc": "2.0", "method": "tools/call", "params": { "name": tool_name, "arguments": arguments, }, "id": str(uuid.uuid4()), } response = await connection.execute_request( method="POST", path="/mcp", data=request_body, timeout=timeout, ) # Handle JSON-RPC response if "error" in response: error = response["error"] raise MCPToolError( error.get("message", "Tool execution failed"), server_name=server_name, tool_name=tool_name, tool_args=arguments, error_code=str(error.get("code", "UNKNOWN")), ) return response.get("result") async def route_tool( self, tool_name: str, arguments: dict[str, Any] | None = None, timeout: float | None = None, ) -> ToolResult: """ Route a tool call to the appropriate server. Automatically discovers which server provides the tool. Args: tool_name: Name of the tool to call arguments: Tool arguments timeout: Optional timeout override Returns: Tool execution result Raises: MCPToolNotFoundError: If no server provides the tool """ server_name = self.find_server_for_tool(tool_name) if server_name is None: # Try to find from registry server_name = self._registry.find_server_for_tool(tool_name) if server_name is None: raise MCPToolNotFoundError( tool_name=tool_name, available_tools=list(self._tool_to_server.keys()), ) return await self.call_tool( server_name=server_name, tool_name=tool_name, arguments=arguments, timeout=timeout, ) async def list_all_tools(self) -> list[ToolInfo]: """ Get all available tools from all servers. Returns: List of tool information """ tools = [] all_server_tools = self._registry.get_all_tools() for server_name, server_tools in all_server_tools.items(): for tool_data in server_tools: tools.append( ToolInfo( name=tool_data.get("name", ""), description=tool_data.get("description"), server_name=server_name, input_schema=tool_data.get("input_schema"), ) ) return tools def get_circuit_breaker_status(self) -> dict[str, dict[str, Any]]: """Get status of all circuit breakers.""" return { name: { "state": cb.current_state, "failure_count": cb.fail_counter, } for name, cb in self._circuit_breakers.items() } async def reset_circuit_breaker(self, server_name: str) -> bool: """ Manually reset a circuit breaker. Args: server_name: Name of the server Returns: True if circuit breaker was reset """ async with self._lock: if server_name in self._circuit_breakers: # Reset by removing (will be recreated on next call) del self._circuit_breakers[server_name] logger.info("Reset circuit breaker for %s", server_name) return True return False