forked from cardosofelipe/fast-next-template
feat(backend): implement MCP client infrastructure (#55)
Core MCP client implementation with comprehensive tooling:
**Services:**
- MCPClientManager: Main facade for all MCP operations
- MCPServerRegistry: Thread-safe singleton for server configs
- ConnectionPool: Connection pooling with auto-reconnection
- ToolRouter: Automatic tool routing with circuit breaker
- AsyncCircuitBreaker: Custom async-compatible circuit breaker
**Configuration:**
- YAML-based config with Pydantic models
- Environment variable expansion support
- Transport types: HTTP, SSE, STDIO
**API Endpoints:**
- GET /mcp/servers - List all MCP servers
- GET /mcp/servers/{name}/tools - List server tools
- GET /mcp/tools - List all tools from all servers
- GET /mcp/health - Health check all servers
- POST /mcp/call - Execute tool (admin only)
- GET /mcp/circuit-breakers - Circuit breaker status
- POST /mcp/circuit-breakers/{name}/reset - Reset circuit breaker
- POST /mcp/servers/{name}/reconnect - Force reconnection
**Testing:**
- 156 unit tests with comprehensive coverage
- Tests for all services, routes, and error handling
- Proper mocking and async test support
**Documentation:**
- MCP_CLIENT.md with usage examples
- Phase 2+ workflow documentation
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
491
backend/tests/api/routes/test_mcp.py
Normal file
491
backend/tests/api/routes/test_mcp.py
Normal file
@@ -0,0 +1,491 @@
|
||||
"""
|
||||
Tests for MCP API Routes
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import status
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.main import app
|
||||
from app.models.user import User
|
||||
from app.services.mcp import (
|
||||
MCPCircuitOpenError,
|
||||
MCPClientManager,
|
||||
MCPConnectionError,
|
||||
MCPServerNotFoundError,
|
||||
MCPTimeoutError,
|
||||
MCPToolNotFoundError,
|
||||
ServerHealth,
|
||||
)
|
||||
from app.services.mcp.config import MCPServerConfig, TransportType
|
||||
from app.services.mcp.routing import ToolInfo, ToolResult
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mcp_client():
|
||||
"""Create a mock MCP client manager."""
|
||||
client = MagicMock(spec=MCPClientManager)
|
||||
client.is_initialized = True
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_superuser():
|
||||
"""Create a mock superuser."""
|
||||
user = MagicMock(spec=User)
|
||||
user.id = "00000000-0000-0000-0000-000000000001"
|
||||
user.is_superuser = True
|
||||
user.email = "admin@example.com"
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_mcp_client, mock_superuser):
|
||||
"""Create a FastAPI test client with mocked dependencies."""
|
||||
from app.api.routes.mcp import get_mcp_client
|
||||
from app.api.dependencies.permissions import require_superuser
|
||||
|
||||
# Override dependencies
|
||||
async def override_get_mcp_client():
|
||||
return mock_mcp_client
|
||||
|
||||
async def override_require_superuser():
|
||||
return mock_superuser
|
||||
|
||||
app.dependency_overrides[get_mcp_client] = override_get_mcp_client
|
||||
app.dependency_overrides[require_superuser] = override_require_superuser
|
||||
|
||||
with patch("app.main.check_database_health", return_value=True):
|
||||
yield TestClient(app)
|
||||
|
||||
# Clean up
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
class TestListServers:
|
||||
"""Tests for GET /mcp/servers endpoint."""
|
||||
|
||||
def test_list_servers_success(self, client, mock_mcp_client):
|
||||
"""Test listing MCP servers returns correct data."""
|
||||
# Setup mock
|
||||
mock_mcp_client.list_servers.return_value = ["server-1", "server-2"]
|
||||
mock_mcp_client.get_server_config.side_effect = [
|
||||
MCPServerConfig(
|
||||
url="http://server1:8000",
|
||||
timeout=30,
|
||||
enabled=True,
|
||||
transport=TransportType.HTTP,
|
||||
description="Server 1",
|
||||
),
|
||||
MCPServerConfig(
|
||||
url="http://server2:8000",
|
||||
timeout=60,
|
||||
enabled=True,
|
||||
transport=TransportType.SSE,
|
||||
description="Server 2",
|
||||
),
|
||||
]
|
||||
|
||||
response = client.get("/api/v1/mcp/servers")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["total"] == 2
|
||||
assert len(data["servers"]) == 2
|
||||
assert data["servers"][0]["name"] == "server-1"
|
||||
assert data["servers"][0]["url"] == "http://server1:8000"
|
||||
assert data["servers"][1]["name"] == "server-2"
|
||||
assert data["servers"][1]["transport"] == "sse"
|
||||
|
||||
def test_list_servers_empty(self, client, mock_mcp_client):
|
||||
"""Test listing servers when none are registered."""
|
||||
mock_mcp_client.list_servers.return_value = []
|
||||
|
||||
response = client.get("/api/v1/mcp/servers")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["total"] == 0
|
||||
assert data["servers"] == []
|
||||
|
||||
def test_list_servers_handles_not_found(self, client, mock_mcp_client):
|
||||
"""Test that missing server configs are skipped gracefully."""
|
||||
mock_mcp_client.list_servers.return_value = ["server-1", "missing"]
|
||||
mock_mcp_client.get_server_config.side_effect = [
|
||||
MCPServerConfig(url="http://server1:8000"),
|
||||
MCPServerNotFoundError(server_name="missing"),
|
||||
]
|
||||
|
||||
response = client.get("/api/v1/mcp/servers")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
# Should only include the successfully retrieved server
|
||||
assert data["total"] == 1
|
||||
|
||||
|
||||
class TestListServerTools:
|
||||
"""Tests for GET /mcp/servers/{server_name}/tools endpoint."""
|
||||
|
||||
def test_list_server_tools_success(self, client, mock_mcp_client):
|
||||
"""Test listing tools for a specific server."""
|
||||
mock_mcp_client.list_tools = AsyncMock(
|
||||
return_value=[
|
||||
ToolInfo(name="tool1", description="Tool 1", server_name="server-1"),
|
||||
ToolInfo(name="tool2", description="Tool 2", server_name="server-1"),
|
||||
]
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/mcp/servers/server-1/tools")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["total"] == 2
|
||||
assert data["tools"][0]["name"] == "tool1"
|
||||
assert data["tools"][1]["name"] == "tool2"
|
||||
|
||||
def test_list_server_tools_not_found(self, client, mock_mcp_client):
|
||||
"""Test listing tools for non-existent server."""
|
||||
mock_mcp_client.list_tools = AsyncMock(
|
||||
side_effect=MCPServerNotFoundError(server_name="unknown")
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/mcp/servers/unknown/tools")
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
class TestListAllTools:
|
||||
"""Tests for GET /mcp/tools endpoint."""
|
||||
|
||||
def test_list_all_tools_success(self, client, mock_mcp_client):
|
||||
"""Test listing all tools from all servers."""
|
||||
mock_mcp_client.list_all_tools = AsyncMock(
|
||||
return_value=[
|
||||
ToolInfo(name="tool1", server_name="server-1"),
|
||||
ToolInfo(name="tool2", server_name="server-1"),
|
||||
ToolInfo(name="tool3", server_name="server-2"),
|
||||
]
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/mcp/tools")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["total"] == 3
|
||||
|
||||
def test_list_all_tools_empty(self, client, mock_mcp_client):
|
||||
"""Test listing tools when none are available."""
|
||||
mock_mcp_client.list_all_tools = AsyncMock(return_value=[])
|
||||
|
||||
response = client.get("/api/v1/mcp/tools")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["total"] == 0
|
||||
|
||||
|
||||
class TestHealthCheck:
|
||||
"""Tests for GET /mcp/health endpoint."""
|
||||
|
||||
def test_health_check_success(self, client, mock_mcp_client):
|
||||
"""Test health check returns correct data."""
|
||||
mock_mcp_client.health_check = AsyncMock(
|
||||
return_value={
|
||||
"server-1": ServerHealth(
|
||||
name="server-1",
|
||||
healthy=True,
|
||||
state="connected",
|
||||
url="http://server1:8000",
|
||||
tools_count=5,
|
||||
),
|
||||
"server-2": ServerHealth(
|
||||
name="server-2",
|
||||
healthy=False,
|
||||
state="error",
|
||||
url="http://server2:8000",
|
||||
error="Connection refused",
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/mcp/health")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["total"] == 2
|
||||
assert data["healthy_count"] == 1
|
||||
assert data["unhealthy_count"] == 1
|
||||
assert data["servers"]["server-1"]["healthy"] is True
|
||||
assert data["servers"]["server-2"]["healthy"] is False
|
||||
|
||||
def test_health_check_all_healthy(self, client, mock_mcp_client):
|
||||
"""Test health check when all servers are healthy."""
|
||||
mock_mcp_client.health_check = AsyncMock(
|
||||
return_value={
|
||||
"server-1": ServerHealth(
|
||||
name="server-1",
|
||||
healthy=True,
|
||||
state="connected",
|
||||
url="http://server1:8000",
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/mcp/health")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["healthy_count"] == 1
|
||||
assert data["unhealthy_count"] == 0
|
||||
|
||||
|
||||
class TestCallTool:
|
||||
"""Tests for POST /mcp/call endpoint."""
|
||||
|
||||
def test_call_tool_success(self, client, mock_mcp_client):
|
||||
"""Test successful tool execution."""
|
||||
mock_mcp_client.call_tool = AsyncMock(
|
||||
return_value=ToolResult(
|
||||
success=True,
|
||||
data={"result": "ok"},
|
||||
tool_name="test-tool",
|
||||
server_name="server-1",
|
||||
execution_time_ms=123.45,
|
||||
request_id="test-request-id",
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/mcp/call",
|
||||
json={
|
||||
"server": "server-1",
|
||||
"tool": "test-tool",
|
||||
"arguments": {"key": "value"},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"] == {"result": "ok"}
|
||||
assert data["tool_name"] == "test-tool"
|
||||
assert data["server_name"] == "server-1"
|
||||
|
||||
def test_call_tool_with_timeout(self, client, mock_mcp_client):
|
||||
"""Test tool execution with custom timeout."""
|
||||
mock_mcp_client.call_tool = AsyncMock(
|
||||
return_value=ToolResult(success=True, data={})
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/mcp/call",
|
||||
json={
|
||||
"server": "server-1",
|
||||
"tool": "test-tool",
|
||||
"timeout": 60.0,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
mock_mcp_client.call_tool.assert_called_once()
|
||||
call_args = mock_mcp_client.call_tool.call_args
|
||||
assert call_args.kwargs["timeout"] == 60.0
|
||||
|
||||
def test_call_tool_server_not_found(self, client, mock_mcp_client):
|
||||
"""Test tool execution with non-existent server."""
|
||||
mock_mcp_client.call_tool = AsyncMock(
|
||||
side_effect=MCPServerNotFoundError(server_name="unknown")
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/mcp/call",
|
||||
json={"server": "unknown", "tool": "test-tool"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
def test_call_tool_not_found(self, client, mock_mcp_client):
|
||||
"""Test tool execution with non-existent tool."""
|
||||
mock_mcp_client.call_tool = AsyncMock(
|
||||
side_effect=MCPToolNotFoundError(tool_name="unknown")
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/mcp/call",
|
||||
json={"server": "server-1", "tool": "unknown"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
def test_call_tool_timeout(self, client, mock_mcp_client):
|
||||
"""Test tool execution timeout."""
|
||||
mock_mcp_client.call_tool = AsyncMock(
|
||||
side_effect=MCPTimeoutError(
|
||||
"Request timed out",
|
||||
server_name="server-1",
|
||||
timeout_seconds=30.0,
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/mcp/call",
|
||||
json={"server": "server-1", "tool": "slow-tool"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_504_GATEWAY_TIMEOUT
|
||||
|
||||
def test_call_tool_connection_error(self, client, mock_mcp_client):
|
||||
"""Test tool execution with connection failure."""
|
||||
mock_mcp_client.call_tool = AsyncMock(
|
||||
side_effect=MCPConnectionError(
|
||||
"Connection refused",
|
||||
server_name="server-1",
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/mcp/call",
|
||||
json={"server": "server-1", "tool": "test-tool"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_502_BAD_GATEWAY
|
||||
|
||||
def test_call_tool_circuit_open(self, client, mock_mcp_client):
|
||||
"""Test tool execution with open circuit breaker."""
|
||||
mock_mcp_client.call_tool = AsyncMock(
|
||||
side_effect=MCPCircuitOpenError(
|
||||
server_name="server-1",
|
||||
failure_count=5,
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/mcp/call",
|
||||
json={"server": "server-1", "tool": "test-tool"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
|
||||
|
||||
class TestCircuitBreakers:
|
||||
"""Tests for circuit breaker endpoints."""
|
||||
|
||||
def test_list_circuit_breakers(self, client, mock_mcp_client):
|
||||
"""Test listing circuit breaker statuses."""
|
||||
mock_mcp_client.get_circuit_breaker_status.return_value = {
|
||||
"server-1": {"state": "closed", "failure_count": 0},
|
||||
"server-2": {"state": "open", "failure_count": 5},
|
||||
}
|
||||
|
||||
response = client.get("/api/v1/mcp/circuit-breakers")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert len(data["circuit_breakers"]) == 2
|
||||
|
||||
def test_list_circuit_breakers_empty(self, client, mock_mcp_client):
|
||||
"""Test listing when no circuit breakers exist."""
|
||||
mock_mcp_client.get_circuit_breaker_status.return_value = {}
|
||||
|
||||
response = client.get("/api/v1/mcp/circuit-breakers")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["circuit_breakers"] == []
|
||||
|
||||
def test_reset_circuit_breaker_success(self, client, mock_mcp_client):
|
||||
"""Test successfully resetting a circuit breaker."""
|
||||
mock_mcp_client.reset_circuit_breaker = AsyncMock(return_value=True)
|
||||
|
||||
response = client.post("/api/v1/mcp/circuit-breakers/server-1/reset")
|
||||
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
|
||||
def test_reset_circuit_breaker_not_found(self, client, mock_mcp_client):
|
||||
"""Test resetting non-existent circuit breaker."""
|
||||
mock_mcp_client.reset_circuit_breaker = AsyncMock(return_value=False)
|
||||
|
||||
response = client.post("/api/v1/mcp/circuit-breakers/unknown/reset")
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
class TestReconnectServer:
|
||||
"""Tests for POST /mcp/servers/{server_name}/reconnect endpoint."""
|
||||
|
||||
def test_reconnect_success(self, client, mock_mcp_client):
|
||||
"""Test successful server reconnection."""
|
||||
mock_mcp_client.disconnect = AsyncMock()
|
||||
mock_mcp_client.connect = AsyncMock()
|
||||
|
||||
response = client.post("/api/v1/mcp/servers/server-1/reconnect")
|
||||
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
mock_mcp_client.disconnect.assert_called_once_with("server-1")
|
||||
mock_mcp_client.connect.assert_called_once_with("server-1")
|
||||
|
||||
def test_reconnect_server_not_found(self, client, mock_mcp_client):
|
||||
"""Test reconnecting to non-existent server."""
|
||||
mock_mcp_client.disconnect = AsyncMock(
|
||||
side_effect=MCPServerNotFoundError(server_name="unknown")
|
||||
)
|
||||
|
||||
response = client.post("/api/v1/mcp/servers/unknown/reconnect")
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
def test_reconnect_connection_failure(self, client, mock_mcp_client):
|
||||
"""Test reconnection failure."""
|
||||
mock_mcp_client.disconnect = AsyncMock()
|
||||
mock_mcp_client.connect = AsyncMock(
|
||||
side_effect=MCPConnectionError(
|
||||
"Connection refused",
|
||||
server_name="server-1",
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post("/api/v1/mcp/servers/server-1/reconnect")
|
||||
|
||||
assert response.status_code == status.HTTP_502_BAD_GATEWAY
|
||||
|
||||
|
||||
class TestMCPEndpointsEdgeCases:
|
||||
"""Edge case tests for MCP endpoints."""
|
||||
|
||||
def test_servers_content_type(self, client, mock_mcp_client):
|
||||
"""Test that endpoints return JSON content type."""
|
||||
mock_mcp_client.list_servers.return_value = []
|
||||
|
||||
response = client.get("/api/v1/mcp/servers")
|
||||
|
||||
assert "application/json" in response.headers["content-type"]
|
||||
|
||||
def test_call_tool_validation_error(self, client):
|
||||
"""Test that invalid request body returns validation error."""
|
||||
response = client.post(
|
||||
"/api/v1/mcp/call",
|
||||
json={}, # Missing required fields
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
def test_call_tool_missing_server(self, client):
|
||||
"""Test that missing server field returns validation error."""
|
||||
response = client.post(
|
||||
"/api/v1/mcp/call",
|
||||
json={"tool": "test-tool"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
def test_call_tool_missing_tool(self, client):
|
||||
"""Test that missing tool field returns validation error."""
|
||||
response = client.post(
|
||||
"/api/v1/mcp/call",
|
||||
json={"server": "server-1"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
1
backend/tests/services/mcp/__init__.py
Normal file
1
backend/tests/services/mcp/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""MCP Service Tests Package."""
|
||||
395
backend/tests/services/mcp/test_client_manager.py
Normal file
395
backend/tests/services/mcp/test_client_manager.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""
|
||||
Tests for MCP Client Manager
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.mcp.client_manager import (
|
||||
MCPClientManager,
|
||||
ServerHealth,
|
||||
get_mcp_client,
|
||||
reset_mcp_client,
|
||||
shutdown_mcp_client,
|
||||
)
|
||||
from app.services.mcp.config import MCPConfig, MCPServerConfig
|
||||
from app.services.mcp.connection import ConnectionState
|
||||
from app.services.mcp.exceptions import MCPServerNotFoundError
|
||||
from app.services.mcp.registry import MCPServerRegistry
|
||||
from app.services.mcp.routing import ToolInfo, ToolResult
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_registry():
|
||||
"""Reset the singleton registry before and after each test."""
|
||||
MCPServerRegistry.reset_instance()
|
||||
reset_mcp_client()
|
||||
yield
|
||||
MCPServerRegistry.reset_instance()
|
||||
reset_mcp_client()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_config():
|
||||
"""Create a sample MCP configuration."""
|
||||
return MCPConfig(
|
||||
mcp_servers={
|
||||
"server-1": MCPServerConfig(
|
||||
url="http://server1:8000",
|
||||
timeout=30,
|
||||
enabled=True,
|
||||
),
|
||||
"server-2": MCPServerConfig(
|
||||
url="http://server2:8000",
|
||||
timeout=60,
|
||||
enabled=True,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TestServerHealth:
|
||||
"""Tests for ServerHealth dataclass."""
|
||||
|
||||
def test_healthy_server(self):
|
||||
"""Test healthy server status."""
|
||||
health = ServerHealth(
|
||||
name="test-server",
|
||||
healthy=True,
|
||||
state="connected",
|
||||
url="http://test:8000",
|
||||
tools_count=5,
|
||||
)
|
||||
assert health.healthy is True
|
||||
assert health.error is None
|
||||
assert health.tools_count == 5
|
||||
|
||||
def test_unhealthy_server(self):
|
||||
"""Test unhealthy server status."""
|
||||
health = ServerHealth(
|
||||
name="test-server",
|
||||
healthy=False,
|
||||
state="error",
|
||||
url="http://test:8000",
|
||||
error="Connection refused",
|
||||
)
|
||||
assert health.healthy is False
|
||||
assert health.error == "Connection refused"
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test converting to dictionary."""
|
||||
health = ServerHealth(
|
||||
name="test-server",
|
||||
healthy=True,
|
||||
state="connected",
|
||||
url="http://test:8000",
|
||||
tools_count=3,
|
||||
)
|
||||
d = health.to_dict()
|
||||
|
||||
assert d["name"] == "test-server"
|
||||
assert d["healthy"] is True
|
||||
assert d["state"] == "connected"
|
||||
assert d["url"] == "http://test:8000"
|
||||
assert d["tools_count"] == 3
|
||||
|
||||
|
||||
class TestMCPClientManager:
|
||||
"""Tests for MCPClientManager class."""
|
||||
|
||||
def test_initial_state(self, reset_registry):
|
||||
"""Test initial manager state."""
|
||||
manager = MCPClientManager()
|
||||
assert manager.is_initialized is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize(self, reset_registry, sample_config):
|
||||
"""Test manager initialization."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
|
||||
with patch.object(manager._pool, "get_connection") as mock_get_conn:
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.is_connected = True
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
with patch.object(manager, "_router") as mock_router:
|
||||
mock_router.discover_tools = AsyncMock()
|
||||
|
||||
await manager.initialize()
|
||||
|
||||
assert manager.is_initialized is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown(self, reset_registry, sample_config):
|
||||
"""Test manager shutdown."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
manager._initialized = True
|
||||
|
||||
with patch.object(manager._pool, "close_all") as mock_close:
|
||||
mock_close.return_value = None
|
||||
await manager.shutdown()
|
||||
|
||||
assert manager.is_initialized is False
|
||||
mock_close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect(self, reset_registry, sample_config):
|
||||
"""Test connecting to specific server."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
|
||||
with patch.object(manager._pool, "get_connection") as mock_get_conn:
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.is_connected = True
|
||||
mock_get_conn.return_value = mock_conn
|
||||
|
||||
await manager.connect("server-1")
|
||||
|
||||
mock_get_conn.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_unknown_server(self, reset_registry, sample_config):
|
||||
"""Test connecting to unknown server raises error."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
|
||||
with pytest.raises(MCPServerNotFoundError):
|
||||
await manager.connect("unknown-server")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect(self, reset_registry, sample_config):
|
||||
"""Test disconnecting from server."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
|
||||
with patch.object(manager._pool, "close_connection") as mock_close:
|
||||
await manager.disconnect("server-1")
|
||||
mock_close.assert_called_once_with("server-1")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_all(self, reset_registry, sample_config):
|
||||
"""Test disconnecting from all servers."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
|
||||
with patch.object(manager._pool, "close_all") as mock_close:
|
||||
await manager.disconnect_all()
|
||||
mock_close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool(self, reset_registry, sample_config):
|
||||
"""Test calling a tool."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
manager._initialized = True
|
||||
|
||||
expected_result = ToolResult(
|
||||
success=True,
|
||||
data={"id": "123"},
|
||||
tool_name="create_issue",
|
||||
server_name="server-1",
|
||||
)
|
||||
|
||||
mock_router = MagicMock()
|
||||
mock_router.call_tool = AsyncMock(return_value=expected_result)
|
||||
manager._router = mock_router
|
||||
|
||||
result = await manager.call_tool(
|
||||
server="server-1",
|
||||
tool="create_issue",
|
||||
args={"title": "Test"},
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data == {"id": "123"}
|
||||
mock_router.call_tool.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_tool(self, reset_registry, sample_config):
|
||||
"""Test routing a tool call."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
manager._initialized = True
|
||||
|
||||
expected_result = ToolResult(
|
||||
success=True,
|
||||
data={"result": "ok"},
|
||||
tool_name="auto_tool",
|
||||
server_name="server-2",
|
||||
)
|
||||
|
||||
mock_router = MagicMock()
|
||||
mock_router.route_tool = AsyncMock(return_value=expected_result)
|
||||
manager._router = mock_router
|
||||
|
||||
result = await manager.route_tool(
|
||||
tool="auto_tool",
|
||||
args={"key": "value"},
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.server_name == "server-2"
|
||||
mock_router.route_tool.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools(self, reset_registry, sample_config):
|
||||
"""Test listing tools for a server."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
|
||||
# Set up capabilities in registry
|
||||
manager._registry.set_capabilities(
|
||||
"server-1",
|
||||
tools=[
|
||||
{"name": "tool1", "description": "Tool 1"},
|
||||
{"name": "tool2", "description": "Tool 2"},
|
||||
],
|
||||
)
|
||||
|
||||
tools = await manager.list_tools("server-1")
|
||||
|
||||
assert len(tools) == 2
|
||||
assert tools[0].name == "tool1"
|
||||
assert tools[1].name == "tool2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_all_tools(self, reset_registry, sample_config):
|
||||
"""Test listing all tools from all servers."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
manager._initialized = True
|
||||
|
||||
expected_tools = [
|
||||
ToolInfo(name="tool1", server_name="server-1"),
|
||||
ToolInfo(name="tool2", server_name="server-2"),
|
||||
]
|
||||
|
||||
mock_router = MagicMock()
|
||||
mock_router.list_all_tools = AsyncMock(return_value=expected_tools)
|
||||
manager._router = mock_router
|
||||
|
||||
tools = await manager.list_all_tools()
|
||||
|
||||
assert len(tools) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check(self, reset_registry, sample_config):
|
||||
"""Test health check on all servers."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
|
||||
with patch.object(manager._pool, "get_status") as mock_status:
|
||||
mock_status.return_value = {
|
||||
"server-1": {"state": "connected"},
|
||||
"server-2": {"state": "disconnected"},
|
||||
}
|
||||
|
||||
with patch.object(manager._pool, "health_check_all") as mock_health:
|
||||
mock_health.return_value = {
|
||||
"server-1": True,
|
||||
"server-2": False,
|
||||
}
|
||||
|
||||
health = await manager.health_check()
|
||||
|
||||
assert "server-1" in health
|
||||
assert "server-2" in health
|
||||
assert health["server-1"].healthy is True
|
||||
assert health["server-2"].healthy is False
|
||||
|
||||
def test_list_servers(self, reset_registry, sample_config):
|
||||
"""Test listing registered servers."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
servers = manager.list_servers()
|
||||
|
||||
assert "server-1" in servers
|
||||
assert "server-2" in servers
|
||||
|
||||
def test_list_enabled_servers(self, reset_registry, sample_config):
|
||||
"""Test listing enabled servers."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
servers = manager.list_enabled_servers()
|
||||
|
||||
assert "server-1" in servers
|
||||
assert "server-2" in servers
|
||||
|
||||
def test_get_server_config(self, reset_registry, sample_config):
|
||||
"""Test getting server configuration."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
|
||||
config = manager.get_server_config("server-1")
|
||||
assert config.url == "http://server1:8000"
|
||||
assert config.timeout == 30
|
||||
|
||||
def test_get_server_config_not_found(self, reset_registry, sample_config):
|
||||
"""Test getting unknown server config raises error."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
|
||||
with pytest.raises(MCPServerNotFoundError):
|
||||
manager.get_server_config("unknown")
|
||||
|
||||
def test_register_server(self, reset_registry, sample_config):
|
||||
"""Test registering new server at runtime."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
|
||||
new_config = MCPServerConfig(url="http://new:8000")
|
||||
manager.register_server("new-server", new_config)
|
||||
|
||||
assert "new-server" in manager.list_servers()
|
||||
|
||||
def test_unregister_server(self, reset_registry, sample_config):
|
||||
"""Test unregistering a server."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
|
||||
result = manager.unregister_server("server-1")
|
||||
assert result is True
|
||||
assert "server-1" not in manager.list_servers()
|
||||
|
||||
# Unregistering non-existent returns False
|
||||
result = manager.unregister_server("nonexistent")
|
||||
assert result is False
|
||||
|
||||
def test_circuit_breaker_status(self, reset_registry, sample_config):
|
||||
"""Test getting circuit breaker status."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
|
||||
# No router yet
|
||||
status = manager.get_circuit_breaker_status()
|
||||
assert status == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_circuit_breaker(self, reset_registry, sample_config):
|
||||
"""Test resetting circuit breaker."""
|
||||
manager = MCPClientManager(config=sample_config)
|
||||
|
||||
# No router yet
|
||||
result = await manager.reset_circuit_breaker("server-1")
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestModuleLevelFunctions:
|
||||
"""Tests for module-level convenience functions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_mcp_client_creates_singleton(self, reset_registry):
|
||||
"""Test get_mcp_client creates and returns singleton."""
|
||||
with patch(
|
||||
"app.services.mcp.client_manager.MCPClientManager.initialize"
|
||||
) as mock_init:
|
||||
mock_init.return_value = None
|
||||
|
||||
client1 = await get_mcp_client()
|
||||
client2 = await get_mcp_client()
|
||||
|
||||
assert client1 is client2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown_mcp_client(self, reset_registry):
|
||||
"""Test shutting down the global client."""
|
||||
with patch(
|
||||
"app.services.mcp.client_manager.MCPClientManager.initialize"
|
||||
) as mock_init:
|
||||
mock_init.return_value = None
|
||||
|
||||
client = await get_mcp_client()
|
||||
|
||||
with patch.object(client, "shutdown") as mock_shutdown:
|
||||
mock_shutdown.return_value = None
|
||||
await shutdown_mcp_client()
|
||||
|
||||
def test_reset_mcp_client(self, reset_registry):
|
||||
"""Test resetting the global client."""
|
||||
reset_mcp_client()
|
||||
# Should not raise
|
||||
319
backend/tests/services/mcp/test_config.py
Normal file
319
backend/tests/services/mcp/test_config.py
Normal file
@@ -0,0 +1,319 @@
|
||||
"""
|
||||
Tests for MCP Configuration System
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from app.services.mcp.config import (
|
||||
MCPConfig,
|
||||
MCPServerConfig,
|
||||
TransportType,
|
||||
create_default_config,
|
||||
load_mcp_config,
|
||||
)
|
||||
|
||||
|
||||
class TestTransportType:
|
||||
"""Tests for TransportType enum."""
|
||||
|
||||
def test_transport_types(self):
|
||||
"""Test that all transport types are defined."""
|
||||
assert TransportType.HTTP == "http"
|
||||
assert TransportType.STDIO == "stdio"
|
||||
assert TransportType.SSE == "sse"
|
||||
|
||||
def test_transport_type_from_string(self):
|
||||
"""Test creating transport type from string."""
|
||||
assert TransportType("http") == TransportType.HTTP
|
||||
assert TransportType("stdio") == TransportType.STDIO
|
||||
assert TransportType("sse") == TransportType.SSE
|
||||
|
||||
|
||||
class TestMCPServerConfig:
|
||||
"""Tests for MCPServerConfig model."""
|
||||
|
||||
def test_minimal_config(self):
|
||||
"""Test creating config with only required fields."""
|
||||
config = MCPServerConfig(url="http://localhost:8000")
|
||||
assert config.url == "http://localhost:8000"
|
||||
assert config.transport == TransportType.HTTP
|
||||
assert config.timeout == 30
|
||||
assert config.retry_attempts == 3
|
||||
assert config.enabled is True
|
||||
|
||||
def test_full_config(self):
|
||||
"""Test creating config with all fields."""
|
||||
config = MCPServerConfig(
|
||||
url="http://localhost:8000",
|
||||
transport=TransportType.SSE,
|
||||
timeout=60,
|
||||
retry_attempts=5,
|
||||
retry_delay=2.0,
|
||||
retry_max_delay=60.0,
|
||||
circuit_breaker_threshold=10,
|
||||
circuit_breaker_timeout=60.0,
|
||||
enabled=False,
|
||||
description="Test server",
|
||||
)
|
||||
assert config.timeout == 60
|
||||
assert config.transport == TransportType.SSE
|
||||
assert config.retry_attempts == 5
|
||||
assert config.retry_delay == 2.0
|
||||
assert config.retry_max_delay == 60.0
|
||||
assert config.circuit_breaker_threshold == 10
|
||||
assert config.circuit_breaker_timeout == 60.0
|
||||
assert config.enabled is False
|
||||
assert config.description == "Test server"
|
||||
|
||||
def test_env_var_expansion_simple(self):
|
||||
"""Test simple environment variable expansion."""
|
||||
os.environ["TEST_SERVER_URL"] = "http://test-server:9000"
|
||||
try:
|
||||
config = MCPServerConfig(url="${TEST_SERVER_URL}")
|
||||
assert config.url == "http://test-server:9000"
|
||||
finally:
|
||||
del os.environ["TEST_SERVER_URL"]
|
||||
|
||||
def test_env_var_expansion_with_default(self):
|
||||
"""Test environment variable expansion with default."""
|
||||
# Ensure env var is not set
|
||||
os.environ.pop("NONEXISTENT_URL", None)
|
||||
config = MCPServerConfig(url="${NONEXISTENT_URL:-http://default:8000}")
|
||||
assert config.url == "http://default:8000"
|
||||
|
||||
def test_env_var_expansion_override_default(self):
|
||||
"""Test environment variable override of default."""
|
||||
os.environ["TEST_OVERRIDE_URL"] = "http://override:9000"
|
||||
try:
|
||||
config = MCPServerConfig(url="${TEST_OVERRIDE_URL:-http://default:8000}")
|
||||
assert config.url == "http://override:9000"
|
||||
finally:
|
||||
del os.environ["TEST_OVERRIDE_URL"]
|
||||
|
||||
def test_timeout_validation(self):
|
||||
"""Test timeout validation bounds."""
|
||||
# Valid bounds
|
||||
config = MCPServerConfig(url="http://localhost", timeout=1)
|
||||
assert config.timeout == 1
|
||||
|
||||
config = MCPServerConfig(url="http://localhost", timeout=600)
|
||||
assert config.timeout == 600
|
||||
|
||||
# Invalid bounds
|
||||
with pytest.raises(ValueError):
|
||||
MCPServerConfig(url="http://localhost", timeout=0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
MCPServerConfig(url="http://localhost", timeout=601)
|
||||
|
||||
def test_retry_attempts_validation(self):
|
||||
"""Test retry attempts validation bounds."""
|
||||
config = MCPServerConfig(url="http://localhost", retry_attempts=0)
|
||||
assert config.retry_attempts == 0
|
||||
|
||||
config = MCPServerConfig(url="http://localhost", retry_attempts=10)
|
||||
assert config.retry_attempts == 10
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
MCPServerConfig(url="http://localhost", retry_attempts=-1)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
MCPServerConfig(url="http://localhost", retry_attempts=11)
|
||||
|
||||
|
||||
class TestMCPConfig:
|
||||
"""Tests for MCPConfig model."""
|
||||
|
||||
def test_empty_config(self):
|
||||
"""Test creating empty config."""
|
||||
config = MCPConfig()
|
||||
assert config.mcp_servers == {}
|
||||
assert config.default_timeout == 30
|
||||
assert config.default_retry_attempts == 3
|
||||
assert config.connection_pool_size == 10
|
||||
assert config.health_check_interval == 30
|
||||
|
||||
def test_config_with_servers(self):
|
||||
"""Test creating config with servers."""
|
||||
config = MCPConfig(
|
||||
mcp_servers={
|
||||
"server-1": MCPServerConfig(url="http://server1:8000"),
|
||||
"server-2": MCPServerConfig(url="http://server2:8000"),
|
||||
}
|
||||
)
|
||||
assert len(config.mcp_servers) == 2
|
||||
assert "server-1" in config.mcp_servers
|
||||
assert "server-2" in config.mcp_servers
|
||||
|
||||
def test_get_server(self):
|
||||
"""Test getting server by name."""
|
||||
config = MCPConfig(
|
||||
mcp_servers={
|
||||
"server-1": MCPServerConfig(url="http://server1:8000"),
|
||||
}
|
||||
)
|
||||
server = config.get_server("server-1")
|
||||
assert server is not None
|
||||
assert server.url == "http://server1:8000"
|
||||
|
||||
missing = config.get_server("nonexistent")
|
||||
assert missing is None
|
||||
|
||||
def test_get_enabled_servers(self):
|
||||
"""Test getting only enabled servers."""
|
||||
config = MCPConfig(
|
||||
mcp_servers={
|
||||
"enabled-1": MCPServerConfig(url="http://e1:8000", enabled=True),
|
||||
"disabled-1": MCPServerConfig(url="http://d1:8000", enabled=False),
|
||||
"enabled-2": MCPServerConfig(url="http://e2:8000", enabled=True),
|
||||
}
|
||||
)
|
||||
enabled = config.get_enabled_servers()
|
||||
assert len(enabled) == 2
|
||||
assert "enabled-1" in enabled
|
||||
assert "enabled-2" in enabled
|
||||
assert "disabled-1" not in enabled
|
||||
|
||||
def test_list_server_names(self):
|
||||
"""Test listing server names."""
|
||||
config = MCPConfig(
|
||||
mcp_servers={
|
||||
"server-a": MCPServerConfig(url="http://a:8000"),
|
||||
"server-b": MCPServerConfig(url="http://b:8000"),
|
||||
}
|
||||
)
|
||||
names = config.list_server_names()
|
||||
assert sorted(names) == ["server-a", "server-b"]
|
||||
|
||||
def test_from_dict(self):
|
||||
"""Test creating config from dictionary."""
|
||||
data = {
|
||||
"mcp_servers": {
|
||||
"test-server": {
|
||||
"url": "http://test:8000",
|
||||
"timeout": 45,
|
||||
}
|
||||
},
|
||||
"default_timeout": 60,
|
||||
}
|
||||
config = MCPConfig.from_dict(data)
|
||||
assert config.default_timeout == 60
|
||||
assert config.mcp_servers["test-server"].timeout == 45
|
||||
|
||||
def test_from_yaml(self):
|
||||
"""Test loading config from YAML file."""
|
||||
yaml_content = """
|
||||
mcp_servers:
|
||||
test-server:
|
||||
url: http://test:8000
|
||||
timeout: 45
|
||||
transport: http
|
||||
enabled: true
|
||||
default_timeout: 60
|
||||
connection_pool_size: 20
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".yaml", delete=False
|
||||
) as f:
|
||||
f.write(yaml_content)
|
||||
f.flush()
|
||||
|
||||
try:
|
||||
config = MCPConfig.from_yaml(f.name)
|
||||
assert config.default_timeout == 60
|
||||
assert config.connection_pool_size == 20
|
||||
assert "test-server" in config.mcp_servers
|
||||
assert config.mcp_servers["test-server"].timeout == 45
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
|
||||
def test_from_yaml_file_not_found(self):
|
||||
"""Test error when YAML file not found."""
|
||||
with pytest.raises(FileNotFoundError):
|
||||
MCPConfig.from_yaml("/nonexistent/path/config.yaml")
|
||||
|
||||
|
||||
class TestLoadMCPConfig:
|
||||
"""Tests for load_mcp_config function."""
|
||||
|
||||
def test_load_with_explicit_path(self):
|
||||
"""Test loading config with explicit path."""
|
||||
yaml_content = """
|
||||
mcp_servers:
|
||||
explicit-server:
|
||||
url: http://explicit:8000
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".yaml", delete=False
|
||||
) as f:
|
||||
f.write(yaml_content)
|
||||
f.flush()
|
||||
|
||||
try:
|
||||
config = load_mcp_config(f.name)
|
||||
assert "explicit-server" in config.mcp_servers
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
|
||||
def test_load_with_env_var(self):
|
||||
"""Test loading config from environment variable path."""
|
||||
yaml_content = """
|
||||
mcp_servers:
|
||||
env-server:
|
||||
url: http://env:8000
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".yaml", delete=False
|
||||
) as f:
|
||||
f.write(yaml_content)
|
||||
f.flush()
|
||||
|
||||
os.environ["MCP_CONFIG_PATH"] = f.name
|
||||
try:
|
||||
config = load_mcp_config()
|
||||
assert "env-server" in config.mcp_servers
|
||||
finally:
|
||||
del os.environ["MCP_CONFIG_PATH"]
|
||||
os.unlink(f.name)
|
||||
|
||||
def test_load_returns_empty_config_if_missing(self):
|
||||
"""Test that missing file returns empty config."""
|
||||
os.environ.pop("MCP_CONFIG_PATH", None)
|
||||
config = load_mcp_config("/nonexistent/path/config.yaml")
|
||||
assert config.mcp_servers == {}
|
||||
|
||||
|
||||
class TestCreateDefaultConfig:
|
||||
"""Tests for create_default_config function."""
|
||||
|
||||
def test_creates_standard_servers(self):
|
||||
"""Test that default config has standard servers."""
|
||||
config = create_default_config()
|
||||
|
||||
assert "llm-gateway" in config.mcp_servers
|
||||
assert "knowledge-base" in config.mcp_servers
|
||||
assert "git-ops" in config.mcp_servers
|
||||
assert "issues" in config.mcp_servers
|
||||
|
||||
def test_servers_have_correct_defaults(self):
|
||||
"""Test that servers have correct default values."""
|
||||
config = create_default_config()
|
||||
|
||||
llm = config.mcp_servers["llm-gateway"]
|
||||
assert llm.timeout == 60 # LLM has longer timeout
|
||||
assert llm.transport == TransportType.HTTP
|
||||
|
||||
git = config.mcp_servers["git-ops"]
|
||||
assert git.timeout == 120 # Git ops has longest timeout
|
||||
|
||||
def test_servers_are_enabled(self):
|
||||
"""Test that all default servers are enabled."""
|
||||
config = create_default_config()
|
||||
|
||||
for name, server in config.mcp_servers.items():
|
||||
assert server.enabled is True, f"Server {name} should be enabled"
|
||||
405
backend/tests/services/mcp/test_connection.py
Normal file
405
backend/tests/services/mcp/test_connection.py
Normal file
@@ -0,0 +1,405 @@
|
||||
"""
|
||||
Tests for MCP Connection Management
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from app.services.mcp.config import MCPServerConfig, TransportType
|
||||
from app.services.mcp.connection import (
|
||||
ConnectionPool,
|
||||
ConnectionState,
|
||||
MCPConnection,
|
||||
)
|
||||
from app.services.mcp.exceptions import MCPConnectionError, MCPTimeoutError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server_config():
|
||||
"""Create a sample server configuration."""
|
||||
return MCPServerConfig(
|
||||
url="http://localhost:8000",
|
||||
transport=TransportType.HTTP,
|
||||
timeout=30,
|
||||
retry_attempts=3,
|
||||
retry_delay=0.1, # Short delay for tests
|
||||
retry_max_delay=1.0,
|
||||
)
|
||||
|
||||
|
||||
class TestConnectionState:
|
||||
"""Tests for ConnectionState enum."""
|
||||
|
||||
def test_connection_states(self):
|
||||
"""Test all connection states are defined."""
|
||||
assert ConnectionState.DISCONNECTED == "disconnected"
|
||||
assert ConnectionState.CONNECTING == "connecting"
|
||||
assert ConnectionState.CONNECTED == "connected"
|
||||
assert ConnectionState.RECONNECTING == "reconnecting"
|
||||
assert ConnectionState.ERROR == "error"
|
||||
|
||||
|
||||
class TestMCPConnection:
|
||||
"""Tests for MCPConnection class."""
|
||||
|
||||
def test_initial_state(self, server_config):
|
||||
"""Test initial connection state."""
|
||||
conn = MCPConnection("test-server", server_config)
|
||||
assert conn.server_name == "test-server"
|
||||
assert conn.config == server_config
|
||||
assert conn.state == ConnectionState.DISCONNECTED
|
||||
assert conn.is_connected is False
|
||||
assert conn.last_error is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_success(self, server_config):
|
||||
"""Test successful connection."""
|
||||
conn = MCPConnection("test-server", server_config)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
await conn.connect()
|
||||
|
||||
assert conn.state == ConnectionState.CONNECTED
|
||||
assert conn.is_connected is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_404_capabilities_ok(self, server_config):
|
||||
"""Test connection succeeds even if /mcp/capabilities returns 404."""
|
||||
conn = MCPConnection("test-server", server_config)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 404
|
||||
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
await conn.connect()
|
||||
|
||||
assert conn.state == ConnectionState.CONNECTED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_failure_with_retry(self, server_config):
|
||||
"""Test connection failure with retries."""
|
||||
# Reduce retry attempts for faster test
|
||||
server_config.retry_attempts = 2
|
||||
server_config.retry_delay = 0.01
|
||||
conn = MCPConnection("test-server", server_config)
|
||||
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(
|
||||
side_effect=httpx.ConnectError("Connection refused")
|
||||
)
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
with pytest.raises(MCPConnectionError) as exc_info:
|
||||
await conn.connect()
|
||||
|
||||
assert conn.state == ConnectionState.ERROR
|
||||
assert "Failed to connect after" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect(self, server_config):
|
||||
"""Test disconnection."""
|
||||
conn = MCPConnection("test-server", server_config)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.aclose = AsyncMock()
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
await conn.connect()
|
||||
assert conn.is_connected is True
|
||||
|
||||
await conn.disconnect()
|
||||
assert conn.state == ConnectionState.DISCONNECTED
|
||||
assert conn.is_connected is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect(self, server_config):
|
||||
"""Test reconnection."""
|
||||
conn = MCPConnection("test-server", server_config)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.aclose = AsyncMock()
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
await conn.connect()
|
||||
assert conn.is_connected is True
|
||||
|
||||
await conn.reconnect()
|
||||
assert conn.is_connected is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_success(self, server_config):
|
||||
"""Test successful health check."""
|
||||
conn = MCPConnection("test-server", server_config)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
await conn.connect()
|
||||
healthy = await conn.health_check()
|
||||
|
||||
assert healthy is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_disconnected(self, server_config):
|
||||
"""Test health check when disconnected."""
|
||||
conn = MCPConnection("test-server", server_config)
|
||||
healthy = await conn.health_check()
|
||||
assert healthy is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_request_get(self, server_config):
|
||||
"""Test executing GET request."""
|
||||
conn = MCPConnection("test-server", server_config)
|
||||
|
||||
mock_connect_response = MagicMock()
|
||||
mock_connect_response.status_code = 200
|
||||
|
||||
mock_request_response = MagicMock()
|
||||
mock_request_response.status_code = 200
|
||||
mock_request_response.json.return_value = {"tools": []}
|
||||
mock_request_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(
|
||||
side_effect=[mock_connect_response, mock_request_response]
|
||||
)
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
await conn.connect()
|
||||
result = await conn.execute_request("GET", "/mcp/tools")
|
||||
|
||||
assert result == {"tools": []}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_request_post(self, server_config):
|
||||
"""Test executing POST request."""
|
||||
conn = MCPConnection("test-server", server_config)
|
||||
|
||||
mock_connect_response = MagicMock()
|
||||
mock_connect_response.status_code = 200
|
||||
|
||||
mock_request_response = MagicMock()
|
||||
mock_request_response.status_code = 200
|
||||
mock_request_response.json.return_value = {"result": "success"}
|
||||
mock_request_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_connect_response)
|
||||
mock_client.post = AsyncMock(return_value=mock_request_response)
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
await conn.connect()
|
||||
result = await conn.execute_request(
|
||||
"POST", "/mcp", data={"method": "test"}
|
||||
)
|
||||
|
||||
assert result == {"result": "success"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_request_timeout(self, server_config):
|
||||
"""Test request timeout."""
|
||||
conn = MCPConnection("test-server", server_config)
|
||||
|
||||
mock_connect_response = MagicMock()
|
||||
mock_connect_response.status_code = 200
|
||||
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(
|
||||
side_effect=[
|
||||
mock_connect_response,
|
||||
httpx.TimeoutException("Request timeout"),
|
||||
]
|
||||
)
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
await conn.connect()
|
||||
|
||||
with pytest.raises(MCPTimeoutError) as exc_info:
|
||||
await conn.execute_request("GET", "/slow-endpoint")
|
||||
|
||||
assert "timeout" in str(exc_info.value).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_request_not_connected(self, server_config):
|
||||
"""Test request when not connected."""
|
||||
conn = MCPConnection("test-server", server_config)
|
||||
|
||||
with pytest.raises(MCPConnectionError) as exc_info:
|
||||
await conn.execute_request("GET", "/test")
|
||||
|
||||
assert "Not connected" in str(exc_info.value)
|
||||
|
||||
def test_backoff_delay_calculation(self, server_config):
|
||||
"""Test exponential backoff delay calculation."""
|
||||
conn = MCPConnection("test-server", server_config)
|
||||
|
||||
# First attempt
|
||||
conn._connection_attempts = 1
|
||||
delay1 = conn._calculate_backoff_delay()
|
||||
|
||||
# Second attempt
|
||||
conn._connection_attempts = 2
|
||||
delay2 = conn._calculate_backoff_delay()
|
||||
|
||||
# Third attempt
|
||||
conn._connection_attempts = 3
|
||||
delay3 = conn._calculate_backoff_delay()
|
||||
|
||||
# Delays should generally increase (with some jitter)
|
||||
# Base is 0.1, so rough expectations:
|
||||
# Attempt 1: ~0.1s
|
||||
# Attempt 2: ~0.2s
|
||||
# Attempt 3: ~0.4s
|
||||
assert delay1 > 0
|
||||
assert delay2 > delay1 * 0.5 # Allow for jitter
|
||||
assert delay3 <= server_config.retry_max_delay * 1.25 # Within max + jitter
|
||||
|
||||
|
||||
class TestConnectionPool:
|
||||
"""Tests for ConnectionPool class."""
|
||||
|
||||
@pytest.fixture
|
||||
def pool(self):
|
||||
"""Create a connection pool."""
|
||||
return ConnectionPool(max_connections_per_server=5)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_connection_creates_new(self, pool, server_config):
|
||||
"""Test getting connection creates new one if not exists."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
conn = await pool.get_connection("test-server", server_config)
|
||||
|
||||
assert conn is not None
|
||||
assert conn.is_connected is True
|
||||
assert conn.server_name == "test-server"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_connection_reuses_existing(self, pool, server_config):
|
||||
"""Test getting connection reuses existing one."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
conn1 = await pool.get_connection("test-server", server_config)
|
||||
conn2 = await pool.get_connection("test-server", server_config)
|
||||
|
||||
assert conn1 is conn2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_connection(self, pool, server_config):
|
||||
"""Test closing a connection."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.aclose = AsyncMock()
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
await pool.get_connection("test-server", server_config)
|
||||
await pool.close_connection("test-server")
|
||||
|
||||
assert "test-server" not in pool._connections
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_all(self, pool, server_config):
|
||||
"""Test closing all connections."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.aclose = AsyncMock()
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
config2 = MCPServerConfig(url="http://server2:8000")
|
||||
await pool.get_connection("server-1", server_config)
|
||||
await pool.get_connection("server-2", config2)
|
||||
|
||||
assert len(pool._connections) == 2
|
||||
|
||||
await pool.close_all()
|
||||
|
||||
assert len(pool._connections) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_all(self, pool, server_config):
|
||||
"""Test health check on all connections."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
await pool.get_connection("test-server", server_config)
|
||||
results = await pool.health_check_all()
|
||||
|
||||
assert "test-server" in results
|
||||
assert results["test-server"] is True
|
||||
|
||||
def test_get_status(self, pool, server_config):
|
||||
"""Test getting pool status."""
|
||||
status = pool.get_status()
|
||||
assert status == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_context_manager(self, pool, server_config):
|
||||
"""Test connection context manager."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch("httpx.AsyncClient") as MockClient:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
async with pool.connection("test-server", server_config) as conn:
|
||||
assert conn.is_connected is True
|
||||
assert conn.server_name == "test-server"
|
||||
259
backend/tests/services/mcp/test_exceptions.py
Normal file
259
backend/tests/services/mcp/test_exceptions.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""
|
||||
Tests for MCP Exception Classes
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.mcp.exceptions import (
|
||||
MCPCircuitOpenError,
|
||||
MCPConnectionError,
|
||||
MCPError,
|
||||
MCPServerNotFoundError,
|
||||
MCPTimeoutError,
|
||||
MCPToolError,
|
||||
MCPToolNotFoundError,
|
||||
MCPValidationError,
|
||||
)
|
||||
|
||||
|
||||
class TestMCPError:
|
||||
"""Tests for base MCPError class."""
|
||||
|
||||
def test_basic_error(self):
|
||||
"""Test basic error creation."""
|
||||
error = MCPError("Test error")
|
||||
assert error.message == "Test error"
|
||||
assert error.server_name is None
|
||||
assert error.details == {}
|
||||
assert str(error) == "Test error"
|
||||
|
||||
def test_error_with_server_name(self):
|
||||
"""Test error with server name."""
|
||||
error = MCPError("Test error", server_name="test-server")
|
||||
assert error.server_name == "test-server"
|
||||
assert "server=test-server" in str(error)
|
||||
|
||||
def test_error_with_details(self):
|
||||
"""Test error with additional details."""
|
||||
error = MCPError(
|
||||
"Test error",
|
||||
server_name="test-server",
|
||||
details={"key": "value"},
|
||||
)
|
||||
assert error.details == {"key": "value"}
|
||||
assert "details={'key': 'value'}" in str(error)
|
||||
|
||||
|
||||
class TestMCPConnectionError:
|
||||
"""Tests for MCPConnectionError class."""
|
||||
|
||||
def test_basic_connection_error(self):
|
||||
"""Test basic connection error."""
|
||||
error = MCPConnectionError("Connection failed")
|
||||
assert error.message == "Connection failed"
|
||||
assert error.url is None
|
||||
assert error.cause is None
|
||||
|
||||
def test_connection_error_with_url(self):
|
||||
"""Test connection error with URL."""
|
||||
error = MCPConnectionError(
|
||||
"Connection failed",
|
||||
server_name="test-server",
|
||||
url="http://localhost:8000",
|
||||
)
|
||||
assert error.url == "http://localhost:8000"
|
||||
assert "url=http://localhost:8000" in str(error)
|
||||
|
||||
def test_connection_error_with_cause(self):
|
||||
"""Test connection error with cause."""
|
||||
cause = ConnectionError("Network error")
|
||||
error = MCPConnectionError(
|
||||
"Connection failed",
|
||||
cause=cause,
|
||||
)
|
||||
assert error.cause is cause
|
||||
assert "ConnectionError" in str(error)
|
||||
|
||||
|
||||
class TestMCPTimeoutError:
|
||||
"""Tests for MCPTimeoutError class."""
|
||||
|
||||
def test_basic_timeout_error(self):
|
||||
"""Test basic timeout error."""
|
||||
error = MCPTimeoutError("Request timed out")
|
||||
assert error.message == "Request timed out"
|
||||
assert error.timeout_seconds is None
|
||||
assert error.operation is None
|
||||
|
||||
def test_timeout_error_with_details(self):
|
||||
"""Test timeout error with details."""
|
||||
error = MCPTimeoutError(
|
||||
"Request timed out",
|
||||
server_name="test-server",
|
||||
timeout_seconds=30.0,
|
||||
operation="POST /mcp",
|
||||
)
|
||||
assert error.timeout_seconds == 30.0
|
||||
assert error.operation == "POST /mcp"
|
||||
assert "timeout=30.0s" in str(error)
|
||||
assert "operation=POST /mcp" in str(error)
|
||||
|
||||
|
||||
class TestMCPToolError:
|
||||
"""Tests for MCPToolError class."""
|
||||
|
||||
def test_basic_tool_error(self):
|
||||
"""Test basic tool error."""
|
||||
error = MCPToolError("Tool execution failed")
|
||||
assert error.message == "Tool execution failed"
|
||||
assert error.tool_name is None
|
||||
assert error.tool_args is None
|
||||
assert error.error_code is None
|
||||
|
||||
def test_tool_error_with_details(self):
|
||||
"""Test tool error with all details."""
|
||||
error = MCPToolError(
|
||||
"Tool execution failed",
|
||||
server_name="llm-gateway",
|
||||
tool_name="chat",
|
||||
tool_args={"prompt": "Hello"},
|
||||
error_code="INVALID_ARGS",
|
||||
)
|
||||
assert error.tool_name == "chat"
|
||||
assert error.tool_args == {"prompt": "Hello"}
|
||||
assert error.error_code == "INVALID_ARGS"
|
||||
assert "tool=chat" in str(error)
|
||||
assert "error_code=INVALID_ARGS" in str(error)
|
||||
|
||||
|
||||
class TestMCPServerNotFoundError:
|
||||
"""Tests for MCPServerNotFoundError class."""
|
||||
|
||||
def test_server_not_found(self):
|
||||
"""Test server not found error."""
|
||||
error = MCPServerNotFoundError("unknown-server")
|
||||
assert error.server_name == "unknown-server"
|
||||
assert "MCP server not found: unknown-server" in error.message
|
||||
assert error.available_servers == []
|
||||
|
||||
def test_server_not_found_with_available(self):
|
||||
"""Test server not found with available servers listed."""
|
||||
error = MCPServerNotFoundError(
|
||||
"unknown-server",
|
||||
available_servers=["server-1", "server-2"],
|
||||
)
|
||||
assert error.available_servers == ["server-1", "server-2"]
|
||||
assert "available=['server-1', 'server-2']" in str(error)
|
||||
|
||||
|
||||
class TestMCPToolNotFoundError:
|
||||
"""Tests for MCPToolNotFoundError class."""
|
||||
|
||||
def test_tool_not_found(self):
|
||||
"""Test tool not found error."""
|
||||
error = MCPToolNotFoundError("unknown-tool")
|
||||
assert error.tool_name == "unknown-tool"
|
||||
assert "Tool not found: unknown-tool" in error.message
|
||||
assert error.available_tools == []
|
||||
|
||||
def test_tool_not_found_with_available(self):
|
||||
"""Test tool not found with available tools listed."""
|
||||
error = MCPToolNotFoundError(
|
||||
"unknown-tool",
|
||||
available_tools=["tool-1", "tool-2", "tool-3", "tool-4", "tool-5", "tool-6"],
|
||||
)
|
||||
assert len(error.available_tools) == 6
|
||||
# Should show first 5 tools with ellipsis
|
||||
assert "available_tools=['tool-1', 'tool-2', 'tool-3', 'tool-4', 'tool-5']..." in str(error)
|
||||
|
||||
|
||||
class TestMCPCircuitOpenError:
|
||||
"""Tests for MCPCircuitOpenError class."""
|
||||
|
||||
def test_circuit_open_error(self):
|
||||
"""Test circuit open error."""
|
||||
error = MCPCircuitOpenError("test-server")
|
||||
assert error.server_name == "test-server"
|
||||
assert "Circuit breaker open for server: test-server" in error.message
|
||||
assert error.failure_count is None
|
||||
assert error.reset_timeout is None
|
||||
|
||||
def test_circuit_open_error_with_details(self):
|
||||
"""Test circuit open error with details."""
|
||||
error = MCPCircuitOpenError(
|
||||
"test-server",
|
||||
failure_count=5,
|
||||
reset_timeout=30.0,
|
||||
)
|
||||
assert error.failure_count == 5
|
||||
assert error.reset_timeout == 30.0
|
||||
assert "failures=5" in str(error)
|
||||
assert "reset_in=30.0s" in str(error)
|
||||
|
||||
|
||||
class TestMCPValidationError:
|
||||
"""Tests for MCPValidationError class."""
|
||||
|
||||
def test_validation_error(self):
|
||||
"""Test validation error."""
|
||||
error = MCPValidationError("Validation failed")
|
||||
assert error.message == "Validation failed"
|
||||
assert error.tool_name is None
|
||||
assert error.field_errors == {}
|
||||
|
||||
def test_validation_error_with_details(self):
|
||||
"""Test validation error with field errors."""
|
||||
error = MCPValidationError(
|
||||
"Validation failed",
|
||||
tool_name="create_issue",
|
||||
field_errors={
|
||||
"title": "Title is required",
|
||||
"priority": "Invalid priority value",
|
||||
},
|
||||
)
|
||||
assert error.tool_name == "create_issue"
|
||||
assert error.field_errors == {
|
||||
"title": "Title is required",
|
||||
"priority": "Invalid priority value",
|
||||
}
|
||||
assert "tool=create_issue" in str(error)
|
||||
assert "fields=['title', 'priority']" in str(error)
|
||||
|
||||
|
||||
class TestExceptionInheritance:
|
||||
"""Tests for exception inheritance chain."""
|
||||
|
||||
def test_all_errors_inherit_from_mcp_error(self):
|
||||
"""Test that all custom exceptions inherit from MCPError."""
|
||||
assert issubclass(MCPConnectionError, MCPError)
|
||||
assert issubclass(MCPTimeoutError, MCPError)
|
||||
assert issubclass(MCPToolError, MCPError)
|
||||
assert issubclass(MCPServerNotFoundError, MCPError)
|
||||
assert issubclass(MCPToolNotFoundError, MCPError)
|
||||
assert issubclass(MCPCircuitOpenError, MCPError)
|
||||
assert issubclass(MCPValidationError, MCPError)
|
||||
|
||||
def test_all_errors_inherit_from_exception(self):
|
||||
"""Test that base MCPError inherits from Exception."""
|
||||
assert issubclass(MCPError, Exception)
|
||||
|
||||
def test_catch_all_with_mcp_error(self):
|
||||
"""Test that all errors can be caught with MCPError."""
|
||||
|
||||
def raise_connection_error():
|
||||
raise MCPConnectionError("Connection failed")
|
||||
|
||||
def raise_timeout_error():
|
||||
raise MCPTimeoutError("Timeout")
|
||||
|
||||
def raise_tool_error():
|
||||
raise MCPToolError("Tool failed")
|
||||
|
||||
with pytest.raises(MCPError):
|
||||
raise_connection_error()
|
||||
|
||||
with pytest.raises(MCPError):
|
||||
raise_timeout_error()
|
||||
|
||||
with pytest.raises(MCPError):
|
||||
raise_tool_error()
|
||||
272
backend/tests/services/mcp/test_registry.py
Normal file
272
backend/tests/services/mcp/test_registry.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
Tests for MCP Server Registry
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.mcp.config import MCPConfig, MCPServerConfig, TransportType
|
||||
from app.services.mcp.exceptions import MCPServerNotFoundError
|
||||
from app.services.mcp.registry import (
|
||||
MCPServerRegistry,
|
||||
ServerCapabilities,
|
||||
get_registry,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_registry():
|
||||
"""Reset the singleton registry before and after each test."""
|
||||
MCPServerRegistry.reset_instance()
|
||||
yield
|
||||
MCPServerRegistry.reset_instance()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_config():
|
||||
"""Create a sample MCP configuration."""
|
||||
return MCPConfig(
|
||||
mcp_servers={
|
||||
"server-1": MCPServerConfig(
|
||||
url="http://server1:8000",
|
||||
timeout=30,
|
||||
enabled=True,
|
||||
),
|
||||
"server-2": MCPServerConfig(
|
||||
url="http://server2:8000",
|
||||
timeout=60,
|
||||
enabled=True,
|
||||
),
|
||||
"disabled-server": MCPServerConfig(
|
||||
url="http://disabled:8000",
|
||||
enabled=False,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TestServerCapabilities:
|
||||
"""Tests for ServerCapabilities class."""
|
||||
|
||||
def test_empty_capabilities(self):
|
||||
"""Test creating empty capabilities."""
|
||||
caps = ServerCapabilities()
|
||||
assert caps.tools == []
|
||||
assert caps.resources == []
|
||||
assert caps.prompts == []
|
||||
assert caps.is_loaded is False
|
||||
assert caps.tool_names == []
|
||||
|
||||
def test_capabilities_with_tools(self):
|
||||
"""Test capabilities with tools."""
|
||||
caps = ServerCapabilities(
|
||||
tools=[
|
||||
{"name": "tool1", "description": "Tool 1"},
|
||||
{"name": "tool2", "description": "Tool 2"},
|
||||
]
|
||||
)
|
||||
assert len(caps.tools) == 2
|
||||
assert caps.tool_names == ["tool1", "tool2"]
|
||||
|
||||
def test_mark_loaded(self):
|
||||
"""Test marking capabilities as loaded."""
|
||||
caps = ServerCapabilities()
|
||||
assert caps.is_loaded is False
|
||||
assert caps._load_time is None
|
||||
|
||||
caps.mark_loaded()
|
||||
assert caps.is_loaded is True
|
||||
assert caps._load_time is not None
|
||||
|
||||
|
||||
class TestMCPServerRegistry:
|
||||
"""Tests for MCPServerRegistry singleton."""
|
||||
|
||||
def test_singleton_pattern(self, reset_registry):
|
||||
"""Test that registry is a singleton."""
|
||||
registry1 = MCPServerRegistry()
|
||||
registry2 = MCPServerRegistry()
|
||||
assert registry1 is registry2
|
||||
|
||||
def test_get_instance(self, reset_registry):
|
||||
"""Test get_instance class method."""
|
||||
registry = MCPServerRegistry.get_instance()
|
||||
assert registry is MCPServerRegistry()
|
||||
|
||||
def test_reset_instance(self, reset_registry):
|
||||
"""Test resetting singleton instance."""
|
||||
registry1 = MCPServerRegistry()
|
||||
MCPServerRegistry.reset_instance()
|
||||
registry2 = MCPServerRegistry()
|
||||
assert registry1 is not registry2
|
||||
|
||||
def test_load_config(self, reset_registry, sample_config):
|
||||
"""Test loading configuration."""
|
||||
registry = MCPServerRegistry()
|
||||
registry.load_config(sample_config)
|
||||
|
||||
assert len(registry.list_servers()) == 3
|
||||
assert "server-1" in registry.list_servers()
|
||||
assert "server-2" in registry.list_servers()
|
||||
assert "disabled-server" in registry.list_servers()
|
||||
|
||||
def test_list_enabled_servers(self, reset_registry, sample_config):
|
||||
"""Test listing only enabled servers."""
|
||||
registry = MCPServerRegistry()
|
||||
registry.load_config(sample_config)
|
||||
|
||||
enabled = registry.list_enabled_servers()
|
||||
assert len(enabled) == 2
|
||||
assert "server-1" in enabled
|
||||
assert "server-2" in enabled
|
||||
assert "disabled-server" not in enabled
|
||||
|
||||
def test_register(self, reset_registry):
|
||||
"""Test registering a new server."""
|
||||
registry = MCPServerRegistry()
|
||||
config = MCPServerConfig(url="http://new:8000")
|
||||
|
||||
registry.register("new-server", config)
|
||||
assert "new-server" in registry.list_servers()
|
||||
assert registry.get("new-server").url == "http://new:8000"
|
||||
|
||||
def test_unregister(self, reset_registry, sample_config):
|
||||
"""Test unregistering a server."""
|
||||
registry = MCPServerRegistry()
|
||||
registry.load_config(sample_config)
|
||||
|
||||
assert registry.unregister("server-1") is True
|
||||
assert "server-1" not in registry.list_servers()
|
||||
|
||||
# Unregistering non-existent server returns False
|
||||
assert registry.unregister("nonexistent") is False
|
||||
|
||||
def test_get(self, reset_registry, sample_config):
|
||||
"""Test getting server config."""
|
||||
registry = MCPServerRegistry()
|
||||
registry.load_config(sample_config)
|
||||
|
||||
config = registry.get("server-1")
|
||||
assert config.url == "http://server1:8000"
|
||||
assert config.timeout == 30
|
||||
|
||||
def test_get_not_found(self, reset_registry, sample_config):
|
||||
"""Test getting non-existent server raises error."""
|
||||
registry = MCPServerRegistry()
|
||||
registry.load_config(sample_config)
|
||||
|
||||
with pytest.raises(MCPServerNotFoundError) as exc_info:
|
||||
registry.get("nonexistent")
|
||||
|
||||
assert exc_info.value.server_name == "nonexistent"
|
||||
assert "server-1" in exc_info.value.available_servers
|
||||
|
||||
def test_get_or_none(self, reset_registry, sample_config):
|
||||
"""Test get_or_none method."""
|
||||
registry = MCPServerRegistry()
|
||||
registry.load_config(sample_config)
|
||||
|
||||
config = registry.get_or_none("server-1")
|
||||
assert config is not None
|
||||
|
||||
config = registry.get_or_none("nonexistent")
|
||||
assert config is None
|
||||
|
||||
def test_get_all_configs(self, reset_registry, sample_config):
|
||||
"""Test getting all configs."""
|
||||
registry = MCPServerRegistry()
|
||||
registry.load_config(sample_config)
|
||||
|
||||
configs = registry.get_all_configs()
|
||||
assert len(configs) == 3
|
||||
|
||||
def test_get_enabled_configs(self, reset_registry, sample_config):
|
||||
"""Test getting enabled configs."""
|
||||
registry = MCPServerRegistry()
|
||||
registry.load_config(sample_config)
|
||||
|
||||
configs = registry.get_enabled_configs()
|
||||
assert len(configs) == 2
|
||||
assert "disabled-server" not in configs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_capabilities(self, reset_registry, sample_config):
|
||||
"""Test getting server capabilities."""
|
||||
registry = MCPServerRegistry()
|
||||
registry.load_config(sample_config)
|
||||
|
||||
# Initially empty capabilities
|
||||
caps = await registry.get_capabilities("server-1")
|
||||
assert caps.is_loaded is False
|
||||
|
||||
def test_set_capabilities(self, reset_registry, sample_config):
|
||||
"""Test setting server capabilities."""
|
||||
registry = MCPServerRegistry()
|
||||
registry.load_config(sample_config)
|
||||
|
||||
registry.set_capabilities(
|
||||
"server-1",
|
||||
tools=[{"name": "tool1"}, {"name": "tool2"}],
|
||||
resources=[{"name": "resource1"}],
|
||||
)
|
||||
|
||||
caps = registry._capabilities["server-1"]
|
||||
assert len(caps.tools) == 2
|
||||
assert len(caps.resources) == 1
|
||||
assert caps.is_loaded is True
|
||||
|
||||
def test_find_server_for_tool(self, reset_registry, sample_config):
|
||||
"""Test finding server that provides a tool."""
|
||||
registry = MCPServerRegistry()
|
||||
registry.load_config(sample_config)
|
||||
|
||||
registry.set_capabilities(
|
||||
"server-1",
|
||||
tools=[{"name": "tool1"}, {"name": "tool2"}],
|
||||
)
|
||||
registry.set_capabilities(
|
||||
"server-2",
|
||||
tools=[{"name": "tool3"}],
|
||||
)
|
||||
|
||||
assert registry.find_server_for_tool("tool1") == "server-1"
|
||||
assert registry.find_server_for_tool("tool3") == "server-2"
|
||||
assert registry.find_server_for_tool("unknown") is None
|
||||
|
||||
def test_get_all_tools(self, reset_registry, sample_config):
|
||||
"""Test getting all tools from all servers."""
|
||||
registry = MCPServerRegistry()
|
||||
registry.load_config(sample_config)
|
||||
|
||||
registry.set_capabilities(
|
||||
"server-1",
|
||||
tools=[{"name": "tool1"}],
|
||||
)
|
||||
registry.set_capabilities(
|
||||
"server-2",
|
||||
tools=[{"name": "tool2"}, {"name": "tool3"}],
|
||||
)
|
||||
|
||||
all_tools = registry.get_all_tools()
|
||||
assert len(all_tools) == 2
|
||||
assert len(all_tools["server-1"]) == 1
|
||||
assert len(all_tools["server-2"]) == 2
|
||||
|
||||
def test_global_config_property(self, reset_registry, sample_config):
|
||||
"""Test accessing global config."""
|
||||
registry = MCPServerRegistry()
|
||||
registry.load_config(sample_config)
|
||||
|
||||
global_config = registry.global_config
|
||||
assert global_config is not None
|
||||
assert len(global_config.mcp_servers) == 3
|
||||
|
||||
|
||||
class TestGetRegistry:
|
||||
"""Tests for get_registry convenience function."""
|
||||
|
||||
def test_get_registry_returns_singleton(self, reset_registry):
|
||||
"""Test that get_registry returns the singleton."""
|
||||
registry1 = get_registry()
|
||||
registry2 = get_registry()
|
||||
assert registry1 is registry2
|
||||
assert registry1 is MCPServerRegistry()
|
||||
345
backend/tests/services/mcp/test_routing.py
Normal file
345
backend/tests/services/mcp/test_routing.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""
|
||||
Tests for MCP Tool Call Routing
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.mcp.config import MCPConfig, MCPServerConfig
|
||||
from app.services.mcp.connection import ConnectionPool
|
||||
from app.services.mcp.exceptions import (
|
||||
MCPCircuitOpenError,
|
||||
MCPToolError,
|
||||
MCPToolNotFoundError,
|
||||
)
|
||||
from app.services.mcp.registry import MCPServerRegistry
|
||||
from app.services.mcp.routing import ToolInfo, ToolResult, ToolRouter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_registry():
|
||||
"""Reset the singleton registry before and after each test."""
|
||||
MCPServerRegistry.reset_instance()
|
||||
yield
|
||||
MCPServerRegistry.reset_instance()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def registry(reset_registry):
|
||||
"""Create a configured registry."""
|
||||
reg = MCPServerRegistry()
|
||||
reg.load_config(
|
||||
MCPConfig(
|
||||
mcp_servers={
|
||||
"server-1": MCPServerConfig(
|
||||
url="http://server1:8000",
|
||||
retry_attempts=1,
|
||||
retry_delay=0.1, # Minimum allowed value
|
||||
circuit_breaker_threshold=3,
|
||||
circuit_breaker_timeout=5.0,
|
||||
),
|
||||
"server-2": MCPServerConfig(
|
||||
url="http://server2:8000",
|
||||
retry_attempts=1,
|
||||
retry_delay=0.1, # Minimum allowed value
|
||||
),
|
||||
}
|
||||
)
|
||||
)
|
||||
return reg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pool():
|
||||
"""Create a connection pool."""
|
||||
return ConnectionPool()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def router(registry, pool):
|
||||
"""Create a tool router."""
|
||||
return ToolRouter(registry, pool)
|
||||
|
||||
|
||||
class TestToolInfo:
|
||||
"""Tests for ToolInfo dataclass."""
|
||||
|
||||
def test_basic_tool_info(self):
|
||||
"""Test creating basic tool info."""
|
||||
info = ToolInfo(name="test-tool")
|
||||
assert info.name == "test-tool"
|
||||
assert info.description is None
|
||||
assert info.server_name is None
|
||||
assert info.input_schema is None
|
||||
|
||||
def test_full_tool_info(self):
|
||||
"""Test creating full tool info."""
|
||||
info = ToolInfo(
|
||||
name="create_issue",
|
||||
description="Create a new issue",
|
||||
server_name="issues",
|
||||
input_schema={"type": "object", "properties": {"title": {"type": "string"}}},
|
||||
)
|
||||
assert info.name == "create_issue"
|
||||
assert info.description == "Create a new issue"
|
||||
assert info.server_name == "issues"
|
||||
assert "properties" in info.input_schema
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test converting to dictionary."""
|
||||
info = ToolInfo(
|
||||
name="test-tool",
|
||||
description="A test tool",
|
||||
server_name="test-server",
|
||||
)
|
||||
result = info.to_dict()
|
||||
|
||||
assert result["name"] == "test-tool"
|
||||
assert result["description"] == "A test tool"
|
||||
assert result["server_name"] == "test-server"
|
||||
|
||||
|
||||
class TestToolResult:
|
||||
"""Tests for ToolResult dataclass."""
|
||||
|
||||
def test_success_result(self):
|
||||
"""Test creating success result."""
|
||||
result = ToolResult(
|
||||
success=True,
|
||||
data={"id": "123"},
|
||||
tool_name="create_issue",
|
||||
server_name="issues",
|
||||
)
|
||||
assert result.success is True
|
||||
assert result.data == {"id": "123"}
|
||||
assert result.error is None
|
||||
|
||||
def test_error_result(self):
|
||||
"""Test creating error result."""
|
||||
result = ToolResult(
|
||||
success=False,
|
||||
error="Tool execution failed",
|
||||
error_code="INTERNAL_ERROR",
|
||||
tool_name="create_issue",
|
||||
server_name="issues",
|
||||
)
|
||||
assert result.success is False
|
||||
assert result.error == "Tool execution failed"
|
||||
assert result.error_code == "INTERNAL_ERROR"
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test converting to dictionary."""
|
||||
result = ToolResult(
|
||||
success=True,
|
||||
data={"result": "ok"},
|
||||
tool_name="test",
|
||||
execution_time_ms=123.45,
|
||||
)
|
||||
d = result.to_dict()
|
||||
|
||||
assert d["success"] is True
|
||||
assert d["data"] == {"result": "ok"}
|
||||
assert d["tool_name"] == "test"
|
||||
assert d["execution_time_ms"] == 123.45
|
||||
assert "request_id" in d # Auto-generated
|
||||
|
||||
|
||||
class TestToolRouter:
|
||||
"""Tests for ToolRouter class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_tool_mapping(self, router):
|
||||
"""Test registering tool mappings."""
|
||||
await router.register_tool_mapping("tool1", "server-1")
|
||||
await router.register_tool_mapping("tool2", "server-2")
|
||||
|
||||
assert router.find_server_for_tool("tool1") == "server-1"
|
||||
assert router.find_server_for_tool("tool2") == "server-2"
|
||||
|
||||
def test_find_server_for_unknown_tool(self, router):
|
||||
"""Test finding server for unknown tool."""
|
||||
result = router.find_server_for_tool("unknown-tool")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_success(self, router, registry):
|
||||
"""Test successful tool call."""
|
||||
# Set up capabilities
|
||||
registry.set_capabilities(
|
||||
"server-1",
|
||||
tools=[{"name": "test-tool"}],
|
||||
)
|
||||
await router.register_tool_mapping("test-tool", "server-1")
|
||||
|
||||
# Mock the pool connection and request
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.execute_request = AsyncMock(
|
||||
return_value={"result": {"status": "ok"}}
|
||||
)
|
||||
mock_conn.is_connected = True
|
||||
|
||||
with patch.object(router._pool, "get_connection", return_value=mock_conn):
|
||||
result = await router.call_tool(
|
||||
server_name="server-1",
|
||||
tool_name="test-tool",
|
||||
arguments={"param": "value"},
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data == {"status": "ok"}
|
||||
assert result.tool_name == "test-tool"
|
||||
assert result.server_name == "server-1"
|
||||
assert result.execution_time_ms > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_error_response(self, router, registry):
|
||||
"""Test tool call with error response."""
|
||||
registry.set_capabilities(
|
||||
"server-1",
|
||||
tools=[{"name": "test-tool"}],
|
||||
)
|
||||
await router.register_tool_mapping("test-tool", "server-1")
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.execute_request = AsyncMock(
|
||||
return_value={
|
||||
"error": {
|
||||
"code": -32000,
|
||||
"message": "Tool execution failed",
|
||||
}
|
||||
}
|
||||
)
|
||||
mock_conn.is_connected = True
|
||||
|
||||
with patch.object(router._pool, "get_connection", return_value=mock_conn):
|
||||
result = await router.call_tool(
|
||||
server_name="server-1",
|
||||
tool_name="test-tool",
|
||||
arguments={},
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "Tool execution failed" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_tool(self, router, registry):
|
||||
"""Test routing tool to correct server."""
|
||||
registry.set_capabilities(
|
||||
"server-1",
|
||||
tools=[{"name": "tool-on-server-1"}],
|
||||
)
|
||||
await router.register_tool_mapping("tool-on-server-1", "server-1")
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.execute_request = AsyncMock(
|
||||
return_value={"result": "routed"}
|
||||
)
|
||||
mock_conn.is_connected = True
|
||||
|
||||
with patch.object(router._pool, "get_connection", return_value=mock_conn):
|
||||
result = await router.route_tool(
|
||||
tool_name="tool-on-server-1",
|
||||
arguments={"key": "value"},
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.server_name == "server-1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_tool_not_found(self, router):
|
||||
"""Test routing unknown tool raises error."""
|
||||
with pytest.raises(MCPToolNotFoundError) as exc_info:
|
||||
await router.route_tool(
|
||||
tool_name="unknown-tool",
|
||||
arguments={},
|
||||
)
|
||||
|
||||
assert exc_info.value.tool_name == "unknown-tool"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_all_tools(self, router, registry):
|
||||
"""Test listing all tools."""
|
||||
registry.set_capabilities(
|
||||
"server-1",
|
||||
tools=[
|
||||
{"name": "tool1", "description": "Tool 1"},
|
||||
{"name": "tool2", "description": "Tool 2"},
|
||||
],
|
||||
)
|
||||
registry.set_capabilities(
|
||||
"server-2",
|
||||
tools=[{"name": "tool3", "description": "Tool 3"}],
|
||||
)
|
||||
|
||||
tools = await router.list_all_tools()
|
||||
|
||||
assert len(tools) == 3
|
||||
tool_names = [t.name for t in tools]
|
||||
assert "tool1" in tool_names
|
||||
assert "tool2" in tool_names
|
||||
assert "tool3" in tool_names
|
||||
|
||||
def test_circuit_breaker_status(self, router, registry):
|
||||
"""Test getting circuit breaker status."""
|
||||
# Initially no circuit breakers
|
||||
status = router.get_circuit_breaker_status()
|
||||
assert status == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_circuit_breaker(self, router, registry):
|
||||
"""Test resetting circuit breaker."""
|
||||
# Reset non-existent returns False
|
||||
result = await router.reset_circuit_breaker("server-1")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_tools(self, router, registry):
|
||||
"""Test tool discovery from servers."""
|
||||
# Create mocks for different servers
|
||||
mock_conn_1 = AsyncMock()
|
||||
mock_conn_1.execute_request = AsyncMock(
|
||||
return_value={
|
||||
"tools": [
|
||||
{"name": "discovered-tool", "description": "A discovered tool"},
|
||||
]
|
||||
}
|
||||
)
|
||||
mock_conn_1.server_name = "server-1"
|
||||
mock_conn_1.is_connected = True
|
||||
|
||||
mock_conn_2 = AsyncMock()
|
||||
mock_conn_2.execute_request = AsyncMock(
|
||||
return_value={"tools": []} # Empty for server-2
|
||||
)
|
||||
mock_conn_2.server_name = "server-2"
|
||||
mock_conn_2.is_connected = True
|
||||
|
||||
async def get_connection_side_effect(server_name, _config):
|
||||
if server_name == "server-1":
|
||||
return mock_conn_1
|
||||
return mock_conn_2
|
||||
|
||||
with patch.object(
|
||||
router._pool,
|
||||
"get_connection",
|
||||
side_effect=get_connection_side_effect,
|
||||
):
|
||||
await router.discover_tools()
|
||||
|
||||
# Check that tool mapping was registered
|
||||
server = router.find_server_for_tool("discovered-tool")
|
||||
assert server == "server-1"
|
||||
|
||||
def test_calculate_retry_delay(self, router, registry):
|
||||
"""Test retry delay calculation."""
|
||||
config = registry.get("server-1")
|
||||
|
||||
delay1 = router._calculate_retry_delay(1, config)
|
||||
delay2 = router._calculate_retry_delay(2, config)
|
||||
delay3 = router._calculate_retry_delay(3, config)
|
||||
|
||||
# Delays should increase with attempts
|
||||
assert delay1 > 0
|
||||
# Allow for jitter variation
|
||||
assert delay1 <= config.retry_max_delay * 1.25
|
||||
Reference in New Issue
Block a user