"""Tests for collection manager module.""" from datetime import UTC, datetime from unittest.mock import MagicMock import pytest class TestCollectionManager: """Tests for CollectionManager class.""" @pytest.fixture def collection_manager(self, settings, mock_database, mock_embeddings): """Create collection manager with mocks.""" from chunking.base import ChunkerFactory from collection_manager import CollectionManager mock_chunker_factory = MagicMock(spec=ChunkerFactory) # Mock chunk_content to return chunks from models import Chunk, ChunkType mock_chunker_factory.chunk_content.return_value = [ Chunk( content="def hello(): pass", chunk_type=ChunkType.CODE, token_count=10, ) ] return CollectionManager( settings=settings, database=mock_database, embeddings=mock_embeddings, chunker_factory=mock_chunker_factory, ) @pytest.mark.asyncio async def test_ingest_content(self, collection_manager, sample_ingest_request): """Test content ingestion.""" result = await collection_manager.ingest(sample_ingest_request) assert result.success is True assert result.chunks_created == 1 assert result.embeddings_generated == 1 assert len(result.chunk_ids) == 1 assert result.collection == "default" @pytest.mark.asyncio async def test_ingest_empty_content(self, collection_manager): """Test ingesting empty content.""" from models import IngestRequest # Mock chunker to return empty list collection_manager._chunker_factory.chunk_content.return_value = [] request = IngestRequest( project_id="proj-123", agent_id="agent-456", content="", ) result = await collection_manager.ingest(request) assert result.success is True assert result.chunks_created == 0 assert result.embeddings_generated == 0 @pytest.mark.asyncio async def test_ingest_error_handling( self, collection_manager, sample_ingest_request ): """Test ingest error handling.""" # Make embedding generation fail collection_manager._embeddings.generate_batch.side_effect = Exception( "Embedding error" ) result = await collection_manager.ingest(sample_ingest_request) assert result.success is False assert "Embedding error" in result.error @pytest.mark.asyncio async def test_delete_by_source(self, collection_manager, sample_delete_request): """Test deletion by source path.""" result = await collection_manager.delete(sample_delete_request) assert result.success is True assert result.chunks_deleted == 1 # Mock returns 1 collection_manager._database.delete_by_source.assert_called_once() @pytest.mark.asyncio async def test_delete_by_collection(self, collection_manager): """Test deletion by collection.""" from models import DeleteRequest request = DeleteRequest( project_id="proj-123", agent_id="agent-456", collection="to-delete", ) result = await collection_manager.delete(request) assert result.success is True collection_manager._database.delete_collection.assert_called_once() @pytest.mark.asyncio async def test_delete_by_ids(self, collection_manager): """Test deletion by chunk IDs.""" from models import DeleteRequest request = DeleteRequest( project_id="proj-123", agent_id="agent-456", chunk_ids=["id-1", "id-2"], ) result = await collection_manager.delete(request) assert result.success is True collection_manager._database.delete_by_ids.assert_called_once() @pytest.mark.asyncio async def test_delete_no_target(self, collection_manager): """Test deletion with no target specified.""" from models import DeleteRequest request = DeleteRequest( project_id="proj-123", agent_id="agent-456", ) result = await collection_manager.delete(request) assert result.success is False assert "Must specify" in result.error @pytest.mark.asyncio async def test_list_collections(self, collection_manager): """Test listing collections.""" from models import CollectionInfo collection_manager._database.list_collections.return_value = [ CollectionInfo( name="collection-1", project_id="proj-123", chunk_count=100, total_tokens=50000, file_types=["python"], created_at=datetime.now(UTC), updated_at=datetime.now(UTC), ), CollectionInfo( name="collection-2", project_id="proj-123", chunk_count=50, total_tokens=25000, file_types=["javascript"], created_at=datetime.now(UTC), updated_at=datetime.now(UTC), ), ] result = await collection_manager.list_collections("proj-123") assert result.project_id == "proj-123" assert result.total_collections == 2 assert len(result.collections) == 2 @pytest.mark.asyncio async def test_get_collection_stats(self, collection_manager): """Test getting collection statistics.""" from models import CollectionStats expected_stats = CollectionStats( collection="test-collection", project_id="proj-123", chunk_count=100, unique_sources=10, total_tokens=50000, avg_chunk_size=500.0, chunk_types={"code": 60, "text": 40}, file_types={"python": 50, "javascript": 10}, ) collection_manager._database.get_collection_stats.return_value = expected_stats stats = await collection_manager.get_collection_stats( "proj-123", "test-collection" ) assert stats.chunk_count == 100 assert stats.unique_sources == 10 collection_manager._database.get_collection_stats.assert_called_once_with( "proj-123", "test-collection" ) @pytest.mark.asyncio async def test_update_document(self, collection_manager): """Test updating a document with atomic replace.""" result = await collection_manager.update_document( project_id="proj-123", agent_id="agent-456", source_path="/test/file.py", content="def updated(): pass", collection="default", ) # Should use atomic replace (delete + insert in transaction) collection_manager._database.replace_source_embeddings.assert_called_once() assert result.success is True assert len(result.chunk_ids) == 1 @pytest.mark.asyncio async def test_cleanup_expired(self, collection_manager): """Test cleaning up expired embeddings.""" collection_manager._database.cleanup_expired.return_value = 10 count = await collection_manager.cleanup_expired() assert count == 10 collection_manager._database.cleanup_expired.assert_called_once() class TestGlobalCollectionManager: """Tests for global collection manager.""" def test_get_collection_manager_singleton(self): """Test that get_collection_manager returns singleton.""" from collection_manager import get_collection_manager, reset_collection_manager reset_collection_manager() manager1 = get_collection_manager() manager2 = get_collection_manager() assert manager1 is manager2 def test_reset_collection_manager(self): """Test resetting collection manager.""" from collection_manager import get_collection_manager, reset_collection_manager manager1 = get_collection_manager() reset_collection_manager() manager2 = get_collection_manager() assert manager1 is not manager2