forked from cardosofelipe/fast-next-template
fix(mcp-kb): address critical issues from deep review
- Fix SQL HAVING clause bug by using CTE approach (closes #73) - Add /mcp JSON-RPC 2.0 endpoint for tool execution (closes #74) - Add /mcp/tools endpoint for tool discovery (closes #75) - Add content size limits to prevent DoS attacks (closes #78) - Add comprehensive tests for new endpoints 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
"""Tests for server module and MCP tools."""
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
class TestHealthCheck:
|
||||
@@ -355,3 +357,248 @@ class TestUpdateDocumentTool:
|
||||
|
||||
assert result["success"] is False
|
||||
assert "Invalid chunk type" in result["error"]
|
||||
|
||||
|
||||
class TestMCPToolsEndpoint:
|
||||
"""Tests for /mcp/tools endpoint."""
|
||||
|
||||
def test_list_mcp_tools(self):
|
||||
"""Test listing available MCP tools."""
|
||||
import server
|
||||
|
||||
client = TestClient(server.app)
|
||||
response = client.get("/mcp/tools")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "tools" in data
|
||||
assert len(data["tools"]) == 6 # 6 tools registered
|
||||
|
||||
tool_names = [t["name"] for t in data["tools"]]
|
||||
assert "search_knowledge" in tool_names
|
||||
assert "ingest_content" in tool_names
|
||||
assert "delete_content" in tool_names
|
||||
assert "list_collections" in tool_names
|
||||
assert "get_collection_stats" in tool_names
|
||||
assert "update_document" in tool_names
|
||||
|
||||
def test_tool_has_schema(self):
|
||||
"""Test that each tool has input schema."""
|
||||
import server
|
||||
|
||||
client = TestClient(server.app)
|
||||
response = client.get("/mcp/tools")
|
||||
|
||||
data = response.json()
|
||||
for tool in data["tools"]:
|
||||
assert "inputSchema" in tool
|
||||
assert "type" in tool["inputSchema"]
|
||||
assert tool["inputSchema"]["type"] == "object"
|
||||
|
||||
|
||||
class TestMCPRPCEndpoint:
|
||||
"""Tests for /mcp JSON-RPC endpoint."""
|
||||
|
||||
def test_valid_jsonrpc_request(self):
|
||||
"""Test valid JSON-RPC request."""
|
||||
import server
|
||||
from models import SearchResponse, SearchResult
|
||||
|
||||
mock_search = MagicMock()
|
||||
mock_search.search = AsyncMock(
|
||||
return_value=SearchResponse(
|
||||
query="test",
|
||||
search_type="hybrid",
|
||||
results=[
|
||||
SearchResult(
|
||||
id="id-1",
|
||||
content="Test",
|
||||
score=0.9,
|
||||
source_path="/test.py",
|
||||
chunk_type="code",
|
||||
collection="default",
|
||||
)
|
||||
],
|
||||
total_results=1,
|
||||
search_time_ms=5.0,
|
||||
)
|
||||
)
|
||||
server._search = mock_search
|
||||
|
||||
client = TestClient(server.app)
|
||||
response = client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "search_knowledge",
|
||||
"params": {
|
||||
"project_id": "proj-123",
|
||||
"agent_id": "agent-456",
|
||||
"query": "test",
|
||||
},
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["jsonrpc"] == "2.0"
|
||||
assert data["id"] == 1
|
||||
assert "result" in data
|
||||
assert data["result"]["success"] is True
|
||||
|
||||
def test_invalid_jsonrpc_version(self):
|
||||
"""Test request with invalid JSON-RPC version."""
|
||||
import server
|
||||
|
||||
client = TestClient(server.app)
|
||||
response = client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "1.0",
|
||||
"method": "search_knowledge",
|
||||
"params": {},
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
data = response.json()
|
||||
assert data["error"]["code"] == -32600
|
||||
assert "jsonrpc must be '2.0'" in data["error"]["message"]
|
||||
|
||||
def test_missing_method(self):
|
||||
"""Test request without method."""
|
||||
import server
|
||||
|
||||
client = TestClient(server.app)
|
||||
response = client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"params": {},
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
data = response.json()
|
||||
assert data["error"]["code"] == -32600
|
||||
assert "method is required" in data["error"]["message"]
|
||||
|
||||
def test_unknown_method(self):
|
||||
"""Test request with unknown method."""
|
||||
import server
|
||||
|
||||
client = TestClient(server.app)
|
||||
response = client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "unknown_method",
|
||||
"params": {},
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
assert data["error"]["code"] == -32601
|
||||
assert "Method not found" in data["error"]["message"]
|
||||
|
||||
def test_invalid_params(self):
|
||||
"""Test request with invalid params."""
|
||||
import server
|
||||
|
||||
client = TestClient(server.app)
|
||||
response = client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "search_knowledge",
|
||||
"params": {"invalid_param": "value"}, # Missing required params
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
data = response.json()
|
||||
assert data["error"]["code"] == -32602
|
||||
|
||||
|
||||
class TestContentSizeLimits:
|
||||
"""Tests for content size validation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_rejects_oversized_content(self):
|
||||
"""Test that ingest rejects content exceeding size limit."""
|
||||
import server
|
||||
from config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
# Create content larger than max size
|
||||
oversized_content = "x" * (settings.max_document_size + 1)
|
||||
|
||||
result = await server.ingest_content.fn(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
content=oversized_content,
|
||||
chunk_type="text",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "exceeds maximum" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_rejects_oversized_content(self):
|
||||
"""Test that update rejects content exceeding size limit."""
|
||||
import server
|
||||
from config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
oversized_content = "x" * (settings.max_document_size + 1)
|
||||
|
||||
result = await server.update_document.fn(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
source_path="/test.py",
|
||||
content=oversized_content,
|
||||
chunk_type="text",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "exceeds maximum" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_accepts_valid_size_content(self):
|
||||
"""Test that ingest accepts content within size limit."""
|
||||
import server
|
||||
from models import IngestResult
|
||||
|
||||
mock_collections = MagicMock()
|
||||
mock_collections.ingest = AsyncMock(
|
||||
return_value=IngestResult(
|
||||
success=True,
|
||||
chunks_created=1,
|
||||
embeddings_generated=1,
|
||||
source_path="/test.py",
|
||||
collection="default",
|
||||
chunk_ids=["id-1"],
|
||||
)
|
||||
)
|
||||
server._collections = mock_collections
|
||||
|
||||
# Small content that's within limits
|
||||
# Pass all parameters to avoid Field default resolution issues
|
||||
result = await server.ingest_content.fn(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
content="def hello(): pass",
|
||||
source_path="/test.py",
|
||||
collection="default",
|
||||
chunk_type="text",
|
||||
file_type=None,
|
||||
metadata=None,
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
Reference in New Issue
Block a user