fix(mcp-kb): add input validation, path security, and health checks

Security fixes from deep review:
- Add input validation patterns for project_id, agent_id, collection
- Add path traversal protection for source_path (reject .., null bytes)
- Add error codes (INTERNAL_ERROR) to generic exception handlers
- Handle FieldInfo objects in validation for test robustness

Performance fixes:
- Enable concurrent hybrid search with asyncio.gather

Health endpoint improvements:
- Check all dependencies (database, Redis, LLM Gateway)
- Return degraded/unhealthy status based on dependency health
- Updated tests for new health check response structure

All 139 tests pass.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-04 01:18:50 +01:00
parent cd7a9ccbdf
commit 6bb376a336
3 changed files with 250 additions and 14 deletions

View File

@@ -5,6 +5,7 @@ Provides semantic (vector), keyword (full-text), and hybrid search
capabilities over the knowledge base. capabilities over the knowledge base.
""" """
import asyncio
import logging import logging
import time import time
@@ -158,6 +159,7 @@ class SearchEngine:
Execute hybrid search combining semantic and keyword. Execute hybrid search combining semantic and keyword.
Uses Reciprocal Rank Fusion (RRF) for result combination. Uses Reciprocal Rank Fusion (RRF) for result combination.
Executes both searches concurrently for better performance.
""" """
# Execute both searches with higher limits for fusion # Execute both searches with higher limits for fusion
fusion_limit = min(request.limit * 2, 100) fusion_limit = min(request.limit * 2, 100)
@@ -187,9 +189,11 @@ class SearchEngine:
include_metadata=request.include_metadata, include_metadata=request.include_metadata,
) )
# Execute searches # Execute searches concurrently for better performance
semantic_results = await self._semantic_search(semantic_request) semantic_results, keyword_results = await asyncio.gather(
keyword_results = await self._keyword_search(keyword_request) self._semantic_search(semantic_request),
self._keyword_search(keyword_request),
)
# Fuse results using RRF # Fuse results using RRF
fused = self._reciprocal_rank_fusion( fused = self._reciprocal_rank_fusion(

View File

@@ -7,6 +7,7 @@ intelligent chunking, and collection management.
import inspect import inspect
import logging import logging
import re
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Any, get_type_hints from typing import Any, get_type_hints
@@ -20,7 +21,7 @@ from collections.abc import AsyncIterator
from config import get_settings from config import get_settings
from database import DatabaseManager, get_database_manager from database import DatabaseManager, get_database_manager
from embeddings import EmbeddingGenerator, get_embedding_generator from embeddings import EmbeddingGenerator, get_embedding_generator
from exceptions import KnowledgeBaseError from exceptions import ErrorCode, KnowledgeBaseError
from models import ( from models import (
ChunkType, ChunkType,
DeleteRequest, DeleteRequest,
@@ -31,6 +32,67 @@ from models import (
) )
from search import SearchEngine, get_search_engine from search import SearchEngine, get_search_engine
# Input validation patterns
# Allow alphanumeric, hyphens, underscores (1-128 chars)
ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,128}$")
# Collection names: alphanumeric, hyphens, underscores (1-64 chars)
COLLECTION_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,64}$")
def _validate_id(value: str, field_name: str) -> str | None:
"""Validate project_id or agent_id format.
Returns error message if invalid, None if valid.
"""
# Handle FieldInfo objects from direct .fn() calls in tests
if not isinstance(value, str):
return f"{field_name} must be a string"
if not value:
return f"{field_name} is required"
if not ID_PATTERN.match(value):
return f"Invalid {field_name}: must be 1-128 alphanumeric characters, hyphens, or underscores"
return None
def _validate_collection(value: str) -> str | None:
"""Validate collection name format.
Returns error message if invalid, None if valid.
"""
# Handle FieldInfo objects from direct .fn() calls in tests
if not isinstance(value, str):
return None # Non-string means default not resolved, skip validation
if not COLLECTION_PATTERN.match(value):
return "Invalid collection: must be 1-64 alphanumeric characters, hyphens, or underscores"
return None
def _validate_source_path(value: str | None) -> str | None:
"""Validate source_path to prevent path traversal.
Returns error message if invalid, None if valid.
"""
if value is None:
return None
# Handle FieldInfo objects from direct .fn() calls in tests
if not isinstance(value, str):
return None # Non-string means default not resolved, skip validation
# Normalize path and check for traversal attempts
if ".." in value:
return "Invalid source_path: path traversal not allowed"
# Check for null bytes (used in some injection attacks)
if "\x00" in value:
return "Invalid source_path: null bytes not allowed"
# Limit path length to prevent DoS
if len(value) > 4096:
return "Invalid source_path: path too long (max 4096 chars)"
return None
# Configure logging # Configure logging
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
@@ -96,24 +158,77 @@ app = FastAPI(
@app.get("/health") @app.get("/health")
async def health_check() -> dict[str, Any]: async def health_check() -> dict[str, Any]:
"""Health check endpoint.""" """Health check endpoint.
Checks all dependencies: database, Redis cache, and LLM Gateway.
Returns degraded status if any non-critical dependency fails.
Returns unhealthy status if critical dependencies fail.
"""
from datetime import UTC, datetime
status: dict[str, Any] = { status: dict[str, Any] = {
"status": "healthy", "status": "healthy",
"service": "knowledge-base", "service": "knowledge-base",
"version": "0.1.0", "version": "0.1.0",
"timestamp": datetime.now(UTC).isoformat(),
"dependencies": {},
} }
# Check database connection is_degraded = False
is_unhealthy = False
# Check database connection (critical)
try: try:
if _database and _database._pool: if _database and _database._pool:
async with _database.acquire() as conn: async with _database.acquire() as conn:
await conn.fetchval("SELECT 1") await conn.fetchval("SELECT 1")
status["database"] = "connected" status["dependencies"]["database"] = "connected"
else: else:
status["database"] = "not initialized" status["dependencies"]["database"] = "not initialized"
is_unhealthy = True
except Exception as e: except Exception as e:
status["database"] = f"error: {e}" status["dependencies"]["database"] = f"error: {e}"
is_unhealthy = True
# Check Redis cache (non-critical - degraded without it)
try:
if _embeddings and _embeddings._redis:
await _embeddings._redis.ping()
status["dependencies"]["redis"] = "connected"
else:
status["dependencies"]["redis"] = "not initialized"
is_degraded = True
except Exception as e:
status["dependencies"]["redis"] = f"error: {e}"
is_degraded = True
# Check LLM Gateway connectivity (non-critical for health check)
try:
if _embeddings and _embeddings._http_client:
settings = get_settings()
response = await _embeddings._http_client.get(
f"{settings.llm_gateway_url}/health",
timeout=5.0,
)
if response.status_code == 200:
status["dependencies"]["llm_gateway"] = "connected"
else:
status["dependencies"]["llm_gateway"] = f"unhealthy (status {response.status_code})"
is_degraded = True
else:
status["dependencies"]["llm_gateway"] = "not initialized"
is_degraded = True
except Exception as e:
status["dependencies"]["llm_gateway"] = f"error: {e}"
is_degraded = True
# Set overall status
if is_unhealthy:
status["status"] = "unhealthy"
elif is_degraded:
status["status"] = "degraded" status["status"] = "degraded"
else:
status["status"] = "healthy"
return status return status
@@ -411,6 +526,14 @@ async def search_knowledge(
Returns chunks ranked by relevance to the query. Returns chunks ranked by relevance to the query.
""" """
try: try:
# Validate inputs
if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if collection and (error := _validate_collection(collection)):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
# Parse search type # Parse search type
try: try:
search_type_enum = SearchType(search_type.lower()) search_type_enum = SearchType(search_type.lower())
@@ -419,6 +542,7 @@ async def search_knowledge(
return { return {
"success": False, "success": False,
"error": f"Invalid search type: {search_type}. Valid types: {valid_types}", "error": f"Invalid search type: {search_type}. Valid types: {valid_types}",
"code": ErrorCode.INVALID_REQUEST.value,
} }
# Parse file types # Parse file types
@@ -430,6 +554,7 @@ async def search_knowledge(
return { return {
"success": False, "success": False,
"error": f"Invalid file type: {e}", "error": f"Invalid file type: {e}",
"code": ErrorCode.INVALID_REQUEST.value,
} }
request = SearchRequest( request = SearchRequest(
@@ -480,6 +605,7 @@ async def search_knowledge(
return { return {
"success": False, "success": False,
"error": str(e), "error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
} }
@@ -516,6 +642,16 @@ async def ingest_content(
the LLM Gateway, and stored in pgvector for search. the LLM Gateway, and stored in pgvector for search.
""" """
try: try:
# Validate inputs
if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_collection(collection):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_source_path(source_path):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
# Validate content size to prevent DoS # Validate content size to prevent DoS
settings = get_settings() settings = get_settings()
content_size = len(content.encode("utf-8")) content_size = len(content.encode("utf-8"))
@@ -523,6 +659,7 @@ async def ingest_content(
return { return {
"success": False, "success": False,
"error": f"Content size ({content_size} bytes) exceeds maximum allowed ({settings.max_document_size} bytes)", "error": f"Content size ({content_size} bytes) exceeds maximum allowed ({settings.max_document_size} bytes)",
"code": ErrorCode.INVALID_REQUEST.value,
} }
# Parse chunk type # Parse chunk type
@@ -533,6 +670,7 @@ async def ingest_content(
return { return {
"success": False, "success": False,
"error": f"Invalid chunk type: {chunk_type}. Valid types: {valid_types}", "error": f"Invalid chunk type: {chunk_type}. Valid types: {valid_types}",
"code": ErrorCode.INVALID_REQUEST.value,
} }
# Parse file type # Parse file type
@@ -545,6 +683,7 @@ async def ingest_content(
return { return {
"success": False, "success": False,
"error": f"Invalid file type: {file_type}. Valid types: {valid_types}", "error": f"Invalid file type: {file_type}. Valid types: {valid_types}",
"code": ErrorCode.INVALID_REQUEST.value,
} }
request = IngestRequest( request = IngestRequest(
@@ -582,6 +721,7 @@ async def ingest_content(
return { return {
"success": False, "success": False,
"error": str(e), "error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
} }
@@ -608,6 +748,16 @@ async def delete_content(
Specify either source_path, collection, or chunk_ids to delete. Specify either source_path, collection, or chunk_ids to delete.
""" """
try: try:
# Validate inputs
if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if collection and (error := _validate_collection(collection)):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_source_path(source_path):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
request = DeleteRequest( request = DeleteRequest(
project_id=project_id, project_id=project_id,
agent_id=agent_id, agent_id=agent_id,
@@ -636,6 +786,7 @@ async def delete_content(
return { return {
"success": False, "success": False,
"error": str(e), "error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
} }
@@ -650,6 +801,12 @@ async def list_collections(
Returns collection names with chunk counts and file types. Returns collection names with chunk counts and file types.
""" """
try: try:
# Validate inputs
if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
result = await _collections.list_collections(project_id) # type: ignore[union-attr] result = await _collections.list_collections(project_id) # type: ignore[union-attr]
return { return {
@@ -681,6 +838,7 @@ async def list_collections(
return { return {
"success": False, "success": False,
"error": str(e), "error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
} }
@@ -696,6 +854,14 @@ async def get_collection_stats(
Returns chunk counts, token totals, and type breakdowns. Returns chunk counts, token totals, and type breakdowns.
""" """
try: try:
# Validate inputs
if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_collection(collection):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
stats = await _collections.get_collection_stats(project_id, collection) # type: ignore[union-attr] stats = await _collections.get_collection_stats(project_id, collection) # type: ignore[union-attr]
return { return {
@@ -724,6 +890,7 @@ async def get_collection_stats(
return { return {
"success": False, "success": False,
"error": str(e), "error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
} }
@@ -756,6 +923,16 @@ async def update_document(
Replaces all existing chunks for the source path with new content. Replaces all existing chunks for the source path with new content.
""" """
try: try:
# Validate inputs
if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_collection(collection):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
if error := _validate_source_path(source_path):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
# Validate content size to prevent DoS # Validate content size to prevent DoS
settings = get_settings() settings = get_settings()
content_size = len(content.encode("utf-8")) content_size = len(content.encode("utf-8"))
@@ -763,6 +940,7 @@ async def update_document(
return { return {
"success": False, "success": False,
"error": f"Content size ({content_size} bytes) exceeds maximum allowed ({settings.max_document_size} bytes)", "error": f"Content size ({content_size} bytes) exceeds maximum allowed ({settings.max_document_size} bytes)",
"code": ErrorCode.INVALID_REQUEST.value,
} }
# Parse chunk type # Parse chunk type
@@ -773,6 +951,7 @@ async def update_document(
return { return {
"success": False, "success": False,
"error": f"Invalid chunk type: {chunk_type}. Valid types: {valid_types}", "error": f"Invalid chunk type: {chunk_type}. Valid types: {valid_types}",
"code": ErrorCode.INVALID_REQUEST.value,
} }
# Parse file type # Parse file type
@@ -785,6 +964,7 @@ async def update_document(
return { return {
"success": False, "success": False,
"error": f"Invalid file type: {file_type}. Valid types: {valid_types}", "error": f"Invalid file type: {file_type}. Valid types: {valid_types}",
"code": ErrorCode.INVALID_REQUEST.value,
} }
result = await _collections.update_document( # type: ignore[union-attr] result = await _collections.update_document( # type: ignore[union-attr]
@@ -820,6 +1000,7 @@ async def update_document(
return { return {
"success": False, "success": False,
"error": str(e), "error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
} }

View File

@@ -13,10 +13,10 @@ class TestHealthCheck:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_health_check_healthy(self): async def test_health_check_healthy(self):
"""Test health check when healthy.""" """Test health check when all dependencies are connected."""
import server import server
# Create a proper async context manager mock # Create a proper async context manager mock for database
mock_conn = AsyncMock() mock_conn = AsyncMock()
mock_conn.fetchval = AsyncMock(return_value=1) mock_conn.fetchval = AsyncMock(return_value=1)
@@ -29,24 +29,75 @@ class TestHealthCheck:
mock_cm.__aexit__.return_value = None mock_cm.__aexit__.return_value = None
mock_db.acquire.return_value = mock_cm mock_db.acquire.return_value = mock_cm
# Mock Redis
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock(return_value=True)
# Mock HTTP client for LLM Gateway
mock_http_response = AsyncMock()
mock_http_response.status_code = 200
mock_http_client = AsyncMock()
mock_http_client.get = AsyncMock(return_value=mock_http_response)
# Mock embeddings with Redis and HTTP client
mock_embeddings = MagicMock()
mock_embeddings._redis = mock_redis
mock_embeddings._http_client = mock_http_client
server._database = mock_db server._database = mock_db
server._embeddings = mock_embeddings
result = await server.health_check() result = await server.health_check()
assert result["status"] == "healthy" assert result["status"] == "healthy"
assert result["service"] == "knowledge-base" assert result["service"] == "knowledge-base"
assert result["database"] == "connected" assert result["dependencies"]["database"] == "connected"
assert result["dependencies"]["redis"] == "connected"
assert result["dependencies"]["llm_gateway"] == "connected"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_health_check_no_database(self): async def test_health_check_no_database(self):
"""Test health check without database.""" """Test health check without database - should be unhealthy."""
import server import server
server._database = None server._database = None
server._embeddings = None
result = await server.health_check() result = await server.health_check()
assert result["database"] == "not initialized" assert result["status"] == "unhealthy"
assert result["dependencies"]["database"] == "not initialized"
@pytest.mark.asyncio
async def test_health_check_degraded(self):
"""Test health check with database but no Redis - should be degraded."""
import server
# Create a proper async context manager mock for database
mock_conn = AsyncMock()
mock_conn.fetchval = AsyncMock(return_value=1)
mock_db = MagicMock()
mock_db._pool = MagicMock()
mock_cm = AsyncMock()
mock_cm.__aenter__.return_value = mock_conn
mock_cm.__aexit__.return_value = None
mock_db.acquire.return_value = mock_cm
# Mock embeddings without Redis
mock_embeddings = MagicMock()
mock_embeddings._redis = None
mock_embeddings._http_client = None
server._database = mock_db
server._embeddings = mock_embeddings
result = await server.health_check()
assert result["status"] == "degraded"
assert result["dependencies"]["database"] == "connected"
assert result["dependencies"]["redis"] == "not initialized"
class TestSearchKnowledgeTool: class TestSearchKnowledgeTool: