The delay2 and delay3 variables were calculated but never asserted, causing lint warnings. Added assertions to verify all delays are positive and within max bounds.
347 lines
11 KiB
Python
347 lines
11 KiB
Python
"""
|
|
Tests for MCP Tool Call Routing
|
|
"""
|
|
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
|
|
from app.services.mcp.config import MCPConfig, MCPServerConfig
|
|
from app.services.mcp.connection import ConnectionPool
|
|
from app.services.mcp.exceptions import (
|
|
MCPToolNotFoundError,
|
|
)
|
|
from app.services.mcp.registry import MCPServerRegistry
|
|
from app.services.mcp.routing import ToolInfo, ToolResult, ToolRouter
|
|
|
|
|
|
@pytest.fixture
|
|
def reset_registry():
|
|
"""Reset the singleton registry before and after each test."""
|
|
MCPServerRegistry.reset_instance()
|
|
yield
|
|
MCPServerRegistry.reset_instance()
|
|
|
|
|
|
@pytest.fixture
|
|
def registry(reset_registry):
|
|
"""Create a configured registry."""
|
|
reg = MCPServerRegistry()
|
|
reg.load_config(
|
|
MCPConfig(
|
|
mcp_servers={
|
|
"server-1": MCPServerConfig(
|
|
url="http://server1:8000",
|
|
retry_attempts=1,
|
|
retry_delay=0.1, # Minimum allowed value
|
|
circuit_breaker_threshold=3,
|
|
circuit_breaker_timeout=5.0,
|
|
),
|
|
"server-2": MCPServerConfig(
|
|
url="http://server2:8000",
|
|
retry_attempts=1,
|
|
retry_delay=0.1, # Minimum allowed value
|
|
),
|
|
}
|
|
)
|
|
)
|
|
return reg
|
|
|
|
|
|
@pytest.fixture
|
|
def pool():
|
|
"""Create a connection pool."""
|
|
return ConnectionPool()
|
|
|
|
|
|
@pytest.fixture
|
|
def router(registry, pool):
|
|
"""Create a tool router."""
|
|
return ToolRouter(registry, pool)
|
|
|
|
|
|
class TestToolInfo:
|
|
"""Tests for ToolInfo dataclass."""
|
|
|
|
def test_basic_tool_info(self):
|
|
"""Test creating basic tool info."""
|
|
info = ToolInfo(name="test-tool")
|
|
assert info.name == "test-tool"
|
|
assert info.description is None
|
|
assert info.server_name is None
|
|
assert info.input_schema is None
|
|
|
|
def test_full_tool_info(self):
|
|
"""Test creating full tool info."""
|
|
info = ToolInfo(
|
|
name="create_issue",
|
|
description="Create a new issue",
|
|
server_name="issues",
|
|
input_schema={
|
|
"type": "object",
|
|
"properties": {"title": {"type": "string"}},
|
|
},
|
|
)
|
|
assert info.name == "create_issue"
|
|
assert info.description == "Create a new issue"
|
|
assert info.server_name == "issues"
|
|
assert "properties" in info.input_schema
|
|
|
|
def test_to_dict(self):
|
|
"""Test converting to dictionary."""
|
|
info = ToolInfo(
|
|
name="test-tool",
|
|
description="A test tool",
|
|
server_name="test-server",
|
|
)
|
|
result = info.to_dict()
|
|
|
|
assert result["name"] == "test-tool"
|
|
assert result["description"] == "A test tool"
|
|
assert result["server_name"] == "test-server"
|
|
|
|
|
|
class TestToolResult:
|
|
"""Tests for ToolResult dataclass."""
|
|
|
|
def test_success_result(self):
|
|
"""Test creating success result."""
|
|
result = ToolResult(
|
|
success=True,
|
|
data={"id": "123"},
|
|
tool_name="create_issue",
|
|
server_name="issues",
|
|
)
|
|
assert result.success is True
|
|
assert result.data == {"id": "123"}
|
|
assert result.error is None
|
|
|
|
def test_error_result(self):
|
|
"""Test creating error result."""
|
|
result = ToolResult(
|
|
success=False,
|
|
error="Tool execution failed",
|
|
error_code="INTERNAL_ERROR",
|
|
tool_name="create_issue",
|
|
server_name="issues",
|
|
)
|
|
assert result.success is False
|
|
assert result.error == "Tool execution failed"
|
|
assert result.error_code == "INTERNAL_ERROR"
|
|
|
|
def test_to_dict(self):
|
|
"""Test converting to dictionary."""
|
|
result = ToolResult(
|
|
success=True,
|
|
data={"result": "ok"},
|
|
tool_name="test",
|
|
execution_time_ms=123.45,
|
|
)
|
|
d = result.to_dict()
|
|
|
|
assert d["success"] is True
|
|
assert d["data"] == {"result": "ok"}
|
|
assert d["tool_name"] == "test"
|
|
assert d["execution_time_ms"] == 123.45
|
|
assert "request_id" in d # Auto-generated
|
|
|
|
|
|
class TestToolRouter:
|
|
"""Tests for ToolRouter class."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_register_tool_mapping(self, router):
|
|
"""Test registering tool mappings."""
|
|
await router.register_tool_mapping("tool1", "server-1")
|
|
await router.register_tool_mapping("tool2", "server-2")
|
|
|
|
assert router.find_server_for_tool("tool1") == "server-1"
|
|
assert router.find_server_for_tool("tool2") == "server-2"
|
|
|
|
def test_find_server_for_unknown_tool(self, router):
|
|
"""Test finding server for unknown tool."""
|
|
result = router.find_server_for_tool("unknown-tool")
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_call_tool_success(self, router, registry):
|
|
"""Test successful tool call."""
|
|
# Set up capabilities
|
|
registry.set_capabilities(
|
|
"server-1",
|
|
tools=[{"name": "test-tool"}],
|
|
)
|
|
await router.register_tool_mapping("test-tool", "server-1")
|
|
|
|
# Mock the pool connection and request
|
|
mock_conn = AsyncMock()
|
|
mock_conn.execute_request = AsyncMock(return_value={"result": {"status": "ok"}})
|
|
mock_conn.is_connected = True
|
|
|
|
with patch.object(router._pool, "get_connection", return_value=mock_conn):
|
|
result = await router.call_tool(
|
|
server_name="server-1",
|
|
tool_name="test-tool",
|
|
arguments={"param": "value"},
|
|
)
|
|
|
|
assert result.success is True
|
|
assert result.data == {"status": "ok"}
|
|
assert result.tool_name == "test-tool"
|
|
assert result.server_name == "server-1"
|
|
assert result.execution_time_ms > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_call_tool_error_response(self, router, registry):
|
|
"""Test tool call with error response."""
|
|
registry.set_capabilities(
|
|
"server-1",
|
|
tools=[{"name": "test-tool"}],
|
|
)
|
|
await router.register_tool_mapping("test-tool", "server-1")
|
|
|
|
mock_conn = AsyncMock()
|
|
mock_conn.execute_request = AsyncMock(
|
|
return_value={
|
|
"error": {
|
|
"code": -32000,
|
|
"message": "Tool execution failed",
|
|
}
|
|
}
|
|
)
|
|
mock_conn.is_connected = True
|
|
|
|
with patch.object(router._pool, "get_connection", return_value=mock_conn):
|
|
result = await router.call_tool(
|
|
server_name="server-1",
|
|
tool_name="test-tool",
|
|
arguments={},
|
|
)
|
|
|
|
assert result.success is False
|
|
assert "Tool execution failed" in result.error
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_route_tool(self, router, registry):
|
|
"""Test routing tool to correct server."""
|
|
registry.set_capabilities(
|
|
"server-1",
|
|
tools=[{"name": "tool-on-server-1"}],
|
|
)
|
|
await router.register_tool_mapping("tool-on-server-1", "server-1")
|
|
|
|
mock_conn = AsyncMock()
|
|
mock_conn.execute_request = AsyncMock(return_value={"result": "routed"})
|
|
mock_conn.is_connected = True
|
|
|
|
with patch.object(router._pool, "get_connection", return_value=mock_conn):
|
|
result = await router.route_tool(
|
|
tool_name="tool-on-server-1",
|
|
arguments={"key": "value"},
|
|
)
|
|
|
|
assert result.success is True
|
|
assert result.server_name == "server-1"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_route_tool_not_found(self, router):
|
|
"""Test routing unknown tool raises error."""
|
|
with pytest.raises(MCPToolNotFoundError) as exc_info:
|
|
await router.route_tool(
|
|
tool_name="unknown-tool",
|
|
arguments={},
|
|
)
|
|
|
|
assert exc_info.value.tool_name == "unknown-tool"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_all_tools(self, router, registry):
|
|
"""Test listing all tools."""
|
|
registry.set_capabilities(
|
|
"server-1",
|
|
tools=[
|
|
{"name": "tool1", "description": "Tool 1"},
|
|
{"name": "tool2", "description": "Tool 2"},
|
|
],
|
|
)
|
|
registry.set_capabilities(
|
|
"server-2",
|
|
tools=[{"name": "tool3", "description": "Tool 3"}],
|
|
)
|
|
|
|
tools = await router.list_all_tools()
|
|
|
|
assert len(tools) == 3
|
|
tool_names = [t.name for t in tools]
|
|
assert "tool1" in tool_names
|
|
assert "tool2" in tool_names
|
|
assert "tool3" in tool_names
|
|
|
|
def test_circuit_breaker_status(self, router, registry):
|
|
"""Test getting circuit breaker status."""
|
|
# Initially no circuit breakers
|
|
status = router.get_circuit_breaker_status()
|
|
assert status == {}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reset_circuit_breaker(self, router, registry):
|
|
"""Test resetting circuit breaker."""
|
|
# Reset non-existent returns False
|
|
result = await router.reset_circuit_breaker("server-1")
|
|
assert result is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_discover_tools(self, router, registry):
|
|
"""Test tool discovery from servers."""
|
|
# Create mocks for different servers
|
|
mock_conn_1 = AsyncMock()
|
|
mock_conn_1.execute_request = AsyncMock(
|
|
return_value={
|
|
"tools": [
|
|
{"name": "discovered-tool", "description": "A discovered tool"},
|
|
]
|
|
}
|
|
)
|
|
mock_conn_1.server_name = "server-1"
|
|
mock_conn_1.is_connected = True
|
|
|
|
mock_conn_2 = AsyncMock()
|
|
mock_conn_2.execute_request = AsyncMock(
|
|
return_value={"tools": []} # Empty for server-2
|
|
)
|
|
mock_conn_2.server_name = "server-2"
|
|
mock_conn_2.is_connected = True
|
|
|
|
async def get_connection_side_effect(server_name, _config):
|
|
if server_name == "server-1":
|
|
return mock_conn_1
|
|
return mock_conn_2
|
|
|
|
with patch.object(
|
|
router._pool,
|
|
"get_connection",
|
|
side_effect=get_connection_side_effect,
|
|
):
|
|
await router.discover_tools()
|
|
|
|
# Check that tool mapping was registered
|
|
server = router.find_server_for_tool("discovered-tool")
|
|
assert server == "server-1"
|
|
|
|
def test_calculate_retry_delay(self, router, registry):
|
|
"""Test retry delay calculation."""
|
|
config = registry.get("server-1")
|
|
|
|
delay1 = router._calculate_retry_delay(1, config)
|
|
delay2 = router._calculate_retry_delay(2, config)
|
|
delay3 = router._calculate_retry_delay(3, config)
|
|
|
|
# All delays should be positive
|
|
assert delay1 > 0
|
|
assert delay2 > 0
|
|
assert delay3 > 0
|
|
# All delays should be within max bounds (allow for jitter variation)
|
|
assert delay1 <= config.retry_max_delay * 1.25
|
|
assert delay2 <= config.retry_max_delay * 1.25
|
|
assert delay3 <= config.retry_max_delay * 1.25
|