forked from cardosofelipe/fast-next-template
- Fix SQL HAVING clause bug by using CTE approach (closes #73) - Add /mcp JSON-RPC 2.0 endpoint for tool execution (closes #74) - Add /mcp/tools endpoint for tool discovery (closes #75) - Add content size limits to prevent DoS attacks (closes #78) - Add comprehensive tests for new endpoints 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
779 lines
28 KiB
Python
779 lines
28 KiB
Python
"""
|
|
Database management for Knowledge Base MCP Server.
|
|
|
|
Handles PostgreSQL connections with pgvector extension for
|
|
vector similarity search operations.
|
|
"""
|
|
|
|
import hashlib
|
|
import logging
|
|
import uuid
|
|
from contextlib import asynccontextmanager
|
|
from datetime import UTC, datetime, timedelta
|
|
from typing import Any
|
|
|
|
import asyncpg
|
|
from pgvector.asyncpg import register_vector
|
|
|
|
from config import Settings, get_settings
|
|
from exceptions import (
|
|
CollectionNotFoundError,
|
|
DatabaseConnectionError,
|
|
DatabaseQueryError,
|
|
ErrorCode,
|
|
KnowledgeBaseError,
|
|
)
|
|
from models import (
|
|
ChunkType,
|
|
CollectionInfo,
|
|
CollectionStats,
|
|
FileType,
|
|
KnowledgeEmbedding,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DatabaseManager:
|
|
"""
|
|
Manages PostgreSQL connections and vector operations.
|
|
|
|
Uses asyncpg for async operations and pgvector for
|
|
vector similarity search.
|
|
"""
|
|
|
|
def __init__(self, settings: Settings | None = None) -> None:
|
|
"""Initialize database manager."""
|
|
self._settings = settings or get_settings()
|
|
self._pool: asyncpg.Pool | None = None # type: ignore[type-arg]
|
|
|
|
@property
|
|
def pool(self) -> asyncpg.Pool: # type: ignore[type-arg]
|
|
"""Get connection pool, raising if not initialized."""
|
|
if self._pool is None:
|
|
raise DatabaseConnectionError("Database pool not initialized")
|
|
return self._pool
|
|
|
|
async def initialize(self) -> None:
|
|
"""Initialize connection pool and create schema."""
|
|
try:
|
|
self._pool = await asyncpg.create_pool(
|
|
self._settings.database_url,
|
|
min_size=2,
|
|
max_size=self._settings.database_pool_size,
|
|
max_inactive_connection_lifetime=300,
|
|
init=self._init_connection,
|
|
)
|
|
logger.info("Database pool created successfully")
|
|
|
|
# Create schema
|
|
await self._create_schema()
|
|
logger.info("Database schema initialized")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize database: {e}")
|
|
raise DatabaseConnectionError(
|
|
message=f"Failed to initialize database: {e}",
|
|
cause=e,
|
|
)
|
|
|
|
async def _init_connection(self, conn: asyncpg.Connection) -> None: # type: ignore[type-arg]
|
|
"""Initialize a connection with pgvector support."""
|
|
await register_vector(conn)
|
|
|
|
async def _create_schema(self) -> None:
|
|
"""Create database schema if not exists."""
|
|
async with self.pool.acquire() as conn:
|
|
# Enable pgvector extension
|
|
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
|
|
|
# Create main embeddings table
|
|
await conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS knowledge_embeddings (
|
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
|
project_id VARCHAR(255) NOT NULL,
|
|
collection VARCHAR(255) NOT NULL DEFAULT 'default',
|
|
content TEXT NOT NULL,
|
|
embedding vector(1536),
|
|
chunk_type VARCHAR(50) NOT NULL,
|
|
source_path TEXT,
|
|
start_line INTEGER,
|
|
end_line INTEGER,
|
|
file_type VARCHAR(50),
|
|
metadata JSONB DEFAULT '{}',
|
|
content_hash VARCHAR(64),
|
|
created_at TIMESTAMPTZ DEFAULT NOW(),
|
|
updated_at TIMESTAMPTZ DEFAULT NOW(),
|
|
expires_at TIMESTAMPTZ
|
|
)
|
|
""")
|
|
|
|
# Create indexes for common queries
|
|
await conn.execute("""
|
|
CREATE INDEX IF NOT EXISTS idx_embeddings_project_collection
|
|
ON knowledge_embeddings(project_id, collection)
|
|
""")
|
|
|
|
await conn.execute("""
|
|
CREATE INDEX IF NOT EXISTS idx_embeddings_source_path
|
|
ON knowledge_embeddings(project_id, source_path)
|
|
""")
|
|
|
|
await conn.execute("""
|
|
CREATE INDEX IF NOT EXISTS idx_embeddings_content_hash
|
|
ON knowledge_embeddings(project_id, content_hash)
|
|
""")
|
|
|
|
await conn.execute("""
|
|
CREATE INDEX IF NOT EXISTS idx_embeddings_chunk_type
|
|
ON knowledge_embeddings(project_id, chunk_type)
|
|
""")
|
|
|
|
# Create HNSW index for vector similarity search
|
|
# This dramatically improves search performance
|
|
await conn.execute("""
|
|
CREATE INDEX IF NOT EXISTS idx_embeddings_vector_hnsw
|
|
ON knowledge_embeddings
|
|
USING hnsw (embedding vector_cosine_ops)
|
|
WITH (m = 16, ef_construction = 64)
|
|
""")
|
|
|
|
# Create GIN index for full-text search
|
|
await conn.execute("""
|
|
CREATE INDEX IF NOT EXISTS idx_embeddings_content_fts
|
|
ON knowledge_embeddings
|
|
USING gin(to_tsvector('english', content))
|
|
""")
|
|
|
|
async def close(self) -> None:
|
|
"""Close the connection pool."""
|
|
if self._pool:
|
|
await self._pool.close()
|
|
self._pool = None
|
|
logger.info("Database pool closed")
|
|
|
|
@asynccontextmanager
|
|
async def acquire(self) -> Any:
|
|
"""Acquire a connection from the pool."""
|
|
async with self.pool.acquire() as conn:
|
|
yield conn
|
|
|
|
@staticmethod
|
|
def compute_content_hash(content: str) -> str:
|
|
"""Compute SHA-256 hash of content for deduplication."""
|
|
return hashlib.sha256(content.encode()).hexdigest()
|
|
|
|
async def store_embedding(
|
|
self,
|
|
project_id: str,
|
|
collection: str,
|
|
content: str,
|
|
embedding: list[float],
|
|
chunk_type: ChunkType,
|
|
source_path: str | None = None,
|
|
start_line: int | None = None,
|
|
end_line: int | None = None,
|
|
file_type: FileType | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> str:
|
|
"""
|
|
Store an embedding in the database.
|
|
|
|
Returns:
|
|
The ID of the stored embedding.
|
|
"""
|
|
content_hash = self.compute_content_hash(content)
|
|
metadata = metadata or {}
|
|
|
|
# Calculate expiration if TTL is set
|
|
expires_at = None
|
|
if self._settings.embedding_ttl_days > 0:
|
|
expires_at = datetime.now(UTC) + timedelta(
|
|
days=self._settings.embedding_ttl_days
|
|
)
|
|
|
|
try:
|
|
async with self.acquire() as conn:
|
|
# Check for duplicate content
|
|
existing = await conn.fetchval(
|
|
"""
|
|
SELECT id FROM knowledge_embeddings
|
|
WHERE project_id = $1 AND collection = $2 AND content_hash = $3
|
|
""",
|
|
project_id,
|
|
collection,
|
|
content_hash,
|
|
)
|
|
|
|
if existing:
|
|
# Update existing embedding
|
|
await conn.execute(
|
|
"""
|
|
UPDATE knowledge_embeddings
|
|
SET embedding = $1, updated_at = NOW(), expires_at = $2,
|
|
metadata = $3, source_path = $4, start_line = $5,
|
|
end_line = $6, file_type = $7
|
|
WHERE id = $8
|
|
""",
|
|
embedding,
|
|
expires_at,
|
|
metadata,
|
|
source_path,
|
|
start_line,
|
|
end_line,
|
|
file_type.value if file_type else None,
|
|
existing,
|
|
)
|
|
logger.debug(f"Updated existing embedding: {existing}")
|
|
return str(existing)
|
|
|
|
# Insert new embedding
|
|
embedding_id = await conn.fetchval(
|
|
"""
|
|
INSERT INTO knowledge_embeddings
|
|
(project_id, collection, content, embedding, chunk_type,
|
|
source_path, start_line, end_line, file_type, metadata,
|
|
content_hash, expires_at)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
|
|
RETURNING id
|
|
""",
|
|
project_id,
|
|
collection,
|
|
content,
|
|
embedding,
|
|
chunk_type.value,
|
|
source_path,
|
|
start_line,
|
|
end_line,
|
|
file_type.value if file_type else None,
|
|
metadata,
|
|
content_hash,
|
|
expires_at,
|
|
)
|
|
logger.debug(f"Stored new embedding: {embedding_id}")
|
|
return str(embedding_id)
|
|
|
|
except asyncpg.PostgresError as e:
|
|
logger.error(f"Database error storing embedding: {e}")
|
|
raise DatabaseQueryError(
|
|
message=f"Failed to store embedding: {e}",
|
|
cause=e,
|
|
)
|
|
|
|
async def store_embeddings_batch(
|
|
self,
|
|
embeddings: list[tuple[str, str, str, list[float], ChunkType, dict[str, Any]]],
|
|
) -> list[str]:
|
|
"""
|
|
Store multiple embeddings in a batch.
|
|
|
|
Args:
|
|
embeddings: List of (project_id, collection, content, embedding, chunk_type, metadata)
|
|
|
|
Returns:
|
|
List of created embedding IDs.
|
|
"""
|
|
if not embeddings:
|
|
return []
|
|
|
|
ids = []
|
|
expires_at = None
|
|
if self._settings.embedding_ttl_days > 0:
|
|
expires_at = datetime.now(UTC) + timedelta(
|
|
days=self._settings.embedding_ttl_days
|
|
)
|
|
|
|
try:
|
|
async with self.acquire() as conn:
|
|
for project_id, collection, content, embedding, chunk_type, metadata in embeddings:
|
|
content_hash = self.compute_content_hash(content)
|
|
source_path = metadata.get("source_path")
|
|
start_line = metadata.get("start_line")
|
|
end_line = metadata.get("end_line")
|
|
file_type = metadata.get("file_type")
|
|
|
|
embedding_id = await conn.fetchval(
|
|
"""
|
|
INSERT INTO knowledge_embeddings
|
|
(project_id, collection, content, embedding, chunk_type,
|
|
source_path, start_line, end_line, file_type, metadata,
|
|
content_hash, expires_at)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
|
|
ON CONFLICT DO NOTHING
|
|
RETURNING id
|
|
""",
|
|
project_id,
|
|
collection,
|
|
content,
|
|
embedding,
|
|
chunk_type.value,
|
|
source_path,
|
|
start_line,
|
|
end_line,
|
|
file_type,
|
|
metadata,
|
|
content_hash,
|
|
expires_at,
|
|
)
|
|
if embedding_id:
|
|
ids.append(str(embedding_id))
|
|
|
|
logger.info(f"Stored {len(ids)} embeddings in batch")
|
|
return ids
|
|
|
|
except asyncpg.PostgresError as e:
|
|
logger.error(f"Database error in batch store: {e}")
|
|
raise DatabaseQueryError(
|
|
message=f"Failed to store embeddings batch: {e}",
|
|
cause=e,
|
|
)
|
|
|
|
async def semantic_search(
|
|
self,
|
|
project_id: str,
|
|
query_embedding: list[float],
|
|
collection: str | None = None,
|
|
limit: int = 10,
|
|
threshold: float = 0.7,
|
|
file_types: list[FileType] | None = None,
|
|
) -> list[tuple[KnowledgeEmbedding, float]]:
|
|
"""
|
|
Perform semantic (vector) search.
|
|
|
|
Returns:
|
|
List of (embedding, similarity_score) tuples.
|
|
"""
|
|
try:
|
|
async with self.acquire() as conn:
|
|
# Build query with optional filters using CTE to filter by similarity
|
|
# We use a CTE to compute similarity once, then filter in outer query
|
|
inner_query = """
|
|
SELECT
|
|
id, project_id, collection, content, embedding,
|
|
chunk_type, source_path, start_line, end_line,
|
|
file_type, metadata, content_hash, created_at,
|
|
updated_at, expires_at,
|
|
1 - (embedding <=> $1) as similarity
|
|
FROM knowledge_embeddings
|
|
WHERE project_id = $2
|
|
AND (expires_at IS NULL OR expires_at > NOW())
|
|
"""
|
|
params: list[Any] = [query_embedding, project_id]
|
|
param_idx = 3
|
|
|
|
if collection:
|
|
inner_query += f" AND collection = ${param_idx}"
|
|
params.append(collection)
|
|
param_idx += 1
|
|
|
|
if file_types:
|
|
file_type_values = [ft.value for ft in file_types]
|
|
inner_query += f" AND file_type = ANY(${param_idx})"
|
|
params.append(file_type_values)
|
|
param_idx += 1
|
|
|
|
# Wrap in CTE and filter by threshold in outer query
|
|
query = f"""
|
|
WITH scored AS ({inner_query})
|
|
SELECT * FROM scored
|
|
WHERE similarity >= ${param_idx}
|
|
ORDER BY similarity DESC
|
|
LIMIT ${param_idx + 1}
|
|
"""
|
|
params.extend([threshold, limit])
|
|
|
|
rows = await conn.fetch(query, *params)
|
|
|
|
results = []
|
|
for row in rows:
|
|
embedding = KnowledgeEmbedding(
|
|
id=str(row["id"]),
|
|
project_id=row["project_id"],
|
|
collection=row["collection"],
|
|
content=row["content"],
|
|
embedding=list(row["embedding"]),
|
|
chunk_type=ChunkType(row["chunk_type"]),
|
|
source_path=row["source_path"],
|
|
start_line=row["start_line"],
|
|
end_line=row["end_line"],
|
|
file_type=FileType(row["file_type"]) if row["file_type"] else None,
|
|
metadata=row["metadata"] or {},
|
|
content_hash=row["content_hash"],
|
|
created_at=row["created_at"],
|
|
updated_at=row["updated_at"],
|
|
expires_at=row["expires_at"],
|
|
)
|
|
results.append((embedding, float(row["similarity"])))
|
|
|
|
return results
|
|
|
|
except asyncpg.PostgresError as e:
|
|
logger.error(f"Semantic search error: {e}")
|
|
raise DatabaseQueryError(
|
|
message=f"Semantic search failed: {e}",
|
|
cause=e,
|
|
)
|
|
|
|
async def keyword_search(
|
|
self,
|
|
project_id: str,
|
|
query: str,
|
|
collection: str | None = None,
|
|
limit: int = 10,
|
|
file_types: list[FileType] | None = None,
|
|
) -> list[tuple[KnowledgeEmbedding, float]]:
|
|
"""
|
|
Perform full-text keyword search.
|
|
|
|
Returns:
|
|
List of (embedding, relevance_score) tuples.
|
|
"""
|
|
try:
|
|
async with self.acquire() as conn:
|
|
# Build query with optional filters
|
|
sql = """
|
|
SELECT
|
|
id, project_id, collection, content, embedding,
|
|
chunk_type, source_path, start_line, end_line,
|
|
file_type, metadata, content_hash, created_at,
|
|
updated_at, expires_at,
|
|
ts_rank(to_tsvector('english', content),
|
|
plainto_tsquery('english', $1)) as relevance
|
|
FROM knowledge_embeddings
|
|
WHERE project_id = $2
|
|
AND to_tsvector('english', content) @@ plainto_tsquery('english', $1)
|
|
AND (expires_at IS NULL OR expires_at > NOW())
|
|
"""
|
|
params: list[Any] = [query, project_id]
|
|
param_idx = 3
|
|
|
|
if collection:
|
|
sql += f" AND collection = ${param_idx}"
|
|
params.append(collection)
|
|
param_idx += 1
|
|
|
|
if file_types:
|
|
file_type_values = [ft.value for ft in file_types]
|
|
sql += f" AND file_type = ANY(${param_idx})"
|
|
params.append(file_type_values)
|
|
param_idx += 1
|
|
|
|
sql += f" ORDER BY relevance DESC LIMIT ${param_idx}"
|
|
params.append(limit)
|
|
|
|
rows = await conn.fetch(sql, *params)
|
|
|
|
results = []
|
|
for row in rows:
|
|
embedding = KnowledgeEmbedding(
|
|
id=str(row["id"]),
|
|
project_id=row["project_id"],
|
|
collection=row["collection"],
|
|
content=row["content"],
|
|
embedding=list(row["embedding"]) if row["embedding"] else [],
|
|
chunk_type=ChunkType(row["chunk_type"]),
|
|
source_path=row["source_path"],
|
|
start_line=row["start_line"],
|
|
end_line=row["end_line"],
|
|
file_type=FileType(row["file_type"]) if row["file_type"] else None,
|
|
metadata=row["metadata"] or {},
|
|
content_hash=row["content_hash"],
|
|
created_at=row["created_at"],
|
|
updated_at=row["updated_at"],
|
|
expires_at=row["expires_at"],
|
|
)
|
|
# Normalize relevance to 0-1 scale (approximate)
|
|
normalized_score = min(1.0, float(row["relevance"]))
|
|
results.append((embedding, normalized_score))
|
|
|
|
return results
|
|
|
|
except asyncpg.PostgresError as e:
|
|
logger.error(f"Keyword search error: {e}")
|
|
raise DatabaseQueryError(
|
|
message=f"Keyword search failed: {e}",
|
|
cause=e,
|
|
)
|
|
|
|
async def delete_by_source(
|
|
self,
|
|
project_id: str,
|
|
source_path: str,
|
|
collection: str | None = None,
|
|
) -> int:
|
|
"""Delete all embeddings for a source path."""
|
|
try:
|
|
async with self.acquire() as conn:
|
|
if collection:
|
|
result = await conn.execute(
|
|
"""
|
|
DELETE FROM knowledge_embeddings
|
|
WHERE project_id = $1 AND source_path = $2 AND collection = $3
|
|
""",
|
|
project_id,
|
|
source_path,
|
|
collection,
|
|
)
|
|
else:
|
|
result = await conn.execute(
|
|
"""
|
|
DELETE FROM knowledge_embeddings
|
|
WHERE project_id = $1 AND source_path = $2
|
|
""",
|
|
project_id,
|
|
source_path,
|
|
)
|
|
# Parse "DELETE N" result
|
|
count = int(result.split()[-1])
|
|
logger.info(f"Deleted {count} embeddings for source: {source_path}")
|
|
return count
|
|
|
|
except asyncpg.PostgresError as e:
|
|
logger.error(f"Delete error: {e}")
|
|
raise DatabaseQueryError(
|
|
message=f"Failed to delete by source: {e}",
|
|
cause=e,
|
|
)
|
|
|
|
async def delete_collection(
|
|
self,
|
|
project_id: str,
|
|
collection: str,
|
|
) -> int:
|
|
"""Delete an entire collection."""
|
|
try:
|
|
async with self.acquire() as conn:
|
|
result = await conn.execute(
|
|
"""
|
|
DELETE FROM knowledge_embeddings
|
|
WHERE project_id = $1 AND collection = $2
|
|
""",
|
|
project_id,
|
|
collection,
|
|
)
|
|
count = int(result.split()[-1])
|
|
logger.info(f"Deleted collection {collection}: {count} embeddings")
|
|
return count
|
|
|
|
except asyncpg.PostgresError as e:
|
|
logger.error(f"Delete collection error: {e}")
|
|
raise DatabaseQueryError(
|
|
message=f"Failed to delete collection: {e}",
|
|
cause=e,
|
|
)
|
|
|
|
async def delete_by_ids(
|
|
self,
|
|
project_id: str,
|
|
chunk_ids: list[str],
|
|
) -> int:
|
|
"""Delete specific embeddings by ID."""
|
|
if not chunk_ids:
|
|
return 0
|
|
|
|
try:
|
|
# Convert string IDs to UUIDs
|
|
uuids = [uuid.UUID(cid) for cid in chunk_ids]
|
|
|
|
async with self.acquire() as conn:
|
|
result = await conn.execute(
|
|
"""
|
|
DELETE FROM knowledge_embeddings
|
|
WHERE project_id = $1 AND id = ANY($2)
|
|
""",
|
|
project_id,
|
|
uuids,
|
|
)
|
|
count = int(result.split()[-1])
|
|
logger.info(f"Deleted {count} embeddings by ID")
|
|
return count
|
|
|
|
except ValueError as e:
|
|
raise KnowledgeBaseError(
|
|
message=f"Invalid chunk ID format: {e}",
|
|
code=ErrorCode.INVALID_REQUEST,
|
|
)
|
|
except asyncpg.PostgresError as e:
|
|
logger.error(f"Delete by IDs error: {e}")
|
|
raise DatabaseQueryError(
|
|
message=f"Failed to delete by IDs: {e}",
|
|
cause=e,
|
|
)
|
|
|
|
async def list_collections(
|
|
self,
|
|
project_id: str,
|
|
) -> list[CollectionInfo]:
|
|
"""List all collections for a project."""
|
|
try:
|
|
async with self.acquire() as conn:
|
|
rows = await conn.fetch(
|
|
"""
|
|
SELECT
|
|
collection,
|
|
COUNT(*) as chunk_count,
|
|
COALESCE(SUM((metadata->>'token_count')::int), 0) as total_tokens,
|
|
ARRAY_AGG(DISTINCT file_type) FILTER (WHERE file_type IS NOT NULL) as file_types,
|
|
MIN(created_at) as created_at,
|
|
MAX(updated_at) as updated_at
|
|
FROM knowledge_embeddings
|
|
WHERE project_id = $1
|
|
AND (expires_at IS NULL OR expires_at > NOW())
|
|
GROUP BY collection
|
|
ORDER BY collection
|
|
""",
|
|
project_id,
|
|
)
|
|
|
|
return [
|
|
CollectionInfo(
|
|
name=row["collection"],
|
|
project_id=project_id,
|
|
chunk_count=row["chunk_count"],
|
|
total_tokens=row["total_tokens"] or 0,
|
|
file_types=row["file_types"] or [],
|
|
created_at=row["created_at"],
|
|
updated_at=row["updated_at"],
|
|
)
|
|
for row in rows
|
|
]
|
|
|
|
except asyncpg.PostgresError as e:
|
|
logger.error(f"List collections error: {e}")
|
|
raise DatabaseQueryError(
|
|
message=f"Failed to list collections: {e}",
|
|
cause=e,
|
|
)
|
|
|
|
async def get_collection_stats(
|
|
self,
|
|
project_id: str,
|
|
collection: str,
|
|
) -> CollectionStats:
|
|
"""Get detailed statistics for a collection."""
|
|
try:
|
|
async with self.acquire() as conn:
|
|
# Check if collection exists
|
|
exists = await conn.fetchval(
|
|
"""
|
|
SELECT EXISTS(
|
|
SELECT 1 FROM knowledge_embeddings
|
|
WHERE project_id = $1 AND collection = $2
|
|
)
|
|
""",
|
|
project_id,
|
|
collection,
|
|
)
|
|
|
|
if not exists:
|
|
raise CollectionNotFoundError(collection, project_id)
|
|
|
|
# Get stats
|
|
row = await conn.fetchrow(
|
|
"""
|
|
SELECT
|
|
COUNT(*) as chunk_count,
|
|
COUNT(DISTINCT source_path) as unique_sources,
|
|
COALESCE(SUM((metadata->>'token_count')::int), 0) as total_tokens,
|
|
COALESCE(AVG(LENGTH(content)), 0) as avg_chunk_size,
|
|
MIN(created_at) as oldest_chunk,
|
|
MAX(created_at) as newest_chunk
|
|
FROM knowledge_embeddings
|
|
WHERE project_id = $1 AND collection = $2
|
|
AND (expires_at IS NULL OR expires_at > NOW())
|
|
""",
|
|
project_id,
|
|
collection,
|
|
)
|
|
|
|
# Get chunk type breakdown
|
|
chunk_rows = await conn.fetch(
|
|
"""
|
|
SELECT chunk_type, COUNT(*) as count
|
|
FROM knowledge_embeddings
|
|
WHERE project_id = $1 AND collection = $2
|
|
AND (expires_at IS NULL OR expires_at > NOW())
|
|
GROUP BY chunk_type
|
|
""",
|
|
project_id,
|
|
collection,
|
|
)
|
|
chunk_types = {r["chunk_type"]: r["count"] for r in chunk_rows}
|
|
|
|
# Get file type breakdown
|
|
file_rows = await conn.fetch(
|
|
"""
|
|
SELECT file_type, COUNT(*) as count
|
|
FROM knowledge_embeddings
|
|
WHERE project_id = $1 AND collection = $2
|
|
AND file_type IS NOT NULL
|
|
AND (expires_at IS NULL OR expires_at > NOW())
|
|
GROUP BY file_type
|
|
""",
|
|
project_id,
|
|
collection,
|
|
)
|
|
file_types = {r["file_type"]: r["count"] for r in file_rows}
|
|
|
|
return CollectionStats(
|
|
collection=collection,
|
|
project_id=project_id,
|
|
chunk_count=row["chunk_count"],
|
|
unique_sources=row["unique_sources"],
|
|
total_tokens=row["total_tokens"] or 0,
|
|
avg_chunk_size=float(row["avg_chunk_size"] or 0),
|
|
chunk_types=chunk_types,
|
|
file_types=file_types,
|
|
oldest_chunk=row["oldest_chunk"],
|
|
newest_chunk=row["newest_chunk"],
|
|
)
|
|
|
|
except CollectionNotFoundError:
|
|
raise
|
|
except asyncpg.PostgresError as e:
|
|
logger.error(f"Get collection stats error: {e}")
|
|
raise DatabaseQueryError(
|
|
message=f"Failed to get collection stats: {e}",
|
|
cause=e,
|
|
)
|
|
|
|
async def cleanup_expired(self) -> int:
|
|
"""Remove expired embeddings."""
|
|
try:
|
|
async with self.acquire() as conn:
|
|
result = await conn.execute(
|
|
"""
|
|
DELETE FROM knowledge_embeddings
|
|
WHERE expires_at IS NOT NULL AND expires_at < NOW()
|
|
"""
|
|
)
|
|
count = int(result.split()[-1])
|
|
if count > 0:
|
|
logger.info(f"Cleaned up {count} expired embeddings")
|
|
return count
|
|
|
|
except asyncpg.PostgresError as e:
|
|
logger.error(f"Cleanup error: {e}")
|
|
raise DatabaseQueryError(
|
|
message=f"Failed to cleanup expired: {e}",
|
|
cause=e,
|
|
)
|
|
|
|
|
|
# Global database manager instance (lazy initialization)
|
|
_db_manager: DatabaseManager | None = None
|
|
|
|
|
|
def get_database_manager() -> DatabaseManager:
|
|
"""Get the global database manager instance."""
|
|
global _db_manager
|
|
if _db_manager is None:
|
|
_db_manager = DatabaseManager()
|
|
return _db_manager
|
|
|
|
|
|
def reset_database_manager() -> None:
|
|
"""Reset the global database manager (for testing)."""
|
|
global _db_manager
|
|
_db_manager = None
|