fix(mcp-kb): add transactional batch insert and atomic document update

- Wrap store_embeddings_batch in transaction for all-or-nothing semantics
- Add replace_source_embeddings method for atomic document updates
- Update collection_manager to use transactional replace
- Prevents race conditions and data inconsistency (closes #77)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-04 01:07:40 +01:00
parent 953af52d0e
commit cd7a9ccbdf
4 changed files with 195 additions and 50 deletions

View File

@@ -265,9 +265,10 @@ class CollectionManager:
metadata: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None,
) -> IngestResult: ) -> IngestResult:
""" """
Update a document by replacing existing chunks. Update a document by atomically replacing existing chunks.
Deletes existing chunks for the source path and ingests new content. Uses a database transaction to delete existing chunks and insert new ones
atomically, preventing race conditions during concurrent updates.
Args: Args:
project_id: Project ID project_id: Project ID
@@ -282,26 +283,76 @@ class CollectionManager:
Returns: Returns:
Ingest result Ingest result
""" """
# First delete existing chunks for this source request_metadata = metadata or {}
await self.database.delete_by_source(
project_id=project_id,
source_path=source_path,
collection=collection,
)
# Then ingest new content # Chunk the content
request = IngestRequest( chunks = self.chunker_factory.chunk_content(
project_id=project_id,
agent_id=agent_id,
content=content, content=content,
source_path=source_path, source_path=source_path,
collection=collection,
chunk_type=chunk_type,
file_type=file_type, file_type=file_type,
metadata=metadata or {}, chunk_type=chunk_type,
metadata=request_metadata,
) )
return await self.ingest(request) if not chunks:
# No chunks = delete existing and return empty result
await self.database.delete_by_source(
project_id=project_id,
source_path=source_path,
collection=collection,
)
return IngestResult(
success=True,
chunks_created=0,
embeddings_generated=0,
source_path=source_path,
collection=collection,
chunk_ids=[],
)
# Generate embeddings for new chunks
chunk_texts = [chunk.content for chunk in chunks]
embeddings_list = await self.embeddings.generate_batch(
texts=chunk_texts,
project_id=project_id,
agent_id=agent_id,
)
# Build embeddings data for transactional replace
embeddings_data = []
for chunk, embedding in zip(chunks, embeddings_list, strict=True):
chunk_metadata = {
**request_metadata,
**chunk.metadata,
"token_count": chunk.token_count,
"source_path": chunk.source_path or source_path,
"start_line": chunk.start_line,
"end_line": chunk.end_line,
"file_type": (chunk.file_type or file_type).value if (chunk.file_type or file_type) else None,
}
embeddings_data.append((
chunk.content,
embedding,
chunk.chunk_type,
chunk_metadata,
))
# Atomically replace old embeddings with new ones
_, chunk_ids = await self.database.replace_source_embeddings(
project_id=project_id,
source_path=source_path,
collection=collection,
embeddings=embeddings_data,
)
return IngestResult(
success=True,
chunks_created=len(chunk_ids),
embeddings_generated=len(embeddings_list),
source_path=source_path,
collection=collection,
chunk_ids=chunk_ids,
)
async def cleanup_expired(self) -> int: async def cleanup_expired(self) -> int:
""" """

View File

