""" Tests for MCP Connection Management """ from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest from app.services.mcp.config import MCPServerConfig, TransportType from app.services.mcp.connection import ( ConnectionPool, ConnectionState, MCPConnection, ) from app.services.mcp.exceptions import MCPConnectionError, MCPTimeoutError @pytest.fixture def server_config(): """Create a sample server configuration.""" return MCPServerConfig( url="http://localhost:8000", transport=TransportType.HTTP, timeout=30, retry_attempts=3, retry_delay=0.1, # Short delay for tests retry_max_delay=1.0, ) class TestConnectionState: """Tests for ConnectionState enum.""" def test_connection_states(self): """Test all connection states are defined.""" assert ConnectionState.DISCONNECTED == "disconnected" assert ConnectionState.CONNECTING == "connecting" assert ConnectionState.CONNECTED == "connected" assert ConnectionState.RECONNECTING == "reconnecting" assert ConnectionState.ERROR == "error" class TestMCPConnection: """Tests for MCPConnection class.""" def test_initial_state(self, server_config): """Test initial connection state.""" conn = MCPConnection("test-server", server_config) assert conn.server_name == "test-server" assert conn.config == server_config assert conn.state == ConnectionState.DISCONNECTED assert conn.is_connected is False assert conn.last_error is None @pytest.mark.asyncio async def test_connect_success(self, server_config): """Test successful connection.""" conn = MCPConnection("test-server", server_config) mock_response = MagicMock() mock_response.status_code = 200 with patch("httpx.AsyncClient") as MockClient: mock_client = AsyncMock() mock_client.get = AsyncMock(return_value=mock_response) MockClient.return_value = mock_client await conn.connect() assert conn.state == ConnectionState.CONNECTED assert conn.is_connected is True @pytest.mark.asyncio async def test_connect_404_capabilities_ok(self, server_config): """Test connection succeeds even if /mcp/capabilities returns 404.""" conn = MCPConnection("test-server", server_config) mock_response = MagicMock() mock_response.status_code = 404 with patch("httpx.AsyncClient") as MockClient: mock_client = AsyncMock() mock_client.get = AsyncMock(return_value=mock_response) MockClient.return_value = mock_client await conn.connect() assert conn.state == ConnectionState.CONNECTED @pytest.mark.asyncio async def test_connect_failure_with_retry(self, server_config): """Test connection failure with retries.""" # Reduce retry attempts for faster test server_config.retry_attempts = 2 server_config.retry_delay = 0.01 conn = MCPConnection("test-server", server_config) with patch("httpx.AsyncClient") as MockClient: mock_client = AsyncMock() mock_client.get = AsyncMock( side_effect=httpx.ConnectError("Connection refused") ) MockClient.return_value = mock_client with pytest.raises(MCPConnectionError) as exc_info: await conn.connect() assert conn.state == ConnectionState.ERROR assert "Failed to connect after" in str(exc_info.value) @pytest.mark.asyncio async def test_disconnect(self, server_config): """Test disconnection.""" conn = MCPConnection("test-server", server_config) mock_response = MagicMock() mock_response.status_code = 200 with patch("httpx.AsyncClient") as MockClient: mock_client = AsyncMock() mock_client.get = AsyncMock(return_value=mock_response) mock_client.aclose = AsyncMock() MockClient.return_value = mock_client await conn.connect() assert conn.is_connected is True await conn.disconnect() assert conn.state == ConnectionState.DISCONNECTED assert conn.is_connected is False @pytest.mark.asyncio async def test_reconnect(self, server_config): """Test reconnection.""" conn = MCPConnection("test-server", server_config) mock_response = MagicMock() mock_response.status_code = 200 with patch("httpx.AsyncClient") as MockClient: mock_client = AsyncMock() mock_client.get = AsyncMock(return_value=mock_response) mock_client.aclose = AsyncMock() MockClient.return_value = mock_client await conn.connect() assert conn.is_connected is True await conn.reconnect() assert conn.is_connected is True @pytest.mark.asyncio async def test_health_check_success(self, server_config): """Test successful health check.""" conn = MCPConnection("test-server", server_config) mock_response = MagicMock() mock_response.status_code = 200 with patch("httpx.AsyncClient") as MockClient: mock_client = AsyncMock() mock_client.get = AsyncMock(return_value=mock_response) MockClient.return_value = mock_client await conn.connect() healthy = await conn.health_check() assert healthy is True @pytest.mark.asyncio async def test_health_check_disconnected(self, server_config): """Test health check when disconnected.""" conn = MCPConnection("test-server", server_config) healthy = await conn.health_check() assert healthy is False @pytest.mark.asyncio async def test_execute_request_get(self, server_config): """Test executing GET request.""" conn = MCPConnection("test-server", server_config) mock_connect_response = MagicMock() mock_connect_response.status_code = 200 mock_request_response = MagicMock() mock_request_response.status_code = 200 mock_request_response.json.return_value = {"tools": []} mock_request_response.raise_for_status = MagicMock() with patch("httpx.AsyncClient") as MockClient: mock_client = AsyncMock() mock_client.get = AsyncMock( side_effect=[mock_connect_response, mock_request_response] ) MockClient.return_value = mock_client await conn.connect() result = await conn.execute_request("GET", "/mcp/tools") assert result == {"tools": []} @pytest.mark.asyncio async def test_execute_request_post(self, server_config): """Test executing POST request.""" conn = MCPConnection("test-server", server_config) mock_connect_response = MagicMock() mock_connect_response.status_code = 200 mock_request_response = MagicMock() mock_request_response.status_code = 200 mock_request_response.json.return_value = {"result": "success"} mock_request_response.raise_for_status = MagicMock() with patch("httpx.AsyncClient") as MockClient: mock_client = AsyncMock() mock_client.get = AsyncMock(return_value=mock_connect_response) mock_client.post = AsyncMock(return_value=mock_request_response) MockClient.return_value = mock_client await conn.connect() result = await conn.execute_request("POST", "/mcp", data={"method": "test"}) assert result == {"result": "success"} @pytest.mark.asyncio async def test_execute_request_timeout(self, server_config): """Test request timeout.""" conn = MCPConnection("test-server", server_config) mock_connect_response = MagicMock() mock_connect_response.status_code = 200 with patch("httpx.AsyncClient") as MockClient: mock_client = AsyncMock() mock_client.get = AsyncMock( side_effect=[ mock_connect_response, httpx.TimeoutException("Request timeout"), ] ) MockClient.return_value = mock_client await conn.connect() with pytest.raises(MCPTimeoutError) as exc_info: await conn.execute_request("GET", "/slow-endpoint") assert "timeout" in str(exc_info.value).lower() @pytest.mark.asyncio async def test_execute_request_not_connected(self, server_config): """Test request when not connected.""" conn = MCPConnection("test-server", server_config) with pytest.raises(MCPConnectionError) as exc_info: await conn.execute_request("GET", "/test") assert "Not connected" in str(exc_info.value) def test_backoff_delay_calculation(self, server_config): """Test exponential backoff delay calculation.""" conn = MCPConnection("test-server", server_config) # First attempt conn._connection_attempts = 1 delay1 = conn._calculate_backoff_delay() # Second attempt conn._connection_attempts = 2 delay2 = conn._calculate_backoff_delay() # Third attempt conn._connection_attempts = 3 delay3 = conn._calculate_backoff_delay() # Delays should generally increase (with some jitter) # Base is 0.1, so rough expectations: # Attempt 1: ~0.1s # Attempt 2: ~0.2s # Attempt 3: ~0.4s assert delay1 > 0 assert delay2 > delay1 * 0.5 # Allow for jitter assert delay3 <= server_config.retry_max_delay * 1.25 # Within max + jitter class TestConnectionPool: """Tests for ConnectionPool class.""" @pytest.fixture def pool(self): """Create a connection pool.""" return ConnectionPool(max_connections_per_server=5) @pytest.mark.asyncio async def test_get_connection_creates_new(self, pool, server_config): """Test getting connection creates new one if not exists.""" mock_response = MagicMock() mock_response.status_code = 200 with patch("httpx.AsyncClient") as MockClient: mock_client = AsyncMock() mock_client.get = AsyncMock(return_value=mock_response) MockClient.return_value = mock_client conn = await pool.get_connection("test-server", server_config) assert conn is not None assert conn.is_connected is True assert conn.server_name == "test-server" @pytest.mark.asyncio async def test_get_connection_reuses_existing(self, pool, server_config): """Test getting connection reuses existing one.""" mock_response = MagicMock() mock_response.status_code = 200 with patch("httpx.AsyncClient") as MockClient: mock_client = AsyncMock() mock_client.get = AsyncMock(return_value=mock_response) MockClient.return_value = mock_client conn1 = await pool.get_connection("test-server", server_config) conn2 = await pool.get_connection("test-server", server_config) assert conn1 is conn2 @pytest.mark.asyncio async def test_close_connection(self, pool, server_config): """Test closing a connection.""" mock_response = MagicMock() mock_response.status_code = 200 with patch("httpx.AsyncClient") as MockClient: mock_client = AsyncMock() mock_client.get = AsyncMock(return_value=mock_response) mock_client.aclose = AsyncMock() MockClient.return_value = mock_client await pool.get_connection("test-server", server_config) await pool.close_connection("test-server") assert "test-server" not in pool._connections @pytest.mark.asyncio async def test_close_all(self, pool, server_config): """Test closing all connections.""" mock_response = MagicMock() mock_response.status_code = 200 with patch("httpx.AsyncClient") as MockClient: mock_client = AsyncMock() mock_client.get = AsyncMock(return_value=mock_response) mock_client.aclose = AsyncMock() MockClient.return_value = mock_client config2 = MCPServerConfig(url="http://server2:8000") await pool.get_connection("server-1", server_config) await pool.get_connection("server-2", config2) assert len(pool._connections) == 2 await pool.close_all() assert len(pool._connections) == 0 @pytest.mark.asyncio async def test_health_check_all(self, pool, server_config): """Test health check on all connections.""" mock_response = MagicMock() mock_response.status_code = 200 with patch("httpx.AsyncClient") as MockClient: mock_client = AsyncMock() mock_client.get = AsyncMock(return_value=mock_response) MockClient.return_value = mock_client await pool.get_connection("test-server", server_config) results = await pool.health_check_all() assert "test-server" in results assert results["test-server"] is True def test_get_status(self, pool, server_config): """Test getting pool status.""" status = pool.get_status() assert status == {} @pytest.mark.asyncio async def test_connection_context_manager(self, pool, server_config): """Test connection context manager.""" mock_response = MagicMock() mock_response.status_code = 200 with patch("httpx.AsyncClient") as MockClient: mock_client = AsyncMock() mock_client.get = AsyncMock(return_value=mock_response) MockClient.return_value = mock_client async with pool.connection("test-server", server_config) as conn: assert conn.is_connected is True assert conn.server_name == "test-server"