feat(backend): implement MCP client infrastructure (#55)

Core MCP client implementation with comprehensive tooling:

**Services:**
- MCPClientManager: Main facade for all MCP operations
- MCPServerRegistry: Thread-safe singleton for server configs
- ConnectionPool: Connection pooling with auto-reconnection
- ToolRouter: Automatic tool routing with circuit breaker
- AsyncCircuitBreaker: Custom async-compatible circuit breaker

**Configuration:**
- YAML-based config with Pydantic models
- Environment variable expansion support
- Transport types: HTTP, SSE, STDIO

**API Endpoints:**
- GET /mcp/servers - List all MCP servers
- GET /mcp/servers/{name}/tools - List server tools
- GET /mcp/tools - List all tools from all servers
- GET /mcp/health - Health check all servers
- POST /mcp/call - Execute tool (admin only)
- GET /mcp/circuit-breakers - Circuit breaker status
- POST /mcp/circuit-breakers/{name}/reset - Reset circuit breaker
- POST /mcp/servers/{name}/reconnect - Force reconnection

**Testing:**
- 156 unit tests with comprehensive coverage
- Tests for all services, routes, and error handling
- Proper mocking and async test support

**Documentation:**
- MCP_CLIENT.md with usage examples
- Phase 2+ workflow documentation

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-03 11:12:41 +01:00
parent 731a188a76
commit e5975fa5d0
22 changed files with 5763 additions and 0 deletions

View File

@@ -0,0 +1,491 @@
"""
Tests for MCP API Routes
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import status
from fastapi.testclient import TestClient
from app.main import app
from app.models.user import User
from app.services.mcp import (
MCPCircuitOpenError,
MCPClientManager,
MCPConnectionError,
MCPServerNotFoundError,
MCPTimeoutError,
MCPToolNotFoundError,
ServerHealth,
)
from app.services.mcp.config import MCPServerConfig, TransportType
from app.services.mcp.routing import ToolInfo, ToolResult
@pytest.fixture
def mock_mcp_client():
"""Create a mock MCP client manager."""
client = MagicMock(spec=MCPClientManager)
client.is_initialized = True
return client
@pytest.fixture
def mock_superuser():
"""Create a mock superuser."""
user = MagicMock(spec=User)
user.id = "00000000-0000-0000-0000-000000000001"
user.is_superuser = True
user.email = "admin@example.com"
return user
@pytest.fixture
def client(mock_mcp_client, mock_superuser):
"""Create a FastAPI test client with mocked dependencies."""
from app.api.routes.mcp import get_mcp_client
from app.api.dependencies.permissions import require_superuser
# Override dependencies
async def override_get_mcp_client():
return mock_mcp_client
async def override_require_superuser():
return mock_superuser
app.dependency_overrides[get_mcp_client] = override_get_mcp_client
app.dependency_overrides[require_superuser] = override_require_superuser
with patch("app.main.check_database_health", return_value=True):
yield TestClient(app)
# Clean up
app.dependency_overrides.clear()
class TestListServers:
"""Tests for GET /mcp/servers endpoint."""
def test_list_servers_success(self, client, mock_mcp_client):
"""Test listing MCP servers returns correct data."""
# Setup mock
mock_mcp_client.list_servers.return_value = ["server-1", "server-2"]
mock_mcp_client.get_server_config.side_effect = [
MCPServerConfig(
url="http://server1:8000",
timeout=30,
enabled=True,
transport=TransportType.HTTP,
description="Server 1",
),
MCPServerConfig(
url="http://server2:8000",
timeout=60,
enabled=True,
transport=TransportType.SSE,
description="Server 2",
),
]
response = client.get("/api/v1/mcp/servers")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["total"] == 2
assert len(data["servers"]) == 2
assert data["servers"][0]["name"] == "server-1"
assert data["servers"][0]["url"] == "http://server1:8000"
assert data["servers"][1]["name"] == "server-2"
assert data["servers"][1]["transport"] == "sse"
def test_list_servers_empty(self, client, mock_mcp_client):
"""Test listing servers when none are registered."""
mock_mcp_client.list_servers.return_value = []
response = client.get("/api/v1/mcp/servers")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["total"] == 0
assert data["servers"] == []
def test_list_servers_handles_not_found(self, client, mock_mcp_client):
"""Test that missing server configs are skipped gracefully."""
mock_mcp_client.list_servers.return_value = ["server-1", "missing"]
mock_mcp_client.get_server_config.side_effect = [
MCPServerConfig(url="http://server1:8000"),
MCPServerNotFoundError(server_name="missing"),
]
response = client.get("/api/v1/mcp/servers")
assert response.status_code == status.HTTP_200_OK
data = response.json()
# Should only include the successfully retrieved server
assert data["total"] == 1
class TestListServerTools:
"""Tests for GET /mcp/servers/{server_name}/tools endpoint."""
def test_list_server_tools_success(self, client, mock_mcp_client):
"""Test listing tools for a specific server."""
mock_mcp_client.list_tools = AsyncMock(
return_value=[
ToolInfo(name="tool1", description="Tool 1", server_name="server-1"),
ToolInfo(name="tool2", description="Tool 2", server_name="server-1"),
]
)
response = client.get("/api/v1/mcp/servers/server-1/tools")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["total"] == 2
assert data["tools"][0]["name"] == "tool1"
assert data["tools"][1]["name"] == "tool2"
def test_list_server_tools_not_found(self, client, mock_mcp_client):
"""Test listing tools for non-existent server."""
mock_mcp_client.list_tools = AsyncMock(
side_effect=MCPServerNotFoundError(server_name="unknown")
)
response = client.get("/api/v1/mcp/servers/unknown/tools")
assert response.status_code == status.HTTP_404_NOT_FOUND
class TestListAllTools:
"""Tests for GET /mcp/tools endpoint."""
def test_list_all_tools_success(self, client, mock_mcp_client):
"""Test listing all tools from all servers."""
mock_mcp_client.list_all_tools = AsyncMock(
return_value=[
ToolInfo(name="tool1", server_name="server-1"),
ToolInfo(name="tool2", server_name="server-1"),
ToolInfo(name="tool3", server_name="server-2"),
]
)
response = client.get("/api/v1/mcp/tools")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["total"] == 3
def test_list_all_tools_empty(self, client, mock_mcp_client):
"""Test listing tools when none are available."""
mock_mcp_client.list_all_tools = AsyncMock(return_value=[])
response = client.get("/api/v1/mcp/tools")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["total"] == 0
class TestHealthCheck:
"""Tests for GET /mcp/health endpoint."""
def test_health_check_success(self, client, mock_mcp_client):
"""Test health check returns correct data."""
mock_mcp_client.health_check = AsyncMock(
return_value={
"server-1": ServerHealth(
name="server-1",
healthy=True,
state="connected",
url="http://server1:8000",
tools_count=5,
),
"server-2": ServerHealth(
name="server-2",
healthy=False,
state="error",
url="http://server2:8000",
error="Connection refused",
),
}
)
response = client.get("/api/v1/mcp/health")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["total"] == 2
assert data["healthy_count"] == 1
assert data["unhealthy_count"] == 1
assert data["servers"]["server-1"]["healthy"] is True
assert data["servers"]["server-2"]["healthy"] is False
def test_health_check_all_healthy(self, client, mock_mcp_client):
"""Test health check when all servers are healthy."""
mock_mcp_client.health_check = AsyncMock(
return_value={
"server-1": ServerHealth(
name="server-1",
healthy=True,
state="connected",
url="http://server1:8000",
),
}
)
response = client.get("/api/v1/mcp/health")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["healthy_count"] == 1
assert data["unhealthy_count"] == 0
class TestCallTool:
"""Tests for POST /mcp/call endpoint."""
def test_call_tool_success(self, client, mock_mcp_client):
"""Test successful tool execution."""
mock_mcp_client.call_tool = AsyncMock(
return_value=ToolResult(
success=True,
data={"result": "ok"},
tool_name="test-tool",
server_name="server-1",
execution_time_ms=123.45,
request_id="test-request-id",
)
)
response = client.post(
"/api/v1/mcp/call",
json={
"server": "server-1",
"tool": "test-tool",
"arguments": {"key": "value"},
},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["success"] is True
assert data["data"] == {"result": "ok"}
assert data["tool_name"] == "test-tool"
assert data["server_name"] == "server-1"
def test_call_tool_with_timeout(self, client, mock_mcp_client):
"""Test tool execution with custom timeout."""
mock_mcp_client.call_tool = AsyncMock(
return_value=ToolResult(success=True, data={})
)
response = client.post(
"/api/v1/mcp/call",
json={
"server": "server-1",
"tool": "test-tool",
"timeout": 60.0,
},
)
assert response.status_code == status.HTTP_200_OK
mock_mcp_client.call_tool.assert_called_once()
call_args = mock_mcp_client.call_tool.call_args
assert call_args.kwargs["timeout"] == 60.0
def test_call_tool_server_not_found(self, client, mock_mcp_client):
"""Test tool execution with non-existent server."""
mock_mcp_client.call_tool = AsyncMock(
side_effect=MCPServerNotFoundError(server_name="unknown")
)
response = client.post(
"/api/v1/mcp/call",
json={"server": "unknown", "tool": "test-tool"},
)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_call_tool_not_found(self, client, mock_mcp_client):
"""Test tool execution with non-existent tool."""
mock_mcp_client.call_tool = AsyncMock(
side_effect=MCPToolNotFoundError(tool_name="unknown")
)
response = client.post(
"/api/v1/mcp/call",
json={"server": "server-1", "tool": "unknown"},
)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_call_tool_timeout(self, client, mock_mcp_client):
"""Test tool execution timeout."""
mock_mcp_client.call_tool = AsyncMock(
side_effect=MCPTimeoutError(
"Request timed out",
server_name="server-1",
timeout_seconds=30.0,
)
)
response = client.post(
"/api/v1/mcp/call",
json={"server": "server-1", "tool": "slow-tool"},
)
assert response.status_code == status.HTTP_504_GATEWAY_TIMEOUT
def test_call_tool_connection_error(self, client, mock_mcp_client):
"""Test tool execution with connection failure."""
mock_mcp_client.call_tool = AsyncMock(
side_effect=MCPConnectionError(
"Connection refused",
server_name="server-1",
)
)
response = client.post(
"/api/v1/mcp/call",
json={"server": "server-1", "tool": "test-tool"},
)
assert response.status_code == status.HTTP_502_BAD_GATEWAY
def test_call_tool_circuit_open(self, client, mock_mcp_client):
"""Test tool execution with open circuit breaker."""
mock_mcp_client.call_tool = AsyncMock(
side_effect=MCPCircuitOpenError(
server_name="server-1",
failure_count=5,
)
)
response = client.post(
"/api/v1/mcp/call",
json={"server": "server-1", "tool": "test-tool"},
)
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
class TestCircuitBreakers:
"""Tests for circuit breaker endpoints."""
def test_list_circuit_breakers(self, client, mock_mcp_client):
"""Test listing circuit breaker statuses."""
mock_mcp_client.get_circuit_breaker_status.return_value = {
"server-1": {"state": "closed", "failure_count": 0},
"server-2": {"state": "open", "failure_count": 5},
}
response = client.get("/api/v1/mcp/circuit-breakers")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert len(data["circuit_breakers"]) == 2
def test_list_circuit_breakers_empty(self, client, mock_mcp_client):
"""Test listing when no circuit breakers exist."""
mock_mcp_client.get_circuit_breaker_status.return_value = {}
response = client.get("/api/v1/mcp/circuit-breakers")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["circuit_breakers"] == []
def test_reset_circuit_breaker_success(self, client, mock_mcp_client):
"""Test successfully resetting a circuit breaker."""
mock_mcp_client.reset_circuit_breaker = AsyncMock(return_value=True)
response = client.post("/api/v1/mcp/circuit-breakers/server-1/reset")
assert response.status_code == status.HTTP_204_NO_CONTENT
def test_reset_circuit_breaker_not_found(self, client, mock_mcp_client):
"""Test resetting non-existent circuit breaker."""
mock_mcp_client.reset_circuit_breaker = AsyncMock(return_value=False)
response = client.post("/api/v1/mcp/circuit-breakers/unknown/reset")
assert response.status_code == status.HTTP_404_NOT_FOUND
class TestReconnectServer:
"""Tests for POST /mcp/servers/{server_name}/reconnect endpoint."""
def test_reconnect_success(self, client, mock_mcp_client):
"""Test successful server reconnection."""
mock_mcp_client.disconnect = AsyncMock()
mock_mcp_client.connect = AsyncMock()
response = client.post("/api/v1/mcp/servers/server-1/reconnect")
assert response.status_code == status.HTTP_204_NO_CONTENT
mock_mcp_client.disconnect.assert_called_once_with("server-1")
mock_mcp_client.connect.assert_called_once_with("server-1")
def test_reconnect_server_not_found(self, client, mock_mcp_client):
"""Test reconnecting to non-existent server."""
mock_mcp_client.disconnect = AsyncMock(
side_effect=MCPServerNotFoundError(server_name="unknown")
)
response = client.post("/api/v1/mcp/servers/unknown/reconnect")
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_reconnect_connection_failure(self, client, mock_mcp_client):
"""Test reconnection failure."""
mock_mcp_client.disconnect = AsyncMock()
mock_mcp_client.connect = AsyncMock(
side_effect=MCPConnectionError(
"Connection refused",
server_name="server-1",
)
)
response = client.post("/api/v1/mcp/servers/server-1/reconnect")
assert response.status_code == status.HTTP_502_BAD_GATEWAY
class TestMCPEndpointsEdgeCases:
"""Edge case tests for MCP endpoints."""
def test_servers_content_type(self, client, mock_mcp_client):
"""Test that endpoints return JSON content type."""
mock_mcp_client.list_servers.return_value = []
response = client.get("/api/v1/mcp/servers")
assert "application/json" in response.headers["content-type"]
def test_call_tool_validation_error(self, client):
"""Test that invalid request body returns validation error."""
response = client.post(
"/api/v1/mcp/call",
json={}, # Missing required fields
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
def test_call_tool_missing_server(self, client):
"""Test that missing server field returns validation error."""
response = client.post(
"/api/v1/mcp/call",
json={"tool": "test-tool"},
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
def test_call_tool_missing_tool(self, client):
"""Test that missing tool field returns validation error."""
response = client.post(
"/api/v1/mcp/call",
json={"server": "server-1"},
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY

View File

@@ -0,0 +1 @@
"""MCP Service Tests Package."""

View File

@@ -0,0 +1,395 @@
"""
Tests for MCP Client Manager
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.services.mcp.client_manager import (
MCPClientManager,
ServerHealth,
get_mcp_client,
reset_mcp_client,
shutdown_mcp_client,
)
from app.services.mcp.config import MCPConfig, MCPServerConfig
from app.services.mcp.connection import ConnectionState
from app.services.mcp.exceptions import MCPServerNotFoundError
from app.services.mcp.registry import MCPServerRegistry
from app.services.mcp.routing import ToolInfo, ToolResult
@pytest.fixture
def reset_registry():
"""Reset the singleton registry before and after each test."""
MCPServerRegistry.reset_instance()
reset_mcp_client()
yield
MCPServerRegistry.reset_instance()
reset_mcp_client()
@pytest.fixture
def sample_config():
"""Create a sample MCP configuration."""
return MCPConfig(
mcp_servers={
"server-1": MCPServerConfig(
url="http://server1:8000",
timeout=30,
enabled=True,
),
"server-2": MCPServerConfig(
url="http://server2:8000",
timeout=60,
enabled=True,
),
}
)
class TestServerHealth:
"""Tests for ServerHealth dataclass."""
def test_healthy_server(self):
"""Test healthy server status."""
health = ServerHealth(
name="test-server",
healthy=True,
state="connected",
url="http://test:8000",
tools_count=5,
)
assert health.healthy is True
assert health.error is None
assert health.tools_count == 5
def test_unhealthy_server(self):
"""Test unhealthy server status."""
health = ServerHealth(
name="test-server",
healthy=False,
state="error",
url="http://test:8000",
error="Connection refused",
)
assert health.healthy is False
assert health.error == "Connection refused"
def test_to_dict(self):
"""Test converting to dictionary."""
health = ServerHealth(
name="test-server",
healthy=True,
state="connected",
url="http://test:8000",
tools_count=3,
)
d = health.to_dict()
assert d["name"] == "test-server"
assert d["healthy"] is True
assert d["state"] == "connected"
assert d["url"] == "http://test:8000"
assert d["tools_count"] == 3
class TestMCPClientManager:
"""Tests for MCPClientManager class."""
def test_initial_state(self, reset_registry):
"""Test initial manager state."""
manager = MCPClientManager()
assert manager.is_initialized is False
@pytest.mark.asyncio
async def test_initialize(self, reset_registry, sample_config):
"""Test manager initialization."""
manager = MCPClientManager(config=sample_config)
with patch.object(manager._pool, "get_connection") as mock_get_conn:
mock_conn = AsyncMock()
mock_conn.is_connected = True
mock_get_conn.return_value = mock_conn
with patch.object(manager, "_router") as mock_router:
mock_router.discover_tools = AsyncMock()
await manager.initialize()
assert manager.is_initialized is True
@pytest.mark.asyncio
async def test_shutdown(self, reset_registry, sample_config):
"""Test manager shutdown."""
manager = MCPClientManager(config=sample_config)
manager._initialized = True
with patch.object(manager._pool, "close_all") as mock_close:
mock_close.return_value = None
await manager.shutdown()
assert manager.is_initialized is False
mock_close.assert_called_once()
@pytest.mark.asyncio
async def test_connect(self, reset_registry, sample_config):
"""Test connecting to specific server."""
manager = MCPClientManager(config=sample_config)
with patch.object(manager._pool, "get_connection") as mock_get_conn:
mock_conn = AsyncMock()
mock_conn.is_connected = True
mock_get_conn.return_value = mock_conn
await manager.connect("server-1")
mock_get_conn.assert_called_once()
@pytest.mark.asyncio
async def test_connect_unknown_server(self, reset_registry, sample_config):
"""Test connecting to unknown server raises error."""
manager = MCPClientManager(config=sample_config)
with pytest.raises(MCPServerNotFoundError):
await manager.connect("unknown-server")
@pytest.mark.asyncio
async def test_disconnect(self, reset_registry, sample_config):
"""Test disconnecting from server."""
manager = MCPClientManager(config=sample_config)
with patch.object(manager._pool, "close_connection") as mock_close:
await manager.disconnect("server-1")
mock_close.assert_called_once_with("server-1")
@pytest.mark.asyncio
async def test_disconnect_all(self, reset_registry, sample_config):
"""Test disconnecting from all servers."""
manager = MCPClientManager(config=sample_config)
with patch.object(manager._pool, "close_all") as mock_close:
await manager.disconnect_all()
mock_close.assert_called_once()
@pytest.mark.asyncio
async def test_call_tool(self, reset_registry, sample_config):
"""Test calling a tool."""
manager = MCPClientManager(config=sample_config)
manager._initialized = True
expected_result = ToolResult(
success=True,
data={"id": "123"},
tool_name="create_issue",
server_name="server-1",
)
mock_router = MagicMock()
mock_router.call_tool = AsyncMock(return_value=expected_result)
manager._router = mock_router
result = await manager.call_tool(
server="server-1",
tool="create_issue",
args={"title": "Test"},
)
assert result.success is True
assert result.data == {"id": "123"}
mock_router.call_tool.assert_called_once()
@pytest.mark.asyncio
async def test_route_tool(self, reset_registry, sample_config):
"""Test routing a tool call."""
manager = MCPClientManager(config=sample_config)
manager._initialized = True
expected_result = ToolResult(
success=True,
data={"result": "ok"},
tool_name="auto_tool",
server_name="server-2",
)
mock_router = MagicMock()
mock_router.route_tool = AsyncMock(return_value=expected_result)
manager._router = mock_router
result = await manager.route_tool(
tool="auto_tool",
args={"key": "value"},
)
assert result.success is True
assert result.server_name == "server-2"
mock_router.route_tool.assert_called_once()
@pytest.mark.asyncio
async def test_list_tools(self, reset_registry, sample_config):
"""Test listing tools for a server."""
manager = MCPClientManager(config=sample_config)
# Set up capabilities in registry
manager._registry.set_capabilities(
"server-1",
tools=[
{"name": "tool1", "description": "Tool 1"},
{"name": "tool2", "description": "Tool 2"},
],
)
tools = await manager.list_tools("server-1")
assert len(tools) == 2
assert tools[0].name == "tool1"
assert tools[1].name == "tool2"
@pytest.mark.asyncio
async def test_list_all_tools(self, reset_registry, sample_config):
"""Test listing all tools from all servers."""
manager = MCPClientManager(config=sample_config)
manager._initialized = True
expected_tools = [
ToolInfo(name="tool1", server_name="server-1"),
ToolInfo(name="tool2", server_name="server-2"),
]
mock_router = MagicMock()
mock_router.list_all_tools = AsyncMock(return_value=expected_tools)
manager._router = mock_router
tools = await manager.list_all_tools()
assert len(tools) == 2
@pytest.mark.asyncio
async def test_health_check(self, reset_registry, sample_config):
"""Test health check on all servers."""
manager = MCPClientManager(config=sample_config)
with patch.object(manager._pool, "get_status") as mock_status:
mock_status.return_value = {
"server-1": {"state": "connected"},
"server-2": {"state": "disconnected"},
}
with patch.object(manager._pool, "health_check_all") as mock_health:
mock_health.return_value = {
"server-1": True,
"server-2": False,
}
health = await manager.health_check()
assert "server-1" in health
assert "server-2" in health
assert health["server-1"].healthy is True
assert health["server-2"].healthy is False
def test_list_servers(self, reset_registry, sample_config):
"""Test listing registered servers."""
manager = MCPClientManager(config=sample_config)
servers = manager.list_servers()
assert "server-1" in servers
assert "server-2" in servers
def test_list_enabled_servers(self, reset_registry, sample_config):
"""Test listing enabled servers."""
manager = MCPClientManager(config=sample_config)
servers = manager.list_enabled_servers()
assert "server-1" in servers
assert "server-2" in servers
def test_get_server_config(self, reset_registry, sample_config):
"""Test getting server configuration."""
manager = MCPClientManager(config=sample_config)
config = manager.get_server_config("server-1")
assert config.url == "http://server1:8000"
assert config.timeout == 30
def test_get_server_config_not_found(self, reset_registry, sample_config):
"""Test getting unknown server config raises error."""
manager = MCPClientManager(config=sample_config)
with pytest.raises(MCPServerNotFoundError):
manager.get_server_config("unknown")
def test_register_server(self, reset_registry, sample_config):
"""Test registering new server at runtime."""
manager = MCPClientManager(config=sample_config)
new_config = MCPServerConfig(url="http://new:8000")
manager.register_server("new-server", new_config)
assert "new-server" in manager.list_servers()
def test_unregister_server(self, reset_registry, sample_config):
"""Test unregistering a server."""
manager = MCPClientManager(config=sample_config)
result = manager.unregister_server("server-1")
assert result is True
assert "server-1" not in manager.list_servers()
# Unregistering non-existent returns False
result = manager.unregister_server("nonexistent")
assert result is False
def test_circuit_breaker_status(self, reset_registry, sample_config):
"""Test getting circuit breaker status."""
manager = MCPClientManager(config=sample_config)
# No router yet
status = manager.get_circuit_breaker_status()
assert status == {}
@pytest.mark.asyncio
async def test_reset_circuit_breaker(self, reset_registry, sample_config):
"""Test resetting circuit breaker."""
manager = MCPClientManager(config=sample_config)
# No router yet
result = await manager.reset_circuit_breaker("server-1")
assert result is False
class TestModuleLevelFunctions:
"""Tests for module-level convenience functions."""
@pytest.mark.asyncio
async def test_get_mcp_client_creates_singleton(self, reset_registry):
"""Test get_mcp_client creates and returns singleton."""
with patch(
"app.services.mcp.client_manager.MCPClientManager.initialize"
) as mock_init:
mock_init.return_value = None
client1 = await get_mcp_client()
client2 = await get_mcp_client()
assert client1 is client2
@pytest.mark.asyncio
async def test_shutdown_mcp_client(self, reset_registry):
"""Test shutting down the global client."""
with patch(
"app.services.mcp.client_manager.MCPClientManager.initialize"
) as mock_init:
mock_init.return_value = None
client = await get_mcp_client()
with patch.object(client, "shutdown") as mock_shutdown:
mock_shutdown.return_value = None
await shutdown_mcp_client()
def test_reset_mcp_client(self, reset_registry):
"""Test resetting the global client."""
reset_mcp_client()
# Should not raise

View File

@@ -0,0 +1,319 @@
"""
Tests for MCP Configuration System
"""
import os
import tempfile
from pathlib import Path
import pytest
import yaml
from app.services.mcp.config import (
MCPConfig,
MCPServerConfig,
TransportType,
create_default_config,
load_mcp_config,
)
class TestTransportType:
"""Tests for TransportType enum."""
def test_transport_types(self):
"""Test that all transport types are defined."""
assert TransportType.HTTP == "http"
assert TransportType.STDIO == "stdio"
assert TransportType.SSE == "sse"
def test_transport_type_from_string(self):
"""Test creating transport type from string."""
assert TransportType("http") == TransportType.HTTP
assert TransportType("stdio") == TransportType.STDIO
assert TransportType("sse") == TransportType.SSE
class TestMCPServerConfig:
"""Tests for MCPServerConfig model."""
def test_minimal_config(self):
"""Test creating config with only required fields."""
config = MCPServerConfig(url="http://localhost:8000")
assert config.url == "http://localhost:8000"
assert config.transport == TransportType.HTTP
assert config.timeout == 30
assert config.retry_attempts == 3
assert config.enabled is True
def test_full_config(self):
"""Test creating config with all fields."""
config = MCPServerConfig(
url="http://localhost:8000",
transport=TransportType.SSE,
timeout=60,
retry_attempts=5,
retry_delay=2.0,
retry_max_delay=60.0,
circuit_breaker_threshold=10,
circuit_breaker_timeout=60.0,
enabled=False,
description="Test server",
)
assert config.timeout == 60
assert config.transport == TransportType.SSE
assert config.retry_attempts == 5
assert config.retry_delay == 2.0
assert config.retry_max_delay == 60.0
assert config.circuit_breaker_threshold == 10
assert config.circuit_breaker_timeout == 60.0
assert config.enabled is False
assert config.description == "Test server"
def test_env_var_expansion_simple(self):
"""Test simple environment variable expansion."""
os.environ["TEST_SERVER_URL"] = "http://test-server:9000"
try:
config = MCPServerConfig(url="${TEST_SERVER_URL}")
assert config.url == "http://test-server:9000"
finally:
del os.environ["TEST_SERVER_URL"]
def test_env_var_expansion_with_default(self):
"""Test environment variable expansion with default."""
# Ensure env var is not set
os.environ.pop("NONEXISTENT_URL", None)
config = MCPServerConfig(url="${NONEXISTENT_URL:-http://default:8000}")
assert config.url == "http://default:8000"
def test_env_var_expansion_override_default(self):
"""Test environment variable override of default."""
os.environ["TEST_OVERRIDE_URL"] = "http://override:9000"
try:
config = MCPServerConfig(url="${TEST_OVERRIDE_URL:-http://default:8000}")
assert config.url == "http://override:9000"
finally:
del os.environ["TEST_OVERRIDE_URL"]
def test_timeout_validation(self):
"""Test timeout validation bounds."""
# Valid bounds
config = MCPServerConfig(url="http://localhost", timeout=1)
assert config.timeout == 1
config = MCPServerConfig(url="http://localhost", timeout=600)
assert config.timeout == 600
# Invalid bounds
with pytest.raises(ValueError):
MCPServerConfig(url="http://localhost", timeout=0)
with pytest.raises(ValueError):
MCPServerConfig(url="http://localhost", timeout=601)
def test_retry_attempts_validation(self):
"""Test retry attempts validation bounds."""
config = MCPServerConfig(url="http://localhost", retry_attempts=0)
assert config.retry_attempts == 0
config = MCPServerConfig(url="http://localhost", retry_attempts=10)
assert config.retry_attempts == 10
with pytest.raises(ValueError):
MCPServerConfig(url="http://localhost", retry_attempts=-1)
with pytest.raises(ValueError):
MCPServerConfig(url="http://localhost", retry_attempts=11)
class TestMCPConfig:
"""Tests for MCPConfig model."""
def test_empty_config(self):
"""Test creating empty config."""
config = MCPConfig()
assert config.mcp_servers == {}
assert config.default_timeout == 30
assert config.default_retry_attempts == 3
assert config.connection_pool_size == 10
assert config.health_check_interval == 30
def test_config_with_servers(self):
"""Test creating config with servers."""
config = MCPConfig(
mcp_servers={
"server-1": MCPServerConfig(url="http://server1:8000"),
"server-2": MCPServerConfig(url="http://server2:8000"),
}
)
assert len(config.mcp_servers) == 2
assert "server-1" in config.mcp_servers
assert "server-2" in config.mcp_servers
def test_get_server(self):
"""Test getting server by name."""
config = MCPConfig(
mcp_servers={
"server-1": MCPServerConfig(url="http://server1:8000"),
}
)
server = config.get_server("server-1")
assert server is not None
assert server.url == "http://server1:8000"
missing = config.get_server("nonexistent")
assert missing is None
def test_get_enabled_servers(self):
"""Test getting only enabled servers."""
config = MCPConfig(
mcp_servers={
"enabled-1": MCPServerConfig(url="http://e1:8000", enabled=True),
"disabled-1": MCPServerConfig(url="http://d1:8000", enabled=False),
"enabled-2": MCPServerConfig(url="http://e2:8000", enabled=True),
}
)
enabled = config.get_enabled_servers()
assert len(enabled) == 2
assert "enabled-1" in enabled
assert "enabled-2" in enabled
assert "disabled-1" not in enabled
def test_list_server_names(self):
"""Test listing server names."""
config = MCPConfig(
mcp_servers={
"server-a": MCPServerConfig(url="http://a:8000"),
"server-b": MCPServerConfig(url="http://b:8000"),
}
)
names = config.list_server_names()
assert sorted(names) == ["server-a", "server-b"]
def test_from_dict(self):
"""Test creating config from dictionary."""
data = {
"mcp_servers": {
"test-server": {
"url": "http://test:8000",
"timeout": 45,
}
},
"default_timeout": 60,
}
config = MCPConfig.from_dict(data)
assert config.default_timeout == 60
assert config.mcp_servers["test-server"].timeout == 45
def test_from_yaml(self):
"""Test loading config from YAML file."""
yaml_content = """
mcp_servers:
test-server:
url: http://test:8000
timeout: 45
transport: http
enabled: true
default_timeout: 60
connection_pool_size: 20
"""
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False
) as f:
f.write(yaml_content)
f.flush()
try:
config = MCPConfig.from_yaml(f.name)
assert config.default_timeout == 60
assert config.connection_pool_size == 20
assert "test-server" in config.mcp_servers
assert config.mcp_servers["test-server"].timeout == 45
finally:
os.unlink(f.name)
def test_from_yaml_file_not_found(self):
"""Test error when YAML file not found."""
with pytest.raises(FileNotFoundError):
MCPConfig.from_yaml("/nonexistent/path/config.yaml")
class TestLoadMCPConfig:
"""Tests for load_mcp_config function."""
def test_load_with_explicit_path(self):
"""Test loading config with explicit path."""
yaml_content = """
mcp_servers:
explicit-server:
url: http://explicit:8000
"""
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False
) as f:
f.write(yaml_content)
f.flush()
try:
config = load_mcp_config(f.name)
assert "explicit-server" in config.mcp_servers
finally:
os.unlink(f.name)
def test_load_with_env_var(self):
"""Test loading config from environment variable path."""
yaml_content = """
mcp_servers:
env-server:
url: http://env:8000
"""
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False
) as f:
f.write(yaml_content)
f.flush()
os.environ["MCP_CONFIG_PATH"] = f.name
try:
config = load_mcp_config()
assert "env-server" in config.mcp_servers
finally:
del os.environ["MCP_CONFIG_PATH"]
os.unlink(f.name)
def test_load_returns_empty_config_if_missing(self):
"""Test that missing file returns empty config."""
os.environ.pop("MCP_CONFIG_PATH", None)
config = load_mcp_config("/nonexistent/path/config.yaml")
assert config.mcp_servers == {}
class TestCreateDefaultConfig:
"""Tests for create_default_config function."""
def test_creates_standard_servers(self):
"""Test that default config has standard servers."""
config = create_default_config()
assert "llm-gateway" in config.mcp_servers
assert "knowledge-base" in config.mcp_servers
assert "git-ops" in config.mcp_servers
assert "issues" in config.mcp_servers
def test_servers_have_correct_defaults(self):
"""Test that servers have correct default values."""
config = create_default_config()
llm = config.mcp_servers["llm-gateway"]
assert llm.timeout == 60 # LLM has longer timeout
assert llm.transport == TransportType.HTTP
git = config.mcp_servers["git-ops"]
assert git.timeout == 120 # Git ops has longest timeout
def test_servers_are_enabled(self):
"""Test that all default servers are enabled."""
config = create_default_config()
for name, server in config.mcp_servers.items():
assert server.enabled is True, f"Server {name} should be enabled"

View File

@@ -0,0 +1,405 @@
"""
Tests for MCP Connection Management
"""
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from app.services.mcp.config import MCPServerConfig, TransportType
from app.services.mcp.connection import (
ConnectionPool,
ConnectionState,
MCPConnection,
)
from app.services.mcp.exceptions import MCPConnectionError, MCPTimeoutError
@pytest.fixture
def server_config():
"""Create a sample server configuration."""
return MCPServerConfig(
url="http://localhost:8000",
transport=TransportType.HTTP,
timeout=30,
retry_attempts=3,
retry_delay=0.1, # Short delay for tests
retry_max_delay=1.0,
)
class TestConnectionState:
"""Tests for ConnectionState enum."""
def test_connection_states(self):
"""Test all connection states are defined."""
assert ConnectionState.DISCONNECTED == "disconnected"
assert ConnectionState.CONNECTING == "connecting"
assert ConnectionState.CONNECTED == "connected"
assert ConnectionState.RECONNECTING == "reconnecting"
assert ConnectionState.ERROR == "error"
class TestMCPConnection:
"""Tests for MCPConnection class."""
def test_initial_state(self, server_config):
"""Test initial connection state."""
conn = MCPConnection("test-server", server_config)
assert conn.server_name == "test-server"
assert conn.config == server_config
assert conn.state == ConnectionState.DISCONNECTED
assert conn.is_connected is False
assert conn.last_error is None
@pytest.mark.asyncio
async def test_connect_success(self, server_config):
"""Test successful connection."""
conn = MCPConnection("test-server", server_config)
mock_response = MagicMock()
mock_response.status_code = 200
with patch("httpx.AsyncClient") as MockClient:
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_response)
MockClient.return_value = mock_client
await conn.connect()
assert conn.state == ConnectionState.CONNECTED
assert conn.is_connected is True
@pytest.mark.asyncio
async def test_connect_404_capabilities_ok(self, server_config):
"""Test connection succeeds even if /mcp/capabilities returns 404."""
conn = MCPConnection("test-server", server_config)
mock_response = MagicMock()
mock_response.status_code = 404
with patch("httpx.AsyncClient") as MockClient:
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_response)
MockClient.return_value = mock_client
await conn.connect()
assert conn.state == ConnectionState.CONNECTED
@pytest.mark.asyncio
async def test_connect_failure_with_retry(self, server_config):
"""Test connection failure with retries."""
# Reduce retry attempts for faster test
server_config.retry_attempts = 2
server_config.retry_delay = 0.01
conn = MCPConnection("test-server", server_config)
with patch("httpx.AsyncClient") as MockClient:
mock_client = AsyncMock()
mock_client.get = AsyncMock(
side_effect=httpx.ConnectError("Connection refused")
)
MockClient.return_value = mock_client
with pytest.raises(MCPConnectionError) as exc_info:
await conn.connect()
assert conn.state == ConnectionState.ERROR
assert "Failed to connect after" in str(exc_info.value)
@pytest.mark.asyncio
async def test_disconnect(self, server_config):
"""Test disconnection."""
conn = MCPConnection("test-server", server_config)
mock_response = MagicMock()
mock_response.status_code = 200
with patch("httpx.AsyncClient") as MockClient:
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_response)
mock_client.aclose = AsyncMock()
MockClient.return_value = mock_client
await conn.connect()
assert conn.is_connected is True
await conn.disconnect()
assert conn.state == ConnectionState.DISCONNECTED
assert conn.is_connected is False
@pytest.mark.asyncio
async def test_reconnect(self, server_config):
"""Test reconnection."""
conn = MCPConnection("test-server", server_config)
mock_response = MagicMock()
mock_response.status_code = 200
with patch("httpx.AsyncClient") as MockClient:
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_response)
mock_client.aclose = AsyncMock()
MockClient.return_value = mock_client
await conn.connect()
assert conn.is_connected is True
await conn.reconnect()
assert conn.is_connected is True
@pytest.mark.asyncio
async def test_health_check_success(self, server_config):
"""Test successful health check."""
conn = MCPConnection("test-server", server_config)
mock_response = MagicMock()
mock_response.status_code = 200
with patch("httpx.AsyncClient") as MockClient:
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_response)
MockClient.return_value = mock_client
await conn.connect()
healthy = await conn.health_check()
assert healthy is True
@pytest.mark.asyncio
async def test_health_check_disconnected(self, server_config):
"""Test health check when disconnected."""
conn = MCPConnection("test-server", server_config)
healthy = await conn.health_check()
assert healthy is False
@pytest.mark.asyncio
async def test_execute_request_get(self, server_config):
"""Test executing GET request."""
conn = MCPConnection("test-server", server_config)
mock_connect_response = MagicMock()
mock_connect_response.status_code = 200
mock_request_response = MagicMock()
mock_request_response.status_code = 200
mock_request_response.json.return_value = {"tools": []}
mock_request_response.raise_for_status = MagicMock()
with patch("httpx.AsyncClient") as MockClient:
mock_client = AsyncMock()
mock_client.get = AsyncMock(
side_effect=[mock_connect_response, mock_request_response]
)
MockClient.return_value = mock_client
await conn.connect()
result = await conn.execute_request("GET", "/mcp/tools")
assert result == {"tools": []}
@pytest.mark.asyncio
async def test_execute_request_post(self, server_config):
"""Test executing POST request."""
conn = MCPConnection("test-server", server_config)
mock_connect_response = MagicMock()
mock_connect_response.status_code = 200
mock_request_response = MagicMock()
mock_request_response.status_code = 200
mock_request_response.json.return_value = {"result": "success"}
mock_request_response.raise_for_status = MagicMock()
with patch("httpx.AsyncClient") as MockClient:
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_connect_response)
mock_client.post = AsyncMock(return_value=mock_request_response)
MockClient.return_value = mock_client
await conn.connect()
result = await conn.execute_request(
"POST", "/mcp", data={"method": "test"}
)
assert result == {"result": "success"}
@pytest.mark.asyncio
async def test_execute_request_timeout(self, server_config):
"""Test request timeout."""
conn = MCPConnection("test-server", server_config)
mock_connect_response = MagicMock()
mock_connect_response.status_code = 200
with patch("httpx.AsyncClient") as MockClient:
mock_client = AsyncMock()
mock_client.get = AsyncMock(
side_effect=[
mock_connect_response,
httpx.TimeoutException("Request timeout"),
]
)
MockClient.return_value = mock_client
await conn.connect()
with pytest.raises(MCPTimeoutError) as exc_info:
await conn.execute_request("GET", "/slow-endpoint")
assert "timeout" in str(exc_info.value).lower()
@pytest.mark.asyncio
async def test_execute_request_not_connected(self, server_config):
"""Test request when not connected."""
conn = MCPConnection("test-server", server_config)
with pytest.raises(MCPConnectionError) as exc_info:
await conn.execute_request("GET", "/test")
assert "Not connected" in str(exc_info.value)
def test_backoff_delay_calculation(self, server_config):
"""Test exponential backoff delay calculation."""
conn = MCPConnection("test-server", server_config)
# First attempt
conn._connection_attempts = 1
delay1 = conn._calculate_backoff_delay()
# Second attempt
conn._connection_attempts = 2
delay2 = conn._calculate_backoff_delay()
# Third attempt
conn._connection_attempts = 3
delay3 = conn._calculate_backoff_delay()
# Delays should generally increase (with some jitter)
# Base is 0.1, so rough expectations:
# Attempt 1: ~0.1s
# Attempt 2: ~0.2s
# Attempt 3: ~0.4s
assert delay1 > 0
assert delay2 > delay1 * 0.5 # Allow for jitter
assert delay3 <= server_config.retry_max_delay * 1.25 # Within max + jitter
class TestConnectionPool:
"""Tests for ConnectionPool class."""
@pytest.fixture
def pool(self):
"""Create a connection pool."""
return ConnectionPool(max_connections_per_server=5)
@pytest.mark.asyncio
async def test_get_connection_creates_new(self, pool, server_config):
"""Test getting connection creates new one if not exists."""
mock_response = MagicMock()
mock_response.status_code = 200
with patch("httpx.AsyncClient") as MockClient:
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_response)
MockClient.return_value = mock_client
conn = await pool.get_connection("test-server", server_config)
assert conn is not None
assert conn.is_connected is True
assert conn.server_name == "test-server"
@pytest.mark.asyncio
async def test_get_connection_reuses_existing(self, pool, server_config):
"""Test getting connection reuses existing one."""
mock_response = MagicMock()
mock_response.status_code = 200
with patch("httpx.AsyncClient") as MockClient:
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_response)
MockClient.return_value = mock_client
conn1 = await pool.get_connection("test-server", server_config)
conn2 = await pool.get_connection("test-server", server_config)
assert conn1 is conn2
@pytest.mark.asyncio
async def test_close_connection(self, pool, server_config):
"""Test closing a connection."""
mock_response = MagicMock()
mock_response.status_code = 200
with patch("httpx.AsyncClient") as MockClient:
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_response)
mock_client.aclose = AsyncMock()
MockClient.return_value = mock_client
await pool.get_connection("test-server", server_config)
await pool.close_connection("test-server")
assert "test-server" not in pool._connections
@pytest.mark.asyncio
async def test_close_all(self, pool, server_config):
"""Test closing all connections."""
mock_response = MagicMock()
mock_response.status_code = 200
with patch("httpx.AsyncClient") as MockClient:
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_response)
mock_client.aclose = AsyncMock()
MockClient.return_value = mock_client
config2 = MCPServerConfig(url="http://server2:8000")
await pool.get_connection("server-1", server_config)
await pool.get_connection("server-2", config2)
assert len(pool._connections) == 2
await pool.close_all()
assert len(pool._connections) == 0
@pytest.mark.asyncio
async def test_health_check_all(self, pool, server_config):
"""Test health check on all connections."""
mock_response = MagicMock()
mock_response.status_code = 200
with patch("httpx.AsyncClient") as MockClient:
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_response)
MockClient.return_value = mock_client
await pool.get_connection("test-server", server_config)
results = await pool.health_check_all()
assert "test-server" in results
assert results["test-server"] is True
def test_get_status(self, pool, server_config):
"""Test getting pool status."""
status = pool.get_status()
assert status == {}
@pytest.mark.asyncio
async def test_connection_context_manager(self, pool, server_config):
"""Test connection context manager."""
mock_response = MagicMock()
mock_response.status_code = 200
with patch("httpx.AsyncClient") as MockClient:
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_response)
MockClient.return_value = mock_client
async with pool.connection("test-server", server_config) as conn:
assert conn.is_connected is True
assert conn.server_name == "test-server"

