From 953af52d0e15d5edfb126276ac041037b2cc09d4 Mon Sep 17 00:00:00 2001 From: Felipe Cardoso Date: Sun, 4 Jan 2026 01:03:58 +0100 Subject: [PATCH] fix(mcp-kb): address critical issues from deep review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix SQL HAVING clause bug by using CTE approach (closes #73) - Add /mcp JSON-RPC 2.0 endpoint for tool execution (closes #74) - Add /mcp/tools endpoint for tool discovery (closes #75) - Add content size limits to prevent DoS attacks (closes #78) - Add comprehensive tests for new endpoints 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- mcp-servers/knowledge-base/config.py | 14 + mcp-servers/knowledge-base/database.py | 16 +- mcp-servers/knowledge-base/server.py | 287 +++++++++++++++++- .../knowledge-base/tests/test_server.py | 249 ++++++++++++++- 4 files changed, 557 insertions(+), 9 deletions(-) diff --git a/mcp-servers/knowledge-base/config.py b/mcp-servers/knowledge-base/config.py index a174ea2..ad9ed78 100644 --- a/mcp-servers/knowledge-base/config.py +++ b/mcp-servers/knowledge-base/config.py @@ -112,6 +112,20 @@ class Settings(BaseSettings): description="TTL for embedding records in days (0 = no expiry)", ) + # Content size limits (DoS prevention) + max_document_size: int = Field( + default=10 * 1024 * 1024, # 10 MB + description="Maximum size of a single document in bytes", + ) + max_batch_size: int = Field( + default=100, + description="Maximum number of documents in a batch operation", + ) + max_batch_total_size: int = Field( + default=50 * 1024 * 1024, # 50 MB + description="Maximum total size of all documents in a batch", + ) + model_config = {"env_prefix": "KB_", "env_file": ".env", "extra": "ignore"} diff --git a/mcp-servers/knowledge-base/database.py b/mcp-servers/knowledge-base/database.py index ab739c8..2db9fba 100644 --- a/mcp-servers/knowledge-base/database.py +++ b/mcp-servers/knowledge-base/database.py @@ -345,8 +345,9 @@ class DatabaseManager: """ try: async with self.acquire() as conn: - # Build query with optional filters - query = """ + # Build query with optional filters using CTE to filter by similarity + # We use a CTE to compute similarity once, then filter in outer query + inner_query = """ SELECT id, project_id, collection, content, embedding, chunk_type, source_path, start_line, end_line, @@ -361,18 +362,21 @@ class DatabaseManager: param_idx = 3 if collection: - query += f" AND collection = ${param_idx}" + inner_query += f" AND collection = ${param_idx}" params.append(collection) param_idx += 1 if file_types: file_type_values = [ft.value for ft in file_types] - query += f" AND file_type = ANY(${param_idx})" + inner_query += f" AND file_type = ANY(${param_idx})" params.append(file_type_values) param_idx += 1 - query += f""" - HAVING 1 - (embedding <=> $1) >= ${param_idx} + # Wrap in CTE and filter by threshold in outer query + query = f""" + WITH scored AS ({inner_query}) + SELECT * FROM scored + WHERE similarity >= ${param_idx} ORDER BY similarity DESC LIMIT ${param_idx + 1} """ diff --git a/mcp-servers/knowledge-base/server.py b/mcp-servers/knowledge-base/server.py index 380b2aa..0e2dc9e 100644 --- a/mcp-servers/knowledge-base/server.py +++ b/mcp-servers/knowledge-base/server.py @@ -5,11 +5,13 @@ Provides RAG capabilities with pgvector for semantic search, intelligent chunking, and collection management. """ +import inspect import logging from contextlib import asynccontextmanager -from typing import Any +from typing import Any, get_type_hints -from fastapi import FastAPI +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse from fastmcp import FastMCP from pydantic import Field @@ -116,6 +118,259 @@ async def health_check() -> dict[str, Any]: return status +# Tool registry for JSON-RPC +_tool_registry: dict[str, Any] = {} + + +def _python_type_to_json_schema(python_type: Any) -> dict[str, Any]: + """Convert Python type annotation to JSON Schema.""" + type_name = getattr(python_type, "__name__", str(python_type)) + + if python_type is str or type_name == "str": + return {"type": "string"} + elif python_type is int or type_name == "int": + return {"type": "integer"} + elif python_type is float or type_name == "float": + return {"type": "number"} + elif python_type is bool or type_name == "bool": + return {"type": "boolean"} + elif type_name == "NoneType": + return {"type": "null"} + elif hasattr(python_type, "__origin__"): + origin = python_type.__origin__ + args = getattr(python_type, "__args__", ()) + + if origin is list: + item_type = args[0] if args else Any + return {"type": "array", "items": _python_type_to_json_schema(item_type)} + elif origin is dict: + return {"type": "object"} + elif origin is type(None) or str(origin) == "typing.Union": + # Handle Optional types (Union with None) + non_none_args = [a for a in args if a is not type(None)] + if len(non_none_args) == 1: + schema = _python_type_to_json_schema(non_none_args[0]) + schema["nullable"] = True + return schema + return {"type": "object"} + return {"type": "object"} + + +def _get_tool_schema(func: Any) -> dict[str, Any]: + """Extract JSON Schema from a tool function.""" + sig = inspect.signature(func) + hints = get_type_hints(func) if hasattr(func, "__annotations__") else {} + + properties: dict[str, Any] = {} + required: list[str] = [] + + for name, param in sig.parameters.items(): + if name in ("self", "cls"): + continue + + prop: dict[str, Any] = {} + + # Get type from hints + if name in hints: + prop = _python_type_to_json_schema(hints[name]) + + # Get description and constraints from Field default (FieldInfo object) + default_val = param.default + if hasattr(default_val, "description") and default_val.description: + prop["description"] = default_val.description + if hasattr(default_val, "ge") and default_val.ge is not None: + prop["minimum"] = default_val.ge + if hasattr(default_val, "le") and default_val.le is not None: + prop["maximum"] = default_val.le + # Handle Field default value (check for PydanticUndefined) + if hasattr(default_val, "default"): + field_default = default_val.default + # Check if it's the "required" sentinel (...) + if field_default is not ... and not ( + hasattr(field_default, "__class__") + and "PydanticUndefined" in field_default.__class__.__name__ + ): + prop["default"] = field_default + + # Determine if required + if param.default is inspect.Parameter.empty: + required.append(name) + elif hasattr(default_val, "default"): + field_default = default_val.default + # Required if default is ellipsis or PydanticUndefined + if field_default is ... or ( + hasattr(field_default, "__class__") + and "PydanticUndefined" in field_default.__class__.__name__ + ): + required.append(name) + + properties[name] = prop + + return { + "type": "object", + "properties": properties, + "required": required, + } + + +def _register_tool(name: str, tool_or_func: Any, description: str | None = None) -> None: + """Register a tool in the registry. + + Handles both raw functions and FastMCP FunctionTool objects. + """ + # Extract the underlying function from FastMCP FunctionTool if needed + if hasattr(tool_or_func, "fn"): + func = tool_or_func.fn + # Use FunctionTool's description if available + if not description and hasattr(tool_or_func, "description") and tool_or_func.description: + description = tool_or_func.description + else: + func = tool_or_func + + _tool_registry[name] = { + "func": func, + "description": description or (func.__doc__ or "").strip(), + "schema": _get_tool_schema(func), + } + + +@app.get("/mcp/tools") +async def list_mcp_tools() -> dict[str, Any]: + """ + Return list of available MCP tools with their schemas. + + This endpoint enables tool discovery for the backend MCP client. + """ + tools = [] + for name, info in _tool_registry.items(): + tools.append({ + "name": name, + "description": info["description"], + "inputSchema": info["schema"], + }) + + return {"tools": tools} + + +@app.post("/mcp") +async def mcp_rpc(request: Request) -> JSONResponse: + """ + JSON-RPC 2.0 endpoint for MCP tool execution. + + Request format: + { + "jsonrpc": "2.0", + "method": "", + "params": {...}, + "id": + } + + Response format: + { + "jsonrpc": "2.0", + "result": {...}, + "id": + } + """ + try: + body = await request.json() + except Exception as e: + return JSONResponse( + status_code=400, + content={ + "jsonrpc": "2.0", + "error": {"code": -32700, "message": f"Parse error: {e}"}, + "id": None, + }, + ) + + # Validate JSON-RPC structure + jsonrpc = body.get("jsonrpc") + method = body.get("method") + params = body.get("params", {}) + request_id = body.get("id") + + if jsonrpc != "2.0": + return JSONResponse( + status_code=400, + content={ + "jsonrpc": "2.0", + "error": {"code": -32600, "message": "Invalid Request: jsonrpc must be '2.0'"}, + "id": request_id, + }, + ) + + if not method: + return JSONResponse( + status_code=400, + content={ + "jsonrpc": "2.0", + "error": {"code": -32600, "message": "Invalid Request: method is required"}, + "id": request_id, + }, + ) + + # Look up tool + tool_info = _tool_registry.get(method) + if not tool_info: + return JSONResponse( + status_code=404, + content={ + "jsonrpc": "2.0", + "error": {"code": -32601, "message": f"Method not found: {method}"}, + "id": request_id, + }, + ) + + # Execute tool + try: + func = tool_info["func"] + + # Resolve Field defaults for missing parameters + sig = inspect.signature(func) + resolved_params = dict(params) + for name, param in sig.parameters.items(): + if name not in resolved_params: + default_val = param.default + # Check if it's a FieldInfo with a default value + if hasattr(default_val, "default"): + field_default = default_val.default + # Only use if it has an actual default (not required) + if field_default is not ... and not ( + hasattr(field_default, "__class__") + and "PydanticUndefined" in field_default.__class__.__name__ + ): + resolved_params[name] = field_default + + result = await func(**resolved_params) + return JSONResponse( + content={ + "jsonrpc": "2.0", + "result": result, + "id": request_id, + } + ) + except TypeError as e: + return JSONResponse( + status_code=400, + content={ + "jsonrpc": "2.0", + "error": {"code": -32602, "message": f"Invalid params: {e}"}, + "id": request_id, + }, + ) + except Exception as e: + logger.error(f"Tool execution error: {e}") + return JSONResponse( + status_code=500, + content={ + "jsonrpc": "2.0", + "error": {"code": -32000, "message": f"Server error: {e}"}, + "id": request_id, + }, + ) + + # MCP Tools @@ -261,6 +516,15 @@ async def ingest_content( the LLM Gateway, and stored in pgvector for search. """ try: + # Validate content size to prevent DoS + settings = get_settings() + content_size = len(content.encode("utf-8")) + if content_size > settings.max_document_size: + return { + "success": False, + "error": f"Content size ({content_size} bytes) exceeds maximum allowed ({settings.max_document_size} bytes)", + } + # Parse chunk type try: chunk_type_enum = ChunkType(chunk_type.lower()) @@ -492,6 +756,15 @@ async def update_document( Replaces all existing chunks for the source path with new content. """ try: + # Validate content size to prevent DoS + settings = get_settings() + content_size = len(content.encode("utf-8")) + if content_size > settings.max_document_size: + return { + "success": False, + "error": f"Content size ({content_size} bytes) exceeds maximum allowed ({settings.max_document_size} bytes)", + } + # Parse chunk type try: chunk_type_enum = ChunkType(chunk_type.lower()) @@ -550,6 +823,16 @@ async def update_document( } +# Register tools in the JSON-RPC registry +# This must happen after tool functions are defined +_register_tool("search_knowledge", search_knowledge) +_register_tool("ingest_content", ingest_content) +_register_tool("delete_content", delete_content) +_register_tool("list_collections", list_collections) +_register_tool("get_collection_stats", get_collection_stats) +_register_tool("update_document", update_document) + + def main() -> None: """Run the server.""" import uvicorn diff --git a/mcp-servers/knowledge-base/tests/test_server.py b/mcp-servers/knowledge-base/tests/test_server.py index 3c6ffff..2dcc771 100644 --- a/mcp-servers/knowledge-base/tests/test_server.py +++ b/mcp-servers/knowledge-base/tests/test_server.py @@ -1,9 +1,11 @@ """Tests for server module and MCP tools.""" +import json from datetime import UTC, datetime -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch import pytest +from fastapi.testclient import TestClient class TestHealthCheck: @@ -355,3 +357,248 @@ class TestUpdateDocumentTool: assert result["success"] is False assert "Invalid chunk type" in result["error"] + + +class TestMCPToolsEndpoint: + """Tests for /mcp/tools endpoint.""" + + def test_list_mcp_tools(self): + """Test listing available MCP tools.""" + import server + + client = TestClient(server.app) + response = client.get("/mcp/tools") + + assert response.status_code == 200 + data = response.json() + assert "tools" in data + assert len(data["tools"]) == 6 # 6 tools registered + + tool_names = [t["name"] for t in data["tools"]] + assert "search_knowledge" in tool_names + assert "ingest_content" in tool_names + assert "delete_content" in tool_names + assert "list_collections" in tool_names + assert "get_collection_stats" in tool_names + assert "update_document" in tool_names + + def test_tool_has_schema(self): + """Test that each tool has input schema.""" + import server + + client = TestClient(server.app) + response = client.get("/mcp/tools") + + data = response.json() + for tool in data["tools"]: + assert "inputSchema" in tool + assert "type" in tool["inputSchema"] + assert tool["inputSchema"]["type"] == "object" + + +class TestMCPRPCEndpoint: + """Tests for /mcp JSON-RPC endpoint.""" + + def test_valid_jsonrpc_request(self): + """Test valid JSON-RPC request.""" + import server + from models import SearchResponse, SearchResult + + mock_search = MagicMock() + mock_search.search = AsyncMock( + return_value=SearchResponse( + query="test", + search_type="hybrid", + results=[ + SearchResult( + id="id-1", + content="Test", + score=0.9, + source_path="/test.py", + chunk_type="code", + collection="default", + ) + ], + total_results=1, + search_time_ms=5.0, + ) + ) + server._search = mock_search + + client = TestClient(server.app) + response = client.post( + "/mcp", + json={ + "jsonrpc": "2.0", + "method": "search_knowledge", + "params": { + "project_id": "proj-123", + "agent_id": "agent-456", + "query": "test", + }, + "id": 1, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["jsonrpc"] == "2.0" + assert data["id"] == 1 + assert "result" in data + assert data["result"]["success"] is True + + def test_invalid_jsonrpc_version(self): + """Test request with invalid JSON-RPC version.""" + import server + + client = TestClient(server.app) + response = client.post( + "/mcp", + json={ + "jsonrpc": "1.0", + "method": "search_knowledge", + "params": {}, + "id": 1, + }, + ) + + assert response.status_code == 400 + data = response.json() + assert data["error"]["code"] == -32600 + assert "jsonrpc must be '2.0'" in data["error"]["message"] + + def test_missing_method(self): + """Test request without method.""" + import server + + client = TestClient(server.app) + response = client.post( + "/mcp", + json={ + "jsonrpc": "2.0", + "params": {}, + "id": 1, + }, + ) + + assert response.status_code == 400 + data = response.json() + assert data["error"]["code"] == -32600 + assert "method is required" in data["error"]["message"] + + def test_unknown_method(self): + """Test request with unknown method.""" + import server + + client = TestClient(server.app) + response = client.post( + "/mcp", + json={ + "jsonrpc": "2.0", + "method": "unknown_method", + "params": {}, + "id": 1, + }, + ) + + assert response.status_code == 404 + data = response.json() + assert data["error"]["code"] == -32601 + assert "Method not found" in data["error"]["message"] + + def test_invalid_params(self): + """Test request with invalid params.""" + import server + + client = TestClient(server.app) + response = client.post( + "/mcp", + json={ + "jsonrpc": "2.0", + "method": "search_knowledge", + "params": {"invalid_param": "value"}, # Missing required params + "id": 1, + }, + ) + + assert response.status_code == 400 + data = response.json() + assert data["error"]["code"] == -32602 + + +class TestContentSizeLimits: + """Tests for content size validation.""" + + @pytest.mark.asyncio + async def test_ingest_rejects_oversized_content(self): + """Test that ingest rejects content exceeding size limit.""" + import server + from config import get_settings + + settings = get_settings() + # Create content larger than max size + oversized_content = "x" * (settings.max_document_size + 1) + + result = await server.ingest_content.fn( + project_id="proj-123", + agent_id="agent-456", + content=oversized_content, + chunk_type="text", + ) + + assert result["success"] is False + assert "exceeds maximum" in result["error"] + + @pytest.mark.asyncio + async def test_update_rejects_oversized_content(self): + """Test that update rejects content exceeding size limit.""" + import server + from config import get_settings + + settings = get_settings() + oversized_content = "x" * (settings.max_document_size + 1) + + result = await server.update_document.fn( + project_id="proj-123", + agent_id="agent-456", + source_path="/test.py", + content=oversized_content, + chunk_type="text", + ) + + assert result["success"] is False + assert "exceeds maximum" in result["error"] + + @pytest.mark.asyncio + async def test_ingest_accepts_valid_size_content(self): + """Test that ingest accepts content within size limit.""" + import server + from models import IngestResult + + mock_collections = MagicMock() + mock_collections.ingest = AsyncMock( + return_value=IngestResult( + success=True, + chunks_created=1, + embeddings_generated=1, + source_path="/test.py", + collection="default", + chunk_ids=["id-1"], + ) + ) + server._collections = mock_collections + + # Small content that's within limits + # Pass all parameters to avoid Field default resolution issues + result = await server.ingest_content.fn( + project_id="proj-123", + agent_id="agent-456", + content="def hello(): pass", + source_path="/test.py", + collection="default", + chunk_type="text", + file_type=None, + metadata=None, + ) + + assert result["success"] is True