forked from cardosofelipe/fast-next-template
fix(mcp-kb): add input validation, path security, and health checks
Security fixes from deep review: - Add input validation patterns for project_id, agent_id, collection - Add path traversal protection for source_path (reject .., null bytes) - Add error codes (INTERNAL_ERROR) to generic exception handlers - Handle FieldInfo objects in validation for test robustness Performance fixes: - Enable concurrent hybrid search with asyncio.gather Health endpoint improvements: - Check all dependencies (database, Redis, LLM Gateway) - Return degraded/unhealthy status based on dependency health - Updated tests for new health check response structure All 139 tests pass. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -5,6 +5,7 @@ Provides semantic (vector), keyword (full-text), and hybrid search
|
|||||||
capabilities over the knowledge base.
|
capabilities over the knowledge base.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@@ -158,6 +159,7 @@ class SearchEngine:
|
|||||||
Execute hybrid search combining semantic and keyword.
|
Execute hybrid search combining semantic and keyword.
|
||||||
|
|
||||||
Uses Reciprocal Rank Fusion (RRF) for result combination.
|
Uses Reciprocal Rank Fusion (RRF) for result combination.
|
||||||
|
Executes both searches concurrently for better performance.
|
||||||
"""
|
"""
|
||||||
# Execute both searches with higher limits for fusion
|
# Execute both searches with higher limits for fusion
|
||||||
fusion_limit = min(request.limit * 2, 100)
|
fusion_limit = min(request.limit * 2, 100)
|
||||||
@@ -187,9 +189,11 @@ class SearchEngine:
|
|||||||
include_metadata=request.include_metadata,
|
include_metadata=request.include_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute searches
|
# Execute searches concurrently for better performance
|
||||||
semantic_results = await self._semantic_search(semantic_request)
|
semantic_results, keyword_results = await asyncio.gather(
|
||||||
keyword_results = await self._keyword_search(keyword_request)
|
self._semantic_search(semantic_request),
|
||||||
|
self._keyword_search(keyword_request),
|
||||||
|
)
|
||||||
|
|
||||||
# Fuse results using RRF
|
# Fuse results using RRF
|
||||||
fused = self._reciprocal_rank_fusion(
|
fused = self._reciprocal_rank_fusion(
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ intelligent chunking, and collection management.
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Any, get_type_hints
|
from typing import Any, get_type_hints
|
||||||
|
|
||||||
@@ -20,7 +21,7 @@ from collections.abc import AsyncIterator
|
|||||||
from config import get_settings
|
from config import get_settings
|
||||||
from database import DatabaseManager, get_database_manager
|
from database import DatabaseManager, get_database_manager
|
||||||
from embeddings import EmbeddingGenerator, get_embedding_generator
|
from embeddings import EmbeddingGenerator, get_embedding_generator
|
||||||
from exceptions import KnowledgeBaseError
|
from exceptions import ErrorCode, KnowledgeBaseError
|
||||||
from models import (
|
from models import (
|
||||||
ChunkType,
|
ChunkType,
|
||||||
DeleteRequest,
|
DeleteRequest,
|
||||||
@@ -31,6 +32,67 @@ from models import (
|
|||||||
)
|
)
|
||||||
from search import SearchEngine, get_search_engine
|
from search import SearchEngine, get_search_engine
|
||||||
|
|
||||||
|
# Input validation patterns
|
||||||
|
# Allow alphanumeric, hyphens, underscores (1-128 chars)
|
||||||
|
ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,128}$")
|
||||||
|
# Collection names: alphanumeric, hyphens, underscores (1-64 chars)
|
||||||
|
COLLECTION_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,64}$")
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_id(value: str, field_name: str) -> str | None:
|
||||||
|
"""Validate project_id or agent_id format.
|
||||||
|
|
||||||
|
Returns error message if invalid, None if valid.
|
||||||
|
"""
|
||||||
|
# Handle FieldInfo objects from direct .fn() calls in tests
|
||||||
|
if not isinstance(value, str):
|
||||||
|
return f"{field_name} must be a string"
|
||||||
|
if not value:
|
||||||
|
return f"{field_name} is required"
|
||||||
|
if not ID_PATTERN.match(value):
|
||||||
|
return f"Invalid {field_name}: must be 1-128 alphanumeric characters, hyphens, or underscores"
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_collection(value: str) -> str | None:
|
||||||
|
"""Validate collection name format.
|
||||||
|
|
||||||
|
Returns error message if invalid, None if valid.
|
||||||
|
"""
|
||||||
|
# Handle FieldInfo objects from direct .fn() calls in tests
|
||||||
|
if not isinstance(value, str):
|
||||||
|
return None # Non-string means default not resolved, skip validation
|
||||||
|
if not COLLECTION_PATTERN.match(value):
|
||||||
|
return "Invalid collection: must be 1-64 alphanumeric characters, hyphens, or underscores"
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_source_path(value: str | None) -> str | None:
|
||||||
|
"""Validate source_path to prevent path traversal.
|
||||||
|
|
||||||
|
Returns error message if invalid, None if valid.
|
||||||
|
"""
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Handle FieldInfo objects from direct .fn() calls in tests
|
||||||
|
if not isinstance(value, str):
|
||||||
|
return None # Non-string means default not resolved, skip validation
|
||||||
|
|
||||||
|
# Normalize path and check for traversal attempts
|
||||||
|
if ".." in value:
|
||||||
|
return "Invalid source_path: path traversal not allowed"
|
||||||
|
|
||||||
|
# Check for null bytes (used in some injection attacks)
|
||||||
|
if "\x00" in value:
|
||||||
|
return "Invalid source_path: null bytes not allowed"
|
||||||
|
|
||||||
|
# Limit path length to prevent DoS
|
||||||
|
if len(value) > 4096:
|
||||||
|
return "Invalid source_path: path too long (max 4096 chars)"
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
@@ -96,24 +158,77 @@ app = FastAPI(
|
|||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health_check() -> dict[str, Any]:
|
async def health_check() -> dict[str, Any]:
|
||||||
"""Health check endpoint."""
|
"""Health check endpoint.
|
||||||
|
|
||||||
|
Checks all dependencies: database, Redis cache, and LLM Gateway.
|
||||||
|
Returns degraded status if any non-critical dependency fails.
|
||||||
|
Returns unhealthy status if critical dependencies fail.
|
||||||
|
"""
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
status: dict[str, Any] = {
|
status: dict[str, Any] = {
|
||||||
"status": "healthy",
|
"status": "healthy",
|
||||||
"service": "knowledge-base",
|
"service": "knowledge-base",
|
||||||
"version": "0.1.0",
|
"version": "0.1.0",
|
||||||
|
"timestamp": datetime.now(UTC).isoformat(),
|
||||||
|
"dependencies": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Check database connection
|
is_degraded = False
|
||||||
|
is_unhealthy = False
|
||||||
|
|
||||||
|
# Check database connection (critical)
|
||||||
try:
|
try:
|
||||||
if _database and _database._pool:
|
if _database and _database._pool:
|
||||||
async with _database.acquire() as conn:
|
async with _database.acquire() as conn:
|
||||||
await conn.fetchval("SELECT 1")
|
await conn.fetchval("SELECT 1")
|
||||||
status["database"] = "connected"
|
status["dependencies"]["database"] = "connected"
|
||||||
else:
|
else:
|
||||||
status["database"] = "not initialized"
|
status["dependencies"]["database"] = "not initialized"
|
||||||
|
is_unhealthy = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
status["database"] = f"error: {e}"
|
status["dependencies"]["database"] = f"error: {e}"
|
||||||
|
is_unhealthy = True
|
||||||
|
|
||||||
|
# Check Redis cache (non-critical - degraded without it)
|
||||||
|
try:
|
||||||
|
if _embeddings and _embeddings._redis:
|
||||||
|
await _embeddings._redis.ping()
|
||||||
|
status["dependencies"]["redis"] = "connected"
|
||||||
|
else:
|
||||||
|
status["dependencies"]["redis"] = "not initialized"
|
||||||
|
is_degraded = True
|
||||||
|
except Exception as e:
|
||||||
|
status["dependencies"]["redis"] = f"error: {e}"
|
||||||
|
is_degraded = True
|
||||||
|
|
||||||
|
# Check LLM Gateway connectivity (non-critical for health check)
|
||||||
|
try:
|
||||||
|
if _embeddings and _embeddings._http_client:
|
||||||
|
settings = get_settings()
|
||||||
|
response = await _embeddings._http_client.get(
|
||||||
|
f"{settings.llm_gateway_url}/health",
|
||||||
|
timeout=5.0,
|
||||||
|
)
|
||||||
|
if response.status_code == 200:
|
||||||
|
status["dependencies"]["llm_gateway"] = "connected"
|
||||||
|
else:
|
||||||
|
status["dependencies"]["llm_gateway"] = f"unhealthy (status {response.status_code})"
|
||||||
|
is_degraded = True
|
||||||
|
else:
|
||||||
|
status["dependencies"]["llm_gateway"] = "not initialized"
|
||||||
|
is_degraded = True
|
||||||
|
except Exception as e:
|
||||||
|
status["dependencies"]["llm_gateway"] = f"error: {e}"
|
||||||
|
is_degraded = True
|
||||||
|
|
||||||
|
# Set overall status
|
||||||
|
if is_unhealthy:
|
||||||
|
status["status"] = "unhealthy"
|
||||||
|
elif is_degraded:
|
||||||
status["status"] = "degraded"
|
status["status"] = "degraded"
|
||||||
|
else:
|
||||||
|
status["status"] = "healthy"
|
||||||
|
|
||||||
return status
|
return status
|
||||||
|
|
||||||
@@ -411,6 +526,14 @@ async def search_knowledge(
|
|||||||
Returns chunks ranked by relevance to the query.
|
Returns chunks ranked by relevance to the query.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Validate inputs
|
||||||
|
if error := _validate_id(project_id, "project_id"):
|
||||||
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||||
|
if error := _validate_id(agent_id, "agent_id"):
|
||||||
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||||
|
if collection and (error := _validate_collection(collection)):
|
||||||
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||||
|
|
||||||
# Parse search type
|
# Parse search type
|
||||||
try:
|
try:
|
||||||
search_type_enum = SearchType(search_type.lower())
|
search_type_enum = SearchType(search_type.lower())
|
||||||
@@ -419,6 +542,7 @@ async def search_knowledge(
|
|||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": f"Invalid search type: {search_type}. Valid types: {valid_types}",
|
"error": f"Invalid search type: {search_type}. Valid types: {valid_types}",
|
||||||
|
"code": ErrorCode.INVALID_REQUEST.value,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Parse file types
|
# Parse file types
|
||||||
@@ -430,6 +554,7 @@ async def search_knowledge(
|
|||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": f"Invalid file type: {e}",
|
"error": f"Invalid file type: {e}",
|
||||||
|
"code": ErrorCode.INVALID_REQUEST.value,
|
||||||
}
|
}
|
||||||
|
|
||||||
request = SearchRequest(
|
request = SearchRequest(
|
||||||
@@ -480,6 +605,7 @@ async def search_knowledge(
|
|||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
|
"code": ErrorCode.INTERNAL_ERROR.value,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -516,6 +642,16 @@ async def ingest_content(
|
|||||||
the LLM Gateway, and stored in pgvector for search.
|
the LLM Gateway, and stored in pgvector for search.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Validate inputs
|
||||||
|
if error := _validate_id(project_id, "project_id"):
|
||||||
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||||
|
if error := _validate_id(agent_id, "agent_id"):
|
||||||
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||||
|
if error := _validate_collection(collection):
|
||||||
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||||
|
if error := _validate_source_path(source_path):
|
||||||
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||||
|
|
||||||
# Validate content size to prevent DoS
|
# Validate content size to prevent DoS
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
content_size = len(content.encode("utf-8"))
|
content_size = len(content.encode("utf-8"))
|
||||||
@@ -523,6 +659,7 @@ async def ingest_content(
|
|||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": f"Content size ({content_size} bytes) exceeds maximum allowed ({settings.max_document_size} bytes)",
|
"error": f"Content size ({content_size} bytes) exceeds maximum allowed ({settings.max_document_size} bytes)",
|
||||||
|
"code": ErrorCode.INVALID_REQUEST.value,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Parse chunk type
|
# Parse chunk type
|
||||||
@@ -533,6 +670,7 @@ async def ingest_content(
|
|||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": f"Invalid chunk type: {chunk_type}. Valid types: {valid_types}",
|
"error": f"Invalid chunk type: {chunk_type}. Valid types: {valid_types}",
|
||||||
|
"code": ErrorCode.INVALID_REQUEST.value,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Parse file type
|
# Parse file type
|
||||||
@@ -545,6 +683,7 @@ async def ingest_content(
|
|||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": f"Invalid file type: {file_type}. Valid types: {valid_types}",
|
"error": f"Invalid file type: {file_type}. Valid types: {valid_types}",
|
||||||
|
"code": ErrorCode.INVALID_REQUEST.value,
|
||||||
}
|
}
|
||||||
|
|
||||||
request = IngestRequest(
|
request = IngestRequest(
|
||||||
@@ -582,6 +721,7 @@ async def ingest_content(
|
|||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
|
"code": ErrorCode.INTERNAL_ERROR.value,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -608,6 +748,16 @@ async def delete_content(
|
|||||||
Specify either source_path, collection, or chunk_ids to delete.
|
Specify either source_path, collection, or chunk_ids to delete.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Validate inputs
|
||||||
|
if error := _validate_id(project_id, "project_id"):
|
||||||
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||||
|
if error := _validate_id(agent_id, "agent_id"):
|
||||||
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||||
|
if collection and (error := _validate_collection(collection)):
|
||||||
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||||
|
if error := _validate_source_path(source_path):
|
||||||
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||||
|
|
||||||
request = DeleteRequest(
|
request = DeleteRequest(
|
||||||
project_id=project_id,
|
project_id=project_id,
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
@@ -636,6 +786,7 @@ async def delete_content(
|
|||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
|
"code": ErrorCode.INTERNAL_ERROR.value,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -650,6 +801,12 @@ async def list_collections(
|
|||||||
Returns collection names with chunk counts and file types.
|
Returns collection names with chunk counts and file types.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Validate inputs
|
||||||
|
if error := _validate_id(project_id, "project_id"):
|
||||||
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||||
|
if error := _validate_id(agent_id, "agent_id"):
|
||||||
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||||
|
|
||||||
result = await _collections.list_collections(project_id) # type: ignore[union-attr]
|
result = await _collections.list_collections(project_id) # type: ignore[union-attr]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -681,6 +838,7 @@ async def list_collections(
|
|||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
|
"code": ErrorCode.INTERNAL_ERROR.value,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -696,6 +854,14 @@ async def get_collection_stats(
|
|||||||
Returns chunk counts, token totals, and type breakdowns.
|
Returns chunk counts, token totals, and type breakdowns.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Validate inputs
|
||||||
|
if error := _validate_id(project_id, "project_id"):
|
||||||
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||||
|
if error := _validate_id(agent_id, "agent_id"):
|
||||||
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||||
|
if error := _validate_collection(collection):
|
||||||
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||||
|
|
||||||
stats = await _collections.get_collection_stats(project_id, collection) # type: ignore[union-attr]
|
stats = await _collections.get_collection_stats(project_id, collection) # type: ignore[union-attr]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -724,6 +890,7 @@ async def get_collection_stats(
|
|||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
|
"code": ErrorCode.INTERNAL_ERROR.value,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -756,6 +923,16 @@ async def update_document(
|
|||||||
Replaces all existing chunks for the source path with new content.
|
Replaces all existing chunks for the source path with new content.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Validate inputs
|
||||||
|
if error := _validate_id(project_id, "project_id"):
|
||||||
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||||
|
if error := _validate_id(agent_id, "agent_id"):
|
||||||
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||||
|
if error := _validate_collection(collection):
|
||||||
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||||
|
if error := _validate_source_path(source_path):
|
||||||
|
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||||
|
|
||||||
# Validate content size to prevent DoS
|
# Validate content size to prevent DoS
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
content_size = len(content.encode("utf-8"))
|
content_size = len(content.encode("utf-8"))
|
||||||
@@ -763,6 +940,7 @@ async def update_document(
|
|||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": f"Content size ({content_size} bytes) exceeds maximum allowed ({settings.max_document_size} bytes)",
|
"error": f"Content size ({content_size} bytes) exceeds maximum allowed ({settings.max_document_size} bytes)",
|
||||||
|
"code": ErrorCode.INVALID_REQUEST.value,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Parse chunk type
|
# Parse chunk type
|
||||||
@@ -773,6 +951,7 @@ async def update_document(
|
|||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": f"Invalid chunk type: {chunk_type}. Valid types: {valid_types}",
|
"error": f"Invalid chunk type: {chunk_type}. Valid types: {valid_types}",
|
||||||
|
"code": ErrorCode.INVALID_REQUEST.value,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Parse file type
|
# Parse file type
|
||||||
@@ -785,6 +964,7 @@ async def update_document(
|
|||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": f"Invalid file type: {file_type}. Valid types: {valid_types}",
|
"error": f"Invalid file type: {file_type}. Valid types: {valid_types}",
|
||||||
|
"code": ErrorCode.INVALID_REQUEST.value,
|
||||||
}
|
}
|
||||||
|
|
||||||
result = await _collections.update_document( # type: ignore[union-attr]
|
result = await _collections.update_document( # type: ignore[union-attr]
|
||||||
@@ -820,6 +1000,7 @@ async def update_document(
|
|||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
|
"code": ErrorCode.INTERNAL_ERROR.value,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -13,10 +13,10 @@ class TestHealthCheck:
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_health_check_healthy(self):
|
async def test_health_check_healthy(self):
|
||||||
"""Test health check when healthy."""
|
"""Test health check when all dependencies are connected."""
|
||||||
import server
|
import server
|
||||||
|
|
||||||
# Create a proper async context manager mock
|
# Create a proper async context manager mock for database
|
||||||
mock_conn = AsyncMock()
|
mock_conn = AsyncMock()
|
||||||
mock_conn.fetchval = AsyncMock(return_value=1)
|
mock_conn.fetchval = AsyncMock(return_value=1)
|
||||||
|
|
||||||
@@ -29,24 +29,75 @@ class TestHealthCheck:
|
|||||||
mock_cm.__aexit__.return_value = None
|
mock_cm.__aexit__.return_value = None
|
||||||
mock_db.acquire.return_value = mock_cm
|
mock_db.acquire.return_value = mock_cm
|
||||||
|
|
||||||
|
# Mock Redis
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.ping = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
# Mock HTTP client for LLM Gateway
|
||||||
|
mock_http_response = AsyncMock()
|
||||||
|
mock_http_response.status_code = 200
|
||||||
|
mock_http_client = AsyncMock()
|
||||||
|
mock_http_client.get = AsyncMock(return_value=mock_http_response)
|
||||||
|
|
||||||
|
# Mock embeddings with Redis and HTTP client
|
||||||
|
mock_embeddings = MagicMock()
|
||||||
|
mock_embeddings._redis = mock_redis
|
||||||
|
mock_embeddings._http_client = mock_http_client
|
||||||
|
|
||||||
server._database = mock_db
|
server._database = mock_db
|
||||||
|
server._embeddings = mock_embeddings
|
||||||
|
|
||||||
result = await server.health_check()
|
result = await server.health_check()
|
||||||
|
|
||||||
assert result["status"] == "healthy"
|
assert result["status"] == "healthy"
|
||||||
assert result["service"] == "knowledge-base"
|
assert result["service"] == "knowledge-base"
|
||||||
assert result["database"] == "connected"
|
assert result["dependencies"]["database"] == "connected"
|
||||||
|
assert result["dependencies"]["redis"] == "connected"
|
||||||
|
assert result["dependencies"]["llm_gateway"] == "connected"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_health_check_no_database(self):
|
async def test_health_check_no_database(self):
|
||||||
"""Test health check without database."""
|
"""Test health check without database - should be unhealthy."""
|
||||||
import server
|
import server
|
||||||
|
|
||||||
server._database = None
|
server._database = None
|
||||||
|
server._embeddings = None
|
||||||
|
|
||||||
result = await server.health_check()
|
result = await server.health_check()
|
||||||
|
|
||||||
assert result["database"] == "not initialized"
|
assert result["status"] == "unhealthy"
|
||||||
|
assert result["dependencies"]["database"] == "not initialized"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_check_degraded(self):
|
||||||
|
"""Test health check with database but no Redis - should be degraded."""
|
||||||
|
import server
|
||||||
|
|
||||||
|
# Create a proper async context manager mock for database
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_conn.fetchval = AsyncMock(return_value=1)
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db._pool = MagicMock()
|
||||||
|
|
||||||
|
mock_cm = AsyncMock()
|
||||||
|
mock_cm.__aenter__.return_value = mock_conn
|
||||||
|
mock_cm.__aexit__.return_value = None
|
||||||
|
mock_db.acquire.return_value = mock_cm
|
||||||
|
|
||||||
|
# Mock embeddings without Redis
|
||||||
|
mock_embeddings = MagicMock()
|
||||||
|
mock_embeddings._redis = None
|
||||||
|
mock_embeddings._http_client = None
|
||||||
|
|
||||||
|
server._database = mock_db
|
||||||
|
server._embeddings = mock_embeddings
|
||||||
|
|
||||||
|
result = await server.health_check()
|
||||||
|
|
||||||
|
assert result["status"] == "degraded"
|
||||||
|
assert result["dependencies"]["database"] == "connected"
|
||||||
|
assert result["dependencies"]["redis"] == "not initialized"
|
||||||
|
|
||||||
|
|
||||||
class TestSearchKnowledgeTool:
|
class TestSearchKnowledgeTool:
|
||||||
|
|||||||
Reference in New Issue
Block a user