forked from cardosofelipe/fast-next-template
Improved code formatting, line breaks, and indentation across chunking logic and multiple test modules to enhance code clarity and maintain consistent style. No functional changes made.
288 lines
9.5 KiB
Python
288 lines
9.5 KiB
Python
"""
|
|
Search implementations for Knowledge Base MCP Server.
|
|
|
|
Provides semantic (vector), keyword (full-text), and hybrid search
|
|
capabilities over the knowledge base.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
|
|
from config import Settings, get_settings
|
|
from database import DatabaseManager, get_database_manager
|
|
from embeddings import EmbeddingGenerator, get_embedding_generator
|
|
from exceptions import InvalidSearchTypeError, SearchError
|
|
from models import (
|
|
SearchRequest,
|
|
SearchResponse,
|
|
SearchResult,
|
|
SearchType,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SearchEngine:
|
|
"""
|
|
Unified search engine supporting multiple search types.
|
|
|
|
Features:
|
|
- Semantic search using vector similarity
|
|
- Keyword search using full-text search
|
|
- Hybrid search combining both approaches
|
|
- Configurable result fusion and weighting
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
settings: Settings | None = None,
|
|
database: DatabaseManager | None = None,
|
|
embeddings: EmbeddingGenerator | None = None,
|
|
) -> None:
|
|
"""Initialize search engine."""
|
|
self._settings = settings or get_settings()
|
|
self._database = database
|
|
self._embeddings = embeddings
|
|
|
|
@property
|
|
def database(self) -> DatabaseManager:
|
|
"""Get database manager."""
|
|
if self._database is None:
|
|
self._database = get_database_manager()
|
|
return self._database
|
|
|
|
@property
|
|
def embeddings(self) -> EmbeddingGenerator:
|
|
"""Get embedding generator."""
|
|
if self._embeddings is None:
|
|
self._embeddings = get_embedding_generator()
|
|
return self._embeddings
|
|
|
|
async def search(self, request: SearchRequest) -> SearchResponse:
|
|
"""
|
|
Execute a search request.
|
|
|
|
Args:
|
|
request: Search request with query and options
|
|
|
|
Returns:
|
|
Search response with results
|
|
"""
|
|
start_time = time.time()
|
|
|
|
try:
|
|
if request.search_type == SearchType.SEMANTIC:
|
|
results = await self._semantic_search(request)
|
|
elif request.search_type == SearchType.KEYWORD:
|
|
results = await self._keyword_search(request)
|
|
elif request.search_type == SearchType.HYBRID:
|
|
results = await self._hybrid_search(request)
|
|
else:
|
|
raise InvalidSearchTypeError(
|
|
search_type=request.search_type,
|
|
valid_types=[t.value for t in SearchType],
|
|
)
|
|
|
|
search_time_ms = (time.time() - start_time) * 1000
|
|
|
|
logger.info(
|
|
f"Search completed: type={request.search_type.value}, "
|
|
f"results={len(results)}, time={search_time_ms:.1f}ms"
|
|
)
|
|
|
|
return SearchResponse(
|
|
query=request.query,
|
|
search_type=request.search_type.value,
|
|
results=results,
|
|
total_results=len(results),
|
|
search_time_ms=search_time_ms,
|
|
)
|
|
|
|
except InvalidSearchTypeError:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Search error: {e}")
|
|
raise SearchError(
|
|
message=f"Search failed: {e}",
|
|
cause=e,
|
|
)
|
|
|
|
async def _semantic_search(self, request: SearchRequest) -> list[SearchResult]:
|
|
"""Execute semantic (vector) search."""
|
|
# Generate embedding for query
|
|
query_embedding = await self.embeddings.generate(
|
|
text=request.query,
|
|
project_id=request.project_id,
|
|
agent_id=request.agent_id,
|
|
)
|
|
|
|
# Search database
|
|
results = await self.database.semantic_search(
|
|
project_id=request.project_id,
|
|
query_embedding=query_embedding,
|
|
collection=request.collection,
|
|
limit=request.limit,
|
|
threshold=request.threshold,
|
|
file_types=request.file_types,
|
|
)
|
|
|
|
# Convert to SearchResult
|
|
return [
|
|
SearchResult.from_embedding(embedding, score)
|
|
for embedding, score in results
|
|
]
|
|
|
|
async def _keyword_search(self, request: SearchRequest) -> list[SearchResult]:
|
|
"""Execute keyword (full-text) search."""
|
|
results = await self.database.keyword_search(
|
|
project_id=request.project_id,
|
|
query=request.query,
|
|
collection=request.collection,
|
|
limit=request.limit,
|
|
file_types=request.file_types,
|
|
)
|
|
|
|
# Filter by threshold (keyword search scores are normalized)
|
|
filtered = [
|
|
(emb, score) for emb, score in results if score >= request.threshold
|
|
]
|
|
|
|
return [
|
|
SearchResult.from_embedding(embedding, score)
|
|
for embedding, score in filtered
|
|
]
|
|
|
|
async def _hybrid_search(self, request: SearchRequest) -> list[SearchResult]:
|
|
"""
|
|
Execute hybrid search combining semantic and keyword.
|
|
|
|
Uses Reciprocal Rank Fusion (RRF) for result combination.
|
|
Executes both searches concurrently for better performance.
|
|
"""
|
|
# Execute both searches with higher limits for fusion
|
|
fusion_limit = min(request.limit * 2, 100)
|
|
|
|
# Create modified request for sub-searches
|
|
semantic_request = SearchRequest(
|
|
project_id=request.project_id,
|
|
agent_id=request.agent_id,
|
|
query=request.query,
|
|
search_type=SearchType.SEMANTIC,
|
|
collection=request.collection,
|
|
limit=fusion_limit,
|
|
threshold=request.threshold * 0.8, # Lower threshold for fusion
|
|
file_types=request.file_types,
|
|
include_metadata=request.include_metadata,
|
|
)
|
|
|
|
keyword_request = SearchRequest(
|
|
project_id=request.project_id,
|
|
agent_id=request.agent_id,
|
|
query=request.query,
|
|
search_type=SearchType.KEYWORD,
|
|
collection=request.collection,
|
|
limit=fusion_limit,
|
|
threshold=0.0, # No threshold for keyword, we'll filter after fusion
|
|
file_types=request.file_types,
|
|
include_metadata=request.include_metadata,
|
|
)
|
|
|
|
# Execute searches concurrently for better performance
|
|
semantic_results, keyword_results = await asyncio.gather(
|
|
self._semantic_search(semantic_request),
|
|
self._keyword_search(keyword_request),
|
|
)
|
|
|
|
# Fuse results using RRF
|
|
fused = self._reciprocal_rank_fusion(
|
|
semantic_results=semantic_results,
|
|
keyword_results=keyword_results,
|
|
semantic_weight=self._settings.hybrid_semantic_weight,
|
|
keyword_weight=self._settings.hybrid_keyword_weight,
|
|
)
|
|
|
|
# Filter by threshold and limit
|
|
filtered = [result for result in fused if result.score >= request.threshold][
|
|
: request.limit
|
|
]
|
|
|
|
return filtered
|
|
|
|
def _reciprocal_rank_fusion(
|
|
self,
|
|
semantic_results: list[SearchResult],
|
|
keyword_results: list[SearchResult],
|
|
semantic_weight: float = 0.7,
|
|
keyword_weight: float = 0.3,
|
|
k: int = 60, # RRF constant
|
|
) -> list[SearchResult]:
|
|
"""
|
|
Combine results using Reciprocal Rank Fusion.
|
|
|
|
RRF score = sum(weight / (k + rank)) for each result list.
|
|
"""
|
|
# Calculate RRF scores
|
|
scores: dict[str, float] = {}
|
|
results_by_id: dict[str, SearchResult] = {}
|
|
|
|
# Process semantic results
|
|
for rank, result in enumerate(semantic_results, start=1):
|
|
rrf_score = semantic_weight / (k + rank)
|
|
scores[result.id] = scores.get(result.id, 0) + rrf_score
|
|
results_by_id[result.id] = result
|
|
|
|
# Process keyword results
|
|
for rank, result in enumerate(keyword_results, start=1):
|
|
rrf_score = keyword_weight / (k + rank)
|
|
scores[result.id] = scores.get(result.id, 0) + rrf_score
|
|
if result.id not in results_by_id:
|
|
results_by_id[result.id] = result
|
|
|
|
# Sort by combined score
|
|
sorted_ids = sorted(scores.keys(), key=lambda x: scores[x], reverse=True)
|
|
|
|
# Normalize scores to 0-1 range
|
|
max_score = max(scores.values()) if scores else 1.0
|
|
|
|
# Create final results with normalized scores
|
|
final_results: list[SearchResult] = []
|
|
for result_id in sorted_ids:
|
|
result = results_by_id[result_id]
|
|
normalized_score = scores[result_id] / max_score
|
|
# Create new result with updated score
|
|
final_results.append(
|
|
SearchResult(
|
|
id=result.id,
|
|
content=result.content,
|
|
score=normalized_score,
|
|
source_path=result.source_path,
|
|
start_line=result.start_line,
|
|
end_line=result.end_line,
|
|
chunk_type=result.chunk_type,
|
|
file_type=result.file_type,
|
|
collection=result.collection,
|
|
metadata=result.metadata,
|
|
)
|
|
)
|
|
|
|
return final_results
|
|
|
|
|
|
# Global search engine instance (lazy initialization)
|
|
_search_engine: SearchEngine | None = None
|
|
|
|
|
|
def get_search_engine() -> SearchEngine:
|
|
"""Get the global search engine instance."""
|
|
global _search_engine
|
|
if _search_engine is None:
|
|
_search_engine = SearchEngine()
|
|
return _search_engine
|
|
|
|
|
|
def reset_search_engine() -> None:
|
|
"""Reset the global search engine (for testing)."""
|
|
global _search_engine
|
|
_search_engine = None
|