View File

@@ -0,0 +1,259 @@
"""
Tests for MCP Exception Classes
"""
import pytest
from app.services.mcp.exceptions import (
MCPCircuitOpenError,
MCPConnectionError,
MCPError,
MCPServerNotFoundError,
MCPTimeoutError,
MCPToolError,
MCPToolNotFoundError,
MCPValidationError,
)
class TestMCPError:
"""Tests for base MCPError class."""
def test_basic_error(self):
"""Test basic error creation."""
error = MCPError("Test error")
assert error.message == "Test error"
assert error.server_name is None
assert error.details == {}
assert str(error) == "Test error"
def test_error_with_server_name(self):
"""Test error with server name."""
error = MCPError("Test error", server_name="test-server")
assert error.server_name == "test-server"
assert "server=test-server" in str(error)
def test_error_with_details(self):
"""Test error with additional details."""
error = MCPError(
"Test error",
server_name="test-server",
details={"key": "value"},
)
assert error.details == {"key": "value"}
assert "details={'key': 'value'}" in str(error)
class TestMCPConnectionError:
"""Tests for MCPConnectionError class."""
def test_basic_connection_error(self):
"""Test basic connection error."""
error = MCPConnectionError("Connection failed")
assert error.message == "Connection failed"
assert error.url is None
assert error.cause is None
def test_connection_error_with_url(self):
"""Test connection error with URL."""
error = MCPConnectionError(
"Connection failed",
server_name="test-server",
url="http://localhost:8000",
)
assert error.url == "http://localhost:8000"
assert "url=http://localhost:8000" in str(error)
def test_connection_error_with_cause(self):
"""Test connection error with cause."""
cause = ConnectionError("Network error")
error = MCPConnectionError(
"Connection failed",
cause=cause,
)
assert error.cause is cause
assert "ConnectionError" in str(error)
class TestMCPTimeoutError:
"""Tests for MCPTimeoutError class."""
def test_basic_timeout_error(self):
"""Test basic timeout error."""
error = MCPTimeoutError("Request timed out")
assert error.message == "Request timed out"
assert error.timeout_seconds is None
assert error.operation is None
def test_timeout_error_with_details(self):
"""Test timeout error with details."""
error = MCPTimeoutError(
"Request timed out",
server_name="test-server",
timeout_seconds=30.0,
operation="POST /mcp",
)
assert error.timeout_seconds == 30.0
assert error.operation == "POST /mcp"
assert "timeout=30.0s" in str(error)
assert "operation=POST /mcp" in str(error)
class TestMCPToolError:
"""Tests for MCPToolError class."""
def test_basic_tool_error(self):
"""Test basic tool error."""
error = MCPToolError("Tool execution failed")
assert error.message == "Tool execution failed"
assert error.tool_name is None
assert error.tool_args is None
assert error.error_code is None
def test_tool_error_with_details(self):
"""Test tool error with all details."""
error = MCPToolError(
"Tool execution failed",
server_name="llm-gateway",
tool_name="chat",
tool_args={"prompt": "Hello"},
error_code="INVALID_ARGS",
)
assert error.tool_name == "chat"
assert error.tool_args == {"prompt": "Hello"}
assert error.error_code == "INVALID_ARGS"
assert "tool=chat" in str(error)
assert "error_code=INVALID_ARGS" in str(error)
class TestMCPServerNotFoundError:
"""Tests for MCPServerNotFoundError class."""
def test_server_not_found(self):
"""Test server not found error."""
error = MCPServerNotFoundError("unknown-server")
assert error.server_name == "unknown-server"
assert "MCP server not found: unknown-server" in error.message
assert error.available_servers == []
def test_server_not_found_with_available(self):
"""Test server not found with available servers listed."""
error = MCPServerNotFoundError(
"unknown-server",
available_servers=["server-1", "server-2"],
)
assert error.available_servers == ["server-1", "server-2"]
assert "available=['server-1', 'server-2']" in str(error)
class TestMCPToolNotFoundError:
"""Tests for MCPToolNotFoundError class."""
def test_tool_not_found(self):
"""Test tool not found error."""
error = MCPToolNotFoundError("unknown-tool")
assert error.tool_name == "unknown-tool"
assert "Tool not found: unknown-tool" in error.message
assert error.available_tools == []
def test_tool_not_found_with_available(self):
"""Test tool not found with available tools listed."""
error = MCPToolNotFoundError(
"unknown-tool",
available_tools=["tool-1", "tool-2", "tool-3", "tool-4", "tool-5", "tool-6"],
)
assert len(error.available_tools) == 6
# Should show first 5 tools with ellipsis
assert "available_tools=['tool-1', 'tool-2', 'tool-3', 'tool-4', 'tool-5']..." in str(error)
class TestMCPCircuitOpenError:
"""Tests for MCPCircuitOpenError class."""
def test_circuit_open_error(self):
"""Test circuit open error."""
error = MCPCircuitOpenError("test-server")
assert error.server_name == "test-server"
assert "Circuit breaker open for server: test-server" in error.message
assert error.failure_count is None
assert error.reset_timeout is None
def test_circuit_open_error_with_details(self):
"""Test circuit open error with details."""
error = MCPCircuitOpenError(
"test-server",
failure_count=5,
reset_timeout=30.0,
)
assert error.failure_count == 5
assert error.reset_timeout == 30.0
assert "failures=5" in str(error)
assert "reset_in=30.0s" in str(error)
class TestMCPValidationError:
"""Tests for MCPValidationError class."""
def test_validation_error(self):
"""Test validation error."""
error = MCPValidationError("Validation failed")
assert error.message == "Validation failed"
assert error.tool_name is None
assert error.field_errors == {}
def test_validation_error_with_details(self):
"""Test validation error with field errors."""
error = MCPValidationError(
"Validation failed",
tool_name="create_issue",
field_errors={
"title": "Title is required",
"priority": "Invalid priority value",
},
)
assert error.tool_name == "create_issue"
assert error.field_errors == {
"title": "Title is required",
"priority": "Invalid priority value",
}
assert "tool=create_issue" in str(error)
assert "fields=['title', 'priority']" in str(error)
class TestExceptionInheritance:
"""Tests for exception inheritance chain."""
def test_all_errors_inherit_from_mcp_error(self):
"""Test that all custom exceptions inherit from MCPError."""
assert issubclass(MCPConnectionError, MCPError)
assert issubclass(MCPTimeoutError, MCPError)
assert issubclass(MCPToolError, MCPError)
assert issubclass(MCPServerNotFoundError, MCPError)
assert issubclass(MCPToolNotFoundError, MCPError)
assert issubclass(MCPCircuitOpenError, MCPError)
assert issubclass(MCPValidationError, MCPError)
def test_all_errors_inherit_from_exception(self):
"""Test that base MCPError inherits from Exception."""
assert issubclass(MCPError, Exception)
def test_catch_all_with_mcp_error(self):
"""Test that all errors can be caught with MCPError."""
def raise_connection_error():
raise MCPConnectionError("Connection failed")
def raise_timeout_error():
raise MCPTimeoutError("Timeout")
def raise_tool_error():
raise MCPToolError("Tool failed")
with pytest.raises(MCPError):
raise_connection_error()
with pytest.raises(MCPError):
raise_timeout_error()
with pytest.raises(MCPError):
raise_tool_error()

