- Connect to MCP servers concurrently instead of sequentially - Reduce retry settings in test mode (IS_TEST=True): - 1 attempt instead of 3 - 100ms retry delay instead of 1s - 2s timeout instead of 30-120s Reduces MCP E2E test time from ~16s to under 1s. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
439 lines
13 KiB
Python
439 lines
13 KiB
Python
"""
|
|
MCP Client Manager
|
|
|
|
Main facade for all MCP operations. Manages server connections,
|
|
tool discovery, and provides a unified interface for tool calls.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
from .config import MCPConfig, MCPServerConfig, load_mcp_config
|
|
from .connection import ConnectionPool, ConnectionState
|
|
from .exceptions import MCPServerNotFoundError
|
|
from .registry import MCPServerRegistry, get_registry
|
|
from .routing import ToolInfo, ToolResult, ToolRouter
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class ServerHealth:
|
|
"""Health status for an MCP server."""
|
|
|
|
name: str
|
|
healthy: bool
|
|
state: str
|
|
url: str
|
|
error: str | None = None
|
|
tools_count: int = 0
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
"""Convert to dictionary."""
|
|
return {
|
|
"name": self.name,
|
|
"healthy": self.healthy,
|
|
"state": self.state,
|
|
"url": self.url,
|
|
"error": self.error,
|
|
"tools_count": self.tools_count,
|
|
}
|
|
|
|
|
|
class MCPClientManager:
|
|
"""
|
|
Central manager for all MCP client operations.
|
|
|
|
Provides a unified interface for:
|
|
- Connecting to MCP servers
|
|
- Discovering and calling tools
|
|
- Health monitoring
|
|
- Connection lifecycle management
|
|
|
|
This is the main entry point for MCP operations in the application.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: MCPConfig | None = None,
|
|
registry: MCPServerRegistry | None = None,
|
|
) -> None:
|
|
"""
|
|
Initialize the MCP client manager.
|
|
|
|
Args:
|
|
config: Optional MCP configuration. If None, loads from default.
|
|
registry: Optional registry instance. If None, uses singleton.
|
|
"""
|
|
self._registry = registry or get_registry()
|
|
self._pool = ConnectionPool()
|
|
self._router: ToolRouter | None = None
|
|
self._initialized = False
|
|
self._lock = asyncio.Lock()
|
|
|
|
# Load configuration if provided
|
|
if config is not None:
|
|
self._registry.load_config(config)
|
|
|
|
@property
|
|
def is_initialized(self) -> bool:
|
|
"""Check if the manager is initialized."""
|
|
return self._initialized
|
|
|
|
async def initialize(self, config: MCPConfig | None = None) -> None:
|
|
"""
|
|
Initialize the MCP client manager.
|
|
|
|
Loads configuration, creates connections, and discovers tools.
|
|
|
|
Args:
|
|
config: Optional configuration to load
|
|
"""
|
|
async with self._lock:
|
|
if self._initialized:
|
|
logger.warning("MCPClientManager already initialized")
|
|
return
|
|
|
|
logger.info("Initializing MCP Client Manager")
|
|
|
|
# Load configuration
|
|
if config is not None:
|
|
self._registry.load_config(config)
|
|
elif len(self._registry.list_servers()) == 0:
|
|
# Try to load from default location
|
|
self._registry.load_config(load_mcp_config())
|
|
|
|
# Create router
|
|
self._router = ToolRouter(self._registry, self._pool)
|
|
|
|
# Connect to all enabled servers
|
|
await self._connect_all_servers()
|
|
|
|
# Discover tools from all servers
|
|
if self._router:
|
|
await self._router.discover_tools()
|
|
|
|
self._initialized = True
|
|
logger.info(
|
|
"MCP Client Manager initialized with %d servers",
|
|
len(self._registry.list_enabled_servers()),
|
|
)
|
|
|
|
async def _connect_all_servers(self) -> None:
|
|
"""Connect to all enabled MCP servers concurrently."""
|
|
import asyncio
|
|
|
|
enabled_servers = self._registry.get_enabled_configs()
|
|
|
|
async def connect_server(name: str, config: "MCPServerConfig") -> None:
|
|
try:
|
|
await self._pool.get_connection(name, config)
|
|
logger.info("Connected to MCP server: %s", name)
|
|
except Exception as e:
|
|
logger.error("Failed to connect to MCP server %s: %s", name, e)
|
|
|
|
# Connect to all servers concurrently for faster startup
|
|
await asyncio.gather(
|
|
*(connect_server(name, config) for name, config in enabled_servers.items()),
|
|
return_exceptions=True,
|
|
)
|
|
|
|
async def shutdown(self) -> None:
|
|
"""
|
|
Shutdown the MCP client manager.
|
|
|
|
Closes all connections and cleans up resources.
|
|
"""
|
|
async with self._lock:
|
|
if not self._initialized:
|
|
return
|
|
|
|
logger.info("Shutting down MCP Client Manager")
|
|
|
|
await self._pool.close_all()
|
|
self._initialized = False
|
|
|
|
logger.info("MCP Client Manager shutdown complete")
|
|
|
|
async def connect(self, server_name: str) -> None:
|
|
"""
|
|
Connect to a specific MCP server.
|
|
|
|
Args:
|
|
server_name: Name of the server to connect to
|
|
|
|
Raises:
|
|
MCPServerNotFoundError: If server is not registered
|
|
"""
|
|
config = self._registry.get(server_name)
|
|
await self._pool.get_connection(server_name, config)
|
|
logger.info("Connected to MCP server: %s", server_name)
|
|
|
|
async def disconnect(self, server_name: str) -> None:
|
|
"""
|
|
Disconnect from a specific MCP server.
|
|
|
|
Args:
|
|
server_name: Name of the server to disconnect from
|
|
"""
|
|
await self._pool.close_connection(server_name)
|
|
logger.info("Disconnected from MCP server: %s", server_name)
|
|
|
|
async def disconnect_all(self) -> None:
|
|
"""Disconnect from all MCP servers."""
|
|
await self._pool.close_all()
|
|
|
|
async def call_tool(
|
|
self,
|
|
server: str,
|
|
tool: str,
|
|
args: dict[str, Any] | None = None,
|
|
timeout: float | None = None,
|
|
) -> ToolResult:
|
|
"""
|
|
Call a tool on a specific MCP server.
|
|
|
|
Args:
|
|
server: Name of the MCP server
|
|
tool: Name of the tool to call
|
|
args: Tool arguments
|
|
timeout: Optional timeout override
|
|
|
|
Returns:
|
|
Tool execution result
|
|
"""
|
|
if not self._initialized or self._router is None:
|
|
await self.initialize()
|
|
|
|
assert self._router is not None # Guaranteed after initialize()
|
|
return await self._router.call_tool(
|
|
server_name=server,
|
|
tool_name=tool,
|
|
arguments=args,
|
|
timeout=timeout,
|
|
)
|
|
|
|
async def route_tool(
|
|
self,
|
|
tool: str,
|
|
args: dict[str, Any] | None = None,
|
|
timeout: float | None = None,
|
|
) -> ToolResult:
|
|
"""
|
|
Route a tool call to the appropriate server automatically.
|
|
|
|
Args:
|
|
tool: Name of the tool to call
|
|
args: Tool arguments
|
|
timeout: Optional timeout override
|
|
|
|
Returns:
|
|
Tool execution result
|
|
"""
|
|
if not self._initialized or self._router is None:
|
|
await self.initialize()
|
|
|
|
assert self._router is not None # Guaranteed after initialize()
|
|
return await self._router.route_tool(
|
|
tool_name=tool,
|
|
arguments=args,
|
|
timeout=timeout,
|
|
)
|
|
|
|
async def list_tools(self, server: str) -> list[ToolInfo]:
|
|
"""
|
|
List all tools available on a specific server.
|
|
|
|
Args:
|
|
server: Name of the MCP server
|
|
|
|
Returns:
|
|
List of tool information
|
|
"""
|
|
capabilities = await self._registry.get_capabilities(server)
|
|
return [
|
|
ToolInfo(
|
|
name=t.get("name", ""),
|
|
description=t.get("description"),
|
|
server_name=server,
|
|
input_schema=t.get("input_schema"),
|
|
)
|
|
for t in capabilities.tools
|
|
]
|
|
|
|
async def list_all_tools(self) -> list[ToolInfo]:
|
|
"""
|
|
List all tools from all servers.
|
|
|
|
Returns:
|
|
List of tool information
|
|
"""
|
|
if not self._initialized or self._router is None:
|
|
await self.initialize()
|
|
|
|
assert self._router is not None # Guaranteed after initialize()
|
|
return await self._router.list_all_tools()
|
|
|
|
async def health_check(self) -> dict[str, ServerHealth]:
|
|
"""
|
|
Perform health check on all MCP servers.
|
|
|
|
Returns:
|
|
Dict mapping server names to health status
|
|
"""
|
|
results: dict[str, ServerHealth] = {}
|
|
pool_status = self._pool.get_status()
|
|
pool_health = await self._pool.health_check_all()
|
|
|
|
for server_name in self._registry.list_servers():
|
|
try:
|
|
config = self._registry.get(server_name)
|
|
status = pool_status.get(server_name, {})
|
|
healthy = pool_health.get(server_name, False)
|
|
|
|
capabilities = self._registry.get_cached_capabilities(server_name)
|
|
|
|
results[server_name] = ServerHealth(
|
|
name=server_name,
|
|
healthy=healthy,
|
|
state=status.get("state", ConnectionState.DISCONNECTED.value),
|
|
url=config.url,
|
|
tools_count=len(capabilities.tools),
|
|
)
|
|
except MCPServerNotFoundError:
|
|
pass
|
|
except Exception as e:
|
|
results[server_name] = ServerHealth(
|
|
name=server_name,
|
|
healthy=False,
|
|
state=ConnectionState.ERROR.value,
|
|
url="unknown",
|
|
error=str(e),
|
|
)
|
|
|
|
return results
|
|
|
|
def list_servers(self) -> list[str]:
|
|
"""Get list of all registered server names."""
|
|
return self._registry.list_servers()
|
|
|
|
def list_enabled_servers(self) -> list[str]:
|
|
"""Get list of enabled server names."""
|
|
return self._registry.list_enabled_servers()
|
|
|
|
def get_server_config(self, server_name: str) -> MCPServerConfig:
|
|
"""
|
|
Get configuration for a specific server.
|
|
|
|
Args:
|
|
server_name: Name of the server
|
|
|
|
Returns:
|
|
Server configuration
|
|
|
|
Raises:
|
|
MCPServerNotFoundError: If server is not registered
|
|
"""
|
|
return self._registry.get(server_name)
|
|
|
|
def register_server(
|
|
self,
|
|
name: str,
|
|
config: MCPServerConfig,
|
|
) -> None:
|
|
"""
|
|
Register a new MCP server at runtime.
|
|
|
|
Args:
|
|
name: Unique server name
|
|
config: Server configuration
|
|
"""
|
|
self._registry.register(name, config)
|
|
|
|
def unregister_server(self, name: str) -> bool:
|
|
"""
|
|
Unregister an MCP server.
|
|
|
|
Args:
|
|
name: Server name to unregister
|
|
|
|
Returns:
|
|
True if server was found and removed
|
|
"""
|
|
return self._registry.unregister(name)
|
|
|
|
def get_circuit_breaker_status(self) -> dict[str, dict[str, Any]]:
|
|
"""Get status of all circuit breakers."""
|
|
if self._router is None:
|
|
return {}
|
|
return self._router.get_circuit_breaker_status()
|
|
|
|
async def reset_circuit_breaker(self, server_name: str) -> bool:
|
|
"""
|
|
Reset a circuit breaker for a server.
|
|
|
|
Args:
|
|
server_name: Name of the server
|
|
|
|
Returns:
|
|
True if circuit breaker was reset
|
|
"""
|
|
if self._router is None:
|
|
return False
|
|
return await self._router.reset_circuit_breaker(server_name)
|
|
|
|
|
|
# Singleton instance
|
|
_manager_instance: MCPClientManager | None = None
|
|
_manager_lock = asyncio.Lock()
|
|
|
|
|
|
async def get_mcp_client() -> MCPClientManager:
|
|
"""
|
|
Get the global MCP client manager instance.
|
|
|
|
This is the main dependency injection point for FastAPI.
|
|
Uses proper locking to avoid race conditions in async contexts.
|
|
"""
|
|
global _manager_instance
|
|
|
|
# Use lock for the entire check-and-create operation to avoid race conditions
|
|
async with _manager_lock:
|
|
if _manager_instance is None:
|
|
_manager_instance = MCPClientManager()
|
|
await _manager_instance.initialize()
|
|
|
|
return _manager_instance
|
|
|
|
|
|
async def shutdown_mcp_client() -> None:
|
|
"""Shutdown the global MCP client manager."""
|
|
global _manager_instance
|
|
|
|
# Use lock to prevent race with get_mcp_client()
|
|
async with _manager_lock:
|
|
if _manager_instance is not None:
|
|
await _manager_instance.shutdown()
|
|
_manager_instance = None
|
|
|
|
|
|
async def reset_mcp_client() -> None:
|
|
"""
|
|
Reset the global MCP client manager (for testing).
|
|
|
|
This is an async function to properly acquire the manager lock
|
|
and avoid race conditions with get_mcp_client().
|
|
"""
|
|
global _manager_instance
|
|
|
|
async with _manager_lock:
|
|
if _manager_instance is not None:
|
|
# Shutdown gracefully before resetting
|
|
try:
|
|
await _manager_instance.shutdown()
|
|
except Exception: # noqa: S110
|
|
pass # Ignore errors during test cleanup
|
|
_manager_instance = None
|