@@ -285,38 +285,40 @@ class DatabaseManager:
try: try:
async with self.acquire() as conn: async with self.acquire() as conn:
for project_id, collection, content, embedding, chunk_type, metadata in embeddings: # Wrap in transaction for all-or-nothing batch semantics
content_hash = self.compute_content_hash(content) async with conn.transaction():
source_path = metadata.get("source_path") for project_id, collection, content, embedding, chunk_type, metadata in embeddings:
start_line = metadata.get("start_line") content_hash = self.compute_content_hash(content)
end_line = metadata.get("end_line") source_path = metadata.get("source_path")
file_type = metadata.get("file_type") start_line = metadata.get("start_line")
end_line = metadata.get("end_line")
file_type = metadata.get("file_type")
embedding_id = await conn.fetchval( embedding_id = await conn.fetchval(
""" """
INSERT INTO knowledge_embeddings INSERT INTO knowledge_embeddings
(project_id, collection, content, embedding, chunk_type, (project_id, collection, content, embedding, chunk_type,
source_path, start_line, end_line, file_type, metadata, source_path, start_line, end_line, file_type, metadata,
content_hash, expires_at) content_hash, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
ON CONFLICT DO NOTHING ON CONFLICT DO NOTHING
RETURNING id RETURNING id
""", """,
project_id, project_id,
collection, collection,
content, content,
embedding, embedding,
chunk_type.value, chunk_type.value,
source_path, source_path,
start_line, start_line,
end_line, end_line,
file_type, file_type,
metadata, metadata,
content_hash, content_hash,
expires_at, expires_at,
) )
if embedding_id: if embedding_id:
ids.append(str(embedding_id)) ids.append(str(embedding_id))
logger.info(f"Stored {len(ids)} embeddings in batch") logger.info(f"Stored {len(ids)} embeddings in batch")
return ids return ids
@@ -535,6 +537,96 @@ class DatabaseManager:
cause=e, cause=e,
) )
async def replace_source_embeddings(
self,
project_id: str,
source_path: str,
collection: str,
embeddings: list[tuple[str, list[float], ChunkType, dict[str, Any]]],
) -> tuple[int, list[str]]:
"""
Atomically replace all embeddings for a source path.
Deletes existing embeddings and inserts new ones in a single transaction,
preventing race conditions during document updates.
Args:
project_id: Project ID
source_path: Source file path being updated
collection: Collection name
embeddings: List of (content, embedding, chunk_type, metadata)
Returns:
Tuple of (deleted_count, new_embedding_ids)
"""
expires_at = None
if self._settings.embedding_ttl_days > 0:
expires_at = datetime.now(UTC) + timedelta(
days=self._settings.embedding_ttl_days
)
try:
async with self.acquire() as conn:
# Use transaction for atomic replace
async with conn.transaction():
# First, delete existing embeddings for this source
delete_result = await conn.execute(
"""
DELETE FROM knowledge_embeddings
WHERE project_id = $1 AND source_path = $2 AND collection = $3
""",
project_id,
source_path,
collection,
)
deleted_count = int(delete_result.split()[-1])
# Then insert new embeddings
new_ids = []
for content, embedding, chunk_type, metadata in embeddings:
content_hash = self.compute_content_hash(content)
start_line = metadata.get("start_line")
end_line = metadata.get("end_line")
file_type = metadata.get("file_type")
embedding_id = await conn.fetchval(
"""
INSERT INTO knowledge_embeddings
(project_id, collection, content, embedding, chunk_type,
source_path, start_line, end_line, file_type, metadata,
content_hash, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
RETURNING id
""",
project_id,
collection,
content,
embedding,
chunk_type.value,
source_path,
start_line,
end_line,
file_type,
metadata,
content_hash,
expires_at,
)
if embedding_id:
new_ids.append(str(embedding_id))
logger.info(
f"Replaced source {source_path}: deleted {deleted_count}, "
f"inserted {len(new_ids)} embeddings"
)
return deleted_count, new_ids
except asyncpg.PostgresError as e:
logger.error(f"Replace source error: {e}")
raise DatabaseQueryError(
message=f"Failed to replace source embeddings: {e}",
cause=e,
)
async def delete_collection( async def delete_collection(
self, self,
project_id: str, project_id: str,

View File

@@ -61,6 +61,7 @@ def mock_database():
mock_db.delete_by_source = AsyncMock(return_value=1) mock_db.delete_by_source = AsyncMock(return_value=1)
mock_db.delete_collection = AsyncMock(return_value=5) mock_db.delete_collection = AsyncMock(return_value=5)
mock_db.delete_by_ids = AsyncMock(return_value=2) mock_db.delete_by_ids = AsyncMock(return_value=2)
mock_db.replace_source_embeddings = AsyncMock(return_value=(1, ["new-id-1"]))
mock_db.list_collections = AsyncMock(return_value=[]) mock_db.list_collections = AsyncMock(return_value=[])
mock_db.get_collection_stats = AsyncMock() mock_db.get_collection_stats = AsyncMock()
mock_db.cleanup_expired = AsyncMock(return_value=0) mock_db.cleanup_expired = AsyncMock(return_value=0)

View File

@@ -192,7 +192,7 @@ class TestCollectionManager:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_document(self, collection_manager): async def test_update_document(self, collection_manager):
"""Test updating a document.""" """Test updating a document with atomic replace."""
result = await collection_manager.update_document( result = await collection_manager.update_document(
project_id="proj-123", project_id="proj-123",
agent_id="agent-456", agent_id="agent-456",
@@ -201,9 +201,10 @@ class TestCollectionManager:
collection="default", collection="default",
) )
# Should delete first, then ingest # Should use atomic replace (delete + insert in transaction)
collection_manager._database.delete_by_source.assert_called_once() collection_manager._database.replace_source_embeddings.assert_called_once()
assert result.success is True assert result.success is True
assert len(result.chunk_ids) == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cleanup_expired(self, collection_manager): async def test_cleanup_expired(self, collection_manager):