Files
syndarix/mcp-servers/knowledge-base/database.py
Felipe Cardoso cd7a9ccbdf fix(mcp-kb): add transactional batch insert and atomic document update
- Wrap store_embeddings_batch in transaction for all-or-nothing semantics
- Add replace_source_embeddings method for atomic document updates
- Update collection_manager to use transactional replace
- Prevents race conditions and data inconsistency (closes #77)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-04 01:07:40 +01:00

871 lines
32 KiB
Python

"""
Database management for Knowledge Base MCP Server.
Handles PostgreSQL connections with pgvector extension for
vector similarity search operations.
"""
import hashlib
import logging
import uuid
from contextlib import asynccontextmanager
from datetime import UTC, datetime, timedelta
from typing import Any
import asyncpg
from pgvector.asyncpg import register_vector
from config import Settings, get_settings
from exceptions import (
CollectionNotFoundError,
DatabaseConnectionError,
DatabaseQueryError,
ErrorCode,
KnowledgeBaseError,
)
from models import (
ChunkType,
CollectionInfo,
CollectionStats,
FileType,
KnowledgeEmbedding,
)
logger = logging.getLogger(__name__)
class DatabaseManager:
"""
Manages PostgreSQL connections and vector operations.
Uses asyncpg for async operations and pgvector for
vector similarity search.
"""
def __init__(self, settings: Settings | None = None) -> None:
"""Initialize database manager."""
self._settings = settings or get_settings()
self._pool: asyncpg.Pool | None = None # type: ignore[type-arg]
@property
def pool(self) -> asyncpg.Pool: # type: ignore[type-arg]
"""Get connection pool, raising if not initialized."""
if self._pool is None:
raise DatabaseConnectionError("Database pool not initialized")
return self._pool
async def initialize(self) -> None:
"""Initialize connection pool and create schema."""
try:
self._pool = await asyncpg.create_pool(
self._settings.database_url,
min_size=2,
max_size=self._settings.database_pool_size,
max_inactive_connection_lifetime=300,
init=self._init_connection,
)
logger.info("Database pool created successfully")
# Create schema
await self._create_schema()
logger.info("Database schema initialized")
except Exception as e:
logger.error(f"Failed to initialize database: {e}")
raise DatabaseConnectionError(
message=f"Failed to initialize database: {e}",
cause=e,
)
async def _init_connection(self, conn: asyncpg.Connection) -> None: # type: ignore[type-arg]
"""Initialize a connection with pgvector support."""
await register_vector(conn)
async def _create_schema(self) -> None:
"""Create database schema if not exists."""
async with self.pool.acquire() as conn:
# Enable pgvector extension
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
# Create main embeddings table
await conn.execute("""
CREATE TABLE IF NOT EXISTS knowledge_embeddings (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
project_id VARCHAR(255) NOT NULL,
collection VARCHAR(255) NOT NULL DEFAULT 'default',
content TEXT NOT NULL,
embedding vector(1536),
chunk_type VARCHAR(50) NOT NULL,
source_path TEXT,
start_line INTEGER,
end_line INTEGER,
file_type VARCHAR(50),
metadata JSONB DEFAULT '{}',
content_hash VARCHAR(64),
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW(),
expires_at TIMESTAMPTZ
)
""")
# Create indexes for common queries
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_embeddings_project_collection
ON knowledge_embeddings(project_id, collection)
""")
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_embeddings_source_path
ON knowledge_embeddings(project_id, source_path)
""")
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_embeddings_content_hash
ON knowledge_embeddings(project_id, content_hash)
""")
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_embeddings_chunk_type
ON knowledge_embeddings(project_id, chunk_type)
""")
# Create HNSW index for vector similarity search
# This dramatically improves search performance
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_embeddings_vector_hnsw
ON knowledge_embeddings
USING hnsw (embedding vector_cosine_ops)
WITH (m = 16, ef_construction = 64)
""")
# Create GIN index for full-text search
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_embeddings_content_fts
ON knowledge_embeddings
USING gin(to_tsvector('english', content))
""")
async def close(self) -> None:
"""Close the connection pool."""
if self._pool:
await self._pool.close()
self._pool = None
logger.info("Database pool closed")
@asynccontextmanager
async def acquire(self) -> Any:
"""Acquire a connection from the pool."""
async with self.pool.acquire() as conn:
yield conn
@staticmethod
def compute_content_hash(content: str) -> str:
"""Compute SHA-256 hash of content for deduplication."""
return hashlib.sha256(content.encode()).hexdigest()
async def store_embedding(
self,
project_id: str,
collection: str,
content: str,
embedding: list[float],
chunk_type: ChunkType,
source_path: str | None = None,
start_line: int | None = None,
end_line: int | None = None,
file_type: FileType | None = None,
metadata: dict[str, Any] | None = None,
) -> str:
"""
Store an embedding in the database.
Returns:
The ID of the stored embedding.
"""
content_hash = self.compute_content_hash(content)
metadata = metadata or {}
# Calculate expiration if TTL is set
expires_at = None
if self._settings.embedding_ttl_days > 0:
expires_at = datetime.now(UTC) + timedelta(
days=self._settings.embedding_ttl_days
)
try:
async with self.acquire() as conn:
# Check for duplicate content
existing = await conn.fetchval(
"""
SELECT id FROM knowledge_embeddings
WHERE project_id = $1 AND collection = $2 AND content_hash = $3
""",
project_id,
collection,
content_hash,
)
if existing:
# Update existing embedding
await conn.execute(
"""
UPDATE knowledge_embeddings
SET embedding = $1, updated_at = NOW(), expires_at = $2,
metadata = $3, source_path = $4, start_line = $5,
end_line = $6, file_type = $7
WHERE id = $8
""",
embedding,
expires_at,
metadata,
source_path,
start_line,
end_line,
file_type.value if file_type else None,
existing,
)
logger.debug(f"Updated existing embedding: {existing}")
return str(existing)
# Insert new embedding
embedding_id = await conn.fetchval(
"""
INSERT INTO knowledge_embeddings
(project_id, collection, content, embedding, chunk_type,
source_path, start_line, end_line, file_type, metadata,
content_hash, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
RETURNING id
""",
project_id,
collection,
content,
embedding,
chunk_type.value,
source_path,
start_line,
end_line,
file_type.value if file_type else None,
metadata,
content_hash,
expires_at,
)
logger.debug(f"Stored new embedding: {embedding_id}")
return str(embedding_id)
except asyncpg.PostgresError as e:
logger.error(f"Database error storing embedding: {e}")
raise DatabaseQueryError(
message=f"Failed to store embedding: {e}",
cause=e,
)
async def store_embeddings_batch(
self,
embeddings: list[tuple[str, str, str, list[float], ChunkType, dict[str, Any]]],
) -> list[str]:
"""
Store multiple embeddings in a batch.
Args:
embeddings: List of (project_id, collection, content, embedding, chunk_type, metadata)
Returns:
List of created embedding IDs.
"""
if not embeddings:
return []
ids = []
expires_at = None
if self._settings.embedding_ttl_days > 0:
expires_at = datetime.now(UTC) + timedelta(
days=self._settings.embedding_ttl_days
)
try:
async with self.acquire() as conn:
# Wrap in transaction for all-or-nothing batch semantics
async with conn.transaction():
for project_id, collection, content, embedding, chunk_type, metadata in embeddings:
content_hash = self.compute_content_hash(content)
source_path = metadata.get("source_path")
start_line = metadata.get("start_line")
end_line = metadata.get("end_line")
file_type = metadata.get("file_type")
embedding_id = await conn.fetchval(
"""
INSERT INTO knowledge_embeddings
(project_id, collection, content, embedding, chunk_type,
source_path, start_line, end_line, file_type, metadata,
content_hash, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
ON CONFLICT DO NOTHING
RETURNING id
""",
project_id,
collection,
content,
embedding,
chunk_type.value,
source_path,
start_line,
end_line,
file_type,
metadata,
content_hash,
expires_at,
)
if embedding_id:
ids.append(str(embedding_id))
logger.info(f"Stored {len(ids)} embeddings in batch")
return ids
except asyncpg.PostgresError as e:
logger.error(f"Database error in batch store: {e}")
raise DatabaseQueryError(
message=f"Failed to store embeddings batch: {e}",
cause=e,
)
async def semantic_search(
self,
project_id: str,
query_embedding: list[float],
collection: str | None = None,
limit: int = 10,
threshold: float = 0.7,
file_types: list[FileType] | None = None,
) -> list[tuple[KnowledgeEmbedding, float]]:
"""
Perform semantic (vector) search.
Returns:
List of (embedding, similarity_score) tuples.
"""
try:
async with self.acquire() as conn:
# Build query with optional filters using CTE to filter by similarity
# We use a CTE to compute similarity once, then filter in outer query
inner_query = """
SELECT
id, project_id, collection, content, embedding,
chunk_type, source_path, start_line, end_line,
file_type, metadata, content_hash, created_at,
updated_at, expires_at,
1 - (embedding <=> $1) as similarity
FROM knowledge_embeddings
WHERE project_id = $2
AND (expires_at IS NULL OR expires_at > NOW())
"""
params: list[Any] = [query_embedding, project_id]
param_idx = 3
if collection:
inner_query += f" AND collection = ${param_idx}"
params.append(collection)
param_idx += 1
if file_types:
file_type_values = [ft.value for ft in file_types]
inner_query += f" AND file_type = ANY(${param_idx})"
params.append(file_type_values)
param_idx += 1
# Wrap in CTE and filter by threshold in outer query
query = f"""
WITH scored AS ({inner_query})
SELECT * FROM scored
WHERE similarity >= ${param_idx}
ORDER BY similarity DESC
LIMIT ${param_idx + 1}
"""
params.extend([threshold, limit])
rows = await conn.fetch(query, *params)
results = []
for row in rows:
embedding = KnowledgeEmbedding(
id=str(row["id"]),
project_id=row["project_id"],
collection=row["collection"],
content=row["content"],
embedding=list(row["embedding"]),
chunk_type=ChunkType(row["chunk_type"]),
source_path=row["source_path"],
start_line=row["start_line"],
end_line=row["end_line"],
file_type=FileType(row["file_type"]) if row["file_type"] else None,
metadata=row["metadata"] or {},
content_hash=row["content_hash"],
created_at=row["created_at"],
updated_at=row["updated_at"],
expires_at=row["expires_at"],
)
results.append((embedding, float(row["similarity"])))
return results
except asyncpg.PostgresError as e:
logger.error(f"Semantic search error: {e}")
raise DatabaseQueryError(
message=f"Semantic search failed: {e}",
cause=e,
)
async def keyword_search(
self,
project_id: str,
query: str,
collection: str | None = None,
limit: int = 10,
file_types: list[FileType] | None = None,
) -> list[tuple[KnowledgeEmbedding, float]]:
"""
Perform full-text keyword search.
Returns:
List of (embedding, relevance_score) tuples.
"""
try:
async with self.acquire() as conn:
# Build query with optional filters
sql = """
SELECT
id, project_id, collection, content, embedding,
chunk_type, source_path, start_line, end_line,
file_type, metadata, content_hash, created_at,
updated_at, expires_at,
ts_rank(to_tsvector('english', content),
plainto_tsquery('english', $1)) as relevance
FROM knowledge_embeddings
WHERE project_id = $2
AND to_tsvector('english', content) @@ plainto_tsquery('english', $1)
AND (expires_at IS NULL OR expires_at > NOW())
"""
params: list[Any] = [query, project_id]
param_idx = 3
if collection:
sql += f" AND collection = ${param_idx}"
params.append(collection)
param_idx += 1
if file_types:
file_type_values = [ft.value for ft in file_types]
sql += f" AND file_type = ANY(${param_idx})"
params.append(file_type_values)
param_idx += 1
sql += f" ORDER BY relevance DESC LIMIT ${param_idx}"
params.append(limit)
rows = await conn.fetch(sql, *params)
results = []
for row in rows:
embedding = KnowledgeEmbedding(
id=str(row["id"]),
project_id=row["project_id"],
collection=row["collection"],
content=row["content"],
embedding=list(row["embedding"]) if row["embedding"] else [],
chunk_type=ChunkType(row["chunk_type"]),
source_path=row["source_path"],
start_line=row["start_line"],
end_line=row["end_line"],
file_type=FileType(row["file_type"]) if row["file_type"] else None,
metadata=row["metadata"] or {},
content_hash=row["content_hash"],
created_at=row["created_at"],
updated_at=row["updated_at"],
expires_at=row["expires_at"],
)
# Normalize relevance to 0-1 scale (approximate)
normalized_score = min(1.0, float(row["relevance"]))
results.append((embedding, normalized_score))
return results
except asyncpg.PostgresError as e:
logger.error(f"Keyword search error: {e}")
raise DatabaseQueryError(
message=f"Keyword search failed: {e}",
cause=e,
)
async def delete_by_source(
self,
project_id: str,
source_path: str,
collection: str | None = None,
) -> int:
"""Delete all embeddings for a source path."""
try:
async with self.acquire() as conn:
if collection:
result = await conn.execute(
"""
DELETE FROM knowledge_embeddings
WHERE project_id = $1 AND source_path = $2 AND collection = $3
""",
project_id,
source_path,
collection,
)
else:
result = await conn.execute(
"""
DELETE FROM knowledge_embeddings
WHERE project_id = $1 AND source_path = $2
""",
project_id,
source_path,
)
# Parse "DELETE N" result
count = int(result.split()[-1])
logger.info(f"Deleted {count} embeddings for source: {source_path}")
return count
except asyncpg.PostgresError as e:
logger.error(f"Delete error: {e}")
raise DatabaseQueryError(
message=f"Failed to delete by source: {e}",
cause=e,
)
async def replace_source_embeddings(
self,
project_id: str,
source_path: str,
collection: str,
embeddings: list[tuple[str, list[float], ChunkType, dict[str, Any]]],
) -> tuple[int, list[str]]:
"""
Atomically replace all embeddings for a source path.
Deletes existing embeddings and inserts new ones in a single transaction,
preventing race conditions during document updates.
Args:
project_id: Project ID
source_path: Source file path being updated
collection: Collection name
embeddings: List of (content, embedding, chunk_type, metadata)
Returns:
Tuple of (deleted_count, new_embedding_ids)
"""
expires_at = None
if self._settings.embedding_ttl_days > 0:
expires_at = datetime.now(UTC) + timedelta(
days=self._settings.embedding_ttl_days
)
try:
async with self.acquire() as conn:
# Use transaction for atomic replace
async with conn.transaction():
# First, delete existing embeddings for this source
delete_result = await conn.execute(
"""
DELETE FROM knowledge_embeddings
WHERE project_id = $1 AND source_path = $2 AND collection = $3
""",
project_id,
source_path,
collection,
)
deleted_count = int(delete_result.split()[-1])
# Then insert new embeddings
new_ids = []
for content, embedding, chunk_type, metadata in embeddings:
content_hash = self.compute_content_hash(content)
start_line = metadata.get("start_line")
end_line = metadata.get("end_line")
file_type = metadata.get("file_type")
embedding_id = await conn.fetchval(
"""
INSERT INTO knowledge_embeddings
(project_id, collection, content, embedding, chunk_type,
source_path, start_line, end_line, file_type, metadata,
content_hash, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
RETURNING id
""",
project_id,
collection,
content,
embedding,
chunk_type.value,
source_path,
start_line,
end_line,
file_type,
metadata,
content_hash,
expires_at,
)
if embedding_id:
new_ids.append(str(embedding_id))
logger.info(
f"Replaced source {source_path}: deleted {deleted_count}, "
f"inserted {len(new_ids)} embeddings"
)
return deleted_count, new_ids
except asyncpg.PostgresError as e:
logger.error(f"Replace source error: {e}")
raise DatabaseQueryError(
message=f"Failed to replace source embeddings: {e}",
cause=e,
)
async def delete_collection(
self,
project_id: str,
collection: str,
) -> int:
"""Delete an entire collection."""
try:
async with self.acquire() as conn:
result = await conn.execute(
"""
DELETE FROM knowledge_embeddings
WHERE project_id = $1 AND collection = $2
""",
project_id,
collection,
)
count = int(result.split()[-1])
logger.info(f"Deleted collection {collection}: {count} embeddings")
return count
except asyncpg.PostgresError as e:
logger.error(f"Delete collection error: {e}")
raise DatabaseQueryError(
message=f"Failed to delete collection: {e}",
cause=e,
)
async def delete_by_ids(
self,
project_id: str,
chunk_ids: list[str],
) -> int:
"""Delete specific embeddings by ID."""
if not chunk_ids:
return 0
try:
# Convert string IDs to UUIDs
uuids = [uuid.UUID(cid) for cid in chunk_ids]
async with self.acquire() as conn:
result = await conn.execute(
"""
DELETE FROM knowledge_embeddings
WHERE project_id = $1 AND id = ANY($2)
""",
project_id,
uuids,
)
count = int(result.split()[-1])
logger.info(f"Deleted {count} embeddings by ID")
return count
except ValueError as e:
raise KnowledgeBaseError(
message=f"Invalid chunk ID format: {e}",
code=ErrorCode.INVALID_REQUEST,
)
except asyncpg.PostgresError as e:
logger.error(f"Delete by IDs error: {e}")
raise DatabaseQueryError(
message=f"Failed to delete by IDs: {e}",
cause=e,
)
async def list_collections(
self,
project_id: str,
) -> list[CollectionInfo]:
"""List all collections for a project."""
try:
async with self.acquire() as conn:
rows = await conn.fetch(
"""
SELECT
collection,
COUNT(*) as chunk_count,
COALESCE(SUM((metadata->>'token_count')::int), 0) as total_tokens,
ARRAY_AGG(DISTINCT file_type) FILTER (WHERE file_type IS NOT NULL) as file_types,
MIN(created_at) as created_at,
MAX(updated_at) as updated_at
FROM knowledge_embeddings
WHERE project_id = $1
AND (expires_at IS NULL OR expires_at > NOW())
GROUP BY collection
ORDER BY collection
""",
project_id,
)
return [
CollectionInfo(
name=row["collection"],
project_id=project_id,
chunk_count=row["chunk_count"],
total_tokens=row["total_tokens"] or 0,
file_types=row["file_types"] or [],
created_at=row["created_at"],
updated_at=row["updated_at"],
)
for row in rows
]
except asyncpg.PostgresError as e:
logger.error(f"List collections error: {e}")
raise DatabaseQueryError(
message=f"Failed to list collections: {e}",
cause=e,
)
async def get_collection_stats(
self,
project_id: str,
collection: str,
) -> CollectionStats:
"""Get detailed statistics for a collection."""
try:
async with self.acquire() as conn:
# Check if collection exists
exists = await conn.fetchval(
"""
SELECT EXISTS(
SELECT 1 FROM knowledge_embeddings
WHERE project_id = $1 AND collection = $2
)
""",
project_id,
collection,
)
if not exists:
raise CollectionNotFoundError(collection, project_id)
# Get stats
row = await conn.fetchrow(
"""
SELECT
COUNT(*) as chunk_count,
COUNT(DISTINCT source_path) as unique_sources,
COALESCE(SUM((metadata->>'token_count')::int), 0) as total_tokens,
COALESCE(AVG(LENGTH(content)), 0) as avg_chunk_size,
MIN(created_at) as oldest_chunk,
MAX(created_at) as newest_chunk
FROM knowledge_embeddings
WHERE project_id = $1 AND collection = $2
AND (expires_at IS NULL OR expires_at > NOW())
""",
project_id,
collection,
)
# Get chunk type breakdown
chunk_rows = await conn.fetch(
"""
SELECT chunk_type, COUNT(*) as count
FROM knowledge_embeddings
WHERE project_id = $1 AND collection = $2
AND (expires_at IS NULL OR expires_at > NOW())
GROUP BY chunk_type
""",
project_id,
collection,
)
chunk_types = {r["chunk_type"]: r["count"] for r in chunk_rows}
# Get file type breakdown
file_rows = await conn.fetch(
"""
SELECT file_type, COUNT(*) as count
FROM knowledge_embeddings
WHERE project_id = $1 AND collection = $2
AND file_type IS NOT NULL
AND (expires_at IS NULL OR expires_at > NOW())
GROUP BY file_type
""",
project_id,
collection,
)
file_types = {r["file_type"]: r["count"] for r in file_rows}
return CollectionStats(
collection=collection,
project_id=project_id,
chunk_count=row["chunk_count"],
unique_sources=row["unique_sources"],
total_tokens=row["total_tokens"] or 0,
avg_chunk_size=float(row["avg_chunk_size"] or 0),
chunk_types=chunk_types,
file_types=file_types,
oldest_chunk=row["oldest_chunk"],
newest_chunk=row["newest_chunk"],
)
except CollectionNotFoundError:
raise
except asyncpg.PostgresError as e:
logger.error(f"Get collection stats error: {e}")
raise DatabaseQueryError(
message=f"Failed to get collection stats: {e}",
cause=e,
)
async def cleanup_expired(self) -> int:
"""Remove expired embeddings."""
try:
async with self.acquire() as conn:
result = await conn.execute(
"""
DELETE FROM knowledge_embeddings
WHERE expires_at IS NOT NULL AND expires_at < NOW()
"""
)
count = int(result.split()[-1])
if count > 0:
logger.info(f"Cleaned up {count} expired embeddings")
return count
except asyncpg.PostgresError as e:
logger.error(f"Cleanup error: {e}")
raise DatabaseQueryError(
message=f"Failed to cleanup expired: {e}",
cause=e,
)
# Global database manager instance (lazy initialization)
_db_manager: DatabaseManager | None = None
def get_database_manager() -> DatabaseManager:
"""Get the global database manager instance."""
global _db_manager
if _db_manager is None:
_db_manager = DatabaseManager()
return _db_manager
def reset_database_manager() -> None:
"""Reset the global database manager (for testing)."""
global _db_manager
_db_manager = None