forked from cardosofelipe/fast-next-template
Improved code formatting, line breaks, and indentation across chunking logic and multiple test modules to enhance code clarity and maintain consistent style. No functional changes made.
1135 lines
35 KiB
Python
1135 lines
35 KiB
Python
"""
|
|
Knowledge Base MCP Server.
|
|
|
|
Provides RAG capabilities with pgvector for semantic search,
|
|
intelligent chunking, and collection management.
|
|
"""
|
|
|
|
import inspect
|
|
import logging
|
|
import re
|
|
from contextlib import asynccontextmanager
|
|
from typing import Any, get_type_hints
|
|
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import JSONResponse
|
|
from fastmcp import FastMCP
|
|
from pydantic import Field
|
|
|
|
from collection_manager import CollectionManager, get_collection_manager
|
|
from collections.abc import AsyncIterator
|
|
from config import get_settings
|
|
from database import DatabaseManager, get_database_manager
|
|
from embeddings import EmbeddingGenerator, get_embedding_generator
|
|
from exceptions import ErrorCode, KnowledgeBaseError
|
|
from models import (
|
|
ChunkType,
|
|
DeleteRequest,
|
|
FileType,
|
|
IngestRequest,
|
|
SearchRequest,
|
|
SearchType,
|
|
)
|
|
from search import SearchEngine, get_search_engine
|
|
|
|
# Input validation patterns
|
|
# Allow alphanumeric, hyphens, underscores (1-128 chars)
|
|
ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,128}$")
|
|
# Collection names: alphanumeric, hyphens, underscores (1-64 chars)
|
|
COLLECTION_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,64}$")
|
|
|
|
|
|
def _validate_id(value: str, field_name: str) -> str | None:
|
|
"""Validate project_id or agent_id format.
|
|
|
|
Returns error message if invalid, None if valid.
|
|
"""
|
|
# Handle FieldInfo objects from direct .fn() calls in tests
|
|
if not isinstance(value, str):
|
|
return f"{field_name} must be a string"
|
|
if not value:
|
|
return f"{field_name} is required"
|
|
if not ID_PATTERN.match(value):
|
|
return f"Invalid {field_name}: must be 1-128 alphanumeric characters, hyphens, or underscores"
|
|
return None
|
|
|
|
|
|
def _validate_collection(value: str) -> str | None:
|
|
"""Validate collection name format.
|
|
|
|
Returns error message if invalid, None if valid.
|
|
"""
|
|
# Handle FieldInfo objects from direct .fn() calls in tests
|
|
if not isinstance(value, str):
|
|
return None # Non-string means default not resolved, skip validation
|
|
if not COLLECTION_PATTERN.match(value):
|
|
return "Invalid collection: must be 1-64 alphanumeric characters, hyphens, or underscores"
|
|
return None
|
|
|
|
|
|
def _validate_source_path(value: str | None) -> str | None:
|
|
"""Validate source_path to prevent path traversal.
|
|
|
|
Returns error message if invalid, None if valid.
|
|
"""
|
|
if value is None:
|
|
return None
|
|
|
|
# Handle FieldInfo objects from direct .fn() calls in tests
|
|
if not isinstance(value, str):
|
|
return None # Non-string means default not resolved, skip validation
|
|
|
|
# Normalize path and check for traversal attempts
|
|
if ".." in value:
|
|
return "Invalid source_path: path traversal not allowed"
|
|
|
|
# Check for null bytes (used in some injection attacks)
|
|
if "\x00" in value:
|
|
return "Invalid source_path: null bytes not allowed"
|
|
|
|
# Limit path length to prevent DoS
|
|
if len(value) > 4096:
|
|
return "Invalid source_path: path too long (max 4096 chars)"
|
|
|
|
return None
|
|
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Global instances
|
|
_database: DatabaseManager | None = None
|
|
_embeddings: EmbeddingGenerator | None = None
|
|
_search: SearchEngine | None = None
|
|
_collections: CollectionManager | None = None
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
|
|
"""Application lifespan handler."""
|
|
global _database, _embeddings, _search, _collections
|
|
|
|
logger.info("Starting Knowledge Base MCP Server...")
|
|
|
|
# Initialize database
|
|
_database = get_database_manager()
|
|
await _database.initialize()
|
|
|
|
# Initialize embedding generator
|
|
_embeddings = get_embedding_generator()
|
|
await _embeddings.initialize()
|
|
|
|
# Initialize search engine
|
|
_search = get_search_engine()
|
|
|
|
# Initialize collection manager
|
|
_collections = get_collection_manager()
|
|
|
|
logger.info("Knowledge Base MCP Server started successfully")
|
|
|
|
yield
|
|
|
|
# Cleanup
|
|
logger.info("Shutting down Knowledge Base MCP Server...")
|
|
|
|
if _embeddings:
|
|
await _embeddings.close()
|
|
|
|
if _database:
|
|
await _database.close()
|
|
|
|
logger.info("Knowledge Base MCP Server shut down")
|
|
|
|
|
|
# Create FastMCP server
|
|
mcp = FastMCP("syndarix-knowledge-base")
|
|
|
|
# Create FastAPI app with lifespan
|
|
app = FastAPI(
|
|
title="Knowledge Base MCP Server",
|
|
description="RAG with pgvector for semantic search",
|
|
version="0.1.0",
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
|
|
@app.get("/health")
|
|
async def health_check() -> dict[str, Any]:
|
|
"""Health check endpoint.
|
|
|
|
Checks all dependencies: database, Redis cache, and LLM Gateway.
|
|
Returns degraded status if any non-critical dependency fails.
|
|
Returns unhealthy status if critical dependencies fail.
|
|
"""
|
|
from datetime import UTC, datetime
|
|
|
|
status: dict[str, Any] = {
|
|
"status": "healthy",
|
|
"service": "knowledge-base",
|
|
"version": "0.1.0",
|
|
"timestamp": datetime.now(UTC).isoformat(),
|
|
"dependencies": {},
|
|
}
|
|
|
|
is_degraded = False
|
|
is_unhealthy = False
|
|
|
|
# Check database connection (critical)
|
|
try:
|
|
if _database and _database._pool:
|
|
async with _database.acquire() as conn:
|
|
await conn.fetchval("SELECT 1")
|
|
status["dependencies"]["database"] = "connected"
|
|
else:
|
|
status["dependencies"]["database"] = "not initialized"
|
|
is_unhealthy = True
|
|
except Exception as e:
|
|
status["dependencies"]["database"] = f"error: {e}"
|
|
is_unhealthy = True
|
|
|
|
# Check Redis cache (non-critical - degraded without it)
|
|
try:
|
|
if _embeddings and _embeddings._redis:
|
|
await _embeddings._redis.ping() # type: ignore[misc]
|
|
status["dependencies"]["redis"] = "connected"
|
|
else:
|
|
status["dependencies"]["redis"] = "not initialized"
|
|
is_degraded = True
|
|
except Exception as e:
|
|
status["dependencies"]["redis"] = f"error: {e}"
|
|
is_degraded = True
|
|
|
|
# Check LLM Gateway connectivity (non-critical for health check)
|
|
try:
|
|
if _embeddings and _embeddings._http_client:
|
|
settings = get_settings()
|
|
response = await _embeddings._http_client.get(
|
|
f"{settings.llm_gateway_url}/health",
|
|
timeout=5.0,
|
|
)
|
|
if response.status_code == 200:
|
|
status["dependencies"]["llm_gateway"] = "connected"
|
|
else:
|
|
status["dependencies"]["llm_gateway"] = (
|
|
f"unhealthy (status {response.status_code})"
|
|
)
|
|
is_degraded = True
|
|
else:
|
|
status["dependencies"]["llm_gateway"] = "not initialized"
|
|
is_degraded = True
|
|
except Exception as e:
|
|
status["dependencies"]["llm_gateway"] = f"error: {e}"
|
|
is_degraded = True
|
|
|
|
# Set overall status
|
|
if is_unhealthy:
|
|
status["status"] = "unhealthy"
|
|
elif is_degraded:
|
|
status["status"] = "degraded"
|
|
else:
|
|
status["status"] = "healthy"
|
|
|
|
return status
|
|
|
|
|
|
# Tool registry for JSON-RPC
|
|
_tool_registry: dict[str, Any] = {}
|
|
|
|
|
|
def _python_type_to_json_schema(python_type: Any) -> dict[str, Any]:
|
|
"""Convert Python type annotation to JSON Schema."""
|
|
type_name = getattr(python_type, "__name__", str(python_type))
|
|
|
|
if python_type is str or type_name == "str":
|
|
return {"type": "string"}
|
|
elif python_type is int or type_name == "int":
|
|
return {"type": "integer"}
|
|
elif python_type is float or type_name == "float":
|
|
return {"type": "number"}
|
|
elif python_type is bool or type_name == "bool":
|
|
return {"type": "boolean"}
|
|
elif type_name == "NoneType":
|
|
return {"type": "null"}
|
|
elif hasattr(python_type, "__origin__"):
|
|
origin = python_type.__origin__
|
|
args = getattr(python_type, "__args__", ())
|
|
|
|
if origin is list:
|
|
item_type = args[0] if args else Any
|
|
return {"type": "array", "items": _python_type_to_json_schema(item_type)}
|
|
elif origin is dict:
|
|
return {"type": "object"}
|
|
elif origin is type(None) or str(origin) == "typing.Union":
|
|
# Handle Optional types (Union with None)
|
|
non_none_args = [a for a in args if a is not type(None)]
|
|
if len(non_none_args) == 1:
|
|
schema = _python_type_to_json_schema(non_none_args[0])
|
|
schema["nullable"] = True
|
|
return schema
|
|
return {"type": "object"}
|
|
return {"type": "object"}
|
|
|
|
|
|
def _get_tool_schema(func: Any) -> dict[str, Any]:
|
|
"""Extract JSON Schema from a tool function."""
|
|
sig = inspect.signature(func)
|
|
hints = get_type_hints(func) if hasattr(func, "__annotations__") else {}
|
|
|
|
properties: dict[str, Any] = {}
|
|
required: list[str] = []
|
|
|
|
for name, param in sig.parameters.items():
|
|
if name in ("self", "cls"):
|
|
continue
|
|
|
|
prop: dict[str, Any] = {}
|
|
|
|
# Get type from hints
|
|
if name in hints:
|
|
prop = _python_type_to_json_schema(hints[name])
|
|
|
|
# Get description and constraints from Field default (FieldInfo object)
|
|
default_val = param.default
|
|
if hasattr(default_val, "description") and default_val.description:
|
|
prop["description"] = default_val.description
|
|
if hasattr(default_val, "ge") and default_val.ge is not None:
|
|
prop["minimum"] = default_val.ge
|
|
if hasattr(default_val, "le") and default_val.le is not None:
|
|
prop["maximum"] = default_val.le
|
|
# Handle Field default value (check for PydanticUndefined)
|
|
if hasattr(default_val, "default"):
|
|
field_default = default_val.default
|
|
# Check if it's the "required" sentinel (...)
|
|
if field_default is not ... and not (
|
|
hasattr(field_default, "__class__")
|
|
and "PydanticUndefined" in field_default.__class__.__name__
|
|
):
|
|
prop["default"] = field_default
|
|
|
|
# Determine if required
|
|
if param.default is inspect.Parameter.empty:
|
|
required.append(name)
|
|
elif hasattr(default_val, "default"):
|
|
field_default = default_val.default
|
|
# Required if default is ellipsis or PydanticUndefined
|
|
if field_default is ... or (
|
|
hasattr(field_default, "__class__")
|
|
and "PydanticUndefined" in field_default.__class__.__name__
|
|
):
|
|
required.append(name)
|
|
|
|
properties[name] = prop
|
|
|
|
return {
|
|
"type": "object",
|
|
"properties": properties,
|
|
"required": required,
|
|
}
|
|
|
|
|
|
def _register_tool(
|
|
name: str, tool_or_func: Any, description: str | None = None
|
|
) -> None:
|
|
"""Register a tool in the registry.
|
|
|
|
Handles both raw functions and FastMCP FunctionTool objects.
|
|
"""
|
|
# Extract the underlying function from FastMCP FunctionTool if needed
|
|
if hasattr(tool_or_func, "fn"):
|
|
func = tool_or_func.fn
|
|
# Use FunctionTool's description if available
|
|
if (
|
|
not description
|
|
and hasattr(tool_or_func, "description")
|
|
and tool_or_func.description
|
|
):
|
|
description = tool_or_func.description
|
|
else:
|
|
func = tool_or_func
|
|
|
|
_tool_registry[name] = {
|
|
"func": func,
|
|
"description": description or (func.__doc__ or "").strip(),
|
|
"schema": _get_tool_schema(func),
|
|
}
|
|
|
|
|
|
@app.get("/mcp/tools")
|
|
async def list_mcp_tools() -> dict[str, Any]:
|
|
"""
|
|
Return list of available MCP tools with their schemas.
|
|
|
|
This endpoint enables tool discovery for the backend MCP client.
|
|
"""
|
|
tools = []
|
|
for name, info in _tool_registry.items():
|
|
tools.append(
|
|
{
|
|
"name": name,
|
|
"description": info["description"],
|
|
"inputSchema": info["schema"],
|
|
}
|
|
)
|
|
|
|
return {"tools": tools}
|
|
|
|
|
|
@app.post("/mcp")
|
|
async def mcp_rpc(request: Request) -> JSONResponse:
|
|
"""
|
|
JSON-RPC 2.0 endpoint for MCP tool execution.
|
|
|
|
Request format:
|
|
{
|
|
"jsonrpc": "2.0",
|
|
"method": "<tool_name>",
|
|
"params": {...},
|
|
"id": <request_id>
|
|
}
|
|
|
|
Response format:
|
|
{
|
|
"jsonrpc": "2.0",
|
|
"result": {...},
|
|
"id": <request_id>
|
|
}
|
|
"""
|
|
try:
|
|
body = await request.json()
|
|
except Exception as e:
|
|
return JSONResponse(
|
|
status_code=400,
|
|
content={
|
|
"jsonrpc": "2.0",
|
|
"error": {"code": -32700, "message": f"Parse error: {e}"},
|
|
"id": None,
|
|
},
|
|
)
|
|
|
|
# Validate JSON-RPC structure
|
|
jsonrpc = body.get("jsonrpc")
|
|
method = body.get("method")
|
|
params = body.get("params", {})
|
|
request_id = body.get("id")
|
|
|
|
if jsonrpc != "2.0":
|
|
return JSONResponse(
|
|
status_code=400,
|
|
content={
|
|
"jsonrpc": "2.0",
|
|
"error": {
|
|
"code": -32600,
|
|
"message": "Invalid Request: jsonrpc must be '2.0'",
|
|
},
|
|
"id": request_id,
|
|
},
|
|
)
|
|
|
|
if not method:
|
|
return JSONResponse(
|
|
status_code=400,
|
|
content={
|
|
"jsonrpc": "2.0",
|
|
"error": {
|
|
"code": -32600,
|
|
"message": "Invalid Request: method is required",
|
|
},
|
|
"id": request_id,
|
|
},
|
|
)
|
|
|
|
# Look up tool
|
|
tool_info = _tool_registry.get(method)
|
|
if not tool_info:
|
|
return JSONResponse(
|
|
status_code=404,
|
|
content={
|
|
"jsonrpc": "2.0",
|
|
"error": {"code": -32601, "message": f"Method not found: {method}"},
|
|
"id": request_id,
|
|
},
|
|
)
|
|
|
|
# Execute tool
|
|
try:
|
|
func = tool_info["func"]
|
|
|
|
# Resolve Field defaults for missing parameters
|
|
sig = inspect.signature(func)
|
|
resolved_params = dict(params)
|
|
for name, param in sig.parameters.items():
|
|
if name not in resolved_params:
|
|
default_val = param.default
|
|
# Check if it's a FieldInfo with a default value
|
|
if hasattr(default_val, "default"):
|
|
field_default = default_val.default
|
|
# Only use if it has an actual default (not required)
|
|
if field_default is not ... and not (
|
|
hasattr(field_default, "__class__")
|
|
and "PydanticUndefined" in field_default.__class__.__name__
|
|
):
|
|
resolved_params[name] = field_default
|
|
|
|
result = await func(**resolved_params)
|
|
return JSONResponse(
|
|
content={
|
|
"jsonrpc": "2.0",
|
|
"result": result,
|
|
"id": request_id,
|
|
}
|
|
)
|
|
except TypeError as e:
|
|
return JSONResponse(
|
|
status_code=400,
|
|
content={
|
|
"jsonrpc": "2.0",
|
|
"error": {"code": -32602, "message": f"Invalid params: {e}"},
|
|
"id": request_id,
|
|
},
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Tool execution error: {e}")
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content={
|
|
"jsonrpc": "2.0",
|
|
"error": {"code": -32000, "message": f"Server error: {e}"},
|
|
"id": request_id,
|
|
},
|
|
)
|
|
|
|
|
|
# MCP Tools
|
|
|
|
|
|
@mcp.tool()
|
|
async def search_knowledge(
|
|
project_id: str = Field(..., description="Project ID for scoping"),
|
|
agent_id: str = Field(..., description="Agent ID making the request"),
|
|
query: str = Field(..., description="Search query"),
|
|
search_type: str = Field(
|
|
default="hybrid",
|
|
description="Search type: semantic, keyword, or hybrid",
|
|
),
|
|
collection: str | None = Field(
|
|
default=None,
|
|
description="Collection to search (None = all)",
|
|
),
|
|
limit: int = Field(
|
|
default=10,
|
|
ge=1,
|
|
le=100,
|
|
description="Maximum number of results",
|
|
),
|
|
threshold: float = Field(
|
|
default=0.7,
|
|
ge=0.0,
|
|
le=1.0,
|
|
description="Minimum similarity score",
|
|
),
|
|
file_types: list[str] | None = Field(
|
|
default=None,
|
|
description="Filter by file types (python, javascript, etc.)",
|
|
),
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Search the knowledge base for relevant content.
|
|
|
|
Supports semantic (vector), keyword (full-text), and hybrid search.
|
|
Returns chunks ranked by relevance to the query.
|
|
"""
|
|
try:
|
|
# Validate inputs
|
|
if error := _validate_id(project_id, "project_id"):
|
|
return {
|
|
"success": False,
|
|
"error": error,
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
if error := _validate_id(agent_id, "agent_id"):
|
|
return {
|
|
"success": False,
|
|
"error": error,
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
if collection and (error := _validate_collection(collection)):
|
|
return {
|
|
"success": False,
|
|
"error": error,
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
|
|
# Parse search type
|
|
try:
|
|
search_type_enum = SearchType(search_type.lower())
|
|
except ValueError:
|
|
valid_types = [t.value for t in SearchType]
|
|
return {
|
|
"success": False,
|
|
"error": f"Invalid search type: {search_type}. Valid types: {valid_types}",
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
|
|
# Parse file types
|
|
file_type_enums = None
|
|
if file_types:
|
|
try:
|
|
file_type_enums = [FileType(ft.lower()) for ft in file_types]
|
|
except ValueError as e:
|
|
return {
|
|
"success": False,
|
|
"error": f"Invalid file type: {e}",
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
|
|
request = SearchRequest(
|
|
project_id=project_id,
|
|
agent_id=agent_id,
|
|
query=query,
|
|
search_type=search_type_enum,
|
|
collection=collection,
|
|
limit=limit,
|
|
threshold=threshold,
|
|
file_types=file_type_enums,
|
|
)
|
|
|
|
response = await _search.search(request) # type: ignore[union-attr]
|
|
|
|
return {
|
|
"success": True,
|
|
"query": response.query,
|
|
"search_type": response.search_type,
|
|
"results": [
|
|
{
|
|
"id": r.id,
|
|
"content": r.content,
|
|
"score": r.score,
|
|
"source_path": r.source_path,
|
|
"start_line": r.start_line,
|
|
"end_line": r.end_line,
|
|
"chunk_type": r.chunk_type,
|
|
"file_type": r.file_type,
|
|
"collection": r.collection,
|
|
"metadata": r.metadata,
|
|
}
|
|
for r in response.results
|
|
],
|
|
"total_results": response.total_results,
|
|
"search_time_ms": response.search_time_ms,
|
|
}
|
|
|
|
except KnowledgeBaseError as e:
|
|
logger.error(f"Search error: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": e.message,
|
|
"code": e.code.value,
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Unexpected search error: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": str(e),
|
|
"code": ErrorCode.INTERNAL_ERROR.value,
|
|
}
|
|
|
|
|
|
@mcp.tool()
|
|
async def ingest_content(
|
|
project_id: str = Field(..., description="Project ID for scoping"),
|
|
agent_id: str = Field(..., description="Agent ID making the request"),
|
|
content: str = Field(..., description="Content to ingest"),
|
|
source_path: str | None = Field(
|
|
default=None,
|
|
description="Source file path for reference",
|
|
),
|
|
collection: str = Field(
|
|
default="default",
|
|
description="Collection to store in",
|
|
),
|
|
chunk_type: str = Field(
|
|
default="text",
|
|
description="Content type: code, markdown, or text",
|
|
),
|
|
file_type: str | None = Field(
|
|
default=None,
|
|
description="File type for code chunking (python, javascript, etc.)",
|
|
),
|
|
metadata: dict[str, Any] | None = Field(
|
|
default=None,
|
|
description="Additional metadata to store",
|
|
),
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Ingest content into the knowledge base.
|
|
|
|
Content is automatically chunked based on type, embedded using
|
|
the LLM Gateway, and stored in pgvector for search.
|
|
"""
|
|
try:
|
|
# Validate inputs
|
|
if error := _validate_id(project_id, "project_id"):
|
|
return {
|
|
"success": False,
|
|
"error": error,
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
if error := _validate_id(agent_id, "agent_id"):
|
|
return {
|
|
"success": False,
|
|
"error": error,
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
if error := _validate_collection(collection):
|
|
return {
|
|
"success": False,
|
|
"error": error,
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
if error := _validate_source_path(source_path):
|
|
return {
|
|
"success": False,
|
|
"error": error,
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
|
|
# Validate content size to prevent DoS
|
|
settings = get_settings()
|
|
content_size = len(content.encode("utf-8"))
|
|
if content_size > settings.max_document_size:
|
|
return {
|
|
"success": False,
|
|
"error": f"Content size ({content_size} bytes) exceeds maximum allowed ({settings.max_document_size} bytes)",
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
|
|
# Parse chunk type
|
|
try:
|
|
chunk_type_enum = ChunkType(chunk_type.lower())
|
|
except ValueError:
|
|
valid_types = [t.value for t in ChunkType]
|
|
return {
|
|
"success": False,
|
|
"error": f"Invalid chunk type: {chunk_type}. Valid types: {valid_types}",
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
|
|
# Parse file type
|
|
file_type_enum = None
|
|
if file_type:
|
|
try:
|
|
file_type_enum = FileType(file_type.lower())
|
|
except ValueError:
|
|
valid_types = [t.value for t in FileType]
|
|
return {
|
|
"success": False,
|
|
"error": f"Invalid file type: {file_type}. Valid types: {valid_types}",
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
|
|
request = IngestRequest(
|
|
project_id=project_id,
|
|
agent_id=agent_id,
|
|
content=content,
|
|
source_path=source_path,
|
|
collection=collection,
|
|
chunk_type=chunk_type_enum,
|
|
file_type=file_type_enum,
|
|
metadata=metadata or {},
|
|
)
|
|
|
|
result = await _collections.ingest(request) # type: ignore[union-attr]
|
|
|
|
return {
|
|
"success": result.success,
|
|
"chunks_created": result.chunks_created,
|
|
"embeddings_generated": result.embeddings_generated,
|
|
"source_path": result.source_path,
|
|
"collection": result.collection,
|
|
"chunk_ids": result.chunk_ids,
|
|
"error": result.error,
|
|
}
|
|
|
|
except KnowledgeBaseError as e:
|
|
logger.error(f"Ingest error: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": e.message,
|
|
"code": e.code.value,
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Unexpected ingest error: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": str(e),
|
|
"code": ErrorCode.INTERNAL_ERROR.value,
|
|
}
|
|
|
|
|
|
@mcp.tool()
|
|
async def delete_content(
|
|
project_id: str = Field(..., description="Project ID for scoping"),
|
|
agent_id: str = Field(..., description="Agent ID making the request"),
|
|
source_path: str | None = Field(
|
|
default=None,
|
|
description="Delete by source file path",
|
|
),
|
|
collection: str | None = Field(
|
|
default=None,
|
|
description="Delete entire collection",
|
|
),
|
|
chunk_ids: list[str] | None = Field(
|
|
default=None,
|
|
description="Delete specific chunk IDs",
|
|
),
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Delete content from the knowledge base.
|
|
|
|
Specify either source_path, collection, or chunk_ids to delete.
|
|
"""
|
|
try:
|
|
# Validate inputs
|
|
if error := _validate_id(project_id, "project_id"):
|
|
return {
|
|
"success": False,
|
|
"error": error,
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
if error := _validate_id(agent_id, "agent_id"):
|
|
return {
|
|
"success": False,
|
|
"error": error,
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
if collection and (error := _validate_collection(collection)):
|
|
return {
|
|
"success": False,
|
|
"error": error,
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
if error := _validate_source_path(source_path):
|
|
return {
|
|
"success": False,
|
|
"error": error,
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
|
|
request = DeleteRequest(
|
|
project_id=project_id,
|
|
agent_id=agent_id,
|
|
source_path=source_path,
|
|
collection=collection,
|
|
chunk_ids=chunk_ids,
|
|
)
|
|
|
|
result = await _collections.delete(request) # type: ignore[union-attr]
|
|
|
|
return {
|
|
"success": result.success,
|
|
"chunks_deleted": result.chunks_deleted,
|
|
"error": result.error,
|
|
}
|
|
|
|
except KnowledgeBaseError as e:
|
|
logger.error(f"Delete error: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": e.message,
|
|
"code": e.code.value,
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Unexpected delete error: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": str(e),
|
|
"code": ErrorCode.INTERNAL_ERROR.value,
|
|
}
|
|
|
|
|
|
@mcp.tool()
|
|
async def list_collections(
|
|
project_id: str = Field(..., description="Project ID for scoping"),
|
|
agent_id: str = Field(..., description="Agent ID making the request"), # noqa: ARG001
|
|
) -> dict[str, Any]:
|
|
"""
|
|
List all collections in a project's knowledge base.
|
|
|
|
Returns collection names with chunk counts and file types.
|
|
"""
|
|
try:
|
|
# Validate inputs
|
|
if error := _validate_id(project_id, "project_id"):
|
|
return {
|
|
"success": False,
|
|
"error": error,
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
if error := _validate_id(agent_id, "agent_id"):
|
|
return {
|
|
"success": False,
|
|
"error": error,
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
|
|
result = await _collections.list_collections(project_id) # type: ignore[union-attr]
|
|
|
|
return {
|
|
"success": True,
|
|
"project_id": result.project_id,
|
|
"collections": [
|
|
{
|
|
"name": c.name,
|
|
"chunk_count": c.chunk_count,
|
|
"total_tokens": c.total_tokens,
|
|
"file_types": c.file_types,
|
|
"created_at": c.created_at.isoformat(),
|
|
"updated_at": c.updated_at.isoformat(),
|
|
}
|
|
for c in result.collections
|
|
],
|
|
"total_collections": result.total_collections,
|
|
}
|
|
|
|
except KnowledgeBaseError as e:
|
|
logger.error(f"List collections error: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": e.message,
|
|
"code": e.code.value,
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Unexpected list collections error: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": str(e),
|
|
"code": ErrorCode.INTERNAL_ERROR.value,
|
|
}
|
|
|
|
|
|
@mcp.tool()
|
|
async def get_collection_stats(
|
|
project_id: str = Field(..., description="Project ID for scoping"),
|
|
agent_id: str = Field(..., description="Agent ID making the request"), # noqa: ARG001
|
|
collection: str = Field(..., description="Collection name"),
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Get detailed statistics for a collection.
|
|
|
|
Returns chunk counts, token totals, and type breakdowns.
|
|
"""
|
|
try:
|
|
# Validate inputs
|
|
if error := _validate_id(project_id, "project_id"):
|
|
return {
|
|
"success": False,
|
|
"error": error,
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
if error := _validate_id(agent_id, "agent_id"):
|
|
return {
|
|
"success": False,
|
|
"error": error,
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
if error := _validate_collection(collection):
|
|
return {
|
|
"success": False,
|
|
"error": error,
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
|
|
stats = await _collections.get_collection_stats(project_id, collection) # type: ignore[union-attr]
|
|
|
|
return {
|
|
"success": True,
|
|
"collection": stats.collection,
|
|
"project_id": stats.project_id,
|
|
"chunk_count": stats.chunk_count,
|
|
"unique_sources": stats.unique_sources,
|
|
"total_tokens": stats.total_tokens,
|
|
"avg_chunk_size": stats.avg_chunk_size,
|
|
"chunk_types": stats.chunk_types,
|
|
"file_types": stats.file_types,
|
|
"oldest_chunk": stats.oldest_chunk.isoformat()
|
|
if stats.oldest_chunk
|
|
else None,
|
|
"newest_chunk": stats.newest_chunk.isoformat()
|
|
if stats.newest_chunk
|
|
else None,
|
|
}
|
|
|
|
except KnowledgeBaseError as e:
|
|
logger.error(f"Get collection stats error: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": e.message,
|
|
"code": e.code.value,
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Unexpected get collection stats error: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": str(e),
|
|
"code": ErrorCode.INTERNAL_ERROR.value,
|
|
}
|
|
|
|
|
|
@mcp.tool()
|
|
async def update_document(
|
|
project_id: str = Field(..., description="Project ID for scoping"),
|
|
agent_id: str = Field(..., description="Agent ID making the request"),
|
|
source_path: str = Field(..., description="Source file path"),
|
|
content: str = Field(..., description="New content"),
|
|
collection: str = Field(
|
|
default="default",
|
|
description="Collection name",
|
|
),
|
|
chunk_type: str = Field(
|
|
default="text",
|
|
description="Content type: code, markdown, or text",
|
|
),
|
|
file_type: str | None = Field(
|
|
default=None,
|
|
description="File type for code chunking",
|
|
),
|
|
metadata: dict[str, Any] | None = Field(
|
|
default=None,
|
|
description="Additional metadata",
|
|
),
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Update a document in the knowledge base.
|
|
|
|
Replaces all existing chunks for the source path with new content.
|
|
"""
|
|
try:
|
|
# Validate inputs
|
|
if error := _validate_id(project_id, "project_id"):
|
|
return {
|
|
"success": False,
|
|
"error": error,
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
if error := _validate_id(agent_id, "agent_id"):
|
|
return {
|
|
"success": False,
|
|
"error": error,
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
if error := _validate_collection(collection):
|
|
return {
|
|
"success": False,
|
|
"error": error,
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
if error := _validate_source_path(source_path):
|
|
return {
|
|
"success": False,
|
|
"error": error,
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
|
|
# Validate content size to prevent DoS
|
|
settings = get_settings()
|
|
content_size = len(content.encode("utf-8"))
|
|
if content_size > settings.max_document_size:
|
|
return {
|
|
"success": False,
|
|
"error": f"Content size ({content_size} bytes) exceeds maximum allowed ({settings.max_document_size} bytes)",
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
|
|
# Parse chunk type
|
|
try:
|
|
chunk_type_enum = ChunkType(chunk_type.lower())
|
|
except ValueError:
|
|
valid_types = [t.value for t in ChunkType]
|
|
return {
|
|
"success": False,
|
|
"error": f"Invalid chunk type: {chunk_type}. Valid types: {valid_types}",
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
|
|
# Parse file type
|
|
file_type_enum = None
|
|
if file_type:
|
|
try:
|
|
file_type_enum = FileType(file_type.lower())
|
|
except ValueError:
|
|
valid_types = [t.value for t in FileType]
|
|
return {
|
|
"success": False,
|
|
"error": f"Invalid file type: {file_type}. Valid types: {valid_types}",
|
|
"code": ErrorCode.INVALID_REQUEST.value,
|
|
}
|
|
|
|
result = await _collections.update_document( # type: ignore[union-attr]
|
|
project_id=project_id,
|
|
agent_id=agent_id,
|
|
source_path=source_path,
|
|
content=content,
|
|
collection=collection,
|
|
chunk_type=chunk_type_enum,
|
|
file_type=file_type_enum,
|
|
metadata=metadata,
|
|
)
|
|
|
|
return {
|
|
"success": result.success,
|
|
"chunks_created": result.chunks_created,
|
|
"embeddings_generated": result.embeddings_generated,
|
|
"source_path": result.source_path,
|
|
"collection": result.collection,
|
|
"chunk_ids": result.chunk_ids,
|
|
"error": result.error,
|
|
}
|
|
|
|
except KnowledgeBaseError as e:
|
|
logger.error(f"Update document error: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": e.message,
|
|
"code": e.code.value,
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Unexpected update document error: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": str(e),
|
|
"code": ErrorCode.INTERNAL_ERROR.value,
|
|
}
|
|
|
|
|
|
# Register tools in the JSON-RPC registry
|
|
# This must happen after tool functions are defined
|
|
_register_tool("search_knowledge", search_knowledge)
|
|
_register_tool("ingest_content", ingest_content)
|
|
_register_tool("delete_content", delete_content)
|
|
_register_tool("list_collections", list_collections)
|
|
_register_tool("get_collection_stats", get_collection_stats)
|
|
_register_tool("update_document", update_document)
|
|
|
|
|
|
def main() -> None:
|
|
"""Run the server."""
|
|
import uvicorn
|
|
|
|
settings = get_settings()
|
|
|
|
uvicorn.run(
|
|
"server:app",
|
|
host=settings.host,
|
|
port=settings.port,
|
|
reload=settings.debug,
|
|
log_level="info",
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|