forked from cardosofelipe/fast-next-template
Security fixes from deep review: - Add input validation patterns for project_id, agent_id, collection - Add path traversal protection for source_path (reject .., null bytes) - Add error codes (INTERNAL_ERROR) to generic exception handlers - Handle FieldInfo objects in validation for test robustness Performance fixes: - Enable concurrent hybrid search with asyncio.gather Health endpoint improvements: - Check all dependencies (database, Redis, LLM Gateway) - Return degraded/unhealthy status based on dependency health - Updated tests for new health check response structure All 139 tests pass. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1034 lines
34 KiB
Python
1034 lines
34 KiB
Python
"""
|
|
Knowledge Base MCP Server.
|
|
|
|
Provides RAG capabilities with pgvector for semantic search,
|
|
intelligent chunking, and collection management.
|
|
"""
|
|
|
|
import inspect
|
|
import logging
|
|
import re
|
|
from contextlib import asynccontextmanager
|
|
from typing import Any, get_type_hints
|
|
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import JSONResponse
|
|
from fastmcp import FastMCP
|
|
from pydantic import Field
|
|
|
|
from collection_manager import CollectionManager, get_collection_manager
|
|
from collections.abc import AsyncIterator
|
|
from config import get_settings
|
|
from database import DatabaseManager, get_database_manager
|
|
from embeddings import EmbeddingGenerator, get_embedding_generator
|
|
from exceptions import ErrorCode, KnowledgeBaseError
|
|
from models import (
|
|
ChunkType,
|
|
DeleteRequest,
|
|
FileType,
|
|
IngestRequest,
|
|
SearchRequest,
|
|
SearchType,
|
|
)
|
|
from search import SearchEngine, get_search_engine
|
|
|
|
# Input validation patterns
|
|
# Allow alphanumeric, hyphens, underscores (1-128 chars)
|
|
ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,128}$")
|
|
# Collection names: alphanumeric, hyphens, underscores (1-64 chars)
|
|
COLLECTION_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,64}$")
|
|
|
|
|
|
def _validate_id(value: str, field_name: str) -> str | None:
|
|
"""Validate project_id or agent_id format.
|
|
|
|
Returns error message if invalid, None if valid.
|
|
"""
|
|
# Handle FieldInfo objects from direct .fn() calls in tests
|
|
if not isinstance(value, str):
|
|
return f"{field_name} must be a string"
|
|
if not value:
|
|
return f"{field_name} is required"
|
|
if not ID_PATTERN.match(value):
|
|
return f"Invalid {field_name}: must be 1-128 alphanumeric characters, hyphens, or underscores"
|
|
return None
|
|
|
|
|
|
def _validate_collection(value: str) -> str | None:
|
|
"""Validate collection name format.
|
|
|
|
Returns error message if invalid, None if valid.
|
|
"""
|
|
# Handle FieldInfo objects from direct .fn() calls in tests
|
|
if not isinstance(value, str):
|
|
return None # Non-string means default not resolved, skip validation
|
|
if not COLLECTION_PATTERN.match(value):
|
|
return "Invalid collection: must be 1-64 alphanumeric characters, hyphens, or underscores"
|
|
return None
|
|
|
|
|
|
def _validate_source_path(value: str | None) -> str | None:
|
|
"""Validate source_path to prevent path traversal.
|
|
|
|
Returns error message if invalid, None if valid.
|
|
"""
|
|
if value is None:
|
|
return None
|
|
|
|
# Handle FieldInfo objects from direct .fn() calls in tests
|
|
if not isinstance(value, str):
|
|
return None # Non-string means default not resolved, skip validation
|
|
|
|
# Normalize path and check for traversal attempts
|
|
if ".." in value:
|
|
return "Invalid source_path: path traversal not allowed"
|
|
|
|
# Check for null bytes (used in some injection attacks)
|
|
if "\x00" in value:
|
|
return "Invalid source_path: null bytes not allowed"
|
|
|
|
# Limit path length to prevent DoS
|
|
if len(value) > 4096:
|
|
return "Invalid source_path: path too long (max 4096 chars)"
|
|
|
|
return None
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Global instances
|
|
_database: DatabaseManager | None = None
|
|
_embeddings: EmbeddingGenerator | None = None
|
|
_search: SearchEngine | None = None
|
|
_collections: CollectionManager | None = None
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
|
|
"""Application lifespan handler."""
|
|
global _database, _embeddings, _search, _collections
|
|
|
|
logger.info("Starting Knowledge Base MCP Server...")
|
|
|
|
# Initialize database
|
|
_database = get_database_manager()
|
|
await _database.initialize()
|
|
|
|
# Initialize embedding generator
|
|
_embeddings = get_embedding_generator()
|
|
await _embeddings.initialize()
|
|
|
|
# Initialize search engine
|
|
_search = get_search_engine()
|
|
|
|
# Initialize collection manager
|
|
_collections = get_collection_manager()
|
|
|
|
logger.info("Knowledge Base MCP Server started successfully")
|
|
|
|
yield
|
|
|
|
# Cleanup
|
|
logger.info("Shutting down Knowledge Base MCP Server...")
|
|
|
|
if _embeddings:
|
|
await _embeddings.close()
|
|
|
|
if _database:
|
|
await _database.close()
|
|
|
|
logger.info("Knowledge Base MCP Server shut down")
|
|
|
|
|
|
# Create FastMCP server
|
|
mcp = FastMCP("syndarix-knowledge-base")
|
|
|
|
# Create FastAPI app with lifespan
|
|
app = FastAPI(
|
|
title="Knowledge Base MCP Server",
|
|
description="RAG with pgvector for semantic search",
|
|
version="0.1.0",
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
|
|
@app.get("/health")
|
|
async def health_check() -> dict[str, Any]:
|
|
"""Health check endpoint.
|
|
|
|
Checks all dependencies: database, Redis cache, and LLM Gateway.
|
|
Returns degraded status if any non-critical dependency fails.
|
|
Returns unhealthy status if critical dependencies fail.
|
|
"""
|
|
from datetime import UTC, datetime
|
|
|
|
status: dict[str, Any] = {
|
|
"status": "healthy",
|
|
"service": "knowledge-base",
|
|
"version": "0.1.0",
|
|
"timestamp": datetime.now(UTC).isoformat(),
|
|
"dependencies": {},
|
|
}
|
|
|
|
is_degraded = False
|
|
is_unhealthy = False
|
|
|
|
# Check database connection (critical)
|
|
try:
|
|
if _database and _database._pool:
|
|
async with _database.acquire() as conn:
|
|
await conn.fetchval("SELECT 1")
|
|
status["dependencies"]["database"] = "connected"
|
|
else:
|
|
status["dependencies"]["database"] = "not initialized"
|
|
is_unhealthy = True
|
|
except Exception as e:
|
|
status["dependencies"]["database"] = f"error: {e}"
|
|
is_unhealthy = True
|
|
|
|
# Check Redis cache (non-critical - degraded without it)
|
|
try:
|
|
if _embeddings and _embeddings._redis:
|
|
await _embeddings._redis.ping()
|
|
status["dependencies"]["redis"] = "connected"
|
|
else:
|
|
status["dependencies"]["redis"] = "not initialized"
|
|
is_degraded = True
|
|
except Exception as e:
|
|
status["dependencies"]["redis"] = f"error: {e}"
|
|
is_degraded = True
|
|
|
|
# Check LLM Gateway connectivity (non-critical for health check)
|
|
try:
|
|
if _embeddings and _embeddings._http_client:
|
|
settings = get_settings()
|
|
response = await _embeddings._http_client.get(
|
|
f"{settings.llm_gateway_url}/health",
|
|
timeout=5.0,
|
|
)
|
|
if response.status_code == 200:
|
|
status["dependencies"]["llm_gateway"] = "connected"
|
|
else:
|
|
status["dependencies"]["llm_gateway"] = f"unhealthy (status {response.status_code})"
|
|
is_degraded = True
|
|
else:
|
|
status["dependencies"]["llm_gateway"] = "not initialized"
|
|
is_degraded = True
|
|
except Exception as e:
|
|
status["dependencies"]["llm_gateway"] = f"error: {e}"
|
|
is_degraded = True
|
|
|
|
# Set overall status
|
|
if is_unhealthy:
|
|
status["status"] = "unhealthy"
|
|
elif is_degraded:
|
|
status["status"] = "degraded"
|
|
else:
|
|
status["status"] = "healthy"
|
|
|
|
return status
|
|
|
|
|
|
# Tool registry for JSON-RPC
|
|
_tool_registry: dict[str, Any] = {}
|
|
|
|
|
|
def _python_type_to_json_schema(python_type: Any) -> dict[str, Any]:
|
|
"""Convert Python type annotation to JSON Schema."""
|
|
type_name = getattr(python_type, "__name__", str(python_type))
|
|
|
|
if python_type is str or type_name == "str":
|
|
return {"type": "string"}
|
|
elif python_type is int or type_name == "int":
|
|
return {"type": "integer"}
|
|
elif python_type is float or type_name == "float":
|
|
return {"type": "number"}
|
|
elif python_type is bool or type_name == "bool":
|
|
return {"type": "boolean"}
|
|
elif type_name == "NoneType":
|
|
return {"type": "null"}
|
|
elif hasattr(python_type, "__origin__"):
|
|
origin = python_type.__origin__
|
|
args = getattr(python_type, "__args__", ())
|
|
|
|
if origin is list:
|
|
item_type = args[0] if args else Any
|
|
return {"type": "array", "items": _python_type_to_json_schema(item_type)}
|
|
elif origin is dict:
|
|
return {"type": "object"}
|
|
elif origin is type(None) or str(origin) == "typing.Union":
|
|
# Handle Optional types (Union with None)
|
|
non_none_args = [a for a in args if a is not type(None)]
|
|
if len(non_none_args) == 1:
|
|
schema = _python_type_to_json_schema(non_none_args[0])
|
|
schema["nullable"] = True
|
|
return schema
|
|
return {"type": "object"}
|
|
return {"type": "object"}
|
|
|
|
|
|
def _get_tool_schema(func: Any) -> dict[str, Any]:
|
|
"""Extract JSON Schema from a tool function."""
|
|
sig = inspect.signature(func)
|
|
hints = get_type_hints(func) if hasattr(func, "__annotations__") else {}
|
|
|
|
properties: dict[str, Any] = {}
|
|
required: list[str] = []
|
|
|
|
for name, param in sig.parameters.items():
|
|
if name in ("self", "cls"):
|
|
continue
|
|
|
|
prop: dict[str, Any] = {}
|
|
|
|
# Get type from hints
|
|
if name in hints:
|
|
prop = _python_type_to_json_schema(hints[name])
|
|
|
|
# Get description and constraints from Field default (FieldInfo object)
|
|
default_val = param.default
|
|
if hasattr(default_val, "description") and default_val.description:
|
|
prop["description"] = default_val.description
|
|
if hasattr(default_val, "ge") and default_val.ge is not None:
|
|
prop["minimum"] = default_val.ge
|
|
if hasattr(default_val, "le") and default_val.le is not None:
|
|
prop["maximum"] = default_val.le
|
|
# Handle Field default value (check for PydanticUndefined)
|
|
if hasattr(default_val, "default"):
|
|
field_default = default_val.default
|
|
# Check if it's the "required" sentinel (...)
|
|
if field_default is not ... and not (
|
|
hasattr(field_default, "__class__")
|
|
and "PydanticUndefined" in field_default.__class__.__name__
|
|
):
|
|
prop["default"] = field_default
|
|
|
|
# Determine if required
|
|
if param.default is inspect.Parameter.empty:
|
|
required.append(name)
|
|
elif hasattr(default_val, "default"):
|
|
field_default = default_val.default
|
|
# Required if default is ellipsis or PydanticUndefined
|
|
if field_default is ... or (
|
|
hasattr(field_default, "__class__")
|
|
and "PydanticUndefined" in field_default.__class__.__name__
|
|
):
|
|
required.append(name)
|
|
|
|
properties[name] = prop
|
|
|
|
return {
|
|
"type": "object",
|
|
"properties": properties,
|
|
"required": required,
|
|
}
|
|
|
|
|
|
def _register_tool(name: str, tool_or_func: Any, description: str | None = None) -> None:
|
|
"""Register a tool in the registry.
|
|
|
|
Handles both raw functions and FastMCP FunctionTool objects.
|
|
"""
|
|
# Extract the underlying function from FastMCP FunctionTool if needed
|
|
if hasattr(tool_or_func, "fn"):
|
|
func = tool_or_func.fn
|
|
# Use FunctionTool's description if available
|
|
if not description and hasattr(tool_or_func, "description") and tool_or_func.description:
|
|
description = tool_or_func.description
|
|
else:
|
|
func = tool_or_func
|
|
|
|
_tool_registry[name] = {
|
|
"func": func,
|
|
"description": description or (func.__doc__ or "").strip(),
|
|
"schema": _get_tool_schema(func),
|
|
}
|
|
|
|
|
|
@app.get("/mcp/tools")
|
|
async def list_mcp_tools() -> dict[str, Any]:
|
|
"""
|
|
Return list of available MCP tools with their schemas.
|
|
|
|
This endpoint enables tool discovery for the backend MCP client.
|
|
"""
|
|
tools = []
|
|
for name, info in _tool_registry.items():
|
|
tools.append({
|
|
"name": name,
|
|
"description": info["description"],
|
|
"inputSchema": info["schema"],
|
|
})
|
|
|
|
return {"tools": tools}
|
|
|
|
|
|
@app.post("/mcp")
|
|
async def mcp_rpc(request: Request) -> JSONResponse:
|
|
"""
|
|
JSON-RPC 2.0 endpoint for MCP tool execution.
|
|
|
|
Request format:
|
|
{
|
|
"jsonrpc": "2.0",
|
|
"method": "<tool_name>",
|
|
"params": {...},
|
|
"id": <request_id>
|
|
}
|
|
|
|
Response format:
|
|
{
|
|
"jsonrpc": "2.0",
|
|
"result": {...},
|
|
"id": <request_id>
|
|
}
|
|
"""
|
|
try:
|
|
body = await request.json()
|
|
except Exception as e:
|
|
return JSONResponse(
|
|
status_code=400,
|
|
content={
|
|
"jsonrpc": "2.0",
|
|
"error": {"code": -32700, "message": f"Parse error: {e}"},
|
|
"id": None,
|
|
},
|
|
)
|
|
|
|
# Validate JSON-RPC structure
|
|
jsonrpc = body.get("jsonrpc")
|
|
method = body.get("method")
|
|
params = body.get("params", {})
|
|
request_id = body.get("id")
|
|
|
|
if jsonrpc != "2.0":
|
|
return JSONResponse(
|
|
status_code=400,
|
|
content={
|
|
"jsonrpc": "2.0",
|
|
"error": {"code": -32600, "message": "Invalid Request: jsonrpc must be '2.0'"},
|
|
"id": request_id,
|
|
},
|
|
)
|
|
|
|
if not method:
|
|
return JSONResponse(
|
|
status_code=400,
|
|
content={
|
|
"jsonrpc": "2.0",
|
|
"error": {"code": -32600, "message": "Invalid Request: method is required"},
|
|
"id": request_id,
|
|
},
|
|
)
|
|
|
|
# Look up tool
|
|
tool_info = _tool_registry.get(method)
|
|
if not tool_info:
|
|
return JSONResponse(
|
|
status_code=404,
|
|
content={
|
|
"jsonrpc": "2.0",
|
|
"error": {"code": -32601, "message": f"Method not found: {method}"},
|
|
"id": request_id,
|
|
},
|
|
)
|
|
|
|
# Execute tool
|
|
try:
|
|
func = tool_info["func"]
|
|
|
|
# Resolve Field defaults for missing parameters
|
|
sig = inspect.signature(func)
|
|
resolved_params = dict(params)
|
|
for name, param in sig.parameters.items():
|
|
if name not in resolved_params:
|
|
default_val = param.default
|
|
# Check if it's a FieldInfo with a default value
|
|
if hasattr(default_val, "default"):
|
|
field_default = default_val.default
|
|
# Only use if it has an actual default (not required)
|
|
if field_default is not ... and not (
|
|
hasattr(field_default, "__class__")
|
|
and "PydanticUndefined" in field_default.__class__.__name__
|
|
):
|
|
resolved_params[name] = field_default
|
|
|
|
result = await func(**resolved_params)
|
|
return JSONResponse(
|
|
content={
|
|
"jsonrpc": "2.0",
|
|
"result": result,
|
|
"id": request_id,
|
|
}
|
|
)
|
|
except TypeError as e:
|
|
return JSONResponse(
|
|
status_code=400,
|
|
content={
|
|
"jsonrpc": "2.0",
|
|
"error": {"code": -32602, "message": f"Invalid params: {e}"},
|
|
"id": request_id,
|
|
},
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Tool execution error: {e}")
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content={
|
|
"jsonrpc": "2.0",
|
|
"error": {"code": -32000, "message": f"Server error: {e}"},
|
|
"id": request_id,
|
|
},
|
|
)
|
|
|
|
|
|
# MCP Tools
|
|
|
|
|
|
@mcp.tool()
|
|
async def search_knowledge(
|
|
project_id: str = Field(..., description="Project ID for scoping"),
|
|
agent_id: str = Field(..., description="Agent ID making the request"),
|
|
query: str = Field(..., description="Search query"),
|
|
search_type: str = Field(
|
|
default="hybrid",
|
|
description="Search type: semantic, keyword, or hybrid",
|
|
),
|
|
collection: str | None = Field(
|
|
default=None,
|
|
description="Collection to search (None = all)",
|
|
),
|
|
limit: int = Field(
|
|
default=10,
|
|
ge=1,
|
|
le=100,
|
|
description="Maximum number of results",
|
|
),
|
|
threshold: float = Field(
|
|
default=0.7,
|
|
ge=0.0,
|
|
le=1.0,
|
|
description="Minimum similarity score",
|
|
),
|
|
file_types: list[str] | None = Field(
|
|
default=None,
|
|
description="Filter by file types (python, javascript, etc.)",
|
|
),
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Search the knowledge base for relevant content.
|
|
|
|
Supports semantic (vector), keyword (full-text), and hybrid search.
|
|
Returns chunks ranked by relevance to the query.
|
|
"""
|
|
try:
|
|
# Validate inputs
|
|
if error := _validate_id(project_id, "project_id"):
|
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
|
if error := _validate_id(agent_id, "agent_id"):
|
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
|
if collection and (error := _validate_collection(collection)):
|
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
|
|
|
# Parse search type
|
|
try:
|
|
search_type_enum = SearchType(search_type.lower())
|
|
except ValueError:
|
|
valid_types = [t.value for t in SearchType]
|
|
return {
|
|
"success": False,
|
|
"error": f"Invalid search type: {search_type}. Valid types: {valid_types}",
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
|
|
# Parse file types
|
|
file_type_enums = None
|
|
if file_types:
|
|
try:
|
|
file_type_enums = [FileType(ft.lower()) for ft in file_types]
|
|
except ValueError as e:
|
|
return {
|
|
"success": False,
|
|
"error": f"Invalid file type: {e}",
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
|
|
request = SearchRequest(
|
|
project_id=project_id,
|
|
agent_id=agent_id,
|
|
query=query,
|
|
search_type=search_type_enum,
|
|
collection=collection,
|
|
limit=limit,
|
|
threshold=threshold,
|
|
file_types=file_type_enums,
|
|
)
|
|
|
|
response = await _search.search(request) # type: ignore[union-attr]
|
|
|
|
return {
|
|
"success": True,
|
|
"query": response.query,
|
|
"search_type": response.search_type,
|
|
"results": [
|
|
{
|
|
"id": r.id,
|
|
"content": r.content,
|
|
"score": r.score,
|
|
"source_path": r.source_path,
|
|
"start_line": r.start_line,
|
|
"end_line": r.end_line,
|
|
"chunk_type": r.chunk_type,
|
|
"file_type": r.file_type,
|
|
"collection": r.collection,
|
|
"metadata": r.metadata,
|
|
}
|
|
for r in response.results
|
|
],
|
|
"total_results": response.total_results,
|
|
"search_time_ms": response.search_time_ms,
|
|
}
|
|
|
|
except KnowledgeBaseError as e:
|
|
logger.error(f"Search error: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": e.message,
|
|
"code": e.code.value,
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Unexpected search error: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": str(e),
|
|
"code": ErrorCode.INTERNAL_ERROR.value,
|
|
}
|
|
|
|
|
|
@mcp.tool()
|
|
async def ingest_content(
|
|
project_id: str = Field(..., description="Project ID for scoping"),
|
|
agent_id: str = Field(..., description="Agent ID making the request"),
|
|
content: str = Field(..., description="Content to ingest"),
|
|
source_path: str | None = Field(
|
|
default=None,
|
|
description="Source file path for reference",
|
|
),
|
|
collection: str = Field(
|
|
default="default",
|
|
description="Collection to store in",
|
|
),
|
|
chunk_type: str = Field(
|
|
default="text",
|
|
description="Content type: code, markdown, or text",
|
|
),
|
|
file_type: str | None = Field(
|
|
default=None,
|
|
description="File type for code chunking (python, javascript, etc.)",
|
|
),
|
|
metadata: dict[str, Any] | None = Field(
|
|
default=None,
|
|
description="Additional metadata to store",
|
|
),
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Ingest content into the knowledge base.
|
|
|
|
Content is automatically chunked based on type, embedded using
|
|
the LLM Gateway, and stored in pgvector for search.
|
|
"""
|
|
try:
|
|
# Validate inputs
|
|
if error := _validate_id(project_id, "project_id"):
|
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
|
if error := _validate_id(agent_id, "agent_id"):
|
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
|
if error := _validate_collection(collection):
|
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
|
if error := _validate_source_path(source_path):
|
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
|
|
|
# Validate content size to prevent DoS
|
|
settings = get_settings()
|
|
content_size = len(content.encode("utf-8"))
|
|
if content_size > settings.max_document_size:
|
|
return {
|
|
"success": False,
|
|
"error": f"Content size ({content_size} bytes) exceeds maximum allowed ({settings.max_document_size} bytes)",
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
|
|
# Parse chunk type
|
|
try:
|
|
chunk_type_enum = ChunkType(chunk_type.lower())
|
|
except ValueError:
|
|
valid_types = [t.value for t in ChunkType]
|
|
return {
|
|
"success": False,
|
|
"error": f"Invalid chunk type: {chunk_type}. Valid types: {valid_types}",
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
|
|
# Parse file type
|
|
file_type_enum = None
|
|
if file_type:
|
|
try:
|
|
file_type_enum = FileType(file_type.lower())
|
|
except ValueError:
|
|
valid_types = [t.value for t in FileType]
|
|
return {
|
|
"success": False,
|
|
"error": f"Invalid file type: {file_type}. Valid types: {valid_types}",
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
|
|
request = IngestRequest(
|
|
project_id=project_id,
|
|
agent_id=agent_id,
|
|
content=content,
|
|
source_path=source_path,
|
|
collection=collection,
|
|
chunk_type=chunk_type_enum,
|
|
file_type=file_type_enum,
|
|
metadata=metadata or {},
|
|
)
|
|
|
|
result = await _collections.ingest(request) # type: ignore[union-attr]
|
|
|
|
return {
|
|
"success": result.success,
|
|
"chunks_created": result.chunks_created,
|
|
"embeddings_generated": result.embeddings_generated,
|
|
"source_path": result.source_path,
|
|
"collection": result.collection,
|
|
"chunk_ids": result.chunk_ids,
|
|
"error": result.error,
|
|
}
|
|
|
|
except KnowledgeBaseError as e:
|
|
logger.error(f"Ingest error: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": e.message,
|
|
"code": e.code.value,
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Unexpected ingest error: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": str(e),
|
|
"code": ErrorCode.INTERNAL_ERROR.value,
|
|
}
|
|
|
|
|
|
@mcp.tool()
|
|
async def delete_content(
|
|
project_id: str = Field(..., description="Project ID for scoping"),
|
|
agent_id: str = Field(..., description="Agent ID making the request"),
|
|
source_path: str | None = Field(
|
|
default=None,
|
|
description="Delete by source file path",
|
|
),
|
|
collection: str | None = Field(
|
|
default=None,
|
|
description="Delete entire collection",
|
|
),
|
|
chunk_ids: list[str] | None = Field(
|
|
default=None,
|
|
description="Delete specific chunk IDs",
|
|
),
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Delete content from the knowledge base.
|
|
|
|
Specify either source_path, collection, or chunk_ids to delete.
|
|
"""
|
|
try:
|
|
# Validate inputs
|
|
if error := _validate_id(project_id, "project_id"):
|
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
|
if error := _validate_id(agent_id, "agent_id"):
|
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
|
if collection and (error := _validate_collection(collection)):
|
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
|
if error := _validate_source_path(source_path):
|
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
|
|
|
request = DeleteRequest(
|
|
project_id=project_id,
|
|
agent_id=agent_id,
|
|
source_path=source_path,
|
|
collection=collection,
|
|
chunk_ids=chunk_ids,
|
|
)
|
|
|
|
result = await _collections.delete(request) # type: ignore[union-attr]
|
|
|
|
return {
|
|
"success": result.success,
|
|
"chunks_deleted": result.chunks_deleted,
|
|
"error": result.error,
|
|
}
|
|
|
|
except KnowledgeBaseError as e:
|
|
logger.error(f"Delete error: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": e.message,
|
|
"code": e.code.value,
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Unexpected delete error: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": str(e),
|
|
"code": ErrorCode.INTERNAL_ERROR.value,
|
|
}
|
|
|
|
|
|
@mcp.tool()
|
|
async def list_collections(
|
|
project_id: str = Field(..., description="Project ID for scoping"),
|
|
agent_id: str = Field(..., description="Agent ID making the request"), # noqa: ARG001
|
|
) -> dict[str, Any]:
|
|
"""
|
|
List all collections in a project's knowledge base.
|
|
|
|
Returns collection names with chunk counts and file types.
|
|
"""
|
|
try:
|
|
# Validate inputs
|
|
if error := _validate_id(project_id, "project_id"):
|
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
|
if error := _validate_id(agent_id, "agent_id"):
|
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
|
|
|
result = await _collections.list_collections(project_id) # type: ignore[union-attr]
|
|
|
|
return {
|
|
"success": True,
|
|
"project_id": result.project_id,
|
|
"collections": [
|
|
{
|
|
"name": c.name,
|
|
"chunk_count": c.chunk_count,
|
|
"total_tokens": c.total_tokens,
|
|
"file_types": c.file_types,
|
|
"created_at": c.created_at.isoformat(),
|
|
"updated_at": c.updated_at.isoformat(),
|
|
}
|
|
for c in result.collections
|
|
],
|
|
"total_collections": result.total_collections,
|
|
}
|
|
|
|
except KnowledgeBaseError as e:
|
|
logger.error(f"List collections error: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": e.message,
|
|
"code": e.code.value,
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Unexpected list collections error: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": str(e),
|
|
"code": ErrorCode.INTERNAL_ERROR.value,
|
|
}
|
|
|
|
|
|
@mcp.tool()
|
|
async def get_collection_stats(
|
|
project_id: str = Field(..., description="Project ID for scoping"),
|
|
agent_id: str = Field(..., description="Agent ID making the request"), # noqa: ARG001
|
|
collection: str = Field(..., description="Collection name"),
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Get detailed statistics for a collection.
|
|
|
|
Returns chunk counts, token totals, and type breakdowns.
|
|
"""
|
|
try:
|
|
# Validate inputs
|
|
if error := _validate_id(project_id, "project_id"):
|
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
|
if error := _validate_id(agent_id, "agent_id"):
|
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
|
if error := _validate_collection(collection):
|
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
|
|
|
stats = await _collections.get_collection_stats(project_id, collection) # type: ignore[union-attr]
|
|
|
|
return {
|
|
"success": True,
|
|
"collection": stats.collection,
|
|
"project_id": stats.project_id,
|
|
"chunk_count": stats.chunk_count,
|
|
"unique_sources": stats.unique_sources,
|
|
"total_tokens": stats.total_tokens,
|
|
"avg_chunk_size": stats.avg_chunk_size,
|
|
"chunk_types": stats.chunk_types,
|
|
"file_types": stats.file_types,
|
|
"oldest_chunk": stats.oldest_chunk.isoformat() if stats.oldest_chunk else None,
|
|
"newest_chunk": stats.newest_chunk.isoformat() if stats.newest_chunk else None,
|
|
}
|
|
|
|
except KnowledgeBaseError as e:
|
|
logger.error(f"Get collection stats error: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": e.message,
|
|
"code": e.code.value,
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Unexpected get collection stats error: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": str(e),
|
|
"code": ErrorCode.INTERNAL_ERROR.value,
|
|
}
|
|
|
|
|
|
@mcp.tool()
|
|
async def update_document(
|
|
project_id: str = Field(..., description="Project ID for scoping"),
|
|
agent_id: str = Field(..., description="Agent ID making the request"),
|
|
source_path: str = Field(..., description="Source file path"),
|
|
content: str = Field(..., description="New content"),
|
|
collection: str = Field(
|
|
default="default",
|
|
description="Collection name",
|
|
),
|
|
chunk_type: str = Field(
|
|
default="text",
|
|
description="Content type: code, markdown, or text",
|
|
),
|
|
file_type: str | None = Field(
|
|
default=None,
|
|
description="File type for code chunking",
|
|
),
|
|
metadata: dict[str, Any] | None = Field(
|
|
default=None,
|
|
description="Additional metadata",
|
|
),
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Update a document in the knowledge base.
|
|
|
|
Replaces all existing chunks for the source path with new content.
|
|
"""
|
|
try:
|
|
# Validate inputs
|
|
if error := _validate_id(project_id, "project_id"):
|
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
|
if error := _validate_id(agent_id, "agent_id"):
|
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
|
if error := _validate_collection(collection):
|
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
|
if error := _validate_source_path(source_path):
|
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
|
|
|
# Validate content size to prevent DoS
|
|
settings = get_settings()
|
|
content_size = len(content.encode("utf-8"))
|
|
if content_size > settings.max_document_size:
|
|
return {
|
|
"success": False,
|
|
"error": f"Content size ({content_size} bytes) exceeds maximum allowed ({settings.max_document_size} bytes)",
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
|
|
# Parse chunk type
|
|
try:
|
|
chunk_type_enum = ChunkType(chunk_type.lower())
|
|
except ValueError:
|
|
valid_types = [t.value for t in ChunkType]
|
|
return {
|
|
"success": False,
|
|
"error": f"Invalid chunk type: {chunk_type}. Valid types: {valid_types}",
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
|
|
# Parse file type
|
|
file_type_enum = None
|
|
if file_type:
|
|
try:
|
|
file_type_enum = FileType(file_type.lower())
|
|
except ValueError:
|
|
valid_types = [t.value for t in FileType]
|
|
return {
|
|
"success": False,
|
|
"error": f"Invalid file type: {file_type}. Valid types: {valid_types}",
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
|
|
result = await _collections.update_document( # type: ignore[union-attr]
|
|
project_id=project_id,
|
|
agent_id=agent_id,
|
|
source_path=source_path,
|
|
content=content,
|
|
collection=collection,
|
|
chunk_type=chunk_type_enum,
|
|
file_type=file_type_enum,
|
|
metadata=metadata,
|
|
)
|
|
|
|
return {
|
|
"success": result.success,
|
|
"chunks_created": result.chunks_created,
|
|
"embeddings_generated": result.embeddings_generated,
|
|
"source_path": result.source_path,
|
|
"collection": result.collection,
|
|
"chunk_ids": result.chunk_ids,
|
|
"error": result.error,
|
|
}
|
|
|
|
except KnowledgeBaseError as e:
|
|
logger.error(f"Update document error: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": e.message,
|
|
"code": e.code.value,
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Unexpected update document error: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": str(e),
|
|
"code": ErrorCode.INTERNAL_ERROR.value,
|
|
}
|
|
|
|
|
|
# Register tools in the JSON-RPC registry
|
|
# This must happen after tool functions are defined
|
|
_register_tool("search_knowledge", search_knowledge)
|
|
_register_tool("ingest_content", ingest_content)
|
|
_register_tool("delete_content", delete_content)
|
|
_register_tool("list_collections", list_collections)
|
|
_register_tool("get_collection_stats", get_collection_stats)
|
|
_register_tool("update_document", update_document)
|
|
|
|
|
|
def main() -> None:
|
|
"""Run the server."""
|
|
import uvicorn
|
|
|
|
settings = get_settings()
|
|
|
|
uvicorn.run(
|
|
"server:app",
|
|
host=settings.host,
|
|
port=settings.port,
|
|
reload=settings.debug,
|
|
log_level="info",
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|