forked from cardosofelipe/fast-next-template
- Wrap store_embeddings_batch in transaction for all-or-nothing semantics - Add replace_source_embeddings method for atomic document updates - Update collection_manager to use transactional replace - Prevents race conditions and data inconsistency (closes #77) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
383 lines
12 KiB
Python
383 lines
12 KiB
Python
"""
|
|
Collection management for Knowledge Base MCP Server.
|
|
|
|
Provides operations for managing document collections including
|
|
ingestion, deletion, and statistics.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Any
|
|
|
|
from chunking.base import ChunkerFactory, get_chunker_factory
|
|
from config import Settings, get_settings
|
|
from database import DatabaseManager, get_database_manager
|
|
from embeddings import EmbeddingGenerator, get_embedding_generator
|
|
from models import (
|
|
ChunkType,
|
|
CollectionStats,
|
|
DeleteRequest,
|
|
DeleteResult,
|
|
FileType,
|
|
IngestRequest,
|
|
IngestResult,
|
|
ListCollectionsResponse,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class CollectionManager:
|
|
"""
|
|
Manages knowledge base collections.
|
|
|
|
Handles document ingestion, chunking, embedding generation,
|
|
and collection operations.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
settings: Settings | None = None,
|
|
database: DatabaseManager | None = None,
|
|
embeddings: EmbeddingGenerator | None = None,
|
|
chunker_factory: ChunkerFactory | None = None,
|
|
) -> None:
|
|
"""Initialize collection manager."""
|
|
self._settings = settings or get_settings()
|
|
self._database = database
|
|
self._embeddings = embeddings
|
|
self._chunker_factory = chunker_factory
|
|
|
|
@property
|
|
def database(self) -> DatabaseManager:
|
|
"""Get database manager."""
|
|
if self._database is None:
|
|
self._database = get_database_manager()
|
|
return self._database
|
|
|
|
@property
|
|
def embeddings(self) -> EmbeddingGenerator:
|
|
"""Get embedding generator."""
|
|
if self._embeddings is None:
|
|
self._embeddings = get_embedding_generator()
|
|
return self._embeddings
|
|
|
|
@property
|
|
def chunker_factory(self) -> ChunkerFactory:
|
|
"""Get chunker factory."""
|
|
if self._chunker_factory is None:
|
|
self._chunker_factory = get_chunker_factory()
|
|
return self._chunker_factory
|
|
|
|
async def ingest(self, request: IngestRequest) -> IngestResult:
|
|
"""
|
|
Ingest content into the knowledge base.
|
|
|
|
Chunks the content, generates embeddings, and stores them.
|
|
|
|
Args:
|
|
request: Ingest request with content and options
|
|
|
|
Returns:
|
|
Ingest result with created chunk IDs
|
|
"""
|
|
try:
|
|
# Chunk the content
|
|
chunks = self.chunker_factory.chunk_content(
|
|
content=request.content,
|
|
source_path=request.source_path,
|
|
file_type=request.file_type,
|
|
chunk_type=request.chunk_type,
|
|
metadata=request.metadata,
|
|
)
|
|
|
|
if not chunks:
|
|
return IngestResult(
|
|
success=True,
|
|
chunks_created=0,
|
|
embeddings_generated=0,
|
|
source_path=request.source_path,
|
|
collection=request.collection,
|
|
chunk_ids=[],
|
|
)
|
|
|
|
# Extract chunk contents for embedding
|
|
chunk_texts = [chunk.content for chunk in chunks]
|
|
|
|
# Generate embeddings
|
|
embeddings_list = await self.embeddings.generate_batch(
|
|
texts=chunk_texts,
|
|
project_id=request.project_id,
|
|
agent_id=request.agent_id,
|
|
)
|
|
|
|
# Store embeddings
|
|
chunk_ids: list[str] = []
|
|
for chunk, embedding in zip(chunks, embeddings_list, strict=True):
|
|
# Build metadata with chunk info
|
|
chunk_metadata = {
|
|
**request.metadata,
|
|
**chunk.metadata,
|
|
"token_count": chunk.token_count,
|
|
}
|
|
|
|
chunk_id = await self.database.store_embedding(
|
|
project_id=request.project_id,
|
|
collection=request.collection,
|
|
content=chunk.content,
|
|
embedding=embedding,
|
|
chunk_type=chunk.chunk_type,
|
|
source_path=chunk.source_path or request.source_path,
|
|
start_line=chunk.start_line,
|
|
end_line=chunk.end_line,
|
|
file_type=chunk.file_type or request.file_type,
|
|
metadata=chunk_metadata,
|
|
)
|
|
chunk_ids.append(chunk_id)
|
|
|
|
logger.info(
|
|
f"Ingested {len(chunks)} chunks into collection '{request.collection}' "
|
|
f"for project {request.project_id}"
|
|
)
|
|
|
|
return IngestResult(
|
|
success=True,
|
|
chunks_created=len(chunks),
|
|
embeddings_generated=len(embeddings_list),
|
|
source_path=request.source_path,
|
|
collection=request.collection,
|
|
chunk_ids=chunk_ids,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Ingest error: {e}")
|
|
return IngestResult(
|
|
success=False,
|
|
chunks_created=0,
|
|
embeddings_generated=0,
|
|
source_path=request.source_path,
|
|
collection=request.collection,
|
|
chunk_ids=[],
|
|
error=str(e),
|
|
)
|
|
|
|
async def delete(self, request: DeleteRequest) -> DeleteResult:
|
|
"""
|
|
Delete content from the knowledge base.
|
|
|
|
Supports deletion by source path, collection, or chunk IDs.
|
|
|
|
Args:
|
|
request: Delete request with target specification
|
|
|
|
Returns:
|
|
Delete result with count of deleted chunks
|
|
"""
|
|
try:
|
|
deleted_count = 0
|
|
|
|
if request.chunk_ids:
|
|
# Delete specific chunks
|
|
deleted_count = await self.database.delete_by_ids(
|
|
project_id=request.project_id,
|
|
chunk_ids=request.chunk_ids,
|
|
)
|
|
elif request.source_path:
|
|
# Delete by source path
|
|
deleted_count = await self.database.delete_by_source(
|
|
project_id=request.project_id,
|
|
source_path=request.source_path,
|
|
collection=request.collection,
|
|
)
|
|
elif request.collection:
|
|
# Delete entire collection
|
|
deleted_count = await self.database.delete_collection(
|
|
project_id=request.project_id,
|
|
collection=request.collection,
|
|
)
|
|
else:
|
|
return DeleteResult(
|
|
success=False,
|
|
chunks_deleted=0,
|
|
error="Must specify chunk_ids, source_path, or collection",
|
|
)
|
|
|
|
logger.info(
|
|
f"Deleted {deleted_count} chunks for project {request.project_id}"
|
|
)
|
|
|
|
return DeleteResult(
|
|
success=True,
|
|
chunks_deleted=deleted_count,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Delete error: {e}")
|
|
return DeleteResult(
|
|
success=False,
|
|
chunks_deleted=0,
|
|
error=str(e),
|
|
)
|
|
|
|
async def list_collections(self, project_id: str) -> ListCollectionsResponse:
|
|
"""
|
|
List all collections for a project.
|
|
|
|
Args:
|
|
project_id: Project ID
|
|
|
|
Returns:
|
|
List of collection info
|
|
"""
|
|
collections = await self.database.list_collections(project_id)
|
|
|
|
return ListCollectionsResponse(
|
|
project_id=project_id,
|
|
collections=collections,
|
|
total_collections=len(collections),
|
|
)
|
|
|
|
async def get_collection_stats(
|
|
self,
|
|
project_id: str,
|
|
collection: str,
|
|
) -> CollectionStats:
|
|
"""
|
|
Get statistics for a collection.
|
|
|
|
Args:
|
|
project_id: Project ID
|
|
collection: Collection name
|
|
|
|
Returns:
|
|
Collection statistics
|
|
"""
|
|
return await self.database.get_collection_stats(project_id, collection)
|
|
|
|
async def update_document(
|
|
self,
|
|
project_id: str,
|
|
agent_id: str,
|
|
source_path: str,
|
|
content: str,
|
|
collection: str = "default",
|
|
chunk_type: ChunkType = ChunkType.TEXT,
|
|
file_type: FileType | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> IngestResult:
|
|
"""
|
|
Update a document by atomically replacing existing chunks.
|
|
|
|
Uses a database transaction to delete existing chunks and insert new ones
|
|
atomically, preventing race conditions during concurrent updates.
|
|
|
|
Args:
|
|
project_id: Project ID
|
|
agent_id: Agent ID
|
|
source_path: Source file path
|
|
content: New content
|
|
collection: Collection name
|
|
chunk_type: Type of content
|
|
file_type: File type for code chunking
|
|
metadata: Additional metadata
|
|
|
|
Returns:
|
|
Ingest result
|
|
"""
|
|
request_metadata = metadata or {}
|
|
|
|
# Chunk the content
|
|
chunks = self.chunker_factory.chunk_content(
|
|
content=content,
|
|
source_path=source_path,
|
|
file_type=file_type,
|
|
chunk_type=chunk_type,
|
|
metadata=request_metadata,
|
|
)
|
|
|
|
if not chunks:
|
|
# No chunks = delete existing and return empty result
|
|
await self.database.delete_by_source(
|
|
project_id=project_id,
|
|
source_path=source_path,
|
|
collection=collection,
|
|
)
|
|
return IngestResult(
|
|
success=True,
|
|
chunks_created=0,
|
|
embeddings_generated=0,
|
|
source_path=source_path,
|
|
collection=collection,
|
|
chunk_ids=[],
|
|
)
|
|
|
|
# Generate embeddings for new chunks
|
|
chunk_texts = [chunk.content for chunk in chunks]
|
|
embeddings_list = await self.embeddings.generate_batch(
|
|
texts=chunk_texts,
|
|
project_id=project_id,
|
|
agent_id=agent_id,
|
|
)
|
|
|
|
# Build embeddings data for transactional replace
|
|
embeddings_data = []
|
|
for chunk, embedding in zip(chunks, embeddings_list, strict=True):
|
|
chunk_metadata = {
|
|
**request_metadata,
|
|
**chunk.metadata,
|
|
"token_count": chunk.token_count,
|
|
"source_path": chunk.source_path or source_path,
|
|
"start_line": chunk.start_line,
|
|
"end_line": chunk.end_line,
|
|
"file_type": (chunk.file_type or file_type).value if (chunk.file_type or file_type) else None,
|
|
}
|
|
embeddings_data.append((
|
|
chunk.content,
|
|
embedding,
|
|
chunk.chunk_type,
|
|
chunk_metadata,
|
|
))
|
|
|
|
# Atomically replace old embeddings with new ones
|
|
_, chunk_ids = await self.database.replace_source_embeddings(
|
|
project_id=project_id,
|
|
source_path=source_path,
|
|
collection=collection,
|
|
embeddings=embeddings_data,
|
|
)
|
|
|
|
return IngestResult(
|
|
success=True,
|
|
chunks_created=len(chunk_ids),
|
|
embeddings_generated=len(embeddings_list),
|
|
source_path=source_path,
|
|
collection=collection,
|
|
chunk_ids=chunk_ids,
|
|
)
|
|
|
|
async def cleanup_expired(self) -> int:
|
|
"""
|
|
Remove expired embeddings from all collections.
|
|
|
|
Returns:
|
|
Number of embeddings removed
|
|
"""
|
|
return await self.database.cleanup_expired()
|
|
|
|
|
|
# Global collection manager instance (lazy initialization)
|
|
_collection_manager: CollectionManager | None = None
|
|
|
|
|
|
def get_collection_manager() -> CollectionManager:
|
|
"""Get the global collection manager instance."""
|
|
global _collection_manager
|
|
if _collection_manager is None:
|
|
_collection_manager = CollectionManager()
|
|
return _collection_manager
|
|
|
|
|
|
def reset_collection_manager() -> None:
|
|
"""Reset the global collection manager (for testing)."""
|
|
global _collection_manager
|
|
_collection_manager = None
|