refactor(knowledge-base mcp server): adjust formatting for consistency and readability

Improved code formatting, line breaks, and indentation across chunking logic and multiple test modules to enhance code clarity and maintain consistent style. No functional changes made.
This commit is contained in:
2026-01-06 17:20:31 +01:00
parent 3f23bc3db3
commit 51404216ae
15 changed files with 306 additions and 155 deletions

View File

@@ -184,7 +184,12 @@ class ChunkerFactory:
if file_type: if file_type:
if file_type == FileType.MARKDOWN: if file_type == FileType.MARKDOWN:
return self._get_markdown_chunker() return self._get_markdown_chunker()
elif file_type in (FileType.TEXT, FileType.JSON, FileType.YAML, FileType.TOML): elif file_type in (
FileType.TEXT,
FileType.JSON,
FileType.YAML,
FileType.TOML,
):
return self._get_text_chunker() return self._get_text_chunker()
else: else:
# Code files # Code files
@@ -193,7 +198,9 @@ class ChunkerFactory:
# Default to text chunker # Default to text chunker
return self._get_text_chunker() return self._get_text_chunker()
def get_chunker_for_path(self, source_path: str) -> tuple[BaseChunker, FileType | None]: def get_chunker_for_path(
self, source_path: str
) -> tuple[BaseChunker, FileType | None]:
""" """
Get chunker based on file path extension. Get chunker based on file path extension.

View File

@@ -151,7 +151,7 @@ class CodeChunker(BaseChunker):
for struct_type, pattern in patterns.items(): for struct_type, pattern in patterns.items():
for match in pattern.finditer(content): for match in pattern.finditer(content):
# Convert character position to line number # Convert character position to line number
line_num = content[:match.start()].count("\n") line_num = content[: match.start()].count("\n")
boundaries.append((line_num, struct_type)) boundaries.append((line_num, struct_type))
if not boundaries: if not boundaries:

View File

@@ -69,9 +69,7 @@ class MarkdownChunker(BaseChunker):
if not sections: if not sections:
# No headings, chunk as plain text # No headings, chunk as plain text
return self._chunk_text_block( return self._chunk_text_block(content, source_path, file_type, metadata, [])
content, source_path, file_type, metadata, []
)
chunks: list[Chunk] = [] chunks: list[Chunk] = []
heading_stack: list[tuple[int, str]] = [] # (level, text) heading_stack: list[tuple[int, str]] = [] # (level, text)
@@ -292,7 +290,10 @@ class MarkdownChunker(BaseChunker):
) )
# Overlap: include last paragraph if it fits # Overlap: include last paragraph if it fits
if current_content and self.count_tokens(current_content[-1]) <= self.chunk_overlap: if (
current_content
and self.count_tokens(current_content[-1]) <= self.chunk_overlap
):
current_content = [current_content[-1]] current_content = [current_content[-1]]
current_tokens = self.count_tokens(current_content[-1]) current_tokens = self.count_tokens(current_content[-1])
else: else:
@@ -341,12 +342,14 @@ class MarkdownChunker(BaseChunker):
# Start of code block - save previous paragraph # Start of code block - save previous paragraph
if current_para and any(p.strip() for p in current_para): if current_para and any(p.strip() for p in current_para):
para_content = "\n".join(current_para) para_content = "\n".join(current_para)
paragraphs.append({ paragraphs.append(
"content": para_content, {
"tokens": self.count_tokens(para_content), "content": para_content,
"start_line": para_start, "tokens": self.count_tokens(para_content),
"end_line": i - 1, "start_line": para_start,
}) "end_line": i - 1,
}
)
current_para = [line] current_para = [line]
para_start = i para_start = i
in_code_block = True in_code_block = True
@@ -360,12 +363,14 @@ class MarkdownChunker(BaseChunker):
if not line.strip(): if not line.strip():
if current_para and any(p.strip() for p in current_para): if current_para and any(p.strip() for p in current_para):
para_content = "\n".join(current_para) para_content = "\n".join(current_para)
paragraphs.append({ paragraphs.append(
"content": para_content, {
"tokens": self.count_tokens(para_content), "content": para_content,
"start_line": para_start, "tokens": self.count_tokens(para_content),
"end_line": i - 1, "start_line": para_start,
}) "end_line": i - 1,
}
)
current_para = [] current_para = []
para_start = i + 1 para_start = i + 1
else: else:
@@ -376,12 +381,14 @@ class MarkdownChunker(BaseChunker):
# Final paragraph # Final paragraph
if current_para and any(p.strip() for p in current_para): if current_para and any(p.strip() for p in current_para):
para_content = "\n".join(current_para) para_content = "\n".join(current_para)
paragraphs.append({ paragraphs.append(
"content": para_content, {
"tokens": self.count_tokens(para_content), "content": para_content,
"start_line": para_start, "tokens": self.count_tokens(para_content),
"end_line": len(lines) - 1, "start_line": para_start,
}) "end_line": len(lines) - 1,
}
)
return paragraphs return paragraphs
@@ -448,7 +455,10 @@ class MarkdownChunker(BaseChunker):
) )
# Overlap with last sentence # Overlap with last sentence
if current_content and self.count_tokens(current_content[-1]) <= self.chunk_overlap: if (
current_content
and self.count_tokens(current_content[-1]) <= self.chunk_overlap
):
current_content = [current_content[-1]] current_content = [current_content[-1]]
current_tokens = self.count_tokens(current_content[-1]) current_tokens = self.count_tokens(current_content[-1])
else: else:

View File

@@ -79,9 +79,7 @@ class TextChunker(BaseChunker):
) )
# Fall back to sentence-based chunking # Fall back to sentence-based chunking
return self._chunk_by_sentences( return self._chunk_by_sentences(content, source_path, file_type, metadata)
content, source_path, file_type, metadata
)
def _split_paragraphs(self, content: str) -> list[dict[str, Any]]: def _split_paragraphs(self, content: str) -> list[dict[str, Any]]:
"""Split content into paragraphs.""" """Split content into paragraphs."""
@@ -97,12 +95,14 @@ class TextChunker(BaseChunker):
continue continue
para_lines = para.count("\n") + 1 para_lines = para.count("\n") + 1
paragraphs.append({ paragraphs.append(
"content": para, {
"tokens": self.count_tokens(para), "content": para,
"start_line": line_num, "tokens": self.count_tokens(para),
"end_line": line_num + para_lines - 1, "start_line": line_num,
}) "end_line": line_num + para_lines - 1,
}
)
line_num += para_lines + 1 # +1 for blank line between paragraphs line_num += para_lines + 1 # +1 for blank line between paragraphs
return paragraphs return paragraphs
@@ -172,7 +172,10 @@ class TextChunker(BaseChunker):
# Overlap: keep last paragraph if small enough # Overlap: keep last paragraph if small enough
overlap_para = None overlap_para = None
if current_paras and self.count_tokens(current_paras[-1]) <= self.chunk_overlap: if (
current_paras
and self.count_tokens(current_paras[-1]) <= self.chunk_overlap
):
overlap_para = current_paras[-1] overlap_para = current_paras[-1]
current_paras = [overlap_para] if overlap_para else [] current_paras = [overlap_para] if overlap_para else []
@@ -266,7 +269,10 @@ class TextChunker(BaseChunker):
# Overlap: keep last sentence if small enough # Overlap: keep last sentence if small enough
overlap = None overlap = None
if current_sentences and self.count_tokens(current_sentences[-1]) <= self.chunk_overlap: if (
current_sentences
and self.count_tokens(current_sentences[-1]) <= self.chunk_overlap
):
overlap = current_sentences[-1] overlap = current_sentences[-1]
current_sentences = [overlap] if overlap else [] current_sentences = [overlap] if overlap else []
@@ -317,14 +323,10 @@ class TextChunker(BaseChunker):
sentences = self._split_sentences(text) sentences = self._split_sentences(text)
if len(sentences) > 1: if len(sentences) > 1:
return self._chunk_by_sentences( return self._chunk_by_sentences(text, source_path, file_type, metadata)
text, source_path, file_type, metadata
)
# Fall back to word-based splitting # Fall back to word-based splitting
return self._chunk_by_words( return self._chunk_by_words(text, source_path, file_type, metadata, base_line)
text, source_path, file_type, metadata, base_line
)
def _chunk_by_words( def _chunk_by_words(
self, self,

View File

@@ -328,14 +328,18 @@ class CollectionManager:
"source_path": chunk.source_path or source_path, "source_path": chunk.source_path or source_path,
"start_line": chunk.start_line, "start_line": chunk.start_line,
"end_line": chunk.end_line, "end_line": chunk.end_line,
"file_type": effective_file_type.value if (effective_file_type := chunk.file_type or file_type) else None, "file_type": effective_file_type.value
if (effective_file_type := chunk.file_type or file_type)
else None,
} }
embeddings_data.append(( embeddings_data.append(
chunk.content, (
embedding, chunk.content,
chunk.chunk_type, embedding,
chunk_metadata, chunk.chunk_type,
)) chunk_metadata,
)
)
# Atomically replace old embeddings with new ones # Atomically replace old embeddings with new ones
_, chunk_ids = await self.database.replace_source_embeddings( _, chunk_ids = await self.database.replace_source_embeddings(

View File

@@ -214,9 +214,7 @@ class EmbeddingGenerator:
return cached return cached
# Generate via LLM Gateway # Generate via LLM Gateway
embeddings = await self._call_llm_gateway( embeddings = await self._call_llm_gateway([text], project_id, agent_id)
[text], project_id, agent_id
)
if not embeddings: if not embeddings:
raise EmbeddingGenerationError( raise EmbeddingGenerationError(
@@ -277,9 +275,7 @@ class EmbeddingGenerator:
for i in range(0, len(texts_to_embed), batch_size): for i in range(0, len(texts_to_embed), batch_size):
batch = texts_to_embed[i : i + batch_size] batch = texts_to_embed[i : i + batch_size]
batch_embeddings = await self._call_llm_gateway( batch_embeddings = await self._call_llm_gateway(batch, project_id, agent_id)
batch, project_id, agent_id
)
new_embeddings.extend(batch_embeddings) new_embeddings.extend(batch_embeddings)
# Validate dimensions # Validate dimensions

View File

@@ -149,12 +149,8 @@ class IngestRequest(BaseModel):
source_path: str | None = Field( source_path: str | None = Field(
default=None, description="Source file path for reference" default=None, description="Source file path for reference"
) )
collection: str = Field( collection: str = Field(default="default", description="Collection to store in")
default="default", description="Collection to store in" chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="Type of content")
)
chunk_type: ChunkType = Field(
default=ChunkType.TEXT, description="Type of content"
)
file_type: FileType | None = Field( file_type: FileType | None = Field(
default=None, description="File type for code chunking" default=None, description="File type for code chunking"
) )
@@ -255,12 +251,8 @@ class DeleteRequest(BaseModel):
project_id: str = Field(..., description="Project ID for scoping") project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request") agent_id: str = Field(..., description="Agent ID making the request")
source_path: str | None = Field( source_path: str | None = Field(default=None, description="Delete by source path")
default=None, description="Delete by source path" collection: str | None = Field(default=None, description="Delete entire collection")
)
collection: str | None = Field(
default=None, description="Delete entire collection"
)
chunk_ids: list[str] | None = Field( chunk_ids: list[str] | None = Field(
default=None, description="Delete specific chunks" default=None, description="Delete specific chunks"
) )

View File

@@ -145,8 +145,7 @@ class SearchEngine:
# Filter by threshold (keyword search scores are normalized) # Filter by threshold (keyword search scores are normalized)
filtered = [ filtered = [
(emb, score) for emb, score in results (emb, score) for emb, score in results if score >= request.threshold
if score >= request.threshold
] ]
return [ return [
@@ -204,10 +203,9 @@ class SearchEngine:
) )
# Filter by threshold and limit # Filter by threshold and limit
filtered = [ filtered = [result for result in fused if result.score >= request.threshold][
result for result in fused : request.limit
if result.score >= request.threshold ]
][:request.limit]
return filtered return filtered

View File

@@ -93,6 +93,7 @@ def _validate_source_path(value: str | None) -> str | None:
return None return None
# Configure logging # Configure logging
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
@@ -213,7 +214,9 @@ async def health_check() -> dict[str, Any]:
if response.status_code == 200: if response.status_code == 200:
status["dependencies"]["llm_gateway"] = "connected" status["dependencies"]["llm_gateway"] = "connected"
else: else:
status["dependencies"]["llm_gateway"] = f"unhealthy (status {response.status_code})" status["dependencies"]["llm_gateway"] = (
f"unhealthy (status {response.status_code})"
)
is_degraded = True is_degraded = True
else: else:
status["dependencies"]["llm_gateway"] = "not initialized" status["dependencies"]["llm_gateway"] = "not initialized"
@@ -328,7 +331,9 @@ def _get_tool_schema(func: Any) -> dict[str, Any]:
} }
def _register_tool(name: str, tool_or_func: Any, description: str | None = None) -> None: def _register_tool(
name: str, tool_or_func: Any, description: str | None = None
) -> None:
"""Register a tool in the registry. """Register a tool in the registry.
Handles both raw functions and FastMCP FunctionTool objects. Handles both raw functions and FastMCP FunctionTool objects.
@@ -337,7 +342,11 @@ def _register_tool(name: str, tool_or_func: Any, description: str | None = None)
if hasattr(tool_or_func, "fn"): if hasattr(tool_or_func, "fn"):
func = tool_or_func.fn func = tool_or_func.fn
# Use FunctionTool's description if available # Use FunctionTool's description if available
if not description and hasattr(tool_or_func, "description") and tool_or_func.description: if (
not description
and hasattr(tool_or_func, "description")
and tool_or_func.description
):
description = tool_or_func.description description = tool_or_func.description
else: else:
func = tool_or_func func = tool_or_func
@@ -358,11 +367,13 @@ async def list_mcp_tools() -> dict[str, Any]:
""" """
tools = [] tools = []
for name, info in _tool_registry.items(): for name, info in _tool_registry.items():
tools.append({ tools.append(
"name": name, {
"description": info["description"], "name": name,
"inputSchema": info["schema"], "description": info["description"],
}) "inputSchema": info["schema"],
}
)
return {"tools": tools} return {"tools": tools}
@@ -410,7 +421,10 @@ async def mcp_rpc(request: Request) -> JSONResponse:
status_code=400, status_code=400,
content={ content={
"jsonrpc": "2.0", "jsonrpc": "2.0",
"error": {"code": -32600, "message": "Invalid Request: jsonrpc must be '2.0'"}, "error": {
"code": -32600,
"message": "Invalid Request: jsonrpc must be '2.0'",
},
"id": request_id, "id": request_id,
}, },
) )
@@ -420,7 +434,10 @@ async def mcp_rpc(request: Request) -> JSONResponse:
status_code=400, status_code=400,
content={ content={
"jsonrpc": "2.0", "jsonrpc": "2.0",
"error": {"code": -32600, "message": "Invalid Request: method is required"}, "error": {
"code": -32600,
"message": "Invalid Request: method is required",
},
"id": request_id, "id": request_id,
}, },
) )
@@ -528,11 +545,23 @@ async def search_knowledge(
try: try:
# Validate inputs # Validate inputs
if error := _validate_id(project_id, "project_id"): if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"): if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if collection and (error := _validate_collection(collection)): if collection and (error := _validate_collection(collection)):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
# Parse search type # Parse search type
try: try:
@@ -644,13 +673,29 @@ async def ingest_content(
try: try:
# Validate inputs # Validate inputs
if error := _validate_id(project_id, "project_id"): if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"): if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_collection(collection): if error := _validate_collection(collection):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_source_path(source_path): if error := _validate_source_path(source_path):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
# Validate content size to prevent DoS # Validate content size to prevent DoS
settings = get_settings() settings = get_settings()
@@ -750,13 +795,29 @@ async def delete_content(
try: try:
# Validate inputs # Validate inputs
if error := _validate_id(project_id, "project_id"): if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"): if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if collection and (error := _validate_collection(collection)): if collection and (error := _validate_collection(collection)):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_source_path(source_path): if error := _validate_source_path(source_path):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
request = DeleteRequest( request = DeleteRequest(
project_id=project_id, project_id=project_id,
@@ -803,9 +864,17 @@ async def list_collections(
try: try:
# Validate inputs # Validate inputs
if error := _validate_id(project_id, "project_id"): if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"): if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
result = await _collections.list_collections(project_id) # type: ignore[union-attr] result = await _collections.list_collections(project_id) # type: ignore[union-attr]
@@ -856,11 +925,23 @@ async def get_collection_stats(
try: try:
# Validate inputs # Validate inputs
if error := _validate_id(project_id, "project_id"): if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"): if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_collection(collection): if error := _validate_collection(collection):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
stats = await _collections.get_collection_stats(project_id, collection) # type: ignore[union-attr] stats = await _collections.get_collection_stats(project_id, collection) # type: ignore[union-attr]
@@ -874,8 +955,12 @@ async def get_collection_stats(
"avg_chunk_size": stats.avg_chunk_size, "avg_chunk_size": stats.avg_chunk_size,
"chunk_types": stats.chunk_types, "chunk_types": stats.chunk_types,
"file_types": stats.file_types, "file_types": stats.file_types,
"oldest_chunk": stats.oldest_chunk.isoformat() if stats.oldest_chunk else None, "oldest_chunk": stats.oldest_chunk.isoformat()
"newest_chunk": stats.newest_chunk.isoformat() if stats.newest_chunk else None, if stats.oldest_chunk
else None,
"newest_chunk": stats.newest_chunk.isoformat()
if stats.newest_chunk
else None,
} }
except KnowledgeBaseError as e: except KnowledgeBaseError as e:
@@ -925,13 +1010,29 @@ async def update_document(
try: try:
# Validate inputs # Validate inputs
if error := _validate_id(project_id, "project_id"): if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"): if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_collection(collection): if error := _validate_collection(collection):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_source_path(source_path): if error := _validate_source_path(source_path):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
# Validate content size to prevent DoS # Validate content size to prevent DoS
settings = get_settings() settings = get_settings()

View File

@@ -83,7 +83,9 @@ def mock_embeddings():
return [0.1] * 1536 return [0.1] * 1536
mock_emb.generate = AsyncMock(return_value=fake_embedding()) mock_emb.generate = AsyncMock(return_value=fake_embedding())
mock_emb.generate_batch = AsyncMock(side_effect=lambda texts, **_kwargs: [fake_embedding() for _ in texts]) mock_emb.generate_batch = AsyncMock(
side_effect=lambda texts, **_kwargs: [fake_embedding() for _ in texts]
)
return mock_emb return mock_emb
@@ -137,7 +139,7 @@ async def async_function() -> None:
@pytest.fixture @pytest.fixture
def sample_markdown(): def sample_markdown():
"""Sample Markdown content for chunking tests.""" """Sample Markdown content for chunking tests."""
return '''# Project Documentation return """# Project Documentation
This is the main documentation for our project. This is the main documentation for our project.
@@ -182,20 +184,20 @@ The search endpoint allows you to query the knowledge base.
## Contributing ## Contributing
We welcome contributions! Please see our contributing guide. We welcome contributions! Please see our contributing guide.
''' """
@pytest.fixture @pytest.fixture
def sample_text(): def sample_text():
"""Sample plain text for chunking tests.""" """Sample plain text for chunking tests."""
return '''The quick brown fox jumps over the lazy dog. This is a sample text that we use for testing the text chunking functionality. It contains multiple sentences that should be properly split into chunks. return """The quick brown fox jumps over the lazy dog. This is a sample text that we use for testing the text chunking functionality. It contains multiple sentences that should be properly split into chunks.
Each paragraph represents a logical unit of text. The chunker should try to respect paragraph boundaries when possible. This helps maintain context and readability. Each paragraph represents a logical unit of text. The chunker should try to respect paragraph boundaries when possible. This helps maintain context and readability.
When chunks need to be split mid-paragraph, the chunker should prefer sentence boundaries. This ensures that each chunk contains complete thoughts and is useful for retrieval. When chunks need to be split mid-paragraph, the chunker should prefer sentence boundaries. This ensures that each chunk contains complete thoughts and is useful for retrieval.
The final paragraph tests edge cases. What happens with short paragraphs? Do they get merged with adjacent content? Let's find out! The final paragraph tests edge cases. What happens with short paragraphs? Do they get merged with adjacent content? Let's find out!
''' """
@pytest.fixture @pytest.fixture

View File

@@ -1,7 +1,6 @@
"""Tests for chunking module.""" """Tests for chunking module."""
class TestBaseChunker: class TestBaseChunker:
"""Tests for base chunker functionality.""" """Tests for base chunker functionality."""
@@ -149,7 +148,7 @@ class TestMarkdownChunker:
"""Test that chunker respects heading hierarchy.""" """Test that chunker respects heading hierarchy."""
from chunking.markdown import MarkdownChunker from chunking.markdown import MarkdownChunker
markdown = '''# Main Title markdown = """# Main Title
Introduction paragraph. Introduction paragraph.
@@ -164,7 +163,7 @@ More detailed content.
## Section Two ## Section Two
Content for section two. Content for section two.
''' """
chunker = MarkdownChunker( chunker = MarkdownChunker(
chunk_size=200, chunk_size=200,
@@ -188,7 +187,7 @@ Content for section two.
"""Test handling of code blocks in markdown.""" """Test handling of code blocks in markdown."""
from chunking.markdown import MarkdownChunker from chunking.markdown import MarkdownChunker
markdown = '''# Code Example markdown = """# Code Example
Here's some code: Here's some code:
@@ -198,7 +197,7 @@ def hello():
``` ```
End of example. End of example.
''' """
chunker = MarkdownChunker( chunker = MarkdownChunker(
chunk_size=500, chunk_size=500,
@@ -256,12 +255,12 @@ class TestTextChunker:
"""Test that chunker respects paragraph boundaries.""" """Test that chunker respects paragraph boundaries."""
from chunking.text import TextChunker from chunking.text import TextChunker
text = '''First paragraph with some content. text = """First paragraph with some content.
Second paragraph with different content. Second paragraph with different content.
Third paragraph to test chunking behavior. Third paragraph to test chunking behavior.
''' """
chunker = TextChunker( chunker = TextChunker(
chunk_size=100, chunk_size=100,

View File

@@ -67,10 +67,14 @@ class TestCollectionManager:
assert result.embeddings_generated == 0 assert result.embeddings_generated == 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_ingest_error_handling(self, collection_manager, sample_ingest_request): async def test_ingest_error_handling(
self, collection_manager, sample_ingest_request
):
"""Test ingest error handling.""" """Test ingest error handling."""
# Make embedding generation fail # Make embedding generation fail
collection_manager._embeddings.generate_batch.side_effect = Exception("Embedding error") collection_manager._embeddings.generate_batch.side_effect = Exception(
"Embedding error"
)
result = await collection_manager.ingest(sample_ingest_request) result = await collection_manager.ingest(sample_ingest_request)
@@ -182,7 +186,9 @@ class TestCollectionManager:
) )
collection_manager._database.get_collection_stats.return_value = expected_stats collection_manager._database.get_collection_stats.return_value = expected_stats
stats = await collection_manager.get_collection_stats("proj-123", "test-collection") stats = await collection_manager.get_collection_stats(
"proj-123", "test-collection"
)
assert stats.chunk_count == 100 assert stats.chunk_count == 100
assert stats.unique_sources == 10 assert stats.unique_sources == 10

View File

@@ -17,19 +17,15 @@ class TestEmbeddingGenerator:
response.raise_for_status = MagicMock() response.raise_for_status = MagicMock()
response.json.return_value = { response.json.return_value = {
"result": { "result": {
"content": [ "content": [{"text": json.dumps({"embeddings": [[0.1] * 1536]})}]
{
"text": json.dumps({
"embeddings": [[0.1] * 1536]
})
}
]
} }
} }
return response return response
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_single_embedding(self, settings, mock_redis, mock_http_response): async def test_generate_single_embedding(
self, settings, mock_redis, mock_http_response
):
"""Test generating a single embedding.""" """Test generating a single embedding."""
from embeddings import EmbeddingGenerator from embeddings import EmbeddingGenerator
@@ -67,9 +63,9 @@ class TestEmbeddingGenerator:
"result": { "result": {
"content": [ "content": [
{ {
"text": json.dumps({ "text": json.dumps(
"embeddings": [[0.1] * 1536, [0.2] * 1536, [0.3] * 1536] {"embeddings": [[0.1] * 1536, [0.2] * 1536, [0.3] * 1536]}
}) )
} }
] ]
} }
@@ -166,9 +162,11 @@ class TestEmbeddingGenerator:
"result": { "result": {
"content": [ "content": [
{ {
"text": json.dumps({ "text": json.dumps(
"embeddings": [[0.1] * 768] # Wrong dimension {
}) "embeddings": [[0.1] * 768] # Wrong dimension
}
)
} }
] ]
} }

View File

@@ -1,7 +1,6 @@
"""Tests for exception classes.""" """Tests for exception classes."""
class TestErrorCode: class TestErrorCode:
"""Tests for ErrorCode enum.""" """Tests for ErrorCode enum."""
@@ -10,8 +9,13 @@ class TestErrorCode:
from exceptions import ErrorCode from exceptions import ErrorCode
assert ErrorCode.UNKNOWN_ERROR.value == "KB_UNKNOWN_ERROR" assert ErrorCode.UNKNOWN_ERROR.value == "KB_UNKNOWN_ERROR"
assert ErrorCode.DATABASE_CONNECTION_ERROR.value == "KB_DATABASE_CONNECTION_ERROR" assert (
assert ErrorCode.EMBEDDING_GENERATION_ERROR.value == "KB_EMBEDDING_GENERATION_ERROR" ErrorCode.DATABASE_CONNECTION_ERROR.value == "KB_DATABASE_CONNECTION_ERROR"
)
assert (
ErrorCode.EMBEDDING_GENERATION_ERROR.value
== "KB_EMBEDDING_GENERATION_ERROR"
)
assert ErrorCode.CHUNKING_ERROR.value == "KB_CHUNKING_ERROR" assert ErrorCode.CHUNKING_ERROR.value == "KB_CHUNKING_ERROR"
assert ErrorCode.SEARCH_ERROR.value == "KB_SEARCH_ERROR" assert ErrorCode.SEARCH_ERROR.value == "KB_SEARCH_ERROR"
assert ErrorCode.COLLECTION_NOT_FOUND.value == "KB_COLLECTION_NOT_FOUND" assert ErrorCode.COLLECTION_NOT_FOUND.value == "KB_COLLECTION_NOT_FOUND"

View File

@@ -59,7 +59,9 @@ class TestSearchEngine:
] ]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_semantic_search(self, search_engine, sample_search_request, sample_db_results): async def test_semantic_search(
self, search_engine, sample_search_request, sample_db_results
):
"""Test semantic search.""" """Test semantic search."""
from models import SearchType from models import SearchType
@@ -74,7 +76,9 @@ class TestSearchEngine:
search_engine._database.semantic_search.assert_called_once() search_engine._database.semantic_search.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_keyword_search(self, search_engine, sample_search_request, sample_db_results): async def test_keyword_search(
self, search_engine, sample_search_request, sample_db_results
):
"""Test keyword search.""" """Test keyword search."""
from models import SearchType from models import SearchType
@@ -88,7 +92,9 @@ class TestSearchEngine:
search_engine._database.keyword_search.assert_called_once() search_engine._database.keyword_search.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_hybrid_search(self, search_engine, sample_search_request, sample_db_results): async def test_hybrid_search(
self, search_engine, sample_search_request, sample_db_results
):
"""Test hybrid search.""" """Test hybrid search."""
from models import SearchType from models import SearchType
@@ -105,7 +111,9 @@ class TestSearchEngine:
assert len(response.results) >= 1 assert len(response.results) >= 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_search_with_collection_filter(self, search_engine, sample_search_request, sample_db_results): async def test_search_with_collection_filter(
self, search_engine, sample_search_request, sample_db_results
):
"""Test search with collection filter.""" """Test search with collection filter."""
from models import SearchType from models import SearchType
@@ -120,7 +128,9 @@ class TestSearchEngine:
assert call_args.kwargs["collection"] == "specific-collection" assert call_args.kwargs["collection"] == "specific-collection"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_search_with_file_type_filter(self, search_engine, sample_search_request, sample_db_results): async def test_search_with_file_type_filter(
self, search_engine, sample_search_request, sample_db_results
):
"""Test search with file type filter.""" """Test search with file type filter."""
from models import FileType, SearchType from models import FileType, SearchType
@@ -135,7 +145,9 @@ class TestSearchEngine:
assert call_args.kwargs["file_types"] == [FileType.PYTHON] assert call_args.kwargs["file_types"] == [FileType.PYTHON]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_search_respects_limit(self, search_engine, sample_search_request, sample_db_results): async def test_search_respects_limit(
self, search_engine, sample_search_request, sample_db_results
):
"""Test that search respects result limit.""" """Test that search respects result limit."""
from models import SearchType from models import SearchType
@@ -148,7 +160,9 @@ class TestSearchEngine:
assert len(response.results) <= 1 assert len(response.results) <= 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_search_records_time(self, search_engine, sample_search_request, sample_db_results): async def test_search_records_time(
self, search_engine, sample_search_request, sample_db_results
):
"""Test that search records time.""" """Test that search records time."""
from models import SearchType from models import SearchType
@@ -203,13 +217,21 @@ class TestReciprocalRankFusion:
from models import SearchResult from models import SearchResult
semantic = [ semantic = [
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"), SearchResult(
SearchResult(id="b", content="B", score=0.8, chunk_type="code", collection="default"), id="a", content="A", score=0.9, chunk_type="code", collection="default"
),
SearchResult(
id="b", content="B", score=0.8, chunk_type="code", collection="default"
),
] ]
keyword = [ keyword = [
SearchResult(id="b", content="B", score=0.85, chunk_type="code", collection="default"), SearchResult(
SearchResult(id="c", content="C", score=0.7, chunk_type="code", collection="default"), id="b", content="B", score=0.85, chunk_type="code", collection="default"
),
SearchResult(
id="c", content="C", score=0.7, chunk_type="code", collection="default"
),
] ]
fused = search_engine._reciprocal_rank_fusion(semantic, keyword) fused = search_engine._reciprocal_rank_fusion(semantic, keyword)
@@ -230,19 +252,23 @@ class TestReciprocalRankFusion:
# Same results in same order # Same results in same order
results = [ results = [
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"), SearchResult(
id="a", content="A", score=0.9, chunk_type="code", collection="default"
),
] ]
# High semantic weight # High semantic weight
fused_semantic_heavy = search_engine._reciprocal_rank_fusion( fused_semantic_heavy = search_engine._reciprocal_rank_fusion(
results, [], results,
[],
semantic_weight=0.9, semantic_weight=0.9,
keyword_weight=0.1, keyword_weight=0.1,
) )
# High keyword weight # High keyword weight
fused_keyword_heavy = search_engine._reciprocal_rank_fusion( fused_keyword_heavy = search_engine._reciprocal_rank_fusion(
[], results, [],
results,
semantic_weight=0.1, semantic_weight=0.1,
keyword_weight=0.9, keyword_weight=0.9,
) )
@@ -256,12 +282,18 @@ class TestReciprocalRankFusion:
from models import SearchResult from models import SearchResult
semantic = [ semantic = [
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"), SearchResult(
SearchResult(id="b", content="B", score=0.8, chunk_type="code", collection="default"), id="a", content="A", score=0.9, chunk_type="code", collection="default"
),
SearchResult(
id="b", content="B", score=0.8, chunk_type="code", collection="default"
),
] ]
keyword = [ keyword = [
SearchResult(id="c", content="C", score=0.7, chunk_type="code", collection="default"), SearchResult(
id="c", content="C", score=0.7, chunk_type="code", collection="default"
),
] ]
fused = search_engine._reciprocal_rank_fusion(semantic, keyword) fused = search_engine._reciprocal_rank_fusion(semantic, keyword)