forked from cardosofelipe/fast-next-template
fix(mcp-kb): address critical issues from deep review
- Fix SQL HAVING clause bug by using CTE approach (closes #73) - Add /mcp JSON-RPC 2.0 endpoint for tool execution (closes #74) - Add /mcp/tools endpoint for tool discovery (closes #75) - Add content size limits to prevent DoS attacks (closes #78) - Add comprehensive tests for new endpoints 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -112,6 +112,20 @@ class Settings(BaseSettings):
|
||||
description="TTL for embedding records in days (0 = no expiry)",
|
||||
)
|
||||
|
||||
# Content size limits (DoS prevention)
|
||||
max_document_size: int = Field(
|
||||
default=10 * 1024 * 1024, # 10 MB
|
||||
description="Maximum size of a single document in bytes",
|
||||
)
|
||||
max_batch_size: int = Field(
|
||||
default=100,
|
||||
description="Maximum number of documents in a batch operation",
|
||||
)
|
||||
max_batch_total_size: int = Field(
|
||||
default=50 * 1024 * 1024, # 50 MB
|
||||
description="Maximum total size of all documents in a batch",
|
||||
)
|
||||
|
||||
model_config = {"env_prefix": "KB_", "env_file": ".env", "extra": "ignore"}
|
||||
|
||||
|
||||
|
||||
@@ -345,8 +345,9 @@ class DatabaseManager:
|
||||
"""
|
||||
try:
|
||||
async with self.acquire() as conn:
|
||||
# Build query with optional filters
|
||||
query = """
|
||||
# Build query with optional filters using CTE to filter by similarity
|
||||
# We use a CTE to compute similarity once, then filter in outer query
|
||||
inner_query = """
|
||||
SELECT
|
||||
id, project_id, collection, content, embedding,
|
||||
chunk_type, source_path, start_line, end_line,
|
||||
@@ -361,18 +362,21 @@ class DatabaseManager:
|
||||
param_idx = 3
|
||||
|
||||
if collection:
|
||||
query += f" AND collection = ${param_idx}"
|
||||
inner_query += f" AND collection = ${param_idx}"
|
||||
params.append(collection)
|
||||
param_idx += 1
|
||||
|
||||
if file_types:
|
||||
file_type_values = [ft.value for ft in file_types]
|
||||
query += f" AND file_type = ANY(${param_idx})"
|
||||
inner_query += f" AND file_type = ANY(${param_idx})"
|
||||
params.append(file_type_values)
|
||||
param_idx += 1
|
||||
|
||||
query += f"""
|
||||
HAVING 1 - (embedding <=> $1) >= ${param_idx}
|
||||
# Wrap in CTE and filter by threshold in outer query
|
||||
query = f"""
|
||||
WITH scored AS ({inner_query})
|
||||
SELECT * FROM scored
|
||||
WHERE similarity >= ${param_idx}
|
||||
ORDER BY similarity DESC
|
||||
LIMIT ${param_idx + 1}
|
||||
"""
|
||||
|
||||
@@ -5,11 +5,13 @@ Provides RAG capabilities with pgvector for semantic search,
|
||||
intelligent chunking, and collection management.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
from typing import Any, get_type_hints
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastmcp import FastMCP
|
||||
from pydantic import Field
|
||||
|
||||
@@ -116,6 +118,259 @@ async def health_check() -> dict[str, Any]:
|
||||
return status
|
||||
|
||||
|
||||
# Tool registry for JSON-RPC
|
||||
_tool_registry: dict[str, Any] = {}
|
||||
|
||||
|
||||
def _python_type_to_json_schema(python_type: Any) -> dict[str, Any]:
|
||||
"""Convert Python type annotation to JSON Schema."""
|
||||
type_name = getattr(python_type, "__name__", str(python_type))
|
||||
|
||||
if python_type is str or type_name == "str":
|
||||
return {"type": "string"}
|
||||
elif python_type is int or type_name == "int":
|
||||
return {"type": "integer"}
|
||||
elif python_type is float or type_name == "float":
|
||||
return {"type": "number"}
|
||||
elif python_type is bool or type_name == "bool":
|
||||
return {"type": "boolean"}
|
||||
elif type_name == "NoneType":
|
||||
return {"type": "null"}
|
||||
elif hasattr(python_type, "__origin__"):
|
||||
origin = python_type.__origin__
|
||||
args = getattr(python_type, "__args__", ())
|
||||
|
||||
if origin is list:
|
||||
item_type = args[0] if args else Any
|
||||
return {"type": "array", "items": _python_type_to_json_schema(item_type)}
|
||||
elif origin is dict:
|
||||
return {"type": "object"}
|
||||
elif origin is type(None) or str(origin) == "typing.Union":
|
||||
# Handle Optional types (Union with None)
|
||||
non_none_args = [a for a in args if a is not type(None)]
|
||||
if len(non_none_args) == 1:
|
||||
schema = _python_type_to_json_schema(non_none_args[0])
|
||||
schema["nullable"] = True
|
||||
return schema
|
||||
return {"type": "object"}
|
||||
return {"type": "object"}
|
||||
|
||||
|
||||
def _get_tool_schema(func: Any) -> dict[str, Any]:
|
||||
"""Extract JSON Schema from a tool function."""
|
||||
sig = inspect.signature(func)
|
||||
hints = get_type_hints(func) if hasattr(func, "__annotations__") else {}
|
||||
|
||||
properties: dict[str, Any] = {}
|
||||
required: list[str] = []
|
||||
|
||||
for name, param in sig.parameters.items():
|
||||
if name in ("self", "cls"):
|
||||
continue
|
||||
|
||||
prop: dict[str, Any] = {}
|
||||
|
||||
# Get type from hints
|
||||
if name in hints:
|
||||
prop = _python_type_to_json_schema(hints[name])
|
||||
|
||||
# Get description and constraints from Field default (FieldInfo object)
|
||||
default_val = param.default
|
||||
if hasattr(default_val, "description") and default_val.description:
|
||||
prop["description"] = default_val.description
|
||||
if hasattr(default_val, "ge") and default_val.ge is not None:
|
||||
prop["minimum"] = default_val.ge
|
||||
if hasattr(default_val, "le") and default_val.le is not None:
|
||||
prop["maximum"] = default_val.le
|
||||
# Handle Field default value (check for PydanticUndefined)
|
||||
if hasattr(default_val, "default"):
|
||||
field_default = default_val.default
|
||||
# Check if it's the "required" sentinel (...)
|
||||
if field_default is not ... and not (
|
||||
hasattr(field_default, "__class__")
|
||||
and "PydanticUndefined" in field_default.__class__.__name__
|
||||
):
|
||||
prop["default"] = field_default
|
||||
|
||||
# Determine if required
|
||||
if param.default is inspect.Parameter.empty:
|
||||
required.append(name)
|
||||
elif hasattr(default_val, "default"):
|
||||
field_default = default_val.default
|
||||
# Required if default is ellipsis or PydanticUndefined
|
||||
if field_default is ... or (
|
||||
hasattr(field_default, "__class__")
|
||||
and "PydanticUndefined" in field_default.__class__.__name__
|
||||
):
|
||||
required.append(name)
|
||||
|
||||
properties[name] = prop
|
||||
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
}
|
||||
|
||||
|
||||
def _register_tool(name: str, tool_or_func: Any, description: str | None = None) -> None:
|
||||
"""Register a tool in the registry.
|
||||
|
||||
Handles both raw functions and FastMCP FunctionTool objects.
|
||||
"""
|
||||
# Extract the underlying function from FastMCP FunctionTool if needed
|
||||
if hasattr(tool_or_func, "fn"):
|
||||
func = tool_or_func.fn
|
||||
# Use FunctionTool's description if available
|
||||
if not description and hasattr(tool_or_func, "description") and tool_or_func.description:
|
||||
description = tool_or_func.description
|
||||
else:
|
||||
func = tool_or_func
|
||||
|
||||
_tool_registry[name] = {
|
||||
"func": func,
|
||||
"description": description or (func.__doc__ or "").strip(),
|
||||
"schema": _get_tool_schema(func),
|
||||
}
|
||||
|
||||
|
||||
@app.get("/mcp/tools")
|
||||
async def list_mcp_tools() -> dict[str, Any]:
|
||||
"""
|
||||
Return list of available MCP tools with their schemas.
|
||||
|
||||
This endpoint enables tool discovery for the backend MCP client.
|
||||
"""
|
||||
tools = []
|
||||
for name, info in _tool_registry.items():
|
||||
tools.append({
|
||||
"name": name,
|
||||
"description": info["description"],
|
||||
"inputSchema": info["schema"],
|
||||
})
|
||||
|
||||
return {"tools": tools}
|
||||
|
||||
|
||||
@app.post("/mcp")
|
||||
async def mcp_rpc(request: Request) -> JSONResponse:
|
||||
"""
|
||||
JSON-RPC 2.0 endpoint for MCP tool execution.
|
||||
|
||||
Request format:
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"method": "<tool_name>",
|
||||
"params": {...},
|
||||
"id": <request_id>
|
||||
}
|
||||
|
||||
Response format:
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"result": {...},
|
||||
"id": <request_id>
|
||||
}
|
||||
"""
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"error": {"code": -32700, "message": f"Parse error: {e}"},
|
||||
"id": None,
|
||||
},
|
||||
)
|
||||
|
||||
# Validate JSON-RPC structure
|
||||
jsonrpc = body.get("jsonrpc")
|
||||
method = body.get("method")
|
||||
params = body.get("params", {})
|
||||
request_id = body.get("id")
|
||||
|
||||
if jsonrpc != "2.0":
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"error": {"code": -32600, "message": "Invalid Request: jsonrpc must be '2.0'"},
|
||||
"id": request_id,
|
||||
},
|
||||
)
|
||||
|
||||
if not method:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"error": {"code": -32600, "message": "Invalid Request: method is required"},
|
||||
"id": request_id,
|
||||
},
|
||||
)
|
||||
|
||||
# Look up tool
|
||||
tool_info = _tool_registry.get(method)
|
||||
if not tool_info:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"error": {"code": -32601, "message": f"Method not found: {method}"},
|
||||
"id": request_id,
|
||||
},
|
||||
)
|
||||
|
||||
# Execute tool
|
||||
try:
|
||||
func = tool_info["func"]
|
||||
|
||||
# Resolve Field defaults for missing parameters
|
||||
sig = inspect.signature(func)
|
||||
resolved_params = dict(params)
|
||||
for name, param in sig.parameters.items():
|
||||
if name not in resolved_params:
|
||||
default_val = param.default
|
||||
# Check if it's a FieldInfo with a default value
|
||||
if hasattr(default_val, "default"):
|
||||
field_default = default_val.default
|
||||
# Only use if it has an actual default (not required)
|
||||
if field_default is not ... and not (
|
||||
hasattr(field_default, "__class__")
|
||||
and "PydanticUndefined" in field_default.__class__.__name__
|
||||
):
|
||||
resolved_params[name] = field_default
|
||||
|
||||
result = await func(**resolved_params)
|
||||
return JSONResponse(
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"result": result,
|
||||
"id": request_id,
|
||||
}
|
||||
)
|
||||
except TypeError as e:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"error": {"code": -32602, "message": f"Invalid params: {e}"},
|
||||
"id": request_id,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Tool execution error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"error": {"code": -32000, "message": f"Server error: {e}"},
|
||||
"id": request_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# MCP Tools
|
||||
|
||||
|
||||
@@ -261,6 +516,15 @@ async def ingest_content(
|
||||
the LLM Gateway, and stored in pgvector for search.
|
||||
"""
|
||||
try:
|
||||
# Validate content size to prevent DoS
|
||||
settings = get_settings()
|
||||
content_size = len(content.encode("utf-8"))
|
||||
if content_size > settings.max_document_size:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Content size ({content_size} bytes) exceeds maximum allowed ({settings.max_document_size} bytes)",
|
||||
}
|
||||
|
||||
# Parse chunk type
|
||||
try:
|
||||
chunk_type_enum = ChunkType(chunk_type.lower())
|
||||
@@ -492,6 +756,15 @@ async def update_document(
|
||||
Replaces all existing chunks for the source path with new content.
|
||||
"""
|
||||
try:
|
||||
# Validate content size to prevent DoS
|
||||
settings = get_settings()
|
||||
content_size = len(content.encode("utf-8"))
|
||||
if content_size > settings.max_document_size:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Content size ({content_size} bytes) exceeds maximum allowed ({settings.max_document_size} bytes)",
|
||||
}
|
||||
|
||||
# Parse chunk type
|
||||
try:
|
||||
chunk_type_enum = ChunkType(chunk_type.lower())
|
||||
@@ -550,6 +823,16 @@ async def update_document(
|
||||
}
|
||||
|
||||
|
||||
# Register tools in the JSON-RPC registry
|
||||
# This must happen after tool functions are defined
|
||||
_register_tool("search_knowledge", search_knowledge)
|
||||
_register_tool("ingest_content", ingest_content)
|
||||
_register_tool("delete_content", delete_content)
|
||||
_register_tool("list_collections", list_collections)
|
||||
_register_tool("get_collection_stats", get_collection_stats)
|
||||
_register_tool("update_document", update_document)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Run the server."""
|
||||
import uvicorn
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
"""Tests for server module and MCP tools."""
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
class TestHealthCheck:
|
||||
@@ -355,3 +357,248 @@ class TestUpdateDocumentTool:
|
||||
|
||||
assert result["success"] is False
|
||||
assert "Invalid chunk type" in result["error"]
|
||||
|
||||
|
||||
class TestMCPToolsEndpoint:
|
||||
"""Tests for /mcp/tools endpoint."""
|
||||
|
||||
def test_list_mcp_tools(self):
|
||||
"""Test listing available MCP tools."""
|
||||
import server
|
||||
|
||||
client = TestClient(server.app)
|
||||
response = client.get("/mcp/tools")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "tools" in data
|
||||
assert len(data["tools"]) == 6 # 6 tools registered
|
||||
|
||||
tool_names = [t["name"] for t in data["tools"]]
|
||||
assert "search_knowledge" in tool_names
|
||||
assert "ingest_content" in tool_names
|
||||
assert "delete_content" in tool_names
|
||||
assert "list_collections" in tool_names
|
||||
assert "get_collection_stats" in tool_names
|
||||
assert "update_document" in tool_names
|
||||
|
||||
def test_tool_has_schema(self):
|
||||
"""Test that each tool has input schema."""
|
||||
import server
|
||||
|
||||
client = TestClient(server.app)
|
||||
response = client.get("/mcp/tools")
|
||||
|
||||
data = response.json()
|
||||
for tool in data["tools"]:
|
||||
assert "inputSchema" in tool
|
||||
assert "type" in tool["inputSchema"]
|
||||
assert tool["inputSchema"]["type"] == "object"
|
||||
|
||||
|
||||
class TestMCPRPCEndpoint:
|
||||
"""Tests for /mcp JSON-RPC endpoint."""
|
||||
|
||||
def test_valid_jsonrpc_request(self):
|
||||
"""Test valid JSON-RPC request."""
|
||||
import server
|
||||
from models import SearchResponse, SearchResult
|
||||
|
||||
mock_search = MagicMock()
|
||||
mock_search.search = AsyncMock(
|
||||
return_value=SearchResponse(
|
||||
query="test",
|
||||
search_type="hybrid",
|
||||
results=[
|
||||
SearchResult(
|
||||
id="id-1",
|
||||
content="Test",
|
||||
score=0.9,
|
||||
source_path="/test.py",
|
||||
chunk_type="code",
|
||||
collection="default",
|
||||
)
|
||||
],
|
||||
total_results=1,
|
||||
search_time_ms=5.0,
|
||||
)
|
||||
)
|
||||
server._search = mock_search
|
||||
|
||||
client = TestClient(server.app)
|
||||
response = client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "search_knowledge",
|
||||
"params": {
|
||||
"project_id": "proj-123",
|
||||
"agent_id": "agent-456",
|
||||
"query": "test",
|
||||
},
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["jsonrpc"] == "2.0"
|
||||
assert data["id"] == 1
|
||||
assert "result" in data
|
||||
assert data["result"]["success"] is True
|
||||
|
||||
def test_invalid_jsonrpc_version(self):
|
||||
"""Test request with invalid JSON-RPC version."""
|
||||
import server
|
||||
|
||||
client = TestClient(server.app)
|
||||
response = client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "1.0",
|
||||
"method": "search_knowledge",
|
||||
"params": {},
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
data = response.json()
|
||||
assert data["error"]["code"] == -32600
|
||||
assert "jsonrpc must be '2.0'" in data["error"]["message"]
|
||||
|
||||
def test_missing_method(self):
|
||||
"""Test request without method."""
|
||||
import server
|
||||
|
||||
client = TestClient(server.app)
|
||||
response = client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"params": {},
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
data = response.json()
|
||||
assert data["error"]["code"] == -32600
|
||||
assert "method is required" in data["error"]["message"]
|
||||
|
||||
def test_unknown_method(self):
|
||||
"""Test request with unknown method."""
|
||||
import server
|
||||
|
||||
client = TestClient(server.app)
|
||||
response = client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "unknown_method",
|
||||
"params": {},
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
assert data["error"]["code"] == -32601
|
||||
assert "Method not found" in data["error"]["message"]
|
||||
|
||||
def test_invalid_params(self):
|
||||
"""Test request with invalid params."""
|
||||
import server
|
||||
|
||||
client = TestClient(server.app)
|
||||
response = client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "search_knowledge",
|
||||
"params": {"invalid_param": "value"}, # Missing required params
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
data = response.json()
|
||||
assert data["error"]["code"] == -32602
|
||||
|
||||
|
||||
class TestContentSizeLimits:
|
||||
"""Tests for content size validation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_rejects_oversized_content(self):
|
||||
"""Test that ingest rejects content exceeding size limit."""
|
||||
import server
|
||||
from config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
# Create content larger than max size
|
||||
oversized_content = "x" * (settings.max_document_size + 1)
|
||||
|
||||
result = await server.ingest_content.fn(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
content=oversized_content,
|
||||
chunk_type="text",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "exceeds maximum" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_rejects_oversized_content(self):
|
||||
"""Test that update rejects content exceeding size limit."""
|
||||
import server
|
||||
from config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
oversized_content = "x" * (settings.max_document_size + 1)
|
||||
|
||||
result = await server.update_document.fn(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
source_path="/test.py",
|
||||
content=oversized_content,
|
||||
chunk_type="text",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "exceeds maximum" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_accepts_valid_size_content(self):
|
||||
"""Test that ingest accepts content within size limit."""
|
||||
import server
|
||||
from models import IngestResult
|
||||
|
||||
mock_collections = MagicMock()
|
||||
mock_collections.ingest = AsyncMock(
|
||||
return_value=IngestResult(
|
||||
success=True,
|
||||
chunks_created=1,
|
||||
embeddings_generated=1,
|
||||
source_path="/test.py",
|
||||
collection="default",
|
||||
chunk_ids=["id-1"],
|
||||
)
|
||||
)
|
||||
server._collections = mock_collections
|
||||
|
||||
# Small content that's within limits
|
||||
# Pass all parameters to avoid Field default resolution issues
|
||||
result = await server.ingest_content.fn(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
content="def hello(): pass",
|
||||
source_path="/test.py",
|
||||
collection="default",
|
||||
chunk_type="text",
|
||||
file_type=None,
|
||||
metadata=None,
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
Reference in New Issue
Block a user