forked from cardosofelipe/fast-next-template
Improved code readability and uniformity by standardizing line breaks, indentation, and inline conditions across safety-related services, models, and tests, including content filters, validation rules, and emergency controls.
395 lines
13 KiB
Python
395 lines
13 KiB
Python
"""
|
|
Tests for MCP Client Manager
|
|
"""
|
|
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from app.services.mcp.client_manager import (
|
|
MCPClientManager,
|
|
ServerHealth,
|
|
get_mcp_client,
|
|
reset_mcp_client,
|
|
shutdown_mcp_client,
|
|
)
|
|
from app.services.mcp.config import MCPConfig, MCPServerConfig
|
|
from app.services.mcp.exceptions import MCPServerNotFoundError
|
|
from app.services.mcp.registry import MCPServerRegistry
|
|
from app.services.mcp.routing import ToolInfo, ToolResult
|
|
|
|
|
|
@pytest.fixture
|
|
def reset_registry():
|
|
"""Reset the singleton registry before and after each test."""
|
|
MCPServerRegistry.reset_instance()
|
|
reset_mcp_client()
|
|
yield
|
|
MCPServerRegistry.reset_instance()
|
|
reset_mcp_client()
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_config():
|
|
"""Create a sample MCP configuration."""
|
|
return MCPConfig(
|
|
mcp_servers={
|
|
"server-1": MCPServerConfig(
|
|
url="http://server1:8000",
|
|
timeout=30,
|
|
enabled=True,
|
|
),
|
|
"server-2": MCPServerConfig(
|
|
url="http://server2:8000",
|
|
timeout=60,
|
|
enabled=True,
|
|
),
|
|
}
|
|
)
|
|
|
|
|
|
class TestServerHealth:
|
|
"""Tests for ServerHealth dataclass."""
|
|
|
|
def test_healthy_server(self):
|
|
"""Test healthy server status."""
|
|
health = ServerHealth(
|
|
name="test-server",
|
|
healthy=True,
|
|
state="connected",
|
|
url="http://test:8000",
|
|
tools_count=5,
|
|
)
|
|
assert health.healthy is True
|
|
assert health.error is None
|
|
assert health.tools_count == 5
|
|
|
|
def test_unhealthy_server(self):
|
|
"""Test unhealthy server status."""
|
|
health = ServerHealth(
|
|
name="test-server",
|
|
healthy=False,
|
|
state="error",
|
|
url="http://test:8000",
|
|
error="Connection refused",
|
|
)
|
|
assert health.healthy is False
|
|
assert health.error == "Connection refused"
|
|
|
|
def test_to_dict(self):
|
|
"""Test converting to dictionary."""
|
|
health = ServerHealth(
|
|
name="test-server",
|
|
healthy=True,
|
|
state="connected",
|
|
url="http://test:8000",
|
|
tools_count=3,
|
|
)
|
|
d = health.to_dict()
|
|
|
|
assert d["name"] == "test-server"
|
|
assert d["healthy"] is True
|
|
assert d["state"] == "connected"
|
|
assert d["url"] == "http://test:8000"
|
|
assert d["tools_count"] == 3
|
|
|
|
|
|
class TestMCPClientManager:
|
|
"""Tests for MCPClientManager class."""
|
|
|
|
def test_initial_state(self, reset_registry):
|
|
"""Test initial manager state."""
|
|
manager = MCPClientManager()
|
|
assert manager.is_initialized is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_initialize(self, reset_registry, sample_config):
|
|
"""Test manager initialization."""
|
|
manager = MCPClientManager(config=sample_config)
|
|
|
|
with patch.object(manager._pool, "get_connection") as mock_get_conn:
|
|
mock_conn = AsyncMock()
|
|
mock_conn.is_connected = True
|
|
mock_get_conn.return_value = mock_conn
|
|
|
|
with patch.object(manager, "_router") as mock_router:
|
|
mock_router.discover_tools = AsyncMock()
|
|
|
|
await manager.initialize()
|
|
|
|
assert manager.is_initialized is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_shutdown(self, reset_registry, sample_config):
|
|
"""Test manager shutdown."""
|
|
manager = MCPClientManager(config=sample_config)
|
|
manager._initialized = True
|
|
|
|
with patch.object(manager._pool, "close_all") as mock_close:
|
|
mock_close.return_value = None
|
|
await manager.shutdown()
|
|
|
|
assert manager.is_initialized is False
|
|
mock_close.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect(self, reset_registry, sample_config):
|
|
"""Test connecting to specific server."""
|
|
manager = MCPClientManager(config=sample_config)
|
|
|
|
with patch.object(manager._pool, "get_connection") as mock_get_conn:
|
|
mock_conn = AsyncMock()
|
|
mock_conn.is_connected = True
|
|
mock_get_conn.return_value = mock_conn
|
|
|
|
await manager.connect("server-1")
|
|
|
|
mock_get_conn.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect_unknown_server(self, reset_registry, sample_config):
|
|
"""Test connecting to unknown server raises error."""
|
|
manager = MCPClientManager(config=sample_config)
|
|
|
|
with pytest.raises(MCPServerNotFoundError):
|
|
await manager.connect("unknown-server")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_disconnect(self, reset_registry, sample_config):
|
|
"""Test disconnecting from server."""
|
|
manager = MCPClientManager(config=sample_config)
|
|
|
|
with patch.object(manager._pool, "close_connection") as mock_close:
|
|
await manager.disconnect("server-1")
|
|
mock_close.assert_called_once_with("server-1")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_disconnect_all(self, reset_registry, sample_config):
|
|
"""Test disconnecting from all servers."""
|
|
manager = MCPClientManager(config=sample_config)
|
|
|
|
with patch.object(manager._pool, "close_all") as mock_close:
|
|
await manager.disconnect_all()
|
|
mock_close.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_call_tool(self, reset_registry, sample_config):
|
|
"""Test calling a tool."""
|
|
manager = MCPClientManager(config=sample_config)
|
|
manager._initialized = True
|
|
|
|
expected_result = ToolResult(
|
|
success=True,
|
|
data={"id": "123"},
|
|
tool_name="create_issue",
|
|
server_name="server-1",
|
|
)
|
|
|
|
mock_router = MagicMock()
|
|
mock_router.call_tool = AsyncMock(return_value=expected_result)
|
|
manager._router = mock_router
|
|
|
|
result = await manager.call_tool(
|
|
server="server-1",
|
|
tool="create_issue",
|
|
args={"title": "Test"},
|
|
)
|
|
|
|
assert result.success is True
|
|
assert result.data == {"id": "123"}
|
|
mock_router.call_tool.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_route_tool(self, reset_registry, sample_config):
|
|
"""Test routing a tool call."""
|
|
manager = MCPClientManager(config=sample_config)
|
|
manager._initialized = True
|
|
|
|
expected_result = ToolResult(
|
|
success=True,
|
|
data={"result": "ok"},
|
|
tool_name="auto_tool",
|
|
server_name="server-2",
|
|
)
|
|
|
|
mock_router = MagicMock()
|
|
mock_router.route_tool = AsyncMock(return_value=expected_result)
|
|
manager._router = mock_router
|
|
|
|
result = await manager.route_tool(
|
|
tool="auto_tool",
|
|
args={"key": "value"},
|
|
)
|
|
|
|
assert result.success is True
|
|
assert result.server_name == "server-2"
|
|
mock_router.route_tool.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_tools(self, reset_registry, sample_config):
|
|
"""Test listing tools for a server."""
|
|
manager = MCPClientManager(config=sample_config)
|
|
|
|
# Set up capabilities in registry
|
|
manager._registry.set_capabilities(
|
|
"server-1",
|
|
tools=[
|
|
{"name": "tool1", "description": "Tool 1"},
|
|
{"name": "tool2", "description": "Tool 2"},
|
|
],
|
|
)
|
|
|
|
tools = await manager.list_tools("server-1")
|
|
|
|
assert len(tools) == 2
|
|
assert tools[0].name == "tool1"
|
|
assert tools[1].name == "tool2"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_all_tools(self, reset_registry, sample_config):
|
|
"""Test listing all tools from all servers."""
|
|
manager = MCPClientManager(config=sample_config)
|
|
manager._initialized = True
|
|
|
|
expected_tools = [
|
|
ToolInfo(name="tool1", server_name="server-1"),
|
|
ToolInfo(name="tool2", server_name="server-2"),
|
|
]
|
|
|
|
mock_router = MagicMock()
|
|
mock_router.list_all_tools = AsyncMock(return_value=expected_tools)
|
|
manager._router = mock_router
|
|
|
|
tools = await manager.list_all_tools()
|
|
|
|
assert len(tools) == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check(self, reset_registry, sample_config):
|
|
"""Test health check on all servers."""
|
|
manager = MCPClientManager(config=sample_config)
|
|
|
|
with patch.object(manager._pool, "get_status") as mock_status:
|
|
mock_status.return_value = {
|
|
"server-1": {"state": "connected"},
|
|
"server-2": {"state": "disconnected"},
|
|
}
|
|
|
|
with patch.object(manager._pool, "health_check_all") as mock_health:
|
|
mock_health.return_value = {
|
|
"server-1": True,
|
|
"server-2": False,
|
|
}
|
|
|
|
health = await manager.health_check()
|
|
|
|
assert "server-1" in health
|
|
assert "server-2" in health
|
|
assert health["server-1"].healthy is True
|
|
assert health["server-2"].healthy is False
|
|
|
|
def test_list_servers(self, reset_registry, sample_config):
|
|
"""Test listing registered servers."""
|
|
manager = MCPClientManager(config=sample_config)
|
|
servers = manager.list_servers()
|
|
|
|
assert "server-1" in servers
|
|
assert "server-2" in servers
|
|
|
|
def test_list_enabled_servers(self, reset_registry, sample_config):
|
|
"""Test listing enabled servers."""
|
|
manager = MCPClientManager(config=sample_config)
|
|
servers = manager.list_enabled_servers()
|
|
|
|
assert "server-1" in servers
|
|
assert "server-2" in servers
|
|
|
|
def test_get_server_config(self, reset_registry, sample_config):
|
|
"""Test getting server configuration."""
|
|
manager = MCPClientManager(config=sample_config)
|
|
|
|
config = manager.get_server_config("server-1")
|
|
assert config.url == "http://server1:8000"
|
|
assert config.timeout == 30
|
|
|
|
def test_get_server_config_not_found(self, reset_registry, sample_config):
|
|
"""Test getting unknown server config raises error."""
|
|
manager = MCPClientManager(config=sample_config)
|
|
|
|
with pytest.raises(MCPServerNotFoundError):
|
|
manager.get_server_config("unknown")
|
|
|
|
def test_register_server(self, reset_registry, sample_config):
|
|
"""Test registering new server at runtime."""
|
|
manager = MCPClientManager(config=sample_config)
|
|
|
|
new_config = MCPServerConfig(url="http://new:8000")
|
|
manager.register_server("new-server", new_config)
|
|
|
|
assert "new-server" in manager.list_servers()
|
|
|
|
def test_unregister_server(self, reset_registry, sample_config):
|
|
"""Test unregistering a server."""
|
|
manager = MCPClientManager(config=sample_config)
|
|
|
|
result = manager.unregister_server("server-1")
|
|
assert result is True
|
|
assert "server-1" not in manager.list_servers()
|
|
|
|
# Unregistering non-existent returns False
|
|
result = manager.unregister_server("nonexistent")
|
|
assert result is False
|
|
|
|
def test_circuit_breaker_status(self, reset_registry, sample_config):
|
|
"""Test getting circuit breaker status."""
|
|
manager = MCPClientManager(config=sample_config)
|
|
|
|
# No router yet
|
|
status = manager.get_circuit_breaker_status()
|
|
assert status == {}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reset_circuit_breaker(self, reset_registry, sample_config):
|
|
"""Test resetting circuit breaker."""
|
|
manager = MCPClientManager(config=sample_config)
|
|
|
|
# No router yet
|
|
result = await manager.reset_circuit_breaker("server-1")
|
|
assert result is False
|
|
|
|
|
|
class TestModuleLevelFunctions:
|
|
"""Tests for module-level convenience functions."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_mcp_client_creates_singleton(self, reset_registry):
|
|
"""Test get_mcp_client creates and returns singleton."""
|
|
with patch(
|
|
"app.services.mcp.client_manager.MCPClientManager.initialize"
|
|
) as mock_init:
|
|
mock_init.return_value = None
|
|
|
|
client1 = await get_mcp_client()
|
|
client2 = await get_mcp_client()
|
|
|
|
assert client1 is client2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_shutdown_mcp_client(self, reset_registry):
|
|
"""Test shutting down the global client."""
|
|
with patch(
|
|
"app.services.mcp.client_manager.MCPClientManager.initialize"
|
|
) as mock_init:
|
|
mock_init.return_value = None
|
|
|
|
client = await get_mcp_client()
|
|
|
|
with patch.object(client, "shutdown") as mock_shutdown:
|
|
mock_shutdown.return_value = None
|
|
await shutdown_mcp_client()
|
|
|
|
def test_reset_mcp_client(self, reset_registry):
|
|
"""Test resetting the global client."""
|
|
reset_mcp_client()
|
|
# Should not raise
|