""" Tests for MCP Client Manager """ from unittest.mock import AsyncMock, MagicMock, patch import pytest from app.services.mcp.client_manager import ( MCPClientManager, ServerHealth, get_mcp_client, reset_mcp_client, shutdown_mcp_client, ) from app.services.mcp.config import MCPConfig, MCPServerConfig from app.services.mcp.connection import ConnectionState from app.services.mcp.exceptions import MCPServerNotFoundError from app.services.mcp.registry import MCPServerRegistry from app.services.mcp.routing import ToolInfo, ToolResult @pytest.fixture def reset_registry(): """Reset the singleton registry before and after each test.""" MCPServerRegistry.reset_instance() reset_mcp_client() yield MCPServerRegistry.reset_instance() reset_mcp_client() @pytest.fixture def sample_config(): """Create a sample MCP configuration.""" return MCPConfig( mcp_servers={ "server-1": MCPServerConfig( url="http://server1:8000", timeout=30, enabled=True, ), "server-2": MCPServerConfig( url="http://server2:8000", timeout=60, enabled=True, ), } ) class TestServerHealth: """Tests for ServerHealth dataclass.""" def test_healthy_server(self): """Test healthy server status.""" health = ServerHealth( name="test-server", healthy=True, state="connected", url="http://test:8000", tools_count=5, ) assert health.healthy is True assert health.error is None assert health.tools_count == 5 def test_unhealthy_server(self): """Test unhealthy server status.""" health = ServerHealth( name="test-server", healthy=False, state="error", url="http://test:8000", error="Connection refused", ) assert health.healthy is False assert health.error == "Connection refused" def test_to_dict(self): """Test converting to dictionary.""" health = ServerHealth( name="test-server", healthy=True, state="connected", url="http://test:8000", tools_count=3, ) d = health.to_dict() assert d["name"] == "test-server" assert d["healthy"] is True assert d["state"] == "connected" assert d["url"] == "http://test:8000" assert d["tools_count"] == 3 class TestMCPClientManager: """Tests for MCPClientManager class.""" def test_initial_state(self, reset_registry): """Test initial manager state.""" manager = MCPClientManager() assert manager.is_initialized is False @pytest.mark.asyncio async def test_initialize(self, reset_registry, sample_config): """Test manager initialization.""" manager = MCPClientManager(config=sample_config) with patch.object(manager._pool, "get_connection") as mock_get_conn: mock_conn = AsyncMock() mock_conn.is_connected = True mock_get_conn.return_value = mock_conn with patch.object(manager, "_router") as mock_router: mock_router.discover_tools = AsyncMock() await manager.initialize() assert manager.is_initialized is True @pytest.mark.asyncio async def test_shutdown(self, reset_registry, sample_config): """Test manager shutdown.""" manager = MCPClientManager(config=sample_config) manager._initialized = True with patch.object(manager._pool, "close_all") as mock_close: mock_close.return_value = None await manager.shutdown() assert manager.is_initialized is False mock_close.assert_called_once() @pytest.mark.asyncio async def test_connect(self, reset_registry, sample_config): """Test connecting to specific server.""" manager = MCPClientManager(config=sample_config) with patch.object(manager._pool, "get_connection") as mock_get_conn: mock_conn = AsyncMock() mock_conn.is_connected = True mock_get_conn.return_value = mock_conn await manager.connect("server-1") mock_get_conn.assert_called_once() @pytest.mark.asyncio async def test_connect_unknown_server(self, reset_registry, sample_config): """Test connecting to unknown server raises error.""" manager = MCPClientManager(config=sample_config) with pytest.raises(MCPServerNotFoundError): await manager.connect("unknown-server") @pytest.mark.asyncio async def test_disconnect(self, reset_registry, sample_config): """Test disconnecting from server.""" manager = MCPClientManager(config=sample_config) with patch.object(manager._pool, "close_connection") as mock_close: await manager.disconnect("server-1") mock_close.assert_called_once_with("server-1") @pytest.mark.asyncio async def test_disconnect_all(self, reset_registry, sample_config): """Test disconnecting from all servers.""" manager = MCPClientManager(config=sample_config) with patch.object(manager._pool, "close_all") as mock_close: await manager.disconnect_all() mock_close.assert_called_once() @pytest.mark.asyncio async def test_call_tool(self, reset_registry, sample_config): """Test calling a tool.""" manager = MCPClientManager(config=sample_config) manager._initialized = True expected_result = ToolResult( success=True, data={"id": "123"}, tool_name="create_issue", server_name="server-1", ) mock_router = MagicMock() mock_router.call_tool = AsyncMock(return_value=expected_result) manager._router = mock_router result = await manager.call_tool( server="server-1", tool="create_issue", args={"title": "Test"}, ) assert result.success is True assert result.data == {"id": "123"} mock_router.call_tool.assert_called_once() @pytest.mark.asyncio async def test_route_tool(self, reset_registry, sample_config): """Test routing a tool call.""" manager = MCPClientManager(config=sample_config) manager._initialized = True expected_result = ToolResult( success=True, data={"result": "ok"}, tool_name="auto_tool", server_name="server-2", ) mock_router = MagicMock() mock_router.route_tool = AsyncMock(return_value=expected_result) manager._router = mock_router result = await manager.route_tool( tool="auto_tool", args={"key": "value"}, ) assert result.success is True assert result.server_name == "server-2" mock_router.route_tool.assert_called_once() @pytest.mark.asyncio async def test_list_tools(self, reset_registry, sample_config): """Test listing tools for a server.""" manager = MCPClientManager(config=sample_config) # Set up capabilities in registry manager._registry.set_capabilities( "server-1", tools=[ {"name": "tool1", "description": "Tool 1"}, {"name": "tool2", "description": "Tool 2"}, ], ) tools = await manager.list_tools("server-1") assert len(tools) == 2 assert tools[0].name == "tool1" assert tools[1].name == "tool2" @pytest.mark.asyncio async def test_list_all_tools(self, reset_registry, sample_config): """Test listing all tools from all servers.""" manager = MCPClientManager(config=sample_config) manager._initialized = True expected_tools = [ ToolInfo(name="tool1", server_name="server-1"), ToolInfo(name="tool2", server_name="server-2"), ] mock_router = MagicMock() mock_router.list_all_tools = AsyncMock(return_value=expected_tools) manager._router = mock_router tools = await manager.list_all_tools() assert len(tools) == 2 @pytest.mark.asyncio async def test_health_check(self, reset_registry, sample_config): """Test health check on all servers.""" manager = MCPClientManager(config=sample_config) with patch.object(manager._pool, "get_status") as mock_status: mock_status.return_value = { "server-1": {"state": "connected"}, "server-2": {"state": "disconnected"}, } with patch.object(manager._pool, "health_check_all") as mock_health: mock_health.return_value = { "server-1": True, "server-2": False, } health = await manager.health_check() assert "server-1" in health assert "server-2" in health assert health["server-1"].healthy is True assert health["server-2"].healthy is False def test_list_servers(self, reset_registry, sample_config): """Test listing registered servers.""" manager = MCPClientManager(config=sample_config) servers = manager.list_servers() assert "server-1" in servers assert "server-2" in servers def test_list_enabled_servers(self, reset_registry, sample_config): """Test listing enabled servers.""" manager = MCPClientManager(config=sample_config) servers = manager.list_enabled_servers() assert "server-1" in servers assert "server-2" in servers def test_get_server_config(self, reset_registry, sample_config): """Test getting server configuration.""" manager = MCPClientManager(config=sample_config) config = manager.get_server_config("server-1") assert config.url == "http://server1:8000" assert config.timeout == 30 def test_get_server_config_not_found(self, reset_registry, sample_config): """Test getting unknown server config raises error.""" manager = MCPClientManager(config=sample_config) with pytest.raises(MCPServerNotFoundError): manager.get_server_config("unknown") def test_register_server(self, reset_registry, sample_config): """Test registering new server at runtime.""" manager = MCPClientManager(config=sample_config) new_config = MCPServerConfig(url="http://new:8000") manager.register_server("new-server", new_config) assert "new-server" in manager.list_servers() def test_unregister_server(self, reset_registry, sample_config): """Test unregistering a server.""" manager = MCPClientManager(config=sample_config) result = manager.unregister_server("server-1") assert result is True assert "server-1" not in manager.list_servers() # Unregistering non-existent returns False result = manager.unregister_server("nonexistent") assert result is False def test_circuit_breaker_status(self, reset_registry, sample_config): """Test getting circuit breaker status.""" manager = MCPClientManager(config=sample_config) # No router yet status = manager.get_circuit_breaker_status() assert status == {} @pytest.mark.asyncio async def test_reset_circuit_breaker(self, reset_registry, sample_config): """Test resetting circuit breaker.""" manager = MCPClientManager(config=sample_config) # No router yet result = await manager.reset_circuit_breaker("server-1") assert result is False class TestModuleLevelFunctions: """Tests for module-level convenience functions.""" @pytest.mark.asyncio async def test_get_mcp_client_creates_singleton(self, reset_registry): """Test get_mcp_client creates and returns singleton.""" with patch( "app.services.mcp.client_manager.MCPClientManager.initialize" ) as mock_init: mock_init.return_value = None client1 = await get_mcp_client() client2 = await get_mcp_client() assert client1 is client2 @pytest.mark.asyncio async def test_shutdown_mcp_client(self, reset_registry): """Test shutting down the global client.""" with patch( "app.services.mcp.client_manager.MCPClientManager.initialize" ) as mock_init: mock_init.return_value = None client = await get_mcp_client() with patch.object(client, "shutdown") as mock_shutdown: mock_shutdown.return_value = None await shutdown_mcp_client() def test_reset_mcp_client(self, reset_registry): """Test resetting the global client.""" reset_mcp_client() # Should not raise