forked from cardosofelipe/fast-next-template
Improved code formatting, line breaks, and indentation across chunking logic and multiple test modules to enhance code clarity and maintain consistent style. No functional changes made.
387 lines
12 KiB
Python
387 lines
12 KiB
Python
"""
|
|
Collection management for Knowledge Base MCP Server.
|
|
|
|
Provides operations for managing document collections including
|
|
ingestion, deletion, and statistics.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Any
|
|
|
|
from chunking.base import ChunkerFactory, get_chunker_factory
|
|
from config import Settings, get_settings
|
|
from database import DatabaseManager, get_database_manager
|
|
from embeddings import EmbeddingGenerator, get_embedding_generator
|
|
from models import (
|
|
ChunkType,
|
|
CollectionStats,
|
|
DeleteRequest,
|
|
DeleteResult,
|
|
FileType,
|
|
IngestRequest,
|
|
IngestResult,
|
|
ListCollectionsResponse,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class CollectionManager:
|
|
"""
|
|
Manages knowledge base collections.
|
|
|
|
Handles document ingestion, chunking, embedding generation,
|
|
and collection operations.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
settings: Settings | None = None,
|
|
database: DatabaseManager | None = None,
|
|
embeddings: EmbeddingGenerator | None = None,
|
|
chunker_factory: ChunkerFactory | None = None,
|
|
) -> None:
|
|
"""Initialize collection manager."""
|
|
self._settings = settings or get_settings()
|
|
self._database = database
|
|
self._embeddings = embeddings
|
|
self._chunker_factory = chunker_factory
|
|
|
|
@property
|
|
def database(self) -> DatabaseManager:
|
|
"""Get database manager."""
|
|
if self._database is None:
|
|
self._database = get_database_manager()
|
|
return self._database
|
|
|
|
@property
|
|
def embeddings(self) -> EmbeddingGenerator:
|
|
"""Get embedding generator."""
|
|
if self._embeddings is None:
|
|
self._embeddings = get_embedding_generator()
|
|
return self._embeddings
|
|
|
|
@property
|
|
def chunker_factory(self) -> ChunkerFactory:
|
|
"""Get chunker factory."""
|
|
if self._chunker_factory is None:
|
|
self._chunker_factory = get_chunker_factory()
|
|
return self._chunker_factory
|
|
|
|
async def ingest(self, request: IngestRequest) -> IngestResult:
|
|
"""
|
|
Ingest content into the knowledge base.
|
|
|
|
Chunks the content, generates embeddings, and stores them.
|
|
|
|
Args:
|
|
request: Ingest request with content and options
|
|
|
|
Returns:
|
|
Ingest result with created chunk IDs
|
|
"""
|
|
try:
|
|
# Chunk the content
|
|
chunks = self.chunker_factory.chunk_content(
|
|
content=request.content,
|
|
source_path=request.source_path,
|
|
file_type=request.file_type,
|
|
chunk_type=request.chunk_type,
|
|
metadata=request.metadata,
|
|
)
|
|
|
|
if not chunks:
|
|
return IngestResult(
|
|
success=True,
|
|
chunks_created=0,
|
|
embeddings_generated=0,
|
|
source_path=request.source_path,
|
|
collection=request.collection,
|
|
chunk_ids=[],
|
|
)
|
|
|
|
# Extract chunk contents for embedding
|
|
chunk_texts = [chunk.content for chunk in chunks]
|
|
|
|
# Generate embeddings
|
|
embeddings_list = await self.embeddings.generate_batch(
|
|
texts=chunk_texts,
|
|
project_id=request.project_id,
|
|
agent_id=request.agent_id,
|
|
)
|
|
|
|
# Store embeddings
|
|
chunk_ids: list[str] = []
|
|
for chunk, embedding in zip(chunks, embeddings_list, strict=True):
|
|
# Build metadata with chunk info
|
|
chunk_metadata = {
|
|
**request.metadata,
|
|
**chunk.metadata,
|
|
"token_count": chunk.token_count,
|
|
}
|
|
|
|
chunk_id = await self.database.store_embedding(
|
|
project_id=request.project_id,
|
|
collection=request.collection,
|
|
content=chunk.content,
|
|
embedding=embedding,
|
|
chunk_type=chunk.chunk_type,
|
|
source_path=chunk.source_path or request.source_path,
|
|
start_line=chunk.start_line,
|
|
end_line=chunk.end_line,
|
|
file_type=chunk.file_type or request.file_type,
|
|
metadata=chunk_metadata,
|
|
)
|
|
chunk_ids.append(chunk_id)
|
|
|
|
logger.info(
|
|
f"Ingested {len(chunks)} chunks into collection '{request.collection}' "
|
|
f"for project {request.project_id}"
|
|
)
|
|
|
|
return IngestResult(
|
|
success=True,
|
|
chunks_created=len(chunks),
|
|
embeddings_generated=len(embeddings_list),
|
|
source_path=request.source_path,
|
|
collection=request.collection,
|
|
chunk_ids=chunk_ids,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Ingest error: {e}")
|
|
return IngestResult(
|
|
success=False,
|
|
chunks_created=0,
|
|
embeddings_generated=0,
|
|
source_path=request.source_path,
|
|
collection=request.collection,
|
|
chunk_ids=[],
|
|
error=str(e),
|
|
)
|
|
|
|
async def delete(self, request: DeleteRequest) -> DeleteResult:
|
|
"""
|
|
Delete content from the knowledge base.
|
|
|
|
Supports deletion by source path, collection, or chunk IDs.
|
|
|
|
Args:
|
|
request: Delete request with target specification
|
|
|
|
Returns:
|
|
Delete result with count of deleted chunks
|
|
"""
|
|
try:
|
|
deleted_count = 0
|
|
|
|
if request.chunk_ids:
|
|
# Delete specific chunks
|
|
deleted_count = await self.database.delete_by_ids(
|
|
project_id=request.project_id,
|
|
chunk_ids=request.chunk_ids,
|
|
)
|
|
elif request.source_path:
|
|
# Delete by source path
|
|
deleted_count = await self.database.delete_by_source(
|
|
project_id=request.project_id,
|
|
source_path=request.source_path,
|
|
collection=request.collection,
|
|
)
|
|
elif request.collection:
|
|
# Delete entire collection
|
|
deleted_count = await self.database.delete_collection(
|
|
project_id=request.project_id,
|
|
collection=request.collection,
|
|
)
|
|
else:
|
|
return DeleteResult(
|
|
success=False,
|
|
chunks_deleted=0,
|
|
error="Must specify chunk_ids, source_path, or collection",
|
|
)
|
|
|
|
logger.info(
|
|
f"Deleted {deleted_count} chunks for project {request.project_id}"
|
|
)
|
|
|
|
return DeleteResult(
|
|
success=True,
|
|
chunks_deleted=deleted_count,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Delete error: {e}")
|
|
return DeleteResult(
|
|
success=False,
|
|
chunks_deleted=0,
|
|
error=str(e),
|
|
)
|
|
|
|
async def list_collections(self, project_id: str) -> ListCollectionsResponse:
|
|
"""
|
|
List all collections for a project.
|
|
|
|
Args:
|
|
project_id: Project ID
|
|
|
|
Returns:
|
|
List of collection info
|
|
"""
|
|
collections = await self.database.list_collections(project_id)
|
|
|
|
return ListCollectionsResponse(
|
|
project_id=project_id,
|
|
collections=collections,
|
|
total_collections=len(collections),
|
|
)
|
|
|
|
async def get_collection_stats(
|
|
self,
|
|
project_id: str,
|
|
collection: str,
|
|
) -> CollectionStats:
|
|
"""
|
|
Get statistics for a collection.
|
|
|
|
Args:
|
|
project_id: Project ID
|
|
collection: Collection name
|
|
|
|
Returns:
|
|
Collection statistics
|
|
"""
|
|
return await self.database.get_collection_stats(project_id, collection)
|
|
|
|
async def update_document(
|
|
self,
|
|
project_id: str,
|
|
agent_id: str,
|
|
source_path: str,
|
|
content: str,
|
|
collection: str = "default",
|
|
chunk_type: ChunkType = ChunkType.TEXT,
|
|
file_type: FileType | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> IngestResult:
|
|
"""
|
|
Update a document by atomically replacing existing chunks.
|
|
|
|
Uses a database transaction to delete existing chunks and insert new ones
|
|
atomically, preventing race conditions during concurrent updates.
|
|
|
|
Args:
|
|
project_id: Project ID
|
|
agent_id: Agent ID
|
|
source_path: Source file path
|
|
content: New content
|
|
collection: Collection name
|
|
chunk_type: Type of content
|
|
file_type: File type for code chunking
|
|
metadata: Additional metadata
|
|
|
|
Returns:
|
|
Ingest result
|
|
"""
|
|
request_metadata = metadata or {}
|
|
|
|
# Chunk the content
|
|
chunks = self.chunker_factory.chunk_content(
|
|
content=content,
|
|
source_path=source_path,
|
|
file_type=file_type,
|
|
chunk_type=chunk_type,
|
|
metadata=request_metadata,
|
|
)
|
|
|
|
if not chunks:
|
|
# No chunks = delete existing and return empty result
|
|
await self.database.delete_by_source(
|
|
project_id=project_id,
|
|
source_path=source_path,
|
|
collection=collection,
|
|
)
|
|
return IngestResult(
|
|
success=True,
|
|
chunks_created=0,
|
|
embeddings_generated=0,
|
|
source_path=source_path,
|
|
collection=collection,
|
|
chunk_ids=[],
|
|
)
|
|
|
|
# Generate embeddings for new chunks
|
|
chunk_texts = [chunk.content for chunk in chunks]
|
|
embeddings_list = await self.embeddings.generate_batch(
|
|
texts=chunk_texts,
|
|
project_id=project_id,
|
|
agent_id=agent_id,
|
|
)
|
|
|
|
# Build embeddings data for transactional replace
|
|
embeddings_data = []
|
|
for chunk, embedding in zip(chunks, embeddings_list, strict=True):
|
|
chunk_metadata = {
|
|
**request_metadata,
|
|
**chunk.metadata,
|
|
"token_count": chunk.token_count,
|
|
"source_path": chunk.source_path or source_path,
|
|
"start_line": chunk.start_line,
|
|
"end_line": chunk.end_line,
|
|
"file_type": effective_file_type.value
|
|
if (effective_file_type := chunk.file_type or file_type)
|
|
else None,
|
|
}
|
|
embeddings_data.append(
|
|
(
|
|
chunk.content,
|
|
embedding,
|
|
chunk.chunk_type,
|
|
chunk_metadata,
|
|
)
|
|
)
|
|
|
|
# Atomically replace old embeddings with new ones
|
|
_, chunk_ids = await self.database.replace_source_embeddings(
|
|
project_id=project_id,
|
|
source_path=source_path,
|
|
collection=collection,
|
|
embeddings=embeddings_data,
|
|
)
|
|
|
|
return IngestResult(
|
|
success=True,
|
|
chunks_created=len(chunk_ids),
|
|
embeddings_generated=len(embeddings_list),
|
|
source_path=source_path,
|
|
collection=collection,
|
|
chunk_ids=chunk_ids,
|
|
)
|
|
|
|
async def cleanup_expired(self) -> int:
|
|
"""
|
|
Remove expired embeddings from all collections.
|
|
|
|
Returns:
|
|
Number of embeddings removed
|
|
"""
|
|
return await self.database.cleanup_expired()
|
|
|
|
|
|
# Global collection manager instance (lazy initialization)
|
|
_collection_manager: CollectionManager | None = None
|
|
|
|
|
|
def get_collection_manager() -> CollectionManager:
|
|
"""Get the global collection manager instance."""
|
|
global _collection_manager
|
|
if _collection_manager is None:
|
|
_collection_manager = CollectionManager()
|
|
return _collection_manager
|
|
|
|
|
|
def reset_collection_manager() -> None:
|
|
"""Reset the global collection manager (for testing)."""
|
|
global _collection_manager
|
|
_collection_manager = None
|