fix(mcp-kb): add transactional batch insert and atomic document update

- Wrap store_embeddings_batch in transaction for all-or-nothing semantics
- Add replace_source_embeddings method for atomic document updates
- Update collection_manager to use transactional replace
- Prevents race conditions and data inconsistency (closes #77)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-04 01:07:40 +01:00
parent 953af52d0e
commit cd7a9ccbdf
4 changed files with 195 additions and 50 deletions

View File

@@ -265,9 +265,10 @@ class CollectionManager:
metadata: dict[str, Any] | None = None,
) -> IngestResult:
"""
Update a document by replacing existing chunks.
Update a document by atomically replacing existing chunks.
Deletes existing chunks for the source path and ingests new content.
Uses a database transaction to delete existing chunks and insert new ones
atomically, preventing race conditions during concurrent updates.
Args:
project_id: Project ID
@@ -282,26 +283,76 @@ class CollectionManager:
Returns:
Ingest result
"""
# First delete existing chunks for this source
await self.database.delete_by_source(
project_id=project_id,
source_path=source_path,
collection=collection,
)
request_metadata = metadata or {}
# Then ingest new content
request = IngestRequest(
project_id=project_id,
agent_id=agent_id,
# Chunk the content
chunks = self.chunker_factory.chunk_content(
content=content,
source_path=source_path,
collection=collection,
chunk_type=chunk_type,
file_type=file_type,
metadata=metadata or {},
chunk_type=chunk_type,
metadata=request_metadata,
)
return await self.ingest(request)
if not chunks:
# No chunks = delete existing and return empty result
await self.database.delete_by_source(
project_id=project_id,
source_path=source_path,
collection=collection,
)
return IngestResult(
success=True,
chunks_created=0,
embeddings_generated=0,
source_path=source_path,
collection=collection,
chunk_ids=[],
)
# Generate embeddings for new chunks
chunk_texts = [chunk.content for chunk in chunks]
embeddings_list = await self.embeddings.generate_batch(
texts=chunk_texts,
project_id=project_id,
agent_id=agent_id,
)
# Build embeddings data for transactional replace
embeddings_data = []
for chunk, embedding in zip(chunks, embeddings_list, strict=True):
chunk_metadata = {
**request_metadata,
**chunk.metadata,
"token_count": chunk.token_count,
"source_path": chunk.source_path or source_path,
"start_line": chunk.start_line,
"end_line": chunk.end_line,
"file_type": (chunk.file_type or file_type).value if (chunk.file_type or file_type) else None,
}
embeddings_data.append((
chunk.content,
embedding,
chunk.chunk_type,
chunk_metadata,
))
# Atomically replace old embeddings with new ones
_, chunk_ids = await self.database.replace_source_embeddings(
project_id=project_id,
source_path=source_path,
collection=collection,
embeddings=embeddings_data,
)
return IngestResult(
success=True,
chunks_created=len(chunk_ids),
embeddings_generated=len(embeddings_list),
source_path=source_path,
collection=collection,
chunk_ids=chunk_ids,
)
async def cleanup_expired(self) -> int:
"""

View File

@@ -285,38 +285,40 @@ class DatabaseManager:
try:
async with self.acquire() as conn:
for project_id, collection, content, embedding, chunk_type, metadata in embeddings:
content_hash = self.compute_content_hash(content)
source_path = metadata.get("source_path")
start_line = metadata.get("start_line")
end_line = metadata.get("end_line")
file_type = metadata.get("file_type")
# Wrap in transaction for all-or-nothing batch semantics
async with conn.transaction():
for project_id, collection, content, embedding, chunk_type, metadata in embeddings:
content_hash = self.compute_content_hash(content)
source_path = metadata.get("source_path")
start_line = metadata.get("start_line")
end_line = metadata.get("end_line")
file_type = metadata.get("file_type")
embedding_id = await conn.fetchval(
"""
INSERT INTO knowledge_embeddings
(project_id, collection, content, embedding, chunk_type,
source_path, start_line, end_line, file_type, metadata,
content_hash, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
ON CONFLICT DO NOTHING
RETURNING id
""",
project_id,
collection,
content,
embedding,
chunk_type.value,
source_path,
start_line,
end_line,
file_type,
metadata,
content_hash,
expires_at,
)
if embedding_id:
ids.append(str(embedding_id))
embedding_id = await conn.fetchval(
"""
INSERT INTO knowledge_embeddings
(project_id, collection, content, embedding, chunk_type,
source_path, start_line, end_line, file_type, metadata,
content_hash, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
ON CONFLICT DO NOTHING
RETURNING id
""",
project_id,
collection,
content,
embedding,
chunk_type.value,
source_path,
start_line,
end_line,
file_type,
metadata,
content_hash,
expires_at,
)
if embedding_id:
ids.append(str(embedding_id))
logger.info(f"Stored {len(ids)} embeddings in batch")
return ids
@@ -535,6 +537,96 @@ class DatabaseManager:
cause=e,
)
async def replace_source_embeddings(
self,
project_id: str,
source_path: str,
collection: str,
embeddings: list[tuple[str, list[float], ChunkType, dict[str, Any]]],
) -> tuple[int, list[str]]:
"""
Atomically replace all embeddings for a source path.
Deletes existing embeddings and inserts new ones in a single transaction,
preventing race conditions during document updates.
Args:
project_id: Project ID
source_path: Source file path being updated
collection: Collection name
embeddings: List of (content, embedding, chunk_type, metadata)
Returns:
Tuple of (deleted_count, new_embedding_ids)
"""
expires_at = None
if self._settings.embedding_ttl_days > 0:
expires_at = datetime.now(UTC) + timedelta(
days=self._settings.embedding_ttl_days
)
try:
async with self.acquire() as conn:
# Use transaction for atomic replace
async with conn.transaction():
# First, delete existing embeddings for this source
delete_result = await conn.execute(
"""
DELETE FROM knowledge_embeddings
WHERE project_id = $1 AND source_path = $2 AND collection = $3
""",
project_id,
source_path,
collection,
)
deleted_count = int(delete_result.split()[-1])
# Then insert new embeddings
new_ids = []
for content, embedding, chunk_type, metadata in embeddings:
content_hash = self.compute_content_hash(content)
start_line = metadata.get("start_line")
end_line = metadata.get("end_line")
file_type = metadata.get("file_type")
embedding_id = await conn.fetchval(
"""
INSERT INTO knowledge_embeddings
(project_id, collection, content, embedding, chunk_type,
source_path, start_line, end_line, file_type, metadata,
content_hash, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
RETURNING id
""",
project_id,
collection,
content,
embedding,
chunk_type.value,
source_path,
start_line,
end_line,
file_type,
metadata,
content_hash,
expires_at,
)
if embedding_id:
new_ids.append(str(embedding_id))
logger.info(
f"Replaced source {source_path}: deleted {deleted_count}, "
f"inserted {len(new_ids)} embeddings"
)
return deleted_count, new_ids
except asyncpg.PostgresError as e:
logger.error(f"Replace source error: {e}")
raise DatabaseQueryError(
message=f"Failed to replace source embeddings: {e}",
cause=e,
)
async def delete_collection(
self,
project_id: str,

View File

@@ -61,6 +61,7 @@ def mock_database():
mock_db.delete_by_source = AsyncMock(return_value=1)
mock_db.delete_collection = AsyncMock(return_value=5)
mock_db.delete_by_ids = AsyncMock(return_value=2)
mock_db.replace_source_embeddings = AsyncMock(return_value=(1, ["new-id-1"]))
mock_db.list_collections = AsyncMock(return_value=[])
mock_db.get_collection_stats = AsyncMock()
mock_db.cleanup_expired = AsyncMock(return_value=0)

View File

@@ -192,7 +192,7 @@ class TestCollectionManager:
@pytest.mark.asyncio
async def test_update_document(self, collection_manager):
"""Test updating a document."""
"""Test updating a document with atomic replace."""
result = await collection_manager.update_document(
project_id="proj-123",
agent_id="agent-456",
@@ -201,9 +201,10 @@ class TestCollectionManager:
collection="default",
)
# Should delete first, then ingest
collection_manager._database.delete_by_source.assert_called_once()
# Should use atomic replace (delete + insert in transaction)
collection_manager._database.replace_source_embeddings.assert_called_once()
assert result.success is True
assert len(result.chunk_ids) == 1
@pytest.mark.asyncio
async def test_cleanup_expired(self, collection_manager):