""" Collection management for Knowledge Base MCP Server. Provides operations for managing document collections including ingestion, deletion, and statistics. """ import logging from typing import Any from chunking.base import ChunkerFactory, get_chunker_factory from config import Settings, get_settings from database import DatabaseManager, get_database_manager from embeddings import EmbeddingGenerator, get_embedding_generator from models import ( ChunkType, CollectionStats, DeleteRequest, DeleteResult, FileType, IngestRequest, IngestResult, ListCollectionsResponse, ) logger = logging.getLogger(__name__) class CollectionManager: """ Manages knowledge base collections. Handles document ingestion, chunking, embedding generation, and collection operations. """ def __init__( self, settings: Settings | None = None, database: DatabaseManager | None = None, embeddings: EmbeddingGenerator | None = None, chunker_factory: ChunkerFactory | None = None, ) -> None: """Initialize collection manager.""" self._settings = settings or get_settings() self._database = database self._embeddings = embeddings self._chunker_factory = chunker_factory @property def database(self) -> DatabaseManager: """Get database manager.""" if self._database is None: self._database = get_database_manager() return self._database @property def embeddings(self) -> EmbeddingGenerator: """Get embedding generator.""" if self._embeddings is None: self._embeddings = get_embedding_generator() return self._embeddings @property def chunker_factory(self) -> ChunkerFactory: """Get chunker factory.""" if self._chunker_factory is None: self._chunker_factory = get_chunker_factory() return self._chunker_factory async def ingest(self, request: IngestRequest) -> IngestResult: """ Ingest content into the knowledge base. Chunks the content, generates embeddings, and stores them. Args: request: Ingest request with content and options Returns: Ingest result with created chunk IDs """ try: # Chunk the content chunks = self.chunker_factory.chunk_content( content=request.content, source_path=request.source_path, file_type=request.file_type, chunk_type=request.chunk_type, metadata=request.metadata, ) if not chunks: return IngestResult( success=True, chunks_created=0, embeddings_generated=0, source_path=request.source_path, collection=request.collection, chunk_ids=[], ) # Extract chunk contents for embedding chunk_texts = [chunk.content for chunk in chunks] # Generate embeddings embeddings_list = await self.embeddings.generate_batch( texts=chunk_texts, project_id=request.project_id, agent_id=request.agent_id, ) # Store embeddings chunk_ids: list[str] = [] for chunk, embedding in zip(chunks, embeddings_list, strict=True): # Build metadata with chunk info chunk_metadata = { **request.metadata, **chunk.metadata, "token_count": chunk.token_count, } chunk_id = await self.database.store_embedding( project_id=request.project_id, collection=request.collection, content=chunk.content, embedding=embedding, chunk_type=chunk.chunk_type, source_path=chunk.source_path or request.source_path, start_line=chunk.start_line, end_line=chunk.end_line, file_type=chunk.file_type or request.file_type, metadata=chunk_metadata, ) chunk_ids.append(chunk_id) logger.info( f"Ingested {len(chunks)} chunks into collection '{request.collection}' " f"for project {request.project_id}" ) return IngestResult( success=True, chunks_created=len(chunks), embeddings_generated=len(embeddings_list), source_path=request.source_path, collection=request.collection, chunk_ids=chunk_ids, ) except Exception as e: logger.error(f"Ingest error: {e}") return IngestResult( success=False, chunks_created=0, embeddings_generated=0, source_path=request.source_path, collection=request.collection, chunk_ids=[], error=str(e), ) async def delete(self, request: DeleteRequest) -> DeleteResult: """ Delete content from the knowledge base. Supports deletion by source path, collection, or chunk IDs. Args: request: Delete request with target specification Returns: Delete result with count of deleted chunks """ try: deleted_count = 0 if request.chunk_ids: # Delete specific chunks deleted_count = await self.database.delete_by_ids( project_id=request.project_id, chunk_ids=request.chunk_ids, ) elif request.source_path: # Delete by source path deleted_count = await self.database.delete_by_source( project_id=request.project_id, source_path=request.source_path, collection=request.collection, ) elif request.collection: # Delete entire collection deleted_count = await self.database.delete_collection( project_id=request.project_id, collection=request.collection, ) else: return DeleteResult( success=False, chunks_deleted=0, error="Must specify chunk_ids, source_path, or collection", ) logger.info( f"Deleted {deleted_count} chunks for project {request.project_id}" ) return DeleteResult( success=True, chunks_deleted=deleted_count, ) except Exception as e: logger.error(f"Delete error: {e}") return DeleteResult( success=False, chunks_deleted=0, error=str(e), ) async def list_collections(self, project_id: str) -> ListCollectionsResponse: """ List all collections for a project. Args: project_id: Project ID Returns: List of collection info """ collections = await self.database.list_collections(project_id) return ListCollectionsResponse( project_id=project_id, collections=collections, total_collections=len(collections), ) async def get_collection_stats( self, project_id: str, collection: str, ) -> CollectionStats: """ Get statistics for a collection. Args: project_id: Project ID collection: Collection name Returns: Collection statistics """ return await self.database.get_collection_stats(project_id, collection) async def update_document( self, project_id: str, agent_id: str, source_path: str, content: str, collection: str = "default", chunk_type: ChunkType = ChunkType.TEXT, file_type: FileType | None = None, metadata: dict[str, Any] | None = None, ) -> IngestResult: """ Update a document by atomically replacing existing chunks. Uses a database transaction to delete existing chunks and insert new ones atomically, preventing race conditions during concurrent updates. Args: project_id: Project ID agent_id: Agent ID source_path: Source file path content: New content collection: Collection name chunk_type: Type of content file_type: File type for code chunking metadata: Additional metadata Returns: Ingest result """ request_metadata = metadata or {} # Chunk the content chunks = self.chunker_factory.chunk_content( content=content, source_path=source_path, file_type=file_type, chunk_type=chunk_type, metadata=request_metadata, ) if not chunks: # No chunks = delete existing and return empty result await self.database.delete_by_source( project_id=project_id, source_path=source_path, collection=collection, ) return IngestResult( success=True, chunks_created=0, embeddings_generated=0, source_path=source_path, collection=collection, chunk_ids=[], ) # Generate embeddings for new chunks chunk_texts = [chunk.content for chunk in chunks] embeddings_list = await self.embeddings.generate_batch( texts=chunk_texts, project_id=project_id, agent_id=agent_id, ) # Build embeddings data for transactional replace embeddings_data = [] for chunk, embedding in zip(chunks, embeddings_list, strict=True): chunk_metadata = { **request_metadata, **chunk.metadata, "token_count": chunk.token_count, "source_path": chunk.source_path or source_path, "start_line": chunk.start_line, "end_line": chunk.end_line, "file_type": effective_file_type.value if (effective_file_type := chunk.file_type or file_type) else None, } embeddings_data.append(( chunk.content, embedding, chunk.chunk_type, chunk_metadata, )) # Atomically replace old embeddings with new ones _, chunk_ids = await self.database.replace_source_embeddings( project_id=project_id, source_path=source_path, collection=collection, embeddings=embeddings_data, ) return IngestResult( success=True, chunks_created=len(chunk_ids), embeddings_generated=len(embeddings_list), source_path=source_path, collection=collection, chunk_ids=chunk_ids, ) async def cleanup_expired(self) -> int: """ Remove expired embeddings from all collections. Returns: Number of embeddings removed """ return await self.database.cleanup_expired() # Global collection manager instance (lazy initialization) _collection_manager: CollectionManager | None = None def get_collection_manager() -> CollectionManager: """Get the global collection manager instance.""" global _collection_manager if _collection_manager is None: _collection_manager = CollectionManager() return _collection_manager def reset_collection_manager() -> None: """Reset the global collection manager (for testing).""" global _collection_manager _collection_manager = None