forked from cardosofelipe/fast-next-template
fix(mcp-kb): add input validation, path security, and health checks
Security fixes from deep review: - Add input validation patterns for project_id, agent_id, collection - Add path traversal protection for source_path (reject .., null bytes) - Add error codes (INTERNAL_ERROR) to generic exception handlers - Handle FieldInfo objects in validation for test robustness Performance fixes: - Enable concurrent hybrid search with asyncio.gather Health endpoint improvements: - Check all dependencies (database, Redis, LLM Gateway) - Return degraded/unhealthy status based on dependency health - Updated tests for new health check response structure All 139 tests pass. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -5,6 +5,7 @@ Provides semantic (vector), keyword (full-text), and hybrid search
|
||||
capabilities over the knowledge base.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
|
||||
@@ -158,6 +159,7 @@ class SearchEngine:
|
||||
Execute hybrid search combining semantic and keyword.
|
||||
|
||||
Uses Reciprocal Rank Fusion (RRF) for result combination.
|
||||
Executes both searches concurrently for better performance.
|
||||
"""
|
||||
# Execute both searches with higher limits for fusion
|
||||
fusion_limit = min(request.limit * 2, 100)
|
||||
@@ -187,9 +189,11 @@ class SearchEngine:
|
||||
include_metadata=request.include_metadata,
|
||||
)
|
||||
|
||||
# Execute searches
|
||||
semantic_results = await self._semantic_search(semantic_request)
|
||||
keyword_results = await self._keyword_search(keyword_request)
|
||||
# Execute searches concurrently for better performance
|
||||
semantic_results, keyword_results = await asyncio.gather(
|
||||
self._semantic_search(semantic_request),
|
||||
self._keyword_search(keyword_request),
|
||||
)
|
||||
|
||||
# Fuse results using RRF
|
||||
fused = self._reciprocal_rank_fusion(
|
||||
|
||||
@@ -7,6 +7,7 @@ intelligent chunking, and collection management.
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import re
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, get_type_hints
|
||||
|
||||
@@ -20,7 +21,7 @@ from collections.abc import AsyncIterator
|
||||
from config import get_settings
|
||||
from database import DatabaseManager, get_database_manager
|
||||
from embeddings import EmbeddingGenerator, get_embedding_generator
|
||||
from exceptions import KnowledgeBaseError
|
||||
from exceptions import ErrorCode, KnowledgeBaseError
|
||||
from models import (
|
||||
ChunkType,
|
||||
DeleteRequest,
|
||||
@@ -31,6 +32,67 @@ from models import (
|
||||
)
|
||||
from search import SearchEngine, get_search_engine
|
||||
|
||||
# Input validation patterns
|
||||
# Allow alphanumeric, hyphens, underscores (1-128 chars)
|
||||
ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,128}$")
|
||||
# Collection names: alphanumeric, hyphens, underscores (1-64 chars)
|
||||
COLLECTION_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,64}$")
|
||||
|
||||
|
||||
def _validate_id(value: str, field_name: str) -> str | None:
|
||||
"""Validate project_id or agent_id format.
|
||||
|
||||
Returns error message if invalid, None if valid.
|
||||
"""
|
||||
# Handle FieldInfo objects from direct .fn() calls in tests
|
||||
if not isinstance(value, str):
|
||||
return f"{field_name} must be a string"
|
||||
if not value:
|
||||
return f"{field_name} is required"
|
||||
if not ID_PATTERN.match(value):
|
||||
return f"Invalid {field_name}: must be 1-128 alphanumeric characters, hyphens, or underscores"
|
||||
return None
|
||||
|
||||
|
||||
def _validate_collection(value: str) -> str | None:
|
||||
"""Validate collection name format.
|
||||
|
||||
Returns error message if invalid, None if valid.
|
||||
"""
|
||||
# Handle FieldInfo objects from direct .fn() calls in tests
|
||||
if not isinstance(value, str):
|
||||
return None # Non-string means default not resolved, skip validation
|
||||
if not COLLECTION_PATTERN.match(value):
|
||||
return "Invalid collection: must be 1-64 alphanumeric characters, hyphens, or underscores"
|
||||
return None
|
||||
|
||||
|
||||
def _validate_source_path(value: str | None) -> str | None:
|
||||
"""Validate source_path to prevent path traversal.
|
||||
|
||||
Returns error message if invalid, None if valid.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# Handle FieldInfo objects from direct .fn() calls in tests
|
||||
if not isinstance(value, str):
|
||||
return None # Non-string means default not resolved, skip validation
|
||||
|
||||
# Normalize path and check for traversal attempts
|
||||
if ".." in value:
|
||||
return "Invalid source_path: path traversal not allowed"
|
||||
|
||||
# Check for null bytes (used in some injection attacks)
|
||||
if "\x00" in value:
|
||||
return "Invalid source_path: null bytes not allowed"
|
||||
|
||||
# Limit path length to prevent DoS
|
||||
if len(value) > 4096:
|
||||
return "Invalid source_path: path too long (max 4096 chars)"
|
||||
|
||||
return None
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
@@ -96,24 +158,77 @@ app = FastAPI(
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check() -> dict[str, Any]:
|
||||
"""Health check endpoint."""
|
||||
"""Health check endpoint.
|
||||
|
||||
Checks all dependencies: database, Redis cache, and LLM Gateway.
|
||||
Returns degraded status if any non-critical dependency fails.
|
||||
Returns unhealthy status if critical dependencies fail.
|
||||
"""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
status: dict[str, Any] = {
|
||||
"status": "healthy",
|
||||
"service": "knowledge-base",
|
||||
"version": "0.1.0",
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"dependencies": {},
|
||||
}
|
||||
|
||||
# Check database connection
|
||||
is_degraded = False
|
||||
is_unhealthy = False
|
||||
|
||||
# Check database connection (critical)
|
||||
try:
|
||||
if _database and _database._pool:
|
||||
async with _database.acquire() as conn:
|
||||
await conn.fetchval("SELECT 1")
|
||||
status["database"] = "connected"
|
||||
status["dependencies"]["database"] = "connected"
|
||||
else:
|
||||
status["database"] = "not initialized"
|
||||
status["dependencies"]["database"] = "not initialized"
|
||||
is_unhealthy = True
|
||||
except Exception as e:
|
||||
status["database"] = f"error: {e}"
|
||||
status["dependencies"]["database"] = f"error: {e}"
|
||||
is_unhealthy = True
|
||||
|
||||
# Check Redis cache (non-critical - degraded without it)
|
||||
try:
|
||||
if _embeddings and _embeddings._redis:
|
||||
await _embeddings._redis.ping()
|
||||
status["dependencies"]["redis"] = "connected"
|
||||
else:
|
||||
status["dependencies"]["redis"] = "not initialized"
|
||||
is_degraded = True
|
||||
except Exception as e:
|
||||
status["dependencies"]["redis"] = f"error: {e}"
|
||||
is_degraded = True
|
||||
|
||||
# Check LLM Gateway connectivity (non-critical for health check)
|
||||
try:
|
||||
if _embeddings and _embeddings._http_client:
|
||||
settings = get_settings()
|
||||
response = await _embeddings._http_client.get(
|
||||
f"{settings.llm_gateway_url}/health",
|
||||
timeout=5.0,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
status["dependencies"]["llm_gateway"] = "connected"
|
||||
else:
|
||||
status["dependencies"]["llm_gateway"] = f"unhealthy (status {response.status_code})"
|
||||
is_degraded = True
|
||||
else:
|
||||
status["dependencies"]["llm_gateway"] = "not initialized"
|
||||
is_degraded = True
|
||||
except Exception as e:
|
||||
status["dependencies"]["llm_gateway"] = f"error: {e}"
|
||||
is_degraded = True
|
||||
|
||||
# Set overall status
|
||||
if is_unhealthy:
|
||||
status["status"] = "unhealthy"
|
||||
elif is_degraded:
|
||||
status["status"] = "degraded"
|
||||
else:
|
||||
status["status"] = "healthy"
|
||||
|
||||
return status
|
||||
|
||||
@@ -411,6 +526,14 @@ async def search_knowledge(
|
||||
Returns chunks ranked by relevance to the query.
|
||||
"""
|
||||
try:
|
||||
# Validate inputs
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
if collection and (error := _validate_collection(collection)):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
|
||||
# Parse search type
|
||||
try:
|
||||
search_type_enum = SearchType(search_type.lower())
|
||||
@@ -419,6 +542,7 @@ async def search_knowledge(
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Invalid search type: {search_type}. Valid types: {valid_types}",
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
# Parse file types
|
||||
@@ -430,6 +554,7 @@ async def search_knowledge(
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Invalid file type: {e}",
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
request = SearchRequest(
|
||||
@@ -480,6 +605,7 @@ async def search_knowledge(
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"code": ErrorCode.INTERNAL_ERROR.value,
|
||||
}
|
||||
|
||||
|
||||
@@ -516,6 +642,16 @@ async def ingest_content(
|
||||
the LLM Gateway, and stored in pgvector for search.
|
||||
"""
|
||||
try:
|
||||
# Validate inputs
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
if error := _validate_collection(collection):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
if error := _validate_source_path(source_path):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
|
||||
# Validate content size to prevent DoS
|
||||
settings = get_settings()
|
||||
content_size = len(content.encode("utf-8"))
|
||||
@@ -523,6 +659,7 @@ async def ingest_content(
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Content size ({content_size} bytes) exceeds maximum allowed ({settings.max_document_size} bytes)",
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
# Parse chunk type
|
||||
@@ -533,6 +670,7 @@ async def ingest_content(
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Invalid chunk type: {chunk_type}. Valid types: {valid_types}",
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
# Parse file type
|
||||
@@ -545,6 +683,7 @@ async def ingest_content(
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Invalid file type: {file_type}. Valid types: {valid_types}",
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
request = IngestRequest(
|
||||
@@ -582,6 +721,7 @@ async def ingest_content(
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"code": ErrorCode.INTERNAL_ERROR.value,
|
||||
}
|
||||
|
||||
|
||||
@@ -608,6 +748,16 @@ async def delete_content(
|
||||
Specify either source_path, collection, or chunk_ids to delete.
|
||||
"""
|
||||
try:
|
||||
# Validate inputs
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
if collection and (error := _validate_collection(collection)):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
if error := _validate_source_path(source_path):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
|
||||
request = DeleteRequest(
|
||||
project_id=project_id,
|
||||
agent_id=agent_id,
|
||||
@@ -636,6 +786,7 @@ async def delete_content(
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"code": ErrorCode.INTERNAL_ERROR.value,
|
||||
}
|
||||
|
||||
|
||||
@@ -650,6 +801,12 @@ async def list_collections(
|
||||
Returns collection names with chunk counts and file types.
|
||||
"""
|
||||
try:
|
||||
# Validate inputs
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
|
||||
result = await _collections.list_collections(project_id) # type: ignore[union-attr]
|
||||
|
||||
return {
|
||||
@@ -681,6 +838,7 @@ async def list_collections(
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"code": ErrorCode.INTERNAL_ERROR.value,
|
||||
}
|
||||
|
||||
|
||||
@@ -696,6 +854,14 @@ async def get_collection_stats(
|
||||
Returns chunk counts, token totals, and type breakdowns.
|
||||
"""
|
||||
try:
|
||||
# Validate inputs
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
if error := _validate_collection(collection):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
|
||||
stats = await _collections.get_collection_stats(project_id, collection) # type: ignore[union-attr]
|
||||
|
||||
return {
|
||||
@@ -724,6 +890,7 @@ async def get_collection_stats(
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"code": ErrorCode.INTERNAL_ERROR.value,
|
||||
}
|
||||
|
||||
|
||||
@@ -756,6 +923,16 @@ async def update_document(
|
||||
Replaces all existing chunks for the source path with new content.
|
||||
"""
|
||||
try:
|
||||
# Validate inputs
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
if error := _validate_collection(collection):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
if error := _validate_source_path(source_path):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
|
||||
# Validate content size to prevent DoS
|
||||
settings = get_settings()
|
||||
content_size = len(content.encode("utf-8"))
|
||||
@@ -763,6 +940,7 @@ async def update_document(
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Content size ({content_size} bytes) exceeds maximum allowed ({settings.max_document_size} bytes)",
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
# Parse chunk type
|
||||
@@ -773,6 +951,7 @@ async def update_document(
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Invalid chunk type: {chunk_type}. Valid types: {valid_types}",
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
# Parse file type
|
||||
@@ -785,6 +964,7 @@ async def update_document(
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Invalid file type: {file_type}. Valid types: {valid_types}",
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
result = await _collections.update_document( # type: ignore[union-attr]
|
||||
@@ -820,6 +1000,7 @@ async def update_document(
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"code": ErrorCode.INTERNAL_ERROR.value,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -13,10 +13,10 @@ class TestHealthCheck:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_healthy(self):
|
||||
"""Test health check when healthy."""
|
||||
"""Test health check when all dependencies are connected."""
|
||||
import server
|
||||
|
||||
# Create a proper async context manager mock
|
||||
# Create a proper async context manager mock for database
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.fetchval = AsyncMock(return_value=1)
|
||||
|
||||
@@ -29,24 +29,75 @@ class TestHealthCheck:
|
||||
mock_cm.__aexit__.return_value = None
|
||||
mock_db.acquire.return_value = mock_cm
|
||||
|
||||
# Mock Redis
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock(return_value=True)
|
||||
|
||||
# Mock HTTP client for LLM Gateway
|
||||
mock_http_response = AsyncMock()
|
||||
mock_http_response.status_code = 200
|
||||
mock_http_client = AsyncMock()
|
||||
mock_http_client.get = AsyncMock(return_value=mock_http_response)
|
||||
|
||||
# Mock embeddings with Redis and HTTP client
|
||||
mock_embeddings = MagicMock()
|
||||
mock_embeddings._redis = mock_redis
|
||||
mock_embeddings._http_client = mock_http_client
|
||||
|
||||
server._database = mock_db
|
||||
server._embeddings = mock_embeddings
|
||||
|
||||
result = await server.health_check()
|
||||
|
||||
assert result["status"] == "healthy"
|
||||
assert result["service"] == "knowledge-base"
|
||||
assert result["database"] == "connected"
|
||||
assert result["dependencies"]["database"] == "connected"
|
||||
assert result["dependencies"]["redis"] == "connected"
|
||||
assert result["dependencies"]["llm_gateway"] == "connected"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_no_database(self):
|
||||
"""Test health check without database."""
|
||||
"""Test health check without database - should be unhealthy."""
|
||||
import server
|
||||
|
||||
server._database = None
|
||||
server._embeddings = None
|
||||
|
||||
result = await server.health_check()
|
||||
|
||||
assert result["database"] == "not initialized"
|
||||
assert result["status"] == "unhealthy"
|
||||
assert result["dependencies"]["database"] == "not initialized"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_degraded(self):
|
||||
"""Test health check with database but no Redis - should be degraded."""
|
||||
import server
|
||||
|
||||
# Create a proper async context manager mock for database
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.fetchval = AsyncMock(return_value=1)
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db._pool = MagicMock()
|
||||
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__.return_value = mock_conn
|
||||
mock_cm.__aexit__.return_value = None
|
||||
mock_db.acquire.return_value = mock_cm
|
||||
|
||||
# Mock embeddings without Redis
|
||||
mock_embeddings = MagicMock()
|
||||
mock_embeddings._redis = None
|
||||
mock_embeddings._http_client = None
|
||||
|
||||
server._database = mock_db
|
||||
server._embeddings = mock_embeddings
|
||||
|
||||
result = await server.health_check()
|
||||
|
||||
assert result["status"] == "degraded"
|
||||
assert result["dependencies"]["database"] == "connected"
|
||||
assert result["dependencies"]["redis"] == "not initialized"
|
||||
|
||||
|
||||
class TestSearchKnowledgeTool:
|
||||
|
||||
Reference in New Issue
Block a user