forked from cardosofelipe/fast-next-template
Implements RAG capabilities with pgvector for semantic search: - Intelligent chunking strategies (code-aware, markdown-aware, text) - Semantic search with vector similarity (HNSW index) - Keyword search with PostgreSQL full-text search - Hybrid search using Reciprocal Rank Fusion (RRF) - Redis caching for embeddings - Collection management (ingest, search, delete, stats) - FastMCP tools: search_knowledge, ingest_content, delete_content, list_collections, get_collection_stats, update_document Testing: - 128 comprehensive tests covering all components - 58% code coverage (database integration tests use mocks) - Passes ruff linting and mypy type checking 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
286 lines
9.4 KiB
Python
286 lines
9.4 KiB
Python
"""
|
|
Search implementations for Knowledge Base MCP Server.
|
|
|
|
Provides semantic (vector), keyword (full-text), and hybrid search
|
|
capabilities over the knowledge base.
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
|
|
from config import Settings, get_settings
|
|
from database import DatabaseManager, get_database_manager
|
|
from embeddings import EmbeddingGenerator, get_embedding_generator
|
|
from exceptions import InvalidSearchTypeError, SearchError
|
|
from models import (
|
|
SearchRequest,
|
|
SearchResponse,
|
|
SearchResult,
|
|
SearchType,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SearchEngine:
|
|
"""
|
|
Unified search engine supporting multiple search types.
|
|
|
|
Features:
|
|
- Semantic search using vector similarity
|
|
- Keyword search using full-text search
|
|
- Hybrid search combining both approaches
|
|
- Configurable result fusion and weighting
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
settings: Settings | None = None,
|
|
database: DatabaseManager | None = None,
|
|
embeddings: EmbeddingGenerator | None = None,
|
|
) -> None:
|
|
"""Initialize search engine."""
|
|
self._settings = settings or get_settings()
|
|
self._database = database
|
|
self._embeddings = embeddings
|
|
|
|
@property
|
|
def database(self) -> DatabaseManager:
|
|
"""Get database manager."""
|
|
if self._database is None:
|
|
self._database = get_database_manager()
|
|
return self._database
|
|
|
|
@property
|
|
def embeddings(self) -> EmbeddingGenerator:
|
|
"""Get embedding generator."""
|
|
if self._embeddings is None:
|
|
self._embeddings = get_embedding_generator()
|
|
return self._embeddings
|
|
|
|
async def search(self, request: SearchRequest) -> SearchResponse:
|
|
"""
|
|
Execute a search request.
|
|
|
|
Args:
|
|
request: Search request with query and options
|
|
|
|
Returns:
|
|
Search response with results
|
|
"""
|
|
start_time = time.time()
|
|
|
|
try:
|
|
if request.search_type == SearchType.SEMANTIC:
|
|
results = await self._semantic_search(request)
|
|
elif request.search_type == SearchType.KEYWORD:
|
|
results = await self._keyword_search(request)
|
|
elif request.search_type == SearchType.HYBRID:
|
|
results = await self._hybrid_search(request)
|
|
else:
|
|
raise InvalidSearchTypeError(
|
|
search_type=request.search_type,
|
|
valid_types=[t.value for t in SearchType],
|
|
)
|
|
|
|
search_time_ms = (time.time() - start_time) * 1000
|
|
|
|
logger.info(
|
|
f"Search completed: type={request.search_type.value}, "
|
|
f"results={len(results)}, time={search_time_ms:.1f}ms"
|
|
)
|
|
|
|
return SearchResponse(
|
|
query=request.query,
|
|
search_type=request.search_type.value,
|
|
results=results,
|
|
total_results=len(results),
|
|
search_time_ms=search_time_ms,
|
|
)
|
|
|
|
except InvalidSearchTypeError:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Search error: {e}")
|
|
raise SearchError(
|
|
message=f"Search failed: {e}",
|
|
cause=e,
|
|
)
|
|
|
|
async def _semantic_search(self, request: SearchRequest) -> list[SearchResult]:
|
|
"""Execute semantic (vector) search."""
|
|
# Generate embedding for query
|
|
query_embedding = await self.embeddings.generate(
|
|
text=request.query,
|
|
project_id=request.project_id,
|
|
agent_id=request.agent_id,
|
|
)
|
|
|
|
# Search database
|
|
results = await self.database.semantic_search(
|
|
project_id=request.project_id,
|
|
query_embedding=query_embedding,
|
|
collection=request.collection,
|
|
limit=request.limit,
|
|
threshold=request.threshold,
|
|
file_types=request.file_types,
|
|
)
|
|
|
|
# Convert to SearchResult
|
|
return [
|
|
SearchResult.from_embedding(embedding, score)
|
|
for embedding, score in results
|
|
]
|
|
|
|
async def _keyword_search(self, request: SearchRequest) -> list[SearchResult]:
|
|
"""Execute keyword (full-text) search."""
|
|
results = await self.database.keyword_search(
|
|
project_id=request.project_id,
|
|
query=request.query,
|
|
collection=request.collection,
|
|
limit=request.limit,
|
|
file_types=request.file_types,
|
|
)
|
|
|
|
# Filter by threshold (keyword search scores are normalized)
|
|
filtered = [
|
|
(emb, score) for emb, score in results
|
|
if score >= request.threshold
|
|
]
|
|
|
|
return [
|
|
SearchResult.from_embedding(embedding, score)
|
|
for embedding, score in filtered
|
|
]
|
|
|
|
async def _hybrid_search(self, request: SearchRequest) -> list[SearchResult]:
|
|
"""
|
|
Execute hybrid search combining semantic and keyword.
|
|
|
|
Uses Reciprocal Rank Fusion (RRF) for result combination.
|
|
"""
|
|
# Execute both searches with higher limits for fusion
|
|
fusion_limit = min(request.limit * 2, 100)
|
|
|
|
# Create modified request for sub-searches
|
|
semantic_request = SearchRequest(
|
|
project_id=request.project_id,
|
|
agent_id=request.agent_id,
|
|
query=request.query,
|
|
search_type=SearchType.SEMANTIC,
|
|
collection=request.collection,
|
|
limit=fusion_limit,
|
|
threshold=request.threshold * 0.8, # Lower threshold for fusion
|
|
file_types=request.file_types,
|
|
include_metadata=request.include_metadata,
|
|
)
|
|
|
|
keyword_request = SearchRequest(
|
|
project_id=request.project_id,
|
|
agent_id=request.agent_id,
|
|
query=request.query,
|
|
search_type=SearchType.KEYWORD,
|
|
collection=request.collection,
|
|
limit=fusion_limit,
|
|
threshold=0.0, # No threshold for keyword, we'll filter after fusion
|
|
file_types=request.file_types,
|
|
include_metadata=request.include_metadata,
|
|
)
|
|
|
|
# Execute searches
|
|
semantic_results = await self._semantic_search(semantic_request)
|
|
keyword_results = await self._keyword_search(keyword_request)
|
|
|
|
# Fuse results using RRF
|
|
fused = self._reciprocal_rank_fusion(
|
|
semantic_results=semantic_results,
|
|
keyword_results=keyword_results,
|
|
semantic_weight=self._settings.hybrid_semantic_weight,
|
|
keyword_weight=self._settings.hybrid_keyword_weight,
|
|
)
|
|
|
|
# Filter by threshold and limit
|
|
filtered = [
|
|
result for result in fused
|
|
if result.score >= request.threshold
|
|
][:request.limit]
|
|
|
|
return filtered
|
|
|
|
def _reciprocal_rank_fusion(
|
|
self,
|
|
semantic_results: list[SearchResult],
|
|
keyword_results: list[SearchResult],
|
|
semantic_weight: float = 0.7,
|
|
keyword_weight: float = 0.3,
|
|
k: int = 60, # RRF constant
|
|
) -> list[SearchResult]:
|
|
"""
|
|
Combine results using Reciprocal Rank Fusion.
|
|
|
|
RRF score = sum(weight / (k + rank)) for each result list.
|
|
"""
|
|
# Calculate RRF scores
|
|
scores: dict[str, float] = {}
|
|
results_by_id: dict[str, SearchResult] = {}
|
|
|
|
# Process semantic results
|
|
for rank, result in enumerate(semantic_results, start=1):
|
|
rrf_score = semantic_weight / (k + rank)
|
|
scores[result.id] = scores.get(result.id, 0) + rrf_score
|
|
results_by_id[result.id] = result
|
|
|
|
# Process keyword results
|
|
for rank, result in enumerate(keyword_results, start=1):
|
|
rrf_score = keyword_weight / (k + rank)
|
|
scores[result.id] = scores.get(result.id, 0) + rrf_score
|
|
if result.id not in results_by_id:
|
|
results_by_id[result.id] = result
|
|
|
|
# Sort by combined score
|
|
sorted_ids = sorted(scores.keys(), key=lambda x: scores[x], reverse=True)
|
|
|
|
# Normalize scores to 0-1 range
|
|
max_score = max(scores.values()) if scores else 1.0
|
|
|
|
# Create final results with normalized scores
|
|
final_results: list[SearchResult] = []
|
|
for result_id in sorted_ids:
|
|
result = results_by_id[result_id]
|
|
normalized_score = scores[result_id] / max_score
|
|
# Create new result with updated score
|
|
final_results.append(
|
|
SearchResult(
|
|
id=result.id,
|
|
content=result.content,
|
|
score=normalized_score,
|
|
source_path=result.source_path,
|
|
start_line=result.start_line,
|
|
end_line=result.end_line,
|
|
chunk_type=result.chunk_type,
|
|
file_type=result.file_type,
|
|
collection=result.collection,
|
|
metadata=result.metadata,
|
|
)
|
|
)
|
|
|
|
return final_results
|
|
|
|
|
|
# Global search engine instance (lazy initialization)
|
|
_search_engine: SearchEngine | None = None
|
|
|
|
|
|
def get_search_engine() -> SearchEngine:
|
|
"""Get the global search engine instance."""
|
|
global _search_engine
|
|
if _search_engine is None:
|
|
_search_engine = SearchEngine()
|
|
return _search_engine
|
|
|
|
|
|
def reset_search_engine() -> None:
|
|
"""Reset the global search engine (for testing)."""
|
|
global _search_engine
|
|
_search_engine = None
|