fix(mcp-kb): add input validation, path security, and health checks

Security fixes from deep review:
- Add input validation patterns for project_id, agent_id, collection
- Add path traversal protection for source_path (reject .., null bytes)
- Add error codes (INTERNAL_ERROR) to generic exception handlers
- Handle FieldInfo objects in validation for test robustness

Performance fixes:
- Enable concurrent hybrid search with asyncio.gather

Health endpoint improvements:
- Check all dependencies (database, Redis, LLM Gateway)
- Return degraded/unhealthy status based on dependency health
- Updated tests for new health check response structure

All 139 tests pass.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-04 01:18:50 +01:00
parent cd7a9ccbdf
commit 6bb376a336
3 changed files with 250 additions and 14 deletions

View File

@@ -5,6 +5,7 @@ Provides semantic (vector), keyword (full-text), and hybrid search
capabilities over the knowledge base.
"""
import asyncio
import logging
import time
@@ -158,6 +159,7 @@ class SearchEngine:
Execute hybrid search combining semantic and keyword.
Uses Reciprocal Rank Fusion (RRF) for result combination.
Executes both searches concurrently for better performance.
"""
# Execute both searches with higher limits for fusion
fusion_limit = min(request.limit * 2, 100)
@@ -187,9 +189,11 @@ class SearchEngine:
include_metadata=request.include_metadata,
)
# Execute searches
semantic_results = await self._semantic_search(semantic_request)
keyword_results = await self._keyword_search(keyword_request)
# Execute searches concurrently for better performance
semantic_results, keyword_results = await asyncio.gather(
self._semantic_search(semantic_request),
self._keyword_search(keyword_request),
)
# Fuse results using RRF
fused = self._reciprocal_rank_fusion(

View File

@@ -7,6 +7,7 @@ intelligent chunking, and collection management.
import inspect
import logging
import re
from contextlib import asynccontextmanager
from typing import Any, get_type_hints
@@ -20,7 +21,7 @@ from collections.abc import AsyncIterator
from config import get_settings
from database import DatabaseManager, get_database_manager
from embeddings import EmbeddingGenerator, get_embedding_generator
from exceptions import KnowledgeBaseError
from exceptions import ErrorCode, KnowledgeBaseError
from models import (
ChunkType,
DeleteRequest,
@@ -31,6 +32,67 @@ from models import (
)
from search import SearchEngine, get_search_engine
# Input validation patterns
# Allow alphanumeric, hyphens, underscores (1-128 chars)
ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,128}$")
# Collection names: alphanumeric, hyphens, underscores (1-64 chars)
COLLECTION_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,64}$")
def _validate_id(value: str, field_name: str) -> str | None:
"""Validate project_id or agent_id format.
Returns error message if invalid, None if valid.
"""
# Handle FieldInfo objects from direct .fn() calls in tests
if not isinstance(value, str):
return f"{field_name} must be a string"
if not value:
return f"{field_name} is required"
if not ID_PATTERN.match(value):
return f"Invalid {field_name}: must be 1-128 alphanumeric characters, hyphens, or underscores"
return None
def _validate_collection(value: str) -> str | None:
"""Validate collection name format.
Returns error message if invalid, None if valid.
"""
# Handle FieldInfo objects from direct .fn() calls in tests
if not isinstance(value, str):
return None # Non-string means default not resolved, skip validation
if not COLLECTION_PATTERN.match(value):
return "Invalid collection: must be 1-64 alphanumeric characters, hyphens, or underscores"
return None
def _validate_source_path(value: str | None) -> str | None:
"""Validate source_path to prevent path traversal.
Returns error message if invalid, None if valid.
"""
if value is None:
return None
# Handle FieldInfo objects from direct .fn() calls in tests
if not isinstance(value, str):
return None # Non-string means default not resolved, skip validation
# Normalize path and check for traversal attempts
if ".." in value:
return "Invalid source_path: path traversal not allowed"
# Check for null bytes (used in some injection attacks)
if "\x00" in value:
return "Invalid source_path: null bytes not allowed"
# Limit path length to prevent DoS
if len(value) > 4096:
return "Invalid source_path: path too long (max 4096 chars)"
return None
# Configure logging
logging.basicConfig(
level=logging.INFO,
@@ -96,24 +158,77 @@ app = FastAPI(
@app.get("/health")
async def health_check() -> dict[str, Any]:
"""Health check endpoint."""
"""Health check endpoint.
Checks all dependencies: database, Redis cache, and LLM Gateway.
Returns degraded status if any non-critical dependency fails.
Returns unhealthy status if critical dependencies fail.
"""
from datetime import UTC, datetime
status: dict[str, Any] = {
"status": "healthy",
"service": "knowledge-base",
"version": "0.1.0",
"timestamp": datetime.now(UTC).isoformat(),
"dependencies": {},
}
# Check database connection
is_degraded = False
is_unhealthy = False
# Check database connection (critical)
try:
if _database and _database._pool:
async with _database.acquire() as conn:
await conn.fetchval("SELECT 1")
status["database"] = "connected"
status["dependencies"]["database"] = "connected"
else:
status["database"] = "not initialized"
status["dependencies"]["database"] = "not initialized"
is_unhealthy = True
except Exception as e:
status["database"] = f"error: {e}"
status["dependencies"]["database"] = f"error: {e}"
is_unhealthy = True
# Check Redis cache (non-critical - degraded without it)
try:
if _embeddings and _embeddings._redis:
await _embeddings._redis.ping()
status["dependencies"]["redis"] = "connected"
else:
status["dependencies"]["redis"] = "not initialized"
is_degraded = True
except Exception as e:
status["dependencies"]["redis"] = f"error: {e}"
is_degraded = True
# Check LLM Gateway connectivity (non-critical for health check)
try:
if _embeddings and _embeddings._http_client:
settings = get_settings()
response = await _embeddings._http_client.get(
f"{settings.llm_gateway_url}/health",
timeout=5.0,
)
if response.status_code == 200:
status["dependencies"]["llm_gateway"] = "connected"
else:
status["dependencies"]["llm_gateway"] = f"unhealthy (status {response.status_code})"
is_degraded = True
else:
status["dependencies"]["llm_gateway"] = "not initialized"
is_degraded = True
except Exception as e:
status["dependencies"]["llm_gateway"] = f"error: {e}"
is_degraded = True
# Set overall status
if is_unhealthy:
status["status"] = "unhealthy"
elif is_degraded:
status["status"] = "degraded"
else:
status["status"] = "healthy"
return status
@@ -411,6 +526,14 @@ async def search_knowledge(
Returns chunks ranked by relevance to the query.
"""
try:
# Validate inputs
if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if collection and (error := _validate_collection(collection)):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
# Parse search type
try:
search_type_enum = SearchType(search_type.lower())
@@ -419,6 +542,7 @@ async def search_knowledge(
return {
"success": False,
"error": f"Invalid search type: {search_type}. Valid types: {valid_types}",
"code": ErrorCode.INVALID_REQUEST.value,
}
# Parse file types
@@ -430,6 +554,7 @@ async def search_knowledge(
return {
"success": False,
"error": f"Invalid file type: {e}",
"code": ErrorCode.INVALID_REQUEST.value,
}
request = SearchRequest(
@@ -480,6 +605,7 @@ async def search_knowledge(
return {
"success": False,
"error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
}
@@ -516,6 +642,16 @@ async def ingest_content(
the LLM Gateway, and stored in pgvector for search.
"""
try:
# Validate inputs
if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_collection(collection):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_source_path(source_path):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
# Validate content size to prevent DoS
settings = get_settings()
content_size = len(content.encode("utf-8"))
@@ -523,6 +659,7 @@ async def ingest_content(
return {
"success": False,
"error": f"Content size ({content_size} bytes) exceeds maximum allowed ({settings.max_document_size} bytes)",
"code": ErrorCode.INVALID_REQUEST.value,
}
# Parse chunk type
@@ -533,6 +670,7 @@ async def ingest_content(
return {
"success": False,
"error": f"Invalid chunk type: {chunk_type}. Valid types: {valid_types}",
"code": ErrorCode.INVALID_REQUEST.value,
}
# Parse file type
@@ -545,6 +683,7 @@ async def ingest_content(
return {
"success": False,
"error": f"Invalid file type: {file_type}. Valid types: {valid_types}",
"code": ErrorCode.INVALID_REQUEST.value,
}
request = IngestRequest(
@@ -582,6 +721,7 @@ async def ingest_content(
return {
"success": False,
"error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
}
@@ -608,6 +748,16 @@ async def delete_content(
Specify either source_path, collection, or chunk_ids to delete.
"""
try:
# Validate inputs
if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if collection and (error := _validate_collection(collection)):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_source_path(source_path):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
request = DeleteRequest(
project_id=project_id,
agent_id=agent_id,
@@ -636,6 +786,7 @@ async def delete_content(
return {
"success": False,
"error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
}
@@ -650,6 +801,12 @@ async def list_collections(
Returns collection names with chunk counts and file types.
"""
try:
# Validate inputs
if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
result = await _collections.list_collections(project_id) # type: ignore[union-attr]
return {
@@ -681,6 +838,7 @@ async def list_collections(
return {
"success": False,
"error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
}
@@ -696,6 +854,14 @@ async def get_collection_stats(
Returns chunk counts, token totals, and type breakdowns.
"""
try:
# Validate inputs
if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_collection(collection):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
stats = await _collections.get_collection_stats(project_id, collection) # type: ignore[union-attr]
return {
@@ -724,6 +890,7 @@ async def get_collection_stats(
return {
"success": False,
"error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
}
@@ -756,6 +923,16 @@ async def update_document(
Replaces all existing chunks for the source path with new content.
"""
try:
# Validate inputs
if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_collection(collection):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_source_path(source_path):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
# Validate content size to prevent DoS
settings = get_settings()
content_size = len(content.encode("utf-8"))
@@ -763,6 +940,7 @@ async def update_document(
return {
"success": False,
"error": f"Content size ({content_size} bytes) exceeds maximum allowed ({settings.max_document_size} bytes)",
"code": ErrorCode.INVALID_REQUEST.value,
}
# Parse chunk type
@@ -773,6 +951,7 @@ async def update_document(
return {
"success": False,
"error": f"Invalid chunk type: {chunk_type}. Valid types: {valid_types}",
"code": ErrorCode.INVALID_REQUEST.value,
}
# Parse file type
@@ -785,6 +964,7 @@ async def update_document(
return {
"success": False,
"error": f"Invalid file type: {file_type}. Valid types: {valid_types}",
"code": ErrorCode.INVALID_REQUEST.value,
}
result = await _collections.update_document( # type: ignore[union-attr]
@@ -820,6 +1000,7 @@ async def update_document(
return {
"success": False,
"error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
}

View File

@@ -13,10 +13,10 @@ class TestHealthCheck:
@pytest.mark.asyncio
async def test_health_check_healthy(self):
"""Test health check when healthy."""
"""Test health check when all dependencies are connected."""
import server
# Create a proper async context manager mock
# Create a proper async context manager mock for database
mock_conn = AsyncMock()
mock_conn.fetchval = AsyncMock(return_value=1)
@@ -29,24 +29,75 @@ class TestHealthCheck:
mock_cm.__aexit__.return_value = None
mock_db.acquire.return_value = mock_cm
# Mock Redis
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock(return_value=True)
# Mock HTTP client for LLM Gateway
mock_http_response = AsyncMock()
mock_http_response.status_code = 200
mock_http_client = AsyncMock()
mock_http_client.get = AsyncMock(return_value=mock_http_response)
# Mock embeddings with Redis and HTTP client
mock_embeddings = MagicMock()
mock_embeddings._redis = mock_redis
mock_embeddings._http_client = mock_http_client
server._database = mock_db
server._embeddings = mock_embeddings
result = await server.health_check()
assert result["status"] == "healthy"
assert result["service"] == "knowledge-base"
assert result["database"] == "connected"
assert result["dependencies"]["database"] == "connected"
assert result["dependencies"]["redis"] == "connected"
assert result["dependencies"]["llm_gateway"] == "connected"
@pytest.mark.asyncio
async def test_health_check_no_database(self):
"""Test health check without database."""
"""Test health check without database - should be unhealthy."""
import server
server._database = None
server._embeddings = None
result = await server.health_check()
assert result["database"] == "not initialized"
assert result["status"] == "unhealthy"
assert result["dependencies"]["database"] == "not initialized"
@pytest.mark.asyncio
async def test_health_check_degraded(self):
"""Test health check with database but no Redis - should be degraded."""
import server
# Create a proper async context manager mock for database
mock_conn = AsyncMock()
mock_conn.fetchval = AsyncMock(return_value=1)
mock_db = MagicMock()
mock_db._pool = MagicMock()
mock_cm = AsyncMock()
mock_cm.__aenter__.return_value = mock_conn
mock_cm.__aexit__.return_value = None
mock_db.acquire.return_value = mock_cm
# Mock embeddings without Redis
mock_embeddings = MagicMock()
mock_embeddings._redis = None
mock_embeddings._http_client = None
server._database = mock_db
server._embeddings = mock_embeddings
result = await server.health_check()
assert result["status"] == "degraded"
assert result["dependencies"]["database"] == "connected"
assert result["dependencies"]["redis"] == "not initialized"
class TestSearchKnowledgeTool: