forked from cardosofelipe/fast-next-template
feat(backend): implement MCP client infrastructure (#55)
Core MCP client implementation with comprehensive tooling:
**Services:**
- MCPClientManager: Main facade for all MCP operations
- MCPServerRegistry: Thread-safe singleton for server configs
- ConnectionPool: Connection pooling with auto-reconnection
- ToolRouter: Automatic tool routing with circuit breaker
- AsyncCircuitBreaker: Custom async-compatible circuit breaker
**Configuration:**
- YAML-based config with Pydantic models
- Environment variable expansion support
- Transport types: HTTP, SSE, STDIO
**API Endpoints:**
- GET /mcp/servers - List all MCP servers
- GET /mcp/servers/{name}/tools - List server tools
- GET /mcp/tools - List all tools from all servers
- GET /mcp/health - Health check all servers
- POST /mcp/call - Execute tool (admin only)
- GET /mcp/circuit-breakers - Circuit breaker status
- POST /mcp/circuit-breakers/{name}/reset - Reset circuit breaker
- POST /mcp/servers/{name}/reconnect - Force reconnection
**Testing:**
- 156 unit tests with comprehensive coverage
- Tests for all services, routes, and error handling
- Proper mocking and async test support
**Documentation:**
- MCP_CLIENT.md with usage examples
- Phase 2+ workflow documentation
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
85
backend/app/services/mcp/__init__.py
Normal file
85
backend/app/services/mcp/__init__.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""
|
||||
MCP Client Service Package
|
||||
|
||||
Provides infrastructure for communicating with MCP (Model Context Protocol)
|
||||
servers. This is the foundation for AI agent tool integration.
|
||||
|
||||
Usage:
|
||||
from app.services.mcp import get_mcp_client, MCPClientManager
|
||||
|
||||
# In FastAPI route
|
||||
async def my_route(mcp: MCPClientManager = Depends(get_mcp_client)):
|
||||
result = await mcp.call_tool("llm-gateway", "chat", {"prompt": "Hello"})
|
||||
|
||||
# Direct usage
|
||||
manager = MCPClientManager()
|
||||
await manager.initialize()
|
||||
result = await manager.call_tool("issues", "create_issue", {...})
|
||||
await manager.shutdown()
|
||||
"""
|
||||
|
||||
from .client_manager import (
|
||||
MCPClientManager,
|
||||
ServerHealth,
|
||||
get_mcp_client,
|
||||
reset_mcp_client,
|
||||
shutdown_mcp_client,
|
||||
)
|
||||
from .config import (
|
||||
MCPConfig,
|
||||
MCPServerConfig,
|
||||
TransportType,
|
||||
create_default_config,
|
||||
load_mcp_config,
|
||||
)
|
||||
from .connection import ConnectionPool, ConnectionState, MCPConnection
|
||||
from .exceptions import (
|
||||
MCPCircuitOpenError,
|
||||
MCPConnectionError,
|
||||
MCPError,
|
||||
MCPServerNotFoundError,
|
||||
MCPTimeoutError,
|
||||
MCPToolError,
|
||||
MCPToolNotFoundError,
|
||||
MCPValidationError,
|
||||
)
|
||||
from .registry import MCPServerRegistry, ServerCapabilities, get_registry
|
||||
from .routing import AsyncCircuitBreaker, CircuitState, ToolInfo, ToolResult, ToolRouter
|
||||
|
||||
__all__ = [
|
||||
# Main facade
|
||||
"MCPClientManager",
|
||||
"get_mcp_client",
|
||||
"shutdown_mcp_client",
|
||||
"reset_mcp_client",
|
||||
"ServerHealth",
|
||||
# Configuration
|
||||
"MCPConfig",
|
||||
"MCPServerConfig",
|
||||
"TransportType",
|
||||
"load_mcp_config",
|
||||
"create_default_config",
|
||||
# Registry
|
||||
"MCPServerRegistry",
|
||||
"ServerCapabilities",
|
||||
"get_registry",
|
||||
# Connection
|
||||
"ConnectionPool",
|
||||
"ConnectionState",
|
||||
"MCPConnection",
|
||||
# Routing
|
||||
"ToolRouter",
|
||||
"ToolInfo",
|
||||
"ToolResult",
|
||||
"AsyncCircuitBreaker",
|
||||
"CircuitState",
|
||||
# Exceptions
|
||||
"MCPError",
|
||||
"MCPConnectionError",
|
||||
"MCPTimeoutError",
|
||||
"MCPToolError",
|
||||
"MCPServerNotFoundError",
|
||||
"MCPToolNotFoundError",
|
||||
"MCPCircuitOpenError",
|
||||
"MCPValidationError",
|
||||
]
|
||||
417
backend/app/services/mcp/client_manager.py
Normal file
417
backend/app/services/mcp/client_manager.py
Normal file
@@ -0,0 +1,417 @@
|
||||
"""
|
||||
MCP Client Manager
|
||||
|
||||
Main facade for all MCP operations. Manages server connections,
|
||||
tool discovery, and provides a unified interface for tool calls.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from .config import MCPConfig, MCPServerConfig, load_mcp_config
|
||||
from .connection import ConnectionPool, ConnectionState
|
||||
from .exceptions import MCPServerNotFoundError
|
||||
from .registry import MCPServerRegistry, get_registry
|
||||
from .routing import ToolInfo, ToolResult, ToolRouter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServerHealth:
|
||||
"""Health status for an MCP server."""
|
||||
|
||||
name: str
|
||||
healthy: bool
|
||||
state: str
|
||||
url: str
|
||||
error: str | None = None
|
||||
tools_count: int = 0
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"healthy": self.healthy,
|
||||
"state": self.state,
|
||||
"url": self.url,
|
||||
"error": self.error,
|
||||
"tools_count": self.tools_count,
|
||||
}
|
||||
|
||||
|
||||
class MCPClientManager:
|
||||
"""
|
||||
Central manager for all MCP client operations.
|
||||
|
||||
Provides a unified interface for:
|
||||
- Connecting to MCP servers
|
||||
- Discovering and calling tools
|
||||
- Health monitoring
|
||||
- Connection lifecycle management
|
||||
|
||||
This is the main entry point for MCP operations in the application.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MCPConfig | None = None,
|
||||
registry: MCPServerRegistry | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the MCP client manager.
|
||||
|
||||
Args:
|
||||
config: Optional MCP configuration. If None, loads from default.
|
||||
registry: Optional registry instance. If None, uses singleton.
|
||||
"""
|
||||
self._registry = registry or get_registry()
|
||||
self._pool = ConnectionPool()
|
||||
self._router: ToolRouter | None = None
|
||||
self._initialized = False
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Load configuration if provided
|
||||
if config is not None:
|
||||
self._registry.load_config(config)
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if the manager is initialized."""
|
||||
return self._initialized
|
||||
|
||||
async def initialize(self, config: MCPConfig | None = None) -> None:
|
||||
"""
|
||||
Initialize the MCP client manager.
|
||||
|
||||
Loads configuration, creates connections, and discovers tools.
|
||||
|
||||
Args:
|
||||
config: Optional configuration to load
|
||||
"""
|
||||
async with self._lock:
|
||||
if self._initialized:
|
||||
logger.warning("MCPClientManager already initialized")
|
||||
return
|
||||
|
||||
logger.info("Initializing MCP Client Manager")
|
||||
|
||||
# Load configuration
|
||||
if config is not None:
|
||||
self._registry.load_config(config)
|
||||
elif len(self._registry.list_servers()) == 0:
|
||||
# Try to load from default location
|
||||
self._registry.load_config(load_mcp_config())
|
||||
|
||||
# Create router
|
||||
self._router = ToolRouter(self._registry, self._pool)
|
||||
|
||||
# Connect to all enabled servers
|
||||
await self._connect_all_servers()
|
||||
|
||||
# Discover tools from all servers
|
||||
if self._router:
|
||||
await self._router.discover_tools()
|
||||
|
||||
self._initialized = True
|
||||
logger.info(
|
||||
"MCP Client Manager initialized with %d servers",
|
||||
len(self._registry.list_enabled_servers()),
|
||||
)
|
||||
|
||||
async def _connect_all_servers(self) -> None:
|
||||
"""Connect to all enabled MCP servers."""
|
||||
enabled_servers = self._registry.get_enabled_configs()
|
||||
|
||||
for name, config in enabled_servers.items():
|
||||
try:
|
||||
await self._pool.get_connection(name, config)
|
||||
logger.info("Connected to MCP server: %s", name)
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect to MCP server %s: %s", name, e)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""
|
||||
Shutdown the MCP client manager.
|
||||
|
||||
Closes all connections and cleans up resources.
|
||||
"""
|
||||
async with self._lock:
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
logger.info("Shutting down MCP Client Manager")
|
||||
|
||||
await self._pool.close_all()
|
||||
self._initialized = False
|
||||
|
||||
logger.info("MCP Client Manager shutdown complete")
|
||||
|
||||
async def connect(self, server_name: str) -> None:
|
||||
"""
|
||||
Connect to a specific MCP server.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server to connect to
|
||||
|
||||
Raises:
|
||||
MCPServerNotFoundError: If server is not registered
|
||||
"""
|
||||
config = self._registry.get(server_name)
|
||||
await self._pool.get_connection(server_name, config)
|
||||
logger.info("Connected to MCP server: %s", server_name)
|
||||
|
||||
async def disconnect(self, server_name: str) -> None:
|
||||
"""
|
||||
Disconnect from a specific MCP server.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server to disconnect from
|
||||
"""
|
||||
await self._pool.close_connection(server_name)
|
||||
logger.info("Disconnected from MCP server: %s", server_name)
|
||||
|
||||
async def disconnect_all(self) -> None:
|
||||
"""Disconnect from all MCP servers."""
|
||||
await self._pool.close_all()
|
||||
|
||||
async def call_tool(
|
||||
self,
|
||||
server: str,
|
||||
tool: str,
|
||||
args: dict[str, Any] | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> ToolResult:
|
||||
"""
|
||||
Call a tool on a specific MCP server.
|
||||
|
||||
Args:
|
||||
server: Name of the MCP server
|
||||
tool: Name of the tool to call
|
||||
args: Tool arguments
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
"""
|
||||
if not self._initialized or self._router is None:
|
||||
await self.initialize()
|
||||
|
||||
assert self._router is not None # Guaranteed after initialize()
|
||||
return await self._router.call_tool(
|
||||
server_name=server,
|
||||
tool_name=tool,
|
||||
arguments=args,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def route_tool(
|
||||
self,
|
||||
tool: str,
|
||||
args: dict[str, Any] | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> ToolResult:
|
||||
"""
|
||||
Route a tool call to the appropriate server automatically.
|
||||
|
||||
Args:
|
||||
tool: Name of the tool to call
|
||||
args: Tool arguments
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
"""
|
||||
if not self._initialized or self._router is None:
|
||||
await self.initialize()
|
||||
|
||||
assert self._router is not None # Guaranteed after initialize()
|
||||
return await self._router.route_tool(
|
||||
tool_name=tool,
|
||||
arguments=args,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def list_tools(self, server: str) -> list[ToolInfo]:
|
||||
"""
|
||||
List all tools available on a specific server.
|
||||
|
||||
Args:
|
||||
server: Name of the MCP server
|
||||
|
||||
Returns:
|
||||
List of tool information
|
||||
"""
|
||||
capabilities = await self._registry.get_capabilities(server)
|
||||
return [
|
||||
ToolInfo(
|
||||
name=t.get("name", ""),
|
||||
description=t.get("description"),
|
||||
server_name=server,
|
||||
input_schema=t.get("input_schema"),
|
||||
)
|
||||
for t in capabilities.tools
|
||||
]
|
||||
|
||||
async def list_all_tools(self) -> list[ToolInfo]:
|
||||
"""
|
||||
List all tools from all servers.
|
||||
|
||||
Returns:
|
||||
List of tool information
|
||||
"""
|
||||
if not self._initialized or self._router is None:
|
||||
await self.initialize()
|
||||
|
||||
assert self._router is not None # Guaranteed after initialize()
|
||||
return await self._router.list_all_tools()
|
||||
|
||||
async def health_check(self) -> dict[str, ServerHealth]:
|
||||
"""
|
||||
Perform health check on all MCP servers.
|
||||
|
||||
Returns:
|
||||
Dict mapping server names to health status
|
||||
"""
|
||||
results: dict[str, ServerHealth] = {}
|
||||
pool_status = self._pool.get_status()
|
||||
pool_health = await self._pool.health_check_all()
|
||||
|
||||
for server_name in self._registry.list_servers():
|
||||
try:
|
||||
config = self._registry.get(server_name)
|
||||
status = pool_status.get(server_name, {})
|
||||
healthy = pool_health.get(server_name, False)
|
||||
|
||||
capabilities = self._registry.get_cached_capabilities(server_name)
|
||||
|
||||
results[server_name] = ServerHealth(
|
||||
name=server_name,
|
||||
healthy=healthy,
|
||||
state=status.get("state", ConnectionState.DISCONNECTED.value),
|
||||
url=config.url,
|
||||
tools_count=len(capabilities.tools),
|
||||
)
|
||||
except MCPServerNotFoundError:
|
||||
pass
|
||||
except Exception as e:
|
||||
results[server_name] = ServerHealth(
|
||||
name=server_name,
|
||||
healthy=False,
|
||||
state=ConnectionState.ERROR.value,
|
||||
url="unknown",
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def list_servers(self) -> list[str]:
|
||||
"""Get list of all registered server names."""
|
||||
return self._registry.list_servers()
|
||||
|
||||
def list_enabled_servers(self) -> list[str]:
|
||||
"""Get list of enabled server names."""
|
||||
return self._registry.list_enabled_servers()
|
||||
|
||||
def get_server_config(self, server_name: str) -> MCPServerConfig:
|
||||
"""
|
||||
Get configuration for a specific server.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server
|
||||
|
||||
Returns:
|
||||
Server configuration
|
||||
|
||||
Raises:
|
||||
MCPServerNotFoundError: If server is not registered
|
||||
"""
|
||||
return self._registry.get(server_name)
|
||||
|
||||
def register_server(
|
||||
self,
|
||||
name: str,
|
||||
config: MCPServerConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Register a new MCP server at runtime.
|
||||
|
||||
Args:
|
||||
name: Unique server name
|
||||
config: Server configuration
|
||||
"""
|
||||
self._registry.register(name, config)
|
||||
|
||||
def unregister_server(self, name: str) -> bool:
|
||||
"""
|
||||
Unregister an MCP server.
|
||||
|
||||
Args:
|
||||
name: Server name to unregister
|
||||
|
||||
Returns:
|
||||
True if server was found and removed
|
||||
"""
|
||||
return self._registry.unregister(name)
|
||||
|
||||
def get_circuit_breaker_status(self) -> dict[str, dict[str, Any]]:
|
||||
"""Get status of all circuit breakers."""
|
||||
if self._router is None:
|
||||
return {}
|
||||
return self._router.get_circuit_breaker_status()
|
||||
|
||||
async def reset_circuit_breaker(self, server_name: str) -> bool:
|
||||
"""
|
||||
Reset a circuit breaker for a server.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server
|
||||
|
||||
Returns:
|
||||
True if circuit breaker was reset
|
||||
"""
|
||||
if self._router is None:
|
||||
return False
|
||||
return await self._router.reset_circuit_breaker(server_name)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_manager_instance: MCPClientManager | None = None
|
||||
_manager_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_mcp_client() -> MCPClientManager:
|
||||
"""
|
||||
Get the global MCP client manager instance.
|
||||
|
||||
This is the main dependency injection point for FastAPI.
|
||||
Uses proper locking to avoid race conditions in async contexts.
|
||||
"""
|
||||
global _manager_instance
|
||||
|
||||
# Use lock for the entire check-and-create operation to avoid race conditions
|
||||
async with _manager_lock:
|
||||
if _manager_instance is None:
|
||||
_manager_instance = MCPClientManager()
|
||||
await _manager_instance.initialize()
|
||||
|
||||
return _manager_instance
|
||||
|
||||
|
||||
async def shutdown_mcp_client() -> None:
|
||||
"""Shutdown the global MCP client manager."""
|
||||
global _manager_instance
|
||||
|
||||
# Use lock to prevent race with get_mcp_client()
|
||||
async with _manager_lock:
|
||||
if _manager_instance is not None:
|
||||
await _manager_instance.shutdown()
|
||||
_manager_instance = None
|
||||
|
||||
|
||||
def reset_mcp_client() -> None:
|
||||
"""Reset the global MCP client manager (for testing)."""
|
||||
global _manager_instance
|
||||
_manager_instance = None
|
||||
234
backend/app/services/mcp/config.py
Normal file
234
backend/app/services/mcp/config.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""
|
||||
MCP Configuration System
|
||||
|
||||
Pydantic models for MCP server configuration with YAML file loading
|
||||
and environment variable overrides.
|
||||
"""
|
||||
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class TransportType(str, Enum):
|
||||
"""Supported MCP transport types."""
|
||||
|
||||
HTTP = "http"
|
||||
STDIO = "stdio"
|
||||
SSE = "sse"
|
||||
|
||||
|
||||
class MCPServerConfig(BaseModel):
|
||||
"""Configuration for a single MCP server."""
|
||||
|
||||
url: str = Field(..., description="Server URL (supports ${ENV_VAR} syntax)")
|
||||
transport: TransportType = Field(
|
||||
default=TransportType.HTTP,
|
||||
description="Transport protocol to use",
|
||||
)
|
||||
timeout: int = Field(
|
||||
default=30,
|
||||
ge=1,
|
||||
le=600,
|
||||
description="Request timeout in seconds",
|
||||
)
|
||||
retry_attempts: int = Field(
|
||||
default=3,
|
||||
ge=0,
|
||||
le=10,
|
||||
description="Number of retry attempts on failure",
|
||||
)
|
||||
retry_delay: float = Field(
|
||||
default=1.0,
|
||||
ge=0.1,
|
||||
le=60.0,
|
||||
description="Initial delay between retries in seconds",
|
||||
)
|
||||
retry_max_delay: float = Field(
|
||||
default=30.0,
|
||||
ge=1.0,
|
||||
le=300.0,
|
||||
description="Maximum delay between retries in seconds",
|
||||
)
|
||||
circuit_breaker_threshold: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
le=50,
|
||||
description="Number of failures before opening circuit",
|
||||
)
|
||||
circuit_breaker_timeout: float = Field(
|
||||
default=30.0,
|
||||
ge=5.0,
|
||||
le=300.0,
|
||||
description="Seconds to wait before attempting to close circuit",
|
||||
)
|
||||
enabled: bool = Field(
|
||||
default=True,
|
||||
description="Whether this server is enabled",
|
||||
)
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="Human-readable description of the server",
|
||||
)
|
||||
|
||||
@field_validator("url", mode="before")
|
||||
@classmethod
|
||||
def expand_env_vars(cls, v: str) -> str:
|
||||
"""Expand environment variables in URL using ${VAR:-default} syntax."""
|
||||
if not isinstance(v, str):
|
||||
return v
|
||||
|
||||
result = v
|
||||
# Find all ${VAR} or ${VAR:-default} patterns
|
||||
import re
|
||||
|
||||
pattern = r"\$\{([^}]+)\}"
|
||||
matches = re.findall(pattern, v)
|
||||
|
||||
for match in matches:
|
||||
if ":-" in match:
|
||||
var_name, default = match.split(":-", 1)
|
||||
else:
|
||||
var_name, default = match, ""
|
||||
|
||||
env_value = os.environ.get(var_name.strip(), default)
|
||||
result = result.replace(f"${{{match}}}", env_value)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class MCPConfig(BaseModel):
|
||||
"""Root configuration for all MCP servers."""
|
||||
|
||||
mcp_servers: dict[str, MCPServerConfig] = Field(
|
||||
default_factory=dict,
|
||||
description="Map of server names to their configurations",
|
||||
)
|
||||
|
||||
# Global defaults
|
||||
default_timeout: int = Field(
|
||||
default=30,
|
||||
description="Default timeout for all servers",
|
||||
)
|
||||
default_retry_attempts: int = Field(
|
||||
default=3,
|
||||
description="Default retry attempts for all servers",
|
||||
)
|
||||
connection_pool_size: int = Field(
|
||||
default=10,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Maximum connections per server",
|
||||
)
|
||||
health_check_interval: int = Field(
|
||||
default=30,
|
||||
ge=5,
|
||||
le=300,
|
||||
description="Seconds between health checks",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, path: str | Path) -> "MCPConfig":
|
||||
"""Load configuration from a YAML file."""
|
||||
path = Path(path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"MCP config file not found: {path}")
|
||||
|
||||
with path.open("r") as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
if data is None:
|
||||
data = {}
|
||||
|
||||
return cls.model_validate(data)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "MCPConfig":
|
||||
"""Load configuration from a dictionary."""
|
||||
return cls.model_validate(data)
|
||||
|
||||
def get_server(self, name: str) -> MCPServerConfig | None:
|
||||
"""Get a server configuration by name."""
|
||||
return self.mcp_servers.get(name)
|
||||
|
||||
def get_enabled_servers(self) -> dict[str, MCPServerConfig]:
|
||||
"""Get all enabled server configurations."""
|
||||
return {
|
||||
name: config
|
||||
for name, config in self.mcp_servers.items()
|
||||
if config.enabled
|
||||
}
|
||||
|
||||
def list_server_names(self) -> list[str]:
|
||||
"""Get list of all configured server names."""
|
||||
return list(self.mcp_servers.keys())
|
||||
|
||||
|
||||
# Default configuration path
|
||||
DEFAULT_CONFIG_PATH = Path(__file__).parent.parent.parent.parent / "mcp_servers.yaml"
|
||||
|
||||
|
||||
def load_mcp_config(path: str | Path | None = None) -> MCPConfig:
|
||||
"""
|
||||
Load MCP configuration from file or environment.
|
||||
|
||||
Priority:
|
||||
1. Explicit path parameter
|
||||
2. MCP_CONFIG_PATH environment variable
|
||||
3. Default path (backend/mcp_servers.yaml)
|
||||
4. Empty config if no file exists
|
||||
"""
|
||||
if path is None:
|
||||
path = os.environ.get("MCP_CONFIG_PATH", str(DEFAULT_CONFIG_PATH))
|
||||
|
||||
path = Path(path)
|
||||
|
||||
if not path.exists():
|
||||
# Return empty config if no file exists (allows runtime registration)
|
||||
return MCPConfig()
|
||||
|
||||
return MCPConfig.from_yaml(path)
|
||||
|
||||
|
||||
def create_default_config() -> MCPConfig:
|
||||
"""
|
||||
Create a default MCP configuration with standard servers.
|
||||
|
||||
This is useful for development and as a template.
|
||||
"""
|
||||
return MCPConfig(
|
||||
mcp_servers={
|
||||
"llm-gateway": MCPServerConfig(
|
||||
url="${LLM_GATEWAY_URL:-http://localhost:8001}",
|
||||
transport=TransportType.HTTP,
|
||||
timeout=60,
|
||||
description="LLM Gateway for multi-provider AI interactions",
|
||||
),
|
||||
"knowledge-base": MCPServerConfig(
|
||||
url="${KNOWLEDGE_BASE_URL:-http://localhost:8002}",
|
||||
transport=TransportType.HTTP,
|
||||
timeout=30,
|
||||
description="Knowledge Base for RAG and document retrieval",
|
||||
),
|
||||
"git-ops": MCPServerConfig(
|
||||
url="${GIT_OPS_URL:-http://localhost:8003}",
|
||||
transport=TransportType.HTTP,
|
||||
timeout=120,
|
||||
description="Git Operations for repository management",
|
||||
),
|
||||
"issues": MCPServerConfig(
|
||||
url="${ISSUES_URL:-http://localhost:8004}",
|
||||
transport=TransportType.HTTP,
|
||||
timeout=30,
|
||||
description="Issue Tracker for Gitea/GitHub/GitLab",
|
||||
),
|
||||
},
|
||||
default_timeout=30,
|
||||
default_retry_attempts=3,
|
||||
connection_pool_size=10,
|
||||
health_check_interval=30,
|
||||
)
|
||||
435
backend/app/services/mcp/connection.py
Normal file
435
backend/app/services/mcp/connection.py
Normal file
@@ -0,0 +1,435 @@
|
||||
"""
|
||||
MCP Connection Management
|
||||
|
||||
Handles connection lifecycle, pooling, and automatic reconnection
|
||||
for MCP servers.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from .config import MCPServerConfig, TransportType
|
||||
from .exceptions import MCPConnectionError, MCPTimeoutError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConnectionState(str, Enum):
|
||||
"""Connection state enumeration."""
|
||||
|
||||
DISCONNECTED = "disconnected"
|
||||
CONNECTING = "connecting"
|
||||
CONNECTED = "connected"
|
||||
RECONNECTING = "reconnecting"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class MCPConnection:
|
||||
"""
|
||||
Manages a single connection to an MCP server.
|
||||
|
||||
Handles connection lifecycle, health checking, and automatic reconnection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_name: str,
|
||||
config: MCPServerConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize connection.
|
||||
|
||||
Args:
|
||||
server_name: Name of the MCP server
|
||||
config: Server configuration
|
||||
"""
|
||||
self.server_name = server_name
|
||||
self.config = config
|
||||
self._state = ConnectionState.DISCONNECTED
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
self._lock = asyncio.Lock()
|
||||
self._last_activity: float | None = None
|
||||
self._connection_attempts = 0
|
||||
self._last_error: Exception | None = None
|
||||
|
||||
# Reconnection settings
|
||||
self._base_delay = config.retry_delay
|
||||
self._max_delay = config.retry_max_delay
|
||||
self._max_attempts = config.retry_attempts
|
||||
|
||||
@property
|
||||
def state(self) -> ConnectionState:
|
||||
"""Get current connection state."""
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if connection is established."""
|
||||
return self._state == ConnectionState.CONNECTED
|
||||
|
||||
@property
|
||||
def last_error(self) -> Exception | None:
|
||||
"""Get the last error that occurred."""
|
||||
return self._last_error
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""
|
||||
Establish connection to the MCP server.
|
||||
|
||||
Raises:
|
||||
MCPConnectionError: If connection fails after all retries
|
||||
"""
|
||||
async with self._lock:
|
||||
if self._state == ConnectionState.CONNECTED:
|
||||
return
|
||||
|
||||
self._state = ConnectionState.CONNECTING
|
||||
self._connection_attempts = 0
|
||||
self._last_error = None
|
||||
|
||||
while self._connection_attempts < self._max_attempts:
|
||||
try:
|
||||
await self._do_connect()
|
||||
self._state = ConnectionState.CONNECTED
|
||||
self._last_activity = time.time()
|
||||
logger.info(
|
||||
"Connected to MCP server: %s at %s",
|
||||
self.server_name,
|
||||
self.config.url,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
self._connection_attempts += 1
|
||||
self._last_error = e
|
||||
logger.warning(
|
||||
"Connection attempt %d/%d failed for %s: %s",
|
||||
self._connection_attempts,
|
||||
self._max_attempts,
|
||||
self.server_name,
|
||||
e,
|
||||
)
|
||||
|
||||
if self._connection_attempts < self._max_attempts:
|
||||
delay = self._calculate_backoff_delay()
|
||||
logger.debug(
|
||||
"Retrying connection to %s in %.1fs",
|
||||
self.server_name,
|
||||
delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# All attempts failed
|
||||
self._state = ConnectionState.ERROR
|
||||
raise MCPConnectionError(
|
||||
f"Failed to connect after {self._max_attempts} attempts",
|
||||
server_name=self.server_name,
|
||||
url=self.config.url,
|
||||
cause=self._last_error,
|
||||
)
|
||||
|
||||
async def _do_connect(self) -> None:
|
||||
"""Perform the actual connection (transport-specific)."""
|
||||
if self.config.transport == TransportType.HTTP:
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self.config.url,
|
||||
timeout=httpx.Timeout(self.config.timeout),
|
||||
headers={
|
||||
"User-Agent": "Syndarix-MCP-Client/1.0",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
# Verify connectivity with a simple request
|
||||
try:
|
||||
# Try to hit the MCP capabilities endpoint
|
||||
response = await self._client.get("/mcp/capabilities")
|
||||
if response.status_code not in (200, 404):
|
||||
# 404 is acceptable - server might not have capabilities endpoint
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code != 404:
|
||||
raise
|
||||
except httpx.ConnectError as e:
|
||||
raise MCPConnectionError(
|
||||
"Failed to connect to server",
|
||||
server_name=self.server_name,
|
||||
url=self.config.url,
|
||||
cause=e,
|
||||
)
|
||||
else:
|
||||
# For STDIO and SSE transports, we'll implement later
|
||||
raise NotImplementedError(
|
||||
f"Transport {self.config.transport} not yet implemented"
|
||||
)
|
||||
|
||||
def _calculate_backoff_delay(self) -> float:
|
||||
"""Calculate exponential backoff delay with jitter."""
|
||||
import random
|
||||
|
||||
delay = self._base_delay * (2 ** (self._connection_attempts - 1))
|
||||
delay = min(delay, self._max_delay)
|
||||
# Add jitter (±25%)
|
||||
jitter = delay * 0.25 * (random.random() * 2 - 1)
|
||||
return delay + jitter
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from the MCP server."""
|
||||
async with self._lock:
|
||||
if self._client is not None:
|
||||
try:
|
||||
await self._client.aclose()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Error closing connection to %s: %s",
|
||||
self.server_name,
|
||||
e,
|
||||
)
|
||||
finally:
|
||||
self._client = None
|
||||
|
||||
self._state = ConnectionState.DISCONNECTED
|
||||
logger.info("Disconnected from MCP server: %s", self.server_name)
|
||||
|
||||
async def reconnect(self) -> None:
|
||||
"""Reconnect to the MCP server."""
|
||||
async with self._lock:
|
||||
self._state = ConnectionState.RECONNECTING
|
||||
await self.disconnect()
|
||||
await self.connect()
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""
|
||||
Perform a health check on the connection.
|
||||
|
||||
Returns:
|
||||
True if connection is healthy
|
||||
"""
|
||||
if not self.is_connected or self._client is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
if self.config.transport == TransportType.HTTP:
|
||||
response = await self._client.get(
|
||||
"/health",
|
||||
timeout=5.0,
|
||||
)
|
||||
return response.status_code == 200
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Health check failed for %s: %s",
|
||||
self.server_name,
|
||||
e,
|
||||
)
|
||||
return False
|
||||
|
||||
async def execute_request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
data: dict[str, Any] | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Execute an HTTP request to the MCP server.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
path: Request path
|
||||
data: Optional request body
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
Response data
|
||||
|
||||
Raises:
|
||||
MCPConnectionError: If not connected
|
||||
MCPTimeoutError: If request times out
|
||||
"""
|
||||
if not self.is_connected or self._client is None:
|
||||
raise MCPConnectionError(
|
||||
"Not connected to server",
|
||||
server_name=self.server_name,
|
||||
)
|
||||
|
||||
effective_timeout = timeout or self.config.timeout
|
||||
|
||||
try:
|
||||
if method.upper() == "GET":
|
||||
response = await self._client.get(
|
||||
path,
|
||||
timeout=effective_timeout,
|
||||
)
|
||||
elif method.upper() == "POST":
|
||||
response = await self._client.post(
|
||||
path,
|
||||
json=data,
|
||||
timeout=effective_timeout,
|
||||
)
|
||||
else:
|
||||
response = await self._client.request(
|
||||
method.upper(),
|
||||
path,
|
||||
json=data,
|
||||
timeout=effective_timeout,
|
||||
)
|
||||
|
||||
self._last_activity = time.time()
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
raise MCPTimeoutError(
|
||||
"Request timed out",
|
||||
server_name=self.server_name,
|
||||
timeout_seconds=effective_timeout,
|
||||
operation=f"{method} {path}",
|
||||
) from e
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise MCPConnectionError(
|
||||
f"HTTP error: {e.response.status_code}",
|
||||
server_name=self.server_name,
|
||||
url=f"{self.config.url}{path}",
|
||||
cause=e,
|
||||
)
|
||||
except Exception as e:
|
||||
raise MCPConnectionError(
|
||||
f"Request failed: {e}",
|
||||
server_name=self.server_name,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
|
||||
class ConnectionPool:
|
||||
"""
|
||||
Pool of connections to MCP servers.
|
||||
|
||||
Manages connection lifecycle and provides connection reuse.
|
||||
"""
|
||||
|
||||
def __init__(self, max_connections_per_server: int = 10) -> None:
|
||||
"""
|
||||
Initialize connection pool.
|
||||
|
||||
Args:
|
||||
max_connections_per_server: Maximum connections per server
|
||||
"""
|
||||
self._connections: dict[str, MCPConnection] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._max_per_server = max_connections_per_server
|
||||
|
||||
async def get_connection(
|
||||
self,
|
||||
server_name: str,
|
||||
config: MCPServerConfig,
|
||||
) -> MCPConnection:
|
||||
"""
|
||||
Get or create a connection to a server.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server
|
||||
config: Server configuration
|
||||
|
||||
Returns:
|
||||
Active connection
|
||||
"""
|
||||
async with self._lock:
|
||||
if server_name not in self._connections:
|
||||
connection = MCPConnection(server_name, config)
|
||||
await connection.connect()
|
||||
self._connections[server_name] = connection
|
||||
|
||||
connection = self._connections[server_name]
|
||||
|
||||
# Reconnect if not connected
|
||||
if not connection.is_connected:
|
||||
await connection.connect()
|
||||
|
||||
return connection
|
||||
|
||||
async def release_connection(self, server_name: str) -> None:
|
||||
"""
|
||||
Release a connection (currently just tracks usage).
|
||||
|
||||
Args:
|
||||
server_name: Name of the server
|
||||
"""
|
||||
# For now, we keep connections alive
|
||||
# Future: implement connection reaping for idle connections
|
||||
|
||||
async def close_connection(self, server_name: str) -> None:
|
||||
"""
|
||||
Close and remove a connection.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server
|
||||
"""
|
||||
async with self._lock:
|
||||
if server_name in self._connections:
|
||||
await self._connections[server_name].disconnect()
|
||||
del self._connections[server_name]
|
||||
|
||||
async def close_all(self) -> None:
|
||||
"""Close all connections in the pool."""
|
||||
async with self._lock:
|
||||
for connection in self._connections.values():
|
||||
try:
|
||||
await connection.disconnect()
|
||||
except Exception as e:
|
||||
logger.warning("Error closing connection: %s", e)
|
||||
|
||||
self._connections.clear()
|
||||
logger.info("Closed all MCP connections")
|
||||
|
||||
async def health_check_all(self) -> dict[str, bool]:
|
||||
"""
|
||||
Perform health check on all connections.
|
||||
|
||||
Returns:
|
||||
Dict mapping server names to health status
|
||||
"""
|
||||
results = {}
|
||||
for name, connection in self._connections.items():
|
||||
results[name] = await connection.health_check()
|
||||
return results
|
||||
|
||||
def get_status(self) -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
Get status of all connections.
|
||||
|
||||
Returns:
|
||||
Dict mapping server names to status info
|
||||
"""
|
||||
return {
|
||||
name: {
|
||||
"state": conn.state.value,
|
||||
"is_connected": conn.is_connected,
|
||||
"url": conn.config.url,
|
||||
}
|
||||
for name, conn in self._connections.items()
|
||||
}
|
||||
|
||||
@asynccontextmanager
|
||||
async def connection(
|
||||
self,
|
||||
server_name: str,
|
||||
config: MCPServerConfig,
|
||||
) -> AsyncGenerator[MCPConnection, None]:
|
||||
"""
|
||||
Context manager for getting a connection.
|
||||
|
||||
Usage:
|
||||
async with pool.connection("server", config) as conn:
|
||||
result = await conn.execute_request("POST", "/tool", data)
|
||||
"""
|
||||
conn = await self.get_connection(server_name, config)
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
await self.release_connection(server_name)
|
||||
201
backend/app/services/mcp/exceptions.py
Normal file
201
backend/app/services/mcp/exceptions.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
MCP Exception Classes
|
||||
|
||||
Custom exceptions for MCP client operations with detailed error context.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class MCPError(Exception):
|
||||
"""Base exception for all MCP-related errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
server_name: str | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.server_name = server_name
|
||||
self.details = details or {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
parts = [self.message]
|
||||
if self.server_name:
|
||||
parts.append(f"server={self.server_name}")
|
||||
if self.details:
|
||||
parts.append(f"details={self.details}")
|
||||
return " | ".join(parts)
|
||||
|
||||
|
||||
class MCPConnectionError(MCPError):
|
||||
"""Raised when connection to an MCP server fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
server_name: str | None = None,
|
||||
url: str | None = None,
|
||||
cause: Exception | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message, server_name=server_name, details=details)
|
||||
self.url = url
|
||||
self.cause = cause
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = super().__str__()
|
||||
if self.url:
|
||||
base = f"{base} | url={self.url}"
|
||||
if self.cause:
|
||||
base = f"{base} | cause={type(self.cause).__name__}: {self.cause}"
|
||||
return base
|
||||
|
||||
|
||||
class MCPTimeoutError(MCPError):
|
||||
"""Raised when an MCP operation times out."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
server_name: str | None = None,
|
||||
timeout_seconds: float | None = None,
|
||||
operation: str | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message, server_name=server_name, details=details)
|
||||
self.timeout_seconds = timeout_seconds
|
||||
self.operation = operation
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = super().__str__()
|
||||
if self.timeout_seconds is not None:
|
||||
base = f"{base} | timeout={self.timeout_seconds}s"
|
||||
if self.operation:
|
||||
base = f"{base} | operation={self.operation}"
|
||||
return base
|
||||
|
||||
|
||||
class MCPToolError(MCPError):
|
||||
"""Raised when a tool execution fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
server_name: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
tool_args: dict[str, Any] | None = None,
|
||||
error_code: str | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message, server_name=server_name, details=details)
|
||||
self.tool_name = tool_name
|
||||
self.tool_args = tool_args
|
||||
self.error_code = error_code
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = super().__str__()
|
||||
if self.tool_name:
|
||||
base = f"{base} | tool={self.tool_name}"
|
||||
if self.error_code:
|
||||
base = f"{base} | error_code={self.error_code}"
|
||||
return base
|
||||
|
||||
|
||||
class MCPServerNotFoundError(MCPError):
|
||||
"""Raised when a requested MCP server is not registered."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_name: str,
|
||||
*,
|
||||
available_servers: list[str] | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
message = f"MCP server not found: {server_name}"
|
||||
super().__init__(message, server_name=server_name, details=details)
|
||||
self.available_servers = available_servers or []
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = super().__str__()
|
||||
if self.available_servers:
|
||||
base = f"{base} | available={self.available_servers}"
|
||||
return base
|
||||
|
||||
|
||||
class MCPToolNotFoundError(MCPError):
|
||||
"""Raised when a requested tool is not found on any server."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool_name: str,
|
||||
*,
|
||||
server_name: str | None = None,
|
||||
available_tools: list[str] | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
message = f"Tool not found: {tool_name}"
|
||||
super().__init__(message, server_name=server_name, details=details)
|
||||
self.tool_name = tool_name
|
||||
self.available_tools = available_tools or []
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = super().__str__()
|
||||
if self.available_tools:
|
||||
base = f"{base} | available_tools={self.available_tools[:5]}..."
|
||||
return base
|
||||
|
||||
|
||||
class MCPCircuitOpenError(MCPError):
|
||||
"""Raised when a circuit breaker is open (server temporarily unavailable)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_name: str,
|
||||
*,
|
||||
failure_count: int | None = None,
|
||||
reset_timeout: float | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
message = f"Circuit breaker open for server: {server_name}"
|
||||
super().__init__(message, server_name=server_name, details=details)
|
||||
self.failure_count = failure_count
|
||||
self.reset_timeout = reset_timeout
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = super().__str__()
|
||||
if self.failure_count is not None:
|
||||
base = f"{base} | failures={self.failure_count}"
|
||||
if self.reset_timeout is not None:
|
||||
base = f"{base} | reset_in={self.reset_timeout}s"
|
||||
return base
|
||||
|
||||
|
||||
class MCPValidationError(MCPError):
|
||||
"""Raised when tool arguments fail validation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
tool_name: str | None = None,
|
||||
field_errors: dict[str, str] | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message, details=details)
|
||||
self.tool_name = tool_name
|
||||
self.field_errors = field_errors or {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = super().__str__()
|
||||
if self.tool_name:
|
||||
base = f"{base} | tool={self.tool_name}"
|
||||
if self.field_errors:
|
||||
base = f"{base} | fields={list(self.field_errors.keys())}"
|
||||
return base
|
||||
305
backend/app/services/mcp/registry.py
Normal file
305
backend/app/services/mcp/registry.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""
|
||||
MCP Server Registry
|
||||
|
||||
Thread-safe singleton registry for managing MCP server configurations
|
||||
and their capabilities.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
from .config import MCPConfig, MCPServerConfig, load_mcp_config
|
||||
from .exceptions import MCPServerNotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ServerCapabilities:
|
||||
"""Cached capabilities for an MCP server."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
resources: list[dict[str, Any]] | None = None,
|
||||
prompts: list[dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
self.tools = tools or []
|
||||
self.resources = resources or []
|
||||
self.prompts = prompts or []
|
||||
self._loaded = False
|
||||
self._load_time: float | None = None
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if capabilities have been loaded."""
|
||||
return self._loaded
|
||||
|
||||
@property
|
||||
def tool_names(self) -> list[str]:
|
||||
"""Get list of tool names."""
|
||||
return [t.get("name", "") for t in self.tools if t.get("name")]
|
||||
|
||||
def mark_loaded(self) -> None:
|
||||
"""Mark capabilities as loaded."""
|
||||
import time
|
||||
|
||||
self._loaded = True
|
||||
self._load_time = time.time()
|
||||
|
||||
|
||||
class MCPServerRegistry:
|
||||
"""
|
||||
Thread-safe singleton registry for MCP servers.
|
||||
|
||||
Manages server configurations and caches their capabilities.
|
||||
"""
|
||||
|
||||
_instance: "MCPServerRegistry | None" = None
|
||||
_lock = Lock()
|
||||
|
||||
def __new__(cls) -> "MCPServerRegistry":
|
||||
"""Ensure singleton pattern."""
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize registry (only runs once due to singleton)."""
|
||||
if getattr(self, "_initialized", False):
|
||||
return
|
||||
|
||||
self._config: MCPConfig = MCPConfig()
|
||||
self._capabilities: dict[str, ServerCapabilities] = {}
|
||||
self._capabilities_lock = asyncio.Lock()
|
||||
self._initialized = True
|
||||
|
||||
logger.info("MCP Server Registry initialized")
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "MCPServerRegistry":
|
||||
"""Get the singleton registry instance."""
|
||||
return cls()
|
||||
|
||||
@classmethod
|
||||
def reset_instance(cls) -> None:
|
||||
"""Reset the singleton (for testing)."""
|
||||
with cls._lock:
|
||||
cls._instance = None
|
||||
|
||||
def load_config(self, config: MCPConfig | None = None) -> None:
|
||||
"""
|
||||
Load configuration into the registry.
|
||||
|
||||
Args:
|
||||
config: Optional config to load. If None, loads from default path.
|
||||
"""
|
||||
if config is None:
|
||||
config = load_mcp_config()
|
||||
|
||||
self._config = config
|
||||
self._capabilities.clear()
|
||||
|
||||
logger.info(
|
||||
"Loaded MCP configuration with %d servers",
|
||||
len(config.mcp_servers),
|
||||
)
|
||||
for name in config.list_server_names():
|
||||
logger.debug("Registered MCP server: %s", name)
|
||||
|
||||
def register(self, name: str, config: MCPServerConfig) -> None:
|
||||
"""
|
||||
Register a new MCP server.
|
||||
|
||||
Args:
|
||||
name: Unique server name
|
||||
config: Server configuration
|
||||
"""
|
||||
self._config.mcp_servers[name] = config
|
||||
self._capabilities.pop(name, None) # Clear any cached capabilities
|
||||
|
||||
logger.info("Registered MCP server: %s at %s", name, config.url)
|
||||
|
||||
def unregister(self, name: str) -> bool:
|
||||
"""
|
||||
Unregister an MCP server.
|
||||
|
||||
Args:
|
||||
name: Server name to unregister
|
||||
|
||||
Returns:
|
||||
True if server was found and removed
|
||||
"""
|
||||
if name in self._config.mcp_servers:
|
||||
del self._config.mcp_servers[name]
|
||||
self._capabilities.pop(name, None)
|
||||
logger.info("Unregistered MCP server: %s", name)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get(self, name: str) -> MCPServerConfig:
|
||||
"""
|
||||
Get a server configuration by name.
|
||||
|
||||
Args:
|
||||
name: Server name
|
||||
|
||||
Returns:
|
||||
Server configuration
|
||||
|
||||
Raises:
|
||||
MCPServerNotFoundError: If server is not registered
|
||||
"""
|
||||
config = self._config.get_server(name)
|
||||
if config is None:
|
||||
raise MCPServerNotFoundError(
|
||||
server_name=name,
|
||||
available_servers=self.list_servers(),
|
||||
)
|
||||
return config
|
||||
|
||||
def get_or_none(self, name: str) -> MCPServerConfig | None:
|
||||
"""
|
||||
Get a server configuration by name, or None if not found.
|
||||
|
||||
Args:
|
||||
name: Server name
|
||||
|
||||
Returns:
|
||||
Server configuration or None
|
||||
"""
|
||||
return self._config.get_server(name)
|
||||
|
||||
def list_servers(self) -> list[str]:
|
||||
"""Get list of all registered server names."""
|
||||
return self._config.list_server_names()
|
||||
|
||||
def list_enabled_servers(self) -> list[str]:
|
||||
"""Get list of enabled server names."""
|
||||
return list(self._config.get_enabled_servers().keys())
|
||||
|
||||
def get_all_configs(self) -> dict[str, MCPServerConfig]:
|
||||
"""Get all server configurations."""
|
||||
return dict(self._config.mcp_servers)
|
||||
|
||||
def get_enabled_configs(self) -> dict[str, MCPServerConfig]:
|
||||
"""Get all enabled server configurations."""
|
||||
return self._config.get_enabled_servers()
|
||||
|
||||
async def get_capabilities(
|
||||
self,
|
||||
name: str,
|
||||
force_refresh: bool = False,
|
||||
) -> ServerCapabilities:
|
||||
"""
|
||||
Get capabilities for a server (lazy-loaded and cached).
|
||||
|
||||
Args:
|
||||
name: Server name
|
||||
force_refresh: If True, refresh cached capabilities
|
||||
|
||||
Returns:
|
||||
Server capabilities
|
||||
|
||||
Raises:
|
||||
MCPServerNotFoundError: If server is not registered
|
||||
"""
|
||||
# Verify server exists
|
||||
self.get(name)
|
||||
|
||||
async with self._capabilities_lock:
|
||||
if name not in self._capabilities or force_refresh:
|
||||
# Will be populated by connection manager when connecting
|
||||
self._capabilities[name] = ServerCapabilities()
|
||||
|
||||
return self._capabilities[name]
|
||||
|
||||
def set_capabilities(
|
||||
self,
|
||||
name: str,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
resources: list[dict[str, Any]] | None = None,
|
||||
prompts: list[dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Set capabilities for a server (called by connection manager).
|
||||
|
||||
Args:
|
||||
name: Server name
|
||||
tools: List of tool definitions
|
||||
resources: List of resource definitions
|
||||
prompts: List of prompt definitions
|
||||
"""
|
||||
capabilities = ServerCapabilities(
|
||||
tools=tools,
|
||||
resources=resources,
|
||||
prompts=prompts,
|
||||
)
|
||||
capabilities.mark_loaded()
|
||||
self._capabilities[name] = capabilities
|
||||
|
||||
logger.debug(
|
||||
"Updated capabilities for %s: %d tools, %d resources, %d prompts",
|
||||
name,
|
||||
len(capabilities.tools),
|
||||
len(capabilities.resources),
|
||||
len(capabilities.prompts),
|
||||
)
|
||||
|
||||
def get_cached_capabilities(self, name: str) -> ServerCapabilities:
|
||||
"""
|
||||
Get cached capabilities without async loading.
|
||||
|
||||
Use this for synchronous access when you only need
|
||||
cached values (e.g., for health check responses).
|
||||
|
||||
Args:
|
||||
name: Server name
|
||||
|
||||
Returns:
|
||||
Cached capabilities or empty ServerCapabilities
|
||||
"""
|
||||
return self._capabilities.get(name, ServerCapabilities())
|
||||
|
||||
def find_server_for_tool(self, tool_name: str) -> str | None:
|
||||
"""
|
||||
Find which server provides a specific tool.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to find
|
||||
|
||||
Returns:
|
||||
Server name or None if not found
|
||||
"""
|
||||
for name, caps in self._capabilities.items():
|
||||
if tool_name in caps.tool_names:
|
||||
return name
|
||||
return None
|
||||
|
||||
def get_all_tools(self) -> dict[str, list[dict[str, Any]]]:
|
||||
"""
|
||||
Get all tools from all servers.
|
||||
|
||||
Returns:
|
||||
Dict mapping server name to list of tool definitions
|
||||
"""
|
||||
return {
|
||||
name: caps.tools
|
||||
for name, caps in self._capabilities.items()
|
||||
if caps.is_loaded
|
||||
}
|
||||
|
||||
@property
|
||||
def global_config(self) -> MCPConfig:
|
||||
"""Get the global MCP configuration."""
|
||||
return self._config
|
||||
|
||||
|
||||
# Module-level convenience function
|
||||
def get_registry() -> MCPServerRegistry:
|
||||
"""Get the global MCP server registry instance."""
|
||||
return MCPServerRegistry.get_instance()
|
||||
619
backend/app/services/mcp/routing.py
Normal file
619
backend/app/services/mcp/routing.py
Normal file
@@ -0,0 +1,619 @@
|
||||
"""
|
||||
MCP Tool Call Routing
|
||||
|
||||
Routes tool calls to appropriate servers with retry logic,
|
||||
circuit breakers, and request/response serialization.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from .config import MCPServerConfig
|
||||
from .connection import ConnectionPool, MCPConnection
|
||||
from .exceptions import (
|
||||
MCPCircuitOpenError,
|
||||
MCPError,
|
||||
MCPTimeoutError,
|
||||
MCPToolError,
|
||||
MCPToolNotFoundError,
|
||||
)
|
||||
from .registry import MCPServerRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CircuitState(Enum):
|
||||
"""Circuit breaker states."""
|
||||
|
||||
CLOSED = "closed"
|
||||
OPEN = "open"
|
||||
HALF_OPEN = "half-open"
|
||||
|
||||
|
||||
class AsyncCircuitBreaker:
|
||||
"""
|
||||
Async-compatible circuit breaker implementation.
|
||||
|
||||
Unlike pybreaker which wraps sync functions, this implementation
|
||||
provides explicit success/failure tracking for async code.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fail_max: int = 5,
|
||||
reset_timeout: float = 30.0,
|
||||
name: str = "",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize circuit breaker.
|
||||
|
||||
Args:
|
||||
fail_max: Maximum failures before opening circuit
|
||||
reset_timeout: Seconds to wait before trying again
|
||||
name: Name for logging
|
||||
"""
|
||||
self.fail_max = fail_max
|
||||
self.reset_timeout = reset_timeout
|
||||
self.name = name
|
||||
self._state = CircuitState.CLOSED
|
||||
self._fail_counter = 0
|
||||
self._last_failure_time: float | None = None
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def current_state(self) -> str:
|
||||
"""Get current state as string."""
|
||||
# Check if we should transition from OPEN to HALF_OPEN
|
||||
if self._state == CircuitState.OPEN:
|
||||
if self._should_try_reset():
|
||||
return CircuitState.HALF_OPEN.value
|
||||
return self._state.value
|
||||
|
||||
@property
|
||||
def fail_counter(self) -> int:
|
||||
"""Get current failure count."""
|
||||
return self._fail_counter
|
||||
|
||||
def _should_try_reset(self) -> bool:
|
||||
"""Check if enough time has passed to try resetting."""
|
||||
if self._last_failure_time is None:
|
||||
return True
|
||||
return (time.time() - self._last_failure_time) >= self.reset_timeout
|
||||
|
||||
async def success(self) -> None:
|
||||
"""Record a successful call."""
|
||||
async with self._lock:
|
||||
self._fail_counter = 0
|
||||
self._state = CircuitState.CLOSED
|
||||
self._last_failure_time = None
|
||||
|
||||
async def failure(self) -> None:
|
||||
"""Record a failed call."""
|
||||
async with self._lock:
|
||||
self._fail_counter += 1
|
||||
self._last_failure_time = time.time()
|
||||
|
||||
if self._fail_counter >= self.fail_max:
|
||||
self._state = CircuitState.OPEN
|
||||
logger.warning(
|
||||
"Circuit breaker %s opened after %d failures",
|
||||
self.name,
|
||||
self._fail_counter,
|
||||
)
|
||||
|
||||
def is_open(self) -> bool:
|
||||
"""Check if circuit is open (not allowing calls)."""
|
||||
if self._state == CircuitState.OPEN:
|
||||
return not self._should_try_reset()
|
||||
return False
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Manually reset the circuit breaker."""
|
||||
async with self._lock:
|
||||
self._state = CircuitState.CLOSED
|
||||
self._fail_counter = 0
|
||||
self._last_failure_time = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolInfo:
|
||||
"""Information about an available tool."""
|
||||
|
||||
name: str
|
||||
description: str | None = None
|
||||
server_name: str | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"server_name": self.server_name,
|
||||
"input_schema": self.input_schema,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
"""Result of a tool execution."""
|
||||
|
||||
success: bool
|
||||
data: Any = None
|
||||
error: str | None = None
|
||||
error_code: str | None = None
|
||||
tool_name: str | None = None
|
||||
server_name: str | None = None
|
||||
execution_time_ms: float = 0.0
|
||||
request_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"success": self.success,
|
||||
"data": self.data,
|
||||
"error": self.error,
|
||||
"error_code": self.error_code,
|
||||
"tool_name": self.tool_name,
|
||||
"server_name": self.server_name,
|
||||
"execution_time_ms": self.execution_time_ms,
|
||||
"request_id": self.request_id,
|
||||
}
|
||||
|
||||
|
||||
class ToolRouter:
|
||||
"""
|
||||
Routes tool calls to the appropriate MCP server.
|
||||
|
||||
Features:
|
||||
- Tool name to server mapping
|
||||
- Retry logic with exponential backoff
|
||||
- Circuit breaker pattern for fault tolerance
|
||||
- Request/response serialization
|
||||
- Execution timing and metrics
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
registry: MCPServerRegistry,
|
||||
connection_pool: ConnectionPool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the tool router.
|
||||
|
||||
Args:
|
||||
registry: MCP server registry
|
||||
connection_pool: Connection pool for servers
|
||||
"""
|
||||
self._registry = registry
|
||||
self._pool = connection_pool
|
||||
self._circuit_breakers: dict[str, AsyncCircuitBreaker] = {}
|
||||
self._tool_to_server: dict[str, str] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def _get_circuit_breaker(
|
||||
self,
|
||||
server_name: str,
|
||||
config: MCPServerConfig,
|
||||
) -> AsyncCircuitBreaker:
|
||||
"""Get or create a circuit breaker for a server."""
|
||||
if server_name not in self._circuit_breakers:
|
||||
self._circuit_breakers[server_name] = AsyncCircuitBreaker(
|
||||
fail_max=config.circuit_breaker_threshold,
|
||||
reset_timeout=config.circuit_breaker_timeout,
|
||||
name=f"mcp-{server_name}",
|
||||
)
|
||||
return self._circuit_breakers[server_name]
|
||||
|
||||
async def register_tool_mapping(
|
||||
self,
|
||||
tool_name: str,
|
||||
server_name: str,
|
||||
) -> None:
|
||||
"""
|
||||
Register a mapping from tool name to server.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool
|
||||
server_name: Name of the server providing the tool
|
||||
"""
|
||||
async with self._lock:
|
||||
self._tool_to_server[tool_name] = server_name
|
||||
logger.debug("Registered tool %s -> server %s", tool_name, server_name)
|
||||
|
||||
async def discover_tools(self) -> None:
|
||||
"""
|
||||
Discover all tools from registered servers and build mappings.
|
||||
"""
|
||||
for server_name in self._registry.list_enabled_servers():
|
||||
try:
|
||||
config = self._registry.get(server_name)
|
||||
connection = await self._pool.get_connection(server_name, config)
|
||||
|
||||
# Fetch tools from server
|
||||
tools = await self._fetch_tools_from_server(connection)
|
||||
|
||||
# Update registry with capabilities
|
||||
self._registry.set_capabilities(
|
||||
server_name,
|
||||
tools=[t.to_dict() for t in tools],
|
||||
)
|
||||
|
||||
# Update tool mappings
|
||||
for tool in tools:
|
||||
await self.register_tool_mapping(tool.name, server_name)
|
||||
|
||||
logger.info(
|
||||
"Discovered %d tools from server %s",
|
||||
len(tools),
|
||||
server_name,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to discover tools from %s: %s",
|
||||
server_name,
|
||||
e,
|
||||
)
|
||||
|
||||
async def _fetch_tools_from_server(
|
||||
self,
|
||||
connection: MCPConnection,
|
||||
) -> list[ToolInfo]:
|
||||
"""Fetch available tools from an MCP server."""
|
||||
try:
|
||||
response = await connection.execute_request(
|
||||
"GET",
|
||||
"/mcp/tools",
|
||||
)
|
||||
|
||||
tools = []
|
||||
for tool_data in response.get("tools", []):
|
||||
tools.append(
|
||||
ToolInfo(
|
||||
name=tool_data.get("name", ""),
|
||||
description=tool_data.get("description"),
|
||||
server_name=connection.server_name,
|
||||
input_schema=tool_data.get("inputSchema"),
|
||||
)
|
||||
)
|
||||
return tools
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Error fetching tools from %s: %s",
|
||||
connection.server_name,
|
||||
e,
|
||||
)
|
||||
return []
|
||||
|
||||
def find_server_for_tool(self, tool_name: str) -> str | None:
|
||||
"""
|
||||
Find which server provides a specific tool.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool
|
||||
|
||||
Returns:
|
||||
Server name or None if not found
|
||||
"""
|
||||
return self._tool_to_server.get(tool_name)
|
||||
|
||||
async def call_tool(
|
||||
self,
|
||||
server_name: str,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> ToolResult:
|
||||
"""
|
||||
Call a tool on a specific server.
|
||||
|
||||
Args:
|
||||
server_name: Name of the MCP server
|
||||
tool_name: Name of the tool to call
|
||||
arguments: Tool arguments
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
"""
|
||||
start_time = time.time()
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
logger.debug(
|
||||
"Tool call [%s]: %s.%s with args %s",
|
||||
request_id,
|
||||
server_name,
|
||||
tool_name,
|
||||
arguments,
|
||||
)
|
||||
|
||||
try:
|
||||
config = self._registry.get(server_name)
|
||||
circuit_breaker = self._get_circuit_breaker(server_name, config)
|
||||
|
||||
# Check circuit breaker state
|
||||
if circuit_breaker.is_open():
|
||||
raise MCPCircuitOpenError(
|
||||
server_name=server_name,
|
||||
failure_count=circuit_breaker.fail_counter,
|
||||
reset_timeout=config.circuit_breaker_timeout,
|
||||
)
|
||||
|
||||
# Execute with retry logic
|
||||
result = await self._execute_with_retry(
|
||||
server_name=server_name,
|
||||
config=config,
|
||||
tool_name=tool_name,
|
||||
arguments=arguments or {},
|
||||
timeout=timeout,
|
||||
circuit_breaker=circuit_breaker,
|
||||
)
|
||||
|
||||
execution_time = (time.time() - start_time) * 1000
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data=result,
|
||||
tool_name=tool_name,
|
||||
server_name=server_name,
|
||||
execution_time_ms=execution_time,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
except MCPCircuitOpenError:
|
||||
raise
|
||||
except MCPError as e:
|
||||
execution_time = (time.time() - start_time) * 1000
|
||||
logger.error(
|
||||
"Tool call failed [%s]: %s.%s - %s",
|
||||
request_id,
|
||||
server_name,
|
||||
tool_name,
|
||||
e,
|
||||
)
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
error_code=type(e).__name__,
|
||||
tool_name=tool_name,
|
||||
server_name=server_name,
|
||||
execution_time_ms=execution_time,
|
||||
request_id=request_id,
|
||||
)
|
||||
except Exception as e:
|
||||
execution_time = (time.time() - start_time) * 1000
|
||||
logger.exception(
|
||||
"Unexpected error in tool call [%s]: %s.%s",
|
||||
request_id,
|
||||
server_name,
|
||||
tool_name,
|
||||
)
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
error_code="UnexpectedError",
|
||||
tool_name=tool_name,
|
||||
server_name=server_name,
|
||||
execution_time_ms=execution_time,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
async def _execute_with_retry(
|
||||
self,
|
||||
server_name: str,
|
||||
config: MCPServerConfig,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any],
|
||||
timeout: float | None,
|
||||
circuit_breaker: AsyncCircuitBreaker,
|
||||
) -> Any:
|
||||
"""Execute tool call with retry logic."""
|
||||
last_error: Exception | None = None
|
||||
attempts = 0
|
||||
max_attempts = config.retry_attempts + 1 # +1 for initial attempt
|
||||
|
||||
while attempts < max_attempts:
|
||||
attempts += 1
|
||||
|
||||
try:
|
||||
# Use circuit breaker to track failures
|
||||
result = await self._execute_tool_call(
|
||||
server_name=server_name,
|
||||
config=config,
|
||||
tool_name=tool_name,
|
||||
arguments=arguments,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Success - record it
|
||||
await circuit_breaker.success()
|
||||
return result
|
||||
|
||||
except MCPCircuitOpenError:
|
||||
raise
|
||||
except MCPTimeoutError:
|
||||
# Timeout - don't retry
|
||||
await circuit_breaker.failure()
|
||||
raise
|
||||
except MCPToolError:
|
||||
# Tool-level error - don't retry (user error)
|
||||
raise
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
await circuit_breaker.failure()
|
||||
|
||||
if attempts < max_attempts:
|
||||
delay = self._calculate_retry_delay(attempts, config)
|
||||
logger.warning(
|
||||
"Tool call attempt %d/%d failed for %s.%s: %s. "
|
||||
"Retrying in %.1fs",
|
||||
attempts,
|
||||
max_attempts,
|
||||
server_name,
|
||||
tool_name,
|
||||
e,
|
||||
delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# All attempts failed
|
||||
raise MCPToolError(
|
||||
f"Tool call failed after {max_attempts} attempts",
|
||||
server_name=server_name,
|
||||
tool_name=tool_name,
|
||||
tool_args=arguments,
|
||||
details={"last_error": str(last_error)},
|
||||
)
|
||||
|
||||
def _calculate_retry_delay(
|
||||
self,
|
||||
attempt: int,
|
||||
config: MCPServerConfig,
|
||||
) -> float:
|
||||
"""Calculate exponential backoff delay with jitter."""
|
||||
import random
|
||||
|
||||
delay = config.retry_delay * (2 ** (attempt - 1))
|
||||
delay = min(delay, config.retry_max_delay)
|
||||
# Add jitter (±25%)
|
||||
jitter = delay * 0.25 * (random.random() * 2 - 1)
|
||||
return max(0.1, delay + jitter)
|
||||
|
||||
async def _execute_tool_call(
|
||||
self,
|
||||
server_name: str,
|
||||
config: MCPServerConfig,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any],
|
||||
timeout: float | None,
|
||||
) -> Any:
|
||||
"""Execute a single tool call."""
|
||||
connection = await self._pool.get_connection(server_name, config)
|
||||
|
||||
# Build MCP tool call request
|
||||
request_body = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": tool_name,
|
||||
"arguments": arguments,
|
||||
},
|
||||
"id": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
response = await connection.execute_request(
|
||||
method="POST",
|
||||
path="/mcp",
|
||||
data=request_body,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Handle JSON-RPC response
|
||||
if "error" in response:
|
||||
error = response["error"]
|
||||
raise MCPToolError(
|
||||
error.get("message", "Tool execution failed"),
|
||||
server_name=server_name,
|
||||
tool_name=tool_name,
|
||||
tool_args=arguments,
|
||||
error_code=str(error.get("code", "UNKNOWN")),
|
||||
)
|
||||
|
||||
return response.get("result")
|
||||
|
||||
async def route_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> ToolResult:
|
||||
"""
|
||||
Route a tool call to the appropriate server.
|
||||
|
||||
Automatically discovers which server provides the tool.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to call
|
||||
arguments: Tool arguments
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
|
||||
Raises:
|
||||
MCPToolNotFoundError: If no server provides the tool
|
||||
"""
|
||||
server_name = self.find_server_for_tool(tool_name)
|
||||
|
||||
if server_name is None:
|
||||
# Try to find from registry
|
||||
server_name = self._registry.find_server_for_tool(tool_name)
|
||||
|
||||
if server_name is None:
|
||||
raise MCPToolNotFoundError(
|
||||
tool_name=tool_name,
|
||||
available_tools=list(self._tool_to_server.keys()),
|
||||
)
|
||||
|
||||
return await self.call_tool(
|
||||
server_name=server_name,
|
||||
tool_name=tool_name,
|
||||
arguments=arguments,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def list_all_tools(self) -> list[ToolInfo]:
|
||||
"""
|
||||
Get all available tools from all servers.
|
||||
|
||||
Returns:
|
||||
List of tool information
|
||||
"""
|
||||
tools = []
|
||||
all_server_tools = self._registry.get_all_tools()
|
||||
|
||||
for server_name, server_tools in all_server_tools.items():
|
||||
for tool_data in server_tools:
|
||||
tools.append(
|
||||
ToolInfo(
|
||||
name=tool_data.get("name", ""),
|
||||
description=tool_data.get("description"),
|
||||
server_name=server_name,
|
||||
input_schema=tool_data.get("input_schema"),
|
||||
)
|
||||
)
|
||||
|
||||
return tools
|
||||
|
||||
def get_circuit_breaker_status(self) -> dict[str, dict[str, Any]]:
|
||||
"""Get status of all circuit breakers."""
|
||||
return {
|
||||
name: {
|
||||
"state": cb.current_state,
|
||||
"failure_count": cb.fail_counter,
|
||||
}
|
||||
for name, cb in self._circuit_breakers.items()
|
||||
}
|
||||
|
||||
async def reset_circuit_breaker(self, server_name: str) -> bool:
|
||||
"""
|
||||
Manually reset a circuit breaker.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server
|
||||
|
||||
Returns:
|
||||
True if circuit breaker was reset
|
||||
"""
|
||||
async with self._lock:
|
||||
if server_name in self._circuit_breakers:
|
||||
# Reset by removing (will be recreated on next call)
|
||||
del self._circuit_breakers[server_name]
|
||||
logger.info("Reset circuit breaker for %s", server_name)
|
||||
return True
|
||||
return False
|
||||
Reference in New Issue
Block a user