forked from cardosofelipe/fast-next-template
feat(backend): implement MCP client infrastructure (#55)
Core MCP client implementation with comprehensive tooling:
**Services:**
- MCPClientManager: Main facade for all MCP operations
- MCPServerRegistry: Thread-safe singleton for server configs
- ConnectionPool: Connection pooling with auto-reconnection
- ToolRouter: Automatic tool routing with circuit breaker
- AsyncCircuitBreaker: Custom async-compatible circuit breaker
**Configuration:**
- YAML-based config with Pydantic models
- Environment variable expansion support
- Transport types: HTTP, SSE, STDIO
**API Endpoints:**
- GET /mcp/servers - List all MCP servers
- GET /mcp/servers/{name}/tools - List server tools
- GET /mcp/tools - List all tools from all servers
- GET /mcp/health - Health check all servers
- POST /mcp/call - Execute tool (admin only)
- GET /mcp/circuit-breakers - Circuit breaker status
- POST /mcp/circuit-breakers/{name}/reset - Reset circuit breaker
- POST /mcp/servers/{name}/reconnect - Force reconnection
**Testing:**
- 156 unit tests with comprehensive coverage
- Tests for all services, routes, and error handling
- Proper mocking and async test support
**Documentation:**
- MCP_CLIENT.md with usage examples
- Phase 2+ workflow documentation
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -7,6 +7,7 @@ from app.api.routes import (
|
||||
auth,
|
||||
events,
|
||||
issues,
|
||||
mcp,
|
||||
oauth,
|
||||
oauth_provider,
|
||||
organizations,
|
||||
@@ -31,6 +32,9 @@ api_router.include_router(
|
||||
# SSE events router - no prefix, routes define full paths
|
||||
api_router.include_router(events.router, tags=["Events"])
|
||||
|
||||
# MCP (Model Context Protocol) router
|
||||
api_router.include_router(mcp.router, prefix="/mcp", tags=["MCP"])
|
||||
|
||||
# Syndarix domain routers
|
||||
api_router.include_router(projects.router, prefix="/projects", tags=["Projects"])
|
||||
api_router.include_router(
|
||||
|
||||
444
backend/app/api/routes/mcp.py
Normal file
444
backend/app/api/routes/mcp.py
Normal file
@@ -0,0 +1,444 @@
|
||||
"""
|
||||
MCP (Model Context Protocol) API Endpoints
|
||||
|
||||
Provides REST endpoints for managing MCP server connections
|
||||
and executing tool calls.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.api.dependencies.permissions import require_superuser
|
||||
from app.models.user import User
|
||||
from app.services.mcp import (
|
||||
MCPCircuitOpenError,
|
||||
MCPClientManager,
|
||||
MCPConnectionError,
|
||||
MCPError,
|
||||
MCPServerNotFoundError,
|
||||
MCPTimeoutError,
|
||||
MCPToolError,
|
||||
MCPToolNotFoundError,
|
||||
get_mcp_client,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Server name validation pattern: alphanumeric, hyphens, underscores, 1-64 chars
|
||||
SERVER_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,64}$")
|
||||
|
||||
# Type alias for validated server name path parameter
|
||||
ServerNamePath = Annotated[
|
||||
str,
|
||||
Path(
|
||||
description="MCP server name",
|
||||
min_length=1,
|
||||
max_length=64,
|
||||
pattern=r"^[a-zA-Z0-9_-]+$",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Request/Response Schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class ServerInfo(BaseModel):
|
||||
"""Information about an MCP server."""
|
||||
|
||||
name: str = Field(..., description="Server name")
|
||||
url: str = Field(..., description="Server URL")
|
||||
enabled: bool = Field(..., description="Whether server is enabled")
|
||||
timeout: int = Field(..., description="Request timeout in seconds")
|
||||
transport: str = Field(..., description="Transport type (http, stdio, sse)")
|
||||
description: str | None = Field(None, description="Server description")
|
||||
|
||||
|
||||
class ServerListResponse(BaseModel):
|
||||
"""Response containing list of MCP servers."""
|
||||
|
||||
servers: list[ServerInfo]
|
||||
total: int
|
||||
|
||||
|
||||
class ToolInfoResponse(BaseModel):
|
||||
"""Information about an MCP tool."""
|
||||
|
||||
name: str = Field(..., description="Tool name")
|
||||
description: str | None = Field(None, description="Tool description")
|
||||
server_name: str | None = Field(None, description="Server providing the tool")
|
||||
input_schema: dict[str, Any] | None = Field(None, description="JSON schema for input")
|
||||
|
||||
|
||||
class ToolListResponse(BaseModel):
|
||||
"""Response containing list of tools."""
|
||||
|
||||
tools: list[ToolInfoResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class ServerHealthStatus(BaseModel):
|
||||
"""Health status for a server."""
|
||||
|
||||
name: str
|
||||
healthy: bool
|
||||
state: str
|
||||
url: str
|
||||
error: str | None = None
|
||||
tools_count: int = 0
|
||||
|
||||
|
||||
class HealthCheckResponse(BaseModel):
|
||||
"""Response containing health status of all servers."""
|
||||
|
||||
servers: dict[str, ServerHealthStatus]
|
||||
healthy_count: int
|
||||
unhealthy_count: int
|
||||
total: int
|
||||
|
||||
|
||||
class ToolCallRequest(BaseModel):
|
||||
"""Request to execute a tool."""
|
||||
|
||||
server: str = Field(..., description="MCP server name")
|
||||
tool: str = Field(..., description="Tool name to execute")
|
||||
arguments: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Tool arguments",
|
||||
)
|
||||
timeout: float | None = Field(
|
||||
None,
|
||||
description="Optional timeout override in seconds",
|
||||
)
|
||||
|
||||
|
||||
class ToolCallResponse(BaseModel):
|
||||
"""Response from tool execution."""
|
||||
|
||||
success: bool
|
||||
data: Any | None = None
|
||||
error: str | None = None
|
||||
error_code: str | None = None
|
||||
tool_name: str | None = None
|
||||
server_name: str | None = None
|
||||
execution_time_ms: float = 0.0
|
||||
request_id: str | None = None
|
||||
|
||||
|
||||
class CircuitBreakerStatus(BaseModel):
|
||||
"""Status of a circuit breaker."""
|
||||
|
||||
server_name: str
|
||||
state: str
|
||||
failure_count: int
|
||||
|
||||
|
||||
class CircuitBreakerListResponse(BaseModel):
|
||||
"""Response containing circuit breaker statuses."""
|
||||
|
||||
circuit_breakers: list[CircuitBreakerStatus]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Endpoints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/servers",
|
||||
response_model=ServerListResponse,
|
||||
summary="List MCP Servers",
|
||||
description="Get list of all registered MCP servers with their configurations.",
|
||||
)
|
||||
async def list_servers(
|
||||
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||
) -> ServerListResponse:
|
||||
"""List all registered MCP servers."""
|
||||
servers = []
|
||||
|
||||
for name in mcp.list_servers():
|
||||
try:
|
||||
config = mcp.get_server_config(name)
|
||||
servers.append(
|
||||
ServerInfo(
|
||||
name=name,
|
||||
url=config.url,
|
||||
enabled=config.enabled,
|
||||
timeout=config.timeout,
|
||||
transport=config.transport.value,
|
||||
description=config.description,
|
||||
)
|
||||
)
|
||||
except MCPServerNotFoundError:
|
||||
continue
|
||||
|
||||
return ServerListResponse(
|
||||
servers=servers,
|
||||
total=len(servers),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/servers/{server_name}/tools",
|
||||
response_model=ToolListResponse,
|
||||
summary="List Server Tools",
|
||||
description="Get list of tools available on a specific MCP server.",
|
||||
)
|
||||
async def list_server_tools(
|
||||
server_name: ServerNamePath,
|
||||
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||
) -> ToolListResponse:
|
||||
"""List all tools available on a specific server."""
|
||||
try:
|
||||
tools = await mcp.list_tools(server_name)
|
||||
return ToolListResponse(
|
||||
tools=[
|
||||
ToolInfoResponse(
|
||||
name=t.name,
|
||||
description=t.description,
|
||||
server_name=t.server_name,
|
||||
input_schema=t.input_schema,
|
||||
)
|
||||
for t in tools
|
||||
],
|
||||
total=len(tools),
|
||||
)
|
||||
except MCPServerNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Server not found: {server_name}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tools",
|
||||
response_model=ToolListResponse,
|
||||
summary="List All Tools",
|
||||
description="Get list of all tools from all MCP servers.",
|
||||
)
|
||||
async def list_all_tools(
|
||||
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||
) -> ToolListResponse:
|
||||
"""List all tools from all servers."""
|
||||
tools = await mcp.list_all_tools()
|
||||
return ToolListResponse(
|
||||
tools=[
|
||||
ToolInfoResponse(
|
||||
name=t.name,
|
||||
description=t.description,
|
||||
server_name=t.server_name,
|
||||
input_schema=t.input_schema,
|
||||
)
|
||||
for t in tools
|
||||
],
|
||||
total=len(tools),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/health",
|
||||
response_model=HealthCheckResponse,
|
||||
summary="Health Check",
|
||||
description="Check health status of all MCP servers.",
|
||||
)
|
||||
async def health_check(
|
||||
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||
) -> HealthCheckResponse:
|
||||
"""Perform health check on all MCP servers."""
|
||||
health_results = await mcp.health_check()
|
||||
|
||||
servers = {
|
||||
name: ServerHealthStatus(
|
||||
name=status.name,
|
||||
healthy=status.healthy,
|
||||
state=status.state,
|
||||
url=status.url,
|
||||
error=status.error,
|
||||
tools_count=status.tools_count,
|
||||
)
|
||||
for name, status in health_results.items()
|
||||
}
|
||||
|
||||
healthy_count = sum(1 for s in servers.values() if s.healthy)
|
||||
unhealthy_count = len(servers) - healthy_count
|
||||
|
||||
return HealthCheckResponse(
|
||||
servers=servers,
|
||||
healthy_count=healthy_count,
|
||||
unhealthy_count=unhealthy_count,
|
||||
total=len(servers),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/call",
|
||||
response_model=ToolCallResponse,
|
||||
summary="Execute Tool (Admin Only)",
|
||||
description="Execute a tool on an MCP server. Requires superuser privileges.",
|
||||
)
|
||||
async def call_tool(
|
||||
request: ToolCallRequest,
|
||||
current_user: User = Depends(require_superuser),
|
||||
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||
) -> ToolCallResponse:
|
||||
"""
|
||||
Execute a tool on an MCP server.
|
||||
|
||||
This endpoint is restricted to superusers for direct tool execution.
|
||||
Normal tool execution should go through agent workflows.
|
||||
"""
|
||||
logger.info(
|
||||
"Tool call by user %s: %s.%s",
|
||||
current_user.id,
|
||||
request.server,
|
||||
request.tool,
|
||||
)
|
||||
|
||||
try:
|
||||
result = await mcp.call_tool(
|
||||
server=request.server,
|
||||
tool=request.tool,
|
||||
args=request.arguments,
|
||||
timeout=request.timeout,
|
||||
)
|
||||
|
||||
return ToolCallResponse(
|
||||
success=result.success,
|
||||
data=result.data,
|
||||
error=result.error,
|
||||
error_code=result.error_code,
|
||||
tool_name=result.tool_name,
|
||||
server_name=result.server_name,
|
||||
execution_time_ms=result.execution_time_ms,
|
||||
request_id=result.request_id,
|
||||
)
|
||||
|
||||
except MCPCircuitOpenError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Server temporarily unavailable: {e.server_name}",
|
||||
) from e
|
||||
except MCPToolNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Tool not found: {e.tool_name}",
|
||||
) from e
|
||||
except MCPServerNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Server not found: {e.server_name}",
|
||||
) from e
|
||||
except MCPTimeoutError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
|
||||
detail=str(e),
|
||||
) from e
|
||||
except MCPConnectionError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=str(e),
|
||||
) from e
|
||||
except MCPToolError as e:
|
||||
# Tool errors are returned in the response, not as HTTP errors
|
||||
return ToolCallResponse(
|
||||
success=False,
|
||||
error=str(e),
|
||||
error_code=e.error_code,
|
||||
tool_name=e.tool_name,
|
||||
server_name=e.server_name,
|
||||
)
|
||||
except MCPError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
) from e
|
||||
|
||||
|
||||
@router.get(
|
||||
"/circuit-breakers",
|
||||
response_model=CircuitBreakerListResponse,
|
||||
summary="List Circuit Breakers",
|
||||
description="Get status of all circuit breakers.",
|
||||
)
|
||||
async def list_circuit_breakers(
|
||||
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||
) -> CircuitBreakerListResponse:
|
||||
"""Get status of all circuit breakers."""
|
||||
status_dict = mcp.get_circuit_breaker_status()
|
||||
|
||||
return CircuitBreakerListResponse(
|
||||
circuit_breakers=[
|
||||
CircuitBreakerStatus(
|
||||
server_name=name,
|
||||
state=info.get("state", "unknown"),
|
||||
failure_count=info.get("failure_count", 0),
|
||||
)
|
||||
for name, info in status_dict.items()
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/circuit-breakers/{server_name}/reset",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Reset Circuit Breaker (Admin Only)",
|
||||
description="Manually reset a circuit breaker for a server.",
|
||||
)
|
||||
async def reset_circuit_breaker(
|
||||
server_name: ServerNamePath,
|
||||
current_user: User = Depends(require_superuser),
|
||||
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||
) -> None:
|
||||
"""Manually reset a circuit breaker."""
|
||||
logger.info(
|
||||
"Circuit breaker reset by user %s for server %s",
|
||||
current_user.id,
|
||||
server_name,
|
||||
)
|
||||
|
||||
success = await mcp.reset_circuit_breaker(server_name)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No circuit breaker found for server: {server_name}",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/servers/{server_name}/reconnect",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Reconnect to Server (Admin Only)",
|
||||
description="Force reconnection to an MCP server.",
|
||||
)
|
||||
async def reconnect_server(
|
||||
server_name: ServerNamePath,
|
||||
current_user: User = Depends(require_superuser),
|
||||
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||
) -> None:
|
||||
"""Force reconnection to an MCP server."""
|
||||
logger.info(
|
||||
"Reconnect requested by user %s for server %s",
|
||||
current_user.id,
|
||||
server_name,
|
||||
)
|
||||
|
||||
try:
|
||||
await mcp.disconnect(server_name)
|
||||
await mcp.connect(server_name)
|
||||
except MCPServerNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Server not found: {server_name}",
|
||||
) from e
|
||||
except MCPConnectionError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Failed to reconnect: {e}",
|
||||
) from e
|
||||
85
backend/app/services/mcp/__init__.py
Normal file
85
backend/app/services/mcp/__init__.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""
|
||||
MCP Client Service Package
|
||||
|
||||
Provides infrastructure for communicating with MCP (Model Context Protocol)
|
||||
servers. This is the foundation for AI agent tool integration.
|
||||
|
||||
Usage:
|
||||
from app.services.mcp import get_mcp_client, MCPClientManager
|
||||
|
||||
# In FastAPI route
|
||||
async def my_route(mcp: MCPClientManager = Depends(get_mcp_client)):
|
||||
result = await mcp.call_tool("llm-gateway", "chat", {"prompt": "Hello"})
|
||||
|
||||
# Direct usage
|
||||
manager = MCPClientManager()
|
||||
await manager.initialize()
|
||||
result = await manager.call_tool("issues", "create_issue", {...})
|
||||
await manager.shutdown()
|
||||
"""
|
||||
|
||||
from .client_manager import (
|
||||
MCPClientManager,
|
||||
ServerHealth,
|
||||
get_mcp_client,
|
||||
reset_mcp_client,
|
||||
shutdown_mcp_client,
|
||||
)
|
||||
from .config import (
|
||||
MCPConfig,
|
||||
MCPServerConfig,
|
||||
TransportType,
|
||||
create_default_config,
|
||||
load_mcp_config,
|
||||
)
|
||||
from .connection import ConnectionPool, ConnectionState, MCPConnection
|
||||
from .exceptions import (
|
||||
MCPCircuitOpenError,
|
||||
MCPConnectionError,
|
||||
MCPError,
|
||||
MCPServerNotFoundError,
|
||||
MCPTimeoutError,
|
||||
MCPToolError,
|
||||
MCPToolNotFoundError,
|
||||
MCPValidationError,
|
||||
)
|
||||
from .registry import MCPServerRegistry, ServerCapabilities, get_registry
|
||||
from .routing import AsyncCircuitBreaker, CircuitState, ToolInfo, ToolResult, ToolRouter
|
||||
|
||||
__all__ = [
|
||||
# Main facade
|
||||
"MCPClientManager",
|
||||
"get_mcp_client",
|
||||
"shutdown_mcp_client",
|
||||
"reset_mcp_client",
|
||||
"ServerHealth",
|
||||
# Configuration
|
||||
"MCPConfig",
|
||||
"MCPServerConfig",
|
||||
"TransportType",
|
||||
"load_mcp_config",
|
||||
"create_default_config",
|
||||
# Registry
|
||||
"MCPServerRegistry",
|
||||
"ServerCapabilities",
|
||||
"get_registry",
|
||||
# Connection
|
||||
"ConnectionPool",
|
||||
"ConnectionState",
|
||||
"MCPConnection",
|
||||
# Routing
|
||||
"ToolRouter",
|
||||
"ToolInfo",
|
||||
"ToolResult",
|
||||
"AsyncCircuitBreaker",
|
||||
"CircuitState",
|
||||
# Exceptions
|
||||
"MCPError",
|
||||
"MCPConnectionError",
|
||||
"MCPTimeoutError",
|
||||
"MCPToolError",
|
||||
"MCPServerNotFoundError",
|
||||
"MCPToolNotFoundError",
|
||||
"MCPCircuitOpenError",
|
||||
"MCPValidationError",
|
||||
]
|
||||
417
backend/app/services/mcp/client_manager.py
Normal file
417
backend/app/services/mcp/client_manager.py
Normal file
@@ -0,0 +1,417 @@
|
||||
"""
|
||||
MCP Client Manager
|
||||
|
||||
Main facade for all MCP operations. Manages server connections,
|
||||
tool discovery, and provides a unified interface for tool calls.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from .config import MCPConfig, MCPServerConfig, load_mcp_config
|
||||
from .connection import ConnectionPool, ConnectionState
|
||||
from .exceptions import MCPServerNotFoundError
|
||||
from .registry import MCPServerRegistry, get_registry
|
||||
from .routing import ToolInfo, ToolResult, ToolRouter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServerHealth:
|
||||
"""Health status for an MCP server."""
|
||||
|
||||
name: str
|
||||
healthy: bool
|
||||
state: str
|
||||
url: str
|
||||
error: str | None = None
|
||||
tools_count: int = 0
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"healthy": self.healthy,
|
||||
"state": self.state,
|
||||
"url": self.url,
|
||||
"error": self.error,
|
||||
"tools_count": self.tools_count,
|
||||
}
|
||||
|
||||
|
||||
class MCPClientManager:
|
||||
"""
|
||||
Central manager for all MCP client operations.
|
||||
|
||||
Provides a unified interface for:
|
||||
- Connecting to MCP servers
|
||||
- Discovering and calling tools
|
||||
- Health monitoring
|
||||
- Connection lifecycle management
|
||||
|
||||
This is the main entry point for MCP operations in the application.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MCPConfig | None = None,
|
||||
registry: MCPServerRegistry | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the MCP client manager.
|
||||
|
||||
Args:
|
||||
config: Optional MCP configuration. If None, loads from default.
|
||||
registry: Optional registry instance. If None, uses singleton.
|
||||
"""
|
||||
self._registry = registry or get_registry()
|
||||
self._pool = ConnectionPool()
|
||||
self._router: ToolRouter | None = None
|
||||
self._initialized = False
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Load configuration if provided
|
||||
if config is not None:
|
||||
self._registry.load_config(config)
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if the manager is initialized."""
|
||||
return self._initialized
|
||||
|
||||
async def initialize(self, config: MCPConfig | None = None) -> None:
|
||||
"""
|
||||
Initialize the MCP client manager.
|
||||
|
||||
Loads configuration, creates connections, and discovers tools.
|
||||
|
||||
Args:
|
||||
config: Optional configuration to load
|
||||
"""
|
||||
async with self._lock:
|
||||
if self._initialized:
|
||||
logger.warning("MCPClientManager already initialized")
|
||||
return
|
||||
|
||||
logger.info("Initializing MCP Client Manager")
|
||||
|
||||
# Load configuration
|
||||
if config is not None:
|
||||
self._registry.load_config(config)
|
||||
elif len(self._registry.list_servers()) == 0:
|
||||
# Try to load from default location
|
||||
self._registry.load_config(load_mcp_config())
|
||||
|
||||
# Create router
|
||||
self._router = ToolRouter(self._registry, self._pool)
|
||||
|
||||
# Connect to all enabled servers
|
||||
await self._connect_all_servers()
|
||||
|
||||
# Discover tools from all servers
|
||||
if self._router:
|
||||
await self._router.discover_tools()
|
||||
|
||||
self._initialized = True
|
||||
logger.info(
|
||||
"MCP Client Manager initialized with %d servers",
|
||||
len(self._registry.list_enabled_servers()),
|
||||
)
|
||||
|
||||
async def _connect_all_servers(self) -> None:
|
||||
"""Connect to all enabled MCP servers."""
|
||||
enabled_servers = self._registry.get_enabled_configs()
|
||||
|
||||
for name, config in enabled_servers.items():
|
||||
try:
|
||||
await self._pool.get_connection(name, config)
|
||||
logger.info("Connected to MCP server: %s", name)
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect to MCP server %s: %s", name, e)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""
|
||||
Shutdown the MCP client manager.
|
||||
|
||||
Closes all connections and cleans up resources.
|
||||
"""
|
||||
async with self._lock:
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
logger.info("Shutting down MCP Client Manager")
|
||||
|
||||
await self._pool.close_all()
|
||||
self._initialized = False
|
||||
|
||||
logger.info("MCP Client Manager shutdown complete")
|
||||
|
||||
async def connect(self, server_name: str) -> None:
|
||||
"""
|
||||
Connect to a specific MCP server.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server to connect to
|
||||
|
||||
Raises:
|
||||
MCPServerNotFoundError: If server is not registered
|
||||
"""
|
||||
config = self._registry.get(server_name)
|
||||
await self._pool.get_connection(server_name, config)
|
||||
logger.info("Connected to MCP server: %s", server_name)
|
||||
|
||||
async def disconnect(self, server_name: str) -> None:
|
||||
"""
|
||||
Disconnect from a specific MCP server.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server to disconnect from
|
||||
"""
|
||||
await self._pool.close_connection(server_name)
|
||||
logger.info("Disconnected from MCP server: %s", server_name)
|
||||
|
||||
async def disconnect_all(self) -> None:
|
||||
"""Disconnect from all MCP servers."""
|
||||
await self._pool.close_all()
|
||||
|
||||
async def call_tool(
|
||||
self,
|
||||
server: str,
|
||||
tool: str,
|
||||
args: dict[str, Any] | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> ToolResult:
|
||||
"""
|
||||
Call a tool on a specific MCP server.
|
||||
|
||||
Args:
|
||||
server: Name of the MCP server
|
||||
tool: Name of the tool to call
|
||||
args: Tool arguments
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
"""
|
||||
if not self._initialized or self._router is None:
|
||||
await self.initialize()
|
||||
|
||||
assert self._router is not None # Guaranteed after initialize()
|
||||
return await self._router.call_tool(
|
||||
server_name=server,
|
||||
tool_name=tool,
|
||||
arguments=args,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def route_tool(
|
||||
self,
|
||||
tool: str,
|
||||
args: dict[str, Any] | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> ToolResult:
|
||||
"""
|
||||
Route a tool call to the appropriate server automatically.
|
||||
|
||||
Args:
|
||||
tool: Name of the tool to call
|
||||
args: Tool arguments
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
"""
|
||||
if not self._initialized or self._router is None:
|
||||
await self.initialize()
|
||||
|
||||
assert self._router is not None # Guaranteed after initialize()
|
||||
return await self._router.route_tool(
|
||||
tool_name=tool,
|
||||
arguments=args,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def list_tools(self, server: str) -> list[ToolInfo]:
|
||||
"""
|
||||
List all tools available on a specific server.
|
||||
|
||||
Args:
|
||||
server: Name of the MCP server
|
||||
|
||||
Returns:
|
||||
List of tool information
|
||||
"""
|
||||
capabilities = await self._registry.get_capabilities(server)
|
||||
return [
|
||||
ToolInfo(
|
||||
name=t.get("name", ""),
|
||||
description=t.get("description"),
|
||||
server_name=server,
|
||||
input_schema=t.get("input_schema"),
|
||||
)
|
||||
for t in capabilities.tools
|
||||
]
|
||||
|
||||
async def list_all_tools(self) -> list[ToolInfo]:
|
||||
"""
|
||||
List all tools from all servers.
|
||||
|
||||
Returns:
|
||||
List of tool information
|
||||
"""
|
||||
if not self._initialized or self._router is None:
|
||||
await self.initialize()
|
||||
|
||||
assert self._router is not None # Guaranteed after initialize()
|
||||
return await self._router.list_all_tools()
|
||||
|
||||
async def health_check(self) -> dict[str, ServerHealth]:
|
||||
"""
|
||||
Perform health check on all MCP servers.
|
||||
|
||||
Returns:
|
||||
Dict mapping server names to health status
|
||||
"""
|
||||
results: dict[str, ServerHealth] = {}
|
||||
pool_status = self._pool.get_status()
|
||||
pool_health = await self._pool.health_check_all()
|
||||
|
||||
for server_name in self._registry.list_servers():
|
||||
try:
|
||||
config = self._registry.get(server_name)
|
||||
status = pool_status.get(server_name, {})
|
||||
healthy = pool_health.get(server_name, False)
|
||||
|
||||
capabilities = self._registry.get_cached_capabilities(server_name)
|
||||
|
||||
results[server_name] = ServerHealth(
|
||||
name=server_name,
|
||||
healthy=healthy,
|
||||
state=status.get("state", ConnectionState.DISCONNECTED.value),
|
||||
url=config.url,
|
||||
tools_count=len(capabilities.tools),
|
||||
)
|
||||
except MCPServerNotFoundError:
|
||||
pass
|
||||
except Exception as e:
|
||||
results[server_name] = ServerHealth(
|
||||
name=server_name,
|
||||
healthy=False,
|
||||
state=ConnectionState.ERROR.value,
|
||||
url="unknown",
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def list_servers(self) -> list[str]:
|
||||
"""Get list of all registered server names."""
|
||||
return self._registry.list_servers()
|
||||
|
||||
def list_enabled_servers(self) -> list[str]:
|
||||
"""Get list of enabled server names."""
|
||||
return self._registry.list_enabled_servers()
|
||||
|
||||
def get_server_config(self, server_name: str) -> MCPServerConfig:
|
||||
"""
|
||||
Get configuration for a specific server.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server
|
||||
|
||||
Returns:
|
||||
Server configuration
|
||||
|
||||
Raises:
|
||||
MCPServerNotFoundError: If server is not registered
|
||||
"""
|
||||
return self._registry.get(server_name)
|
||||
|
||||
def register_server(
|
||||
self,
|
||||
name: str,
|
||||
config: MCPServerConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Register a new MCP server at runtime.
|
||||
|
||||
Args:
|
||||
name: Unique server name
|
||||
config: Server configuration
|
||||
"""
|
||||
self._registry.register(name, config)
|
||||
|
||||
def unregister_server(self, name: str) -> bool:
|
||||
"""
|
||||
Unregister an MCP server.
|
||||
|
||||
Args:
|
||||
name: Server name to unregister
|
||||
|
||||
Returns:
|
||||
True if server was found and removed
|
||||
"""
|
||||
return self._registry.unregister(name)
|
||||
|
||||
def get_circuit_breaker_status(self) -> dict[str, dict[str, Any]]:
|
||||
"""Get status of all circuit breakers."""
|
||||
if self._router is None:
|
||||
return {}
|
||||
return self._router.get_circuit_breaker_status()
|
||||
|
||||
async def reset_circuit_breaker(self, server_name: str) -> bool:
|
||||
"""
|
||||
Reset a circuit breaker for a server.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server
|
||||
|
||||
Returns:
|
||||
True if circuit breaker was reset
|
||||
"""
|
||||
if self._router is None:
|
||||
return False
|
||||
return await self._router.reset_circuit_breaker(server_name)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_manager_instance: MCPClientManager | None = None
|
||||
_manager_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_mcp_client() -> MCPClientManager:
|
||||
"""
|
||||
Get the global MCP client manager instance.
|
||||
|
||||
This is the main dependency injection point for FastAPI.
|
||||
Uses proper locking to avoid race conditions in async contexts.
|
||||
"""
|
||||
global _manager_instance
|
||||
|
||||
# Use lock for the entire check-and-create operation to avoid race conditions
|
||||
async with _manager_lock:
|
||||
if _manager_instance is None:
|
||||
_manager_instance = MCPClientManager()
|
||||
await _manager_instance.initialize()
|
||||
|
||||
return _manager_instance
|
||||
|
||||
|
||||
async def shutdown_mcp_client() -> None:
|
||||
"""Shutdown the global MCP client manager."""
|
||||
global _manager_instance
|
||||
|
||||
# Use lock to prevent race with get_mcp_client()
|
||||
async with _manager_lock:
|
||||
if _manager_instance is not None:
|
||||
await _manager_instance.shutdown()
|
||||
_manager_instance = None
|
||||
|
||||
|
||||
def reset_mcp_client() -> None:
|
||||
"""Reset the global MCP client manager (for testing)."""
|
||||
global _manager_instance
|
||||
_manager_instance = None
|
||||
234
backend/app/services/mcp/config.py
Normal file
234
backend/app/services/mcp/config.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""
|
||||
MCP Configuration System
|
||||
|
||||
Pydantic models for MCP server configuration with YAML file loading
|
||||
and environment variable overrides.
|
||||
"""
|
||||
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class TransportType(str, Enum):
|
||||
"""Supported MCP transport types."""
|
||||
|
||||
HTTP = "http"
|
||||
STDIO = "stdio"
|
||||
SSE = "sse"
|
||||
|
||||
|
||||
class MCPServerConfig(BaseModel):
|
||||
"""Configuration for a single MCP server."""
|
||||
|
||||
url: str = Field(..., description="Server URL (supports ${ENV_VAR} syntax)")
|
||||
transport: TransportType = Field(
|
||||
default=TransportType.HTTP,
|
||||
description="Transport protocol to use",
|
||||
)
|
||||
timeout: int = Field(
|
||||
default=30,
|
||||
ge=1,
|
||||
le=600,
|
||||
description="Request timeout in seconds",
|
||||
)
|
||||
retry_attempts: int = Field(
|
||||
default=3,
|
||||
ge=0,
|
||||
le=10,
|
||||
description="Number of retry attempts on failure",
|
||||
)
|
||||
retry_delay: float = Field(
|
||||
default=1.0,
|
||||
ge=0.1,
|
||||
le=60.0,
|
||||
description="Initial delay between retries in seconds",
|
||||
)
|
||||
retry_max_delay: float = Field(
|
||||
default=30.0,
|
||||
ge=1.0,
|
||||
le=300.0,
|
||||
description="Maximum delay between retries in seconds",
|
||||
)
|
||||
circuit_breaker_threshold: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
le=50,
|
||||
description="Number of failures before opening circuit",
|
||||
)
|
||||
circuit_breaker_timeout: float = Field(
|
||||
default=30.0,
|
||||
ge=5.0,
|
||||
le=300.0,
|
||||
description="Seconds to wait before attempting to close circuit",
|
||||
)
|
||||
enabled: bool = Field(
|
||||
default=True,
|
||||
description="Whether this server is enabled",
|
||||
)
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="Human-readable description of the server",
|
||||
)
|
||||
|
||||
@field_validator("url", mode="before")
|
||||
@classmethod
|
||||
def expand_env_vars(cls, v: str) -> str:
|
||||
"""Expand environment variables in URL using ${VAR:-default} syntax."""
|
||||
if not isinstance(v, str):
|
||||
return v
|
||||
|
||||
result = v
|
||||
# Find all ${VAR} or ${VAR:-default} patterns
|
||||
import re
|
||||
|
||||
pattern = r"\$\{([^}]+)\}"
|
||||
matches = re.findall(pattern, v)
|
||||
|
||||
for match in matches:
|
||||
if ":-" in match:
|
||||
var_name, default = match.split(":-", 1)
|
||||
else:
|
||||
var_name, default = match, ""
|
||||
|
||||
env_value = os.environ.get(var_name.strip(), default)
|
||||
result = result.replace(f"${{{match}}}", env_value)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class MCPConfig(BaseModel):
|
||||
"""Root configuration for all MCP servers."""
|
||||
|
||||
mcp_servers: dict[str, MCPServerConfig] = Field(
|
||||
default_factory=dict,
|
||||
description="Map of server names to their configurations",
|
||||
)
|
||||
|
||||
# Global defaults
|
||||
default_timeout: int = Field(
|
||||
default=30,
|
||||
description="Default timeout for all servers",
|
||||
)
|
||||
default_retry_attempts: int = Field(
|
||||
default=3,
|
||||
description="Default retry attempts for all servers",
|
||||
)
|
||||
connection_pool_size: int = Field(
|
||||
default=10,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Maximum connections per server",
|
||||
)
|
||||
health_check_interval: int = Field(
|
||||
default=30,
|
||||
ge=5,
|
||||
le=300,
|
||||
description="Seconds between health checks",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, path: str | Path) -> "MCPConfig":
|
||||
"""Load configuration from a YAML file."""
|
||||
path = Path(path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"MCP config file not found: {path}")
|
||||
|
||||
with path.open("r") as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
if data is None:
|
||||
data = {}
|
||||
|
||||
return cls.model_validate(data)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "MCPConfig":
|
||||
"""Load configuration from a dictionary."""
|
||||
return cls.model_validate(data)
|
||||
|
||||
def get_server(self, name: str) -> MCPServerConfig | None:
|
||||
"""Get a server configuration by name."""
|
||||
return self.mcp_servers.get(name)
|
||||
|
||||
def get_enabled_servers(self) -> dict[str, MCPServerConfig]:
|
||||
"""Get all enabled server configurations."""
|
||||
return {
|
||||
name: config
|
||||
for name, config in self.mcp_servers.items()
|
||||
if config.enabled
|
||||
}
|
||||
|
||||
def list_server_names(self) -> list[str]:
|
||||
"""Get list of all configured server names."""
|
||||
return list(self.mcp_servers.keys())
|
||||
|
||||
|
||||
# Default configuration path
|
||||
DEFAULT_CONFIG_PATH = Path(__file__).parent.parent.parent.parent / "mcp_servers.yaml"
|
||||
|
||||
|
||||
def load_mcp_config(path: str | Path | None = None) -> MCPConfig:
|
||||
"""
|
||||
Load MCP configuration from file or environment.
|
||||
|
||||
Priority:
|
||||
1. Explicit path parameter
|
||||
2. MCP_CONFIG_PATH environment variable
|
||||
3. Default path (backend/mcp_servers.yaml)
|
||||
4. Empty config if no file exists
|
||||
"""
|
||||
if path is None:
|
||||
path = os.environ.get("MCP_CONFIG_PATH", str(DEFAULT_CONFIG_PATH))
|
||||
|
||||
path = Path(path)
|
||||
|
||||
if not path.exists():
|
||||
# Return empty config if no file exists (allows runtime registration)
|
||||
return MCPConfig()
|
||||
|
||||
return MCPConfig.from_yaml(path)
|
||||
|
||||
|
||||
def create_default_config() -> MCPConfig:
|
||||
"""
|
||||
Create a default MCP configuration with standard servers.
|
||||
|
||||
This is useful for development and as a template.
|
||||
"""
|
||||
return MCPConfig(
|
||||
mcp_servers={
|
||||
"llm-gateway": MCPServerConfig(
|
||||
url="${LLM_GATEWAY_URL:-http://localhost:8001}",
|
||||
transport=TransportType.HTTP,
|
||||
timeout=60,
|
||||
description="LLM Gateway for multi-provider AI interactions",
|
||||
),
|
||||
"knowledge-base": MCPServerConfig(
|
||||
url="${KNOWLEDGE_BASE_URL:-http://localhost:8002}",
|
||||
transport=TransportType.HTTP,
|
||||
timeout=30,
|
||||
description="Knowledge Base for RAG and document retrieval",
|
||||
),
|
||||
"git-ops": MCPServerConfig(
|
||||
url="${GIT_OPS_URL:-http://localhost:8003}",
|
||||
transport=TransportType.HTTP,
|
||||
timeout=120,
|
||||
description="Git Operations for repository management",
|
||||
),
|
||||
"issues": MCPServerConfig(
|
||||
url="${ISSUES_URL:-http://localhost:8004}",
|
||||
transport=TransportType.HTTP,
|
||||
timeout=30,
|
||||
description="Issue Tracker for Gitea/GitHub/GitLab",
|
||||
),
|
||||
},
|
||||
default_timeout=30,
|
||||
default_retry_attempts=3,
|
||||
connection_pool_size=10,
|
||||
health_check_interval=30,
|
||||
)
|
||||
435
backend/app/services/mcp/connection.py
Normal file
435
backend/app/services/mcp/connection.py
Normal file
@@ -0,0 +1,435 @@
|
||||
"""
|
||||
MCP Connection Management
|
||||
|
||||
Handles connection lifecycle, pooling, and automatic reconnection
|
||||
for MCP servers.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from .config import MCPServerConfig, TransportType
|
||||
from .exceptions import MCPConnectionError, MCPTimeoutError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConnectionState(str, Enum):
|
||||
"""Connection state enumeration."""
|
||||
|
||||
DISCONNECTED = "disconnected"
|
||||
CONNECTING = "connecting"
|
||||
CONNECTED = "connected"
|
||||
RECONNECTING = "reconnecting"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class MCPConnection:
|
||||
"""
|
||||
Manages a single connection to an MCP server.
|
||||
|
||||
Handles connection lifecycle, health checking, and automatic reconnection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_name: str,
|
||||
config: MCPServerConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize connection.
|
||||
|
||||
Args:
|
||||
server_name: Name of the MCP server
|
||||
config: Server configuration
|
||||
"""
|
||||
self.server_name = server_name
|
||||
self.config = config
|
||||
self._state = ConnectionState.DISCONNECTED
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
self._lock = asyncio.Lock()
|
||||
self._last_activity: float | None = None
|
||||
self._connection_attempts = 0
|
||||
self._last_error: Exception | None = None
|
||||
|
||||
# Reconnection settings
|
||||
self._base_delay = config.retry_delay
|
||||
self._max_delay = config.retry_max_delay
|
||||
self._max_attempts = config.retry_attempts
|
||||
|
||||
@property
|
||||
def state(self) -> ConnectionState:
|
||||
"""Get current connection state."""
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if connection is established."""
|
||||
return self._state == ConnectionState.CONNECTED
|
||||
|
||||
@property
|
||||
def last_error(self) -> Exception | None:
|
||||
"""Get the last error that occurred."""
|
||||
return self._last_error
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""
|
||||
Establish connection to the MCP server.
|
||||
|
||||
Raises:
|
||||
MCPConnectionError: If connection fails after all retries
|
||||
"""
|
||||
async with self._lock:
|
||||
if self._state == ConnectionState.CONNECTED:
|
||||
return
|
||||
|
||||
self._state = ConnectionState.CONNECTING
|
||||
self._connection_attempts = 0
|
||||
self._last_error = None
|
||||
|
||||
while self._connection_attempts < self._max_attempts:
|
||||
try:
|
||||
await self._do_connect()
|
||||
self._state = ConnectionState.CONNECTED
|
||||
self._last_activity = time.time()
|
||||
logger.info(
|
||||
"Connected to MCP server: %s at %s",
|
||||
self.server_name,
|
||||
self.config.url,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
self._connection_attempts += 1
|
||||
self._last_error = e
|
||||
logger.warning(
|
||||
"Connection attempt %d/%d failed for %s: %s",
|
||||
self._connection_attempts,
|
||||
self._max_attempts,
|
||||
self.server_name,
|
||||
e,
|
||||
)
|
||||
|
||||
if self._connection_attempts < self._max_attempts:
|
||||
delay = self._calculate_backoff_delay()
|
||||
logger.debug(
|
||||
"Retrying connection to %s in %.1fs",
|
||||
self.server_name,
|
||||
delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# All attempts failed
|
||||
self._state = ConnectionState.ERROR
|
||||
raise MCPConnectionError(
|
||||
f"Failed to connect after {self._max_attempts} attempts",
|
||||
server_name=self.server_name,
|
||||
url=self.config.url,
|
||||
cause=self._last_error,
|
||||
)
|
||||
|
||||
async def _do_connect(self) -> None:
|
||||
"""Perform the actual connection (transport-specific)."""
|
||||
if self.config.transport == TransportType.HTTP:
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self.config.url,
|
||||
timeout=httpx.Timeout(self.config.timeout),
|
||||
headers={
|
||||
"User-Agent": "Syndarix-MCP-Client/1.0",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
# Verify connectivity with a simple request
|
||||
try:
|
||||
# Try to hit the MCP capabilities endpoint
|
||||
response = await self._client.get("/mcp/capabilities")
|
||||
if response.status_code not in (200, 404):
|
||||
# 404 is acceptable - server might not have capabilities endpoint
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code != 404:
|
||||
raise
|
||||
except httpx.ConnectError as e:
|
||||
raise MCPConnectionError(
|
||||
"Failed to connect to server",
|
||||
server_name=self.server_name,
|
||||
url=self.config.url,
|
||||
cause=e,
|
||||
)
|
||||
else:
|
||||
# For STDIO and SSE transports, we'll implement later
|
||||
raise NotImplementedError(
|
||||
f"Transport {self.config.transport} not yet implemented"
|
||||
)
|
||||
|
||||
def _calculate_backoff_delay(self) -> float:
|
||||
"""Calculate exponential backoff delay with jitter."""
|
||||
import random
|
||||
|
||||
delay = self._base_delay * (2 ** (self._connection_attempts - 1))
|
||||
delay = min(delay, self._max_delay)
|
||||
# Add jitter (±25%)
|
||||
jitter = delay * 0.25 * (random.random() * 2 - 1)
|
||||
return delay + jitter
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from the MCP server."""
|
||||
async with self._lock:
|
||||
if self._client is not None:
|
||||
try:
|
||||
await self._client.aclose()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Error closing connection to %s: %s",
|
||||
self.server_name,
|
||||
e,
|
||||
)
|
||||
finally:
|
||||
self._client = None
|
||||
|
||||
self._state = ConnectionState.DISCONNECTED
|
||||
logger.info("Disconnected from MCP server: %s", self.server_name)
|
||||
|
||||
async def reconnect(self) -> None:
|
||||
"""Reconnect to the MCP server."""
|
||||
async with self._lock:
|
||||
self._state = ConnectionState.RECONNECTING
|
||||
await self.disconnect()
|
||||
await self.connect()
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""
|
||||
Perform a health check on the connection.
|
||||
|
||||
Returns:
|
||||
True if connection is healthy
|
||||
"""
|
||||
if not self.is_connected or self._client is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
if self.config.transport == TransportType.HTTP:
|
||||
response = await self._client.get(
|
||||
"/health",
|
||||
timeout=5.0,
|
||||
)
|
||||
return response.status_code == 200
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Health check failed for %s: %s",
|
||||
self.server_name,
|
||||
e,
|
||||
)
|
||||
return False
|
||||
|
||||
async def execute_request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
data: dict[str, Any] | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Execute an HTTP request to the MCP server.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
path: Request path
|
||||
data: Optional request body
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
Response data
|
||||
|
||||
Raises:
|
||||
MCPConnectionError: If not connected
|
||||
MCPTimeoutError: If request times out
|
||||
"""
|
||||
if not self.is_connected or self._client is None:
|
||||
raise MCPConnectionError(
|
||||
"Not connected to server",
|
||||
server_name=self.server_name,
|
||||
)
|
||||
|
||||
effective_timeout = timeout or self.config.timeout
|
||||
|
||||
try:
|
||||
if method.upper() == "GET":
|
||||
response = await self._client.get(
|
||||
path,
|
||||
timeout=effective_timeout,
|
||||
)
|
||||
elif method.upper() == "POST":
|
||||
response = await self._client.post(
|
||||
path,
|
||||
json=data,
|
||||
timeout=effective_timeout,
|
||||
)
|
||||
else:
|
||||
response = await self._client.request(
|
||||
method.upper(),
|
||||
path,
|
||||
json=data,
|
||||
timeout=effective_timeout,
|
||||
)
|
||||
|
||||
self._last_activity = time.time()
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
raise MCPTimeoutError(
|
||||
"Request timed out",
|
||||
server_name=self.server_name,
|
||||
timeout_seconds=effective_timeout,
|
||||
operation=f"{method} {path}",
|
||||
) from e
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise MCPConnectionError(
|
||||
f"HTTP error: {e.response.status_code}",
|
||||
server_name=self.server_name,
|
||||
url=f"{self.config.url}{path}",
|
||||
cause=e,
|
||||
)
|
||||
except Exception as e:
|
||||
raise MCPConnectionError(
|
||||
f"Request failed: {e}",
|
||||
server_name=self.server_name,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
|
||||
class ConnectionPool:
|
||||
"""
|
||||
Pool of connections to MCP servers.
|
||||
|
||||
Manages connection lifecycle and provides connection reuse.
|
||||
"""
|
||||
|
||||
def __init__(self, max_connections_per_server: int = 10) -> None:
|
||||
"""
|
||||
Initialize connection pool.
|
||||
|
||||
Args:
|
||||
max_connections_per_server: Maximum connections per server
|
||||
"""
|
||||
self._connections: dict[str, MCPConnection] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._max_per_server = max_connections_per_server
|
||||
|
||||
async def get_connection(
|
||||
self,
|
||||
server_name: str,
|
||||
config: MCPServerConfig,
|
||||
) -> MCPConnection:
|
||||
"""
|
||||
Get or create a connection to a server.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server
|
||||
config: Server configuration
|
||||
|
||||
Returns:
|
||||
Active connection
|
||||
"""
|
||||
async with self._lock:
|
||||
if server_name not in self._connections:
|
||||
connection = MCPConnection(server_name, config)
|
||||
await connection.connect()
|
||||
self._connections[server_name] = connection
|
||||
|
||||
connection = self._connections[server_name]
|
||||
|
||||
# Reconnect if not connected
|
||||
if not connection.is_connected:
|
||||
await connection.connect()
|
||||
|
||||
return connection
|
||||
|
||||
async def release_connection(self, server_name: str) -> None:
|
||||
"""
|
||||
Release a connection (currently just tracks usage).
|
||||
|
||||
Args:
|
||||
server_name: Name of the server
|
||||
"""
|
||||
# For now, we keep connections alive
|
||||
# Future: implement connection reaping for idle connections
|
||||
|
||||
async def close_connection(self, server_name: str) -> None:
|
||||
"""
|
||||
Close and remove a connection.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server
|
||||
"""
|
||||
async with self._lock:
|
||||
if server_name in self._connections:
|
||||
await self._connections[server_name].disconnect()
|
||||
del self._connections[server_name]
|
||||
|
||||
async def close_all(self) -> None:
|
||||
"""Close all connections in the pool."""
|
||||
async with self._lock:
|
||||
for connection in self._connections.values():
|
||||
try:
|
||||
await connection.disconnect()
|
||||
except Exception as e:
|
||||
logger.warning("Error closing connection: %s", e)
|
||||
|
||||
self._connections.clear()
|
||||
logger.info("Closed all MCP connections")
|
||||
|
||||
async def health_check_all(self) -> dict[str, bool]:
|
||||
"""
|
||||
Perform health check on all connections.
|
||||
|
||||
Returns:
|
||||
Dict mapping server names to health status
|
||||
"""
|
||||
results = {}
|
||||
for name, connection in self._connections.items():
|
||||
results[name] = await connection.health_check()
|
||||
return results
|
||||
|
||||
def get_status(self) -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
Get status of all connections.
|
||||
|
||||
Returns:
|
||||
Dict mapping server names to status info
|
||||
"""
|
||||
return {
|
||||
name: {
|
||||
"state": conn.state.value,
|
||||
"is_connected": conn.is_connected,
|
||||
"url": conn.config.url,
|
||||
}
|
||||
for name, conn in self._connections.items()
|
||||
}
|
||||
|
||||
@asynccontextmanager
|
||||
async def connection(
|
||||
self,
|
||||
server_name: str,
|
||||
config: MCPServerConfig,
|
||||
) -> AsyncGenerator[MCPConnection, None]:
|
||||
"""
|
||||
Context manager for getting a connection.
|
||||
|
||||
Usage:
|
||||
async with pool.connection("server", config) as conn:
|
||||
result = await conn.execute_request("POST", "/tool", data)
|
||||
"""
|
||||
conn = await self.get_connection(server_name, config)
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
await self.release_connection(server_name)
|
||||
201
backend/app/services/mcp/exceptions.py
Normal file
201
backend/app/services/mcp/exceptions.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
MCP Exception Classes
|
||||
|
||||
Custom exceptions for MCP client operations with detailed error context.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class MCPError(Exception):
|
||||
"""Base exception for all MCP-related errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
server_name: str | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.server_name = server_name
|
||||
self.details = details or {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
parts = [self.message]
|
||||
if self.server_name:
|
||||
parts.append(f"server={self.server_name}")
|
||||
if self.details:
|
||||
parts.append(f"details={self.details}")
|
||||
return " | ".join(parts)
|
||||
|
||||
|
||||
class MCPConnectionError(MCPError):
|
||||
"""Raised when connection to an MCP server fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
server_name: str | None = None,
|
||||
url: str | None = None,
|
||||
cause: Exception | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message, server_name=server_name, details=details)
|
||||
self.url = url
|
||||
self.cause = cause
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = super().__str__()
|
||||
if self.url:
|
||||
base = f"{base} | url={self.url}"
|
||||
if self.cause:
|
||||
base = f"{base} | cause={type(self.cause).__name__}: {self.cause}"
|
||||
return base
|
||||
|
||||
|
||||
class MCPTimeoutError(MCPError):
|
||||
"""Raised when an MCP operation times out."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
server_name: str | None = None,
|
||||
timeout_seconds: float | None = None,
|
||||
operation: str | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message, server_name=server_name, details=details)
|
||||
self.timeout_seconds = timeout_seconds
|
||||
self.operation = operation
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = super().__str__()
|
||||
if self.timeout_seconds is not None:
|
||||
base = f"{base} | timeout={self.timeout_seconds}s"
|
||||
if self.operation:
|
||||
base = f"{base} | operation={self.operation}"
|
||||
return base
|
||||
|
||||
|
||||
class MCPToolError(MCPError):
|
||||
"""Raised when a tool execution fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
server_name: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
tool_args: dict[str, Any] | None = None,
|
||||
error_code: str | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message, server_name=server_name, details=details)
|
||||
self.tool_name = tool_name
|
||||
self.tool_args = tool_args
|
||||
self.error_code = error_code
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = super().__str__()
|
||||
if self.tool_name:
|
||||
base = f"{base} | tool={self.tool_name}"
|
||||
if self.error_code:
|
||||
base = f"{base} | error_code={self.error_code}"
|
||||
return base
|
||||
|
||||
|
||||
class MCPServerNotFoundError(MCPError):
|
||||
"""Raised when a requested MCP server is not registered."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_name: str,
|
||||
*,
|
||||
available_servers: list[str] | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
message = f"MCP server not found: {server_name}"
|
||||
super().__init__(message, server_name=server_name, details=details)
|
||||
self.available_servers = available_servers or []
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = super().__str__()
|
||||
if self.available_servers:
|
||||
base = f"{base} | available={self.available_servers}"
|
||||
return base
|
||||
|
||||
|
||||
class MCPToolNotFoundError(MCPError):
|
||||
"""Raised when a requested tool is not found on any server."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool_name: str,
|
||||
*,
|
||||
server_name: str | None = None,
|
||||
available_tools: list[str] | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
message = f"Tool not found: {tool_name}"
|
||||
super().__init__(message, server_name=server_name, details=details)
|
||||
self.tool_name = tool_name
|
||||
self.available_tools = available_tools or []
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = super().__str__()
|
||||
if self.available_tools:
|
||||
base = f"{base} | available_tools={self.available_tools[:5]}..."
|
||||
return base
|
||||
|
||||
|
||||
class MCPCircuitOpenError(MCPError):
|
||||
"""Raised when a circuit breaker is open (server temporarily unavailable)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_name: str,
|
||||
*,
|
||||
failure_count: int | None = None,
|
||||
reset_timeout: float | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
message = f"Circuit breaker open for server: {server_name}"
|
||||
super().__init__(message, server_name=server_name, details=details)
|
||||
self.failure_count = failure_count
|
||||
self.reset_timeout = reset_timeout
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = super().__str__()
|
||||
if self.failure_count is not None:
|
||||
base = f"{base} | failures={self.failure_count}"
|
||||
if self.reset_timeout is not None:
|
||||
base = f"{base} | reset_in={self.reset_timeout}s"
|
||||
return base
|
||||
|
||||
|
||||
class MCPValidationError(MCPError):
|
||||
"""Raised when tool arguments fail validation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
tool_name: str | None = None,
|
||||
field_errors: dict[str, str] | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message, details=details)
|
||||
self.tool_name = tool_name
|
||||
self.field_errors = field_errors or {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = super().__str__()
|
||||
if self.tool_name:
|
||||
base = f"{base} | tool={self.tool_name}"
|
||||
if self.field_errors:
|
||||
base = f"{base} | fields={list(self.field_errors.keys())}"
|
||||
return base
|
||||
305
backend/app/services/mcp/registry.py
Normal file
305
backend/app/services/mcp/registry.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""
|
||||
MCP Server Registry
|
||||
|
||||
Thread-safe singleton registry for managing MCP server configurations
|
||||
and their capabilities.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
from .config import MCPConfig, MCPServerConfig, load_mcp_config
|
||||
from .exceptions import MCPServerNotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ServerCapabilities:
|
||||
"""Cached capabilities for an MCP server."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
resources: list[dict[str, Any]] | None = None,
|
||||
prompts: list[dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
self.tools = tools or []
|
||||
self.resources = resources or []
|
||||
self.prompts = prompts or []
|
||||
self._loaded = False
|
||||
self._load_time: float | None = None
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if capabilities have been loaded."""
|
||||
return self._loaded
|
||||
|
||||
@property
|
||||
def tool_names(self) -> list[str]:
|
||||
"""Get list of tool names."""
|
||||
return [t.get("name", "") for t in self.tools if t.get("name")]
|
||||
|
||||
def mark_loaded(self) -> None:
|
||||
"""Mark capabilities as loaded."""
|
||||
import time
|
||||
|
||||
self._loaded = True
|
||||
self._load_time = time.time()
|
||||
|
||||
|
||||
class MCPServerRegistry:
|
||||
"""
|
||||
Thread-safe singleton registry for MCP servers.
|
||||
|
||||
Manages server configurations and caches their capabilities.
|
||||
"""
|
||||
|
||||
_instance: "MCPServerRegistry | None" = None
|
||||
_lock = Lock()
|
||||
|
||||
def __new__(cls) -> "MCPServerRegistry":
|
||||
"""Ensure singleton pattern."""
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize registry (only runs once due to singleton)."""
|
||||
if getattr(self, "_initialized", False):
|
||||
return
|
||||
|
||||
self._config: MCPConfig = MCPConfig()
|
||||
self._capabilities: dict[str, ServerCapabilities] = {}
|
||||
self._capabilities_lock = asyncio.Lock()
|
||||
self._initialized = True
|
||||
|
||||
logger.info("MCP Server Registry initialized")
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "MCPServerRegistry":
|
||||
"""Get the singleton registry instance."""
|
||||
return cls()
|
||||
|
||||
@classmethod
|
||||
def reset_instance(cls) -> None:
|
||||
"""Reset the singleton (for testing)."""
|
||||
with cls._lock:
|
||||
cls._instance = None
|
||||
|
||||
def load_config(self, config: MCPConfig | None = None) -> None:
|
||||
"""
|
||||
Load configuration into the registry.
|
||||
|
||||
Args:
|
||||
config: Optional config to load. If None, loads from default path.
|
||||
"""
|
||||
if config is None:
|
||||
config = load_mcp_config()
|
||||
|
||||
self._config = config
|
||||
self._capabilities.clear()
|
||||
|
||||
logger.info(
|
||||
"Loaded MCP configuration with %d servers",
|
||||
len(config.mcp_servers),
|
||||
)
|
||||
for name in config.list_server_names():
|
||||
logger.debug("Registered MCP server: %s", name)
|
||||
|
||||
def register(self, name: str, config: MCPServerConfig) -> None:
|
||||
"""
|
||||
Register a new MCP server.
|
||||
|
||||
Args:
|
||||
name: Unique server name
|
||||
config: Server configuration
|
||||
"""
|
||||
self._config.mcp_servers[name] = config
|
||||
self._capabilities.pop(name, None) # Clear any cached capabilities
|
||||
|
||||
logger.info("Registered MCP server: %s at %s", name, config.url)
|
||||
|
||||
def unregister(self, name: str) -> bool:
|
||||
"""
|
||||
Unregister an MCP server.
|
||||
|
||||
Args:
|
||||
name: Server name to unregister
|
||||
|
||||
Returns:
|
||||
True if server was found and removed
|
||||
"""
|
||||
if name in self._config.mcp_servers:
|
||||
del self._config.mcp_servers[name]
|
||||
self._capabilities.pop(name, None)
|
||||
logger.info("Unregistered MCP server: %s", name)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get(self, name: str) -> MCPServerConfig:
|
||||
"""
|
||||
Get a server configuration by name.
|
||||
|
||||
Args:
|
||||
name: Server name
|
||||
|
||||
Returns:
|
||||
Server configuration
|
||||
|
||||
Raises:
|
||||
MCPServerNotFoundError: If server is not registered
|
||||
"""
|
||||
config = self._config.get_server(name)
|
||||
if config is None:
|
||||
raise MCPServerNotFoundError(
|
||||
server_name=name,
|
||||
available_servers=self.list_servers(),
|
||||
)
|
||||
return config
|
||||
|
||||
def get_or_none(self, name: str) -> MCPServerConfig | None:
|
||||
"""
|
||||
Get a server configuration by name, or None if not found.
|
||||
|
||||
Args:
|
||||
name: Server name
|
||||
|
||||
Returns:
|
||||
Server configuration or None
|
||||
"""
|
||||
return self._config.get_server(name)
|
||||
|
||||
def list_servers(self) -> list[str]:
|
||||
"""Get list of all registered server names."""
|
||||
return self._config.list_server_names()
|
||||
|
||||
def list_enabled_servers(self) -> list[str]:
|
||||
"""Get list of enabled server names."""
|
||||
return list(self._config.get_enabled_servers().keys())
|
||||
|
||||
def get_all_configs(self) -> dict[str, MCPServerConfig]:
|
||||
"""Get all server configurations."""
|
||||
return dict(self._config.mcp_servers)
|
||||
|
||||
def get_enabled_configs(self) -> dict[str, MCPServerConfig]:
|
||||
"""Get all enabled server configurations."""
|
||||
return self._config.get_enabled_servers()
|
||||
|
||||
async def get_capabilities(
|
||||
self,
|
||||
name: str,
|
||||
force_refresh: bool = False,
|
||||
) -> ServerCapabilities:
|
||||
"""
|
||||
Get capabilities for a server (lazy-loaded and cached).
|
||||
|
||||
Args:
|
||||
name: Server name
|
||||
force_refresh: If True, refresh cached capabilities
|
||||
|
||||
Returns:
|
||||
Server capabilities
|
||||
|
||||
Raises:
|
||||
MCPServerNotFoundError: If server is not registered
|
||||
"""
|
||||
# Verify server exists
|
||||
self.get(name)
|
||||
|
||||
async with self._capabilities_lock:
|
||||
if name not in self._capabilities or force_refresh:
|
||||
# Will be populated by connection manager when connecting
|
||||
self._capabilities[name] = ServerCapabilities()
|
||||
|
||||
return self._capabilities[name]
|
||||
|
||||
def set_capabilities(
|
||||
self,
|
||||
name: str,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
resources: list[dict[str, Any]] | None = None,
|
||||
prompts: list[dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Set capabilities for a server (called by connection manager).
|
||||
|
||||
Args:
|
||||
name: Server name
|
||||
tools: List of tool definitions
|
||||
resources: List of resource definitions
|
||||
prompts: List of prompt definitions
|
||||
"""
|
||||
capabilities = ServerCapabilities(
|
||||
tools=tools,
|
||||
resources=resources,
|
||||
prompts=prompts,
|
||||
)
|
||||
capabilities.mark_loaded()
|
||||
self._capabilities[name] = capabilities
|
||||
|
||||
logger.debug(
|
||||
"Updated capabilities for %s: %d tools, %d resources, %d prompts",
|
||||
name,
|
||||
len(capabilities.tools),
|
||||
len(capabilities.resources),
|
||||
len(capabilities.prompts),
|
||||
)
|
||||
|
||||
def get_cached_capabilities(self, name: str) -> ServerCapabilities:
|
||||
"""
|
||||
Get cached capabilities without async loading.
|
||||
|
||||
Use this for synchronous access when you only need
|
||||
cached values (e.g., for health check responses).
|
||||
|
||||
Args:
|
||||
name: Server name
|
||||
|
||||
Returns:
|
||||
Cached capabilities or empty ServerCapabilities
|
||||
"""
|
||||
return self._capabilities.get(name, ServerCapabilities())
|
||||
|
||||
def find_server_for_tool(self, tool_name: str) -> str | None:
|
||||
"""
|
||||
Find which server provides a specific tool.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to find
|
||||
|
||||
Returns:
|
||||
Server name or None if not found
|
||||
"""
|
||||
for name, caps in self._capabilities.items():
|
||||
if tool_name in caps.tool_names:
|
||||
return name
|
||||
return None
|
||||
|
||||
def get_all_tools(self) -> dict[str, list[dict[str, Any]]]:
|
||||
"""
|
||||
Get all tools from all servers.
|
||||
|
||||
Returns:
|
||||
Dict mapping server name to list of tool definitions
|
||||
"""
|
||||
return {
|
||||
name: caps.tools
|
||||
for name, caps in self._capabilities.items()
|
||||
if caps.is_loaded
|
||||
}
|
||||
|
||||
@property
|
||||
def global_config(self) -> MCPConfig:
|
||||
"""Get the global MCP configuration."""
|
||||
return self._config
|
||||
|
||||
|
||||
# Module-level convenience function
|
||||
def get_registry() -> MCPServerRegistry:
|
||||
"""Get the global MCP server registry instance."""
|
||||
return MCPServerRegistry.get_instance()
|
||||
619
backend/app/services/mcp/routing.py
Normal file
619
backend/app/services/mcp/routing.py
Normal file
@@ -0,0 +1,619 @@
|
||||
"""
|
||||
MCP Tool Call Routing
|
||||
|
||||
Routes tool calls to appropriate servers with retry logic,
|
||||
circuit breakers, and request/response serialization.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from .config import MCPServerConfig
|
||||
from .connection import ConnectionPool, MCPConnection
|
||||
from .exceptions import (
|
||||
MCPCircuitOpenError,
|
||||
MCPError,
|
||||
MCPTimeoutError,
|
||||
MCPToolError,
|
||||
MCPToolNotFoundError,
|
||||
)
|
||||
from .registry import MCPServerRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CircuitState(Enum):
|
||||
"""Circuit breaker states."""
|
||||
|
||||
CLOSED = "closed"
|
||||
OPEN = "open"
|
||||
HALF_OPEN = "half-open"
|
||||
|
||||
|
||||
class AsyncCircuitBreaker:
|
||||
"""
|
||||
Async-compatible circuit breaker implementation.
|
||||
|
||||
Unlike pybreaker which wraps sync functions, this implementation
|
||||
provides explicit success/failure tracking for async code.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fail_max: int = 5,
|
||||
reset_timeout: float = 30.0,
|
||||
name: str = "",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize circuit breaker.
|
||||
|
||||
Args:
|
||||
fail_max: Maximum failures before opening circuit
|
||||
reset_timeout: Seconds to wait before trying again
|
||||
name: Name for logging
|
||||
"""
|
||||
self.fail_max = fail_max
|
||||
self.reset_timeout = reset_timeout
|
||||
self.name = name
|
||||
self._state = CircuitState.CLOSED
|
||||
self._fail_counter = 0
|
||||
self._last_failure_time: float | None = None
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def current_state(self) -> str:
|
||||
"""Get current state as string."""
|
||||
# Check if we should transition from OPEN to HALF_OPEN
|
||||
if self._state == CircuitState.OPEN:
|
||||
if self._should_try_reset():
|
||||
return CircuitState.HALF_OPEN.value
|
||||
return self._state.value
|
||||
|
||||
@property
|
||||
def fail_counter(self) -> int:
|
||||
"""Get current failure count."""
|
||||
return self._fail_counter
|
||||
|
||||
def _should_try_reset(self) -> bool:
|
||||
"""Check if enough time has passed to try resetting."""
|
||||
if self._last_failure_time is None:
|
||||
return True
|
||||
return (time.time() - self._last_failure_time) >= self.reset_timeout
|
||||
|
||||
async def success(self) -> None:
|
||||
"""Record a successful call."""
|
||||
async with self._lock:
|
||||
self._fail_counter = 0
|
||||
self._state = CircuitState.CLOSED
|
||||
self._last_failure_time = None
|
||||
|
||||
async def failure(self) -> None:
|
||||
"""Record a failed call."""
|
||||
async with self._lock:
|
||||
self._fail_counter += 1
|
||||
self._last_failure_time = time.time()
|
||||
|
||||
if self._fail_counter >= self.fail_max:
|
||||
self._state = CircuitState.OPEN
|
||||
logger.warning(
|
||||
"Circuit breaker %s opened after %d failures",
|
||||
self.name,
|
||||
self._fail_counter,
|
||||
)
|
||||
|
||||
def is_open(self) -> bool:
|
||||
"""Check if circuit is open (not allowing calls)."""
|
||||
if self._state == CircuitState.OPEN:
|
||||
return not self._should_try_reset()
|
||||
return False
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Manually reset the circuit breaker."""
|
||||
async with self._lock:
|
||||
self._state = CircuitState.CLOSED
|
||||
self._fail_counter = 0
|
||||
self._last_failure_time = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolInfo:
|
||||
"""Information about an available tool."""
|
||||
|
||||
name: str
|
||||
description: str | None = None
|
||||
server_name: str | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"server_name": self.server_name,
|
||||
"input_schema": self.input_schema,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
"""Result of a tool execution."""
|
||||
|
||||
success: bool
|
||||
data: Any = None
|
||||
error: str | None = None
|
||||
error_code: str | None = None
|
||||
tool_name: str | None = None
|
||||
server_name: str | None = None
|
||||
execution_time_ms: float = 0.0
|
||||
request_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"success": self.success,
|
||||
"data": self.data,
|
||||
"error": self.error,
|
||||
"error_code": self.error_code,
|
||||
"tool_name": self.tool_name,
|
||||
"server_name": self.server_name,
|
||||
"execution_time_ms": self.execution_time_ms,
|
||||
"request_id": self.request_id,
|
||||
}
|
||||
|
||||
|
||||
class ToolRouter:
|
||||
"""
|
||||
Routes tool calls to the appropriate MCP server.
|
||||
|
||||
Features:
|
||||
- Tool name to server mapping
|
||||
- Retry logic with exponential backoff
|
||||
- Circuit breaker pattern for fault tolerance
|
||||
- Request/response serialization
|
||||
- Execution timing and metrics
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
registry: MCPServerRegistry,
|
||||
connection_pool: ConnectionPool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the tool router.
|
||||
|
||||
Args:
|
||||
registry: MCP server registry
|
||||
connection_pool: Connection pool for servers
|
||||
"""
|
||||
self._registry = registry
|
||||
self._pool = connection_pool
|
||||
self._circuit_breakers: dict[str, AsyncCircuitBreaker] = {}
|
||||
self._tool_to_server: dict[str, str] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def _get_circuit_breaker(
|
||||
self,
|
||||
server_name: str,
|
||||
config: MCPServerConfig,
|
||||
) -> AsyncCircuitBreaker:
|
||||
"""Get or create a circuit breaker for a server."""
|
||||
if server_name not in self._circuit_breakers:
|
||||
self._circuit_breakers[server_name] = AsyncCircuitBreaker(
|
||||
fail_max=config.circuit_breaker_threshold,
|
||||
reset_timeout=config.circuit_breaker_timeout,
|
||||
name=f"mcp-{server_name}",
|
||||
)
|
||||
return self._circuit_breakers[server_name]
|
||||
|
||||
async def register_tool_mapping(
|
||||
self,
|
||||
tool_name: str,
|
||||
server_name: str,
|
||||
) -> None:
|
||||
"""
|
||||
Register a mapping from tool name to server.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool
|
||||
server_name: Name of the server providing the tool
|
||||
"""
|
||||
async with self._lock:
|
||||
self._tool_to_server[tool_name] = server_name
|
||||
logger.debug("Registered tool %s -> server %s", tool_name, server_name)
|
||||
|
||||
async def discover_tools(self) -> None:
|
||||
"""
|
||||
Discover all tools from registered servers and build mappings.
|
||||
"""
|
||||
for server_name in self._registry.list_enabled_servers():
|
||||
try:
|
||||
config = self._registry.get(server_name)
|
||||
connection = await self._pool.get_connection(server_name, config)
|
||||
|
||||
# Fetch tools from server
|
||||
tools = await self._fetch_tools_from_server(connection)
|
||||
|
||||
# Update registry with capabilities
|
||||
self._registry.set_capabilities(
|
||||
server_name,
|
||||
tools=[t.to_dict() for t in tools],
|
||||
)
|
||||
|
||||
# Update tool mappings
|
||||
for tool in tools:
|
||||
await self.register_tool_mapping(tool.name, server_name)
|
||||
|
||||
logger.info(
|
||||
"Discovered %d tools from server %s",
|
||||
len(tools),
|
||||
server_name,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to discover tools from %s: %s",
|
||||
server_name,
|
||||
e,
|
||||
)
|
||||
|
||||
async def _fetch_tools_from_server(
|
||||
self,
|
||||
connection: MCPConnection,
|
||||
) -> list[ToolInfo]:
|
||||
"""Fetch available tools from an MCP server."""
|
||||
try:
|
||||
response = await connection.execute_request(
|
||||
"GET",
|
||||
"/mcp/tools",
|
||||
)
|
||||
|
||||
tools = []
|
||||
for tool_data in response.get("tools", []):
|
||||
tools.append(
|
||||
ToolInfo(
|
||||
name=tool_data.get("name", ""),
|
||||
description=tool_data.get("description"),
|
||||
server_name=connection.server_name,
|
||||
input_schema=tool_data.get("inputSchema"),
|
||||
)
|
||||
)
|
||||
return tools
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Error fetching tools from %s: %s",
|
||||
connection.server_name,
|
||||
e,
|
||||
)
|
||||
return []
|
||||
|
||||
def find_server_for_tool(self, tool_name: str) -> str | None:
|
||||
"""
|
||||
Find which server provides a specific tool.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool
|
||||
|
||||
Returns:
|
||||
Server name or None if not found
|
||||
"""
|
||||
return self._tool_to_server.get(tool_name)
|
||||
|
||||
async def call_tool(
|
||||
self,
|
||||
server_name: str,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> ToolResult:
|
||||
"""
|
||||
Call a tool on a specific server.
|
||||
|
||||
Args:
|
||||
server_name: Name of the MCP server
|
||||
tool_name: Name of the tool to call
|
||||
arguments: Tool arguments
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
"""
|
||||
start_time = time.time()
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
logger.debug(
|
||||
"Tool call [%s]: %s.%s with args %s",
|
||||
request_id,
|
||||
server_name,
|
||||
tool_name,
|
||||
arguments,
|
||||
)
|
||||
|
||||
try:
|
||||
config = self._registry.get(server_name)
|
||||
circuit_breaker = self._get_circuit_breaker(server_name, config)
|
||||
|
||||
# Check circuit breaker state
|
||||
if circuit_breaker.is_open():
|
||||
raise MCPCircuitOpenError(
|
||||
server_name=server_name,
|
||||
failure_count=circuit_breaker.fail_counter,
|
||||
reset_timeout=config.circuit_breaker_timeout,
|
||||
)
|
||||
|
||||
# Execute with retry logic
|
||||
result = await self._execute_with_retry(
|
||||
server_name=server_name,
|
||||
config=config,
|
||||
tool_name=tool_name,
|
||||
arguments=arguments or {},
|
||||
timeout=timeout,
|
||||
circuit_breaker=circuit_breaker,
|
||||
)
|
||||
|
||||
execution_time = (time.time() - start_time) * 1000
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data=result,
|
||||
tool_name=tool_name,
|
||||
server_name=server_name,
|
||||
execution_time_ms=execution_time,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
except MCPCircuitOpenError:
|
||||
raise
|
||||
except MCPError as e:
|
||||
execution_time = (time.time() - start_time) * 1000
|
||||
logger.error(
|
||||
"Tool call failed [%s]: %s.%s - %s",
|
||||
request_id,
|
||||
server_name,
|
||||
tool_name,
|
||||
e,
|
||||
)
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
error_code=type(e).__name__,
|
||||
tool_name=tool_name,
|
||||
server_name=server_name,
|
||||
execution_time_ms=execution_time,
|
||||
request_id=request_id,
|
||||
)
|
||||
except Exception as e:
|
||||
execution_time = (time.time() - start_time) * 1000
|
||||
logger.exception(
|
||||
"Unexpected error in tool call [%s]: %s.%s",
|
||||
request_id,
|
||||
server_name,
|
||||
tool_name,
|
||||
)
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
error_code="UnexpectedError",
|
||||
tool_name=tool_name,
|
||||
server_name=server_name,
|
||||
execution_time_ms=execution_time,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
async def _execute_with_retry(
|
||||
self,
|
||||
server_name: str,
|
||||
config: MCPServerConfig,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any],
|
||||
timeout: float | None,
|
||||
circuit_breaker: AsyncCircuitBreaker,
|
||||
) -> Any:
|
||||
"""Execute tool call with retry logic."""
|
||||
last_error: Exception | None = None
|
||||
attempts = 0
|
||||
max_attempts = config.retry_attempts + 1 # +1 for initial attempt
|
||||
|
||||
while attempts < max_attempts:
|
||||
attempts += 1
|
||||
|
||||
try:
|
||||
# Use circuit breaker to track failures
|
||||
result = await self._execute_tool_call(
|
||||
server_name=server_name,
|
||||
config=config,
|
||||
tool_name=tool_name,
|
||||
arguments=arguments,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Success - record it
|
||||
await circuit_breaker.success()
|
||||
return result
|
||||
|
||||
except MCPCircuitOpenError:
|
||||
raise
|
||||
except MCPTimeoutError:
|
||||
# Timeout - don't retry
|
||||
await circuit_breaker.failure()
|
||||
raise
|
||||
except MCPToolError:
|
||||
# Tool-level error - don't retry (user error)
|
||||
raise
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
await circuit_breaker.failure()
|
||||
|
||||
if attempts < max_attempts:
|
||||
delay = self._calculate_retry_delay(attempts, config)
|
||||
logger.warning(
|
||||
"Tool call attempt %d/%d failed for %s.%s: %s. "
|
||||
"Retrying in %.1fs",
|
||||
attempts,
|
||||
max_attempts,
|
||||
server_name,
|
||||
tool_name,
|
||||
e,
|
||||
delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# All attempts failed
|
||||
raise MCPToolError(
|
||||
f"Tool call failed after {max_attempts} attempts",
|
||||
server_name=server_name,
|
||||
tool_name=tool_name,
|
||||
tool_args=arguments,
|
||||
details={"last_error": str(last_error)},
|
||||
)
|
||||
|
||||
def _calculate_retry_delay(
|
||||
self,
|
||||
attempt: int,
|
||||
config: MCPServerConfig,
|
||||
) -> float:
|
||||
"""Calculate exponential backoff delay with jitter."""
|
||||
import random
|
||||
|
||||
delay = config.retry_delay * (2 ** (attempt - 1))
|
||||
delay = min(delay, config.retry_max_delay)
|
||||
# Add jitter (±25%)
|
||||
jitter = delay * 0.25 * (random.random() * 2 - 1)
|
||||
return max(0.1, delay + jitter)
|
||||
|
||||
async def _execute_tool_call(
|
||||
self,
|
||||
server_name: str,
|
||||
config: MCPServerConfig,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any],
|
||||
timeout: float | None,
|
||||
) -> Any:
|
||||
"""Execute a single tool call."""
|
||||
connection = await self._pool.get_connection(server_name, config)
|
||||
|
||||
# Build MCP tool call request
|
||||
request_body = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": tool_name,
|
||||
"arguments": arguments,
|
||||
},
|
||||
"id": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
response = await connection.execute_request(
|
||||
method="POST",
|
||||
path="/mcp",
|
||||
data=request_body,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Handle JSON-RPC response
|
||||
if "error" in response:
|
||||
error = response["error"]
|
||||
raise MCPToolError(
|
||||
error.get("message", "Tool execution failed"),
|
||||
server_name=server_name,
|
||||
tool_name=tool_name,
|
||||
tool_args=arguments,
|
||||
error_code=str(error.get("code", "UNKNOWN")),
|
||||
)
|
||||
|
||||
return response.get("result")
|
||||
|
||||
async def route_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> ToolResult:
|
||||
"""
|
||||
Route a tool call to the appropriate server.
|
||||
|
||||
Automatically discovers which server provides the tool.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to call
|
||||
arguments: Tool arguments
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
|
||||
Raises:
|
||||
MCPToolNotFoundError: If no server provides the tool
|
||||
"""
|
||||
server_name = self.find_server_for_tool(tool_name)
|
||||
|
||||
if server_name is None:
|
||||
# Try to find from registry
|
||||
server_name = self._registry.find_server_for_tool(tool_name)
|
||||
|
||||
if server_name is None:
|
||||
raise MCPToolNotFoundError(
|
||||
tool_name=tool_name,
|
||||
available_tools=list(self._tool_to_server.keys()),
|
||||
)
|
||||
|
||||
return await self.call_tool(
|
||||
server_name=server_name,
|
||||
tool_name=tool_name,
|
||||
arguments=arguments,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def list_all_tools(self) -> list[ToolInfo]:
|
||||
"""
|
||||
Get all available tools from all servers.
|
||||
|
||||
Returns:
|
||||
List of tool information
|
||||
"""
|
||||
tools = []
|
||||
all_server_tools = self._registry.get_all_tools()
|
||||
|
||||
for server_name, server_tools in all_server_tools.items():
|
||||
for tool_data in server_tools:
|
||||
tools.append(
|
||||
ToolInfo(
|
||||
name=tool_data.get("name", ""),
|
||||
description=tool_data.get("description"),
|
||||
server_name=server_name,
|
||||
input_schema=tool_data.get("input_schema"),
|
||||
)
|
||||
)
|
||||
|
||||
return tools
|
||||
|
||||
def get_circuit_breaker_status(self) -> dict[str, dict[str, Any]]:
|
||||
"""Get status of all circuit breakers."""
|
||||
return {
|
||||
name: {
|
||||
"state": cb.current_state,
|
||||
"failure_count": cb.fail_counter,
|
||||
}
|
||||
for name, cb in self._circuit_breakers.items()
|
||||
}
|
||||
|
||||
async def reset_circuit_breaker(self, server_name: str) -> bool:
|
||||
"""
|
||||
Manually reset a circuit breaker.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server
|
||||
|
||||
Returns:
|
||||
True if circuit breaker was reset
|
||||
"""
|
||||
async with self._lock:
|
||||
if server_name in self._circuit_breakers:
|
||||
# Reset by removing (will be recreated on next call)
|
||||
del self._circuit_breakers[server_name]
|
||||
logger.info("Reset circuit breaker for %s", server_name)
|
||||
return True
|
||||
return False
|
||||
324
backend/docs/MCP_CLIENT.md
Normal file
324
backend/docs/MCP_CLIENT.md
Normal file
@@ -0,0 +1,324 @@
|
||||
# MCP Client Infrastructure
|
||||
|
||||
This document describes the Model Context Protocol (MCP) client infrastructure used by Syndarix to communicate with AI agent tools.
|
||||
|
||||
## Overview
|
||||
|
||||
The MCP client infrastructure provides a robust, fault-tolerant layer for communicating with MCP servers. It enables AI agents to discover and execute tools provided by various services (LLM Gateway, Knowledge Base, Git Operations, Issue Tracker, etc.).
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌────────────────────────────────────────────────────────────────────────┐
|
||||
│ MCPClientManager │
|
||||
│ (Main Facade Class) │
|
||||
├────────────────────────────────────────────────────────────────────────┤
|
||||
│ - initialize() / shutdown() │
|
||||
│ - call_tool() / route_tool() │
|
||||
│ - connect() / disconnect() │
|
||||
│ - health_check() / list_tools() │
|
||||
└─────────────┬────────────────────┬─────────────────┬───────────────────┘
|
||||
│ │ │
|
||||
▼ ▼ ▼
|
||||
┌─────────────────────┐ ┌─────────────────┐ ┌──────────────────────────┐
|
||||
│ MCPServerRegistry │ │ ConnectionPool │ │ ToolRouter │
|
||||
│ (Singleton) │ │ │ │ │
|
||||
├─────────────────────┤ ├─────────────────┤ ├──────────────────────────┤
|
||||
│ - Server configs │ │ - Connection │ │ - Tool → Server mapping │
|
||||
│ - Capabilities │ │ management │ │ - Circuit breakers │
|
||||
│ - Tool discovery │ │ - Auto reconnect│ │ - Retry logic │
|
||||
└─────────────────────┘ └─────────────────┘ └──────────────────────────┘
|
||||
```
|
||||
|
||||
## Components
|
||||
|
||||
### MCPClientManager
|
||||
|
||||
The main entry point for all MCP operations. Provides a clean facade over the underlying infrastructure.
|
||||
|
||||
```python
|
||||
from app.services.mcp import get_mcp_client, MCPClientManager
|
||||
|
||||
# In FastAPI dependency injection
|
||||
async def my_route(mcp: MCPClientManager = Depends(get_mcp_client)):
|
||||
result = await mcp.call_tool(
|
||||
server="llm-gateway",
|
||||
tool="chat",
|
||||
args={"prompt": "Hello"}
|
||||
)
|
||||
return result.data
|
||||
|
||||
# Direct usage
|
||||
manager = MCPClientManager()
|
||||
await manager.initialize()
|
||||
|
||||
# Execute a tool
|
||||
result = await manager.call_tool(
|
||||
server="issues",
|
||||
tool="create_issue",
|
||||
args={"title": "New Feature", "body": "Description"}
|
||||
)
|
||||
|
||||
await manager.shutdown()
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
Configuration is loaded from YAML files and supports environment variable expansion:
|
||||
|
||||
```yaml
|
||||
# mcp_servers.yaml
|
||||
mcp_servers:
|
||||
llm-gateway:
|
||||
url: ${LLM_GATEWAY_URL:-http://localhost:8001}
|
||||
timeout: 60
|
||||
transport: http
|
||||
enabled: true
|
||||
retry_attempts: 3
|
||||
circuit_breaker_threshold: 5
|
||||
circuit_breaker_timeout: 30.0
|
||||
|
||||
knowledge-base:
|
||||
url: ${KNOWLEDGE_BASE_URL:-http://localhost:8002}
|
||||
timeout: 30
|
||||
enabled: true
|
||||
|
||||
default_timeout: 30
|
||||
connection_pool_size: 10
|
||||
health_check_interval: 30
|
||||
```
|
||||
|
||||
**Environment Variable Syntax:**
|
||||
- `${VAR_NAME}` - Uses the environment variable value
|
||||
- `${VAR_NAME:-default}` - Uses default if variable is not set
|
||||
|
||||
### Connection Management
|
||||
|
||||
The `ConnectionPool` manages connections to MCP servers with:
|
||||
|
||||
- **Connection Reuse**: Connections are pooled and reused
|
||||
- **Auto Reconnection**: Failed connections are automatically retried
|
||||
- **Health Checks**: Periodic health checks detect unhealthy servers
|
||||
- **Exponential Backoff**: Retry delays increase exponentially with jitter
|
||||
|
||||
```python
|
||||
from app.services.mcp import ConnectionPool, MCPConnection
|
||||
|
||||
pool = ConnectionPool(max_connections_per_server=5)
|
||||
|
||||
# Get a connection (creates new or reuses existing)
|
||||
conn = await pool.get_connection("server-1", config)
|
||||
|
||||
# Execute request
|
||||
result = await conn.execute_request("POST", "/mcp", data={...})
|
||||
|
||||
# Health check all connections
|
||||
health = await pool.health_check_all()
|
||||
```
|
||||
|
||||
### Circuit Breaker Pattern
|
||||
|
||||
The `AsyncCircuitBreaker` prevents cascade failures:
|
||||
|
||||
| State | Description |
|
||||
|-------|-------------|
|
||||
| CLOSED | Normal operation, calls pass through |
|
||||
| OPEN | Too many failures, calls are rejected immediately |
|
||||
| HALF-OPEN | After timeout, allows one call to test if service recovered |
|
||||
|
||||
```python
|
||||
from app.services.mcp import AsyncCircuitBreaker
|
||||
|
||||
breaker = AsyncCircuitBreaker(
|
||||
fail_max=5, # Open after 5 failures
|
||||
reset_timeout=30, # Try again after 30 seconds
|
||||
name="my-service"
|
||||
)
|
||||
|
||||
if breaker.is_open():
|
||||
raise MCPCircuitOpenError(...)
|
||||
|
||||
try:
|
||||
result = await call_external_service()
|
||||
await breaker.success()
|
||||
except Exception:
|
||||
await breaker.failure()
|
||||
raise
|
||||
```
|
||||
|
||||
### Tool Routing
|
||||
|
||||
The `ToolRouter` handles:
|
||||
|
||||
- **Tool Discovery**: Automatically discovers tools from connected servers
|
||||
- **Routing**: Routes tool calls to the appropriate server
|
||||
- **Retry Logic**: Retries failed calls with exponential backoff
|
||||
|
||||
```python
|
||||
from app.services.mcp import ToolRouter
|
||||
|
||||
router = ToolRouter(registry, pool)
|
||||
|
||||
# Discover tools from all servers
|
||||
await router.discover_tools()
|
||||
|
||||
# Route to the right server automatically
|
||||
result = await router.route_tool(
|
||||
tool_name="create_issue",
|
||||
arguments={"title": "Bug fix"}
|
||||
)
|
||||
|
||||
# Or call a specific server
|
||||
result = await router.call_tool(
|
||||
server_name="issues",
|
||||
tool_name="create_issue",
|
||||
arguments={"title": "Bug fix"}
|
||||
)
|
||||
```
|
||||
|
||||
## Exception Hierarchy
|
||||
|
||||
```
|
||||
MCPError
|
||||
├── MCPConnectionError # Connection failures
|
||||
├── MCPTimeoutError # Operation timeouts
|
||||
├── MCPToolError # Tool execution errors
|
||||
├── MCPServerNotFoundError # Unknown server
|
||||
├── MCPToolNotFoundError # Unknown tool
|
||||
├── MCPCircuitOpenError # Circuit breaker open
|
||||
└── MCPValidationError # Invalid configuration
|
||||
```
|
||||
|
||||
All exceptions include rich context:
|
||||
|
||||
```python
|
||||
except MCPServerNotFoundError as e:
|
||||
print(f"Server: {e.server_name}")
|
||||
print(f"Available: {e.available_servers}")
|
||||
print(f"Suggestion: {e.suggestion}")
|
||||
```
|
||||
|
||||
## REST API Endpoints
|
||||
|
||||
| Method | Endpoint | Description | Auth |
|
||||
|--------|----------|-------------|------|
|
||||
| GET | `/api/v1/mcp/servers` | List all MCP servers | No |
|
||||
| GET | `/api/v1/mcp/servers/{name}/tools` | List server tools | No |
|
||||
| GET | `/api/v1/mcp/tools` | List all tools | No |
|
||||
| GET | `/api/v1/mcp/health` | Health check | No |
|
||||
| POST | `/api/v1/mcp/call` | Execute tool | Superuser |
|
||||
| GET | `/api/v1/mcp/circuit-breakers` | List circuit breakers | No |
|
||||
| POST | `/api/v1/mcp/circuit-breakers/{name}/reset` | Reset breaker | Superuser |
|
||||
| POST | `/api/v1/mcp/servers/{name}/reconnect` | Force reconnect | Superuser |
|
||||
|
||||
### Example: Execute a Tool
|
||||
|
||||
```http
|
||||
POST /api/v1/mcp/call
|
||||
Authorization: Bearer <token>
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"server": "issues",
|
||||
"tool": "create_issue",
|
||||
"arguments": {
|
||||
"title": "New Feature Request",
|
||||
"body": "Please add dark mode support"
|
||||
},
|
||||
"timeout": 30
|
||||
}
|
||||
```
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"issue_id": "12345",
|
||||
"url": "https://gitea.example.com/org/repo/issues/42"
|
||||
},
|
||||
"tool_name": "create_issue",
|
||||
"server_name": "issues",
|
||||
"execution_time_ms": 234.5,
|
||||
"request_id": "550e8400-e29b-41d4-a716-446655440000"
|
||||
}
|
||||
```
|
||||
|
||||
## Usage in Syndarix Agents
|
||||
|
||||
AI agents use the MCP client to execute tools:
|
||||
|
||||
```python
|
||||
class IssueCreatorAgent:
|
||||
def __init__(self, mcp: MCPClientManager):
|
||||
self.mcp = mcp
|
||||
|
||||
async def create_issue(self, title: str, body: str) -> dict:
|
||||
result = await self.mcp.call_tool(
|
||||
server="issues",
|
||||
tool="create_issue",
|
||||
args={"title": title, "body": body}
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
raise AgentError(f"Failed to create issue: {result.error}")
|
||||
|
||||
return result.data
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
The MCP infrastructure is thoroughly tested:
|
||||
|
||||
- **Unit Tests**: `tests/services/mcp/` - Service layer tests
|
||||
- **API Tests**: `tests/api/routes/test_mcp.py` - Endpoint tests
|
||||
|
||||
Run tests:
|
||||
```bash
|
||||
# All MCP tests
|
||||
IS_TEST=True uv run pytest tests/services/mcp/ tests/api/routes/test_mcp.py -v
|
||||
|
||||
# With coverage
|
||||
IS_TEST=True uv run pytest tests/services/mcp/ --cov=app/services/mcp
|
||||
```
|
||||
|
||||
## Configuration Reference
|
||||
|
||||
### MCPServerConfig
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `url` | str | Required | Server URL |
|
||||
| `transport` | str | "http" | Transport type (http, stdio, sse) |
|
||||
| `timeout` | int | 30 | Request timeout (1-600 seconds) |
|
||||
| `retry_attempts` | int | 3 | Max retry attempts (0-10) |
|
||||
| `retry_delay` | float | 1.0 | Initial retry delay (0.1-300 seconds) |
|
||||
| `retry_max_delay` | float | 30.0 | Maximum retry delay |
|
||||
| `circuit_breaker_threshold` | int | 5 | Failures before opening circuit |
|
||||
| `circuit_breaker_timeout` | float | 30.0 | Seconds before trying again |
|
||||
| `enabled` | bool | true | Whether server is enabled |
|
||||
| `description` | str | None | Server description |
|
||||
|
||||
### MCPConfig (Global)
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `mcp_servers` | dict | {} | Server configurations |
|
||||
| `default_timeout` | int | 30 | Default request timeout |
|
||||
| `default_retry_attempts` | int | 3 | Default retry attempts |
|
||||
| `connection_pool_size` | int | 10 | Max connections per server |
|
||||
| `health_check_interval` | int | 30 | Health check interval (seconds) |
|
||||
|
||||
## Files
|
||||
|
||||
| Path | Description |
|
||||
|------|-------------|
|
||||
| `app/services/mcp/__init__.py` | Package exports |
|
||||
| `app/services/mcp/client_manager.py` | Main facade class |
|
||||
| `app/services/mcp/config.py` | Configuration models |
|
||||
| `app/services/mcp/registry.py` | Server registry singleton |
|
||||
| `app/services/mcp/connection.py` | Connection management |
|
||||
| `app/services/mcp/routing.py` | Tool routing and circuit breakers |
|
||||
| `app/services/mcp/exceptions.py` | Exception classes |
|
||||
| `app/api/routes/mcp.py` | REST API endpoints |
|
||||
| `mcp_servers.yaml` | Default configuration |
|
||||
60
backend/mcp_servers.yaml
Normal file
60
backend/mcp_servers.yaml
Normal file
@@ -0,0 +1,60 @@
|
||||
# MCP Server Configuration
|
||||
#
|
||||
# This file defines the MCP servers that the Syndarix backend connects to.
|
||||
# Environment variables can be used with ${VAR:-default} syntax.
|
||||
#
|
||||
# Example:
|
||||
# url: ${MY_SERVER_URL:-http://localhost:8001}
|
||||
#
|
||||
# For development, these servers typically run as separate Docker containers.
|
||||
# See docker-compose.yml for container definitions.
|
||||
|
||||
mcp_servers:
|
||||
# LLM Gateway - Multi-provider AI interactions
|
||||
llm-gateway:
|
||||
url: ${LLM_GATEWAY_URL:-http://localhost:8001}
|
||||
transport: http
|
||||
timeout: 60
|
||||
retry_attempts: 3
|
||||
retry_delay: 1.0
|
||||
retry_max_delay: 30.0
|
||||
circuit_breaker_threshold: 5
|
||||
circuit_breaker_timeout: 30.0
|
||||
enabled: true
|
||||
description: "LLM Gateway for Anthropic, OpenAI, Ollama, and other providers"
|
||||
|
||||
# Knowledge Base - RAG and document retrieval
|
||||
knowledge-base:
|
||||
url: ${KNOWLEDGE_BASE_URL:-http://localhost:8002}
|
||||
transport: http
|
||||
timeout: 30
|
||||
retry_attempts: 3
|
||||
circuit_breaker_threshold: 5
|
||||
enabled: true
|
||||
description: "Knowledge Base with pgvector for semantic search and RAG"
|
||||
|
||||
# Git Operations - Repository management
|
||||
git-ops:
|
||||
url: ${GIT_OPS_URL:-http://localhost:8003}
|
||||
transport: http
|
||||
timeout: 120
|
||||
retry_attempts: 2
|
||||
circuit_breaker_threshold: 3
|
||||
enabled: true
|
||||
description: "Git Operations for clone, commit, push, and repository management"
|
||||
|
||||
# Issues - Issue tracker integration
|
||||
issues:
|
||||
url: ${ISSUES_URL:-http://localhost:8004}
|
||||
transport: http
|
||||
timeout: 30
|
||||
retry_attempts: 3
|
||||
circuit_breaker_threshold: 5
|
||||
enabled: true
|
||||
description: "Issue Tracker integration for Gitea, GitHub, and GitLab"
|
||||
|
||||
# Global defaults
|
||||
default_timeout: 30
|
||||
default_retry_attempts: 3
|
||||
connection_pool_size: 10
|
||||
health_check_interval: 30
|
||||
@@ -53,6 +53,12 @@ dependencies = [
|
||||
# Celery for background task processing (Syndarix agent jobs)
|
||||
"celery[redis]>=5.4.0",
|
||||
"sse-starlette>=3.1.1",
|
||||
# MCP (Model Context Protocol) for AI agent tool integration
|
||||
"mcp>=1.0.0",
|
||||
# Circuit breaker pattern for resilient connections
|
||||
"pybreaker>=1.0.0",
|
||||
# YAML configuration support
|
||||
"pyyaml>=6.0.0",
|
||||
]
|
||||
|
||||
# Development dependencies
|
||||
@@ -151,6 +157,7 @@ unfixable = []
|
||||
"app/alembic/env.py" = ["E402", "F403", "F405"] # Alembic requires specific import order
|
||||
"app/alembic/versions/*.py" = ["E402"] # Migration files have specific structure
|
||||
"tests/**/*.py" = ["S101", "N806", "B017", "N817", "S110", "ASYNC251", "RUF043"] # pytest: asserts, CamelCase fixtures, blind exceptions, try-pass patterns, and async test helpers are intentional
|
||||
"app/services/mcp/*.py" = ["ASYNC109", "S311", "RUF022"] # timeout is config param not asyncio.timeout; random is ok for jitter; __all__ order is intentional for readability
|
||||
"app/models/__init__.py" = ["F401"] # __init__ files re-export modules
|
||||
"app/models/base.py" = ["F401"] # Re-exports Base for use by other models
|
||||
"app/utils/test_utils.py" = ["N806"] # SQLAlchemy session factories use CamelCase convention
|
||||
@@ -268,6 +275,14 @@ ignore_missing_imports = true
|
||||
module = "httpx.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "pybreaker.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "yaml.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
# SQLAlchemy ORM models - Column descriptors cause type confusion
|
||||
[[tool.mypy.overrides]]
|
||||
module = "app.models.*"
|
||||
@@ -307,6 +322,11 @@ disable_error_code = ["assignment", "arg-type", "attr-defined", "unused-ignore"]
|
||||
module = "app.services.oauth_service"
|
||||
disable_error_code = ["assignment", "arg-type", "attr-defined"]
|
||||
|
||||
# MCP services - circuit breaker and httpx client handling
|
||||
[[tool.mypy.overrides]]
|
||||
module = "app.services.mcp.*"
|
||||
disable_error_code = ["attr-defined", "arg-type"]
|
||||
|
||||
# Test utils - Testing patterns
|
||||
[[tool.mypy.overrides]]
|
||||
module = "app.utils.auth_test_utils"
|
||||
|
||||
491
backend/tests/api/routes/test_mcp.py
Normal file
491
backend/tests/api/routes/test_mcp.py
Normal file
@@ -0,0 +1,491 @@
|
||||
"""
|
||||
Tests for MCP API Routes
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import status
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.main import app
|
||||
from app.models.user import User
|
||||
from app.services.mcp import (
|
||||
MCPCircuitOpenError,
|
||||
MCPClientManager,
|
||||
MCPConnectionError,
|
||||
MCPServerNotFoundError,
|
||||
MCPTimeoutError,
|
||||
MCPToolNotFoundError,
|
||||
ServerHealth,
|
||||
)
|
||||
from app.services.mcp.config import MCPServerConfig, TransportType
|
||||
from app.services.mcp.routing import ToolInfo, ToolResult
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mcp_client():
|
||||
"""Create a mock MCP client manager."""
|
||||
client = MagicMock(spec=MCPClientManager)
|
||||
client.is_initialized = True
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_superuser():
|
||||
"""Create a mock superuser."""
|
||||
user = MagicMock(spec=User)
|
||||
user.id = "00000000-0000-0000-0000-000000000001"
|
||||
user.is_superuser = True
|
||||
user.email = "admin@example.com"
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_mcp_client, mock_superuser):
|
||||
"""Create a FastAPI test client with mocked dependencies."""
|
||||
from app.api.routes.mcp import get_mcp_client
|
||||
from app.api.dependencies.permissions import require_superuser
|
||||
|
||||
# Override dependencies
|
||||
async def override_get_mcp_client():
|
||||
return mock_mcp_client
|
||||
|
||||
async def override_require_superuser():
|
||||
return mock_superuser
|
||||
|
||||
app.dependency_overrides[get_mcp_client] = override_get_mcp_client
|
||||
app.dependency_overrides[require_superuser] = override_require_superuser
|
||||
|
||||
with patch("app.main.check_database_health", return_value=True):
|
||||
yield TestClient(app)
|
||||
|
||||
# Clean up
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
class TestListServers:
|
||||
"""Tests for GET /mcp/servers endpoint."""
|
||||
|
||||
def test_list_servers_success(self, client, mock_mcp_client):
|
||||
"""Test listing MCP servers returns correct data."""
|
||||
# Setup mock
|
||||
mock_mcp_client.list_servers.return_value = ["server-1", "server-2"]
|
||||
mock_mcp_client.get_server_config.side_effect = [
|
||||
MCPServerConfig(
|
||||
url="http://server1:8000",
|
||||
timeout=30,
|
||||
enabled=True,
|
||||
transport=TransportType.HTTP,
|
||||
description="Server 1",
|
||||
),
|
||||
MCPServerConfig(
|
||||
url="http://server2:8000",
|
||||
timeout=60,
|
||||
enabled=True,
|
||||
transport=TransportType.SSE,
|
||||
description="Server 2",
|
||||
),
|
||||
]
|
||||
|
||||
response = client.get("/api/v1/mcp/servers")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["total"] == 2
|
||||
assert len(data["servers"]) == 2
|
||||
assert data["servers"][0]["name"] == "server-1"
|
||||
assert data["servers"][0]["url"] == "http://server1:8000"
|
||||
assert data["servers"][1]["name"] == "server-2"
|
||||
assert data["servers"][1]["transport"] == "sse"
|
||||
|
||||
def test_list_servers_empty(self, client, mock_mcp_client):
|
||||
"""Test listing servers when none are registered."""
|
||||
mock_mcp_client.list_servers.return_value = []
|
||||
|
||||
response = client.get("/api/v1/mcp/servers")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["total"] == 0
|
||||
assert data["servers"] == []
|
||||
|
||||
def test_list_servers_handles_not_found(self, client, mock_mcp_client):
|
||||
"""Test that missing server configs are skipped gracefully."""
|
||||
mock_mcp_client.list_servers.return_value = ["server-1", "missing"]
|
||||
mock_mcp_client.get_server_config.side_effect = [
|
||||
MCPServerConfig(url="http://server1:8000"),
|
||||
MCPServerNotFoundError(server_name="missing"),
|
||||
]
|
||||
|
||||
response = client.get("/api/v1/mcp/servers")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
# Should only include the successfully retrieved server
|
||||
assert data["total"] == 1
|
||||
|
||||
|
||||
class TestListServerTools:
|
||||
"""Tests for GET /mcp/servers/{server_name}/tools endpoint."""
|
||||
|
||||
def test_list_server_tools_success(self, client, mock_mcp_client):
|
||||
"""Test listing tools for a specific server."""
|
||||
mock_mcp_client.list_tools = AsyncMock(
|
||||
return_value=[
|
||||
ToolInfo(name="tool1", description="Tool 1", server_name="server-1"),
|
||||
ToolInfo(name="tool2", description="Tool 2", server_name="server-1"),
|
||||
]
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/mcp/servers/server-1/tools")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["total"] == 2
|
||||
assert data["tools"][0]["name"] == "tool1"
|
||||
assert data["tools"][1]["name"] == "tool2"
|
||||
|
||||
def test_list_server_tools_not_found(self, client, mock_mcp_client):
|
||||
"""Test listing tools for non-existent server."""
|
||||
mock_mcp_client.list_tools = AsyncMock(
|
||||
side_effect=MCPServerNotFoundError(server_name="unknown")
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/mcp/servers/unknown/tools")
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
class TestListAllTools:
|
||||
"""Tests for GET /mcp/tools endpoint."""
|
||||
|
||||
def test_list_all_tools_success(self, client, mock_mcp_client):
|
||||
"""Test listing all tools from all servers."""
|
||||
mock_mcp_client.list_all_tools = AsyncMock(
|
||||
return_value=[
|
||||
ToolInfo(name="tool1", server_name="server-1"),
|
||||
ToolInfo(name="tool2", server_name="server-1"),
|
||||
ToolInfo(name="tool3", server_name="server-2"),
|
||||
]
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/mcp/tools")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["total"] == 3
|
||||
|
||||
def test_list_all_tools_empty(self, client, mock_mcp_client):
|
||||
"""Test listing tools when none are available."""
|
||||
mock_mcp_client.list_all_tools = AsyncMock(return_value=[])
|
||||
|
||||
response = client.get("/api/v1/mcp/tools")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["total"] == 0
|
||||
|
||||
|
||||
class TestHealthCheck:
|
||||
"""Tests for GET /mcp/health endpoint."""
|
||||
|
||||
def test_health_check_success(self, client, mock_mcp_client):
|
||||
"""Test health check returns correct data."""
|
||||
mock_mcp_client.health_check = AsyncMock(
|
||||
return_value={
|
||||
"server-1": ServerHealth(
|
||||
name="server-1",
|
||||
healthy=True,
|
||||
state="connected",
|
||||
url="http://server1:8000",
|
||||
tools_count=5,
|
||||
),
|
||||
"server-2": ServerHealth(
|
||||
name="server-2",
|
||||
healthy=False,
|
||||
state="error",
|
||||
url="http://server2:8000",
|
||||
error="Connection refused",
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/mcp/health")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["total"] == 2
|
||||
assert data["healthy_count"] == 1
|
||||
assert data["unhealthy_count"] == 1
|
||||
assert data["servers"]["server-1"]["healthy"] is True
|
||||
assert data["servers"]["server-2"]["healthy"] is False
|
||||
|
||||
def test_health_check_all_healthy(self, client, mock_mcp_client):
|
||||
"""Test health check when all servers are healthy."""
|
||||
mock_mcp_client.health_check = AsyncMock(
|
||||
return_value={
|
||||
"server-1": ServerHealth(
|
||||
name="server-1",
|
||||
healthy=True,
|
||||
state="connected",
|
||||
url="http://server1:8000",
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/mcp/health")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["healthy_count"] == 1
|
||||
assert data["unhealthy_count"] == 0
|
||||
|
||||
|
||||
class TestCallTool:
|
||||
"""Tests for POST /mcp/call endpoint."""
|
||||
|
||||
def test_call_tool_success(self, client, mock_mcp_client):
|
||||
"""Test successful tool execution."""
|
||||
mock_mcp_client.call_tool = AsyncMock(
|
||||
return_value=ToolResult(
|
||||
success=True,
|
||||
data={"result": "ok"},
|
||||
tool_name="test-tool",
|
||||
server_name="server-1",
|
||||
execution_time_ms=123.45,
|
||||
request_id="test-request-id",
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/mcp/call",
|
||||
json={
|
||||
"server": "server-1",
|
||||
"tool": "test-tool",
|
||||
"arguments": {"key": "value"},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"] == {"result": "ok"}
|
||||
assert data["tool_name"] == "test-tool"
|
||||
assert data["server_name"] == "server-1"
|
||||
|
||||
def test_call_tool_with_timeout(self, client, mock_mcp_client):
|
||||
"""Test tool execution with custom timeout."""
|
||||
mock_mcp_client.call_tool = AsyncMock(
|
||||
return_value=ToolResult(success=True, data={})
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/mcp/call",
|
||||
json={
|
||||
"server": "server-1",
|
||||
"tool": "test-tool",
|
||||
"timeout": 60.0,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
mock_mcp_client.call_tool.assert_called_once()
|
||||
call_args = mock_mcp_client.call_tool.call_args
|
||||
assert call_args.kwargs["timeout"] == 60.0
|
||||
|
||||
def test_call_tool_server_not_found(self, client, mock_mcp_client):
|
||||
"""Test tool execution with non-existent server."""
|
||||
mock_mcp_client.call_tool = AsyncMock(
|
||||
side_effect=MCPServerNotFoundError(server_name="unknown")
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/mcp/call",
|
||||
json={"server": "unknown", "tool": "test-tool"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
def test_call_tool_not_found(self, client, mock_mcp_client):
|
||||
"""Test tool execution with non-existent tool."""
|
||||
mock_mcp_client.call_tool = AsyncMock(
|
||||
side_effect=MCPToolNotFoundError(tool_name="unknown")
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/mcp/call",
|
||||
json={"server": "server-1", "tool": "unknown"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
def test_call_tool_timeout(self, client, mock_mcp_client):
|
||||
"""Test tool execution timeout."""
|
||||
mock_mcp_client.call_tool = AsyncMock(
|
||||
side_effect=MCPTimeoutError(
|
||||
"Request timed out",
|
||||
server_name="server-1",
|
||||
timeout_seconds=30.0,
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/mcp/call",
|
||||
json={"server": "server-1", "tool": "slow-tool"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_504_GATEWAY_TIMEOUT
|
||||
|
||||
def test_call_tool_connection_error(self, client, mock_mcp_client):
|
||||
"""Test tool execution with connection failure."""
|
||||
mock_mcp_client.call_tool = AsyncMock(
|
||||
side_effect=MCPConnectionError(
|
||||
"Connection refused",
|
||||
server_name="server-1",
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/mcp/call",
|
||||
json={"server": "server-1", "tool": "test-tool"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_502_BAD_GATEWAY
|
||||
|
||||
def test_call_tool_circuit_open(self, client, mock_mcp_client):
|
||||
"""Test tool execution with open circuit breaker."""
|
||||
mock_mcp_client.call_tool = AsyncMock(
|
||||
side_effect=MCPCircuitOpenError(
|
||||
server_name="server-1",
|
||||
failure_count=5,
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/mcp/call",
|
||||
json={"server": "server-1", "tool": "test-tool"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
|
||||
|
||||
class TestCircuitBreakers:
|
||||
"""Tests for circuit breaker endpoints."""
|
||||
|
||||
def test_list_circuit_breakers(self, client, mock_mcp_client):
|
||||
"""Test listing circuit breaker statuses."""
|
||||
mock_mcp_client.get_circuit_breaker_status.return_value = {
|
||||
"server-1": {"state": "closed", "failure_count": 0},
|
||||
"server-2": {"state": "open", "failure_count": 5},
|
||||
}
|
||||
|
||||
response = client.get("/api/v1/mcp/circuit-breakers")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert len(data["circuit_breakers"]) == 2
|
||||
|
||||
def test_list_circuit_breakers_empty(self, client, mock_mcp_client):
|
||||
"""Test listing when no circuit breakers exist."""
|
||||
mock_mcp_client.get_circuit_breaker_status.return_value = {}
|
||||
|
||||
response = client.get("/api/v1/mcp/circuit-breakers")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["circuit_breakers"] == []
|
||||
|
||||
def test_reset_circuit_breaker_success(self, client, mock_mcp_client):
|
||||
"""Test successfully resetting a circuit breaker."""
|
||||
mock_mcp_client.reset_circuit_breaker = AsyncMock(return_value=True)
|
||||
|
||||
response = client.post("/api/v1/mcp/circuit-breakers/server-1/reset")
|
||||
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
|
||||
def test_reset_circuit_breaker_not_found(self, client, mock_mcp_client):
|
||||
"""Test resetting non-existent circuit breaker."""
|
||||
mock_mcp_client.reset_circuit_breaker = AsyncMock(return_value=False)
|
||||
|
||||
response = client.post("/api/v1/mcp/circuit-breakers/unknown/reset")
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
class TestReconnectServer:
|
||||
"""Tests for POST /mcp/servers/{server_name}/reconnect endpoint."""
|
||||
|
||||
def test_reconnect_success(self, client, mock_mcp_client):
|
||||
"""Test successful server reconnection."""
|
||||
mock_mcp_client.disconnect = AsyncMock()
|
||||
mock_mcp_client.connect = AsyncMock()
|
||||
|
||||
response = client.post("/api/v1/mcp/servers/server-1/reconnect")
|
||||
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
mock_mcp_client.disconnect.assert_called_once_with("server-1")
|
||||
mock_mcp_client.connect.assert_called_once_with("server-1")
|
||||
|
||||
def test_reconnect_server_not_found(self, client, mock_mcp_client):
|
||||
"""Test reconnecting to non-existent server."""
|
||||
mock_mcp_client.disconnect = AsyncMock(
|
||||
side_effect=MCPServerNotFoundError(server_name="unknown")
|
||||
)
|
||||
|
||||
response = client.post("/api/v1/mcp/servers/unknown/reconnect")
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
def test_reconnect_connection_failure(self, client, mock_mcp_client):
|
||||
"""Test reconnection failure."""
|
||||
mock_mcp_client.disconnect = AsyncMock()
|
||||
mock_mcp_client.connect = AsyncMock(
|
||||
side_effect=MCPConnectionError(
|
||||
"Connection refused",
|
||||
server_name="server-1",
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post("/api/v1/mcp/servers/server-1/reconnect")
|
||||
|
||||
assert response.status_code == status.HTTP_502_BAD_GATEWAY
|
||||
|
||||
|
||||
class TestMCPEndpointsEdgeCases:
|
||||
"""Edge case tests for MCP endpoints."""
|
||||
|
||||
def test_servers_content_type(self, client, mock_mcp_client):
|
||||
"""Test that endpoints return JSON content type."""
|
||||
mock_mcp_client.list_servers.return_value = []
|
||||
|
||||
response = client.get("/api/v1/mcp/servers")
|
||||
|
||||
assert "application/json" in response.headers["content-type"]
|
||||
|
||||
def test_call_tool_validation_error(self, client):
|
||||
"""Test that invalid request body returns validation error."""
|
||||
response = client.post(
|
||||
"/api/v1/mcp/call",
|
||||
json={}, # Missing required fields
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
def test_call_tool_missing_server(self, client):
|
||||
"""Test that missing server field returns validation error."""
|
||||
response = client.post(
|
||||
"/api/v1/mcp/call",
|
||||
json={"tool": "test-tool"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
def test_call_tool_missing_tool(self, client):
|
||||
"""Test that missing tool field returns validation error."""
|
||||
response = client.post(
|
||||
"/api/v1/mcp/call",
|
||||
json={"server": "server-1"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
1
backend/tests/services/mcp/__init__.py
Normal file
1
backend/tests/services/mcp/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""MCP Service Tests Package."""
|
||||
395
backend/tests/services/mcp/test_client_manager.py
Normal file
395
backend/tests/services/mcp/test_client_manager.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""
|
||||
Tests for MCP Client Manager
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.mcp.client_manager import (
|
||||
MCPClientManager,
|
||||
ServerHealth,
|
||||
get_mcp_client,
|
||||
reset_mcp_client,
|
||||
shutdown_mcp_client,
|
||||
)
|
||||
from app.services.mcp.config import MCPConfig, MCPServerConfig
|
||||
from app.services.mcp.connection import ConnectionState
|
||||
from app.services.mcp.exceptions import MCPServerNotFoundError
|
||||
from app.services.mcp.registry import MCPServerRegistry
|
||||
from app.services.mcp.routing import ToolInfo, ToolResult
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_registry():
|
||||
"""Reset the singleton registry before and after each test."""
|
||||
MCPServerRegistry.reset_instance()
|
||||
reset_mcp_client()
|
||||
yield
|
||||
MCPServerRegistry.reset_instance()
|
||||
reset_mcp_client()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_config():
|
||||
"""Create a sample MCP configuration."""
|
||||
return MCPConfig(
|
||||
mcp_servers={
|
||||
"server-1": MCPServerConfig(
|
||||
url="http://server1:8000",
|
||||
timeout=30,
|
||||
enabled=True,
|
||||
),
|
||||
"server-2": MCPServerConfig(
|
||||
url="http://server2:8000",
|
||||
timeout=60,
|
||||
enabled=True,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TestServerHealth:
|
||||
"""Tests for ServerHealth dataclass."""
|
||||
|
||||
def test_healthy_server(self):
|
||||
"""Test healthy server status."""
|
||||
health = ServerHealth(
|
||||
name="test-server",
|
||||
healthy=True,
|
||||
state="connected",
|
||||
url="http://test:8000",
|
||||
tools_count=5,
|
||||
)
|
||||
assert health.healthy is True
|
||||
assert health.error is None
|
||||
assert health.tools_count == 5
|
||||
|
||||
def test_unhealthy_server(self):
|
||||
"""Test unhealthy server status."""
|
||||
health = ServerHealth(
|
||||
name="test-server",
|
||||
healthy=False,
|
||||
state="error",
|
||||
url="http://test:8000",
|
||||
error="Connection refused",
|
||||
)
|
||||
assert health.healthy is False
|
||||
assert health.error == "Connection refused"
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test converting to dictionary."""
|
||||
health = ServerHealth(
|
||||
name="test-server",
|
||||
healthy=True,
|
||||
state="connected",
|
||||
url="http://test:8000",
|
||||
tools_count=3,
|
||||
)
|
||||
d = health.to_dict()
|
||||
|
||||
assert d["name"] == "test-server"
|
||||
assert d["healthy"] is True
|
||||
assert d["state"] == "connected"
|
||||
assert d["url"] == "http://test:8000"
|
||||
assert d["tools_count"] == 3
|
||||
|
||||
|
||||
class TestMCPClientManager:
|
||||
"""Tests for MCPClientManager class."""
|
||||
|
||||
def test_initial_state(self, reset_registry):
|
||||
"""Test initial manager state."""
|
||||
manager = MCPClientManager()
|
||||
assert manager.is_initialized is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize(self, reset_registry, sample_config):
|
||||
"""Test manager initialization."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
|
||||
with patch.object(manager._pool, "get_connection") as mock_get_conn:
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.is_connected = True
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
with patch.object(manager, "_router") as mock_router:
|
||||
mock_router.discover_tools = AsyncMock()
|
||||
|
||||
await manager.initialize()
|
||||
|
||||
assert manager.is_initialized is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown(self, reset_registry, sample_config):
|
||||
"""Test manager shutdown."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
manager._initialized = True
|
||||
|
||||
with patch.object(manager._pool, "close_all") as mock_close:
|
||||
mock_close.return_value = None
|
||||
await manager.shutdown()
|
||||
|
||||
assert manager.is_initialized is False
|
||||
mock_close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect(self, reset_registry, sample_config):
|
||||
"""Test connecting to specific server."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
|
||||
with patch.object(manager._pool, "get_connection") as mock_get_conn:
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.is_connected = True
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
await manager.connect("server-1")
|
||||
|
||||
mock_get_conn.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_unknown_server(self, reset_registry, sample_config):
|
||||
"""Test connecting to unknown server raises error."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
|
||||
with pytest.raises(MCPServerNotFoundError):
|
||||
await manager.connect("unknown-server")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect(self, reset_registry, sample_config):
|
||||
"""Test disconnecting from server."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
|
||||
with patch.object(manager._pool, "close_connection") as mock_close:
|
||||
await manager.disconnect("server-1")
|
||||
mock_close.assert_called_once_with("server-1")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_all(self, reset_registry, sample_config):
|
||||
"""Test disconnecting from all servers."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
|
||||
with patch.object(manager._pool, "close_all") as mock_close:
|
||||
await manager.disconnect_all()
|
||||
mock_close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool(self, reset_registry, sample_config):
|
||||
"""Test calling a tool."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
manager._initialized = True
|
||||
|
||||
expected_result = ToolResult(
|
||||
success=True,
|
||||
data={"id": "123"},
|
||||
tool_name="create_issue",
|
||||
server_name="server-1",
|
||||
)
|
||||
|
||||
mock_router = MagicMock()
|
||||
mock_router.call_tool = AsyncMock(return_value=expected_result)
|
||||
manager._router = mock_router
|
||||
|
||||
result = await manager.call_tool(
|
||||
server="server-1",
|
||||
tool="create_issue",
|
||||
args={"title": "Test"},
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data == {"id": "123"}
|
||||
mock_router.call_tool.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_tool(self, reset_registry, sample_config):
|
||||
"""Test routing a tool call."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
manager._initialized = True
|
||||
|
||||
expected_result = ToolResult(
|
||||
success=True,
|
||||
data={"result": "ok"},
|
||||
tool_name="auto_tool",
|
||||
server_name="server-2",
|
||||
)
|
||||
|
||||
mock_router = MagicMock()
|
||||
mock_router.route_tool = AsyncMock(return_value=expected_result)
|
||||
manager._router = mock_router
|
||||
|
||||
result = await manager.route_tool(
|
||||
tool="auto_tool",
|
||||
args={"key": "value"},
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.server_name == "server-2"
|
||||
mock_router.route_tool.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools(self, reset_registry, sample_config):
|
||||
"""Test listing tools for a server."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
|
||||
# Set up capabilities in registry
|
||||
manager._registry.set_capabilities(
|
||||
"server-1",
|
||||
tools=[
|
||||
{"name": "tool1", "description": "Tool 1"},
|
||||
{"name": "tool2", "description": "Tool 2"},
|
||||
],
|
||||
)
|
||||
|
||||
tools = await manager.list_tools("server-1")
|
||||
|
||||
assert len(tools) == 2
|
||||
assert tools[0].name == "tool1"
|
||||
assert tools[1].name == "tool2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_all_tools(self, reset_registry, sample_config):
|
||||
"""Test listing all tools from all servers."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
manager._initialized = True
|
||||
|
||||
expected_tools = [
|
||||
ToolInfo(name="tool1", server_name="server-1"),
|
||||
ToolInfo(name="tool2", server_name="server-2"),
|
||||
]
|
||||
|
||||
mock_router = MagicMock()
|
||||
mock_router.list_all_tools = AsyncMock(return_value=expected_tools)
|
||||
manager._router = mock_router
|
||||
|
||||
tools = await manager.list_all_tools()
|
||||
|
||||
assert len(tools) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check(self, reset_registry, sample_config):
|
||||
"""Test health check on all servers."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
|
||||
with patch.object(manager._pool, "get_status") as mock_status:
|
||||
mock_status.return_value = {
|
||||
"server-1": {"state": "connected"},
|
||||
"server-2": {"state": "disconnected"},
|
||||
}
|
||||
|
||||
with patch.object(manager._pool, "health_check_all") as mock_health:
|
||||
mock_health.return_value = {
|
||||
"server-1": True,
|
||||
"server-2": False,
|
||||
}
|
||||
|
||||
health = await manager.health_check()
|
||||
|
||||
assert "server-1" in health
|
||||
assert "server-2" in health
|
||||
assert health["server-1"].healthy is True
|
||||
assert health["server-2"].healthy is False
|
||||
|
||||
def test_list_servers(self, reset_registry, sample_config):
|
||||
"""Test listing registered servers."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
servers = manager.list_servers()
|
||||
|
||||
assert "server-1" in servers
|
||||
assert "server-2" in servers
|
||||
|
||||
def test_list_enabled_servers(self, reset_registry, sample_config):
|
||||
"""Test listing enabled servers."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
servers = manager.list_enabled_servers()
|
||||
|
||||
assert "server-1" in servers
|
||||
assert "server-2" in servers
|
||||
|
||||
def test_get_server_config(self, reset_registry, sample_config):
|
||||
"""Test getting server configuration."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
|
||||
config = manager.get_server_config("server-1")
|
||||
assert config.url == "http://server1:8000"
|
||||
assert config.timeout == 30
|
||||
|
||||
def test_get_server_config_not_found(self, reset_registry, sample_config):
|
||||
"""Test getting unknown server config raises error."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
|
||||
with pytest.raises(MCPServerNotFoundError):
|
||||
manager.get_server_config("unknown")
|
||||
|
||||
def test_register_server(self, reset_registry, sample_config):
|
||||
"""Test registering new server at runtime."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
|
||||
new_config = MCPServerConfig(url="http://new:8000")
|
||||
manager.register_server("new-server", new_config)
|
||||
|
||||
assert "new-server" in manager.list_servers()
|
||||
|
||||
def test_unregister_server(self, reset_registry, sample_config):
|
||||
"""Test unregistering a server."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
|
||||
result = manager.unregister_server("server-1")
|
||||
assert result is True
|
||||
assert "server-1" not in manager.list_servers()
|
||||
|
||||
# Unregistering non-existent returns False
|
||||
result = manager.unregister_server("nonexistent")
|
||||
assert result is False
|
||||
|
||||
def test_circuit_breaker_status(self, reset_registry, sample_config):
|
||||
"""Test getting circuit breaker status."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
|
||||
# No router yet
|
||||
status = manager.get_circuit_breaker_status()
|
||||
assert status == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_circuit_breaker(self, reset_registry, sample_config):
|
||||
"""Test resetting circuit breaker."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
|
||||
# No router yet
|
||||
result = await manager.reset_circuit_breaker("server-1")
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestModuleLevelFunctions:
|
||||
"""Tests for module-level convenience functions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_mcp_client_creates_singleton(self, reset_registry):
|
||||
"""Test get_mcp_client creates and returns singleton."""
|
||||
with patch(
|
||||
"app.services.mcp.client_manager.MCPClientManager.initialize"
|
||||
) as mock_init:
|
||||
mock_init.return_value = None
|
||||
|
||||
client1 = await get_mcp_client()
|
||||
client2 = await get_mcp_client()
|
||||
|
||||
assert client1 is client2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown_mcp_client(self, reset_registry):
|
||||
"""Test shutting down the global client."""
|
||||
with patch(
|
||||
"app.services.mcp.client_manager.MCPClientManager.initialize"
|
||||
) as mock_init:
|
||||
mock_init.return_value = None
|
||||
|
||||
client = await get_mcp_client()
|
||||
|
||||
with patch.object(client, "shutdown") as mock_shutdown:
|
||||
mock_shutdown.return_value = None
|
||||
await shutdown_mcp_client()
|
||||
|
||||
def test_reset_mcp_client(self, reset_registry):
|
||||
"""Test resetting the global client."""
|
||||
reset_mcp_client()
|
||||
# Should not raise
|
||||
319
backend/tests/services/mcp/test_config.py
Normal file
319
backend/tests/services/mcp/test_config.py
Normal file
@@ -0,0 +1,319 @@
|
||||
"""
|
||||
Tests for MCP Configuration System
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from app.services.mcp.config import (
|
||||
MCPConfig,
|
||||
MCPServerConfig,
|
||||
TransportType,
|
||||
create_default_config,
|
||||
load_mcp_config,
|
||||
)
|
||||
|
||||
|
||||
class TestTransportType:
|
||||
"""Tests for TransportType enum."""
|
||||
|
||||
def test_transport_types(self):
|
||||
"""Test that all transport types are defined."""
|
||||
assert TransportType.HTTP == "http"
|
||||
assert TransportType.STDIO == "stdio"
|
||||
assert TransportType.SSE == "sse"
|
||||
|
||||
def test_transport_type_from_string(self):
|
||||
"""Test creating transport type from string."""
|
||||
assert TransportType("http") == TransportType.HTTP
|
||||
assert TransportType("stdio") == TransportType.STDIO
|
||||
assert TransportType("sse") == TransportType.SSE
|
||||
|
||||
|
||||
class TestMCPServerConfig:
|
||||
"""Tests for MCPServerConfig model."""
|
||||
|
||||
def test_minimal_config(self):
|
||||
"""Test creating config with only required fields."""
|
||||
config = MCPServerConfig(url="http://localhost:8000")
|
||||
assert config.url == "http://localhost:8000"
|
||||
assert config.transport == TransportType.HTTP
|
||||
assert config.timeout == 30
|
||||
assert config.retry_attempts == 3
|
||||
assert config.enabled is True
|
||||
|
||||
def test_full_config(self):
|
||||
"""Test creating config with all fields."""
|
||||
config = MCPServerConfig(
|
||||
url="http://localhost:8000",
|
||||
transport=TransportType.SSE,
|
||||
timeout=60,
|
||||
retry_attempts=5,
|
||||
retry_delay=2.0,
|
||||
retry_max_delay=60.0,
|
||||
circuit_breaker_threshold=10,
|
||||
circuit_breaker_timeout=60.0,
|
||||
enabled=False,
|
||||
description="Test server",
|
||||
)
|
||||
assert config.timeout == 60
|
||||
assert config.transport == TransportType.SSE
|
||||
assert config.retry_attempts == 5
|
||||
assert config.retry_delay == 2.0
|
||||
assert config.retry_max_delay == 60.0
|
||||
assert config.circuit_breaker_threshold == 10
|
||||
assert config.circuit_breaker_timeout == 60.0
|
||||
assert config.enabled is False
|
||||
assert config.description == "Test server"
|
||||
|
||||
def test_env_var_expansion_simple(self):
|
||||
"""Test simple environment variable expansion."""
|
||||
os.environ["TEST_SERVER_URL"] = "http://test-server:9000"
|
||||
try:
|
||||
config = MCPServerConfig(url="${TEST_SERVER_URL}")
|
||||
assert config.url == "http://test-server:9000"
|
||||
finally:
|
||||
del os.environ["TEST_SERVER_URL"]
|
||||
|
||||
def test_env_var_expansion_with_default(self):
|
||||
"""Test environment variable expansion with default."""
|
||||
# Ensure env var is not set
|
||||
os.environ.pop("NONEXISTENT_URL", None)
|
||||
config = MCPServerConfig(url="${NONEXISTENT_URL:-http://default:8000}")
|
||||
assert config.url == "http://default:8000"
|
||||
|
||||
def test_env_var_expansion_override_default(self):
|
||||
"""Test environment variable override of default."""
|
||||
os.environ["TEST_OVERRIDE_URL"] = "http://override:9000"
|
||||
try:
|
||||
config = MCPServerConfig(url="${TEST_OVERRIDE_URL:-http://default:8000}")
|
||||
assert config.url == "http://override:9000"
|
||||
finally:
|
||||
del os.environ["TEST_OVERRIDE_URL"]
|
||||
|
||||
def test_timeout_validation(self):
|
||||
"""Test timeout validation bounds."""
|
||||
# Valid bounds
|
||||
config = MCPServerConfig(url="http://localhost", timeout=1)
|
||||
assert config.timeout == 1
|
||||
|
||||
config = MCPServerConfig(url="http://localhost", timeout=600)
|
||||
assert config.timeout == 600
|
||||
|
||||
# Invalid bounds
|
||||
with pytest.raises(ValueError):
|
||||
MCPServerConfig(url="http://localhost", timeout=0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
MCPServerConfig(url="http://localhost", timeout=601)
|
||||
|
||||
def test_retry_attempts_validation(self):
|
||||
"""Test retry attempts validation bounds."""
|
||||
config = MCPServerConfig(url="http://localhost", retry_attempts=0)
|
||||
assert config.retry_attempts == 0
|
||||
|
||||
config = MCPServerConfig(url="http://localhost", retry_attempts=10)
|
||||
assert config.retry_attempts == 10
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
MCPServerConfig(url="http://localhost", retry_attempts=-1)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
MCPServerConfig(url="http://localhost", retry_attempts=11)
|
||||
|
||||
|
||||
class TestMCPConfig:
|
||||
"""Tests for MCPConfig model."""
|
||||
|
||||
def test_empty_config(self):
|
||||
"""Test creating empty config."""
|
||||
config = MCPConfig()
|
||||
assert config.mcp_servers == {}
|
||||
assert config.default_timeout == 30
|
||||
assert config.default_retry_attempts == 3
|
||||
assert config.connection_pool_size == 10
|
||||
assert config.health_check_interval == 30
|
||||
|
||||
def test_config_with_servers(self):
|
||||
"""Test creating config with servers."""
|
||||
config = MCPConfig(
|
||||
mcp_servers={
|
||||
"server-1": MCPServerConfig(url="http://server1:8000"),
|
||||
"server-2": MCPServerConfig(url="http://server2:8000"),
|
||||
}
|
||||
)
|
||||
assert len(config.mcp_servers) == 2
|
||||
assert "server-1" in config.mcp_servers
|
||||
assert "server-2" in config.mcp_servers
|
||||
|
||||
def test_get_server(self):
|
||||
"""Test getting server by name."""
|
||||
config = MCPConfig(
|
||||
mcp_servers={
|
||||
"server-1": MCPServerConfig(url="http://server1:8000"),
|
||||
}
|
||||
)
|
||||
server = config.get_server("server-1")
|
||||
assert server is not None
|
||||
assert server.url == "http://server1:8000"
|
||||
|
||||
missing = config.get_server("nonexistent")
|
||||
assert missing is None
|
||||
|
||||
def test_get_enabled_servers(self):
|
||||
"""Test getting only enabled servers."""
|
||||
config = MCPConfig(
|
||||
mcp_servers={
|
||||
"enabled-1": MCPServerConfig(url="http://e1:8000", enabled=True),
|
||||
"disabled-1": MCPServerConfig(url="http://d1:8000", enabled=False),
|
||||
"enabled-2": MCPServerConfig(url="http://e2:8000", enabled=True),
|
||||
}
|
||||
)
|
||||
enabled = config.get_enabled_servers()
|
||||
assert len(enabled) == 2
|
||||
assert "enabled-1" in enabled
|
||||
assert "enabled-2" in enabled
|
||||
assert "disabled-1" not in enabled
|
||||
|
||||
def test_list_server_names(self):
|
||||
"""Test listing server names."""
|
||||
config = MCPConfig(
|
||||
mcp_servers={
|
||||
"server-a": MCPServerConfig(url="http://a:8000"),
|
||||
"server-b": MCPServerConfig(url="http://b:8000"),
|
||||
}
|
||||
)
|
||||
names = config.list_server_names()
|
||||
assert sorted(names) == ["server-a", "server-b"]
|
||||
|
||||
def test_from_dict(self):
|
||||
"""Test creating config from dictionary."""
|
||||
data = {
|
||||
"mcp_servers": {
|
||||
"test-server": {
|
||||
"url": "http://test:8000",
|
||||
"timeout": 45,
|
||||
}
|
||||
},
|
||||
"default_timeout": 60,
|
||||
}
|
||||
config = MCPConfig.from_dict(data)
|
||||
assert config.default_timeout == 60
|
||||
assert config.mcp_servers["test-server"].timeout == 45
|
||||
|
||||
def test_from_yaml(self):
|
||||
"""Test loading config from YAML file."""
|
||||
yaml_content = """
|
||||
mcp_servers:
|
||||
test-server:
|
||||
url: http://test:8000
|
||||
timeout: 45
|
||||
transport: http
|
||||
enabled: true
|
||||
default_timeout: 60
|
||||
connection_pool_size: 20
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".yaml", delete=False
|
||||
) as f:
|
||||
f.write(yaml_content)
|
||||
f.flush()
|
||||
|
||||
try:
|
||||
config = MCPConfig.from_yaml(f.name)
|
||||
assert config.default_timeout == 60
|
||||
assert config.connection_pool_size == 20
|
||||
assert "test-server" in config.mcp_servers
|
||||
assert config.mcp_servers["test-server"].timeout == 45
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
|
||||
def test_from_yaml_file_not_found(self):
|
||||
"""Test error when YAML file not found."""
|
||||
with pytest.raises(FileNotFoundError):
|
||||
MCPConfig.from_yaml("/nonexistent/path/config.yaml")
|
||||
|
||||
|
||||
class TestLoadMCPConfig:
|
||||
"""Tests for load_mcp_config function."""
|
||||
|
||||
def test_load_with_explicit_path(self):
|
||||
"""Test loading config with explicit path."""
|
||||
yaml_content = """
|
||||
mcp_servers:
|
||||
explicit-server:
|
||||
url: http://explicit:8000
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".yaml", delete=False
|
||||
) as f:
|
||||
f.write(yaml_content)
|
||||
f.flush()
|
||||
|
||||
try:
|
||||
config = load_mcp_config(f.name)
|
||||
assert "explicit-server" in config.mcp_servers
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
|
||||
def test_load_with_env_var(self):
|
||||
"""Test loading config from environment variable path."""
|
||||
yaml_content = """
|
||||
mcp_servers:
|
||||
env-server:
|
||||
url: http://env:8000
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".yaml", delete=False
|
||||
) as f:
|
||||
f.write(yaml_content)
|
||||
f.flush()
|
||||
|
||||
os.environ["MCP_CONFIG_PATH"] = f.name
|
||||
try:
|
||||
config = load_mcp_config()
|
||||
assert "env-server" in config.mcp_servers
|
||||
finally:
|
||||
del os.environ["MCP_CONFIG_PATH"]
|
||||
os.unlink(f.name)
|
||||
|
||||
def test_load_returns_empty_config_if_missing(self):
|
||||
"""Test that missing file returns empty config."""
|
||||
os.environ.pop("MCP_CONFIG_PATH", None)
|
||||
config = load_mcp_config("/nonexistent/path/config.yaml")
|
||||
assert config.mcp_servers == {}
|
||||
|
||||
|
||||
class TestCreateDefaultConfig:
|
||||
"""Tests for create_default_config function."""
|
||||
|
||||
def test_creates_standard_servers(self):
|
||||
"""Test that default config has standard servers."""
|
||||
config = create_default_config()
|
||||
|
||||
assert "llm-gateway" in config.mcp_servers
|
||||
assert "knowledge-base" in config.mcp_servers
|
||||
assert "git-ops" in config.mcp_servers
|
||||
assert "issues" in config.mcp_servers
|
||||
|
||||
def test_servers_have_correct_defaults(self):
|
||||
"""Test that servers have correct default values."""
|
||||
config = create_default_config()
|
||||
|
||||
llm = config.mcp_servers["llm-gateway"]
|
||||
assert llm.timeout == 60 # LLM has longer timeout
|
||||
assert llm.transport == TransportType.HTTP
|
||||
|
||||
git = config.mcp_servers["git-ops"]
|
||||
assert git.timeout == 120 # Git ops has longest timeout
|
||||
|
||||
def test_servers_are_enabled(self):
|
||||
"""Test that all default servers are enabled."""
|
||||
config = create_default_config()
|
||||
|
||||
for name, server in config.mcp_servers.items():
|
||||
assert server.enabled is True, f"Server {name} should be enabled"
|
||||
405
backend/tests/services/mcp/test_connection.py
Normal file
405
backend/tests/services/mcp/test_connection.py
Normal file
@@ -0,0 +1,405 @@
|
||||
"""
|
||||
Tests for MCP Connection Management
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from app.services.mcp.config import MCPServerConfig, TransportType
|
||||
from app.services.mcp.connection import (
|
||||
ConnectionPool,
|
||||
ConnectionState,
|
||||
MCPConnection,
|
||||
)
|
||||
from app.services.mcp.exceptions import MCPConnectionError, MCPTimeoutError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server_config():
|
||||
"""Create a sample server configuration."""
|
||||
return MCPServerConfig(
|
||||
url="http://localhost:8000",
|
||||
transport=TransportType.HTTP,
|
||||
timeout=30,
|
||||
retry_attempts=3,
|
||||
retry_delay=0.1, # Short delay for tests
|
||||
retry_max_delay=1.0,
|
||||
)
|
||||
|
||||
|
||||
class TestConnectionState:
|
||||
"""Tests for ConnectionState enum."""
|
||||
|
||||
def test_connection_states(self):
|
||||
"""Test all connection states are defined."""
|
||||
assert ConnectionState.DISCONNECTED == "disconnected"
|
||||
assert ConnectionState.CONNECTING == "connecting"
|
||||
assert ConnectionState.CONNECTED == "connected"
|
||||
assert ConnectionState.RECONNECTING == "reconnecting"
|
||||
assert ConnectionState.ERROR == "error"
|
||||
|
||||
|
||||
class TestMCPConnection:
|
||||
"""Tests for MCPConnection class."""
|
||||
|
||||
def test_initial_state(self, server_config):
|
||||
"""Test initial connection state."""
|
||||
conn = MCPConnection("test-server", server_config)
|
||||
assert conn.server_name == "test-server"
|
||||
assert conn.config == server_config
|
||||
assert conn.state == ConnectionState.DISCONNECTED
|
||||
assert conn.is_connected is False
|
||||
assert conn.last_error is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_success(self, server_config):
|
||||
"""Test successful connection."""
|
||||
conn = MCPConnection("test-server", server_config)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
await conn.connect()
|
||||
|
||||
assert conn.state == ConnectionState.CONNECTED
|
||||
assert conn.is_connected is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_404_capabilities_ok(self, server_config):
|
||||
"""Test connection succeeds even if /mcp/capabilities returns 404."""
|
||||
conn = MCPConnection("test-server", server_config)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 404
|
||||
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
await conn.connect()
|
||||
|
||||
assert conn.state == ConnectionState.CONNECTED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_failure_with_retry(self, server_config):
|
||||
"""Test connection failure with retries."""
|
||||
# Reduce retry attempts for faster test
|
||||
server_config.retry_attempts = 2
|
||||
server_config.retry_delay = 0.01
|
||||
conn = MCPConnection("test-server", server_config)
|
||||
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(
|
||||
side_effect=httpx.ConnectError("Connection refused")
|
||||
)
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
with pytest.raises(MCPConnectionError) as exc_info:
|
||||
await conn.connect()
|
||||
|
||||
assert conn.state == ConnectionState.ERROR
|
||||
assert "Failed to connect after" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect(self, server_config):
|
||||
"""Test disconnection."""
|
||||
conn = MCPConnection("test-server", server_config)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.aclose = AsyncMock()
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
await conn.connect()
|
||||
assert conn.is_connected is True
|
||||
|
||||
await conn.disconnect()
|
||||
assert conn.state == ConnectionState.DISCONNECTED
|
||||
assert conn.is_connected is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect(self, server_config):
|
||||
"""Test reconnection."""
|
||||
conn = MCPConnection("test-server", server_config)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.aclose = AsyncMock()
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
await conn.connect()
|
||||
assert conn.is_connected is True
|
||||
|
||||
await conn.reconnect()
|
||||
assert conn.is_connected is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_success(self, server_config):
|
||||
"""Test successful health check."""
|
||||
conn = MCPConnection("test-server", server_config)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
await conn.connect()
|
||||
healthy = await conn.health_check()
|
||||
|
||||
assert healthy is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_disconnected(self, server_config):
|
||||
"""Test health check when disconnected."""
|
||||
conn = MCPConnection("test-server", server_config)
|
||||
healthy = await conn.health_check()
|
||||
assert healthy is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_request_get(self, server_config):
|
||||
"""Test executing GET request."""
|
||||
conn = MCPConnection("test-server", server_config)
|
||||
|
||||
mock_connect_response = MagicMock()
|
||||
mock_connect_response.status_code = 200
|
||||
|
||||
mock_request_response = MagicMock()
|
||||
mock_request_response.status_code = 200
|
||||
mock_request_response.json.return_value = {"tools": []}
|
||||
mock_request_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(
|
||||
side_effect=[mock_connect_response, mock_request_response]
|
||||
)
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
await conn.connect()
|
||||
result = await conn.execute_request("GET", "/mcp/tools")
|
||||
|
||||
assert result == {"tools": []}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_request_post(self, server_config):
|
||||
"""Test executing POST request."""
|
||||
conn = MCPConnection("test-server", server_config)
|
||||
|
||||
mock_connect_response = MagicMock()
|
||||
mock_connect_response.status_code = 200
|
||||
|
||||
mock_request_response = MagicMock()
|
||||
mock_request_response.status_code = 200
|
||||
mock_request_response.json.return_value = {"result": "success"}
|
||||
mock_request_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_connect_response)
|
||||
mock_client.post = AsyncMock(return_value=mock_request_response)
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
await conn.connect()
|
||||
result = await conn.execute_request(
|
||||
"POST", "/mcp", data={"method": "test"}
|
||||
)
|
||||
|
||||
assert result == {"result": "success"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_request_timeout(self, server_config):
|
||||
"""Test request timeout."""
|
||||
conn = MCPConnection("test-server", server_config)
|
||||
|
||||
mock_connect_response = MagicMock()
|
||||
mock_connect_response.status_code = 200
|
||||
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(
|
||||
side_effect=[
|
||||
mock_connect_response,
|
||||
httpx.TimeoutException("Request timeout"),
|
||||
]
|
||||
)
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
await conn.connect()
|
||||
|
||||
with pytest.raises(MCPTimeoutError) as exc_info:
|
||||
await conn.execute_request("GET", "/slow-endpoint")
|
||||
|
||||
assert "timeout" in str(exc_info.value).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_request_not_connected(self, server_config):
|
||||
"""Test request when not connected."""
|
||||
conn = MCPConnection("test-server", server_config)
|
||||
|
||||
with pytest.raises(MCPConnectionError) as exc_info:
|
||||
await conn.execute_request("GET", "/test")
|
||||
|
||||
assert "Not connected" in str(exc_info.value)
|
||||
|
||||
def test_backoff_delay_calculation(self, server_config):
|
||||
"""Test exponential backoff delay calculation."""
|
||||
conn = MCPConnection("test-server", server_config)
|
||||
|
||||
# First attempt
|
||||
conn._connection_attempts = 1
|
||||
delay1 = conn._calculate_backoff_delay()
|
||||
|
||||
# Second attempt
|
||||
conn._connection_attempts = 2
|
||||
delay2 = conn._calculate_backoff_delay()
|
||||
|
||||
# Third attempt
|
||||
conn._connection_attempts = 3
|
||||
delay3 = conn._calculate_backoff_delay()
|
||||
|
||||
# Delays should generally increase (with some jitter)
|
||||
# Base is 0.1, so rough expectations:
|
||||
# Attempt 1: ~0.1s
|
||||
# Attempt 2: ~0.2s
|
||||
# Attempt 3: ~0.4s
|
||||
assert delay1 > 0
|
||||
assert delay2 > delay1 * 0.5 # Allow for jitter
|
||||
assert delay3 <= server_config.retry_max_delay * 1.25 # Within max + jitter
|
||||
|
||||
|
||||
class TestConnectionPool:
|
||||
"""Tests for ConnectionPool class."""
|
||||
|
||||
@pytest.fixture
|
||||
def pool(self):
|
||||
"""Create a connection pool."""
|
||||
return ConnectionPool(max_connections_per_server=5)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_connection_creates_new(self, pool, server_config):
|
||||
"""Test getting connection creates new one if not exists."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
conn = await pool.get_connection("test-server", server_config)
|
||||
|
||||
assert conn is not None
|
||||
assert conn.is_connected is True
|
||||
assert conn.server_name == "test-server"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_connection_reuses_existing(self, pool, server_config):
|
||||
"""Test getting connection reuses existing one."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
conn1 = await pool.get_connection("test-server", server_config)
|
||||
conn2 = await pool.get_connection("test-server", server_config)
|
||||
|
||||
assert conn1 is conn2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_connection(self, pool, server_config):
|
||||
"""Test closing a connection."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.aclose = AsyncMock()
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
await pool.get_connection("test-server", server_config)
|
||||
await pool.close_connection("test-server")
|
||||
|
||||
assert "test-server" not in pool._connections
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_all(self, pool, server_config):
|
||||
"""Test closing all connections."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.aclose = AsyncMock()
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
config2 = MCPServerConfig(url="http://server2:8000")
|
||||
await pool.get_connection("server-1", server_config)
|
||||
await pool.get_connection("server-2", config2)
|
||||
|
||||
assert len(pool._connections) == 2
|
||||
|
||||
await pool.close_all()
|
||||
|
||||
assert len(pool._connections) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_all(self, pool, server_config):
|
||||
"""Test health check on all connections."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
await pool.get_connection("test-server", server_config)
|
||||
results = await pool.health_check_all()
|
||||
|
||||
assert "test-server" in results
|
||||
assert results["test-server"] is True
|
||||
|
||||
def test_get_status(self, pool, server_config):
|
||||
"""Test getting pool status."""
|
||||
status = pool.get_status()
|
||||
assert status == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_context_manager(self, pool, server_config):
|
||||
"""Test connection context manager."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
async with pool.connection("test-server", server_config) as conn:
|
||||
assert conn.is_connected is True
|
||||
assert conn.server_name == "test-server"
|
||||
259
backend/tests/services/mcp/test_exceptions.py
Normal file
259
backend/tests/services/mcp/test_exceptions.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""
|
||||
Tests for MCP Exception Classes
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.mcp.exceptions import (
|
||||
MCPCircuitOpenError,
|
||||
MCPConnectionError,
|
||||
MCPError,
|
||||
MCPServerNotFoundError,
|
||||
MCPTimeoutError,
|
||||
MCPToolError,
|
||||
MCPToolNotFoundError,
|
||||
MCPValidationError,
|
||||
)
|
||||
|
||||
|
||||
class TestMCPError:
|
||||
"""Tests for base MCPError class."""
|
||||
|
||||
def test_basic_error(self):
|
||||
"""Test basic error creation."""
|
||||
error = MCPError("Test error")
|
||||
assert error.message == "Test error"
|
||||
assert error.server_name is None
|
||||
assert error.details == {}
|
||||
assert str(error) == "Test error"
|
||||
|
||||
def test_error_with_server_name(self):
|
||||
"""Test error with server name."""
|
||||
error = MCPError("Test error", server_name="test-server")
|
||||
assert error.server_name == "test-server"
|
||||
assert "server=test-server" in str(error)
|
||||
|
||||
def test_error_with_details(self):
|
||||
"""Test error with additional details."""
|
||||
error = MCPError(
|
||||
"Test error",
|
||||
server_name="test-server",
|
||||
details={"key": "value"},
|
||||
)
|
||||
assert error.details == {"key": "value"}
|
||||
assert "details={'key': 'value'}" in str(error)
|
||||
|
||||
|
||||
class TestMCPConnectionError:
|
||||
"""Tests for MCPConnectionError class."""
|
||||
|
||||
def test_basic_connection_error(self):
|
||||
"""Test basic connection error."""
|
||||
error = MCPConnectionError("Connection failed")
|
||||
assert error.message == "Connection failed"
|
||||
assert error.url is None
|
||||
assert error.cause is None
|
||||
|
||||
def test_connection_error_with_url(self):
|
||||
"""Test connection error with URL."""
|
||||
error = MCPConnectionError(
|
||||
"Connection failed",
|
||||
server_name="test-server",
|
||||
url="http://localhost:8000",
|
||||
)
|
||||
assert error.url == "http://localhost:8000"
|
||||
assert "url=http://localhost:8000" in str(error)
|
||||
|
||||
def test_connection_error_with_cause(self):
|
||||
"""Test connection error with cause."""
|
||||
cause = ConnectionError("Network error")
|
||||
error = MCPConnectionError(
|
||||
"Connection failed",
|
||||
cause=cause,
|
||||
)
|
||||
assert error.cause is cause
|
||||
assert "ConnectionError" in str(error)
|
||||
|
||||
|
||||
class TestMCPTimeoutError:
|
||||
"""Tests for MCPTimeoutError class."""
|
||||
|
||||
def test_basic_timeout_error(self):
|
||||
"""Test basic timeout error."""
|
||||
error = MCPTimeoutError("Request timed out")
|
||||
assert error.message == "Request timed out"
|
||||
assert error.timeout_seconds is None
|
||||
assert error.operation is None
|
||||
|
||||
def test_timeout_error_with_details(self):
|
||||
"""Test timeout error with details."""
|
||||
error = MCPTimeoutError(
|
||||
"Request timed out",
|
||||
server_name="test-server",
|
||||
timeout_seconds=30.0,
|
||||
operation="POST /mcp",
|
||||
)
|
||||
assert error.timeout_seconds == 30.0
|
||||
assert error.operation == "POST /mcp"
|
||||
assert "timeout=30.0s" in str(error)
|
||||
assert "operation=POST /mcp" in str(error)
|
||||
|
||||
|
||||
class TestMCPToolError:
|
||||
"""Tests for MCPToolError class."""
|
||||
|
||||
def test_basic_tool_error(self):
|
||||
"""Test basic tool error."""
|
||||
error = MCPToolError("Tool execution failed")
|
||||
assert error.message == "Tool execution failed"
|
||||
assert error.tool_name is None
|
||||
assert error.tool_args is None
|
||||
assert error.error_code is None
|
||||
|
||||
def test_tool_error_with_details(self):
|
||||
"""Test tool error with all details."""
|
||||
error = MCPToolError(
|
||||
"Tool execution failed",
|
||||
server_name="llm-gateway",
|
||||
tool_name="chat",
|
||||
tool_args={"prompt": "Hello"},
|
||||
error_code="INVALID_ARGS",
|
||||
)
|
||||
assert error.tool_name == "chat"
|
||||
assert error.tool_args == {"prompt": "Hello"}
|
||||
assert error.error_code == "INVALID_ARGS"
|
||||
assert "tool=chat" in str(error)
|
||||
assert "error_code=INVALID_ARGS" in str(error)
|
||||
|
||||
|
||||
class TestMCPServerNotFoundError:
|
||||
"""Tests for MCPServerNotFoundError class."""
|
||||
|
||||
def test_server_not_found(self):
|
||||
"""Test server not found error."""
|
||||
error = MCPServerNotFoundError("unknown-server")
|
||||
assert error.server_name == "unknown-server"
|
||||
assert "MCP server not found: unknown-server" in error.message
|
||||
assert error.available_servers == []
|
||||
|
||||
def test_server_not_found_with_available(self):
|
||||
"""Test server not found with available servers listed."""
|
||||
error = MCPServerNotFoundError(
|
||||
"unknown-server",
|
||||
available_servers=["server-1", "server-2"],
|
||||
)
|
||||
assert error.available_servers == ["server-1", "server-2"]
|
||||
assert "available=['server-1', 'server-2']" in str(error)
|
||||
|
||||
|
||||
class TestMCPToolNotFoundError:
|
||||
"""Tests for MCPToolNotFoundError class."""
|
||||
|
||||
def test_tool_not_found(self):
|
||||
"""Test tool not found error."""
|
||||
error = MCPToolNotFoundError("unknown-tool")
|
||||
assert error.tool_name == "unknown-tool"
|
||||
assert "Tool not found: unknown-tool" in error.message
|
||||
assert error.available_tools == []
|
||||
|
||||
def test_tool_not_found_with_available(self):
|
||||
"""Test tool not found with available tools listed."""
|
||||
error = MCPToolNotFoundError(
|
||||
"unknown-tool",
|
||||
available_tools=["tool-1", "tool-2", "tool-3", "tool-4", "tool-5", "tool-6"],
|
||||
)
|
||||
assert len(error.available_tools) == 6
|
||||
# Should show first 5 tools with ellipsis
|
||||
assert "available_tools=['tool-1', 'tool-2', 'tool-3', 'tool-4', 'tool-5']..." in str(error)
|
||||
|
||||
|
||||
class TestMCPCircuitOpenError:
|
||||
"""Tests for MCPCircuitOpenError class."""
|
||||
|
||||
def test_circuit_open_error(self):
|
||||
"""Test circuit open error."""
|
||||
error = MCPCircuitOpenError("test-server")
|
||||
assert error.server_name == "test-server"
|
||||
assert "Circuit breaker open for server: test-server" in error.message
|
||||
assert error.failure_count is None
|
||||
assert error.reset_timeout is None
|
||||
|
||||
def test_circuit_open_error_with_details(self):
|
||||
"""Test circuit open error with details."""
|
||||
error = MCPCircuitOpenError(
|
||||
"test-server",
|
||||
failure_count=5,
|
||||
reset_timeout=30.0,
|
||||
)
|
||||
assert error.failure_count == 5
|
||||
assert error.reset_timeout == 30.0
|
||||
assert "failures=5" in str(error)
|
||||
assert "reset_in=30.0s" in str(error)
|
||||
|
||||
|
||||
class TestMCPValidationError:
|
||||
"""Tests for MCPValidationError class."""
|
||||
|
||||
def test_validation_error(self):
|
||||
"""Test validation error."""
|
||||
error = MCPValidationError("Validation failed")
|
||||
assert error.message == "Validation failed"
|
||||
assert error.tool_name is None
|
||||
assert error.field_errors == {}
|
||||
|
||||
def test_validation_error_with_details(self):
|
||||
"""Test validation error with field errors."""
|
||||
error = MCPValidationError(
|
||||
"Validation failed",
|
||||
tool_name="create_issue",
|
||||
field_errors={
|
||||
"title": "Title is required",
|
||||
"priority": "Invalid priority value",
|
||||
},
|
||||
)
|
||||
assert error.tool_name == "create_issue"
|
||||
assert error.field_errors == {
|
||||
"title": "Title is required",
|
||||
"priority": "Invalid priority value",
|
||||
}
|
||||
assert "tool=create_issue" in str(error)
|
||||
assert "fields=['title', 'priority']" in str(error)
|
||||
|
||||
|
||||
class TestExceptionInheritance:
|
||||
"""Tests for exception inheritance chain."""
|
||||
|
||||
def test_all_errors_inherit_from_mcp_error(self):
|
||||
"""Test that all custom exceptions inherit from MCPError."""
|
||||
assert issubclass(MCPConnectionError, MCPError)
|
||||
assert issubclass(MCPTimeoutError, MCPError)
|
||||
assert issubclass(MCPToolError, MCPError)
|
||||
assert issubclass(MCPServerNotFoundError, MCPError)
|
||||
assert issubclass(MCPToolNotFoundError, MCPError)
|
||||
assert issubclass(MCPCircuitOpenError, MCPError)
|
||||
assert issubclass(MCPValidationError, MCPError)
|
||||
|
||||
def test_all_errors_inherit_from_exception(self):
|
||||
"""Test that base MCPError inherits from Exception."""
|
||||
assert issubclass(MCPError, Exception)
|
||||
|
||||
def test_catch_all_with_mcp_error(self):
|
||||
"""Test that all errors can be caught with MCPError."""
|
||||
|
||||
def raise_connection_error():
|
||||
raise MCPConnectionError("Connection failed")
|
||||
|
||||
def raise_timeout_error():
|
||||
raise MCPTimeoutError("Timeout")
|
||||
|
||||
def raise_tool_error():
|
||||
raise MCPToolError("Tool failed")
|
||||
|
||||
with pytest.raises(MCPError):
|
||||
raise_connection_error()
|
||||
|
||||
with pytest.raises(MCPError):
|
||||
raise_timeout_error()
|
||||
|
||||
with pytest.raises(MCPError):
|
||||
raise_tool_error()
|
||||
272
backend/tests/services/mcp/test_registry.py
Normal file
272
backend/tests/services/mcp/test_registry.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
Tests for MCP Server Registry
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.mcp.config import MCPConfig, MCPServerConfig, TransportType
|
||||
from app.services.mcp.exceptions import MCPServerNotFoundError
|
||||
from app.services.mcp.registry import (
|
||||
MCPServerRegistry,
|
||||
ServerCapabilities,
|
||||
get_registry,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_registry():
|
||||
"""Reset the singleton registry before and after each test."""
|
||||
MCPServerRegistry.reset_instance()
|
||||
yield
|
||||
MCPServerRegistry.reset_instance()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_config():
|
||||
"""Create a sample MCP configuration."""
|
||||
return MCPConfig(
|
||||
mcp_servers={
|
||||
"server-1": MCPServerConfig(
|
||||
url="http://server1:8000",
|
||||
timeout=30,
|
||||
enabled=True,
|
||||
),
|
||||
"server-2": MCPServerConfig(
|
||||
url="http://server2:8000",
|
||||
timeout=60,
|
||||
enabled=True,
|
||||
),
|
||||
"disabled-server": MCPServerConfig(
|
||||
url="http://disabled:8000",
|
||||
enabled=False,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TestServerCapabilities:
|
||||
"""Tests for ServerCapabilities class."""
|
||||
|
||||
def test_empty_capabilities(self):
|
||||
"""Test creating empty capabilities."""
|
||||
caps = ServerCapabilities()
|
||||
assert caps.tools == []
|
||||
assert caps.resources == []
|
||||
assert caps.prompts == []
|
||||
assert caps.is_loaded is False
|
||||
assert caps.tool_names == []
|
||||
|
||||
def test_capabilities_with_tools(self):
|
||||
"""Test capabilities with tools."""
|
||||
caps = ServerCapabilities(
|
||||
tools=[
|
||||
{"name": "tool1", "description": "Tool 1"},
|
||||
{"name": "tool2", "description": "Tool 2"},
|
||||
]
|
||||
)
|
||||
assert len(caps.tools) == 2
|
||||
assert caps.tool_names == ["tool1", "tool2"]
|
||||
|
||||
def test_mark_loaded(self):
|
||||
"""Test marking capabilities as loaded."""
|
||||
caps = ServerCapabilities()
|
||||
assert caps.is_loaded is False
|
||||
assert caps._load_time is None
|
||||
|
||||
caps.mark_loaded()
|
||||
assert caps.is_loaded is True
|
||||
assert caps._load_time is not None
|
||||
|
||||
|
||||
class TestMCPServerRegistry:
|
||||
"""Tests for MCPServerRegistry singleton."""
|
||||
|
||||
def test_singleton_pattern(self, reset_registry):
|
||||
"""Test that registry is a singleton."""
|
||||
registry1 = MCPServerRegistry()
|
||||
registry2 = MCPServerRegistry()
|
||||
assert registry1 is registry2
|
||||
|
||||
def test_get_instance(self, reset_registry):
|
||||
"""Test get_instance class method."""
|
||||
registry = MCPServerRegistry.get_instance()
|
||||
assert registry is MCPServerRegistry()
|
||||
|
||||
def test_reset_instance(self, reset_registry):
|
||||
"""Test resetting singleton instance."""
|
||||
registry1 = MCPServerRegistry()
|
||||
MCPServerRegistry.reset_instance()
|
||||
registry2 = MCPServerRegistry()
|
||||
assert registry1 is not registry2
|
||||
|
||||
def test_load_config(self, reset_registry, sample_config):
|
||||
"""Test loading configuration."""
|
||||
registry = MCPServerRegistry()
|
||||
registry.load_config(sample_config)
|
||||
|
||||
assert len(registry.list_servers()) == 3
|
||||
assert "server-1" in registry.list_servers()
|
||||
assert "server-2" in registry.list_servers()
|
||||
assert "disabled-server" in registry.list_servers()
|
||||
|
||||
def test_list_enabled_servers(self, reset_registry, sample_config):
|
||||
"""Test listing only enabled servers."""
|
||||
registry = MCPServerRegistry()
|
||||
registry.load_config(sample_config)
|
||||
|
||||
enabled = registry.list_enabled_servers()
|
||||
assert len(enabled) == 2
|
||||
assert "server-1" in enabled
|
||||
assert "server-2" in enabled
|
||||
assert "disabled-server" not in enabled
|
||||
|
||||
def test_register(self, reset_registry):
|
||||
"""Test registering a new server."""
|
||||
registry = MCPServerRegistry()
|
||||
config = MCPServerConfig(url="http://new:8000")
|
||||
|
||||
registry.register("new-server", config)
|
||||
assert "new-server" in registry.list_servers()
|
||||
assert registry.get("new-server").url == "http://new:8000"
|
||||
|
||||
def test_unregister(self, reset_registry, sample_config):
|
||||
"""Test unregistering a server."""
|
||||
registry = MCPServerRegistry()
|
||||
registry.load_config(sample_config)
|
||||
|
||||
assert registry.unregister("server-1") is True
|
||||
assert "server-1" not in registry.list_servers()
|
||||
|
||||
# Unregistering non-existent server returns False
|
||||
assert registry.unregister("nonexistent") is False
|
||||
|
||||
def test_get(self, reset_registry, sample_config):
|
||||
"""Test getting server config."""
|
||||
registry = MCPServerRegistry()
|
||||
registry.load_config(sample_config)
|
||||
|
||||
config = registry.get("server-1")
|
||||
assert config.url == "http://server1:8000"
|
||||
assert config.timeout == 30
|
||||
|
||||
def test_get_not_found(self, reset_registry, sample_config):
|
||||
"""Test getting non-existent server raises error."""
|
||||
registry = MCPServerRegistry()
|
||||
registry.load_config(sample_config)
|
||||
|
||||
with pytest.raises(MCPServerNotFoundError) as exc_info:
|
||||
registry.get("nonexistent")
|
||||
|
||||
assert exc_info.value.server_name == "nonexistent"
|
||||
assert "server-1" in exc_info.value.available_servers
|
||||
|
||||
def test_get_or_none(self, reset_registry, sample_config):
|
||||
"""Test get_or_none method."""
|
||||
registry = MCPServerRegistry()
|
||||
registry.load_config(sample_config)
|
||||
|
||||
config = registry.get_or_none("server-1")
|
||||
assert config is not None
|
||||
|
||||
config = registry.get_or_none("nonexistent")
|
||||
assert config is None
|
||||
|
||||
def test_get_all_configs(self, reset_registry, sample_config):
|
||||
"""Test getting all configs."""
|
||||
registry = MCPServerRegistry()
|
||||
registry.load_config(sample_config)
|
||||
|
||||
configs = registry.get_all_configs()
|
||||
assert len(configs) == 3
|
||||
|
||||
def test_get_enabled_configs(self, reset_registry, sample_config):
|
||||
"""Test getting enabled configs."""
|
||||
registry = MCPServerRegistry()
|
||||
registry.load_config(sample_config)
|
||||
|
||||
configs = registry.get_enabled_configs()
|
||||
assert len(configs) == 2
|
||||
assert "disabled-server" not in configs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_capabilities(self, reset_registry, sample_config):
|
||||
"""Test getting server capabilities."""
|
||||
registry = MCPServerRegistry()
|
||||
registry.load_config(sample_config)
|
||||
|
||||
# Initially empty capabilities
|
||||
caps = await registry.get_capabilities("server-1")
|
||||
assert caps.is_loaded is False
|
||||
|
||||
def test_set_capabilities(self, reset_registry, sample_config):
|
||||
"""Test setting server capabilities."""
|
||||
registry = MCPServerRegistry()
|
||||
registry.load_config(sample_config)
|
||||
|
||||
registry.set_capabilities(
|
||||
"server-1",
|
||||
tools=[{"name": "tool1"}, {"name": "tool2"}],
|
||||
resources=[{"name": "resource1"}],
|
||||
)
|
||||
|
||||
caps = registry._capabilities["server-1"]
|
||||
assert len(caps.tools) == 2
|
||||
assert len(caps.resources) == 1
|
||||
assert caps.is_loaded is True
|
||||
|
||||
def test_find_server_for_tool(self, reset_registry, sample_config):
|
||||
"""Test finding server that provides a tool."""
|
||||
registry = MCPServerRegistry()
|
||||
registry.load_config(sample_config)
|
||||
|
||||
registry.set_capabilities(
|
||||
"server-1",
|
||||
tools=[{"name": "tool1"}, {"name": "tool2"}],
|
||||
)
|
||||
registry.set_capabilities(
|
||||
"server-2",
|
||||
tools=[{"name": "tool3"}],
|
||||
)
|
||||
|
||||
assert registry.find_server_for_tool("tool1") == "server-1"
|
||||
assert registry.find_server_for_tool("tool3") == "server-2"
|
||||
assert registry.find_server_for_tool("unknown") is None
|
||||
|
||||
def test_get_all_tools(self, reset_registry, sample_config):
|
||||
"""Test getting all tools from all servers."""
|
||||
registry = MCPServerRegistry()
|
||||
registry.load_config(sample_config)
|
||||
|
||||
registry.set_capabilities(
|
||||
"server-1",
|
||||
tools=[{"name": "tool1"}],
|
||||
)
|
||||
registry.set_capabilities(
|
||||
"server-2",
|
||||
tools=[{"name": "tool2"}, {"name": "tool3"}],
|
||||
)
|
||||
|
||||
all_tools = registry.get_all_tools()
|
||||
assert len(all_tools) == 2
|
||||
assert len(all_tools["server-1"]) == 1
|
||||
assert len(all_tools["server-2"]) == 2
|
||||
|
||||
def test_global_config_property(self, reset_registry, sample_config):
|
||||
"""Test accessing global config."""
|
||||
registry = MCPServerRegistry()
|
||||
registry.load_config(sample_config)
|
||||
|
||||
global_config = registry.global_config
|
||||
assert global_config is not None
|
||||
assert len(global_config.mcp_servers) == 3
|
||||
|
||||
|
||||
class TestGetRegistry:
|
||||
"""Tests for get_registry convenience function."""
|
||||
|
||||
def test_get_registry_returns_singleton(self, reset_registry):
|
||||
"""Test that get_registry returns the singleton."""
|
||||
registry1 = get_registry()
|
||||
registry2 = get_registry()
|
||||
assert registry1 is registry2
|
||||
assert registry1 is MCPServerRegistry()
|
||||
345
backend/tests/services/mcp/test_routing.py
Normal file
345
backend/tests/services/mcp/test_routing.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""
|
||||
Tests for MCP Tool Call Routing
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.mcp.config import MCPConfig, MCPServerConfig
|
||||
from app.services.mcp.connection import ConnectionPool
|
||||
from app.services.mcp.exceptions import (
|
||||
MCPCircuitOpenError,
|
||||
MCPToolError,
|
||||
MCPToolNotFoundError,
|
||||
)
|
||||
from app.services.mcp.registry import MCPServerRegistry
|
||||
from app.services.mcp.routing import ToolInfo, ToolResult, ToolRouter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_registry():
|
||||
"""Reset the singleton registry before and after each test."""
|
||||
MCPServerRegistry.reset_instance()
|
||||
yield
|
||||
MCPServerRegistry.reset_instance()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def registry(reset_registry):
|
||||
"""Create a configured registry."""
|
||||
reg = MCPServerRegistry()
|
||||
reg.load_config(
|
||||
MCPConfig(
|
||||
mcp_servers={
|
||||
"server-1": MCPServerConfig(
|
||||
url="http://server1:8000",
|
||||
retry_attempts=1,
|
||||
retry_delay=0.1, # Minimum allowed value
|
||||
circuit_breaker_threshold=3,
|
||||
circuit_breaker_timeout=5.0,
|
||||
),
|
||||
"server-2": MCPServerConfig(
|
||||
url="http://server2:8000",
|
||||
retry_attempts=1,
|
||||
retry_delay=0.1, # Minimum allowed value
|
||||
),
|
||||
}
|
||||
)
|
||||
)
|
||||
return reg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pool():
|
||||
"""Create a connection pool."""
|
||||
return ConnectionPool()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def router(registry, pool):
|
||||
"""Create a tool router."""
|
||||
return ToolRouter(registry, pool)
|
||||
|
||||
|
||||
class TestToolInfo:
|
||||
"""Tests for ToolInfo dataclass."""
|
||||
|
||||
def test_basic_tool_info(self):
|
||||
"""Test creating basic tool info."""
|
||||
info = ToolInfo(name="test-tool")
|
||||
assert info.name == "test-tool"
|
||||
assert info.description is None
|
||||
assert info.server_name is None
|
||||
assert info.input_schema is None
|
||||
|
||||
def test_full_tool_info(self):
|
||||
"""Test creating full tool info."""
|
||||
info = ToolInfo(
|
||||
name="create_issue",
|
||||
description="Create a new issue",
|
||||
server_name="issues",
|
||||
input_schema={"type": "object", "properties": {"title": {"type": "string"}}},
|
||||
)
|
||||
assert info.name == "create_issue"
|
||||
assert info.description == "Create a new issue"
|
||||
assert info.server_name == "issues"
|
||||
assert "properties" in info.input_schema
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test converting to dictionary."""
|
||||
info = ToolInfo(
|
||||
name="test-tool",
|
||||
description="A test tool",
|
||||
server_name="test-server",
|
||||
)
|
||||
result = info.to_dict()
|
||||
|
||||
assert result["name"] == "test-tool"
|
||||
assert result["description"] == "A test tool"
|
||||
assert result["server_name"] == "test-server"
|
||||
|
||||
|
||||
class TestToolResult:
|
||||
"""Tests for ToolResult dataclass."""
|
||||
|
||||
def test_success_result(self):
|
||||
"""Test creating success result."""
|
||||
result = ToolResult(
|
||||
success=True,
|
||||
data={"id": "123"},
|
||||
tool_name="create_issue",
|
||||
server_name="issues",
|
||||
)
|
||||
assert result.success is True
|
||||
assert result.data == {"id": "123"}
|
||||
assert result.error is None
|
||||
|
||||
def test_error_result(self):
|
||||
"""Test creating error result."""
|
||||
result = ToolResult(
|
||||
success=False,
|
||||
error="Tool execution failed",
|
||||
error_code="INTERNAL_ERROR",
|
||||
tool_name="create_issue",
|
||||
server_name="issues",
|
||||
)
|
||||
assert result.success is False
|
||||
assert result.error == "Tool execution failed"
|
||||
assert result.error_code == "INTERNAL_ERROR"
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test converting to dictionary."""
|
||||
result = ToolResult(
|
||||
success=True,
|
||||
data={"result": "ok"},
|
||||
tool_name="test",
|
||||
execution_time_ms=123.45,
|
||||
)
|
||||
d = result.to_dict()
|
||||
|
||||
assert d["success"] is True
|
||||
assert d["data"] == {"result": "ok"}
|
||||
assert d["tool_name"] == "test"
|
||||
assert d["execution_time_ms"] == 123.45
|
||||
assert "request_id" in d # Auto-generated
|
||||
|
||||
|
||||
class TestToolRouter:
|
||||
"""Tests for ToolRouter class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_tool_mapping(self, router):
|
||||
"""Test registering tool mappings."""
|
||||
await router.register_tool_mapping("tool1", "server-1")
|
||||
await router.register_tool_mapping("tool2", "server-2")
|
||||
|
||||
assert router.find_server_for_tool("tool1") == "server-1"
|
||||
assert router.find_server_for_tool("tool2") == "server-2"
|
||||
|
||||
def test_find_server_for_unknown_tool(self, router):
|
||||
"""Test finding server for unknown tool."""
|
||||
result = router.find_server_for_tool("unknown-tool")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_success(self, router, registry):
|
||||
"""Test successful tool call."""
|
||||
# Set up capabilities
|
||||
registry.set_capabilities(
|
||||
"server-1",
|
||||
tools=[{"name": "test-tool"}],
|
||||
)
|
||||
await router.register_tool_mapping("test-tool", "server-1")
|
||||
|
||||
# Mock the pool connection and request
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.execute_request = AsyncMock(
|
||||
return_value={"result": {"status": "ok"}}
|
||||
)
|
||||
mock_conn.is_connected = True
|
||||
|
||||
with patch.object(router._pool, "get_connection", return_value=mock_conn):
|
||||
result = await router.call_tool(
|
||||
server_name="server-1",
|
||||
tool_name="test-tool",
|
||||
arguments={"param": "value"},
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data == {"status": "ok"}
|
||||
assert result.tool_name == "test-tool"
|
||||
assert result.server_name == "server-1"
|
||||
assert result.execution_time_ms > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_error_response(self, router, registry):
|
||||
"""Test tool call with error response."""
|
||||
registry.set_capabilities(
|
||||
"server-1",
|
||||
tools=[{"name": "test-tool"}],
|
||||
)
|
||||
await router.register_tool_mapping("test-tool", "server-1")
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.execute_request = AsyncMock(
|
||||
return_value={
|
||||
"error": {
|
||||
"code": -32000,
|
||||
"message": "Tool execution failed",
|
||||
}
|
||||
}
|
||||
)
|
||||
mock_conn.is_connected = True
|
||||
|
||||
with patch.object(router._pool, "get_connection", return_value=mock_conn):
|
||||
result = await router.call_tool(
|
||||
server_name="server-1",
|
||||
tool_name="test-tool",
|
||||
arguments={},
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "Tool execution failed" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_tool(self, router, registry):
|
||||
"""Test routing tool to correct server."""
|
||||
registry.set_capabilities(
|
||||
"server-1",
|
||||
tools=[{"name": "tool-on-server-1"}],
|
||||
)
|
||||
await router.register_tool_mapping("tool-on-server-1", "server-1")
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.execute_request = AsyncMock(
|
||||
return_value={"result": "routed"}
|
||||
)
|
||||
mock_conn.is_connected = True
|
||||
|
||||
with patch.object(router._pool, "get_connection", return_value=mock_conn):
|
||||
result = await router.route_tool(
|
||||
tool_name="tool-on-server-1",
|
||||
arguments={"key": "value"},
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.server_name == "server-1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_tool_not_found(self, router):
|
||||
"""Test routing unknown tool raises error."""
|
||||
with pytest.raises(MCPToolNotFoundError) as exc_info:
|
||||
await router.route_tool(
|
||||
tool_name="unknown-tool",
|
||||
arguments={},
|
||||
)
|
||||
|
||||
assert exc_info.value.tool_name == "unknown-tool"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_all_tools(self, router, registry):
|
||||
"""Test listing all tools."""
|
||||
registry.set_capabilities(
|
||||
"server-1",
|
||||
tools=[
|
||||
{"name": "tool1", "description": "Tool 1"},
|
||||
{"name": "tool2", "description": "Tool 2"},
|
||||
],
|
||||
)
|
||||
registry.set_capabilities(
|
||||
"server-2",
|
||||
tools=[{"name": "tool3", "description": "Tool 3"}],
|
||||
)
|
||||
|
||||
tools = await router.list_all_tools()
|
||||
|
||||
assert len(tools) == 3
|
||||
tool_names = [t.name for t in tools]
|
||||
assert "tool1" in tool_names
|
||||
assert "tool2" in tool_names
|
||||
assert "tool3" in tool_names
|
||||
|
||||
def test_circuit_breaker_status(self, router, registry):
|
||||
"""Test getting circuit breaker status."""
|
||||
# Initially no circuit breakers
|
||||
status = router.get_circuit_breaker_status()
|
||||
assert status == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_circuit_breaker(self, router, registry):
|
||||
"""Test resetting circuit breaker."""
|
||||
# Reset non-existent returns False
|
||||
result = await router.reset_circuit_breaker("server-1")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_tools(self, router, registry):
|
||||
"""Test tool discovery from servers."""
|
||||
# Create mocks for different servers
|
||||
mock_conn_1 = AsyncMock()
|
||||
mock_conn_1.execute_request = AsyncMock(
|
||||
return_value={
|
||||
"tools": [
|
||||
{"name": "discovered-tool", "description": "A discovered tool"},
|
||||
]
|
||||
}
|
||||
)
|
||||
mock_conn_1.server_name = "server-1"
|
||||
mock_conn_1.is_connected = True
|
||||
|
||||
mock_conn_2 = AsyncMock()
|
||||
mock_conn_2.execute_request = AsyncMock(
|
||||
return_value={"tools": []} # Empty for server-2
|
||||
)
|
||||
mock_conn_2.server_name = "server-2"
|
||||
mock_conn_2.is_connected = True
|
||||
|
||||
async def get_connection_side_effect(server_name, _config):
|
||||
if server_name == "server-1":
|
||||
return mock_conn_1
|
||||
return mock_conn_2
|
||||
|
||||
with patch.object(
|
||||
router._pool,
|
||||
"get_connection",
|
||||
side_effect=get_connection_side_effect,
|
||||
):
|
||||
await router.discover_tools()
|
||||
|
||||
# Check that tool mapping was registered
|
||||
server = router.find_server_for_tool("discovered-tool")
|
||||
assert server == "server-1"
|
||||
|
||||
def test_calculate_retry_delay(self, router, registry):
|
||||
"""Test retry delay calculation."""
|
||||
config = registry.get("server-1")
|
||||
|
||||
delay1 = router._calculate_retry_delay(1, config)
|
||||
delay2 = router._calculate_retry_delay(2, config)
|
||||
delay3 = router._calculate_retry_delay(3, config)
|
||||
|
||||
# Delays should increase with attempts
|
||||
assert delay1 > 0
|
||||
# Allow for jitter variation
|
||||
assert delay1 <= config.retry_max_delay * 1.25
|
||||
63
backend/uv.lock
generated
63
backend/uv.lock
generated
@@ -582,15 +582,18 @@ dependencies = [
|
||||
{ name = "fastapi" },
|
||||
{ name = "fastapi-utils" },
|
||||
{ name = "httpx" },
|
||||
{ name = "mcp" },
|
||||
{ name = "passlib" },
|
||||
{ name = "pillow" },
|
||||
{ name = "psycopg2-binary" },
|
||||
{ name = "pybreaker" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "pydantic-settings" },
|
||||
{ name = "python-dotenv" },
|
||||
{ name = "python-jose" },
|
||||
{ name = "python-multipart" },
|
||||
{ name = "pytz" },
|
||||
{ name = "pyyaml" },
|
||||
{ name = "slowapi" },
|
||||
{ name = "sqlalchemy" },
|
||||
{ name = "sse-starlette" },
|
||||
@@ -632,10 +635,12 @@ requires-dist = [
|
||||
{ name = "fastapi-utils", specifier = "==0.8.0" },
|
||||
{ name = "freezegun", marker = "extra == 'dev'", specifier = "~=1.5.1" },
|
||||
{ name = "httpx", specifier = ">=0.27.0" },
|
||||
{ name = "mcp", specifier = ">=1.0.0" },
|
||||
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.8.0" },
|
||||
{ name = "passlib", specifier = "==1.7.4" },
|
||||
{ name = "pillow", specifier = ">=10.3.0" },
|
||||
{ name = "psycopg2-binary", specifier = ">=2.9.9" },
|
||||
{ name = "pybreaker", specifier = ">=1.0.0" },
|
||||
{ name = "pydantic", specifier = ">=2.10.6" },
|
||||
{ name = "pydantic-settings", specifier = ">=2.2.1" },
|
||||
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" },
|
||||
@@ -646,6 +651,7 @@ requires-dist = [
|
||||
{ name = "python-jose", specifier = "==3.4.0" },
|
||||
{ name = "python-multipart", specifier = ">=0.0.19" },
|
||||
{ name = "pytz", specifier = ">=2024.1" },
|
||||
{ name = "pyyaml", specifier = ">=6.0.0" },
|
||||
{ name = "requests", marker = "extra == 'dev'", specifier = ">=2.32.0" },
|
||||
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.8.0" },
|
||||
{ name = "schemathesis", marker = "extra == 'e2e'", specifier = ">=3.30.0" },
|
||||
@@ -805,6 +811,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "httpx-sse"
|
||||
version = "0.4.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0f/4c/751061ffa58615a32c31b2d82e8482be8dd4a89154f003147acee90f2be9/httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d", size = 15943, upload-time = "2025-10-10T21:48:22.271Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d2/fd/6668e5aec43ab844de6fc74927e155a3b37bf40d7c3790e49fc0406b6578/httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc", size = 8960, upload-time = "2025-10-10T21:48:21.158Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hypothesis"
|
||||
version = "6.148.2"
|
||||
@@ -1063,6 +1078,31 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/70/bc/6f1c2f612465f5fa89b95bead1f44dcb607670fd42891d8fdcd5d039f4f4/markupsafe-3.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:32001d6a8fc98c8cb5c947787c5d08b0a50663d139f1305bac5885d98d9b40fa", size = 14146, upload-time = "2025-09-27T18:37:28.327Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mcp"
|
||||
version = "1.25.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
{ name = "httpx" },
|
||||
{ name = "httpx-sse" },
|
||||
{ name = "jsonschema" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "pydantic-settings" },
|
||||
{ name = "pyjwt", extra = ["crypto"] },
|
||||
{ name = "python-multipart" },
|
||||
{ name = "pywin32", marker = "sys_platform == 'win32'" },
|
||||
{ name = "sse-starlette" },
|
||||
{ name = "starlette" },
|
||||
{ name = "typing-extensions" },
|
||||
{ name = "typing-inspection" },
|
||||
{ name = "uvicorn", marker = "sys_platform != 'emscripten'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d5/2d/649d80a0ecf6a1f82632ca44bec21c0461a9d9fc8934d38cb5b319f2db5e/mcp-1.25.0.tar.gz", hash = "sha256:56310361ebf0364e2d438e5b45f7668cbb124e158bb358333cd06e49e83a6802", size = 605387, upload-time = "2025-12-19T10:19:56.985Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e2/fc/6dc7659c2ae5ddf280477011f4213a74f806862856b796ef08f028e664bf/mcp-1.25.0-py3-none-any.whl", hash = "sha256:b37c38144a666add0862614cc79ec276e97d72aa8ca26d622818d4e278b9721a", size = 233076, upload-time = "2025-12-19T10:19:55.416Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mdurl"
|
||||
version = "0.1.2"
|
||||
@@ -1294,6 +1334,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/62/1e/a94a8d635fa3ce4cfc7f506003548d0a2447ae76fd5ca53932970fe3053f/pyasn1-0.4.8-py2.py3-none-any.whl", hash = "sha256:39c7e2ec30515947ff4e87fb6f456dfc6e84857d34be479c9d4a4ba4bf46aa5d", size = 77145, upload-time = "2019-11-16T17:27:11.07Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pybreaker"
|
||||
version = "1.4.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f2/89/fbf98e383f1ec6d117af2cd983efdb3eb7018b63834c427025764194cac2/pybreaker-1.4.1.tar.gz", hash = "sha256:8df2d245c73ba40c8242c56ffb4f12138fbadc23e296224740c2028ea9dc1178", size = 15555, upload-time = "2025-09-21T15:12:04.499Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/44/75/e64d3d40a741e2be21d69154f4e5c43a66f0c603c5ef11f49e01429a5932/pybreaker-1.4.1-py3-none-any.whl", hash = "sha256:b4dab4a05195b7f2a64a6c1a6c4ba7a96534ef56ea7210e6bcb59f28897160e0", size = 12915, upload-time = "2025-09-21T15:12:02.284Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pycparser"
|
||||
version = "2.23"
|
||||
@@ -1412,6 +1461,20 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyjwt"
|
||||
version = "2.10.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e7/46/bd74733ff231675599650d3e47f361794b22ef3e3770998dda30d3b63726/pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953", size = 87785, upload-time = "2024-11-28T03:43:29.933Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997, upload-time = "2024-11-28T03:43:27.893Z" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
crypto = [
|
||||
{ name = "cryptography" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyrate-limiter"
|
||||
version = "3.9.0"
|
||||
|
||||
@@ -214,6 +214,71 @@ test(frontend): add unit tests for ProjectDashboard
|
||||
|
||||
---
|
||||
|
||||
## Phase 2+ Implementation Workflow
|
||||
|
||||
**For complex infrastructure issues (Phase 2 MCP, core systems), follow this rigorous process:**
|
||||
|
||||
### 1. Branch Setup
|
||||
```bash
|
||||
# Create feature branch from dev
|
||||
git checkout dev && git pull
|
||||
git checkout -b feature/<issue-number>-<description>
|
||||
```
|
||||
|
||||
### 2. Planning Phase
|
||||
- Read the issue thoroughly, understand ALL sub-tasks
|
||||
- Identify components and their dependencies
|
||||
- Determine if multi-agent parallel execution is appropriate
|
||||
- Create a detailed execution plan before writing any code
|
||||
|
||||
### 3. Implementation with Continuous Testing
|
||||
**After EACH sub-task:**
|
||||
- [ ] Run unit tests for the component
|
||||
- [ ] Run integration tests if applicable
|
||||
- [ ] Verify type checking passes
|
||||
- [ ] Verify linting passes
|
||||
- [ ] Keep coverage high (aim for >90%)
|
||||
|
||||
**Both modules must be tested:**
|
||||
- Backend: `IS_TEST=True uv run pytest` + E2E tests
|
||||
- Frontend: `npm test` + E2E tests (if applicable)
|
||||
|
||||
### 4. Multi-Agent Review (MANDATORY before considering done)
|
||||
Before closing an issue, perform deep review from multiple angles:
|
||||
|
||||
| Review Type | Focus Areas |
|
||||
|-------------|-------------|
|
||||
| **Code Quality** | Logic errors, edge cases, race conditions, error handling |
|
||||
| **Security** | OWASP Top 10, input validation, authentication, authorization |
|
||||
| **Performance** | N+1 queries, memory leaks, inefficient algorithms |
|
||||
| **Architecture** | Pattern adherence, separation of concerns, extensibility |
|
||||
| **Testing** | Coverage completeness, test quality, edge case coverage |
|
||||
| **Documentation** | Code comments, README, API docs, usage examples |
|
||||
|
||||
**No stone unturned. No sloppy results. No unreviewed work.**
|
||||
|
||||
### 5. Final Validation
|
||||
- [ ] All tests pass (unit, integration, E2E)
|
||||
- [ ] Type checking passes
|
||||
- [ ] Linting passes
|
||||
- [ ] Documentation updated
|
||||
- [ ] Coverage meets threshold
|
||||
- [ ] Issue checklist 100% complete
|
||||
- [ ] Multi-agent review passed
|
||||
|
||||
### When to Use Parallel Agents
|
||||
Use multiple agents working in parallel when:
|
||||
- Sub-tasks are independent (no shared state/dependencies)
|
||||
- Different expertise areas (backend vs frontend)
|
||||
- Time-critical deliveries with clear boundaries
|
||||
|
||||
Do NOT use parallel agents when:
|
||||
- Tasks share state or have dependencies
|
||||
- Sequential testing is required
|
||||
- Integration points need careful coordination
|
||||
|
||||
---
|
||||
|
||||
## Quick Reference
|
||||
|
||||
| Action | Command/Location |
|
||||
|
||||
Reference in New Issue
Block a user