""" MCP Connection Management Handles connection lifecycle, pooling, and automatic reconnection for MCP servers. """ import asyncio import logging import time from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from enum import Enum from typing import Any import httpx from .config import MCPServerConfig, TransportType from .exceptions import MCPConnectionError, MCPTimeoutError logger = logging.getLogger(__name__) class ConnectionState(str, Enum): """Connection state enumeration.""" DISCONNECTED = "disconnected" CONNECTING = "connecting" CONNECTED = "connected" RECONNECTING = "reconnecting" ERROR = "error" class MCPConnection: """ Manages a single connection to an MCP server. Handles connection lifecycle, health checking, and automatic reconnection. """ def __init__( self, server_name: str, config: MCPServerConfig, ) -> None: """ Initialize connection. Args: server_name: Name of the MCP server config: Server configuration """ self.server_name = server_name self.config = config self._state = ConnectionState.DISCONNECTED self._client: httpx.AsyncClient | None = None self._lock = asyncio.Lock() self._last_activity: float | None = None self._connection_attempts = 0 self._last_error: Exception | None = None # Reconnection settings self._base_delay = config.retry_delay self._max_delay = config.retry_max_delay self._max_attempts = config.retry_attempts @property def state(self) -> ConnectionState: """Get current connection state.""" return self._state @property def is_connected(self) -> bool: """Check if connection is established.""" return self._state == ConnectionState.CONNECTED @property def last_error(self) -> Exception | None: """Get the last error that occurred.""" return self._last_error async def connect(self) -> None: """ Establish connection to the MCP server. Raises: MCPConnectionError: If connection fails after all retries """ async with self._lock: if self._state == ConnectionState.CONNECTED: return self._state = ConnectionState.CONNECTING self._connection_attempts = 0 self._last_error = None while self._connection_attempts < self._max_attempts: try: await self._do_connect() self._state = ConnectionState.CONNECTED self._last_activity = time.time() logger.info( "Connected to MCP server: %s at %s", self.server_name, self.config.url, ) return except Exception as e: self._connection_attempts += 1 self._last_error = e logger.warning( "Connection attempt %d/%d failed for %s: %s", self._connection_attempts, self._max_attempts, self.server_name, e, ) if self._connection_attempts < self._max_attempts: delay = self._calculate_backoff_delay() logger.debug( "Retrying connection to %s in %.1fs", self.server_name, delay, ) await asyncio.sleep(delay) # All attempts failed self._state = ConnectionState.ERROR raise MCPConnectionError( f"Failed to connect after {self._max_attempts} attempts", server_name=self.server_name, url=self.config.url, cause=self._last_error, ) async def _do_connect(self) -> None: """Perform the actual connection (transport-specific).""" if self.config.transport == TransportType.HTTP: self._client = httpx.AsyncClient( base_url=self.config.url, timeout=httpx.Timeout(self.config.timeout), headers={ "User-Agent": "Syndarix-MCP-Client/1.0", "Accept": "application/json", }, ) # Verify connectivity with a simple request try: # Try to hit the MCP capabilities endpoint response = await self._client.get("/mcp/capabilities") if response.status_code not in (200, 404): # 404 is acceptable - server might not have capabilities endpoint response.raise_for_status() except httpx.HTTPStatusError as e: if e.response.status_code != 404: raise except httpx.ConnectError as e: raise MCPConnectionError( "Failed to connect to server", server_name=self.server_name, url=self.config.url, cause=e, ) else: # For STDIO and SSE transports, we'll implement later raise NotImplementedError( f"Transport {self.config.transport} not yet implemented" ) def _calculate_backoff_delay(self) -> float: """Calculate exponential backoff delay with jitter.""" import random delay = self._base_delay * (2 ** (self._connection_attempts - 1)) delay = min(delay, self._max_delay) # Add jitter (±25%) jitter = delay * 0.25 * (random.random() * 2 - 1) return delay + jitter async def disconnect(self) -> None: """Disconnect from the MCP server.""" async with self._lock: if self._client is not None: try: await self._client.aclose() except Exception as e: logger.warning( "Error closing connection to %s: %s", self.server_name, e, ) finally: self._client = None self._state = ConnectionState.DISCONNECTED logger.info("Disconnected from MCP server: %s", self.server_name) async def reconnect(self) -> None: """Reconnect to the MCP server.""" async with self._lock: self._state = ConnectionState.RECONNECTING await self.disconnect() await self.connect() async def health_check(self) -> bool: """ Perform a health check on the connection. Returns: True if connection is healthy """ if not self.is_connected or self._client is None: return False try: if self.config.transport == TransportType.HTTP: response = await self._client.get( "/health", timeout=5.0, ) return response.status_code == 200 return True except Exception as e: logger.warning( "Health check failed for %s: %s", self.server_name, e, ) return False async def execute_request( self, method: str, path: str, data: dict[str, Any] | None = None, timeout: float | None = None, ) -> dict[str, Any]: """ Execute an HTTP request to the MCP server. Args: method: HTTP method (GET, POST, etc.) path: Request path data: Optional request body timeout: Optional timeout override Returns: Response data Raises: MCPConnectionError: If not connected MCPTimeoutError: If request times out """ if not self.is_connected or self._client is None: raise MCPConnectionError( "Not connected to server", server_name=self.server_name, ) effective_timeout = timeout or self.config.timeout try: if method.upper() == "GET": response = await self._client.get( path, timeout=effective_timeout, ) elif method.upper() == "POST": response = await self._client.post( path, json=data, timeout=effective_timeout, ) else: response = await self._client.request( method.upper(), path, json=data, timeout=effective_timeout, ) self._last_activity = time.time() response.raise_for_status() return response.json() except httpx.TimeoutException as e: raise MCPTimeoutError( "Request timed out", server_name=self.server_name, timeout_seconds=effective_timeout, operation=f"{method} {path}", ) from e except httpx.HTTPStatusError as e: raise MCPConnectionError( f"HTTP error: {e.response.status_code}", server_name=self.server_name, url=f"{self.config.url}{path}", cause=e, ) except Exception as e: raise MCPConnectionError( f"Request failed: {e}", server_name=self.server_name, cause=e, ) class ConnectionPool: """ Pool of connections to MCP servers. Manages connection lifecycle and provides connection reuse. """ def __init__(self, max_connections_per_server: int = 10) -> None: """ Initialize connection pool. Args: max_connections_per_server: Maximum connections per server """ self._connections: dict[str, MCPConnection] = {} self._lock = asyncio.Lock() self._max_per_server = max_connections_per_server async def get_connection( self, server_name: str, config: MCPServerConfig, ) -> MCPConnection: """ Get or create a connection to a server. Args: server_name: Name of the server config: Server configuration Returns: Active connection """ async with self._lock: if server_name not in self._connections: connection = MCPConnection(server_name, config) await connection.connect() self._connections[server_name] = connection connection = self._connections[server_name] # Reconnect if not connected if not connection.is_connected: await connection.connect() return connection async def release_connection(self, server_name: str) -> None: """ Release a connection (currently just tracks usage). Args: server_name: Name of the server """ # For now, we keep connections alive # Future: implement connection reaping for idle connections async def close_connection(self, server_name: str) -> None: """ Close and remove a connection. Args: server_name: Name of the server """ async with self._lock: if server_name in self._connections: await self._connections[server_name].disconnect() del self._connections[server_name] async def close_all(self) -> None: """Close all connections in the pool.""" async with self._lock: for connection in self._connections.values(): try: await connection.disconnect() except Exception as e: logger.warning("Error closing connection: %s", e) self._connections.clear() logger.info("Closed all MCP connections") async def health_check_all(self) -> dict[str, bool]: """ Perform health check on all connections. Returns: Dict mapping server names to health status """ results = {} for name, connection in self._connections.items(): results[name] = await connection.health_check() return results def get_status(self) -> dict[str, dict[str, Any]]: """ Get status of all connections. Returns: Dict mapping server names to status info """ return { name: { "state": conn.state.value, "is_connected": conn.is_connected, "url": conn.config.url, } for name, conn in self._connections.items() } @asynccontextmanager async def connection( self, server_name: str, config: MCPServerConfig, ) -> AsyncGenerator[MCPConnection, None]: """ Context manager for getting a connection. Usage: async with pool.connection("server", config) as conn: result = await conn.execute_request("POST", "/tool", data) """ conn = await self.get_connection(server_name, config) try: yield conn finally: await self.release_connection(server_name)