Files
fast-next-template/backend/app/services/mcp/connection.py
Felipe Cardoso caf283bed2 feat(safety): enhance rate limiting and cost control with alert deduplication and usage tracking
- Added `record_action` in `RateLimiter` for precise tracking of slot consumption post-validation.
- Introduced deduplication mechanism for warning alerts in `CostController` to prevent spamming.
- Refactored `CostController`'s session and daily budget alert handling for improved clarity.
- Implemented test suites for `CostController` and `SafetyGuardian` to validate changes.
- Expanded integration testing to cover deduplication, validation, and loop detection edge cases.
2026-01-03 17:55:34 +01:00

474 lines
15 KiB
Python

"""
MCP Connection Management
Handles connection lifecycle, pooling, and automatic reconnection
for MCP servers.
"""
import asyncio
import logging
import time
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from enum import Enum
from typing import Any
import httpx
from .config import MCPServerConfig, TransportType
from .exceptions import MCPConnectionError, MCPTimeoutError
logger = logging.getLogger(__name__)
class ConnectionState(str, Enum):
"""Connection state enumeration."""
DISCONNECTED = "disconnected"
CONNECTING = "connecting"
CONNECTED = "connected"
RECONNECTING = "reconnecting"
ERROR = "error"
class MCPConnection:
"""
Manages a single connection to an MCP server.
Handles connection lifecycle, health checking, and automatic reconnection.
"""
def __init__(
self,
server_name: str,
config: MCPServerConfig,
) -> None:
"""
Initialize connection.
Args:
server_name: Name of the MCP server
config: Server configuration
"""
self.server_name = server_name
self.config = config
self._state = ConnectionState.DISCONNECTED
self._client: httpx.AsyncClient | None = None
self._lock = asyncio.Lock()
self._last_activity: float | None = None
self._connection_attempts = 0
self._last_error: Exception | None = None
# Reconnection settings
self._base_delay = config.retry_delay
self._max_delay = config.retry_max_delay
self._max_attempts = config.retry_attempts
@property
def state(self) -> ConnectionState:
"""Get current connection state."""
return self._state
@property
def is_connected(self) -> bool:
"""Check if connection is established."""
return self._state == ConnectionState.CONNECTED
@property
def last_error(self) -> Exception | None:
"""Get the last error that occurred."""
return self._last_error
async def connect(self) -> None:
"""
Establish connection to the MCP server.
Raises:
MCPConnectionError: If connection fails after all retries
"""
async with self._lock:
if self._state == ConnectionState.CONNECTED:
return
self._state = ConnectionState.CONNECTING
self._connection_attempts = 0
self._last_error = None
while self._connection_attempts < self._max_attempts:
try:
await self._do_connect()
self._state = ConnectionState.CONNECTED
self._last_activity = time.time()
logger.info(
"Connected to MCP server: %s at %s",
self.server_name,
self.config.url,
)
return
except Exception as e:
self._connection_attempts += 1
self._last_error = e
logger.warning(
"Connection attempt %d/%d failed for %s: %s",
self._connection_attempts,
self._max_attempts,
self.server_name,
e,
)
if self._connection_attempts < self._max_attempts:
delay = self._calculate_backoff_delay()
logger.debug(
"Retrying connection to %s in %.1fs",
self.server_name,
delay,
)
await asyncio.sleep(delay)
# All attempts failed
self._state = ConnectionState.ERROR
raise MCPConnectionError(
f"Failed to connect after {self._max_attempts} attempts",
server_name=self.server_name,
url=self.config.url,
cause=self._last_error,
)
async def _do_connect(self) -> None:
"""Perform the actual connection (transport-specific)."""
if self.config.transport == TransportType.HTTP:
self._client = httpx.AsyncClient(
base_url=self.config.url,
timeout=httpx.Timeout(self.config.timeout),
headers={
"User-Agent": "Syndarix-MCP-Client/1.0",
"Accept": "application/json",
},
)
# Verify connectivity with a simple request
try:
# Try to hit the MCP capabilities endpoint
response = await self._client.get("/mcp/capabilities")
if response.status_code not in (200, 404):
# 404 is acceptable - server might not have capabilities endpoint
response.raise_for_status()
except httpx.HTTPStatusError as e:
if e.response.status_code != 404:
raise
except httpx.ConnectError as e:
raise MCPConnectionError(
"Failed to connect to server",
server_name=self.server_name,
url=self.config.url,
cause=e,
) from e
else:
# For STDIO and SSE transports, we'll implement later
raise NotImplementedError(
f"Transport {self.config.transport} not yet implemented"
)
def _calculate_backoff_delay(self) -> float:
"""Calculate exponential backoff delay with jitter."""
import random
delay = self._base_delay * (2 ** (self._connection_attempts - 1))
delay = min(delay, self._max_delay)
# Add jitter (±25%)
jitter = delay * 0.25 * (random.random() * 2 - 1)
return delay + jitter
async def disconnect(self) -> None:
"""Disconnect from the MCP server."""
async with self._lock:
if self._client is not None:
try:
await self._client.aclose()
except Exception as e:
logger.warning(
"Error closing connection to %s: %s",
self.server_name,
e,
)
finally:
self._client = None
self._state = ConnectionState.DISCONNECTED
logger.info("Disconnected from MCP server: %s", self.server_name)
async def reconnect(self) -> None:
"""Reconnect to the MCP server."""
async with self._lock:
self._state = ConnectionState.RECONNECTING
await self.disconnect()
await self.connect()
async def health_check(self) -> bool:
"""
Perform a health check on the connection.
Returns:
True if connection is healthy
"""
if not self.is_connected or self._client is None:
return False
try:
if self.config.transport == TransportType.HTTP:
response = await self._client.get(
"/health",
timeout=5.0,
)
return response.status_code == 200
return True
except Exception as e:
logger.warning(
"Health check failed for %s: %s",
self.server_name,
e,
)
return False
async def execute_request(
self,
method: str,
path: str,
data: dict[str, Any] | None = None,
timeout: float | None = None,
) -> dict[str, Any]:
"""
Execute an HTTP request to the MCP server.
Args:
method: HTTP method (GET, POST, etc.)
path: Request path
data: Optional request body
timeout: Optional timeout override
Returns:
Response data
Raises:
MCPConnectionError: If not connected
MCPTimeoutError: If request times out
"""
if not self.is_connected or self._client is None:
raise MCPConnectionError(
"Not connected to server",
server_name=self.server_name,
)
effective_timeout = timeout or self.config.timeout
try:
if method.upper() == "GET":
response = await self._client.get(
path,
timeout=effective_timeout,
)
elif method.upper() == "POST":
response = await self._client.post(
path,
json=data,
timeout=effective_timeout,
)
else:
response = await self._client.request(
method.upper(),
path,
json=data,
timeout=effective_timeout,
)
self._last_activity = time.time()
response.raise_for_status()
return response.json()
except httpx.TimeoutException as e:
raise MCPTimeoutError(
"Request timed out",
server_name=self.server_name,
timeout_seconds=effective_timeout,
operation=f"{method} {path}",
) from e
except httpx.HTTPStatusError as e:
raise MCPConnectionError(
f"HTTP error: {e.response.status_code}",
server_name=self.server_name,
url=f"{self.config.url}{path}",
cause=e,
) from e
except Exception as e:
raise MCPConnectionError(
f"Request failed: {e}",
server_name=self.server_name,
cause=e,
) from e
class ConnectionPool:
"""
Pool of connections to MCP servers.
Manages connection lifecycle and provides connection reuse.
"""
def __init__(self, max_connections_per_server: int = 10) -> None:
"""
Initialize connection pool.
Args:
max_connections_per_server: Maximum connections per server
"""
self._connections: dict[str, MCPConnection] = {}
self._lock = asyncio.Lock()
self._per_server_locks: dict[str, asyncio.Lock] = {}
self._max_per_server = max_connections_per_server
def _get_server_lock(self, server_name: str) -> asyncio.Lock:
"""Get or create a lock for a specific server.
Uses setdefault for atomic dict access to prevent race conditions
where two coroutines could create different locks for the same server.
"""
# setdefault is atomic - if key exists, returns existing value
# if key doesn't exist, inserts new value and returns it
return self._per_server_locks.setdefault(server_name, asyncio.Lock())
async def get_connection(
self,
server_name: str,
config: MCPServerConfig,
) -> MCPConnection:
"""
Get or create a connection to a server.
Uses per-server locking to avoid blocking all connections
when establishing a new connection.
Args:
server_name: Name of the server
config: Server configuration
Returns:
Active connection
"""
# Quick check without lock - if connection exists and is connected, return it
if server_name in self._connections:
connection = self._connections[server_name]
if connection.is_connected:
return connection
# Need to create or reconnect - use per-server lock to avoid blocking others
async with self._lock:
server_lock = self._get_server_lock(server_name)
async with server_lock:
# Double-check after acquiring per-server lock
if server_name in self._connections:
connection = self._connections[server_name]
if connection.is_connected:
return connection
# Connection exists but not connected - reconnect
await connection.connect()
return connection
# Create new connection (outside global lock, under per-server lock)
connection = MCPConnection(server_name, config)
await connection.connect()
# Store connection under global lock
async with self._lock:
self._connections[server_name] = connection
return connection
async def release_connection(self, server_name: str) -> None:
"""
Release a connection (currently just tracks usage).
Args:
server_name: Name of the server
"""
# For now, we keep connections alive
# Future: implement connection reaping for idle connections
async def close_connection(self, server_name: str) -> None:
"""
Close and remove a connection.
Args:
server_name: Name of the server
"""
async with self._lock:
if server_name in self._connections:
await self._connections[server_name].disconnect()
del self._connections[server_name]
# Clean up per-server lock
if server_name in self._per_server_locks:
del self._per_server_locks[server_name]
async def close_all(self) -> None:
"""Close all connections in the pool."""
async with self._lock:
for connection in self._connections.values():
try:
await connection.disconnect()
except Exception as e:
logger.warning("Error closing connection: %s", e)
self._connections.clear()
self._per_server_locks.clear()
logger.info("Closed all MCP connections")
async def health_check_all(self) -> dict[str, bool]:
"""
Perform health check on all connections.
Returns:
Dict mapping server names to health status
"""
# Copy connections under lock to prevent modification during iteration
async with self._lock:
connections_snapshot = dict(self._connections)
results = {}
for name, connection in connections_snapshot.items():
results[name] = await connection.health_check()
return results
def get_status(self) -> dict[str, dict[str, Any]]:
"""
Get status of all connections.
Returns:
Dict mapping server names to status info
"""
return {
name: {
"state": conn.state.value,
"is_connected": conn.is_connected,
"url": conn.config.url,
}
for name, conn in self._connections.items()
}
@asynccontextmanager
async def connection(
self,
server_name: str,
config: MCPServerConfig,
) -> AsyncGenerator[MCPConnection, None]:
"""
Context manager for getting a connection.
Usage:
async with pool.connection("server", config) as conn:
result = await conn.execute_request("POST", "/tool", data)
"""
conn = await self.get_connection(server_name, config)
try:
yield conn
finally:
await self.release_connection(server_name)