From cd7a9ccbdf1d6534073a8bf70e5392cd48350acc Mon Sep 17 00:00:00 2001 From: Felipe Cardoso Date: Sun, 4 Jan 2026 01:07:40 +0100 Subject: [PATCH] fix(mcp-kb): add transactional batch insert and atomic document update MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Wrap store_embeddings_batch in transaction for all-or-nothing semantics - Add replace_source_embeddings method for atomic document updates - Update collection_manager to use transactional replace - Prevents race conditions and data inconsistency (closes #77) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- .../knowledge-base/collection_manager.py | 83 ++++++++-- mcp-servers/knowledge-base/database.py | 154 ++++++++++++++---- mcp-servers/knowledge-base/tests/conftest.py | 1 + .../tests/test_collection_manager.py | 7 +- 4 files changed, 195 insertions(+), 50 deletions(-) diff --git a/mcp-servers/knowledge-base/collection_manager.py b/mcp-servers/knowledge-base/collection_manager.py index e708ab0..25083a5 100644 --- a/mcp-servers/knowledge-base/collection_manager.py +++ b/mcp-servers/knowledge-base/collection_manager.py @@ -265,9 +265,10 @@ class CollectionManager: metadata: dict[str, Any] | None = None, ) -> IngestResult: """ - Update a document by replacing existing chunks. + Update a document by atomically replacing existing chunks. - Deletes existing chunks for the source path and ingests new content. + Uses a database transaction to delete existing chunks and insert new ones + atomically, preventing race conditions during concurrent updates. Args: project_id: Project ID @@ -282,26 +283,76 @@ class CollectionManager: Returns: Ingest result """ - # First delete existing chunks for this source - await self.database.delete_by_source( - project_id=project_id, - source_path=source_path, - collection=collection, - ) + request_metadata = metadata or {} - # Then ingest new content - request = IngestRequest( - project_id=project_id, - agent_id=agent_id, + # Chunk the content + chunks = self.chunker_factory.chunk_content( content=content, source_path=source_path, - collection=collection, - chunk_type=chunk_type, file_type=file_type, - metadata=metadata or {}, + chunk_type=chunk_type, + metadata=request_metadata, ) - return await self.ingest(request) + if not chunks: + # No chunks = delete existing and return empty result + await self.database.delete_by_source( + project_id=project_id, + source_path=source_path, + collection=collection, + ) + return IngestResult( + success=True, + chunks_created=0, + embeddings_generated=0, + source_path=source_path, + collection=collection, + chunk_ids=[], + ) + + # Generate embeddings for new chunks + chunk_texts = [chunk.content for chunk in chunks] + embeddings_list = await self.embeddings.generate_batch( + texts=chunk_texts, + project_id=project_id, + agent_id=agent_id, + ) + + # Build embeddings data for transactional replace + embeddings_data = [] + for chunk, embedding in zip(chunks, embeddings_list, strict=True): + chunk_metadata = { + **request_metadata, + **chunk.metadata, + "token_count": chunk.token_count, + "source_path": chunk.source_path or source_path, + "start_line": chunk.start_line, + "end_line": chunk.end_line, + "file_type": (chunk.file_type or file_type).value if (chunk.file_type or file_type) else None, + } + embeddings_data.append(( + chunk.content, + embedding, + chunk.chunk_type, + chunk_metadata, + )) + + # Atomically replace old embeddings with new ones + _, chunk_ids = await self.database.replace_source_embeddings( + project_id=project_id, + source_path=source_path, + collection=collection, + embeddings=embeddings_data, + ) + + return IngestResult( + success=True, + chunks_created=len(chunk_ids), + embeddings_generated=len(embeddings_list), + source_path=source_path, + collection=collection, + chunk_ids=chunk_ids, + ) async def cleanup_expired(self) -> int: """ diff --git a/mcp-servers/knowledge-base/database.py b/mcp-servers/knowledge-base/database.py index 2db9fba..e28f00f 100644 --- a/mcp-servers/knowledge-base/database.py +++ b/mcp-servers/knowledge-base/database.py @@ -285,38 +285,40 @@ class DatabaseManager: try: async with self.acquire() as conn: - for project_id, collection, content, embedding, chunk_type, metadata in embeddings: - content_hash = self.compute_content_hash(content) - source_path = metadata.get("source_path") - start_line = metadata.get("start_line") - end_line = metadata.get("end_line") - file_type = metadata.get("file_type") + # Wrap in transaction for all-or-nothing batch semantics + async with conn.transaction(): + for project_id, collection, content, embedding, chunk_type, metadata in embeddings: + content_hash = self.compute_content_hash(content) + source_path = metadata.get("source_path") + start_line = metadata.get("start_line") + end_line = metadata.get("end_line") + file_type = metadata.get("file_type") - embedding_id = await conn.fetchval( - """ - INSERT INTO knowledge_embeddings - (project_id, collection, content, embedding, chunk_type, - source_path, start_line, end_line, file_type, metadata, - content_hash, expires_at) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) - ON CONFLICT DO NOTHING - RETURNING id - """, - project_id, - collection, - content, - embedding, - chunk_type.value, - source_path, - start_line, - end_line, - file_type, - metadata, - content_hash, - expires_at, - ) - if embedding_id: - ids.append(str(embedding_id)) + embedding_id = await conn.fetchval( + """ + INSERT INTO knowledge_embeddings + (project_id, collection, content, embedding, chunk_type, + source_path, start_line, end_line, file_type, metadata, + content_hash, expires_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) + ON CONFLICT DO NOTHING + RETURNING id + """, + project_id, + collection, + content, + embedding, + chunk_type.value, + source_path, + start_line, + end_line, + file_type, + metadata, + content_hash, + expires_at, + ) + if embedding_id: + ids.append(str(embedding_id)) logger.info(f"Stored {len(ids)} embeddings in batch") return ids @@ -535,6 +537,96 @@ class DatabaseManager: cause=e, ) + async def replace_source_embeddings( + self, + project_id: str, + source_path: str, + collection: str, + embeddings: list[tuple[str, list[float], ChunkType, dict[str, Any]]], + ) -> tuple[int, list[str]]: + """ + Atomically replace all embeddings for a source path. + + Deletes existing embeddings and inserts new ones in a single transaction, + preventing race conditions during document updates. + + Args: + project_id: Project ID + source_path: Source file path being updated + collection: Collection name + embeddings: List of (content, embedding, chunk_type, metadata) + + Returns: + Tuple of (deleted_count, new_embedding_ids) + """ + expires_at = None + if self._settings.embedding_ttl_days > 0: + expires_at = datetime.now(UTC) + timedelta( + days=self._settings.embedding_ttl_days + ) + + try: + async with self.acquire() as conn: + # Use transaction for atomic replace + async with conn.transaction(): + # First, delete existing embeddings for this source + delete_result = await conn.execute( + """ + DELETE FROM knowledge_embeddings + WHERE project_id = $1 AND source_path = $2 AND collection = $3 + """, + project_id, + source_path, + collection, + ) + deleted_count = int(delete_result.split()[-1]) + + # Then insert new embeddings + new_ids = [] + for content, embedding, chunk_type, metadata in embeddings: + content_hash = self.compute_content_hash(content) + start_line = metadata.get("start_line") + end_line = metadata.get("end_line") + file_type = metadata.get("file_type") + + embedding_id = await conn.fetchval( + """ + INSERT INTO knowledge_embeddings + (project_id, collection, content, embedding, chunk_type, + source_path, start_line, end_line, file_type, metadata, + content_hash, expires_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) + RETURNING id + """, + project_id, + collection, + content, + embedding, + chunk_type.value, + source_path, + start_line, + end_line, + file_type, + metadata, + content_hash, + expires_at, + ) + if embedding_id: + new_ids.append(str(embedding_id)) + + logger.info( + f"Replaced source {source_path}: deleted {deleted_count}, " + f"inserted {len(new_ids)} embeddings" + ) + return deleted_count, new_ids + + except asyncpg.PostgresError as e: + logger.error(f"Replace source error: {e}") + raise DatabaseQueryError( + message=f"Failed to replace source embeddings: {e}", + cause=e, + ) + async def delete_collection( self, project_id: str, diff --git a/mcp-servers/knowledge-base/tests/conftest.py b/mcp-servers/knowledge-base/tests/conftest.py index 55fc922..7c68f09 100644 --- a/mcp-servers/knowledge-base/tests/conftest.py +++ b/mcp-servers/knowledge-base/tests/conftest.py @@ -61,6 +61,7 @@ def mock_database(): mock_db.delete_by_source = AsyncMock(return_value=1) mock_db.delete_collection = AsyncMock(return_value=5) mock_db.delete_by_ids = AsyncMock(return_value=2) + mock_db.replace_source_embeddings = AsyncMock(return_value=(1, ["new-id-1"])) mock_db.list_collections = AsyncMock(return_value=[]) mock_db.get_collection_stats = AsyncMock() mock_db.cleanup_expired = AsyncMock(return_value=0) diff --git a/mcp-servers/knowledge-base/tests/test_collection_manager.py b/mcp-servers/knowledge-base/tests/test_collection_manager.py index 41e91fe..95a201f 100644 --- a/mcp-servers/knowledge-base/tests/test_collection_manager.py +++ b/mcp-servers/knowledge-base/tests/test_collection_manager.py @@ -192,7 +192,7 @@ class TestCollectionManager: @pytest.mark.asyncio async def test_update_document(self, collection_manager): - """Test updating a document.""" + """Test updating a document with atomic replace.""" result = await collection_manager.update_document( project_id="proj-123", agent_id="agent-456", @@ -201,9 +201,10 @@ class TestCollectionManager: collection="default", ) - # Should delete first, then ingest - collection_manager._database.delete_by_source.assert_called_once() + # Should use atomic replace (delete + insert in transaction) + collection_manager._database.replace_source_embeddings.assert_called_once() assert result.success is True + assert len(result.chunk_ids) == 1 @pytest.mark.asyncio async def test_cleanup_expired(self, collection_manager):