""" Tests for MCP Tool Call Routing """ from unittest.mock import AsyncMock, MagicMock, patch import pytest from app.services.mcp.config import MCPConfig, MCPServerConfig from app.services.mcp.connection import ConnectionPool from app.services.mcp.exceptions import ( MCPCircuitOpenError, MCPToolError, MCPToolNotFoundError, ) from app.services.mcp.registry import MCPServerRegistry from app.services.mcp.routing import ToolInfo, ToolResult, ToolRouter @pytest.fixture def reset_registry(): """Reset the singleton registry before and after each test.""" MCPServerRegistry.reset_instance() yield MCPServerRegistry.reset_instance() @pytest.fixture def registry(reset_registry): """Create a configured registry.""" reg = MCPServerRegistry() reg.load_config( MCPConfig( mcp_servers={ "server-1": MCPServerConfig( url="http://server1:8000", retry_attempts=1, retry_delay=0.1, # Minimum allowed value circuit_breaker_threshold=3, circuit_breaker_timeout=5.0, ), "server-2": MCPServerConfig( url="http://server2:8000", retry_attempts=1, retry_delay=0.1, # Minimum allowed value ), } ) ) return reg @pytest.fixture def pool(): """Create a connection pool.""" return ConnectionPool() @pytest.fixture def router(registry, pool): """Create a tool router.""" return ToolRouter(registry, pool) class TestToolInfo: """Tests for ToolInfo dataclass.""" def test_basic_tool_info(self): """Test creating basic tool info.""" info = ToolInfo(name="test-tool") assert info.name == "test-tool" assert info.description is None assert info.server_name is None assert info.input_schema is None def test_full_tool_info(self): """Test creating full tool info.""" info = ToolInfo( name="create_issue", description="Create a new issue", server_name="issues", input_schema={"type": "object", "properties": {"title": {"type": "string"}}}, ) assert info.name == "create_issue" assert info.description == "Create a new issue" assert info.server_name == "issues" assert "properties" in info.input_schema def test_to_dict(self): """Test converting to dictionary.""" info = ToolInfo( name="test-tool", description="A test tool", server_name="test-server", ) result = info.to_dict() assert result["name"] == "test-tool" assert result["description"] == "A test tool" assert result["server_name"] == "test-server" class TestToolResult: """Tests for ToolResult dataclass.""" def test_success_result(self): """Test creating success result.""" result = ToolResult( success=True, data={"id": "123"}, tool_name="create_issue", server_name="issues", ) assert result.success is True assert result.data == {"id": "123"} assert result.error is None def test_error_result(self): """Test creating error result.""" result = ToolResult( success=False, error="Tool execution failed", error_code="INTERNAL_ERROR", tool_name="create_issue", server_name="issues", ) assert result.success is False assert result.error == "Tool execution failed" assert result.error_code == "INTERNAL_ERROR" def test_to_dict(self): """Test converting to dictionary.""" result = ToolResult( success=True, data={"result": "ok"}, tool_name="test", execution_time_ms=123.45, ) d = result.to_dict() assert d["success"] is True assert d["data"] == {"result": "ok"} assert d["tool_name"] == "test" assert d["execution_time_ms"] == 123.45 assert "request_id" in d # Auto-generated class TestToolRouter: """Tests for ToolRouter class.""" @pytest.mark.asyncio async def test_register_tool_mapping(self, router): """Test registering tool mappings.""" await router.register_tool_mapping("tool1", "server-1") await router.register_tool_mapping("tool2", "server-2") assert router.find_server_for_tool("tool1") == "server-1" assert router.find_server_for_tool("tool2") == "server-2" def test_find_server_for_unknown_tool(self, router): """Test finding server for unknown tool.""" result = router.find_server_for_tool("unknown-tool") assert result is None @pytest.mark.asyncio async def test_call_tool_success(self, router, registry): """Test successful tool call.""" # Set up capabilities registry.set_capabilities( "server-1", tools=[{"name": "test-tool"}], ) await router.register_tool_mapping("test-tool", "server-1") # Mock the pool connection and request mock_conn = AsyncMock() mock_conn.execute_request = AsyncMock( return_value={"result": {"status": "ok"}} ) mock_conn.is_connected = True with patch.object(router._pool, "get_connection", return_value=mock_conn): result = await router.call_tool( server_name="server-1", tool_name="test-tool", arguments={"param": "value"}, ) assert result.success is True assert result.data == {"status": "ok"} assert result.tool_name == "test-tool" assert result.server_name == "server-1" assert result.execution_time_ms > 0 @pytest.mark.asyncio async def test_call_tool_error_response(self, router, registry): """Test tool call with error response.""" registry.set_capabilities( "server-1", tools=[{"name": "test-tool"}], ) await router.register_tool_mapping("test-tool", "server-1") mock_conn = AsyncMock() mock_conn.execute_request = AsyncMock( return_value={ "error": { "code": -32000, "message": "Tool execution failed", } } ) mock_conn.is_connected = True with patch.object(router._pool, "get_connection", return_value=mock_conn): result = await router.call_tool( server_name="server-1", tool_name="test-tool", arguments={}, ) assert result.success is False assert "Tool execution failed" in result.error @pytest.mark.asyncio async def test_route_tool(self, router, registry): """Test routing tool to correct server.""" registry.set_capabilities( "server-1", tools=[{"name": "tool-on-server-1"}], ) await router.register_tool_mapping("tool-on-server-1", "server-1") mock_conn = AsyncMock() mock_conn.execute_request = AsyncMock( return_value={"result": "routed"} ) mock_conn.is_connected = True with patch.object(router._pool, "get_connection", return_value=mock_conn): result = await router.route_tool( tool_name="tool-on-server-1", arguments={"key": "value"}, ) assert result.success is True assert result.server_name == "server-1" @pytest.mark.asyncio async def test_route_tool_not_found(self, router): """Test routing unknown tool raises error.""" with pytest.raises(MCPToolNotFoundError) as exc_info: await router.route_tool( tool_name="unknown-tool", arguments={}, ) assert exc_info.value.tool_name == "unknown-tool" @pytest.mark.asyncio async def test_list_all_tools(self, router, registry): """Test listing all tools.""" registry.set_capabilities( "server-1", tools=[ {"name": "tool1", "description": "Tool 1"}, {"name": "tool2", "description": "Tool 2"}, ], ) registry.set_capabilities( "server-2", tools=[{"name": "tool3", "description": "Tool 3"}], ) tools = await router.list_all_tools() assert len(tools) == 3 tool_names = [t.name for t in tools] assert "tool1" in tool_names assert "tool2" in tool_names assert "tool3" in tool_names def test_circuit_breaker_status(self, router, registry): """Test getting circuit breaker status.""" # Initially no circuit breakers status = router.get_circuit_breaker_status() assert status == {} @pytest.mark.asyncio async def test_reset_circuit_breaker(self, router, registry): """Test resetting circuit breaker.""" # Reset non-existent returns False result = await router.reset_circuit_breaker("server-1") assert result is False @pytest.mark.asyncio async def test_discover_tools(self, router, registry): """Test tool discovery from servers.""" # Create mocks for different servers mock_conn_1 = AsyncMock() mock_conn_1.execute_request = AsyncMock( return_value={ "tools": [ {"name": "discovered-tool", "description": "A discovered tool"}, ] } ) mock_conn_1.server_name = "server-1" mock_conn_1.is_connected = True mock_conn_2 = AsyncMock() mock_conn_2.execute_request = AsyncMock( return_value={"tools": []} # Empty for server-2 ) mock_conn_2.server_name = "server-2" mock_conn_2.is_connected = True async def get_connection_side_effect(server_name, _config): if server_name == "server-1": return mock_conn_1 return mock_conn_2 with patch.object( router._pool, "get_connection", side_effect=get_connection_side_effect, ): await router.discover_tools() # Check that tool mapping was registered server = router.find_server_for_tool("discovered-tool") assert server == "server-1" def test_calculate_retry_delay(self, router, registry): """Test retry delay calculation.""" config = registry.get("server-1") delay1 = router._calculate_retry_delay(1, config) delay2 = router._calculate_retry_delay(2, config) delay3 = router._calculate_retry_delay(3, config) # Delays should increase with attempts assert delay1 > 0 # Allow for jitter variation assert delay1 <= config.retry_max_delay * 1.25