feat(knowledge-base): implement Knowledge Base MCP Server (#57)
Implements RAG capabilities with pgvector for semantic search: - Intelligent chunking strategies (code-aware, markdown-aware, text) - Semantic search with vector similarity (HNSW index) - Keyword search with PostgreSQL full-text search - Hybrid search using Reciprocal Rank Fusion (RRF) - Redis caching for embeddings - Collection management (ingest, search, delete, stats) - FastMCP tools: search_knowledge, ingest_content, delete_content, list_collections, get_collection_stats, update_document Testing: - 128 comprehensive tests covering all components - 58% code coverage (database integration tests use mocks) - Passes ruff linting and mypy type checking 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
774
mcp-servers/knowledge-base/database.py
Normal file
774
mcp-servers/knowledge-base/database.py
Normal file
@@ -0,0 +1,774 @@
|
||||
"""
|
||||
Database management for Knowledge Base MCP Server.
|
||||
|
||||
Handles PostgreSQL connections with pgvector extension for
|
||||
vector similarity search operations.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
import asyncpg
|
||||
from pgvector.asyncpg import register_vector
|
||||
|
||||
from config import Settings, get_settings
|
||||
from exceptions import (
|
||||
CollectionNotFoundError,
|
||||
DatabaseConnectionError,
|
||||
DatabaseQueryError,
|
||||
ErrorCode,
|
||||
KnowledgeBaseError,
|
||||
)
|
||||
from models import (
|
||||
ChunkType,
|
||||
CollectionInfo,
|
||||
CollectionStats,
|
||||
FileType,
|
||||
KnowledgeEmbedding,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""
|
||||
Manages PostgreSQL connections and vector operations.
|
||||
|
||||
Uses asyncpg for async operations and pgvector for
|
||||
vector similarity search.
|
||||
"""
|
||||
|
||||
def __init__(self, settings: Settings | None = None) -> None:
|
||||
"""Initialize database manager."""
|
||||
self._settings = settings or get_settings()
|
||||
self._pool: asyncpg.Pool | None = None # type: ignore[type-arg]
|
||||
|
||||
@property
|
||||
def pool(self) -> asyncpg.Pool: # type: ignore[type-arg]
|
||||
"""Get connection pool, raising if not initialized."""
|
||||
if self._pool is None:
|
||||
raise DatabaseConnectionError("Database pool not initialized")
|
||||
return self._pool
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize connection pool and create schema."""
|
||||
try:
|
||||
self._pool = await asyncpg.create_pool(
|
||||
self._settings.database_url,
|
||||
min_size=2,
|
||||
max_size=self._settings.database_pool_size,
|
||||
max_inactive_connection_lifetime=300,
|
||||
init=self._init_connection,
|
||||
)
|
||||
logger.info("Database pool created successfully")
|
||||
|
||||
# Create schema
|
||||
await self._create_schema()
|
||||
logger.info("Database schema initialized")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize database: {e}")
|
||||
raise DatabaseConnectionError(
|
||||
message=f"Failed to initialize database: {e}",
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def _init_connection(self, conn: asyncpg.Connection) -> None: # type: ignore[type-arg]
|
||||
"""Initialize a connection with pgvector support."""
|
||||
await register_vector(conn)
|
||||
|
||||
async def _create_schema(self) -> None:
|
||||
"""Create database schema if not exists."""
|
||||
async with self.pool.acquire() as conn:
|
||||
# Enable pgvector extension
|
||||
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||
|
||||
# Create main embeddings table
|
||||
await conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS knowledge_embeddings (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
project_id VARCHAR(255) NOT NULL,
|
||||
collection VARCHAR(255) NOT NULL DEFAULT 'default',
|
||||
content TEXT NOT NULL,
|
||||
embedding vector(1536),
|
||||
chunk_type VARCHAR(50) NOT NULL,
|
||||
source_path TEXT,
|
||||
start_line INTEGER,
|
||||
end_line INTEGER,
|
||||
file_type VARCHAR(50),
|
||||
metadata JSONB DEFAULT '{}',
|
||||
content_hash VARCHAR(64),
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
expires_at TIMESTAMPTZ
|
||||
)
|
||||
""")
|
||||
|
||||
# Create indexes for common queries
|
||||
await conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_embeddings_project_collection
|
||||
ON knowledge_embeddings(project_id, collection)
|
||||
""")
|
||||
|
||||
await conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_embeddings_source_path
|
||||
ON knowledge_embeddings(project_id, source_path)
|
||||
""")
|
||||
|
||||
await conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_embeddings_content_hash
|
||||
ON knowledge_embeddings(project_id, content_hash)
|
||||
""")
|
||||
|
||||
await conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_embeddings_chunk_type
|
||||
ON knowledge_embeddings(project_id, chunk_type)
|
||||
""")
|
||||
|
||||
# Create HNSW index for vector similarity search
|
||||
# This dramatically improves search performance
|
||||
await conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_embeddings_vector_hnsw
|
||||
ON knowledge_embeddings
|
||||
USING hnsw (embedding vector_cosine_ops)
|
||||
WITH (m = 16, ef_construction = 64)
|
||||
""")
|
||||
|
||||
# Create GIN index for full-text search
|
||||
await conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_embeddings_content_fts
|
||||
ON knowledge_embeddings
|
||||
USING gin(to_tsvector('english', content))
|
||||
""")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the connection pool."""
|
||||
if self._pool:
|
||||
await self._pool.close()
|
||||
self._pool = None
|
||||
logger.info("Database pool closed")
|
||||
|
||||
@asynccontextmanager
|
||||
async def acquire(self) -> Any:
|
||||
"""Acquire a connection from the pool."""
|
||||
async with self.pool.acquire() as conn:
|
||||
yield conn
|
||||
|
||||
@staticmethod
|
||||
def compute_content_hash(content: str) -> str:
|
||||
"""Compute SHA-256 hash of content for deduplication."""
|
||||
return hashlib.sha256(content.encode()).hexdigest()
|
||||
|
||||
async def store_embedding(
|
||||
self,
|
||||
project_id: str,
|
||||
collection: str,
|
||||
content: str,
|
||||
embedding: list[float],
|
||||
chunk_type: ChunkType,
|
||||
source_path: str | None = None,
|
||||
start_line: int | None = None,
|
||||
end_line: int | None = None,
|
||||
file_type: FileType | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Store an embedding in the database.
|
||||
|
||||
Returns:
|
||||
The ID of the stored embedding.
|
||||
"""
|
||||
content_hash = self.compute_content_hash(content)
|
||||
metadata = metadata or {}
|
||||
|
||||
# Calculate expiration if TTL is set
|
||||
expires_at = None
|
||||
if self._settings.embedding_ttl_days > 0:
|
||||
expires_at = datetime.now(UTC) + timedelta(
|
||||
days=self._settings.embedding_ttl_days
|
||||
)
|
||||
|
||||
try:
|
||||
async with self.acquire() as conn:
|
||||
# Check for duplicate content
|
||||
existing = await conn.fetchval(
|
||||
"""
|
||||
SELECT id FROM knowledge_embeddings
|
||||
WHERE project_id = $1 AND collection = $2 AND content_hash = $3
|
||||
""",
|
||||
project_id,
|
||||
collection,
|
||||
content_hash,
|
||||
)
|
||||
|
||||
if existing:
|
||||
# Update existing embedding
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE knowledge_embeddings
|
||||
SET embedding = $1, updated_at = NOW(), expires_at = $2,
|
||||
metadata = $3, source_path = $4, start_line = $5,
|
||||
end_line = $6, file_type = $7
|
||||
WHERE id = $8
|
||||
""",
|
||||
embedding,
|
||||
expires_at,
|
||||
metadata,
|
||||
source_path,
|
||||
start_line,
|
||||
end_line,
|
||||
file_type.value if file_type else None,
|
||||
existing,
|
||||
)
|
||||
logger.debug(f"Updated existing embedding: {existing}")
|
||||
return str(existing)
|
||||
|
||||
# Insert new embedding
|
||||
embedding_id = await conn.fetchval(
|
||||
"""
|
||||
INSERT INTO knowledge_embeddings
|
||||
(project_id, collection, content, embedding, chunk_type,
|
||||
source_path, start_line, end_line, file_type, metadata,
|
||||
content_hash, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
|
||||
RETURNING id
|
||||
""",
|
||||
project_id,
|
||||
collection,
|
||||
content,
|
||||
embedding,
|
||||
chunk_type.value,
|
||||
source_path,
|
||||
start_line,
|
||||
end_line,
|
||||
file_type.value if file_type else None,
|
||||
metadata,
|
||||
content_hash,
|
||||
expires_at,
|
||||
)
|
||||
logger.debug(f"Stored new embedding: {embedding_id}")
|
||||
return str(embedding_id)
|
||||
|
||||
except asyncpg.PostgresError as e:
|
||||
logger.error(f"Database error storing embedding: {e}")
|
||||
raise DatabaseQueryError(
|
||||
message=f"Failed to store embedding: {e}",
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def store_embeddings_batch(
|
||||
self,
|
||||
embeddings: list[tuple[str, str, str, list[float], ChunkType, dict[str, Any]]],
|
||||
) -> list[str]:
|
||||
"""
|
||||
Store multiple embeddings in a batch.
|
||||
|
||||
Args:
|
||||
embeddings: List of (project_id, collection, content, embedding, chunk_type, metadata)
|
||||
|
||||
Returns:
|
||||
List of created embedding IDs.
|
||||
"""
|
||||
if not embeddings:
|
||||
return []
|
||||
|
||||
ids = []
|
||||
expires_at = None
|
||||
if self._settings.embedding_ttl_days > 0:
|
||||
expires_at = datetime.now(UTC) + timedelta(
|
||||
days=self._settings.embedding_ttl_days
|
||||
)
|
||||
|
||||
try:
|
||||
async with self.acquire() as conn:
|
||||
for project_id, collection, content, embedding, chunk_type, metadata in embeddings:
|
||||
content_hash = self.compute_content_hash(content)
|
||||
source_path = metadata.get("source_path")
|
||||
start_line = metadata.get("start_line")
|
||||
end_line = metadata.get("end_line")
|
||||
file_type = metadata.get("file_type")
|
||||
|
||||
embedding_id = await conn.fetchval(
|
||||
"""
|
||||
INSERT INTO knowledge_embeddings
|
||||
(project_id, collection, content, embedding, chunk_type,
|
||||
source_path, start_line, end_line, file_type, metadata,
|
||||
content_hash, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
|
||||
ON CONFLICT DO NOTHING
|
||||
RETURNING id
|
||||
""",
|
||||
project_id,
|
||||
collection,
|
||||
content,
|
||||
embedding,
|
||||
chunk_type.value,
|
||||
source_path,
|
||||
start_line,
|
||||
end_line,
|
||||
file_type,
|
||||
metadata,
|
||||
content_hash,
|
||||
expires_at,
|
||||
)
|
||||
if embedding_id:
|
||||
ids.append(str(embedding_id))
|
||||
|
||||
logger.info(f"Stored {len(ids)} embeddings in batch")
|
||||
return ids
|
||||
|
||||
except asyncpg.PostgresError as e:
|
||||
logger.error(f"Database error in batch store: {e}")
|
||||
raise DatabaseQueryError(
|
||||
message=f"Failed to store embeddings batch: {e}",
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def semantic_search(
|
||||
self,
|
||||
project_id: str,
|
||||
query_embedding: list[float],
|
||||
collection: str | None = None,
|
||||
limit: int = 10,
|
||||
threshold: float = 0.7,
|
||||
file_types: list[FileType] | None = None,
|
||||
) -> list[tuple[KnowledgeEmbedding, float]]:
|
||||
"""
|
||||
Perform semantic (vector) search.
|
||||
|
||||
Returns:
|
||||
List of (embedding, similarity_score) tuples.
|
||||
"""
|
||||
try:
|
||||
async with self.acquire() as conn:
|
||||
# Build query with optional filters
|
||||
query = """
|
||||
SELECT
|
||||
id, project_id, collection, content, embedding,
|
||||
chunk_type, source_path, start_line, end_line,
|
||||
file_type, metadata, content_hash, created_at,
|
||||
updated_at, expires_at,
|
||||
1 - (embedding <=> $1) as similarity
|
||||
FROM knowledge_embeddings
|
||||
WHERE project_id = $2
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
"""
|
||||
params: list[Any] = [query_embedding, project_id]
|
||||
param_idx = 3
|
||||
|
||||
if collection:
|
||||
query += f" AND collection = ${param_idx}"
|
||||
params.append(collection)
|
||||
param_idx += 1
|
||||
|
||||
if file_types:
|
||||
file_type_values = [ft.value for ft in file_types]
|
||||
query += f" AND file_type = ANY(${param_idx})"
|
||||
params.append(file_type_values)
|
||||
param_idx += 1
|
||||
|
||||
query += f"""
|
||||
HAVING 1 - (embedding <=> $1) >= ${param_idx}
|
||||
ORDER BY similarity DESC
|
||||
LIMIT ${param_idx + 1}
|
||||
"""
|
||||
params.extend([threshold, limit])
|
||||
|
||||
rows = await conn.fetch(query, *params)
|
||||
|
||||
results = []
|
||||
for row in rows:
|
||||
embedding = KnowledgeEmbedding(
|
||||
id=str(row["id"]),
|
||||
project_id=row["project_id"],
|
||||
collection=row["collection"],
|
||||
content=row["content"],
|
||||
embedding=list(row["embedding"]),
|
||||
chunk_type=ChunkType(row["chunk_type"]),
|
||||
source_path=row["source_path"],
|
||||
start_line=row["start_line"],
|
||||
end_line=row["end_line"],
|
||||
file_type=FileType(row["file_type"]) if row["file_type"] else None,
|
||||
metadata=row["metadata"] or {},
|
||||
content_hash=row["content_hash"],
|
||||
created_at=row["created_at"],
|
||||
updated_at=row["updated_at"],
|
||||
expires_at=row["expires_at"],
|
||||
)
|
||||
results.append((embedding, float(row["similarity"])))
|
||||
|
||||
return results
|
||||
|
||||
except asyncpg.PostgresError as e:
|
||||
logger.error(f"Semantic search error: {e}")
|
||||
raise DatabaseQueryError(
|
||||
message=f"Semantic search failed: {e}",
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def keyword_search(
|
||||
self,
|
||||
project_id: str,
|
||||
query: str,
|
||||
collection: str | None = None,
|
||||
limit: int = 10,
|
||||
file_types: list[FileType] | None = None,
|
||||
) -> list[tuple[KnowledgeEmbedding, float]]:
|
||||
"""
|
||||
Perform full-text keyword search.
|
||||
|
||||
Returns:
|
||||
List of (embedding, relevance_score) tuples.
|
||||
"""
|
||||
try:
|
||||
async with self.acquire() as conn:
|
||||
# Build query with optional filters
|
||||
sql = """
|
||||
SELECT
|
||||
id, project_id, collection, content, embedding,
|
||||
chunk_type, source_path, start_line, end_line,
|
||||
file_type, metadata, content_hash, created_at,
|
||||
updated_at, expires_at,
|
||||
ts_rank(to_tsvector('english', content),
|
||||
plainto_tsquery('english', $1)) as relevance
|
||||
FROM knowledge_embeddings
|
||||
WHERE project_id = $2
|
||||
AND to_tsvector('english', content) @@ plainto_tsquery('english', $1)
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
"""
|
||||
params: list[Any] = [query, project_id]
|
||||
param_idx = 3
|
||||
|
||||
if collection:
|
||||
sql += f" AND collection = ${param_idx}"
|
||||
params.append(collection)
|
||||
param_idx += 1
|
||||
|
||||
if file_types:
|
||||
file_type_values = [ft.value for ft in file_types]
|
||||
sql += f" AND file_type = ANY(${param_idx})"
|
||||
params.append(file_type_values)
|
||||
param_idx += 1
|
||||
|
||||
sql += f" ORDER BY relevance DESC LIMIT ${param_idx}"
|
||||
params.append(limit)
|
||||
|
||||
rows = await conn.fetch(sql, *params)
|
||||
|
||||
results = []
|
||||
for row in rows:
|
||||
embedding = KnowledgeEmbedding(
|
||||
id=str(row["id"]),
|
||||
project_id=row["project_id"],
|
||||
collection=row["collection"],
|
||||
content=row["content"],
|
||||
embedding=list(row["embedding"]) if row["embedding"] else [],
|
||||
chunk_type=ChunkType(row["chunk_type"]),
|
||||
source_path=row["source_path"],
|
||||
start_line=row["start_line"],
|
||||
end_line=row["end_line"],
|
||||
file_type=FileType(row["file_type"]) if row["file_type"] else None,
|
||||
metadata=row["metadata"] or {},
|
||||
content_hash=row["content_hash"],
|
||||
created_at=row["created_at"],
|
||||
updated_at=row["updated_at"],
|
||||
expires_at=row["expires_at"],
|
||||
)
|
||||
# Normalize relevance to 0-1 scale (approximate)
|
||||
normalized_score = min(1.0, float(row["relevance"]))
|
||||
results.append((embedding, normalized_score))
|
||||
|
||||
return results
|
||||
|
||||
except asyncpg.PostgresError as e:
|
||||
logger.error(f"Keyword search error: {e}")
|
||||
raise DatabaseQueryError(
|
||||
message=f"Keyword search failed: {e}",
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def delete_by_source(
|
||||
self,
|
||||
project_id: str,
|
||||
source_path: str,
|
||||
collection: str | None = None,
|
||||
) -> int:
|
||||
"""Delete all embeddings for a source path."""
|
||||
try:
|
||||
async with self.acquire() as conn:
|
||||
if collection:
|
||||
result = await conn.execute(
|
||||
"""
|
||||
DELETE FROM knowledge_embeddings
|
||||
WHERE project_id = $1 AND source_path = $2 AND collection = $3
|
||||
""",
|
||||
project_id,
|
||||
source_path,
|
||||
collection,
|
||||
)
|
||||
else:
|
||||
result = await conn.execute(
|
||||
"""
|
||||
DELETE FROM knowledge_embeddings
|
||||
WHERE project_id = $1 AND source_path = $2
|
||||
""",
|
||||
project_id,
|
||||
source_path,
|
||||
)
|
||||
# Parse "DELETE N" result
|
||||
count = int(result.split()[-1])
|
||||
logger.info(f"Deleted {count} embeddings for source: {source_path}")
|
||||
return count
|
||||
|
||||
except asyncpg.PostgresError as e:
|
||||
logger.error(f"Delete error: {e}")
|
||||
raise DatabaseQueryError(
|
||||
message=f"Failed to delete by source: {e}",
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def delete_collection(
|
||||
self,
|
||||
project_id: str,
|
||||
collection: str,
|
||||
) -> int:
|
||||
"""Delete an entire collection."""
|
||||
try:
|
||||
async with self.acquire() as conn:
|
||||
result = await conn.execute(
|
||||
"""
|
||||
DELETE FROM knowledge_embeddings
|
||||
WHERE project_id = $1 AND collection = $2
|
||||
""",
|
||||
project_id,
|
||||
collection,
|
||||
)
|
||||
count = int(result.split()[-1])
|
||||
logger.info(f"Deleted collection {collection}: {count} embeddings")
|
||||
return count
|
||||
|
||||
except asyncpg.PostgresError as e:
|
||||
logger.error(f"Delete collection error: {e}")
|
||||
raise DatabaseQueryError(
|
||||
message=f"Failed to delete collection: {e}",
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def delete_by_ids(
|
||||
self,
|
||||
project_id: str,
|
||||
chunk_ids: list[str],
|
||||
) -> int:
|
||||
"""Delete specific embeddings by ID."""
|
||||
if not chunk_ids:
|
||||
return 0
|
||||
|
||||
try:
|
||||
# Convert string IDs to UUIDs
|
||||
uuids = [uuid.UUID(cid) for cid in chunk_ids]
|
||||
|
||||
async with self.acquire() as conn:
|
||||
result = await conn.execute(
|
||||
"""
|
||||
DELETE FROM knowledge_embeddings
|
||||
WHERE project_id = $1 AND id = ANY($2)
|
||||
""",
|
||||
project_id,
|
||||
uuids,
|
||||
)
|
||||
count = int(result.split()[-1])
|
||||
logger.info(f"Deleted {count} embeddings by ID")
|
||||
return count
|
||||
|
||||
except ValueError as e:
|
||||
raise KnowledgeBaseError(
|
||||
message=f"Invalid chunk ID format: {e}",
|
||||
code=ErrorCode.INVALID_REQUEST,
|
||||
)
|
||||
except asyncpg.PostgresError as e:
|
||||
logger.error(f"Delete by IDs error: {e}")
|
||||
raise DatabaseQueryError(
|
||||
message=f"Failed to delete by IDs: {e}",
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def list_collections(
|
||||
self,
|
||||
project_id: str,
|
||||
) -> list[CollectionInfo]:
|
||||
"""List all collections for a project."""
|
||||
try:
|
||||
async with self.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT
|
||||
collection,
|
||||
COUNT(*) as chunk_count,
|
||||
COALESCE(SUM((metadata->>'token_count')::int), 0) as total_tokens,
|
||||
ARRAY_AGG(DISTINCT file_type) FILTER (WHERE file_type IS NOT NULL) as file_types,
|
||||
MIN(created_at) as created_at,
|
||||
MAX(updated_at) as updated_at
|
||||
FROM knowledge_embeddings
|
||||
WHERE project_id = $1
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
GROUP BY collection
|
||||
ORDER BY collection
|
||||
""",
|
||||
project_id,
|
||||
)
|
||||
|
||||
return [
|
||||
CollectionInfo(
|
||||
name=row["collection"],
|
||||
project_id=project_id,
|
||||
chunk_count=row["chunk_count"],
|
||||
total_tokens=row["total_tokens"] or 0,
|
||||
file_types=row["file_types"] or [],
|
||||
created_at=row["created_at"],
|
||||
updated_at=row["updated_at"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
except asyncpg.PostgresError as e:
|
||||
logger.error(f"List collections error: {e}")
|
||||
raise DatabaseQueryError(
|
||||
message=f"Failed to list collections: {e}",
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def get_collection_stats(
|
||||
self,
|
||||
project_id: str,
|
||||
collection: str,
|
||||
) -> CollectionStats:
|
||||
"""Get detailed statistics for a collection."""
|
||||
try:
|
||||
async with self.acquire() as conn:
|
||||
# Check if collection exists
|
||||
exists = await conn.fetchval(
|
||||
"""
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM knowledge_embeddings
|
||||
WHERE project_id = $1 AND collection = $2
|
||||
)
|
||||
""",
|
||||
project_id,
|
||||
collection,
|
||||
)
|
||||
|
||||
if not exists:
|
||||
raise CollectionNotFoundError(collection, project_id)
|
||||
|
||||
# Get stats
|
||||
row = await conn.fetchrow(
|
||||
"""
|
||||
SELECT
|
||||
COUNT(*) as chunk_count,
|
||||
COUNT(DISTINCT source_path) as unique_sources,
|
||||
COALESCE(SUM((metadata->>'token_count')::int), 0) as total_tokens,
|
||||
COALESCE(AVG(LENGTH(content)), 0) as avg_chunk_size,
|
||||
MIN(created_at) as oldest_chunk,
|
||||
MAX(created_at) as newest_chunk
|
||||
FROM knowledge_embeddings
|
||||
WHERE project_id = $1 AND collection = $2
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
""",
|
||||
project_id,
|
||||
collection,
|
||||
)
|
||||
|
||||
# Get chunk type breakdown
|
||||
chunk_rows = await conn.fetch(
|
||||
"""
|
||||
SELECT chunk_type, COUNT(*) as count
|
||||
FROM knowledge_embeddings
|
||||
WHERE project_id = $1 AND collection = $2
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
GROUP BY chunk_type
|
||||
""",
|
||||
project_id,
|
||||
collection,
|
||||
)
|
||||
chunk_types = {r["chunk_type"]: r["count"] for r in chunk_rows}
|
||||
|
||||
# Get file type breakdown
|
||||
file_rows = await conn.fetch(
|
||||
"""
|
||||
SELECT file_type, COUNT(*) as count
|
||||
FROM knowledge_embeddings
|
||||
WHERE project_id = $1 AND collection = $2
|
||||
AND file_type IS NOT NULL
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
GROUP BY file_type
|
||||
""",
|
||||
project_id,
|
||||
collection,
|
||||
)
|
||||
file_types = {r["file_type"]: r["count"] for r in file_rows}
|
||||
|
||||
return CollectionStats(
|
||||
collection=collection,
|
||||
project_id=project_id,
|
||||
chunk_count=row["chunk_count"],
|
||||
unique_sources=row["unique_sources"],
|
||||
total_tokens=row["total_tokens"] or 0,
|
||||
avg_chunk_size=float(row["avg_chunk_size"] or 0),
|
||||
chunk_types=chunk_types,
|
||||
file_types=file_types,
|
||||
oldest_chunk=row["oldest_chunk"],
|
||||
newest_chunk=row["newest_chunk"],
|
||||
)
|
||||
|
||||
except CollectionNotFoundError:
|
||||
raise
|
||||
except asyncpg.PostgresError as e:
|
||||
logger.error(f"Get collection stats error: {e}")
|
||||
raise DatabaseQueryError(
|
||||
message=f"Failed to get collection stats: {e}",
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def cleanup_expired(self) -> int:
|
||||
"""Remove expired embeddings."""
|
||||
try:
|
||||
async with self.acquire() as conn:
|
||||
result = await conn.execute(
|
||||
"""
|
||||
DELETE FROM knowledge_embeddings
|
||||
WHERE expires_at IS NOT NULL AND expires_at < NOW()
|
||||
"""
|
||||
)
|
||||
count = int(result.split()[-1])
|
||||
if count > 0:
|
||||
logger.info(f"Cleaned up {count} expired embeddings")
|
||||
return count
|
||||
|
||||
except asyncpg.PostgresError as e:
|
||||
logger.error(f"Cleanup error: {e}")
|
||||
raise DatabaseQueryError(
|
||||
message=f"Failed to cleanup expired: {e}",
|
||||
cause=e,
|
||||
)
|
||||
|
||||
|
||||
# Global database manager instance (lazy initialization)
|
||||
_db_manager: DatabaseManager | None = None
|
||||
|
||||
|
||||
def get_database_manager() -> DatabaseManager:
|
||||
"""Get the global database manager instance."""
|
||||
global _db_manager
|
||||
if _db_manager is None:
|
||||
_db_manager = DatabaseManager()
|
||||
return _db_manager
|
||||
|
||||
|
||||
def reset_database_manager() -> None:
|
||||
"""Reset the global database manager (for testing)."""
|
||||
global _db_manager
|
||||
_db_manager = None
|
||||
Reference in New Issue
Block a user