forked from cardosofelipe/fast-next-template
refactor(knowledge-base mcp server): adjust formatting for consistency and readability
Improved code formatting, line breaks, and indentation across chunking logic and multiple test modules to enhance code clarity and maintain consistent style. No functional changes made.
This commit is contained in:
@@ -184,7 +184,12 @@ class ChunkerFactory:
|
||||
if file_type:
|
||||
if file_type == FileType.MARKDOWN:
|
||||
return self._get_markdown_chunker()
|
||||
elif file_type in (FileType.TEXT, FileType.JSON, FileType.YAML, FileType.TOML):
|
||||
elif file_type in (
|
||||
FileType.TEXT,
|
||||
FileType.JSON,
|
||||
FileType.YAML,
|
||||
FileType.TOML,
|
||||
):
|
||||
return self._get_text_chunker()
|
||||
else:
|
||||
# Code files
|
||||
@@ -193,7 +198,9 @@ class ChunkerFactory:
|
||||
# Default to text chunker
|
||||
return self._get_text_chunker()
|
||||
|
||||
def get_chunker_for_path(self, source_path: str) -> tuple[BaseChunker, FileType | None]:
|
||||
def get_chunker_for_path(
|
||||
self, source_path: str
|
||||
) -> tuple[BaseChunker, FileType | None]:
|
||||
"""
|
||||
Get chunker based on file path extension.
|
||||
|
||||
|
||||
@@ -151,7 +151,7 @@ class CodeChunker(BaseChunker):
|
||||
for struct_type, pattern in patterns.items():
|
||||
for match in pattern.finditer(content):
|
||||
# Convert character position to line number
|
||||
line_num = content[:match.start()].count("\n")
|
||||
line_num = content[: match.start()].count("\n")
|
||||
boundaries.append((line_num, struct_type))
|
||||
|
||||
if not boundaries:
|
||||
|
||||
@@ -69,9 +69,7 @@ class MarkdownChunker(BaseChunker):
|
||||
|
||||
if not sections:
|
||||
# No headings, chunk as plain text
|
||||
return self._chunk_text_block(
|
||||
content, source_path, file_type, metadata, []
|
||||
)
|
||||
return self._chunk_text_block(content, source_path, file_type, metadata, [])
|
||||
|
||||
chunks: list[Chunk] = []
|
||||
heading_stack: list[tuple[int, str]] = [] # (level, text)
|
||||
@@ -292,7 +290,10 @@ class MarkdownChunker(BaseChunker):
|
||||
)
|
||||
|
||||
# Overlap: include last paragraph if it fits
|
||||
if current_content and self.count_tokens(current_content[-1]) <= self.chunk_overlap:
|
||||
if (
|
||||
current_content
|
||||
and self.count_tokens(current_content[-1]) <= self.chunk_overlap
|
||||
):
|
||||
current_content = [current_content[-1]]
|
||||
current_tokens = self.count_tokens(current_content[-1])
|
||||
else:
|
||||
@@ -341,12 +342,14 @@ class MarkdownChunker(BaseChunker):
|
||||
# Start of code block - save previous paragraph
|
||||
if current_para and any(p.strip() for p in current_para):
|
||||
para_content = "\n".join(current_para)
|
||||
paragraphs.append({
|
||||
"content": para_content,
|
||||
"tokens": self.count_tokens(para_content),
|
||||
"start_line": para_start,
|
||||
"end_line": i - 1,
|
||||
})
|
||||
paragraphs.append(
|
||||
{
|
||||
"content": para_content,
|
||||
"tokens": self.count_tokens(para_content),
|
||||
"start_line": para_start,
|
||||
"end_line": i - 1,
|
||||
}
|
||||
)
|
||||
current_para = [line]
|
||||
para_start = i
|
||||
in_code_block = True
|
||||
@@ -360,12 +363,14 @@ class MarkdownChunker(BaseChunker):
|
||||
if not line.strip():
|
||||
if current_para and any(p.strip() for p in current_para):
|
||||
para_content = "\n".join(current_para)
|
||||
paragraphs.append({
|
||||
"content": para_content,
|
||||
"tokens": self.count_tokens(para_content),
|
||||
"start_line": para_start,
|
||||
"end_line": i - 1,
|
||||
})
|
||||
paragraphs.append(
|
||||
{
|
||||
"content": para_content,
|
||||
"tokens": self.count_tokens(para_content),
|
||||
"start_line": para_start,
|
||||
"end_line": i - 1,
|
||||
}
|
||||
)
|
||||
current_para = []
|
||||
para_start = i + 1
|
||||
else:
|
||||
@@ -376,12 +381,14 @@ class MarkdownChunker(BaseChunker):
|
||||
# Final paragraph
|
||||
if current_para and any(p.strip() for p in current_para):
|
||||
para_content = "\n".join(current_para)
|
||||
paragraphs.append({
|
||||
"content": para_content,
|
||||
"tokens": self.count_tokens(para_content),
|
||||
"start_line": para_start,
|
||||
"end_line": len(lines) - 1,
|
||||
})
|
||||
paragraphs.append(
|
||||
{
|
||||
"content": para_content,
|
||||
"tokens": self.count_tokens(para_content),
|
||||
"start_line": para_start,
|
||||
"end_line": len(lines) - 1,
|
||||
}
|
||||
)
|
||||
|
||||
return paragraphs
|
||||
|
||||
@@ -448,7 +455,10 @@ class MarkdownChunker(BaseChunker):
|
||||
)
|
||||
|
||||
# Overlap with last sentence
|
||||
if current_content and self.count_tokens(current_content[-1]) <= self.chunk_overlap:
|
||||
if (
|
||||
current_content
|
||||
and self.count_tokens(current_content[-1]) <= self.chunk_overlap
|
||||
):
|
||||
current_content = [current_content[-1]]
|
||||
current_tokens = self.count_tokens(current_content[-1])
|
||||
else:
|
||||
|
||||
@@ -79,9 +79,7 @@ class TextChunker(BaseChunker):
|
||||
)
|
||||
|
||||
# Fall back to sentence-based chunking
|
||||
return self._chunk_by_sentences(
|
||||
content, source_path, file_type, metadata
|
||||
)
|
||||
return self._chunk_by_sentences(content, source_path, file_type, metadata)
|
||||
|
||||
def _split_paragraphs(self, content: str) -> list[dict[str, Any]]:
|
||||
"""Split content into paragraphs."""
|
||||
@@ -97,12 +95,14 @@ class TextChunker(BaseChunker):
|
||||
continue
|
||||
|
||||
para_lines = para.count("\n") + 1
|
||||
paragraphs.append({
|
||||
"content": para,
|
||||
"tokens": self.count_tokens(para),
|
||||
"start_line": line_num,
|
||||
"end_line": line_num + para_lines - 1,
|
||||
})
|
||||
paragraphs.append(
|
||||
{
|
||||
"content": para,
|
||||
"tokens": self.count_tokens(para),
|
||||
"start_line": line_num,
|
||||
"end_line": line_num + para_lines - 1,
|
||||
}
|
||||
)
|
||||
line_num += para_lines + 1 # +1 for blank line between paragraphs
|
||||
|
||||
return paragraphs
|
||||
@@ -172,7 +172,10 @@ class TextChunker(BaseChunker):
|
||||
|
||||
# Overlap: keep last paragraph if small enough
|
||||
overlap_para = None
|
||||
if current_paras and self.count_tokens(current_paras[-1]) <= self.chunk_overlap:
|
||||
if (
|
||||
current_paras
|
||||
and self.count_tokens(current_paras[-1]) <= self.chunk_overlap
|
||||
):
|
||||
overlap_para = current_paras[-1]
|
||||
|
||||
current_paras = [overlap_para] if overlap_para else []
|
||||
@@ -266,7 +269,10 @@ class TextChunker(BaseChunker):
|
||||
|
||||
# Overlap: keep last sentence if small enough
|
||||
overlap = None
|
||||
if current_sentences and self.count_tokens(current_sentences[-1]) <= self.chunk_overlap:
|
||||
if (
|
||||
current_sentences
|
||||
and self.count_tokens(current_sentences[-1]) <= self.chunk_overlap
|
||||
):
|
||||
overlap = current_sentences[-1]
|
||||
|
||||
current_sentences = [overlap] if overlap else []
|
||||
@@ -317,14 +323,10 @@ class TextChunker(BaseChunker):
|
||||
sentences = self._split_sentences(text)
|
||||
|
||||
if len(sentences) > 1:
|
||||
return self._chunk_by_sentences(
|
||||
text, source_path, file_type, metadata
|
||||
)
|
||||
return self._chunk_by_sentences(text, source_path, file_type, metadata)
|
||||
|
||||
# Fall back to word-based splitting
|
||||
return self._chunk_by_words(
|
||||
text, source_path, file_type, metadata, base_line
|
||||
)
|
||||
return self._chunk_by_words(text, source_path, file_type, metadata, base_line)
|
||||
|
||||
def _chunk_by_words(
|
||||
self,
|
||||
|
||||
@@ -328,14 +328,18 @@ class CollectionManager:
|
||||
"source_path": chunk.source_path or source_path,
|
||||
"start_line": chunk.start_line,
|
||||
"end_line": chunk.end_line,
|
||||
"file_type": effective_file_type.value if (effective_file_type := chunk.file_type or file_type) else None,
|
||||
"file_type": effective_file_type.value
|
||||
if (effective_file_type := chunk.file_type or file_type)
|
||||
else None,
|
||||
}
|
||||
embeddings_data.append((
|
||||
chunk.content,
|
||||
embedding,
|
||||
chunk.chunk_type,
|
||||
chunk_metadata,
|
||||
))
|
||||
embeddings_data.append(
|
||||
(
|
||||
chunk.content,
|
||||
embedding,
|
||||
chunk.chunk_type,
|
||||
chunk_metadata,
|
||||
)
|
||||
)
|
||||
|
||||
# Atomically replace old embeddings with new ones
|
||||
_, chunk_ids = await self.database.replace_source_embeddings(
|
||||
|
||||
@@ -214,9 +214,7 @@ class EmbeddingGenerator:
|
||||
return cached
|
||||
|
||||
# Generate via LLM Gateway
|
||||
embeddings = await self._call_llm_gateway(
|
||||
[text], project_id, agent_id
|
||||
)
|
||||
embeddings = await self._call_llm_gateway([text], project_id, agent_id)
|
||||
|
||||
if not embeddings:
|
||||
raise EmbeddingGenerationError(
|
||||
@@ -277,9 +275,7 @@ class EmbeddingGenerator:
|
||||
|
||||
for i in range(0, len(texts_to_embed), batch_size):
|
||||
batch = texts_to_embed[i : i + batch_size]
|
||||
batch_embeddings = await self._call_llm_gateway(
|
||||
batch, project_id, agent_id
|
||||
)
|
||||
batch_embeddings = await self._call_llm_gateway(batch, project_id, agent_id)
|
||||
new_embeddings.extend(batch_embeddings)
|
||||
|
||||
# Validate dimensions
|
||||
|
||||
@@ -149,12 +149,8 @@ class IngestRequest(BaseModel):
|
||||
source_path: str | None = Field(
|
||||
default=None, description="Source file path for reference"
|
||||
)
|
||||
collection: str = Field(
|
||||
default="default", description="Collection to store in"
|
||||
)
|
||||
chunk_type: ChunkType = Field(
|
||||
default=ChunkType.TEXT, description="Type of content"
|
||||
)
|
||||
collection: str = Field(default="default", description="Collection to store in")
|
||||
chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="Type of content")
|
||||
file_type: FileType | None = Field(
|
||||
default=None, description="File type for code chunking"
|
||||
)
|
||||
@@ -255,12 +251,8 @@ class DeleteRequest(BaseModel):
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
source_path: str | None = Field(
|
||||
default=None, description="Delete by source path"
|
||||
)
|
||||
collection: str | None = Field(
|
||||
default=None, description="Delete entire collection"
|
||||
)
|
||||
source_path: str | None = Field(default=None, description="Delete by source path")
|
||||
collection: str | None = Field(default=None, description="Delete entire collection")
|
||||
chunk_ids: list[str] | None = Field(
|
||||
default=None, description="Delete specific chunks"
|
||||
)
|
||||
|
||||
@@ -145,8 +145,7 @@ class SearchEngine:
|
||||
|
||||
# Filter by threshold (keyword search scores are normalized)
|
||||
filtered = [
|
||||
(emb, score) for emb, score in results
|
||||
if score >= request.threshold
|
||||
(emb, score) for emb, score in results if score >= request.threshold
|
||||
]
|
||||
|
||||
return [
|
||||
@@ -204,10 +203,9 @@ class SearchEngine:
|
||||
)
|
||||
|
||||
# Filter by threshold and limit
|
||||
filtered = [
|
||||
result for result in fused
|
||||
if result.score >= request.threshold
|
||||
][:request.limit]
|
||||
filtered = [result for result in fused if result.score >= request.threshold][
|
||||
: request.limit
|
||||
]
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
@@ -93,6 +93,7 @@ def _validate_source_path(value: str | None) -> str | None:
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
@@ -213,7 +214,9 @@ async def health_check() -> dict[str, Any]:
|
||||
if response.status_code == 200:
|
||||
status["dependencies"]["llm_gateway"] = "connected"
|
||||
else:
|
||||
status["dependencies"]["llm_gateway"] = f"unhealthy (status {response.status_code})"
|
||||
status["dependencies"]["llm_gateway"] = (
|
||||
f"unhealthy (status {response.status_code})"
|
||||
)
|
||||
is_degraded = True
|
||||
else:
|
||||
status["dependencies"]["llm_gateway"] = "not initialized"
|
||||
@@ -328,7 +331,9 @@ def _get_tool_schema(func: Any) -> dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def _register_tool(name: str, tool_or_func: Any, description: str | None = None) -> None:
|
||||
def _register_tool(
|
||||
name: str, tool_or_func: Any, description: str | None = None
|
||||
) -> None:
|
||||
"""Register a tool in the registry.
|
||||
|
||||
Handles both raw functions and FastMCP FunctionTool objects.
|
||||
@@ -337,7 +342,11 @@ def _register_tool(name: str, tool_or_func: Any, description: str | None = None)
|
||||
if hasattr(tool_or_func, "fn"):
|
||||
func = tool_or_func.fn
|
||||
# Use FunctionTool's description if available
|
||||
if not description and hasattr(tool_or_func, "description") and tool_or_func.description:
|
||||
if (
|
||||
not description
|
||||
and hasattr(tool_or_func, "description")
|
||||
and tool_or_func.description
|
||||
):
|
||||
description = tool_or_func.description
|
||||
else:
|
||||
func = tool_or_func
|
||||
@@ -358,11 +367,13 @@ async def list_mcp_tools() -> dict[str, Any]:
|
||||
"""
|
||||
tools = []
|
||||
for name, info in _tool_registry.items():
|
||||
tools.append({
|
||||
"name": name,
|
||||
"description": info["description"],
|
||||
"inputSchema": info["schema"],
|
||||
})
|
||||
tools.append(
|
||||
{
|
||||
"name": name,
|
||||
"description": info["description"],
|
||||
"inputSchema": info["schema"],
|
||||
}
|
||||
)
|
||||
|
||||
return {"tools": tools}
|
||||
|
||||
@@ -410,7 +421,10 @@ async def mcp_rpc(request: Request) -> JSONResponse:
|
||||
status_code=400,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"error": {"code": -32600, "message": "Invalid Request: jsonrpc must be '2.0'"},
|
||||
"error": {
|
||||
"code": -32600,
|
||||
"message": "Invalid Request: jsonrpc must be '2.0'",
|
||||
},
|
||||
"id": request_id,
|
||||
},
|
||||
)
|
||||
@@ -420,7 +434,10 @@ async def mcp_rpc(request: Request) -> JSONResponse:
|
||||
status_code=400,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"error": {"code": -32600, "message": "Invalid Request: method is required"},
|
||||
"error": {
|
||||
"code": -32600,
|
||||
"message": "Invalid Request: method is required",
|
||||
},
|
||||
"id": request_id,
|
||||
},
|
||||
)
|
||||
@@ -528,11 +545,23 @@ async def search_knowledge(
|
||||
try:
|
||||
# Validate inputs
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if collection and (error := _validate_collection(collection)):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
# Parse search type
|
||||
try:
|
||||
@@ -644,13 +673,29 @@ async def ingest_content(
|
||||
try:
|
||||
# Validate inputs
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_collection(collection):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_source_path(source_path):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
# Validate content size to prevent DoS
|
||||
settings = get_settings()
|
||||
@@ -750,13 +795,29 @@ async def delete_content(
|
||||
try:
|
||||
# Validate inputs
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if collection and (error := _validate_collection(collection)):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_source_path(source_path):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
request = DeleteRequest(
|
||||
project_id=project_id,
|
||||
@@ -803,9 +864,17 @@ async def list_collections(
|
||||
try:
|
||||
# Validate inputs
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
result = await _collections.list_collections(project_id) # type: ignore[union-attr]
|
||||
|
||||
@@ -856,11 +925,23 @@ async def get_collection_stats(
|
||||
try:
|
||||
# Validate inputs
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_collection(collection):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
stats = await _collections.get_collection_stats(project_id, collection) # type: ignore[union-attr]
|
||||
|
||||
@@ -874,8 +955,12 @@ async def get_collection_stats(
|
||||
"avg_chunk_size": stats.avg_chunk_size,
|
||||
"chunk_types": stats.chunk_types,
|
||||
"file_types": stats.file_types,
|
||||
"oldest_chunk": stats.oldest_chunk.isoformat() if stats.oldest_chunk else None,
|
||||
"newest_chunk": stats.newest_chunk.isoformat() if stats.newest_chunk else None,
|
||||
"oldest_chunk": stats.oldest_chunk.isoformat()
|
||||
if stats.oldest_chunk
|
||||
else None,
|
||||
"newest_chunk": stats.newest_chunk.isoformat()
|
||||
if stats.newest_chunk
|
||||
else None,
|
||||
}
|
||||
|
||||
except KnowledgeBaseError as e:
|
||||
@@ -925,13 +1010,29 @@ async def update_document(
|
||||
try:
|
||||
# Validate inputs
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_collection(collection):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_source_path(source_path):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
# Validate content size to prevent DoS
|
||||
settings = get_settings()
|
||||
|
||||
@@ -83,7 +83,9 @@ def mock_embeddings():
|
||||
return [0.1] * 1536
|
||||
|
||||
mock_emb.generate = AsyncMock(return_value=fake_embedding())
|
||||
mock_emb.generate_batch = AsyncMock(side_effect=lambda texts, **_kwargs: [fake_embedding() for _ in texts])
|
||||
mock_emb.generate_batch = AsyncMock(
|
||||
side_effect=lambda texts, **_kwargs: [fake_embedding() for _ in texts]
|
||||
)
|
||||
|
||||
return mock_emb
|
||||
|
||||
@@ -137,7 +139,7 @@ async def async_function() -> None:
|
||||
@pytest.fixture
|
||||
def sample_markdown():
|
||||
"""Sample Markdown content for chunking tests."""
|
||||
return '''# Project Documentation
|
||||
return """# Project Documentation
|
||||
|
||||
This is the main documentation for our project.
|
||||
|
||||
@@ -182,20 +184,20 @@ The search endpoint allows you to query the knowledge base.
|
||||
## Contributing
|
||||
|
||||
We welcome contributions! Please see our contributing guide.
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_text():
|
||||
"""Sample plain text for chunking tests."""
|
||||
return '''The quick brown fox jumps over the lazy dog. This is a sample text that we use for testing the text chunking functionality. It contains multiple sentences that should be properly split into chunks.
|
||||
return """The quick brown fox jumps over the lazy dog. This is a sample text that we use for testing the text chunking functionality. It contains multiple sentences that should be properly split into chunks.
|
||||
|
||||
Each paragraph represents a logical unit of text. The chunker should try to respect paragraph boundaries when possible. This helps maintain context and readability.
|
||||
|
||||
When chunks need to be split mid-paragraph, the chunker should prefer sentence boundaries. This ensures that each chunk contains complete thoughts and is useful for retrieval.
|
||||
|
||||
The final paragraph tests edge cases. What happens with short paragraphs? Do they get merged with adjacent content? Let's find out!
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Tests for chunking module."""
|
||||
|
||||
|
||||
|
||||
class TestBaseChunker:
|
||||
"""Tests for base chunker functionality."""
|
||||
|
||||
@@ -149,7 +148,7 @@ class TestMarkdownChunker:
|
||||
"""Test that chunker respects heading hierarchy."""
|
||||
from chunking.markdown import MarkdownChunker
|
||||
|
||||
markdown = '''# Main Title
|
||||
markdown = """# Main Title
|
||||
|
||||
Introduction paragraph.
|
||||
|
||||
@@ -164,7 +163,7 @@ More detailed content.
|
||||
## Section Two
|
||||
|
||||
Content for section two.
|
||||
'''
|
||||
"""
|
||||
|
||||
chunker = MarkdownChunker(
|
||||
chunk_size=200,
|
||||
@@ -188,7 +187,7 @@ Content for section two.
|
||||
"""Test handling of code blocks in markdown."""
|
||||
from chunking.markdown import MarkdownChunker
|
||||
|
||||
markdown = '''# Code Example
|
||||
markdown = """# Code Example
|
||||
|
||||
Here's some code:
|
||||
|
||||
@@ -198,7 +197,7 @@ def hello():
|
||||
```
|
||||
|
||||
End of example.
|
||||
'''
|
||||
"""
|
||||
|
||||
chunker = MarkdownChunker(
|
||||
chunk_size=500,
|
||||
@@ -256,12 +255,12 @@ class TestTextChunker:
|
||||
"""Test that chunker respects paragraph boundaries."""
|
||||
from chunking.text import TextChunker
|
||||
|
||||
text = '''First paragraph with some content.
|
||||
text = """First paragraph with some content.
|
||||
|
||||
Second paragraph with different content.
|
||||
|
||||
Third paragraph to test chunking behavior.
|
||||
'''
|
||||
"""
|
||||
|
||||
chunker = TextChunker(
|
||||
chunk_size=100,
|
||||
|
||||
@@ -67,10 +67,14 @@ class TestCollectionManager:
|
||||
assert result.embeddings_generated == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_error_handling(self, collection_manager, sample_ingest_request):
|
||||
async def test_ingest_error_handling(
|
||||
self, collection_manager, sample_ingest_request
|
||||
):
|
||||
"""Test ingest error handling."""
|
||||
# Make embedding generation fail
|
||||
collection_manager._embeddings.generate_batch.side_effect = Exception("Embedding error")
|
||||
collection_manager._embeddings.generate_batch.side_effect = Exception(
|
||||
"Embedding error"
|
||||
)
|
||||
|
||||
result = await collection_manager.ingest(sample_ingest_request)
|
||||
|
||||
@@ -182,7 +186,9 @@ class TestCollectionManager:
|
||||
)
|
||||
collection_manager._database.get_collection_stats.return_value = expected_stats
|
||||
|
||||
stats = await collection_manager.get_collection_stats("proj-123", "test-collection")
|
||||
stats = await collection_manager.get_collection_stats(
|
||||
"proj-123", "test-collection"
|
||||
)
|
||||
|
||||
assert stats.chunk_count == 100
|
||||
assert stats.unique_sources == 10
|
||||
|
||||
@@ -17,19 +17,15 @@ class TestEmbeddingGenerator:
|
||||
response.raise_for_status = MagicMock()
|
||||
response.json.return_value = {
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"text": json.dumps({
|
||||
"embeddings": [[0.1] * 1536]
|
||||
})
|
||||
}
|
||||
]
|
||||
"content": [{"text": json.dumps({"embeddings": [[0.1] * 1536]})}]
|
||||
}
|
||||
}
|
||||
return response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_single_embedding(self, settings, mock_redis, mock_http_response):
|
||||
async def test_generate_single_embedding(
|
||||
self, settings, mock_redis, mock_http_response
|
||||
):
|
||||
"""Test generating a single embedding."""
|
||||
from embeddings import EmbeddingGenerator
|
||||
|
||||
@@ -67,9 +63,9 @@ class TestEmbeddingGenerator:
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"text": json.dumps({
|
||||
"embeddings": [[0.1] * 1536, [0.2] * 1536, [0.3] * 1536]
|
||||
})
|
||||
"text": json.dumps(
|
||||
{"embeddings": [[0.1] * 1536, [0.2] * 1536, [0.3] * 1536]}
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -166,9 +162,11 @@ class TestEmbeddingGenerator:
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"text": json.dumps({
|
||||
"embeddings": [[0.1] * 768] # Wrong dimension
|
||||
})
|
||||
"text": json.dumps(
|
||||
{
|
||||
"embeddings": [[0.1] * 768] # Wrong dimension
|
||||
}
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Tests for exception classes."""
|
||||
|
||||
|
||||
|
||||
class TestErrorCode:
|
||||
"""Tests for ErrorCode enum."""
|
||||
|
||||
@@ -10,8 +9,13 @@ class TestErrorCode:
|
||||
from exceptions import ErrorCode
|
||||
|
||||
assert ErrorCode.UNKNOWN_ERROR.value == "KB_UNKNOWN_ERROR"
|
||||
assert ErrorCode.DATABASE_CONNECTION_ERROR.value == "KB_DATABASE_CONNECTION_ERROR"
|
||||
assert ErrorCode.EMBEDDING_GENERATION_ERROR.value == "KB_EMBEDDING_GENERATION_ERROR"
|
||||
assert (
|
||||
ErrorCode.DATABASE_CONNECTION_ERROR.value == "KB_DATABASE_CONNECTION_ERROR"
|
||||
)
|
||||
assert (
|
||||
ErrorCode.EMBEDDING_GENERATION_ERROR.value
|
||||
== "KB_EMBEDDING_GENERATION_ERROR"
|
||||
)
|
||||
assert ErrorCode.CHUNKING_ERROR.value == "KB_CHUNKING_ERROR"
|
||||
assert ErrorCode.SEARCH_ERROR.value == "KB_SEARCH_ERROR"
|
||||
assert ErrorCode.COLLECTION_NOT_FOUND.value == "KB_COLLECTION_NOT_FOUND"
|
||||
|
||||
@@ -59,7 +59,9 @@ class TestSearchEngine:
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_semantic_search(self, search_engine, sample_search_request, sample_db_results):
|
||||
async def test_semantic_search(
|
||||
self, search_engine, sample_search_request, sample_db_results
|
||||
):
|
||||
"""Test semantic search."""
|
||||
from models import SearchType
|
||||
|
||||
@@ -74,7 +76,9 @@ class TestSearchEngine:
|
||||
search_engine._database.semantic_search.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keyword_search(self, search_engine, sample_search_request, sample_db_results):
|
||||
async def test_keyword_search(
|
||||
self, search_engine, sample_search_request, sample_db_results
|
||||
):
|
||||
"""Test keyword search."""
|
||||
from models import SearchType
|
||||
|
||||
@@ -88,7 +92,9 @@ class TestSearchEngine:
|
||||
search_engine._database.keyword_search.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hybrid_search(self, search_engine, sample_search_request, sample_db_results):
|
||||
async def test_hybrid_search(
|
||||
self, search_engine, sample_search_request, sample_db_results
|
||||
):
|
||||
"""Test hybrid search."""
|
||||
from models import SearchType
|
||||
|
||||
@@ -105,7 +111,9 @@ class TestSearchEngine:
|
||||
assert len(response.results) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_collection_filter(self, search_engine, sample_search_request, sample_db_results):
|
||||
async def test_search_with_collection_filter(
|
||||
self, search_engine, sample_search_request, sample_db_results
|
||||
):
|
||||
"""Test search with collection filter."""
|
||||
from models import SearchType
|
||||
|
||||
@@ -120,7 +128,9 @@ class TestSearchEngine:
|
||||
assert call_args.kwargs["collection"] == "specific-collection"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_file_type_filter(self, search_engine, sample_search_request, sample_db_results):
|
||||
async def test_search_with_file_type_filter(
|
||||
self, search_engine, sample_search_request, sample_db_results
|
||||
):
|
||||
"""Test search with file type filter."""
|
||||
from models import FileType, SearchType
|
||||
|
||||
@@ -135,7 +145,9 @@ class TestSearchEngine:
|
||||
assert call_args.kwargs["file_types"] == [FileType.PYTHON]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_respects_limit(self, search_engine, sample_search_request, sample_db_results):
|
||||
async def test_search_respects_limit(
|
||||
self, search_engine, sample_search_request, sample_db_results
|
||||
):
|
||||
"""Test that search respects result limit."""
|
||||
from models import SearchType
|
||||
|
||||
@@ -148,7 +160,9 @@ class TestSearchEngine:
|
||||
assert len(response.results) <= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_records_time(self, search_engine, sample_search_request, sample_db_results):
|
||||
async def test_search_records_time(
|
||||
self, search_engine, sample_search_request, sample_db_results
|
||||
):
|
||||
"""Test that search records time."""
|
||||
from models import SearchType
|
||||
|
||||
@@ -203,13 +217,21 @@ class TestReciprocalRankFusion:
|
||||
from models import SearchResult
|
||||
|
||||
semantic = [
|
||||
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"),
|
||||
SearchResult(id="b", content="B", score=0.8, chunk_type="code", collection="default"),
|
||||
SearchResult(
|
||||
id="a", content="A", score=0.9, chunk_type="code", collection="default"
|
||||
),
|
||||
SearchResult(
|
||||
id="b", content="B", score=0.8, chunk_type="code", collection="default"
|
||||
),
|
||||
]
|
||||
|
||||
keyword = [
|
||||
SearchResult(id="b", content="B", score=0.85, chunk_type="code", collection="default"),
|
||||
SearchResult(id="c", content="C", score=0.7, chunk_type="code", collection="default"),
|
||||
SearchResult(
|
||||
id="b", content="B", score=0.85, chunk_type="code", collection="default"
|
||||
),
|
||||
SearchResult(
|
||||
id="c", content="C", score=0.7, chunk_type="code", collection="default"
|
||||
),
|
||||
]
|
||||
|
||||
fused = search_engine._reciprocal_rank_fusion(semantic, keyword)
|
||||
@@ -230,19 +252,23 @@ class TestReciprocalRankFusion:
|
||||
|
||||
# Same results in same order
|
||||
results = [
|
||||
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"),
|
||||
SearchResult(
|
||||
id="a", content="A", score=0.9, chunk_type="code", collection="default"
|
||||
),
|
||||
]
|
||||
|
||||
# High semantic weight
|
||||
fused_semantic_heavy = search_engine._reciprocal_rank_fusion(
|
||||
results, [],
|
||||
results,
|
||||
[],
|
||||
semantic_weight=0.9,
|
||||
keyword_weight=0.1,
|
||||
)
|
||||
|
||||
# High keyword weight
|
||||
fused_keyword_heavy = search_engine._reciprocal_rank_fusion(
|
||||
[], results,
|
||||
[],
|
||||
results,
|
||||
semantic_weight=0.1,
|
||||
keyword_weight=0.9,
|
||||
)
|
||||
@@ -256,12 +282,18 @@ class TestReciprocalRankFusion:
|
||||
from models import SearchResult
|
||||
|
||||
semantic = [
|
||||
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"),
|
||||
SearchResult(id="b", content="B", score=0.8, chunk_type="code", collection="default"),
|
||||
SearchResult(
|
||||
id="a", content="A", score=0.9, chunk_type="code", collection="default"
|
||||
),
|
||||
SearchResult(
|
||||
id="b", content="B", score=0.8, chunk_type="code", collection="default"
|
||||
),
|
||||
]
|
||||
|
||||
keyword = [
|
||||
SearchResult(id="c", content="C", score=0.7, chunk_type="code", collection="default"),
|
||||
SearchResult(
|
||||
id="c", content="C", score=0.7, chunk_type="code", collection="default"
|
||||
),
|
||||
]
|
||||
|
||||
fused = search_engine._reciprocal_rank_fusion(semantic, keyword)
|
||||
|
||||
Reference in New Issue
Block a user