fix(mcp-kb): add input validation, path security, and health checks

Security fixes from deep review:
- Add input validation patterns for project_id, agent_id, collection
- Add path traversal protection for source_path (reject .., null bytes)
- Add error codes (INTERNAL_ERROR) to generic exception handlers
- Handle FieldInfo objects in validation for test robustness

Performance fixes:
- Enable concurrent hybrid search with asyncio.gather

Health endpoint improvements:
- Check all dependencies (database, Redis, LLM Gateway)
- Return degraded/unhealthy status based on dependency health
- Updated tests for new health check response structure

All 139 tests pass.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-04 01:18:50 +01:00
parent cd7a9ccbdf
commit 6bb376a336
3 changed files with 250 additions and 14 deletions

View File

@@ -7,6 +7,7 @@ intelligent chunking, and collection management.
import inspect
import logging
import re
from contextlib import asynccontextmanager
from typing import Any, get_type_hints
@@ -20,7 +21,7 @@ from collections.abc import AsyncIterator
from config import get_settings
from database import DatabaseManager, get_database_manager
from embeddings import EmbeddingGenerator, get_embedding_generator
from exceptions import KnowledgeBaseError
from exceptions import ErrorCode, KnowledgeBaseError
from models import (
ChunkType,
DeleteRequest,
@@ -31,6 +32,67 @@ from models import (
)
from search import SearchEngine, get_search_engine
# Input validation patterns
# Allow alphanumeric, hyphens, underscores (1-128 chars)
ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,128}$")
# Collection names: alphanumeric, hyphens, underscores (1-64 chars)
COLLECTION_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,64}$")
def _validate_id(value: str, field_name: str) -> str | None:
"""Validate project_id or agent_id format.
Returns error message if invalid, None if valid.
"""
# Handle FieldInfo objects from direct .fn() calls in tests
if not isinstance(value, str):
return f"{field_name} must be a string"
if not value:
return f"{field_name} is required"
if not ID_PATTERN.match(value):
return f"Invalid {field_name}: must be 1-128 alphanumeric characters, hyphens, or underscores"
return None
def _validate_collection(value: str) -> str | None:
"""Validate collection name format.
Returns error message if invalid, None if valid.
"""
# Handle FieldInfo objects from direct .fn() calls in tests
if not isinstance(value, str):
return None # Non-string means default not resolved, skip validation
if not COLLECTION_PATTERN.match(value):
return "Invalid collection: must be 1-64 alphanumeric characters, hyphens, or underscores"
return None
def _validate_source_path(value: str | None) -> str | None:
"""Validate source_path to prevent path traversal.
Returns error message if invalid, None if valid.
"""
if value is None:
return None
# Handle FieldInfo objects from direct .fn() calls in tests
if not isinstance(value, str):
return None # Non-string means default not resolved, skip validation
# Normalize path and check for traversal attempts
if ".." in value:
return "Invalid source_path: path traversal not allowed"
# Check for null bytes (used in some injection attacks)
if "\x00" in value:
return "Invalid source_path: null bytes not allowed"
# Limit path length to prevent DoS
if len(value) > 4096:
return "Invalid source_path: path too long (max 4096 chars)"
return None
# Configure logging
logging.basicConfig(
level=logging.INFO,
@@ -96,24 +158,77 @@ app = FastAPI(
@app.get("/health")
async def health_check() -> dict[str, Any]:
"""Health check endpoint."""
"""Health check endpoint.
Checks all dependencies: database, Redis cache, and LLM Gateway.
Returns degraded status if any non-critical dependency fails.
Returns unhealthy status if critical dependencies fail.
"""
from datetime import UTC, datetime
status: dict[str, Any] = {
"status": "healthy",
"service": "knowledge-base",
"version": "0.1.0",
"timestamp": datetime.now(UTC).isoformat(),
"dependencies": {},
}
# Check database connection
is_degraded = False
is_unhealthy = False
# Check database connection (critical)
try:
if _database and _database._pool:
async with _database.acquire() as conn:
await conn.fetchval("SELECT 1")
status["database"] = "connected"
status["dependencies"]["database"] = "connected"
else:
status["database"] = "not initialized"
status["dependencies"]["database"] = "not initialized"
is_unhealthy = True
except Exception as e:
status["database"] = f"error: {e}"
status["dependencies"]["database"] = f"error: {e}"
is_unhealthy = True
# Check Redis cache (non-critical - degraded without it)
try:
if _embeddings and _embeddings._redis:
await _embeddings._redis.ping()
status["dependencies"]["redis"] = "connected"
else:
status["dependencies"]["redis"] = "not initialized"
is_degraded = True
except Exception as e:
status["dependencies"]["redis"] = f"error: {e}"
is_degraded = True
# Check LLM Gateway connectivity (non-critical for health check)
try:
if _embeddings and _embeddings._http_client:
settings = get_settings()
response = await _embeddings._http_client.get(
f"{settings.llm_gateway_url}/health",
timeout=5.0,
)
if response.status_code == 200:
status["dependencies"]["llm_gateway"] = "connected"
else:
status["dependencies"]["llm_gateway"] = f"unhealthy (status {response.status_code})"
is_degraded = True
else:
status["dependencies"]["llm_gateway"] = "not initialized"
is_degraded = True
except Exception as e:
status["dependencies"]["llm_gateway"] = f"error: {e}"
is_degraded = True
# Set overall status
if is_unhealthy:
status["status"] = "unhealthy"
elif is_degraded:
status["status"] = "degraded"
else:
status["status"] = "healthy"
return status
@@ -411,6 +526,14 @@ async def search_knowledge(
Returns chunks ranked by relevance to the query.
"""
try:
# Validate inputs
if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if collection and (error := _validate_collection(collection)):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
# Parse search type
try:
search_type_enum = SearchType(search_type.lower())
@@ -419,6 +542,7 @@ async def search_knowledge(
return {
"success": False,
"error": f"Invalid search type: {search_type}. Valid types: {valid_types}",
"code": ErrorCode.INVALID_REQUEST.value,
}
# Parse file types
@@ -430,6 +554,7 @@ async def search_knowledge(
return {
"success": False,
"error": f"Invalid file type: {e}",
"code": ErrorCode.INVALID_REQUEST.value,
}
request = SearchRequest(
@@ -480,6 +605,7 @@ async def search_knowledge(
return {
"success": False,
"error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
}
@@ -516,6 +642,16 @@ async def ingest_content(
the LLM Gateway, and stored in pgvector for search.
"""
try:
# Validate inputs
if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_collection(collection):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_source_path(source_path):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
# Validate content size to prevent DoS
settings = get_settings()
content_size = len(content.encode("utf-8"))
@@ -523,6 +659,7 @@ async def ingest_content(
return {
"success": False,
"error": f"Content size ({content_size} bytes) exceeds maximum allowed ({settings.max_document_size} bytes)",
"code": ErrorCode.INVALID_REQUEST.value,
}
# Parse chunk type
@@ -533,6 +670,7 @@ async def ingest_content(
return {
"success": False,
"error": f"Invalid chunk type: {chunk_type}. Valid types: {valid_types}",
"code": ErrorCode.INVALID_REQUEST.value,
}
# Parse file type
@@ -545,6 +683,7 @@ async def ingest_content(
return {
"success": False,
"error": f"Invalid file type: {file_type}. Valid types: {valid_types}",
"code": ErrorCode.INVALID_REQUEST.value,
}
request = IngestRequest(
@@ -582,6 +721,7 @@ async def ingest_content(
return {
"success": False,
"error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
}
@@ -608,6 +748,16 @@ async def delete_content(
Specify either source_path, collection, or chunk_ids to delete.
"""
try:
# Validate inputs
if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if collection and (error := _validate_collection(collection)):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_source_path(source_path):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
request = DeleteRequest(
project_id=project_id,
agent_id=agent_id,
@@ -636,6 +786,7 @@ async def delete_content(
return {
"success": False,
"error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
}
@@ -650,6 +801,12 @@ async def list_collections(
Returns collection names with chunk counts and file types.
"""
try:
# Validate inputs
if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
result = await _collections.list_collections(project_id) # type: ignore[union-attr]
return {
@@ -681,6 +838,7 @@ async def list_collections(
return {
"success": False,
"error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
}
@@ -696,6 +854,14 @@ async def get_collection_stats(
Returns chunk counts, token totals, and type breakdowns.
"""
try:
# Validate inputs
if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_collection(collection):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
stats = await _collections.get_collection_stats(project_id, collection) # type: ignore[union-attr]
return {
@@ -724,6 +890,7 @@ async def get_collection_stats(
return {
"success": False,
"error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
}
@@ -756,6 +923,16 @@ async def update_document(
Replaces all existing chunks for the source path with new content.
"""
try:
# Validate inputs
if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_collection(collection):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_source_path(source_path):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
# Validate content size to prevent DoS
settings = get_settings()
content_size = len(content.encode("utf-8"))
@@ -763,6 +940,7 @@ async def update_document(
return {
"success": False,
"error": f"Content size ({content_size} bytes) exceeds maximum allowed ({settings.max_document_size} bytes)",
"code": ErrorCode.INVALID_REQUEST.value,
}
# Parse chunk type
@@ -773,6 +951,7 @@ async def update_document(
return {
"success": False,
"error": f"Invalid chunk type: {chunk_type}. Valid types: {valid_types}",
"code": ErrorCode.INVALID_REQUEST.value,
}
# Parse file type
@@ -785,6 +964,7 @@ async def update_document(
return {
"success": False,
"error": f"Invalid file type: {file_type}. Valid types: {valid_types}",
"code": ErrorCode.INVALID_REQUEST.value,
}
result = await _collections.update_document( # type: ignore[union-attr]
@@ -820,6 +1000,7 @@ async def update_document(
return {
"success": False,
"error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
}