Files
Felipe Cardoso 51404216ae refactor(knowledge-base mcp server): adjust formatting for consistency and readability
Improved code formatting, line breaks, and indentation across chunking logic and multiple test modules to enhance code clarity and maintain consistent style. No functional changes made.
2026-01-06 17:20:31 +01:00

288 lines
9.5 KiB
Python

"""
Search implementations for Knowledge Base MCP Server.
Provides semantic (vector), keyword (full-text), and hybrid search
capabilities over the knowledge base.
"""
import asyncio
import logging
import time
from config import Settings, get_settings
from database import DatabaseManager, get_database_manager
from embeddings import EmbeddingGenerator, get_embedding_generator
from exceptions import InvalidSearchTypeError, SearchError
from models import (
SearchRequest,
SearchResponse,
SearchResult,
SearchType,
)
logger = logging.getLogger(__name__)
class SearchEngine:
"""
Unified search engine supporting multiple search types.
Features:
- Semantic search using vector similarity
- Keyword search using full-text search
- Hybrid search combining both approaches
- Configurable result fusion and weighting
"""
def __init__(
self,
settings: Settings | None = None,
database: DatabaseManager | None = None,
embeddings: EmbeddingGenerator | None = None,
) -> None:
"""Initialize search engine."""
self._settings = settings or get_settings()
self._database = database
self._embeddings = embeddings
@property
def database(self) -> DatabaseManager:
"""Get database manager."""
if self._database is None:
self._database = get_database_manager()
return self._database
@property
def embeddings(self) -> EmbeddingGenerator:
"""Get embedding generator."""
if self._embeddings is None:
self._embeddings = get_embedding_generator()
return self._embeddings
async def search(self, request: SearchRequest) -> SearchResponse:
"""
Execute a search request.
Args:
request: Search request with query and options
Returns:
Search response with results
"""
start_time = time.time()
try:
if request.search_type == SearchType.SEMANTIC:
results = await self._semantic_search(request)
elif request.search_type == SearchType.KEYWORD:
results = await self._keyword_search(request)
elif request.search_type == SearchType.HYBRID:
results = await self._hybrid_search(request)
else:
raise InvalidSearchTypeError(
search_type=request.search_type,
valid_types=[t.value for t in SearchType],
)
search_time_ms = (time.time() - start_time) * 1000
logger.info(
f"Search completed: type={request.search_type.value}, "
f"results={len(results)}, time={search_time_ms:.1f}ms"
)
return SearchResponse(
query=request.query,
search_type=request.search_type.value,
results=results,
total_results=len(results),
search_time_ms=search_time_ms,
)
except InvalidSearchTypeError:
raise
except Exception as e:
logger.error(f"Search error: {e}")
raise SearchError(
message=f"Search failed: {e}",
cause=e,
)
async def _semantic_search(self, request: SearchRequest) -> list[SearchResult]:
"""Execute semantic (vector) search."""
# Generate embedding for query
query_embedding = await self.embeddings.generate(
text=request.query,
project_id=request.project_id,
agent_id=request.agent_id,
)
# Search database
results = await self.database.semantic_search(
project_id=request.project_id,
query_embedding=query_embedding,
collection=request.collection,
limit=request.limit,
threshold=request.threshold,
file_types=request.file_types,
)
# Convert to SearchResult
return [
SearchResult.from_embedding(embedding, score)
for embedding, score in results
]
async def _keyword_search(self, request: SearchRequest) -> list[SearchResult]:
"""Execute keyword (full-text) search."""
results = await self.database.keyword_search(
project_id=request.project_id,
query=request.query,
collection=request.collection,
limit=request.limit,
file_types=request.file_types,
)
# Filter by threshold (keyword search scores are normalized)
filtered = [
(emb, score) for emb, score in results if score >= request.threshold
]
return [
SearchResult.from_embedding(embedding, score)
for embedding, score in filtered
]
async def _hybrid_search(self, request: SearchRequest) -> list[SearchResult]:
"""
Execute hybrid search combining semantic and keyword.
Uses Reciprocal Rank Fusion (RRF) for result combination.
Executes both searches concurrently for better performance.
"""
# Execute both searches with higher limits for fusion
fusion_limit = min(request.limit * 2, 100)
# Create modified request for sub-searches
semantic_request = SearchRequest(
project_id=request.project_id,
agent_id=request.agent_id,
query=request.query,
search_type=SearchType.SEMANTIC,
collection=request.collection,
limit=fusion_limit,
threshold=request.threshold * 0.8, # Lower threshold for fusion
file_types=request.file_types,
include_metadata=request.include_metadata,
)
keyword_request = SearchRequest(
project_id=request.project_id,
agent_id=request.agent_id,
query=request.query,
search_type=SearchType.KEYWORD,
collection=request.collection,
limit=fusion_limit,
threshold=0.0, # No threshold for keyword, we'll filter after fusion
file_types=request.file_types,
include_metadata=request.include_metadata,
)
# Execute searches concurrently for better performance
semantic_results, keyword_results = await asyncio.gather(
self._semantic_search(semantic_request),
self._keyword_search(keyword_request),
)
# Fuse results using RRF
fused = self._reciprocal_rank_fusion(
semantic_results=semantic_results,
keyword_results=keyword_results,
semantic_weight=self._settings.hybrid_semantic_weight,
keyword_weight=self._settings.hybrid_keyword_weight,
)
# Filter by threshold and limit
filtered = [result for result in fused if result.score >= request.threshold][
: request.limit
]
return filtered
def _reciprocal_rank_fusion(
self,
semantic_results: list[SearchResult],
keyword_results: list[SearchResult],
semantic_weight: float = 0.7,
keyword_weight: float = 0.3,
k: int = 60, # RRF constant
) -> list[SearchResult]:
"""
Combine results using Reciprocal Rank Fusion.
RRF score = sum(weight / (k + rank)) for each result list.
"""
# Calculate RRF scores
scores: dict[str, float] = {}
results_by_id: dict[str, SearchResult] = {}
# Process semantic results
for rank, result in enumerate(semantic_results, start=1):
rrf_score = semantic_weight / (k + rank)
scores[result.id] = scores.get(result.id, 0) + rrf_score
results_by_id[result.id] = result
# Process keyword results
for rank, result in enumerate(keyword_results, start=1):
rrf_score = keyword_weight / (k + rank)
scores[result.id] = scores.get(result.id, 0) + rrf_score
if result.id not in results_by_id:
results_by_id[result.id] = result
# Sort by combined score
sorted_ids = sorted(scores.keys(), key=lambda x: scores[x], reverse=True)
# Normalize scores to 0-1 range
max_score = max(scores.values()) if scores else 1.0
# Create final results with normalized scores
final_results: list[SearchResult] = []
for result_id in sorted_ids:
result = results_by_id[result_id]
normalized_score = scores[result_id] / max_score
# Create new result with updated score
final_results.append(
SearchResult(
id=result.id,
content=result.content,
score=normalized_score,
source_path=result.source_path,
start_line=result.start_line,
end_line=result.end_line,
chunk_type=result.chunk_type,
file_type=result.file_type,
collection=result.collection,
metadata=result.metadata,
)
)
return final_results
# Global search engine instance (lazy initialization)
_search_engine: SearchEngine | None = None
def get_search_engine() -> SearchEngine:
"""Get the global search engine instance."""
global _search_engine
if _search_engine is None:
_search_engine = SearchEngine()
return _search_engine
def reset_search_engine() -> None:
"""Reset the global search engine (for testing)."""
global _search_engine
_search_engine = None