View File

@@ -0,0 +1,272 @@
"""
Tests for MCP Server Registry
"""
import pytest
from app.services.mcp.config import MCPConfig, MCPServerConfig, TransportType
from app.services.mcp.exceptions import MCPServerNotFoundError
from app.services.mcp.registry import (
MCPServerRegistry,
ServerCapabilities,
get_registry,
)
@pytest.fixture
def reset_registry():
"""Reset the singleton registry before and after each test."""
MCPServerRegistry.reset_instance()
yield
MCPServerRegistry.reset_instance()
@pytest.fixture
def sample_config():
"""Create a sample MCP configuration."""
return MCPConfig(
mcp_servers={
"server-1": MCPServerConfig(
url="http://server1:8000",
timeout=30,
enabled=True,
),
"server-2": MCPServerConfig(
url="http://server2:8000",
timeout=60,
enabled=True,
),
"disabled-server": MCPServerConfig(
url="http://disabled:8000",
enabled=False,
),
}
)
class TestServerCapabilities:
"""Tests for ServerCapabilities class."""
def test_empty_capabilities(self):
"""Test creating empty capabilities."""
caps = ServerCapabilities()
assert caps.tools == []
assert caps.resources == []
assert caps.prompts == []
assert caps.is_loaded is False
assert caps.tool_names == []
def test_capabilities_with_tools(self):
"""Test capabilities with tools."""
caps = ServerCapabilities(
tools=[
{"name": "tool1", "description": "Tool 1"},
{"name": "tool2", "description": "Tool 2"},
]
)
assert len(caps.tools) == 2
assert caps.tool_names == ["tool1", "tool2"]
def test_mark_loaded(self):
"""Test marking capabilities as loaded."""
caps = ServerCapabilities()
assert caps.is_loaded is False
assert caps._load_time is None
caps.mark_loaded()
assert caps.is_loaded is True
assert caps._load_time is not None
class TestMCPServerRegistry:
"""Tests for MCPServerRegistry singleton."""
def test_singleton_pattern(self, reset_registry):
"""Test that registry is a singleton."""
registry1 = MCPServerRegistry()
registry2 = MCPServerRegistry()
assert registry1 is registry2
def test_get_instance(self, reset_registry):
"""Test get_instance class method."""
registry = MCPServerRegistry.get_instance()
assert registry is MCPServerRegistry()
def test_reset_instance(self, reset_registry):
"""Test resetting singleton instance."""
registry1 = MCPServerRegistry()
MCPServerRegistry.reset_instance()
registry2 = MCPServerRegistry()
assert registry1 is not registry2
def test_load_config(self, reset_registry, sample_config):
"""Test loading configuration."""
registry = MCPServerRegistry()
registry.load_config(sample_config)
assert len(registry.list_servers()) == 3
assert "server-1" in registry.list_servers()
assert "server-2" in registry.list_servers()
assert "disabled-server" in registry.list_servers()
def test_list_enabled_servers(self, reset_registry, sample_config):
"""Test listing only enabled servers."""
registry = MCPServerRegistry()
registry.load_config(sample_config)
enabled = registry.list_enabled_servers()
assert len(enabled) == 2
assert "server-1" in enabled
assert "server-2" in enabled
assert "disabled-server" not in enabled
def test_register(self, reset_registry):
"""Test registering a new server."""
registry = MCPServerRegistry()
config = MCPServerConfig(url="http://new:8000")
registry.register("new-server", config)
assert "new-server" in registry.list_servers()
assert registry.get("new-server").url == "http://new:8000"
def test_unregister(self, reset_registry, sample_config):
"""Test unregistering a server."""
registry = MCPServerRegistry()
registry.load_config(sample_config)
assert registry.unregister("server-1") is True
assert "server-1" not in registry.list_servers()
# Unregistering non-existent server returns False
assert registry.unregister("nonexistent") is False
def test_get(self, reset_registry, sample_config):
"""Test getting server config."""
registry = MCPServerRegistry()
registry.load_config(sample_config)
config = registry.get("server-1")
assert config.url == "http://server1:8000"
assert config.timeout == 30
def test_get_not_found(self, reset_registry, sample_config):
"""Test getting non-existent server raises error."""
registry = MCPServerRegistry()
registry.load_config(sample_config)
with pytest.raises(MCPServerNotFoundError) as exc_info:
registry.get("nonexistent")
assert exc_info.value.server_name == "nonexistent"
assert "server-1" in exc_info.value.available_servers
def test_get_or_none(self, reset_registry, sample_config):
"""Test get_or_none method."""
registry = MCPServerRegistry()
registry.load_config(sample_config)
config = registry.get_or_none("server-1")
assert config is not None
config = registry.get_or_none("nonexistent")
assert config is None
def test_get_all_configs(self, reset_registry, sample_config):
"""Test getting all configs."""
registry = MCPServerRegistry()
registry.load_config(sample_config)
configs = registry.get_all_configs()
assert len(configs) == 3
def test_get_enabled_configs(self, reset_registry, sample_config):
"""Test getting enabled configs."""
registry = MCPServerRegistry()
registry.load_config(sample_config)
configs = registry.get_enabled_configs()
assert len(configs) == 2
assert "disabled-server" not in configs
@pytest.mark.asyncio
async def test_get_capabilities(self, reset_registry, sample_config):
"""Test getting server capabilities."""
registry = MCPServerRegistry()
registry.load_config(sample_config)
# Initially empty capabilities
caps = await registry.get_capabilities("server-1")
assert caps.is_loaded is False
def test_set_capabilities(self, reset_registry, sample_config):
"""Test setting server capabilities."""
registry = MCPServerRegistry()
registry.load_config(sample_config)
registry.set_capabilities(
"server-1",
tools=[{"name": "tool1"}, {"name": "tool2"}],
resources=[{"name": "resource1"}],
)
caps = registry._capabilities["server-1"]
assert len(caps.tools) == 2
assert len(caps.resources) == 1
assert caps.is_loaded is True
def test_find_server_for_tool(self, reset_registry, sample_config):
"""Test finding server that provides a tool."""
registry = MCPServerRegistry()
registry.load_config(sample_config)
registry.set_capabilities(
"server-1",
tools=[{"name": "tool1"}, {"name": "tool2"}],
)
registry.set_capabilities(
"server-2",
tools=[{"name": "tool3"}],
)
assert registry.find_server_for_tool("tool1") == "server-1"
assert registry.find_server_for_tool("tool3") == "server-2"
assert registry.find_server_for_tool("unknown") is None
def test_get_all_tools(self, reset_registry, sample_config):
"""Test getting all tools from all servers."""
registry = MCPServerRegistry()
registry.load_config(sample_config)
registry.set_capabilities(
"server-1",
tools=[{"name": "tool1"}],
)
registry.set_capabilities(
"server-2",
tools=[{"name": "tool2"}, {"name": "tool3"}],
)
all_tools = registry.get_all_tools()
assert len(all_tools) == 2
assert len(all_tools["server-1"]) == 1
assert len(all_tools["server-2"]) == 2
def test_global_config_property(self, reset_registry, sample_config):
"""Test accessing global config."""
registry = MCPServerRegistry()
registry.load_config(sample_config)
global_config = registry.global_config
assert global_config is not None
assert len(global_config.mcp_servers) == 3
class TestGetRegistry:
"""Tests for get_registry convenience function."""
def test_get_registry_returns_singleton(self, reset_registry):
"""Test that get_registry returns the singleton."""
registry1 = get_registry()
registry2 = get_registry()
assert registry1 is registry2
assert registry1 is MCPServerRegistry()

