fix(mcp-kb): address critical issues from deep review

- Fix SQL HAVING clause bug by using CTE approach (closes #73)
- Add /mcp JSON-RPC 2.0 endpoint for tool execution (closes #74)
- Add /mcp/tools endpoint for tool discovery (closes #75)
- Add content size limits to prevent DoS attacks (closes #78)
- Add comprehensive tests for new endpoints

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-04 01:03:58 +01:00
parent e6e98d4ed1
commit 953af52d0e
4 changed files with 557 additions and 9 deletions

View File

@@ -112,6 +112,20 @@ class Settings(BaseSettings):
description="TTL for embedding records in days (0 = no expiry)",
)
# Content size limits (DoS prevention)
max_document_size: int = Field(
default=10 * 1024 * 1024, # 10 MB
description="Maximum size of a single document in bytes",
)
max_batch_size: int = Field(
default=100,
description="Maximum number of documents in a batch operation",
)
max_batch_total_size: int = Field(
default=50 * 1024 * 1024, # 50 MB
description="Maximum total size of all documents in a batch",
)
model_config = {"env_prefix": "KB_", "env_file": ".env", "extra": "ignore"}

View File

@@ -345,8 +345,9 @@ class DatabaseManager:
"""
try:
async with self.acquire() as conn:
# Build query with optional filters
query = """
# Build query with optional filters using CTE to filter by similarity
# We use a CTE to compute similarity once, then filter in outer query
inner_query = """
SELECT
id, project_id, collection, content, embedding,
chunk_type, source_path, start_line, end_line,
@@ -361,18 +362,21 @@ class DatabaseManager:
param_idx = 3
if collection:
query += f" AND collection = ${param_idx}"
inner_query += f" AND collection = ${param_idx}"
params.append(collection)
param_idx += 1
if file_types:
file_type_values = [ft.value for ft in file_types]
query += f" AND file_type = ANY(${param_idx})"
inner_query += f" AND file_type = ANY(${param_idx})"
params.append(file_type_values)
param_idx += 1
query += f"""
HAVING 1 - (embedding <=> $1) >= ${param_idx}
# Wrap in CTE and filter by threshold in outer query
query = f"""
WITH scored AS ({inner_query})
SELECT * FROM scored
WHERE similarity >= ${param_idx}
ORDER BY similarity DESC
LIMIT ${param_idx + 1}
"""

View File

@@ -5,11 +5,13 @@ Provides RAG capabilities with pgvector for semantic search,
intelligent chunking, and collection management.
"""
import inspect
import logging
from contextlib import asynccontextmanager
from typing import Any
from typing import Any, get_type_hints
from fastapi import FastAPI
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastmcp import FastMCP
from pydantic import Field
@@ -116,6 +118,259 @@ async def health_check() -> dict[str, Any]:
return status
# Tool registry for JSON-RPC
_tool_registry: dict[str, Any] = {}
def _python_type_to_json_schema(python_type: Any) -> dict[str, Any]:
"""Convert Python type annotation to JSON Schema."""
type_name = getattr(python_type, "__name__", str(python_type))
if python_type is str or type_name == "str":
return {"type": "string"}
elif python_type is int or type_name == "int":
return {"type": "integer"}
elif python_type is float or type_name == "float":
return {"type": "number"}
elif python_type is bool or type_name == "bool":
return {"type": "boolean"}
elif type_name == "NoneType":
return {"type": "null"}
elif hasattr(python_type, "__origin__"):
origin = python_type.__origin__
args = getattr(python_type, "__args__", ())
if origin is list:
item_type = args[0] if args else Any
return {"type": "array", "items": _python_type_to_json_schema(item_type)}
elif origin is dict:
return {"type": "object"}
elif origin is type(None) or str(origin) == "typing.Union":
# Handle Optional types (Union with None)
non_none_args = [a for a in args if a is not type(None)]
if len(non_none_args) == 1:
schema = _python_type_to_json_schema(non_none_args[0])
schema["nullable"] = True
return schema
return {"type": "object"}
return {"type": "object"}
def _get_tool_schema(func: Any) -> dict[str, Any]:
"""Extract JSON Schema from a tool function."""
sig = inspect.signature(func)
hints = get_type_hints(func) if hasattr(func, "__annotations__") else {}
properties: dict[str, Any] = {}
required: list[str] = []
for name, param in sig.parameters.items():
if name in ("self", "cls"):
continue
prop: dict[str, Any] = {}
# Get type from hints
if name in hints:
prop = _python_type_to_json_schema(hints[name])
# Get description and constraints from Field default (FieldInfo object)
default_val = param.default
if hasattr(default_val, "description") and default_val.description:
prop["description"] = default_val.description
if hasattr(default_val, "ge") and default_val.ge is not None:
prop["minimum"] = default_val.ge
if hasattr(default_val, "le") and default_val.le is not None:
prop["maximum"] = default_val.le
# Handle Field default value (check for PydanticUndefined)
if hasattr(default_val, "default"):
field_default = default_val.default
# Check if it's the "required" sentinel (...)
if field_default is not ... and not (
hasattr(field_default, "__class__")
and "PydanticUndefined" in field_default.__class__.__name__
):
prop["default"] = field_default
# Determine if required
if param.default is inspect.Parameter.empty:
required.append(name)
elif hasattr(default_val, "default"):
field_default = default_val.default
# Required if default is ellipsis or PydanticUndefined
if field_default is ... or (
hasattr(field_default, "__class__")
and "PydanticUndefined" in field_default.__class__.__name__
):
required.append(name)
properties[name] = prop
return {
"type": "object",
"properties": properties,
"required": required,
}
def _register_tool(name: str, tool_or_func: Any, description: str | None = None) -> None:
"""Register a tool in the registry.
Handles both raw functions and FastMCP FunctionTool objects.
"""
# Extract the underlying function from FastMCP FunctionTool if needed
if hasattr(tool_or_func, "fn"):
func = tool_or_func.fn
# Use FunctionTool's description if available
if not description and hasattr(tool_or_func, "description") and tool_or_func.description:
description = tool_or_func.description
else:
func = tool_or_func
_tool_registry[name] = {
"func": func,
"description": description or (func.__doc__ or "").strip(),
"schema": _get_tool_schema(func),
}
@app.get("/mcp/tools")
async def list_mcp_tools() -> dict[str, Any]:
"""
Return list of available MCP tools with their schemas.
This endpoint enables tool discovery for the backend MCP client.
"""
tools = []
for name, info in _tool_registry.items():
tools.append({
"name": name,
"description": info["description"],
"inputSchema": info["schema"],
})
return {"tools": tools}
@app.post("/mcp")
async def mcp_rpc(request: Request) -> JSONResponse:
"""
JSON-RPC 2.0 endpoint for MCP tool execution.
Request format:
{
"jsonrpc": "2.0",
"method": "<tool_name>",
"params": {...},
"id": <request_id>
}
Response format:
{
"jsonrpc": "2.0",
"result": {...},
"id": <request_id>
}
"""
try:
body = await request.json()
except Exception as e:
return JSONResponse(
status_code=400,
content={
"jsonrpc": "2.0",
"error": {"code": -32700, "message": f"Parse error: {e}"},
"id": None,
},
)
# Validate JSON-RPC structure
jsonrpc = body.get("jsonrpc")
method = body.get("method")
params = body.get("params", {})
request_id = body.get("id")
if jsonrpc != "2.0":
return JSONResponse(
status_code=400,
content={
"jsonrpc": "2.0",
"error": {"code": -32600, "message": "Invalid Request: jsonrpc must be '2.0'"},
"id": request_id,
},
)
if not method:
return JSONResponse(
status_code=400,
content={
"jsonrpc": "2.0",
"error": {"code": -32600, "message": "Invalid Request: method is required"},
"id": request_id,
},
)
# Look up tool
tool_info = _tool_registry.get(method)
if not tool_info:
return JSONResponse(
status_code=404,
content={
"jsonrpc": "2.0",
"error": {"code": -32601, "message": f"Method not found: {method}"},
"id": request_id,
},
)
# Execute tool
try:
func = tool_info["func"]
# Resolve Field defaults for missing parameters
sig = inspect.signature(func)
resolved_params = dict(params)
for name, param in sig.parameters.items():
if name not in resolved_params:
default_val = param.default
# Check if it's a FieldInfo with a default value
if hasattr(default_val, "default"):
field_default = default_val.default
# Only use if it has an actual default (not required)
if field_default is not ... and not (
hasattr(field_default, "__class__")
and "PydanticUndefined" in field_default.__class__.__name__
):
resolved_params[name] = field_default
result = await func(**resolved_params)
return JSONResponse(
content={
"jsonrpc": "2.0",
"result": result,
"id": request_id,
}
)
except TypeError as e:
return JSONResponse(
status_code=400,
content={
"jsonrpc": "2.0",
"error": {"code": -32602, "message": f"Invalid params: {e}"},
"id": request_id,
},
)
except Exception as e:
logger.error(f"Tool execution error: {e}")
return JSONResponse(
status_code=500,
content={
"jsonrpc": "2.0",
"error": {"code": -32000, "message": f"Server error: {e}"},
"id": request_id,
},
)
# MCP Tools
@@ -261,6 +516,15 @@ async def ingest_content(
the LLM Gateway, and stored in pgvector for search.
"""
try:
# Validate content size to prevent DoS
settings = get_settings()
content_size = len(content.encode("utf-8"))
if content_size > settings.max_document_size:
return {
"success": False,
"error": f"Content size ({content_size} bytes) exceeds maximum allowed ({settings.max_document_size} bytes)",
}
# Parse chunk type
try:
chunk_type_enum = ChunkType(chunk_type.lower())
@@ -492,6 +756,15 @@ async def update_document(
Replaces all existing chunks for the source path with new content.
"""
try:
# Validate content size to prevent DoS
settings = get_settings()
content_size = len(content.encode("utf-8"))
if content_size > settings.max_document_size:
return {
"success": False,
"error": f"Content size ({content_size} bytes) exceeds maximum allowed ({settings.max_document_size} bytes)",
}
# Parse chunk type
try:
chunk_type_enum = ChunkType(chunk_type.lower())
@@ -550,6 +823,16 @@ async def update_document(
}
# Register tools in the JSON-RPC registry
# This must happen after tool functions are defined
_register_tool("search_knowledge", search_knowledge)
_register_tool("ingest_content", ingest_content)
_register_tool("delete_content", delete_content)
_register_tool("list_collections", list_collections)
_register_tool("get_collection_stats", get_collection_stats)
_register_tool("update_document", update_document)
def main() -> None:
"""Run the server."""
import uvicorn

View File

@@ -1,9 +1,11 @@
"""Tests for server module and MCP tools."""
import json
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
class TestHealthCheck:
@@ -355,3 +357,248 @@ class TestUpdateDocumentTool:
assert result["success"] is False
assert "Invalid chunk type" in result["error"]
class TestMCPToolsEndpoint:
"""Tests for /mcp/tools endpoint."""
def test_list_mcp_tools(self):
"""Test listing available MCP tools."""
import server
client = TestClient(server.app)
response = client.get("/mcp/tools")
assert response.status_code == 200
data = response.json()
assert "tools" in data
assert len(data["tools"]) == 6 # 6 tools registered
tool_names = [t["name"] for t in data["tools"]]
assert "search_knowledge" in tool_names
assert "ingest_content" in tool_names
assert "delete_content" in tool_names
assert "list_collections" in tool_names
assert "get_collection_stats" in tool_names
assert "update_document" in tool_names
def test_tool_has_schema(self):
"""Test that each tool has input schema."""
import server
client = TestClient(server.app)
response = client.get("/mcp/tools")
data = response.json()
for tool in data["tools"]:
assert "inputSchema" in tool
assert "type" in tool["inputSchema"]
assert tool["inputSchema"]["type"] == "object"
class TestMCPRPCEndpoint:
"""Tests for /mcp JSON-RPC endpoint."""
def test_valid_jsonrpc_request(self):
"""Test valid JSON-RPC request."""
import server
from models import SearchResponse, SearchResult
mock_search = MagicMock()
mock_search.search = AsyncMock(
return_value=SearchResponse(
query="test",
search_type="hybrid",
results=[
SearchResult(
id="id-1",
content="Test",
score=0.9,
source_path="/test.py",
chunk_type="code",
collection="default",
)
],
total_results=1,
search_time_ms=5.0,
)
)
server._search = mock_search
client = TestClient(server.app)
response = client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "search_knowledge",
"params": {
"project_id": "proj-123",
"agent_id": "agent-456",
"query": "test",
},
"id": 1,
},
)
assert response.status_code == 200
data = response.json()
assert data["jsonrpc"] == "2.0"
assert data["id"] == 1
assert "result" in data
assert data["result"]["success"] is True
def test_invalid_jsonrpc_version(self):
"""Test request with invalid JSON-RPC version."""
import server
client = TestClient(server.app)
response = client.post(
"/mcp",
json={
"jsonrpc": "1.0",
"method": "search_knowledge",
"params": {},
"id": 1,
},
)
assert response.status_code == 400
data = response.json()
assert data["error"]["code"] == -32600
assert "jsonrpc must be '2.0'" in data["error"]["message"]
def test_missing_method(self):
"""Test request without method."""
import server
client = TestClient(server.app)
response = client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"params": {},
"id": 1,
},
)
assert response.status_code == 400
data = response.json()
assert data["error"]["code"] == -32600
assert "method is required" in data["error"]["message"]
def test_unknown_method(self):
"""Test request with unknown method."""
import server
client = TestClient(server.app)
response = client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "unknown_method",
"params": {},
"id": 1,
},
)
assert response.status_code == 404
data = response.json()
assert data["error"]["code"] == -32601
assert "Method not found" in data["error"]["message"]
def test_invalid_params(self):
"""Test request with invalid params."""
import server
client = TestClient(server.app)
response = client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "search_knowledge",
"params": {"invalid_param": "value"}, # Missing required params
"id": 1,
},
)
assert response.status_code == 400
data = response.json()
assert data["error"]["code"] == -32602
class TestContentSizeLimits:
"""Tests for content size validation."""
@pytest.mark.asyncio
async def test_ingest_rejects_oversized_content(self):
"""Test that ingest rejects content exceeding size limit."""
import server
from config import get_settings
settings = get_settings()
# Create content larger than max size
oversized_content = "x" * (settings.max_document_size + 1)
result = await server.ingest_content.fn(
project_id="proj-123",
agent_id="agent-456",
content=oversized_content,
chunk_type="text",
)
assert result["success"] is False
assert "exceeds maximum" in result["error"]
@pytest.mark.asyncio
async def test_update_rejects_oversized_content(self):
"""Test that update rejects content exceeding size limit."""
import server
from config import get_settings
settings = get_settings()
oversized_content = "x" * (settings.max_document_size + 1)
result = await server.update_document.fn(
project_id="proj-123",
agent_id="agent-456",
source_path="/test.py",
content=oversized_content,
chunk_type="text",
)
assert result["success"] is False
assert "exceeds maximum" in result["error"]
@pytest.mark.asyncio
async def test_ingest_accepts_valid_size_content(self):
"""Test that ingest accepts content within size limit."""
import server
from models import IngestResult
mock_collections = MagicMock()
mock_collections.ingest = AsyncMock(
return_value=IngestResult(
success=True,
chunks_created=1,
embeddings_generated=1,
source_path="/test.py",
collection="default",
chunk_ids=["id-1"],
)
)
server._collections = mock_collections
# Small content that's within limits
# Pass all parameters to avoid Field default resolution issues
result = await server.ingest_content.fn(
project_id="proj-123",
agent_id="agent-456",
content="def hello(): pass",
source_path="/test.py",
collection="default",
chunk_type="text",
file_type=None,
metadata=None,
)
assert result["success"] is True