"""Tests for search module.""" from datetime import UTC, datetime import pytest class TestSearchEngine: """Tests for SearchEngine class.""" @pytest.fixture def search_engine(self, settings, mock_database, mock_embeddings): """Create search engine with mocks.""" from search import SearchEngine engine = SearchEngine( settings=settings, database=mock_database, embeddings=mock_embeddings, ) return engine @pytest.fixture def sample_db_results(self): """Create sample database results.""" from models import ChunkType, FileType, KnowledgeEmbedding return [ ( KnowledgeEmbedding( id="id-1", project_id="proj-123", collection="default", content="def hello(): pass", embedding=[0.1] * 1536, chunk_type=ChunkType.CODE, source_path="/test/file.py", file_type=FileType.PYTHON, created_at=datetime.now(UTC), updated_at=datetime.now(UTC), ), 0.95, ), ( KnowledgeEmbedding( id="id-2", project_id="proj-123", collection="default", content="def world(): pass", embedding=[0.2] * 1536, chunk_type=ChunkType.CODE, source_path="/test/file2.py", file_type=FileType.PYTHON, created_at=datetime.now(UTC), updated_at=datetime.now(UTC), ), 0.85, ), ] @pytest.mark.asyncio async def test_semantic_search( self, search_engine, sample_search_request, sample_db_results ): """Test semantic search.""" from models import SearchType sample_search_request.search_type = SearchType.SEMANTIC search_engine._database.semantic_search.return_value = sample_db_results response = await search_engine.search(sample_search_request) assert response.search_type == "semantic" assert len(response.results) == 2 assert response.results[0].score == 0.95 search_engine._database.semantic_search.assert_called_once() @pytest.mark.asyncio async def test_keyword_search( self, search_engine, sample_search_request, sample_db_results ): """Test keyword search.""" from models import SearchType sample_search_request.search_type = SearchType.KEYWORD search_engine._database.keyword_search.return_value = sample_db_results response = await search_engine.search(sample_search_request) assert response.search_type == "keyword" assert len(response.results) == 2 search_engine._database.keyword_search.assert_called_once() @pytest.mark.asyncio async def test_hybrid_search( self, search_engine, sample_search_request, sample_db_results ): """Test hybrid search.""" from models import SearchType sample_search_request.search_type = SearchType.HYBRID # Both searches return same results for simplicity search_engine._database.semantic_search.return_value = sample_db_results search_engine._database.keyword_search.return_value = sample_db_results response = await search_engine.search(sample_search_request) assert response.search_type == "hybrid" # Results should be fused assert len(response.results) >= 1 @pytest.mark.asyncio async def test_search_with_collection_filter( self, search_engine, sample_search_request, sample_db_results ): """Test search with collection filter.""" from models import SearchType sample_search_request.search_type = SearchType.SEMANTIC sample_search_request.collection = "specific-collection" search_engine._database.semantic_search.return_value = sample_db_results await search_engine.search(sample_search_request) # Verify collection was passed to database call_args = search_engine._database.semantic_search.call_args assert call_args.kwargs["collection"] == "specific-collection" @pytest.mark.asyncio async def test_search_with_file_type_filter( self, search_engine, sample_search_request, sample_db_results ): """Test search with file type filter.""" from models import FileType, SearchType sample_search_request.search_type = SearchType.SEMANTIC sample_search_request.file_types = [FileType.PYTHON] search_engine._database.semantic_search.return_value = sample_db_results await search_engine.search(sample_search_request) # Verify file types were passed to database call_args = search_engine._database.semantic_search.call_args assert call_args.kwargs["file_types"] == [FileType.PYTHON] @pytest.mark.asyncio async def test_search_respects_limit( self, search_engine, sample_search_request, sample_db_results ): """Test that search respects result limit.""" from models import SearchType sample_search_request.search_type = SearchType.SEMANTIC sample_search_request.limit = 1 search_engine._database.semantic_search.return_value = sample_db_results[:1] response = await search_engine.search(sample_search_request) assert len(response.results) <= 1 @pytest.mark.asyncio async def test_search_records_time( self, search_engine, sample_search_request, sample_db_results ): """Test that search records time.""" from models import SearchType sample_search_request.search_type = SearchType.SEMANTIC search_engine._database.semantic_search.return_value = sample_db_results response = await search_engine.search(sample_search_request) assert response.search_time_ms > 0 @pytest.mark.asyncio async def test_invalid_search_type(self, search_engine, sample_search_request): """Test handling invalid search type.""" from exceptions import InvalidSearchTypeError # Force invalid search type sample_search_request.search_type = "invalid" with pytest.raises((InvalidSearchTypeError, ValueError)): await search_engine.search(sample_search_request) @pytest.mark.asyncio async def test_empty_results(self, search_engine, sample_search_request): """Test search with no results.""" from models import SearchType sample_search_request.search_type = SearchType.SEMANTIC search_engine._database.semantic_search.return_value = [] response = await search_engine.search(sample_search_request) assert len(response.results) == 0 assert response.total_results == 0 class TestReciprocalRankFusion: """Tests for reciprocal rank fusion.""" @pytest.fixture def search_engine(self, settings, mock_database, mock_embeddings): """Create search engine with mocks.""" from search import SearchEngine return SearchEngine( settings=settings, database=mock_database, embeddings=mock_embeddings, ) def test_fusion_combines_results(self, search_engine): """Test that RRF combines results from both searches.""" from models import SearchResult semantic = [ SearchResult( id="a", content="A", score=0.9, chunk_type="code", collection="default" ), SearchResult( id="b", content="B", score=0.8, chunk_type="code", collection="default" ), ] keyword = [ SearchResult( id="b", content="B", score=0.85, chunk_type="code", collection="default" ), SearchResult( id="c", content="C", score=0.7, chunk_type="code", collection="default" ), ] fused = search_engine._reciprocal_rank_fusion(semantic, keyword) # Should have all unique results ids = [r.id for r in fused] assert "a" in ids assert "b" in ids assert "c" in ids # B should be ranked higher (appears in both) b_rank = ids.index("b") assert b_rank < 2 # Should be in top 2 def test_fusion_respects_weights(self, search_engine): """Test that RRF respects semantic/keyword weights.""" from models import SearchResult # Same results in same order results = [ SearchResult( id="a", content="A", score=0.9, chunk_type="code", collection="default" ), ] # High semantic weight fused_semantic_heavy = search_engine._reciprocal_rank_fusion( results, [], semantic_weight=0.9, keyword_weight=0.1, ) # High keyword weight fused_keyword_heavy = search_engine._reciprocal_rank_fusion( [], results, semantic_weight=0.1, keyword_weight=0.9, ) # Both should still return the result assert len(fused_semantic_heavy) == 1 assert len(fused_keyword_heavy) == 1 def test_fusion_normalizes_scores(self, search_engine): """Test that RRF normalizes scores to 0-1.""" from models import SearchResult semantic = [ SearchResult( id="a", content="A", score=0.9, chunk_type="code", collection="default" ), SearchResult( id="b", content="B", score=0.8, chunk_type="code", collection="default" ), ] keyword = [ SearchResult( id="c", content="C", score=0.7, chunk_type="code", collection="default" ), ] fused = search_engine._reciprocal_rank_fusion(semantic, keyword) # All scores should be between 0 and 1 for result in fused: assert 0 <= result.score <= 1 class TestGlobalSearchEngine: """Tests for global search engine.""" def test_get_search_engine_singleton(self): """Test that get_search_engine returns singleton.""" from search import get_search_engine, reset_search_engine reset_search_engine() engine1 = get_search_engine() engine2 = get_search_engine() assert engine1 is engine2 def test_reset_search_engine(self): """Test resetting search engine.""" from search import get_search_engine, reset_search_engine engine1 = get_search_engine() reset_search_engine() engine2 = get_search_engine() assert engine1 is not engine2