""" Search implementations for Knowledge Base MCP Server. Provides semantic (vector), keyword (full-text), and hybrid search capabilities over the knowledge base. """ import logging import time from config import Settings, get_settings from database import DatabaseManager, get_database_manager from embeddings import EmbeddingGenerator, get_embedding_generator from exceptions import InvalidSearchTypeError, SearchError from models import ( SearchRequest, SearchResponse, SearchResult, SearchType, ) logger = logging.getLogger(__name__) class SearchEngine: """ Unified search engine supporting multiple search types. Features: - Semantic search using vector similarity - Keyword search using full-text search - Hybrid search combining both approaches - Configurable result fusion and weighting """ def __init__( self, settings: Settings | None = None, database: DatabaseManager | None = None, embeddings: EmbeddingGenerator | None = None, ) -> None: """Initialize search engine.""" self._settings = settings or get_settings() self._database = database self._embeddings = embeddings @property def database(self) -> DatabaseManager: """Get database manager.""" if self._database is None: self._database = get_database_manager() return self._database @property def embeddings(self) -> EmbeddingGenerator: """Get embedding generator.""" if self._embeddings is None: self._embeddings = get_embedding_generator() return self._embeddings async def search(self, request: SearchRequest) -> SearchResponse: """ Execute a search request. Args: request: Search request with query and options Returns: Search response with results """ start_time = time.time() try: if request.search_type == SearchType.SEMANTIC: results = await self._semantic_search(request) elif request.search_type == SearchType.KEYWORD: results = await self._keyword_search(request) elif request.search_type == SearchType.HYBRID: results = await self._hybrid_search(request) else: raise InvalidSearchTypeError( search_type=request.search_type, valid_types=[t.value for t in SearchType], ) search_time_ms = (time.time() - start_time) * 1000 logger.info( f"Search completed: type={request.search_type.value}, " f"results={len(results)}, time={search_time_ms:.1f}ms" ) return SearchResponse( query=request.query, search_type=request.search_type.value, results=results, total_results=len(results), search_time_ms=search_time_ms, ) except InvalidSearchTypeError: raise except Exception as e: logger.error(f"Search error: {e}") raise SearchError( message=f"Search failed: {e}", cause=e, ) async def _semantic_search(self, request: SearchRequest) -> list[SearchResult]: """Execute semantic (vector) search.""" # Generate embedding for query query_embedding = await self.embeddings.generate( text=request.query, project_id=request.project_id, agent_id=request.agent_id, ) # Search database results = await self.database.semantic_search( project_id=request.project_id, query_embedding=query_embedding, collection=request.collection, limit=request.limit, threshold=request.threshold, file_types=request.file_types, ) # Convert to SearchResult return [ SearchResult.from_embedding(embedding, score) for embedding, score in results ] async def _keyword_search(self, request: SearchRequest) -> list[SearchResult]: """Execute keyword (full-text) search.""" results = await self.database.keyword_search( project_id=request.project_id, query=request.query, collection=request.collection, limit=request.limit, file_types=request.file_types, ) # Filter by threshold (keyword search scores are normalized) filtered = [ (emb, score) for emb, score in results if score >= request.threshold ] return [ SearchResult.from_embedding(embedding, score) for embedding, score in filtered ] async def _hybrid_search(self, request: SearchRequest) -> list[SearchResult]: """ Execute hybrid search combining semantic and keyword. Uses Reciprocal Rank Fusion (RRF) for result combination. """ # Execute both searches with higher limits for fusion fusion_limit = min(request.limit * 2, 100) # Create modified request for sub-searches semantic_request = SearchRequest( project_id=request.project_id, agent_id=request.agent_id, query=request.query, search_type=SearchType.SEMANTIC, collection=request.collection, limit=fusion_limit, threshold=request.threshold * 0.8, # Lower threshold for fusion file_types=request.file_types, include_metadata=request.include_metadata, ) keyword_request = SearchRequest( project_id=request.project_id, agent_id=request.agent_id, query=request.query, search_type=SearchType.KEYWORD, collection=request.collection, limit=fusion_limit, threshold=0.0, # No threshold for keyword, we'll filter after fusion file_types=request.file_types, include_metadata=request.include_metadata, ) # Execute searches semantic_results = await self._semantic_search(semantic_request) keyword_results = await self._keyword_search(keyword_request) # Fuse results using RRF fused = self._reciprocal_rank_fusion( semantic_results=semantic_results, keyword_results=keyword_results, semantic_weight=self._settings.hybrid_semantic_weight, keyword_weight=self._settings.hybrid_keyword_weight, ) # Filter by threshold and limit filtered = [ result for result in fused if result.score >= request.threshold ][:request.limit] return filtered def _reciprocal_rank_fusion( self, semantic_results: list[SearchResult], keyword_results: list[SearchResult], semantic_weight: float = 0.7, keyword_weight: float = 0.3, k: int = 60, # RRF constant ) -> list[SearchResult]: """ Combine results using Reciprocal Rank Fusion. RRF score = sum(weight / (k + rank)) for each result list. """ # Calculate RRF scores scores: dict[str, float] = {} results_by_id: dict[str, SearchResult] = {} # Process semantic results for rank, result in enumerate(semantic_results, start=1): rrf_score = semantic_weight / (k + rank) scores[result.id] = scores.get(result.id, 0) + rrf_score results_by_id[result.id] = result # Process keyword results for rank, result in enumerate(keyword_results, start=1): rrf_score = keyword_weight / (k + rank) scores[result.id] = scores.get(result.id, 0) + rrf_score if result.id not in results_by_id: results_by_id[result.id] = result # Sort by combined score sorted_ids = sorted(scores.keys(), key=lambda x: scores[x], reverse=True) # Normalize scores to 0-1 range max_score = max(scores.values()) if scores else 1.0 # Create final results with normalized scores final_results: list[SearchResult] = [] for result_id in sorted_ids: result = results_by_id[result_id] normalized_score = scores[result_id] / max_score # Create new result with updated score final_results.append( SearchResult( id=result.id, content=result.content, score=normalized_score, source_path=result.source_path, start_line=result.start_line, end_line=result.end_line, chunk_type=result.chunk_type, file_type=result.file_type, collection=result.collection, metadata=result.metadata, ) ) return final_results # Global search engine instance (lazy initialization) _search_engine: SearchEngine | None = None def get_search_engine() -> SearchEngine: """Get the global search engine instance.""" global _search_engine if _search_engine is None: _search_engine = SearchEngine() return _search_engine def reset_search_engine() -> None: """Reset the global search engine (for testing).""" global _search_engine _search_engine = None