forked from cardosofelipe/fast-next-template
feat(knowledge-base): implement Knowledge Base MCP Server (#57)
Implements RAG capabilities with pgvector for semantic search: - Intelligent chunking strategies (code-aware, markdown-aware, text) - Semantic search with vector similarity (HNSW index) - Keyword search with PostgreSQL full-text search - Hybrid search using Reciprocal Rank Fusion (RRF) - Redis caching for embeddings - Collection management (ingest, search, delete, stats) - FastMCP tools: search_knowledge, ingest_content, delete_content, list_collections, get_collection_stats, update_document Testing: - 128 comprehensive tests covering all components - 58% code coverage (database integration tests use mocks) - Passes ruff linting and mypy type checking 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
31
mcp-servers/knowledge-base/Dockerfile
Normal file
31
mcp-servers/knowledge-base/Dockerfile
Normal file
@@ -0,0 +1,31 @@
|
||||
FROM python:3.12-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install uv for fast package installation
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /bin/uv
|
||||
|
||||
# Copy project files
|
||||
COPY pyproject.toml ./
|
||||
COPY *.py ./
|
||||
COPY chunking/ ./chunking/
|
||||
|
||||
# Install dependencies
|
||||
RUN uv pip install --system --no-cache .
|
||||
|
||||
# Create non-root user
|
||||
RUN useradd --create-home --shell /bin/bash appuser
|
||||
USER appuser
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD python -c "import httpx; httpx.get('http://localhost:8002/health').raise_for_status()"
|
||||
|
||||
EXPOSE 8002
|
||||
|
||||
CMD ["python", "server.py"]
|
||||
19
mcp-servers/knowledge-base/chunking/__init__.py
Normal file
19
mcp-servers/knowledge-base/chunking/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
Chunking module for Knowledge Base MCP Server.
|
||||
|
||||
Provides intelligent content chunking for different file types
|
||||
with overlap and context preservation.
|
||||
"""
|
||||
|
||||
from chunking.base import BaseChunker, ChunkerFactory
|
||||
from chunking.code import CodeChunker
|
||||
from chunking.markdown import MarkdownChunker
|
||||
from chunking.text import TextChunker
|
||||
|
||||
__all__ = [
|
||||
"BaseChunker",
|
||||
"ChunkerFactory",
|
||||
"CodeChunker",
|
||||
"MarkdownChunker",
|
||||
"TextChunker",
|
||||
]
|
||||
281
mcp-servers/knowledge-base/chunking/base.py
Normal file
281
mcp-servers/knowledge-base/chunking/base.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""
|
||||
Base chunker implementation.
|
||||
|
||||
Provides abstract interface and common utilities for content chunking.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
|
||||
from config import Settings, get_settings
|
||||
from exceptions import ChunkingError
|
||||
from models import FILE_EXTENSION_MAP, Chunk, ChunkType, FileType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseChunker(ABC):
|
||||
"""
|
||||
Abstract base class for content chunkers.
|
||||
|
||||
Subclasses implement specific chunking strategies for
|
||||
different content types (code, markdown, text).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int,
|
||||
chunk_overlap: int,
|
||||
settings: Settings | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize chunker.
|
||||
|
||||
Args:
|
||||
chunk_size: Target tokens per chunk
|
||||
chunk_overlap: Token overlap between chunks
|
||||
settings: Application settings
|
||||
"""
|
||||
self._settings = settings or get_settings()
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
|
||||
# Use cl100k_base encoding (GPT-4/text-embedding-3)
|
||||
self._tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
def count_tokens(self, text: str) -> int:
|
||||
"""Count tokens in text."""
|
||||
return len(self._tokenizer.encode(text))
|
||||
|
||||
def truncate_to_tokens(self, text: str, max_tokens: int) -> str:
|
||||
"""Truncate text to max tokens."""
|
||||
tokens = self._tokenizer.encode(text)
|
||||
if len(tokens) <= max_tokens:
|
||||
return text
|
||||
return self._tokenizer.decode(tokens[:max_tokens])
|
||||
|
||||
@abstractmethod
|
||||
def chunk(
|
||||
self,
|
||||
content: str,
|
||||
source_path: str | None = None,
|
||||
file_type: FileType | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> list[Chunk]:
|
||||
"""
|
||||
Split content into chunks.
|
||||
|
||||
Args:
|
||||
content: Content to chunk
|
||||
source_path: Source file path for reference
|
||||
file_type: File type for specialized handling
|
||||
metadata: Additional metadata to include
|
||||
|
||||
Returns:
|
||||
List of Chunk objects
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def chunk_type(self) -> ChunkType:
|
||||
"""Get the chunk type this chunker produces."""
|
||||
pass
|
||||
|
||||
def _create_chunk(
|
||||
self,
|
||||
content: str,
|
||||
source_path: str | None = None,
|
||||
start_line: int | None = None,
|
||||
end_line: int | None = None,
|
||||
file_type: FileType | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> Chunk:
|
||||
"""Create a chunk with token count."""
|
||||
token_count = self.count_tokens(content)
|
||||
return Chunk(
|
||||
content=content,
|
||||
chunk_type=self.chunk_type,
|
||||
file_type=file_type,
|
||||
source_path=source_path,
|
||||
start_line=start_line,
|
||||
end_line=end_line,
|
||||
metadata=metadata or {},
|
||||
token_count=token_count,
|
||||
)
|
||||
|
||||
|
||||
class ChunkerFactory:
|
||||
"""
|
||||
Factory for creating appropriate chunkers.
|
||||
|
||||
Selects the best chunker based on file type or content.
|
||||
"""
|
||||
|
||||
def __init__(self, settings: Settings | None = None) -> None:
|
||||
"""Initialize factory."""
|
||||
self._settings = settings or get_settings()
|
||||
self._chunkers: dict[str, BaseChunker] = {}
|
||||
|
||||
def _get_code_chunker(self) -> "BaseChunker":
|
||||
"""Get or create code chunker."""
|
||||
from chunking.code import CodeChunker
|
||||
|
||||
if "code" not in self._chunkers:
|
||||
self._chunkers["code"] = CodeChunker(
|
||||
chunk_size=self._settings.code_chunk_size,
|
||||
chunk_overlap=self._settings.code_chunk_overlap,
|
||||
settings=self._settings,
|
||||
)
|
||||
return self._chunkers["code"]
|
||||
|
||||
def _get_markdown_chunker(self) -> "BaseChunker":
|
||||
"""Get or create markdown chunker."""
|
||||
from chunking.markdown import MarkdownChunker
|
||||
|
||||
if "markdown" not in self._chunkers:
|
||||
self._chunkers["markdown"] = MarkdownChunker(
|
||||
chunk_size=self._settings.markdown_chunk_size,
|
||||
chunk_overlap=self._settings.markdown_chunk_overlap,
|
||||
settings=self._settings,
|
||||
)
|
||||
return self._chunkers["markdown"]
|
||||
|
||||
def _get_text_chunker(self) -> "BaseChunker":
|
||||
"""Get or create text chunker."""
|
||||
from chunking.text import TextChunker
|
||||
|
||||
if "text" not in self._chunkers:
|
||||
self._chunkers["text"] = TextChunker(
|
||||
chunk_size=self._settings.text_chunk_size,
|
||||
chunk_overlap=self._settings.text_chunk_overlap,
|
||||
settings=self._settings,
|
||||
)
|
||||
return self._chunkers["text"]
|
||||
|
||||
def get_chunker(
|
||||
self,
|
||||
file_type: FileType | None = None,
|
||||
chunk_type: ChunkType | None = None,
|
||||
) -> BaseChunker:
|
||||
"""
|
||||
Get appropriate chunker for content type.
|
||||
|
||||
Args:
|
||||
file_type: File type to chunk
|
||||
chunk_type: Explicit chunk type to use
|
||||
|
||||
Returns:
|
||||
Appropriate chunker instance
|
||||
"""
|
||||
# If explicit chunk type specified, use it
|
||||
if chunk_type:
|
||||
if chunk_type == ChunkType.CODE:
|
||||
return self._get_code_chunker()
|
||||
elif chunk_type == ChunkType.MARKDOWN:
|
||||
return self._get_markdown_chunker()
|
||||
else:
|
||||
return self._get_text_chunker()
|
||||
|
||||
# Otherwise, infer from file type
|
||||
if file_type:
|
||||
if file_type == FileType.MARKDOWN:
|
||||
return self._get_markdown_chunker()
|
||||
elif file_type in (FileType.TEXT, FileType.JSON, FileType.YAML, FileType.TOML):
|
||||
return self._get_text_chunker()
|
||||
else:
|
||||
# Code files
|
||||
return self._get_code_chunker()
|
||||
|
||||
# Default to text chunker
|
||||
return self._get_text_chunker()
|
||||
|
||||
def get_chunker_for_path(self, source_path: str) -> tuple[BaseChunker, FileType | None]:
|
||||
"""
|
||||
Get chunker based on file path extension.
|
||||
|
||||
Args:
|
||||
source_path: File path to chunk
|
||||
|
||||
Returns:
|
||||
Tuple of (chunker, file_type)
|
||||
"""
|
||||
# Extract extension
|
||||
ext = ""
|
||||
if "." in source_path:
|
||||
ext = "." + source_path.rsplit(".", 1)[-1].lower()
|
||||
|
||||
file_type = FILE_EXTENSION_MAP.get(ext)
|
||||
chunker = self.get_chunker(file_type=file_type)
|
||||
|
||||
return chunker, file_type
|
||||
|
||||
def chunk_content(
|
||||
self,
|
||||
content: str,
|
||||
source_path: str | None = None,
|
||||
file_type: FileType | None = None,
|
||||
chunk_type: ChunkType | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> list[Chunk]:
|
||||
"""
|
||||
Chunk content using appropriate strategy.
|
||||
|
||||
Args:
|
||||
content: Content to chunk
|
||||
source_path: Source file path
|
||||
file_type: File type
|
||||
chunk_type: Explicit chunk type
|
||||
metadata: Additional metadata
|
||||
|
||||
Returns:
|
||||
List of chunks
|
||||
"""
|
||||
# If we have a source path but no file type, infer it
|
||||
if source_path and not file_type:
|
||||
chunker, file_type = self.get_chunker_for_path(source_path)
|
||||
else:
|
||||
chunker = self.get_chunker(file_type=file_type, chunk_type=chunk_type)
|
||||
|
||||
try:
|
||||
chunks = chunker.chunk(
|
||||
content=content,
|
||||
source_path=source_path,
|
||||
file_type=file_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Chunked content into {len(chunks)} chunks "
|
||||
f"(type={chunker.chunk_type.value})"
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Chunking error: {e}")
|
||||
raise ChunkingError(
|
||||
message=f"Failed to chunk content: {e}",
|
||||
cause=e,
|
||||
)
|
||||
|
||||
|
||||
# Global chunker factory instance
|
||||
_chunker_factory: ChunkerFactory | None = None
|
||||
|
||||
|
||||
def get_chunker_factory() -> ChunkerFactory:
|
||||
"""Get the global chunker factory instance."""
|
||||
global _chunker_factory
|
||||
if _chunker_factory is None:
|
||||
_chunker_factory = ChunkerFactory()
|
||||
return _chunker_factory
|
||||
|
||||
|
||||
def reset_chunker_factory() -> None:
|
||||
"""Reset the global chunker factory (for testing)."""
|
||||
global _chunker_factory
|
||||
_chunker_factory = None
|
||||
410
mcp-servers/knowledge-base/chunking/code.py
Normal file
410
mcp-servers/knowledge-base/chunking/code.py
Normal file
@@ -0,0 +1,410 @@
|
||||
"""
|
||||
Code-aware chunking implementation.
|
||||
|
||||
Provides intelligent chunking for source code that respects
|
||||
function/class boundaries and preserves context.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from chunking.base import BaseChunker
|
||||
from config import Settings
|
||||
from models import Chunk, ChunkType, FileType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Language-specific patterns for detecting function/class definitions
|
||||
LANGUAGE_PATTERNS: dict[FileType, dict[str, re.Pattern[str]]] = {
|
||||
FileType.PYTHON: {
|
||||
"function": re.compile(r"^(\s*)(async\s+)?def\s+\w+", re.MULTILINE),
|
||||
"class": re.compile(r"^(\s*)class\s+\w+", re.MULTILINE),
|
||||
"decorator": re.compile(r"^(\s*)@\w+", re.MULTILINE),
|
||||
},
|
||||
FileType.JAVASCRIPT: {
|
||||
"function": re.compile(
|
||||
r"^(\s*)(export\s+)?(async\s+)?function\s+\w+|"
|
||||
r"^(\s*)(export\s+)?(const|let|var)\s+\w+\s*=\s*(async\s+)?\(",
|
||||
re.MULTILINE,
|
||||
),
|
||||
"class": re.compile(r"^(\s*)(export\s+)?class\s+\w+", re.MULTILINE),
|
||||
"arrow": re.compile(
|
||||
r"^(\s*)(export\s+)?(const|let|var)\s+\w+\s*=\s*(async\s+)?(\([^)]*\)|[^=])\s*=>",
|
||||
re.MULTILINE,
|
||||
),
|
||||
},
|
||||
FileType.TYPESCRIPT: {
|
||||
"function": re.compile(
|
||||
r"^(\s*)(export\s+)?(async\s+)?function\s+\w+|"
|
||||
r"^(\s*)(export\s+)?(const|let|var)\s+\w+\s*[:<]",
|
||||
re.MULTILINE,
|
||||
),
|
||||
"class": re.compile(r"^(\s*)(export\s+)?class\s+\w+", re.MULTILINE),
|
||||
"interface": re.compile(r"^(\s*)(export\s+)?interface\s+\w+", re.MULTILINE),
|
||||
"type": re.compile(r"^(\s*)(export\s+)?type\s+\w+", re.MULTILINE),
|
||||
},
|
||||
FileType.GO: {
|
||||
"function": re.compile(r"^func\s+(\([^)]+\)\s+)?\w+", re.MULTILINE),
|
||||
"struct": re.compile(r"^type\s+\w+\s+struct", re.MULTILINE),
|
||||
"interface": re.compile(r"^type\s+\w+\s+interface", re.MULTILINE),
|
||||
},
|
||||
FileType.RUST: {
|
||||
"function": re.compile(r"^(\s*)(pub\s+)?(async\s+)?fn\s+\w+", re.MULTILINE),
|
||||
"struct": re.compile(r"^(\s*)(pub\s+)?struct\s+\w+", re.MULTILINE),
|
||||
"impl": re.compile(r"^(\s*)impl\s+", re.MULTILINE),
|
||||
"trait": re.compile(r"^(\s*)(pub\s+)?trait\s+\w+", re.MULTILINE),
|
||||
},
|
||||
FileType.JAVA: {
|
||||
"method": re.compile(
|
||||
r"^(\s*)(public|private|protected)?\s*(static)?\s*\w+\s+\w+\s*\(",
|
||||
re.MULTILINE,
|
||||
),
|
||||
"class": re.compile(
|
||||
r"^(\s*)(public|private|protected)?\s*(abstract)?\s*class\s+\w+",
|
||||
re.MULTILINE,
|
||||
),
|
||||
"interface": re.compile(
|
||||
r"^(\s*)(public|private|protected)?\s*interface\s+\w+",
|
||||
re.MULTILINE,
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class CodeChunker(BaseChunker):
|
||||
"""
|
||||
Code-aware chunker that respects logical boundaries.
|
||||
|
||||
Features:
|
||||
- Detects function/class boundaries
|
||||
- Preserves decorator/annotation context
|
||||
- Handles nested structures
|
||||
- Falls back to line-based chunking when needed
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int,
|
||||
chunk_overlap: int,
|
||||
settings: Settings | None = None,
|
||||
) -> None:
|
||||
"""Initialize code chunker."""
|
||||
super().__init__(chunk_size, chunk_overlap, settings)
|
||||
|
||||
@property
|
||||
def chunk_type(self) -> ChunkType:
|
||||
"""Get chunk type."""
|
||||
return ChunkType.CODE
|
||||
|
||||
def chunk(
|
||||
self,
|
||||
content: str,
|
||||
source_path: str | None = None,
|
||||
file_type: FileType | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> list[Chunk]:
|
||||
"""
|
||||
Chunk code content.
|
||||
|
||||
Tries to respect function/class boundaries, falling back
|
||||
to line-based chunking if needed.
|
||||
"""
|
||||
if not content.strip():
|
||||
return []
|
||||
|
||||
metadata = metadata or {}
|
||||
lines = content.splitlines(keepends=True)
|
||||
|
||||
# Try language-aware chunking if we have patterns
|
||||
if file_type and file_type in LANGUAGE_PATTERNS:
|
||||
chunks = self._chunk_by_structure(
|
||||
content, lines, file_type, source_path, metadata
|
||||
)
|
||||
if chunks:
|
||||
return chunks
|
||||
|
||||
# Fall back to line-based chunking
|
||||
return self._chunk_by_lines(lines, source_path, file_type, metadata)
|
||||
|
||||
def _chunk_by_structure(
|
||||
self,
|
||||
content: str,
|
||||
lines: list[str],
|
||||
file_type: FileType,
|
||||
source_path: str | None,
|
||||
metadata: dict[str, Any],
|
||||
) -> list[Chunk]:
|
||||
"""
|
||||
Chunk by detecting code structure (functions, classes).
|
||||
|
||||
Returns empty list if structure detection isn't useful.
|
||||
"""
|
||||
patterns = LANGUAGE_PATTERNS.get(file_type, {})
|
||||
if not patterns:
|
||||
return []
|
||||
|
||||
# Find all structure boundaries
|
||||
boundaries: list[tuple[int, str]] = [] # (line_number, type)
|
||||
|
||||
for struct_type, pattern in patterns.items():
|
||||
for match in pattern.finditer(content):
|
||||
# Convert character position to line number
|
||||
line_num = content[:match.start()].count("\n")
|
||||
boundaries.append((line_num, struct_type))
|
||||
|
||||
if not boundaries:
|
||||
return []
|
||||
|
||||
# Sort boundaries by line number
|
||||
boundaries.sort(key=lambda x: x[0])
|
||||
|
||||
# If we have very few boundaries, line-based might be better
|
||||
if len(boundaries) < 3 and len(lines) > 50:
|
||||
return []
|
||||
|
||||
# Create chunks based on boundaries
|
||||
chunks: list[Chunk] = []
|
||||
current_start = 0
|
||||
|
||||
for _i, (line_num, struct_type) in enumerate(boundaries):
|
||||
# Check if we need to create a chunk before this boundary
|
||||
if line_num > current_start:
|
||||
# Include any preceding comments/decorators
|
||||
actual_start = self._find_context_start(lines, line_num)
|
||||
if actual_start < current_start:
|
||||
actual_start = current_start
|
||||
|
||||
chunk_lines = lines[current_start:line_num]
|
||||
chunk_content = "".join(chunk_lines)
|
||||
|
||||
if chunk_content.strip():
|
||||
token_count = self.count_tokens(chunk_content)
|
||||
|
||||
# If chunk is too large, split it
|
||||
if token_count > self.chunk_size * 1.5:
|
||||
sub_chunks = self._split_large_chunk(
|
||||
chunk_lines, current_start, source_path, file_type, metadata
|
||||
)
|
||||
chunks.extend(sub_chunks)
|
||||
elif token_count > 0:
|
||||
chunks.append(
|
||||
self._create_chunk(
|
||||
content=chunk_content.rstrip(),
|
||||
source_path=source_path,
|
||||
start_line=current_start + 1,
|
||||
end_line=line_num,
|
||||
file_type=file_type,
|
||||
metadata={**metadata, "structure_type": struct_type},
|
||||
)
|
||||
)
|
||||
|
||||
current_start = line_num
|
||||
|
||||
# Handle remaining content
|
||||
if current_start < len(lines):
|
||||
chunk_lines = lines[current_start:]
|
||||
chunk_content = "".join(chunk_lines)
|
||||
|
||||
if chunk_content.strip():
|
||||
token_count = self.count_tokens(chunk_content)
|
||||
|
||||
if token_count > self.chunk_size * 1.5:
|
||||
sub_chunks = self._split_large_chunk(
|
||||
chunk_lines, current_start, source_path, file_type, metadata
|
||||
)
|
||||
chunks.extend(sub_chunks)
|
||||
else:
|
||||
chunks.append(
|
||||
self._create_chunk(
|
||||
content=chunk_content.rstrip(),
|
||||
source_path=source_path,
|
||||
start_line=current_start + 1,
|
||||
end_line=len(lines),
|
||||
file_type=file_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
def _find_context_start(self, lines: list[str], line_num: int) -> int:
|
||||
"""Find the start of context (decorators, comments) before a line."""
|
||||
start = line_num
|
||||
|
||||
# Look backwards for decorators/comments
|
||||
for i in range(line_num - 1, max(0, line_num - 10), -1):
|
||||
line = lines[i].strip()
|
||||
if not line:
|
||||
continue
|
||||
if line.startswith(("#", "//", "/*", "*", "@", "'")):
|
||||
start = i
|
||||
else:
|
||||
break
|
||||
|
||||
return start
|
||||
|
||||
def _split_large_chunk(
|
||||
self,
|
||||
chunk_lines: list[str],
|
||||
base_line: int,
|
||||
source_path: str | None,
|
||||
file_type: FileType | None,
|
||||
metadata: dict[str, Any],
|
||||
) -> list[Chunk]:
|
||||
"""Split a large chunk into smaller pieces with overlap."""
|
||||
chunks: list[Chunk] = []
|
||||
current_lines: list[str] = []
|
||||
current_tokens = 0
|
||||
chunk_start = 0
|
||||
|
||||
for i, line in enumerate(chunk_lines):
|
||||
line_tokens = self.count_tokens(line)
|
||||
|
||||
if current_tokens + line_tokens > self.chunk_size and current_lines:
|
||||
# Create chunk
|
||||
chunk_content = "".join(current_lines).rstrip()
|
||||
chunks.append(
|
||||
self._create_chunk(
|
||||
content=chunk_content,
|
||||
source_path=source_path,
|
||||
start_line=base_line + chunk_start + 1,
|
||||
end_line=base_line + i,
|
||||
file_type=file_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
|
||||
# Calculate overlap
|
||||
overlap_tokens = 0
|
||||
overlap_lines: list[str] = []
|
||||
for j in range(len(current_lines) - 1, -1, -1):
|
||||
overlap_tokens += self.count_tokens(current_lines[j])
|
||||
if overlap_tokens >= self.chunk_overlap:
|
||||
overlap_lines = current_lines[j:]
|
||||
break
|
||||
|
||||
current_lines = overlap_lines
|
||||
current_tokens = sum(self.count_tokens(line) for line in current_lines)
|
||||
chunk_start = i - len(overlap_lines)
|
||||
|
||||
current_lines.append(line)
|
||||
current_tokens += line_tokens
|
||||
|
||||
# Final chunk
|
||||
if current_lines:
|
||||
chunk_content = "".join(current_lines).rstrip()
|
||||
if chunk_content.strip():
|
||||
chunks.append(
|
||||
self._create_chunk(
|
||||
content=chunk_content,
|
||||
source_path=source_path,
|
||||
start_line=base_line + chunk_start + 1,
|
||||
end_line=base_line + len(chunk_lines),
|
||||
file_type=file_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
def _chunk_by_lines(
|
||||
self,
|
||||
lines: list[str],
|
||||
source_path: str | None,
|
||||
file_type: FileType | None,
|
||||
metadata: dict[str, Any],
|
||||
) -> list[Chunk]:
|
||||
"""Chunk by lines with overlap."""
|
||||
chunks: list[Chunk] = []
|
||||
current_lines: list[str] = []
|
||||
current_tokens = 0
|
||||
chunk_start = 0
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
line_tokens = self.count_tokens(line)
|
||||
|
||||
# If this line alone exceeds chunk size, handle specially
|
||||
if line_tokens > self.chunk_size:
|
||||
# Flush current chunk
|
||||
if current_lines:
|
||||
chunk_content = "".join(current_lines).rstrip()
|
||||
if chunk_content.strip():
|
||||
chunks.append(
|
||||
self._create_chunk(
|
||||
content=chunk_content,
|
||||
source_path=source_path,
|
||||
start_line=chunk_start + 1,
|
||||
end_line=i,
|
||||
file_type=file_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
current_lines = []
|
||||
current_tokens = 0
|
||||
chunk_start = i
|
||||
|
||||
# Truncate and add long line
|
||||
truncated = self.truncate_to_tokens(line, self.chunk_size)
|
||||
chunks.append(
|
||||
self._create_chunk(
|
||||
content=truncated.rstrip(),
|
||||
source_path=source_path,
|
||||
start_line=i + 1,
|
||||
end_line=i + 1,
|
||||
file_type=file_type,
|
||||
metadata={**metadata, "truncated": True},
|
||||
)
|
||||
)
|
||||
chunk_start = i + 1
|
||||
continue
|
||||
|
||||
if current_tokens + line_tokens > self.chunk_size and current_lines:
|
||||
# Create chunk
|
||||
chunk_content = "".join(current_lines).rstrip()
|
||||
if chunk_content.strip():
|
||||
chunks.append(
|
||||
self._create_chunk(
|
||||
content=chunk_content,
|
||||
source_path=source_path,
|
||||
start_line=chunk_start + 1,
|
||||
end_line=i,
|
||||
file_type=file_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
|
||||
# Calculate overlap
|
||||
overlap_tokens = 0
|
||||
overlap_lines: list[str] = []
|
||||
for j in range(len(current_lines) - 1, -1, -1):
|
||||
line_tok = self.count_tokens(current_lines[j])
|
||||
if overlap_tokens + line_tok > self.chunk_overlap:
|
||||
break
|
||||
overlap_lines.insert(0, current_lines[j])
|
||||
overlap_tokens += line_tok
|
||||
|
||||
current_lines = overlap_lines
|
||||
current_tokens = overlap_tokens
|
||||
chunk_start = i - len(overlap_lines)
|
||||
|
||||
current_lines.append(line)
|
||||
current_tokens += line_tokens
|
||||
|
||||
# Final chunk
|
||||
if current_lines:
|
||||
chunk_content = "".join(current_lines).rstrip()
|
||||
if chunk_content.strip():
|
||||
chunks.append(
|
||||
self._create_chunk(
|
||||
content=chunk_content,
|
||||
source_path=source_path,
|
||||
start_line=chunk_start + 1,
|
||||
end_line=len(lines),
|
||||
file_type=file_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
|
||||
return chunks
|
||||
483
mcp-servers/knowledge-base/chunking/markdown.py
Normal file
483
mcp-servers/knowledge-base/chunking/markdown.py
Normal file
@@ -0,0 +1,483 @@
|
||||
"""
|
||||
Markdown-aware chunking implementation.
|
||||
|
||||
Provides intelligent chunking for markdown content that respects
|
||||
heading hierarchy and preserves document structure.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from chunking.base import BaseChunker
|
||||
from config import Settings
|
||||
from models import Chunk, ChunkType, FileType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Patterns for markdown elements
|
||||
HEADING_PATTERN = re.compile(r"^(#{1,6})\s+(.+)$", re.MULTILINE)
|
||||
CODE_BLOCK_PATTERN = re.compile(r"^```", re.MULTILINE)
|
||||
HR_PATTERN = re.compile(r"^(-{3,}|_{3,}|\*{3,})$", re.MULTILINE)
|
||||
|
||||
|
||||
class MarkdownChunker(BaseChunker):
|
||||
"""
|
||||
Markdown-aware chunker that respects document structure.
|
||||
|
||||
Features:
|
||||
- Respects heading hierarchy
|
||||
- Preserves heading context in chunks
|
||||
- Handles code blocks as units
|
||||
- Maintains list continuity where possible
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int,
|
||||
chunk_overlap: int,
|
||||
settings: Settings | None = None,
|
||||
) -> None:
|
||||
"""Initialize markdown chunker."""
|
||||
super().__init__(chunk_size, chunk_overlap, settings)
|
||||
|
||||
@property
|
||||
def chunk_type(self) -> ChunkType:
|
||||
"""Get chunk type."""
|
||||
return ChunkType.MARKDOWN
|
||||
|
||||
def chunk(
|
||||
self,
|
||||
content: str,
|
||||
source_path: str | None = None,
|
||||
file_type: FileType | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> list[Chunk]:
|
||||
"""
|
||||
Chunk markdown content.
|
||||
|
||||
Splits on heading boundaries and preserves heading context.
|
||||
"""
|
||||
if not content.strip():
|
||||
return []
|
||||
|
||||
metadata = metadata or {}
|
||||
file_type = file_type or FileType.MARKDOWN
|
||||
|
||||
# Split content into sections by headings
|
||||
sections = self._split_by_headings(content)
|
||||
|
||||
if not sections:
|
||||
# No headings, chunk as plain text
|
||||
return self._chunk_text_block(
|
||||
content, source_path, file_type, metadata, []
|
||||
)
|
||||
|
||||
chunks: list[Chunk] = []
|
||||
heading_stack: list[tuple[int, str]] = [] # (level, text)
|
||||
|
||||
for section in sections:
|
||||
heading_level = section.get("level", 0)
|
||||
heading_text = section.get("heading", "")
|
||||
section_content = section.get("content", "")
|
||||
start_line = section.get("start_line", 1)
|
||||
end_line = section.get("end_line", 1)
|
||||
|
||||
# Update heading stack
|
||||
if heading_level > 0:
|
||||
# Pop headings of equal or higher level
|
||||
while heading_stack and heading_stack[-1][0] >= heading_level:
|
||||
heading_stack.pop()
|
||||
heading_stack.append((heading_level, heading_text))
|
||||
|
||||
# Build heading context prefix
|
||||
heading_context = " > ".join(h[1] for h in heading_stack)
|
||||
|
||||
section_chunks = self._chunk_section(
|
||||
content=section_content,
|
||||
heading_context=heading_context,
|
||||
heading_level=heading_level,
|
||||
heading_text=heading_text,
|
||||
start_line=start_line,
|
||||
end_line=end_line,
|
||||
source_path=source_path,
|
||||
file_type=file_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
chunks.extend(section_chunks)
|
||||
|
||||
return chunks
|
||||
|
||||
def _split_by_headings(self, content: str) -> list[dict[str, Any]]:
|
||||
"""Split content into sections by headings."""
|
||||
sections: list[dict[str, Any]] = []
|
||||
lines = content.split("\n")
|
||||
|
||||
current_section: dict[str, Any] = {
|
||||
"level": 0,
|
||||
"heading": "",
|
||||
"content": "",
|
||||
"start_line": 1,
|
||||
"end_line": 1,
|
||||
}
|
||||
current_lines: list[str] = []
|
||||
in_code_block = False
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
# Track code blocks
|
||||
if line.strip().startswith("```"):
|
||||
in_code_block = not in_code_block
|
||||
current_lines.append(line)
|
||||
continue
|
||||
|
||||
# Skip heading detection in code blocks
|
||||
if in_code_block:
|
||||
current_lines.append(line)
|
||||
continue
|
||||
|
||||
# Check for heading
|
||||
heading_match = HEADING_PATTERN.match(line)
|
||||
if heading_match:
|
||||
# Save previous section
|
||||
if current_lines:
|
||||
current_section["content"] = "\n".join(current_lines)
|
||||
current_section["end_line"] = i
|
||||
if current_section["content"].strip():
|
||||
sections.append(current_section)
|
||||
|
||||
# Start new section
|
||||
level = len(heading_match.group(1))
|
||||
heading_text = heading_match.group(2).strip()
|
||||
current_section = {
|
||||
"level": level,
|
||||
"heading": heading_text,
|
||||
"content": "",
|
||||
"start_line": i + 1,
|
||||
"end_line": i + 1,
|
||||
}
|
||||
current_lines = [line]
|
||||
else:
|
||||
current_lines.append(line)
|
||||
|
||||
# Save final section
|
||||
if current_lines:
|
||||
current_section["content"] = "\n".join(current_lines)
|
||||
current_section["end_line"] = len(lines)
|
||||
if current_section["content"].strip():
|
||||
sections.append(current_section)
|
||||
|
||||
return sections
|
||||
|
||||
def _chunk_section(
|
||||
self,
|
||||
content: str,
|
||||
heading_context: str,
|
||||
heading_level: int,
|
||||
heading_text: str,
|
||||
start_line: int,
|
||||
end_line: int,
|
||||
source_path: str | None,
|
||||
file_type: FileType,
|
||||
metadata: dict[str, Any],
|
||||
) -> list[Chunk]:
|
||||
"""Chunk a single section of markdown."""
|
||||
if not content.strip():
|
||||
return []
|
||||
|
||||
token_count = self.count_tokens(content)
|
||||
|
||||
# If section fits in one chunk, return as-is
|
||||
if token_count <= self.chunk_size:
|
||||
section_metadata = {
|
||||
**metadata,
|
||||
"heading_context": heading_context,
|
||||
"heading_level": heading_level,
|
||||
"heading_text": heading_text,
|
||||
}
|
||||
return [
|
||||
self._create_chunk(
|
||||
content=content.strip(),
|
||||
source_path=source_path,
|
||||
start_line=start_line,
|
||||
end_line=end_line,
|
||||
file_type=file_type,
|
||||
metadata=section_metadata,
|
||||
)
|
||||
]
|
||||
|
||||
# Need to split - try to split on paragraphs first
|
||||
return self._chunk_text_block(
|
||||
content,
|
||||
source_path,
|
||||
file_type,
|
||||
{
|
||||
**metadata,
|
||||
"heading_context": heading_context,
|
||||
"heading_level": heading_level,
|
||||
"heading_text": heading_text,
|
||||
},
|
||||
_heading_stack=[(heading_level, heading_text)] if heading_text else [],
|
||||
base_line=start_line,
|
||||
)
|
||||
|
||||
def _chunk_text_block(
|
||||
self,
|
||||
content: str,
|
||||
source_path: str | None,
|
||||
file_type: FileType,
|
||||
metadata: dict[str, Any],
|
||||
_heading_stack: list[tuple[int, str]],
|
||||
base_line: int = 1,
|
||||
) -> list[Chunk]:
|
||||
"""Chunk a block of text by paragraphs."""
|
||||
# Split into paragraphs (separated by blank lines)
|
||||
paragraphs = self._split_into_paragraphs(content)
|
||||
|
||||
if not paragraphs:
|
||||
return []
|
||||
|
||||
chunks: list[Chunk] = []
|
||||
current_content: list[str] = []
|
||||
current_tokens = 0
|
||||
chunk_start_line = base_line
|
||||
|
||||
for para_info in paragraphs:
|
||||
para_content = para_info["content"]
|
||||
para_tokens = para_info["tokens"]
|
||||
para_start = para_info["start_line"]
|
||||
|
||||
# Handle very large paragraphs
|
||||
if para_tokens > self.chunk_size:
|
||||
# Flush current content
|
||||
if current_content:
|
||||
chunk_text = "\n\n".join(current_content)
|
||||
chunks.append(
|
||||
self._create_chunk(
|
||||
content=chunk_text.strip(),
|
||||
source_path=source_path,
|
||||
start_line=chunk_start_line,
|
||||
end_line=base_line + para_start - 1,
|
||||
file_type=file_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
current_content = []
|
||||
current_tokens = 0
|
||||
|
||||
# Split large paragraph by sentences/lines
|
||||
sub_chunks = self._split_large_paragraph(
|
||||
para_content,
|
||||
source_path,
|
||||
file_type,
|
||||
metadata,
|
||||
base_line + para_start,
|
||||
)
|
||||
chunks.extend(sub_chunks)
|
||||
chunk_start_line = base_line + para_info["end_line"] + 1
|
||||
continue
|
||||
|
||||
# Check if adding this paragraph exceeds limit
|
||||
if current_tokens + para_tokens > self.chunk_size and current_content:
|
||||
# Create chunk
|
||||
chunk_text = "\n\n".join(current_content)
|
||||
chunks.append(
|
||||
self._create_chunk(
|
||||
content=chunk_text.strip(),
|
||||
source_path=source_path,
|
||||
start_line=chunk_start_line,
|
||||
end_line=base_line + para_start - 1,
|
||||
file_type=file_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
|
||||
# Overlap: include last paragraph if it fits
|
||||
if current_content and self.count_tokens(current_content[-1]) <= self.chunk_overlap:
|
||||
current_content = [current_content[-1]]
|
||||
current_tokens = self.count_tokens(current_content[-1])
|
||||
else:
|
||||
current_content = []
|
||||
current_tokens = 0
|
||||
|
||||
chunk_start_line = base_line + para_start
|
||||
|
||||
current_content.append(para_content)
|
||||
current_tokens += para_tokens
|
||||
|
||||
# Final chunk
|
||||
if current_content:
|
||||
chunk_text = "\n\n".join(current_content)
|
||||
end_line_num = base_line + (paragraphs[-1]["end_line"] if paragraphs else 0)
|
||||
chunks.append(
|
||||
self._create_chunk(
|
||||
content=chunk_text.strip(),
|
||||
source_path=source_path,
|
||||
start_line=chunk_start_line,
|
||||
end_line=end_line_num,
|
||||
file_type=file_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
def _split_into_paragraphs(self, content: str) -> list[dict[str, Any]]:
|
||||
"""Split content into paragraphs with metadata."""
|
||||
paragraphs: list[dict[str, Any]] = []
|
||||
lines = content.split("\n")
|
||||
|
||||
current_para: list[str] = []
|
||||
para_start = 0
|
||||
in_code_block = False
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
# Track code blocks (keep them as single units)
|
||||
if line.strip().startswith("```"):
|
||||
if in_code_block:
|
||||
# End of code block
|
||||
current_para.append(line)
|
||||
in_code_block = False
|
||||
else:
|
||||
# Start of code block - save previous paragraph
|
||||
if current_para and any(p.strip() for p in current_para):
|
||||
para_content = "\n".join(current_para)
|
||||
paragraphs.append({
|
||||
"content": para_content,
|
||||
"tokens": self.count_tokens(para_content),
|
||||
"start_line": para_start,
|
||||
"end_line": i - 1,
|
||||
})
|
||||
current_para = [line]
|
||||
para_start = i
|
||||
in_code_block = True
|
||||
continue
|
||||
|
||||
if in_code_block:
|
||||
current_para.append(line)
|
||||
continue
|
||||
|
||||
# Empty line indicates paragraph break
|
||||
if not line.strip():
|
||||
if current_para and any(p.strip() for p in current_para):
|
||||
para_content = "\n".join(current_para)
|
||||
paragraphs.append({
|
||||
"content": para_content,
|
||||
"tokens": self.count_tokens(para_content),
|
||||
"start_line": para_start,
|
||||
"end_line": i - 1,
|
||||
})
|
||||
current_para = []
|
||||
para_start = i + 1
|
||||
else:
|
||||
if not current_para:
|
||||
para_start = i
|
||||
current_para.append(line)
|
||||
|
||||
# Final paragraph
|
||||
if current_para and any(p.strip() for p in current_para):
|
||||
para_content = "\n".join(current_para)
|
||||
paragraphs.append({
|
||||
"content": para_content,
|
||||
"tokens": self.count_tokens(para_content),
|
||||
"start_line": para_start,
|
||||
"end_line": len(lines) - 1,
|
||||
})
|
||||
|
||||
return paragraphs
|
||||
|
||||
def _split_large_paragraph(
|
||||
self,
|
||||
content: str,
|
||||
source_path: str | None,
|
||||
file_type: FileType,
|
||||
metadata: dict[str, Any],
|
||||
base_line: int,
|
||||
) -> list[Chunk]:
|
||||
"""Split a large paragraph into smaller chunks."""
|
||||
# Try splitting by sentences
|
||||
sentences = self._split_into_sentences(content)
|
||||
|
||||
chunks: list[Chunk] = []
|
||||
current_content: list[str] = []
|
||||
current_tokens = 0
|
||||
|
||||
for sentence in sentences:
|
||||
sentence_tokens = self.count_tokens(sentence)
|
||||
|
||||
# If single sentence is too large, truncate
|
||||
if sentence_tokens > self.chunk_size:
|
||||
if current_content:
|
||||
chunk_text = " ".join(current_content)
|
||||
chunks.append(
|
||||
self._create_chunk(
|
||||
content=chunk_text.strip(),
|
||||
source_path=source_path,
|
||||
start_line=base_line,
|
||||
end_line=base_line,
|
||||
file_type=file_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
current_content = []
|
||||
current_tokens = 0
|
||||
|
||||
truncated = self.truncate_to_tokens(sentence, self.chunk_size)
|
||||
chunks.append(
|
||||
self._create_chunk(
|
||||
content=truncated.strip(),
|
||||
source_path=source_path,
|
||||
start_line=base_line,
|
||||
end_line=base_line,
|
||||
file_type=file_type,
|
||||
metadata={**metadata, "truncated": True},
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if current_tokens + sentence_tokens > self.chunk_size and current_content:
|
||||
chunk_text = " ".join(current_content)
|
||||
chunks.append(
|
||||
self._create_chunk(
|
||||
content=chunk_text.strip(),
|
||||
source_path=source_path,
|
||||
start_line=base_line,
|
||||
end_line=base_line,
|
||||
file_type=file_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
|
||||
# Overlap with last sentence
|
||||
if current_content and self.count_tokens(current_content[-1]) <= self.chunk_overlap:
|
||||
current_content = [current_content[-1]]
|
||||
current_tokens = self.count_tokens(current_content[-1])
|
||||
else:
|
||||
current_content = []
|
||||
current_tokens = 0
|
||||
|
||||
current_content.append(sentence)
|
||||
current_tokens += sentence_tokens
|
||||
|
||||
# Final chunk
|
||||
if current_content:
|
||||
chunk_text = " ".join(current_content)
|
||||
chunks.append(
|
||||
self._create_chunk(
|
||||
content=chunk_text.strip(),
|
||||
source_path=source_path,
|
||||
start_line=base_line,
|
||||
end_line=base_line,
|
||||
file_type=file_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
def _split_into_sentences(self, text: str) -> list[str]:
|
||||
"""Split text into sentences."""
|
||||
# Simple sentence splitting on common terminators
|
||||
# More sophisticated splitting could use nltk or spacy
|
||||
sentence_endings = re.compile(r"(?<=[.!?])\s+")
|
||||
sentences = sentence_endings.split(text)
|
||||
return [s.strip() for s in sentences if s.strip()]
|
||||
389
mcp-servers/knowledge-base/chunking/text.py
Normal file
389
mcp-servers/knowledge-base/chunking/text.py
Normal file
@@ -0,0 +1,389 @@
|
||||
"""
|
||||
Plain text chunking implementation.
|
||||
|
||||
Provides simple text chunking with paragraph and sentence
|
||||
boundary detection.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from chunking.base import BaseChunker
|
||||
from config import Settings
|
||||
from models import Chunk, ChunkType, FileType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TextChunker(BaseChunker):
|
||||
"""
|
||||
Plain text chunker with paragraph awareness.
|
||||
|
||||
Features:
|
||||
- Splits on paragraph boundaries
|
||||
- Falls back to sentence/word boundaries
|
||||
- Configurable overlap for context preservation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int,
|
||||
chunk_overlap: int,
|
||||
settings: Settings | None = None,
|
||||
) -> None:
|
||||
"""Initialize text chunker."""
|
||||
super().__init__(chunk_size, chunk_overlap, settings)
|
||||
|
||||
@property
|
||||
def chunk_type(self) -> ChunkType:
|
||||
"""Get chunk type."""
|
||||
return ChunkType.TEXT
|
||||
|
||||
def chunk(
|
||||
self,
|
||||
content: str,
|
||||
source_path: str | None = None,
|
||||
file_type: FileType | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> list[Chunk]:
|
||||
"""
|
||||
Chunk plain text content.
|
||||
|
||||
Tries paragraph boundaries first, then sentences.
|
||||
"""
|
||||
if not content.strip():
|
||||
return []
|
||||
|
||||
metadata = metadata or {}
|
||||
|
||||
# Check if content fits in a single chunk
|
||||
total_tokens = self.count_tokens(content)
|
||||
if total_tokens <= self.chunk_size:
|
||||
return [
|
||||
self._create_chunk(
|
||||
content=content.strip(),
|
||||
source_path=source_path,
|
||||
start_line=1,
|
||||
end_line=content.count("\n") + 1,
|
||||
file_type=file_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
]
|
||||
|
||||
# Try paragraph-based chunking
|
||||
paragraphs = self._split_paragraphs(content)
|
||||
if len(paragraphs) > 1:
|
||||
return self._chunk_by_paragraphs(
|
||||
paragraphs, source_path, file_type, metadata
|
||||
)
|
||||
|
||||
# Fall back to sentence-based chunking
|
||||
return self._chunk_by_sentences(
|
||||
content, source_path, file_type, metadata
|
||||
)
|
||||
|
||||
def _split_paragraphs(self, content: str) -> list[dict[str, Any]]:
|
||||
"""Split content into paragraphs."""
|
||||
paragraphs: list[dict[str, Any]] = []
|
||||
|
||||
# Split on double newlines (paragraph boundaries)
|
||||
raw_paras = re.split(r"\n\s*\n", content)
|
||||
|
||||
line_num = 1
|
||||
for para in raw_paras:
|
||||
para = para.strip()
|
||||
if not para:
|
||||
continue
|
||||
|
||||
para_lines = para.count("\n") + 1
|
||||
paragraphs.append({
|
||||
"content": para,
|
||||
"tokens": self.count_tokens(para),
|
||||
"start_line": line_num,
|
||||
"end_line": line_num + para_lines - 1,
|
||||
})
|
||||
line_num += para_lines + 1 # +1 for blank line between paragraphs
|
||||
|
||||
return paragraphs
|
||||
|
||||
def _chunk_by_paragraphs(
|
||||
self,
|
||||
paragraphs: list[dict[str, Any]],
|
||||
source_path: str | None,
|
||||
file_type: FileType | None,
|
||||
metadata: dict[str, Any],
|
||||
) -> list[Chunk]:
|
||||
"""Chunk by combining paragraphs up to size limit."""
|
||||
chunks: list[Chunk] = []
|
||||
current_paras: list[str] = []
|
||||
current_tokens = 0
|
||||
chunk_start = paragraphs[0]["start_line"] if paragraphs else 1
|
||||
chunk_end = chunk_start
|
||||
|
||||
for para in paragraphs:
|
||||
para_content = para["content"]
|
||||
para_tokens = para["tokens"]
|
||||
|
||||
# Handle paragraphs larger than chunk size
|
||||
if para_tokens > self.chunk_size:
|
||||
# Flush current content
|
||||
if current_paras:
|
||||
chunk_text = "\n\n".join(current_paras)
|
||||
chunks.append(
|
||||
self._create_chunk(
|
||||
content=chunk_text,
|
||||
source_path=source_path,
|
||||
start_line=chunk_start,
|
||||
end_line=chunk_end,
|
||||
file_type=file_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
current_paras = []
|
||||
current_tokens = 0
|
||||
|
||||
# Split large paragraph
|
||||
sub_chunks = self._split_large_text(
|
||||
para_content,
|
||||
source_path,
|
||||
file_type,
|
||||
metadata,
|
||||
para["start_line"],
|
||||
)
|
||||
chunks.extend(sub_chunks)
|
||||
chunk_start = para["end_line"] + 1
|
||||
chunk_end = chunk_start
|
||||
continue
|
||||
|
||||
# Check if adding paragraph exceeds limit
|
||||
if current_tokens + para_tokens > self.chunk_size and current_paras:
|
||||
chunk_text = "\n\n".join(current_paras)
|
||||
chunks.append(
|
||||
self._create_chunk(
|
||||
content=chunk_text,
|
||||
source_path=source_path,
|
||||
start_line=chunk_start,
|
||||
end_line=chunk_end,
|
||||
file_type=file_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
|
||||
# Overlap: keep last paragraph if small enough
|
||||
overlap_para = None
|
||||
if current_paras and self.count_tokens(current_paras[-1]) <= self.chunk_overlap:
|
||||
overlap_para = current_paras[-1]
|
||||
|
||||
current_paras = [overlap_para] if overlap_para else []
|
||||
current_tokens = self.count_tokens(overlap_para) if overlap_para else 0
|
||||
chunk_start = para["start_line"]
|
||||
|
||||
current_paras.append(para_content)
|
||||
current_tokens += para_tokens
|
||||
chunk_end = para["end_line"]
|
||||
|
||||
# Final chunk
|
||||
if current_paras:
|
||||
chunk_text = "\n\n".join(current_paras)
|
||||
chunks.append(
|
||||
self._create_chunk(
|
||||
content=chunk_text,
|
||||
source_path=source_path,
|
||||
start_line=chunk_start,
|
||||
end_line=chunk_end,
|
||||
file_type=file_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
def _chunk_by_sentences(
|
||||
self,
|
||||
content: str,
|
||||
source_path: str | None,
|
||||
file_type: FileType | None,
|
||||
metadata: dict[str, Any],
|
||||
) -> list[Chunk]:
|
||||
"""Chunk by sentences."""
|
||||
sentences = self._split_sentences(content)
|
||||
|
||||
if not sentences:
|
||||
return []
|
||||
|
||||
chunks: list[Chunk] = []
|
||||
current_sentences: list[str] = []
|
||||
current_tokens = 0
|
||||
|
||||
for sentence in sentences:
|
||||
sentence_tokens = self.count_tokens(sentence)
|
||||
|
||||
# Handle sentences larger than chunk size
|
||||
if sentence_tokens > self.chunk_size:
|
||||
if current_sentences:
|
||||
chunk_text = " ".join(current_sentences)
|
||||
chunks.append(
|
||||
self._create_chunk(
|
||||
content=chunk_text,
|
||||
source_path=source_path,
|
||||
start_line=1,
|
||||
end_line=1,
|
||||
file_type=file_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
current_sentences = []
|
||||
current_tokens = 0
|
||||
|
||||
# Truncate large sentence
|
||||
truncated = self.truncate_to_tokens(sentence, self.chunk_size)
|
||||
chunks.append(
|
||||
self._create_chunk(
|
||||
content=truncated,
|
||||
source_path=source_path,
|
||||
start_line=1,
|
||||
end_line=1,
|
||||
file_type=file_type,
|
||||
metadata={**metadata, "truncated": True},
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
# Check if adding sentence exceeds limit
|
||||
if current_tokens + sentence_tokens > self.chunk_size and current_sentences:
|
||||
chunk_text = " ".join(current_sentences)
|
||||
chunks.append(
|
||||
self._create_chunk(
|
||||
content=chunk_text,
|
||||
source_path=source_path,
|
||||
start_line=1,
|
||||
end_line=1,
|
||||
file_type=file_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
|
||||
# Overlap: keep last sentence if small enough
|
||||
overlap = None
|
||||
if current_sentences and self.count_tokens(current_sentences[-1]) <= self.chunk_overlap:
|
||||
overlap = current_sentences[-1]
|
||||
|
||||
current_sentences = [overlap] if overlap else []
|
||||
current_tokens = self.count_tokens(overlap) if overlap else 0
|
||||
|
||||
current_sentences.append(sentence)
|
||||
current_tokens += sentence_tokens
|
||||
|
||||
# Final chunk
|
||||
if current_sentences:
|
||||
chunk_text = " ".join(current_sentences)
|
||||
chunks.append(
|
||||
self._create_chunk(
|
||||
content=chunk_text,
|
||||
source_path=source_path,
|
||||
start_line=1,
|
||||
end_line=content.count("\n") + 1,
|
||||
file_type=file_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
def _split_sentences(self, text: str) -> list[str]:
|
||||
"""Split text into sentences."""
|
||||
# Handle common sentence endings
|
||||
# This is a simple approach - production might use nltk or spacy
|
||||
sentence_pattern = re.compile(
|
||||
r"(?<=[.!?])\s+(?=[A-Z])|" # Standard sentence ending
|
||||
r"(?<=[.!?])\s*$|" # End of text
|
||||
r"(?<=\n)\s*(?=\S)" # Newlines as boundaries
|
||||
)
|
||||
|
||||
sentences = sentence_pattern.split(text)
|
||||
return [s.strip() for s in sentences if s.strip()]
|
||||
|
||||
def _split_large_text(
|
||||
self,
|
||||
text: str,
|
||||
source_path: str | None,
|
||||
file_type: FileType | None,
|
||||
metadata: dict[str, Any],
|
||||
base_line: int,
|
||||
) -> list[Chunk]:
|
||||
"""Split text that exceeds chunk size."""
|
||||
# First try sentences
|
||||
sentences = self._split_sentences(text)
|
||||
|
||||
if len(sentences) > 1:
|
||||
return self._chunk_by_sentences(
|
||||
text, source_path, file_type, metadata
|
||||
)
|
||||
|
||||
# Fall back to word-based splitting
|
||||
return self._chunk_by_words(
|
||||
text, source_path, file_type, metadata, base_line
|
||||
)
|
||||
|
||||
def _chunk_by_words(
|
||||
self,
|
||||
text: str,
|
||||
source_path: str | None,
|
||||
file_type: FileType | None,
|
||||
metadata: dict[str, Any],
|
||||
base_line: int,
|
||||
) -> list[Chunk]:
|
||||
"""Last resort: chunk by words."""
|
||||
words = text.split()
|
||||
chunks: list[Chunk] = []
|
||||
current_words: list[str] = []
|
||||
current_tokens = 0
|
||||
|
||||
for word in words:
|
||||
word_tokens = self.count_tokens(word + " ")
|
||||
|
||||
if current_tokens + word_tokens > self.chunk_size and current_words:
|
||||
chunk_text = " ".join(current_words)
|
||||
chunks.append(
|
||||
self._create_chunk(
|
||||
content=chunk_text,
|
||||
source_path=source_path,
|
||||
start_line=base_line,
|
||||
end_line=base_line,
|
||||
file_type=file_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
|
||||
# Word overlap
|
||||
overlap_count = 0
|
||||
overlap_words: list[str] = []
|
||||
for w in reversed(current_words):
|
||||
w_tokens = self.count_tokens(w + " ")
|
||||
if overlap_count + w_tokens > self.chunk_overlap:
|
||||
break
|
||||
overlap_words.insert(0, w)
|
||||
overlap_count += w_tokens
|
||||
|
||||
current_words = overlap_words
|
||||
current_tokens = overlap_count
|
||||
|
||||
current_words.append(word)
|
||||
current_tokens += word_tokens
|
||||
|
||||
# Final chunk
|
||||
if current_words:
|
||||
chunk_text = " ".join(current_words)
|
||||
chunks.append(
|
||||
self._create_chunk(
|
||||
content=chunk_text,
|
||||
source_path=source_path,
|
||||
start_line=base_line,
|
||||
end_line=base_line,
|
||||
file_type=file_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
|
||||
return chunks
|
||||
331
mcp-servers/knowledge-base/collection_manager.py
Normal file
331
mcp-servers/knowledge-base/collection_manager.py
Normal file
@@ -0,0 +1,331 @@
|
||||
"""
|
||||
Collection management for Knowledge Base MCP Server.
|
||||
|
||||
Provides operations for managing document collections including
|
||||
ingestion, deletion, and statistics.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from chunking.base import ChunkerFactory, get_chunker_factory
|
||||
from config import Settings, get_settings
|
||||
from database import DatabaseManager, get_database_manager
|
||||
from embeddings import EmbeddingGenerator, get_embedding_generator
|
||||
from models import (
|
||||
ChunkType,
|
||||
CollectionStats,
|
||||
DeleteRequest,
|
||||
DeleteResult,
|
||||
FileType,
|
||||
IngestRequest,
|
||||
IngestResult,
|
||||
ListCollectionsResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CollectionManager:
|
||||
"""
|
||||
Manages knowledge base collections.
|
||||
|
||||
Handles document ingestion, chunking, embedding generation,
|
||||
and collection operations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: Settings | None = None,
|
||||
database: DatabaseManager | None = None,
|
||||
embeddings: EmbeddingGenerator | None = None,
|
||||
chunker_factory: ChunkerFactory | None = None,
|
||||
) -> None:
|
||||
"""Initialize collection manager."""
|
||||
self._settings = settings or get_settings()
|
||||
self._database = database
|
||||
self._embeddings = embeddings
|
||||
self._chunker_factory = chunker_factory
|
||||
|
||||
@property
|
||||
def database(self) -> DatabaseManager:
|
||||
"""Get database manager."""
|
||||
if self._database is None:
|
||||
self._database = get_database_manager()
|
||||
return self._database
|
||||
|
||||
@property
|
||||
def embeddings(self) -> EmbeddingGenerator:
|
||||
"""Get embedding generator."""
|
||||
if self._embeddings is None:
|
||||
self._embeddings = get_embedding_generator()
|
||||
return self._embeddings
|
||||
|
||||
@property
|
||||
def chunker_factory(self) -> ChunkerFactory:
|
||||
"""Get chunker factory."""
|
||||
if self._chunker_factory is None:
|
||||
self._chunker_factory = get_chunker_factory()
|
||||
return self._chunker_factory
|
||||
|
||||
async def ingest(self, request: IngestRequest) -> IngestResult:
|
||||
"""
|
||||
Ingest content into the knowledge base.
|
||||
|
||||
Chunks the content, generates embeddings, and stores them.
|
||||
|
||||
Args:
|
||||
request: Ingest request with content and options
|
||||
|
||||
Returns:
|
||||
Ingest result with created chunk IDs
|
||||
"""
|
||||
try:
|
||||
# Chunk the content
|
||||
chunks = self.chunker_factory.chunk_content(
|
||||
content=request.content,
|
||||
source_path=request.source_path,
|
||||
file_type=request.file_type,
|
||||
chunk_type=request.chunk_type,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
|
||||
if not chunks:
|
||||
return IngestResult(
|
||||
success=True,
|
||||
chunks_created=0,
|
||||
embeddings_generated=0,
|
||||
source_path=request.source_path,
|
||||
collection=request.collection,
|
||||
chunk_ids=[],
|
||||
)
|
||||
|
||||
# Extract chunk contents for embedding
|
||||
chunk_texts = [chunk.content for chunk in chunks]
|
||||
|
||||
# Generate embeddings
|
||||
embeddings_list = await self.embeddings.generate_batch(
|
||||
texts=chunk_texts,
|
||||
project_id=request.project_id,
|
||||
agent_id=request.agent_id,
|
||||
)
|
||||
|
||||
# Store embeddings
|
||||
chunk_ids: list[str] = []
|
||||
for chunk, embedding in zip(chunks, embeddings_list, strict=True):
|
||||
# Build metadata with chunk info
|
||||
chunk_metadata = {
|
||||
**request.metadata,
|
||||
**chunk.metadata,
|
||||
"token_count": chunk.token_count,
|
||||
}
|
||||
|
||||
chunk_id = await self.database.store_embedding(
|
||||
project_id=request.project_id,
|
||||
collection=request.collection,
|
||||
content=chunk.content,
|
||||
embedding=embedding,
|
||||
chunk_type=chunk.chunk_type,
|
||||
source_path=chunk.source_path or request.source_path,
|
||||
start_line=chunk.start_line,
|
||||
end_line=chunk.end_line,
|
||||
file_type=chunk.file_type or request.file_type,
|
||||
metadata=chunk_metadata,
|
||||
)
|
||||
chunk_ids.append(chunk_id)
|
||||
|
||||
logger.info(
|
||||
f"Ingested {len(chunks)} chunks into collection '{request.collection}' "
|
||||
f"for project {request.project_id}"
|
||||
)
|
||||
|
||||
return IngestResult(
|
||||
success=True,
|
||||
chunks_created=len(chunks),
|
||||
embeddings_generated=len(embeddings_list),
|
||||
source_path=request.source_path,
|
||||
collection=request.collection,
|
||||
chunk_ids=chunk_ids,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ingest error: {e}")
|
||||
return IngestResult(
|
||||
success=False,
|
||||
chunks_created=0,
|
||||
embeddings_generated=0,
|
||||
source_path=request.source_path,
|
||||
collection=request.collection,
|
||||
chunk_ids=[],
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def delete(self, request: DeleteRequest) -> DeleteResult:
|
||||
"""
|
||||
Delete content from the knowledge base.
|
||||
|
||||
Supports deletion by source path, collection, or chunk IDs.
|
||||
|
||||
Args:
|
||||
request: Delete request with target specification
|
||||
|
||||
Returns:
|
||||
Delete result with count of deleted chunks
|
||||
"""
|
||||
try:
|
||||
deleted_count = 0
|
||||
|
||||
if request.chunk_ids:
|
||||
# Delete specific chunks
|
||||
deleted_count = await self.database.delete_by_ids(
|
||||
project_id=request.project_id,
|
||||
chunk_ids=request.chunk_ids,
|
||||
)
|
||||
elif request.source_path:
|
||||
# Delete by source path
|
||||
deleted_count = await self.database.delete_by_source(
|
||||
project_id=request.project_id,
|
||||
source_path=request.source_path,
|
||||
collection=request.collection,
|
||||
)
|
||||
elif request.collection:
|
||||
# Delete entire collection
|
||||
deleted_count = await self.database.delete_collection(
|
||||
project_id=request.project_id,
|
||||
collection=request.collection,
|
||||
)
|
||||
else:
|
||||
return DeleteResult(
|
||||
success=False,
|
||||
chunks_deleted=0,
|
||||
error="Must specify chunk_ids, source_path, or collection",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Deleted {deleted_count} chunks for project {request.project_id}"
|
||||
)
|
||||
|
||||
return DeleteResult(
|
||||
success=True,
|
||||
chunks_deleted=deleted_count,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Delete error: {e}")
|
||||
return DeleteResult(
|
||||
success=False,
|
||||
chunks_deleted=0,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def list_collections(self, project_id: str) -> ListCollectionsResponse:
|
||||
"""
|
||||
List all collections for a project.
|
||||
|
||||
Args:
|
||||
project_id: Project ID
|
||||
|
||||
Returns:
|
||||
List of collection info
|
||||
"""
|
||||
collections = await self.database.list_collections(project_id)
|
||||
|
||||
return ListCollectionsResponse(
|
||||
project_id=project_id,
|
||||
collections=collections,
|
||||
total_collections=len(collections),
|
||||
)
|
||||
|
||||
async def get_collection_stats(
|
||||
self,
|
||||
project_id: str,
|
||||
collection: str,
|
||||
) -> CollectionStats:
|
||||
"""
|
||||
Get statistics for a collection.
|
||||
|
||||
Args:
|
||||
project_id: Project ID
|
||||
collection: Collection name
|
||||
|
||||
Returns:
|
||||
Collection statistics
|
||||
"""
|
||||
return await self.database.get_collection_stats(project_id, collection)
|
||||
|
||||
async def update_document(
|
||||
self,
|
||||
project_id: str,
|
||||
agent_id: str,
|
||||
source_path: str,
|
||||
content: str,
|
||||
collection: str = "default",
|
||||
chunk_type: ChunkType = ChunkType.TEXT,
|
||||
file_type: FileType | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> IngestResult:
|
||||
"""
|
||||
Update a document by replacing existing chunks.
|
||||
|
||||
Deletes existing chunks for the source path and ingests new content.
|
||||
|
||||
Args:
|
||||
project_id: Project ID
|
||||
agent_id: Agent ID
|
||||
source_path: Source file path
|
||||
content: New content
|
||||
collection: Collection name
|
||||
chunk_type: Type of content
|
||||
file_type: File type for code chunking
|
||||
metadata: Additional metadata
|
||||
|
||||
Returns:
|
||||
Ingest result
|
||||
"""
|
||||
# First delete existing chunks for this source
|
||||
await self.database.delete_by_source(
|
||||
project_id=project_id,
|
||||
source_path=source_path,
|
||||
collection=collection,
|
||||
)
|
||||
|
||||
# Then ingest new content
|
||||
request = IngestRequest(
|
||||
project_id=project_id,
|
||||
agent_id=agent_id,
|
||||
content=content,
|
||||
source_path=source_path,
|
||||
collection=collection,
|
||||
chunk_type=chunk_type,
|
||||
file_type=file_type,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
return await self.ingest(request)
|
||||
|
||||
async def cleanup_expired(self) -> int:
|
||||
"""
|
||||
Remove expired embeddings from all collections.
|
||||
|
||||
Returns:
|
||||
Number of embeddings removed
|
||||
"""
|
||||
return await self.database.cleanup_expired()
|
||||
|
||||
|
||||
# Global collection manager instance (lazy initialization)
|
||||
_collection_manager: CollectionManager | None = None
|
||||
|
||||
|
||||
def get_collection_manager() -> CollectionManager:
|
||||
"""Get the global collection manager instance."""
|
||||
global _collection_manager
|
||||
if _collection_manager is None:
|
||||
_collection_manager = CollectionManager()
|
||||
return _collection_manager
|
||||
|
||||
|
||||
def reset_collection_manager() -> None:
|
||||
"""Reset the global collection manager (for testing)."""
|
||||
global _collection_manager
|
||||
_collection_manager = None
|
||||
138
mcp-servers/knowledge-base/config.py
Normal file
138
mcp-servers/knowledge-base/config.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""
|
||||
Configuration for Knowledge Base MCP Server.
|
||||
|
||||
Uses pydantic-settings for environment variable loading.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment."""
|
||||
|
||||
# Server settings
|
||||
host: str = Field(default="0.0.0.0", description="Server host")
|
||||
port: int = Field(default=8002, description="Server port")
|
||||
debug: bool = Field(default=False, description="Debug mode")
|
||||
|
||||
# Database settings
|
||||
database_url: str = Field(
|
||||
default="postgresql://postgres:postgres@localhost:5432/syndarix",
|
||||
description="PostgreSQL connection URL with pgvector extension",
|
||||
)
|
||||
database_pool_size: int = Field(default=10, description="Connection pool size")
|
||||
database_pool_max_overflow: int = Field(
|
||||
default=20, description="Max overflow connections"
|
||||
)
|
||||
|
||||
# Redis settings
|
||||
redis_url: str = Field(
|
||||
default="redis://localhost:6379/0",
|
||||
description="Redis connection URL",
|
||||
)
|
||||
|
||||
# LLM Gateway settings (for embeddings)
|
||||
llm_gateway_url: str = Field(
|
||||
default="http://localhost:8001",
|
||||
description="LLM Gateway MCP server URL",
|
||||
)
|
||||
|
||||
# Embedding settings
|
||||
embedding_model: str = Field(
|
||||
default="text-embedding-3-large",
|
||||
description="Default embedding model",
|
||||
)
|
||||
embedding_dimension: int = Field(
|
||||
default=1536,
|
||||
description="Embedding vector dimension",
|
||||
)
|
||||
embedding_batch_size: int = Field(
|
||||
default=100,
|
||||
description="Max texts per embedding batch",
|
||||
)
|
||||
embedding_cache_ttl: int = Field(
|
||||
default=86400,
|
||||
description="Embedding cache TTL in seconds (24 hours)",
|
||||
)
|
||||
|
||||
# Chunking settings
|
||||
code_chunk_size: int = Field(
|
||||
default=500,
|
||||
description="Target tokens per code chunk",
|
||||
)
|
||||
code_chunk_overlap: int = Field(
|
||||
default=50,
|
||||
description="Token overlap between code chunks",
|
||||
)
|
||||
markdown_chunk_size: int = Field(
|
||||
default=800,
|
||||
description="Target tokens per markdown chunk",
|
||||
)
|
||||
markdown_chunk_overlap: int = Field(
|
||||
default=100,
|
||||
description="Token overlap between markdown chunks",
|
||||
)
|
||||
text_chunk_size: int = Field(
|
||||
default=400,
|
||||
description="Target tokens per text chunk",
|
||||
)
|
||||
text_chunk_overlap: int = Field(
|
||||
default=50,
|
||||
description="Token overlap between text chunks",
|
||||
)
|
||||
|
||||
# Search settings
|
||||
search_default_limit: int = Field(
|
||||
default=10,
|
||||
description="Default number of search results",
|
||||
)
|
||||
search_max_limit: int = Field(
|
||||
default=100,
|
||||
description="Maximum number of search results",
|
||||
)
|
||||
semantic_threshold: float = Field(
|
||||
default=0.7,
|
||||
description="Minimum similarity score for semantic search",
|
||||
)
|
||||
hybrid_semantic_weight: float = Field(
|
||||
default=0.7,
|
||||
description="Weight for semantic results in hybrid search",
|
||||
)
|
||||
hybrid_keyword_weight: float = Field(
|
||||
default=0.3,
|
||||
description="Weight for keyword results in hybrid search",
|
||||
)
|
||||
|
||||
# Storage settings
|
||||
embedding_ttl_days: int = Field(
|
||||
default=30,
|
||||
description="TTL for embedding records in days (0 = no expiry)",
|
||||
)
|
||||
|
||||
model_config = {"env_prefix": "KB_", "env_file": ".env", "extra": "ignore"}
|
||||
|
||||
|
||||
# Global settings instance (lazy initialization)
|
||||
_settings: Settings | None = None
|
||||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
"""Get the global settings instance."""
|
||||
global _settings
|
||||
if _settings is None:
|
||||
_settings = Settings()
|
||||
return _settings
|
||||
|
||||
|
||||
def reset_settings() -> None:
|
||||
"""Reset the global settings (for testing)."""
|
||||
global _settings
|
||||
_settings = None
|
||||
|
||||
|
||||
def is_test_mode() -> bool:
|
||||
"""Check if running in test mode."""
|
||||
return os.getenv("IS_TEST", "").lower() in ("true", "1", "yes")
|
||||
774
mcp-servers/knowledge-base/database.py
Normal file
774
mcp-servers/knowledge-base/database.py
Normal file
@@ -0,0 +1,774 @@
|
||||
"""
|
||||
Database management for Knowledge Base MCP Server.
|
||||
|
||||
Handles PostgreSQL connections with pgvector extension for
|
||||
vector similarity search operations.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
import asyncpg
|
||||
from pgvector.asyncpg import register_vector
|
||||
|
||||
from config import Settings, get_settings
|
||||
from exceptions import (
|
||||
CollectionNotFoundError,
|
||||
DatabaseConnectionError,
|
||||
DatabaseQueryError,
|
||||
ErrorCode,
|
||||
KnowledgeBaseError,
|
||||
)
|
||||
from models import (
|
||||
ChunkType,
|
||||
CollectionInfo,
|
||||
CollectionStats,
|
||||
FileType,
|
||||
KnowledgeEmbedding,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""
|
||||
Manages PostgreSQL connections and vector operations.
|
||||
|
||||
Uses asyncpg for async operations and pgvector for
|
||||
vector similarity search.
|
||||
"""
|
||||
|
||||
def __init__(self, settings: Settings | None = None) -> None:
|
||||
"""Initialize database manager."""
|
||||
self._settings = settings or get_settings()
|
||||
self._pool: asyncpg.Pool | None = None # type: ignore[type-arg]
|
||||
|
||||
@property
|
||||
def pool(self) -> asyncpg.Pool: # type: ignore[type-arg]
|
||||
"""Get connection pool, raising if not initialized."""
|
||||
if self._pool is None:
|
||||
raise DatabaseConnectionError("Database pool not initialized")
|
||||
return self._pool
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize connection pool and create schema."""
|
||||
try:
|
||||
self._pool = await asyncpg.create_pool(
|
||||
self._settings.database_url,
|
||||
min_size=2,
|
||||
max_size=self._settings.database_pool_size,
|
||||
max_inactive_connection_lifetime=300,
|
||||
init=self._init_connection,
|
||||
)
|
||||
logger.info("Database pool created successfully")
|
||||
|
||||
# Create schema
|
||||
await self._create_schema()
|
||||
logger.info("Database schema initialized")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize database: {e}")
|
||||
raise DatabaseConnectionError(
|
||||
message=f"Failed to initialize database: {e}",
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def _init_connection(self, conn: asyncpg.Connection) -> None: # type: ignore[type-arg]
|
||||
"""Initialize a connection with pgvector support."""
|
||||
await register_vector(conn)
|
||||
|
||||
async def _create_schema(self) -> None:
|
||||
"""Create database schema if not exists."""
|
||||
async with self.pool.acquire() as conn:
|
||||
# Enable pgvector extension
|
||||
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||
|
||||
# Create main embeddings table
|
||||
await conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS knowledge_embeddings (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
project_id VARCHAR(255) NOT NULL,
|
||||
collection VARCHAR(255) NOT NULL DEFAULT 'default',
|
||||
content TEXT NOT NULL,
|
||||
embedding vector(1536),
|
||||
chunk_type VARCHAR(50) NOT NULL,
|
||||
source_path TEXT,
|
||||
start_line INTEGER,
|
||||
end_line INTEGER,
|
||||
file_type VARCHAR(50),
|
||||
metadata JSONB DEFAULT '{}',
|
||||
content_hash VARCHAR(64),
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
expires_at TIMESTAMPTZ
|
||||
)
|
||||
""")
|
||||
|
||||
# Create indexes for common queries
|
||||
await conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_embeddings_project_collection
|
||||
ON knowledge_embeddings(project_id, collection)
|
||||
""")
|
||||
|
||||
await conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_embeddings_source_path
|
||||
ON knowledge_embeddings(project_id, source_path)
|
||||
""")
|
||||
|
||||
await conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_embeddings_content_hash
|
||||
ON knowledge_embeddings(project_id, content_hash)
|
||||
""")
|
||||
|
||||
await conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_embeddings_chunk_type
|
||||
ON knowledge_embeddings(project_id, chunk_type)
|
||||
""")
|
||||
|
||||
# Create HNSW index for vector similarity search
|
||||
# This dramatically improves search performance
|
||||
await conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_embeddings_vector_hnsw
|
||||
ON knowledge_embeddings
|
||||
USING hnsw (embedding vector_cosine_ops)
|
||||
WITH (m = 16, ef_construction = 64)
|
||||
""")
|
||||
|
||||
# Create GIN index for full-text search
|
||||
await conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_embeddings_content_fts
|
||||
ON knowledge_embeddings
|
||||
USING gin(to_tsvector('english', content))
|
||||
""")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the connection pool."""
|
||||
if self._pool:
|
||||
await self._pool.close()
|
||||
self._pool = None
|
||||
logger.info("Database pool closed")
|
||||
|
||||
@asynccontextmanager
|
||||
async def acquire(self) -> Any:
|
||||
"""Acquire a connection from the pool."""
|
||||
async with self.pool.acquire() as conn:
|
||||
yield conn
|
||||
|
||||
@staticmethod
|
||||
def compute_content_hash(content: str) -> str:
|
||||
"""Compute SHA-256 hash of content for deduplication."""
|
||||
return hashlib.sha256(content.encode()).hexdigest()
|
||||
|
||||
async def store_embedding(
|
||||
self,
|
||||
project_id: str,
|
||||
collection: str,
|
||||
content: str,
|
||||
embedding: list[float],
|
||||
chunk_type: ChunkType,
|
||||
source_path: str | None = None,
|
||||
start_line: int | None = None,
|
||||
end_line: int | None = None,
|
||||
file_type: FileType | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Store an embedding in the database.
|
||||
|
||||
Returns:
|
||||
The ID of the stored embedding.
|
||||
"""
|
||||
content_hash = self.compute_content_hash(content)
|
||||
metadata = metadata or {}
|
||||
|
||||
# Calculate expiration if TTL is set
|
||||
expires_at = None
|
||||
if self._settings.embedding_ttl_days > 0:
|
||||
expires_at = datetime.now(UTC) + timedelta(
|
||||
days=self._settings.embedding_ttl_days
|
||||
)
|
||||
|
||||
try:
|
||||
async with self.acquire() as conn:
|
||||
# Check for duplicate content
|
||||
existing = await conn.fetchval(
|
||||
"""
|
||||
SELECT id FROM knowledge_embeddings
|
||||
WHERE project_id = $1 AND collection = $2 AND content_hash = $3
|
||||
""",
|
||||
project_id,
|
||||
collection,
|
||||
content_hash,
|
||||
)
|
||||
|
||||
if existing:
|
||||
# Update existing embedding
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE knowledge_embeddings
|
||||
SET embedding = $1, updated_at = NOW(), expires_at = $2,
|
||||
metadata = $3, source_path = $4, start_line = $5,
|
||||
end_line = $6, file_type = $7
|
||||
WHERE id = $8
|
||||
""",
|
||||
embedding,
|
||||
expires_at,
|
||||
metadata,
|
||||
source_path,
|
||||
start_line,
|
||||
end_line,
|
||||
file_type.value if file_type else None,
|
||||
existing,
|
||||
)
|
||||
logger.debug(f"Updated existing embedding: {existing}")
|
||||
return str(existing)
|
||||
|
||||
# Insert new embedding
|
||||
embedding_id = await conn.fetchval(
|
||||
"""
|
||||
INSERT INTO knowledge_embeddings
|
||||
(project_id, collection, content, embedding, chunk_type,
|
||||
source_path, start_line, end_line, file_type, metadata,
|
||||
content_hash, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
|
||||
RETURNING id
|
||||
""",
|
||||
project_id,
|
||||
collection,
|
||||
content,
|
||||
embedding,
|
||||
chunk_type.value,
|
||||
source_path,
|
||||
start_line,
|
||||
end_line,
|
||||
file_type.value if file_type else None,
|
||||
metadata,
|
||||
content_hash,
|
||||
expires_at,
|
||||
)
|
||||
logger.debug(f"Stored new embedding: {embedding_id}")
|
||||
return str(embedding_id)
|
||||
|
||||
except asyncpg.PostgresError as e:
|
||||
logger.error(f"Database error storing embedding: {e}")
|
||||
raise DatabaseQueryError(
|
||||
message=f"Failed to store embedding: {e}",
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def store_embeddings_batch(
|
||||
self,
|
||||
embeddings: list[tuple[str, str, str, list[float], ChunkType, dict[str, Any]]],
|
||||
) -> list[str]:
|
||||
"""
|
||||
Store multiple embeddings in a batch.
|
||||
|
||||
Args:
|
||||
embeddings: List of (project_id, collection, content, embedding, chunk_type, metadata)
|
||||
|
||||
Returns:
|
||||
List of created embedding IDs.
|
||||
"""
|
||||
if not embeddings:
|
||||
return []
|
||||
|
||||
ids = []
|
||||
expires_at = None
|
||||
if self._settings.embedding_ttl_days > 0:
|
||||
expires_at = datetime.now(UTC) + timedelta(
|
||||
days=self._settings.embedding_ttl_days
|
||||
)
|
||||
|
||||
try:
|
||||
async with self.acquire() as conn:
|
||||
for project_id, collection, content, embedding, chunk_type, metadata in embeddings:
|
||||
content_hash = self.compute_content_hash(content)
|
||||
source_path = metadata.get("source_path")
|
||||
start_line = metadata.get("start_line")
|
||||
end_line = metadata.get("end_line")
|
||||
file_type = metadata.get("file_type")
|
||||
|
||||
embedding_id = await conn.fetchval(
|
||||
"""
|
||||
INSERT INTO knowledge_embeddings
|
||||
(project_id, collection, content, embedding, chunk_type,
|
||||
source_path, start_line, end_line, file_type, metadata,
|
||||
content_hash, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
|
||||
ON CONFLICT DO NOTHING
|
||||
RETURNING id
|
||||
""",
|
||||
project_id,
|
||||
collection,
|
||||
content,
|
||||
embedding,
|
||||
chunk_type.value,
|
||||
source_path,
|
||||
start_line,
|
||||
end_line,
|
||||
file_type,
|
||||
metadata,
|
||||
content_hash,
|
||||
expires_at,
|
||||
)
|
||||
if embedding_id:
|
||||
ids.append(str(embedding_id))
|
||||
|
||||
logger.info(f"Stored {len(ids)} embeddings in batch")
|
||||
return ids
|
||||
|
||||
except asyncpg.PostgresError as e:
|
||||
logger.error(f"Database error in batch store: {e}")
|
||||
raise DatabaseQueryError(
|
||||
message=f"Failed to store embeddings batch: {e}",
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def semantic_search(
|
||||
self,
|
||||
project_id: str,
|
||||
query_embedding: list[float],
|
||||
collection: str | None = None,
|
||||
limit: int = 10,
|
||||
threshold: float = 0.7,
|
||||
file_types: list[FileType] | None = None,
|
||||
) -> list[tuple[KnowledgeEmbedding, float]]:
|
||||
"""
|
||||
Perform semantic (vector) search.
|
||||
|
||||
Returns:
|
||||
List of (embedding, similarity_score) tuples.
|
||||
"""
|
||||
try:
|
||||
async with self.acquire() as conn:
|
||||
# Build query with optional filters
|
||||
query = """
|
||||
SELECT
|
||||
id, project_id, collection, content, embedding,
|
||||
chunk_type, source_path, start_line, end_line,
|
||||
file_type, metadata, content_hash, created_at,
|
||||
updated_at, expires_at,
|
||||
1 - (embedding <=> $1) as similarity
|
||||
FROM knowledge_embeddings
|
||||
WHERE project_id = $2
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
"""
|
||||
params: list[Any] = [query_embedding, project_id]
|
||||
param_idx = 3
|
||||
|
||||
if collection:
|
||||
query += f" AND collection = ${param_idx}"
|
||||
params.append(collection)
|
||||
param_idx += 1
|
||||
|
||||
if file_types:
|
||||
file_type_values = [ft.value for ft in file_types]
|
||||
query += f" AND file_type = ANY(${param_idx})"
|
||||
params.append(file_type_values)
|
||||
param_idx += 1
|
||||
|
||||
query += f"""
|
||||
HAVING 1 - (embedding <=> $1) >= ${param_idx}
|
||||
ORDER BY similarity DESC
|
||||
LIMIT ${param_idx + 1}
|
||||
"""
|
||||
params.extend([threshold, limit])
|
||||
|
||||
rows = await conn.fetch(query, *params)
|
||||
|
||||
results = []
|
||||
for row in rows:
|
||||
embedding = KnowledgeEmbedding(
|
||||
id=str(row["id"]),
|
||||
project_id=row["project_id"],
|
||||
collection=row["collection"],
|
||||
content=row["content"],
|
||||
embedding=list(row["embedding"]),
|
||||
chunk_type=ChunkType(row["chunk_type"]),
|
||||
source_path=row["source_path"],
|
||||
start_line=row["start_line"],
|
||||
end_line=row["end_line"],
|
||||
file_type=FileType(row["file_type"]) if row["file_type"] else None,
|
||||
metadata=row["metadata"] or {},
|
||||
content_hash=row["content_hash"],
|
||||
created_at=row["created_at"],
|
||||
updated_at=row["updated_at"],
|
||||
expires_at=row["expires_at"],
|
||||
)
|
||||
results.append((embedding, float(row["similarity"])))
|
||||
|
||||
return results
|
||||
|
||||
except asyncpg.PostgresError as e:
|
||||
logger.error(f"Semantic search error: {e}")
|
||||
raise DatabaseQueryError(
|
||||
message=f"Semantic search failed: {e}",
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def keyword_search(
|
||||
self,
|
||||
project_id: str,
|
||||
query: str,
|
||||
collection: str | None = None,
|
||||
limit: int = 10,
|
||||
file_types: list[FileType] | None = None,
|
||||
) -> list[tuple[KnowledgeEmbedding, float]]:
|
||||
"""
|
||||
Perform full-text keyword search.
|
||||
|
||||
Returns:
|
||||
List of (embedding, relevance_score) tuples.
|
||||
"""
|
||||
try:
|
||||
async with self.acquire() as conn:
|
||||
# Build query with optional filters
|
||||
sql = """
|
||||
SELECT
|
||||
id, project_id, collection, content, embedding,
|
||||
chunk_type, source_path, start_line, end_line,
|
||||
file_type, metadata, content_hash, created_at,
|
||||
updated_at, expires_at,
|
||||
ts_rank(to_tsvector('english', content),
|
||||
plainto_tsquery('english', $1)) as relevance
|
||||
FROM knowledge_embeddings
|
||||
WHERE project_id = $2
|
||||
AND to_tsvector('english', content) @@ plainto_tsquery('english', $1)
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
"""
|
||||
params: list[Any] = [query, project_id]
|
||||
param_idx = 3
|
||||
|
||||
if collection:
|
||||
sql += f" AND collection = ${param_idx}"
|
||||
params.append(collection)
|
||||
param_idx += 1
|
||||
|
||||
if file_types:
|
||||
file_type_values = [ft.value for ft in file_types]
|
||||
sql += f" AND file_type = ANY(${param_idx})"
|
||||
params.append(file_type_values)
|
||||
param_idx += 1
|
||||
|
||||
sql += f" ORDER BY relevance DESC LIMIT ${param_idx}"
|
||||
params.append(limit)
|
||||
|
||||
rows = await conn.fetch(sql, *params)
|
||||
|
||||
results = []
|
||||
for row in rows:
|
||||
embedding = KnowledgeEmbedding(
|
||||
id=str(row["id"]),
|
||||
project_id=row["project_id"],
|
||||
collection=row["collection"],
|
||||
content=row["content"],
|
||||
embedding=list(row["embedding"]) if row["embedding"] else [],
|
||||
chunk_type=ChunkType(row["chunk_type"]),
|
||||
source_path=row["source_path"],
|
||||
start_line=row["start_line"],
|
||||
end_line=row["end_line"],
|
||||
file_type=FileType(row["file_type"]) if row["file_type"] else None,
|
||||
metadata=row["metadata"] or {},
|
||||
content_hash=row["content_hash"],
|
||||
created_at=row["created_at"],
|
||||
updated_at=row["updated_at"],
|
||||
expires_at=row["expires_at"],
|
||||
)
|
||||
# Normalize relevance to 0-1 scale (approximate)
|
||||
normalized_score = min(1.0, float(row["relevance"]))
|
||||
results.append((embedding, normalized_score))
|
||||
|
||||
return results
|
||||
|
||||
except asyncpg.PostgresError as e:
|
||||
logger.error(f"Keyword search error: {e}")
|
||||
raise DatabaseQueryError(
|
||||
message=f"Keyword search failed: {e}",
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def delete_by_source(
|
||||
self,
|
||||
project_id: str,
|
||||
source_path: str,
|
||||
collection: str | None = None,
|
||||
) -> int:
|
||||
"""Delete all embeddings for a source path."""
|
||||
try:
|
||||
async with self.acquire() as conn:
|
||||
if collection:
|
||||
result = await conn.execute(
|
||||
"""
|
||||
DELETE FROM knowledge_embeddings
|
||||
WHERE project_id = $1 AND source_path = $2 AND collection = $3
|
||||
""",
|
||||
project_id,
|
||||
source_path,
|
||||
collection,
|
||||
)
|
||||
else:
|
||||
result = await conn.execute(
|
||||
"""
|
||||
DELETE FROM knowledge_embeddings
|
||||
WHERE project_id = $1 AND source_path = $2
|
||||
""",
|
||||
project_id,
|
||||
source_path,
|
||||
)
|
||||
# Parse "DELETE N" result
|
||||
count = int(result.split()[-1])
|
||||
logger.info(f"Deleted {count} embeddings for source: {source_path}")
|
||||
return count
|
||||
|
||||
except asyncpg.PostgresError as e:
|
||||
logger.error(f"Delete error: {e}")
|
||||
raise DatabaseQueryError(
|
||||
message=f"Failed to delete by source: {e}",
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def delete_collection(
|
||||
self,
|
||||
project_id: str,
|
||||
collection: str,
|
||||
) -> int:
|
||||
"""Delete an entire collection."""
|
||||
try:
|
||||
async with self.acquire() as conn:
|
||||
result = await conn.execute(
|
||||
"""
|
||||
DELETE FROM knowledge_embeddings
|
||||
WHERE project_id = $1 AND collection = $2
|
||||
""",
|
||||
project_id,
|
||||
collection,
|
||||
)
|
||||
count = int(result.split()[-1])
|
||||
logger.info(f"Deleted collection {collection}: {count} embeddings")
|
||||
return count
|
||||
|
||||
except asyncpg.PostgresError as e:
|
||||
logger.error(f"Delete collection error: {e}")
|
||||
raise DatabaseQueryError(
|
||||
message=f"Failed to delete collection: {e}",
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def delete_by_ids(
|
||||
self,
|
||||
project_id: str,
|
||||
chunk_ids: list[str],
|
||||
) -> int:
|
||||
"""Delete specific embeddings by ID."""
|
||||
if not chunk_ids:
|
||||
return 0
|
||||
|
||||
try:
|
||||
# Convert string IDs to UUIDs
|
||||
uuids = [uuid.UUID(cid) for cid in chunk_ids]
|
||||
|
||||
async with self.acquire() as conn:
|
||||
result = await conn.execute(
|
||||
"""
|
||||
DELETE FROM knowledge_embeddings
|
||||
WHERE project_id = $1 AND id = ANY($2)
|
||||
""",
|
||||
project_id,
|
||||
uuids,
|
||||
)
|
||||
count = int(result.split()[-1])
|
||||
logger.info(f"Deleted {count} embeddings by ID")
|
||||
return count
|
||||
|
||||
except ValueError as e:
|
||||
raise KnowledgeBaseError(
|
||||
message=f"Invalid chunk ID format: {e}",
|
||||
code=ErrorCode.INVALID_REQUEST,
|
||||
)
|
||||
except asyncpg.PostgresError as e:
|
||||
logger.error(f"Delete by IDs error: {e}")
|
||||
raise DatabaseQueryError(
|
||||
message=f"Failed to delete by IDs: {e}",
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def list_collections(
|
||||
self,
|
||||
project_id: str,
|
||||
) -> list[CollectionInfo]:
|
||||
"""List all collections for a project."""
|
||||
try:
|
||||
async with self.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT
|
||||
collection,
|
||||
COUNT(*) as chunk_count,
|
||||
COALESCE(SUM((metadata->>'token_count')::int), 0) as total_tokens,
|
||||
ARRAY_AGG(DISTINCT file_type) FILTER (WHERE file_type IS NOT NULL) as file_types,
|
||||
MIN(created_at) as created_at,
|
||||
MAX(updated_at) as updated_at
|
||||
FROM knowledge_embeddings
|
||||
WHERE project_id = $1
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
GROUP BY collection
|
||||
ORDER BY collection
|
||||
""",
|
||||
project_id,
|
||||
)
|
||||
|
||||
return [
|
||||
CollectionInfo(
|
||||
name=row["collection"],
|
||||
project_id=project_id,
|
||||
chunk_count=row["chunk_count"],
|
||||
total_tokens=row["total_tokens"] or 0,
|
||||
file_types=row["file_types"] or [],
|
||||
created_at=row["created_at"],
|
||||
updated_at=row["updated_at"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
except asyncpg.PostgresError as e:
|
||||
logger.error(f"List collections error: {e}")
|
||||
raise DatabaseQueryError(
|
||||
message=f"Failed to list collections: {e}",
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def get_collection_stats(
|
||||
self,
|
||||
project_id: str,
|
||||
collection: str,
|
||||
) -> CollectionStats:
|
||||
"""Get detailed statistics for a collection."""
|
||||
try:
|
||||
async with self.acquire() as conn:
|
||||
# Check if collection exists
|
||||
exists = await conn.fetchval(
|
||||
"""
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM knowledge_embeddings
|
||||
WHERE project_id = $1 AND collection = $2
|
||||
)
|
||||
""",
|
||||
project_id,
|
||||
collection,
|
||||
)
|
||||
|
||||
if not exists:
|
||||
raise CollectionNotFoundError(collection, project_id)
|
||||
|
||||
# Get stats
|
||||
row = await conn.fetchrow(
|
||||
"""
|
||||
SELECT
|
||||
COUNT(*) as chunk_count,
|
||||
COUNT(DISTINCT source_path) as unique_sources,
|
||||
COALESCE(SUM((metadata->>'token_count')::int), 0) as total_tokens,
|
||||
COALESCE(AVG(LENGTH(content)), 0) as avg_chunk_size,
|
||||
MIN(created_at) as oldest_chunk,
|
||||
MAX(created_at) as newest_chunk
|
||||
FROM knowledge_embeddings
|
||||
WHERE project_id = $1 AND collection = $2
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
""",
|
||||
project_id,
|
||||
collection,
|
||||
)
|
||||
|
||||
# Get chunk type breakdown
|
||||
chunk_rows = await conn.fetch(
|
||||
"""
|
||||
SELECT chunk_type, COUNT(*) as count
|
||||
FROM knowledge_embeddings
|
||||
WHERE project_id = $1 AND collection = $2
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
GROUP BY chunk_type
|
||||
""",
|
||||
project_id,
|
||||
collection,
|
||||
)
|
||||
chunk_types = {r["chunk_type"]: r["count"] for r in chunk_rows}
|
||||
|
||||
# Get file type breakdown
|
||||
file_rows = await conn.fetch(
|
||||
"""
|
||||
SELECT file_type, COUNT(*) as count
|
||||
FROM knowledge_embeddings
|
||||
WHERE project_id = $1 AND collection = $2
|
||||
AND file_type IS NOT NULL
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
GROUP BY file_type
|
||||
""",
|
||||
project_id,
|
||||
collection,
|
||||
)
|
||||
file_types = {r["file_type"]: r["count"] for r in file_rows}
|
||||
|
||||
return CollectionStats(
|
||||
collection=collection,
|
||||
project_id=project_id,
|
||||
chunk_count=row["chunk_count"],
|
||||
unique_sources=row["unique_sources"],
|
||||
total_tokens=row["total_tokens"] or 0,
|
||||
avg_chunk_size=float(row["avg_chunk_size"] or 0),
|
||||
chunk_types=chunk_types,
|
||||
file_types=file_types,
|
||||
oldest_chunk=row["oldest_chunk"],
|
||||
newest_chunk=row["newest_chunk"],
|
||||
)
|
||||
|
||||
except CollectionNotFoundError:
|
||||
raise
|
||||
except asyncpg.PostgresError as e:
|
||||
logger.error(f"Get collection stats error: {e}")
|
||||
raise DatabaseQueryError(
|
||||
message=f"Failed to get collection stats: {e}",
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def cleanup_expired(self) -> int:
|
||||
"""Remove expired embeddings."""
|
||||
try:
|
||||
async with self.acquire() as conn:
|
||||
result = await conn.execute(
|
||||
"""
|
||||
DELETE FROM knowledge_embeddings
|
||||
WHERE expires_at IS NOT NULL AND expires_at < NOW()
|
||||
"""
|
||||
)
|
||||
count = int(result.split()[-1])
|
||||
if count > 0:
|
||||
logger.info(f"Cleaned up {count} expired embeddings")
|
||||
return count
|
||||
|
||||
except asyncpg.PostgresError as e:
|
||||
logger.error(f"Cleanup error: {e}")
|
||||
raise DatabaseQueryError(
|
||||
message=f"Failed to cleanup expired: {e}",
|
||||
cause=e,
|
||||
)
|
||||
|
||||
|
||||
# Global database manager instance (lazy initialization)
|
||||
_db_manager: DatabaseManager | None = None
|
||||
|
||||
|
||||
def get_database_manager() -> DatabaseManager:
|
||||
"""Get the global database manager instance."""
|
||||
global _db_manager
|
||||
if _db_manager is None:
|
||||
_db_manager = DatabaseManager()
|
||||
return _db_manager
|
||||
|
||||
|
||||
def reset_database_manager() -> None:
|
||||
"""Reset the global database manager (for testing)."""
|
||||
global _db_manager
|
||||
_db_manager = None
|
||||
426
mcp-servers/knowledge-base/embeddings.py
Normal file
426
mcp-servers/knowledge-base/embeddings.py
Normal file
@@ -0,0 +1,426 @@
|
||||
"""
|
||||
Embedding generation for Knowledge Base MCP Server.
|
||||
|
||||
Generates vector embeddings via the LLM Gateway MCP server
|
||||
with caching support using Redis.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
import redis.asyncio as redis
|
||||
|
||||
from config import Settings, get_settings
|
||||
from exceptions import (
|
||||
EmbeddingDimensionMismatchError,
|
||||
EmbeddingGenerationError,
|
||||
ErrorCode,
|
||||
KnowledgeBaseError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingGenerator:
|
||||
"""
|
||||
Generates embeddings via LLM Gateway.
|
||||
|
||||
Features:
|
||||
- Batched embedding generation
|
||||
- Redis caching for deduplication
|
||||
- Automatic retry on transient failures
|
||||
"""
|
||||
|
||||
def __init__(self, settings: Settings | None = None) -> None:
|
||||
"""Initialize embedding generator."""
|
||||
self._settings = settings or get_settings()
|
||||
self._redis: redis.Redis | None = None # type: ignore[type-arg]
|
||||
self._http_client: httpx.AsyncClient | None = None
|
||||
|
||||
@property
|
||||
def redis_client(self) -> redis.Redis: # type: ignore[type-arg]
|
||||
"""Get Redis client, raising if not initialized."""
|
||||
if self._redis is None:
|
||||
raise KnowledgeBaseError(
|
||||
message="Redis client not initialized",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
)
|
||||
return self._redis
|
||||
|
||||
@property
|
||||
def http_client(self) -> httpx.AsyncClient:
|
||||
"""Get HTTP client, raising if not initialized."""
|
||||
if self._http_client is None:
|
||||
raise KnowledgeBaseError(
|
||||
message="HTTP client not initialized",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
)
|
||||
return self._http_client
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize Redis and HTTP clients."""
|
||||
try:
|
||||
self._redis = redis.from_url(
|
||||
self._settings.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
# Test connection
|
||||
await self._redis.ping() # type: ignore[misc]
|
||||
logger.info("Redis connection established for embedding cache")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis connection failed, caching disabled: {e}")
|
||||
self._redis = None
|
||||
|
||||
self._http_client = httpx.AsyncClient(
|
||||
base_url=self._settings.llm_gateway_url,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
)
|
||||
logger.info("Embedding generator initialized")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close connections."""
|
||||
if self._redis:
|
||||
await self._redis.close()
|
||||
self._redis = None
|
||||
|
||||
if self._http_client:
|
||||
await self._http_client.aclose()
|
||||
self._http_client = None
|
||||
|
||||
logger.info("Embedding generator closed")
|
||||
|
||||
def _cache_key(self, text: str) -> str:
|
||||
"""Generate cache key for a text."""
|
||||
text_hash = hashlib.sha256(text.encode()).hexdigest()[:32]
|
||||
model = self._settings.embedding_model
|
||||
return f"kb:emb:{model}:{text_hash}"
|
||||
|
||||
async def _get_cached(self, text: str) -> list[float] | None:
|
||||
"""Get cached embedding if available."""
|
||||
if not self._redis:
|
||||
return None
|
||||
|
||||
try:
|
||||
key = self._cache_key(text)
|
||||
cached = await self._redis.get(key)
|
||||
if cached:
|
||||
logger.debug(f"Cache hit for embedding: {key[:20]}...")
|
||||
return json.loads(cached)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache read error: {e}")
|
||||
return None
|
||||
|
||||
async def _set_cached(self, text: str, embedding: list[float]) -> None:
|
||||
"""Cache an embedding."""
|
||||
if not self._redis:
|
||||
return
|
||||
|
||||
try:
|
||||
key = self._cache_key(text)
|
||||
await self._redis.setex(
|
||||
key,
|
||||
self._settings.embedding_cache_ttl,
|
||||
json.dumps(embedding),
|
||||
)
|
||||
logger.debug(f"Cached embedding: {key[:20]}...")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache write error: {e}")
|
||||
|
||||
async def _get_cached_batch(
|
||||
self, texts: list[str]
|
||||
) -> tuple[list[list[float] | None], list[int]]:
|
||||
"""
|
||||
Get cached embeddings for a batch of texts.
|
||||
|
||||
Returns:
|
||||
Tuple of (embeddings list with None for misses, indices of misses)
|
||||
"""
|
||||
if not self._redis:
|
||||
return [None] * len(texts), list(range(len(texts)))
|
||||
|
||||
try:
|
||||
keys = [self._cache_key(text) for text in texts]
|
||||
cached_values = await self._redis.mget(keys)
|
||||
|
||||
embeddings: list[list[float] | None] = []
|
||||
missing_indices: list[int] = []
|
||||
|
||||
for i, cached in enumerate(cached_values):
|
||||
if cached:
|
||||
embeddings.append(json.loads(cached))
|
||||
else:
|
||||
embeddings.append(None)
|
||||
missing_indices.append(i)
|
||||
|
||||
cache_hits = len(texts) - len(missing_indices)
|
||||
if cache_hits > 0:
|
||||
logger.debug(f"Batch cache hits: {cache_hits}/{len(texts)}")
|
||||
|
||||
return embeddings, missing_indices
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Batch cache read error: {e}")
|
||||
return [None] * len(texts), list(range(len(texts)))
|
||||
|
||||
async def _set_cached_batch(
|
||||
self, texts: list[str], embeddings: list[list[float]]
|
||||
) -> None:
|
||||
"""Cache a batch of embeddings."""
|
||||
if not self._redis or len(texts) != len(embeddings):
|
||||
return
|
||||
|
||||
try:
|
||||
pipe = self._redis.pipeline()
|
||||
for text, embedding in zip(texts, embeddings, strict=True):
|
||||
key = self._cache_key(text)
|
||||
pipe.setex(
|
||||
key,
|
||||
self._settings.embedding_cache_ttl,
|
||||
json.dumps(embedding),
|
||||
)
|
||||
await pipe.execute()
|
||||
logger.debug(f"Cached {len(texts)} embeddings in batch")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Batch cache write error: {e}")
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
text: str,
|
||||
project_id: str = "system",
|
||||
agent_id: str = "knowledge-base",
|
||||
) -> list[float]:
|
||||
"""
|
||||
Generate embedding for a single text.
|
||||
|
||||
Args:
|
||||
text: Text to embed
|
||||
project_id: Project ID for cost attribution
|
||||
agent_id: Agent ID for cost attribution
|
||||
|
||||
Returns:
|
||||
Embedding vector
|
||||
"""
|
||||
# Check cache first
|
||||
cached = await self._get_cached(text)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
# Generate via LLM Gateway
|
||||
embeddings = await self._call_llm_gateway(
|
||||
[text], project_id, agent_id
|
||||
)
|
||||
|
||||
if not embeddings:
|
||||
raise EmbeddingGenerationError(
|
||||
message="No embedding returned from LLM Gateway",
|
||||
texts_count=1,
|
||||
)
|
||||
|
||||
embedding = embeddings[0]
|
||||
|
||||
# Validate dimension
|
||||
if len(embedding) != self._settings.embedding_dimension:
|
||||
raise EmbeddingDimensionMismatchError(
|
||||
expected=self._settings.embedding_dimension,
|
||||
actual=len(embedding),
|
||||
)
|
||||
|
||||
# Cache the result
|
||||
await self._set_cached(text, embedding)
|
||||
|
||||
return embedding
|
||||
|
||||
async def generate_batch(
|
||||
self,
|
||||
texts: list[str],
|
||||
project_id: str = "system",
|
||||
agent_id: str = "knowledge-base",
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
Generate embeddings for multiple texts.
|
||||
|
||||
Uses caching and batches requests to LLM Gateway.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed
|
||||
project_id: Project ID for cost attribution
|
||||
agent_id: Agent ID for cost attribution
|
||||
|
||||
Returns:
|
||||
List of embedding vectors
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
# Check cache for existing embeddings
|
||||
cached_embeddings, missing_indices = await self._get_cached_batch(texts)
|
||||
|
||||
# If all cached, return immediately
|
||||
if not missing_indices:
|
||||
logger.debug(f"All {len(texts)} embeddings served from cache")
|
||||
return [e for e in cached_embeddings if e is not None]
|
||||
|
||||
# Get texts that need embedding
|
||||
texts_to_embed = [texts[i] for i in missing_indices]
|
||||
|
||||
# Generate embeddings in batches
|
||||
new_embeddings: list[list[float]] = []
|
||||
batch_size = self._settings.embedding_batch_size
|
||||
|
||||
for i in range(0, len(texts_to_embed), batch_size):
|
||||
batch = texts_to_embed[i : i + batch_size]
|
||||
batch_embeddings = await self._call_llm_gateway(
|
||||
batch, project_id, agent_id
|
||||
)
|
||||
new_embeddings.extend(batch_embeddings)
|
||||
|
||||
# Validate dimensions
|
||||
for embedding in new_embeddings:
|
||||
if len(embedding) != self._settings.embedding_dimension:
|
||||
raise EmbeddingDimensionMismatchError(
|
||||
expected=self._settings.embedding_dimension,
|
||||
actual=len(embedding),
|
||||
)
|
||||
|
||||
# Cache new embeddings
|
||||
await self._set_cached_batch(texts_to_embed, new_embeddings)
|
||||
|
||||
# Combine cached and new embeddings
|
||||
result: list[list[float]] = []
|
||||
new_idx = 0
|
||||
|
||||
for i in range(len(texts)):
|
||||
if cached_embeddings[i] is not None:
|
||||
result.append(cached_embeddings[i]) # type: ignore[arg-type]
|
||||
else:
|
||||
result.append(new_embeddings[new_idx])
|
||||
new_idx += 1
|
||||
|
||||
logger.info(
|
||||
f"Generated {len(new_embeddings)} embeddings, "
|
||||
f"{len(texts) - len(missing_indices)} from cache"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def _call_llm_gateway(
|
||||
self,
|
||||
texts: list[str],
|
||||
project_id: str,
|
||||
agent_id: str,
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
Call LLM Gateway to generate embeddings.
|
||||
|
||||
Uses JSON-RPC 2.0 protocol to call the embedding tool.
|
||||
"""
|
||||
try:
|
||||
# JSON-RPC 2.0 request for embedding tool
|
||||
request = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "generate_embeddings",
|
||||
"arguments": {
|
||||
"project_id": project_id,
|
||||
"agent_id": agent_id,
|
||||
"texts": texts,
|
||||
"model": self._settings.embedding_model,
|
||||
},
|
||||
},
|
||||
"id": 1,
|
||||
}
|
||||
|
||||
response = await self.http_client.post("/mcp", json=request)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
|
||||
if "error" in result:
|
||||
error = result["error"]
|
||||
raise EmbeddingGenerationError(
|
||||
message=f"LLM Gateway error: {error.get('message', 'Unknown')}",
|
||||
texts_count=len(texts),
|
||||
details=error.get("data"),
|
||||
)
|
||||
|
||||
# Extract embeddings from response
|
||||
content = result.get("result", {}).get("content", [])
|
||||
if not content:
|
||||
raise EmbeddingGenerationError(
|
||||
message="Empty response from LLM Gateway",
|
||||
texts_count=len(texts),
|
||||
)
|
||||
|
||||
# Parse the response content
|
||||
# LLM Gateway returns embeddings in content[0].text as JSON
|
||||
embeddings_data = content[0].get("text", "")
|
||||
if isinstance(embeddings_data, str):
|
||||
embeddings_data = json.loads(embeddings_data)
|
||||
|
||||
embeddings = embeddings_data.get("embeddings", [])
|
||||
|
||||
if len(embeddings) != len(texts):
|
||||
raise EmbeddingGenerationError(
|
||||
message=f"Embedding count mismatch: expected {len(texts)}, got {len(embeddings)}",
|
||||
texts_count=len(texts),
|
||||
)
|
||||
|
||||
return embeddings
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"LLM Gateway HTTP error: {e}")
|
||||
raise EmbeddingGenerationError(
|
||||
message=f"LLM Gateway request failed: {e.response.status_code}",
|
||||
texts_count=len(texts),
|
||||
cause=e,
|
||||
)
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"LLM Gateway request error: {e}")
|
||||
raise EmbeddingGenerationError(
|
||||
message=f"Failed to connect to LLM Gateway: {e}",
|
||||
texts_count=len(texts),
|
||||
cause=e,
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Invalid JSON response from LLM Gateway: {e}")
|
||||
raise EmbeddingGenerationError(
|
||||
message="Invalid response format from LLM Gateway",
|
||||
texts_count=len(texts),
|
||||
cause=e,
|
||||
)
|
||||
except EmbeddingGenerationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error generating embeddings: {e}")
|
||||
raise EmbeddingGenerationError(
|
||||
message=f"Unexpected error: {e}",
|
||||
texts_count=len(texts),
|
||||
cause=e,
|
||||
)
|
||||
|
||||
|
||||
# Global embedding generator instance (lazy initialization)
|
||||
_embedding_generator: EmbeddingGenerator | None = None
|
||||
|
||||
|
||||
def get_embedding_generator() -> EmbeddingGenerator:
|
||||
"""Get the global embedding generator instance."""
|
||||
global _embedding_generator
|
||||
if _embedding_generator is None:
|
||||
_embedding_generator = EmbeddingGenerator()
|
||||
return _embedding_generator
|
||||
|
||||
|
||||
def reset_embedding_generator() -> None:
|
||||
"""Reset the global embedding generator (for testing)."""
|
||||
global _embedding_generator
|
||||
_embedding_generator = None
|
||||
409
mcp-servers/knowledge-base/exceptions.py
Normal file
409
mcp-servers/knowledge-base/exceptions.py
Normal file
@@ -0,0 +1,409 @@
|
||||
"""
|
||||
Custom exceptions for Knowledge Base MCP Server.
|
||||
|
||||
Provides structured error handling with error codes and details.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ErrorCode(str, Enum):
|
||||
"""Error codes for Knowledge Base operations."""
|
||||
|
||||
# General errors
|
||||
UNKNOWN_ERROR = "KB_UNKNOWN_ERROR"
|
||||
INVALID_REQUEST = "KB_INVALID_REQUEST"
|
||||
INTERNAL_ERROR = "KB_INTERNAL_ERROR"
|
||||
|
||||
# Database errors
|
||||
DATABASE_CONNECTION_ERROR = "KB_DATABASE_CONNECTION_ERROR"
|
||||
DATABASE_QUERY_ERROR = "KB_DATABASE_QUERY_ERROR"
|
||||
DATABASE_INTEGRITY_ERROR = "KB_DATABASE_INTEGRITY_ERROR"
|
||||
|
||||
# Embedding errors
|
||||
EMBEDDING_GENERATION_ERROR = "KB_EMBEDDING_GENERATION_ERROR"
|
||||
EMBEDDING_DIMENSION_MISMATCH = "KB_EMBEDDING_DIMENSION_MISMATCH"
|
||||
EMBEDDING_RATE_LIMIT = "KB_EMBEDDING_RATE_LIMIT"
|
||||
|
||||
# Chunking errors
|
||||
CHUNKING_ERROR = "KB_CHUNKING_ERROR"
|
||||
UNSUPPORTED_FILE_TYPE = "KB_UNSUPPORTED_FILE_TYPE"
|
||||
FILE_TOO_LARGE = "KB_FILE_TOO_LARGE"
|
||||
ENCODING_ERROR = "KB_ENCODING_ERROR"
|
||||
|
||||
# Search errors
|
||||
SEARCH_ERROR = "KB_SEARCH_ERROR"
|
||||
INVALID_SEARCH_TYPE = "KB_INVALID_SEARCH_TYPE"
|
||||
SEARCH_TIMEOUT = "KB_SEARCH_TIMEOUT"
|
||||
|
||||
# Collection errors
|
||||
COLLECTION_NOT_FOUND = "KB_COLLECTION_NOT_FOUND"
|
||||
COLLECTION_ALREADY_EXISTS = "KB_COLLECTION_ALREADY_EXISTS"
|
||||
|
||||
# Document errors
|
||||
DOCUMENT_NOT_FOUND = "KB_DOCUMENT_NOT_FOUND"
|
||||
DOCUMENT_ALREADY_EXISTS = "KB_DOCUMENT_ALREADY_EXISTS"
|
||||
INVALID_DOCUMENT = "KB_INVALID_DOCUMENT"
|
||||
|
||||
# Project errors
|
||||
PROJECT_NOT_FOUND = "KB_PROJECT_NOT_FOUND"
|
||||
PROJECT_ACCESS_DENIED = "KB_PROJECT_ACCESS_DENIED"
|
||||
|
||||
|
||||
class KnowledgeBaseError(Exception):
|
||||
"""
|
||||
Base exception for Knowledge Base errors.
|
||||
|
||||
All custom exceptions inherit from this class.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
code: ErrorCode = ErrorCode.UNKNOWN_ERROR,
|
||||
details: dict[str, Any] | None = None,
|
||||
cause: Exception | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize Knowledge Base error.
|
||||
|
||||
Args:
|
||||
message: Human-readable error message
|
||||
code: Error code for programmatic handling
|
||||
details: Additional error details
|
||||
cause: Original exception that caused this error
|
||||
"""
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.code = code
|
||||
self.details = details or {}
|
||||
self.cause = cause
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert error to dictionary for JSON response."""
|
||||
result: dict[str, Any] = {
|
||||
"error": self.code.value,
|
||||
"message": self.message,
|
||||
}
|
||||
if self.details:
|
||||
result["details"] = self.details
|
||||
return result
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation."""
|
||||
return f"[{self.code.value}] {self.message}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Detailed representation."""
|
||||
return (
|
||||
f"{self.__class__.__name__}("
|
||||
f"message={self.message!r}, "
|
||||
f"code={self.code.value!r}, "
|
||||
f"details={self.details!r})"
|
||||
)
|
||||
|
||||
|
||||
# Database Errors
|
||||
|
||||
|
||||
class DatabaseError(KnowledgeBaseError):
|
||||
"""Base class for database-related errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
code: ErrorCode = ErrorCode.DATABASE_QUERY_ERROR,
|
||||
details: dict[str, Any] | None = None,
|
||||
cause: Exception | None = None,
|
||||
) -> None:
|
||||
super().__init__(message, code, details, cause)
|
||||
|
||||
|
||||
class DatabaseConnectionError(DatabaseError):
|
||||
"""Failed to connect to the database."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Failed to connect to database",
|
||||
details: dict[str, Any] | None = None,
|
||||
cause: Exception | None = None,
|
||||
) -> None:
|
||||
super().__init__(message, ErrorCode.DATABASE_CONNECTION_ERROR, details, cause)
|
||||
|
||||
|
||||
class DatabaseQueryError(DatabaseError):
|
||||
"""Database query failed."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
query: str | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
cause: Exception | None = None,
|
||||
) -> None:
|
||||
details = details or {}
|
||||
if query:
|
||||
details["query"] = query
|
||||
super().__init__(message, ErrorCode.DATABASE_QUERY_ERROR, details, cause)
|
||||
|
||||
|
||||
# Embedding Errors
|
||||
|
||||
|
||||
class EmbeddingError(KnowledgeBaseError):
|
||||
"""Base class for embedding-related errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
code: ErrorCode = ErrorCode.EMBEDDING_GENERATION_ERROR,
|
||||
details: dict[str, Any] | None = None,
|
||||
cause: Exception | None = None,
|
||||
) -> None:
|
||||
super().__init__(message, code, details, cause)
|
||||
|
||||
|
||||
class EmbeddingGenerationError(EmbeddingError):
|
||||
"""Failed to generate embeddings."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Failed to generate embeddings",
|
||||
texts_count: int | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
cause: Exception | None = None,
|
||||
) -> None:
|
||||
details = details or {}
|
||||
if texts_count is not None:
|
||||
details["texts_count"] = texts_count
|
||||
super().__init__(message, ErrorCode.EMBEDDING_GENERATION_ERROR, details, cause)
|
||||
|
||||
|
||||
class EmbeddingDimensionMismatchError(EmbeddingError):
|
||||
"""Embedding dimension doesn't match expected dimension."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
expected: int,
|
||||
actual: int,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
details = details or {}
|
||||
details["expected_dimension"] = expected
|
||||
details["actual_dimension"] = actual
|
||||
message = f"Embedding dimension mismatch: expected {expected}, got {actual}"
|
||||
super().__init__(message, ErrorCode.EMBEDDING_DIMENSION_MISMATCH, details)
|
||||
|
||||
|
||||
# Chunking Errors
|
||||
|
||||
|
||||
class ChunkingError(KnowledgeBaseError):
|
||||
"""Base class for chunking-related errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
code: ErrorCode = ErrorCode.CHUNKING_ERROR,
|
||||
details: dict[str, Any] | None = None,
|
||||
cause: Exception | None = None,
|
||||
) -> None:
|
||||
super().__init__(message, code, details, cause)
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(ChunkingError):
|
||||
"""File type is not supported for chunking."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_type: str,
|
||||
supported_types: list[str] | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
details = details or {}
|
||||
details["file_type"] = file_type
|
||||
if supported_types:
|
||||
details["supported_types"] = supported_types
|
||||
message = f"Unsupported file type: {file_type}"
|
||||
super().__init__(message, ErrorCode.UNSUPPORTED_FILE_TYPE, details)
|
||||
|
||||
|
||||
class FileTooLargeError(ChunkingError):
|
||||
"""File exceeds maximum allowed size."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_size: int,
|
||||
max_size: int,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
details = details or {}
|
||||
details["file_size"] = file_size
|
||||
details["max_size"] = max_size
|
||||
message = f"File too large: {file_size} bytes exceeds limit of {max_size} bytes"
|
||||
super().__init__(message, ErrorCode.FILE_TOO_LARGE, details)
|
||||
|
||||
|
||||
class EncodingError(ChunkingError):
|
||||
"""Failed to decode file content."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Failed to decode file content",
|
||||
encoding: str | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
cause: Exception | None = None,
|
||||
) -> None:
|
||||
details = details or {}
|
||||
if encoding:
|
||||
details["encoding"] = encoding
|
||||
super().__init__(message, ErrorCode.ENCODING_ERROR, details, cause)
|
||||
|
||||
|
||||
# Search Errors
|
||||
|
||||
|
||||
class SearchError(KnowledgeBaseError):
|
||||
"""Base class for search-related errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
code: ErrorCode = ErrorCode.SEARCH_ERROR,
|
||||
details: dict[str, Any] | None = None,
|
||||
cause: Exception | None = None,
|
||||
) -> None:
|
||||
super().__init__(message, code, details, cause)
|
||||
|
||||
|
||||
class InvalidSearchTypeError(SearchError):
|
||||
"""Invalid search type specified."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
search_type: str,
|
||||
valid_types: list[str] | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
details = details or {}
|
||||
details["search_type"] = search_type
|
||||
if valid_types:
|
||||
details["valid_types"] = valid_types
|
||||
message = f"Invalid search type: {search_type}"
|
||||
super().__init__(message, ErrorCode.INVALID_SEARCH_TYPE, details)
|
||||
|
||||
|
||||
class SearchTimeoutError(SearchError):
|
||||
"""Search operation timed out."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timeout: float,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
details = details or {}
|
||||
details["timeout"] = timeout
|
||||
message = f"Search timed out after {timeout} seconds"
|
||||
super().__init__(message, ErrorCode.SEARCH_TIMEOUT, details)
|
||||
|
||||
|
||||
# Collection Errors
|
||||
|
||||
|
||||
class CollectionError(KnowledgeBaseError):
|
||||
"""Base class for collection-related errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CollectionNotFoundError(CollectionError):
|
||||
"""Collection does not exist."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection: str,
|
||||
project_id: str | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
details = details or {}
|
||||
details["collection"] = collection
|
||||
if project_id:
|
||||
details["project_id"] = project_id
|
||||
message = f"Collection not found: {collection}"
|
||||
super().__init__(message, ErrorCode.COLLECTION_NOT_FOUND, details)
|
||||
|
||||
|
||||
# Document Errors
|
||||
|
||||
|
||||
class DocumentError(KnowledgeBaseError):
|
||||
"""Base class for document-related errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DocumentNotFoundError(DocumentError):
|
||||
"""Document does not exist."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source_path: str,
|
||||
project_id: str | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
details = details or {}
|
||||
details["source_path"] = source_path
|
||||
if project_id:
|
||||
details["project_id"] = project_id
|
||||
message = f"Document not found: {source_path}"
|
||||
super().__init__(message, ErrorCode.DOCUMENT_NOT_FOUND, details)
|
||||
|
||||
|
||||
class InvalidDocumentError(DocumentError):
|
||||
"""Document content is invalid."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Invalid document content",
|
||||
details: dict[str, Any] | None = None,
|
||||
cause: Exception | None = None,
|
||||
) -> None:
|
||||
super().__init__(message, ErrorCode.INVALID_DOCUMENT, details, cause)
|
||||
|
||||
|
||||
# Project Errors
|
||||
|
||||
|
||||
class ProjectError(KnowledgeBaseError):
|
||||
"""Base class for project-related errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ProjectNotFoundError(ProjectError):
|
||||
"""Project does not exist."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
project_id: str,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
details = details or {}
|
||||
details["project_id"] = project_id
|
||||
message = f"Project not found: {project_id}"
|
||||
super().__init__(message, ErrorCode.PROJECT_NOT_FOUND, details)
|
||||
|
||||
|
||||
class ProjectAccessDeniedError(ProjectError):
|
||||
"""Access to project is denied."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
project_id: str,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
details = details or {}
|
||||
details["project_id"] = project_id
|
||||
message = f"Access denied to project: {project_id}"
|
||||
super().__init__(message, ErrorCode.PROJECT_ACCESS_DENIED, details)
|
||||
321
mcp-servers/knowledge-base/models.py
Normal file
321
mcp-servers/knowledge-base/models.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""
|
||||
Data models for Knowledge Base MCP Server.
|
||||
|
||||
Defines database models, Pydantic schemas, and data structures
|
||||
for RAG operations with pgvector.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SearchType(str, Enum):
|
||||
"""Types of search supported."""
|
||||
|
||||
SEMANTIC = "semantic" # Vector similarity search
|
||||
KEYWORD = "keyword" # Full-text search
|
||||
HYBRID = "hybrid" # Combined semantic + keyword
|
||||
|
||||
|
||||
class ChunkType(str, Enum):
|
||||
"""Types of content chunks."""
|
||||
|
||||
CODE = "code"
|
||||
MARKDOWN = "markdown"
|
||||
TEXT = "text"
|
||||
DOCUMENTATION = "documentation"
|
||||
|
||||
|
||||
class FileType(str, Enum):
|
||||
"""Supported file types for chunking."""
|
||||
|
||||
PYTHON = "python"
|
||||
JAVASCRIPT = "javascript"
|
||||
TYPESCRIPT = "typescript"
|
||||
GO = "go"
|
||||
RUST = "rust"
|
||||
JAVA = "java"
|
||||
MARKDOWN = "markdown"
|
||||
TEXT = "text"
|
||||
JSON = "json"
|
||||
YAML = "yaml"
|
||||
TOML = "toml"
|
||||
|
||||
|
||||
# File extension to FileType mapping
|
||||
FILE_EXTENSION_MAP: dict[str, FileType] = {
|
||||
".py": FileType.PYTHON,
|
||||
".js": FileType.JAVASCRIPT,
|
||||
".jsx": FileType.JAVASCRIPT,
|
||||
".ts": FileType.TYPESCRIPT,
|
||||
".tsx": FileType.TYPESCRIPT,
|
||||
".go": FileType.GO,
|
||||
".rs": FileType.RUST,
|
||||
".java": FileType.JAVA,
|
||||
".md": FileType.MARKDOWN,
|
||||
".mdx": FileType.MARKDOWN,
|
||||
".txt": FileType.TEXT,
|
||||
".json": FileType.JSON,
|
||||
".yaml": FileType.YAML,
|
||||
".yml": FileType.YAML,
|
||||
".toml": FileType.TOML,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Chunk:
|
||||
"""A chunk of content ready for embedding."""
|
||||
|
||||
content: str
|
||||
chunk_type: ChunkType
|
||||
file_type: FileType | None = None
|
||||
source_path: str | None = None
|
||||
start_line: int | None = None
|
||||
end_line: int | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
token_count: int = 0
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"content": self.content,
|
||||
"chunk_type": self.chunk_type.value,
|
||||
"file_type": self.file_type.value if self.file_type else None,
|
||||
"source_path": self.source_path,
|
||||
"start_line": self.start_line,
|
||||
"end_line": self.end_line,
|
||||
"metadata": self.metadata,
|
||||
"token_count": self.token_count,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class KnowledgeEmbedding:
|
||||
"""
|
||||
A knowledge embedding stored in the database.
|
||||
|
||||
Represents a chunk of content with its vector embedding.
|
||||
"""
|
||||
|
||||
id: str
|
||||
project_id: str
|
||||
collection: str
|
||||
content: str
|
||||
embedding: list[float]
|
||||
chunk_type: ChunkType
|
||||
source_path: str | None = None
|
||||
start_line: int | None = None
|
||||
end_line: int | None = None
|
||||
file_type: FileType | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
content_hash: str | None = None
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
updated_at: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
expires_at: datetime | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary (excluding embedding for size)."""
|
||||
return {
|
||||
"id": self.id,
|
||||
"project_id": self.project_id,
|
||||
"collection": self.collection,
|
||||
"content": self.content,
|
||||
"chunk_type": self.chunk_type.value,
|
||||
"source_path": self.source_path,
|
||||
"start_line": self.start_line,
|
||||
"end_line": self.end_line,
|
||||
"file_type": self.file_type.value if self.file_type else None,
|
||||
"metadata": self.metadata,
|
||||
"content_hash": self.content_hash,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
}
|
||||
|
||||
|
||||
# Pydantic Request/Response Models
|
||||
|
||||
|
||||
class IngestRequest(BaseModel):
|
||||
"""Request to ingest content into the knowledge base."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
content: str = Field(..., description="Content to ingest")
|
||||
source_path: str | None = Field(
|
||||
default=None, description="Source file path for reference"
|
||||
)
|
||||
collection: str = Field(
|
||||
default="default", description="Collection to store in"
|
||||
)
|
||||
chunk_type: ChunkType = Field(
|
||||
default=ChunkType.TEXT, description="Type of content"
|
||||
)
|
||||
file_type: FileType | None = Field(
|
||||
default=None, description="File type for code chunking"
|
||||
)
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Additional metadata"
|
||||
)
|
||||
|
||||
|
||||
class IngestResult(BaseModel):
|
||||
"""Result of an ingest operation."""
|
||||
|
||||
success: bool = Field(..., description="Whether ingest succeeded")
|
||||
chunks_created: int = Field(default=0, description="Number of chunks created")
|
||||
embeddings_generated: int = Field(
|
||||
default=0, description="Number of embeddings generated"
|
||||
)
|
||||
source_path: str | None = Field(default=None, description="Source path ingested")
|
||||
collection: str = Field(default="default", description="Collection stored in")
|
||||
chunk_ids: list[str] = Field(
|
||||
default_factory=list, description="IDs of created chunks"
|
||||
)
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
"""Request to search the knowledge base."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
query: str = Field(..., description="Search query")
|
||||
search_type: SearchType = Field(
|
||||
default=SearchType.HYBRID, description="Type of search"
|
||||
)
|
||||
collection: str | None = Field(
|
||||
default=None, description="Collection to search (None = all)"
|
||||
)
|
||||
limit: int = Field(default=10, ge=1, le=100, description="Max results")
|
||||
threshold: float = Field(
|
||||
default=0.7, ge=0.0, le=1.0, description="Minimum similarity score"
|
||||
)
|
||||
file_types: list[FileType] | None = Field(
|
||||
default=None, description="Filter by file types"
|
||||
)
|
||||
include_metadata: bool = Field(
|
||||
default=True, description="Include metadata in results"
|
||||
)
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
"""A single search result."""
|
||||
|
||||
id: str = Field(..., description="Chunk ID")
|
||||
content: str = Field(..., description="Chunk content")
|
||||
score: float = Field(..., description="Relevance score (0-1)")
|
||||
source_path: str | None = Field(default=None, description="Source file path")
|
||||
start_line: int | None = Field(default=None, description="Start line in source")
|
||||
end_line: int | None = Field(default=None, description="End line in source")
|
||||
chunk_type: str = Field(..., description="Type of chunk")
|
||||
file_type: str | None = Field(default=None, description="File type")
|
||||
collection: str = Field(..., description="Collection name")
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Additional metadata"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_embedding(
|
||||
cls, embedding: KnowledgeEmbedding, score: float
|
||||
) -> "SearchResult":
|
||||
"""Create SearchResult from KnowledgeEmbedding."""
|
||||
return cls(
|
||||
id=embedding.id,
|
||||
content=embedding.content,
|
||||
score=score,
|
||||
source_path=embedding.source_path,
|
||||
start_line=embedding.start_line,
|
||||
end_line=embedding.end_line,
|
||||
chunk_type=embedding.chunk_type.value,
|
||||
file_type=embedding.file_type.value if embedding.file_type else None,
|
||||
collection=embedding.collection,
|
||||
metadata=embedding.metadata,
|
||||
)
|
||||
|
||||
|
||||
class SearchResponse(BaseModel):
|
||||
"""Response from a search operation."""
|
||||
|
||||
query: str = Field(..., description="Original query")
|
||||
search_type: str = Field(..., description="Type of search performed")
|
||||
results: list[SearchResult] = Field(
|
||||
default_factory=list, description="Search results"
|
||||
)
|
||||
total_results: int = Field(default=0, description="Total results found")
|
||||
search_time_ms: float = Field(default=0.0, description="Search time in ms")
|
||||
|
||||
|
||||
class DeleteRequest(BaseModel):
|
||||
"""Request to delete from the knowledge base."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
source_path: str | None = Field(
|
||||
default=None, description="Delete by source path"
|
||||
)
|
||||
collection: str | None = Field(
|
||||
default=None, description="Delete entire collection"
|
||||
)
|
||||
chunk_ids: list[str] | None = Field(
|
||||
default=None, description="Delete specific chunks"
|
||||
)
|
||||
|
||||
|
||||
class DeleteResult(BaseModel):
|
||||
"""Result of a delete operation."""
|
||||
|
||||
success: bool = Field(..., description="Whether delete succeeded")
|
||||
chunks_deleted: int = Field(default=0, description="Number of chunks deleted")
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
|
||||
class CollectionInfo(BaseModel):
|
||||
"""Information about a collection."""
|
||||
|
||||
name: str = Field(..., description="Collection name")
|
||||
project_id: str = Field(..., description="Project ID")
|
||||
chunk_count: int = Field(default=0, description="Number of chunks")
|
||||
total_tokens: int = Field(default=0, description="Total tokens stored")
|
||||
file_types: list[str] = Field(
|
||||
default_factory=list, description="File types in collection"
|
||||
)
|
||||
created_at: datetime = Field(..., description="Creation time")
|
||||
updated_at: datetime = Field(..., description="Last update time")
|
||||
|
||||
|
||||
class ListCollectionsResponse(BaseModel):
|
||||
"""Response for listing collections."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID")
|
||||
collections: list[CollectionInfo] = Field(
|
||||
default_factory=list, description="Collections in project"
|
||||
)
|
||||
total_collections: int = Field(default=0, description="Total count")
|
||||
|
||||
|
||||
class CollectionStats(BaseModel):
|
||||
"""Statistics for a collection."""
|
||||
|
||||
collection: str = Field(..., description="Collection name")
|
||||
project_id: str = Field(..., description="Project ID")
|
||||
chunk_count: int = Field(default=0, description="Number of chunks")
|
||||
unique_sources: int = Field(default=0, description="Unique source files")
|
||||
total_tokens: int = Field(default=0, description="Total tokens")
|
||||
avg_chunk_size: float = Field(default=0.0, description="Average chunk size")
|
||||
chunk_types: dict[str, int] = Field(
|
||||
default_factory=dict, description="Count by chunk type"
|
||||
)
|
||||
file_types: dict[str, int] = Field(
|
||||
default_factory=dict, description="Count by file type"
|
||||
)
|
||||
oldest_chunk: datetime | None = Field(
|
||||
default=None, description="Oldest chunk timestamp"
|
||||
)
|
||||
newest_chunk: datetime | None = Field(
|
||||
default=None, description="Newest chunk timestamp"
|
||||
)
|
||||
@@ -4,21 +4,101 @@ version = "0.1.0"
|
||||
description = "Syndarix Knowledge Base MCP Server - RAG with pgvector for semantic search"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"fastmcp>=0.1.0",
|
||||
"fastmcp>=2.0.0",
|
||||
"asyncpg>=0.29.0",
|
||||
"pgvector>=0.3.0",
|
||||
"redis>=5.0.0",
|
||||
"pydantic>=2.0.0",
|
||||
"pydantic-settings>=2.0.0",
|
||||
"tiktoken>=0.7.0",
|
||||
"httpx>=0.27.0",
|
||||
"uvicorn>=0.30.0",
|
||||
"fastapi>=0.115.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.0.0",
|
||||
"pytest-asyncio>=0.23.0",
|
||||
"pytest-asyncio>=0.24.0",
|
||||
"pytest-cov>=5.0.0",
|
||||
"fakeredis>=2.25.0",
|
||||
"ruff>=0.8.0",
|
||||
"mypy>=1.11.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
knowledge-base = "server:main"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["."]
|
||||
exclude = ["tests/", "*.md", "Dockerfile"]
|
||||
|
||||
[tool.hatch.build.targets.sdist]
|
||||
include = ["*.py", "pyproject.toml"]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py312"
|
||||
line-length = 88
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"E", # pycodestyle errors
|
||||
"W", # pycodestyle warnings
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
"B", # flake8-bugbear
|
||||
"C4", # flake8-comprehensions
|
||||
"UP", # pyupgrade
|
||||
"ARG", # flake8-unused-arguments
|
||||
"SIM", # flake8-simplify
|
||||
]
|
||||
ignore = [
|
||||
"E501", # line too long (handled by formatter)
|
||||
"B008", # do not perform function calls in argument defaults
|
||||
"B904", # raise from in except (too noisy)
|
||||
]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
known-first-party = ["config", "models", "exceptions", "database", "embeddings", "chunking", "search", "collections"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
testpaths = ["tests"]
|
||||
addopts = "-v --tb=short"
|
||||
filterwarnings = [
|
||||
"ignore::DeprecationWarning",
|
||||
]
|
||||
|
||||
[tool.coverage.run]
|
||||
source = ["."]
|
||||
omit = ["tests/*", "conftest.py"]
|
||||
branch = true
|
||||
|
||||
[tool.coverage.report]
|
||||
exclude_lines = [
|
||||
"pragma: no cover",
|
||||
"def __repr__",
|
||||
"raise NotImplementedError",
|
||||
"if TYPE_CHECKING:",
|
||||
"if __name__ == .__main__.:",
|
||||
]
|
||||
fail_under = 55
|
||||
show_missing = true
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.12"
|
||||
warn_return_any = false
|
||||
warn_unused_ignores = false
|
||||
disallow_untyped_defs = true
|
||||
ignore_missing_imports = true
|
||||
plugins = ["pydantic.mypy"]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "tests.*"
|
||||
disallow_untyped_defs = false
|
||||
ignore_errors = true
|
||||
|
||||
285
mcp-servers/knowledge-base/search.py
Normal file
285
mcp-servers/knowledge-base/search.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
Search implementations for Knowledge Base MCP Server.
|
||||
|
||||
Provides semantic (vector), keyword (full-text), and hybrid search
|
||||
capabilities over the knowledge base.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from config import Settings, get_settings
|
||||
from database import DatabaseManager, get_database_manager
|
||||
from embeddings import EmbeddingGenerator, get_embedding_generator
|
||||
from exceptions import InvalidSearchTypeError, SearchError
|
||||
from models import (
|
||||
SearchRequest,
|
||||
SearchResponse,
|
||||
SearchResult,
|
||||
SearchType,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SearchEngine:
|
||||
"""
|
||||
Unified search engine supporting multiple search types.
|
||||
|
||||
Features:
|
||||
- Semantic search using vector similarity
|
||||
- Keyword search using full-text search
|
||||
- Hybrid search combining both approaches
|
||||
- Configurable result fusion and weighting
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: Settings | None = None,
|
||||
database: DatabaseManager | None = None,
|
||||
embeddings: EmbeddingGenerator | None = None,
|
||||
) -> None:
|
||||
"""Initialize search engine."""
|
||||
self._settings = settings or get_settings()
|
||||
self._database = database
|
||||
self._embeddings = embeddings
|
||||
|
||||
@property
|
||||
def database(self) -> DatabaseManager:
|
||||
"""Get database manager."""
|
||||
if self._database is None:
|
||||
self._database = get_database_manager()
|
||||
return self._database
|
||||
|
||||
@property
|
||||
def embeddings(self) -> EmbeddingGenerator:
|
||||
"""Get embedding generator."""
|
||||
if self._embeddings is None:
|
||||
self._embeddings = get_embedding_generator()
|
||||
return self._embeddings
|
||||
|
||||
async def search(self, request: SearchRequest) -> SearchResponse:
|
||||
"""
|
||||
Execute a search request.
|
||||
|
||||
Args:
|
||||
request: Search request with query and options
|
||||
|
||||
Returns:
|
||||
Search response with results
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
if request.search_type == SearchType.SEMANTIC:
|
||||
results = await self._semantic_search(request)
|
||||
elif request.search_type == SearchType.KEYWORD:
|
||||
results = await self._keyword_search(request)
|
||||
elif request.search_type == SearchType.HYBRID:
|
||||
results = await self._hybrid_search(request)
|
||||
else:
|
||||
raise InvalidSearchTypeError(
|
||||
search_type=request.search_type,
|
||||
valid_types=[t.value for t in SearchType],
|
||||
)
|
||||
|
||||
search_time_ms = (time.time() - start_time) * 1000
|
||||
|
||||
logger.info(
|
||||
f"Search completed: type={request.search_type.value}, "
|
||||
f"results={len(results)}, time={search_time_ms:.1f}ms"
|
||||
)
|
||||
|
||||
return SearchResponse(
|
||||
query=request.query,
|
||||
search_type=request.search_type.value,
|
||||
results=results,
|
||||
total_results=len(results),
|
||||
search_time_ms=search_time_ms,
|
||||
)
|
||||
|
||||
except InvalidSearchTypeError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Search error: {e}")
|
||||
raise SearchError(
|
||||
message=f"Search failed: {e}",
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def _semantic_search(self, request: SearchRequest) -> list[SearchResult]:
|
||||
"""Execute semantic (vector) search."""
|
||||
# Generate embedding for query
|
||||
query_embedding = await self.embeddings.generate(
|
||||
text=request.query,
|
||||
project_id=request.project_id,
|
||||
agent_id=request.agent_id,
|
||||
)
|
||||
|
||||
# Search database
|
||||
results = await self.database.semantic_search(
|
||||
project_id=request.project_id,
|
||||
query_embedding=query_embedding,
|
||||
collection=request.collection,
|
||||
limit=request.limit,
|
||||
threshold=request.threshold,
|
||||
file_types=request.file_types,
|
||||
)
|
||||
|
||||
# Convert to SearchResult
|
||||
return [
|
||||
SearchResult.from_embedding(embedding, score)
|
||||
for embedding, score in results
|
||||
]
|
||||
|
||||
async def _keyword_search(self, request: SearchRequest) -> list[SearchResult]:
|
||||
"""Execute keyword (full-text) search."""
|
||||
results = await self.database.keyword_search(
|
||||
project_id=request.project_id,
|
||||
query=request.query,
|
||||
collection=request.collection,
|
||||
limit=request.limit,
|
||||
file_types=request.file_types,
|
||||
)
|
||||
|
||||
# Filter by threshold (keyword search scores are normalized)
|
||||
filtered = [
|
||||
(emb, score) for emb, score in results
|
||||
if score >= request.threshold
|
||||
]
|
||||
|
||||
return [
|
||||
SearchResult.from_embedding(embedding, score)
|
||||
for embedding, score in filtered
|
||||
]
|
||||
|
||||
async def _hybrid_search(self, request: SearchRequest) -> list[SearchResult]:
|
||||
"""
|
||||
Execute hybrid search combining semantic and keyword.
|
||||
|
||||
Uses Reciprocal Rank Fusion (RRF) for result combination.
|
||||
"""
|
||||
# Execute both searches with higher limits for fusion
|
||||
fusion_limit = min(request.limit * 2, 100)
|
||||
|
||||
# Create modified request for sub-searches
|
||||
semantic_request = SearchRequest(
|
||||
project_id=request.project_id,
|
||||
agent_id=request.agent_id,
|
||||
query=request.query,
|
||||
search_type=SearchType.SEMANTIC,
|
||||
collection=request.collection,
|
||||
limit=fusion_limit,
|
||||
threshold=request.threshold * 0.8, # Lower threshold for fusion
|
||||
file_types=request.file_types,
|
||||
include_metadata=request.include_metadata,
|
||||
)
|
||||
|
||||
keyword_request = SearchRequest(
|
||||
project_id=request.project_id,
|
||||
agent_id=request.agent_id,
|
||||
query=request.query,
|
||||
search_type=SearchType.KEYWORD,
|
||||
collection=request.collection,
|
||||
limit=fusion_limit,
|
||||
threshold=0.0, # No threshold for keyword, we'll filter after fusion
|
||||
file_types=request.file_types,
|
||||
include_metadata=request.include_metadata,
|
||||
)
|
||||
|
||||
# Execute searches
|
||||
semantic_results = await self._semantic_search(semantic_request)
|
||||
keyword_results = await self._keyword_search(keyword_request)
|
||||
|
||||
# Fuse results using RRF
|
||||
fused = self._reciprocal_rank_fusion(
|
||||
semantic_results=semantic_results,
|
||||
keyword_results=keyword_results,
|
||||
semantic_weight=self._settings.hybrid_semantic_weight,
|
||||
keyword_weight=self._settings.hybrid_keyword_weight,
|
||||
)
|
||||
|
||||
# Filter by threshold and limit
|
||||
filtered = [
|
||||
result for result in fused
|
||||
if result.score >= request.threshold
|
||||
][:request.limit]
|
||||
|
||||
return filtered
|
||||
|
||||
def _reciprocal_rank_fusion(
|
||||
self,
|
||||
semantic_results: list[SearchResult],
|
||||
keyword_results: list[SearchResult],
|
||||
semantic_weight: float = 0.7,
|
||||
keyword_weight: float = 0.3,
|
||||
k: int = 60, # RRF constant
|
||||
) -> list[SearchResult]:
|
||||
"""
|
||||
Combine results using Reciprocal Rank Fusion.
|
||||
|
||||
RRF score = sum(weight / (k + rank)) for each result list.
|
||||
"""
|
||||
# Calculate RRF scores
|
||||
scores: dict[str, float] = {}
|
||||
results_by_id: dict[str, SearchResult] = {}
|
||||
|
||||
# Process semantic results
|
||||
for rank, result in enumerate(semantic_results, start=1):
|
||||
rrf_score = semantic_weight / (k + rank)
|
||||
scores[result.id] = scores.get(result.id, 0) + rrf_score
|
||||
results_by_id[result.id] = result
|
||||
|
||||
# Process keyword results
|
||||
for rank, result in enumerate(keyword_results, start=1):
|
||||
rrf_score = keyword_weight / (k + rank)
|
||||
scores[result.id] = scores.get(result.id, 0) + rrf_score
|
||||
if result.id not in results_by_id:
|
||||
results_by_id[result.id] = result
|
||||
|
||||
# Sort by combined score
|
||||
sorted_ids = sorted(scores.keys(), key=lambda x: scores[x], reverse=True)
|
||||
|
||||
# Normalize scores to 0-1 range
|
||||
max_score = max(scores.values()) if scores else 1.0
|
||||
|
||||
# Create final results with normalized scores
|
||||
final_results: list[SearchResult] = []
|
||||
for result_id in sorted_ids:
|
||||
result = results_by_id[result_id]
|
||||
normalized_score = scores[result_id] / max_score
|
||||
# Create new result with updated score
|
||||
final_results.append(
|
||||
SearchResult(
|
||||
id=result.id,
|
||||
content=result.content,
|
||||
score=normalized_score,
|
||||
source_path=result.source_path,
|
||||
start_line=result.start_line,
|
||||
end_line=result.end_line,
|
||||
chunk_type=result.chunk_type,
|
||||
file_type=result.file_type,
|
||||
collection=result.collection,
|
||||
metadata=result.metadata,
|
||||
)
|
||||
)
|
||||
|
||||
return final_results
|
||||
|
||||
|
||||
# Global search engine instance (lazy initialization)
|
||||
_search_engine: SearchEngine | None = None
|
||||
|
||||
|
||||
def get_search_engine() -> SearchEngine:
|
||||
"""Get the global search engine instance."""
|
||||
global _search_engine
|
||||
if _search_engine is None:
|
||||
_search_engine = SearchEngine()
|
||||
return _search_engine
|
||||
|
||||
|
||||
def reset_search_engine() -> None:
|
||||
"""Reset the global search engine (for testing)."""
|
||||
global _search_engine
|
||||
_search_engine = None
|
||||
@@ -1,162 +1,569 @@
|
||||
"""
|
||||
Syndarix Knowledge Base MCP Server.
|
||||
Knowledge Base MCP Server.
|
||||
|
||||
Provides RAG capabilities with:
|
||||
- pgvector for semantic search
|
||||
- Per-project collection isolation
|
||||
- Hybrid search (vector + keyword)
|
||||
- Chunking strategies for code, markdown, and text
|
||||
|
||||
Per ADR-008: Knowledge Base RAG Architecture.
|
||||
Provides RAG capabilities with pgvector for semantic search,
|
||||
intelligent chunking, and collection management.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastmcp import FastMCP
|
||||
from pydantic import Field
|
||||
|
||||
# Create MCP server
|
||||
mcp = FastMCP(
|
||||
"syndarix-knowledge-base",
|
||||
from collection_manager import CollectionManager, get_collection_manager
|
||||
from collections.abc import AsyncIterator
|
||||
from config import get_settings
|
||||
from database import DatabaseManager, get_database_manager
|
||||
from embeddings import EmbeddingGenerator, get_embedding_generator
|
||||
from exceptions import KnowledgeBaseError
|
||||
from models import (
|
||||
ChunkType,
|
||||
DeleteRequest,
|
||||
FileType,
|
||||
IngestRequest,
|
||||
SearchRequest,
|
||||
SearchType,
|
||||
)
|
||||
from search import SearchEngine, get_search_engine
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global instances
|
||||
_database: DatabaseManager | None = None
|
||||
_embeddings: EmbeddingGenerator | None = None
|
||||
_search: SearchEngine | None = None
|
||||
_collections: CollectionManager | None = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
|
||||
"""Application lifespan handler."""
|
||||
global _database, _embeddings, _search, _collections
|
||||
|
||||
logger.info("Starting Knowledge Base MCP Server...")
|
||||
|
||||
# Initialize database
|
||||
_database = get_database_manager()
|
||||
await _database.initialize()
|
||||
|
||||
# Initialize embedding generator
|
||||
_embeddings = get_embedding_generator()
|
||||
await _embeddings.initialize()
|
||||
|
||||
# Initialize search engine
|
||||
_search = get_search_engine()
|
||||
|
||||
# Initialize collection manager
|
||||
_collections = get_collection_manager()
|
||||
|
||||
logger.info("Knowledge Base MCP Server started successfully")
|
||||
|
||||
yield
|
||||
|
||||
# Cleanup
|
||||
logger.info("Shutting down Knowledge Base MCP Server...")
|
||||
|
||||
if _embeddings:
|
||||
await _embeddings.close()
|
||||
|
||||
if _database:
|
||||
await _database.close()
|
||||
|
||||
logger.info("Knowledge Base MCP Server shut down")
|
||||
|
||||
|
||||
# Create FastMCP server
|
||||
mcp = FastMCP("syndarix-knowledge-base")
|
||||
|
||||
# Create FastAPI app with lifespan
|
||||
app = FastAPI(
|
||||
title="Knowledge Base MCP Server",
|
||||
description="RAG with pgvector for semantic search",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Configuration
|
||||
DATABASE_URL = os.getenv("DATABASE_URL")
|
||||
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0")
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check() -> dict[str, Any]:
|
||||
"""Health check endpoint."""
|
||||
status: dict[str, Any] = {
|
||||
"status": "healthy",
|
||||
"service": "knowledge-base",
|
||||
"version": "0.1.0",
|
||||
}
|
||||
|
||||
# Check database connection
|
||||
try:
|
||||
if _database and _database._pool:
|
||||
async with _database.acquire() as conn:
|
||||
await conn.fetchval("SELECT 1")
|
||||
status["database"] = "connected"
|
||||
else:
|
||||
status["database"] = "not initialized"
|
||||
except Exception as e:
|
||||
status["database"] = f"error: {e}"
|
||||
status["status"] = "degraded"
|
||||
|
||||
return status
|
||||
|
||||
|
||||
# MCP Tools
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def search_knowledge(
|
||||
project_id: str,
|
||||
query: str,
|
||||
top_k: int = 10,
|
||||
search_type: str = "hybrid",
|
||||
filters: dict | None = None,
|
||||
) -> dict:
|
||||
project_id: str = Field(..., description="Project ID for scoping"),
|
||||
agent_id: str = Field(..., description="Agent ID making the request"),
|
||||
query: str = Field(..., description="Search query"),
|
||||
search_type: str = Field(
|
||||
default="hybrid",
|
||||
description="Search type: semantic, keyword, or hybrid",
|
||||
),
|
||||
collection: str | None = Field(
|
||||
default=None,
|
||||
description="Collection to search (None = all)",
|
||||
),
|
||||
limit: int = Field(
|
||||
default=10,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Maximum number of results",
|
||||
),
|
||||
threshold: float = Field(
|
||||
default=0.7,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Minimum similarity score",
|
||||
),
|
||||
file_types: list[str] | None = Field(
|
||||
default=None,
|
||||
description="Filter by file types (python, javascript, etc.)",
|
||||
),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Search the project knowledge base.
|
||||
Search the knowledge base for relevant content.
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project (scopes to project collection)
|
||||
query: Search query text
|
||||
top_k: Number of results to return
|
||||
search_type: Search type (semantic, keyword, hybrid)
|
||||
filters: Optional filters (file_type, path_prefix, etc.)
|
||||
|
||||
Returns:
|
||||
List of matching documents with scores
|
||||
Supports semantic (vector), keyword (full-text), and hybrid search.
|
||||
Returns chunks ranked by relevance to the query.
|
||||
"""
|
||||
# TODO: Implement pgvector search
|
||||
# 1. Generate query embedding via LLM Gateway
|
||||
# 2. Search project-scoped collection
|
||||
# 3. Apply filters
|
||||
# 4. Return results with scores
|
||||
return {
|
||||
"status": "not_implemented",
|
||||
"project_id": project_id,
|
||||
"query": query,
|
||||
}
|
||||
try:
|
||||
# Parse search type
|
||||
try:
|
||||
search_type_enum = SearchType(search_type.lower())
|
||||
except ValueError:
|
||||
valid_types = [t.value for t in SearchType]
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Invalid search type: {search_type}. Valid types: {valid_types}",
|
||||
}
|
||||
|
||||
# Parse file types
|
||||
file_type_enums = None
|
||||
if file_types:
|
||||
try:
|
||||
file_type_enums = [FileType(ft.lower()) for ft in file_types]
|
||||
except ValueError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Invalid file type: {e}",
|
||||
}
|
||||
|
||||
request = SearchRequest(
|
||||
project_id=project_id,
|
||||
agent_id=agent_id,
|
||||
query=query,
|
||||
search_type=search_type_enum,
|
||||
collection=collection,
|
||||
limit=limit,
|
||||
threshold=threshold,
|
||||
file_types=file_type_enums,
|
||||
)
|
||||
|
||||
response = await _search.search(request) # type: ignore[union-attr]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"query": response.query,
|
||||
"search_type": response.search_type,
|
||||
"results": [
|
||||
{
|
||||
"id": r.id,
|
||||
"content": r.content,
|
||||
"score": r.score,
|
||||
"source_path": r.source_path,
|
||||
"start_line": r.start_line,
|
||||
"end_line": r.end_line,
|
||||
"chunk_type": r.chunk_type,
|
||||
"file_type": r.file_type,
|
||||
"collection": r.collection,
|
||||
"metadata": r.metadata,
|
||||
}
|
||||
for r in response.results
|
||||
],
|
||||
"total_results": response.total_results,
|
||||
"search_time_ms": response.search_time_ms,
|
||||
}
|
||||
|
||||
except KnowledgeBaseError as e:
|
||||
logger.error(f"Search error: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": e.message,
|
||||
"code": e.code.value,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected search error: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def ingest_document(
|
||||
project_id: str,
|
||||
content: str,
|
||||
source_path: str,
|
||||
doc_type: str = "text",
|
||||
metadata: dict | None = None,
|
||||
) -> dict:
|
||||
async def ingest_content(
|
||||
project_id: str = Field(..., description="Project ID for scoping"),
|
||||
agent_id: str = Field(..., description="Agent ID making the request"),
|
||||
content: str = Field(..., description="Content to ingest"),
|
||||
source_path: str | None = Field(
|
||||
default=None,
|
||||
description="Source file path for reference",
|
||||
),
|
||||
collection: str = Field(
|
||||
default="default",
|
||||
description="Collection to store in",
|
||||
),
|
||||
chunk_type: str = Field(
|
||||
default="text",
|
||||
description="Content type: code, markdown, or text",
|
||||
),
|
||||
file_type: str | None = Field(
|
||||
default=None,
|
||||
description="File type for code chunking (python, javascript, etc.)",
|
||||
),
|
||||
metadata: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Additional metadata to store",
|
||||
),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Ingest a document into the knowledge base.
|
||||
Ingest content into the knowledge base.
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
content: Document content
|
||||
source_path: Original file path for reference
|
||||
doc_type: Document type (code, markdown, text)
|
||||
metadata: Additional metadata
|
||||
|
||||
Returns:
|
||||
Ingestion result with chunk count
|
||||
Content is automatically chunked based on type, embedded using
|
||||
the LLM Gateway, and stored in pgvector for search.
|
||||
"""
|
||||
# TODO: Implement document ingestion
|
||||
# 1. Apply chunking strategy based on doc_type
|
||||
# 2. Generate embeddings for chunks
|
||||
# 3. Store in project collection
|
||||
return {
|
||||
"status": "not_implemented",
|
||||
"project_id": project_id,
|
||||
"source_path": source_path,
|
||||
}
|
||||
try:
|
||||
# Parse chunk type
|
||||
try:
|
||||
chunk_type_enum = ChunkType(chunk_type.lower())
|
||||
except ValueError:
|
||||
valid_types = [t.value for t in ChunkType]
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Invalid chunk type: {chunk_type}. Valid types: {valid_types}",
|
||||
}
|
||||
|
||||
# Parse file type
|
||||
file_type_enum = None
|
||||
if file_type:
|
||||
try:
|
||||
file_type_enum = FileType(file_type.lower())
|
||||
except ValueError:
|
||||
valid_types = [t.value for t in FileType]
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Invalid file type: {file_type}. Valid types: {valid_types}",
|
||||
}
|
||||
|
||||
request = IngestRequest(
|
||||
project_id=project_id,
|
||||
agent_id=agent_id,
|
||||
content=content,
|
||||
source_path=source_path,
|
||||
collection=collection,
|
||||
chunk_type=chunk_type_enum,
|
||||
file_type=file_type_enum,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
result = await _collections.ingest(request) # type: ignore[union-attr]
|
||||
|
||||
return {
|
||||
"success": result.success,
|
||||
"chunks_created": result.chunks_created,
|
||||
"embeddings_generated": result.embeddings_generated,
|
||||
"source_path": result.source_path,
|
||||
"collection": result.collection,
|
||||
"chunk_ids": result.chunk_ids,
|
||||
"error": result.error,
|
||||
}
|
||||
|
||||
except KnowledgeBaseError as e:
|
||||
logger.error(f"Ingest error: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": e.message,
|
||||
"code": e.code.value,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected ingest error: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def ingest_repository(
|
||||
project_id: str,
|
||||
repo_path: str,
|
||||
include_patterns: list[str] | None = None,
|
||||
exclude_patterns: list[str] | None = None,
|
||||
) -> dict:
|
||||
async def delete_content(
|
||||
project_id: str = Field(..., description="Project ID for scoping"),
|
||||
agent_id: str = Field(..., description="Agent ID making the request"),
|
||||
source_path: str | None = Field(
|
||||
default=None,
|
||||
description="Delete by source file path",
|
||||
),
|
||||
collection: str | None = Field(
|
||||
default=None,
|
||||
description="Delete entire collection",
|
||||
),
|
||||
chunk_ids: list[str] | None = Field(
|
||||
default=None,
|
||||
description="Delete specific chunk IDs",
|
||||
),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Ingest an entire repository into the knowledge base.
|
||||
Delete content from the knowledge base.
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
repo_path: Path to the repository
|
||||
include_patterns: Glob patterns to include (e.g., ["*.py", "*.md"])
|
||||
exclude_patterns: Glob patterns to exclude (e.g., ["node_modules/*"])
|
||||
|
||||
Returns:
|
||||
Ingestion summary with file and chunk counts
|
||||
Specify either source_path, collection, or chunk_ids to delete.
|
||||
"""
|
||||
# TODO: Implement repository ingestion
|
||||
return {
|
||||
"status": "not_implemented",
|
||||
"project_id": project_id,
|
||||
"repo_path": repo_path,
|
||||
}
|
||||
try:
|
||||
request = DeleteRequest(
|
||||
project_id=project_id,
|
||||
agent_id=agent_id,
|
||||
source_path=source_path,
|
||||
collection=collection,
|
||||
chunk_ids=chunk_ids,
|
||||
)
|
||||
|
||||
result = await _collections.delete(request) # type: ignore[union-attr]
|
||||
|
||||
return {
|
||||
"success": result.success,
|
||||
"chunks_deleted": result.chunks_deleted,
|
||||
"error": result.error,
|
||||
}
|
||||
|
||||
except KnowledgeBaseError as e:
|
||||
logger.error(f"Delete error: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": e.message,
|
||||
"code": e.code.value,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected delete error: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def delete_document(
|
||||
project_id: str,
|
||||
source_path: str,
|
||||
) -> dict:
|
||||
async def list_collections(
|
||||
project_id: str = Field(..., description="Project ID for scoping"),
|
||||
agent_id: str = Field(..., description="Agent ID making the request"), # noqa: ARG001
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Delete a document from the knowledge base.
|
||||
List all collections in a project's knowledge base.
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
source_path: Original file path
|
||||
|
||||
Returns:
|
||||
Deletion result
|
||||
Returns collection names with chunk counts and file types.
|
||||
"""
|
||||
# TODO: Implement document deletion
|
||||
return {
|
||||
"status": "not_implemented",
|
||||
"project_id": project_id,
|
||||
"source_path": source_path,
|
||||
}
|
||||
try:
|
||||
result = await _collections.list_collections(project_id) # type: ignore[union-attr]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"project_id": result.project_id,
|
||||
"collections": [
|
||||
{
|
||||
"name": c.name,
|
||||
"chunk_count": c.chunk_count,
|
||||
"total_tokens": c.total_tokens,
|
||||
"file_types": c.file_types,
|
||||
"created_at": c.created_at.isoformat(),
|
||||
"updated_at": c.updated_at.isoformat(),
|
||||
}
|
||||
for c in result.collections
|
||||
],
|
||||
"total_collections": result.total_collections,
|
||||
}
|
||||
|
||||
except KnowledgeBaseError as e:
|
||||
logger.error(f"List collections error: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": e.message,
|
||||
"code": e.code.value,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected list collections error: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_collection_stats(project_id: str) -> dict:
|
||||
async def get_collection_stats(
|
||||
project_id: str = Field(..., description="Project ID for scoping"),
|
||||
agent_id: str = Field(..., description="Agent ID making the request"), # noqa: ARG001
|
||||
collection: str = Field(..., description="Collection name"),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get statistics for a project's knowledge base collection.
|
||||
Get detailed statistics for a collection.
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
|
||||
Returns:
|
||||
Collection statistics (document count, chunk count, etc.)
|
||||
Returns chunk counts, token totals, and type breakdowns.
|
||||
"""
|
||||
# TODO: Implement collection stats
|
||||
return {
|
||||
"status": "not_implemented",
|
||||
"project_id": project_id,
|
||||
}
|
||||
try:
|
||||
stats = await _collections.get_collection_stats(project_id, collection) # type: ignore[union-attr]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"collection": stats.collection,
|
||||
"project_id": stats.project_id,
|
||||
"chunk_count": stats.chunk_count,
|
||||
"unique_sources": stats.unique_sources,
|
||||
"total_tokens": stats.total_tokens,
|
||||
"avg_chunk_size": stats.avg_chunk_size,
|
||||
"chunk_types": stats.chunk_types,
|
||||
"file_types": stats.file_types,
|
||||
"oldest_chunk": stats.oldest_chunk.isoformat() if stats.oldest_chunk else None,
|
||||
"newest_chunk": stats.newest_chunk.isoformat() if stats.newest_chunk else None,
|
||||
}
|
||||
|
||||
except KnowledgeBaseError as e:
|
||||
logger.error(f"Get collection stats error: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": e.message,
|
||||
"code": e.code.value,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected get collection stats error: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def update_document(
|
||||
project_id: str = Field(..., description="Project ID for scoping"),
|
||||
agent_id: str = Field(..., description="Agent ID making the request"),
|
||||
source_path: str = Field(..., description="Source file path"),
|
||||
content: str = Field(..., description="New content"),
|
||||
collection: str = Field(
|
||||
default="default",
|
||||
description="Collection name",
|
||||
),
|
||||
chunk_type: str = Field(
|
||||
default="text",
|
||||
description="Content type: code, markdown, or text",
|
||||
),
|
||||
file_type: str | None = Field(
|
||||
default=None,
|
||||
description="File type for code chunking",
|
||||
),
|
||||
metadata: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Additional metadata",
|
||||
),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Update a document in the knowledge base.
|
||||
|
||||
Replaces all existing chunks for the source path with new content.
|
||||
"""
|
||||
try:
|
||||
# Parse chunk type
|
||||
try:
|
||||
chunk_type_enum = ChunkType(chunk_type.lower())
|
||||
except ValueError:
|
||||
valid_types = [t.value for t in ChunkType]
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Invalid chunk type: {chunk_type}. Valid types: {valid_types}",
|
||||
}
|
||||
|
||||
# Parse file type
|
||||
file_type_enum = None
|
||||
if file_type:
|
||||
try:
|
||||
file_type_enum = FileType(file_type.lower())
|
||||
except ValueError:
|
||||
valid_types = [t.value for t in FileType]
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Invalid file type: {file_type}. Valid types: {valid_types}",
|
||||
}
|
||||
|
||||
result = await _collections.update_document( # type: ignore[union-attr]
|
||||
project_id=project_id,
|
||||
agent_id=agent_id,
|
||||
source_path=source_path,
|
||||
content=content,
|
||||
collection=collection,
|
||||
chunk_type=chunk_type_enum,
|
||||
file_type=file_type_enum,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": result.success,
|
||||
"chunks_created": result.chunks_created,
|
||||
"embeddings_generated": result.embeddings_generated,
|
||||
"source_path": result.source_path,
|
||||
"collection": result.collection,
|
||||
"chunk_ids": result.chunk_ids,
|
||||
"error": result.error,
|
||||
}
|
||||
|
||||
except KnowledgeBaseError as e:
|
||||
logger.error(f"Update document error: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": e.message,
|
||||
"code": e.code.value,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected update document error: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Run the server."""
|
||||
import uvicorn
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
uvicorn.run(
|
||||
"server:app",
|
||||
host=settings.host,
|
||||
port=settings.port,
|
||||
reload=settings.debug,
|
||||
log_level="info",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mcp.run()
|
||||
main()
|
||||
|
||||
1
mcp-servers/knowledge-base/tests/__init__.py
Normal file
1
mcp-servers/knowledge-base/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for Knowledge Base MCP Server."""
|
||||
282
mcp-servers/knowledge-base/tests/conftest.py
Normal file
282
mcp-servers/knowledge-base/tests/conftest.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""
|
||||
Test fixtures for Knowledge Base MCP Server.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
# Set test mode before importing modules
|
||||
os.environ["IS_TEST"] = "true"
|
||||
os.environ["KB_DATABASE_URL"] = "postgresql://test:test@localhost:5432/test"
|
||||
os.environ["KB_REDIS_URL"] = "redis://localhost:6379/0"
|
||||
os.environ["KB_LLM_GATEWAY_URL"] = "http://localhost:8001"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def settings():
|
||||
"""Create test settings."""
|
||||
from config import Settings, reset_settings
|
||||
|
||||
reset_settings()
|
||||
return Settings(
|
||||
host="127.0.0.1",
|
||||
port=8002,
|
||||
debug=True,
|
||||
database_url="postgresql://test:test@localhost:5432/test",
|
||||
redis_url="redis://localhost:6379/0",
|
||||
llm_gateway_url="http://localhost:8001",
|
||||
embedding_dimension=1536,
|
||||
code_chunk_size=500,
|
||||
code_chunk_overlap=50,
|
||||
markdown_chunk_size=800,
|
||||
markdown_chunk_overlap=100,
|
||||
text_chunk_size=400,
|
||||
text_chunk_overlap=50,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_database():
|
||||
"""Create mock database manager."""
|
||||
from database import DatabaseManager
|
||||
|
||||
mock_db = MagicMock(spec=DatabaseManager)
|
||||
mock_db._pool = MagicMock()
|
||||
mock_db.acquire = MagicMock(return_value=AsyncMock())
|
||||
|
||||
# Mock database methods
|
||||
mock_db.initialize = AsyncMock()
|
||||
mock_db.close = AsyncMock()
|
||||
mock_db.store_embedding = AsyncMock(return_value="test-id-123")
|
||||
mock_db.store_embeddings_batch = AsyncMock(return_value=["id-1", "id-2"])
|
||||
mock_db.semantic_search = AsyncMock(return_value=[])
|
||||
mock_db.keyword_search = AsyncMock(return_value=[])
|
||||
mock_db.delete_by_source = AsyncMock(return_value=1)
|
||||
mock_db.delete_collection = AsyncMock(return_value=5)
|
||||
mock_db.delete_by_ids = AsyncMock(return_value=2)
|
||||
mock_db.list_collections = AsyncMock(return_value=[])
|
||||
mock_db.get_collection_stats = AsyncMock()
|
||||
mock_db.cleanup_expired = AsyncMock(return_value=0)
|
||||
|
||||
return mock_db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embeddings():
|
||||
"""Create mock embedding generator."""
|
||||
from embeddings import EmbeddingGenerator
|
||||
|
||||
mock_emb = MagicMock(spec=EmbeddingGenerator)
|
||||
mock_emb.initialize = AsyncMock()
|
||||
mock_emb.close = AsyncMock()
|
||||
|
||||
# Generate fake embeddings (1536 dimensions)
|
||||
def fake_embedding() -> list[float]:
|
||||
return [0.1] * 1536
|
||||
|
||||
mock_emb.generate = AsyncMock(return_value=fake_embedding())
|
||||
mock_emb.generate_batch = AsyncMock(side_effect=lambda texts, **_kwargs: [fake_embedding() for _ in texts])
|
||||
|
||||
return mock_emb
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis():
|
||||
"""Create mock Redis client."""
|
||||
import fakeredis.aioredis
|
||||
|
||||
return fakeredis.aioredis.FakeRedis()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_python_code():
|
||||
"""Sample Python code for chunking tests."""
|
||||
return '''"""Sample module for testing."""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Calculator:
|
||||
"""A simple calculator class."""
|
||||
|
||||
def __init__(self, initial: int = 0) -> None:
|
||||
"""Initialize calculator."""
|
||||
self.value = initial
|
||||
|
||||
def add(self, x: int) -> int:
|
||||
"""Add a value."""
|
||||
self.value += x
|
||||
return self.value
|
||||
|
||||
def subtract(self, x: int) -> int:
|
||||
"""Subtract a value."""
|
||||
self.value -= x
|
||||
return self.value
|
||||
|
||||
|
||||
def helper_function(data: dict[str, Any]) -> str:
|
||||
"""A helper function."""
|
||||
return str(data)
|
||||
|
||||
|
||||
async def async_function() -> None:
|
||||
"""An async function."""
|
||||
pass
|
||||
'''
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_markdown():
|
||||
"""Sample Markdown content for chunking tests."""
|
||||
return '''# Project Documentation
|
||||
|
||||
This is the main documentation for our project.
|
||||
|
||||
## Getting Started
|
||||
|
||||
To get started, follow these steps:
|
||||
|
||||
1. Install dependencies
|
||||
2. Configure settings
|
||||
3. Run the application
|
||||
|
||||
### Prerequisites
|
||||
|
||||
You'll need the following installed:
|
||||
|
||||
- Python 3.12+
|
||||
- PostgreSQL
|
||||
- Redis
|
||||
|
||||
```python
|
||||
# Example code
|
||||
def main():
|
||||
print("Hello, World!")
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### Search Endpoint
|
||||
|
||||
The search endpoint allows you to query the knowledge base.
|
||||
|
||||
**Endpoint:** `POST /api/search`
|
||||
|
||||
**Request:**
|
||||
```json
|
||||
{
|
||||
"query": "your search query",
|
||||
"limit": 10
|
||||
}
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
We welcome contributions! Please see our contributing guide.
|
||||
'''
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_text():
|
||||
"""Sample plain text for chunking tests."""
|
||||
return '''The quick brown fox jumps over the lazy dog. This is a sample text that we use for testing the text chunking functionality. It contains multiple sentences that should be properly split into chunks.
|
||||
|
||||
Each paragraph represents a logical unit of text. The chunker should try to respect paragraph boundaries when possible. This helps maintain context and readability.
|
||||
|
||||
When chunks need to be split mid-paragraph, the chunker should prefer sentence boundaries. This ensures that each chunk contains complete thoughts and is useful for retrieval.
|
||||
|
||||
The final paragraph tests edge cases. What happens with short paragraphs? Do they get merged with adjacent content? Let's find out!
|
||||
'''
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_chunk():
|
||||
"""Sample chunk for testing."""
|
||||
from models import Chunk, ChunkType, FileType
|
||||
|
||||
return Chunk(
|
||||
content="def hello():\n print('Hello')",
|
||||
chunk_type=ChunkType.CODE,
|
||||
file_type=FileType.PYTHON,
|
||||
source_path="/test/hello.py",
|
||||
start_line=1,
|
||||
end_line=2,
|
||||
metadata={"function": "hello"},
|
||||
token_count=15,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_embedding():
|
||||
"""Sample knowledge embedding for testing."""
|
||||
from models import ChunkType, FileType, KnowledgeEmbedding
|
||||
|
||||
return KnowledgeEmbedding(
|
||||
id="test-id-123",
|
||||
project_id="proj-123",
|
||||
collection="default",
|
||||
content="def hello():\n print('Hello')",
|
||||
embedding=[0.1] * 1536,
|
||||
chunk_type=ChunkType.CODE,
|
||||
source_path="/test/hello.py",
|
||||
start_line=1,
|
||||
end_line=2,
|
||||
file_type=FileType.PYTHON,
|
||||
metadata={"function": "hello"},
|
||||
content_hash="abc123",
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_ingest_request():
|
||||
"""Sample ingest request for testing."""
|
||||
from models import ChunkType, FileType, IngestRequest
|
||||
|
||||
return IngestRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
content="def hello():\n print('Hello')",
|
||||
source_path="/test/hello.py",
|
||||
collection="default",
|
||||
chunk_type=ChunkType.CODE,
|
||||
file_type=FileType.PYTHON,
|
||||
metadata={"test": True},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_search_request():
|
||||
"""Sample search request for testing."""
|
||||
from models import SearchRequest, SearchType
|
||||
|
||||
return SearchRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="hello function",
|
||||
search_type=SearchType.HYBRID,
|
||||
collection="default",
|
||||
limit=10,
|
||||
threshold=0.7,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_delete_request():
|
||||
"""Sample delete request for testing."""
|
||||
from models import DeleteRequest
|
||||
|
||||
return DeleteRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
source_path="/test/hello.py",
|
||||
)
|
||||
422
mcp-servers/knowledge-base/tests/test_chunking.py
Normal file
422
mcp-servers/knowledge-base/tests/test_chunking.py
Normal file
@@ -0,0 +1,422 @@
|
||||
"""Tests for chunking module."""
|
||||
|
||||
|
||||
|
||||
class TestBaseChunker:
|
||||
"""Tests for base chunker functionality."""
|
||||
|
||||
def test_count_tokens(self, settings):
|
||||
"""Test token counting."""
|
||||
from chunking.text import TextChunker
|
||||
|
||||
chunker = TextChunker(
|
||||
chunk_size=400,
|
||||
chunk_overlap=50,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
# Simple text should count tokens
|
||||
tokens = chunker.count_tokens("Hello, world!")
|
||||
assert tokens > 0
|
||||
assert tokens < 10 # Should be about 3-4 tokens
|
||||
|
||||
def test_truncate_to_tokens(self, settings):
|
||||
"""Test truncating text to token limit."""
|
||||
from chunking.text import TextChunker
|
||||
|
||||
chunker = TextChunker(
|
||||
chunk_size=400,
|
||||
chunk_overlap=50,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
long_text = "word " * 1000
|
||||
truncated = chunker.truncate_to_tokens(long_text, 10)
|
||||
|
||||
assert chunker.count_tokens(truncated) <= 10
|
||||
|
||||
|
||||
class TestCodeChunker:
|
||||
"""Tests for code chunker."""
|
||||
|
||||
def test_chunk_python_code(self, settings, sample_python_code):
|
||||
"""Test chunking Python code."""
|
||||
from chunking.code import CodeChunker
|
||||
from models import ChunkType, FileType
|
||||
|
||||
chunker = CodeChunker(
|
||||
chunk_size=500,
|
||||
chunk_overlap=50,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
chunks = chunker.chunk(
|
||||
content=sample_python_code,
|
||||
source_path="/test/sample.py",
|
||||
file_type=FileType.PYTHON,
|
||||
)
|
||||
|
||||
assert len(chunks) > 0
|
||||
assert all(c.chunk_type == ChunkType.CODE for c in chunks)
|
||||
assert all(c.file_type == FileType.PYTHON for c in chunks)
|
||||
|
||||
def test_preserves_function_boundaries(self, settings):
|
||||
"""Test that chunker preserves function boundaries."""
|
||||
from chunking.code import CodeChunker
|
||||
from models import FileType
|
||||
|
||||
code = '''def function_one():
|
||||
"""First function."""
|
||||
return 1
|
||||
|
||||
def function_two():
|
||||
"""Second function."""
|
||||
return 2
|
||||
'''
|
||||
|
||||
chunker = CodeChunker(
|
||||
chunk_size=100,
|
||||
chunk_overlap=10,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
chunks = chunker.chunk(
|
||||
content=code,
|
||||
source_path="/test/funcs.py",
|
||||
file_type=FileType.PYTHON,
|
||||
)
|
||||
|
||||
# Each function should ideally be in its own chunk
|
||||
assert len(chunks) >= 1
|
||||
for chunk in chunks:
|
||||
# Check chunks have line numbers
|
||||
assert chunk.start_line is not None
|
||||
assert chunk.end_line is not None
|
||||
assert chunk.start_line <= chunk.end_line
|
||||
|
||||
def test_handles_empty_content(self, settings):
|
||||
"""Test handling empty content."""
|
||||
from chunking.code import CodeChunker
|
||||
|
||||
chunker = CodeChunker(
|
||||
chunk_size=500,
|
||||
chunk_overlap=50,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
chunks = chunker.chunk(content="", source_path="/test/empty.py")
|
||||
|
||||
assert chunks == []
|
||||
|
||||
def test_chunk_type_is_code(self, settings):
|
||||
"""Test that chunk_type property returns CODE."""
|
||||
from chunking.code import CodeChunker
|
||||
from models import ChunkType
|
||||
|
||||
chunker = CodeChunker(
|
||||
chunk_size=500,
|
||||
chunk_overlap=50,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
assert chunker.chunk_type == ChunkType.CODE
|
||||
|
||||
|
||||
class TestMarkdownChunker:
|
||||
"""Tests for markdown chunker."""
|
||||
|
||||
def test_chunk_markdown(self, settings, sample_markdown):
|
||||
"""Test chunking markdown content."""
|
||||
from chunking.markdown import MarkdownChunker
|
||||
from models import ChunkType, FileType
|
||||
|
||||
chunker = MarkdownChunker(
|
||||
chunk_size=800,
|
||||
chunk_overlap=100,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
chunks = chunker.chunk(
|
||||
content=sample_markdown,
|
||||
source_path="/test/docs.md",
|
||||
file_type=FileType.MARKDOWN,
|
||||
)
|
||||
|
||||
assert len(chunks) > 0
|
||||
assert all(c.chunk_type == ChunkType.MARKDOWN for c in chunks)
|
||||
|
||||
def test_respects_heading_hierarchy(self, settings):
|
||||
"""Test that chunker respects heading hierarchy."""
|
||||
from chunking.markdown import MarkdownChunker
|
||||
|
||||
markdown = '''# Main Title
|
||||
|
||||
Introduction paragraph.
|
||||
|
||||
## Section One
|
||||
|
||||
Content for section one.
|
||||
|
||||
### Subsection
|
||||
|
||||
More detailed content.
|
||||
|
||||
## Section Two
|
||||
|
||||
Content for section two.
|
||||
'''
|
||||
|
||||
chunker = MarkdownChunker(
|
||||
chunk_size=200,
|
||||
chunk_overlap=20,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
chunks = chunker.chunk(
|
||||
content=markdown,
|
||||
source_path="/test/docs.md",
|
||||
)
|
||||
|
||||
# Should have multiple chunks based on sections
|
||||
assert len(chunks) >= 1
|
||||
# Metadata should include heading context
|
||||
for chunk in chunks:
|
||||
# Chunks should have content
|
||||
assert len(chunk.content) > 0
|
||||
|
||||
def test_handles_code_blocks(self, settings):
|
||||
"""Test handling of code blocks in markdown."""
|
||||
from chunking.markdown import MarkdownChunker
|
||||
|
||||
markdown = '''# Code Example
|
||||
|
||||
Here's some code:
|
||||
|
||||
```python
|
||||
def hello():
|
||||
print("Hello, World!")
|
||||
```
|
||||
|
||||
End of example.
|
||||
'''
|
||||
|
||||
chunker = MarkdownChunker(
|
||||
chunk_size=500,
|
||||
chunk_overlap=50,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
chunks = chunker.chunk(
|
||||
content=markdown,
|
||||
source_path="/test/code.md",
|
||||
)
|
||||
|
||||
# Code blocks should be preserved
|
||||
assert len(chunks) >= 1
|
||||
full_content = " ".join(c.content for c in chunks)
|
||||
assert "```python" in full_content or "def hello" in full_content
|
||||
|
||||
def test_chunk_type_is_markdown(self, settings):
|
||||
"""Test that chunk_type property returns MARKDOWN."""
|
||||
from chunking.markdown import MarkdownChunker
|
||||
from models import ChunkType
|
||||
|
||||
chunker = MarkdownChunker(
|
||||
chunk_size=800,
|
||||
chunk_overlap=100,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
assert chunker.chunk_type == ChunkType.MARKDOWN
|
||||
|
||||
|
||||
class TestTextChunker:
|
||||
"""Tests for text chunker."""
|
||||
|
||||
def test_chunk_text(self, settings, sample_text):
|
||||
"""Test chunking plain text."""
|
||||
from chunking.text import TextChunker
|
||||
from models import ChunkType
|
||||
|
||||
chunker = TextChunker(
|
||||
chunk_size=400,
|
||||
chunk_overlap=50,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
chunks = chunker.chunk(
|
||||
content=sample_text,
|
||||
source_path="/test/docs.txt",
|
||||
)
|
||||
|
||||
assert len(chunks) > 0
|
||||
assert all(c.chunk_type == ChunkType.TEXT for c in chunks)
|
||||
|
||||
def test_respects_paragraph_boundaries(self, settings):
|
||||
"""Test that chunker respects paragraph boundaries."""
|
||||
from chunking.text import TextChunker
|
||||
|
||||
text = '''First paragraph with some content.
|
||||
|
||||
Second paragraph with different content.
|
||||
|
||||
Third paragraph to test chunking behavior.
|
||||
'''
|
||||
|
||||
chunker = TextChunker(
|
||||
chunk_size=100,
|
||||
chunk_overlap=10,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
chunks = chunker.chunk(
|
||||
content=text,
|
||||
source_path="/test/text.txt",
|
||||
)
|
||||
|
||||
assert len(chunks) >= 1
|
||||
|
||||
def test_handles_single_paragraph(self, settings):
|
||||
"""Test handling of single paragraph that fits in one chunk."""
|
||||
from chunking.text import TextChunker
|
||||
|
||||
text = "This is a short paragraph."
|
||||
|
||||
chunker = TextChunker(
|
||||
chunk_size=400,
|
||||
chunk_overlap=50,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
chunks = chunker.chunk(content=text, source_path="/test/short.txt")
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].content == text
|
||||
|
||||
def test_chunk_type_is_text(self, settings):
|
||||
"""Test that chunk_type property returns TEXT."""
|
||||
from chunking.text import TextChunker
|
||||
from models import ChunkType
|
||||
|
||||
chunker = TextChunker(
|
||||
chunk_size=400,
|
||||
chunk_overlap=50,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
assert chunker.chunk_type == ChunkType.TEXT
|
||||
|
||||
|
||||
class TestChunkerFactory:
|
||||
"""Tests for chunker factory."""
|
||||
|
||||
def test_get_code_chunker(self, settings):
|
||||
"""Test getting code chunker."""
|
||||
from chunking.base import ChunkerFactory
|
||||
from chunking.code import CodeChunker
|
||||
from models import FileType
|
||||
|
||||
factory = ChunkerFactory(settings=settings)
|
||||
chunker = factory.get_chunker(file_type=FileType.PYTHON)
|
||||
|
||||
assert isinstance(chunker, CodeChunker)
|
||||
|
||||
def test_get_markdown_chunker(self, settings):
|
||||
"""Test getting markdown chunker."""
|
||||
from chunking.base import ChunkerFactory
|
||||
from chunking.markdown import MarkdownChunker
|
||||
from models import FileType
|
||||
|
||||
factory = ChunkerFactory(settings=settings)
|
||||
chunker = factory.get_chunker(file_type=FileType.MARKDOWN)
|
||||
|
||||
assert isinstance(chunker, MarkdownChunker)
|
||||
|
||||
def test_get_text_chunker(self, settings):
|
||||
"""Test getting text chunker."""
|
||||
from chunking.base import ChunkerFactory
|
||||
from chunking.text import TextChunker
|
||||
from models import FileType
|
||||
|
||||
factory = ChunkerFactory(settings=settings)
|
||||
chunker = factory.get_chunker(file_type=FileType.TEXT)
|
||||
|
||||
assert isinstance(chunker, TextChunker)
|
||||
|
||||
def test_get_chunker_for_path(self, settings):
|
||||
"""Test getting chunker based on file path."""
|
||||
from chunking.base import ChunkerFactory
|
||||
from chunking.code import CodeChunker
|
||||
from chunking.markdown import MarkdownChunker
|
||||
from models import FileType
|
||||
|
||||
factory = ChunkerFactory(settings=settings)
|
||||
|
||||
chunker, file_type = factory.get_chunker_for_path("/test/file.py")
|
||||
assert isinstance(chunker, CodeChunker)
|
||||
assert file_type == FileType.PYTHON
|
||||
|
||||
chunker, file_type = factory.get_chunker_for_path("/test/docs.md")
|
||||
assert isinstance(chunker, MarkdownChunker)
|
||||
assert file_type == FileType.MARKDOWN
|
||||
|
||||
def test_chunk_content(self, settings, sample_python_code):
|
||||
"""Test chunk_content convenience method."""
|
||||
from chunking.base import ChunkerFactory
|
||||
from models import ChunkType
|
||||
|
||||
factory = ChunkerFactory(settings=settings)
|
||||
|
||||
chunks = factory.chunk_content(
|
||||
content=sample_python_code,
|
||||
source_path="/test/sample.py",
|
||||
)
|
||||
|
||||
assert len(chunks) > 0
|
||||
assert all(c.chunk_type == ChunkType.CODE for c in chunks)
|
||||
|
||||
def test_default_to_text_chunker(self, settings):
|
||||
"""Test defaulting to text chunker."""
|
||||
from chunking.base import ChunkerFactory
|
||||
from chunking.text import TextChunker
|
||||
|
||||
factory = ChunkerFactory(settings=settings)
|
||||
chunker = factory.get_chunker()
|
||||
|
||||
assert isinstance(chunker, TextChunker)
|
||||
|
||||
def test_chunker_caching(self, settings):
|
||||
"""Test that factory caches chunker instances."""
|
||||
from chunking.base import ChunkerFactory
|
||||
from models import FileType
|
||||
|
||||
factory = ChunkerFactory(settings=settings)
|
||||
|
||||
chunker1 = factory.get_chunker(file_type=FileType.PYTHON)
|
||||
chunker2 = factory.get_chunker(file_type=FileType.PYTHON)
|
||||
|
||||
assert chunker1 is chunker2
|
||||
|
||||
|
||||
class TestGlobalChunkerFactory:
|
||||
"""Tests for global chunker factory."""
|
||||
|
||||
def test_get_chunker_factory_singleton(self):
|
||||
"""Test that get_chunker_factory returns singleton."""
|
||||
from chunking.base import get_chunker_factory, reset_chunker_factory
|
||||
|
||||
reset_chunker_factory()
|
||||
factory1 = get_chunker_factory()
|
||||
factory2 = get_chunker_factory()
|
||||
|
||||
assert factory1 is factory2
|
||||
|
||||
def test_reset_chunker_factory(self):
|
||||
"""Test resetting chunker factory."""
|
||||
from chunking.base import get_chunker_factory, reset_chunker_factory
|
||||
|
||||
factory1 = get_chunker_factory()
|
||||
reset_chunker_factory()
|
||||
factory2 = get_chunker_factory()
|
||||
|
||||
assert factory1 is not factory2
|
||||
240
mcp-servers/knowledge-base/tests/test_collection_manager.py
Normal file
240
mcp-servers/knowledge-base/tests/test_collection_manager.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""Tests for collection manager module."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestCollectionManager:
|
||||
"""Tests for CollectionManager class."""
|
||||
|
||||
@pytest.fixture
|
||||
def collection_manager(self, settings, mock_database, mock_embeddings):
|
||||
"""Create collection manager with mocks."""
|
||||
from chunking.base import ChunkerFactory
|
||||
from collection_manager import CollectionManager
|
||||
|
||||
mock_chunker_factory = MagicMock(spec=ChunkerFactory)
|
||||
|
||||
# Mock chunk_content to return chunks
|
||||
from models import Chunk, ChunkType
|
||||
|
||||
mock_chunker_factory.chunk_content.return_value = [
|
||||
Chunk(
|
||||
content="def hello(): pass",
|
||||
chunk_type=ChunkType.CODE,
|
||||
token_count=10,
|
||||
)
|
||||
]
|
||||
|
||||
return CollectionManager(
|
||||
settings=settings,
|
||||
database=mock_database,
|
||||
embeddings=mock_embeddings,
|
||||
chunker_factory=mock_chunker_factory,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_content(self, collection_manager, sample_ingest_request):
|
||||
"""Test content ingestion."""
|
||||
result = await collection_manager.ingest(sample_ingest_request)
|
||||
|
||||
assert result.success is True
|
||||
assert result.chunks_created == 1
|
||||
assert result.embeddings_generated == 1
|
||||
assert len(result.chunk_ids) == 1
|
||||
assert result.collection == "default"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_empty_content(self, collection_manager):
|
||||
"""Test ingesting empty content."""
|
||||
from models import IngestRequest
|
||||
|
||||
# Mock chunker to return empty list
|
||||
collection_manager._chunker_factory.chunk_content.return_value = []
|
||||
|
||||
request = IngestRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
content="",
|
||||
)
|
||||
|
||||
result = await collection_manager.ingest(request)
|
||||
|
||||
assert result.success is True
|
||||
assert result.chunks_created == 0
|
||||
assert result.embeddings_generated == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_error_handling(self, collection_manager, sample_ingest_request):
|
||||
"""Test ingest error handling."""
|
||||
# Make embedding generation fail
|
||||
collection_manager._embeddings.generate_batch.side_effect = Exception("Embedding error")
|
||||
|
||||
result = await collection_manager.ingest(sample_ingest_request)
|
||||
|
||||
assert result.success is False
|
||||
assert "Embedding error" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_by_source(self, collection_manager, sample_delete_request):
|
||||
"""Test deletion by source path."""
|
||||
result = await collection_manager.delete(sample_delete_request)
|
||||
|
||||
assert result.success is True
|
||||
assert result.chunks_deleted == 1 # Mock returns 1
|
||||
collection_manager._database.delete_by_source.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_by_collection(self, collection_manager):
|
||||
"""Test deletion by collection."""
|
||||
from models import DeleteRequest
|
||||
|
||||
request = DeleteRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
collection="to-delete",
|
||||
)
|
||||
|
||||
result = await collection_manager.delete(request)
|
||||
|
||||
assert result.success is True
|
||||
collection_manager._database.delete_collection.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_by_ids(self, collection_manager):
|
||||
"""Test deletion by chunk IDs."""
|
||||
from models import DeleteRequest
|
||||
|
||||
request = DeleteRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
chunk_ids=["id-1", "id-2"],
|
||||
)
|
||||
|
||||
result = await collection_manager.delete(request)
|
||||
|
||||
assert result.success is True
|
||||
collection_manager._database.delete_by_ids.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_no_target(self, collection_manager):
|
||||
"""Test deletion with no target specified."""
|
||||
from models import DeleteRequest
|
||||
|
||||
request = DeleteRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
)
|
||||
|
||||
result = await collection_manager.delete(request)
|
||||
|
||||
assert result.success is False
|
||||
assert "Must specify" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_collections(self, collection_manager):
|
||||
"""Test listing collections."""
|
||||
from models import CollectionInfo
|
||||
|
||||
collection_manager._database.list_collections.return_value = [
|
||||
CollectionInfo(
|
||||
name="collection-1",
|
||||
project_id="proj-123",
|
||||
chunk_count=100,
|
||||
total_tokens=50000,
|
||||
file_types=["python"],
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
),
|
||||
CollectionInfo(
|
||||
name="collection-2",
|
||||
project_id="proj-123",
|
||||
chunk_count=50,
|
||||
total_tokens=25000,
|
||||
file_types=["javascript"],
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
),
|
||||
]
|
||||
|
||||
result = await collection_manager.list_collections("proj-123")
|
||||
|
||||
assert result.project_id == "proj-123"
|
||||
assert result.total_collections == 2
|
||||
assert len(result.collections) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_collection_stats(self, collection_manager):
|
||||
"""Test getting collection statistics."""
|
||||
from models import CollectionStats
|
||||
|
||||
expected_stats = CollectionStats(
|
||||
collection="test-collection",
|
||||
project_id="proj-123",
|
||||
chunk_count=100,
|
||||
unique_sources=10,
|
||||
total_tokens=50000,
|
||||
avg_chunk_size=500.0,
|
||||
chunk_types={"code": 60, "text": 40},
|
||||
file_types={"python": 50, "javascript": 10},
|
||||
)
|
||||
collection_manager._database.get_collection_stats.return_value = expected_stats
|
||||
|
||||
stats = await collection_manager.get_collection_stats("proj-123", "test-collection")
|
||||
|
||||
assert stats.chunk_count == 100
|
||||
assert stats.unique_sources == 10
|
||||
collection_manager._database.get_collection_stats.assert_called_once_with(
|
||||
"proj-123", "test-collection"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_document(self, collection_manager):
|
||||
"""Test updating a document."""
|
||||
result = await collection_manager.update_document(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
source_path="/test/file.py",
|
||||
content="def updated(): pass",
|
||||
collection="default",
|
||||
)
|
||||
|
||||
# Should delete first, then ingest
|
||||
collection_manager._database.delete_by_source.assert_called_once()
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired(self, collection_manager):
|
||||
"""Test cleaning up expired embeddings."""
|
||||
collection_manager._database.cleanup_expired.return_value = 10
|
||||
|
||||
count = await collection_manager.cleanup_expired()
|
||||
|
||||
assert count == 10
|
||||
collection_manager._database.cleanup_expired.assert_called_once()
|
||||
|
||||
|
||||
class TestGlobalCollectionManager:
|
||||
"""Tests for global collection manager."""
|
||||
|
||||
def test_get_collection_manager_singleton(self):
|
||||
"""Test that get_collection_manager returns singleton."""
|
||||
from collection_manager import get_collection_manager, reset_collection_manager
|
||||
|
||||
reset_collection_manager()
|
||||
manager1 = get_collection_manager()
|
||||
manager2 = get_collection_manager()
|
||||
|
||||
assert manager1 is manager2
|
||||
|
||||
def test_reset_collection_manager(self):
|
||||
"""Test resetting collection manager."""
|
||||
from collection_manager import get_collection_manager, reset_collection_manager
|
||||
|
||||
manager1 = get_collection_manager()
|
||||
reset_collection_manager()
|
||||
manager2 = get_collection_manager()
|
||||
|
||||
assert manager1 is not manager2
|
||||
104
mcp-servers/knowledge-base/tests/test_config.py
Normal file
104
mcp-servers/knowledge-base/tests/test_config.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Tests for configuration module."""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class TestSettings:
|
||||
"""Tests for Settings class."""
|
||||
|
||||
def test_default_values(self, settings):
|
||||
"""Test default configuration values."""
|
||||
assert settings.port == 8002
|
||||
assert settings.embedding_dimension == 1536
|
||||
assert settings.code_chunk_size == 500
|
||||
assert settings.search_default_limit == 10
|
||||
|
||||
def test_env_prefix(self):
|
||||
"""Test environment variable prefix."""
|
||||
from config import Settings, reset_settings
|
||||
|
||||
reset_settings()
|
||||
os.environ["KB_PORT"] = "9999"
|
||||
|
||||
settings = Settings()
|
||||
assert settings.port == 9999
|
||||
|
||||
# Cleanup
|
||||
del os.environ["KB_PORT"]
|
||||
reset_settings()
|
||||
|
||||
def test_embedding_settings(self, settings):
|
||||
"""Test embedding-related settings."""
|
||||
assert settings.embedding_model == "text-embedding-3-large"
|
||||
assert settings.embedding_batch_size == 100
|
||||
assert settings.embedding_cache_ttl == 86400
|
||||
|
||||
def test_chunking_settings(self, settings):
|
||||
"""Test chunking-related settings."""
|
||||
assert settings.code_chunk_size == 500
|
||||
assert settings.code_chunk_overlap == 50
|
||||
assert settings.markdown_chunk_size == 800
|
||||
assert settings.markdown_chunk_overlap == 100
|
||||
assert settings.text_chunk_size == 400
|
||||
assert settings.text_chunk_overlap == 50
|
||||
|
||||
def test_search_settings(self, settings):
|
||||
"""Test search-related settings."""
|
||||
assert settings.search_default_limit == 10
|
||||
assert settings.search_max_limit == 100
|
||||
assert settings.semantic_threshold == 0.7
|
||||
assert settings.hybrid_semantic_weight == 0.7
|
||||
assert settings.hybrid_keyword_weight == 0.3
|
||||
|
||||
|
||||
class TestGetSettings:
|
||||
"""Tests for get_settings function."""
|
||||
|
||||
def test_returns_singleton(self):
|
||||
"""Test that get_settings returns singleton."""
|
||||
from config import get_settings, reset_settings
|
||||
|
||||
reset_settings()
|
||||
settings1 = get_settings()
|
||||
settings2 = get_settings()
|
||||
assert settings1 is settings2
|
||||
|
||||
def test_reset_creates_new_instance(self):
|
||||
"""Test that reset_settings clears the singleton."""
|
||||
from config import get_settings, reset_settings
|
||||
|
||||
settings1 = get_settings()
|
||||
reset_settings()
|
||||
settings2 = get_settings()
|
||||
assert settings1 is not settings2
|
||||
|
||||
|
||||
class TestIsTestMode:
|
||||
"""Tests for is_test_mode function."""
|
||||
|
||||
def test_returns_true_when_set(self):
|
||||
"""Test returns True when IS_TEST is set."""
|
||||
from config import is_test_mode
|
||||
|
||||
old_value = os.environ.get("IS_TEST")
|
||||
os.environ["IS_TEST"] = "true"
|
||||
|
||||
assert is_test_mode() is True
|
||||
|
||||
if old_value:
|
||||
os.environ["IS_TEST"] = old_value
|
||||
else:
|
||||
del os.environ["IS_TEST"]
|
||||
|
||||
def test_returns_false_when_not_set(self):
|
||||
"""Test returns False when IS_TEST is not set."""
|
||||
from config import is_test_mode
|
||||
|
||||
old_value = os.environ.get("IS_TEST")
|
||||
if old_value:
|
||||
del os.environ["IS_TEST"]
|
||||
|
||||
assert is_test_mode() is False
|
||||
|
||||
if old_value:
|
||||
os.environ["IS_TEST"] = old_value
|
||||
245
mcp-servers/knowledge-base/tests/test_embeddings.py
Normal file
245
mcp-servers/knowledge-base/tests/test_embeddings.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""Tests for embedding generation module."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestEmbeddingGenerator:
|
||||
"""Tests for EmbeddingGenerator class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_http_response(self):
|
||||
"""Create mock HTTP response."""
|
||||
response = MagicMock()
|
||||
response.status_code = 200
|
||||
response.raise_for_status = MagicMock()
|
||||
response.json.return_value = {
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"text": json.dumps({
|
||||
"embeddings": [[0.1] * 1536]
|
||||
})
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
return response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_single_embedding(self, settings, mock_redis, mock_http_response):
|
||||
"""Test generating a single embedding."""
|
||||
from embeddings import EmbeddingGenerator
|
||||
|
||||
generator = EmbeddingGenerator(settings=settings)
|
||||
generator._redis = mock_redis
|
||||
|
||||
# Mock HTTP client
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_http_response
|
||||
generator._http_client = mock_client
|
||||
|
||||
embedding = await generator.generate(
|
||||
text="Hello, world!",
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
)
|
||||
|
||||
assert len(embedding) == 1536
|
||||
mock_client.post.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_batch_embeddings(self, settings, mock_redis):
|
||||
"""Test generating batch embeddings."""
|
||||
from embeddings import EmbeddingGenerator
|
||||
|
||||
generator = EmbeddingGenerator(settings=settings)
|
||||
generator._redis = mock_redis
|
||||
|
||||
# Mock HTTP client with batch response
|
||||
mock_client = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"text": json.dumps({
|
||||
"embeddings": [[0.1] * 1536, [0.2] * 1536, [0.3] * 1536]
|
||||
})
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
mock_client.post.return_value = mock_response
|
||||
generator._http_client = mock_client
|
||||
|
||||
embeddings = await generator.generate_batch(
|
||||
texts=["Text 1", "Text 2", "Text 3"],
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
)
|
||||
|
||||
assert len(embeddings) == 3
|
||||
assert all(len(e) == 1536 for e in embeddings)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caching(self, settings, mock_redis):
|
||||
"""Test embedding caching."""
|
||||
from embeddings import EmbeddingGenerator
|
||||
|
||||
generator = EmbeddingGenerator(settings=settings)
|
||||
generator._redis = mock_redis
|
||||
|
||||
# Pre-populate cache
|
||||
cache_key = generator._cache_key("Hello, world!")
|
||||
await mock_redis.setex(cache_key, 3600, json.dumps([0.5] * 1536))
|
||||
|
||||
# Mock HTTP client (should not be called)
|
||||
mock_client = AsyncMock()
|
||||
generator._http_client = mock_client
|
||||
|
||||
embedding = await generator.generate(
|
||||
text="Hello, world!",
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
)
|
||||
|
||||
# Should return cached embedding
|
||||
assert len(embedding) == 1536
|
||||
assert embedding[0] == 0.5
|
||||
mock_client.post.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_miss(self, settings, mock_redis, mock_http_response):
|
||||
"""Test embedding cache miss."""
|
||||
from embeddings import EmbeddingGenerator
|
||||
|
||||
generator = EmbeddingGenerator(settings=settings)
|
||||
generator._redis = mock_redis
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_http_response
|
||||
generator._http_client = mock_client
|
||||
|
||||
embedding = await generator.generate(
|
||||
text="New text not in cache",
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
)
|
||||
|
||||
assert len(embedding) == 1536
|
||||
mock_client.post.assert_called_once()
|
||||
|
||||
def test_cache_key_generation(self, settings):
|
||||
"""Test cache key generation."""
|
||||
from embeddings import EmbeddingGenerator
|
||||
|
||||
generator = EmbeddingGenerator(settings=settings)
|
||||
|
||||
key1 = generator._cache_key("Hello")
|
||||
key2 = generator._cache_key("Hello")
|
||||
key3 = generator._cache_key("World")
|
||||
|
||||
assert key1 == key2
|
||||
assert key1 != key3
|
||||
assert key1.startswith("kb:emb:")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dimension_validation(self, settings, mock_redis):
|
||||
"""Test embedding dimension validation."""
|
||||
from embeddings import EmbeddingGenerator
|
||||
from exceptions import EmbeddingDimensionMismatchError
|
||||
|
||||
generator = EmbeddingGenerator(settings=settings)
|
||||
generator._redis = mock_redis
|
||||
|
||||
# Mock HTTP client with wrong dimension
|
||||
mock_client = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"text": json.dumps({
|
||||
"embeddings": [[0.1] * 768] # Wrong dimension
|
||||
})
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
mock_client.post.return_value = mock_response
|
||||
generator._http_client = mock_client
|
||||
|
||||
with pytest.raises(EmbeddingDimensionMismatchError):
|
||||
await generator.generate(
|
||||
text="Test text",
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_batch(self, settings, mock_redis):
|
||||
"""Test generating embeddings for empty batch."""
|
||||
from embeddings import EmbeddingGenerator
|
||||
|
||||
generator = EmbeddingGenerator(settings=settings)
|
||||
generator._redis = mock_redis
|
||||
|
||||
embeddings = await generator.generate_batch(
|
||||
texts=[],
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
)
|
||||
|
||||
assert embeddings == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_and_close(self, settings):
|
||||
"""Test initialize and close methods."""
|
||||
from embeddings import EmbeddingGenerator
|
||||
|
||||
generator = EmbeddingGenerator(settings=settings)
|
||||
|
||||
# Mock successful initialization
|
||||
with patch("embeddings.redis.from_url") as mock_redis_from_url:
|
||||
mock_redis_client = AsyncMock()
|
||||
mock_redis_client.ping = AsyncMock()
|
||||
mock_redis_from_url.return_value = mock_redis_client
|
||||
|
||||
await generator.initialize()
|
||||
|
||||
assert generator._http_client is not None
|
||||
|
||||
await generator.close()
|
||||
|
||||
assert generator._http_client is None
|
||||
|
||||
|
||||
class TestGlobalEmbeddingGenerator:
|
||||
"""Tests for global embedding generator."""
|
||||
|
||||
def test_get_embedding_generator_singleton(self):
|
||||
"""Test that get_embedding_generator returns singleton."""
|
||||
from embeddings import get_embedding_generator, reset_embedding_generator
|
||||
|
||||
reset_embedding_generator()
|
||||
gen1 = get_embedding_generator()
|
||||
gen2 = get_embedding_generator()
|
||||
|
||||
assert gen1 is gen2
|
||||
|
||||
def test_reset_embedding_generator(self):
|
||||
"""Test resetting embedding generator."""
|
||||
from embeddings import get_embedding_generator, reset_embedding_generator
|
||||
|
||||
gen1 = get_embedding_generator()
|
||||
reset_embedding_generator()
|
||||
gen2 = get_embedding_generator()
|
||||
|
||||
assert gen1 is not gen2
|
||||
307
mcp-servers/knowledge-base/tests/test_exceptions.py
Normal file
307
mcp-servers/knowledge-base/tests/test_exceptions.py
Normal file
@@ -0,0 +1,307 @@
|
||||
"""Tests for exception classes."""
|
||||
|
||||
|
||||
|
||||
class TestErrorCode:
|
||||
"""Tests for ErrorCode enum."""
|
||||
|
||||
def test_error_code_values(self):
|
||||
"""Test error code values."""
|
||||
from exceptions import ErrorCode
|
||||
|
||||
assert ErrorCode.UNKNOWN_ERROR.value == "KB_UNKNOWN_ERROR"
|
||||
assert ErrorCode.DATABASE_CONNECTION_ERROR.value == "KB_DATABASE_CONNECTION_ERROR"
|
||||
assert ErrorCode.EMBEDDING_GENERATION_ERROR.value == "KB_EMBEDDING_GENERATION_ERROR"
|
||||
assert ErrorCode.CHUNKING_ERROR.value == "KB_CHUNKING_ERROR"
|
||||
assert ErrorCode.SEARCH_ERROR.value == "KB_SEARCH_ERROR"
|
||||
assert ErrorCode.COLLECTION_NOT_FOUND.value == "KB_COLLECTION_NOT_FOUND"
|
||||
assert ErrorCode.DOCUMENT_NOT_FOUND.value == "KB_DOCUMENT_NOT_FOUND"
|
||||
|
||||
|
||||
class TestKnowledgeBaseError:
|
||||
"""Tests for base exception class."""
|
||||
|
||||
def test_basic_error(self):
|
||||
"""Test basic error creation."""
|
||||
from exceptions import ErrorCode, KnowledgeBaseError
|
||||
|
||||
error = KnowledgeBaseError(
|
||||
message="Something went wrong",
|
||||
code=ErrorCode.UNKNOWN_ERROR,
|
||||
)
|
||||
|
||||
assert error.message == "Something went wrong"
|
||||
assert error.code == ErrorCode.UNKNOWN_ERROR
|
||||
assert error.details == {}
|
||||
assert error.cause is None
|
||||
|
||||
def test_error_with_details(self):
|
||||
"""Test error with details."""
|
||||
from exceptions import ErrorCode, KnowledgeBaseError
|
||||
|
||||
error = KnowledgeBaseError(
|
||||
message="Query failed",
|
||||
code=ErrorCode.DATABASE_QUERY_ERROR,
|
||||
details={"query": "SELECT * FROM table", "error_code": 42},
|
||||
)
|
||||
|
||||
assert error.details["query"] == "SELECT * FROM table"
|
||||
assert error.details["error_code"] == 42
|
||||
|
||||
def test_error_with_cause(self):
|
||||
"""Test error with underlying cause."""
|
||||
from exceptions import ErrorCode, KnowledgeBaseError
|
||||
|
||||
original = ValueError("Original error")
|
||||
error = KnowledgeBaseError(
|
||||
message="Wrapped error",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
cause=original,
|
||||
)
|
||||
|
||||
assert error.cause is original
|
||||
assert isinstance(error.cause, ValueError)
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test to_dict method."""
|
||||
from exceptions import ErrorCode, KnowledgeBaseError
|
||||
|
||||
error = KnowledgeBaseError(
|
||||
message="Test error",
|
||||
code=ErrorCode.INVALID_REQUEST,
|
||||
details={"field": "value"},
|
||||
)
|
||||
|
||||
result = error.to_dict()
|
||||
|
||||
assert result["error"] == "KB_INVALID_REQUEST"
|
||||
assert result["message"] == "Test error"
|
||||
assert result["details"]["field"] == "value"
|
||||
|
||||
def test_str_representation(self):
|
||||
"""Test string representation."""
|
||||
from exceptions import ErrorCode, KnowledgeBaseError
|
||||
|
||||
error = KnowledgeBaseError(
|
||||
message="Test error",
|
||||
code=ErrorCode.INVALID_REQUEST,
|
||||
)
|
||||
|
||||
assert str(error) == "[KB_INVALID_REQUEST] Test error"
|
||||
|
||||
def test_repr_representation(self):
|
||||
"""Test repr representation."""
|
||||
from exceptions import ErrorCode, KnowledgeBaseError
|
||||
|
||||
error = KnowledgeBaseError(
|
||||
message="Test error",
|
||||
code=ErrorCode.INVALID_REQUEST,
|
||||
details={"key": "value"},
|
||||
)
|
||||
|
||||
repr_str = repr(error)
|
||||
assert "KnowledgeBaseError" in repr_str
|
||||
assert "Test error" in repr_str
|
||||
assert "KB_INVALID_REQUEST" in repr_str
|
||||
|
||||
|
||||
class TestDatabaseErrors:
|
||||
"""Tests for database-related exceptions."""
|
||||
|
||||
def test_database_connection_error(self):
|
||||
"""Test database connection error."""
|
||||
from exceptions import DatabaseConnectionError, ErrorCode
|
||||
|
||||
error = DatabaseConnectionError(
|
||||
message="Cannot connect to database",
|
||||
details={"host": "localhost", "port": 5432},
|
||||
)
|
||||
|
||||
assert error.code == ErrorCode.DATABASE_CONNECTION_ERROR
|
||||
assert error.details["host"] == "localhost"
|
||||
|
||||
def test_database_connection_error_default_message(self):
|
||||
"""Test database connection error with default message."""
|
||||
from exceptions import DatabaseConnectionError
|
||||
|
||||
error = DatabaseConnectionError()
|
||||
|
||||
assert error.message == "Failed to connect to database"
|
||||
|
||||
def test_database_query_error(self):
|
||||
"""Test database query error."""
|
||||
from exceptions import DatabaseQueryError, ErrorCode
|
||||
|
||||
error = DatabaseQueryError(
|
||||
message="Query failed",
|
||||
query="SELECT * FROM missing_table",
|
||||
)
|
||||
|
||||
assert error.code == ErrorCode.DATABASE_QUERY_ERROR
|
||||
assert error.details["query"] == "SELECT * FROM missing_table"
|
||||
|
||||
|
||||
class TestEmbeddingErrors:
|
||||
"""Tests for embedding-related exceptions."""
|
||||
|
||||
def test_embedding_generation_error(self):
|
||||
"""Test embedding generation error."""
|
||||
from exceptions import EmbeddingGenerationError, ErrorCode
|
||||
|
||||
error = EmbeddingGenerationError(
|
||||
message="Failed to generate",
|
||||
texts_count=10,
|
||||
)
|
||||
|
||||
assert error.code == ErrorCode.EMBEDDING_GENERATION_ERROR
|
||||
assert error.details["texts_count"] == 10
|
||||
|
||||
def test_embedding_dimension_mismatch(self):
|
||||
"""Test embedding dimension mismatch error."""
|
||||
from exceptions import EmbeddingDimensionMismatchError, ErrorCode
|
||||
|
||||
error = EmbeddingDimensionMismatchError(
|
||||
expected=1536,
|
||||
actual=768,
|
||||
)
|
||||
|
||||
assert error.code == ErrorCode.EMBEDDING_DIMENSION_MISMATCH
|
||||
assert "expected 1536" in error.message
|
||||
assert "got 768" in error.message
|
||||
assert error.details["expected_dimension"] == 1536
|
||||
assert error.details["actual_dimension"] == 768
|
||||
|
||||
|
||||
class TestChunkingErrors:
|
||||
"""Tests for chunking-related exceptions."""
|
||||
|
||||
def test_unsupported_file_type_error(self):
|
||||
"""Test unsupported file type error."""
|
||||
from exceptions import ErrorCode, UnsupportedFileTypeError
|
||||
|
||||
error = UnsupportedFileTypeError(
|
||||
file_type=".xyz",
|
||||
supported_types=[".py", ".js", ".md"],
|
||||
)
|
||||
|
||||
assert error.code == ErrorCode.UNSUPPORTED_FILE_TYPE
|
||||
assert error.details["file_type"] == ".xyz"
|
||||
assert len(error.details["supported_types"]) == 3
|
||||
|
||||
def test_file_too_large_error(self):
|
||||
"""Test file too large error."""
|
||||
from exceptions import ErrorCode, FileTooLargeError
|
||||
|
||||
error = FileTooLargeError(
|
||||
file_size=10_000_000,
|
||||
max_size=1_000_000,
|
||||
)
|
||||
|
||||
assert error.code == ErrorCode.FILE_TOO_LARGE
|
||||
assert error.details["file_size"] == 10_000_000
|
||||
assert error.details["max_size"] == 1_000_000
|
||||
|
||||
def test_encoding_error(self):
|
||||
"""Test encoding error."""
|
||||
from exceptions import EncodingError, ErrorCode
|
||||
|
||||
error = EncodingError(
|
||||
message="Cannot decode file",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
assert error.code == ErrorCode.ENCODING_ERROR
|
||||
assert error.details["encoding"] == "utf-8"
|
||||
|
||||
|
||||
class TestSearchErrors:
|
||||
"""Tests for search-related exceptions."""
|
||||
|
||||
def test_invalid_search_type_error(self):
|
||||
"""Test invalid search type error."""
|
||||
from exceptions import ErrorCode, InvalidSearchTypeError
|
||||
|
||||
error = InvalidSearchTypeError(
|
||||
search_type="invalid",
|
||||
valid_types=["semantic", "keyword", "hybrid"],
|
||||
)
|
||||
|
||||
assert error.code == ErrorCode.INVALID_SEARCH_TYPE
|
||||
assert error.details["search_type"] == "invalid"
|
||||
assert len(error.details["valid_types"]) == 3
|
||||
|
||||
def test_search_timeout_error(self):
|
||||
"""Test search timeout error."""
|
||||
from exceptions import ErrorCode, SearchTimeoutError
|
||||
|
||||
error = SearchTimeoutError(timeout=30.0)
|
||||
|
||||
assert error.code == ErrorCode.SEARCH_TIMEOUT
|
||||
assert error.details["timeout"] == 30.0
|
||||
assert "30" in error.message
|
||||
|
||||
|
||||
class TestCollectionErrors:
|
||||
"""Tests for collection-related exceptions."""
|
||||
|
||||
def test_collection_not_found_error(self):
|
||||
"""Test collection not found error."""
|
||||
from exceptions import CollectionNotFoundError, ErrorCode
|
||||
|
||||
error = CollectionNotFoundError(
|
||||
collection="missing-collection",
|
||||
project_id="proj-123",
|
||||
)
|
||||
|
||||
assert error.code == ErrorCode.COLLECTION_NOT_FOUND
|
||||
assert error.details["collection"] == "missing-collection"
|
||||
assert error.details["project_id"] == "proj-123"
|
||||
|
||||
|
||||
class TestDocumentErrors:
|
||||
"""Tests for document-related exceptions."""
|
||||
|
||||
def test_document_not_found_error(self):
|
||||
"""Test document not found error."""
|
||||
from exceptions import DocumentNotFoundError, ErrorCode
|
||||
|
||||
error = DocumentNotFoundError(
|
||||
source_path="/path/to/file.py",
|
||||
project_id="proj-123",
|
||||
)
|
||||
|
||||
assert error.code == ErrorCode.DOCUMENT_NOT_FOUND
|
||||
assert error.details["source_path"] == "/path/to/file.py"
|
||||
|
||||
def test_invalid_document_error(self):
|
||||
"""Test invalid document error."""
|
||||
from exceptions import ErrorCode, InvalidDocumentError
|
||||
|
||||
error = InvalidDocumentError(
|
||||
message="Empty content",
|
||||
details={"reason": "no content"},
|
||||
)
|
||||
|
||||
assert error.code == ErrorCode.INVALID_DOCUMENT
|
||||
|
||||
|
||||
class TestProjectErrors:
|
||||
"""Tests for project-related exceptions."""
|
||||
|
||||
def test_project_not_found_error(self):
|
||||
"""Test project not found error."""
|
||||
from exceptions import ErrorCode, ProjectNotFoundError
|
||||
|
||||
error = ProjectNotFoundError(project_id="missing-proj")
|
||||
|
||||
assert error.code == ErrorCode.PROJECT_NOT_FOUND
|
||||
assert error.details["project_id"] == "missing-proj"
|
||||
|
||||
def test_project_access_denied_error(self):
|
||||
"""Test project access denied error."""
|
||||
from exceptions import ErrorCode, ProjectAccessDeniedError
|
||||
|
||||
error = ProjectAccessDeniedError(project_id="restricted-proj")
|
||||
|
||||
assert error.code == ErrorCode.PROJECT_ACCESS_DENIED
|
||||
assert "restricted-proj" in error.message
|
||||
347
mcp-servers/knowledge-base/tests/test_models.py
Normal file
347
mcp-servers/knowledge-base/tests/test_models.py
Normal file
@@ -0,0 +1,347 @@
|
||||
"""Tests for data models."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
|
||||
class TestEnums:
|
||||
"""Tests for enum classes."""
|
||||
|
||||
def test_search_type_values(self):
|
||||
"""Test SearchType enum values."""
|
||||
from models import SearchType
|
||||
|
||||
assert SearchType.SEMANTIC.value == "semantic"
|
||||
assert SearchType.KEYWORD.value == "keyword"
|
||||
assert SearchType.HYBRID.value == "hybrid"
|
||||
|
||||
def test_chunk_type_values(self):
|
||||
"""Test ChunkType enum values."""
|
||||
from models import ChunkType
|
||||
|
||||
assert ChunkType.CODE.value == "code"
|
||||
assert ChunkType.MARKDOWN.value == "markdown"
|
||||
assert ChunkType.TEXT.value == "text"
|
||||
assert ChunkType.DOCUMENTATION.value == "documentation"
|
||||
|
||||
def test_file_type_values(self):
|
||||
"""Test FileType enum values."""
|
||||
from models import FileType
|
||||
|
||||
assert FileType.PYTHON.value == "python"
|
||||
assert FileType.JAVASCRIPT.value == "javascript"
|
||||
assert FileType.TYPESCRIPT.value == "typescript"
|
||||
assert FileType.MARKDOWN.value == "markdown"
|
||||
|
||||
|
||||
class TestFileExtensionMap:
|
||||
"""Tests for file extension mapping."""
|
||||
|
||||
def test_python_extensions(self):
|
||||
"""Test Python file extensions."""
|
||||
from models import FILE_EXTENSION_MAP, FileType
|
||||
|
||||
assert FILE_EXTENSION_MAP[".py"] == FileType.PYTHON
|
||||
|
||||
def test_javascript_extensions(self):
|
||||
"""Test JavaScript file extensions."""
|
||||
from models import FILE_EXTENSION_MAP, FileType
|
||||
|
||||
assert FILE_EXTENSION_MAP[".js"] == FileType.JAVASCRIPT
|
||||
assert FILE_EXTENSION_MAP[".jsx"] == FileType.JAVASCRIPT
|
||||
|
||||
def test_typescript_extensions(self):
|
||||
"""Test TypeScript file extensions."""
|
||||
from models import FILE_EXTENSION_MAP, FileType
|
||||
|
||||
assert FILE_EXTENSION_MAP[".ts"] == FileType.TYPESCRIPT
|
||||
assert FILE_EXTENSION_MAP[".tsx"] == FileType.TYPESCRIPT
|
||||
|
||||
def test_markdown_extensions(self):
|
||||
"""Test Markdown file extensions."""
|
||||
from models import FILE_EXTENSION_MAP, FileType
|
||||
|
||||
assert FILE_EXTENSION_MAP[".md"] == FileType.MARKDOWN
|
||||
assert FILE_EXTENSION_MAP[".mdx"] == FileType.MARKDOWN
|
||||
|
||||
|
||||
class TestChunk:
|
||||
"""Tests for Chunk dataclass."""
|
||||
|
||||
def test_chunk_creation(self, sample_chunk):
|
||||
"""Test chunk creation."""
|
||||
from models import ChunkType, FileType
|
||||
|
||||
assert sample_chunk.content == "def hello():\n print('Hello')"
|
||||
assert sample_chunk.chunk_type == ChunkType.CODE
|
||||
assert sample_chunk.file_type == FileType.PYTHON
|
||||
assert sample_chunk.source_path == "/test/hello.py"
|
||||
assert sample_chunk.start_line == 1
|
||||
assert sample_chunk.end_line == 2
|
||||
assert sample_chunk.token_count == 15
|
||||
|
||||
def test_chunk_to_dict(self, sample_chunk):
|
||||
"""Test chunk to_dict method."""
|
||||
result = sample_chunk.to_dict()
|
||||
|
||||
assert result["content"] == "def hello():\n print('Hello')"
|
||||
assert result["chunk_type"] == "code"
|
||||
assert result["file_type"] == "python"
|
||||
assert result["source_path"] == "/test/hello.py"
|
||||
assert result["start_line"] == 1
|
||||
assert result["end_line"] == 2
|
||||
assert result["token_count"] == 15
|
||||
|
||||
|
||||
class TestKnowledgeEmbedding:
|
||||
"""Tests for KnowledgeEmbedding dataclass."""
|
||||
|
||||
def test_embedding_creation(self, sample_embedding):
|
||||
"""Test embedding creation."""
|
||||
assert sample_embedding.id == "test-id-123"
|
||||
assert sample_embedding.project_id == "proj-123"
|
||||
assert sample_embedding.collection == "default"
|
||||
assert len(sample_embedding.embedding) == 1536
|
||||
|
||||
def test_embedding_to_dict(self, sample_embedding):
|
||||
"""Test embedding to_dict method."""
|
||||
result = sample_embedding.to_dict()
|
||||
|
||||
assert result["id"] == "test-id-123"
|
||||
assert result["project_id"] == "proj-123"
|
||||
assert result["collection"] == "default"
|
||||
assert result["chunk_type"] == "code"
|
||||
assert result["file_type"] == "python"
|
||||
assert "embedding" not in result # Embedding excluded for size
|
||||
|
||||
|
||||
class TestIngestRequest:
|
||||
"""Tests for IngestRequest model."""
|
||||
|
||||
def test_ingest_request_creation(self, sample_ingest_request):
|
||||
"""Test ingest request creation."""
|
||||
from models import ChunkType, FileType
|
||||
|
||||
assert sample_ingest_request.project_id == "proj-123"
|
||||
assert sample_ingest_request.agent_id == "agent-456"
|
||||
assert sample_ingest_request.chunk_type == ChunkType.CODE
|
||||
assert sample_ingest_request.file_type == FileType.PYTHON
|
||||
assert sample_ingest_request.collection == "default"
|
||||
|
||||
def test_ingest_request_defaults(self):
|
||||
"""Test ingest request default values."""
|
||||
from models import ChunkType, IngestRequest
|
||||
|
||||
request = IngestRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
content="test content",
|
||||
)
|
||||
|
||||
assert request.collection == "default"
|
||||
assert request.chunk_type == ChunkType.TEXT
|
||||
assert request.file_type is None
|
||||
assert request.metadata == {}
|
||||
|
||||
|
||||
class TestIngestResult:
|
||||
"""Tests for IngestResult model."""
|
||||
|
||||
def test_successful_result(self):
|
||||
"""Test successful ingest result."""
|
||||
from models import IngestResult
|
||||
|
||||
result = IngestResult(
|
||||
success=True,
|
||||
chunks_created=5,
|
||||
embeddings_generated=5,
|
||||
source_path="/test/file.py",
|
||||
collection="default",
|
||||
chunk_ids=["id1", "id2", "id3", "id4", "id5"],
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.chunks_created == 5
|
||||
assert result.error is None
|
||||
|
||||
def test_failed_result(self):
|
||||
"""Test failed ingest result."""
|
||||
from models import IngestResult
|
||||
|
||||
result = IngestResult(
|
||||
success=False,
|
||||
chunks_created=0,
|
||||
embeddings_generated=0,
|
||||
collection="default",
|
||||
chunk_ids=[],
|
||||
error="Something went wrong",
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert result.error == "Something went wrong"
|
||||
|
||||
|
||||
class TestSearchRequest:
|
||||
"""Tests for SearchRequest model."""
|
||||
|
||||
def test_search_request_creation(self, sample_search_request):
|
||||
"""Test search request creation."""
|
||||
from models import SearchType
|
||||
|
||||
assert sample_search_request.project_id == "proj-123"
|
||||
assert sample_search_request.query == "hello function"
|
||||
assert sample_search_request.search_type == SearchType.HYBRID
|
||||
assert sample_search_request.limit == 10
|
||||
assert sample_search_request.threshold == 0.7
|
||||
|
||||
def test_search_request_defaults(self):
|
||||
"""Test search request default values."""
|
||||
from models import SearchRequest, SearchType
|
||||
|
||||
request = SearchRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="test query",
|
||||
)
|
||||
|
||||
assert request.search_type == SearchType.HYBRID
|
||||
assert request.collection is None
|
||||
assert request.limit == 10
|
||||
assert request.threshold == 0.7
|
||||
assert request.file_types is None
|
||||
|
||||
|
||||
class TestSearchResult:
|
||||
"""Tests for SearchResult model."""
|
||||
|
||||
def test_from_embedding(self, sample_embedding):
|
||||
"""Test creating SearchResult from KnowledgeEmbedding."""
|
||||
from models import SearchResult
|
||||
|
||||
result = SearchResult.from_embedding(sample_embedding, 0.95)
|
||||
|
||||
assert result.id == "test-id-123"
|
||||
assert result.content == "def hello():\n print('Hello')"
|
||||
assert result.score == 0.95
|
||||
assert result.source_path == "/test/hello.py"
|
||||
assert result.chunk_type == "code"
|
||||
assert result.file_type == "python"
|
||||
|
||||
|
||||
class TestSearchResponse:
|
||||
"""Tests for SearchResponse model."""
|
||||
|
||||
def test_search_response(self):
|
||||
"""Test search response creation."""
|
||||
from models import SearchResponse, SearchResult
|
||||
|
||||
results = [
|
||||
SearchResult(
|
||||
id="id1",
|
||||
content="test content 1",
|
||||
score=0.95,
|
||||
chunk_type="code",
|
||||
collection="default",
|
||||
),
|
||||
SearchResult(
|
||||
id="id2",
|
||||
content="test content 2",
|
||||
score=0.85,
|
||||
chunk_type="text",
|
||||
collection="default",
|
||||
),
|
||||
]
|
||||
|
||||
response = SearchResponse(
|
||||
query="test query",
|
||||
search_type="hybrid",
|
||||
results=results,
|
||||
total_results=2,
|
||||
search_time_ms=15.5,
|
||||
)
|
||||
|
||||
assert response.query == "test query"
|
||||
assert len(response.results) == 2
|
||||
assert response.search_time_ms == 15.5
|
||||
|
||||
|
||||
class TestDeleteRequest:
|
||||
"""Tests for DeleteRequest model."""
|
||||
|
||||
def test_delete_by_source(self, sample_delete_request):
|
||||
"""Test delete request by source path."""
|
||||
assert sample_delete_request.project_id == "proj-123"
|
||||
assert sample_delete_request.source_path == "/test/hello.py"
|
||||
assert sample_delete_request.collection is None
|
||||
assert sample_delete_request.chunk_ids is None
|
||||
|
||||
def test_delete_by_collection(self):
|
||||
"""Test delete request by collection."""
|
||||
from models import DeleteRequest
|
||||
|
||||
request = DeleteRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
collection="to-delete",
|
||||
)
|
||||
|
||||
assert request.collection == "to-delete"
|
||||
assert request.source_path is None
|
||||
|
||||
def test_delete_by_ids(self):
|
||||
"""Test delete request by chunk IDs."""
|
||||
from models import DeleteRequest
|
||||
|
||||
request = DeleteRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
chunk_ids=["id1", "id2", "id3"],
|
||||
)
|
||||
|
||||
assert len(request.chunk_ids) == 3
|
||||
|
||||
|
||||
class TestCollectionInfo:
|
||||
"""Tests for CollectionInfo model."""
|
||||
|
||||
def test_collection_info(self):
|
||||
"""Test collection info creation."""
|
||||
from models import CollectionInfo
|
||||
|
||||
info = CollectionInfo(
|
||||
name="test-collection",
|
||||
project_id="proj-123",
|
||||
chunk_count=100,
|
||||
total_tokens=50000,
|
||||
file_types=["python", "javascript"],
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
assert info.name == "test-collection"
|
||||
assert info.chunk_count == 100
|
||||
assert len(info.file_types) == 2
|
||||
|
||||
|
||||
class TestCollectionStats:
|
||||
"""Tests for CollectionStats model."""
|
||||
|
||||
def test_collection_stats(self):
|
||||
"""Test collection stats creation."""
|
||||
from models import CollectionStats
|
||||
|
||||
stats = CollectionStats(
|
||||
collection="test-collection",
|
||||
project_id="proj-123",
|
||||
chunk_count=100,
|
||||
unique_sources=10,
|
||||
total_tokens=50000,
|
||||
avg_chunk_size=500.0,
|
||||
chunk_types={"code": 60, "text": 40},
|
||||
file_types={"python": 50, "javascript": 10},
|
||||
oldest_chunk=datetime.now(UTC),
|
||||
newest_chunk=datetime.now(UTC),
|
||||
)
|
||||
|
||||
assert stats.chunk_count == 100
|
||||
assert stats.unique_sources == 10
|
||||
assert stats.chunk_types["code"] == 60
|
||||
295
mcp-servers/knowledge-base/tests/test_search.py
Normal file
295
mcp-servers/knowledge-base/tests/test_search.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""Tests for search module."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestSearchEngine:
|
||||
"""Tests for SearchEngine class."""
|
||||
|
||||
@pytest.fixture
|
||||
def search_engine(self, settings, mock_database, mock_embeddings):
|
||||
"""Create search engine with mocks."""
|
||||
from search import SearchEngine
|
||||
|
||||
engine = SearchEngine(
|
||||
settings=settings,
|
||||
database=mock_database,
|
||||
embeddings=mock_embeddings,
|
||||
)
|
||||
return engine
|
||||
|
||||
@pytest.fixture
|
||||
def sample_db_results(self):
|
||||
"""Create sample database results."""
|
||||
from models import ChunkType, FileType, KnowledgeEmbedding
|
||||
|
||||
return [
|
||||
(
|
||||
KnowledgeEmbedding(
|
||||
id="id-1",
|
||||
project_id="proj-123",
|
||||
collection="default",
|
||||
content="def hello(): pass",
|
||||
embedding=[0.1] * 1536,
|
||||
chunk_type=ChunkType.CODE,
|
||||
source_path="/test/file.py",
|
||||
file_type=FileType.PYTHON,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
),
|
||||
0.95,
|
||||
),
|
||||
(
|
||||
KnowledgeEmbedding(
|
||||
id="id-2",
|
||||
project_id="proj-123",
|
||||
collection="default",
|
||||
content="def world(): pass",
|
||||
embedding=[0.2] * 1536,
|
||||
chunk_type=ChunkType.CODE,
|
||||
source_path="/test/file2.py",
|
||||
file_type=FileType.PYTHON,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
),
|
||||
0.85,
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_semantic_search(self, search_engine, sample_search_request, sample_db_results):
|
||||
"""Test semantic search."""
|
||||
from models import SearchType
|
||||
|
||||
sample_search_request.search_type = SearchType.SEMANTIC
|
||||
search_engine._database.semantic_search.return_value = sample_db_results
|
||||
|
||||
response = await search_engine.search(sample_search_request)
|
||||
|
||||
assert response.search_type == "semantic"
|
||||
assert len(response.results) == 2
|
||||
assert response.results[0].score == 0.95
|
||||
search_engine._database.semantic_search.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keyword_search(self, search_engine, sample_search_request, sample_db_results):
|
||||
"""Test keyword search."""
|
||||
from models import SearchType
|
||||
|
||||
sample_search_request.search_type = SearchType.KEYWORD
|
||||
search_engine._database.keyword_search.return_value = sample_db_results
|
||||
|
||||
response = await search_engine.search(sample_search_request)
|
||||
|
||||
assert response.search_type == "keyword"
|
||||
assert len(response.results) == 2
|
||||
search_engine._database.keyword_search.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hybrid_search(self, search_engine, sample_search_request, sample_db_results):
|
||||
"""Test hybrid search."""
|
||||
from models import SearchType
|
||||
|
||||
sample_search_request.search_type = SearchType.HYBRID
|
||||
|
||||
# Both searches return same results for simplicity
|
||||
search_engine._database.semantic_search.return_value = sample_db_results
|
||||
search_engine._database.keyword_search.return_value = sample_db_results
|
||||
|
||||
response = await search_engine.search(sample_search_request)
|
||||
|
||||
assert response.search_type == "hybrid"
|
||||
# Results should be fused
|
||||
assert len(response.results) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_collection_filter(self, search_engine, sample_search_request, sample_db_results):
|
||||
"""Test search with collection filter."""
|
||||
from models import SearchType
|
||||
|
||||
sample_search_request.search_type = SearchType.SEMANTIC
|
||||
sample_search_request.collection = "specific-collection"
|
||||
search_engine._database.semantic_search.return_value = sample_db_results
|
||||
|
||||
await search_engine.search(sample_search_request)
|
||||
|
||||
# Verify collection was passed to database
|
||||
call_args = search_engine._database.semantic_search.call_args
|
||||
assert call_args.kwargs["collection"] == "specific-collection"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_file_type_filter(self, search_engine, sample_search_request, sample_db_results):
|
||||
"""Test search with file type filter."""
|
||||
from models import FileType, SearchType
|
||||
|
||||
sample_search_request.search_type = SearchType.SEMANTIC
|
||||
sample_search_request.file_types = [FileType.PYTHON]
|
||||
search_engine._database.semantic_search.return_value = sample_db_results
|
||||
|
||||
await search_engine.search(sample_search_request)
|
||||
|
||||
# Verify file types were passed to database
|
||||
call_args = search_engine._database.semantic_search.call_args
|
||||
assert call_args.kwargs["file_types"] == [FileType.PYTHON]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_respects_limit(self, search_engine, sample_search_request, sample_db_results):
|
||||
"""Test that search respects result limit."""
|
||||
from models import SearchType
|
||||
|
||||
sample_search_request.search_type = SearchType.SEMANTIC
|
||||
sample_search_request.limit = 1
|
||||
search_engine._database.semantic_search.return_value = sample_db_results[:1]
|
||||
|
||||
response = await search_engine.search(sample_search_request)
|
||||
|
||||
assert len(response.results) <= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_records_time(self, search_engine, sample_search_request, sample_db_results):
|
||||
"""Test that search records time."""
|
||||
from models import SearchType
|
||||
|
||||
sample_search_request.search_type = SearchType.SEMANTIC
|
||||
search_engine._database.semantic_search.return_value = sample_db_results
|
||||
|
||||
response = await search_engine.search(sample_search_request)
|
||||
|
||||
assert response.search_time_ms > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_search_type(self, search_engine, sample_search_request):
|
||||
"""Test handling invalid search type."""
|
||||
from exceptions import InvalidSearchTypeError
|
||||
|
||||
# Force invalid search type
|
||||
sample_search_request.search_type = "invalid"
|
||||
|
||||
with pytest.raises((InvalidSearchTypeError, ValueError)):
|
||||
await search_engine.search(sample_search_request)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_results(self, search_engine, sample_search_request):
|
||||
"""Test search with no results."""
|
||||
from models import SearchType
|
||||
|
||||
sample_search_request.search_type = SearchType.SEMANTIC
|
||||
search_engine._database.semantic_search.return_value = []
|
||||
|
||||
response = await search_engine.search(sample_search_request)
|
||||
|
||||
assert len(response.results) == 0
|
||||
assert response.total_results == 0
|
||||
|
||||
|
||||
class TestReciprocalRankFusion:
|
||||
"""Tests for reciprocal rank fusion."""
|
||||
|
||||
@pytest.fixture
|
||||
def search_engine(self, settings, mock_database, mock_embeddings):
|
||||
"""Create search engine with mocks."""
|
||||
from search import SearchEngine
|
||||
|
||||
return SearchEngine(
|
||||
settings=settings,
|
||||
database=mock_database,
|
||||
embeddings=mock_embeddings,
|
||||
)
|
||||
|
||||
def test_fusion_combines_results(self, search_engine):
|
||||
"""Test that RRF combines results from both searches."""
|
||||
from models import SearchResult
|
||||
|
||||
semantic = [
|
||||
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"),
|
||||
SearchResult(id="b", content="B", score=0.8, chunk_type="code", collection="default"),
|
||||
]
|
||||
|
||||
keyword = [
|
||||
SearchResult(id="b", content="B", score=0.85, chunk_type="code", collection="default"),
|
||||
SearchResult(id="c", content="C", score=0.7, chunk_type="code", collection="default"),
|
||||
]
|
||||
|
||||
fused = search_engine._reciprocal_rank_fusion(semantic, keyword)
|
||||
|
||||
# Should have all unique results
|
||||
ids = [r.id for r in fused]
|
||||
assert "a" in ids
|
||||
assert "b" in ids
|
||||
assert "c" in ids
|
||||
|
||||
# B should be ranked higher (appears in both)
|
||||
b_rank = ids.index("b")
|
||||
assert b_rank < 2 # Should be in top 2
|
||||
|
||||
def test_fusion_respects_weights(self, search_engine):
|
||||
"""Test that RRF respects semantic/keyword weights."""
|
||||
from models import SearchResult
|
||||
|
||||
# Same results in same order
|
||||
results = [
|
||||
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"),
|
||||
]
|
||||
|
||||
# High semantic weight
|
||||
fused_semantic_heavy = search_engine._reciprocal_rank_fusion(
|
||||
results, [],
|
||||
semantic_weight=0.9,
|
||||
keyword_weight=0.1,
|
||||
)
|
||||
|
||||
# High keyword weight
|
||||
fused_keyword_heavy = search_engine._reciprocal_rank_fusion(
|
||||
[], results,
|
||||
semantic_weight=0.1,
|
||||
keyword_weight=0.9,
|
||||
)
|
||||
|
||||
# Both should still return the result
|
||||
assert len(fused_semantic_heavy) == 1
|
||||
assert len(fused_keyword_heavy) == 1
|
||||
|
||||
def test_fusion_normalizes_scores(self, search_engine):
|
||||
"""Test that RRF normalizes scores to 0-1."""
|
||||
from models import SearchResult
|
||||
|
||||
semantic = [
|
||||
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"),
|
||||
SearchResult(id="b", content="B", score=0.8, chunk_type="code", collection="default"),
|
||||
]
|
||||
|
||||
keyword = [
|
||||
SearchResult(id="c", content="C", score=0.7, chunk_type="code", collection="default"),
|
||||
]
|
||||
|
||||
fused = search_engine._reciprocal_rank_fusion(semantic, keyword)
|
||||
|
||||
# All scores should be between 0 and 1
|
||||
for result in fused:
|
||||
assert 0 <= result.score <= 1
|
||||
|
||||
|
||||
class TestGlobalSearchEngine:
|
||||
"""Tests for global search engine."""
|
||||
|
||||
def test_get_search_engine_singleton(self):
|
||||
"""Test that get_search_engine returns singleton."""
|
||||
from search import get_search_engine, reset_search_engine
|
||||
|
||||
reset_search_engine()
|
||||
engine1 = get_search_engine()
|
||||
engine2 = get_search_engine()
|
||||
|
||||
assert engine1 is engine2
|
||||
|
||||
def test_reset_search_engine(self):
|
||||
"""Test resetting search engine."""
|
||||
from search import get_search_engine, reset_search_engine
|
||||
|
||||
engine1 = get_search_engine()
|
||||
reset_search_engine()
|
||||
engine2 = get_search_engine()
|
||||
|
||||
assert engine1 is not engine2
|
||||
357
mcp-servers/knowledge-base/tests/test_server.py
Normal file
357
mcp-servers/knowledge-base/tests/test_server.py
Normal file
@@ -0,0 +1,357 @@
|
||||
"""Tests for server module and MCP tools."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestHealthCheck:
|
||||
"""Tests for health check endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_healthy(self):
|
||||
"""Test health check when healthy."""
|
||||
import server
|
||||
|
||||
# Create a proper async context manager mock
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.fetchval = AsyncMock(return_value=1)
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db._pool = MagicMock()
|
||||
|
||||
# Make acquire an async context manager
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__.return_value = mock_conn
|
||||
mock_cm.__aexit__.return_value = None
|
||||
mock_db.acquire.return_value = mock_cm
|
||||
|
||||
server._database = mock_db
|
||||
|
||||
result = await server.health_check()
|
||||
|
||||
assert result["status"] == "healthy"
|
||||
assert result["service"] == "knowledge-base"
|
||||
assert result["database"] == "connected"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_no_database(self):
|
||||
"""Test health check without database."""
|
||||
import server
|
||||
|
||||
server._database = None
|
||||
|
||||
result = await server.health_check()
|
||||
|
||||
assert result["database"] == "not initialized"
|
||||
|
||||
|
||||
class TestSearchKnowledgeTool:
|
||||
"""Tests for search_knowledge MCP tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_success(self):
|
||||
"""Test successful search."""
|
||||
import server
|
||||
from models import SearchResponse, SearchResult
|
||||
|
||||
mock_search = MagicMock()
|
||||
mock_search.search = AsyncMock(
|
||||
return_value=SearchResponse(
|
||||
query="test query",
|
||||
search_type="hybrid",
|
||||
results=[
|
||||
SearchResult(
|
||||
id="id-1",
|
||||
content="Test content",
|
||||
score=0.95,
|
||||
source_path="/test/file.py",
|
||||
chunk_type="code",
|
||||
collection="default",
|
||||
)
|
||||
],
|
||||
total_results=1,
|
||||
search_time_ms=10.5,
|
||||
)
|
||||
)
|
||||
server._search = mock_search
|
||||
|
||||
# Call the wrapped function via .fn
|
||||
result = await server.search_knowledge.fn(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="test query",
|
||||
search_type="hybrid",
|
||||
collection=None,
|
||||
limit=10,
|
||||
threshold=0.7,
|
||||
file_types=None,
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert len(result["results"]) == 1
|
||||
assert result["results"][0]["score"] == 0.95
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_invalid_type(self):
|
||||
"""Test search with invalid search type."""
|
||||
import server
|
||||
|
||||
result = await server.search_knowledge.fn(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="test",
|
||||
search_type="invalid",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "Invalid search type" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_invalid_file_type(self):
|
||||
"""Test search with invalid file type."""
|
||||
import server
|
||||
|
||||
result = await server.search_knowledge.fn(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="test",
|
||||
search_type="hybrid",
|
||||
collection=None,
|
||||
limit=10,
|
||||
threshold=0.7,
|
||||
file_types=["invalid_type"],
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "Invalid file type" in result["error"]
|
||||
|
||||
|
||||
class TestIngestContentTool:
|
||||
"""Tests for ingest_content MCP tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_success(self):
|
||||
"""Test successful ingestion."""
|
||||
import server
|
||||
from models import IngestResult
|
||||
|
||||
mock_collections = MagicMock()
|
||||
mock_collections.ingest = AsyncMock(
|
||||
return_value=IngestResult(
|
||||
success=True,
|
||||
chunks_created=3,
|
||||
embeddings_generated=3,
|
||||
source_path="/test/file.py",
|
||||
collection="default",
|
||||
chunk_ids=["id-1", "id-2", "id-3"],
|
||||
)
|
||||
)
|
||||
server._collections = mock_collections
|
||||
|
||||
result = await server.ingest_content.fn(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
content="def hello(): pass",
|
||||
source_path="/test/file.py",
|
||||
collection="default",
|
||||
chunk_type="text",
|
||||
file_type=None,
|
||||
metadata=None,
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["chunks_created"] == 3
|
||||
assert len(result["chunk_ids"]) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_invalid_chunk_type(self):
|
||||
"""Test ingest with invalid chunk type."""
|
||||
import server
|
||||
|
||||
result = await server.ingest_content.fn(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
content="test content",
|
||||
chunk_type="invalid",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "Invalid chunk type" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_invalid_file_type(self):
|
||||
"""Test ingest with invalid file type."""
|
||||
import server
|
||||
|
||||
result = await server.ingest_content.fn(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
content="test content",
|
||||
source_path=None,
|
||||
collection="default",
|
||||
chunk_type="text",
|
||||
file_type="invalid",
|
||||
metadata=None,
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "Invalid file type" in result["error"]
|
||||
|
||||
|
||||
class TestDeleteContentTool:
|
||||
"""Tests for delete_content MCP tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_success(self):
|
||||
"""Test successful deletion."""
|
||||
import server
|
||||
from models import DeleteResult
|
||||
|
||||
mock_collections = MagicMock()
|
||||
mock_collections.delete = AsyncMock(
|
||||
return_value=DeleteResult(
|
||||
success=True,
|
||||
chunks_deleted=5,
|
||||
)
|
||||
)
|
||||
server._collections = mock_collections
|
||||
|
||||
result = await server.delete_content.fn(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
source_path="/test/file.py",
|
||||
collection=None,
|
||||
chunk_ids=None,
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["chunks_deleted"] == 5
|
||||
|
||||
|
||||
class TestListCollectionsTool:
|
||||
"""Tests for list_collections MCP tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_collections_success(self):
|
||||
"""Test listing collections."""
|
||||
import server
|
||||
from models import CollectionInfo, ListCollectionsResponse
|
||||
|
||||
mock_collections = MagicMock()
|
||||
mock_collections.list_collections = AsyncMock(
|
||||
return_value=ListCollectionsResponse(
|
||||
project_id="proj-123",
|
||||
collections=[
|
||||
CollectionInfo(
|
||||
name="collection-1",
|
||||
project_id="proj-123",
|
||||
chunk_count=100,
|
||||
total_tokens=50000,
|
||||
file_types=["python"],
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
],
|
||||
total_collections=1,
|
||||
)
|
||||
)
|
||||
server._collections = mock_collections
|
||||
|
||||
result = await server.list_collections.fn(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["total_collections"] == 1
|
||||
assert len(result["collections"]) == 1
|
||||
|
||||
|
||||
class TestGetCollectionStatsTool:
|
||||
"""Tests for get_collection_stats MCP tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stats_success(self):
|
||||
"""Test getting collection stats."""
|
||||
import server
|
||||
from models import CollectionStats
|
||||
|
||||
mock_collections = MagicMock()
|
||||
mock_collections.get_collection_stats = AsyncMock(
|
||||
return_value=CollectionStats(
|
||||
collection="test-collection",
|
||||
project_id="proj-123",
|
||||
chunk_count=100,
|
||||
unique_sources=10,
|
||||
total_tokens=50000,
|
||||
avg_chunk_size=500.0,
|
||||
chunk_types={"code": 60, "text": 40},
|
||||
file_types={"python": 50, "javascript": 10},
|
||||
)
|
||||
)
|
||||
server._collections = mock_collections
|
||||
|
||||
result = await server.get_collection_stats.fn(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
collection="test-collection",
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["chunk_count"] == 100
|
||||
assert result["unique_sources"] == 10
|
||||
|
||||
|
||||
class TestUpdateDocumentTool:
|
||||
"""Tests for update_document MCP tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_success(self):
|
||||
"""Test updating a document."""
|
||||
import server
|
||||
from models import IngestResult
|
||||
|
||||
mock_collections = MagicMock()
|
||||
mock_collections.update_document = AsyncMock(
|
||||
return_value=IngestResult(
|
||||
success=True,
|
||||
chunks_created=2,
|
||||
embeddings_generated=2,
|
||||
source_path="/test/file.py",
|
||||
collection="default",
|
||||
chunk_ids=["id-1", "id-2"],
|
||||
)
|
||||
)
|
||||
server._collections = mock_collections
|
||||
|
||||
result = await server.update_document.fn(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
source_path="/test/file.py",
|
||||
content="def updated(): pass",
|
||||
collection="default",
|
||||
chunk_type="text",
|
||||
file_type=None,
|
||||
metadata=None,
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["chunks_created"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_invalid_chunk_type(self):
|
||||
"""Test update with invalid chunk type."""
|
||||
import server
|
||||
|
||||
result = await server.update_document.fn(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
source_path="/test/file.py",
|
||||
content="test",
|
||||
chunk_type="invalid",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "Invalid chunk type" in result["error"]
|
||||
2026
mcp-servers/knowledge-base/uv.lock
generated
Normal file
2026
mcp-servers/knowledge-base/uv.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user