forked from cardosofelipe/fast-next-template
fix(mcp-kb): add transactional batch insert and atomic document update
- Wrap store_embeddings_batch in transaction for all-or-nothing semantics - Add replace_source_embeddings method for atomic document updates - Update collection_manager to use transactional replace - Prevents race conditions and data inconsistency (closes #77) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -265,9 +265,10 @@ class CollectionManager:
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> IngestResult:
|
||||
"""
|
||||
Update a document by replacing existing chunks.
|
||||
Update a document by atomically replacing existing chunks.
|
||||
|
||||
Deletes existing chunks for the source path and ingests new content.
|
||||
Uses a database transaction to delete existing chunks and insert new ones
|
||||
atomically, preventing race conditions during concurrent updates.
|
||||
|
||||
Args:
|
||||
project_id: Project ID
|
||||
@@ -282,26 +283,76 @@ class CollectionManager:
|
||||
Returns:
|
||||
Ingest result
|
||||
"""
|
||||
# First delete existing chunks for this source
|
||||
await self.database.delete_by_source(
|
||||
project_id=project_id,
|
||||
source_path=source_path,
|
||||
collection=collection,
|
||||
)
|
||||
request_metadata = metadata or {}
|
||||
|
||||
# Then ingest new content
|
||||
request = IngestRequest(
|
||||
project_id=project_id,
|
||||
agent_id=agent_id,
|
||||
# Chunk the content
|
||||
chunks = self.chunker_factory.chunk_content(
|
||||
content=content,
|
||||
source_path=source_path,
|
||||
collection=collection,
|
||||
chunk_type=chunk_type,
|
||||
file_type=file_type,
|
||||
metadata=metadata or {},
|
||||
chunk_type=chunk_type,
|
||||
metadata=request_metadata,
|
||||
)
|
||||
|
||||
return await self.ingest(request)
|
||||
if not chunks:
|
||||
# No chunks = delete existing and return empty result
|
||||
await self.database.delete_by_source(
|
||||
project_id=project_id,
|
||||
source_path=source_path,
|
||||
collection=collection,
|
||||
)
|
||||
return IngestResult(
|
||||
success=True,
|
||||
chunks_created=0,
|
||||
embeddings_generated=0,
|
||||
source_path=source_path,
|
||||
collection=collection,
|
||||
chunk_ids=[],
|
||||
)
|
||||
|
||||
# Generate embeddings for new chunks
|
||||
chunk_texts = [chunk.content for chunk in chunks]
|
||||
embeddings_list = await self.embeddings.generate_batch(
|
||||
texts=chunk_texts,
|
||||
project_id=project_id,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
# Build embeddings data for transactional replace
|
||||
embeddings_data = []
|
||||
for chunk, embedding in zip(chunks, embeddings_list, strict=True):
|
||||
chunk_metadata = {
|
||||
**request_metadata,
|
||||
**chunk.metadata,
|
||||
"token_count": chunk.token_count,
|
||||
"source_path": chunk.source_path or source_path,
|
||||
"start_line": chunk.start_line,
|
||||
"end_line": chunk.end_line,
|
||||
"file_type": (chunk.file_type or file_type).value if (chunk.file_type or file_type) else None,
|
||||
}
|
||||
embeddings_data.append((
|
||||
chunk.content,
|
||||
embedding,
|
||||
chunk.chunk_type,
|
||||
chunk_metadata,
|
||||
))
|
||||
|
||||
# Atomically replace old embeddings with new ones
|
||||
_, chunk_ids = await self.database.replace_source_embeddings(
|
||||
project_id=project_id,
|
||||
source_path=source_path,
|
||||
collection=collection,
|
||||
embeddings=embeddings_data,
|
||||
)
|
||||
|
||||
return IngestResult(
|
||||
success=True,
|
||||
chunks_created=len(chunk_ids),
|
||||
embeddings_generated=len(embeddings_list),
|
||||
source_path=source_path,
|
||||
collection=collection,
|
||||
chunk_ids=chunk_ids,
|
||||
)
|
||||
|
||||
async def cleanup_expired(self) -> int:
|
||||
"""
|
||||
|
||||
@@ -285,38 +285,40 @@ class DatabaseManager:
|
||||
|
||||
try:
|
||||
async with self.acquire() as conn:
|
||||
for project_id, collection, content, embedding, chunk_type, metadata in embeddings:
|
||||
content_hash = self.compute_content_hash(content)
|
||||
source_path = metadata.get("source_path")
|
||||
start_line = metadata.get("start_line")
|
||||
end_line = metadata.get("end_line")
|
||||
file_type = metadata.get("file_type")
|
||||
# Wrap in transaction for all-or-nothing batch semantics
|
||||
async with conn.transaction():
|
||||
for project_id, collection, content, embedding, chunk_type, metadata in embeddings:
|
||||
content_hash = self.compute_content_hash(content)
|
||||
source_path = metadata.get("source_path")
|
||||
start_line = metadata.get("start_line")
|
||||
end_line = metadata.get("end_line")
|
||||
file_type = metadata.get("file_type")
|
||||
|
||||
embedding_id = await conn.fetchval(
|
||||
"""
|
||||
INSERT INTO knowledge_embeddings
|
||||
(project_id, collection, content, embedding, chunk_type,
|
||||
source_path, start_line, end_line, file_type, metadata,
|
||||
content_hash, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
|
||||
ON CONFLICT DO NOTHING
|
||||
RETURNING id
|
||||
""",
|
||||
project_id,
|
||||
collection,
|
||||
content,
|
||||
embedding,
|
||||
chunk_type.value,
|
||||
source_path,
|
||||
start_line,
|
||||
end_line,
|
||||
file_type,
|
||||
metadata,
|
||||
content_hash,
|
||||
expires_at,
|
||||
)
|
||||
if embedding_id:
|
||||
ids.append(str(embedding_id))
|
||||
embedding_id = await conn.fetchval(
|
||||
"""
|
||||
INSERT INTO knowledge_embeddings
|
||||
(project_id, collection, content, embedding, chunk_type,
|
||||
source_path, start_line, end_line, file_type, metadata,
|
||||
content_hash, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
|
||||
ON CONFLICT DO NOTHING
|
||||
RETURNING id
|
||||
""",
|
||||
project_id,
|
||||
collection,
|
||||
content,
|
||||
embedding,
|
||||
chunk_type.value,
|
||||
source_path,
|
||||
start_line,
|
||||
end_line,
|
||||
file_type,
|
||||
metadata,
|
||||
content_hash,
|
||||
expires_at,
|
||||
)
|
||||
if embedding_id:
|
||||
ids.append(str(embedding_id))
|
||||
|
||||
logger.info(f"Stored {len(ids)} embeddings in batch")
|
||||
return ids
|
||||
@@ -535,6 +537,96 @@ class DatabaseManager:
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def replace_source_embeddings(
|
||||
self,
|
||||
project_id: str,
|
||||
source_path: str,
|
||||
collection: str,
|
||||
embeddings: list[tuple[str, list[float], ChunkType, dict[str, Any]]],
|
||||
) -> tuple[int, list[str]]:
|
||||
"""
|
||||
Atomically replace all embeddings for a source path.
|
||||
|
||||
Deletes existing embeddings and inserts new ones in a single transaction,
|
||||
preventing race conditions during document updates.
|
||||
|
||||
Args:
|
||||
project_id: Project ID
|
||||
source_path: Source file path being updated
|
||||
collection: Collection name
|
||||
embeddings: List of (content, embedding, chunk_type, metadata)
|
||||
|
||||
Returns:
|
||||
Tuple of (deleted_count, new_embedding_ids)
|
||||
"""
|
||||
expires_at = None
|
||||
if self._settings.embedding_ttl_days > 0:
|
||||
expires_at = datetime.now(UTC) + timedelta(
|
||||
days=self._settings.embedding_ttl_days
|
||||
)
|
||||
|
||||
try:
|
||||
async with self.acquire() as conn:
|
||||
# Use transaction for atomic replace
|
||||
async with conn.transaction():
|
||||
# First, delete existing embeddings for this source
|
||||
delete_result = await conn.execute(
|
||||
"""
|
||||
DELETE FROM knowledge_embeddings
|
||||
WHERE project_id = $1 AND source_path = $2 AND collection = $3
|
||||
""",
|
||||
project_id,
|
||||
source_path,
|
||||
collection,
|
||||
)
|
||||
deleted_count = int(delete_result.split()[-1])
|
||||
|
||||
# Then insert new embeddings
|
||||
new_ids = []
|
||||
for content, embedding, chunk_type, metadata in embeddings:
|
||||
content_hash = self.compute_content_hash(content)
|
||||
start_line = metadata.get("start_line")
|
||||
end_line = metadata.get("end_line")
|
||||
file_type = metadata.get("file_type")
|
||||
|
||||
embedding_id = await conn.fetchval(
|
||||
"""
|
||||
INSERT INTO knowledge_embeddings
|
||||
(project_id, collection, content, embedding, chunk_type,
|
||||
source_path, start_line, end_line, file_type, metadata,
|
||||
content_hash, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
|
||||
RETURNING id
|
||||
""",
|
||||
project_id,
|
||||
collection,
|
||||
content,
|
||||
embedding,
|
||||
chunk_type.value,
|
||||
source_path,
|
||||
start_line,
|
||||
end_line,
|
||||
file_type,
|
||||
metadata,
|
||||
content_hash,
|
||||
expires_at,
|
||||
)
|
||||
if embedding_id:
|
||||
new_ids.append(str(embedding_id))
|
||||
|
||||
logger.info(
|
||||
f"Replaced source {source_path}: deleted {deleted_count}, "
|
||||
f"inserted {len(new_ids)} embeddings"
|
||||
)
|
||||
return deleted_count, new_ids
|
||||
|
||||
except asyncpg.PostgresError as e:
|
||||
logger.error(f"Replace source error: {e}")
|
||||
raise DatabaseQueryError(
|
||||
message=f"Failed to replace source embeddings: {e}",
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def delete_collection(
|
||||
self,
|
||||
project_id: str,
|
||||
|
||||
@@ -61,6 +61,7 @@ def mock_database():
|
||||
mock_db.delete_by_source = AsyncMock(return_value=1)
|
||||
mock_db.delete_collection = AsyncMock(return_value=5)
|
||||
mock_db.delete_by_ids = AsyncMock(return_value=2)
|
||||
mock_db.replace_source_embeddings = AsyncMock(return_value=(1, ["new-id-1"]))
|
||||
mock_db.list_collections = AsyncMock(return_value=[])
|
||||
mock_db.get_collection_stats = AsyncMock()
|
||||
mock_db.cleanup_expired = AsyncMock(return_value=0)
|
||||
|
||||
@@ -192,7 +192,7 @@ class TestCollectionManager:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_document(self, collection_manager):
|
||||
"""Test updating a document."""
|
||||
"""Test updating a document with atomic replace."""
|
||||
result = await collection_manager.update_document(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
@@ -201,9 +201,10 @@ class TestCollectionManager:
|
||||
collection="default",
|
||||
)
|
||||
|
||||
# Should delete first, then ingest
|
||||
collection_manager._database.delete_by_source.assert_called_once()
|
||||
# Should use atomic replace (delete + insert in transaction)
|
||||
collection_manager._database.replace_source_embeddings.assert_called_once()
|
||||
assert result.success is True
|
||||
assert len(result.chunk_ids) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired(self, collection_manager):
|
||||
|
||||
Reference in New Issue
Block a user