forked from cardosofelipe/fast-next-template
Implements RAG capabilities with pgvector for semantic search: - Intelligent chunking strategies (code-aware, markdown-aware, text) - Semantic search with vector similarity (HNSW index) - Keyword search with PostgreSQL full-text search - Hybrid search using Reciprocal Rank Fusion (RRF) - Redis caching for embeddings - Collection management (ingest, search, delete, stats) - FastMCP tools: search_knowledge, ingest_content, delete_content, list_collections, get_collection_stats, update_document Testing: - 128 comprehensive tests covering all components - 58% code coverage (database integration tests use mocks) - Passes ruff linting and mypy type checking 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
296 lines
10 KiB
Python
296 lines
10 KiB
Python
"""Tests for search module."""
|
|
|
|
from datetime import UTC, datetime
|
|
|
|
import pytest
|
|
|
|
|
|
class TestSearchEngine:
|
|
"""Tests for SearchEngine class."""
|
|
|
|
@pytest.fixture
|
|
def search_engine(self, settings, mock_database, mock_embeddings):
|
|
"""Create search engine with mocks."""
|
|
from search import SearchEngine
|
|
|
|
engine = SearchEngine(
|
|
settings=settings,
|
|
database=mock_database,
|
|
embeddings=mock_embeddings,
|
|
)
|
|
return engine
|
|
|
|
@pytest.fixture
|
|
def sample_db_results(self):
|
|
"""Create sample database results."""
|
|
from models import ChunkType, FileType, KnowledgeEmbedding
|
|
|
|
return [
|
|
(
|
|
KnowledgeEmbedding(
|
|
id="id-1",
|
|
project_id="proj-123",
|
|
collection="default",
|
|
content="def hello(): pass",
|
|
embedding=[0.1] * 1536,
|
|
chunk_type=ChunkType.CODE,
|
|
source_path="/test/file.py",
|
|
file_type=FileType.PYTHON,
|
|
created_at=datetime.now(UTC),
|
|
updated_at=datetime.now(UTC),
|
|
),
|
|
0.95,
|
|
),
|
|
(
|
|
KnowledgeEmbedding(
|
|
id="id-2",
|
|
project_id="proj-123",
|
|
collection="default",
|
|
content="def world(): pass",
|
|
embedding=[0.2] * 1536,
|
|
chunk_type=ChunkType.CODE,
|
|
source_path="/test/file2.py",
|
|
file_type=FileType.PYTHON,
|
|
created_at=datetime.now(UTC),
|
|
updated_at=datetime.now(UTC),
|
|
),
|
|
0.85,
|
|
),
|
|
]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_semantic_search(self, search_engine, sample_search_request, sample_db_results):
|
|
"""Test semantic search."""
|
|
from models import SearchType
|
|
|
|
sample_search_request.search_type = SearchType.SEMANTIC
|
|
search_engine._database.semantic_search.return_value = sample_db_results
|
|
|
|
response = await search_engine.search(sample_search_request)
|
|
|
|
assert response.search_type == "semantic"
|
|
assert len(response.results) == 2
|
|
assert response.results[0].score == 0.95
|
|
search_engine._database.semantic_search.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_keyword_search(self, search_engine, sample_search_request, sample_db_results):
|
|
"""Test keyword search."""
|
|
from models import SearchType
|
|
|
|
sample_search_request.search_type = SearchType.KEYWORD
|
|
search_engine._database.keyword_search.return_value = sample_db_results
|
|
|
|
response = await search_engine.search(sample_search_request)
|
|
|
|
assert response.search_type == "keyword"
|
|
assert len(response.results) == 2
|
|
search_engine._database.keyword_search.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_hybrid_search(self, search_engine, sample_search_request, sample_db_results):
|
|
"""Test hybrid search."""
|
|
from models import SearchType
|
|
|
|
sample_search_request.search_type = SearchType.HYBRID
|
|
|
|
# Both searches return same results for simplicity
|
|
search_engine._database.semantic_search.return_value = sample_db_results
|
|
search_engine._database.keyword_search.return_value = sample_db_results
|
|
|
|
response = await search_engine.search(sample_search_request)
|
|
|
|
assert response.search_type == "hybrid"
|
|
# Results should be fused
|
|
assert len(response.results) >= 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_with_collection_filter(self, search_engine, sample_search_request, sample_db_results):
|
|
"""Test search with collection filter."""
|
|
from models import SearchType
|
|
|
|
sample_search_request.search_type = SearchType.SEMANTIC
|
|
sample_search_request.collection = "specific-collection"
|
|
search_engine._database.semantic_search.return_value = sample_db_results
|
|
|
|
await search_engine.search(sample_search_request)
|
|
|
|
# Verify collection was passed to database
|
|
call_args = search_engine._database.semantic_search.call_args
|
|
assert call_args.kwargs["collection"] == "specific-collection"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_with_file_type_filter(self, search_engine, sample_search_request, sample_db_results):
|
|
"""Test search with file type filter."""
|
|
from models import FileType, SearchType
|
|
|
|
sample_search_request.search_type = SearchType.SEMANTIC
|
|
sample_search_request.file_types = [FileType.PYTHON]
|
|
search_engine._database.semantic_search.return_value = sample_db_results
|
|
|
|
await search_engine.search(sample_search_request)
|
|
|
|
# Verify file types were passed to database
|
|
call_args = search_engine._database.semantic_search.call_args
|
|
assert call_args.kwargs["file_types"] == [FileType.PYTHON]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_respects_limit(self, search_engine, sample_search_request, sample_db_results):
|
|
"""Test that search respects result limit."""
|
|
from models import SearchType
|
|
|
|
sample_search_request.search_type = SearchType.SEMANTIC
|
|
sample_search_request.limit = 1
|
|
search_engine._database.semantic_search.return_value = sample_db_results[:1]
|
|
|
|
response = await search_engine.search(sample_search_request)
|
|
|
|
assert len(response.results) <= 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_records_time(self, search_engine, sample_search_request, sample_db_results):
|
|
"""Test that search records time."""
|
|
from models import SearchType
|
|
|
|
sample_search_request.search_type = SearchType.SEMANTIC
|
|
search_engine._database.semantic_search.return_value = sample_db_results
|
|
|
|
response = await search_engine.search(sample_search_request)
|
|
|
|
assert response.search_time_ms > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_search_type(self, search_engine, sample_search_request):
|
|
"""Test handling invalid search type."""
|
|
from exceptions import InvalidSearchTypeError
|
|
|
|
# Force invalid search type
|
|
sample_search_request.search_type = "invalid"
|
|
|
|
with pytest.raises((InvalidSearchTypeError, ValueError)):
|
|
await search_engine.search(sample_search_request)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_empty_results(self, search_engine, sample_search_request):
|
|
"""Test search with no results."""
|
|
from models import SearchType
|
|
|
|
sample_search_request.search_type = SearchType.SEMANTIC
|
|
search_engine._database.semantic_search.return_value = []
|
|
|
|
response = await search_engine.search(sample_search_request)
|
|
|
|
assert len(response.results) == 0
|
|
assert response.total_results == 0
|
|
|
|
|
|
class TestReciprocalRankFusion:
|
|
"""Tests for reciprocal rank fusion."""
|
|
|
|
@pytest.fixture
|
|
def search_engine(self, settings, mock_database, mock_embeddings):
|
|
"""Create search engine with mocks."""
|
|
from search import SearchEngine
|
|
|
|
return SearchEngine(
|
|
settings=settings,
|
|
database=mock_database,
|
|
embeddings=mock_embeddings,
|
|
)
|
|
|
|
def test_fusion_combines_results(self, search_engine):
|
|
"""Test that RRF combines results from both searches."""
|
|
from models import SearchResult
|
|
|
|
semantic = [
|
|
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"),
|
|
SearchResult(id="b", content="B", score=0.8, chunk_type="code", collection="default"),
|
|
]
|
|
|
|
keyword = [
|
|
SearchResult(id="b", content="B", score=0.85, chunk_type="code", collection="default"),
|
|
SearchResult(id="c", content="C", score=0.7, chunk_type="code", collection="default"),
|
|
]
|
|
|
|
fused = search_engine._reciprocal_rank_fusion(semantic, keyword)
|
|
|
|
# Should have all unique results
|
|
ids = [r.id for r in fused]
|
|
assert "a" in ids
|
|
assert "b" in ids
|
|
assert "c" in ids
|
|
|
|
# B should be ranked higher (appears in both)
|
|
b_rank = ids.index("b")
|
|
assert b_rank < 2 # Should be in top 2
|
|
|
|
def test_fusion_respects_weights(self, search_engine):
|
|
"""Test that RRF respects semantic/keyword weights."""
|
|
from models import SearchResult
|
|
|
|
# Same results in same order
|
|
results = [
|
|
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"),
|
|
]
|
|
|
|
# High semantic weight
|
|
fused_semantic_heavy = search_engine._reciprocal_rank_fusion(
|
|
results, [],
|
|
semantic_weight=0.9,
|
|
keyword_weight=0.1,
|
|
)
|
|
|
|
# High keyword weight
|
|
fused_keyword_heavy = search_engine._reciprocal_rank_fusion(
|
|
[], results,
|
|
semantic_weight=0.1,
|
|
keyword_weight=0.9,
|
|
)
|
|
|
|
# Both should still return the result
|
|
assert len(fused_semantic_heavy) == 1
|
|
assert len(fused_keyword_heavy) == 1
|
|
|
|
def test_fusion_normalizes_scores(self, search_engine):
|
|
"""Test that RRF normalizes scores to 0-1."""
|
|
from models import SearchResult
|
|
|
|
semantic = [
|
|
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"),
|
|
SearchResult(id="b", content="B", score=0.8, chunk_type="code", collection="default"),
|
|
]
|
|
|
|
keyword = [
|
|
SearchResult(id="c", content="C", score=0.7, chunk_type="code", collection="default"),
|
|
]
|
|
|
|
fused = search_engine._reciprocal_rank_fusion(semantic, keyword)
|
|
|
|
# All scores should be between 0 and 1
|
|
for result in fused:
|
|
assert 0 <= result.score <= 1
|
|
|
|
|
|
class TestGlobalSearchEngine:
|
|
"""Tests for global search engine."""
|
|
|
|
def test_get_search_engine_singleton(self):
|
|
"""Test that get_search_engine returns singleton."""
|
|
from search import get_search_engine, reset_search_engine
|
|
|
|
reset_search_engine()
|
|
engine1 = get_search_engine()
|
|
engine2 = get_search_engine()
|
|
|
|
assert engine1 is engine2
|
|
|
|
def test_reset_search_engine(self):
|
|
"""Test resetting search engine."""
|
|
from search import get_search_engine, reset_search_engine
|
|
|
|
engine1 = get_search_engine()
|
|
reset_search_engine()
|
|
engine2 = get_search_engine()
|
|
|
|
assert engine1 is not engine2
|