View File

@@ -0,0 +1,345 @@
"""
Tests for MCP Tool Call Routing
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.services.mcp.config import MCPConfig, MCPServerConfig
from app.services.mcp.connection import ConnectionPool
from app.services.mcp.exceptions import (
MCPCircuitOpenError,
MCPToolError,
MCPToolNotFoundError,
)
from app.services.mcp.registry import MCPServerRegistry
from app.services.mcp.routing import ToolInfo, ToolResult, ToolRouter
@pytest.fixture
def reset_registry():
"""Reset the singleton registry before and after each test."""
MCPServerRegistry.reset_instance()
yield
MCPServerRegistry.reset_instance()
@pytest.fixture
def registry(reset_registry):
"""Create a configured registry."""
reg = MCPServerRegistry()
reg.load_config(
MCPConfig(
mcp_servers={
"server-1": MCPServerConfig(
url="http://server1:8000",
retry_attempts=1,
retry_delay=0.1, # Minimum allowed value
circuit_breaker_threshold=3,
circuit_breaker_timeout=5.0,
),
"server-2": MCPServerConfig(
url="http://server2:8000",
retry_attempts=1,
retry_delay=0.1, # Minimum allowed value
),
}
)
)
return reg
@pytest.fixture
def pool():
"""Create a connection pool."""
return ConnectionPool()
@pytest.fixture
def router(registry, pool):
"""Create a tool router."""
return ToolRouter(registry, pool)
class TestToolInfo:
"""Tests for ToolInfo dataclass."""
def test_basic_tool_info(self):
"""Test creating basic tool info."""
info = ToolInfo(name="test-tool")
assert info.name == "test-tool"
assert info.description is None
assert info.server_name is None
assert info.input_schema is None
def test_full_tool_info(self):
"""Test creating full tool info."""
info = ToolInfo(
name="create_issue",
description="Create a new issue",
server_name="issues",
input_schema={"type": "object", "properties": {"title": {"type": "string"}}},
)
assert info.name == "create_issue"
assert info.description == "Create a new issue"
assert info.server_name == "issues"
assert "properties" in info.input_schema
def test_to_dict(self):
"""Test converting to dictionary."""
info = ToolInfo(
name="test-tool",
description="A test tool",
server_name="test-server",
)
result = info.to_dict()
assert result["name"] == "test-tool"
assert result["description"] == "A test tool"
assert result["server_name"] == "test-server"
class TestToolResult:
"""Tests for ToolResult dataclass."""
def test_success_result(self):
"""Test creating success result."""
result = ToolResult(
success=True,
data={"id": "123"},
tool_name="create_issue",
server_name="issues",
)
assert result.success is True
assert result.data == {"id": "123"}
assert result.error is None
def test_error_result(self):
"""Test creating error result."""
result = ToolResult(
success=False,
error="Tool execution failed",
error_code="INTERNAL_ERROR",
tool_name="create_issue",
server_name="issues",
)
assert result.success is False
assert result.error == "Tool execution failed"
assert result.error_code == "INTERNAL_ERROR"
def test_to_dict(self):
"""Test converting to dictionary."""
result = ToolResult(
success=True,
data={"result": "ok"},
tool_name="test",
execution_time_ms=123.45,
)
d = result.to_dict()
assert d["success"] is True
assert d["data"] == {"result": "ok"}
assert d["tool_name"] == "test"
assert d["execution_time_ms"] == 123.45
assert "request_id" in d # Auto-generated
class TestToolRouter:
"""Tests for ToolRouter class."""
@pytest.mark.asyncio
async def test_register_tool_mapping(self, router):
"""Test registering tool mappings."""
await router.register_tool_mapping("tool1", "server-1")
await router.register_tool_mapping("tool2", "server-2")
assert router.find_server_for_tool("tool1") == "server-1"
assert router.find_server_for_tool("tool2") == "server-2"
def test_find_server_for_unknown_tool(self, router):
"""Test finding server for unknown tool."""
result = router.find_server_for_tool("unknown-tool")
assert result is None
@pytest.mark.asyncio
async def test_call_tool_success(self, router, registry):
"""Test successful tool call."""
# Set up capabilities
registry.set_capabilities(
"server-1",
tools=[{"name": "test-tool"}],
)
await router.register_tool_mapping("test-tool", "server-1")
# Mock the pool connection and request
mock_conn = AsyncMock()
mock_conn.execute_request = AsyncMock(
return_value={"result": {"status": "ok"}}
)
mock_conn.is_connected = True
with patch.object(router._pool, "get_connection", return_value=mock_conn):
result = await router.call_tool(
server_name="server-1",
tool_name="test-tool",
arguments={"param": "value"},
)
assert result.success is True
assert result.data == {"status": "ok"}
assert result.tool_name == "test-tool"
assert result.server_name == "server-1"
assert result.execution_time_ms > 0
@pytest.mark.asyncio
async def test_call_tool_error_response(self, router, registry):
"""Test tool call with error response."""
registry.set_capabilities(
"server-1",
tools=[{"name": "test-tool"}],
)
await router.register_tool_mapping("test-tool", "server-1")
mock_conn = AsyncMock()
mock_conn.execute_request = AsyncMock(
return_value={
"error": {
"code": -32000,
"message": "Tool execution failed",
}
}
)
mock_conn.is_connected = True
with patch.object(router._pool, "get_connection", return_value=mock_conn):
result = await router.call_tool(
server_name="server-1",
tool_name="test-tool",
arguments={},
)
assert result.success is False
assert "Tool execution failed" in result.error
@pytest.mark.asyncio
async def test_route_tool(self, router, registry):
"""Test routing tool to correct server."""
registry.set_capabilities(
"server-1",
tools=[{"name": "tool-on-server-1"}],
)
await router.register_tool_mapping("tool-on-server-1", "server-1")
mock_conn = AsyncMock()
mock_conn.execute_request = AsyncMock(
return_value={"result": "routed"}
)
mock_conn.is_connected = True
with patch.object(router._pool, "get_connection", return_value=mock_conn):
result = await router.route_tool(
tool_name="tool-on-server-1",
arguments={"key": "value"},
)
assert result.success is True
assert result.server_name == "server-1"
@pytest.mark.asyncio
async def test_route_tool_not_found(self, router):
"""Test routing unknown tool raises error."""
with pytest.raises(MCPToolNotFoundError) as exc_info:
await router.route_tool(
tool_name="unknown-tool",
arguments={},
)
assert exc_info.value.tool_name == "unknown-tool"
@pytest.mark.asyncio
async def test_list_all_tools(self, router, registry):
"""Test listing all tools."""
registry.set_capabilities(
"server-1",
tools=[
{"name": "tool1", "description": "Tool 1"},
{"name": "tool2", "description": "Tool 2"},
],
)
registry.set_capabilities(
"server-2",
tools=[{"name": "tool3", "description": "Tool 3"}],
)
tools = await router.list_all_tools()
assert len(tools) == 3
tool_names = [t.name for t in tools]
assert "tool1" in tool_names
assert "tool2" in tool_names
assert "tool3" in tool_names
def test_circuit_breaker_status(self, router, registry):
"""Test getting circuit breaker status."""
# Initially no circuit breakers
status = router.get_circuit_breaker_status()
assert status == {}
@pytest.mark.asyncio
async def test_reset_circuit_breaker(self, router, registry):
"""Test resetting circuit breaker."""
# Reset non-existent returns False
result = await router.reset_circuit_breaker("server-1")
assert result is False
@pytest.mark.asyncio
async def test_discover_tools(self, router, registry):
"""Test tool discovery from servers."""
# Create mocks for different servers
mock_conn_1 = AsyncMock()
mock_conn_1.execute_request = AsyncMock(
return_value={
"tools": [
{"name": "discovered-tool", "description": "A discovered tool"},
]
}
)
mock_conn_1.server_name = "server-1"
mock_conn_1.is_connected = True
mock_conn_2 = AsyncMock()
mock_conn_2.execute_request = AsyncMock(
return_value={"tools": []} # Empty for server-2
)
mock_conn_2.server_name = "server-2"
mock_conn_2.is_connected = True
async def get_connection_side_effect(server_name, _config):
if server_name == "server-1":
return mock_conn_1
return mock_conn_2
with patch.object(
router._pool,
"get_connection",
side_effect=get_connection_side_effect,
):
await router.discover_tools()
# Check that tool mapping was registered
server = router.find_server_for_tool("discovered-tool")
assert server == "server-1"
def test_calculate_retry_delay(self, router, registry):
"""Test retry delay calculation."""
config = registry.get("server-1")
delay1 = router._calculate_retry_delay(1, config)
delay2 = router._calculate_retry_delay(2, config)
delay3 = router._calculate_retry_delay(3, config)
# Delays should increase with attempts
assert delay1 > 0
# Allow for jitter variation
assert delay1 <= config.retry_max_delay * 1.25