""" Database management for Knowledge Base MCP Server. Handles PostgreSQL connections with pgvector extension for vector similarity search operations. """ import hashlib import logging import uuid from contextlib import asynccontextmanager from datetime import UTC, datetime, timedelta from typing import Any import asyncpg from pgvector.asyncpg import register_vector from config import Settings, get_settings from exceptions import ( CollectionNotFoundError, DatabaseConnectionError, DatabaseQueryError, ErrorCode, KnowledgeBaseError, ) from models import ( ChunkType, CollectionInfo, CollectionStats, FileType, KnowledgeEmbedding, ) logger = logging.getLogger(__name__) class DatabaseManager: """ Manages PostgreSQL connections and vector operations. Uses asyncpg for async operations and pgvector for vector similarity search. """ def __init__(self, settings: Settings | None = None) -> None: """Initialize database manager.""" self._settings = settings or get_settings() self._pool: asyncpg.Pool | None = None # type: ignore[type-arg] @property def pool(self) -> asyncpg.Pool: # type: ignore[type-arg] """Get connection pool, raising if not initialized.""" if self._pool is None: raise DatabaseConnectionError("Database pool not initialized") return self._pool async def initialize(self) -> None: """Initialize connection pool and create schema.""" try: self._pool = await asyncpg.create_pool( self._settings.database_url, min_size=2, max_size=self._settings.database_pool_size, max_inactive_connection_lifetime=300, init=self._init_connection, ) logger.info("Database pool created successfully") # Create schema await self._create_schema() logger.info("Database schema initialized") except Exception as e: logger.error(f"Failed to initialize database: {e}") raise DatabaseConnectionError( message=f"Failed to initialize database: {e}", cause=e, ) async def _init_connection(self, conn: asyncpg.Connection) -> None: # type: ignore[type-arg] """Initialize a connection with pgvector support.""" await register_vector(conn) async def _create_schema(self) -> None: """Create database schema if not exists.""" async with self.pool.acquire() as conn: # Enable pgvector extension await conn.execute("CREATE EXTENSION IF NOT EXISTS vector") # Create main embeddings table await conn.execute(""" CREATE TABLE IF NOT EXISTS knowledge_embeddings ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), project_id VARCHAR(255) NOT NULL, collection VARCHAR(255) NOT NULL DEFAULT 'default', content TEXT NOT NULL, embedding vector(1536), chunk_type VARCHAR(50) NOT NULL, source_path TEXT, start_line INTEGER, end_line INTEGER, file_type VARCHAR(50), metadata JSONB DEFAULT '{}', content_hash VARCHAR(64), created_at TIMESTAMPTZ DEFAULT NOW(), updated_at TIMESTAMPTZ DEFAULT NOW(), expires_at TIMESTAMPTZ ) """) # Create indexes for common queries await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_embeddings_project_collection ON knowledge_embeddings(project_id, collection) """) await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_embeddings_source_path ON knowledge_embeddings(project_id, source_path) """) await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_embeddings_content_hash ON knowledge_embeddings(project_id, content_hash) """) await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_embeddings_chunk_type ON knowledge_embeddings(project_id, chunk_type) """) # Create HNSW index for vector similarity search # This dramatically improves search performance await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_embeddings_vector_hnsw ON knowledge_embeddings USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64) """) # Create GIN index for full-text search await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_embeddings_content_fts ON knowledge_embeddings USING gin(to_tsvector('english', content)) """) async def close(self) -> None: """Close the connection pool.""" if self._pool: await self._pool.close() self._pool = None logger.info("Database pool closed") @asynccontextmanager async def acquire(self) -> Any: """Acquire a connection from the pool.""" async with self.pool.acquire() as conn: yield conn @staticmethod def compute_content_hash(content: str) -> str: """Compute SHA-256 hash of content for deduplication.""" return hashlib.sha256(content.encode()).hexdigest() async def store_embedding( self, project_id: str, collection: str, content: str, embedding: list[float], chunk_type: ChunkType, source_path: str | None = None, start_line: int | None = None, end_line: int | None = None, file_type: FileType | None = None, metadata: dict[str, Any] | None = None, ) -> str: """ Store an embedding in the database. Returns: The ID of the stored embedding. """ content_hash = self.compute_content_hash(content) metadata = metadata or {} # Calculate expiration if TTL is set expires_at = None if self._settings.embedding_ttl_days > 0: expires_at = datetime.now(UTC) + timedelta( days=self._settings.embedding_ttl_days ) try: async with self.acquire() as conn: # Check for duplicate content existing = await conn.fetchval( """ SELECT id FROM knowledge_embeddings WHERE project_id = $1 AND collection = $2 AND content_hash = $3 """, project_id, collection, content_hash, ) if existing: # Update existing embedding await conn.execute( """ UPDATE knowledge_embeddings SET embedding = $1, updated_at = NOW(), expires_at = $2, metadata = $3, source_path = $4, start_line = $5, end_line = $6, file_type = $7 WHERE id = $8 """, embedding, expires_at, metadata, source_path, start_line, end_line, file_type.value if file_type else None, existing, ) logger.debug(f"Updated existing embedding: {existing}") return str(existing) # Insert new embedding embedding_id = await conn.fetchval( """ INSERT INTO knowledge_embeddings (project_id, collection, content, embedding, chunk_type, source_path, start_line, end_line, file_type, metadata, content_hash, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) RETURNING id """, project_id, collection, content, embedding, chunk_type.value, source_path, start_line, end_line, file_type.value if file_type else None, metadata, content_hash, expires_at, ) logger.debug(f"Stored new embedding: {embedding_id}") return str(embedding_id) except asyncpg.PostgresError as e: logger.error(f"Database error storing embedding: {e}") raise DatabaseQueryError( message=f"Failed to store embedding: {e}", cause=e, ) async def store_embeddings_batch( self, embeddings: list[tuple[str, str, str, list[float], ChunkType, dict[str, Any]]], ) -> list[str]: """ Store multiple embeddings in a batch. Args: embeddings: List of (project_id, collection, content, embedding, chunk_type, metadata) Returns: List of created embedding IDs. """ if not embeddings: return [] ids = [] expires_at = None if self._settings.embedding_ttl_days > 0: expires_at = datetime.now(UTC) + timedelta( days=self._settings.embedding_ttl_days ) try: async with self.acquire() as conn: # Wrap in transaction for all-or-nothing batch semantics async with conn.transaction(): for project_id, collection, content, embedding, chunk_type, metadata in embeddings: content_hash = self.compute_content_hash(content) source_path = metadata.get("source_path") start_line = metadata.get("start_line") end_line = metadata.get("end_line") file_type = metadata.get("file_type") embedding_id = await conn.fetchval( """ INSERT INTO knowledge_embeddings (project_id, collection, content, embedding, chunk_type, source_path, start_line, end_line, file_type, metadata, content_hash, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) ON CONFLICT DO NOTHING RETURNING id """, project_id, collection, content, embedding, chunk_type.value, source_path, start_line, end_line, file_type, metadata, content_hash, expires_at, ) if embedding_id: ids.append(str(embedding_id)) logger.info(f"Stored {len(ids)} embeddings in batch") return ids except asyncpg.PostgresError as e: logger.error(f"Database error in batch store: {e}") raise DatabaseQueryError( message=f"Failed to store embeddings batch: {e}", cause=e, ) async def semantic_search( self, project_id: str, query_embedding: list[float], collection: str | None = None, limit: int = 10, threshold: float = 0.7, file_types: list[FileType] | None = None, ) -> list[tuple[KnowledgeEmbedding, float]]: """ Perform semantic (vector) search. Returns: List of (embedding, similarity_score) tuples. """ try: async with self.acquire() as conn: # Build query with optional filters using CTE to filter by similarity # We use a CTE to compute similarity once, then filter in outer query inner_query = """ SELECT id, project_id, collection, content, embedding, chunk_type, source_path, start_line, end_line, file_type, metadata, content_hash, created_at, updated_at, expires_at, 1 - (embedding <=> $1) as similarity FROM knowledge_embeddings WHERE project_id = $2 AND (expires_at IS NULL OR expires_at > NOW()) """ params: list[Any] = [query_embedding, project_id] param_idx = 3 if collection: inner_query += f" AND collection = ${param_idx}" params.append(collection) param_idx += 1 if file_types: file_type_values = [ft.value for ft in file_types] inner_query += f" AND file_type = ANY(${param_idx})" params.append(file_type_values) param_idx += 1 # Wrap in CTE and filter by threshold in outer query query = f""" WITH scored AS ({inner_query}) SELECT * FROM scored WHERE similarity >= ${param_idx} ORDER BY similarity DESC LIMIT ${param_idx + 1} """ params.extend([threshold, limit]) rows = await conn.fetch(query, *params) results = [] for row in rows: embedding = KnowledgeEmbedding( id=str(row["id"]), project_id=row["project_id"], collection=row["collection"], content=row["content"], embedding=list(row["embedding"]), chunk_type=ChunkType(row["chunk_type"]), source_path=row["source_path"], start_line=row["start_line"], end_line=row["end_line"], file_type=FileType(row["file_type"]) if row["file_type"] else None, metadata=row["metadata"] or {}, content_hash=row["content_hash"], created_at=row["created_at"], updated_at=row["updated_at"], expires_at=row["expires_at"], ) results.append((embedding, float(row["similarity"]))) return results except asyncpg.PostgresError as e: logger.error(f"Semantic search error: {e}") raise DatabaseQueryError( message=f"Semantic search failed: {e}", cause=e, ) async def keyword_search( self, project_id: str, query: str, collection: str | None = None, limit: int = 10, file_types: list[FileType] | None = None, ) -> list[tuple[KnowledgeEmbedding, float]]: """ Perform full-text keyword search. Returns: List of (embedding, relevance_score) tuples. """ try: async with self.acquire() as conn: # Build query with optional filters sql = """ SELECT id, project_id, collection, content, embedding, chunk_type, source_path, start_line, end_line, file_type, metadata, content_hash, created_at, updated_at, expires_at, ts_rank(to_tsvector('english', content), plainto_tsquery('english', $1)) as relevance FROM knowledge_embeddings WHERE project_id = $2 AND to_tsvector('english', content) @@ plainto_tsquery('english', $1) AND (expires_at IS NULL OR expires_at > NOW()) """ params: list[Any] = [query, project_id] param_idx = 3 if collection: sql += f" AND collection = ${param_idx}" params.append(collection) param_idx += 1 if file_types: file_type_values = [ft.value for ft in file_types] sql += f" AND file_type = ANY(${param_idx})" params.append(file_type_values) param_idx += 1 sql += f" ORDER BY relevance DESC LIMIT ${param_idx}" params.append(limit) rows = await conn.fetch(sql, *params) results = [] for row in rows: embedding = KnowledgeEmbedding( id=str(row["id"]), project_id=row["project_id"], collection=row["collection"], content=row["content"], embedding=list(row["embedding"]) if row["embedding"] else [], chunk_type=ChunkType(row["chunk_type"]), source_path=row["source_path"], start_line=row["start_line"], end_line=row["end_line"], file_type=FileType(row["file_type"]) if row["file_type"] else None, metadata=row["metadata"] or {}, content_hash=row["content_hash"], created_at=row["created_at"], updated_at=row["updated_at"], expires_at=row["expires_at"], ) # Normalize relevance to 0-1 scale (approximate) normalized_score = min(1.0, float(row["relevance"])) results.append((embedding, normalized_score)) return results except asyncpg.PostgresError as e: logger.error(f"Keyword search error: {e}") raise DatabaseQueryError( message=f"Keyword search failed: {e}", cause=e, ) async def delete_by_source( self, project_id: str, source_path: str, collection: str | None = None, ) -> int: """Delete all embeddings for a source path.""" try: async with self.acquire() as conn: if collection: result = await conn.execute( """ DELETE FROM knowledge_embeddings WHERE project_id = $1 AND source_path = $2 AND collection = $3 """, project_id, source_path, collection, ) else: result = await conn.execute( """ DELETE FROM knowledge_embeddings WHERE project_id = $1 AND source_path = $2 """, project_id, source_path, ) # Parse "DELETE N" result count = int(result.split()[-1]) logger.info(f"Deleted {count} embeddings for source: {source_path}") return count except asyncpg.PostgresError as e: logger.error(f"Delete error: {e}") raise DatabaseQueryError( message=f"Failed to delete by source: {e}", cause=e, ) async def replace_source_embeddings( self, project_id: str, source_path: str, collection: str, embeddings: list[tuple[str, list[float], ChunkType, dict[str, Any]]], ) -> tuple[int, list[str]]: """ Atomically replace all embeddings for a source path. Deletes existing embeddings and inserts new ones in a single transaction, preventing race conditions during document updates. Args: project_id: Project ID source_path: Source file path being updated collection: Collection name embeddings: List of (content, embedding, chunk_type, metadata) Returns: Tuple of (deleted_count, new_embedding_ids) """ expires_at = None if self._settings.embedding_ttl_days > 0: expires_at = datetime.now(UTC) + timedelta( days=self._settings.embedding_ttl_days ) try: async with self.acquire() as conn: # Use transaction for atomic replace async with conn.transaction(): # First, delete existing embeddings for this source delete_result = await conn.execute( """ DELETE FROM knowledge_embeddings WHERE project_id = $1 AND source_path = $2 AND collection = $3 """, project_id, source_path, collection, ) deleted_count = int(delete_result.split()[-1]) # Then insert new embeddings new_ids = [] for content, embedding, chunk_type, metadata in embeddings: content_hash = self.compute_content_hash(content) start_line = metadata.get("start_line") end_line = metadata.get("end_line") file_type = metadata.get("file_type") embedding_id = await conn.fetchval( """ INSERT INTO knowledge_embeddings (project_id, collection, content, embedding, chunk_type, source_path, start_line, end_line, file_type, metadata, content_hash, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) RETURNING id """, project_id, collection, content, embedding, chunk_type.value, source_path, start_line, end_line, file_type, metadata, content_hash, expires_at, ) if embedding_id: new_ids.append(str(embedding_id)) logger.info( f"Replaced source {source_path}: deleted {deleted_count}, " f"inserted {len(new_ids)} embeddings" ) return deleted_count, new_ids except asyncpg.PostgresError as e: logger.error(f"Replace source error: {e}") raise DatabaseQueryError( message=f"Failed to replace source embeddings: {e}", cause=e, ) async def delete_collection( self, project_id: str, collection: str, ) -> int: """Delete an entire collection.""" try: async with self.acquire() as conn: result = await conn.execute( """ DELETE FROM knowledge_embeddings WHERE project_id = $1 AND collection = $2 """, project_id, collection, ) count = int(result.split()[-1]) logger.info(f"Deleted collection {collection}: {count} embeddings") return count except asyncpg.PostgresError as e: logger.error(f"Delete collection error: {e}") raise DatabaseQueryError( message=f"Failed to delete collection: {e}", cause=e, ) async def delete_by_ids( self, project_id: str, chunk_ids: list[str], ) -> int: """Delete specific embeddings by ID.""" if not chunk_ids: return 0 try: # Convert string IDs to UUIDs uuids = [uuid.UUID(cid) for cid in chunk_ids] async with self.acquire() as conn: result = await conn.execute( """ DELETE FROM knowledge_embeddings WHERE project_id = $1 AND id = ANY($2) """, project_id, uuids, ) count = int(result.split()[-1]) logger.info(f"Deleted {count} embeddings by ID") return count except ValueError as e: raise KnowledgeBaseError( message=f"Invalid chunk ID format: {e}", code=ErrorCode.INVALID_REQUEST, ) except asyncpg.PostgresError as e: logger.error(f"Delete by IDs error: {e}") raise DatabaseQueryError( message=f"Failed to delete by IDs: {e}", cause=e, ) async def list_collections( self, project_id: str, ) -> list[CollectionInfo]: """List all collections for a project.""" try: async with self.acquire() as conn: rows = await conn.fetch( """ SELECT collection, COUNT(*) as chunk_count, COALESCE(SUM((metadata->>'token_count')::int), 0) as total_tokens, ARRAY_AGG(DISTINCT file_type) FILTER (WHERE file_type IS NOT NULL) as file_types, MIN(created_at) as created_at, MAX(updated_at) as updated_at FROM knowledge_embeddings WHERE project_id = $1 AND (expires_at IS NULL OR expires_at > NOW()) GROUP BY collection ORDER BY collection """, project_id, ) return [ CollectionInfo( name=row["collection"], project_id=project_id, chunk_count=row["chunk_count"], total_tokens=row["total_tokens"] or 0, file_types=row["file_types"] or [], created_at=row["created_at"], updated_at=row["updated_at"], ) for row in rows ] except asyncpg.PostgresError as e: logger.error(f"List collections error: {e}") raise DatabaseQueryError( message=f"Failed to list collections: {e}", cause=e, ) async def get_collection_stats( self, project_id: str, collection: str, ) -> CollectionStats: """Get detailed statistics for a collection.""" try: async with self.acquire() as conn: # Check if collection exists exists = await conn.fetchval( """ SELECT EXISTS( SELECT 1 FROM knowledge_embeddings WHERE project_id = $1 AND collection = $2 ) """, project_id, collection, ) if not exists: raise CollectionNotFoundError(collection, project_id) # Get stats row = await conn.fetchrow( """ SELECT COUNT(*) as chunk_count, COUNT(DISTINCT source_path) as unique_sources, COALESCE(SUM((metadata->>'token_count')::int), 0) as total_tokens, COALESCE(AVG(LENGTH(content)), 0) as avg_chunk_size, MIN(created_at) as oldest_chunk, MAX(created_at) as newest_chunk FROM knowledge_embeddings WHERE project_id = $1 AND collection = $2 AND (expires_at IS NULL OR expires_at > NOW()) """, project_id, collection, ) # Get chunk type breakdown chunk_rows = await conn.fetch( """ SELECT chunk_type, COUNT(*) as count FROM knowledge_embeddings WHERE project_id = $1 AND collection = $2 AND (expires_at IS NULL OR expires_at > NOW()) GROUP BY chunk_type """, project_id, collection, ) chunk_types = {r["chunk_type"]: r["count"] for r in chunk_rows} # Get file type breakdown file_rows = await conn.fetch( """ SELECT file_type, COUNT(*) as count FROM knowledge_embeddings WHERE project_id = $1 AND collection = $2 AND file_type IS NOT NULL AND (expires_at IS NULL OR expires_at > NOW()) GROUP BY file_type """, project_id, collection, ) file_types = {r["file_type"]: r["count"] for r in file_rows} return CollectionStats( collection=collection, project_id=project_id, chunk_count=row["chunk_count"], unique_sources=row["unique_sources"], total_tokens=row["total_tokens"] or 0, avg_chunk_size=float(row["avg_chunk_size"] or 0), chunk_types=chunk_types, file_types=file_types, oldest_chunk=row["oldest_chunk"], newest_chunk=row["newest_chunk"], ) except CollectionNotFoundError: raise except asyncpg.PostgresError as e: logger.error(f"Get collection stats error: {e}") raise DatabaseQueryError( message=f"Failed to get collection stats: {e}", cause=e, ) async def cleanup_expired(self) -> int: """Remove expired embeddings.""" try: async with self.acquire() as conn: result = await conn.execute( """ DELETE FROM knowledge_embeddings WHERE expires_at IS NOT NULL AND expires_at < NOW() """ ) count = int(result.split()[-1]) if count > 0: logger.info(f"Cleaned up {count} expired embeddings") return count except asyncpg.PostgresError as e: logger.error(f"Cleanup error: {e}") raise DatabaseQueryError( message=f"Failed to cleanup expired: {e}", cause=e, ) # Global database manager instance (lazy initialization) _db_manager: DatabaseManager | None = None def get_database_manager() -> DatabaseManager: """Get the global database manager instance.""" global _db_manager if _db_manager is None: _db_manager = DatabaseManager() return _db_manager def reset_database_manager() -> None: """Reset the global database manager (for testing).""" global _db_manager _db_manager = None