""" Tests for MCP API Routes """ from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi import status from fastapi.testclient import TestClient from app.main import app from app.models.user import User from app.services.mcp import ( MCPCircuitOpenError, MCPClientManager, MCPConnectionError, MCPServerNotFoundError, MCPTimeoutError, MCPToolNotFoundError, ServerHealth, ) from app.services.mcp.config import MCPServerConfig, TransportType from app.services.mcp.routing import ToolInfo, ToolResult @pytest.fixture def mock_mcp_client(): """Create a mock MCP client manager.""" client = MagicMock(spec=MCPClientManager) client.is_initialized = True return client @pytest.fixture def mock_superuser(): """Create a mock superuser.""" user = MagicMock(spec=User) user.id = "00000000-0000-0000-0000-000000000001" user.is_superuser = True user.email = "admin@example.com" return user @pytest.fixture def client(mock_mcp_client, mock_superuser): """Create a FastAPI test client with mocked dependencies.""" from app.api.dependencies.permissions import require_superuser from app.api.routes.mcp import get_mcp_client # Override dependencies async def override_get_mcp_client(): return mock_mcp_client async def override_require_superuser(): return mock_superuser app.dependency_overrides[get_mcp_client] = override_get_mcp_client app.dependency_overrides[require_superuser] = override_require_superuser with patch("app.main.check_database_health", return_value=True): yield TestClient(app) # Clean up app.dependency_overrides.clear() class TestListServers: """Tests for GET /mcp/servers endpoint.""" def test_list_servers_success(self, client, mock_mcp_client): """Test listing MCP servers returns correct data.""" # Setup mock mock_mcp_client.list_servers.return_value = ["server-1", "server-2"] mock_mcp_client.get_server_config.side_effect = [ MCPServerConfig( url="http://server1:8000", timeout=30, enabled=True, transport=TransportType.HTTP, description="Server 1", ), MCPServerConfig( url="http://server2:8000", timeout=60, enabled=True, transport=TransportType.SSE, description="Server 2", ), ] response = client.get("/api/v1/mcp/servers") assert response.status_code == status.HTTP_200_OK data = response.json() assert data["total"] == 2 assert len(data["servers"]) == 2 assert data["servers"][0]["name"] == "server-1" assert data["servers"][0]["url"] == "http://server1:8000" assert data["servers"][1]["name"] == "server-2" assert data["servers"][1]["transport"] == "sse" def test_list_servers_empty(self, client, mock_mcp_client): """Test listing servers when none are registered.""" mock_mcp_client.list_servers.return_value = [] response = client.get("/api/v1/mcp/servers") assert response.status_code == status.HTTP_200_OK data = response.json() assert data["total"] == 0 assert data["servers"] == [] def test_list_servers_handles_not_found(self, client, mock_mcp_client): """Test that missing server configs are skipped gracefully.""" mock_mcp_client.list_servers.return_value = ["server-1", "missing"] mock_mcp_client.get_server_config.side_effect = [ MCPServerConfig(url="http://server1:8000"), MCPServerNotFoundError(server_name="missing"), ] response = client.get("/api/v1/mcp/servers") assert response.status_code == status.HTTP_200_OK data = response.json() # Should only include the successfully retrieved server assert data["total"] == 1 class TestListServerTools: """Tests for GET /mcp/servers/{server_name}/tools endpoint.""" def test_list_server_tools_success(self, client, mock_mcp_client): """Test listing tools for a specific server.""" mock_mcp_client.list_tools = AsyncMock( return_value=[ ToolInfo(name="tool1", description="Tool 1", server_name="server-1"), ToolInfo(name="tool2", description="Tool 2", server_name="server-1"), ] ) response = client.get("/api/v1/mcp/servers/server-1/tools") assert response.status_code == status.HTTP_200_OK data = response.json() assert data["total"] == 2 assert data["tools"][0]["name"] == "tool1" assert data["tools"][1]["name"] == "tool2" def test_list_server_tools_not_found(self, client, mock_mcp_client): """Test listing tools for non-existent server.""" mock_mcp_client.list_tools = AsyncMock( side_effect=MCPServerNotFoundError(server_name="unknown") ) response = client.get("/api/v1/mcp/servers/unknown/tools") assert response.status_code == status.HTTP_404_NOT_FOUND class TestListAllTools: """Tests for GET /mcp/tools endpoint.""" def test_list_all_tools_success(self, client, mock_mcp_client): """Test listing all tools from all servers.""" mock_mcp_client.list_all_tools = AsyncMock( return_value=[ ToolInfo(name="tool1", server_name="server-1"), ToolInfo(name="tool2", server_name="server-1"), ToolInfo(name="tool3", server_name="server-2"), ] ) response = client.get("/api/v1/mcp/tools") assert response.status_code == status.HTTP_200_OK data = response.json() assert data["total"] == 3 def test_list_all_tools_empty(self, client, mock_mcp_client): """Test listing tools when none are available.""" mock_mcp_client.list_all_tools = AsyncMock(return_value=[]) response = client.get("/api/v1/mcp/tools") assert response.status_code == status.HTTP_200_OK data = response.json() assert data["total"] == 0 class TestHealthCheck: """Tests for GET /mcp/health endpoint.""" def test_health_check_success(self, client, mock_mcp_client): """Test health check returns correct data.""" mock_mcp_client.health_check = AsyncMock( return_value={ "server-1": ServerHealth( name="server-1", healthy=True, state="connected", url="http://server1:8000", tools_count=5, ), "server-2": ServerHealth( name="server-2", healthy=False, state="error", url="http://server2:8000", error="Connection refused", ), } ) response = client.get("/api/v1/mcp/health") assert response.status_code == status.HTTP_200_OK data = response.json() assert data["total"] == 2 assert data["healthy_count"] == 1 assert data["unhealthy_count"] == 1 assert data["servers"]["server-1"]["healthy"] is True assert data["servers"]["server-2"]["healthy"] is False def test_health_check_all_healthy(self, client, mock_mcp_client): """Test health check when all servers are healthy.""" mock_mcp_client.health_check = AsyncMock( return_value={ "server-1": ServerHealth( name="server-1", healthy=True, state="connected", url="http://server1:8000", ), } ) response = client.get("/api/v1/mcp/health") assert response.status_code == status.HTTP_200_OK data = response.json() assert data["healthy_count"] == 1 assert data["unhealthy_count"] == 0 class TestCallTool: """Tests for POST /mcp/call endpoint.""" def test_call_tool_success(self, client, mock_mcp_client): """Test successful tool execution.""" mock_mcp_client.call_tool = AsyncMock( return_value=ToolResult( success=True, data={"result": "ok"}, tool_name="test-tool", server_name="server-1", execution_time_ms=123.45, request_id="test-request-id", ) ) response = client.post( "/api/v1/mcp/call", json={ "server": "server-1", "tool": "test-tool", "arguments": {"key": "value"}, }, ) assert response.status_code == status.HTTP_200_OK data = response.json() assert data["success"] is True assert data["data"] == {"result": "ok"} assert data["tool_name"] == "test-tool" assert data["server_name"] == "server-1" def test_call_tool_with_timeout(self, client, mock_mcp_client): """Test tool execution with custom timeout.""" mock_mcp_client.call_tool = AsyncMock( return_value=ToolResult(success=True, data={}) ) response = client.post( "/api/v1/mcp/call", json={ "server": "server-1", "tool": "test-tool", "timeout": 60.0, }, ) assert response.status_code == status.HTTP_200_OK mock_mcp_client.call_tool.assert_called_once() call_args = mock_mcp_client.call_tool.call_args assert call_args.kwargs["timeout"] == 60.0 def test_call_tool_server_not_found(self, client, mock_mcp_client): """Test tool execution with non-existent server.""" mock_mcp_client.call_tool = AsyncMock( side_effect=MCPServerNotFoundError(server_name="unknown") ) response = client.post( "/api/v1/mcp/call", json={"server": "unknown", "tool": "test-tool"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND def test_call_tool_not_found(self, client, mock_mcp_client): """Test tool execution with non-existent tool.""" mock_mcp_client.call_tool = AsyncMock( side_effect=MCPToolNotFoundError(tool_name="unknown") ) response = client.post( "/api/v1/mcp/call", json={"server": "server-1", "tool": "unknown"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND def test_call_tool_timeout(self, client, mock_mcp_client): """Test tool execution timeout.""" mock_mcp_client.call_tool = AsyncMock( side_effect=MCPTimeoutError( "Request timed out", server_name="server-1", timeout_seconds=30.0, ) ) response = client.post( "/api/v1/mcp/call", json={"server": "server-1", "tool": "slow-tool"}, ) assert response.status_code == status.HTTP_504_GATEWAY_TIMEOUT def test_call_tool_connection_error(self, client, mock_mcp_client): """Test tool execution with connection failure.""" mock_mcp_client.call_tool = AsyncMock( side_effect=MCPConnectionError( "Connection refused", server_name="server-1", ) ) response = client.post( "/api/v1/mcp/call", json={"server": "server-1", "tool": "test-tool"}, ) assert response.status_code == status.HTTP_502_BAD_GATEWAY def test_call_tool_circuit_open(self, client, mock_mcp_client): """Test tool execution with open circuit breaker.""" mock_mcp_client.call_tool = AsyncMock( side_effect=MCPCircuitOpenError( server_name="server-1", failure_count=5, ) ) response = client.post( "/api/v1/mcp/call", json={"server": "server-1", "tool": "test-tool"}, ) assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE class TestCircuitBreakers: """Tests for circuit breaker endpoints.""" def test_list_circuit_breakers(self, client, mock_mcp_client): """Test listing circuit breaker statuses.""" mock_mcp_client.get_circuit_breaker_status.return_value = { "server-1": {"state": "closed", "failure_count": 0}, "server-2": {"state": "open", "failure_count": 5}, } response = client.get("/api/v1/mcp/circuit-breakers") assert response.status_code == status.HTTP_200_OK data = response.json() assert len(data["circuit_breakers"]) == 2 def test_list_circuit_breakers_empty(self, client, mock_mcp_client): """Test listing when no circuit breakers exist.""" mock_mcp_client.get_circuit_breaker_status.return_value = {} response = client.get("/api/v1/mcp/circuit-breakers") assert response.status_code == status.HTTP_200_OK data = response.json() assert data["circuit_breakers"] == [] def test_reset_circuit_breaker_success(self, client, mock_mcp_client): """Test successfully resetting a circuit breaker.""" mock_mcp_client.reset_circuit_breaker = AsyncMock(return_value=True) response = client.post("/api/v1/mcp/circuit-breakers/server-1/reset") assert response.status_code == status.HTTP_204_NO_CONTENT def test_reset_circuit_breaker_not_found(self, client, mock_mcp_client): """Test resetting non-existent circuit breaker.""" mock_mcp_client.reset_circuit_breaker = AsyncMock(return_value=False) response = client.post("/api/v1/mcp/circuit-breakers/unknown/reset") assert response.status_code == status.HTTP_404_NOT_FOUND class TestReconnectServer: """Tests for POST /mcp/servers/{server_name}/reconnect endpoint.""" def test_reconnect_success(self, client, mock_mcp_client): """Test successful server reconnection.""" mock_mcp_client.disconnect = AsyncMock() mock_mcp_client.connect = AsyncMock() response = client.post("/api/v1/mcp/servers/server-1/reconnect") assert response.status_code == status.HTTP_204_NO_CONTENT mock_mcp_client.disconnect.assert_called_once_with("server-1") mock_mcp_client.connect.assert_called_once_with("server-1") def test_reconnect_server_not_found(self, client, mock_mcp_client): """Test reconnecting to non-existent server.""" mock_mcp_client.disconnect = AsyncMock( side_effect=MCPServerNotFoundError(server_name="unknown") ) response = client.post("/api/v1/mcp/servers/unknown/reconnect") assert response.status_code == status.HTTP_404_NOT_FOUND def test_reconnect_connection_failure(self, client, mock_mcp_client): """Test reconnection failure.""" mock_mcp_client.disconnect = AsyncMock() mock_mcp_client.connect = AsyncMock( side_effect=MCPConnectionError( "Connection refused", server_name="server-1", ) ) response = client.post("/api/v1/mcp/servers/server-1/reconnect") assert response.status_code == status.HTTP_502_BAD_GATEWAY class TestMCPEndpointsEdgeCases: """Edge case tests for MCP endpoints.""" def test_servers_content_type(self, client, mock_mcp_client): """Test that endpoints return JSON content type.""" mock_mcp_client.list_servers.return_value = [] response = client.get("/api/v1/mcp/servers") assert "application/json" in response.headers["content-type"] def test_call_tool_validation_error(self, client): """Test that invalid request body returns validation error.""" response = client.post( "/api/v1/mcp/call", json={}, # Missing required fields ) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY def test_call_tool_missing_server(self, client): """Test that missing server field returns validation error.""" response = client.post( "/api/v1/mcp/call", json={"tool": "test-tool"}, ) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY def test_call_tool_missing_tool(self, client): """Test that missing tool field returns validation error.""" response = client.post( "/api/v1/mcp/call", json={"server": "server-1"}, ) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY