fix(mcp-kb): address critical issues from deep review
- Fix SQL HAVING clause bug by using CTE approach (closes #73) - Add /mcp JSON-RPC 2.0 endpoint for tool execution (closes #74) - Add /mcp/tools endpoint for tool discovery (closes #75) - Add content size limits to prevent DoS attacks (closes #78) - Add comprehensive tests for new endpoints 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -5,11 +5,13 @@ Provides RAG capabilities with pgvector for semantic search,
|
||||
intelligent chunking, and collection management.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
from typing import Any, get_type_hints
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastmcp import FastMCP
|
||||
from pydantic import Field
|
||||
|
||||
@@ -116,6 +118,259 @@ async def health_check() -> dict[str, Any]:
|
||||
return status
|
||||
|
||||
|
||||
# Tool registry for JSON-RPC
|
||||
_tool_registry: dict[str, Any] = {}
|
||||
|
||||
|
||||
def _python_type_to_json_schema(python_type: Any) -> dict[str, Any]:
|
||||
"""Convert Python type annotation to JSON Schema."""
|
||||
type_name = getattr(python_type, "__name__", str(python_type))
|
||||
|
||||
if python_type is str or type_name == "str":
|
||||
return {"type": "string"}
|
||||
elif python_type is int or type_name == "int":
|
||||
return {"type": "integer"}
|
||||
elif python_type is float or type_name == "float":
|
||||
return {"type": "number"}
|
||||
elif python_type is bool or type_name == "bool":
|
||||
return {"type": "boolean"}
|
||||
elif type_name == "NoneType":
|
||||
return {"type": "null"}
|
||||
elif hasattr(python_type, "__origin__"):
|
||||
origin = python_type.__origin__
|
||||
args = getattr(python_type, "__args__", ())
|
||||
|
||||
if origin is list:
|
||||
item_type = args[0] if args else Any
|
||||
return {"type": "array", "items": _python_type_to_json_schema(item_type)}
|
||||
elif origin is dict:
|
||||
return {"type": "object"}
|
||||
elif origin is type(None) or str(origin) == "typing.Union":
|
||||
# Handle Optional types (Union with None)
|
||||
non_none_args = [a for a in args if a is not type(None)]
|
||||
if len(non_none_args) == 1:
|
||||
schema = _python_type_to_json_schema(non_none_args[0])
|
||||
schema["nullable"] = True
|
||||
return schema
|
||||
return {"type": "object"}
|
||||
return {"type": "object"}
|
||||
|
||||
|
||||
def _get_tool_schema(func: Any) -> dict[str, Any]:
|
||||
"""Extract JSON Schema from a tool function."""
|
||||
sig = inspect.signature(func)
|
||||
hints = get_type_hints(func) if hasattr(func, "__annotations__") else {}
|
||||
|
||||
properties: dict[str, Any] = {}
|
||||
required: list[str] = []
|
||||
|
||||
for name, param in sig.parameters.items():
|
||||
if name in ("self", "cls"):
|
||||
continue
|
||||
|
||||
prop: dict[str, Any] = {}
|
||||
|
||||
# Get type from hints
|
||||
if name in hints:
|
||||
prop = _python_type_to_json_schema(hints[name])
|
||||
|
||||
# Get description and constraints from Field default (FieldInfo object)
|
||||
default_val = param.default
|
||||
if hasattr(default_val, "description") and default_val.description:
|
||||
prop["description"] = default_val.description
|
||||
if hasattr(default_val, "ge") and default_val.ge is not None:
|
||||
prop["minimum"] = default_val.ge
|
||||
if hasattr(default_val, "le") and default_val.le is not None:
|
||||
prop["maximum"] = default_val.le
|
||||
# Handle Field default value (check for PydanticUndefined)
|
||||
if hasattr(default_val, "default"):
|
||||
field_default = default_val.default
|
||||
# Check if it's the "required" sentinel (...)
|
||||
if field_default is not ... and not (
|
||||
hasattr(field_default, "__class__")
|
||||
and "PydanticUndefined" in field_default.__class__.__name__
|
||||
):
|
||||
prop["default"] = field_default
|
||||
|
||||
# Determine if required
|
||||
if param.default is inspect.Parameter.empty:
|
||||
required.append(name)
|
||||
elif hasattr(default_val, "default"):
|
||||
field_default = default_val.default
|
||||
# Required if default is ellipsis or PydanticUndefined
|
||||
if field_default is ... or (
|
||||
hasattr(field_default, "__class__")
|
||||
and "PydanticUndefined" in field_default.__class__.__name__
|
||||
):
|
||||
required.append(name)
|
||||
|
||||
properties[name] = prop
|
||||
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
}
|
||||
|
||||
|
||||
def _register_tool(name: str, tool_or_func: Any, description: str | None = None) -> None:
|
||||
"""Register a tool in the registry.
|
||||
|
||||
Handles both raw functions and FastMCP FunctionTool objects.
|
||||
"""
|
||||
# Extract the underlying function from FastMCP FunctionTool if needed
|
||||
if hasattr(tool_or_func, "fn"):
|
||||
func = tool_or_func.fn
|
||||
# Use FunctionTool's description if available
|
||||
if not description and hasattr(tool_or_func, "description") and tool_or_func.description:
|
||||
description = tool_or_func.description
|
||||
else:
|
||||
func = tool_or_func
|
||||
|
||||
_tool_registry[name] = {
|
||||
"func": func,
|
||||
"description": description or (func.__doc__ or "").strip(),
|
||||
"schema": _get_tool_schema(func),
|
||||
}
|
||||
|
||||
|
||||
@app.get("/mcp/tools")
|
||||
async def list_mcp_tools() -> dict[str, Any]:
|
||||
"""
|
||||
Return list of available MCP tools with their schemas.
|
||||
|
||||
This endpoint enables tool discovery for the backend MCP client.
|
||||
"""
|
||||
tools = []
|
||||
for name, info in _tool_registry.items():
|
||||
tools.append({
|
||||
"name": name,
|
||||
"description": info["description"],
|
||||
"inputSchema": info["schema"],
|
||||
})
|
||||
|
||||
return {"tools": tools}
|
||||
|
||||
|
||||
@app.post("/mcp")
|
||||
async def mcp_rpc(request: Request) -> JSONResponse:
|
||||
"""
|
||||
JSON-RPC 2.0 endpoint for MCP tool execution.
|
||||
|
||||
Request format:
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"method": "<tool_name>",
|
||||
"params": {...},
|
||||
"id": <request_id>
|
||||
}
|
||||
|
||||
Response format:
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"result": {...},
|
||||
"id": <request_id>
|
||||
}
|
||||
"""
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"error": {"code": -32700, "message": f"Parse error: {e}"},
|
||||
"id": None,
|
||||
},
|
||||
)
|
||||
|
||||
# Validate JSON-RPC structure
|
||||
jsonrpc = body.get("jsonrpc")
|
||||
method = body.get("method")
|
||||
params = body.get("params", {})
|
||||
request_id = body.get("id")
|
||||
|
||||
if jsonrpc != "2.0":
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"error": {"code": -32600, "message": "Invalid Request: jsonrpc must be '2.0'"},
|
||||
"id": request_id,
|
||||
},
|
||||
)
|
||||
|
||||
if not method:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"error": {"code": -32600, "message": "Invalid Request: method is required"},
|
||||
"id": request_id,
|
||||
},
|
||||
)
|
||||
|
||||
# Look up tool
|
||||
tool_info = _tool_registry.get(method)
|
||||
if not tool_info:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"error": {"code": -32601, "message": f"Method not found: {method}"},
|
||||
"id": request_id,
|
||||
},
|
||||
)
|
||||
|
||||
# Execute tool
|
||||
try:
|
||||
func = tool_info["func"]
|
||||
|
||||
# Resolve Field defaults for missing parameters
|
||||
sig = inspect.signature(func)
|
||||
resolved_params = dict(params)
|
||||
for name, param in sig.parameters.items():
|
||||
if name not in resolved_params:
|
||||
default_val = param.default
|
||||
# Check if it's a FieldInfo with a default value
|
||||
if hasattr(default_val, "default"):
|
||||
field_default = default_val.default
|
||||
# Only use if it has an actual default (not required)
|
||||
if field_default is not ... and not (
|
||||
hasattr(field_default, "__class__")
|
||||
and "PydanticUndefined" in field_default.__class__.__name__
|
||||
):
|
||||
resolved_params[name] = field_default
|
||||
|
||||
result = await func(**resolved_params)
|
||||
return JSONResponse(
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"result": result,
|
||||
"id": request_id,
|
||||
}
|
||||
)
|
||||
except TypeError as e:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"error": {"code": -32602, "message": f"Invalid params: {e}"},
|
||||
"id": request_id,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Tool execution error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"error": {"code": -32000, "message": f"Server error: {e}"},
|
||||
"id": request_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# MCP Tools
|
||||
|
||||
|
||||
@@ -261,6 +516,15 @@ async def ingest_content(
|
||||
the LLM Gateway, and stored in pgvector for search.
|
||||
"""
|
||||
try:
|
||||
# Validate content size to prevent DoS
|
||||
settings = get_settings()
|
||||
content_size = len(content.encode("utf-8"))
|
||||
if content_size > settings.max_document_size:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Content size ({content_size} bytes) exceeds maximum allowed ({settings.max_document_size} bytes)",
|
||||
}
|
||||
|
||||
# Parse chunk type
|
||||
try:
|
||||
chunk_type_enum = ChunkType(chunk_type.lower())
|
||||
@@ -492,6 +756,15 @@ async def update_document(
|
||||
Replaces all existing chunks for the source path with new content.
|
||||
"""
|
||||
try:
|
||||
# Validate content size to prevent DoS
|
||||
settings = get_settings()
|
||||
content_size = len(content.encode("utf-8"))
|
||||
if content_size > settings.max_document_size:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Content size ({content_size} bytes) exceeds maximum allowed ({settings.max_document_size} bytes)",
|
||||
}
|
||||
|
||||
# Parse chunk type
|
||||
try:
|
||||
chunk_type_enum = ChunkType(chunk_type.lower())
|
||||
@@ -550,6 +823,16 @@ async def update_document(
|
||||
}
|
||||
|
||||
|
||||
# Register tools in the JSON-RPC registry
|
||||
# This must happen after tool functions are defined
|
||||
_register_tool("search_knowledge", search_knowledge)
|
||||
_register_tool("ingest_content", ingest_content)
|
||||
_register_tool("delete_content", delete_content)
|
||||
_register_tool("list_collections", list_collections)
|
||||
_register_tool("get_collection_stats", get_collection_stats)
|
||||
_register_tool("update_document", update_document)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Run the server."""
|
||||
import uvicorn
|
||||
|
||||
Reference in New Issue
Block a user