forked from cardosofelipe/fast-next-template
Improved code formatting, line breaks, and indentation across chunking logic and multiple test modules to enhance code clarity and maintain consistent style. No functional changes made.
328 lines
11 KiB
Python
328 lines
11 KiB
Python
"""Tests for search module."""
|
|
|
|
from datetime import UTC, datetime
|
|
|
|
import pytest
|
|
|
|
|
|
class TestSearchEngine:
|
|
"""Tests for SearchEngine class."""
|
|
|
|
@pytest.fixture
|
|
def search_engine(self, settings, mock_database, mock_embeddings):
|
|
"""Create search engine with mocks."""
|
|
from search import SearchEngine
|
|
|
|
engine = SearchEngine(
|
|
settings=settings,
|
|
database=mock_database,
|
|
embeddings=mock_embeddings,
|
|
)
|
|
return engine
|
|
|
|
@pytest.fixture
|
|
def sample_db_results(self):
|
|
"""Create sample database results."""
|
|
from models import ChunkType, FileType, KnowledgeEmbedding
|
|
|
|
return [
|
|
(
|
|
KnowledgeEmbedding(
|
|
id="id-1",
|
|
project_id="proj-123",
|
|
collection="default",
|
|
content="def hello(): pass",
|
|
embedding=[0.1] * 1536,
|
|
chunk_type=ChunkType.CODE,
|
|
source_path="/test/file.py",
|
|
file_type=FileType.PYTHON,
|
|
created_at=datetime.now(UTC),
|
|
updated_at=datetime.now(UTC),
|
|
),
|
|
0.95,
|
|
),
|
|
(
|
|
KnowledgeEmbedding(
|
|
id="id-2",
|
|
project_id="proj-123",
|
|
collection="default",
|
|
content="def world(): pass",
|
|
embedding=[0.2] * 1536,
|
|
chunk_type=ChunkType.CODE,
|
|
source_path="/test/file2.py",
|
|
file_type=FileType.PYTHON,
|
|
created_at=datetime.now(UTC),
|
|
updated_at=datetime.now(UTC),
|
|
),
|
|
0.85,
|
|
),
|
|
]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_semantic_search(
|
|
self, search_engine, sample_search_request, sample_db_results
|
|
):
|
|
"""Test semantic search."""
|
|
from models import SearchType
|
|
|
|
sample_search_request.search_type = SearchType.SEMANTIC
|
|
search_engine._database.semantic_search.return_value = sample_db_results
|
|
|
|
response = await search_engine.search(sample_search_request)
|
|
|
|
assert response.search_type == "semantic"
|
|
assert len(response.results) == 2
|
|
assert response.results[0].score == 0.95
|
|
search_engine._database.semantic_search.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_keyword_search(
|
|
self, search_engine, sample_search_request, sample_db_results
|
|
):
|
|
"""Test keyword search."""
|
|
from models import SearchType
|
|
|
|
sample_search_request.search_type = SearchType.KEYWORD
|
|
search_engine._database.keyword_search.return_value = sample_db_results
|
|
|
|
response = await search_engine.search(sample_search_request)
|
|
|
|
assert response.search_type == "keyword"
|
|
assert len(response.results) == 2
|
|
search_engine._database.keyword_search.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_hybrid_search(
|
|
self, search_engine, sample_search_request, sample_db_results
|
|
):
|
|
"""Test hybrid search."""
|
|
from models import SearchType
|
|
|
|
sample_search_request.search_type = SearchType.HYBRID
|
|
|
|
# Both searches return same results for simplicity
|
|
search_engine._database.semantic_search.return_value = sample_db_results
|
|
search_engine._database.keyword_search.return_value = sample_db_results
|
|
|
|
response = await search_engine.search(sample_search_request)
|
|
|
|
assert response.search_type == "hybrid"
|
|
# Results should be fused
|
|
assert len(response.results) >= 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_with_collection_filter(
|
|
self, search_engine, sample_search_request, sample_db_results
|
|
):
|
|
"""Test search with collection filter."""
|
|
from models import SearchType
|
|
|
|
sample_search_request.search_type = SearchType.SEMANTIC
|
|
sample_search_request.collection = "specific-collection"
|
|
search_engine._database.semantic_search.return_value = sample_db_results
|
|
|
|
await search_engine.search(sample_search_request)
|
|
|
|
# Verify collection was passed to database
|
|
call_args = search_engine._database.semantic_search.call_args
|
|
assert call_args.kwargs["collection"] == "specific-collection"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_with_file_type_filter(
|
|
self, search_engine, sample_search_request, sample_db_results
|
|
):
|
|
"""Test search with file type filter."""
|
|
from models import FileType, SearchType
|
|
|
|
sample_search_request.search_type = SearchType.SEMANTIC
|
|
sample_search_request.file_types = [FileType.PYTHON]
|
|
search_engine._database.semantic_search.return_value = sample_db_results
|
|
|
|
await search_engine.search(sample_search_request)
|
|
|
|
# Verify file types were passed to database
|
|
call_args = search_engine._database.semantic_search.call_args
|
|
assert call_args.kwargs["file_types"] == [FileType.PYTHON]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_respects_limit(
|
|
self, search_engine, sample_search_request, sample_db_results
|
|
):
|
|
"""Test that search respects result limit."""
|
|
from models import SearchType
|
|
|
|
sample_search_request.search_type = SearchType.SEMANTIC
|
|
sample_search_request.limit = 1
|
|
search_engine._database.semantic_search.return_value = sample_db_results[:1]
|
|
|
|
response = await search_engine.search(sample_search_request)
|
|
|
|
assert len(response.results) <= 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_records_time(
|
|
self, search_engine, sample_search_request, sample_db_results
|
|
):
|
|
"""Test that search records time."""
|
|
from models import SearchType
|
|
|
|
sample_search_request.search_type = SearchType.SEMANTIC
|
|
search_engine._database.semantic_search.return_value = sample_db_results
|
|
|
|
response = await search_engine.search(sample_search_request)
|
|
|
|
assert response.search_time_ms > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_search_type(self, search_engine, sample_search_request):
|
|
"""Test handling invalid search type."""
|
|
from exceptions import InvalidSearchTypeError
|
|
|
|
# Force invalid search type
|
|
sample_search_request.search_type = "invalid"
|
|
|
|
with pytest.raises((InvalidSearchTypeError, ValueError)):
|
|
await search_engine.search(sample_search_request)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_empty_results(self, search_engine, sample_search_request):
|
|
"""Test search with no results."""
|
|
from models import SearchType
|
|
|
|
sample_search_request.search_type = SearchType.SEMANTIC
|
|
search_engine._database.semantic_search.return_value = []
|
|
|
|
response = await search_engine.search(sample_search_request)
|
|
|
|
assert len(response.results) == 0
|
|
assert response.total_results == 0
|
|
|
|
|
|
class TestReciprocalRankFusion:
|
|
"""Tests for reciprocal rank fusion."""
|
|
|
|
@pytest.fixture
|
|
def search_engine(self, settings, mock_database, mock_embeddings):
|
|
"""Create search engine with mocks."""
|
|
from search import SearchEngine
|
|
|
|
return SearchEngine(
|
|
settings=settings,
|
|
database=mock_database,
|
|
embeddings=mock_embeddings,
|
|
)
|
|
|
|
def test_fusion_combines_results(self, search_engine):
|
|
"""Test that RRF combines results from both searches."""
|
|
from models import SearchResult
|
|
|
|
semantic = [
|
|
SearchResult(
|
|
id="a", content="A", score=0.9, chunk_type="code", collection="default"
|
|
),
|
|
SearchResult(
|
|
id="b", content="B", score=0.8, chunk_type="code", collection="default"
|
|
),
|
|
]
|
|
|
|
keyword = [
|
|
SearchResult(
|
|
id="b", content="B", score=0.85, chunk_type="code", collection="default"
|
|
),
|
|
SearchResult(
|
|
id="c", content="C", score=0.7, chunk_type="code", collection="default"
|
|
),
|
|
]
|
|
|
|
fused = search_engine._reciprocal_rank_fusion(semantic, keyword)
|
|
|
|
# Should have all unique results
|
|
ids = [r.id for r in fused]
|
|
assert "a" in ids
|
|
assert "b" in ids
|
|
assert "c" in ids
|
|
|
|
# B should be ranked higher (appears in both)
|
|
b_rank = ids.index("b")
|
|
assert b_rank < 2 # Should be in top 2
|
|
|
|
def test_fusion_respects_weights(self, search_engine):
|
|
"""Test that RRF respects semantic/keyword weights."""
|
|
from models import SearchResult
|
|
|
|
# Same results in same order
|
|
results = [
|
|
SearchResult(
|
|
id="a", content="A", score=0.9, chunk_type="code", collection="default"
|
|
),
|
|
]
|
|
|
|
# High semantic weight
|
|
fused_semantic_heavy = search_engine._reciprocal_rank_fusion(
|
|
results,
|
|
[],
|
|
semantic_weight=0.9,
|
|
keyword_weight=0.1,
|
|
)
|
|
|
|
# High keyword weight
|
|
fused_keyword_heavy = search_engine._reciprocal_rank_fusion(
|
|
[],
|
|
results,
|
|
semantic_weight=0.1,
|
|
keyword_weight=0.9,
|
|
)
|
|
|
|
# Both should still return the result
|
|
assert len(fused_semantic_heavy) == 1
|
|
assert len(fused_keyword_heavy) == 1
|
|
|
|
def test_fusion_normalizes_scores(self, search_engine):
|
|
"""Test that RRF normalizes scores to 0-1."""
|
|
from models import SearchResult
|
|
|
|
semantic = [
|
|
SearchResult(
|
|
id="a", content="A", score=0.9, chunk_type="code", collection="default"
|
|
),
|
|
SearchResult(
|
|
id="b", content="B", score=0.8, chunk_type="code", collection="default"
|
|
),
|
|
]
|
|
|
|
keyword = [
|
|
SearchResult(
|
|
id="c", content="C", score=0.7, chunk_type="code", collection="default"
|
|
),
|
|
]
|
|
|
|
fused = search_engine._reciprocal_rank_fusion(semantic, keyword)
|
|
|
|
# All scores should be between 0 and 1
|
|
for result in fused:
|
|
assert 0 <= result.score <= 1
|
|
|
|
|
|
class TestGlobalSearchEngine:
|
|
"""Tests for global search engine."""
|
|
|
|
def test_get_search_engine_singleton(self):
|
|
"""Test that get_search_engine returns singleton."""
|
|
from search import get_search_engine, reset_search_engine
|
|
|
|
reset_search_engine()
|
|
engine1 = get_search_engine()
|
|
engine2 = get_search_engine()
|
|
|
|
assert engine1 is engine2
|
|
|
|
def test_reset_search_engine(self):
|
|
"""Test resetting search engine."""
|
|
from search import get_search_engine, reset_search_engine
|
|
|
|
engine1 = get_search_engine()
|
|
reset_search_engine()
|
|
engine2 = get_search_engine()
|
|
|
|
assert engine1 is not engine2
|