diff --git a/mcp-servers/knowledge-base/database.py b/mcp-servers/knowledge-base/database.py index 41c2da5..551037d 100644 --- a/mcp-servers/knowledge-base/database.py +++ b/mcp-servers/knowledge-base/database.py @@ -57,6 +57,9 @@ class DatabaseManager: async def initialize(self) -> None: """Initialize connection pool and create schema.""" try: + # First, create pgvector extension (required before register_vector in pool init) + await self._ensure_pgvector_extension() + self._pool = await asyncpg.create_pool( self._settings.database_url, min_size=2, @@ -66,7 +69,7 @@ class DatabaseManager: ) logger.info("Database pool created successfully") - # Create schema + # Create schema (tables and indexes) await self._create_schema() logger.info("Database schema initialized") @@ -77,6 +80,19 @@ class DatabaseManager: cause=e, ) + async def _ensure_pgvector_extension(self) -> None: + """Ensure pgvector extension exists before pool creation. + + This must run before creating the connection pool because + register_vector() in _init_connection requires the extension to exist. + """ + conn = await asyncpg.connect(self._settings.database_url) + try: + await conn.execute("CREATE EXTENSION IF NOT EXISTS vector") + logger.info("pgvector extension ensured") + finally: + await conn.close() + async def _init_connection(self, conn: asyncpg.Connection) -> None: # type: ignore[type-arg] """Initialize a connection with pgvector support.""" await register_vector(conn) @@ -84,8 +100,7 @@ class DatabaseManager: async def _create_schema(self) -> None: """Create database schema if not exists.""" async with self.pool.acquire() as conn: - # Enable pgvector extension - await conn.execute("CREATE EXTENSION IF NOT EXISTS vector") + # Note: pgvector extension is created in _ensure_pgvector_extension() # Create main embeddings table await conn.execute(""" @@ -286,7 +301,14 @@ class DatabaseManager: try: async with self.acquire() as conn, conn.transaction(): # Wrap in transaction for all-or-nothing batch semantics - for project_id, collection, content, embedding, chunk_type, metadata in embeddings: + for ( + project_id, + collection, + content, + embedding, + chunk_type, + metadata, + ) in embeddings: content_hash = self.compute_content_hash(content) source_path = metadata.get("source_path") start_line = metadata.get("start_line") @@ -397,7 +419,9 @@ class DatabaseManager: source_path=row["source_path"], start_line=row["start_line"], end_line=row["end_line"], - file_type=FileType(row["file_type"]) if row["file_type"] else None, + file_type=FileType(row["file_type"]) + if row["file_type"] + else None, metadata=row["metadata"] or {}, content_hash=row["content_hash"], created_at=row["created_at"], @@ -476,7 +500,9 @@ class DatabaseManager: source_path=row["source_path"], start_line=row["start_line"], end_line=row["end_line"], - file_type=FileType(row["file_type"]) if row["file_type"] else None, + file_type=FileType(row["file_type"]) + if row["file_type"] + else None, metadata=row["metadata"] or {}, content_hash=row["content_hash"], created_at=row["created_at"],