- Added `record_action` in `RateLimiter` for precise tracking of slot consumption post-validation. - Introduced deduplication mechanism for warning alerts in `CostController` to prevent spamming. - Refactored `CostController`'s session and daily budget alert handling for improved clarity. - Implemented test suites for `CostController` and `SafetyGuardian` to validate changes. - Expanded integration testing to cover deduplication, validation, and loop detection edge cases.
474 lines
15 KiB
Python
474 lines
15 KiB
Python
"""
|
|
MCP Connection Management
|
|
|
|
Handles connection lifecycle, pooling, and automatic reconnection
|
|
for MCP servers.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
from collections.abc import AsyncGenerator
|
|
from contextlib import asynccontextmanager
|
|
from enum import Enum
|
|
from typing import Any
|
|
|
|
import httpx
|
|
|
|
from .config import MCPServerConfig, TransportType
|
|
from .exceptions import MCPConnectionError, MCPTimeoutError
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ConnectionState(str, Enum):
|
|
"""Connection state enumeration."""
|
|
|
|
DISCONNECTED = "disconnected"
|
|
CONNECTING = "connecting"
|
|
CONNECTED = "connected"
|
|
RECONNECTING = "reconnecting"
|
|
ERROR = "error"
|
|
|
|
|
|
class MCPConnection:
|
|
"""
|
|
Manages a single connection to an MCP server.
|
|
|
|
Handles connection lifecycle, health checking, and automatic reconnection.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
server_name: str,
|
|
config: MCPServerConfig,
|
|
) -> None:
|
|
"""
|
|
Initialize connection.
|
|
|
|
Args:
|
|
server_name: Name of the MCP server
|
|
config: Server configuration
|
|
"""
|
|
self.server_name = server_name
|
|
self.config = config
|
|
self._state = ConnectionState.DISCONNECTED
|
|
self._client: httpx.AsyncClient | None = None
|
|
self._lock = asyncio.Lock()
|
|
self._last_activity: float | None = None
|
|
self._connection_attempts = 0
|
|
self._last_error: Exception | None = None
|
|
|
|
# Reconnection settings
|
|
self._base_delay = config.retry_delay
|
|
self._max_delay = config.retry_max_delay
|
|
self._max_attempts = config.retry_attempts
|
|
|
|
@property
|
|
def state(self) -> ConnectionState:
|
|
"""Get current connection state."""
|
|
return self._state
|
|
|
|
@property
|
|
def is_connected(self) -> bool:
|
|
"""Check if connection is established."""
|
|
return self._state == ConnectionState.CONNECTED
|
|
|
|
@property
|
|
def last_error(self) -> Exception | None:
|
|
"""Get the last error that occurred."""
|
|
return self._last_error
|
|
|
|
async def connect(self) -> None:
|
|
"""
|
|
Establish connection to the MCP server.
|
|
|
|
Raises:
|
|
MCPConnectionError: If connection fails after all retries
|
|
"""
|
|
async with self._lock:
|
|
if self._state == ConnectionState.CONNECTED:
|
|
return
|
|
|
|
self._state = ConnectionState.CONNECTING
|
|
self._connection_attempts = 0
|
|
self._last_error = None
|
|
|
|
while self._connection_attempts < self._max_attempts:
|
|
try:
|
|
await self._do_connect()
|
|
self._state = ConnectionState.CONNECTED
|
|
self._last_activity = time.time()
|
|
logger.info(
|
|
"Connected to MCP server: %s at %s",
|
|
self.server_name,
|
|
self.config.url,
|
|
)
|
|
return
|
|
except Exception as e:
|
|
self._connection_attempts += 1
|
|
self._last_error = e
|
|
logger.warning(
|
|
"Connection attempt %d/%d failed for %s: %s",
|
|
self._connection_attempts,
|
|
self._max_attempts,
|
|
self.server_name,
|
|
e,
|
|
)
|
|
|
|
if self._connection_attempts < self._max_attempts:
|
|
delay = self._calculate_backoff_delay()
|
|
logger.debug(
|
|
"Retrying connection to %s in %.1fs",
|
|
self.server_name,
|
|
delay,
|
|
)
|
|
await asyncio.sleep(delay)
|
|
|
|
# All attempts failed
|
|
self._state = ConnectionState.ERROR
|
|
raise MCPConnectionError(
|
|
f"Failed to connect after {self._max_attempts} attempts",
|
|
server_name=self.server_name,
|
|
url=self.config.url,
|
|
cause=self._last_error,
|
|
)
|
|
|
|
async def _do_connect(self) -> None:
|
|
"""Perform the actual connection (transport-specific)."""
|
|
if self.config.transport == TransportType.HTTP:
|
|
self._client = httpx.AsyncClient(
|
|
base_url=self.config.url,
|
|
timeout=httpx.Timeout(self.config.timeout),
|
|
headers={
|
|
"User-Agent": "Syndarix-MCP-Client/1.0",
|
|
"Accept": "application/json",
|
|
},
|
|
)
|
|
# Verify connectivity with a simple request
|
|
try:
|
|
# Try to hit the MCP capabilities endpoint
|
|
response = await self._client.get("/mcp/capabilities")
|
|
if response.status_code not in (200, 404):
|
|
# 404 is acceptable - server might not have capabilities endpoint
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError as e:
|
|
if e.response.status_code != 404:
|
|
raise
|
|
except httpx.ConnectError as e:
|
|
raise MCPConnectionError(
|
|
"Failed to connect to server",
|
|
server_name=self.server_name,
|
|
url=self.config.url,
|
|
cause=e,
|
|
) from e
|
|
else:
|
|
# For STDIO and SSE transports, we'll implement later
|
|
raise NotImplementedError(
|
|
f"Transport {self.config.transport} not yet implemented"
|
|
)
|
|
|
|
def _calculate_backoff_delay(self) -> float:
|
|
"""Calculate exponential backoff delay with jitter."""
|
|
import random
|
|
|
|
delay = self._base_delay * (2 ** (self._connection_attempts - 1))
|
|
delay = min(delay, self._max_delay)
|
|
# Add jitter (±25%)
|
|
jitter = delay * 0.25 * (random.random() * 2 - 1)
|
|
return delay + jitter
|
|
|
|
async def disconnect(self) -> None:
|
|
"""Disconnect from the MCP server."""
|
|
async with self._lock:
|
|
if self._client is not None:
|
|
try:
|
|
await self._client.aclose()
|
|
except Exception as e:
|
|
logger.warning(
|
|
"Error closing connection to %s: %s",
|
|
self.server_name,
|
|
e,
|
|
)
|
|
finally:
|
|
self._client = None
|
|
|
|
self._state = ConnectionState.DISCONNECTED
|
|
logger.info("Disconnected from MCP server: %s", self.server_name)
|
|
|
|
async def reconnect(self) -> None:
|
|
"""Reconnect to the MCP server."""
|
|
async with self._lock:
|
|
self._state = ConnectionState.RECONNECTING
|
|
await self.disconnect()
|
|
await self.connect()
|
|
|
|
async def health_check(self) -> bool:
|
|
"""
|
|
Perform a health check on the connection.
|
|
|
|
Returns:
|
|
True if connection is healthy
|
|
"""
|
|
if not self.is_connected or self._client is None:
|
|
return False
|
|
|
|
try:
|
|
if self.config.transport == TransportType.HTTP:
|
|
response = await self._client.get(
|
|
"/health",
|
|
timeout=5.0,
|
|
)
|
|
return response.status_code == 200
|
|
return True
|
|
except Exception as e:
|
|
logger.warning(
|
|
"Health check failed for %s: %s",
|
|
self.server_name,
|
|
e,
|
|
)
|
|
return False
|
|
|
|
async def execute_request(
|
|
self,
|
|
method: str,
|
|
path: str,
|
|
data: dict[str, Any] | None = None,
|
|
timeout: float | None = None,
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Execute an HTTP request to the MCP server.
|
|
|
|
Args:
|
|
method: HTTP method (GET, POST, etc.)
|
|
path: Request path
|
|
data: Optional request body
|
|
timeout: Optional timeout override
|
|
|
|
Returns:
|
|
Response data
|
|
|
|
Raises:
|
|
MCPConnectionError: If not connected
|
|
MCPTimeoutError: If request times out
|
|
"""
|
|
if not self.is_connected or self._client is None:
|
|
raise MCPConnectionError(
|
|
"Not connected to server",
|
|
server_name=self.server_name,
|
|
)
|
|
|
|
effective_timeout = timeout or self.config.timeout
|
|
|
|
try:
|
|
if method.upper() == "GET":
|
|
response = await self._client.get(
|
|
path,
|
|
timeout=effective_timeout,
|
|
)
|
|
elif method.upper() == "POST":
|
|
response = await self._client.post(
|
|
path,
|
|
json=data,
|
|
timeout=effective_timeout,
|
|
)
|
|
else:
|
|
response = await self._client.request(
|
|
method.upper(),
|
|
path,
|
|
json=data,
|
|
timeout=effective_timeout,
|
|
)
|
|
|
|
self._last_activity = time.time()
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
except httpx.TimeoutException as e:
|
|
raise MCPTimeoutError(
|
|
"Request timed out",
|
|
server_name=self.server_name,
|
|
timeout_seconds=effective_timeout,
|
|
operation=f"{method} {path}",
|
|
) from e
|
|
except httpx.HTTPStatusError as e:
|
|
raise MCPConnectionError(
|
|
f"HTTP error: {e.response.status_code}",
|
|
server_name=self.server_name,
|
|
url=f"{self.config.url}{path}",
|
|
cause=e,
|
|
) from e
|
|
except Exception as e:
|
|
raise MCPConnectionError(
|
|
f"Request failed: {e}",
|
|
server_name=self.server_name,
|
|
cause=e,
|
|
) from e
|
|
|
|
|
|
class ConnectionPool:
|
|
"""
|
|
Pool of connections to MCP servers.
|
|
|
|
Manages connection lifecycle and provides connection reuse.
|
|
"""
|
|
|
|
def __init__(self, max_connections_per_server: int = 10) -> None:
|
|
"""
|
|
Initialize connection pool.
|
|
|
|
Args:
|
|
max_connections_per_server: Maximum connections per server
|
|
"""
|
|
self._connections: dict[str, MCPConnection] = {}
|
|
self._lock = asyncio.Lock()
|
|
self._per_server_locks: dict[str, asyncio.Lock] = {}
|
|
self._max_per_server = max_connections_per_server
|
|
|
|
def _get_server_lock(self, server_name: str) -> asyncio.Lock:
|
|
"""Get or create a lock for a specific server.
|
|
|
|
Uses setdefault for atomic dict access to prevent race conditions
|
|
where two coroutines could create different locks for the same server.
|
|
"""
|
|
# setdefault is atomic - if key exists, returns existing value
|
|
# if key doesn't exist, inserts new value and returns it
|
|
return self._per_server_locks.setdefault(server_name, asyncio.Lock())
|
|
|
|
async def get_connection(
|
|
self,
|
|
server_name: str,
|
|
config: MCPServerConfig,
|
|
) -> MCPConnection:
|
|
"""
|
|
Get or create a connection to a server.
|
|
|
|
Uses per-server locking to avoid blocking all connections
|
|
when establishing a new connection.
|
|
|
|
Args:
|
|
server_name: Name of the server
|
|
config: Server configuration
|
|
|
|
Returns:
|
|
Active connection
|
|
"""
|
|
# Quick check without lock - if connection exists and is connected, return it
|
|
if server_name in self._connections:
|
|
connection = self._connections[server_name]
|
|
if connection.is_connected:
|
|
return connection
|
|
|
|
# Need to create or reconnect - use per-server lock to avoid blocking others
|
|
async with self._lock:
|
|
server_lock = self._get_server_lock(server_name)
|
|
|
|
async with server_lock:
|
|
# Double-check after acquiring per-server lock
|
|
if server_name in self._connections:
|
|
connection = self._connections[server_name]
|
|
if connection.is_connected:
|
|
return connection
|
|
# Connection exists but not connected - reconnect
|
|
await connection.connect()
|
|
return connection
|
|
|
|
# Create new connection (outside global lock, under per-server lock)
|
|
connection = MCPConnection(server_name, config)
|
|
await connection.connect()
|
|
|
|
# Store connection under global lock
|
|
async with self._lock:
|
|
self._connections[server_name] = connection
|
|
|
|
return connection
|
|
|
|
async def release_connection(self, server_name: str) -> None:
|
|
"""
|
|
Release a connection (currently just tracks usage).
|
|
|
|
Args:
|
|
server_name: Name of the server
|
|
"""
|
|
# For now, we keep connections alive
|
|
# Future: implement connection reaping for idle connections
|
|
|
|
async def close_connection(self, server_name: str) -> None:
|
|
"""
|
|
Close and remove a connection.
|
|
|
|
Args:
|
|
server_name: Name of the server
|
|
"""
|
|
async with self._lock:
|
|
if server_name in self._connections:
|
|
await self._connections[server_name].disconnect()
|
|
del self._connections[server_name]
|
|
# Clean up per-server lock
|
|
if server_name in self._per_server_locks:
|
|
del self._per_server_locks[server_name]
|
|
|
|
async def close_all(self) -> None:
|
|
"""Close all connections in the pool."""
|
|
async with self._lock:
|
|
for connection in self._connections.values():
|
|
try:
|
|
await connection.disconnect()
|
|
except Exception as e:
|
|
logger.warning("Error closing connection: %s", e)
|
|
|
|
self._connections.clear()
|
|
self._per_server_locks.clear()
|
|
logger.info("Closed all MCP connections")
|
|
|
|
async def health_check_all(self) -> dict[str, bool]:
|
|
"""
|
|
Perform health check on all connections.
|
|
|
|
Returns:
|
|
Dict mapping server names to health status
|
|
"""
|
|
# Copy connections under lock to prevent modification during iteration
|
|
async with self._lock:
|
|
connections_snapshot = dict(self._connections)
|
|
|
|
results = {}
|
|
for name, connection in connections_snapshot.items():
|
|
results[name] = await connection.health_check()
|
|
return results
|
|
|
|
def get_status(self) -> dict[str, dict[str, Any]]:
|
|
"""
|
|
Get status of all connections.
|
|
|
|
Returns:
|
|
Dict mapping server names to status info
|
|
"""
|
|
return {
|
|
name: {
|
|
"state": conn.state.value,
|
|
"is_connected": conn.is_connected,
|
|
"url": conn.config.url,
|
|
}
|
|
for name, conn in self._connections.items()
|
|
}
|
|
|
|
@asynccontextmanager
|
|
async def connection(
|
|
self,
|
|
server_name: str,
|
|
config: MCPServerConfig,
|
|
) -> AsyncGenerator[MCPConnection, None]:
|
|
"""
|
|
Context manager for getting a connection.
|
|
|
|
Usage:
|
|
async with pool.connection("server", config) as conn:
|
|
result = await conn.execute_request("POST", "/tool", data)
|
|
"""
|
|
conn = await self.get_connection(server_name, config)
|
|
try:
|
|
yield conn
|
|
finally:
|
|
await self.release_connection(server_name)
|