diff --git a/mcp-servers/knowledge-base/search.py b/mcp-servers/knowledge-base/search.py index e4df6f2..d1d9bf5 100644 --- a/mcp-servers/knowledge-base/search.py +++ b/mcp-servers/knowledge-base/search.py @@ -5,6 +5,7 @@ Provides semantic (vector), keyword (full-text), and hybrid search capabilities over the knowledge base. """ +import asyncio import logging import time @@ -158,6 +159,7 @@ class SearchEngine: Execute hybrid search combining semantic and keyword. Uses Reciprocal Rank Fusion (RRF) for result combination. + Executes both searches concurrently for better performance. """ # Execute both searches with higher limits for fusion fusion_limit = min(request.limit * 2, 100) @@ -187,9 +189,11 @@ class SearchEngine: include_metadata=request.include_metadata, ) - # Execute searches - semantic_results = await self._semantic_search(semantic_request) - keyword_results = await self._keyword_search(keyword_request) + # Execute searches concurrently for better performance + semantic_results, keyword_results = await asyncio.gather( + self._semantic_search(semantic_request), + self._keyword_search(keyword_request), + ) # Fuse results using RRF fused = self._reciprocal_rank_fusion( diff --git a/mcp-servers/knowledge-base/server.py b/mcp-servers/knowledge-base/server.py index 0e2dc9e..7a5595a 100644 --- a/mcp-servers/knowledge-base/server.py +++ b/mcp-servers/knowledge-base/server.py @@ -7,6 +7,7 @@ intelligent chunking, and collection management. import inspect import logging +import re from contextlib import asynccontextmanager from typing import Any, get_type_hints @@ -20,7 +21,7 @@ from collections.abc import AsyncIterator from config import get_settings from database import DatabaseManager, get_database_manager from embeddings import EmbeddingGenerator, get_embedding_generator -from exceptions import KnowledgeBaseError +from exceptions import ErrorCode, KnowledgeBaseError from models import ( ChunkType, DeleteRequest, @@ -31,6 +32,67 @@ from models import ( ) from search import SearchEngine, get_search_engine +# Input validation patterns +# Allow alphanumeric, hyphens, underscores (1-128 chars) +ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,128}$") +# Collection names: alphanumeric, hyphens, underscores (1-64 chars) +COLLECTION_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,64}$") + + +def _validate_id(value: str, field_name: str) -> str | None: + """Validate project_id or agent_id format. + + Returns error message if invalid, None if valid. + """ + # Handle FieldInfo objects from direct .fn() calls in tests + if not isinstance(value, str): + return f"{field_name} must be a string" + if not value: + return f"{field_name} is required" + if not ID_PATTERN.match(value): + return f"Invalid {field_name}: must be 1-128 alphanumeric characters, hyphens, or underscores" + return None + + +def _validate_collection(value: str) -> str | None: + """Validate collection name format. + + Returns error message if invalid, None if valid. + """ + # Handle FieldInfo objects from direct .fn() calls in tests + if not isinstance(value, str): + return None # Non-string means default not resolved, skip validation + if not COLLECTION_PATTERN.match(value): + return "Invalid collection: must be 1-64 alphanumeric characters, hyphens, or underscores" + return None + + +def _validate_source_path(value: str | None) -> str | None: + """Validate source_path to prevent path traversal. + + Returns error message if invalid, None if valid. + """ + if value is None: + return None + + # Handle FieldInfo objects from direct .fn() calls in tests + if not isinstance(value, str): + return None # Non-string means default not resolved, skip validation + + # Normalize path and check for traversal attempts + if ".." in value: + return "Invalid source_path: path traversal not allowed" + + # Check for null bytes (used in some injection attacks) + if "\x00" in value: + return "Invalid source_path: null bytes not allowed" + + # Limit path length to prevent DoS + if len(value) > 4096: + return "Invalid source_path: path too long (max 4096 chars)" + + return None + # Configure logging logging.basicConfig( level=logging.INFO, @@ -96,24 +158,77 @@ app = FastAPI( @app.get("/health") async def health_check() -> dict[str, Any]: - """Health check endpoint.""" + """Health check endpoint. + + Checks all dependencies: database, Redis cache, and LLM Gateway. + Returns degraded status if any non-critical dependency fails. + Returns unhealthy status if critical dependencies fail. + """ + from datetime import UTC, datetime + status: dict[str, Any] = { "status": "healthy", "service": "knowledge-base", "version": "0.1.0", + "timestamp": datetime.now(UTC).isoformat(), + "dependencies": {}, } - # Check database connection + is_degraded = False + is_unhealthy = False + + # Check database connection (critical) try: if _database and _database._pool: async with _database.acquire() as conn: await conn.fetchval("SELECT 1") - status["database"] = "connected" + status["dependencies"]["database"] = "connected" else: - status["database"] = "not initialized" + status["dependencies"]["database"] = "not initialized" + is_unhealthy = True except Exception as e: - status["database"] = f"error: {e}" + status["dependencies"]["database"] = f"error: {e}" + is_unhealthy = True + + # Check Redis cache (non-critical - degraded without it) + try: + if _embeddings and _embeddings._redis: + await _embeddings._redis.ping() + status["dependencies"]["redis"] = "connected" + else: + status["dependencies"]["redis"] = "not initialized" + is_degraded = True + except Exception as e: + status["dependencies"]["redis"] = f"error: {e}" + is_degraded = True + + # Check LLM Gateway connectivity (non-critical for health check) + try: + if _embeddings and _embeddings._http_client: + settings = get_settings() + response = await _embeddings._http_client.get( + f"{settings.llm_gateway_url}/health", + timeout=5.0, + ) + if response.status_code == 200: + status["dependencies"]["llm_gateway"] = "connected" + else: + status["dependencies"]["llm_gateway"] = f"unhealthy (status {response.status_code})" + is_degraded = True + else: + status["dependencies"]["llm_gateway"] = "not initialized" + is_degraded = True + except Exception as e: + status["dependencies"]["llm_gateway"] = f"error: {e}" + is_degraded = True + + # Set overall status + if is_unhealthy: + status["status"] = "unhealthy" + elif is_degraded: status["status"] = "degraded" + else: + status["status"] = "healthy" return status @@ -411,6 +526,14 @@ async def search_knowledge( Returns chunks ranked by relevance to the query. """ try: + # Validate inputs + if error := _validate_id(project_id, "project_id"): + return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} + if error := _validate_id(agent_id, "agent_id"): + return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} + if collection and (error := _validate_collection(collection)): + return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} + # Parse search type try: search_type_enum = SearchType(search_type.lower()) @@ -419,6 +542,7 @@ async def search_knowledge( return { "success": False, "error": f"Invalid search type: {search_type}. Valid types: {valid_types}", + "code": ErrorCode.INVALID_REQUEST.value, } # Parse file types @@ -430,6 +554,7 @@ async def search_knowledge( return { "success": False, "error": f"Invalid file type: {e}", + "code": ErrorCode.INVALID_REQUEST.value, } request = SearchRequest( @@ -480,6 +605,7 @@ async def search_knowledge( return { "success": False, "error": str(e), + "code": ErrorCode.INTERNAL_ERROR.value, } @@ -516,6 +642,16 @@ async def ingest_content( the LLM Gateway, and stored in pgvector for search. """ try: + # Validate inputs + if error := _validate_id(project_id, "project_id"): + return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} + if error := _validate_id(agent_id, "agent_id"): + return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} + if error := _validate_collection(collection): + return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} + if error := _validate_source_path(source_path): + return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} + # Validate content size to prevent DoS settings = get_settings() content_size = len(content.encode("utf-8")) @@ -523,6 +659,7 @@ async def ingest_content( return { "success": False, "error": f"Content size ({content_size} bytes) exceeds maximum allowed ({settings.max_document_size} bytes)", + "code": ErrorCode.INVALID_REQUEST.value, } # Parse chunk type @@ -533,6 +670,7 @@ async def ingest_content( return { "success": False, "error": f"Invalid chunk type: {chunk_type}. Valid types: {valid_types}", + "code": ErrorCode.INVALID_REQUEST.value, } # Parse file type @@ -545,6 +683,7 @@ async def ingest_content( return { "success": False, "error": f"Invalid file type: {file_type}. Valid types: {valid_types}", + "code": ErrorCode.INVALID_REQUEST.value, } request = IngestRequest( @@ -582,6 +721,7 @@ async def ingest_content( return { "success": False, "error": str(e), + "code": ErrorCode.INTERNAL_ERROR.value, } @@ -608,6 +748,16 @@ async def delete_content( Specify either source_path, collection, or chunk_ids to delete. """ try: + # Validate inputs + if error := _validate_id(project_id, "project_id"): + return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} + if error := _validate_id(agent_id, "agent_id"): + return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} + if collection and (error := _validate_collection(collection)): + return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} + if error := _validate_source_path(source_path): + return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} + request = DeleteRequest( project_id=project_id, agent_id=agent_id, @@ -636,6 +786,7 @@ async def delete_content( return { "success": False, "error": str(e), + "code": ErrorCode.INTERNAL_ERROR.value, } @@ -650,6 +801,12 @@ async def list_collections( Returns collection names with chunk counts and file types. """ try: + # Validate inputs + if error := _validate_id(project_id, "project_id"): + return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} + if error := _validate_id(agent_id, "agent_id"): + return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} + result = await _collections.list_collections(project_id) # type: ignore[union-attr] return { @@ -681,6 +838,7 @@ async def list_collections( return { "success": False, "error": str(e), + "code": ErrorCode.INTERNAL_ERROR.value, } @@ -696,6 +854,14 @@ async def get_collection_stats( Returns chunk counts, token totals, and type breakdowns. """ try: + # Validate inputs + if error := _validate_id(project_id, "project_id"): + return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} + if error := _validate_id(agent_id, "agent_id"): + return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} + if error := _validate_collection(collection): + return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} + stats = await _collections.get_collection_stats(project_id, collection) # type: ignore[union-attr] return { @@ -724,6 +890,7 @@ async def get_collection_stats( return { "success": False, "error": str(e), + "code": ErrorCode.INTERNAL_ERROR.value, } @@ -756,6 +923,16 @@ async def update_document( Replaces all existing chunks for the source path with new content. """ try: + # Validate inputs + if error := _validate_id(project_id, "project_id"): + return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} + if error := _validate_id(agent_id, "agent_id"): + return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} + if error := _validate_collection(collection): + return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} + if error := _validate_source_path(source_path): + return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} + # Validate content size to prevent DoS settings = get_settings() content_size = len(content.encode("utf-8")) @@ -763,6 +940,7 @@ async def update_document( return { "success": False, "error": f"Content size ({content_size} bytes) exceeds maximum allowed ({settings.max_document_size} bytes)", + "code": ErrorCode.INVALID_REQUEST.value, } # Parse chunk type @@ -773,6 +951,7 @@ async def update_document( return { "success": False, "error": f"Invalid chunk type: {chunk_type}. Valid types: {valid_types}", + "code": ErrorCode.INVALID_REQUEST.value, } # Parse file type @@ -785,6 +964,7 @@ async def update_document( return { "success": False, "error": f"Invalid file type: {file_type}. Valid types: {valid_types}", + "code": ErrorCode.INVALID_REQUEST.value, } result = await _collections.update_document( # type: ignore[union-attr] @@ -820,6 +1000,7 @@ async def update_document( return { "success": False, "error": str(e), + "code": ErrorCode.INTERNAL_ERROR.value, } diff --git a/mcp-servers/knowledge-base/tests/test_server.py b/mcp-servers/knowledge-base/tests/test_server.py index 2dcc771..e40f701 100644 --- a/mcp-servers/knowledge-base/tests/test_server.py +++ b/mcp-servers/knowledge-base/tests/test_server.py @@ -13,10 +13,10 @@ class TestHealthCheck: @pytest.mark.asyncio async def test_health_check_healthy(self): - """Test health check when healthy.""" + """Test health check when all dependencies are connected.""" import server - # Create a proper async context manager mock + # Create a proper async context manager mock for database mock_conn = AsyncMock() mock_conn.fetchval = AsyncMock(return_value=1) @@ -29,24 +29,75 @@ class TestHealthCheck: mock_cm.__aexit__.return_value = None mock_db.acquire.return_value = mock_cm + # Mock Redis + mock_redis = AsyncMock() + mock_redis.ping = AsyncMock(return_value=True) + + # Mock HTTP client for LLM Gateway + mock_http_response = AsyncMock() + mock_http_response.status_code = 200 + mock_http_client = AsyncMock() + mock_http_client.get = AsyncMock(return_value=mock_http_response) + + # Mock embeddings with Redis and HTTP client + mock_embeddings = MagicMock() + mock_embeddings._redis = mock_redis + mock_embeddings._http_client = mock_http_client + server._database = mock_db + server._embeddings = mock_embeddings result = await server.health_check() assert result["status"] == "healthy" assert result["service"] == "knowledge-base" - assert result["database"] == "connected" + assert result["dependencies"]["database"] == "connected" + assert result["dependencies"]["redis"] == "connected" + assert result["dependencies"]["llm_gateway"] == "connected" @pytest.mark.asyncio async def test_health_check_no_database(self): - """Test health check without database.""" + """Test health check without database - should be unhealthy.""" import server server._database = None + server._embeddings = None result = await server.health_check() - assert result["database"] == "not initialized" + assert result["status"] == "unhealthy" + assert result["dependencies"]["database"] == "not initialized" + + @pytest.mark.asyncio + async def test_health_check_degraded(self): + """Test health check with database but no Redis - should be degraded.""" + import server + + # Create a proper async context manager mock for database + mock_conn = AsyncMock() + mock_conn.fetchval = AsyncMock(return_value=1) + + mock_db = MagicMock() + mock_db._pool = MagicMock() + + mock_cm = AsyncMock() + mock_cm.__aenter__.return_value = mock_conn + mock_cm.__aexit__.return_value = None + mock_db.acquire.return_value = mock_cm + + # Mock embeddings without Redis + mock_embeddings = MagicMock() + mock_embeddings._redis = None + mock_embeddings._http_client = None + + server._database = mock_db + server._embeddings = mock_embeddings + + result = await server.health_check() + + assert result["status"] == "degraded" + assert result["dependencies"]["database"] == "connected" + assert result["dependencies"]["redis"] == "not initialized" class TestSearchKnowledgeTool: