refactor(knowledge-base mcp server): adjust formatting for consistency and readability

Improved code formatting, line breaks, and indentation across chunking logic and multiple test modules to enhance code clarity and maintain consistent style. No functional changes made.
This commit is contained in:
2026-01-06 17:20:31 +01:00
parent 3f23bc3db3
commit 51404216ae
15 changed files with 306 additions and 155 deletions

View File

@@ -184,7 +184,12 @@ class ChunkerFactory:
if file_type:
if file_type == FileType.MARKDOWN:
return self._get_markdown_chunker()
elif file_type in (FileType.TEXT, FileType.JSON, FileType.YAML, FileType.TOML):
elif file_type in (
FileType.TEXT,
FileType.JSON,
FileType.YAML,
FileType.TOML,
):
return self._get_text_chunker()
else:
# Code files
@@ -193,7 +198,9 @@ class ChunkerFactory:
# Default to text chunker
return self._get_text_chunker()
def get_chunker_for_path(self, source_path: str) -> tuple[BaseChunker, FileType | None]:
def get_chunker_for_path(
self, source_path: str
) -> tuple[BaseChunker, FileType | None]:
"""
Get chunker based on file path extension.

View File

@@ -151,7 +151,7 @@ class CodeChunker(BaseChunker):
for struct_type, pattern in patterns.items():
for match in pattern.finditer(content):
# Convert character position to line number
line_num = content[:match.start()].count("\n")
line_num = content[: match.start()].count("\n")
boundaries.append((line_num, struct_type))
if not boundaries:

View File

@@ -69,9 +69,7 @@ class MarkdownChunker(BaseChunker):
if not sections:
# No headings, chunk as plain text
return self._chunk_text_block(
content, source_path, file_type, metadata, []
)
return self._chunk_text_block(content, source_path, file_type, metadata, [])
chunks: list[Chunk] = []
heading_stack: list[tuple[int, str]] = [] # (level, text)
@@ -292,7 +290,10 @@ class MarkdownChunker(BaseChunker):
)
# Overlap: include last paragraph if it fits
if current_content and self.count_tokens(current_content[-1]) <= self.chunk_overlap:
if (
current_content
and self.count_tokens(current_content[-1]) <= self.chunk_overlap
):
current_content = [current_content[-1]]
current_tokens = self.count_tokens(current_content[-1])
else:
@@ -341,12 +342,14 @@ class MarkdownChunker(BaseChunker):
# Start of code block - save previous paragraph
if current_para and any(p.strip() for p in current_para):
para_content = "\n".join(current_para)
paragraphs.append({
"content": para_content,
"tokens": self.count_tokens(para_content),
"start_line": para_start,
"end_line": i - 1,
})
paragraphs.append(
{
"content": para_content,
"tokens": self.count_tokens(para_content),
"start_line": para_start,
"end_line": i - 1,
}
)
current_para = [line]
para_start = i
in_code_block = True
@@ -360,12 +363,14 @@ class MarkdownChunker(BaseChunker):
if not line.strip():
if current_para and any(p.strip() for p in current_para):
para_content = "\n".join(current_para)
paragraphs.append({
"content": para_content,
"tokens": self.count_tokens(para_content),
"start_line": para_start,
"end_line": i - 1,
})
paragraphs.append(
{
"content": para_content,
"tokens": self.count_tokens(para_content),
"start_line": para_start,
"end_line": i - 1,
}
)
current_para = []
para_start = i + 1
else:
@@ -376,12 +381,14 @@ class MarkdownChunker(BaseChunker):
# Final paragraph
if current_para and any(p.strip() for p in current_para):
para_content = "\n".join(current_para)
paragraphs.append({
"content": para_content,
"tokens": self.count_tokens(para_content),
"start_line": para_start,
"end_line": len(lines) - 1,
})
paragraphs.append(
{
"content": para_content,
"tokens": self.count_tokens(para_content),
"start_line": para_start,
"end_line": len(lines) - 1,
}
)
return paragraphs
@@ -448,7 +455,10 @@ class MarkdownChunker(BaseChunker):
)
# Overlap with last sentence
if current_content and self.count_tokens(current_content[-1]) <= self.chunk_overlap:
if (
current_content
and self.count_tokens(current_content[-1]) <= self.chunk_overlap
):
current_content = [current_content[-1]]
current_tokens = self.count_tokens(current_content[-1])
else:

View File

@@ -79,9 +79,7 @@ class TextChunker(BaseChunker):
)
# Fall back to sentence-based chunking
return self._chunk_by_sentences(
content, source_path, file_type, metadata
)
return self._chunk_by_sentences(content, source_path, file_type, metadata)
def _split_paragraphs(self, content: str) -> list[dict[str, Any]]:
"""Split content into paragraphs."""
@@ -97,12 +95,14 @@ class TextChunker(BaseChunker):
continue
para_lines = para.count("\n") + 1
paragraphs.append({
"content": para,
"tokens": self.count_tokens(para),
"start_line": line_num,
"end_line": line_num + para_lines - 1,
})
paragraphs.append(
{
"content": para,
"tokens": self.count_tokens(para),
"start_line": line_num,
"end_line": line_num + para_lines - 1,
}
)
line_num += para_lines + 1 # +1 for blank line between paragraphs
return paragraphs
@@ -172,7 +172,10 @@ class TextChunker(BaseChunker):
# Overlap: keep last paragraph if small enough
overlap_para = None
if current_paras and self.count_tokens(current_paras[-1]) <= self.chunk_overlap:
if (
current_paras
and self.count_tokens(current_paras[-1]) <= self.chunk_overlap
):
overlap_para = current_paras[-1]
current_paras = [overlap_para] if overlap_para else []
@@ -266,7 +269,10 @@ class TextChunker(BaseChunker):
# Overlap: keep last sentence if small enough
overlap = None
if current_sentences and self.count_tokens(current_sentences[-1]) <= self.chunk_overlap:
if (
current_sentences
and self.count_tokens(current_sentences[-1]) <= self.chunk_overlap
):
overlap = current_sentences[-1]
current_sentences = [overlap] if overlap else []
@@ -317,14 +323,10 @@ class TextChunker(BaseChunker):
sentences = self._split_sentences(text)
if len(sentences) > 1:
return self._chunk_by_sentences(
text, source_path, file_type, metadata
)
return self._chunk_by_sentences(text, source_path, file_type, metadata)
# Fall back to word-based splitting
return self._chunk_by_words(
text, source_path, file_type, metadata, base_line
)
return self._chunk_by_words(text, source_path, file_type, metadata, base_line)
def _chunk_by_words(
self,

View File

@@ -328,14 +328,18 @@ class CollectionManager:
"source_path": chunk.source_path or source_path,
"start_line": chunk.start_line,
"end_line": chunk.end_line,
"file_type": effective_file_type.value if (effective_file_type := chunk.file_type or file_type) else None,
"file_type": effective_file_type.value
if (effective_file_type := chunk.file_type or file_type)
else None,
}
embeddings_data.append((
chunk.content,
embedding,
chunk.chunk_type,
chunk_metadata,
))
embeddings_data.append(
(
chunk.content,
embedding,
chunk.chunk_type,
chunk_metadata,
)
)
# Atomically replace old embeddings with new ones
_, chunk_ids = await self.database.replace_source_embeddings(

View File

@@ -214,9 +214,7 @@ class EmbeddingGenerator:
return cached
# Generate via LLM Gateway
embeddings = await self._call_llm_gateway(
[text], project_id, agent_id
)
embeddings = await self._call_llm_gateway([text], project_id, agent_id)
if not embeddings:
raise EmbeddingGenerationError(
@@ -277,9 +275,7 @@ class EmbeddingGenerator:
for i in range(0, len(texts_to_embed), batch_size):
batch = texts_to_embed[i : i + batch_size]
batch_embeddings = await self._call_llm_gateway(
batch, project_id, agent_id
)
batch_embeddings = await self._call_llm_gateway(batch, project_id, agent_id)
new_embeddings.extend(batch_embeddings)
# Validate dimensions

View File

@@ -149,12 +149,8 @@ class IngestRequest(BaseModel):
source_path: str | None = Field(
default=None, description="Source file path for reference"
)
collection: str = Field(
default="default", description="Collection to store in"
)
chunk_type: ChunkType = Field(
default=ChunkType.TEXT, description="Type of content"
)
collection: str = Field(default="default", description="Collection to store in")
chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="Type of content")
file_type: FileType | None = Field(
default=None, description="File type for code chunking"
)
@@ -255,12 +251,8 @@ class DeleteRequest(BaseModel):
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
source_path: str | None = Field(
default=None, description="Delete by source path"
)
collection: str | None = Field(
default=None, description="Delete entire collection"
)
source_path: str | None = Field(default=None, description="Delete by source path")
collection: str | None = Field(default=None, description="Delete entire collection")
chunk_ids: list[str] | None = Field(
default=None, description="Delete specific chunks"
)

View File

@@ -145,8 +145,7 @@ class SearchEngine:
# Filter by threshold (keyword search scores are normalized)
filtered = [
(emb, score) for emb, score in results
if score >= request.threshold
(emb, score) for emb, score in results if score >= request.threshold
]
return [
@@ -204,10 +203,9 @@ class SearchEngine:
)
# Filter by threshold and limit
filtered = [
result for result in fused
if result.score >= request.threshold
][:request.limit]
filtered = [result for result in fused if result.score >= request.threshold][
: request.limit
]
return filtered

View File

@@ -93,6 +93,7 @@ def _validate_source_path(value: str | None) -> str | None:
return None
# Configure logging
logging.basicConfig(
level=logging.INFO,
@@ -213,7 +214,9 @@ async def health_check() -> dict[str, Any]:
if response.status_code == 200:
status["dependencies"]["llm_gateway"] = "connected"
else:
status["dependencies"]["llm_gateway"] = f"unhealthy (status {response.status_code})"
status["dependencies"]["llm_gateway"] = (
f"unhealthy (status {response.status_code})"
)
is_degraded = True
else:
status["dependencies"]["llm_gateway"] = "not initialized"
@@ -328,7 +331,9 @@ def _get_tool_schema(func: Any) -> dict[str, Any]:
}
def _register_tool(name: str, tool_or_func: Any, description: str | None = None) -> None:
def _register_tool(
name: str, tool_or_func: Any, description: str | None = None
) -> None:
"""Register a tool in the registry.
Handles both raw functions and FastMCP FunctionTool objects.
@@ -337,7 +342,11 @@ def _register_tool(name: str, tool_or_func: Any, description: str | None = None)
if hasattr(tool_or_func, "fn"):
func = tool_or_func.fn
# Use FunctionTool's description if available
if not description and hasattr(tool_or_func, "description") and tool_or_func.description:
if (
not description
and hasattr(tool_or_func, "description")
and tool_or_func.description
):
description = tool_or_func.description
else:
func = tool_or_func
@@ -358,11 +367,13 @@ async def list_mcp_tools() -> dict[str, Any]:
"""
tools = []
for name, info in _tool_registry.items():
tools.append({
"name": name,
"description": info["description"],
"inputSchema": info["schema"],
})
tools.append(
{
"name": name,
"description": info["description"],
"inputSchema": info["schema"],
}
)
return {"tools": tools}
@@ -410,7 +421,10 @@ async def mcp_rpc(request: Request) -> JSONResponse:
status_code=400,
content={
"jsonrpc": "2.0",
"error": {"code": -32600, "message": "Invalid Request: jsonrpc must be '2.0'"},
"error": {
"code": -32600,
"message": "Invalid Request: jsonrpc must be '2.0'",
},
"id": request_id,
},
)
@@ -420,7 +434,10 @@ async def mcp_rpc(request: Request) -> JSONResponse:
status_code=400,
content={
"jsonrpc": "2.0",
"error": {"code": -32600, "message": "Invalid Request: method is required"},
"error": {
"code": -32600,
"message": "Invalid Request: method is required",
},
"id": request_id,
},
)
@@ -528,11 +545,23 @@ async def search_knowledge(
try:
# Validate inputs
if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if collection and (error := _validate_collection(collection)):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
# Parse search type
try:
@@ -644,13 +673,29 @@ async def ingest_content(
try:
# Validate inputs
if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_collection(collection):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_source_path(source_path):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
# Validate content size to prevent DoS
settings = get_settings()
@@ -750,13 +795,29 @@ async def delete_content(
try:
# Validate inputs
if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if collection and (error := _validate_collection(collection)):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_source_path(source_path):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
request = DeleteRequest(
project_id=project_id,
@@ -803,9 +864,17 @@ async def list_collections(
try:
# Validate inputs
if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
result = await _collections.list_collections(project_id) # type: ignore[union-attr]
@@ -856,11 +925,23 @@ async def get_collection_stats(
try:
# Validate inputs
if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_collection(collection):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
stats = await _collections.get_collection_stats(project_id, collection) # type: ignore[union-attr]
@@ -874,8 +955,12 @@ async def get_collection_stats(
"avg_chunk_size": stats.avg_chunk_size,
"chunk_types": stats.chunk_types,
"file_types": stats.file_types,
"oldest_chunk": stats.oldest_chunk.isoformat() if stats.oldest_chunk else None,
"newest_chunk": stats.newest_chunk.isoformat() if stats.newest_chunk else None,
"oldest_chunk": stats.oldest_chunk.isoformat()
if stats.oldest_chunk
else None,
"newest_chunk": stats.newest_chunk.isoformat()
if stats.newest_chunk
else None,
}
except KnowledgeBaseError as e:
@@ -925,13 +1010,29 @@ async def update_document(
try:
# Validate inputs
if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_collection(collection):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_source_path(source_path):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
# Validate content size to prevent DoS
settings = get_settings()

View File

@@ -83,7 +83,9 @@ def mock_embeddings():
return [0.1] * 1536
mock_emb.generate = AsyncMock(return_value=fake_embedding())
mock_emb.generate_batch = AsyncMock(side_effect=lambda texts, **_kwargs: [fake_embedding() for _ in texts])
mock_emb.generate_batch = AsyncMock(
side_effect=lambda texts, **_kwargs: [fake_embedding() for _ in texts]
)
return mock_emb
@@ -137,7 +139,7 @@ async def async_function() -> None:
@pytest.fixture
def sample_markdown():
"""Sample Markdown content for chunking tests."""
return '''# Project Documentation
return """# Project Documentation
This is the main documentation for our project.
@@ -182,20 +184,20 @@ The search endpoint allows you to query the knowledge base.
## Contributing
We welcome contributions! Please see our contributing guide.
'''
"""
@pytest.fixture
def sample_text():
"""Sample plain text for chunking tests."""
return '''The quick brown fox jumps over the lazy dog. This is a sample text that we use for testing the text chunking functionality. It contains multiple sentences that should be properly split into chunks.
return """The quick brown fox jumps over the lazy dog. This is a sample text that we use for testing the text chunking functionality. It contains multiple sentences that should be properly split into chunks.
Each paragraph represents a logical unit of text. The chunker should try to respect paragraph boundaries when possible. This helps maintain context and readability.
When chunks need to be split mid-paragraph, the chunker should prefer sentence boundaries. This ensures that each chunk contains complete thoughts and is useful for retrieval.
The final paragraph tests edge cases. What happens with short paragraphs? Do they get merged with adjacent content? Let's find out!
'''
"""
@pytest.fixture

View File

@@ -1,7 +1,6 @@
"""Tests for chunking module."""
class TestBaseChunker:
"""Tests for base chunker functionality."""
@@ -149,7 +148,7 @@ class TestMarkdownChunker:
"""Test that chunker respects heading hierarchy."""
from chunking.markdown import MarkdownChunker
markdown = '''# Main Title
markdown = """# Main Title
Introduction paragraph.
@@ -164,7 +163,7 @@ More detailed content.
## Section Two
Content for section two.
'''
"""
chunker = MarkdownChunker(
chunk_size=200,
@@ -188,7 +187,7 @@ Content for section two.
"""Test handling of code blocks in markdown."""
from chunking.markdown import MarkdownChunker
markdown = '''# Code Example
markdown = """# Code Example
Here's some code:
@@ -198,7 +197,7 @@ def hello():
```
End of example.
'''
"""
chunker = MarkdownChunker(
chunk_size=500,
@@ -256,12 +255,12 @@ class TestTextChunker:
"""Test that chunker respects paragraph boundaries."""
from chunking.text import TextChunker
text = '''First paragraph with some content.
text = """First paragraph with some content.
Second paragraph with different content.
Third paragraph to test chunking behavior.
'''
"""
chunker = TextChunker(
chunk_size=100,

View File

@@ -67,10 +67,14 @@ class TestCollectionManager:
assert result.embeddings_generated == 0
@pytest.mark.asyncio
async def test_ingest_error_handling(self, collection_manager, sample_ingest_request):
async def test_ingest_error_handling(
self, collection_manager, sample_ingest_request
):
"""Test ingest error handling."""
# Make embedding generation fail
collection_manager._embeddings.generate_batch.side_effect = Exception("Embedding error")
collection_manager._embeddings.generate_batch.side_effect = Exception(
"Embedding error"
)
result = await collection_manager.ingest(sample_ingest_request)
@@ -182,7 +186,9 @@ class TestCollectionManager:
)
collection_manager._database.get_collection_stats.return_value = expected_stats
stats = await collection_manager.get_collection_stats("proj-123", "test-collection")
stats = await collection_manager.get_collection_stats(
"proj-123", "test-collection"
)
assert stats.chunk_count == 100
assert stats.unique_sources == 10

View File

@@ -17,19 +17,15 @@ class TestEmbeddingGenerator:
response.raise_for_status = MagicMock()
response.json.return_value = {
"result": {
"content": [
{
"text": json.dumps({
"embeddings": [[0.1] * 1536]
})
}
]
"content": [{"text": json.dumps({"embeddings": [[0.1] * 1536]})}]
}
}
return response
@pytest.mark.asyncio
async def test_generate_single_embedding(self, settings, mock_redis, mock_http_response):
async def test_generate_single_embedding(
self, settings, mock_redis, mock_http_response
):
"""Test generating a single embedding."""
from embeddings import EmbeddingGenerator
@@ -67,9 +63,9 @@ class TestEmbeddingGenerator:
"result": {
"content": [
{
"text": json.dumps({
"embeddings": [[0.1] * 1536, [0.2] * 1536, [0.3] * 1536]
})
"text": json.dumps(
{"embeddings": [[0.1] * 1536, [0.2] * 1536, [0.3] * 1536]}
)
}
]
}
@@ -166,9 +162,11 @@ class TestEmbeddingGenerator:
"result": {
"content": [
{
"text": json.dumps({
"embeddings": [[0.1] * 768] # Wrong dimension
})
"text": json.dumps(
{
"embeddings": [[0.1] * 768] # Wrong dimension
}
)
}
]
}

View File

@@ -1,7 +1,6 @@
"""Tests for exception classes."""
class TestErrorCode:
"""Tests for ErrorCode enum."""
@@ -10,8 +9,13 @@ class TestErrorCode:
from exceptions import ErrorCode
assert ErrorCode.UNKNOWN_ERROR.value == "KB_UNKNOWN_ERROR"
assert ErrorCode.DATABASE_CONNECTION_ERROR.value == "KB_DATABASE_CONNECTION_ERROR"
assert ErrorCode.EMBEDDING_GENERATION_ERROR.value == "KB_EMBEDDING_GENERATION_ERROR"
assert (
ErrorCode.DATABASE_CONNECTION_ERROR.value == "KB_DATABASE_CONNECTION_ERROR"
)
assert (
ErrorCode.EMBEDDING_GENERATION_ERROR.value
== "KB_EMBEDDING_GENERATION_ERROR"
)
assert ErrorCode.CHUNKING_ERROR.value == "KB_CHUNKING_ERROR"
assert ErrorCode.SEARCH_ERROR.value == "KB_SEARCH_ERROR"
assert ErrorCode.COLLECTION_NOT_FOUND.value == "KB_COLLECTION_NOT_FOUND"

View File

@@ -59,7 +59,9 @@ class TestSearchEngine:
]
@pytest.mark.asyncio
async def test_semantic_search(self, search_engine, sample_search_request, sample_db_results):
async def test_semantic_search(
self, search_engine, sample_search_request, sample_db_results
):
"""Test semantic search."""
from models import SearchType
@@ -74,7 +76,9 @@ class TestSearchEngine:
search_engine._database.semantic_search.assert_called_once()
@pytest.mark.asyncio
async def test_keyword_search(self, search_engine, sample_search_request, sample_db_results):
async def test_keyword_search(
self, search_engine, sample_search_request, sample_db_results
):
"""Test keyword search."""
from models import SearchType
@@ -88,7 +92,9 @@ class TestSearchEngine:
search_engine._database.keyword_search.assert_called_once()
@pytest.mark.asyncio
async def test_hybrid_search(self, search_engine, sample_search_request, sample_db_results):
async def test_hybrid_search(
self, search_engine, sample_search_request, sample_db_results
):
"""Test hybrid search."""
from models import SearchType
@@ -105,7 +111,9 @@ class TestSearchEngine:
assert len(response.results) >= 1
@pytest.mark.asyncio
async def test_search_with_collection_filter(self, search_engine, sample_search_request, sample_db_results):
async def test_search_with_collection_filter(
self, search_engine, sample_search_request, sample_db_results
):
"""Test search with collection filter."""
from models import SearchType
@@ -120,7 +128,9 @@ class TestSearchEngine:
assert call_args.kwargs["collection"] == "specific-collection"
@pytest.mark.asyncio
async def test_search_with_file_type_filter(self, search_engine, sample_search_request, sample_db_results):
async def test_search_with_file_type_filter(
self, search_engine, sample_search_request, sample_db_results
):
"""Test search with file type filter."""
from models import FileType, SearchType
@@ -135,7 +145,9 @@ class TestSearchEngine:
assert call_args.kwargs["file_types"] == [FileType.PYTHON]
@pytest.mark.asyncio
async def test_search_respects_limit(self, search_engine, sample_search_request, sample_db_results):
async def test_search_respects_limit(
self, search_engine, sample_search_request, sample_db_results
):
"""Test that search respects result limit."""
from models import SearchType
@@ -148,7 +160,9 @@ class TestSearchEngine:
assert len(response.results) <= 1
@pytest.mark.asyncio
async def test_search_records_time(self, search_engine, sample_search_request, sample_db_results):
async def test_search_records_time(
self, search_engine, sample_search_request, sample_db_results
):
"""Test that search records time."""
from models import SearchType
@@ -203,13 +217,21 @@ class TestReciprocalRankFusion:
from models import SearchResult
semantic = [
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"),
SearchResult(id="b", content="B", score=0.8, chunk_type="code", collection="default"),
SearchResult(
id="a", content="A", score=0.9, chunk_type="code", collection="default"
),
SearchResult(
id="b", content="B", score=0.8, chunk_type="code", collection="default"
),
]
keyword = [
SearchResult(id="b", content="B", score=0.85, chunk_type="code", collection="default"),
SearchResult(id="c", content="C", score=0.7, chunk_type="code", collection="default"),
SearchResult(
id="b", content="B", score=0.85, chunk_type="code", collection="default"
),
SearchResult(
id="c", content="C", score=0.7, chunk_type="code", collection="default"
),
]
fused = search_engine._reciprocal_rank_fusion(semantic, keyword)
@@ -230,19 +252,23 @@ class TestReciprocalRankFusion:
# Same results in same order
results = [
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"),
SearchResult(
id="a", content="A", score=0.9, chunk_type="code", collection="default"
),
]
# High semantic weight
fused_semantic_heavy = search_engine._reciprocal_rank_fusion(
results, [],
results,
[],
semantic_weight=0.9,
keyword_weight=0.1,
)
# High keyword weight
fused_keyword_heavy = search_engine._reciprocal_rank_fusion(
[], results,
[],
results,
semantic_weight=0.1,
keyword_weight=0.9,
)
@@ -256,12 +282,18 @@ class TestReciprocalRankFusion:
from models import SearchResult
semantic = [
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"),
SearchResult(id="b", content="B", score=0.8, chunk_type="code", collection="default"),
SearchResult(
id="a", content="A", score=0.9, chunk_type="code", collection="default"
),
SearchResult(
id="b", content="B", score=0.8, chunk_type="code", collection="default"
),
]
keyword = [
SearchResult(id="c", content="C", score=0.7, chunk_type="code", collection="default"),
SearchResult(
id="c", content="C", score=0.7, chunk_type="code", collection="default"
),
]
fused = search_engine._reciprocal_rank_fusion(semantic, keyword)