From e5975fa5d01ef8d4a971e2dfea94b05208c0774f Mon Sep 17 00:00:00 2001 From: Felipe Cardoso Date: Sat, 3 Jan 2026 11:12:41 +0100 Subject: [PATCH] feat(backend): implement MCP client infrastructure (#55) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Core MCP client implementation with comprehensive tooling: **Services:** - MCPClientManager: Main facade for all MCP operations - MCPServerRegistry: Thread-safe singleton for server configs - ConnectionPool: Connection pooling with auto-reconnection - ToolRouter: Automatic tool routing with circuit breaker - AsyncCircuitBreaker: Custom async-compatible circuit breaker **Configuration:** - YAML-based config with Pydantic models - Environment variable expansion support - Transport types: HTTP, SSE, STDIO **API Endpoints:** - GET /mcp/servers - List all MCP servers - GET /mcp/servers/{name}/tools - List server tools - GET /mcp/tools - List all tools from all servers - GET /mcp/health - Health check all servers - POST /mcp/call - Execute tool (admin only) - GET /mcp/circuit-breakers - Circuit breaker status - POST /mcp/circuit-breakers/{name}/reset - Reset circuit breaker - POST /mcp/servers/{name}/reconnect - Force reconnection **Testing:** - 156 unit tests with comprehensive coverage - Tests for all services, routes, and error handling - Proper mocking and async test support **Documentation:** - MCP_CLIENT.md with usage examples - Phase 2+ workflow documentation 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- backend/app/api/main.py | 4 + backend/app/api/routes/mcp.py | 444 +++++++++++++ backend/app/services/mcp/__init__.py | 85 +++ backend/app/services/mcp/client_manager.py | 417 ++++++++++++ backend/app/services/mcp/config.py | 234 +++++++ backend/app/services/mcp/connection.py | 435 ++++++++++++ backend/app/services/mcp/exceptions.py | 201 ++++++ backend/app/services/mcp/registry.py | 305 +++++++++ backend/app/services/mcp/routing.py | 619 ++++++++++++++++++ backend/docs/MCP_CLIENT.md | 324 +++++++++ backend/mcp_servers.yaml | 60 ++ backend/pyproject.toml | 20 + backend/tests/api/routes/test_mcp.py | 491 ++++++++++++++ backend/tests/services/mcp/__init__.py | 1 + .../tests/services/mcp/test_client_manager.py | 395 +++++++++++ backend/tests/services/mcp/test_config.py | 319 +++++++++ backend/tests/services/mcp/test_connection.py | 405 ++++++++++++ backend/tests/services/mcp/test_exceptions.py | 259 ++++++++ backend/tests/services/mcp/test_registry.py | 272 ++++++++ backend/tests/services/mcp/test_routing.py | 345 ++++++++++ backend/uv.lock | 63 ++ docs/development/WORKFLOW.md | 65 ++ 22 files changed, 5763 insertions(+) create mode 100644 backend/app/api/routes/mcp.py create mode 100644 backend/app/services/mcp/__init__.py create mode 100644 backend/app/services/mcp/client_manager.py create mode 100644 backend/app/services/mcp/config.py create mode 100644 backend/app/services/mcp/connection.py create mode 100644 backend/app/services/mcp/exceptions.py create mode 100644 backend/app/services/mcp/registry.py create mode 100644 backend/app/services/mcp/routing.py create mode 100644 backend/docs/MCP_CLIENT.md create mode 100644 backend/mcp_servers.yaml create mode 100644 backend/tests/api/routes/test_mcp.py create mode 100644 backend/tests/services/mcp/__init__.py create mode 100644 backend/tests/services/mcp/test_client_manager.py create mode 100644 backend/tests/services/mcp/test_config.py create mode 100644 backend/tests/services/mcp/test_connection.py create mode 100644 backend/tests/services/mcp/test_exceptions.py create mode 100644 backend/tests/services/mcp/test_registry.py create mode 100644 backend/tests/services/mcp/test_routing.py diff --git a/backend/app/api/main.py b/backend/app/api/main.py index c083321..16e2594 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -7,6 +7,7 @@ from app.api.routes import ( auth, events, issues, + mcp, oauth, oauth_provider, organizations, @@ -31,6 +32,9 @@ api_router.include_router( # SSE events router - no prefix, routes define full paths api_router.include_router(events.router, tags=["Events"]) +# MCP (Model Context Protocol) router +api_router.include_router(mcp.router, prefix="/mcp", tags=["MCP"]) + # Syndarix domain routers api_router.include_router(projects.router, prefix="/projects", tags=["Projects"]) api_router.include_router( diff --git a/backend/app/api/routes/mcp.py b/backend/app/api/routes/mcp.py new file mode 100644 index 0000000..0fef4a5 --- /dev/null +++ b/backend/app/api/routes/mcp.py @@ -0,0 +1,444 @@ +""" +MCP (Model Context Protocol) API Endpoints + +Provides REST endpoints for managing MCP server connections +and executing tool calls. +""" + +import logging +import re +from typing import Annotated, Any + +from fastapi import APIRouter, Depends, HTTPException, Path, status +from pydantic import BaseModel, Field + +from app.api.dependencies.permissions import require_superuser +from app.models.user import User +from app.services.mcp import ( + MCPCircuitOpenError, + MCPClientManager, + MCPConnectionError, + MCPError, + MCPServerNotFoundError, + MCPTimeoutError, + MCPToolError, + MCPToolNotFoundError, + get_mcp_client, +) + +logger = logging.getLogger(__name__) + +router = APIRouter() + +# Server name validation pattern: alphanumeric, hyphens, underscores, 1-64 chars +SERVER_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,64}$") + +# Type alias for validated server name path parameter +ServerNamePath = Annotated[ + str, + Path( + description="MCP server name", + min_length=1, + max_length=64, + pattern=r"^[a-zA-Z0-9_-]+$", + ), +] + + +# ============================================================================ +# Request/Response Schemas +# ============================================================================ + + +class ServerInfo(BaseModel): + """Information about an MCP server.""" + + name: str = Field(..., description="Server name") + url: str = Field(..., description="Server URL") + enabled: bool = Field(..., description="Whether server is enabled") + timeout: int = Field(..., description="Request timeout in seconds") + transport: str = Field(..., description="Transport type (http, stdio, sse)") + description: str | None = Field(None, description="Server description") + + +class ServerListResponse(BaseModel): + """Response containing list of MCP servers.""" + + servers: list[ServerInfo] + total: int + + +class ToolInfoResponse(BaseModel): + """Information about an MCP tool.""" + + name: str = Field(..., description="Tool name") + description: str | None = Field(None, description="Tool description") + server_name: str | None = Field(None, description="Server providing the tool") + input_schema: dict[str, Any] | None = Field(None, description="JSON schema for input") + + +class ToolListResponse(BaseModel): + """Response containing list of tools.""" + + tools: list[ToolInfoResponse] + total: int + + +class ServerHealthStatus(BaseModel): + """Health status for a server.""" + + name: str + healthy: bool + state: str + url: str + error: str | None = None + tools_count: int = 0 + + +class HealthCheckResponse(BaseModel): + """Response containing health status of all servers.""" + + servers: dict[str, ServerHealthStatus] + healthy_count: int + unhealthy_count: int + total: int + + +class ToolCallRequest(BaseModel): + """Request to execute a tool.""" + + server: str = Field(..., description="MCP server name") + tool: str = Field(..., description="Tool name to execute") + arguments: dict[str, Any] = Field( + default_factory=dict, + description="Tool arguments", + ) + timeout: float | None = Field( + None, + description="Optional timeout override in seconds", + ) + + +class ToolCallResponse(BaseModel): + """Response from tool execution.""" + + success: bool + data: Any | None = None + error: str | None = None + error_code: str | None = None + tool_name: str | None = None + server_name: str | None = None + execution_time_ms: float = 0.0 + request_id: str | None = None + + +class CircuitBreakerStatus(BaseModel): + """Status of a circuit breaker.""" + + server_name: str + state: str + failure_count: int + + +class CircuitBreakerListResponse(BaseModel): + """Response containing circuit breaker statuses.""" + + circuit_breakers: list[CircuitBreakerStatus] + + +# ============================================================================ +# Endpoints +# ============================================================================ + + +@router.get( + "/servers", + response_model=ServerListResponse, + summary="List MCP Servers", + description="Get list of all registered MCP servers with their configurations.", +) +async def list_servers( + mcp: MCPClientManager = Depends(get_mcp_client), +) -> ServerListResponse: + """List all registered MCP servers.""" + servers = [] + + for name in mcp.list_servers(): + try: + config = mcp.get_server_config(name) + servers.append( + ServerInfo( + name=name, + url=config.url, + enabled=config.enabled, + timeout=config.timeout, + transport=config.transport.value, + description=config.description, + ) + ) + except MCPServerNotFoundError: + continue + + return ServerListResponse( + servers=servers, + total=len(servers), + ) + + +@router.get( + "/servers/{server_name}/tools", + response_model=ToolListResponse, + summary="List Server Tools", + description="Get list of tools available on a specific MCP server.", +) +async def list_server_tools( + server_name: ServerNamePath, + mcp: MCPClientManager = Depends(get_mcp_client), +) -> ToolListResponse: + """List all tools available on a specific server.""" + try: + tools = await mcp.list_tools(server_name) + return ToolListResponse( + tools=[ + ToolInfoResponse( + name=t.name, + description=t.description, + server_name=t.server_name, + input_schema=t.input_schema, + ) + for t in tools + ], + total=len(tools), + ) + except MCPServerNotFoundError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Server not found: {server_name}", + ) from e + + +@router.get( + "/tools", + response_model=ToolListResponse, + summary="List All Tools", + description="Get list of all tools from all MCP servers.", +) +async def list_all_tools( + mcp: MCPClientManager = Depends(get_mcp_client), +) -> ToolListResponse: + """List all tools from all servers.""" + tools = await mcp.list_all_tools() + return ToolListResponse( + tools=[ + ToolInfoResponse( + name=t.name, + description=t.description, + server_name=t.server_name, + input_schema=t.input_schema, + ) + for t in tools + ], + total=len(tools), + ) + + +@router.get( + "/health", + response_model=HealthCheckResponse, + summary="Health Check", + description="Check health status of all MCP servers.", +) +async def health_check( + mcp: MCPClientManager = Depends(get_mcp_client), +) -> HealthCheckResponse: + """Perform health check on all MCP servers.""" + health_results = await mcp.health_check() + + servers = { + name: ServerHealthStatus( + name=status.name, + healthy=status.healthy, + state=status.state, + url=status.url, + error=status.error, + tools_count=status.tools_count, + ) + for name, status in health_results.items() + } + + healthy_count = sum(1 for s in servers.values() if s.healthy) + unhealthy_count = len(servers) - healthy_count + + return HealthCheckResponse( + servers=servers, + healthy_count=healthy_count, + unhealthy_count=unhealthy_count, + total=len(servers), + ) + + +@router.post( + "/call", + response_model=ToolCallResponse, + summary="Execute Tool (Admin Only)", + description="Execute a tool on an MCP server. Requires superuser privileges.", +) +async def call_tool( + request: ToolCallRequest, + current_user: User = Depends(require_superuser), + mcp: MCPClientManager = Depends(get_mcp_client), +) -> ToolCallResponse: + """ + Execute a tool on an MCP server. + + This endpoint is restricted to superusers for direct tool execution. + Normal tool execution should go through agent workflows. + """ + logger.info( + "Tool call by user %s: %s.%s", + current_user.id, + request.server, + request.tool, + ) + + try: + result = await mcp.call_tool( + server=request.server, + tool=request.tool, + args=request.arguments, + timeout=request.timeout, + ) + + return ToolCallResponse( + success=result.success, + data=result.data, + error=result.error, + error_code=result.error_code, + tool_name=result.tool_name, + server_name=result.server_name, + execution_time_ms=result.execution_time_ms, + request_id=result.request_id, + ) + + except MCPCircuitOpenError as e: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail=f"Server temporarily unavailable: {e.server_name}", + ) from e + except MCPToolNotFoundError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Tool not found: {e.tool_name}", + ) from e + except MCPServerNotFoundError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Server not found: {e.server_name}", + ) from e + except MCPTimeoutError as e: + raise HTTPException( + status_code=status.HTTP_504_GATEWAY_TIMEOUT, + detail=str(e), + ) from e + except MCPConnectionError as e: + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=str(e), + ) from e + except MCPToolError as e: + # Tool errors are returned in the response, not as HTTP errors + return ToolCallResponse( + success=False, + error=str(e), + error_code=e.error_code, + tool_name=e.tool_name, + server_name=e.server_name, + ) + except MCPError as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(e), + ) from e + + +@router.get( + "/circuit-breakers", + response_model=CircuitBreakerListResponse, + summary="List Circuit Breakers", + description="Get status of all circuit breakers.", +) +async def list_circuit_breakers( + mcp: MCPClientManager = Depends(get_mcp_client), +) -> CircuitBreakerListResponse: + """Get status of all circuit breakers.""" + status_dict = mcp.get_circuit_breaker_status() + + return CircuitBreakerListResponse( + circuit_breakers=[ + CircuitBreakerStatus( + server_name=name, + state=info.get("state", "unknown"), + failure_count=info.get("failure_count", 0), + ) + for name, info in status_dict.items() + ] + ) + + +@router.post( + "/circuit-breakers/{server_name}/reset", + status_code=status.HTTP_204_NO_CONTENT, + summary="Reset Circuit Breaker (Admin Only)", + description="Manually reset a circuit breaker for a server.", +) +async def reset_circuit_breaker( + server_name: ServerNamePath, + current_user: User = Depends(require_superuser), + mcp: MCPClientManager = Depends(get_mcp_client), +) -> None: + """Manually reset a circuit breaker.""" + logger.info( + "Circuit breaker reset by user %s for server %s", + current_user.id, + server_name, + ) + + success = await mcp.reset_circuit_breaker(server_name) + if not success: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"No circuit breaker found for server: {server_name}", + ) + + +@router.post( + "/servers/{server_name}/reconnect", + status_code=status.HTTP_204_NO_CONTENT, + summary="Reconnect to Server (Admin Only)", + description="Force reconnection to an MCP server.", +) +async def reconnect_server( + server_name: ServerNamePath, + current_user: User = Depends(require_superuser), + mcp: MCPClientManager = Depends(get_mcp_client), +) -> None: + """Force reconnection to an MCP server.""" + logger.info( + "Reconnect requested by user %s for server %s", + current_user.id, + server_name, + ) + + try: + await mcp.disconnect(server_name) + await mcp.connect(server_name) + except MCPServerNotFoundError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Server not found: {server_name}", + ) from e + except MCPConnectionError as e: + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"Failed to reconnect: {e}", + ) from e diff --git a/backend/app/services/mcp/__init__.py b/backend/app/services/mcp/__init__.py new file mode 100644 index 0000000..db407ea --- /dev/null +++ b/backend/app/services/mcp/__init__.py @@ -0,0 +1,85 @@ +""" +MCP Client Service Package + +Provides infrastructure for communicating with MCP (Model Context Protocol) +servers. This is the foundation for AI agent tool integration. + +Usage: + from app.services.mcp import get_mcp_client, MCPClientManager + + # In FastAPI route + async def my_route(mcp: MCPClientManager = Depends(get_mcp_client)): + result = await mcp.call_tool("llm-gateway", "chat", {"prompt": "Hello"}) + + # Direct usage + manager = MCPClientManager() + await manager.initialize() + result = await manager.call_tool("issues", "create_issue", {...}) + await manager.shutdown() +""" + +from .client_manager import ( + MCPClientManager, + ServerHealth, + get_mcp_client, + reset_mcp_client, + shutdown_mcp_client, +) +from .config import ( + MCPConfig, + MCPServerConfig, + TransportType, + create_default_config, + load_mcp_config, +) +from .connection import ConnectionPool, ConnectionState, MCPConnection +from .exceptions import ( + MCPCircuitOpenError, + MCPConnectionError, + MCPError, + MCPServerNotFoundError, + MCPTimeoutError, + MCPToolError, + MCPToolNotFoundError, + MCPValidationError, +) +from .registry import MCPServerRegistry, ServerCapabilities, get_registry +from .routing import AsyncCircuitBreaker, CircuitState, ToolInfo, ToolResult, ToolRouter + +__all__ = [ + # Main facade + "MCPClientManager", + "get_mcp_client", + "shutdown_mcp_client", + "reset_mcp_client", + "ServerHealth", + # Configuration + "MCPConfig", + "MCPServerConfig", + "TransportType", + "load_mcp_config", + "create_default_config", + # Registry + "MCPServerRegistry", + "ServerCapabilities", + "get_registry", + # Connection + "ConnectionPool", + "ConnectionState", + "MCPConnection", + # Routing + "ToolRouter", + "ToolInfo", + "ToolResult", + "AsyncCircuitBreaker", + "CircuitState", + # Exceptions + "MCPError", + "MCPConnectionError", + "MCPTimeoutError", + "MCPToolError", + "MCPServerNotFoundError", + "MCPToolNotFoundError", + "MCPCircuitOpenError", + "MCPValidationError", +] diff --git a/backend/app/services/mcp/client_manager.py b/backend/app/services/mcp/client_manager.py new file mode 100644 index 0000000..7f80a7c --- /dev/null +++ b/backend/app/services/mcp/client_manager.py @@ -0,0 +1,417 @@ +""" +MCP Client Manager + +Main facade for all MCP operations. Manages server connections, +tool discovery, and provides a unified interface for tool calls. +""" + +import asyncio +import logging +from dataclasses import dataclass +from typing import Any + +from .config import MCPConfig, MCPServerConfig, load_mcp_config +from .connection import ConnectionPool, ConnectionState +from .exceptions import MCPServerNotFoundError +from .registry import MCPServerRegistry, get_registry +from .routing import ToolInfo, ToolResult, ToolRouter + +logger = logging.getLogger(__name__) + + +@dataclass +class ServerHealth: + """Health status for an MCP server.""" + + name: str + healthy: bool + state: str + url: str + error: str | None = None + tools_count: int = 0 + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary.""" + return { + "name": self.name, + "healthy": self.healthy, + "state": self.state, + "url": self.url, + "error": self.error, + "tools_count": self.tools_count, + } + + +class MCPClientManager: + """ + Central manager for all MCP client operations. + + Provides a unified interface for: + - Connecting to MCP servers + - Discovering and calling tools + - Health monitoring + - Connection lifecycle management + + This is the main entry point for MCP operations in the application. + """ + + def __init__( + self, + config: MCPConfig | None = None, + registry: MCPServerRegistry | None = None, + ) -> None: + """ + Initialize the MCP client manager. + + Args: + config: Optional MCP configuration. If None, loads from default. + registry: Optional registry instance. If None, uses singleton. + """ + self._registry = registry or get_registry() + self._pool = ConnectionPool() + self._router: ToolRouter | None = None + self._initialized = False + self._lock = asyncio.Lock() + + # Load configuration if provided + if config is not None: + self._registry.load_config(config) + + @property + def is_initialized(self) -> bool: + """Check if the manager is initialized.""" + return self._initialized + + async def initialize(self, config: MCPConfig | None = None) -> None: + """ + Initialize the MCP client manager. + + Loads configuration, creates connections, and discovers tools. + + Args: + config: Optional configuration to load + """ + async with self._lock: + if self._initialized: + logger.warning("MCPClientManager already initialized") + return + + logger.info("Initializing MCP Client Manager") + + # Load configuration + if config is not None: + self._registry.load_config(config) + elif len(self._registry.list_servers()) == 0: + # Try to load from default location + self._registry.load_config(load_mcp_config()) + + # Create router + self._router = ToolRouter(self._registry, self._pool) + + # Connect to all enabled servers + await self._connect_all_servers() + + # Discover tools from all servers + if self._router: + await self._router.discover_tools() + + self._initialized = True + logger.info( + "MCP Client Manager initialized with %d servers", + len(self._registry.list_enabled_servers()), + ) + + async def _connect_all_servers(self) -> None: + """Connect to all enabled MCP servers.""" + enabled_servers = self._registry.get_enabled_configs() + + for name, config in enabled_servers.items(): + try: + await self._pool.get_connection(name, config) + logger.info("Connected to MCP server: %s", name) + except Exception as e: + logger.error("Failed to connect to MCP server %s: %s", name, e) + + async def shutdown(self) -> None: + """ + Shutdown the MCP client manager. + + Closes all connections and cleans up resources. + """ + async with self._lock: + if not self._initialized: + return + + logger.info("Shutting down MCP Client Manager") + + await self._pool.close_all() + self._initialized = False + + logger.info("MCP Client Manager shutdown complete") + + async def connect(self, server_name: str) -> None: + """ + Connect to a specific MCP server. + + Args: + server_name: Name of the server to connect to + + Raises: + MCPServerNotFoundError: If server is not registered + """ + config = self._registry.get(server_name) + await self._pool.get_connection(server_name, config) + logger.info("Connected to MCP server: %s", server_name) + + async def disconnect(self, server_name: str) -> None: + """ + Disconnect from a specific MCP server. + + Args: + server_name: Name of the server to disconnect from + """ + await self._pool.close_connection(server_name) + logger.info("Disconnected from MCP server: %s", server_name) + + async def disconnect_all(self) -> None: + """Disconnect from all MCP servers.""" + await self._pool.close_all() + + async def call_tool( + self, + server: str, + tool: str, + args: dict[str, Any] | None = None, + timeout: float | None = None, + ) -> ToolResult: + """ + Call a tool on a specific MCP server. + + Args: + server: Name of the MCP server + tool: Name of the tool to call + args: Tool arguments + timeout: Optional timeout override + + Returns: + Tool execution result + """ + if not self._initialized or self._router is None: + await self.initialize() + + assert self._router is not None # Guaranteed after initialize() + return await self._router.call_tool( + server_name=server, + tool_name=tool, + arguments=args, + timeout=timeout, + ) + + async def route_tool( + self, + tool: str, + args: dict[str, Any] | None = None, + timeout: float | None = None, + ) -> ToolResult: + """ + Route a tool call to the appropriate server automatically. + + Args: + tool: Name of the tool to call + args: Tool arguments + timeout: Optional timeout override + + Returns: + Tool execution result + """ + if not self._initialized or self._router is None: + await self.initialize() + + assert self._router is not None # Guaranteed after initialize() + return await self._router.route_tool( + tool_name=tool, + arguments=args, + timeout=timeout, + ) + + async def list_tools(self, server: str) -> list[ToolInfo]: + """ + List all tools available on a specific server. + + Args: + server: Name of the MCP server + + Returns: + List of tool information + """ + capabilities = await self._registry.get_capabilities(server) + return [ + ToolInfo( + name=t.get("name", ""), + description=t.get("description"), + server_name=server, + input_schema=t.get("input_schema"), + ) + for t in capabilities.tools + ] + + async def list_all_tools(self) -> list[ToolInfo]: + """ + List all tools from all servers. + + Returns: + List of tool information + """ + if not self._initialized or self._router is None: + await self.initialize() + + assert self._router is not None # Guaranteed after initialize() + return await self._router.list_all_tools() + + async def health_check(self) -> dict[str, ServerHealth]: + """ + Perform health check on all MCP servers. + + Returns: + Dict mapping server names to health status + """ + results: dict[str, ServerHealth] = {} + pool_status = self._pool.get_status() + pool_health = await self._pool.health_check_all() + + for server_name in self._registry.list_servers(): + try: + config = self._registry.get(server_name) + status = pool_status.get(server_name, {}) + healthy = pool_health.get(server_name, False) + + capabilities = self._registry.get_cached_capabilities(server_name) + + results[server_name] = ServerHealth( + name=server_name, + healthy=healthy, + state=status.get("state", ConnectionState.DISCONNECTED.value), + url=config.url, + tools_count=len(capabilities.tools), + ) + except MCPServerNotFoundError: + pass + except Exception as e: + results[server_name] = ServerHealth( + name=server_name, + healthy=False, + state=ConnectionState.ERROR.value, + url="unknown", + error=str(e), + ) + + return results + + def list_servers(self) -> list[str]: + """Get list of all registered server names.""" + return self._registry.list_servers() + + def list_enabled_servers(self) -> list[str]: + """Get list of enabled server names.""" + return self._registry.list_enabled_servers() + + def get_server_config(self, server_name: str) -> MCPServerConfig: + """ + Get configuration for a specific server. + + Args: + server_name: Name of the server + + Returns: + Server configuration + + Raises: + MCPServerNotFoundError: If server is not registered + """ + return self._registry.get(server_name) + + def register_server( + self, + name: str, + config: MCPServerConfig, + ) -> None: + """ + Register a new MCP server at runtime. + + Args: + name: Unique server name + config: Server configuration + """ + self._registry.register(name, config) + + def unregister_server(self, name: str) -> bool: + """ + Unregister an MCP server. + + Args: + name: Server name to unregister + + Returns: + True if server was found and removed + """ + return self._registry.unregister(name) + + def get_circuit_breaker_status(self) -> dict[str, dict[str, Any]]: + """Get status of all circuit breakers.""" + if self._router is None: + return {} + return self._router.get_circuit_breaker_status() + + async def reset_circuit_breaker(self, server_name: str) -> bool: + """ + Reset a circuit breaker for a server. + + Args: + server_name: Name of the server + + Returns: + True if circuit breaker was reset + """ + if self._router is None: + return False + return await self._router.reset_circuit_breaker(server_name) + + +# Singleton instance +_manager_instance: MCPClientManager | None = None +_manager_lock = asyncio.Lock() + + +async def get_mcp_client() -> MCPClientManager: + """ + Get the global MCP client manager instance. + + This is the main dependency injection point for FastAPI. + Uses proper locking to avoid race conditions in async contexts. + """ + global _manager_instance + + # Use lock for the entire check-and-create operation to avoid race conditions + async with _manager_lock: + if _manager_instance is None: + _manager_instance = MCPClientManager() + await _manager_instance.initialize() + + return _manager_instance + + +async def shutdown_mcp_client() -> None: + """Shutdown the global MCP client manager.""" + global _manager_instance + + # Use lock to prevent race with get_mcp_client() + async with _manager_lock: + if _manager_instance is not None: + await _manager_instance.shutdown() + _manager_instance = None + + +def reset_mcp_client() -> None: + """Reset the global MCP client manager (for testing).""" + global _manager_instance + _manager_instance = None diff --git a/backend/app/services/mcp/config.py b/backend/app/services/mcp/config.py new file mode 100644 index 0000000..ad4803e --- /dev/null +++ b/backend/app/services/mcp/config.py @@ -0,0 +1,234 @@ +""" +MCP Configuration System + +Pydantic models for MCP server configuration with YAML file loading +and environment variable overrides. +""" + +import os +from enum import Enum +from pathlib import Path +from typing import Any + +import yaml +from pydantic import BaseModel, Field, field_validator + + +class TransportType(str, Enum): + """Supported MCP transport types.""" + + HTTP = "http" + STDIO = "stdio" + SSE = "sse" + + +class MCPServerConfig(BaseModel): + """Configuration for a single MCP server.""" + + url: str = Field(..., description="Server URL (supports ${ENV_VAR} syntax)") + transport: TransportType = Field( + default=TransportType.HTTP, + description="Transport protocol to use", + ) + timeout: int = Field( + default=30, + ge=1, + le=600, + description="Request timeout in seconds", + ) + retry_attempts: int = Field( + default=3, + ge=0, + le=10, + description="Number of retry attempts on failure", + ) + retry_delay: float = Field( + default=1.0, + ge=0.1, + le=60.0, + description="Initial delay between retries in seconds", + ) + retry_max_delay: float = Field( + default=30.0, + ge=1.0, + le=300.0, + description="Maximum delay between retries in seconds", + ) + circuit_breaker_threshold: int = Field( + default=5, + ge=1, + le=50, + description="Number of failures before opening circuit", + ) + circuit_breaker_timeout: float = Field( + default=30.0, + ge=5.0, + le=300.0, + description="Seconds to wait before attempting to close circuit", + ) + enabled: bool = Field( + default=True, + description="Whether this server is enabled", + ) + description: str | None = Field( + default=None, + description="Human-readable description of the server", + ) + + @field_validator("url", mode="before") + @classmethod + def expand_env_vars(cls, v: str) -> str: + """Expand environment variables in URL using ${VAR:-default} syntax.""" + if not isinstance(v, str): + return v + + result = v + # Find all ${VAR} or ${VAR:-default} patterns + import re + + pattern = r"\$\{([^}]+)\}" + matches = re.findall(pattern, v) + + for match in matches: + if ":-" in match: + var_name, default = match.split(":-", 1) + else: + var_name, default = match, "" + + env_value = os.environ.get(var_name.strip(), default) + result = result.replace(f"${{{match}}}", env_value) + + return result + + +class MCPConfig(BaseModel): + """Root configuration for all MCP servers.""" + + mcp_servers: dict[str, MCPServerConfig] = Field( + default_factory=dict, + description="Map of server names to their configurations", + ) + + # Global defaults + default_timeout: int = Field( + default=30, + description="Default timeout for all servers", + ) + default_retry_attempts: int = Field( + default=3, + description="Default retry attempts for all servers", + ) + connection_pool_size: int = Field( + default=10, + ge=1, + le=100, + description="Maximum connections per server", + ) + health_check_interval: int = Field( + default=30, + ge=5, + le=300, + description="Seconds between health checks", + ) + + @classmethod + def from_yaml(cls, path: str | Path) -> "MCPConfig": + """Load configuration from a YAML file.""" + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"MCP config file not found: {path}") + + with path.open("r") as f: + data = yaml.safe_load(f) + + if data is None: + data = {} + + return cls.model_validate(data) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "MCPConfig": + """Load configuration from a dictionary.""" + return cls.model_validate(data) + + def get_server(self, name: str) -> MCPServerConfig | None: + """Get a server configuration by name.""" + return self.mcp_servers.get(name) + + def get_enabled_servers(self) -> dict[str, MCPServerConfig]: + """Get all enabled server configurations.""" + return { + name: config + for name, config in self.mcp_servers.items() + if config.enabled + } + + def list_server_names(self) -> list[str]: + """Get list of all configured server names.""" + return list(self.mcp_servers.keys()) + + +# Default configuration path +DEFAULT_CONFIG_PATH = Path(__file__).parent.parent.parent.parent / "mcp_servers.yaml" + + +def load_mcp_config(path: str | Path | None = None) -> MCPConfig: + """ + Load MCP configuration from file or environment. + + Priority: + 1. Explicit path parameter + 2. MCP_CONFIG_PATH environment variable + 3. Default path (backend/mcp_servers.yaml) + 4. Empty config if no file exists + """ + if path is None: + path = os.environ.get("MCP_CONFIG_PATH", str(DEFAULT_CONFIG_PATH)) + + path = Path(path) + + if not path.exists(): + # Return empty config if no file exists (allows runtime registration) + return MCPConfig() + + return MCPConfig.from_yaml(path) + + +def create_default_config() -> MCPConfig: + """ + Create a default MCP configuration with standard servers. + + This is useful for development and as a template. + """ + return MCPConfig( + mcp_servers={ + "llm-gateway": MCPServerConfig( + url="${LLM_GATEWAY_URL:-http://localhost:8001}", + transport=TransportType.HTTP, + timeout=60, + description="LLM Gateway for multi-provider AI interactions", + ), + "knowledge-base": MCPServerConfig( + url="${KNOWLEDGE_BASE_URL:-http://localhost:8002}", + transport=TransportType.HTTP, + timeout=30, + description="Knowledge Base for RAG and document retrieval", + ), + "git-ops": MCPServerConfig( + url="${GIT_OPS_URL:-http://localhost:8003}", + transport=TransportType.HTTP, + timeout=120, + description="Git Operations for repository management", + ), + "issues": MCPServerConfig( + url="${ISSUES_URL:-http://localhost:8004}", + transport=TransportType.HTTP, + timeout=30, + description="Issue Tracker for Gitea/GitHub/GitLab", + ), + }, + default_timeout=30, + default_retry_attempts=3, + connection_pool_size=10, + health_check_interval=30, + ) diff --git a/backend/app/services/mcp/connection.py b/backend/app/services/mcp/connection.py new file mode 100644 index 0000000..fc71d7d --- /dev/null +++ b/backend/app/services/mcp/connection.py @@ -0,0 +1,435 @@ +""" +MCP Connection Management + +Handles connection lifecycle, pooling, and automatic reconnection +for MCP servers. +""" + +import asyncio +import logging +import time +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from enum import Enum +from typing import Any + +import httpx + +from .config import MCPServerConfig, TransportType +from .exceptions import MCPConnectionError, MCPTimeoutError + +logger = logging.getLogger(__name__) + + +class ConnectionState(str, Enum): + """Connection state enumeration.""" + + DISCONNECTED = "disconnected" + CONNECTING = "connecting" + CONNECTED = "connected" + RECONNECTING = "reconnecting" + ERROR = "error" + + +class MCPConnection: + """ + Manages a single connection to an MCP server. + + Handles connection lifecycle, health checking, and automatic reconnection. + """ + + def __init__( + self, + server_name: str, + config: MCPServerConfig, + ) -> None: + """ + Initialize connection. + + Args: + server_name: Name of the MCP server + config: Server configuration + """ + self.server_name = server_name + self.config = config + self._state = ConnectionState.DISCONNECTED + self._client: httpx.AsyncClient | None = None + self._lock = asyncio.Lock() + self._last_activity: float | None = None + self._connection_attempts = 0 + self._last_error: Exception | None = None + + # Reconnection settings + self._base_delay = config.retry_delay + self._max_delay = config.retry_max_delay + self._max_attempts = config.retry_attempts + + @property + def state(self) -> ConnectionState: + """Get current connection state.""" + return self._state + + @property + def is_connected(self) -> bool: + """Check if connection is established.""" + return self._state == ConnectionState.CONNECTED + + @property + def last_error(self) -> Exception | None: + """Get the last error that occurred.""" + return self._last_error + + async def connect(self) -> None: + """ + Establish connection to the MCP server. + + Raises: + MCPConnectionError: If connection fails after all retries + """ + async with self._lock: + if self._state == ConnectionState.CONNECTED: + return + + self._state = ConnectionState.CONNECTING + self._connection_attempts = 0 + self._last_error = None + + while self._connection_attempts < self._max_attempts: + try: + await self._do_connect() + self._state = ConnectionState.CONNECTED + self._last_activity = time.time() + logger.info( + "Connected to MCP server: %s at %s", + self.server_name, + self.config.url, + ) + return + except Exception as e: + self._connection_attempts += 1 + self._last_error = e + logger.warning( + "Connection attempt %d/%d failed for %s: %s", + self._connection_attempts, + self._max_attempts, + self.server_name, + e, + ) + + if self._connection_attempts < self._max_attempts: + delay = self._calculate_backoff_delay() + logger.debug( + "Retrying connection to %s in %.1fs", + self.server_name, + delay, + ) + await asyncio.sleep(delay) + + # All attempts failed + self._state = ConnectionState.ERROR + raise MCPConnectionError( + f"Failed to connect after {self._max_attempts} attempts", + server_name=self.server_name, + url=self.config.url, + cause=self._last_error, + ) + + async def _do_connect(self) -> None: + """Perform the actual connection (transport-specific).""" + if self.config.transport == TransportType.HTTP: + self._client = httpx.AsyncClient( + base_url=self.config.url, + timeout=httpx.Timeout(self.config.timeout), + headers={ + "User-Agent": "Syndarix-MCP-Client/1.0", + "Accept": "application/json", + }, + ) + # Verify connectivity with a simple request + try: + # Try to hit the MCP capabilities endpoint + response = await self._client.get("/mcp/capabilities") + if response.status_code not in (200, 404): + # 404 is acceptable - server might not have capabilities endpoint + response.raise_for_status() + except httpx.HTTPStatusError as e: + if e.response.status_code != 404: + raise + except httpx.ConnectError as e: + raise MCPConnectionError( + "Failed to connect to server", + server_name=self.server_name, + url=self.config.url, + cause=e, + ) + else: + # For STDIO and SSE transports, we'll implement later + raise NotImplementedError( + f"Transport {self.config.transport} not yet implemented" + ) + + def _calculate_backoff_delay(self) -> float: + """Calculate exponential backoff delay with jitter.""" + import random + + delay = self._base_delay * (2 ** (self._connection_attempts - 1)) + delay = min(delay, self._max_delay) + # Add jitter (±25%) + jitter = delay * 0.25 * (random.random() * 2 - 1) + return delay + jitter + + async def disconnect(self) -> None: + """Disconnect from the MCP server.""" + async with self._lock: + if self._client is not None: + try: + await self._client.aclose() + except Exception as e: + logger.warning( + "Error closing connection to %s: %s", + self.server_name, + e, + ) + finally: + self._client = None + + self._state = ConnectionState.DISCONNECTED + logger.info("Disconnected from MCP server: %s", self.server_name) + + async def reconnect(self) -> None: + """Reconnect to the MCP server.""" + async with self._lock: + self._state = ConnectionState.RECONNECTING + await self.disconnect() + await self.connect() + + async def health_check(self) -> bool: + """ + Perform a health check on the connection. + + Returns: + True if connection is healthy + """ + if not self.is_connected or self._client is None: + return False + + try: + if self.config.transport == TransportType.HTTP: + response = await self._client.get( + "/health", + timeout=5.0, + ) + return response.status_code == 200 + return True + except Exception as e: + logger.warning( + "Health check failed for %s: %s", + self.server_name, + e, + ) + return False + + async def execute_request( + self, + method: str, + path: str, + data: dict[str, Any] | None = None, + timeout: float | None = None, + ) -> dict[str, Any]: + """ + Execute an HTTP request to the MCP server. + + Args: + method: HTTP method (GET, POST, etc.) + path: Request path + data: Optional request body + timeout: Optional timeout override + + Returns: + Response data + + Raises: + MCPConnectionError: If not connected + MCPTimeoutError: If request times out + """ + if not self.is_connected or self._client is None: + raise MCPConnectionError( + "Not connected to server", + server_name=self.server_name, + ) + + effective_timeout = timeout or self.config.timeout + + try: + if method.upper() == "GET": + response = await self._client.get( + path, + timeout=effective_timeout, + ) + elif method.upper() == "POST": + response = await self._client.post( + path, + json=data, + timeout=effective_timeout, + ) + else: + response = await self._client.request( + method.upper(), + path, + json=data, + timeout=effective_timeout, + ) + + self._last_activity = time.time() + response.raise_for_status() + return response.json() + + except httpx.TimeoutException as e: + raise MCPTimeoutError( + "Request timed out", + server_name=self.server_name, + timeout_seconds=effective_timeout, + operation=f"{method} {path}", + ) from e + except httpx.HTTPStatusError as e: + raise MCPConnectionError( + f"HTTP error: {e.response.status_code}", + server_name=self.server_name, + url=f"{self.config.url}{path}", + cause=e, + ) + except Exception as e: + raise MCPConnectionError( + f"Request failed: {e}", + server_name=self.server_name, + cause=e, + ) + + +class ConnectionPool: + """ + Pool of connections to MCP servers. + + Manages connection lifecycle and provides connection reuse. + """ + + def __init__(self, max_connections_per_server: int = 10) -> None: + """ + Initialize connection pool. + + Args: + max_connections_per_server: Maximum connections per server + """ + self._connections: dict[str, MCPConnection] = {} + self._lock = asyncio.Lock() + self._max_per_server = max_connections_per_server + + async def get_connection( + self, + server_name: str, + config: MCPServerConfig, + ) -> MCPConnection: + """ + Get or create a connection to a server. + + Args: + server_name: Name of the server + config: Server configuration + + Returns: + Active connection + """ + async with self._lock: + if server_name not in self._connections: + connection = MCPConnection(server_name, config) + await connection.connect() + self._connections[server_name] = connection + + connection = self._connections[server_name] + + # Reconnect if not connected + if not connection.is_connected: + await connection.connect() + + return connection + + async def release_connection(self, server_name: str) -> None: + """ + Release a connection (currently just tracks usage). + + Args: + server_name: Name of the server + """ + # For now, we keep connections alive + # Future: implement connection reaping for idle connections + + async def close_connection(self, server_name: str) -> None: + """ + Close and remove a connection. + + Args: + server_name: Name of the server + """ + async with self._lock: + if server_name in self._connections: + await self._connections[server_name].disconnect() + del self._connections[server_name] + + async def close_all(self) -> None: + """Close all connections in the pool.""" + async with self._lock: + for connection in self._connections.values(): + try: + await connection.disconnect() + except Exception as e: + logger.warning("Error closing connection: %s", e) + + self._connections.clear() + logger.info("Closed all MCP connections") + + async def health_check_all(self) -> dict[str, bool]: + """ + Perform health check on all connections. + + Returns: + Dict mapping server names to health status + """ + results = {} + for name, connection in self._connections.items(): + results[name] = await connection.health_check() + return results + + def get_status(self) -> dict[str, dict[str, Any]]: + """ + Get status of all connections. + + Returns: + Dict mapping server names to status info + """ + return { + name: { + "state": conn.state.value, + "is_connected": conn.is_connected, + "url": conn.config.url, + } + for name, conn in self._connections.items() + } + + @asynccontextmanager + async def connection( + self, + server_name: str, + config: MCPServerConfig, + ) -> AsyncGenerator[MCPConnection, None]: + """ + Context manager for getting a connection. + + Usage: + async with pool.connection("server", config) as conn: + result = await conn.execute_request("POST", "/tool", data) + """ + conn = await self.get_connection(server_name, config) + try: + yield conn + finally: + await self.release_connection(server_name) diff --git a/backend/app/services/mcp/exceptions.py b/backend/app/services/mcp/exceptions.py new file mode 100644 index 0000000..5b3e93b --- /dev/null +++ b/backend/app/services/mcp/exceptions.py @@ -0,0 +1,201 @@ +""" +MCP Exception Classes + +Custom exceptions for MCP client operations with detailed error context. +""" + +from typing import Any + + +class MCPError(Exception): + """Base exception for all MCP-related errors.""" + + def __init__( + self, + message: str, + *, + server_name: str | None = None, + details: dict[str, Any] | None = None, + ) -> None: + super().__init__(message) + self.message = message + self.server_name = server_name + self.details = details or {} + + def __str__(self) -> str: + parts = [self.message] + if self.server_name: + parts.append(f"server={self.server_name}") + if self.details: + parts.append(f"details={self.details}") + return " | ".join(parts) + + +class MCPConnectionError(MCPError): + """Raised when connection to an MCP server fails.""" + + def __init__( + self, + message: str, + *, + server_name: str | None = None, + url: str | None = None, + cause: Exception | None = None, + details: dict[str, Any] | None = None, + ) -> None: + super().__init__(message, server_name=server_name, details=details) + self.url = url + self.cause = cause + + def __str__(self) -> str: + base = super().__str__() + if self.url: + base = f"{base} | url={self.url}" + if self.cause: + base = f"{base} | cause={type(self.cause).__name__}: {self.cause}" + return base + + +class MCPTimeoutError(MCPError): + """Raised when an MCP operation times out.""" + + def __init__( + self, + message: str, + *, + server_name: str | None = None, + timeout_seconds: float | None = None, + operation: str | None = None, + details: dict[str, Any] | None = None, + ) -> None: + super().__init__(message, server_name=server_name, details=details) + self.timeout_seconds = timeout_seconds + self.operation = operation + + def __str__(self) -> str: + base = super().__str__() + if self.timeout_seconds is not None: + base = f"{base} | timeout={self.timeout_seconds}s" + if self.operation: + base = f"{base} | operation={self.operation}" + return base + + +class MCPToolError(MCPError): + """Raised when a tool execution fails.""" + + def __init__( + self, + message: str, + *, + server_name: str | None = None, + tool_name: str | None = None, + tool_args: dict[str, Any] | None = None, + error_code: str | None = None, + details: dict[str, Any] | None = None, + ) -> None: + super().__init__(message, server_name=server_name, details=details) + self.tool_name = tool_name + self.tool_args = tool_args + self.error_code = error_code + + def __str__(self) -> str: + base = super().__str__() + if self.tool_name: + base = f"{base} | tool={self.tool_name}" + if self.error_code: + base = f"{base} | error_code={self.error_code}" + return base + + +class MCPServerNotFoundError(MCPError): + """Raised when a requested MCP server is not registered.""" + + def __init__( + self, + server_name: str, + *, + available_servers: list[str] | None = None, + details: dict[str, Any] | None = None, + ) -> None: + message = f"MCP server not found: {server_name}" + super().__init__(message, server_name=server_name, details=details) + self.available_servers = available_servers or [] + + def __str__(self) -> str: + base = super().__str__() + if self.available_servers: + base = f"{base} | available={self.available_servers}" + return base + + +class MCPToolNotFoundError(MCPError): + """Raised when a requested tool is not found on any server.""" + + def __init__( + self, + tool_name: str, + *, + server_name: str | None = None, + available_tools: list[str] | None = None, + details: dict[str, Any] | None = None, + ) -> None: + message = f"Tool not found: {tool_name}" + super().__init__(message, server_name=server_name, details=details) + self.tool_name = tool_name + self.available_tools = available_tools or [] + + def __str__(self) -> str: + base = super().__str__() + if self.available_tools: + base = f"{base} | available_tools={self.available_tools[:5]}..." + return base + + +class MCPCircuitOpenError(MCPError): + """Raised when a circuit breaker is open (server temporarily unavailable).""" + + def __init__( + self, + server_name: str, + *, + failure_count: int | None = None, + reset_timeout: float | None = None, + details: dict[str, Any] | None = None, + ) -> None: + message = f"Circuit breaker open for server: {server_name}" + super().__init__(message, server_name=server_name, details=details) + self.failure_count = failure_count + self.reset_timeout = reset_timeout + + def __str__(self) -> str: + base = super().__str__() + if self.failure_count is not None: + base = f"{base} | failures={self.failure_count}" + if self.reset_timeout is not None: + base = f"{base} | reset_in={self.reset_timeout}s" + return base + + +class MCPValidationError(MCPError): + """Raised when tool arguments fail validation.""" + + def __init__( + self, + message: str, + *, + tool_name: str | None = None, + field_errors: dict[str, str] | None = None, + details: dict[str, Any] | None = None, + ) -> None: + super().__init__(message, details=details) + self.tool_name = tool_name + self.field_errors = field_errors or {} + + def __str__(self) -> str: + base = super().__str__() + if self.tool_name: + base = f"{base} | tool={self.tool_name}" + if self.field_errors: + base = f"{base} | fields={list(self.field_errors.keys())}" + return base diff --git a/backend/app/services/mcp/registry.py b/backend/app/services/mcp/registry.py new file mode 100644 index 0000000..04154a5 --- /dev/null +++ b/backend/app/services/mcp/registry.py @@ -0,0 +1,305 @@ +""" +MCP Server Registry + +Thread-safe singleton registry for managing MCP server configurations +and their capabilities. +""" + +import asyncio +import logging +from threading import Lock +from typing import Any + +from .config import MCPConfig, MCPServerConfig, load_mcp_config +from .exceptions import MCPServerNotFoundError + +logger = logging.getLogger(__name__) + + +class ServerCapabilities: + """Cached capabilities for an MCP server.""" + + def __init__( + self, + tools: list[dict[str, Any]] | None = None, + resources: list[dict[str, Any]] | None = None, + prompts: list[dict[str, Any]] | None = None, + ) -> None: + self.tools = tools or [] + self.resources = resources or [] + self.prompts = prompts or [] + self._loaded = False + self._load_time: float | None = None + + @property + def is_loaded(self) -> bool: + """Check if capabilities have been loaded.""" + return self._loaded + + @property + def tool_names(self) -> list[str]: + """Get list of tool names.""" + return [t.get("name", "") for t in self.tools if t.get("name")] + + def mark_loaded(self) -> None: + """Mark capabilities as loaded.""" + import time + + self._loaded = True + self._load_time = time.time() + + +class MCPServerRegistry: + """ + Thread-safe singleton registry for MCP servers. + + Manages server configurations and caches their capabilities. + """ + + _instance: "MCPServerRegistry | None" = None + _lock = Lock() + + def __new__(cls) -> "MCPServerRegistry": + """Ensure singleton pattern.""" + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self) -> None: + """Initialize registry (only runs once due to singleton).""" + if getattr(self, "_initialized", False): + return + + self._config: MCPConfig = MCPConfig() + self._capabilities: dict[str, ServerCapabilities] = {} + self._capabilities_lock = asyncio.Lock() + self._initialized = True + + logger.info("MCP Server Registry initialized") + + @classmethod + def get_instance(cls) -> "MCPServerRegistry": + """Get the singleton registry instance.""" + return cls() + + @classmethod + def reset_instance(cls) -> None: + """Reset the singleton (for testing).""" + with cls._lock: + cls._instance = None + + def load_config(self, config: MCPConfig | None = None) -> None: + """ + Load configuration into the registry. + + Args: + config: Optional config to load. If None, loads from default path. + """ + if config is None: + config = load_mcp_config() + + self._config = config + self._capabilities.clear() + + logger.info( + "Loaded MCP configuration with %d servers", + len(config.mcp_servers), + ) + for name in config.list_server_names(): + logger.debug("Registered MCP server: %s", name) + + def register(self, name: str, config: MCPServerConfig) -> None: + """ + Register a new MCP server. + + Args: + name: Unique server name + config: Server configuration + """ + self._config.mcp_servers[name] = config + self._capabilities.pop(name, None) # Clear any cached capabilities + + logger.info("Registered MCP server: %s at %s", name, config.url) + + def unregister(self, name: str) -> bool: + """ + Unregister an MCP server. + + Args: + name: Server name to unregister + + Returns: + True if server was found and removed + """ + if name in self._config.mcp_servers: + del self._config.mcp_servers[name] + self._capabilities.pop(name, None) + logger.info("Unregistered MCP server: %s", name) + return True + + return False + + def get(self, name: str) -> MCPServerConfig: + """ + Get a server configuration by name. + + Args: + name: Server name + + Returns: + Server configuration + + Raises: + MCPServerNotFoundError: If server is not registered + """ + config = self._config.get_server(name) + if config is None: + raise MCPServerNotFoundError( + server_name=name, + available_servers=self.list_servers(), + ) + return config + + def get_or_none(self, name: str) -> MCPServerConfig | None: + """ + Get a server configuration by name, or None if not found. + + Args: + name: Server name + + Returns: + Server configuration or None + """ + return self._config.get_server(name) + + def list_servers(self) -> list[str]: + """Get list of all registered server names.""" + return self._config.list_server_names() + + def list_enabled_servers(self) -> list[str]: + """Get list of enabled server names.""" + return list(self._config.get_enabled_servers().keys()) + + def get_all_configs(self) -> dict[str, MCPServerConfig]: + """Get all server configurations.""" + return dict(self._config.mcp_servers) + + def get_enabled_configs(self) -> dict[str, MCPServerConfig]: + """Get all enabled server configurations.""" + return self._config.get_enabled_servers() + + async def get_capabilities( + self, + name: str, + force_refresh: bool = False, + ) -> ServerCapabilities: + """ + Get capabilities for a server (lazy-loaded and cached). + + Args: + name: Server name + force_refresh: If True, refresh cached capabilities + + Returns: + Server capabilities + + Raises: + MCPServerNotFoundError: If server is not registered + """ + # Verify server exists + self.get(name) + + async with self._capabilities_lock: + if name not in self._capabilities or force_refresh: + # Will be populated by connection manager when connecting + self._capabilities[name] = ServerCapabilities() + + return self._capabilities[name] + + def set_capabilities( + self, + name: str, + tools: list[dict[str, Any]] | None = None, + resources: list[dict[str, Any]] | None = None, + prompts: list[dict[str, Any]] | None = None, + ) -> None: + """ + Set capabilities for a server (called by connection manager). + + Args: + name: Server name + tools: List of tool definitions + resources: List of resource definitions + prompts: List of prompt definitions + """ + capabilities = ServerCapabilities( + tools=tools, + resources=resources, + prompts=prompts, + ) + capabilities.mark_loaded() + self._capabilities[name] = capabilities + + logger.debug( + "Updated capabilities for %s: %d tools, %d resources, %d prompts", + name, + len(capabilities.tools), + len(capabilities.resources), + len(capabilities.prompts), + ) + + def get_cached_capabilities(self, name: str) -> ServerCapabilities: + """ + Get cached capabilities without async loading. + + Use this for synchronous access when you only need + cached values (e.g., for health check responses). + + Args: + name: Server name + + Returns: + Cached capabilities or empty ServerCapabilities + """ + return self._capabilities.get(name, ServerCapabilities()) + + def find_server_for_tool(self, tool_name: str) -> str | None: + """ + Find which server provides a specific tool. + + Args: + tool_name: Name of the tool to find + + Returns: + Server name or None if not found + """ + for name, caps in self._capabilities.items(): + if tool_name in caps.tool_names: + return name + return None + + def get_all_tools(self) -> dict[str, list[dict[str, Any]]]: + """ + Get all tools from all servers. + + Returns: + Dict mapping server name to list of tool definitions + """ + return { + name: caps.tools + for name, caps in self._capabilities.items() + if caps.is_loaded + } + + @property + def global_config(self) -> MCPConfig: + """Get the global MCP configuration.""" + return self._config + + +# Module-level convenience function +def get_registry() -> MCPServerRegistry: + """Get the global MCP server registry instance.""" + return MCPServerRegistry.get_instance() diff --git a/backend/app/services/mcp/routing.py b/backend/app/services/mcp/routing.py new file mode 100644 index 0000000..be7dbbe --- /dev/null +++ b/backend/app/services/mcp/routing.py @@ -0,0 +1,619 @@ +""" +MCP Tool Call Routing + +Routes tool calls to appropriate servers with retry logic, +circuit breakers, and request/response serialization. +""" + +import asyncio +import logging +import time +import uuid +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +from .config import MCPServerConfig +from .connection import ConnectionPool, MCPConnection +from .exceptions import ( + MCPCircuitOpenError, + MCPError, + MCPTimeoutError, + MCPToolError, + MCPToolNotFoundError, +) +from .registry import MCPServerRegistry + +logger = logging.getLogger(__name__) + + +class CircuitState(Enum): + """Circuit breaker states.""" + + CLOSED = "closed" + OPEN = "open" + HALF_OPEN = "half-open" + + +class AsyncCircuitBreaker: + """ + Async-compatible circuit breaker implementation. + + Unlike pybreaker which wraps sync functions, this implementation + provides explicit success/failure tracking for async code. + """ + + def __init__( + self, + fail_max: int = 5, + reset_timeout: float = 30.0, + name: str = "", + ) -> None: + """ + Initialize circuit breaker. + + Args: + fail_max: Maximum failures before opening circuit + reset_timeout: Seconds to wait before trying again + name: Name for logging + """ + self.fail_max = fail_max + self.reset_timeout = reset_timeout + self.name = name + self._state = CircuitState.CLOSED + self._fail_counter = 0 + self._last_failure_time: float | None = None + self._lock = asyncio.Lock() + + @property + def current_state(self) -> str: + """Get current state as string.""" + # Check if we should transition from OPEN to HALF_OPEN + if self._state == CircuitState.OPEN: + if self._should_try_reset(): + return CircuitState.HALF_OPEN.value + return self._state.value + + @property + def fail_counter(self) -> int: + """Get current failure count.""" + return self._fail_counter + + def _should_try_reset(self) -> bool: + """Check if enough time has passed to try resetting.""" + if self._last_failure_time is None: + return True + return (time.time() - self._last_failure_time) >= self.reset_timeout + + async def success(self) -> None: + """Record a successful call.""" + async with self._lock: + self._fail_counter = 0 + self._state = CircuitState.CLOSED + self._last_failure_time = None + + async def failure(self) -> None: + """Record a failed call.""" + async with self._lock: + self._fail_counter += 1 + self._last_failure_time = time.time() + + if self._fail_counter >= self.fail_max: + self._state = CircuitState.OPEN + logger.warning( + "Circuit breaker %s opened after %d failures", + self.name, + self._fail_counter, + ) + + def is_open(self) -> bool: + """Check if circuit is open (not allowing calls).""" + if self._state == CircuitState.OPEN: + return not self._should_try_reset() + return False + + async def reset(self) -> None: + """Manually reset the circuit breaker.""" + async with self._lock: + self._state = CircuitState.CLOSED + self._fail_counter = 0 + self._last_failure_time = None + + +@dataclass +class ToolInfo: + """Information about an available tool.""" + + name: str + description: str | None = None + server_name: str | None = None + input_schema: dict[str, Any] | None = None + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary.""" + return { + "name": self.name, + "description": self.description, + "server_name": self.server_name, + "input_schema": self.input_schema, + } + + +@dataclass +class ToolResult: + """Result of a tool execution.""" + + success: bool + data: Any = None + error: str | None = None + error_code: str | None = None + tool_name: str | None = None + server_name: str | None = None + execution_time_ms: float = 0.0 + request_id: str = field(default_factory=lambda: str(uuid.uuid4())) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary.""" + return { + "success": self.success, + "data": self.data, + "error": self.error, + "error_code": self.error_code, + "tool_name": self.tool_name, + "server_name": self.server_name, + "execution_time_ms": self.execution_time_ms, + "request_id": self.request_id, + } + + +class ToolRouter: + """ + Routes tool calls to the appropriate MCP server. + + Features: + - Tool name to server mapping + - Retry logic with exponential backoff + - Circuit breaker pattern for fault tolerance + - Request/response serialization + - Execution timing and metrics + """ + + def __init__( + self, + registry: MCPServerRegistry, + connection_pool: ConnectionPool, + ) -> None: + """ + Initialize the tool router. + + Args: + registry: MCP server registry + connection_pool: Connection pool for servers + """ + self._registry = registry + self._pool = connection_pool + self._circuit_breakers: dict[str, AsyncCircuitBreaker] = {} + self._tool_to_server: dict[str, str] = {} + self._lock = asyncio.Lock() + + def _get_circuit_breaker( + self, + server_name: str, + config: MCPServerConfig, + ) -> AsyncCircuitBreaker: + """Get or create a circuit breaker for a server.""" + if server_name not in self._circuit_breakers: + self._circuit_breakers[server_name] = AsyncCircuitBreaker( + fail_max=config.circuit_breaker_threshold, + reset_timeout=config.circuit_breaker_timeout, + name=f"mcp-{server_name}", + ) + return self._circuit_breakers[server_name] + + async def register_tool_mapping( + self, + tool_name: str, + server_name: str, + ) -> None: + """ + Register a mapping from tool name to server. + + Args: + tool_name: Name of the tool + server_name: Name of the server providing the tool + """ + async with self._lock: + self._tool_to_server[tool_name] = server_name + logger.debug("Registered tool %s -> server %s", tool_name, server_name) + + async def discover_tools(self) -> None: + """ + Discover all tools from registered servers and build mappings. + """ + for server_name in self._registry.list_enabled_servers(): + try: + config = self._registry.get(server_name) + connection = await self._pool.get_connection(server_name, config) + + # Fetch tools from server + tools = await self._fetch_tools_from_server(connection) + + # Update registry with capabilities + self._registry.set_capabilities( + server_name, + tools=[t.to_dict() for t in tools], + ) + + # Update tool mappings + for tool in tools: + await self.register_tool_mapping(tool.name, server_name) + + logger.info( + "Discovered %d tools from server %s", + len(tools), + server_name, + ) + except Exception as e: + logger.warning( + "Failed to discover tools from %s: %s", + server_name, + e, + ) + + async def _fetch_tools_from_server( + self, + connection: MCPConnection, + ) -> list[ToolInfo]: + """Fetch available tools from an MCP server.""" + try: + response = await connection.execute_request( + "GET", + "/mcp/tools", + ) + + tools = [] + for tool_data in response.get("tools", []): + tools.append( + ToolInfo( + name=tool_data.get("name", ""), + description=tool_data.get("description"), + server_name=connection.server_name, + input_schema=tool_data.get("inputSchema"), + ) + ) + return tools + except Exception as e: + logger.warning( + "Error fetching tools from %s: %s", + connection.server_name, + e, + ) + return [] + + def find_server_for_tool(self, tool_name: str) -> str | None: + """ + Find which server provides a specific tool. + + Args: + tool_name: Name of the tool + + Returns: + Server name or None if not found + """ + return self._tool_to_server.get(tool_name) + + async def call_tool( + self, + server_name: str, + tool_name: str, + arguments: dict[str, Any] | None = None, + timeout: float | None = None, + ) -> ToolResult: + """ + Call a tool on a specific server. + + Args: + server_name: Name of the MCP server + tool_name: Name of the tool to call + arguments: Tool arguments + timeout: Optional timeout override + + Returns: + Tool execution result + """ + start_time = time.time() + request_id = str(uuid.uuid4()) + + logger.debug( + "Tool call [%s]: %s.%s with args %s", + request_id, + server_name, + tool_name, + arguments, + ) + + try: + config = self._registry.get(server_name) + circuit_breaker = self._get_circuit_breaker(server_name, config) + + # Check circuit breaker state + if circuit_breaker.is_open(): + raise MCPCircuitOpenError( + server_name=server_name, + failure_count=circuit_breaker.fail_counter, + reset_timeout=config.circuit_breaker_timeout, + ) + + # Execute with retry logic + result = await self._execute_with_retry( + server_name=server_name, + config=config, + tool_name=tool_name, + arguments=arguments or {}, + timeout=timeout, + circuit_breaker=circuit_breaker, + ) + + execution_time = (time.time() - start_time) * 1000 + + return ToolResult( + success=True, + data=result, + tool_name=tool_name, + server_name=server_name, + execution_time_ms=execution_time, + request_id=request_id, + ) + + except MCPCircuitOpenError: + raise + except MCPError as e: + execution_time = (time.time() - start_time) * 1000 + logger.error( + "Tool call failed [%s]: %s.%s - %s", + request_id, + server_name, + tool_name, + e, + ) + return ToolResult( + success=False, + error=str(e), + error_code=type(e).__name__, + tool_name=tool_name, + server_name=server_name, + execution_time_ms=execution_time, + request_id=request_id, + ) + except Exception as e: + execution_time = (time.time() - start_time) * 1000 + logger.exception( + "Unexpected error in tool call [%s]: %s.%s", + request_id, + server_name, + tool_name, + ) + return ToolResult( + success=False, + error=str(e), + error_code="UnexpectedError", + tool_name=tool_name, + server_name=server_name, + execution_time_ms=execution_time, + request_id=request_id, + ) + + async def _execute_with_retry( + self, + server_name: str, + config: MCPServerConfig, + tool_name: str, + arguments: dict[str, Any], + timeout: float | None, + circuit_breaker: AsyncCircuitBreaker, + ) -> Any: + """Execute tool call with retry logic.""" + last_error: Exception | None = None + attempts = 0 + max_attempts = config.retry_attempts + 1 # +1 for initial attempt + + while attempts < max_attempts: + attempts += 1 + + try: + # Use circuit breaker to track failures + result = await self._execute_tool_call( + server_name=server_name, + config=config, + tool_name=tool_name, + arguments=arguments, + timeout=timeout, + ) + + # Success - record it + await circuit_breaker.success() + return result + + except MCPCircuitOpenError: + raise + except MCPTimeoutError: + # Timeout - don't retry + await circuit_breaker.failure() + raise + except MCPToolError: + # Tool-level error - don't retry (user error) + raise + except Exception as e: + last_error = e + await circuit_breaker.failure() + + if attempts < max_attempts: + delay = self._calculate_retry_delay(attempts, config) + logger.warning( + "Tool call attempt %d/%d failed for %s.%s: %s. " + "Retrying in %.1fs", + attempts, + max_attempts, + server_name, + tool_name, + e, + delay, + ) + await asyncio.sleep(delay) + + # All attempts failed + raise MCPToolError( + f"Tool call failed after {max_attempts} attempts", + server_name=server_name, + tool_name=tool_name, + tool_args=arguments, + details={"last_error": str(last_error)}, + ) + + def _calculate_retry_delay( + self, + attempt: int, + config: MCPServerConfig, + ) -> float: + """Calculate exponential backoff delay with jitter.""" + import random + + delay = config.retry_delay * (2 ** (attempt - 1)) + delay = min(delay, config.retry_max_delay) + # Add jitter (±25%) + jitter = delay * 0.25 * (random.random() * 2 - 1) + return max(0.1, delay + jitter) + + async def _execute_tool_call( + self, + server_name: str, + config: MCPServerConfig, + tool_name: str, + arguments: dict[str, Any], + timeout: float | None, + ) -> Any: + """Execute a single tool call.""" + connection = await self._pool.get_connection(server_name, config) + + # Build MCP tool call request + request_body = { + "jsonrpc": "2.0", + "method": "tools/call", + "params": { + "name": tool_name, + "arguments": arguments, + }, + "id": str(uuid.uuid4()), + } + + response = await connection.execute_request( + method="POST", + path="/mcp", + data=request_body, + timeout=timeout, + ) + + # Handle JSON-RPC response + if "error" in response: + error = response["error"] + raise MCPToolError( + error.get("message", "Tool execution failed"), + server_name=server_name, + tool_name=tool_name, + tool_args=arguments, + error_code=str(error.get("code", "UNKNOWN")), + ) + + return response.get("result") + + async def route_tool( + self, + tool_name: str, + arguments: dict[str, Any] | None = None, + timeout: float | None = None, + ) -> ToolResult: + """ + Route a tool call to the appropriate server. + + Automatically discovers which server provides the tool. + + Args: + tool_name: Name of the tool to call + arguments: Tool arguments + timeout: Optional timeout override + + Returns: + Tool execution result + + Raises: + MCPToolNotFoundError: If no server provides the tool + """ + server_name = self.find_server_for_tool(tool_name) + + if server_name is None: + # Try to find from registry + server_name = self._registry.find_server_for_tool(tool_name) + + if server_name is None: + raise MCPToolNotFoundError( + tool_name=tool_name, + available_tools=list(self._tool_to_server.keys()), + ) + + return await self.call_tool( + server_name=server_name, + tool_name=tool_name, + arguments=arguments, + timeout=timeout, + ) + + async def list_all_tools(self) -> list[ToolInfo]: + """ + Get all available tools from all servers. + + Returns: + List of tool information + """ + tools = [] + all_server_tools = self._registry.get_all_tools() + + for server_name, server_tools in all_server_tools.items(): + for tool_data in server_tools: + tools.append( + ToolInfo( + name=tool_data.get("name", ""), + description=tool_data.get("description"), + server_name=server_name, + input_schema=tool_data.get("input_schema"), + ) + ) + + return tools + + def get_circuit_breaker_status(self) -> dict[str, dict[str, Any]]: + """Get status of all circuit breakers.""" + return { + name: { + "state": cb.current_state, + "failure_count": cb.fail_counter, + } + for name, cb in self._circuit_breakers.items() + } + + async def reset_circuit_breaker(self, server_name: str) -> bool: + """ + Manually reset a circuit breaker. + + Args: + server_name: Name of the server + + Returns: + True if circuit breaker was reset + """ + async with self._lock: + if server_name in self._circuit_breakers: + # Reset by removing (will be recreated on next call) + del self._circuit_breakers[server_name] + logger.info("Reset circuit breaker for %s", server_name) + return True + return False diff --git a/backend/docs/MCP_CLIENT.md b/backend/docs/MCP_CLIENT.md new file mode 100644 index 0000000..855cd02 --- /dev/null +++ b/backend/docs/MCP_CLIENT.md @@ -0,0 +1,324 @@ +# MCP Client Infrastructure + +This document describes the Model Context Protocol (MCP) client infrastructure used by Syndarix to communicate with AI agent tools. + +## Overview + +The MCP client infrastructure provides a robust, fault-tolerant layer for communicating with MCP servers. It enables AI agents to discover and execute tools provided by various services (LLM Gateway, Knowledge Base, Git Operations, Issue Tracker, etc.). + +## Architecture + +``` +┌────────────────────────────────────────────────────────────────────────┐ +│ MCPClientManager │ +│ (Main Facade Class) │ +├────────────────────────────────────────────────────────────────────────┤ +│ - initialize() / shutdown() │ +│ - call_tool() / route_tool() │ +│ - connect() / disconnect() │ +│ - health_check() / list_tools() │ +└─────────────┬────────────────────┬─────────────────┬───────────────────┘ + │ │ │ + ▼ ▼ ▼ +┌─────────────────────┐ ┌─────────────────┐ ┌──────────────────────────┐ +│ MCPServerRegistry │ │ ConnectionPool │ │ ToolRouter │ +│ (Singleton) │ │ │ │ │ +├─────────────────────┤ ├─────────────────┤ ├──────────────────────────┤ +│ - Server configs │ │ - Connection │ │ - Tool → Server mapping │ +│ - Capabilities │ │ management │ │ - Circuit breakers │ +│ - Tool discovery │ │ - Auto reconnect│ │ - Retry logic │ +└─────────────────────┘ └─────────────────┘ └──────────────────────────┘ +``` + +## Components + +### MCPClientManager + +The main entry point for all MCP operations. Provides a clean facade over the underlying infrastructure. + +```python +from app.services.mcp import get_mcp_client, MCPClientManager + +# In FastAPI dependency injection +async def my_route(mcp: MCPClientManager = Depends(get_mcp_client)): + result = await mcp.call_tool( + server="llm-gateway", + tool="chat", + args={"prompt": "Hello"} + ) + return result.data + +# Direct usage +manager = MCPClientManager() +await manager.initialize() + +# Execute a tool +result = await manager.call_tool( + server="issues", + tool="create_issue", + args={"title": "New Feature", "body": "Description"} +) + +await manager.shutdown() +``` + +### Configuration + +Configuration is loaded from YAML files and supports environment variable expansion: + +```yaml +# mcp_servers.yaml +mcp_servers: + llm-gateway: + url: ${LLM_GATEWAY_URL:-http://localhost:8001} + timeout: 60 + transport: http + enabled: true + retry_attempts: 3 + circuit_breaker_threshold: 5 + circuit_breaker_timeout: 30.0 + + knowledge-base: + url: ${KNOWLEDGE_BASE_URL:-http://localhost:8002} + timeout: 30 + enabled: true + +default_timeout: 30 +connection_pool_size: 10 +health_check_interval: 30 +``` + +**Environment Variable Syntax:** +- `${VAR_NAME}` - Uses the environment variable value +- `${VAR_NAME:-default}` - Uses default if variable is not set + +### Connection Management + +The `ConnectionPool` manages connections to MCP servers with: + +- **Connection Reuse**: Connections are pooled and reused +- **Auto Reconnection**: Failed connections are automatically retried +- **Health Checks**: Periodic health checks detect unhealthy servers +- **Exponential Backoff**: Retry delays increase exponentially with jitter + +```python +from app.services.mcp import ConnectionPool, MCPConnection + +pool = ConnectionPool(max_connections_per_server=5) + +# Get a connection (creates new or reuses existing) +conn = await pool.get_connection("server-1", config) + +# Execute request +result = await conn.execute_request("POST", "/mcp", data={...}) + +# Health check all connections +health = await pool.health_check_all() +``` + +### Circuit Breaker Pattern + +The `AsyncCircuitBreaker` prevents cascade failures: + +| State | Description | +|-------|-------------| +| CLOSED | Normal operation, calls pass through | +| OPEN | Too many failures, calls are rejected immediately | +| HALF-OPEN | After timeout, allows one call to test if service recovered | + +```python +from app.services.mcp import AsyncCircuitBreaker + +breaker = AsyncCircuitBreaker( + fail_max=5, # Open after 5 failures + reset_timeout=30, # Try again after 30 seconds + name="my-service" +) + +if breaker.is_open(): + raise MCPCircuitOpenError(...) + +try: + result = await call_external_service() + await breaker.success() +except Exception: + await breaker.failure() + raise +``` + +### Tool Routing + +The `ToolRouter` handles: + +- **Tool Discovery**: Automatically discovers tools from connected servers +- **Routing**: Routes tool calls to the appropriate server +- **Retry Logic**: Retries failed calls with exponential backoff + +```python +from app.services.mcp import ToolRouter + +router = ToolRouter(registry, pool) + +# Discover tools from all servers +await router.discover_tools() + +# Route to the right server automatically +result = await router.route_tool( + tool_name="create_issue", + arguments={"title": "Bug fix"} +) + +# Or call a specific server +result = await router.call_tool( + server_name="issues", + tool_name="create_issue", + arguments={"title": "Bug fix"} +) +``` + +## Exception Hierarchy + +``` +MCPError +├── MCPConnectionError # Connection failures +├── MCPTimeoutError # Operation timeouts +├── MCPToolError # Tool execution errors +├── MCPServerNotFoundError # Unknown server +├── MCPToolNotFoundError # Unknown tool +├── MCPCircuitOpenError # Circuit breaker open +└── MCPValidationError # Invalid configuration +``` + +All exceptions include rich context: + +```python +except MCPServerNotFoundError as e: + print(f"Server: {e.server_name}") + print(f"Available: {e.available_servers}") + print(f"Suggestion: {e.suggestion}") +``` + +## REST API Endpoints + +| Method | Endpoint | Description | Auth | +|--------|----------|-------------|------| +| GET | `/api/v1/mcp/servers` | List all MCP servers | No | +| GET | `/api/v1/mcp/servers/{name}/tools` | List server tools | No | +| GET | `/api/v1/mcp/tools` | List all tools | No | +| GET | `/api/v1/mcp/health` | Health check | No | +| POST | `/api/v1/mcp/call` | Execute tool | Superuser | +| GET | `/api/v1/mcp/circuit-breakers` | List circuit breakers | No | +| POST | `/api/v1/mcp/circuit-breakers/{name}/reset` | Reset breaker | Superuser | +| POST | `/api/v1/mcp/servers/{name}/reconnect` | Force reconnect | Superuser | + +### Example: Execute a Tool + +```http +POST /api/v1/mcp/call +Authorization: Bearer +Content-Type: application/json + +{ + "server": "issues", + "tool": "create_issue", + "arguments": { + "title": "New Feature Request", + "body": "Please add dark mode support" + }, + "timeout": 30 +} +``` + +**Response:** +```json +{ + "success": true, + "data": { + "issue_id": "12345", + "url": "https://gitea.example.com/org/repo/issues/42" + }, + "tool_name": "create_issue", + "server_name": "issues", + "execution_time_ms": 234.5, + "request_id": "550e8400-e29b-41d4-a716-446655440000" +} +``` + +## Usage in Syndarix Agents + +AI agents use the MCP client to execute tools: + +```python +class IssueCreatorAgent: + def __init__(self, mcp: MCPClientManager): + self.mcp = mcp + + async def create_issue(self, title: str, body: str) -> dict: + result = await self.mcp.call_tool( + server="issues", + tool="create_issue", + args={"title": title, "body": body} + ) + + if not result.success: + raise AgentError(f"Failed to create issue: {result.error}") + + return result.data +``` + +## Testing + +The MCP infrastructure is thoroughly tested: + +- **Unit Tests**: `tests/services/mcp/` - Service layer tests +- **API Tests**: `tests/api/routes/test_mcp.py` - Endpoint tests + +Run tests: +```bash +# All MCP tests +IS_TEST=True uv run pytest tests/services/mcp/ tests/api/routes/test_mcp.py -v + +# With coverage +IS_TEST=True uv run pytest tests/services/mcp/ --cov=app/services/mcp +``` + +## Configuration Reference + +### MCPServerConfig + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `url` | str | Required | Server URL | +| `transport` | str | "http" | Transport type (http, stdio, sse) | +| `timeout` | int | 30 | Request timeout (1-600 seconds) | +| `retry_attempts` | int | 3 | Max retry attempts (0-10) | +| `retry_delay` | float | 1.0 | Initial retry delay (0.1-300 seconds) | +| `retry_max_delay` | float | 30.0 | Maximum retry delay | +| `circuit_breaker_threshold` | int | 5 | Failures before opening circuit | +| `circuit_breaker_timeout` | float | 30.0 | Seconds before trying again | +| `enabled` | bool | true | Whether server is enabled | +| `description` | str | None | Server description | + +### MCPConfig (Global) + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `mcp_servers` | dict | {} | Server configurations | +| `default_timeout` | int | 30 | Default request timeout | +| `default_retry_attempts` | int | 3 | Default retry attempts | +| `connection_pool_size` | int | 10 | Max connections per server | +| `health_check_interval` | int | 30 | Health check interval (seconds) | + +## Files + +| Path | Description | +|------|-------------| +| `app/services/mcp/__init__.py` | Package exports | +| `app/services/mcp/client_manager.py` | Main facade class | +| `app/services/mcp/config.py` | Configuration models | +| `app/services/mcp/registry.py` | Server registry singleton | +| `app/services/mcp/connection.py` | Connection management | +| `app/services/mcp/routing.py` | Tool routing and circuit breakers | +| `app/services/mcp/exceptions.py` | Exception classes | +| `app/api/routes/mcp.py` | REST API endpoints | +| `mcp_servers.yaml` | Default configuration | diff --git a/backend/mcp_servers.yaml b/backend/mcp_servers.yaml new file mode 100644 index 0000000..37af5db --- /dev/null +++ b/backend/mcp_servers.yaml @@ -0,0 +1,60 @@ +# MCP Server Configuration +# +# This file defines the MCP servers that the Syndarix backend connects to. +# Environment variables can be used with ${VAR:-default} syntax. +# +# Example: +# url: ${MY_SERVER_URL:-http://localhost:8001} +# +# For development, these servers typically run as separate Docker containers. +# See docker-compose.yml for container definitions. + +mcp_servers: + # LLM Gateway - Multi-provider AI interactions + llm-gateway: + url: ${LLM_GATEWAY_URL:-http://localhost:8001} + transport: http + timeout: 60 + retry_attempts: 3 + retry_delay: 1.0 + retry_max_delay: 30.0 + circuit_breaker_threshold: 5 + circuit_breaker_timeout: 30.0 + enabled: true + description: "LLM Gateway for Anthropic, OpenAI, Ollama, and other providers" + + # Knowledge Base - RAG and document retrieval + knowledge-base: + url: ${KNOWLEDGE_BASE_URL:-http://localhost:8002} + transport: http + timeout: 30 + retry_attempts: 3 + circuit_breaker_threshold: 5 + enabled: true + description: "Knowledge Base with pgvector for semantic search and RAG" + + # Git Operations - Repository management + git-ops: + url: ${GIT_OPS_URL:-http://localhost:8003} + transport: http + timeout: 120 + retry_attempts: 2 + circuit_breaker_threshold: 3 + enabled: true + description: "Git Operations for clone, commit, push, and repository management" + + # Issues - Issue tracker integration + issues: + url: ${ISSUES_URL:-http://localhost:8004} + transport: http + timeout: 30 + retry_attempts: 3 + circuit_breaker_threshold: 5 + enabled: true + description: "Issue Tracker integration for Gitea, GitHub, and GitLab" + +# Global defaults +default_timeout: 30 +default_retry_attempts: 3 +connection_pool_size: 10 +health_check_interval: 30 diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 54a06d1..3572bba 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -53,6 +53,12 @@ dependencies = [ # Celery for background task processing (Syndarix agent jobs) "celery[redis]>=5.4.0", "sse-starlette>=3.1.1", + # MCP (Model Context Protocol) for AI agent tool integration + "mcp>=1.0.0", + # Circuit breaker pattern for resilient connections + "pybreaker>=1.0.0", + # YAML configuration support + "pyyaml>=6.0.0", ] # Development dependencies @@ -151,6 +157,7 @@ unfixable = [] "app/alembic/env.py" = ["E402", "F403", "F405"] # Alembic requires specific import order "app/alembic/versions/*.py" = ["E402"] # Migration files have specific structure "tests/**/*.py" = ["S101", "N806", "B017", "N817", "S110", "ASYNC251", "RUF043"] # pytest: asserts, CamelCase fixtures, blind exceptions, try-pass patterns, and async test helpers are intentional +"app/services/mcp/*.py" = ["ASYNC109", "S311", "RUF022"] # timeout is config param not asyncio.timeout; random is ok for jitter; __all__ order is intentional for readability "app/models/__init__.py" = ["F401"] # __init__ files re-export modules "app/models/base.py" = ["F401"] # Re-exports Base for use by other models "app/utils/test_utils.py" = ["N806"] # SQLAlchemy session factories use CamelCase convention @@ -268,6 +275,14 @@ ignore_missing_imports = true module = "httpx.*" ignore_missing_imports = true +[[tool.mypy.overrides]] +module = "pybreaker.*" +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "yaml.*" +ignore_missing_imports = true + # SQLAlchemy ORM models - Column descriptors cause type confusion [[tool.mypy.overrides]] module = "app.models.*" @@ -307,6 +322,11 @@ disable_error_code = ["assignment", "arg-type", "attr-defined", "unused-ignore"] module = "app.services.oauth_service" disable_error_code = ["assignment", "arg-type", "attr-defined"] +# MCP services - circuit breaker and httpx client handling +[[tool.mypy.overrides]] +module = "app.services.mcp.*" +disable_error_code = ["attr-defined", "arg-type"] + # Test utils - Testing patterns [[tool.mypy.overrides]] module = "app.utils.auth_test_utils" diff --git a/backend/tests/api/routes/test_mcp.py b/backend/tests/api/routes/test_mcp.py new file mode 100644 index 0000000..f142972 --- /dev/null +++ b/backend/tests/api/routes/test_mcp.py @@ -0,0 +1,491 @@ +""" +Tests for MCP API Routes +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import status +from fastapi.testclient import TestClient + +from app.main import app +from app.models.user import User +from app.services.mcp import ( + MCPCircuitOpenError, + MCPClientManager, + MCPConnectionError, + MCPServerNotFoundError, + MCPTimeoutError, + MCPToolNotFoundError, + ServerHealth, +) +from app.services.mcp.config import MCPServerConfig, TransportType +from app.services.mcp.routing import ToolInfo, ToolResult + + +@pytest.fixture +def mock_mcp_client(): + """Create a mock MCP client manager.""" + client = MagicMock(spec=MCPClientManager) + client.is_initialized = True + return client + + +@pytest.fixture +def mock_superuser(): + """Create a mock superuser.""" + user = MagicMock(spec=User) + user.id = "00000000-0000-0000-0000-000000000001" + user.is_superuser = True + user.email = "admin@example.com" + return user + + +@pytest.fixture +def client(mock_mcp_client, mock_superuser): + """Create a FastAPI test client with mocked dependencies.""" + from app.api.routes.mcp import get_mcp_client + from app.api.dependencies.permissions import require_superuser + + # Override dependencies + async def override_get_mcp_client(): + return mock_mcp_client + + async def override_require_superuser(): + return mock_superuser + + app.dependency_overrides[get_mcp_client] = override_get_mcp_client + app.dependency_overrides[require_superuser] = override_require_superuser + + with patch("app.main.check_database_health", return_value=True): + yield TestClient(app) + + # Clean up + app.dependency_overrides.clear() + + +class TestListServers: + """Tests for GET /mcp/servers endpoint.""" + + def test_list_servers_success(self, client, mock_mcp_client): + """Test listing MCP servers returns correct data.""" + # Setup mock + mock_mcp_client.list_servers.return_value = ["server-1", "server-2"] + mock_mcp_client.get_server_config.side_effect = [ + MCPServerConfig( + url="http://server1:8000", + timeout=30, + enabled=True, + transport=TransportType.HTTP, + description="Server 1", + ), + MCPServerConfig( + url="http://server2:8000", + timeout=60, + enabled=True, + transport=TransportType.SSE, + description="Server 2", + ), + ] + + response = client.get("/api/v1/mcp/servers") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["total"] == 2 + assert len(data["servers"]) == 2 + assert data["servers"][0]["name"] == "server-1" + assert data["servers"][0]["url"] == "http://server1:8000" + assert data["servers"][1]["name"] == "server-2" + assert data["servers"][1]["transport"] == "sse" + + def test_list_servers_empty(self, client, mock_mcp_client): + """Test listing servers when none are registered.""" + mock_mcp_client.list_servers.return_value = [] + + response = client.get("/api/v1/mcp/servers") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["total"] == 0 + assert data["servers"] == [] + + def test_list_servers_handles_not_found(self, client, mock_mcp_client): + """Test that missing server configs are skipped gracefully.""" + mock_mcp_client.list_servers.return_value = ["server-1", "missing"] + mock_mcp_client.get_server_config.side_effect = [ + MCPServerConfig(url="http://server1:8000"), + MCPServerNotFoundError(server_name="missing"), + ] + + response = client.get("/api/v1/mcp/servers") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + # Should only include the successfully retrieved server + assert data["total"] == 1 + + +class TestListServerTools: + """Tests for GET /mcp/servers/{server_name}/tools endpoint.""" + + def test_list_server_tools_success(self, client, mock_mcp_client): + """Test listing tools for a specific server.""" + mock_mcp_client.list_tools = AsyncMock( + return_value=[ + ToolInfo(name="tool1", description="Tool 1", server_name="server-1"), + ToolInfo(name="tool2", description="Tool 2", server_name="server-1"), + ] + ) + + response = client.get("/api/v1/mcp/servers/server-1/tools") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["total"] == 2 + assert data["tools"][0]["name"] == "tool1" + assert data["tools"][1]["name"] == "tool2" + + def test_list_server_tools_not_found(self, client, mock_mcp_client): + """Test listing tools for non-existent server.""" + mock_mcp_client.list_tools = AsyncMock( + side_effect=MCPServerNotFoundError(server_name="unknown") + ) + + response = client.get("/api/v1/mcp/servers/unknown/tools") + + assert response.status_code == status.HTTP_404_NOT_FOUND + + +class TestListAllTools: + """Tests for GET /mcp/tools endpoint.""" + + def test_list_all_tools_success(self, client, mock_mcp_client): + """Test listing all tools from all servers.""" + mock_mcp_client.list_all_tools = AsyncMock( + return_value=[ + ToolInfo(name="tool1", server_name="server-1"), + ToolInfo(name="tool2", server_name="server-1"), + ToolInfo(name="tool3", server_name="server-2"), + ] + ) + + response = client.get("/api/v1/mcp/tools") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["total"] == 3 + + def test_list_all_tools_empty(self, client, mock_mcp_client): + """Test listing tools when none are available.""" + mock_mcp_client.list_all_tools = AsyncMock(return_value=[]) + + response = client.get("/api/v1/mcp/tools") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["total"] == 0 + + +class TestHealthCheck: + """Tests for GET /mcp/health endpoint.""" + + def test_health_check_success(self, client, mock_mcp_client): + """Test health check returns correct data.""" + mock_mcp_client.health_check = AsyncMock( + return_value={ + "server-1": ServerHealth( + name="server-1", + healthy=True, + state="connected", + url="http://server1:8000", + tools_count=5, + ), + "server-2": ServerHealth( + name="server-2", + healthy=False, + state="error", + url="http://server2:8000", + error="Connection refused", + ), + } + ) + + response = client.get("/api/v1/mcp/health") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["total"] == 2 + assert data["healthy_count"] == 1 + assert data["unhealthy_count"] == 1 + assert data["servers"]["server-1"]["healthy"] is True + assert data["servers"]["server-2"]["healthy"] is False + + def test_health_check_all_healthy(self, client, mock_mcp_client): + """Test health check when all servers are healthy.""" + mock_mcp_client.health_check = AsyncMock( + return_value={ + "server-1": ServerHealth( + name="server-1", + healthy=True, + state="connected", + url="http://server1:8000", + ), + } + ) + + response = client.get("/api/v1/mcp/health") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["healthy_count"] == 1 + assert data["unhealthy_count"] == 0 + + +class TestCallTool: + """Tests for POST /mcp/call endpoint.""" + + def test_call_tool_success(self, client, mock_mcp_client): + """Test successful tool execution.""" + mock_mcp_client.call_tool = AsyncMock( + return_value=ToolResult( + success=True, + data={"result": "ok"}, + tool_name="test-tool", + server_name="server-1", + execution_time_ms=123.45, + request_id="test-request-id", + ) + ) + + response = client.post( + "/api/v1/mcp/call", + json={ + "server": "server-1", + "tool": "test-tool", + "arguments": {"key": "value"}, + }, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["success"] is True + assert data["data"] == {"result": "ok"} + assert data["tool_name"] == "test-tool" + assert data["server_name"] == "server-1" + + def test_call_tool_with_timeout(self, client, mock_mcp_client): + """Test tool execution with custom timeout.""" + mock_mcp_client.call_tool = AsyncMock( + return_value=ToolResult(success=True, data={}) + ) + + response = client.post( + "/api/v1/mcp/call", + json={ + "server": "server-1", + "tool": "test-tool", + "timeout": 60.0, + }, + ) + + assert response.status_code == status.HTTP_200_OK + mock_mcp_client.call_tool.assert_called_once() + call_args = mock_mcp_client.call_tool.call_args + assert call_args.kwargs["timeout"] == 60.0 + + def test_call_tool_server_not_found(self, client, mock_mcp_client): + """Test tool execution with non-existent server.""" + mock_mcp_client.call_tool = AsyncMock( + side_effect=MCPServerNotFoundError(server_name="unknown") + ) + + response = client.post( + "/api/v1/mcp/call", + json={"server": "unknown", "tool": "test-tool"}, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_call_tool_not_found(self, client, mock_mcp_client): + """Test tool execution with non-existent tool.""" + mock_mcp_client.call_tool = AsyncMock( + side_effect=MCPToolNotFoundError(tool_name="unknown") + ) + + response = client.post( + "/api/v1/mcp/call", + json={"server": "server-1", "tool": "unknown"}, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_call_tool_timeout(self, client, mock_mcp_client): + """Test tool execution timeout.""" + mock_mcp_client.call_tool = AsyncMock( + side_effect=MCPTimeoutError( + "Request timed out", + server_name="server-1", + timeout_seconds=30.0, + ) + ) + + response = client.post( + "/api/v1/mcp/call", + json={"server": "server-1", "tool": "slow-tool"}, + ) + + assert response.status_code == status.HTTP_504_GATEWAY_TIMEOUT + + def test_call_tool_connection_error(self, client, mock_mcp_client): + """Test tool execution with connection failure.""" + mock_mcp_client.call_tool = AsyncMock( + side_effect=MCPConnectionError( + "Connection refused", + server_name="server-1", + ) + ) + + response = client.post( + "/api/v1/mcp/call", + json={"server": "server-1", "tool": "test-tool"}, + ) + + assert response.status_code == status.HTTP_502_BAD_GATEWAY + + def test_call_tool_circuit_open(self, client, mock_mcp_client): + """Test tool execution with open circuit breaker.""" + mock_mcp_client.call_tool = AsyncMock( + side_effect=MCPCircuitOpenError( + server_name="server-1", + failure_count=5, + ) + ) + + response = client.post( + "/api/v1/mcp/call", + json={"server": "server-1", "tool": "test-tool"}, + ) + + assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE + + +class TestCircuitBreakers: + """Tests for circuit breaker endpoints.""" + + def test_list_circuit_breakers(self, client, mock_mcp_client): + """Test listing circuit breaker statuses.""" + mock_mcp_client.get_circuit_breaker_status.return_value = { + "server-1": {"state": "closed", "failure_count": 0}, + "server-2": {"state": "open", "failure_count": 5}, + } + + response = client.get("/api/v1/mcp/circuit-breakers") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data["circuit_breakers"]) == 2 + + def test_list_circuit_breakers_empty(self, client, mock_mcp_client): + """Test listing when no circuit breakers exist.""" + mock_mcp_client.get_circuit_breaker_status.return_value = {} + + response = client.get("/api/v1/mcp/circuit-breakers") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["circuit_breakers"] == [] + + def test_reset_circuit_breaker_success(self, client, mock_mcp_client): + """Test successfully resetting a circuit breaker.""" + mock_mcp_client.reset_circuit_breaker = AsyncMock(return_value=True) + + response = client.post("/api/v1/mcp/circuit-breakers/server-1/reset") + + assert response.status_code == status.HTTP_204_NO_CONTENT + + def test_reset_circuit_breaker_not_found(self, client, mock_mcp_client): + """Test resetting non-existent circuit breaker.""" + mock_mcp_client.reset_circuit_breaker = AsyncMock(return_value=False) + + response = client.post("/api/v1/mcp/circuit-breakers/unknown/reset") + + assert response.status_code == status.HTTP_404_NOT_FOUND + + +class TestReconnectServer: + """Tests for POST /mcp/servers/{server_name}/reconnect endpoint.""" + + def test_reconnect_success(self, client, mock_mcp_client): + """Test successful server reconnection.""" + mock_mcp_client.disconnect = AsyncMock() + mock_mcp_client.connect = AsyncMock() + + response = client.post("/api/v1/mcp/servers/server-1/reconnect") + + assert response.status_code == status.HTTP_204_NO_CONTENT + mock_mcp_client.disconnect.assert_called_once_with("server-1") + mock_mcp_client.connect.assert_called_once_with("server-1") + + def test_reconnect_server_not_found(self, client, mock_mcp_client): + """Test reconnecting to non-existent server.""" + mock_mcp_client.disconnect = AsyncMock( + side_effect=MCPServerNotFoundError(server_name="unknown") + ) + + response = client.post("/api/v1/mcp/servers/unknown/reconnect") + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_reconnect_connection_failure(self, client, mock_mcp_client): + """Test reconnection failure.""" + mock_mcp_client.disconnect = AsyncMock() + mock_mcp_client.connect = AsyncMock( + side_effect=MCPConnectionError( + "Connection refused", + server_name="server-1", + ) + ) + + response = client.post("/api/v1/mcp/servers/server-1/reconnect") + + assert response.status_code == status.HTTP_502_BAD_GATEWAY + + +class TestMCPEndpointsEdgeCases: + """Edge case tests for MCP endpoints.""" + + def test_servers_content_type(self, client, mock_mcp_client): + """Test that endpoints return JSON content type.""" + mock_mcp_client.list_servers.return_value = [] + + response = client.get("/api/v1/mcp/servers") + + assert "application/json" in response.headers["content-type"] + + def test_call_tool_validation_error(self, client): + """Test that invalid request body returns validation error.""" + response = client.post( + "/api/v1/mcp/call", + json={}, # Missing required fields + ) + + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + def test_call_tool_missing_server(self, client): + """Test that missing server field returns validation error.""" + response = client.post( + "/api/v1/mcp/call", + json={"tool": "test-tool"}, + ) + + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + def test_call_tool_missing_tool(self, client): + """Test that missing tool field returns validation error.""" + response = client.post( + "/api/v1/mcp/call", + json={"server": "server-1"}, + ) + + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY diff --git a/backend/tests/services/mcp/__init__.py b/backend/tests/services/mcp/__init__.py new file mode 100644 index 0000000..b397e8b --- /dev/null +++ b/backend/tests/services/mcp/__init__.py @@ -0,0 +1 @@ +"""MCP Service Tests Package.""" diff --git a/backend/tests/services/mcp/test_client_manager.py b/backend/tests/services/mcp/test_client_manager.py new file mode 100644 index 0000000..326999f --- /dev/null +++ b/backend/tests/services/mcp/test_client_manager.py @@ -0,0 +1,395 @@ +""" +Tests for MCP Client Manager +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from app.services.mcp.client_manager import ( + MCPClientManager, + ServerHealth, + get_mcp_client, + reset_mcp_client, + shutdown_mcp_client, +) +from app.services.mcp.config import MCPConfig, MCPServerConfig +from app.services.mcp.connection import ConnectionState +from app.services.mcp.exceptions import MCPServerNotFoundError +from app.services.mcp.registry import MCPServerRegistry +from app.services.mcp.routing import ToolInfo, ToolResult + + +@pytest.fixture +def reset_registry(): + """Reset the singleton registry before and after each test.""" + MCPServerRegistry.reset_instance() + reset_mcp_client() + yield + MCPServerRegistry.reset_instance() + reset_mcp_client() + + +@pytest.fixture +def sample_config(): + """Create a sample MCP configuration.""" + return MCPConfig( + mcp_servers={ + "server-1": MCPServerConfig( + url="http://server1:8000", + timeout=30, + enabled=True, + ), + "server-2": MCPServerConfig( + url="http://server2:8000", + timeout=60, + enabled=True, + ), + } + ) + + +class TestServerHealth: + """Tests for ServerHealth dataclass.""" + + def test_healthy_server(self): + """Test healthy server status.""" + health = ServerHealth( + name="test-server", + healthy=True, + state="connected", + url="http://test:8000", + tools_count=5, + ) + assert health.healthy is True + assert health.error is None + assert health.tools_count == 5 + + def test_unhealthy_server(self): + """Test unhealthy server status.""" + health = ServerHealth( + name="test-server", + healthy=False, + state="error", + url="http://test:8000", + error="Connection refused", + ) + assert health.healthy is False + assert health.error == "Connection refused" + + def test_to_dict(self): + """Test converting to dictionary.""" + health = ServerHealth( + name="test-server", + healthy=True, + state="connected", + url="http://test:8000", + tools_count=3, + ) + d = health.to_dict() + + assert d["name"] == "test-server" + assert d["healthy"] is True + assert d["state"] == "connected" + assert d["url"] == "http://test:8000" + assert d["tools_count"] == 3 + + +class TestMCPClientManager: + """Tests for MCPClientManager class.""" + + def test_initial_state(self, reset_registry): + """Test initial manager state.""" + manager = MCPClientManager() + assert manager.is_initialized is False + + @pytest.mark.asyncio + async def test_initialize(self, reset_registry, sample_config): + """Test manager initialization.""" + manager = MCPClientManager(config=sample_config) + + with patch.object(manager._pool, "get_connection") as mock_get_conn: + mock_conn = AsyncMock() + mock_conn.is_connected = True + mock_get_conn.return_value = mock_conn + + with patch.object(manager, "_router") as mock_router: + mock_router.discover_tools = AsyncMock() + + await manager.initialize() + + assert manager.is_initialized is True + + @pytest.mark.asyncio + async def test_shutdown(self, reset_registry, sample_config): + """Test manager shutdown.""" + manager = MCPClientManager(config=sample_config) + manager._initialized = True + + with patch.object(manager._pool, "close_all") as mock_close: + mock_close.return_value = None + await manager.shutdown() + + assert manager.is_initialized is False + mock_close.assert_called_once() + + @pytest.mark.asyncio + async def test_connect(self, reset_registry, sample_config): + """Test connecting to specific server.""" + manager = MCPClientManager(config=sample_config) + + with patch.object(manager._pool, "get_connection") as mock_get_conn: + mock_conn = AsyncMock() + mock_conn.is_connected = True + mock_get_conn.return_value = mock_conn + + await manager.connect("server-1") + + mock_get_conn.assert_called_once() + + @pytest.mark.asyncio + async def test_connect_unknown_server(self, reset_registry, sample_config): + """Test connecting to unknown server raises error.""" + manager = MCPClientManager(config=sample_config) + + with pytest.raises(MCPServerNotFoundError): + await manager.connect("unknown-server") + + @pytest.mark.asyncio + async def test_disconnect(self, reset_registry, sample_config): + """Test disconnecting from server.""" + manager = MCPClientManager(config=sample_config) + + with patch.object(manager._pool, "close_connection") as mock_close: + await manager.disconnect("server-1") + mock_close.assert_called_once_with("server-1") + + @pytest.mark.asyncio + async def test_disconnect_all(self, reset_registry, sample_config): + """Test disconnecting from all servers.""" + manager = MCPClientManager(config=sample_config) + + with patch.object(manager._pool, "close_all") as mock_close: + await manager.disconnect_all() + mock_close.assert_called_once() + + @pytest.mark.asyncio + async def test_call_tool(self, reset_registry, sample_config): + """Test calling a tool.""" + manager = MCPClientManager(config=sample_config) + manager._initialized = True + + expected_result = ToolResult( + success=True, + data={"id": "123"}, + tool_name="create_issue", + server_name="server-1", + ) + + mock_router = MagicMock() + mock_router.call_tool = AsyncMock(return_value=expected_result) + manager._router = mock_router + + result = await manager.call_tool( + server="server-1", + tool="create_issue", + args={"title": "Test"}, + ) + + assert result.success is True + assert result.data == {"id": "123"} + mock_router.call_tool.assert_called_once() + + @pytest.mark.asyncio + async def test_route_tool(self, reset_registry, sample_config): + """Test routing a tool call.""" + manager = MCPClientManager(config=sample_config) + manager._initialized = True + + expected_result = ToolResult( + success=True, + data={"result": "ok"}, + tool_name="auto_tool", + server_name="server-2", + ) + + mock_router = MagicMock() + mock_router.route_tool = AsyncMock(return_value=expected_result) + manager._router = mock_router + + result = await manager.route_tool( + tool="auto_tool", + args={"key": "value"}, + ) + + assert result.success is True + assert result.server_name == "server-2" + mock_router.route_tool.assert_called_once() + + @pytest.mark.asyncio + async def test_list_tools(self, reset_registry, sample_config): + """Test listing tools for a server.""" + manager = MCPClientManager(config=sample_config) + + # Set up capabilities in registry + manager._registry.set_capabilities( + "server-1", + tools=[ + {"name": "tool1", "description": "Tool 1"}, + {"name": "tool2", "description": "Tool 2"}, + ], + ) + + tools = await manager.list_tools("server-1") + + assert len(tools) == 2 + assert tools[0].name == "tool1" + assert tools[1].name == "tool2" + + @pytest.mark.asyncio + async def test_list_all_tools(self, reset_registry, sample_config): + """Test listing all tools from all servers.""" + manager = MCPClientManager(config=sample_config) + manager._initialized = True + + expected_tools = [ + ToolInfo(name="tool1", server_name="server-1"), + ToolInfo(name="tool2", server_name="server-2"), + ] + + mock_router = MagicMock() + mock_router.list_all_tools = AsyncMock(return_value=expected_tools) + manager._router = mock_router + + tools = await manager.list_all_tools() + + assert len(tools) == 2 + + @pytest.mark.asyncio + async def test_health_check(self, reset_registry, sample_config): + """Test health check on all servers.""" + manager = MCPClientManager(config=sample_config) + + with patch.object(manager._pool, "get_status") as mock_status: + mock_status.return_value = { + "server-1": {"state": "connected"}, + "server-2": {"state": "disconnected"}, + } + + with patch.object(manager._pool, "health_check_all") as mock_health: + mock_health.return_value = { + "server-1": True, + "server-2": False, + } + + health = await manager.health_check() + + assert "server-1" in health + assert "server-2" in health + assert health["server-1"].healthy is True + assert health["server-2"].healthy is False + + def test_list_servers(self, reset_registry, sample_config): + """Test listing registered servers.""" + manager = MCPClientManager(config=sample_config) + servers = manager.list_servers() + + assert "server-1" in servers + assert "server-2" in servers + + def test_list_enabled_servers(self, reset_registry, sample_config): + """Test listing enabled servers.""" + manager = MCPClientManager(config=sample_config) + servers = manager.list_enabled_servers() + + assert "server-1" in servers + assert "server-2" in servers + + def test_get_server_config(self, reset_registry, sample_config): + """Test getting server configuration.""" + manager = MCPClientManager(config=sample_config) + + config = manager.get_server_config("server-1") + assert config.url == "http://server1:8000" + assert config.timeout == 30 + + def test_get_server_config_not_found(self, reset_registry, sample_config): + """Test getting unknown server config raises error.""" + manager = MCPClientManager(config=sample_config) + + with pytest.raises(MCPServerNotFoundError): + manager.get_server_config("unknown") + + def test_register_server(self, reset_registry, sample_config): + """Test registering new server at runtime.""" + manager = MCPClientManager(config=sample_config) + + new_config = MCPServerConfig(url="http://new:8000") + manager.register_server("new-server", new_config) + + assert "new-server" in manager.list_servers() + + def test_unregister_server(self, reset_registry, sample_config): + """Test unregistering a server.""" + manager = MCPClientManager(config=sample_config) + + result = manager.unregister_server("server-1") + assert result is True + assert "server-1" not in manager.list_servers() + + # Unregistering non-existent returns False + result = manager.unregister_server("nonexistent") + assert result is False + + def test_circuit_breaker_status(self, reset_registry, sample_config): + """Test getting circuit breaker status.""" + manager = MCPClientManager(config=sample_config) + + # No router yet + status = manager.get_circuit_breaker_status() + assert status == {} + + @pytest.mark.asyncio + async def test_reset_circuit_breaker(self, reset_registry, sample_config): + """Test resetting circuit breaker.""" + manager = MCPClientManager(config=sample_config) + + # No router yet + result = await manager.reset_circuit_breaker("server-1") + assert result is False + + +class TestModuleLevelFunctions: + """Tests for module-level convenience functions.""" + + @pytest.mark.asyncio + async def test_get_mcp_client_creates_singleton(self, reset_registry): + """Test get_mcp_client creates and returns singleton.""" + with patch( + "app.services.mcp.client_manager.MCPClientManager.initialize" + ) as mock_init: + mock_init.return_value = None + + client1 = await get_mcp_client() + client2 = await get_mcp_client() + + assert client1 is client2 + + @pytest.mark.asyncio + async def test_shutdown_mcp_client(self, reset_registry): + """Test shutting down the global client.""" + with patch( + "app.services.mcp.client_manager.MCPClientManager.initialize" + ) as mock_init: + mock_init.return_value = None + + client = await get_mcp_client() + + with patch.object(client, "shutdown") as mock_shutdown: + mock_shutdown.return_value = None + await shutdown_mcp_client() + + def test_reset_mcp_client(self, reset_registry): + """Test resetting the global client.""" + reset_mcp_client() + # Should not raise diff --git a/backend/tests/services/mcp/test_config.py b/backend/tests/services/mcp/test_config.py new file mode 100644 index 0000000..b8a891d --- /dev/null +++ b/backend/tests/services/mcp/test_config.py @@ -0,0 +1,319 @@ +""" +Tests for MCP Configuration System +""" + +import os +import tempfile +from pathlib import Path + +import pytest +import yaml + +from app.services.mcp.config import ( + MCPConfig, + MCPServerConfig, + TransportType, + create_default_config, + load_mcp_config, +) + + +class TestTransportType: + """Tests for TransportType enum.""" + + def test_transport_types(self): + """Test that all transport types are defined.""" + assert TransportType.HTTP == "http" + assert TransportType.STDIO == "stdio" + assert TransportType.SSE == "sse" + + def test_transport_type_from_string(self): + """Test creating transport type from string.""" + assert TransportType("http") == TransportType.HTTP + assert TransportType("stdio") == TransportType.STDIO + assert TransportType("sse") == TransportType.SSE + + +class TestMCPServerConfig: + """Tests for MCPServerConfig model.""" + + def test_minimal_config(self): + """Test creating config with only required fields.""" + config = MCPServerConfig(url="http://localhost:8000") + assert config.url == "http://localhost:8000" + assert config.transport == TransportType.HTTP + assert config.timeout == 30 + assert config.retry_attempts == 3 + assert config.enabled is True + + def test_full_config(self): + """Test creating config with all fields.""" + config = MCPServerConfig( + url="http://localhost:8000", + transport=TransportType.SSE, + timeout=60, + retry_attempts=5, + retry_delay=2.0, + retry_max_delay=60.0, + circuit_breaker_threshold=10, + circuit_breaker_timeout=60.0, + enabled=False, + description="Test server", + ) + assert config.timeout == 60 + assert config.transport == TransportType.SSE + assert config.retry_attempts == 5 + assert config.retry_delay == 2.0 + assert config.retry_max_delay == 60.0 + assert config.circuit_breaker_threshold == 10 + assert config.circuit_breaker_timeout == 60.0 + assert config.enabled is False + assert config.description == "Test server" + + def test_env_var_expansion_simple(self): + """Test simple environment variable expansion.""" + os.environ["TEST_SERVER_URL"] = "http://test-server:9000" + try: + config = MCPServerConfig(url="${TEST_SERVER_URL}") + assert config.url == "http://test-server:9000" + finally: + del os.environ["TEST_SERVER_URL"] + + def test_env_var_expansion_with_default(self): + """Test environment variable expansion with default.""" + # Ensure env var is not set + os.environ.pop("NONEXISTENT_URL", None) + config = MCPServerConfig(url="${NONEXISTENT_URL:-http://default:8000}") + assert config.url == "http://default:8000" + + def test_env_var_expansion_override_default(self): + """Test environment variable override of default.""" + os.environ["TEST_OVERRIDE_URL"] = "http://override:9000" + try: + config = MCPServerConfig(url="${TEST_OVERRIDE_URL:-http://default:8000}") + assert config.url == "http://override:9000" + finally: + del os.environ["TEST_OVERRIDE_URL"] + + def test_timeout_validation(self): + """Test timeout validation bounds.""" + # Valid bounds + config = MCPServerConfig(url="http://localhost", timeout=1) + assert config.timeout == 1 + + config = MCPServerConfig(url="http://localhost", timeout=600) + assert config.timeout == 600 + + # Invalid bounds + with pytest.raises(ValueError): + MCPServerConfig(url="http://localhost", timeout=0) + + with pytest.raises(ValueError): + MCPServerConfig(url="http://localhost", timeout=601) + + def test_retry_attempts_validation(self): + """Test retry attempts validation bounds.""" + config = MCPServerConfig(url="http://localhost", retry_attempts=0) + assert config.retry_attempts == 0 + + config = MCPServerConfig(url="http://localhost", retry_attempts=10) + assert config.retry_attempts == 10 + + with pytest.raises(ValueError): + MCPServerConfig(url="http://localhost", retry_attempts=-1) + + with pytest.raises(ValueError): + MCPServerConfig(url="http://localhost", retry_attempts=11) + + +class TestMCPConfig: + """Tests for MCPConfig model.""" + + def test_empty_config(self): + """Test creating empty config.""" + config = MCPConfig() + assert config.mcp_servers == {} + assert config.default_timeout == 30 + assert config.default_retry_attempts == 3 + assert config.connection_pool_size == 10 + assert config.health_check_interval == 30 + + def test_config_with_servers(self): + """Test creating config with servers.""" + config = MCPConfig( + mcp_servers={ + "server-1": MCPServerConfig(url="http://server1:8000"), + "server-2": MCPServerConfig(url="http://server2:8000"), + } + ) + assert len(config.mcp_servers) == 2 + assert "server-1" in config.mcp_servers + assert "server-2" in config.mcp_servers + + def test_get_server(self): + """Test getting server by name.""" + config = MCPConfig( + mcp_servers={ + "server-1": MCPServerConfig(url="http://server1:8000"), + } + ) + server = config.get_server("server-1") + assert server is not None + assert server.url == "http://server1:8000" + + missing = config.get_server("nonexistent") + assert missing is None + + def test_get_enabled_servers(self): + """Test getting only enabled servers.""" + config = MCPConfig( + mcp_servers={ + "enabled-1": MCPServerConfig(url="http://e1:8000", enabled=True), + "disabled-1": MCPServerConfig(url="http://d1:8000", enabled=False), + "enabled-2": MCPServerConfig(url="http://e2:8000", enabled=True), + } + ) + enabled = config.get_enabled_servers() + assert len(enabled) == 2 + assert "enabled-1" in enabled + assert "enabled-2" in enabled + assert "disabled-1" not in enabled + + def test_list_server_names(self): + """Test listing server names.""" + config = MCPConfig( + mcp_servers={ + "server-a": MCPServerConfig(url="http://a:8000"), + "server-b": MCPServerConfig(url="http://b:8000"), + } + ) + names = config.list_server_names() + assert sorted(names) == ["server-a", "server-b"] + + def test_from_dict(self): + """Test creating config from dictionary.""" + data = { + "mcp_servers": { + "test-server": { + "url": "http://test:8000", + "timeout": 45, + } + }, + "default_timeout": 60, + } + config = MCPConfig.from_dict(data) + assert config.default_timeout == 60 + assert config.mcp_servers["test-server"].timeout == 45 + + def test_from_yaml(self): + """Test loading config from YAML file.""" + yaml_content = """ +mcp_servers: + test-server: + url: http://test:8000 + timeout: 45 + transport: http + enabled: true +default_timeout: 60 +connection_pool_size: 20 +""" + with tempfile.NamedTemporaryFile( + mode="w", suffix=".yaml", delete=False + ) as f: + f.write(yaml_content) + f.flush() + + try: + config = MCPConfig.from_yaml(f.name) + assert config.default_timeout == 60 + assert config.connection_pool_size == 20 + assert "test-server" in config.mcp_servers + assert config.mcp_servers["test-server"].timeout == 45 + finally: + os.unlink(f.name) + + def test_from_yaml_file_not_found(self): + """Test error when YAML file not found.""" + with pytest.raises(FileNotFoundError): + MCPConfig.from_yaml("/nonexistent/path/config.yaml") + + +class TestLoadMCPConfig: + """Tests for load_mcp_config function.""" + + def test_load_with_explicit_path(self): + """Test loading config with explicit path.""" + yaml_content = """ +mcp_servers: + explicit-server: + url: http://explicit:8000 +""" + with tempfile.NamedTemporaryFile( + mode="w", suffix=".yaml", delete=False + ) as f: + f.write(yaml_content) + f.flush() + + try: + config = load_mcp_config(f.name) + assert "explicit-server" in config.mcp_servers + finally: + os.unlink(f.name) + + def test_load_with_env_var(self): + """Test loading config from environment variable path.""" + yaml_content = """ +mcp_servers: + env-server: + url: http://env:8000 +""" + with tempfile.NamedTemporaryFile( + mode="w", suffix=".yaml", delete=False + ) as f: + f.write(yaml_content) + f.flush() + + os.environ["MCP_CONFIG_PATH"] = f.name + try: + config = load_mcp_config() + assert "env-server" in config.mcp_servers + finally: + del os.environ["MCP_CONFIG_PATH"] + os.unlink(f.name) + + def test_load_returns_empty_config_if_missing(self): + """Test that missing file returns empty config.""" + os.environ.pop("MCP_CONFIG_PATH", None) + config = load_mcp_config("/nonexistent/path/config.yaml") + assert config.mcp_servers == {} + + +class TestCreateDefaultConfig: + """Tests for create_default_config function.""" + + def test_creates_standard_servers(self): + """Test that default config has standard servers.""" + config = create_default_config() + + assert "llm-gateway" in config.mcp_servers + assert "knowledge-base" in config.mcp_servers + assert "git-ops" in config.mcp_servers + assert "issues" in config.mcp_servers + + def test_servers_have_correct_defaults(self): + """Test that servers have correct default values.""" + config = create_default_config() + + llm = config.mcp_servers["llm-gateway"] + assert llm.timeout == 60 # LLM has longer timeout + assert llm.transport == TransportType.HTTP + + git = config.mcp_servers["git-ops"] + assert git.timeout == 120 # Git ops has longest timeout + + def test_servers_are_enabled(self): + """Test that all default servers are enabled.""" + config = create_default_config() + + for name, server in config.mcp_servers.items(): + assert server.enabled is True, f"Server {name} should be enabled" diff --git a/backend/tests/services/mcp/test_connection.py b/backend/tests/services/mcp/test_connection.py new file mode 100644 index 0000000..25f9579 --- /dev/null +++ b/backend/tests/services/mcp/test_connection.py @@ -0,0 +1,405 @@ +""" +Tests for MCP Connection Management +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from app.services.mcp.config import MCPServerConfig, TransportType +from app.services.mcp.connection import ( + ConnectionPool, + ConnectionState, + MCPConnection, +) +from app.services.mcp.exceptions import MCPConnectionError, MCPTimeoutError + + +@pytest.fixture +def server_config(): + """Create a sample server configuration.""" + return MCPServerConfig( + url="http://localhost:8000", + transport=TransportType.HTTP, + timeout=30, + retry_attempts=3, + retry_delay=0.1, # Short delay for tests + retry_max_delay=1.0, + ) + + +class TestConnectionState: + """Tests for ConnectionState enum.""" + + def test_connection_states(self): + """Test all connection states are defined.""" + assert ConnectionState.DISCONNECTED == "disconnected" + assert ConnectionState.CONNECTING == "connecting" + assert ConnectionState.CONNECTED == "connected" + assert ConnectionState.RECONNECTING == "reconnecting" + assert ConnectionState.ERROR == "error" + + +class TestMCPConnection: + """Tests for MCPConnection class.""" + + def test_initial_state(self, server_config): + """Test initial connection state.""" + conn = MCPConnection("test-server", server_config) + assert conn.server_name == "test-server" + assert conn.config == server_config + assert conn.state == ConnectionState.DISCONNECTED + assert conn.is_connected is False + assert conn.last_error is None + + @pytest.mark.asyncio + async def test_connect_success(self, server_config): + """Test successful connection.""" + conn = MCPConnection("test-server", server_config) + + mock_response = MagicMock() + mock_response.status_code = 200 + + with patch("httpx.AsyncClient") as MockClient: + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=mock_response) + MockClient.return_value = mock_client + + await conn.connect() + + assert conn.state == ConnectionState.CONNECTED + assert conn.is_connected is True + + @pytest.mark.asyncio + async def test_connect_404_capabilities_ok(self, server_config): + """Test connection succeeds even if /mcp/capabilities returns 404.""" + conn = MCPConnection("test-server", server_config) + + mock_response = MagicMock() + mock_response.status_code = 404 + + with patch("httpx.AsyncClient") as MockClient: + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=mock_response) + MockClient.return_value = mock_client + + await conn.connect() + + assert conn.state == ConnectionState.CONNECTED + + @pytest.mark.asyncio + async def test_connect_failure_with_retry(self, server_config): + """Test connection failure with retries.""" + # Reduce retry attempts for faster test + server_config.retry_attempts = 2 + server_config.retry_delay = 0.01 + conn = MCPConnection("test-server", server_config) + + with patch("httpx.AsyncClient") as MockClient: + mock_client = AsyncMock() + mock_client.get = AsyncMock( + side_effect=httpx.ConnectError("Connection refused") + ) + MockClient.return_value = mock_client + + with pytest.raises(MCPConnectionError) as exc_info: + await conn.connect() + + assert conn.state == ConnectionState.ERROR + assert "Failed to connect after" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_disconnect(self, server_config): + """Test disconnection.""" + conn = MCPConnection("test-server", server_config) + + mock_response = MagicMock() + mock_response.status_code = 200 + + with patch("httpx.AsyncClient") as MockClient: + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=mock_response) + mock_client.aclose = AsyncMock() + MockClient.return_value = mock_client + + await conn.connect() + assert conn.is_connected is True + + await conn.disconnect() + assert conn.state == ConnectionState.DISCONNECTED + assert conn.is_connected is False + + @pytest.mark.asyncio + async def test_reconnect(self, server_config): + """Test reconnection.""" + conn = MCPConnection("test-server", server_config) + + mock_response = MagicMock() + mock_response.status_code = 200 + + with patch("httpx.AsyncClient") as MockClient: + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=mock_response) + mock_client.aclose = AsyncMock() + MockClient.return_value = mock_client + + await conn.connect() + assert conn.is_connected is True + + await conn.reconnect() + assert conn.is_connected is True + + @pytest.mark.asyncio + async def test_health_check_success(self, server_config): + """Test successful health check.""" + conn = MCPConnection("test-server", server_config) + + mock_response = MagicMock() + mock_response.status_code = 200 + + with patch("httpx.AsyncClient") as MockClient: + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=mock_response) + MockClient.return_value = mock_client + + await conn.connect() + healthy = await conn.health_check() + + assert healthy is True + + @pytest.mark.asyncio + async def test_health_check_disconnected(self, server_config): + """Test health check when disconnected.""" + conn = MCPConnection("test-server", server_config) + healthy = await conn.health_check() + assert healthy is False + + @pytest.mark.asyncio + async def test_execute_request_get(self, server_config): + """Test executing GET request.""" + conn = MCPConnection("test-server", server_config) + + mock_connect_response = MagicMock() + mock_connect_response.status_code = 200 + + mock_request_response = MagicMock() + mock_request_response.status_code = 200 + mock_request_response.json.return_value = {"tools": []} + mock_request_response.raise_for_status = MagicMock() + + with patch("httpx.AsyncClient") as MockClient: + mock_client = AsyncMock() + mock_client.get = AsyncMock( + side_effect=[mock_connect_response, mock_request_response] + ) + MockClient.return_value = mock_client + + await conn.connect() + result = await conn.execute_request("GET", "/mcp/tools") + + assert result == {"tools": []} + + @pytest.mark.asyncio + async def test_execute_request_post(self, server_config): + """Test executing POST request.""" + conn = MCPConnection("test-server", server_config) + + mock_connect_response = MagicMock() + mock_connect_response.status_code = 200 + + mock_request_response = MagicMock() + mock_request_response.status_code = 200 + mock_request_response.json.return_value = {"result": "success"} + mock_request_response.raise_for_status = MagicMock() + + with patch("httpx.AsyncClient") as MockClient: + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=mock_connect_response) + mock_client.post = AsyncMock(return_value=mock_request_response) + MockClient.return_value = mock_client + + await conn.connect() + result = await conn.execute_request( + "POST", "/mcp", data={"method": "test"} + ) + + assert result == {"result": "success"} + + @pytest.mark.asyncio + async def test_execute_request_timeout(self, server_config): + """Test request timeout.""" + conn = MCPConnection("test-server", server_config) + + mock_connect_response = MagicMock() + mock_connect_response.status_code = 200 + + with patch("httpx.AsyncClient") as MockClient: + mock_client = AsyncMock() + mock_client.get = AsyncMock( + side_effect=[ + mock_connect_response, + httpx.TimeoutException("Request timeout"), + ] + ) + MockClient.return_value = mock_client + + await conn.connect() + + with pytest.raises(MCPTimeoutError) as exc_info: + await conn.execute_request("GET", "/slow-endpoint") + + assert "timeout" in str(exc_info.value).lower() + + @pytest.mark.asyncio + async def test_execute_request_not_connected(self, server_config): + """Test request when not connected.""" + conn = MCPConnection("test-server", server_config) + + with pytest.raises(MCPConnectionError) as exc_info: + await conn.execute_request("GET", "/test") + + assert "Not connected" in str(exc_info.value) + + def test_backoff_delay_calculation(self, server_config): + """Test exponential backoff delay calculation.""" + conn = MCPConnection("test-server", server_config) + + # First attempt + conn._connection_attempts = 1 + delay1 = conn._calculate_backoff_delay() + + # Second attempt + conn._connection_attempts = 2 + delay2 = conn._calculate_backoff_delay() + + # Third attempt + conn._connection_attempts = 3 + delay3 = conn._calculate_backoff_delay() + + # Delays should generally increase (with some jitter) + # Base is 0.1, so rough expectations: + # Attempt 1: ~0.1s + # Attempt 2: ~0.2s + # Attempt 3: ~0.4s + assert delay1 > 0 + assert delay2 > delay1 * 0.5 # Allow for jitter + assert delay3 <= server_config.retry_max_delay * 1.25 # Within max + jitter + + +class TestConnectionPool: + """Tests for ConnectionPool class.""" + + @pytest.fixture + def pool(self): + """Create a connection pool.""" + return ConnectionPool(max_connections_per_server=5) + + @pytest.mark.asyncio + async def test_get_connection_creates_new(self, pool, server_config): + """Test getting connection creates new one if not exists.""" + mock_response = MagicMock() + mock_response.status_code = 200 + + with patch("httpx.AsyncClient") as MockClient: + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=mock_response) + MockClient.return_value = mock_client + + conn = await pool.get_connection("test-server", server_config) + + assert conn is not None + assert conn.is_connected is True + assert conn.server_name == "test-server" + + @pytest.mark.asyncio + async def test_get_connection_reuses_existing(self, pool, server_config): + """Test getting connection reuses existing one.""" + mock_response = MagicMock() + mock_response.status_code = 200 + + with patch("httpx.AsyncClient") as MockClient: + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=mock_response) + MockClient.return_value = mock_client + + conn1 = await pool.get_connection("test-server", server_config) + conn2 = await pool.get_connection("test-server", server_config) + + assert conn1 is conn2 + + @pytest.mark.asyncio + async def test_close_connection(self, pool, server_config): + """Test closing a connection.""" + mock_response = MagicMock() + mock_response.status_code = 200 + + with patch("httpx.AsyncClient") as MockClient: + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=mock_response) + mock_client.aclose = AsyncMock() + MockClient.return_value = mock_client + + await pool.get_connection("test-server", server_config) + await pool.close_connection("test-server") + + assert "test-server" not in pool._connections + + @pytest.mark.asyncio + async def test_close_all(self, pool, server_config): + """Test closing all connections.""" + mock_response = MagicMock() + mock_response.status_code = 200 + + with patch("httpx.AsyncClient") as MockClient: + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=mock_response) + mock_client.aclose = AsyncMock() + MockClient.return_value = mock_client + + config2 = MCPServerConfig(url="http://server2:8000") + await pool.get_connection("server-1", server_config) + await pool.get_connection("server-2", config2) + + assert len(pool._connections) == 2 + + await pool.close_all() + + assert len(pool._connections) == 0 + + @pytest.mark.asyncio + async def test_health_check_all(self, pool, server_config): + """Test health check on all connections.""" + mock_response = MagicMock() + mock_response.status_code = 200 + + with patch("httpx.AsyncClient") as MockClient: + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=mock_response) + MockClient.return_value = mock_client + + await pool.get_connection("test-server", server_config) + results = await pool.health_check_all() + + assert "test-server" in results + assert results["test-server"] is True + + def test_get_status(self, pool, server_config): + """Test getting pool status.""" + status = pool.get_status() + assert status == {} + + @pytest.mark.asyncio + async def test_connection_context_manager(self, pool, server_config): + """Test connection context manager.""" + mock_response = MagicMock() + mock_response.status_code = 200 + + with patch("httpx.AsyncClient") as MockClient: + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=mock_response) + MockClient.return_value = mock_client + + async with pool.connection("test-server", server_config) as conn: + assert conn.is_connected is True + assert conn.server_name == "test-server" diff --git a/backend/tests/services/mcp/test_exceptions.py b/backend/tests/services/mcp/test_exceptions.py new file mode 100644 index 0000000..2647067 --- /dev/null +++ b/backend/tests/services/mcp/test_exceptions.py @@ -0,0 +1,259 @@ +""" +Tests for MCP Exception Classes +""" + +import pytest + +from app.services.mcp.exceptions import ( + MCPCircuitOpenError, + MCPConnectionError, + MCPError, + MCPServerNotFoundError, + MCPTimeoutError, + MCPToolError, + MCPToolNotFoundError, + MCPValidationError, +) + + +class TestMCPError: + """Tests for base MCPError class.""" + + def test_basic_error(self): + """Test basic error creation.""" + error = MCPError("Test error") + assert error.message == "Test error" + assert error.server_name is None + assert error.details == {} + assert str(error) == "Test error" + + def test_error_with_server_name(self): + """Test error with server name.""" + error = MCPError("Test error", server_name="test-server") + assert error.server_name == "test-server" + assert "server=test-server" in str(error) + + def test_error_with_details(self): + """Test error with additional details.""" + error = MCPError( + "Test error", + server_name="test-server", + details={"key": "value"}, + ) + assert error.details == {"key": "value"} + assert "details={'key': 'value'}" in str(error) + + +class TestMCPConnectionError: + """Tests for MCPConnectionError class.""" + + def test_basic_connection_error(self): + """Test basic connection error.""" + error = MCPConnectionError("Connection failed") + assert error.message == "Connection failed" + assert error.url is None + assert error.cause is None + + def test_connection_error_with_url(self): + """Test connection error with URL.""" + error = MCPConnectionError( + "Connection failed", + server_name="test-server", + url="http://localhost:8000", + ) + assert error.url == "http://localhost:8000" + assert "url=http://localhost:8000" in str(error) + + def test_connection_error_with_cause(self): + """Test connection error with cause.""" + cause = ConnectionError("Network error") + error = MCPConnectionError( + "Connection failed", + cause=cause, + ) + assert error.cause is cause + assert "ConnectionError" in str(error) + + +class TestMCPTimeoutError: + """Tests for MCPTimeoutError class.""" + + def test_basic_timeout_error(self): + """Test basic timeout error.""" + error = MCPTimeoutError("Request timed out") + assert error.message == "Request timed out" + assert error.timeout_seconds is None + assert error.operation is None + + def test_timeout_error_with_details(self): + """Test timeout error with details.""" + error = MCPTimeoutError( + "Request timed out", + server_name="test-server", + timeout_seconds=30.0, + operation="POST /mcp", + ) + assert error.timeout_seconds == 30.0 + assert error.operation == "POST /mcp" + assert "timeout=30.0s" in str(error) + assert "operation=POST /mcp" in str(error) + + +class TestMCPToolError: + """Tests for MCPToolError class.""" + + def test_basic_tool_error(self): + """Test basic tool error.""" + error = MCPToolError("Tool execution failed") + assert error.message == "Tool execution failed" + assert error.tool_name is None + assert error.tool_args is None + assert error.error_code is None + + def test_tool_error_with_details(self): + """Test tool error with all details.""" + error = MCPToolError( + "Tool execution failed", + server_name="llm-gateway", + tool_name="chat", + tool_args={"prompt": "Hello"}, + error_code="INVALID_ARGS", + ) + assert error.tool_name == "chat" + assert error.tool_args == {"prompt": "Hello"} + assert error.error_code == "INVALID_ARGS" + assert "tool=chat" in str(error) + assert "error_code=INVALID_ARGS" in str(error) + + +class TestMCPServerNotFoundError: + """Tests for MCPServerNotFoundError class.""" + + def test_server_not_found(self): + """Test server not found error.""" + error = MCPServerNotFoundError("unknown-server") + assert error.server_name == "unknown-server" + assert "MCP server not found: unknown-server" in error.message + assert error.available_servers == [] + + def test_server_not_found_with_available(self): + """Test server not found with available servers listed.""" + error = MCPServerNotFoundError( + "unknown-server", + available_servers=["server-1", "server-2"], + ) + assert error.available_servers == ["server-1", "server-2"] + assert "available=['server-1', 'server-2']" in str(error) + + +class TestMCPToolNotFoundError: + """Tests for MCPToolNotFoundError class.""" + + def test_tool_not_found(self): + """Test tool not found error.""" + error = MCPToolNotFoundError("unknown-tool") + assert error.tool_name == "unknown-tool" + assert "Tool not found: unknown-tool" in error.message + assert error.available_tools == [] + + def test_tool_not_found_with_available(self): + """Test tool not found with available tools listed.""" + error = MCPToolNotFoundError( + "unknown-tool", + available_tools=["tool-1", "tool-2", "tool-3", "tool-4", "tool-5", "tool-6"], + ) + assert len(error.available_tools) == 6 + # Should show first 5 tools with ellipsis + assert "available_tools=['tool-1', 'tool-2', 'tool-3', 'tool-4', 'tool-5']..." in str(error) + + +class TestMCPCircuitOpenError: + """Tests for MCPCircuitOpenError class.""" + + def test_circuit_open_error(self): + """Test circuit open error.""" + error = MCPCircuitOpenError("test-server") + assert error.server_name == "test-server" + assert "Circuit breaker open for server: test-server" in error.message + assert error.failure_count is None + assert error.reset_timeout is None + + def test_circuit_open_error_with_details(self): + """Test circuit open error with details.""" + error = MCPCircuitOpenError( + "test-server", + failure_count=5, + reset_timeout=30.0, + ) + assert error.failure_count == 5 + assert error.reset_timeout == 30.0 + assert "failures=5" in str(error) + assert "reset_in=30.0s" in str(error) + + +class TestMCPValidationError: + """Tests for MCPValidationError class.""" + + def test_validation_error(self): + """Test validation error.""" + error = MCPValidationError("Validation failed") + assert error.message == "Validation failed" + assert error.tool_name is None + assert error.field_errors == {} + + def test_validation_error_with_details(self): + """Test validation error with field errors.""" + error = MCPValidationError( + "Validation failed", + tool_name="create_issue", + field_errors={ + "title": "Title is required", + "priority": "Invalid priority value", + }, + ) + assert error.tool_name == "create_issue" + assert error.field_errors == { + "title": "Title is required", + "priority": "Invalid priority value", + } + assert "tool=create_issue" in str(error) + assert "fields=['title', 'priority']" in str(error) + + +class TestExceptionInheritance: + """Tests for exception inheritance chain.""" + + def test_all_errors_inherit_from_mcp_error(self): + """Test that all custom exceptions inherit from MCPError.""" + assert issubclass(MCPConnectionError, MCPError) + assert issubclass(MCPTimeoutError, MCPError) + assert issubclass(MCPToolError, MCPError) + assert issubclass(MCPServerNotFoundError, MCPError) + assert issubclass(MCPToolNotFoundError, MCPError) + assert issubclass(MCPCircuitOpenError, MCPError) + assert issubclass(MCPValidationError, MCPError) + + def test_all_errors_inherit_from_exception(self): + """Test that base MCPError inherits from Exception.""" + assert issubclass(MCPError, Exception) + + def test_catch_all_with_mcp_error(self): + """Test that all errors can be caught with MCPError.""" + + def raise_connection_error(): + raise MCPConnectionError("Connection failed") + + def raise_timeout_error(): + raise MCPTimeoutError("Timeout") + + def raise_tool_error(): + raise MCPToolError("Tool failed") + + with pytest.raises(MCPError): + raise_connection_error() + + with pytest.raises(MCPError): + raise_timeout_error() + + with pytest.raises(MCPError): + raise_tool_error() diff --git a/backend/tests/services/mcp/test_registry.py b/backend/tests/services/mcp/test_registry.py new file mode 100644 index 0000000..d881c63 --- /dev/null +++ b/backend/tests/services/mcp/test_registry.py @@ -0,0 +1,272 @@ +""" +Tests for MCP Server Registry +""" + +import pytest + +from app.services.mcp.config import MCPConfig, MCPServerConfig, TransportType +from app.services.mcp.exceptions import MCPServerNotFoundError +from app.services.mcp.registry import ( + MCPServerRegistry, + ServerCapabilities, + get_registry, +) + + +@pytest.fixture +def reset_registry(): + """Reset the singleton registry before and after each test.""" + MCPServerRegistry.reset_instance() + yield + MCPServerRegistry.reset_instance() + + +@pytest.fixture +def sample_config(): + """Create a sample MCP configuration.""" + return MCPConfig( + mcp_servers={ + "server-1": MCPServerConfig( + url="http://server1:8000", + timeout=30, + enabled=True, + ), + "server-2": MCPServerConfig( + url="http://server2:8000", + timeout=60, + enabled=True, + ), + "disabled-server": MCPServerConfig( + url="http://disabled:8000", + enabled=False, + ), + } + ) + + +class TestServerCapabilities: + """Tests for ServerCapabilities class.""" + + def test_empty_capabilities(self): + """Test creating empty capabilities.""" + caps = ServerCapabilities() + assert caps.tools == [] + assert caps.resources == [] + assert caps.prompts == [] + assert caps.is_loaded is False + assert caps.tool_names == [] + + def test_capabilities_with_tools(self): + """Test capabilities with tools.""" + caps = ServerCapabilities( + tools=[ + {"name": "tool1", "description": "Tool 1"}, + {"name": "tool2", "description": "Tool 2"}, + ] + ) + assert len(caps.tools) == 2 + assert caps.tool_names == ["tool1", "tool2"] + + def test_mark_loaded(self): + """Test marking capabilities as loaded.""" + caps = ServerCapabilities() + assert caps.is_loaded is False + assert caps._load_time is None + + caps.mark_loaded() + assert caps.is_loaded is True + assert caps._load_time is not None + + +class TestMCPServerRegistry: + """Tests for MCPServerRegistry singleton.""" + + def test_singleton_pattern(self, reset_registry): + """Test that registry is a singleton.""" + registry1 = MCPServerRegistry() + registry2 = MCPServerRegistry() + assert registry1 is registry2 + + def test_get_instance(self, reset_registry): + """Test get_instance class method.""" + registry = MCPServerRegistry.get_instance() + assert registry is MCPServerRegistry() + + def test_reset_instance(self, reset_registry): + """Test resetting singleton instance.""" + registry1 = MCPServerRegistry() + MCPServerRegistry.reset_instance() + registry2 = MCPServerRegistry() + assert registry1 is not registry2 + + def test_load_config(self, reset_registry, sample_config): + """Test loading configuration.""" + registry = MCPServerRegistry() + registry.load_config(sample_config) + + assert len(registry.list_servers()) == 3 + assert "server-1" in registry.list_servers() + assert "server-2" in registry.list_servers() + assert "disabled-server" in registry.list_servers() + + def test_list_enabled_servers(self, reset_registry, sample_config): + """Test listing only enabled servers.""" + registry = MCPServerRegistry() + registry.load_config(sample_config) + + enabled = registry.list_enabled_servers() + assert len(enabled) == 2 + assert "server-1" in enabled + assert "server-2" in enabled + assert "disabled-server" not in enabled + + def test_register(self, reset_registry): + """Test registering a new server.""" + registry = MCPServerRegistry() + config = MCPServerConfig(url="http://new:8000") + + registry.register("new-server", config) + assert "new-server" in registry.list_servers() + assert registry.get("new-server").url == "http://new:8000" + + def test_unregister(self, reset_registry, sample_config): + """Test unregistering a server.""" + registry = MCPServerRegistry() + registry.load_config(sample_config) + + assert registry.unregister("server-1") is True + assert "server-1" not in registry.list_servers() + + # Unregistering non-existent server returns False + assert registry.unregister("nonexistent") is False + + def test_get(self, reset_registry, sample_config): + """Test getting server config.""" + registry = MCPServerRegistry() + registry.load_config(sample_config) + + config = registry.get("server-1") + assert config.url == "http://server1:8000" + assert config.timeout == 30 + + def test_get_not_found(self, reset_registry, sample_config): + """Test getting non-existent server raises error.""" + registry = MCPServerRegistry() + registry.load_config(sample_config) + + with pytest.raises(MCPServerNotFoundError) as exc_info: + registry.get("nonexistent") + + assert exc_info.value.server_name == "nonexistent" + assert "server-1" in exc_info.value.available_servers + + def test_get_or_none(self, reset_registry, sample_config): + """Test get_or_none method.""" + registry = MCPServerRegistry() + registry.load_config(sample_config) + + config = registry.get_or_none("server-1") + assert config is not None + + config = registry.get_or_none("nonexistent") + assert config is None + + def test_get_all_configs(self, reset_registry, sample_config): + """Test getting all configs.""" + registry = MCPServerRegistry() + registry.load_config(sample_config) + + configs = registry.get_all_configs() + assert len(configs) == 3 + + def test_get_enabled_configs(self, reset_registry, sample_config): + """Test getting enabled configs.""" + registry = MCPServerRegistry() + registry.load_config(sample_config) + + configs = registry.get_enabled_configs() + assert len(configs) == 2 + assert "disabled-server" not in configs + + @pytest.mark.asyncio + async def test_get_capabilities(self, reset_registry, sample_config): + """Test getting server capabilities.""" + registry = MCPServerRegistry() + registry.load_config(sample_config) + + # Initially empty capabilities + caps = await registry.get_capabilities("server-1") + assert caps.is_loaded is False + + def test_set_capabilities(self, reset_registry, sample_config): + """Test setting server capabilities.""" + registry = MCPServerRegistry() + registry.load_config(sample_config) + + registry.set_capabilities( + "server-1", + tools=[{"name": "tool1"}, {"name": "tool2"}], + resources=[{"name": "resource1"}], + ) + + caps = registry._capabilities["server-1"] + assert len(caps.tools) == 2 + assert len(caps.resources) == 1 + assert caps.is_loaded is True + + def test_find_server_for_tool(self, reset_registry, sample_config): + """Test finding server that provides a tool.""" + registry = MCPServerRegistry() + registry.load_config(sample_config) + + registry.set_capabilities( + "server-1", + tools=[{"name": "tool1"}, {"name": "tool2"}], + ) + registry.set_capabilities( + "server-2", + tools=[{"name": "tool3"}], + ) + + assert registry.find_server_for_tool("tool1") == "server-1" + assert registry.find_server_for_tool("tool3") == "server-2" + assert registry.find_server_for_tool("unknown") is None + + def test_get_all_tools(self, reset_registry, sample_config): + """Test getting all tools from all servers.""" + registry = MCPServerRegistry() + registry.load_config(sample_config) + + registry.set_capabilities( + "server-1", + tools=[{"name": "tool1"}], + ) + registry.set_capabilities( + "server-2", + tools=[{"name": "tool2"}, {"name": "tool3"}], + ) + + all_tools = registry.get_all_tools() + assert len(all_tools) == 2 + assert len(all_tools["server-1"]) == 1 + assert len(all_tools["server-2"]) == 2 + + def test_global_config_property(self, reset_registry, sample_config): + """Test accessing global config.""" + registry = MCPServerRegistry() + registry.load_config(sample_config) + + global_config = registry.global_config + assert global_config is not None + assert len(global_config.mcp_servers) == 3 + + +class TestGetRegistry: + """Tests for get_registry convenience function.""" + + def test_get_registry_returns_singleton(self, reset_registry): + """Test that get_registry returns the singleton.""" + registry1 = get_registry() + registry2 = get_registry() + assert registry1 is registry2 + assert registry1 is MCPServerRegistry() diff --git a/backend/tests/services/mcp/test_routing.py b/backend/tests/services/mcp/test_routing.py new file mode 100644 index 0000000..a039637 --- /dev/null +++ b/backend/tests/services/mcp/test_routing.py @@ -0,0 +1,345 @@ +""" +Tests for MCP Tool Call Routing +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from app.services.mcp.config import MCPConfig, MCPServerConfig +from app.services.mcp.connection import ConnectionPool +from app.services.mcp.exceptions import ( + MCPCircuitOpenError, + MCPToolError, + MCPToolNotFoundError, +) +from app.services.mcp.registry import MCPServerRegistry +from app.services.mcp.routing import ToolInfo, ToolResult, ToolRouter + + +@pytest.fixture +def reset_registry(): + """Reset the singleton registry before and after each test.""" + MCPServerRegistry.reset_instance() + yield + MCPServerRegistry.reset_instance() + + +@pytest.fixture +def registry(reset_registry): + """Create a configured registry.""" + reg = MCPServerRegistry() + reg.load_config( + MCPConfig( + mcp_servers={ + "server-1": MCPServerConfig( + url="http://server1:8000", + retry_attempts=1, + retry_delay=0.1, # Minimum allowed value + circuit_breaker_threshold=3, + circuit_breaker_timeout=5.0, + ), + "server-2": MCPServerConfig( + url="http://server2:8000", + retry_attempts=1, + retry_delay=0.1, # Minimum allowed value + ), + } + ) + ) + return reg + + +@pytest.fixture +def pool(): + """Create a connection pool.""" + return ConnectionPool() + + +@pytest.fixture +def router(registry, pool): + """Create a tool router.""" + return ToolRouter(registry, pool) + + +class TestToolInfo: + """Tests for ToolInfo dataclass.""" + + def test_basic_tool_info(self): + """Test creating basic tool info.""" + info = ToolInfo(name="test-tool") + assert info.name == "test-tool" + assert info.description is None + assert info.server_name is None + assert info.input_schema is None + + def test_full_tool_info(self): + """Test creating full tool info.""" + info = ToolInfo( + name="create_issue", + description="Create a new issue", + server_name="issues", + input_schema={"type": "object", "properties": {"title": {"type": "string"}}}, + ) + assert info.name == "create_issue" + assert info.description == "Create a new issue" + assert info.server_name == "issues" + assert "properties" in info.input_schema + + def test_to_dict(self): + """Test converting to dictionary.""" + info = ToolInfo( + name="test-tool", + description="A test tool", + server_name="test-server", + ) + result = info.to_dict() + + assert result["name"] == "test-tool" + assert result["description"] == "A test tool" + assert result["server_name"] == "test-server" + + +class TestToolResult: + """Tests for ToolResult dataclass.""" + + def test_success_result(self): + """Test creating success result.""" + result = ToolResult( + success=True, + data={"id": "123"}, + tool_name="create_issue", + server_name="issues", + ) + assert result.success is True + assert result.data == {"id": "123"} + assert result.error is None + + def test_error_result(self): + """Test creating error result.""" + result = ToolResult( + success=False, + error="Tool execution failed", + error_code="INTERNAL_ERROR", + tool_name="create_issue", + server_name="issues", + ) + assert result.success is False + assert result.error == "Tool execution failed" + assert result.error_code == "INTERNAL_ERROR" + + def test_to_dict(self): + """Test converting to dictionary.""" + result = ToolResult( + success=True, + data={"result": "ok"}, + tool_name="test", + execution_time_ms=123.45, + ) + d = result.to_dict() + + assert d["success"] is True + assert d["data"] == {"result": "ok"} + assert d["tool_name"] == "test" + assert d["execution_time_ms"] == 123.45 + assert "request_id" in d # Auto-generated + + +class TestToolRouter: + """Tests for ToolRouter class.""" + + @pytest.mark.asyncio + async def test_register_tool_mapping(self, router): + """Test registering tool mappings.""" + await router.register_tool_mapping("tool1", "server-1") + await router.register_tool_mapping("tool2", "server-2") + + assert router.find_server_for_tool("tool1") == "server-1" + assert router.find_server_for_tool("tool2") == "server-2" + + def test_find_server_for_unknown_tool(self, router): + """Test finding server for unknown tool.""" + result = router.find_server_for_tool("unknown-tool") + assert result is None + + @pytest.mark.asyncio + async def test_call_tool_success(self, router, registry): + """Test successful tool call.""" + # Set up capabilities + registry.set_capabilities( + "server-1", + tools=[{"name": "test-tool"}], + ) + await router.register_tool_mapping("test-tool", "server-1") + + # Mock the pool connection and request + mock_conn = AsyncMock() + mock_conn.execute_request = AsyncMock( + return_value={"result": {"status": "ok"}} + ) + mock_conn.is_connected = True + + with patch.object(router._pool, "get_connection", return_value=mock_conn): + result = await router.call_tool( + server_name="server-1", + tool_name="test-tool", + arguments={"param": "value"}, + ) + + assert result.success is True + assert result.data == {"status": "ok"} + assert result.tool_name == "test-tool" + assert result.server_name == "server-1" + assert result.execution_time_ms > 0 + + @pytest.mark.asyncio + async def test_call_tool_error_response(self, router, registry): + """Test tool call with error response.""" + registry.set_capabilities( + "server-1", + tools=[{"name": "test-tool"}], + ) + await router.register_tool_mapping("test-tool", "server-1") + + mock_conn = AsyncMock() + mock_conn.execute_request = AsyncMock( + return_value={ + "error": { + "code": -32000, + "message": "Tool execution failed", + } + } + ) + mock_conn.is_connected = True + + with patch.object(router._pool, "get_connection", return_value=mock_conn): + result = await router.call_tool( + server_name="server-1", + tool_name="test-tool", + arguments={}, + ) + + assert result.success is False + assert "Tool execution failed" in result.error + + @pytest.mark.asyncio + async def test_route_tool(self, router, registry): + """Test routing tool to correct server.""" + registry.set_capabilities( + "server-1", + tools=[{"name": "tool-on-server-1"}], + ) + await router.register_tool_mapping("tool-on-server-1", "server-1") + + mock_conn = AsyncMock() + mock_conn.execute_request = AsyncMock( + return_value={"result": "routed"} + ) + mock_conn.is_connected = True + + with patch.object(router._pool, "get_connection", return_value=mock_conn): + result = await router.route_tool( + tool_name="tool-on-server-1", + arguments={"key": "value"}, + ) + + assert result.success is True + assert result.server_name == "server-1" + + @pytest.mark.asyncio + async def test_route_tool_not_found(self, router): + """Test routing unknown tool raises error.""" + with pytest.raises(MCPToolNotFoundError) as exc_info: + await router.route_tool( + tool_name="unknown-tool", + arguments={}, + ) + + assert exc_info.value.tool_name == "unknown-tool" + + @pytest.mark.asyncio + async def test_list_all_tools(self, router, registry): + """Test listing all tools.""" + registry.set_capabilities( + "server-1", + tools=[ + {"name": "tool1", "description": "Tool 1"}, + {"name": "tool2", "description": "Tool 2"}, + ], + ) + registry.set_capabilities( + "server-2", + tools=[{"name": "tool3", "description": "Tool 3"}], + ) + + tools = await router.list_all_tools() + + assert len(tools) == 3 + tool_names = [t.name for t in tools] + assert "tool1" in tool_names + assert "tool2" in tool_names + assert "tool3" in tool_names + + def test_circuit_breaker_status(self, router, registry): + """Test getting circuit breaker status.""" + # Initially no circuit breakers + status = router.get_circuit_breaker_status() + assert status == {} + + @pytest.mark.asyncio + async def test_reset_circuit_breaker(self, router, registry): + """Test resetting circuit breaker.""" + # Reset non-existent returns False + result = await router.reset_circuit_breaker("server-1") + assert result is False + + @pytest.mark.asyncio + async def test_discover_tools(self, router, registry): + """Test tool discovery from servers.""" + # Create mocks for different servers + mock_conn_1 = AsyncMock() + mock_conn_1.execute_request = AsyncMock( + return_value={ + "tools": [ + {"name": "discovered-tool", "description": "A discovered tool"}, + ] + } + ) + mock_conn_1.server_name = "server-1" + mock_conn_1.is_connected = True + + mock_conn_2 = AsyncMock() + mock_conn_2.execute_request = AsyncMock( + return_value={"tools": []} # Empty for server-2 + ) + mock_conn_2.server_name = "server-2" + mock_conn_2.is_connected = True + + async def get_connection_side_effect(server_name, _config): + if server_name == "server-1": + return mock_conn_1 + return mock_conn_2 + + with patch.object( + router._pool, + "get_connection", + side_effect=get_connection_side_effect, + ): + await router.discover_tools() + + # Check that tool mapping was registered + server = router.find_server_for_tool("discovered-tool") + assert server == "server-1" + + def test_calculate_retry_delay(self, router, registry): + """Test retry delay calculation.""" + config = registry.get("server-1") + + delay1 = router._calculate_retry_delay(1, config) + delay2 = router._calculate_retry_delay(2, config) + delay3 = router._calculate_retry_delay(3, config) + + # Delays should increase with attempts + assert delay1 > 0 + # Allow for jitter variation + assert delay1 <= config.retry_max_delay * 1.25 diff --git a/backend/uv.lock b/backend/uv.lock index 6c26b7d..c926d28 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -582,15 +582,18 @@ dependencies = [ { name = "fastapi" }, { name = "fastapi-utils" }, { name = "httpx" }, + { name = "mcp" }, { name = "passlib" }, { name = "pillow" }, { name = "psycopg2-binary" }, + { name = "pybreaker" }, { name = "pydantic" }, { name = "pydantic-settings" }, { name = "python-dotenv" }, { name = "python-jose" }, { name = "python-multipart" }, { name = "pytz" }, + { name = "pyyaml" }, { name = "slowapi" }, { name = "sqlalchemy" }, { name = "sse-starlette" }, @@ -632,10 +635,12 @@ requires-dist = [ { name = "fastapi-utils", specifier = "==0.8.0" }, { name = "freezegun", marker = "extra == 'dev'", specifier = "~=1.5.1" }, { name = "httpx", specifier = ">=0.27.0" }, + { name = "mcp", specifier = ">=1.0.0" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.8.0" }, { name = "passlib", specifier = "==1.7.4" }, { name = "pillow", specifier = ">=10.3.0" }, { name = "psycopg2-binary", specifier = ">=2.9.9" }, + { name = "pybreaker", specifier = ">=1.0.0" }, { name = "pydantic", specifier = ">=2.10.6" }, { name = "pydantic-settings", specifier = ">=2.2.1" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" }, @@ -646,6 +651,7 @@ requires-dist = [ { name = "python-jose", specifier = "==3.4.0" }, { name = "python-multipart", specifier = ">=0.0.19" }, { name = "pytz", specifier = ">=2024.1" }, + { name = "pyyaml", specifier = ">=6.0.0" }, { name = "requests", marker = "extra == 'dev'", specifier = ">=2.32.0" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.8.0" }, { name = "schemathesis", marker = "extra == 'e2e'", specifier = ">=3.30.0" }, @@ -805,6 +811,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, ] +[[package]] +name = "httpx-sse" +version = "0.4.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0f/4c/751061ffa58615a32c31b2d82e8482be8dd4a89154f003147acee90f2be9/httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d", size = 15943, upload-time = "2025-10-10T21:48:22.271Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/fd/6668e5aec43ab844de6fc74927e155a3b37bf40d7c3790e49fc0406b6578/httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc", size = 8960, upload-time = "2025-10-10T21:48:21.158Z" }, +] + [[package]] name = "hypothesis" version = "6.148.2" @@ -1063,6 +1078,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/70/bc/6f1c2f612465f5fa89b95bead1f44dcb607670fd42891d8fdcd5d039f4f4/markupsafe-3.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:32001d6a8fc98c8cb5c947787c5d08b0a50663d139f1305bac5885d98d9b40fa", size = 14146, upload-time = "2025-09-27T18:37:28.327Z" }, ] +[[package]] +name = "mcp" +version = "1.25.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "httpx" }, + { name = "httpx-sse" }, + { name = "jsonschema" }, + { name = "pydantic" }, + { name = "pydantic-settings" }, + { name = "pyjwt", extra = ["crypto"] }, + { name = "python-multipart" }, + { name = "pywin32", marker = "sys_platform == 'win32'" }, + { name = "sse-starlette" }, + { name = "starlette" }, + { name = "typing-extensions" }, + { name = "typing-inspection" }, + { name = "uvicorn", marker = "sys_platform != 'emscripten'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d5/2d/649d80a0ecf6a1f82632ca44bec21c0461a9d9fc8934d38cb5b319f2db5e/mcp-1.25.0.tar.gz", hash = "sha256:56310361ebf0364e2d438e5b45f7668cbb124e158bb358333cd06e49e83a6802", size = 605387, upload-time = "2025-12-19T10:19:56.985Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e2/fc/6dc7659c2ae5ddf280477011f4213a74f806862856b796ef08f028e664bf/mcp-1.25.0-py3-none-any.whl", hash = "sha256:b37c38144a666add0862614cc79ec276e97d72aa8ca26d622818d4e278b9721a", size = 233076, upload-time = "2025-12-19T10:19:55.416Z" }, +] + [[package]] name = "mdurl" version = "0.1.2" @@ -1294,6 +1334,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/62/1e/a94a8d635fa3ce4cfc7f506003548d0a2447ae76fd5ca53932970fe3053f/pyasn1-0.4.8-py2.py3-none-any.whl", hash = "sha256:39c7e2ec30515947ff4e87fb6f456dfc6e84857d34be479c9d4a4ba4bf46aa5d", size = 77145, upload-time = "2019-11-16T17:27:11.07Z" }, ] +[[package]] +name = "pybreaker" +version = "1.4.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f2/89/fbf98e383f1ec6d117af2cd983efdb3eb7018b63834c427025764194cac2/pybreaker-1.4.1.tar.gz", hash = "sha256:8df2d245c73ba40c8242c56ffb4f12138fbadc23e296224740c2028ea9dc1178", size = 15555, upload-time = "2025-09-21T15:12:04.499Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/44/75/e64d3d40a741e2be21d69154f4e5c43a66f0c603c5ef11f49e01429a5932/pybreaker-1.4.1-py3-none-any.whl", hash = "sha256:b4dab4a05195b7f2a64a6c1a6c4ba7a96534ef56ea7210e6bcb59f28897160e0", size = 12915, upload-time = "2025-09-21T15:12:02.284Z" }, +] + [[package]] name = "pycparser" version = "2.23" @@ -1412,6 +1461,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, ] +[[package]] +name = "pyjwt" +version = "2.10.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/46/bd74733ff231675599650d3e47f361794b22ef3e3770998dda30d3b63726/pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953", size = 87785, upload-time = "2024-11-28T03:43:29.933Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997, upload-time = "2024-11-28T03:43:27.893Z" }, +] + +[package.optional-dependencies] +crypto = [ + { name = "cryptography" }, +] + [[package]] name = "pyrate-limiter" version = "3.9.0" diff --git a/docs/development/WORKFLOW.md b/docs/development/WORKFLOW.md index 9209f25..d8f0b4d 100644 --- a/docs/development/WORKFLOW.md +++ b/docs/development/WORKFLOW.md @@ -214,6 +214,71 @@ test(frontend): add unit tests for ProjectDashboard --- +## Phase 2+ Implementation Workflow + +**For complex infrastructure issues (Phase 2 MCP, core systems), follow this rigorous process:** + +### 1. Branch Setup +```bash +# Create feature branch from dev +git checkout dev && git pull +git checkout -b feature/- +``` + +### 2. Planning Phase +- Read the issue thoroughly, understand ALL sub-tasks +- Identify components and their dependencies +- Determine if multi-agent parallel execution is appropriate +- Create a detailed execution plan before writing any code + +### 3. Implementation with Continuous Testing +**After EACH sub-task:** +- [ ] Run unit tests for the component +- [ ] Run integration tests if applicable +- [ ] Verify type checking passes +- [ ] Verify linting passes +- [ ] Keep coverage high (aim for >90%) + +**Both modules must be tested:** +- Backend: `IS_TEST=True uv run pytest` + E2E tests +- Frontend: `npm test` + E2E tests (if applicable) + +### 4. Multi-Agent Review (MANDATORY before considering done) +Before closing an issue, perform deep review from multiple angles: + +| Review Type | Focus Areas | +|-------------|-------------| +| **Code Quality** | Logic errors, edge cases, race conditions, error handling | +| **Security** | OWASP Top 10, input validation, authentication, authorization | +| **Performance** | N+1 queries, memory leaks, inefficient algorithms | +| **Architecture** | Pattern adherence, separation of concerns, extensibility | +| **Testing** | Coverage completeness, test quality, edge case coverage | +| **Documentation** | Code comments, README, API docs, usage examples | + +**No stone unturned. No sloppy results. No unreviewed work.** + +### 5. Final Validation +- [ ] All tests pass (unit, integration, E2E) +- [ ] Type checking passes +- [ ] Linting passes +- [ ] Documentation updated +- [ ] Coverage meets threshold +- [ ] Issue checklist 100% complete +- [ ] Multi-agent review passed + +### When to Use Parallel Agents +Use multiple agents working in parallel when: +- Sub-tasks are independent (no shared state/dependencies) +- Different expertise areas (backend vs frontend) +- Time-critical deliveries with clear boundaries + +Do NOT use parallel agents when: +- Tasks share state or have dependencies +- Sequential testing is required +- Integration points need careful coordination + +--- + ## Quick Reference | Action | Command/Location |