Files
syndarix/mcp-servers/knowledge-base/tests/test_collection_manager.py
Felipe Cardoso cd7a9ccbdf fix(mcp-kb): add transactional batch insert and atomic document update
- Wrap store_embeddings_batch in transaction for all-or-nothing semantics
- Add replace_source_embeddings method for atomic document updates
- Update collection_manager to use transactional replace
- Prevents race conditions and data inconsistency (closes #77)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-04 01:07:40 +01:00

242 lines
8.1 KiB
Python

"""Tests for collection manager module."""
from datetime import UTC, datetime
from unittest.mock import MagicMock
import pytest
class TestCollectionManager:
"""Tests for CollectionManager class."""
@pytest.fixture
def collection_manager(self, settings, mock_database, mock_embeddings):
"""Create collection manager with mocks."""
from chunking.base import ChunkerFactory
from collection_manager import CollectionManager
mock_chunker_factory = MagicMock(spec=ChunkerFactory)
# Mock chunk_content to return chunks
from models import Chunk, ChunkType
mock_chunker_factory.chunk_content.return_value = [
Chunk(
content="def hello(): pass",
chunk_type=ChunkType.CODE,
token_count=10,
)
]
return CollectionManager(
settings=settings,
database=mock_database,
embeddings=mock_embeddings,
chunker_factory=mock_chunker_factory,
)
@pytest.mark.asyncio
async def test_ingest_content(self, collection_manager, sample_ingest_request):
"""Test content ingestion."""
result = await collection_manager.ingest(sample_ingest_request)
assert result.success is True
assert result.chunks_created == 1
assert result.embeddings_generated == 1
assert len(result.chunk_ids) == 1
assert result.collection == "default"
@pytest.mark.asyncio
async def test_ingest_empty_content(self, collection_manager):
"""Test ingesting empty content."""
from models import IngestRequest
# Mock chunker to return empty list
collection_manager._chunker_factory.chunk_content.return_value = []
request = IngestRequest(
project_id="proj-123",
agent_id="agent-456",
content="",
)
result = await collection_manager.ingest(request)
assert result.success is True
assert result.chunks_created == 0
assert result.embeddings_generated == 0
@pytest.mark.asyncio
async def test_ingest_error_handling(self, collection_manager, sample_ingest_request):
"""Test ingest error handling."""
# Make embedding generation fail
collection_manager._embeddings.generate_batch.side_effect = Exception("Embedding error")
result = await collection_manager.ingest(sample_ingest_request)
assert result.success is False
assert "Embedding error" in result.error
@pytest.mark.asyncio
async def test_delete_by_source(self, collection_manager, sample_delete_request):
"""Test deletion by source path."""
result = await collection_manager.delete(sample_delete_request)
assert result.success is True
assert result.chunks_deleted == 1 # Mock returns 1
collection_manager._database.delete_by_source.assert_called_once()
@pytest.mark.asyncio
async def test_delete_by_collection(self, collection_manager):
"""Test deletion by collection."""
from models import DeleteRequest
request = DeleteRequest(
project_id="proj-123",
agent_id="agent-456",
collection="to-delete",
)
result = await collection_manager.delete(request)
assert result.success is True
collection_manager._database.delete_collection.assert_called_once()
@pytest.mark.asyncio
async def test_delete_by_ids(self, collection_manager):
"""Test deletion by chunk IDs."""
from models import DeleteRequest
request = DeleteRequest(
project_id="proj-123",
agent_id="agent-456",
chunk_ids=["id-1", "id-2"],
)
result = await collection_manager.delete(request)
assert result.success is True
collection_manager._database.delete_by_ids.assert_called_once()
@pytest.mark.asyncio
async def test_delete_no_target(self, collection_manager):
"""Test deletion with no target specified."""
from models import DeleteRequest
request = DeleteRequest(
project_id="proj-123",
agent_id="agent-456",
)
result = await collection_manager.delete(request)
assert result.success is False
assert "Must specify" in result.error
@pytest.mark.asyncio
async def test_list_collections(self, collection_manager):
"""Test listing collections."""
from models import CollectionInfo
collection_manager._database.list_collections.return_value = [
CollectionInfo(
name="collection-1",
project_id="proj-123",
chunk_count=100,
total_tokens=50000,
file_types=["python"],
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
),
CollectionInfo(
name="collection-2",
project_id="proj-123",
chunk_count=50,
total_tokens=25000,
file_types=["javascript"],
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
),
]
result = await collection_manager.list_collections("proj-123")
assert result.project_id == "proj-123"
assert result.total_collections == 2
assert len(result.collections) == 2
@pytest.mark.asyncio
async def test_get_collection_stats(self, collection_manager):
"""Test getting collection statistics."""
from models import CollectionStats
expected_stats = CollectionStats(
collection="test-collection",
project_id="proj-123",
chunk_count=100,
unique_sources=10,
total_tokens=50000,
avg_chunk_size=500.0,
chunk_types={"code": 60, "text": 40},
file_types={"python": 50, "javascript": 10},
)
collection_manager._database.get_collection_stats.return_value = expected_stats
stats = await collection_manager.get_collection_stats("proj-123", "test-collection")
assert stats.chunk_count == 100
assert stats.unique_sources == 10
collection_manager._database.get_collection_stats.assert_called_once_with(
"proj-123", "test-collection"
)
@pytest.mark.asyncio
async def test_update_document(self, collection_manager):
"""Test updating a document with atomic replace."""
result = await collection_manager.update_document(
project_id="proj-123",
agent_id="agent-456",
source_path="/test/file.py",
content="def updated(): pass",
collection="default",
)
# Should use atomic replace (delete + insert in transaction)
collection_manager._database.replace_source_embeddings.assert_called_once()
assert result.success is True
assert len(result.chunk_ids) == 1
@pytest.mark.asyncio
async def test_cleanup_expired(self, collection_manager):
"""Test cleaning up expired embeddings."""
collection_manager._database.cleanup_expired.return_value = 10
count = await collection_manager.cleanup_expired()
assert count == 10
collection_manager._database.cleanup_expired.assert_called_once()
class TestGlobalCollectionManager:
"""Tests for global collection manager."""
def test_get_collection_manager_singleton(self):
"""Test that get_collection_manager returns singleton."""
from collection_manager import get_collection_manager, reset_collection_manager
reset_collection_manager()
manager1 = get_collection_manager()
manager2 = get_collection_manager()
assert manager1 is manager2
def test_reset_collection_manager(self):
"""Test resetting collection manager."""
from collection_manager import get_collection_manager, reset_collection_manager
manager1 = get_collection_manager()
reset_collection_manager()
manager2 = get_collection_manager()
assert manager1 is not manager2