Improved code readability and uniformity by standardizing line breaks, indentation, and inline conditions across safety-related services, models, and tests, including content filters, validation rules, and emergency controls.
447 lines
12 KiB
Python
447 lines
12 KiB
Python
"""
|
|
MCP (Model Context Protocol) API Endpoints
|
|
|
|
Provides REST endpoints for managing MCP server connections
|
|
and executing tool calls.
|
|
"""
|
|
|
|
import logging
|
|
import re
|
|
from typing import Annotated, Any
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Path, status
|
|
from pydantic import BaseModel, Field
|
|
|
|
from app.api.dependencies.permissions import require_superuser
|
|
from app.models.user import User
|
|
from app.services.mcp import (
|
|
MCPCircuitOpenError,
|
|
MCPClientManager,
|
|
MCPConnectionError,
|
|
MCPError,
|
|
MCPServerNotFoundError,
|
|
MCPTimeoutError,
|
|
MCPToolError,
|
|
MCPToolNotFoundError,
|
|
get_mcp_client,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter()
|
|
|
|
# Server name validation pattern: alphanumeric, hyphens, underscores, 1-64 chars
|
|
SERVER_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,64}$")
|
|
|
|
# Type alias for validated server name path parameter
|
|
ServerNamePath = Annotated[
|
|
str,
|
|
Path(
|
|
description="MCP server name",
|
|
min_length=1,
|
|
max_length=64,
|
|
pattern=r"^[a-zA-Z0-9_-]+$",
|
|
),
|
|
]
|
|
|
|
|
|
# ============================================================================
|
|
# Request/Response Schemas
|
|
# ============================================================================
|
|
|
|
|
|
class ServerInfo(BaseModel):
|
|
"""Information about an MCP server."""
|
|
|
|
name: str = Field(..., description="Server name")
|
|
url: str = Field(..., description="Server URL")
|
|
enabled: bool = Field(..., description="Whether server is enabled")
|
|
timeout: int = Field(..., description="Request timeout in seconds")
|
|
transport: str = Field(..., description="Transport type (http, stdio, sse)")
|
|
description: str | None = Field(None, description="Server description")
|
|
|
|
|
|
class ServerListResponse(BaseModel):
|
|
"""Response containing list of MCP servers."""
|
|
|
|
servers: list[ServerInfo]
|
|
total: int
|
|
|
|
|
|
class ToolInfoResponse(BaseModel):
|
|
"""Information about an MCP tool."""
|
|
|
|
name: str = Field(..., description="Tool name")
|
|
description: str | None = Field(None, description="Tool description")
|
|
server_name: str | None = Field(None, description="Server providing the tool")
|
|
input_schema: dict[str, Any] | None = Field(
|
|
None, description="JSON schema for input"
|
|
)
|
|
|
|
|
|
class ToolListResponse(BaseModel):
|
|
"""Response containing list of tools."""
|
|
|
|
tools: list[ToolInfoResponse]
|
|
total: int
|
|
|
|
|
|
class ServerHealthStatus(BaseModel):
|
|
"""Health status for a server."""
|
|
|
|
name: str
|
|
healthy: bool
|
|
state: str
|
|
url: str
|
|
error: str | None = None
|
|
tools_count: int = 0
|
|
|
|
|
|
class HealthCheckResponse(BaseModel):
|
|
"""Response containing health status of all servers."""
|
|
|
|
servers: dict[str, ServerHealthStatus]
|
|
healthy_count: int
|
|
unhealthy_count: int
|
|
total: int
|
|
|
|
|
|
class ToolCallRequest(BaseModel):
|
|
"""Request to execute a tool."""
|
|
|
|
server: str = Field(..., description="MCP server name")
|
|
tool: str = Field(..., description="Tool name to execute")
|
|
arguments: dict[str, Any] = Field(
|
|
default_factory=dict,
|
|
description="Tool arguments",
|
|
)
|
|
timeout: float | None = Field(
|
|
None,
|
|
description="Optional timeout override in seconds",
|
|
)
|
|
|
|
|
|
class ToolCallResponse(BaseModel):
|
|
"""Response from tool execution."""
|
|
|
|
success: bool
|
|
data: Any | None = None
|
|
error: str | None = None
|
|
error_code: str | None = None
|
|
tool_name: str | None = None
|
|
server_name: str | None = None
|
|
execution_time_ms: float = 0.0
|
|
request_id: str | None = None
|
|
|
|
|
|
class CircuitBreakerStatus(BaseModel):
|
|
"""Status of a circuit breaker."""
|
|
|
|
server_name: str
|
|
state: str
|
|
failure_count: int
|
|
|
|
|
|
class CircuitBreakerListResponse(BaseModel):
|
|
"""Response containing circuit breaker statuses."""
|
|
|
|
circuit_breakers: list[CircuitBreakerStatus]
|
|
|
|
|
|
# ============================================================================
|
|
# Endpoints
|
|
# ============================================================================
|
|
|
|
|
|
@router.get(
|
|
"/servers",
|
|
response_model=ServerListResponse,
|
|
summary="List MCP Servers",
|
|
description="Get list of all registered MCP servers with their configurations.",
|
|
)
|
|
async def list_servers(
|
|
mcp: MCPClientManager = Depends(get_mcp_client),
|
|
) -> ServerListResponse:
|
|
"""List all registered MCP servers."""
|
|
servers = []
|
|
|
|
for name in mcp.list_servers():
|
|
try:
|
|
config = mcp.get_server_config(name)
|
|
servers.append(
|
|
ServerInfo(
|
|
name=name,
|
|
url=config.url,
|
|
enabled=config.enabled,
|
|
timeout=config.timeout,
|
|
transport=config.transport.value,
|
|
description=config.description,
|
|
)
|
|
)
|
|
except MCPServerNotFoundError:
|
|
continue
|
|
|
|
return ServerListResponse(
|
|
servers=servers,
|
|
total=len(servers),
|
|
)
|
|
|
|
|
|
@router.get(
|
|
"/servers/{server_name}/tools",
|
|
response_model=ToolListResponse,
|
|
summary="List Server Tools",
|
|
description="Get list of tools available on a specific MCP server.",
|
|
)
|
|
async def list_server_tools(
|
|
server_name: ServerNamePath,
|
|
mcp: MCPClientManager = Depends(get_mcp_client),
|
|
) -> ToolListResponse:
|
|
"""List all tools available on a specific server."""
|
|
try:
|
|
tools = await mcp.list_tools(server_name)
|
|
return ToolListResponse(
|
|
tools=[
|
|
ToolInfoResponse(
|
|
name=t.name,
|
|
description=t.description,
|
|
server_name=t.server_name,
|
|
input_schema=t.input_schema,
|
|
)
|
|
for t in tools
|
|
],
|
|
total=len(tools),
|
|
)
|
|
except MCPServerNotFoundError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Server not found: {server_name}",
|
|
) from e
|
|
|
|
|
|
@router.get(
|
|
"/tools",
|
|
response_model=ToolListResponse,
|
|
summary="List All Tools",
|
|
description="Get list of all tools from all MCP servers.",
|
|
)
|
|
async def list_all_tools(
|
|
mcp: MCPClientManager = Depends(get_mcp_client),
|
|
) -> ToolListResponse:
|
|
"""List all tools from all servers."""
|
|
tools = await mcp.list_all_tools()
|
|
return ToolListResponse(
|
|
tools=[
|
|
ToolInfoResponse(
|
|
name=t.name,
|
|
description=t.description,
|
|
server_name=t.server_name,
|
|
input_schema=t.input_schema,
|
|
)
|
|
for t in tools
|
|
],
|
|
total=len(tools),
|
|
)
|
|
|
|
|
|
@router.get(
|
|
"/health",
|
|
response_model=HealthCheckResponse,
|
|
summary="Health Check",
|
|
description="Check health status of all MCP servers.",
|
|
)
|
|
async def health_check(
|
|
mcp: MCPClientManager = Depends(get_mcp_client),
|
|
) -> HealthCheckResponse:
|
|
"""Perform health check on all MCP servers."""
|
|
health_results = await mcp.health_check()
|
|
|
|
servers = {
|
|
name: ServerHealthStatus(
|
|
name=status.name,
|
|
healthy=status.healthy,
|
|
state=status.state,
|
|
url=status.url,
|
|
error=status.error,
|
|
tools_count=status.tools_count,
|
|
)
|
|
for name, status in health_results.items()
|
|
}
|
|
|
|
healthy_count = sum(1 for s in servers.values() if s.healthy)
|
|
unhealthy_count = len(servers) - healthy_count
|
|
|
|
return HealthCheckResponse(
|
|
servers=servers,
|
|
healthy_count=healthy_count,
|
|
unhealthy_count=unhealthy_count,
|
|
total=len(servers),
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/call",
|
|
response_model=ToolCallResponse,
|
|
summary="Execute Tool (Admin Only)",
|
|
description="Execute a tool on an MCP server. Requires superuser privileges.",
|
|
)
|
|
async def call_tool(
|
|
request: ToolCallRequest,
|
|
current_user: User = Depends(require_superuser),
|
|
mcp: MCPClientManager = Depends(get_mcp_client),
|
|
) -> ToolCallResponse:
|
|
"""
|
|
Execute a tool on an MCP server.
|
|
|
|
This endpoint is restricted to superusers for direct tool execution.
|
|
Normal tool execution should go through agent workflows.
|
|
"""
|
|
logger.info(
|
|
"Tool call by user %s: %s.%s",
|
|
current_user.id,
|
|
request.server,
|
|
request.tool,
|
|
)
|
|
|
|
try:
|
|
result = await mcp.call_tool(
|
|
server=request.server,
|
|
tool=request.tool,
|
|
args=request.arguments,
|
|
timeout=request.timeout,
|
|
)
|
|
|
|
return ToolCallResponse(
|
|
success=result.success,
|
|
data=result.data,
|
|
error=result.error,
|
|
error_code=result.error_code,
|
|
tool_name=result.tool_name,
|
|
server_name=result.server_name,
|
|
execution_time_ms=result.execution_time_ms,
|
|
request_id=result.request_id,
|
|
)
|
|
|
|
except MCPCircuitOpenError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail=f"Server temporarily unavailable: {e.server_name}",
|
|
) from e
|
|
except MCPToolNotFoundError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Tool not found: {e.tool_name}",
|
|
) from e
|
|
except MCPServerNotFoundError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Server not found: {e.server_name}",
|
|
) from e
|
|
except MCPTimeoutError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
|
|
detail=str(e),
|
|
) from e
|
|
except MCPConnectionError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
|
detail=str(e),
|
|
) from e
|
|
except MCPToolError as e:
|
|
# Tool errors are returned in the response, not as HTTP errors
|
|
return ToolCallResponse(
|
|
success=False,
|
|
error=str(e),
|
|
error_code=e.error_code,
|
|
tool_name=e.tool_name,
|
|
server_name=e.server_name,
|
|
)
|
|
except MCPError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=str(e),
|
|
) from e
|
|
|
|
|
|
@router.get(
|
|
"/circuit-breakers",
|
|
response_model=CircuitBreakerListResponse,
|
|
summary="List Circuit Breakers",
|
|
description="Get status of all circuit breakers.",
|
|
)
|
|
async def list_circuit_breakers(
|
|
mcp: MCPClientManager = Depends(get_mcp_client),
|
|
) -> CircuitBreakerListResponse:
|
|
"""Get status of all circuit breakers."""
|
|
status_dict = mcp.get_circuit_breaker_status()
|
|
|
|
return CircuitBreakerListResponse(
|
|
circuit_breakers=[
|
|
CircuitBreakerStatus(
|
|
server_name=name,
|
|
state=info.get("state", "unknown"),
|
|
failure_count=info.get("failure_count", 0),
|
|
)
|
|
for name, info in status_dict.items()
|
|
]
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/circuit-breakers/{server_name}/reset",
|
|
status_code=status.HTTP_204_NO_CONTENT,
|
|
summary="Reset Circuit Breaker (Admin Only)",
|
|
description="Manually reset a circuit breaker for a server.",
|
|
)
|
|
async def reset_circuit_breaker(
|
|
server_name: ServerNamePath,
|
|
current_user: User = Depends(require_superuser),
|
|
mcp: MCPClientManager = Depends(get_mcp_client),
|
|
) -> None:
|
|
"""Manually reset a circuit breaker."""
|
|
logger.info(
|
|
"Circuit breaker reset by user %s for server %s",
|
|
current_user.id,
|
|
server_name,
|
|
)
|
|
|
|
success = await mcp.reset_circuit_breaker(server_name)
|
|
if not success:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"No circuit breaker found for server: {server_name}",
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/servers/{server_name}/reconnect",
|
|
status_code=status.HTTP_204_NO_CONTENT,
|
|
summary="Reconnect to Server (Admin Only)",
|
|
description="Force reconnection to an MCP server.",
|
|
)
|
|
async def reconnect_server(
|
|
server_name: ServerNamePath,
|
|
current_user: User = Depends(require_superuser),
|
|
mcp: MCPClientManager = Depends(get_mcp_client),
|
|
) -> None:
|
|
"""Force reconnection to an MCP server."""
|
|
logger.info(
|
|
"Reconnect requested by user %s for server %s",
|
|
current_user.id,
|
|
server_name,
|
|
)
|
|
|
|
try:
|
|
await mcp.disconnect(server_name)
|
|
await mcp.connect(server_name)
|
|
except MCPServerNotFoundError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Server not found: {server_name}",
|
|
) from e
|
|
except MCPConnectionError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
|
detail=f"Failed to reconnect: {e}",
|
|
) from e
|