Files
syndarix/mcp-servers/knowledge-base/collection_manager.py
Felipe Cardoso cd7a9ccbdf fix(mcp-kb): add transactional batch insert and atomic document update
- Wrap store_embeddings_batch in transaction for all-or-nothing semantics
- Add replace_source_embeddings method for atomic document updates
- Update collection_manager to use transactional replace
- Prevents race conditions and data inconsistency (closes #77)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-04 01:07:40 +01:00

383 lines
12 KiB
Python

"""
Collection management for Knowledge Base MCP Server.
Provides operations for managing document collections including
ingestion, deletion, and statistics.
"""
import logging
from typing import Any
from chunking.base import ChunkerFactory, get_chunker_factory
from config import Settings, get_settings
from database import DatabaseManager, get_database_manager
from embeddings import EmbeddingGenerator, get_embedding_generator
from models import (
ChunkType,
CollectionStats,
DeleteRequest,
DeleteResult,
FileType,
IngestRequest,
IngestResult,
ListCollectionsResponse,
)
logger = logging.getLogger(__name__)
class CollectionManager:
"""
Manages knowledge base collections.
Handles document ingestion, chunking, embedding generation,
and collection operations.
"""
def __init__(
self,
settings: Settings | None = None,
database: DatabaseManager | None = None,
embeddings: EmbeddingGenerator | None = None,
chunker_factory: ChunkerFactory | None = None,
) -> None:
"""Initialize collection manager."""
self._settings = settings or get_settings()
self._database = database
self._embeddings = embeddings
self._chunker_factory = chunker_factory
@property
def database(self) -> DatabaseManager:
"""Get database manager."""
if self._database is None:
self._database = get_database_manager()
return self._database
@property
def embeddings(self) -> EmbeddingGenerator:
"""Get embedding generator."""
if self._embeddings is None:
self._embeddings = get_embedding_generator()
return self._embeddings
@property
def chunker_factory(self) -> ChunkerFactory:
"""Get chunker factory."""
if self._chunker_factory is None:
self._chunker_factory = get_chunker_factory()
return self._chunker_factory
async def ingest(self, request: IngestRequest) -> IngestResult:
"""
Ingest content into the knowledge base.
Chunks the content, generates embeddings, and stores them.
Args:
request: Ingest request with content and options
Returns:
Ingest result with created chunk IDs
"""
try:
# Chunk the content
chunks = self.chunker_factory.chunk_content(
content=request.content,
source_path=request.source_path,
file_type=request.file_type,
chunk_type=request.chunk_type,
metadata=request.metadata,
)
if not chunks:
return IngestResult(
success=True,
chunks_created=0,
embeddings_generated=0,
source_path=request.source_path,
collection=request.collection,
chunk_ids=[],
)
# Extract chunk contents for embedding
chunk_texts = [chunk.content for chunk in chunks]
# Generate embeddings
embeddings_list = await self.embeddings.generate_batch(
texts=chunk_texts,
project_id=request.project_id,
agent_id=request.agent_id,
)
# Store embeddings
chunk_ids: list[str] = []
for chunk, embedding in zip(chunks, embeddings_list, strict=True):
# Build metadata with chunk info
chunk_metadata = {
**request.metadata,
**chunk.metadata,
"token_count": chunk.token_count,
}
chunk_id = await self.database.store_embedding(
project_id=request.project_id,
collection=request.collection,
content=chunk.content,
embedding=embedding,
chunk_type=chunk.chunk_type,
source_path=chunk.source_path or request.source_path,
start_line=chunk.start_line,
end_line=chunk.end_line,
file_type=chunk.file_type or request.file_type,
metadata=chunk_metadata,
)
chunk_ids.append(chunk_id)
logger.info(
f"Ingested {len(chunks)} chunks into collection '{request.collection}' "
f"for project {request.project_id}"
)
return IngestResult(
success=True,
chunks_created=len(chunks),
embeddings_generated=len(embeddings_list),
source_path=request.source_path,
collection=request.collection,
chunk_ids=chunk_ids,
)
except Exception as e:
logger.error(f"Ingest error: {e}")
return IngestResult(
success=False,
chunks_created=0,
embeddings_generated=0,
source_path=request.source_path,
collection=request.collection,
chunk_ids=[],
error=str(e),
)
async def delete(self, request: DeleteRequest) -> DeleteResult:
"""
Delete content from the knowledge base.
Supports deletion by source path, collection, or chunk IDs.
Args:
request: Delete request with target specification
Returns:
Delete result with count of deleted chunks
"""
try:
deleted_count = 0
if request.chunk_ids:
# Delete specific chunks
deleted_count = await self.database.delete_by_ids(
project_id=request.project_id,
chunk_ids=request.chunk_ids,
)
elif request.source_path:
# Delete by source path
deleted_count = await self.database.delete_by_source(
project_id=request.project_id,
source_path=request.source_path,
collection=request.collection,
)
elif request.collection:
# Delete entire collection
deleted_count = await self.database.delete_collection(
project_id=request.project_id,
collection=request.collection,
)
else:
return DeleteResult(
success=False,
chunks_deleted=0,
error="Must specify chunk_ids, source_path, or collection",
)
logger.info(
f"Deleted {deleted_count} chunks for project {request.project_id}"
)
return DeleteResult(
success=True,
chunks_deleted=deleted_count,
)
except Exception as e:
logger.error(f"Delete error: {e}")
return DeleteResult(
success=False,
chunks_deleted=0,
error=str(e),
)
async def list_collections(self, project_id: str) -> ListCollectionsResponse:
"""
List all collections for a project.
Args:
project_id: Project ID
Returns:
List of collection info
"""
collections = await self.database.list_collections(project_id)
return ListCollectionsResponse(
project_id=project_id,
collections=collections,
total_collections=len(collections),
)
async def get_collection_stats(
self,
project_id: str,
collection: str,
) -> CollectionStats:
"""
Get statistics for a collection.
Args:
project_id: Project ID
collection: Collection name
Returns:
Collection statistics
"""
return await self.database.get_collection_stats(project_id, collection)
async def update_document(
self,
project_id: str,
agent_id: str,
source_path: str,
content: str,
collection: str = "default",
chunk_type: ChunkType = ChunkType.TEXT,
file_type: FileType | None = None,
metadata: dict[str, Any] | None = None,
) -> IngestResult:
"""
Update a document by atomically replacing existing chunks.
Uses a database transaction to delete existing chunks and insert new ones
atomically, preventing race conditions during concurrent updates.
Args:
project_id: Project ID
agent_id: Agent ID
source_path: Source file path
content: New content
collection: Collection name
chunk_type: Type of content
file_type: File type for code chunking
metadata: Additional metadata
Returns:
Ingest result
"""
request_metadata = metadata or {}
# Chunk the content
chunks = self.chunker_factory.chunk_content(
content=content,
source_path=source_path,
file_type=file_type,
chunk_type=chunk_type,
metadata=request_metadata,
)
if not chunks:
# No chunks = delete existing and return empty result
await self.database.delete_by_source(
project_id=project_id,
source_path=source_path,
collection=collection,
)
return IngestResult(
success=True,
chunks_created=0,
embeddings_generated=0,
source_path=source_path,
collection=collection,
chunk_ids=[],
)
# Generate embeddings for new chunks
chunk_texts = [chunk.content for chunk in chunks]
embeddings_list = await self.embeddings.generate_batch(
texts=chunk_texts,
project_id=project_id,
agent_id=agent_id,
)
# Build embeddings data for transactional replace
embeddings_data = []
for chunk, embedding in zip(chunks, embeddings_list, strict=True):
chunk_metadata = {
**request_metadata,
**chunk.metadata,
"token_count": chunk.token_count,
"source_path": chunk.source_path or source_path,
"start_line": chunk.start_line,
"end_line": chunk.end_line,
"file_type": (chunk.file_type or file_type).value if (chunk.file_type or file_type) else None,
}
embeddings_data.append((
chunk.content,
embedding,
chunk.chunk_type,
chunk_metadata,
))
# Atomically replace old embeddings with new ones
_, chunk_ids = await self.database.replace_source_embeddings(
project_id=project_id,
source_path=source_path,
collection=collection,
embeddings=embeddings_data,
)
return IngestResult(
success=True,
chunks_created=len(chunk_ids),
embeddings_generated=len(embeddings_list),
source_path=source_path,
collection=collection,
chunk_ids=chunk_ids,
)
async def cleanup_expired(self) -> int:
"""
Remove expired embeddings from all collections.
Returns:
Number of embeddings removed
"""
return await self.database.cleanup_expired()
# Global collection manager instance (lazy initialization)
_collection_manager: CollectionManager | None = None
def get_collection_manager() -> CollectionManager:
"""Get the global collection manager instance."""
global _collection_manager
if _collection_manager is None:
_collection_manager = CollectionManager()
return _collection_manager
def reset_collection_manager() -> None:
"""Reset the global collection manager (for testing)."""
global _collection_manager
_collection_manager = None