feat(knowledge-base): implement Knowledge Base MCP Server (#57)

Implements RAG capabilities with pgvector for semantic search:

- Intelligent chunking strategies (code-aware, markdown-aware, text)
- Semantic search with vector similarity (HNSW index)
- Keyword search with PostgreSQL full-text search
- Hybrid search using Reciprocal Rank Fusion (RRF)
- Redis caching for embeddings
- Collection management (ingest, search, delete, stats)
- FastMCP tools: search_knowledge, ingest_content, delete_content,
  list_collections, get_collection_stats, update_document

Testing:
- 128 comprehensive tests covering all components
- 58% code coverage (database integration tests use mocks)
- Passes ruff linting and mypy type checking

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-03 21:33:26 +01:00
parent 18d717e996
commit d0fc7f37ff
26 changed files with 9530 additions and 120 deletions

View File

@@ -0,0 +1,31 @@
FROM python:3.12-slim
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
&& rm -rf /var/lib/apt/lists/*
# Install uv for fast package installation
COPY --from=ghcr.io/astral-sh/uv:latest /uv /bin/uv
# Copy project files
COPY pyproject.toml ./
COPY *.py ./
COPY chunking/ ./chunking/
# Install dependencies
RUN uv pip install --system --no-cache .
# Create non-root user
RUN useradd --create-home --shell /bin/bash appuser
USER appuser
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD python -c "import httpx; httpx.get('http://localhost:8002/health').raise_for_status()"
EXPOSE 8002
CMD ["python", "server.py"]

View File

@@ -0,0 +1,19 @@
"""
Chunking module for Knowledge Base MCP Server.
Provides intelligent content chunking for different file types
with overlap and context preservation.
"""
from chunking.base import BaseChunker, ChunkerFactory
from chunking.code import CodeChunker
from chunking.markdown import MarkdownChunker
from chunking.text import TextChunker
__all__ = [
"BaseChunker",
"ChunkerFactory",
"CodeChunker",
"MarkdownChunker",
"TextChunker",
]

View File

@@ -0,0 +1,281 @@
"""
Base chunker implementation.
Provides abstract interface and common utilities for content chunking.
"""
import logging
from abc import ABC, abstractmethod
from typing import Any
import tiktoken
from config import Settings, get_settings
from exceptions import ChunkingError
from models import FILE_EXTENSION_MAP, Chunk, ChunkType, FileType
logger = logging.getLogger(__name__)
class BaseChunker(ABC):
"""
Abstract base class for content chunkers.
Subclasses implement specific chunking strategies for
different content types (code, markdown, text).
"""
def __init__(
self,
chunk_size: int,
chunk_overlap: int,
settings: Settings | None = None,
) -> None:
"""
Initialize chunker.
Args:
chunk_size: Target tokens per chunk
chunk_overlap: Token overlap between chunks
settings: Application settings
"""
self._settings = settings or get_settings()
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
# Use cl100k_base encoding (GPT-4/text-embedding-3)
self._tokenizer = tiktoken.get_encoding("cl100k_base")
def count_tokens(self, text: str) -> int:
"""Count tokens in text."""
return len(self._tokenizer.encode(text))
def truncate_to_tokens(self, text: str, max_tokens: int) -> str:
"""Truncate text to max tokens."""
tokens = self._tokenizer.encode(text)
if len(tokens) <= max_tokens:
return text
return self._tokenizer.decode(tokens[:max_tokens])
@abstractmethod
def chunk(
self,
content: str,
source_path: str | None = None,
file_type: FileType | None = None,
metadata: dict[str, Any] | None = None,
) -> list[Chunk]:
"""
Split content into chunks.
Args:
content: Content to chunk
source_path: Source file path for reference
file_type: File type for specialized handling
metadata: Additional metadata to include
Returns:
List of Chunk objects
"""
pass
@property
@abstractmethod
def chunk_type(self) -> ChunkType:
"""Get the chunk type this chunker produces."""
pass
def _create_chunk(
self,
content: str,
source_path: str | None = None,
start_line: int | None = None,
end_line: int | None = None,
file_type: FileType | None = None,
metadata: dict[str, Any] | None = None,
) -> Chunk:
"""Create a chunk with token count."""
token_count = self.count_tokens(content)
return Chunk(
content=content,
chunk_type=self.chunk_type,
file_type=file_type,
source_path=source_path,
start_line=start_line,
end_line=end_line,
metadata=metadata or {},
token_count=token_count,
)
class ChunkerFactory:
"""
Factory for creating appropriate chunkers.
Selects the best chunker based on file type or content.
"""
def __init__(self, settings: Settings | None = None) -> None:
"""Initialize factory."""
self._settings = settings or get_settings()
self._chunkers: dict[str, BaseChunker] = {}
def _get_code_chunker(self) -> "BaseChunker":
"""Get or create code chunker."""
from chunking.code import CodeChunker
if "code" not in self._chunkers:
self._chunkers["code"] = CodeChunker(
chunk_size=self._settings.code_chunk_size,
chunk_overlap=self._settings.code_chunk_overlap,
settings=self._settings,
)
return self._chunkers["code"]
def _get_markdown_chunker(self) -> "BaseChunker":
"""Get or create markdown chunker."""
from chunking.markdown import MarkdownChunker
if "markdown" not in self._chunkers:
self._chunkers["markdown"] = MarkdownChunker(
chunk_size=self._settings.markdown_chunk_size,
chunk_overlap=self._settings.markdown_chunk_overlap,
settings=self._settings,
)
return self._chunkers["markdown"]
def _get_text_chunker(self) -> "BaseChunker":
"""Get or create text chunker."""
from chunking.text import TextChunker
if "text" not in self._chunkers:
self._chunkers["text"] = TextChunker(
chunk_size=self._settings.text_chunk_size,
chunk_overlap=self._settings.text_chunk_overlap,
settings=self._settings,
)
return self._chunkers["text"]
def get_chunker(
self,
file_type: FileType | None = None,
chunk_type: ChunkType | None = None,
) -> BaseChunker:
"""
Get appropriate chunker for content type.
Args:
file_type: File type to chunk
chunk_type: Explicit chunk type to use
Returns:
Appropriate chunker instance
"""
# If explicit chunk type specified, use it
if chunk_type:
if chunk_type == ChunkType.CODE:
return self._get_code_chunker()
elif chunk_type == ChunkType.MARKDOWN:
return self._get_markdown_chunker()
else:
return self._get_text_chunker()
# Otherwise, infer from file type
if file_type:
if file_type == FileType.MARKDOWN:
return self._get_markdown_chunker()
elif file_type in (FileType.TEXT, FileType.JSON, FileType.YAML, FileType.TOML):
return self._get_text_chunker()
else:
# Code files
return self._get_code_chunker()
# Default to text chunker
return self._get_text_chunker()
def get_chunker_for_path(self, source_path: str) -> tuple[BaseChunker, FileType | None]:
"""
Get chunker based on file path extension.
Args:
source_path: File path to chunk
Returns:
Tuple of (chunker, file_type)
"""
# Extract extension
ext = ""
if "." in source_path:
ext = "." + source_path.rsplit(".", 1)[-1].lower()
file_type = FILE_EXTENSION_MAP.get(ext)
chunker = self.get_chunker(file_type=file_type)
return chunker, file_type
def chunk_content(
self,
content: str,
source_path: str | None = None,
file_type: FileType | None = None,
chunk_type: ChunkType | None = None,
metadata: dict[str, Any] | None = None,
) -> list[Chunk]:
"""
Chunk content using appropriate strategy.
Args:
content: Content to chunk
source_path: Source file path
file_type: File type
chunk_type: Explicit chunk type
metadata: Additional metadata
Returns:
List of chunks
"""
# If we have a source path but no file type, infer it
if source_path and not file_type:
chunker, file_type = self.get_chunker_for_path(source_path)
else:
chunker = self.get_chunker(file_type=file_type, chunk_type=chunk_type)
try:
chunks = chunker.chunk(
content=content,
source_path=source_path,
file_type=file_type,
metadata=metadata,
)
logger.debug(
f"Chunked content into {len(chunks)} chunks "
f"(type={chunker.chunk_type.value})"
)
return chunks
except Exception as e:
logger.error(f"Chunking error: {e}")
raise ChunkingError(
message=f"Failed to chunk content: {e}",
cause=e,
)
# Global chunker factory instance
_chunker_factory: ChunkerFactory | None = None
def get_chunker_factory() -> ChunkerFactory:
"""Get the global chunker factory instance."""
global _chunker_factory
if _chunker_factory is None:
_chunker_factory = ChunkerFactory()
return _chunker_factory
def reset_chunker_factory() -> None:
"""Reset the global chunker factory (for testing)."""
global _chunker_factory
_chunker_factory = None

View File

@@ -0,0 +1,410 @@
"""
Code-aware chunking implementation.
Provides intelligent chunking for source code that respects
function/class boundaries and preserves context.
"""
import logging
import re
from typing import Any
from chunking.base import BaseChunker
from config import Settings
from models import Chunk, ChunkType, FileType
logger = logging.getLogger(__name__)
# Language-specific patterns for detecting function/class definitions
LANGUAGE_PATTERNS: dict[FileType, dict[str, re.Pattern[str]]] = {
FileType.PYTHON: {
"function": re.compile(r"^(\s*)(async\s+)?def\s+\w+", re.MULTILINE),
"class": re.compile(r"^(\s*)class\s+\w+", re.MULTILINE),
"decorator": re.compile(r"^(\s*)@\w+", re.MULTILINE),
},
FileType.JAVASCRIPT: {
"function": re.compile(
r"^(\s*)(export\s+)?(async\s+)?function\s+\w+|"
r"^(\s*)(export\s+)?(const|let|var)\s+\w+\s*=\s*(async\s+)?\(",
re.MULTILINE,
),
"class": re.compile(r"^(\s*)(export\s+)?class\s+\w+", re.MULTILINE),
"arrow": re.compile(
r"^(\s*)(export\s+)?(const|let|var)\s+\w+\s*=\s*(async\s+)?(\([^)]*\)|[^=])\s*=>",
re.MULTILINE,
),
},
FileType.TYPESCRIPT: {
"function": re.compile(
r"^(\s*)(export\s+)?(async\s+)?function\s+\w+|"
r"^(\s*)(export\s+)?(const|let|var)\s+\w+\s*[:<]",
re.MULTILINE,
),
"class": re.compile(r"^(\s*)(export\s+)?class\s+\w+", re.MULTILINE),
"interface": re.compile(r"^(\s*)(export\s+)?interface\s+\w+", re.MULTILINE),
"type": re.compile(r"^(\s*)(export\s+)?type\s+\w+", re.MULTILINE),
},
FileType.GO: {
"function": re.compile(r"^func\s+(\([^)]+\)\s+)?\w+", re.MULTILINE),
"struct": re.compile(r"^type\s+\w+\s+struct", re.MULTILINE),
"interface": re.compile(r"^type\s+\w+\s+interface", re.MULTILINE),
},
FileType.RUST: {
"function": re.compile(r"^(\s*)(pub\s+)?(async\s+)?fn\s+\w+", re.MULTILINE),
"struct": re.compile(r"^(\s*)(pub\s+)?struct\s+\w+", re.MULTILINE),
"impl": re.compile(r"^(\s*)impl\s+", re.MULTILINE),
"trait": re.compile(r"^(\s*)(pub\s+)?trait\s+\w+", re.MULTILINE),
},
FileType.JAVA: {
"method": re.compile(
r"^(\s*)(public|private|protected)?\s*(static)?\s*\w+\s+\w+\s*\(",
re.MULTILINE,
),
"class": re.compile(
r"^(\s*)(public|private|protected)?\s*(abstract)?\s*class\s+\w+",
re.MULTILINE,
),
"interface": re.compile(
r"^(\s*)(public|private|protected)?\s*interface\s+\w+",
re.MULTILINE,
),
},
}
class CodeChunker(BaseChunker):
"""
Code-aware chunker that respects logical boundaries.
Features:
- Detects function/class boundaries
- Preserves decorator/annotation context
- Handles nested structures
- Falls back to line-based chunking when needed
"""
def __init__(
self,
chunk_size: int,
chunk_overlap: int,
settings: Settings | None = None,
) -> None:
"""Initialize code chunker."""
super().__init__(chunk_size, chunk_overlap, settings)
@property
def chunk_type(self) -> ChunkType:
"""Get chunk type."""
return ChunkType.CODE
def chunk(
self,
content: str,
source_path: str | None = None,
file_type: FileType | None = None,
metadata: dict[str, Any] | None = None,
) -> list[Chunk]:
"""
Chunk code content.
Tries to respect function/class boundaries, falling back
to line-based chunking if needed.
"""
if not content.strip():
return []
metadata = metadata or {}
lines = content.splitlines(keepends=True)
# Try language-aware chunking if we have patterns
if file_type and file_type in LANGUAGE_PATTERNS:
chunks = self._chunk_by_structure(
content, lines, file_type, source_path, metadata
)
if chunks:
return chunks
# Fall back to line-based chunking
return self._chunk_by_lines(lines, source_path, file_type, metadata)
def _chunk_by_structure(
self,
content: str,
lines: list[str],
file_type: FileType,
source_path: str | None,
metadata: dict[str, Any],
) -> list[Chunk]:
"""
Chunk by detecting code structure (functions, classes).
Returns empty list if structure detection isn't useful.
"""
patterns = LANGUAGE_PATTERNS.get(file_type, {})
if not patterns:
return []
# Find all structure boundaries
boundaries: list[tuple[int, str]] = [] # (line_number, type)
for struct_type, pattern in patterns.items():
for match in pattern.finditer(content):
# Convert character position to line number
line_num = content[:match.start()].count("\n")
boundaries.append((line_num, struct_type))
if not boundaries:
return []
# Sort boundaries by line number
boundaries.sort(key=lambda x: x[0])
# If we have very few boundaries, line-based might be better
if len(boundaries) < 3 and len(lines) > 50:
return []
# Create chunks based on boundaries
chunks: list[Chunk] = []
current_start = 0
for _i, (line_num, struct_type) in enumerate(boundaries):
# Check if we need to create a chunk before this boundary
if line_num > current_start:
# Include any preceding comments/decorators
actual_start = self._find_context_start(lines, line_num)
if actual_start < current_start:
actual_start = current_start
chunk_lines = lines[current_start:line_num]
chunk_content = "".join(chunk_lines)
if chunk_content.strip():
token_count = self.count_tokens(chunk_content)
# If chunk is too large, split it
if token_count > self.chunk_size * 1.5:
sub_chunks = self._split_large_chunk(
chunk_lines, current_start, source_path, file_type, metadata
)
chunks.extend(sub_chunks)
elif token_count > 0:
chunks.append(
self._create_chunk(
content=chunk_content.rstrip(),
source_path=source_path,
start_line=current_start + 1,
end_line=line_num,
file_type=file_type,
metadata={**metadata, "structure_type": struct_type},
)
)
current_start = line_num
# Handle remaining content
if current_start < len(lines):
chunk_lines = lines[current_start:]
chunk_content = "".join(chunk_lines)
if chunk_content.strip():
token_count = self.count_tokens(chunk_content)
if token_count > self.chunk_size * 1.5:
sub_chunks = self._split_large_chunk(
chunk_lines, current_start, source_path, file_type, metadata
)
chunks.extend(sub_chunks)
else:
chunks.append(
self._create_chunk(
content=chunk_content.rstrip(),
source_path=source_path,
start_line=current_start + 1,
end_line=len(lines),
file_type=file_type,
metadata=metadata,
)
)
return chunks
def _find_context_start(self, lines: list[str], line_num: int) -> int:
"""Find the start of context (decorators, comments) before a line."""
start = line_num
# Look backwards for decorators/comments
for i in range(line_num - 1, max(0, line_num - 10), -1):
line = lines[i].strip()
if not line:
continue
if line.startswith(("#", "//", "/*", "*", "@", "'")):
start = i
else:
break
return start
def _split_large_chunk(
self,
chunk_lines: list[str],
base_line: int,
source_path: str | None,
file_type: FileType | None,
metadata: dict[str, Any],
) -> list[Chunk]:
"""Split a large chunk into smaller pieces with overlap."""
chunks: list[Chunk] = []
current_lines: list[str] = []
current_tokens = 0
chunk_start = 0
for i, line in enumerate(chunk_lines):
line_tokens = self.count_tokens(line)
if current_tokens + line_tokens > self.chunk_size and current_lines:
# Create chunk
chunk_content = "".join(current_lines).rstrip()
chunks.append(
self._create_chunk(
content=chunk_content,
source_path=source_path,
start_line=base_line + chunk_start + 1,
end_line=base_line + i,
file_type=file_type,
metadata=metadata,
)
)
# Calculate overlap
overlap_tokens = 0
overlap_lines: list[str] = []
for j in range(len(current_lines) - 1, -1, -1):
overlap_tokens += self.count_tokens(current_lines[j])
if overlap_tokens >= self.chunk_overlap:
overlap_lines = current_lines[j:]
break
current_lines = overlap_lines
current_tokens = sum(self.count_tokens(line) for line in current_lines)
chunk_start = i - len(overlap_lines)
current_lines.append(line)
current_tokens += line_tokens
# Final chunk
if current_lines:
chunk_content = "".join(current_lines).rstrip()
if chunk_content.strip():
chunks.append(
self._create_chunk(
content=chunk_content,
source_path=source_path,
start_line=base_line + chunk_start + 1,
end_line=base_line + len(chunk_lines),
file_type=file_type,
metadata=metadata,
)
)
return chunks
def _chunk_by_lines(
self,
lines: list[str],
source_path: str | None,
file_type: FileType | None,
metadata: dict[str, Any],
) -> list[Chunk]:
"""Chunk by lines with overlap."""
chunks: list[Chunk] = []
current_lines: list[str] = []
current_tokens = 0
chunk_start = 0
for i, line in enumerate(lines):
line_tokens = self.count_tokens(line)
# If this line alone exceeds chunk size, handle specially
if line_tokens > self.chunk_size:
# Flush current chunk
if current_lines:
chunk_content = "".join(current_lines).rstrip()
if chunk_content.strip():
chunks.append(
self._create_chunk(
content=chunk_content,
source_path=source_path,
start_line=chunk_start + 1,
end_line=i,
file_type=file_type,
metadata=metadata,
)
)
current_lines = []
current_tokens = 0
chunk_start = i
# Truncate and add long line
truncated = self.truncate_to_tokens(line, self.chunk_size)
chunks.append(
self._create_chunk(
content=truncated.rstrip(),
source_path=source_path,
start_line=i + 1,
end_line=i + 1,
file_type=file_type,
metadata={**metadata, "truncated": True},
)
)
chunk_start = i + 1
continue
if current_tokens + line_tokens > self.chunk_size and current_lines:
# Create chunk
chunk_content = "".join(current_lines).rstrip()
if chunk_content.strip():
chunks.append(
self._create_chunk(
content=chunk_content,
source_path=source_path,
start_line=chunk_start + 1,
end_line=i,
file_type=file_type,
metadata=metadata,
)
)
# Calculate overlap
overlap_tokens = 0
overlap_lines: list[str] = []
for j in range(len(current_lines) - 1, -1, -1):
line_tok = self.count_tokens(current_lines[j])
if overlap_tokens + line_tok > self.chunk_overlap:
break
overlap_lines.insert(0, current_lines[j])
overlap_tokens += line_tok
current_lines = overlap_lines
current_tokens = overlap_tokens
chunk_start = i - len(overlap_lines)
current_lines.append(line)
current_tokens += line_tokens
# Final chunk
if current_lines:
chunk_content = "".join(current_lines).rstrip()
if chunk_content.strip():
chunks.append(
self._create_chunk(
content=chunk_content,
source_path=source_path,
start_line=chunk_start + 1,
end_line=len(lines),
file_type=file_type,
metadata=metadata,
)
)
return chunks

View File

@@ -0,0 +1,483 @@
"""
Markdown-aware chunking implementation.
Provides intelligent chunking for markdown content that respects
heading hierarchy and preserves document structure.
"""
import logging
import re
from typing import Any
from chunking.base import BaseChunker
from config import Settings
from models import Chunk, ChunkType, FileType
logger = logging.getLogger(__name__)
# Patterns for markdown elements
HEADING_PATTERN = re.compile(r"^(#{1,6})\s+(.+)$", re.MULTILINE)
CODE_BLOCK_PATTERN = re.compile(r"^```", re.MULTILINE)
HR_PATTERN = re.compile(r"^(-{3,}|_{3,}|\*{3,})$", re.MULTILINE)
class MarkdownChunker(BaseChunker):
"""
Markdown-aware chunker that respects document structure.
Features:
- Respects heading hierarchy
- Preserves heading context in chunks
- Handles code blocks as units
- Maintains list continuity where possible
"""
def __init__(
self,
chunk_size: int,
chunk_overlap: int,
settings: Settings | None = None,
) -> None:
"""Initialize markdown chunker."""
super().__init__(chunk_size, chunk_overlap, settings)
@property
def chunk_type(self) -> ChunkType:
"""Get chunk type."""
return ChunkType.MARKDOWN
def chunk(
self,
content: str,
source_path: str | None = None,
file_type: FileType | None = None,
metadata: dict[str, Any] | None = None,
) -> list[Chunk]:
"""
Chunk markdown content.
Splits on heading boundaries and preserves heading context.
"""
if not content.strip():
return []
metadata = metadata or {}
file_type = file_type or FileType.MARKDOWN
# Split content into sections by headings
sections = self._split_by_headings(content)
if not sections:
# No headings, chunk as plain text
return self._chunk_text_block(
content, source_path, file_type, metadata, []
)
chunks: list[Chunk] = []
heading_stack: list[tuple[int, str]] = [] # (level, text)
for section in sections:
heading_level = section.get("level", 0)
heading_text = section.get("heading", "")
section_content = section.get("content", "")
start_line = section.get("start_line", 1)
end_line = section.get("end_line", 1)
# Update heading stack
if heading_level > 0:
# Pop headings of equal or higher level
while heading_stack and heading_stack[-1][0] >= heading_level:
heading_stack.pop()
heading_stack.append((heading_level, heading_text))
# Build heading context prefix
heading_context = " > ".join(h[1] for h in heading_stack)
section_chunks = self._chunk_section(
content=section_content,
heading_context=heading_context,
heading_level=heading_level,
heading_text=heading_text,
start_line=start_line,
end_line=end_line,
source_path=source_path,
file_type=file_type,
metadata=metadata,
)
chunks.extend(section_chunks)
return chunks
def _split_by_headings(self, content: str) -> list[dict[str, Any]]:
"""Split content into sections by headings."""
sections: list[dict[str, Any]] = []
lines = content.split("\n")
current_section: dict[str, Any] = {
"level": 0,
"heading": "",
"content": "",
"start_line": 1,
"end_line": 1,
}
current_lines: list[str] = []
in_code_block = False
for i, line in enumerate(lines):
# Track code blocks
if line.strip().startswith("```"):
in_code_block = not in_code_block
current_lines.append(line)
continue
# Skip heading detection in code blocks
if in_code_block:
current_lines.append(line)
continue
# Check for heading
heading_match = HEADING_PATTERN.match(line)
if heading_match:
# Save previous section
if current_lines:
current_section["content"] = "\n".join(current_lines)
current_section["end_line"] = i
if current_section["content"].strip():
sections.append(current_section)
# Start new section
level = len(heading_match.group(1))
heading_text = heading_match.group(2).strip()
current_section = {
"level": level,
"heading": heading_text,
"content": "",
"start_line": i + 1,
"end_line": i + 1,
}
current_lines = [line]
else:
current_lines.append(line)
# Save final section
if current_lines:
current_section["content"] = "\n".join(current_lines)
current_section["end_line"] = len(lines)
if current_section["content"].strip():
sections.append(current_section)
return sections
def _chunk_section(
self,
content: str,
heading_context: str,
heading_level: int,
heading_text: str,
start_line: int,
end_line: int,
source_path: str | None,
file_type: FileType,
metadata: dict[str, Any],
) -> list[Chunk]:
"""Chunk a single section of markdown."""
if not content.strip():
return []
token_count = self.count_tokens(content)
# If section fits in one chunk, return as-is
if token_count <= self.chunk_size:
section_metadata = {
**metadata,
"heading_context": heading_context,
"heading_level": heading_level,
"heading_text": heading_text,
}
return [
self._create_chunk(
content=content.strip(),
source_path=source_path,
start_line=start_line,
end_line=end_line,
file_type=file_type,
metadata=section_metadata,
)
]
# Need to split - try to split on paragraphs first
return self._chunk_text_block(
content,
source_path,
file_type,
{
**metadata,
"heading_context": heading_context,
"heading_level": heading_level,
"heading_text": heading_text,
},
_heading_stack=[(heading_level, heading_text)] if heading_text else [],
base_line=start_line,
)
def _chunk_text_block(
self,
content: str,
source_path: str | None,
file_type: FileType,
metadata: dict[str, Any],
_heading_stack: list[tuple[int, str]],
base_line: int = 1,
) -> list[Chunk]:
"""Chunk a block of text by paragraphs."""
# Split into paragraphs (separated by blank lines)
paragraphs = self._split_into_paragraphs(content)
if not paragraphs:
return []
chunks: list[Chunk] = []
current_content: list[str] = []
current_tokens = 0
chunk_start_line = base_line
for para_info in paragraphs:
para_content = para_info["content"]
para_tokens = para_info["tokens"]
para_start = para_info["start_line"]
# Handle very large paragraphs
if para_tokens > self.chunk_size:
# Flush current content
if current_content:
chunk_text = "\n\n".join(current_content)
chunks.append(
self._create_chunk(
content=chunk_text.strip(),
source_path=source_path,
start_line=chunk_start_line,
end_line=base_line + para_start - 1,
file_type=file_type,
metadata=metadata,
)
)
current_content = []
current_tokens = 0
# Split large paragraph by sentences/lines
sub_chunks = self._split_large_paragraph(
para_content,
source_path,
file_type,
metadata,
base_line + para_start,
)
chunks.extend(sub_chunks)
chunk_start_line = base_line + para_info["end_line"] + 1
continue
# Check if adding this paragraph exceeds limit
if current_tokens + para_tokens > self.chunk_size and current_content:
# Create chunk
chunk_text = "\n\n".join(current_content)
chunks.append(
self._create_chunk(
content=chunk_text.strip(),
source_path=source_path,
start_line=chunk_start_line,
end_line=base_line + para_start - 1,
file_type=file_type,
metadata=metadata,
)
)
# Overlap: include last paragraph if it fits
if current_content and self.count_tokens(current_content[-1]) <= self.chunk_overlap:
current_content = [current_content[-1]]
current_tokens = self.count_tokens(current_content[-1])
else:
current_content = []
current_tokens = 0
chunk_start_line = base_line + para_start
current_content.append(para_content)
current_tokens += para_tokens
# Final chunk
if current_content:
chunk_text = "\n\n".join(current_content)
end_line_num = base_line + (paragraphs[-1]["end_line"] if paragraphs else 0)
chunks.append(
self._create_chunk(
content=chunk_text.strip(),
source_path=source_path,
start_line=chunk_start_line,
end_line=end_line_num,
file_type=file_type,
metadata=metadata,
)
)
return chunks
def _split_into_paragraphs(self, content: str) -> list[dict[str, Any]]:
"""Split content into paragraphs with metadata."""
paragraphs: list[dict[str, Any]] = []
lines = content.split("\n")
current_para: list[str] = []
para_start = 0
in_code_block = False
for i, line in enumerate(lines):
# Track code blocks (keep them as single units)
if line.strip().startswith("```"):
if in_code_block:
# End of code block
current_para.append(line)
in_code_block = False
else:
# Start of code block - save previous paragraph
if current_para and any(p.strip() for p in current_para):
para_content = "\n".join(current_para)
paragraphs.append({
"content": para_content,
"tokens": self.count_tokens(para_content),
"start_line": para_start,
"end_line": i - 1,
})
current_para = [line]
para_start = i
in_code_block = True
continue
if in_code_block:
current_para.append(line)
continue
# Empty line indicates paragraph break
if not line.strip():
if current_para and any(p.strip() for p in current_para):
para_content = "\n".join(current_para)
paragraphs.append({
"content": para_content,
"tokens": self.count_tokens(para_content),
"start_line": para_start,
"end_line": i - 1,
})
current_para = []
para_start = i + 1
else:
if not current_para:
para_start = i
current_para.append(line)
# Final paragraph
if current_para and any(p.strip() for p in current_para):
para_content = "\n".join(current_para)
paragraphs.append({
"content": para_content,
"tokens": self.count_tokens(para_content),
"start_line": para_start,
"end_line": len(lines) - 1,
})
return paragraphs
def _split_large_paragraph(
self,
content: str,
source_path: str | None,
file_type: FileType,
metadata: dict[str, Any],
base_line: int,
) -> list[Chunk]:
"""Split a large paragraph into smaller chunks."""
# Try splitting by sentences
sentences = self._split_into_sentences(content)
chunks: list[Chunk] = []
current_content: list[str] = []
current_tokens = 0
for sentence in sentences:
sentence_tokens = self.count_tokens(sentence)
# If single sentence is too large, truncate
if sentence_tokens > self.chunk_size:
if current_content:
chunk_text = " ".join(current_content)
chunks.append(
self._create_chunk(
content=chunk_text.strip(),
source_path=source_path,
start_line=base_line,
end_line=base_line,
file_type=file_type,
metadata=metadata,
)
)
current_content = []
current_tokens = 0
truncated = self.truncate_to_tokens(sentence, self.chunk_size)
chunks.append(
self._create_chunk(
content=truncated.strip(),
source_path=source_path,
start_line=base_line,
end_line=base_line,
file_type=file_type,
metadata={**metadata, "truncated": True},
)
)
continue
if current_tokens + sentence_tokens > self.chunk_size and current_content:
chunk_text = " ".join(current_content)
chunks.append(
self._create_chunk(
content=chunk_text.strip(),
source_path=source_path,
start_line=base_line,
end_line=base_line,
file_type=file_type,
metadata=metadata,
)
)
# Overlap with last sentence
if current_content and self.count_tokens(current_content[-1]) <= self.chunk_overlap:
current_content = [current_content[-1]]
current_tokens = self.count_tokens(current_content[-1])
else:
current_content = []
current_tokens = 0
current_content.append(sentence)
current_tokens += sentence_tokens
# Final chunk
if current_content:
chunk_text = " ".join(current_content)
chunks.append(
self._create_chunk(
content=chunk_text.strip(),
source_path=source_path,
start_line=base_line,
end_line=base_line,
file_type=file_type,
metadata=metadata,
)
)
return chunks
def _split_into_sentences(self, text: str) -> list[str]:
"""Split text into sentences."""
# Simple sentence splitting on common terminators
# More sophisticated splitting could use nltk or spacy
sentence_endings = re.compile(r"(?<=[.!?])\s+")
sentences = sentence_endings.split(text)
return [s.strip() for s in sentences if s.strip()]

View File

@@ -0,0 +1,389 @@
"""
Plain text chunking implementation.
Provides simple text chunking with paragraph and sentence
boundary detection.
"""
import logging
import re
from typing import Any
from chunking.base import BaseChunker
from config import Settings
from models import Chunk, ChunkType, FileType
logger = logging.getLogger(__name__)
class TextChunker(BaseChunker):
"""
Plain text chunker with paragraph awareness.
Features:
- Splits on paragraph boundaries
- Falls back to sentence/word boundaries
- Configurable overlap for context preservation
"""
def __init__(
self,
chunk_size: int,
chunk_overlap: int,
settings: Settings | None = None,
) -> None:
"""Initialize text chunker."""
super().__init__(chunk_size, chunk_overlap, settings)
@property
def chunk_type(self) -> ChunkType:
"""Get chunk type."""
return ChunkType.TEXT
def chunk(
self,
content: str,
source_path: str | None = None,
file_type: FileType | None = None,
metadata: dict[str, Any] | None = None,
) -> list[Chunk]:
"""
Chunk plain text content.
Tries paragraph boundaries first, then sentences.
"""
if not content.strip():
return []
metadata = metadata or {}
# Check if content fits in a single chunk
total_tokens = self.count_tokens(content)
if total_tokens <= self.chunk_size:
return [
self._create_chunk(
content=content.strip(),
source_path=source_path,
start_line=1,
end_line=content.count("\n") + 1,
file_type=file_type,
metadata=metadata,
)
]
# Try paragraph-based chunking
paragraphs = self._split_paragraphs(content)
if len(paragraphs) > 1:
return self._chunk_by_paragraphs(
paragraphs, source_path, file_type, metadata
)
# Fall back to sentence-based chunking
return self._chunk_by_sentences(
content, source_path, file_type, metadata
)
def _split_paragraphs(self, content: str) -> list[dict[str, Any]]:
"""Split content into paragraphs."""
paragraphs: list[dict[str, Any]] = []
# Split on double newlines (paragraph boundaries)
raw_paras = re.split(r"\n\s*\n", content)
line_num = 1
for para in raw_paras:
para = para.strip()
if not para:
continue
para_lines = para.count("\n") + 1
paragraphs.append({
"content": para,
"tokens": self.count_tokens(para),
"start_line": line_num,
"end_line": line_num + para_lines - 1,
})
line_num += para_lines + 1 # +1 for blank line between paragraphs
return paragraphs
def _chunk_by_paragraphs(
self,
paragraphs: list[dict[str, Any]],
source_path: str | None,
file_type: FileType | None,
metadata: dict[str, Any],
) -> list[Chunk]:
"""Chunk by combining paragraphs up to size limit."""
chunks: list[Chunk] = []
current_paras: list[str] = []
current_tokens = 0
chunk_start = paragraphs[0]["start_line"] if paragraphs else 1
chunk_end = chunk_start
for para in paragraphs:
para_content = para["content"]
para_tokens = para["tokens"]
# Handle paragraphs larger than chunk size
if para_tokens > self.chunk_size:
# Flush current content
if current_paras:
chunk_text = "\n\n".join(current_paras)
chunks.append(
self._create_chunk(
content=chunk_text,
source_path=source_path,
start_line=chunk_start,
end_line=chunk_end,
file_type=file_type,
metadata=metadata,
)
)
current_paras = []
current_tokens = 0
# Split large paragraph
sub_chunks = self._split_large_text(
para_content,
source_path,
file_type,
metadata,
para["start_line"],
)
chunks.extend(sub_chunks)
chunk_start = para["end_line"] + 1
chunk_end = chunk_start
continue
# Check if adding paragraph exceeds limit
if current_tokens + para_tokens > self.chunk_size and current_paras:
chunk_text = "\n\n".join(current_paras)
chunks.append(
self._create_chunk(
content=chunk_text,
source_path=source_path,
start_line=chunk_start,
end_line=chunk_end,
file_type=file_type,
metadata=metadata,
)
)
# Overlap: keep last paragraph if small enough
overlap_para = None
if current_paras and self.count_tokens(current_paras[-1]) <= self.chunk_overlap:
overlap_para = current_paras[-1]
current_paras = [overlap_para] if overlap_para else []
current_tokens = self.count_tokens(overlap_para) if overlap_para else 0
chunk_start = para["start_line"]
current_paras.append(para_content)
current_tokens += para_tokens
chunk_end = para["end_line"]
# Final chunk
if current_paras:
chunk_text = "\n\n".join(current_paras)
chunks.append(
self._create_chunk(
content=chunk_text,
source_path=source_path,
start_line=chunk_start,
end_line=chunk_end,
file_type=file_type,
metadata=metadata,
)
)
return chunks
def _chunk_by_sentences(
self,
content: str,
source_path: str | None,
file_type: FileType | None,
metadata: dict[str, Any],
) -> list[Chunk]:
"""Chunk by sentences."""
sentences = self._split_sentences(content)
if not sentences:
return []
chunks: list[Chunk] = []
current_sentences: list[str] = []
current_tokens = 0
for sentence in sentences:
sentence_tokens = self.count_tokens(sentence)
# Handle sentences larger than chunk size
if sentence_tokens > self.chunk_size:
if current_sentences:
chunk_text = " ".join(current_sentences)
chunks.append(
self._create_chunk(
content=chunk_text,
source_path=source_path,
start_line=1,
end_line=1,
file_type=file_type,
metadata=metadata,
)
)
current_sentences = []
current_tokens = 0
# Truncate large sentence
truncated = self.truncate_to_tokens(sentence, self.chunk_size)
chunks.append(
self._create_chunk(
content=truncated,
source_path=source_path,
start_line=1,
end_line=1,
file_type=file_type,
metadata={**metadata, "truncated": True},
)
)
continue
# Check if adding sentence exceeds limit
if current_tokens + sentence_tokens > self.chunk_size and current_sentences:
chunk_text = " ".join(current_sentences)
chunks.append(
self._create_chunk(
content=chunk_text,
source_path=source_path,
start_line=1,
end_line=1,
file_type=file_type,
metadata=metadata,
)
)
# Overlap: keep last sentence if small enough
overlap = None
if current_sentences and self.count_tokens(current_sentences[-1]) <= self.chunk_overlap:
overlap = current_sentences[-1]
current_sentences = [overlap] if overlap else []
current_tokens = self.count_tokens(overlap) if overlap else 0
current_sentences.append(sentence)
current_tokens += sentence_tokens
# Final chunk
if current_sentences:
chunk_text = " ".join(current_sentences)
chunks.append(
self._create_chunk(
content=chunk_text,
source_path=source_path,
start_line=1,
end_line=content.count("\n") + 1,
file_type=file_type,
metadata=metadata,
)
)
return chunks
def _split_sentences(self, text: str) -> list[str]:
"""Split text into sentences."""
# Handle common sentence endings
# This is a simple approach - production might use nltk or spacy
sentence_pattern = re.compile(
r"(?<=[.!?])\s+(?=[A-Z])|" # Standard sentence ending
r"(?<=[.!?])\s*$|" # End of text
r"(?<=\n)\s*(?=\S)" # Newlines as boundaries
)
sentences = sentence_pattern.split(text)
return [s.strip() for s in sentences if s.strip()]
def _split_large_text(
self,
text: str,
source_path: str | None,
file_type: FileType | None,
metadata: dict[str, Any],
base_line: int,
) -> list[Chunk]:
"""Split text that exceeds chunk size."""
# First try sentences
sentences = self._split_sentences(text)
if len(sentences) > 1:
return self._chunk_by_sentences(
text, source_path, file_type, metadata
)
# Fall back to word-based splitting
return self._chunk_by_words(
text, source_path, file_type, metadata, base_line
)
def _chunk_by_words(
self,
text: str,
source_path: str | None,
file_type: FileType | None,
metadata: dict[str, Any],
base_line: int,
) -> list[Chunk]:
"""Last resort: chunk by words."""
words = text.split()
chunks: list[Chunk] = []
current_words: list[str] = []
current_tokens = 0
for word in words:
word_tokens = self.count_tokens(word + " ")
if current_tokens + word_tokens > self.chunk_size and current_words:
chunk_text = " ".join(current_words)
chunks.append(
self._create_chunk(
content=chunk_text,
source_path=source_path,
start_line=base_line,
end_line=base_line,
file_type=file_type,
metadata=metadata,
)
)
# Word overlap
overlap_count = 0
overlap_words: list[str] = []
for w in reversed(current_words):
w_tokens = self.count_tokens(w + " ")
if overlap_count + w_tokens > self.chunk_overlap:
break
overlap_words.insert(0, w)
overlap_count += w_tokens
current_words = overlap_words
current_tokens = overlap_count
current_words.append(word)
current_tokens += word_tokens
# Final chunk
if current_words:
chunk_text = " ".join(current_words)
chunks.append(
self._create_chunk(
content=chunk_text,
source_path=source_path,
start_line=base_line,
end_line=base_line,
file_type=file_type,
metadata=metadata,
)
)
return chunks

View File

@@ -0,0 +1,331 @@
"""
Collection management for Knowledge Base MCP Server.
Provides operations for managing document collections including
ingestion, deletion, and statistics.
"""
import logging
from typing import Any
from chunking.base import ChunkerFactory, get_chunker_factory
from config import Settings, get_settings
from database import DatabaseManager, get_database_manager
from embeddings import EmbeddingGenerator, get_embedding_generator
from models import (
ChunkType,
CollectionStats,
DeleteRequest,
DeleteResult,
FileType,
IngestRequest,
IngestResult,
ListCollectionsResponse,
)
logger = logging.getLogger(__name__)
class CollectionManager:
"""
Manages knowledge base collections.
Handles document ingestion, chunking, embedding generation,
and collection operations.
"""
def __init__(
self,
settings: Settings | None = None,
database: DatabaseManager | None = None,
embeddings: EmbeddingGenerator | None = None,
chunker_factory: ChunkerFactory | None = None,
) -> None:
"""Initialize collection manager."""
self._settings = settings or get_settings()
self._database = database
self._embeddings = embeddings
self._chunker_factory = chunker_factory
@property
def database(self) -> DatabaseManager:
"""Get database manager."""
if self._database is None:
self._database = get_database_manager()
return self._database
@property
def embeddings(self) -> EmbeddingGenerator:
"""Get embedding generator."""
if self._embeddings is None:
self._embeddings = get_embedding_generator()
return self._embeddings
@property
def chunker_factory(self) -> ChunkerFactory:
"""Get chunker factory."""
if self._chunker_factory is None:
self._chunker_factory = get_chunker_factory()
return self._chunker_factory
async def ingest(self, request: IngestRequest) -> IngestResult:
"""
Ingest content into the knowledge base.
Chunks the content, generates embeddings, and stores them.
Args:
request: Ingest request with content and options
Returns:
Ingest result with created chunk IDs
"""
try:
# Chunk the content
chunks = self.chunker_factory.chunk_content(
content=request.content,
source_path=request.source_path,
file_type=request.file_type,
chunk_type=request.chunk_type,
metadata=request.metadata,
)
if not chunks:
return IngestResult(
success=True,
chunks_created=0,
embeddings_generated=0,
source_path=request.source_path,
collection=request.collection,
chunk_ids=[],
)
# Extract chunk contents for embedding
chunk_texts = [chunk.content for chunk in chunks]
# Generate embeddings
embeddings_list = await self.embeddings.generate_batch(
texts=chunk_texts,
project_id=request.project_id,
agent_id=request.agent_id,
)
# Store embeddings
chunk_ids: list[str] = []
for chunk, embedding in zip(chunks, embeddings_list, strict=True):
# Build metadata with chunk info
chunk_metadata = {
**request.metadata,
**chunk.metadata,
"token_count": chunk.token_count,
}
chunk_id = await self.database.store_embedding(
project_id=request.project_id,
collection=request.collection,
content=chunk.content,
embedding=embedding,
chunk_type=chunk.chunk_type,
source_path=chunk.source_path or request.source_path,
start_line=chunk.start_line,
end_line=chunk.end_line,
file_type=chunk.file_type or request.file_type,
metadata=chunk_metadata,
)
chunk_ids.append(chunk_id)
logger.info(
f"Ingested {len(chunks)} chunks into collection '{request.collection}' "
f"for project {request.project_id}"
)
return IngestResult(
success=True,
chunks_created=len(chunks),
embeddings_generated=len(embeddings_list),
source_path=request.source_path,
collection=request.collection,
chunk_ids=chunk_ids,
)
except Exception as e:
logger.error(f"Ingest error: {e}")
return IngestResult(
success=False,
chunks_created=0,
embeddings_generated=0,
source_path=request.source_path,
collection=request.collection,
chunk_ids=[],
error=str(e),
)
async def delete(self, request: DeleteRequest) -> DeleteResult:
"""
Delete content from the knowledge base.
Supports deletion by source path, collection, or chunk IDs.
Args:
request: Delete request with target specification
Returns:
Delete result with count of deleted chunks
"""
try:
deleted_count = 0
if request.chunk_ids:
# Delete specific chunks
deleted_count = await self.database.delete_by_ids(
project_id=request.project_id,
chunk_ids=request.chunk_ids,
)
elif request.source_path:
# Delete by source path
deleted_count = await self.database.delete_by_source(
project_id=request.project_id,
source_path=request.source_path,
collection=request.collection,
)
elif request.collection:
# Delete entire collection
deleted_count = await self.database.delete_collection(
project_id=request.project_id,
collection=request.collection,
)
else:
return DeleteResult(
success=False,
chunks_deleted=0,
error="Must specify chunk_ids, source_path, or collection",
)
logger.info(
f"Deleted {deleted_count} chunks for project {request.project_id}"
)
return DeleteResult(
success=True,
chunks_deleted=deleted_count,
)
except Exception as e:
logger.error(f"Delete error: {e}")
return DeleteResult(
success=False,
chunks_deleted=0,
error=str(e),
)
async def list_collections(self, project_id: str) -> ListCollectionsResponse:
"""
List all collections for a project.
Args:
project_id: Project ID
Returns:
List of collection info
"""
collections = await self.database.list_collections(project_id)
return ListCollectionsResponse(
project_id=project_id,
collections=collections,
total_collections=len(collections),
)
async def get_collection_stats(
self,
project_id: str,
collection: str,
) -> CollectionStats:
"""
Get statistics for a collection.
Args:
project_id: Project ID
collection: Collection name
Returns:
Collection statistics
"""
return await self.database.get_collection_stats(project_id, collection)
async def update_document(
self,
project_id: str,
agent_id: str,
source_path: str,
content: str,
collection: str = "default",
chunk_type: ChunkType = ChunkType.TEXT,
file_type: FileType | None = None,
metadata: dict[str, Any] | None = None,
) -> IngestResult:
"""
Update a document by replacing existing chunks.
Deletes existing chunks for the source path and ingests new content.
Args:
project_id: Project ID
agent_id: Agent ID
source_path: Source file path
content: New content
collection: Collection name
chunk_type: Type of content
file_type: File type for code chunking
metadata: Additional metadata
Returns:
Ingest result
"""
# First delete existing chunks for this source
await self.database.delete_by_source(
project_id=project_id,
source_path=source_path,
collection=collection,
)
# Then ingest new content
request = IngestRequest(
project_id=project_id,
agent_id=agent_id,
content=content,
source_path=source_path,
collection=collection,
chunk_type=chunk_type,
file_type=file_type,
metadata=metadata or {},
)
return await self.ingest(request)
async def cleanup_expired(self) -> int:
"""
Remove expired embeddings from all collections.
Returns:
Number of embeddings removed
"""
return await self.database.cleanup_expired()
# Global collection manager instance (lazy initialization)
_collection_manager: CollectionManager | None = None
def get_collection_manager() -> CollectionManager:
"""Get the global collection manager instance."""
global _collection_manager
if _collection_manager is None:
_collection_manager = CollectionManager()
return _collection_manager
def reset_collection_manager() -> None:
"""Reset the global collection manager (for testing)."""
global _collection_manager
_collection_manager = None

View File

@@ -0,0 +1,138 @@
"""
Configuration for Knowledge Base MCP Server.
Uses pydantic-settings for environment variable loading.
"""
import os
from pydantic import Field
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
"""Application settings loaded from environment."""
# Server settings
host: str = Field(default="0.0.0.0", description="Server host")
port: int = Field(default=8002, description="Server port")
debug: bool = Field(default=False, description="Debug mode")
# Database settings
database_url: str = Field(
default="postgresql://postgres:postgres@localhost:5432/syndarix",
description="PostgreSQL connection URL with pgvector extension",
)
database_pool_size: int = Field(default=10, description="Connection pool size")
database_pool_max_overflow: int = Field(
default=20, description="Max overflow connections"
)
# Redis settings
redis_url: str = Field(
default="redis://localhost:6379/0",
description="Redis connection URL",
)
# LLM Gateway settings (for embeddings)
llm_gateway_url: str = Field(
default="http://localhost:8001",
description="LLM Gateway MCP server URL",
)
# Embedding settings
embedding_model: str = Field(
default="text-embedding-3-large",
description="Default embedding model",
)
embedding_dimension: int = Field(
default=1536,
description="Embedding vector dimension",
)
embedding_batch_size: int = Field(
default=100,
description="Max texts per embedding batch",
)
embedding_cache_ttl: int = Field(
default=86400,
description="Embedding cache TTL in seconds (24 hours)",
)
# Chunking settings
code_chunk_size: int = Field(
default=500,
description="Target tokens per code chunk",
)
code_chunk_overlap: int = Field(
default=50,
description="Token overlap between code chunks",
)
markdown_chunk_size: int = Field(
default=800,
description="Target tokens per markdown chunk",
)
markdown_chunk_overlap: int = Field(
default=100,
description="Token overlap between markdown chunks",
)
text_chunk_size: int = Field(
default=400,
description="Target tokens per text chunk",
)
text_chunk_overlap: int = Field(
default=50,
description="Token overlap between text chunks",
)
# Search settings
search_default_limit: int = Field(
default=10,
description="Default number of search results",
)
search_max_limit: int = Field(
default=100,
description="Maximum number of search results",
)
semantic_threshold: float = Field(
default=0.7,
description="Minimum similarity score for semantic search",
)
hybrid_semantic_weight: float = Field(
default=0.7,
description="Weight for semantic results in hybrid search",
)
hybrid_keyword_weight: float = Field(
default=0.3,
description="Weight for keyword results in hybrid search",
)
# Storage settings
embedding_ttl_days: int = Field(
default=30,
description="TTL for embedding records in days (0 = no expiry)",
)
model_config = {"env_prefix": "KB_", "env_file": ".env", "extra": "ignore"}
# Global settings instance (lazy initialization)
_settings: Settings | None = None
def get_settings() -> Settings:
"""Get the global settings instance."""
global _settings
if _settings is None:
_settings = Settings()
return _settings
def reset_settings() -> None:
"""Reset the global settings (for testing)."""
global _settings
_settings = None
def is_test_mode() -> bool:
"""Check if running in test mode."""
return os.getenv("IS_TEST", "").lower() in ("true", "1", "yes")

View File

@@ -0,0 +1,774 @@
"""
Database management for Knowledge Base MCP Server.
Handles PostgreSQL connections with pgvector extension for
vector similarity search operations.
"""
import hashlib
import logging
import uuid
from contextlib import asynccontextmanager
from datetime import UTC, datetime, timedelta
from typing import Any
import asyncpg
from pgvector.asyncpg import register_vector
from config import Settings, get_settings
from exceptions import (
CollectionNotFoundError,
DatabaseConnectionError,
DatabaseQueryError,
ErrorCode,
KnowledgeBaseError,
)
from models import (
ChunkType,
CollectionInfo,
CollectionStats,
FileType,
KnowledgeEmbedding,
)
logger = logging.getLogger(__name__)
class DatabaseManager:
"""
Manages PostgreSQL connections and vector operations.
Uses asyncpg for async operations and pgvector for
vector similarity search.
"""
def __init__(self, settings: Settings | None = None) -> None:
"""Initialize database manager."""
self._settings = settings or get_settings()
self._pool: asyncpg.Pool | None = None # type: ignore[type-arg]
@property
def pool(self) -> asyncpg.Pool: # type: ignore[type-arg]
"""Get connection pool, raising if not initialized."""
if self._pool is None:
raise DatabaseConnectionError("Database pool not initialized")
return self._pool
async def initialize(self) -> None:
"""Initialize connection pool and create schema."""
try:
self._pool = await asyncpg.create_pool(
self._settings.database_url,
min_size=2,
max_size=self._settings.database_pool_size,
max_inactive_connection_lifetime=300,
init=self._init_connection,
)
logger.info("Database pool created successfully")
# Create schema
await self._create_schema()
logger.info("Database schema initialized")
except Exception as e:
logger.error(f"Failed to initialize database: {e}")
raise DatabaseConnectionError(
message=f"Failed to initialize database: {e}",
cause=e,
)
async def _init_connection(self, conn: asyncpg.Connection) -> None: # type: ignore[type-arg]
"""Initialize a connection with pgvector support."""
await register_vector(conn)
async def _create_schema(self) -> None:
"""Create database schema if not exists."""
async with self.pool.acquire() as conn:
# Enable pgvector extension
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
# Create main embeddings table
await conn.execute("""
CREATE TABLE IF NOT EXISTS knowledge_embeddings (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
project_id VARCHAR(255) NOT NULL,
collection VARCHAR(255) NOT NULL DEFAULT 'default',
content TEXT NOT NULL,
embedding vector(1536),
chunk_type VARCHAR(50) NOT NULL,
source_path TEXT,
start_line INTEGER,
end_line INTEGER,
file_type VARCHAR(50),
metadata JSONB DEFAULT '{}',
content_hash VARCHAR(64),
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW(),
expires_at TIMESTAMPTZ
)
""")
# Create indexes for common queries
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_embeddings_project_collection
ON knowledge_embeddings(project_id, collection)
""")
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_embeddings_source_path
ON knowledge_embeddings(project_id, source_path)
""")
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_embeddings_content_hash
ON knowledge_embeddings(project_id, content_hash)
""")
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_embeddings_chunk_type
ON knowledge_embeddings(project_id, chunk_type)
""")
# Create HNSW index for vector similarity search
# This dramatically improves search performance
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_embeddings_vector_hnsw
ON knowledge_embeddings
USING hnsw (embedding vector_cosine_ops)
WITH (m = 16, ef_construction = 64)
""")
# Create GIN index for full-text search
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_embeddings_content_fts
ON knowledge_embeddings
USING gin(to_tsvector('english', content))
""")
async def close(self) -> None:
"""Close the connection pool."""
if self._pool:
await self._pool.close()
self._pool = None
logger.info("Database pool closed")
@asynccontextmanager
async def acquire(self) -> Any:
"""Acquire a connection from the pool."""
async with self.pool.acquire() as conn:
yield conn
@staticmethod
def compute_content_hash(content: str) -> str:
"""Compute SHA-256 hash of content for deduplication."""
return hashlib.sha256(content.encode()).hexdigest()
async def store_embedding(
self,
project_id: str,
collection: str,
content: str,
embedding: list[float],
chunk_type: ChunkType,
source_path: str | None = None,
start_line: int | None = None,
end_line: int | None = None,
file_type: FileType | None = None,
metadata: dict[str, Any] | None = None,
) -> str:
"""
Store an embedding in the database.
Returns:
The ID of the stored embedding.
"""
content_hash = self.compute_content_hash(content)
metadata = metadata or {}
# Calculate expiration if TTL is set
expires_at = None
if self._settings.embedding_ttl_days > 0:
expires_at = datetime.now(UTC) + timedelta(
days=self._settings.embedding_ttl_days
)
try:
async with self.acquire() as conn:
# Check for duplicate content
existing = await conn.fetchval(
"""
SELECT id FROM knowledge_embeddings
WHERE project_id = $1 AND collection = $2 AND content_hash = $3
""",
project_id,
collection,
content_hash,
)
if existing:
# Update existing embedding
await conn.execute(
"""
UPDATE knowledge_embeddings
SET embedding = $1, updated_at = NOW(), expires_at = $2,
metadata = $3, source_path = $4, start_line = $5,
end_line = $6, file_type = $7
WHERE id = $8
""",
embedding,
expires_at,
metadata,
source_path,
start_line,
end_line,
file_type.value if file_type else None,
existing,
)
logger.debug(f"Updated existing embedding: {existing}")
return str(existing)
# Insert new embedding
embedding_id = await conn.fetchval(
"""
INSERT INTO knowledge_embeddings
(project_id, collection, content, embedding, chunk_type,
source_path, start_line, end_line, file_type, metadata,
content_hash, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
RETURNING id
""",
project_id,
collection,
content,
embedding,
chunk_type.value,
source_path,
start_line,
end_line,
file_type.value if file_type else None,
metadata,
content_hash,
expires_at,
)
logger.debug(f"Stored new embedding: {embedding_id}")
return str(embedding_id)
except asyncpg.PostgresError as e:
logger.error(f"Database error storing embedding: {e}")
raise DatabaseQueryError(
message=f"Failed to store embedding: {e}",
cause=e,
)
async def store_embeddings_batch(
self,
embeddings: list[tuple[str, str, str, list[float], ChunkType, dict[str, Any]]],
) -> list[str]:
"""
Store multiple embeddings in a batch.
Args:
embeddings: List of (project_id, collection, content, embedding, chunk_type, metadata)
Returns:
List of created embedding IDs.
"""
if not embeddings:
return []
ids = []
expires_at = None
if self._settings.embedding_ttl_days > 0:
expires_at = datetime.now(UTC) + timedelta(
days=self._settings.embedding_ttl_days
)
try:
async with self.acquire() as conn:
for project_id, collection, content, embedding, chunk_type, metadata in embeddings:
content_hash = self.compute_content_hash(content)
source_path = metadata.get("source_path")
start_line = metadata.get("start_line")
end_line = metadata.get("end_line")
file_type = metadata.get("file_type")
embedding_id = await conn.fetchval(
"""
INSERT INTO knowledge_embeddings
(project_id, collection, content, embedding, chunk_type,
source_path, start_line, end_line, file_type, metadata,
content_hash, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
ON CONFLICT DO NOTHING
RETURNING id
""",
project_id,
collection,
content,
embedding,
chunk_type.value,
source_path,
start_line,
end_line,
file_type,
metadata,
content_hash,
expires_at,
)
if embedding_id:
ids.append(str(embedding_id))
logger.info(f"Stored {len(ids)} embeddings in batch")
return ids
except asyncpg.PostgresError as e:
logger.error(f"Database error in batch store: {e}")
raise DatabaseQueryError(
message=f"Failed to store embeddings batch: {e}",
cause=e,
)
async def semantic_search(
self,
project_id: str,
query_embedding: list[float],
collection: str | None = None,
limit: int = 10,
threshold: float = 0.7,
file_types: list[FileType] | None = None,
) -> list[tuple[KnowledgeEmbedding, float]]:
"""
Perform semantic (vector) search.
Returns:
List of (embedding, similarity_score) tuples.
"""
try:
async with self.acquire() as conn:
# Build query with optional filters
query = """
SELECT
id, project_id, collection, content, embedding,
chunk_type, source_path, start_line, end_line,
file_type, metadata, content_hash, created_at,
updated_at, expires_at,
1 - (embedding <=> $1) as similarity
FROM knowledge_embeddings
WHERE project_id = $2
AND (expires_at IS NULL OR expires_at > NOW())
"""
params: list[Any] = [query_embedding, project_id]
param_idx = 3
if collection:
query += f" AND collection = ${param_idx}"
params.append(collection)
param_idx += 1
if file_types:
file_type_values = [ft.value for ft in file_types]
query += f" AND file_type = ANY(${param_idx})"
params.append(file_type_values)
param_idx += 1
query += f"""
HAVING 1 - (embedding <=> $1) >= ${param_idx}
ORDER BY similarity DESC
LIMIT ${param_idx + 1}
"""
params.extend([threshold, limit])
rows = await conn.fetch(query, *params)
results = []
for row in rows:
embedding = KnowledgeEmbedding(
id=str(row["id"]),
project_id=row["project_id"],
collection=row["collection"],
content=row["content"],
embedding=list(row["embedding"]),
chunk_type=ChunkType(row["chunk_type"]),
source_path=row["source_path"],
start_line=row["start_line"],
end_line=row["end_line"],
file_type=FileType(row["file_type"]) if row["file_type"] else None,
metadata=row["metadata"] or {},
content_hash=row["content_hash"],
created_at=row["created_at"],
updated_at=row["updated_at"],
expires_at=row["expires_at"],
)
results.append((embedding, float(row["similarity"])))
return results
except asyncpg.PostgresError as e:
logger.error(f"Semantic search error: {e}")
raise DatabaseQueryError(
message=f"Semantic search failed: {e}",
cause=e,
)
async def keyword_search(
self,
project_id: str,
query: str,
collection: str | None = None,
limit: int = 10,
file_types: list[FileType] | None = None,
) -> list[tuple[KnowledgeEmbedding, float]]:
"""
Perform full-text keyword search.
Returns:
List of (embedding, relevance_score) tuples.
"""
try:
async with self.acquire() as conn:
# Build query with optional filters
sql = """
SELECT
id, project_id, collection, content, embedding,
chunk_type, source_path, start_line, end_line,
file_type, metadata, content_hash, created_at,
updated_at, expires_at,
ts_rank(to_tsvector('english', content),
plainto_tsquery('english', $1)) as relevance
FROM knowledge_embeddings
WHERE project_id = $2
AND to_tsvector('english', content) @@ plainto_tsquery('english', $1)
AND (expires_at IS NULL OR expires_at > NOW())
"""
params: list[Any] = [query, project_id]
param_idx = 3
if collection:
sql += f" AND collection = ${param_idx}"
params.append(collection)
param_idx += 1
if file_types:
file_type_values = [ft.value for ft in file_types]
sql += f" AND file_type = ANY(${param_idx})"
params.append(file_type_values)
param_idx += 1
sql += f" ORDER BY relevance DESC LIMIT ${param_idx}"
params.append(limit)
rows = await conn.fetch(sql, *params)
results = []
for row in rows:
embedding = KnowledgeEmbedding(
id=str(row["id"]),
project_id=row["project_id"],
collection=row["collection"],
content=row["content"],
embedding=list(row["embedding"]) if row["embedding"] else [],
chunk_type=ChunkType(row["chunk_type"]),
source_path=row["source_path"],
start_line=row["start_line"],
end_line=row["end_line"],
file_type=FileType(row["file_type"]) if row["file_type"] else None,
metadata=row["metadata"] or {},
content_hash=row["content_hash"],
created_at=row["created_at"],
updated_at=row["updated_at"],
expires_at=row["expires_at"],
)
# Normalize relevance to 0-1 scale (approximate)
normalized_score = min(1.0, float(row["relevance"]))
results.append((embedding, normalized_score))
return results
except asyncpg.PostgresError as e:
logger.error(f"Keyword search error: {e}")
raise DatabaseQueryError(
message=f"Keyword search failed: {e}",
cause=e,
)
async def delete_by_source(
self,
project_id: str,
source_path: str,
collection: str | None = None,
) -> int:
"""Delete all embeddings for a source path."""
try:
async with self.acquire() as conn:
if collection:
result = await conn.execute(
"""
DELETE FROM knowledge_embeddings
WHERE project_id = $1 AND source_path = $2 AND collection = $3
""",
project_id,
source_path,
collection,
)
else:
result = await conn.execute(
"""
DELETE FROM knowledge_embeddings
WHERE project_id = $1 AND source_path = $2
""",
project_id,
source_path,
)
# Parse "DELETE N" result
count = int(result.split()[-1])
logger.info(f"Deleted {count} embeddings for source: {source_path}")
return count
except asyncpg.PostgresError as e:
logger.error(f"Delete error: {e}")
raise DatabaseQueryError(
message=f"Failed to delete by source: {e}",
cause=e,
)
async def delete_collection(
self,
project_id: str,
collection: str,
) -> int:
"""Delete an entire collection."""
try:
async with self.acquire() as conn:
result = await conn.execute(
"""
DELETE FROM knowledge_embeddings
WHERE project_id = $1 AND collection = $2
""",
project_id,
collection,
)
count = int(result.split()[-1])
logger.info(f"Deleted collection {collection}: {count} embeddings")
return count
except asyncpg.PostgresError as e:
logger.error(f"Delete collection error: {e}")
raise DatabaseQueryError(
message=f"Failed to delete collection: {e}",
cause=e,
)
async def delete_by_ids(
self,
project_id: str,
chunk_ids: list[str],
) -> int:
"""Delete specific embeddings by ID."""
if not chunk_ids:
return 0
try:
# Convert string IDs to UUIDs
uuids = [uuid.UUID(cid) for cid in chunk_ids]
async with self.acquire() as conn:
result = await conn.execute(
"""
DELETE FROM knowledge_embeddings
WHERE project_id = $1 AND id = ANY($2)
""",
project_id,
uuids,
)
count = int(result.split()[-1])
logger.info(f"Deleted {count} embeddings by ID")
return count
except ValueError as e:
raise KnowledgeBaseError(
message=f"Invalid chunk ID format: {e}",
code=ErrorCode.INVALID_REQUEST,
)
except asyncpg.PostgresError as e:
logger.error(f"Delete by IDs error: {e}")
raise DatabaseQueryError(
message=f"Failed to delete by IDs: {e}",
cause=e,
)
async def list_collections(
self,
project_id: str,
) -> list[CollectionInfo]:
"""List all collections for a project."""
try:
async with self.acquire() as conn:
rows = await conn.fetch(
"""
SELECT
collection,
COUNT(*) as chunk_count,
COALESCE(SUM((metadata->>'token_count')::int), 0) as total_tokens,
ARRAY_AGG(DISTINCT file_type) FILTER (WHERE file_type IS NOT NULL) as file_types,
MIN(created_at) as created_at,
MAX(updated_at) as updated_at
FROM knowledge_embeddings
WHERE project_id = $1
AND (expires_at IS NULL OR expires_at > NOW())
GROUP BY collection
ORDER BY collection
""",
project_id,
)
return [
CollectionInfo(
name=row["collection"],
project_id=project_id,
chunk_count=row["chunk_count"],
total_tokens=row["total_tokens"] or 0,
file_types=row["file_types"] or [],
created_at=row["created_at"],
updated_at=row["updated_at"],
)
for row in rows
]
except asyncpg.PostgresError as e:
logger.error(f"List collections error: {e}")
raise DatabaseQueryError(
message=f"Failed to list collections: {e}",
cause=e,
)
async def get_collection_stats(
self,
project_id: str,
collection: str,
) -> CollectionStats:
"""Get detailed statistics for a collection."""
try:
async with self.acquire() as conn:
# Check if collection exists
exists = await conn.fetchval(
"""
SELECT EXISTS(
SELECT 1 FROM knowledge_embeddings
WHERE project_id = $1 AND collection = $2
)
""",
project_id,
collection,
)
if not exists:
raise CollectionNotFoundError(collection, project_id)
# Get stats
row = await conn.fetchrow(
"""
SELECT
COUNT(*) as chunk_count,
COUNT(DISTINCT source_path) as unique_sources,
COALESCE(SUM((metadata->>'token_count')::int), 0) as total_tokens,
COALESCE(AVG(LENGTH(content)), 0) as avg_chunk_size,
MIN(created_at) as oldest_chunk,
MAX(created_at) as newest_chunk
FROM knowledge_embeddings
WHERE project_id = $1 AND collection = $2
AND (expires_at IS NULL OR expires_at > NOW())
""",
project_id,
collection,
)
# Get chunk type breakdown
chunk_rows = await conn.fetch(
"""
SELECT chunk_type, COUNT(*) as count
FROM knowledge_embeddings
WHERE project_id = $1 AND collection = $2
AND (expires_at IS NULL OR expires_at > NOW())
GROUP BY chunk_type
""",
project_id,
collection,
)
chunk_types = {r["chunk_type"]: r["count"] for r in chunk_rows}
# Get file type breakdown
file_rows = await conn.fetch(
"""
SELECT file_type, COUNT(*) as count
FROM knowledge_embeddings
WHERE project_id = $1 AND collection = $2
AND file_type IS NOT NULL
AND (expires_at IS NULL OR expires_at > NOW())
GROUP BY file_type
""",
project_id,
collection,
)
file_types = {r["file_type"]: r["count"] for r in file_rows}
return CollectionStats(
collection=collection,
project_id=project_id,
chunk_count=row["chunk_count"],
unique_sources=row["unique_sources"],
total_tokens=row["total_tokens"] or 0,
avg_chunk_size=float(row["avg_chunk_size"] or 0),
chunk_types=chunk_types,
file_types=file_types,
oldest_chunk=row["oldest_chunk"],
newest_chunk=row["newest_chunk"],
)
except CollectionNotFoundError:
raise
except asyncpg.PostgresError as e:
logger.error(f"Get collection stats error: {e}")
raise DatabaseQueryError(
message=f"Failed to get collection stats: {e}",
cause=e,
)
async def cleanup_expired(self) -> int:
"""Remove expired embeddings."""
try:
async with self.acquire() as conn:
result = await conn.execute(
"""
DELETE FROM knowledge_embeddings
WHERE expires_at IS NOT NULL AND expires_at < NOW()
"""
)
count = int(result.split()[-1])
if count > 0:
logger.info(f"Cleaned up {count} expired embeddings")
return count
except asyncpg.PostgresError as e:
logger.error(f"Cleanup error: {e}")
raise DatabaseQueryError(
message=f"Failed to cleanup expired: {e}",
cause=e,
)
# Global database manager instance (lazy initialization)
_db_manager: DatabaseManager | None = None
def get_database_manager() -> DatabaseManager:
"""Get the global database manager instance."""
global _db_manager
if _db_manager is None:
_db_manager = DatabaseManager()
return _db_manager
def reset_database_manager() -> None:
"""Reset the global database manager (for testing)."""
global _db_manager
_db_manager = None

View File

@@ -0,0 +1,426 @@
"""
Embedding generation for Knowledge Base MCP Server.
Generates vector embeddings via the LLM Gateway MCP server
with caching support using Redis.
"""
import hashlib
import json
import logging
import httpx
import redis.asyncio as redis
from config import Settings, get_settings
from exceptions import (
EmbeddingDimensionMismatchError,
EmbeddingGenerationError,
ErrorCode,
KnowledgeBaseError,
)
logger = logging.getLogger(__name__)
class EmbeddingGenerator:
"""
Generates embeddings via LLM Gateway.
Features:
- Batched embedding generation
- Redis caching for deduplication
- Automatic retry on transient failures
"""
def __init__(self, settings: Settings | None = None) -> None:
"""Initialize embedding generator."""
self._settings = settings or get_settings()
self._redis: redis.Redis | None = None # type: ignore[type-arg]
self._http_client: httpx.AsyncClient | None = None
@property
def redis_client(self) -> redis.Redis: # type: ignore[type-arg]
"""Get Redis client, raising if not initialized."""
if self._redis is None:
raise KnowledgeBaseError(
message="Redis client not initialized",
code=ErrorCode.INTERNAL_ERROR,
)
return self._redis
@property
def http_client(self) -> httpx.AsyncClient:
"""Get HTTP client, raising if not initialized."""
if self._http_client is None:
raise KnowledgeBaseError(
message="HTTP client not initialized",
code=ErrorCode.INTERNAL_ERROR,
)
return self._http_client
async def initialize(self) -> None:
"""Initialize Redis and HTTP clients."""
try:
self._redis = redis.from_url(
self._settings.redis_url,
encoding="utf-8",
decode_responses=True,
)
# Test connection
await self._redis.ping() # type: ignore[misc]
logger.info("Redis connection established for embedding cache")
except Exception as e:
logger.warning(f"Redis connection failed, caching disabled: {e}")
self._redis = None
self._http_client = httpx.AsyncClient(
base_url=self._settings.llm_gateway_url,
timeout=httpx.Timeout(60.0, connect=10.0),
)
logger.info("Embedding generator initialized")
async def close(self) -> None:
"""Close connections."""
if self._redis:
await self._redis.close()
self._redis = None
if self._http_client:
await self._http_client.aclose()
self._http_client = None
logger.info("Embedding generator closed")
def _cache_key(self, text: str) -> str:
"""Generate cache key for a text."""
text_hash = hashlib.sha256(text.encode()).hexdigest()[:32]
model = self._settings.embedding_model
return f"kb:emb:{model}:{text_hash}"
async def _get_cached(self, text: str) -> list[float] | None:
"""Get cached embedding if available."""
if not self._redis:
return None
try:
key = self._cache_key(text)
cached = await self._redis.get(key)
if cached:
logger.debug(f"Cache hit for embedding: {key[:20]}...")
return json.loads(cached)
return None
except Exception as e:
logger.warning(f"Cache read error: {e}")
return None
async def _set_cached(self, text: str, embedding: list[float]) -> None:
"""Cache an embedding."""
if not self._redis:
return
try:
key = self._cache_key(text)
await self._redis.setex(
key,
self._settings.embedding_cache_ttl,
json.dumps(embedding),
)
logger.debug(f"Cached embedding: {key[:20]}...")
except Exception as e:
logger.warning(f"Cache write error: {e}")
async def _get_cached_batch(
self, texts: list[str]
) -> tuple[list[list[float] | None], list[int]]:
"""
Get cached embeddings for a batch of texts.
Returns:
Tuple of (embeddings list with None for misses, indices of misses)
"""
if not self._redis:
return [None] * len(texts), list(range(len(texts)))
try:
keys = [self._cache_key(text) for text in texts]
cached_values = await self._redis.mget(keys)
embeddings: list[list[float] | None] = []
missing_indices: list[int] = []
for i, cached in enumerate(cached_values):
if cached:
embeddings.append(json.loads(cached))
else:
embeddings.append(None)
missing_indices.append(i)
cache_hits = len(texts) - len(missing_indices)
if cache_hits > 0:
logger.debug(f"Batch cache hits: {cache_hits}/{len(texts)}")
return embeddings, missing_indices
except Exception as e:
logger.warning(f"Batch cache read error: {e}")
return [None] * len(texts), list(range(len(texts)))
async def _set_cached_batch(
self, texts: list[str], embeddings: list[list[float]]
) -> None:
"""Cache a batch of embeddings."""
if not self._redis or len(texts) != len(embeddings):
return
try:
pipe = self._redis.pipeline()
for text, embedding in zip(texts, embeddings, strict=True):
key = self._cache_key(text)
pipe.setex(
key,
self._settings.embedding_cache_ttl,
json.dumps(embedding),
)
await pipe.execute()
logger.debug(f"Cached {len(texts)} embeddings in batch")
except Exception as e:
logger.warning(f"Batch cache write error: {e}")
async def generate(
self,
text: str,
project_id: str = "system",
agent_id: str = "knowledge-base",
) -> list[float]:
"""
Generate embedding for a single text.
Args:
text: Text to embed
project_id: Project ID for cost attribution
agent_id: Agent ID for cost attribution
Returns:
Embedding vector
"""
# Check cache first
cached = await self._get_cached(text)
if cached:
return cached
# Generate via LLM Gateway
embeddings = await self._call_llm_gateway(
[text], project_id, agent_id
)
if not embeddings:
raise EmbeddingGenerationError(
message="No embedding returned from LLM Gateway",
texts_count=1,
)
embedding = embeddings[0]
# Validate dimension
if len(embedding) != self._settings.embedding_dimension:
raise EmbeddingDimensionMismatchError(
expected=self._settings.embedding_dimension,
actual=len(embedding),
)
# Cache the result
await self._set_cached(text, embedding)
return embedding
async def generate_batch(
self,
texts: list[str],
project_id: str = "system",
agent_id: str = "knowledge-base",
) -> list[list[float]]:
"""
Generate embeddings for multiple texts.
Uses caching and batches requests to LLM Gateway.
Args:
texts: List of texts to embed
project_id: Project ID for cost attribution
agent_id: Agent ID for cost attribution
Returns:
List of embedding vectors
"""
if not texts:
return []
# Check cache for existing embeddings
cached_embeddings, missing_indices = await self._get_cached_batch(texts)
# If all cached, return immediately
if not missing_indices:
logger.debug(f"All {len(texts)} embeddings served from cache")
return [e for e in cached_embeddings if e is not None]
# Get texts that need embedding
texts_to_embed = [texts[i] for i in missing_indices]
# Generate embeddings in batches
new_embeddings: list[list[float]] = []
batch_size = self._settings.embedding_batch_size
for i in range(0, len(texts_to_embed), batch_size):
batch = texts_to_embed[i : i + batch_size]
batch_embeddings = await self._call_llm_gateway(
batch, project_id, agent_id
)
new_embeddings.extend(batch_embeddings)
# Validate dimensions
for embedding in new_embeddings:
if len(embedding) != self._settings.embedding_dimension:
raise EmbeddingDimensionMismatchError(
expected=self._settings.embedding_dimension,
actual=len(embedding),
)
# Cache new embeddings
await self._set_cached_batch(texts_to_embed, new_embeddings)
# Combine cached and new embeddings
result: list[list[float]] = []
new_idx = 0
for i in range(len(texts)):
if cached_embeddings[i] is not None:
result.append(cached_embeddings[i]) # type: ignore[arg-type]
else:
result.append(new_embeddings[new_idx])
new_idx += 1
logger.info(
f"Generated {len(new_embeddings)} embeddings, "
f"{len(texts) - len(missing_indices)} from cache"
)
return result
async def _call_llm_gateway(
self,
texts: list[str],
project_id: str,
agent_id: str,
) -> list[list[float]]:
"""
Call LLM Gateway to generate embeddings.
Uses JSON-RPC 2.0 protocol to call the embedding tool.
"""
try:
# JSON-RPC 2.0 request for embedding tool
request = {
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "generate_embeddings",
"arguments": {
"project_id": project_id,
"agent_id": agent_id,
"texts": texts,
"model": self._settings.embedding_model,
},
},
"id": 1,
}
response = await self.http_client.post("/mcp", json=request)
response.raise_for_status()
result = response.json()
if "error" in result:
error = result["error"]
raise EmbeddingGenerationError(
message=f"LLM Gateway error: {error.get('message', 'Unknown')}",
texts_count=len(texts),
details=error.get("data"),
)
# Extract embeddings from response
content = result.get("result", {}).get("content", [])
if not content:
raise EmbeddingGenerationError(
message="Empty response from LLM Gateway",
texts_count=len(texts),
)
# Parse the response content
# LLM Gateway returns embeddings in content[0].text as JSON
embeddings_data = content[0].get("text", "")
if isinstance(embeddings_data, str):
embeddings_data = json.loads(embeddings_data)
embeddings = embeddings_data.get("embeddings", [])
if len(embeddings) != len(texts):
raise EmbeddingGenerationError(
message=f"Embedding count mismatch: expected {len(texts)}, got {len(embeddings)}",
texts_count=len(texts),
)
return embeddings
except httpx.HTTPStatusError as e:
logger.error(f"LLM Gateway HTTP error: {e}")
raise EmbeddingGenerationError(
message=f"LLM Gateway request failed: {e.response.status_code}",
texts_count=len(texts),
cause=e,
)
except httpx.RequestError as e:
logger.error(f"LLM Gateway request error: {e}")
raise EmbeddingGenerationError(
message=f"Failed to connect to LLM Gateway: {e}",
texts_count=len(texts),
cause=e,
)
except json.JSONDecodeError as e:
logger.error(f"Invalid JSON response from LLM Gateway: {e}")
raise EmbeddingGenerationError(
message="Invalid response format from LLM Gateway",
texts_count=len(texts),
cause=e,
)
except EmbeddingGenerationError:
raise
except Exception as e:
logger.error(f"Unexpected error generating embeddings: {e}")
raise EmbeddingGenerationError(
message=f"Unexpected error: {e}",
texts_count=len(texts),
cause=e,
)
# Global embedding generator instance (lazy initialization)
_embedding_generator: EmbeddingGenerator | None = None
def get_embedding_generator() -> EmbeddingGenerator:
"""Get the global embedding generator instance."""
global _embedding_generator
if _embedding_generator is None:
_embedding_generator = EmbeddingGenerator()
return _embedding_generator
def reset_embedding_generator() -> None:
"""Reset the global embedding generator (for testing)."""
global _embedding_generator
_embedding_generator = None

View File

@@ -0,0 +1,409 @@
"""
Custom exceptions for Knowledge Base MCP Server.
Provides structured error handling with error codes and details.
"""
from enum import Enum
from typing import Any
class ErrorCode(str, Enum):
"""Error codes for Knowledge Base operations."""
# General errors
UNKNOWN_ERROR = "KB_UNKNOWN_ERROR"
INVALID_REQUEST = "KB_INVALID_REQUEST"
INTERNAL_ERROR = "KB_INTERNAL_ERROR"
# Database errors
DATABASE_CONNECTION_ERROR = "KB_DATABASE_CONNECTION_ERROR"
DATABASE_QUERY_ERROR = "KB_DATABASE_QUERY_ERROR"
DATABASE_INTEGRITY_ERROR = "KB_DATABASE_INTEGRITY_ERROR"
# Embedding errors
EMBEDDING_GENERATION_ERROR = "KB_EMBEDDING_GENERATION_ERROR"
EMBEDDING_DIMENSION_MISMATCH = "KB_EMBEDDING_DIMENSION_MISMATCH"
EMBEDDING_RATE_LIMIT = "KB_EMBEDDING_RATE_LIMIT"
# Chunking errors
CHUNKING_ERROR = "KB_CHUNKING_ERROR"
UNSUPPORTED_FILE_TYPE = "KB_UNSUPPORTED_FILE_TYPE"
FILE_TOO_LARGE = "KB_FILE_TOO_LARGE"
ENCODING_ERROR = "KB_ENCODING_ERROR"
# Search errors
SEARCH_ERROR = "KB_SEARCH_ERROR"
INVALID_SEARCH_TYPE = "KB_INVALID_SEARCH_TYPE"
SEARCH_TIMEOUT = "KB_SEARCH_TIMEOUT"
# Collection errors
COLLECTION_NOT_FOUND = "KB_COLLECTION_NOT_FOUND"
COLLECTION_ALREADY_EXISTS = "KB_COLLECTION_ALREADY_EXISTS"
# Document errors
DOCUMENT_NOT_FOUND = "KB_DOCUMENT_NOT_FOUND"
DOCUMENT_ALREADY_EXISTS = "KB_DOCUMENT_ALREADY_EXISTS"
INVALID_DOCUMENT = "KB_INVALID_DOCUMENT"
# Project errors
PROJECT_NOT_FOUND = "KB_PROJECT_NOT_FOUND"
PROJECT_ACCESS_DENIED = "KB_PROJECT_ACCESS_DENIED"
class KnowledgeBaseError(Exception):
"""
Base exception for Knowledge Base errors.
All custom exceptions inherit from this class.
"""
def __init__(
self,
message: str,
code: ErrorCode = ErrorCode.UNKNOWN_ERROR,
details: dict[str, Any] | None = None,
cause: Exception | None = None,
) -> None:
"""
Initialize Knowledge Base error.
Args:
message: Human-readable error message
code: Error code for programmatic handling
details: Additional error details
cause: Original exception that caused this error
"""
super().__init__(message)
self.message = message
self.code = code
self.details = details or {}
self.cause = cause
def to_dict(self) -> dict[str, Any]:
"""Convert error to dictionary for JSON response."""
result: dict[str, Any] = {
"error": self.code.value,
"message": self.message,
}
if self.details:
result["details"] = self.details
return result
def __str__(self) -> str:
"""String representation."""
return f"[{self.code.value}] {self.message}"
def __repr__(self) -> str:
"""Detailed representation."""
return (
f"{self.__class__.__name__}("
f"message={self.message!r}, "
f"code={self.code.value!r}, "
f"details={self.details!r})"
)
# Database Errors
class DatabaseError(KnowledgeBaseError):
"""Base class for database-related errors."""
def __init__(
self,
message: str,
code: ErrorCode = ErrorCode.DATABASE_QUERY_ERROR,
details: dict[str, Any] | None = None,
cause: Exception | None = None,
) -> None:
super().__init__(message, code, details, cause)
class DatabaseConnectionError(DatabaseError):
"""Failed to connect to the database."""
def __init__(
self,
message: str = "Failed to connect to database",
details: dict[str, Any] | None = None,
cause: Exception | None = None,
) -> None:
super().__init__(message, ErrorCode.DATABASE_CONNECTION_ERROR, details, cause)
class DatabaseQueryError(DatabaseError):
"""Database query failed."""
def __init__(
self,
message: str,
query: str | None = None,
details: dict[str, Any] | None = None,
cause: Exception | None = None,
) -> None:
details = details or {}
if query:
details["query"] = query
super().__init__(message, ErrorCode.DATABASE_QUERY_ERROR, details, cause)
# Embedding Errors
class EmbeddingError(KnowledgeBaseError):
"""Base class for embedding-related errors."""
def __init__(
self,
message: str,
code: ErrorCode = ErrorCode.EMBEDDING_GENERATION_ERROR,
details: dict[str, Any] | None = None,
cause: Exception | None = None,
) -> None:
super().__init__(message, code, details, cause)
class EmbeddingGenerationError(EmbeddingError):
"""Failed to generate embeddings."""
def __init__(
self,
message: str = "Failed to generate embeddings",
texts_count: int | None = None,
details: dict[str, Any] | None = None,
cause: Exception | None = None,
) -> None:
details = details or {}
if texts_count is not None:
details["texts_count"] = texts_count
super().__init__(message, ErrorCode.EMBEDDING_GENERATION_ERROR, details, cause)
class EmbeddingDimensionMismatchError(EmbeddingError):
"""Embedding dimension doesn't match expected dimension."""
def __init__(
self,
expected: int,
actual: int,
details: dict[str, Any] | None = None,
) -> None:
details = details or {}
details["expected_dimension"] = expected
details["actual_dimension"] = actual
message = f"Embedding dimension mismatch: expected {expected}, got {actual}"
super().__init__(message, ErrorCode.EMBEDDING_DIMENSION_MISMATCH, details)
# Chunking Errors
class ChunkingError(KnowledgeBaseError):
"""Base class for chunking-related errors."""
def __init__(
self,
message: str,
code: ErrorCode = ErrorCode.CHUNKING_ERROR,
details: dict[str, Any] | None = None,
cause: Exception | None = None,
) -> None:
super().__init__(message, code, details, cause)
class UnsupportedFileTypeError(ChunkingError):
"""File type is not supported for chunking."""
def __init__(
self,
file_type: str,
supported_types: list[str] | None = None,
details: dict[str, Any] | None = None,
) -> None:
details = details or {}
details["file_type"] = file_type
if supported_types:
details["supported_types"] = supported_types
message = f"Unsupported file type: {file_type}"
super().__init__(message, ErrorCode.UNSUPPORTED_FILE_TYPE, details)
class FileTooLargeError(ChunkingError):
"""File exceeds maximum allowed size."""
def __init__(
self,
file_size: int,
max_size: int,
details: dict[str, Any] | None = None,
) -> None:
details = details or {}
details["file_size"] = file_size
details["max_size"] = max_size
message = f"File too large: {file_size} bytes exceeds limit of {max_size} bytes"
super().__init__(message, ErrorCode.FILE_TOO_LARGE, details)
class EncodingError(ChunkingError):
"""Failed to decode file content."""
def __init__(
self,
message: str = "Failed to decode file content",
encoding: str | None = None,
details: dict[str, Any] | None = None,
cause: Exception | None = None,
) -> None:
details = details or {}
if encoding:
details["encoding"] = encoding
super().__init__(message, ErrorCode.ENCODING_ERROR, details, cause)
# Search Errors
class SearchError(KnowledgeBaseError):
"""Base class for search-related errors."""
def __init__(
self,
message: str,
code: ErrorCode = ErrorCode.SEARCH_ERROR,
details: dict[str, Any] | None = None,
cause: Exception | None = None,
) -> None:
super().__init__(message, code, details, cause)
class InvalidSearchTypeError(SearchError):
"""Invalid search type specified."""
def __init__(
self,
search_type: str,
valid_types: list[str] | None = None,
details: dict[str, Any] | None = None,
) -> None:
details = details or {}
details["search_type"] = search_type
if valid_types:
details["valid_types"] = valid_types
message = f"Invalid search type: {search_type}"
super().__init__(message, ErrorCode.INVALID_SEARCH_TYPE, details)
class SearchTimeoutError(SearchError):
"""Search operation timed out."""
def __init__(
self,
timeout: float,
details: dict[str, Any] | None = None,
) -> None:
details = details or {}
details["timeout"] = timeout
message = f"Search timed out after {timeout} seconds"
super().__init__(message, ErrorCode.SEARCH_TIMEOUT, details)
# Collection Errors
class CollectionError(KnowledgeBaseError):
"""Base class for collection-related errors."""
pass
class CollectionNotFoundError(CollectionError):
"""Collection does not exist."""
def __init__(
self,
collection: str,
project_id: str | None = None,
details: dict[str, Any] | None = None,
) -> None:
details = details or {}
details["collection"] = collection
if project_id:
details["project_id"] = project_id
message = f"Collection not found: {collection}"
super().__init__(message, ErrorCode.COLLECTION_NOT_FOUND, details)
# Document Errors
class DocumentError(KnowledgeBaseError):
"""Base class for document-related errors."""
pass
class DocumentNotFoundError(DocumentError):
"""Document does not exist."""
def __init__(
self,
source_path: str,
project_id: str | None = None,
details: dict[str, Any] | None = None,
) -> None:
details = details or {}
details["source_path"] = source_path
if project_id:
details["project_id"] = project_id
message = f"Document not found: {source_path}"
super().__init__(message, ErrorCode.DOCUMENT_NOT_FOUND, details)
class InvalidDocumentError(DocumentError):
"""Document content is invalid."""
def __init__(
self,
message: str = "Invalid document content",
details: dict[str, Any] | None = None,
cause: Exception | None = None,
) -> None:
super().__init__(message, ErrorCode.INVALID_DOCUMENT, details, cause)
# Project Errors
class ProjectError(KnowledgeBaseError):
"""Base class for project-related errors."""
pass
class ProjectNotFoundError(ProjectError):
"""Project does not exist."""
def __init__(
self,
project_id: str,
details: dict[str, Any] | None = None,
) -> None:
details = details or {}
details["project_id"] = project_id
message = f"Project not found: {project_id}"
super().__init__(message, ErrorCode.PROJECT_NOT_FOUND, details)
class ProjectAccessDeniedError(ProjectError):
"""Access to project is denied."""
def __init__(
self,
project_id: str,
details: dict[str, Any] | None = None,
) -> None:
details = details or {}
details["project_id"] = project_id
message = f"Access denied to project: {project_id}"
super().__init__(message, ErrorCode.PROJECT_ACCESS_DENIED, details)

View File

@@ -0,0 +1,321 @@
"""
Data models for Knowledge Base MCP Server.
Defines database models, Pydantic schemas, and data structures
for RAG operations with pgvector.
"""
from dataclasses import dataclass, field
from datetime import UTC, datetime
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
class SearchType(str, Enum):
"""Types of search supported."""
SEMANTIC = "semantic" # Vector similarity search
KEYWORD = "keyword" # Full-text search
HYBRID = "hybrid" # Combined semantic + keyword
class ChunkType(str, Enum):
"""Types of content chunks."""
CODE = "code"
MARKDOWN = "markdown"
TEXT = "text"
DOCUMENTATION = "documentation"
class FileType(str, Enum):
"""Supported file types for chunking."""
PYTHON = "python"
JAVASCRIPT = "javascript"
TYPESCRIPT = "typescript"
GO = "go"
RUST = "rust"
JAVA = "java"
MARKDOWN = "markdown"
TEXT = "text"
JSON = "json"
YAML = "yaml"
TOML = "toml"
# File extension to FileType mapping
FILE_EXTENSION_MAP: dict[str, FileType] = {
".py": FileType.PYTHON,
".js": FileType.JAVASCRIPT,
".jsx": FileType.JAVASCRIPT,
".ts": FileType.TYPESCRIPT,
".tsx": FileType.TYPESCRIPT,
".go": FileType.GO,
".rs": FileType.RUST,
".java": FileType.JAVA,
".md": FileType.MARKDOWN,
".mdx": FileType.MARKDOWN,
".txt": FileType.TEXT,
".json": FileType.JSON,
".yaml": FileType.YAML,
".yml": FileType.YAML,
".toml": FileType.TOML,
}
@dataclass
class Chunk:
"""A chunk of content ready for embedding."""
content: str
chunk_type: ChunkType
file_type: FileType | None = None
source_path: str | None = None
start_line: int | None = None
end_line: int | None = None
metadata: dict[str, Any] = field(default_factory=dict)
token_count: int = 0
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"content": self.content,
"chunk_type": self.chunk_type.value,
"file_type": self.file_type.value if self.file_type else None,
"source_path": self.source_path,
"start_line": self.start_line,
"end_line": self.end_line,
"metadata": self.metadata,
"token_count": self.token_count,
}
@dataclass
class KnowledgeEmbedding:
"""
A knowledge embedding stored in the database.
Represents a chunk of content with its vector embedding.
"""
id: str
project_id: str
collection: str
content: str
embedding: list[float]
chunk_type: ChunkType
source_path: str | None = None
start_line: int | None = None
end_line: int | None = None
file_type: FileType | None = None
metadata: dict[str, Any] = field(default_factory=dict)
content_hash: str | None = None
created_at: datetime = field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = field(default_factory=lambda: datetime.now(UTC))
expires_at: datetime | None = None
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary (excluding embedding for size)."""
return {
"id": self.id,
"project_id": self.project_id,
"collection": self.collection,
"content": self.content,
"chunk_type": self.chunk_type.value,
"source_path": self.source_path,
"start_line": self.start_line,
"end_line": self.end_line,
"file_type": self.file_type.value if self.file_type else None,
"metadata": self.metadata,
"content_hash": self.content_hash,
"created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(),
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
}
# Pydantic Request/Response Models
class IngestRequest(BaseModel):
"""Request to ingest content into the knowledge base."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
content: str = Field(..., description="Content to ingest")
source_path: str | None = Field(
default=None, description="Source file path for reference"
)
collection: str = Field(
default="default", description="Collection to store in"
)
chunk_type: ChunkType = Field(
default=ChunkType.TEXT, description="Type of content"
)
file_type: FileType | None = Field(
default=None, description="File type for code chunking"
)
metadata: dict[str, Any] = Field(
default_factory=dict, description="Additional metadata"
)
class IngestResult(BaseModel):
"""Result of an ingest operation."""
success: bool = Field(..., description="Whether ingest succeeded")
chunks_created: int = Field(default=0, description="Number of chunks created")
embeddings_generated: int = Field(
default=0, description="Number of embeddings generated"
)
source_path: str | None = Field(default=None, description="Source path ingested")
collection: str = Field(default="default", description="Collection stored in")
chunk_ids: list[str] = Field(
default_factory=list, description="IDs of created chunks"
)
error: str | None = Field(default=None, description="Error message if failed")
class SearchRequest(BaseModel):
"""Request to search the knowledge base."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
query: str = Field(..., description="Search query")
search_type: SearchType = Field(
default=SearchType.HYBRID, description="Type of search"
)
collection: str | None = Field(
default=None, description="Collection to search (None = all)"
)
limit: int = Field(default=10, ge=1, le=100, description="Max results")
threshold: float = Field(
default=0.7, ge=0.0, le=1.0, description="Minimum similarity score"
)
file_types: list[FileType] | None = Field(
default=None, description="Filter by file types"
)
include_metadata: bool = Field(
default=True, description="Include metadata in results"
)
class SearchResult(BaseModel):
"""A single search result."""
id: str = Field(..., description="Chunk ID")
content: str = Field(..., description="Chunk content")
score: float = Field(..., description="Relevance score (0-1)")
source_path: str | None = Field(default=None, description="Source file path")
start_line: int | None = Field(default=None, description="Start line in source")
end_line: int | None = Field(default=None, description="End line in source")
chunk_type: str = Field(..., description="Type of chunk")
file_type: str | None = Field(default=None, description="File type")
collection: str = Field(..., description="Collection name")
metadata: dict[str, Any] = Field(
default_factory=dict, description="Additional metadata"
)
@classmethod
def from_embedding(
cls, embedding: KnowledgeEmbedding, score: float
) -> "SearchResult":
"""Create SearchResult from KnowledgeEmbedding."""
return cls(
id=embedding.id,
content=embedding.content,
score=score,
source_path=embedding.source_path,
start_line=embedding.start_line,
end_line=embedding.end_line,
chunk_type=embedding.chunk_type.value,
file_type=embedding.file_type.value if embedding.file_type else None,
collection=embedding.collection,
metadata=embedding.metadata,
)
class SearchResponse(BaseModel):
"""Response from a search operation."""
query: str = Field(..., description="Original query")
search_type: str = Field(..., description="Type of search performed")
results: list[SearchResult] = Field(
default_factory=list, description="Search results"
)
total_results: int = Field(default=0, description="Total results found")
search_time_ms: float = Field(default=0.0, description="Search time in ms")
class DeleteRequest(BaseModel):
"""Request to delete from the knowledge base."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
source_path: str | None = Field(
default=None, description="Delete by source path"
)
collection: str | None = Field(
default=None, description="Delete entire collection"
)
chunk_ids: list[str] | None = Field(
default=None, description="Delete specific chunks"
)
class DeleteResult(BaseModel):
"""Result of a delete operation."""
success: bool = Field(..., description="Whether delete succeeded")
chunks_deleted: int = Field(default=0, description="Number of chunks deleted")
error: str | None = Field(default=None, description="Error message if failed")
class CollectionInfo(BaseModel):
"""Information about a collection."""
name: str = Field(..., description="Collection name")
project_id: str = Field(..., description="Project ID")
chunk_count: int = Field(default=0, description="Number of chunks")
total_tokens: int = Field(default=0, description="Total tokens stored")
file_types: list[str] = Field(
default_factory=list, description="File types in collection"
)
created_at: datetime = Field(..., description="Creation time")
updated_at: datetime = Field(..., description="Last update time")
class ListCollectionsResponse(BaseModel):
"""Response for listing collections."""
project_id: str = Field(..., description="Project ID")
collections: list[CollectionInfo] = Field(
default_factory=list, description="Collections in project"
)
total_collections: int = Field(default=0, description="Total count")
class CollectionStats(BaseModel):
"""Statistics for a collection."""
collection: str = Field(..., description="Collection name")
project_id: str = Field(..., description="Project ID")
chunk_count: int = Field(default=0, description="Number of chunks")
unique_sources: int = Field(default=0, description="Unique source files")
total_tokens: int = Field(default=0, description="Total tokens")
avg_chunk_size: float = Field(default=0.0, description="Average chunk size")
chunk_types: dict[str, int] = Field(
default_factory=dict, description="Count by chunk type"
)
file_types: dict[str, int] = Field(
default_factory=dict, description="Count by file type"
)
oldest_chunk: datetime | None = Field(
default=None, description="Oldest chunk timestamp"
)
newest_chunk: datetime | None = Field(
default=None, description="Newest chunk timestamp"
)

View File

@@ -4,21 +4,101 @@ version = "0.1.0"
description = "Syndarix Knowledge Base MCP Server - RAG with pgvector for semantic search"
requires-python = ">=3.12"
dependencies = [
"fastmcp>=0.1.0",
"fastmcp>=2.0.0",
"asyncpg>=0.29.0",
"pgvector>=0.3.0",
"redis>=5.0.0",
"pydantic>=2.0.0",
"pydantic-settings>=2.0.0",
"tiktoken>=0.7.0",
"httpx>=0.27.0",
"uvicorn>=0.30.0",
"fastapi>=0.115.0",
]
[project.optional-dependencies]
dev = [
"pytest>=8.0.0",
"pytest-asyncio>=0.23.0",
"pytest-asyncio>=0.24.0",
"pytest-cov>=5.0.0",
"fakeredis>=2.25.0",
"ruff>=0.8.0",
"mypy>=1.11.0",
]
[project.scripts]
knowledge-base = "server:main"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["."]
exclude = ["tests/", "*.md", "Dockerfile"]
[tool.hatch.build.targets.sdist]
include = ["*.py", "pyproject.toml"]
[tool.ruff]
target-version = "py312"
line-length = 88
[tool.ruff.lint]
select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # pyflakes
"I", # isort
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"UP", # pyupgrade
"ARG", # flake8-unused-arguments
"SIM", # flake8-simplify
]
ignore = [
"E501", # line too long (handled by formatter)
"B008", # do not perform function calls in argument defaults
"B904", # raise from in except (too noisy)
]
[tool.ruff.lint.isort]
known-first-party = ["config", "models", "exceptions", "database", "embeddings", "chunking", "search", "collections"]
[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"
testpaths = ["tests"]
addopts = "-v --tb=short"
filterwarnings = [
"ignore::DeprecationWarning",
]
[tool.coverage.run]
source = ["."]
omit = ["tests/*", "conftest.py"]
branch = true
[tool.coverage.report]
exclude_lines = [
"pragma: no cover",
"def __repr__",
"raise NotImplementedError",
"if TYPE_CHECKING:",
"if __name__ == .__main__.:",
]
fail_under = 55
show_missing = true
[tool.mypy]
python_version = "3.12"
warn_return_any = false
warn_unused_ignores = false
disallow_untyped_defs = true
ignore_missing_imports = true
plugins = ["pydantic.mypy"]
[[tool.mypy.overrides]]
module = "tests.*"
disallow_untyped_defs = false
ignore_errors = true

View File

@@ -0,0 +1,285 @@
"""
Search implementations for Knowledge Base MCP Server.
Provides semantic (vector), keyword (full-text), and hybrid search
capabilities over the knowledge base.
"""
import logging
import time
from config import Settings, get_settings
from database import DatabaseManager, get_database_manager
from embeddings import EmbeddingGenerator, get_embedding_generator
from exceptions import InvalidSearchTypeError, SearchError
from models import (
SearchRequest,
SearchResponse,
SearchResult,
SearchType,
)
logger = logging.getLogger(__name__)
class SearchEngine:
"""
Unified search engine supporting multiple search types.
Features:
- Semantic search using vector similarity
- Keyword search using full-text search
- Hybrid search combining both approaches
- Configurable result fusion and weighting
"""
def __init__(
self,
settings: Settings | None = None,
database: DatabaseManager | None = None,
embeddings: EmbeddingGenerator | None = None,
) -> None:
"""Initialize search engine."""
self._settings = settings or get_settings()
self._database = database
self._embeddings = embeddings
@property
def database(self) -> DatabaseManager:
"""Get database manager."""
if self._database is None:
self._database = get_database_manager()
return self._database
@property
def embeddings(self) -> EmbeddingGenerator:
"""Get embedding generator."""
if self._embeddings is None:
self._embeddings = get_embedding_generator()
return self._embeddings
async def search(self, request: SearchRequest) -> SearchResponse:
"""
Execute a search request.
Args:
request: Search request with query and options
Returns:
Search response with results
"""
start_time = time.time()
try:
if request.search_type == SearchType.SEMANTIC:
results = await self._semantic_search(request)
elif request.search_type == SearchType.KEYWORD:
results = await self._keyword_search(request)
elif request.search_type == SearchType.HYBRID:
results = await self._hybrid_search(request)
else:
raise InvalidSearchTypeError(
search_type=request.search_type,
valid_types=[t.value for t in SearchType],
)
search_time_ms = (time.time() - start_time) * 1000
logger.info(
f"Search completed: type={request.search_type.value}, "
f"results={len(results)}, time={search_time_ms:.1f}ms"
)
return SearchResponse(
query=request.query,
search_type=request.search_type.value,
results=results,
total_results=len(results),
search_time_ms=search_time_ms,
)
except InvalidSearchTypeError:
raise
except Exception as e:
logger.error(f"Search error: {e}")
raise SearchError(
message=f"Search failed: {e}",
cause=e,
)
async def _semantic_search(self, request: SearchRequest) -> list[SearchResult]:
"""Execute semantic (vector) search."""
# Generate embedding for query
query_embedding = await self.embeddings.generate(
text=request.query,
project_id=request.project_id,
agent_id=request.agent_id,
)
# Search database
results = await self.database.semantic_search(
project_id=request.project_id,
query_embedding=query_embedding,
collection=request.collection,
limit=request.limit,
threshold=request.threshold,
file_types=request.file_types,
)
# Convert to SearchResult
return [
SearchResult.from_embedding(embedding, score)
for embedding, score in results
]
async def _keyword_search(self, request: SearchRequest) -> list[SearchResult]:
"""Execute keyword (full-text) search."""
results = await self.database.keyword_search(
project_id=request.project_id,
query=request.query,
collection=request.collection,
limit=request.limit,
file_types=request.file_types,
)
# Filter by threshold (keyword search scores are normalized)
filtered = [
(emb, score) for emb, score in results
if score >= request.threshold
]
return [
SearchResult.from_embedding(embedding, score)
for embedding, score in filtered
]
async def _hybrid_search(self, request: SearchRequest) -> list[SearchResult]:
"""
Execute hybrid search combining semantic and keyword.
Uses Reciprocal Rank Fusion (RRF) for result combination.
"""
# Execute both searches with higher limits for fusion
fusion_limit = min(request.limit * 2, 100)
# Create modified request for sub-searches
semantic_request = SearchRequest(
project_id=request.project_id,
agent_id=request.agent_id,
query=request.query,
search_type=SearchType.SEMANTIC,
collection=request.collection,
limit=fusion_limit,
threshold=request.threshold * 0.8, # Lower threshold for fusion
file_types=request.file_types,
include_metadata=request.include_metadata,
)
keyword_request = SearchRequest(
project_id=request.project_id,
agent_id=request.agent_id,
query=request.query,
search_type=SearchType.KEYWORD,
collection=request.collection,
limit=fusion_limit,
threshold=0.0, # No threshold for keyword, we'll filter after fusion
file_types=request.file_types,
include_metadata=request.include_metadata,
)
# Execute searches
semantic_results = await self._semantic_search(semantic_request)
keyword_results = await self._keyword_search(keyword_request)
# Fuse results using RRF
fused = self._reciprocal_rank_fusion(
semantic_results=semantic_results,
keyword_results=keyword_results,
semantic_weight=self._settings.hybrid_semantic_weight,
keyword_weight=self._settings.hybrid_keyword_weight,
)
# Filter by threshold and limit
filtered = [
result for result in fused
if result.score >= request.threshold
][:request.limit]
return filtered
def _reciprocal_rank_fusion(
self,
semantic_results: list[SearchResult],
keyword_results: list[SearchResult],
semantic_weight: float = 0.7,
keyword_weight: float = 0.3,
k: int = 60, # RRF constant
) -> list[SearchResult]:
"""
Combine results using Reciprocal Rank Fusion.
RRF score = sum(weight / (k + rank)) for each result list.
"""
# Calculate RRF scores
scores: dict[str, float] = {}
results_by_id: dict[str, SearchResult] = {}
# Process semantic results
for rank, result in enumerate(semantic_results, start=1):
rrf_score = semantic_weight / (k + rank)
scores[result.id] = scores.get(result.id, 0) + rrf_score
results_by_id[result.id] = result
# Process keyword results
for rank, result in enumerate(keyword_results, start=1):
rrf_score = keyword_weight / (k + rank)
scores[result.id] = scores.get(result.id, 0) + rrf_score
if result.id not in results_by_id:
results_by_id[result.id] = result
# Sort by combined score
sorted_ids = sorted(scores.keys(), key=lambda x: scores[x], reverse=True)
# Normalize scores to 0-1 range
max_score = max(scores.values()) if scores else 1.0
# Create final results with normalized scores
final_results: list[SearchResult] = []
for result_id in sorted_ids:
result = results_by_id[result_id]
normalized_score = scores[result_id] / max_score
# Create new result with updated score
final_results.append(
SearchResult(
id=result.id,
content=result.content,
score=normalized_score,
source_path=result.source_path,
start_line=result.start_line,
end_line=result.end_line,
chunk_type=result.chunk_type,
file_type=result.file_type,
collection=result.collection,
metadata=result.metadata,
)
)
return final_results
# Global search engine instance (lazy initialization)
_search_engine: SearchEngine | None = None
def get_search_engine() -> SearchEngine:
"""Get the global search engine instance."""
global _search_engine
if _search_engine is None:
_search_engine = SearchEngine()
return _search_engine
def reset_search_engine() -> None:
"""Reset the global search engine (for testing)."""
global _search_engine
_search_engine = None

View File

@@ -1,162 +1,569 @@
"""
Syndarix Knowledge Base MCP Server.
Knowledge Base MCP Server.
Provides RAG capabilities with:
- pgvector for semantic search
- Per-project collection isolation
- Hybrid search (vector + keyword)
- Chunking strategies for code, markdown, and text
Per ADR-008: Knowledge Base RAG Architecture.
Provides RAG capabilities with pgvector for semantic search,
intelligent chunking, and collection management.
"""
import os
import logging
from contextlib import asynccontextmanager
from typing import Any
from fastapi import FastAPI
from fastmcp import FastMCP
from pydantic import Field
# Create MCP server
mcp = FastMCP(
"syndarix-knowledge-base",
from collection_manager import CollectionManager, get_collection_manager
from collections.abc import AsyncIterator
from config import get_settings
from database import DatabaseManager, get_database_manager
from embeddings import EmbeddingGenerator, get_embedding_generator
from exceptions import KnowledgeBaseError
from models import (
ChunkType,
DeleteRequest,
FileType,
IngestRequest,
SearchRequest,
SearchType,
)
from search import SearchEngine, get_search_engine
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# Global instances
_database: DatabaseManager | None = None
_embeddings: EmbeddingGenerator | None = None
_search: SearchEngine | None = None
_collections: CollectionManager | None = None
@asynccontextmanager
async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
"""Application lifespan handler."""
global _database, _embeddings, _search, _collections
logger.info("Starting Knowledge Base MCP Server...")
# Initialize database
_database = get_database_manager()
await _database.initialize()
# Initialize embedding generator
_embeddings = get_embedding_generator()
await _embeddings.initialize()
# Initialize search engine
_search = get_search_engine()
# Initialize collection manager
_collections = get_collection_manager()
logger.info("Knowledge Base MCP Server started successfully")
yield
# Cleanup
logger.info("Shutting down Knowledge Base MCP Server...")
if _embeddings:
await _embeddings.close()
if _database:
await _database.close()
logger.info("Knowledge Base MCP Server shut down")
# Create FastMCP server
mcp = FastMCP("syndarix-knowledge-base")
# Create FastAPI app with lifespan
app = FastAPI(
title="Knowledge Base MCP Server",
description="RAG with pgvector for semantic search",
version="0.1.0",
lifespan=lifespan,
)
# Configuration
DATABASE_URL = os.getenv("DATABASE_URL")
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0")
@app.get("/health")
async def health_check() -> dict[str, Any]:
"""Health check endpoint."""
status: dict[str, Any] = {
"status": "healthy",
"service": "knowledge-base",
"version": "0.1.0",
}
# Check database connection
try:
if _database and _database._pool:
async with _database.acquire() as conn:
await conn.fetchval("SELECT 1")
status["database"] = "connected"
else:
status["database"] = "not initialized"
except Exception as e:
status["database"] = f"error: {e}"
status["status"] = "degraded"
return status
# MCP Tools
@mcp.tool()
async def search_knowledge(
project_id: str,
query: str,
top_k: int = 10,
search_type: str = "hybrid",
filters: dict | None = None,
) -> dict:
project_id: str = Field(..., description="Project ID for scoping"),
agent_id: str = Field(..., description="Agent ID making the request"),
query: str = Field(..., description="Search query"),
search_type: str = Field(
default="hybrid",
description="Search type: semantic, keyword, or hybrid",
),
collection: str | None = Field(
default=None,
description="Collection to search (None = all)",
),
limit: int = Field(
default=10,
ge=1,
le=100,
description="Maximum number of results",
),
threshold: float = Field(
default=0.7,
ge=0.0,
le=1.0,
description="Minimum similarity score",
),
file_types: list[str] | None = Field(
default=None,
description="Filter by file types (python, javascript, etc.)",
),
) -> dict[str, Any]:
"""
Search the project knowledge base.
Search the knowledge base for relevant content.
Args:
project_id: UUID of the project (scopes to project collection)
query: Search query text
top_k: Number of results to return
search_type: Search type (semantic, keyword, hybrid)
filters: Optional filters (file_type, path_prefix, etc.)
Returns:
List of matching documents with scores
Supports semantic (vector), keyword (full-text), and hybrid search.
Returns chunks ranked by relevance to the query.
"""
# TODO: Implement pgvector search
# 1. Generate query embedding via LLM Gateway
# 2. Search project-scoped collection
# 3. Apply filters
# 4. Return results with scores
return {
"status": "not_implemented",
"project_id": project_id,
"query": query,
}
try:
# Parse search type
try:
search_type_enum = SearchType(search_type.lower())
except ValueError:
valid_types = [t.value for t in SearchType]
return {
"success": False,
"error": f"Invalid search type: {search_type}. Valid types: {valid_types}",
}
# Parse file types
file_type_enums = None
if file_types:
try:
file_type_enums = [FileType(ft.lower()) for ft in file_types]
except ValueError as e:
return {
"success": False,
"error": f"Invalid file type: {e}",
}
request = SearchRequest(
project_id=project_id,
agent_id=agent_id,
query=query,
search_type=search_type_enum,
collection=collection,
limit=limit,
threshold=threshold,
file_types=file_type_enums,
)
response = await _search.search(request) # type: ignore[union-attr]
return {
"success": True,
"query": response.query,
"search_type": response.search_type,
"results": [
{
"id": r.id,
"content": r.content,
"score": r.score,
"source_path": r.source_path,
"start_line": r.start_line,
"end_line": r.end_line,
"chunk_type": r.chunk_type,
"file_type": r.file_type,
"collection": r.collection,
"metadata": r.metadata,
}
for r in response.results
],
"total_results": response.total_results,
"search_time_ms": response.search_time_ms,
}
except KnowledgeBaseError as e:
logger.error(f"Search error: {e}")
return {
"success": False,
"error": e.message,
"code": e.code.value,
}
except Exception as e:
logger.error(f"Unexpected search error: {e}")
return {
"success": False,
"error": str(e),
}
@mcp.tool()
async def ingest_document(
project_id: str,
content: str,
source_path: str,
doc_type: str = "text",
metadata: dict | None = None,
) -> dict:
async def ingest_content(
project_id: str = Field(..., description="Project ID for scoping"),
agent_id: str = Field(..., description="Agent ID making the request"),
content: str = Field(..., description="Content to ingest"),
source_path: str | None = Field(
default=None,
description="Source file path for reference",
),
collection: str = Field(
default="default",
description="Collection to store in",
),
chunk_type: str = Field(
default="text",
description="Content type: code, markdown, or text",
),
file_type: str | None = Field(
default=None,
description="File type for code chunking (python, javascript, etc.)",
),
metadata: dict[str, Any] | None = Field(
default=None,
description="Additional metadata to store",
),
) -> dict[str, Any]:
"""
Ingest a document into the knowledge base.
Ingest content into the knowledge base.
Args:
project_id: UUID of the project
content: Document content
source_path: Original file path for reference
doc_type: Document type (code, markdown, text)
metadata: Additional metadata
Returns:
Ingestion result with chunk count
Content is automatically chunked based on type, embedded using
the LLM Gateway, and stored in pgvector for search.
"""
# TODO: Implement document ingestion
# 1. Apply chunking strategy based on doc_type
# 2. Generate embeddings for chunks
# 3. Store in project collection
return {
"status": "not_implemented",
"project_id": project_id,
"source_path": source_path,
}
try:
# Parse chunk type
try:
chunk_type_enum = ChunkType(chunk_type.lower())
except ValueError:
valid_types = [t.value for t in ChunkType]
return {
"success": False,
"error": f"Invalid chunk type: {chunk_type}. Valid types: {valid_types}",
}
# Parse file type
file_type_enum = None
if file_type:
try:
file_type_enum = FileType(file_type.lower())
except ValueError:
valid_types = [t.value for t in FileType]
return {
"success": False,
"error": f"Invalid file type: {file_type}. Valid types: {valid_types}",
}
request = IngestRequest(
project_id=project_id,
agent_id=agent_id,
content=content,
source_path=source_path,
collection=collection,
chunk_type=chunk_type_enum,
file_type=file_type_enum,
metadata=metadata or {},
)
result = await _collections.ingest(request) # type: ignore[union-attr]
return {
"success": result.success,
"chunks_created": result.chunks_created,
"embeddings_generated": result.embeddings_generated,
"source_path": result.source_path,
"collection": result.collection,
"chunk_ids": result.chunk_ids,
"error": result.error,
}
except KnowledgeBaseError as e:
logger.error(f"Ingest error: {e}")
return {
"success": False,
"error": e.message,
"code": e.code.value,
}
except Exception as e:
logger.error(f"Unexpected ingest error: {e}")
return {
"success": False,
"error": str(e),
}
@mcp.tool()
async def ingest_repository(
project_id: str,
repo_path: str,
include_patterns: list[str] | None = None,
exclude_patterns: list[str] | None = None,
) -> dict:
async def delete_content(
project_id: str = Field(..., description="Project ID for scoping"),
agent_id: str = Field(..., description="Agent ID making the request"),
source_path: str | None = Field(
default=None,
description="Delete by source file path",
),
collection: str | None = Field(
default=None,
description="Delete entire collection",
),
chunk_ids: list[str] | None = Field(
default=None,
description="Delete specific chunk IDs",
),
) -> dict[str, Any]:
"""
Ingest an entire repository into the knowledge base.
Delete content from the knowledge base.
Args:
project_id: UUID of the project
repo_path: Path to the repository
include_patterns: Glob patterns to include (e.g., ["*.py", "*.md"])
exclude_patterns: Glob patterns to exclude (e.g., ["node_modules/*"])
Returns:
Ingestion summary with file and chunk counts
Specify either source_path, collection, or chunk_ids to delete.
"""
# TODO: Implement repository ingestion
return {
"status": "not_implemented",
"project_id": project_id,
"repo_path": repo_path,
}
try:
request = DeleteRequest(
project_id=project_id,
agent_id=agent_id,
source_path=source_path,
collection=collection,
chunk_ids=chunk_ids,
)
result = await _collections.delete(request) # type: ignore[union-attr]
return {
"success": result.success,
"chunks_deleted": result.chunks_deleted,
"error": result.error,
}
except KnowledgeBaseError as e:
logger.error(f"Delete error: {e}")
return {
"success": False,
"error": e.message,
"code": e.code.value,
}
except Exception as e:
logger.error(f"Unexpected delete error: {e}")
return {
"success": False,
"error": str(e),
}
@mcp.tool()
async def delete_document(
project_id: str,
source_path: str,
) -> dict:
async def list_collections(
project_id: str = Field(..., description="Project ID for scoping"),
agent_id: str = Field(..., description="Agent ID making the request"), # noqa: ARG001
) -> dict[str, Any]:
"""
Delete a document from the knowledge base.
List all collections in a project's knowledge base.
Args:
project_id: UUID of the project
source_path: Original file path
Returns:
Deletion result
Returns collection names with chunk counts and file types.
"""
# TODO: Implement document deletion
return {
"status": "not_implemented",
"project_id": project_id,
"source_path": source_path,
}
try:
result = await _collections.list_collections(project_id) # type: ignore[union-attr]
return {
"success": True,
"project_id": result.project_id,
"collections": [
{
"name": c.name,
"chunk_count": c.chunk_count,
"total_tokens": c.total_tokens,
"file_types": c.file_types,
"created_at": c.created_at.isoformat(),
"updated_at": c.updated_at.isoformat(),
}
for c in result.collections
],
"total_collections": result.total_collections,
}
except KnowledgeBaseError as e:
logger.error(f"List collections error: {e}")
return {
"success": False,
"error": e.message,
"code": e.code.value,
}
except Exception as e:
logger.error(f"Unexpected list collections error: {e}")
return {
"success": False,
"error": str(e),
}
@mcp.tool()
async def get_collection_stats(project_id: str) -> dict:
async def get_collection_stats(
project_id: str = Field(..., description="Project ID for scoping"),
agent_id: str = Field(..., description="Agent ID making the request"), # noqa: ARG001
collection: str = Field(..., description="Collection name"),
) -> dict[str, Any]:
"""
Get statistics for a project's knowledge base collection.
Get detailed statistics for a collection.
Args:
project_id: UUID of the project
Returns:
Collection statistics (document count, chunk count, etc.)
Returns chunk counts, token totals, and type breakdowns.
"""
# TODO: Implement collection stats
return {
"status": "not_implemented",
"project_id": project_id,
}
try:
stats = await _collections.get_collection_stats(project_id, collection) # type: ignore[union-attr]
return {
"success": True,
"collection": stats.collection,
"project_id": stats.project_id,
"chunk_count": stats.chunk_count,
"unique_sources": stats.unique_sources,
"total_tokens": stats.total_tokens,
"avg_chunk_size": stats.avg_chunk_size,
"chunk_types": stats.chunk_types,
"file_types": stats.file_types,
"oldest_chunk": stats.oldest_chunk.isoformat() if stats.oldest_chunk else None,
"newest_chunk": stats.newest_chunk.isoformat() if stats.newest_chunk else None,
}
except KnowledgeBaseError as e:
logger.error(f"Get collection stats error: {e}")
return {
"success": False,
"error": e.message,
"code": e.code.value,
}
except Exception as e:
logger.error(f"Unexpected get collection stats error: {e}")
return {
"success": False,
"error": str(e),
}
@mcp.tool()
async def update_document(
project_id: str = Field(..., description="Project ID for scoping"),
agent_id: str = Field(..., description="Agent ID making the request"),
source_path: str = Field(..., description="Source file path"),
content: str = Field(..., description="New content"),
collection: str = Field(
default="default",
description="Collection name",
),
chunk_type: str = Field(
default="text",
description="Content type: code, markdown, or text",
),
file_type: str | None = Field(
default=None,
description="File type for code chunking",
),
metadata: dict[str, Any] | None = Field(
default=None,
description="Additional metadata",
),
) -> dict[str, Any]:
"""
Update a document in the knowledge base.
Replaces all existing chunks for the source path with new content.
"""
try:
# Parse chunk type
try:
chunk_type_enum = ChunkType(chunk_type.lower())
except ValueError:
valid_types = [t.value for t in ChunkType]
return {
"success": False,
"error": f"Invalid chunk type: {chunk_type}. Valid types: {valid_types}",
}
# Parse file type
file_type_enum = None
if file_type:
try:
file_type_enum = FileType(file_type.lower())
except ValueError:
valid_types = [t.value for t in FileType]
return {
"success": False,
"error": f"Invalid file type: {file_type}. Valid types: {valid_types}",
}
result = await _collections.update_document( # type: ignore[union-attr]
project_id=project_id,
agent_id=agent_id,
source_path=source_path,
content=content,
collection=collection,
chunk_type=chunk_type_enum,
file_type=file_type_enum,
metadata=metadata,
)
return {
"success": result.success,
"chunks_created": result.chunks_created,
"embeddings_generated": result.embeddings_generated,
"source_path": result.source_path,
"collection": result.collection,
"chunk_ids": result.chunk_ids,
"error": result.error,
}
except KnowledgeBaseError as e:
logger.error(f"Update document error: {e}")
return {
"success": False,
"error": e.message,
"code": e.code.value,
}
except Exception as e:
logger.error(f"Unexpected update document error: {e}")
return {
"success": False,
"error": str(e),
}
def main() -> None:
"""Run the server."""
import uvicorn
settings = get_settings()
uvicorn.run(
"server:app",
host=settings.host,
port=settings.port,
reload=settings.debug,
log_level="info",
)
if __name__ == "__main__":
mcp.run()
main()

View File

@@ -0,0 +1 @@
"""Tests for Knowledge Base MCP Server."""

View File

@@ -0,0 +1,282 @@
"""
Test fixtures for Knowledge Base MCP Server.
"""
import os
import sys
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock
import pytest
# Add parent directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Set test mode before importing modules
os.environ["IS_TEST"] = "true"
os.environ["KB_DATABASE_URL"] = "postgresql://test:test@localhost:5432/test"
os.environ["KB_REDIS_URL"] = "redis://localhost:6379/0"
os.environ["KB_LLM_GATEWAY_URL"] = "http://localhost:8001"
@pytest.fixture
def settings():
"""Create test settings."""
from config import Settings, reset_settings
reset_settings()
return Settings(
host="127.0.0.1",
port=8002,
debug=True,
database_url="postgresql://test:test@localhost:5432/test",
redis_url="redis://localhost:6379/0",
llm_gateway_url="http://localhost:8001",
embedding_dimension=1536,
code_chunk_size=500,
code_chunk_overlap=50,
markdown_chunk_size=800,
markdown_chunk_overlap=100,
text_chunk_size=400,
text_chunk_overlap=50,
)
@pytest.fixture
def mock_database():
"""Create mock database manager."""
from database import DatabaseManager
mock_db = MagicMock(spec=DatabaseManager)
mock_db._pool = MagicMock()
mock_db.acquire = MagicMock(return_value=AsyncMock())
# Mock database methods
mock_db.initialize = AsyncMock()
mock_db.close = AsyncMock()
mock_db.store_embedding = AsyncMock(return_value="test-id-123")
mock_db.store_embeddings_batch = AsyncMock(return_value=["id-1", "id-2"])
mock_db.semantic_search = AsyncMock(return_value=[])
mock_db.keyword_search = AsyncMock(return_value=[])
mock_db.delete_by_source = AsyncMock(return_value=1)
mock_db.delete_collection = AsyncMock(return_value=5)
mock_db.delete_by_ids = AsyncMock(return_value=2)
mock_db.list_collections = AsyncMock(return_value=[])
mock_db.get_collection_stats = AsyncMock()
mock_db.cleanup_expired = AsyncMock(return_value=0)
return mock_db
@pytest.fixture
def mock_embeddings():
"""Create mock embedding generator."""
from embeddings import EmbeddingGenerator
mock_emb = MagicMock(spec=EmbeddingGenerator)
mock_emb.initialize = AsyncMock()
mock_emb.close = AsyncMock()
# Generate fake embeddings (1536 dimensions)
def fake_embedding() -> list[float]:
return [0.1] * 1536
mock_emb.generate = AsyncMock(return_value=fake_embedding())
mock_emb.generate_batch = AsyncMock(side_effect=lambda texts, **_kwargs: [fake_embedding() for _ in texts])
return mock_emb
@pytest.fixture
def mock_redis():
"""Create mock Redis client."""
import fakeredis.aioredis
return fakeredis.aioredis.FakeRedis()
@pytest.fixture
def sample_python_code():
"""Sample Python code for chunking tests."""
return '''"""Sample module for testing."""
import os
from typing import Any
class Calculator:
"""A simple calculator class."""
def __init__(self, initial: int = 0) -> None:
"""Initialize calculator."""
self.value = initial
def add(self, x: int) -> int:
"""Add a value."""
self.value += x
return self.value
def subtract(self, x: int) -> int:
"""Subtract a value."""
self.value -= x
return self.value
def helper_function(data: dict[str, Any]) -> str:
"""A helper function."""
return str(data)
async def async_function() -> None:
"""An async function."""
pass
'''
@pytest.fixture
def sample_markdown():
"""Sample Markdown content for chunking tests."""
return '''# Project Documentation
This is the main documentation for our project.
## Getting Started
To get started, follow these steps:
1. Install dependencies
2. Configure settings
3. Run the application
### Prerequisites
You'll need the following installed:
- Python 3.12+
- PostgreSQL
- Redis
```python
# Example code
def main():
print("Hello, World!")
```
## API Reference
### Search Endpoint
The search endpoint allows you to query the knowledge base.
**Endpoint:** `POST /api/search`
**Request:**
```json
{
"query": "your search query",
"limit": 10
}
```
## Contributing
We welcome contributions! Please see our contributing guide.
'''
@pytest.fixture
def sample_text():
"""Sample plain text for chunking tests."""
return '''The quick brown fox jumps over the lazy dog. This is a sample text that we use for testing the text chunking functionality. It contains multiple sentences that should be properly split into chunks.
Each paragraph represents a logical unit of text. The chunker should try to respect paragraph boundaries when possible. This helps maintain context and readability.
When chunks need to be split mid-paragraph, the chunker should prefer sentence boundaries. This ensures that each chunk contains complete thoughts and is useful for retrieval.
The final paragraph tests edge cases. What happens with short paragraphs? Do they get merged with adjacent content? Let's find out!
'''
@pytest.fixture
def sample_chunk():
"""Sample chunk for testing."""
from models import Chunk, ChunkType, FileType
return Chunk(
content="def hello():\n print('Hello')",
chunk_type=ChunkType.CODE,
file_type=FileType.PYTHON,
source_path="/test/hello.py",
start_line=1,
end_line=2,
metadata={"function": "hello"},
token_count=15,
)
@pytest.fixture
def sample_embedding():
"""Sample knowledge embedding for testing."""
from models import ChunkType, FileType, KnowledgeEmbedding
return KnowledgeEmbedding(
id="test-id-123",
project_id="proj-123",
collection="default",
content="def hello():\n print('Hello')",
embedding=[0.1] * 1536,
chunk_type=ChunkType.CODE,
source_path="/test/hello.py",
start_line=1,
end_line=2,
file_type=FileType.PYTHON,
metadata={"function": "hello"},
content_hash="abc123",
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
@pytest.fixture
def sample_ingest_request():
"""Sample ingest request for testing."""
from models import ChunkType, FileType, IngestRequest
return IngestRequest(
project_id="proj-123",
agent_id="agent-456",
content="def hello():\n print('Hello')",
source_path="/test/hello.py",
collection="default",
chunk_type=ChunkType.CODE,
file_type=FileType.PYTHON,
metadata={"test": True},
)
@pytest.fixture
def sample_search_request():
"""Sample search request for testing."""
from models import SearchRequest, SearchType
return SearchRequest(
project_id="proj-123",
agent_id="agent-456",
query="hello function",
search_type=SearchType.HYBRID,
collection="default",
limit=10,
threshold=0.7,
)
@pytest.fixture
def sample_delete_request():
"""Sample delete request for testing."""
from models import DeleteRequest
return DeleteRequest(
project_id="proj-123",
agent_id="agent-456",
source_path="/test/hello.py",
)

View File

@@ -0,0 +1,422 @@
"""Tests for chunking module."""
class TestBaseChunker:
"""Tests for base chunker functionality."""
def test_count_tokens(self, settings):
"""Test token counting."""
from chunking.text import TextChunker
chunker = TextChunker(
chunk_size=400,
chunk_overlap=50,
settings=settings,
)
# Simple text should count tokens
tokens = chunker.count_tokens("Hello, world!")
assert tokens > 0
assert tokens < 10 # Should be about 3-4 tokens
def test_truncate_to_tokens(self, settings):
"""Test truncating text to token limit."""
from chunking.text import TextChunker
chunker = TextChunker(
chunk_size=400,
chunk_overlap=50,
settings=settings,
)
long_text = "word " * 1000
truncated = chunker.truncate_to_tokens(long_text, 10)
assert chunker.count_tokens(truncated) <= 10
class TestCodeChunker:
"""Tests for code chunker."""
def test_chunk_python_code(self, settings, sample_python_code):
"""Test chunking Python code."""
from chunking.code import CodeChunker
from models import ChunkType, FileType
chunker = CodeChunker(
chunk_size=500,
chunk_overlap=50,
settings=settings,
)
chunks = chunker.chunk(
content=sample_python_code,
source_path="/test/sample.py",
file_type=FileType.PYTHON,
)
assert len(chunks) > 0
assert all(c.chunk_type == ChunkType.CODE for c in chunks)
assert all(c.file_type == FileType.PYTHON for c in chunks)
def test_preserves_function_boundaries(self, settings):
"""Test that chunker preserves function boundaries."""
from chunking.code import CodeChunker
from models import FileType
code = '''def function_one():
"""First function."""
return 1
def function_two():
"""Second function."""
return 2
'''
chunker = CodeChunker(
chunk_size=100,
chunk_overlap=10,
settings=settings,
)
chunks = chunker.chunk(
content=code,
source_path="/test/funcs.py",
file_type=FileType.PYTHON,
)
# Each function should ideally be in its own chunk
assert len(chunks) >= 1
for chunk in chunks:
# Check chunks have line numbers
assert chunk.start_line is not None
assert chunk.end_line is not None
assert chunk.start_line <= chunk.end_line
def test_handles_empty_content(self, settings):
"""Test handling empty content."""
from chunking.code import CodeChunker
chunker = CodeChunker(
chunk_size=500,
chunk_overlap=50,
settings=settings,
)
chunks = chunker.chunk(content="", source_path="/test/empty.py")
assert chunks == []
def test_chunk_type_is_code(self, settings):
"""Test that chunk_type property returns CODE."""
from chunking.code import CodeChunker
from models import ChunkType
chunker = CodeChunker(
chunk_size=500,
chunk_overlap=50,
settings=settings,
)
assert chunker.chunk_type == ChunkType.CODE
class TestMarkdownChunker:
"""Tests for markdown chunker."""
def test_chunk_markdown(self, settings, sample_markdown):
"""Test chunking markdown content."""
from chunking.markdown import MarkdownChunker
from models import ChunkType, FileType
chunker = MarkdownChunker(
chunk_size=800,
chunk_overlap=100,
settings=settings,
)
chunks = chunker.chunk(
content=sample_markdown,
source_path="/test/docs.md",
file_type=FileType.MARKDOWN,
)
assert len(chunks) > 0
assert all(c.chunk_type == ChunkType.MARKDOWN for c in chunks)
def test_respects_heading_hierarchy(self, settings):
"""Test that chunker respects heading hierarchy."""
from chunking.markdown import MarkdownChunker
markdown = '''# Main Title
Introduction paragraph.
## Section One
Content for section one.
### Subsection
More detailed content.
## Section Two
Content for section two.
'''
chunker = MarkdownChunker(
chunk_size=200,
chunk_overlap=20,
settings=settings,
)
chunks = chunker.chunk(
content=markdown,
source_path="/test/docs.md",
)
# Should have multiple chunks based on sections
assert len(chunks) >= 1
# Metadata should include heading context
for chunk in chunks:
# Chunks should have content
assert len(chunk.content) > 0
def test_handles_code_blocks(self, settings):
"""Test handling of code blocks in markdown."""
from chunking.markdown import MarkdownChunker
markdown = '''# Code Example
Here's some code:
```python
def hello():
print("Hello, World!")
```
End of example.
'''
chunker = MarkdownChunker(
chunk_size=500,
chunk_overlap=50,
settings=settings,
)
chunks = chunker.chunk(
content=markdown,
source_path="/test/code.md",
)
# Code blocks should be preserved
assert len(chunks) >= 1
full_content = " ".join(c.content for c in chunks)
assert "```python" in full_content or "def hello" in full_content
def test_chunk_type_is_markdown(self, settings):
"""Test that chunk_type property returns MARKDOWN."""
from chunking.markdown import MarkdownChunker
from models import ChunkType
chunker = MarkdownChunker(
chunk_size=800,
chunk_overlap=100,
settings=settings,
)
assert chunker.chunk_type == ChunkType.MARKDOWN
class TestTextChunker:
"""Tests for text chunker."""
def test_chunk_text(self, settings, sample_text):
"""Test chunking plain text."""
from chunking.text import TextChunker
from models import ChunkType
chunker = TextChunker(
chunk_size=400,
chunk_overlap=50,
settings=settings,
)
chunks = chunker.chunk(
content=sample_text,
source_path="/test/docs.txt",
)
assert len(chunks) > 0
assert all(c.chunk_type == ChunkType.TEXT for c in chunks)
def test_respects_paragraph_boundaries(self, settings):
"""Test that chunker respects paragraph boundaries."""
from chunking.text import TextChunker
text = '''First paragraph with some content.
Second paragraph with different content.
Third paragraph to test chunking behavior.
'''
chunker = TextChunker(
chunk_size=100,
chunk_overlap=10,
settings=settings,
)
chunks = chunker.chunk(
content=text,
source_path="/test/text.txt",
)
assert len(chunks) >= 1
def test_handles_single_paragraph(self, settings):
"""Test handling of single paragraph that fits in one chunk."""
from chunking.text import TextChunker
text = "This is a short paragraph."
chunker = TextChunker(
chunk_size=400,
chunk_overlap=50,
settings=settings,
)
chunks = chunker.chunk(content=text, source_path="/test/short.txt")
assert len(chunks) == 1
assert chunks[0].content == text
def test_chunk_type_is_text(self, settings):
"""Test that chunk_type property returns TEXT."""
from chunking.text import TextChunker
from models import ChunkType
chunker = TextChunker(
chunk_size=400,
chunk_overlap=50,
settings=settings,
)
assert chunker.chunk_type == ChunkType.TEXT
class TestChunkerFactory:
"""Tests for chunker factory."""
def test_get_code_chunker(self, settings):
"""Test getting code chunker."""
from chunking.base import ChunkerFactory
from chunking.code import CodeChunker
from models import FileType
factory = ChunkerFactory(settings=settings)
chunker = factory.get_chunker(file_type=FileType.PYTHON)
assert isinstance(chunker, CodeChunker)
def test_get_markdown_chunker(self, settings):
"""Test getting markdown chunker."""
from chunking.base import ChunkerFactory
from chunking.markdown import MarkdownChunker
from models import FileType
factory = ChunkerFactory(settings=settings)
chunker = factory.get_chunker(file_type=FileType.MARKDOWN)
assert isinstance(chunker, MarkdownChunker)
def test_get_text_chunker(self, settings):
"""Test getting text chunker."""
from chunking.base import ChunkerFactory
from chunking.text import TextChunker
from models import FileType
factory = ChunkerFactory(settings=settings)
chunker = factory.get_chunker(file_type=FileType.TEXT)
assert isinstance(chunker, TextChunker)
def test_get_chunker_for_path(self, settings):
"""Test getting chunker based on file path."""
from chunking.base import ChunkerFactory
from chunking.code import CodeChunker
from chunking.markdown import MarkdownChunker
from models import FileType
factory = ChunkerFactory(settings=settings)
chunker, file_type = factory.get_chunker_for_path("/test/file.py")
assert isinstance(chunker, CodeChunker)
assert file_type == FileType.PYTHON
chunker, file_type = factory.get_chunker_for_path("/test/docs.md")
assert isinstance(chunker, MarkdownChunker)
assert file_type == FileType.MARKDOWN
def test_chunk_content(self, settings, sample_python_code):
"""Test chunk_content convenience method."""
from chunking.base import ChunkerFactory
from models import ChunkType
factory = ChunkerFactory(settings=settings)
chunks = factory.chunk_content(
content=sample_python_code,
source_path="/test/sample.py",
)
assert len(chunks) > 0
assert all(c.chunk_type == ChunkType.CODE for c in chunks)
def test_default_to_text_chunker(self, settings):
"""Test defaulting to text chunker."""
from chunking.base import ChunkerFactory
from chunking.text import TextChunker
factory = ChunkerFactory(settings=settings)
chunker = factory.get_chunker()
assert isinstance(chunker, TextChunker)
def test_chunker_caching(self, settings):
"""Test that factory caches chunker instances."""
from chunking.base import ChunkerFactory
from models import FileType
factory = ChunkerFactory(settings=settings)
chunker1 = factory.get_chunker(file_type=FileType.PYTHON)
chunker2 = factory.get_chunker(file_type=FileType.PYTHON)
assert chunker1 is chunker2
class TestGlobalChunkerFactory:
"""Tests for global chunker factory."""
def test_get_chunker_factory_singleton(self):
"""Test that get_chunker_factory returns singleton."""
from chunking.base import get_chunker_factory, reset_chunker_factory
reset_chunker_factory()
factory1 = get_chunker_factory()
factory2 = get_chunker_factory()
assert factory1 is factory2
def test_reset_chunker_factory(self):
"""Test resetting chunker factory."""
from chunking.base import get_chunker_factory, reset_chunker_factory
factory1 = get_chunker_factory()
reset_chunker_factory()
factory2 = get_chunker_factory()
assert factory1 is not factory2

View File

@@ -0,0 +1,240 @@
"""Tests for collection manager module."""
from datetime import UTC, datetime
from unittest.mock import MagicMock
import pytest
class TestCollectionManager:
"""Tests for CollectionManager class."""
@pytest.fixture
def collection_manager(self, settings, mock_database, mock_embeddings):
"""Create collection manager with mocks."""
from chunking.base import ChunkerFactory
from collection_manager import CollectionManager
mock_chunker_factory = MagicMock(spec=ChunkerFactory)
# Mock chunk_content to return chunks
from models import Chunk, ChunkType
mock_chunker_factory.chunk_content.return_value = [
Chunk(
content="def hello(): pass",
chunk_type=ChunkType.CODE,
token_count=10,
)
]
return CollectionManager(
settings=settings,
database=mock_database,
embeddings=mock_embeddings,
chunker_factory=mock_chunker_factory,
)
@pytest.mark.asyncio
async def test_ingest_content(self, collection_manager, sample_ingest_request):
"""Test content ingestion."""
result = await collection_manager.ingest(sample_ingest_request)
assert result.success is True
assert result.chunks_created == 1
assert result.embeddings_generated == 1
assert len(result.chunk_ids) == 1
assert result.collection == "default"
@pytest.mark.asyncio
async def test_ingest_empty_content(self, collection_manager):
"""Test ingesting empty content."""
from models import IngestRequest
# Mock chunker to return empty list
collection_manager._chunker_factory.chunk_content.return_value = []
request = IngestRequest(
project_id="proj-123",
agent_id="agent-456",
content="",
)
result = await collection_manager.ingest(request)
assert result.success is True
assert result.chunks_created == 0
assert result.embeddings_generated == 0
@pytest.mark.asyncio
async def test_ingest_error_handling(self, collection_manager, sample_ingest_request):
"""Test ingest error handling."""
# Make embedding generation fail
collection_manager._embeddings.generate_batch.side_effect = Exception("Embedding error")
result = await collection_manager.ingest(sample_ingest_request)
assert result.success is False
assert "Embedding error" in result.error
@pytest.mark.asyncio
async def test_delete_by_source(self, collection_manager, sample_delete_request):
"""Test deletion by source path."""
result = await collection_manager.delete(sample_delete_request)
assert result.success is True
assert result.chunks_deleted == 1 # Mock returns 1
collection_manager._database.delete_by_source.assert_called_once()
@pytest.mark.asyncio
async def test_delete_by_collection(self, collection_manager):
"""Test deletion by collection."""
from models import DeleteRequest
request = DeleteRequest(
project_id="proj-123",
agent_id="agent-456",
collection="to-delete",
)
result = await collection_manager.delete(request)
assert result.success is True
collection_manager._database.delete_collection.assert_called_once()
@pytest.mark.asyncio
async def test_delete_by_ids(self, collection_manager):
"""Test deletion by chunk IDs."""
from models import DeleteRequest
request = DeleteRequest(
project_id="proj-123",
agent_id="agent-456",
chunk_ids=["id-1", "id-2"],
)
result = await collection_manager.delete(request)
assert result.success is True
collection_manager._database.delete_by_ids.assert_called_once()
@pytest.mark.asyncio
async def test_delete_no_target(self, collection_manager):
"""Test deletion with no target specified."""
from models import DeleteRequest
request = DeleteRequest(
project_id="proj-123",
agent_id="agent-456",
)
result = await collection_manager.delete(request)
assert result.success is False
assert "Must specify" in result.error
@pytest.mark.asyncio
async def test_list_collections(self, collection_manager):
"""Test listing collections."""
from models import CollectionInfo
collection_manager._database.list_collections.return_value = [
CollectionInfo(
name="collection-1",
project_id="proj-123",
chunk_count=100,
total_tokens=50000,
file_types=["python"],
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
),
CollectionInfo(
name="collection-2",
project_id="proj-123",
chunk_count=50,
total_tokens=25000,
file_types=["javascript"],
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
),
]
result = await collection_manager.list_collections("proj-123")
assert result.project_id == "proj-123"
assert result.total_collections == 2
assert len(result.collections) == 2
@pytest.mark.asyncio
async def test_get_collection_stats(self, collection_manager):
"""Test getting collection statistics."""
from models import CollectionStats
expected_stats = CollectionStats(
collection="test-collection",
project_id="proj-123",
chunk_count=100,
unique_sources=10,
total_tokens=50000,
avg_chunk_size=500.0,
chunk_types={"code": 60, "text": 40},
file_types={"python": 50, "javascript": 10},
)
collection_manager._database.get_collection_stats.return_value = expected_stats
stats = await collection_manager.get_collection_stats("proj-123", "test-collection")
assert stats.chunk_count == 100
assert stats.unique_sources == 10
collection_manager._database.get_collection_stats.assert_called_once_with(
"proj-123", "test-collection"
)
@pytest.mark.asyncio
async def test_update_document(self, collection_manager):
"""Test updating a document."""
result = await collection_manager.update_document(
project_id="proj-123",
agent_id="agent-456",
source_path="/test/file.py",
content="def updated(): pass",
collection="default",
)
# Should delete first, then ingest
collection_manager._database.delete_by_source.assert_called_once()
assert result.success is True
@pytest.mark.asyncio
async def test_cleanup_expired(self, collection_manager):
"""Test cleaning up expired embeddings."""
collection_manager._database.cleanup_expired.return_value = 10
count = await collection_manager.cleanup_expired()
assert count == 10
collection_manager._database.cleanup_expired.assert_called_once()
class TestGlobalCollectionManager:
"""Tests for global collection manager."""
def test_get_collection_manager_singleton(self):
"""Test that get_collection_manager returns singleton."""
from collection_manager import get_collection_manager, reset_collection_manager
reset_collection_manager()
manager1 = get_collection_manager()
manager2 = get_collection_manager()
assert manager1 is manager2
def test_reset_collection_manager(self):
"""Test resetting collection manager."""
from collection_manager import get_collection_manager, reset_collection_manager
manager1 = get_collection_manager()
reset_collection_manager()
manager2 = get_collection_manager()
assert manager1 is not manager2

View File

@@ -0,0 +1,104 @@
"""Tests for configuration module."""
import os
class TestSettings:
"""Tests for Settings class."""
def test_default_values(self, settings):
"""Test default configuration values."""
assert settings.port == 8002
assert settings.embedding_dimension == 1536
assert settings.code_chunk_size == 500
assert settings.search_default_limit == 10
def test_env_prefix(self):
"""Test environment variable prefix."""
from config import Settings, reset_settings
reset_settings()
os.environ["KB_PORT"] = "9999"
settings = Settings()
assert settings.port == 9999
# Cleanup
del os.environ["KB_PORT"]
reset_settings()
def test_embedding_settings(self, settings):
"""Test embedding-related settings."""
assert settings.embedding_model == "text-embedding-3-large"
assert settings.embedding_batch_size == 100
assert settings.embedding_cache_ttl == 86400
def test_chunking_settings(self, settings):
"""Test chunking-related settings."""
assert settings.code_chunk_size == 500
assert settings.code_chunk_overlap == 50
assert settings.markdown_chunk_size == 800
assert settings.markdown_chunk_overlap == 100
assert settings.text_chunk_size == 400
assert settings.text_chunk_overlap == 50
def test_search_settings(self, settings):
"""Test search-related settings."""
assert settings.search_default_limit == 10
assert settings.search_max_limit == 100
assert settings.semantic_threshold == 0.7
assert settings.hybrid_semantic_weight == 0.7
assert settings.hybrid_keyword_weight == 0.3
class TestGetSettings:
"""Tests for get_settings function."""
def test_returns_singleton(self):
"""Test that get_settings returns singleton."""
from config import get_settings, reset_settings
reset_settings()
settings1 = get_settings()
settings2 = get_settings()
assert settings1 is settings2
def test_reset_creates_new_instance(self):
"""Test that reset_settings clears the singleton."""
from config import get_settings, reset_settings
settings1 = get_settings()
reset_settings()
settings2 = get_settings()
assert settings1 is not settings2
class TestIsTestMode:
"""Tests for is_test_mode function."""
def test_returns_true_when_set(self):
"""Test returns True when IS_TEST is set."""
from config import is_test_mode
old_value = os.environ.get("IS_TEST")
os.environ["IS_TEST"] = "true"
assert is_test_mode() is True
if old_value:
os.environ["IS_TEST"] = old_value
else:
del os.environ["IS_TEST"]
def test_returns_false_when_not_set(self):
"""Test returns False when IS_TEST is not set."""
from config import is_test_mode
old_value = os.environ.get("IS_TEST")
if old_value:
del os.environ["IS_TEST"]
assert is_test_mode() is False
if old_value:
os.environ["IS_TEST"] = old_value

View File

@@ -0,0 +1,245 @@
"""Tests for embedding generation module."""
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
class TestEmbeddingGenerator:
"""Tests for EmbeddingGenerator class."""
@pytest.fixture
def mock_http_response(self):
"""Create mock HTTP response."""
response = MagicMock()
response.status_code = 200
response.raise_for_status = MagicMock()
response.json.return_value = {
"result": {
"content": [
{
"text": json.dumps({
"embeddings": [[0.1] * 1536]
})
}
]
}
}
return response
@pytest.mark.asyncio
async def test_generate_single_embedding(self, settings, mock_redis, mock_http_response):
"""Test generating a single embedding."""
from embeddings import EmbeddingGenerator
generator = EmbeddingGenerator(settings=settings)
generator._redis = mock_redis
# Mock HTTP client
mock_client = AsyncMock()
mock_client.post.return_value = mock_http_response
generator._http_client = mock_client
embedding = await generator.generate(
text="Hello, world!",
project_id="proj-123",
agent_id="agent-456",
)
assert len(embedding) == 1536
mock_client.post.assert_called_once()
@pytest.mark.asyncio
async def test_generate_batch_embeddings(self, settings, mock_redis):
"""Test generating batch embeddings."""
from embeddings import EmbeddingGenerator
generator = EmbeddingGenerator(settings=settings)
generator._redis = mock_redis
# Mock HTTP client with batch response
mock_client = AsyncMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.raise_for_status = MagicMock()
mock_response.json.return_value = {
"result": {
"content": [
{
"text": json.dumps({
"embeddings": [[0.1] * 1536, [0.2] * 1536, [0.3] * 1536]
})
}
]
}
}
mock_client.post.return_value = mock_response
generator._http_client = mock_client
embeddings = await generator.generate_batch(
texts=["Text 1", "Text 2", "Text 3"],
project_id="proj-123",
agent_id="agent-456",
)
assert len(embeddings) == 3
assert all(len(e) == 1536 for e in embeddings)
@pytest.mark.asyncio
async def test_caching(self, settings, mock_redis):
"""Test embedding caching."""
from embeddings import EmbeddingGenerator
generator = EmbeddingGenerator(settings=settings)
generator._redis = mock_redis
# Pre-populate cache
cache_key = generator._cache_key("Hello, world!")
await mock_redis.setex(cache_key, 3600, json.dumps([0.5] * 1536))
# Mock HTTP client (should not be called)
mock_client = AsyncMock()
generator._http_client = mock_client
embedding = await generator.generate(
text="Hello, world!",
project_id="proj-123",
agent_id="agent-456",
)
# Should return cached embedding
assert len(embedding) == 1536
assert embedding[0] == 0.5
mock_client.post.assert_not_called()
@pytest.mark.asyncio
async def test_cache_miss(self, settings, mock_redis, mock_http_response):
"""Test embedding cache miss."""
from embeddings import EmbeddingGenerator
generator = EmbeddingGenerator(settings=settings)
generator._redis = mock_redis
mock_client = AsyncMock()
mock_client.post.return_value = mock_http_response
generator._http_client = mock_client
embedding = await generator.generate(
text="New text not in cache",
project_id="proj-123",
agent_id="agent-456",
)
assert len(embedding) == 1536
mock_client.post.assert_called_once()
def test_cache_key_generation(self, settings):
"""Test cache key generation."""
from embeddings import EmbeddingGenerator
generator = EmbeddingGenerator(settings=settings)
key1 = generator._cache_key("Hello")
key2 = generator._cache_key("Hello")
key3 = generator._cache_key("World")
assert key1 == key2
assert key1 != key3
assert key1.startswith("kb:emb:")
@pytest.mark.asyncio
async def test_dimension_validation(self, settings, mock_redis):
"""Test embedding dimension validation."""
from embeddings import EmbeddingGenerator
from exceptions import EmbeddingDimensionMismatchError
generator = EmbeddingGenerator(settings=settings)
generator._redis = mock_redis
# Mock HTTP client with wrong dimension
mock_client = AsyncMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.raise_for_status = MagicMock()
mock_response.json.return_value = {
"result": {
"content": [
{
"text": json.dumps({
"embeddings": [[0.1] * 768] # Wrong dimension
})
}
]
}
}
mock_client.post.return_value = mock_response
generator._http_client = mock_client
with pytest.raises(EmbeddingDimensionMismatchError):
await generator.generate(
text="Test text",
project_id="proj-123",
agent_id="agent-456",
)
@pytest.mark.asyncio
async def test_empty_batch(self, settings, mock_redis):
"""Test generating embeddings for empty batch."""
from embeddings import EmbeddingGenerator
generator = EmbeddingGenerator(settings=settings)
generator._redis = mock_redis
embeddings = await generator.generate_batch(
texts=[],
project_id="proj-123",
agent_id="agent-456",
)
assert embeddings == []
@pytest.mark.asyncio
async def test_initialize_and_close(self, settings):
"""Test initialize and close methods."""
from embeddings import EmbeddingGenerator
generator = EmbeddingGenerator(settings=settings)
# Mock successful initialization
with patch("embeddings.redis.from_url") as mock_redis_from_url:
mock_redis_client = AsyncMock()
mock_redis_client.ping = AsyncMock()
mock_redis_from_url.return_value = mock_redis_client
await generator.initialize()
assert generator._http_client is not None
await generator.close()
assert generator._http_client is None
class TestGlobalEmbeddingGenerator:
"""Tests for global embedding generator."""
def test_get_embedding_generator_singleton(self):
"""Test that get_embedding_generator returns singleton."""
from embeddings import get_embedding_generator, reset_embedding_generator
reset_embedding_generator()
gen1 = get_embedding_generator()
gen2 = get_embedding_generator()
assert gen1 is gen2
def test_reset_embedding_generator(self):
"""Test resetting embedding generator."""
from embeddings import get_embedding_generator, reset_embedding_generator
gen1 = get_embedding_generator()
reset_embedding_generator()
gen2 = get_embedding_generator()
assert gen1 is not gen2

View File

@@ -0,0 +1,307 @@
"""Tests for exception classes."""
class TestErrorCode:
"""Tests for ErrorCode enum."""
def test_error_code_values(self):
"""Test error code values."""
from exceptions import ErrorCode
assert ErrorCode.UNKNOWN_ERROR.value == "KB_UNKNOWN_ERROR"
assert ErrorCode.DATABASE_CONNECTION_ERROR.value == "KB_DATABASE_CONNECTION_ERROR"
assert ErrorCode.EMBEDDING_GENERATION_ERROR.value == "KB_EMBEDDING_GENERATION_ERROR"
assert ErrorCode.CHUNKING_ERROR.value == "KB_CHUNKING_ERROR"
assert ErrorCode.SEARCH_ERROR.value == "KB_SEARCH_ERROR"
assert ErrorCode.COLLECTION_NOT_FOUND.value == "KB_COLLECTION_NOT_FOUND"
assert ErrorCode.DOCUMENT_NOT_FOUND.value == "KB_DOCUMENT_NOT_FOUND"
class TestKnowledgeBaseError:
"""Tests for base exception class."""
def test_basic_error(self):
"""Test basic error creation."""
from exceptions import ErrorCode, KnowledgeBaseError
error = KnowledgeBaseError(
message="Something went wrong",
code=ErrorCode.UNKNOWN_ERROR,
)
assert error.message == "Something went wrong"
assert error.code == ErrorCode.UNKNOWN_ERROR
assert error.details == {}
assert error.cause is None
def test_error_with_details(self):
"""Test error with details."""
from exceptions import ErrorCode, KnowledgeBaseError
error = KnowledgeBaseError(
message="Query failed",
code=ErrorCode.DATABASE_QUERY_ERROR,
details={"query": "SELECT * FROM table", "error_code": 42},
)
assert error.details["query"] == "SELECT * FROM table"
assert error.details["error_code"] == 42
def test_error_with_cause(self):
"""Test error with underlying cause."""
from exceptions import ErrorCode, KnowledgeBaseError
original = ValueError("Original error")
error = KnowledgeBaseError(
message="Wrapped error",
code=ErrorCode.INTERNAL_ERROR,
cause=original,
)
assert error.cause is original
assert isinstance(error.cause, ValueError)
def test_to_dict(self):
"""Test to_dict method."""
from exceptions import ErrorCode, KnowledgeBaseError
error = KnowledgeBaseError(
message="Test error",
code=ErrorCode.INVALID_REQUEST,
details={"field": "value"},
)
result = error.to_dict()
assert result["error"] == "KB_INVALID_REQUEST"
assert result["message"] == "Test error"
assert result["details"]["field"] == "value"
def test_str_representation(self):
"""Test string representation."""
from exceptions import ErrorCode, KnowledgeBaseError
error = KnowledgeBaseError(
message="Test error",
code=ErrorCode.INVALID_REQUEST,
)
assert str(error) == "[KB_INVALID_REQUEST] Test error"
def test_repr_representation(self):
"""Test repr representation."""
from exceptions import ErrorCode, KnowledgeBaseError
error = KnowledgeBaseError(
message="Test error",
code=ErrorCode.INVALID_REQUEST,
details={"key": "value"},
)
repr_str = repr(error)
assert "KnowledgeBaseError" in repr_str
assert "Test error" in repr_str
assert "KB_INVALID_REQUEST" in repr_str
class TestDatabaseErrors:
"""Tests for database-related exceptions."""
def test_database_connection_error(self):
"""Test database connection error."""
from exceptions import DatabaseConnectionError, ErrorCode
error = DatabaseConnectionError(
message="Cannot connect to database",
details={"host": "localhost", "port": 5432},
)
assert error.code == ErrorCode.DATABASE_CONNECTION_ERROR
assert error.details["host"] == "localhost"
def test_database_connection_error_default_message(self):
"""Test database connection error with default message."""
from exceptions import DatabaseConnectionError
error = DatabaseConnectionError()
assert error.message == "Failed to connect to database"
def test_database_query_error(self):
"""Test database query error."""
from exceptions import DatabaseQueryError, ErrorCode
error = DatabaseQueryError(
message="Query failed",
query="SELECT * FROM missing_table",
)
assert error.code == ErrorCode.DATABASE_QUERY_ERROR
assert error.details["query"] == "SELECT * FROM missing_table"
class TestEmbeddingErrors:
"""Tests for embedding-related exceptions."""
def test_embedding_generation_error(self):
"""Test embedding generation error."""
from exceptions import EmbeddingGenerationError, ErrorCode
error = EmbeddingGenerationError(
message="Failed to generate",
texts_count=10,
)
assert error.code == ErrorCode.EMBEDDING_GENERATION_ERROR
assert error.details["texts_count"] == 10
def test_embedding_dimension_mismatch(self):
"""Test embedding dimension mismatch error."""
from exceptions import EmbeddingDimensionMismatchError, ErrorCode
error = EmbeddingDimensionMismatchError(
expected=1536,
actual=768,
)
assert error.code == ErrorCode.EMBEDDING_DIMENSION_MISMATCH
assert "expected 1536" in error.message
assert "got 768" in error.message
assert error.details["expected_dimension"] == 1536
assert error.details["actual_dimension"] == 768
class TestChunkingErrors:
"""Tests for chunking-related exceptions."""
def test_unsupported_file_type_error(self):
"""Test unsupported file type error."""
from exceptions import ErrorCode, UnsupportedFileTypeError
error = UnsupportedFileTypeError(
file_type=".xyz",
supported_types=[".py", ".js", ".md"],
)
assert error.code == ErrorCode.UNSUPPORTED_FILE_TYPE
assert error.details["file_type"] == ".xyz"
assert len(error.details["supported_types"]) == 3
def test_file_too_large_error(self):
"""Test file too large error."""
from exceptions import ErrorCode, FileTooLargeError
error = FileTooLargeError(
file_size=10_000_000,
max_size=1_000_000,
)
assert error.code == ErrorCode.FILE_TOO_LARGE
assert error.details["file_size"] == 10_000_000
assert error.details["max_size"] == 1_000_000
def test_encoding_error(self):
"""Test encoding error."""
from exceptions import EncodingError, ErrorCode
error = EncodingError(
message="Cannot decode file",
encoding="utf-8",
)
assert error.code == ErrorCode.ENCODING_ERROR
assert error.details["encoding"] == "utf-8"
class TestSearchErrors:
"""Tests for search-related exceptions."""
def test_invalid_search_type_error(self):
"""Test invalid search type error."""
from exceptions import ErrorCode, InvalidSearchTypeError
error = InvalidSearchTypeError(
search_type="invalid",
valid_types=["semantic", "keyword", "hybrid"],
)
assert error.code == ErrorCode.INVALID_SEARCH_TYPE
assert error.details["search_type"] == "invalid"
assert len(error.details["valid_types"]) == 3
def test_search_timeout_error(self):
"""Test search timeout error."""
from exceptions import ErrorCode, SearchTimeoutError
error = SearchTimeoutError(timeout=30.0)
assert error.code == ErrorCode.SEARCH_TIMEOUT
assert error.details["timeout"] == 30.0
assert "30" in error.message
class TestCollectionErrors:
"""Tests for collection-related exceptions."""
def test_collection_not_found_error(self):
"""Test collection not found error."""
from exceptions import CollectionNotFoundError, ErrorCode
error = CollectionNotFoundError(
collection="missing-collection",
project_id="proj-123",
)
assert error.code == ErrorCode.COLLECTION_NOT_FOUND
assert error.details["collection"] == "missing-collection"
assert error.details["project_id"] == "proj-123"
class TestDocumentErrors:
"""Tests for document-related exceptions."""
def test_document_not_found_error(self):
"""Test document not found error."""
from exceptions import DocumentNotFoundError, ErrorCode
error = DocumentNotFoundError(
source_path="/path/to/file.py",
project_id="proj-123",
)
assert error.code == ErrorCode.DOCUMENT_NOT_FOUND
assert error.details["source_path"] == "/path/to/file.py"
def test_invalid_document_error(self):
"""Test invalid document error."""
from exceptions import ErrorCode, InvalidDocumentError
error = InvalidDocumentError(
message="Empty content",
details={"reason": "no content"},
)
assert error.code == ErrorCode.INVALID_DOCUMENT
class TestProjectErrors:
"""Tests for project-related exceptions."""
def test_project_not_found_error(self):
"""Test project not found error."""
from exceptions import ErrorCode, ProjectNotFoundError
error = ProjectNotFoundError(project_id="missing-proj")
assert error.code == ErrorCode.PROJECT_NOT_FOUND
assert error.details["project_id"] == "missing-proj"
def test_project_access_denied_error(self):
"""Test project access denied error."""
from exceptions import ErrorCode, ProjectAccessDeniedError
error = ProjectAccessDeniedError(project_id="restricted-proj")
assert error.code == ErrorCode.PROJECT_ACCESS_DENIED
assert "restricted-proj" in error.message

View File

@@ -0,0 +1,347 @@
"""Tests for data models."""
from datetime import UTC, datetime
class TestEnums:
"""Tests for enum classes."""
def test_search_type_values(self):
"""Test SearchType enum values."""
from models import SearchType
assert SearchType.SEMANTIC.value == "semantic"
assert SearchType.KEYWORD.value == "keyword"
assert SearchType.HYBRID.value == "hybrid"
def test_chunk_type_values(self):
"""Test ChunkType enum values."""
from models import ChunkType
assert ChunkType.CODE.value == "code"
assert ChunkType.MARKDOWN.value == "markdown"
assert ChunkType.TEXT.value == "text"
assert ChunkType.DOCUMENTATION.value == "documentation"
def test_file_type_values(self):
"""Test FileType enum values."""
from models import FileType
assert FileType.PYTHON.value == "python"
assert FileType.JAVASCRIPT.value == "javascript"
assert FileType.TYPESCRIPT.value == "typescript"
assert FileType.MARKDOWN.value == "markdown"
class TestFileExtensionMap:
"""Tests for file extension mapping."""
def test_python_extensions(self):
"""Test Python file extensions."""
from models import FILE_EXTENSION_MAP, FileType
assert FILE_EXTENSION_MAP[".py"] == FileType.PYTHON
def test_javascript_extensions(self):
"""Test JavaScript file extensions."""
from models import FILE_EXTENSION_MAP, FileType
assert FILE_EXTENSION_MAP[".js"] == FileType.JAVASCRIPT
assert FILE_EXTENSION_MAP[".jsx"] == FileType.JAVASCRIPT
def test_typescript_extensions(self):
"""Test TypeScript file extensions."""
from models import FILE_EXTENSION_MAP, FileType
assert FILE_EXTENSION_MAP[".ts"] == FileType.TYPESCRIPT
assert FILE_EXTENSION_MAP[".tsx"] == FileType.TYPESCRIPT
def test_markdown_extensions(self):
"""Test Markdown file extensions."""
from models import FILE_EXTENSION_MAP, FileType
assert FILE_EXTENSION_MAP[".md"] == FileType.MARKDOWN
assert FILE_EXTENSION_MAP[".mdx"] == FileType.MARKDOWN
class TestChunk:
"""Tests for Chunk dataclass."""
def test_chunk_creation(self, sample_chunk):
"""Test chunk creation."""
from models import ChunkType, FileType
assert sample_chunk.content == "def hello():\n print('Hello')"
assert sample_chunk.chunk_type == ChunkType.CODE
assert sample_chunk.file_type == FileType.PYTHON
assert sample_chunk.source_path == "/test/hello.py"
assert sample_chunk.start_line == 1
assert sample_chunk.end_line == 2
assert sample_chunk.token_count == 15
def test_chunk_to_dict(self, sample_chunk):
"""Test chunk to_dict method."""
result = sample_chunk.to_dict()
assert result["content"] == "def hello():\n print('Hello')"
assert result["chunk_type"] == "code"
assert result["file_type"] == "python"
assert result["source_path"] == "/test/hello.py"
assert result["start_line"] == 1
assert result["end_line"] == 2
assert result["token_count"] == 15
class TestKnowledgeEmbedding:
"""Tests for KnowledgeEmbedding dataclass."""
def test_embedding_creation(self, sample_embedding):
"""Test embedding creation."""
assert sample_embedding.id == "test-id-123"
assert sample_embedding.project_id == "proj-123"
assert sample_embedding.collection == "default"
assert len(sample_embedding.embedding) == 1536
def test_embedding_to_dict(self, sample_embedding):
"""Test embedding to_dict method."""
result = sample_embedding.to_dict()
assert result["id"] == "test-id-123"
assert result["project_id"] == "proj-123"
assert result["collection"] == "default"
assert result["chunk_type"] == "code"
assert result["file_type"] == "python"
assert "embedding" not in result # Embedding excluded for size
class TestIngestRequest:
"""Tests for IngestRequest model."""
def test_ingest_request_creation(self, sample_ingest_request):
"""Test ingest request creation."""
from models import ChunkType, FileType
assert sample_ingest_request.project_id == "proj-123"
assert sample_ingest_request.agent_id == "agent-456"
assert sample_ingest_request.chunk_type == ChunkType.CODE
assert sample_ingest_request.file_type == FileType.PYTHON
assert sample_ingest_request.collection == "default"
def test_ingest_request_defaults(self):
"""Test ingest request default values."""
from models import ChunkType, IngestRequest
request = IngestRequest(
project_id="proj-123",
agent_id="agent-456",
content="test content",
)
assert request.collection == "default"
assert request.chunk_type == ChunkType.TEXT
assert request.file_type is None
assert request.metadata == {}
class TestIngestResult:
"""Tests for IngestResult model."""
def test_successful_result(self):
"""Test successful ingest result."""
from models import IngestResult
result = IngestResult(
success=True,
chunks_created=5,
embeddings_generated=5,
source_path="/test/file.py",
collection="default",
chunk_ids=["id1", "id2", "id3", "id4", "id5"],
)
assert result.success is True
assert result.chunks_created == 5
assert result.error is None
def test_failed_result(self):
"""Test failed ingest result."""
from models import IngestResult
result = IngestResult(
success=False,
chunks_created=0,
embeddings_generated=0,
collection="default",
chunk_ids=[],
error="Something went wrong",
)
assert result.success is False
assert result.error == "Something went wrong"
class TestSearchRequest:
"""Tests for SearchRequest model."""
def test_search_request_creation(self, sample_search_request):
"""Test search request creation."""
from models import SearchType
assert sample_search_request.project_id == "proj-123"
assert sample_search_request.query == "hello function"
assert sample_search_request.search_type == SearchType.HYBRID
assert sample_search_request.limit == 10
assert sample_search_request.threshold == 0.7
def test_search_request_defaults(self):
"""Test search request default values."""
from models import SearchRequest, SearchType
request = SearchRequest(
project_id="proj-123",
agent_id="agent-456",
query="test query",
)
assert request.search_type == SearchType.HYBRID
assert request.collection is None
assert request.limit == 10
assert request.threshold == 0.7
assert request.file_types is None
class TestSearchResult:
"""Tests for SearchResult model."""
def test_from_embedding(self, sample_embedding):
"""Test creating SearchResult from KnowledgeEmbedding."""
from models import SearchResult
result = SearchResult.from_embedding(sample_embedding, 0.95)
assert result.id == "test-id-123"
assert result.content == "def hello():\n print('Hello')"
assert result.score == 0.95
assert result.source_path == "/test/hello.py"
assert result.chunk_type == "code"
assert result.file_type == "python"
class TestSearchResponse:
"""Tests for SearchResponse model."""
def test_search_response(self):
"""Test search response creation."""
from models import SearchResponse, SearchResult
results = [
SearchResult(
id="id1",
content="test content 1",
score=0.95,
chunk_type="code",
collection="default",
),
SearchResult(
id="id2",
content="test content 2",
score=0.85,
chunk_type="text",
collection="default",
),
]
response = SearchResponse(
query="test query",
search_type="hybrid",
results=results,
total_results=2,
search_time_ms=15.5,
)
assert response.query == "test query"
assert len(response.results) == 2
assert response.search_time_ms == 15.5
class TestDeleteRequest:
"""Tests for DeleteRequest model."""
def test_delete_by_source(self, sample_delete_request):
"""Test delete request by source path."""
assert sample_delete_request.project_id == "proj-123"
assert sample_delete_request.source_path == "/test/hello.py"
assert sample_delete_request.collection is None
assert sample_delete_request.chunk_ids is None
def test_delete_by_collection(self):
"""Test delete request by collection."""
from models import DeleteRequest
request = DeleteRequest(
project_id="proj-123",
agent_id="agent-456",
collection="to-delete",
)
assert request.collection == "to-delete"
assert request.source_path is None
def test_delete_by_ids(self):
"""Test delete request by chunk IDs."""
from models import DeleteRequest
request = DeleteRequest(
project_id="proj-123",
agent_id="agent-456",
chunk_ids=["id1", "id2", "id3"],
)
assert len(request.chunk_ids) == 3
class TestCollectionInfo:
"""Tests for CollectionInfo model."""
def test_collection_info(self):
"""Test collection info creation."""
from models import CollectionInfo
info = CollectionInfo(
name="test-collection",
project_id="proj-123",
chunk_count=100,
total_tokens=50000,
file_types=["python", "javascript"],
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
assert info.name == "test-collection"
assert info.chunk_count == 100
assert len(info.file_types) == 2
class TestCollectionStats:
"""Tests for CollectionStats model."""
def test_collection_stats(self):
"""Test collection stats creation."""
from models import CollectionStats
stats = CollectionStats(
collection="test-collection",
project_id="proj-123",
chunk_count=100,
unique_sources=10,
total_tokens=50000,
avg_chunk_size=500.0,
chunk_types={"code": 60, "text": 40},
file_types={"python": 50, "javascript": 10},
oldest_chunk=datetime.now(UTC),
newest_chunk=datetime.now(UTC),
)
assert stats.chunk_count == 100
assert stats.unique_sources == 10
assert stats.chunk_types["code"] == 60

View File

@@ -0,0 +1,295 @@
"""Tests for search module."""
from datetime import UTC, datetime
import pytest
class TestSearchEngine:
"""Tests for SearchEngine class."""
@pytest.fixture
def search_engine(self, settings, mock_database, mock_embeddings):
"""Create search engine with mocks."""
from search import SearchEngine
engine = SearchEngine(
settings=settings,
database=mock_database,
embeddings=mock_embeddings,
)
return engine
@pytest.fixture
def sample_db_results(self):
"""Create sample database results."""
from models import ChunkType, FileType, KnowledgeEmbedding
return [
(
KnowledgeEmbedding(
id="id-1",
project_id="proj-123",
collection="default",
content="def hello(): pass",
embedding=[0.1] * 1536,
chunk_type=ChunkType.CODE,
source_path="/test/file.py",
file_type=FileType.PYTHON,
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
),
0.95,
),
(
KnowledgeEmbedding(
id="id-2",
project_id="proj-123",
collection="default",
content="def world(): pass",
embedding=[0.2] * 1536,
chunk_type=ChunkType.CODE,
source_path="/test/file2.py",
file_type=FileType.PYTHON,
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
),
0.85,
),
]
@pytest.mark.asyncio
async def test_semantic_search(self, search_engine, sample_search_request, sample_db_results):
"""Test semantic search."""
from models import SearchType
sample_search_request.search_type = SearchType.SEMANTIC
search_engine._database.semantic_search.return_value = sample_db_results
response = await search_engine.search(sample_search_request)
assert response.search_type == "semantic"
assert len(response.results) == 2
assert response.results[0].score == 0.95
search_engine._database.semantic_search.assert_called_once()
@pytest.mark.asyncio
async def test_keyword_search(self, search_engine, sample_search_request, sample_db_results):
"""Test keyword search."""
from models import SearchType
sample_search_request.search_type = SearchType.KEYWORD
search_engine._database.keyword_search.return_value = sample_db_results
response = await search_engine.search(sample_search_request)
assert response.search_type == "keyword"
assert len(response.results) == 2
search_engine._database.keyword_search.assert_called_once()
@pytest.mark.asyncio
async def test_hybrid_search(self, search_engine, sample_search_request, sample_db_results):
"""Test hybrid search."""
from models import SearchType
sample_search_request.search_type = SearchType.HYBRID
# Both searches return same results for simplicity
search_engine._database.semantic_search.return_value = sample_db_results
search_engine._database.keyword_search.return_value = sample_db_results
response = await search_engine.search(sample_search_request)
assert response.search_type == "hybrid"
# Results should be fused
assert len(response.results) >= 1
@pytest.mark.asyncio
async def test_search_with_collection_filter(self, search_engine, sample_search_request, sample_db_results):
"""Test search with collection filter."""
from models import SearchType
sample_search_request.search_type = SearchType.SEMANTIC
sample_search_request.collection = "specific-collection"
search_engine._database.semantic_search.return_value = sample_db_results
await search_engine.search(sample_search_request)
# Verify collection was passed to database
call_args = search_engine._database.semantic_search.call_args
assert call_args.kwargs["collection"] == "specific-collection"
@pytest.mark.asyncio
async def test_search_with_file_type_filter(self, search_engine, sample_search_request, sample_db_results):
"""Test search with file type filter."""
from models import FileType, SearchType
sample_search_request.search_type = SearchType.SEMANTIC
sample_search_request.file_types = [FileType.PYTHON]
search_engine._database.semantic_search.return_value = sample_db_results
await search_engine.search(sample_search_request)
# Verify file types were passed to database
call_args = search_engine._database.semantic_search.call_args
assert call_args.kwargs["file_types"] == [FileType.PYTHON]
@pytest.mark.asyncio
async def test_search_respects_limit(self, search_engine, sample_search_request, sample_db_results):
"""Test that search respects result limit."""
from models import SearchType
sample_search_request.search_type = SearchType.SEMANTIC
sample_search_request.limit = 1
search_engine._database.semantic_search.return_value = sample_db_results[:1]
response = await search_engine.search(sample_search_request)
assert len(response.results) <= 1
@pytest.mark.asyncio
async def test_search_records_time(self, search_engine, sample_search_request, sample_db_results):
"""Test that search records time."""
from models import SearchType
sample_search_request.search_type = SearchType.SEMANTIC
search_engine._database.semantic_search.return_value = sample_db_results
response = await search_engine.search(sample_search_request)
assert response.search_time_ms > 0
@pytest.mark.asyncio
async def test_invalid_search_type(self, search_engine, sample_search_request):
"""Test handling invalid search type."""
from exceptions import InvalidSearchTypeError
# Force invalid search type
sample_search_request.search_type = "invalid"
with pytest.raises((InvalidSearchTypeError, ValueError)):
await search_engine.search(sample_search_request)
@pytest.mark.asyncio
async def test_empty_results(self, search_engine, sample_search_request):
"""Test search with no results."""
from models import SearchType
sample_search_request.search_type = SearchType.SEMANTIC
search_engine._database.semantic_search.return_value = []
response = await search_engine.search(sample_search_request)
assert len(response.results) == 0
assert response.total_results == 0
class TestReciprocalRankFusion:
"""Tests for reciprocal rank fusion."""
@pytest.fixture
def search_engine(self, settings, mock_database, mock_embeddings):
"""Create search engine with mocks."""
from search import SearchEngine
return SearchEngine(
settings=settings,
database=mock_database,
embeddings=mock_embeddings,
)
def test_fusion_combines_results(self, search_engine):
"""Test that RRF combines results from both searches."""
from models import SearchResult
semantic = [
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"),
SearchResult(id="b", content="B", score=0.8, chunk_type="code", collection="default"),
]
keyword = [
SearchResult(id="b", content="B", score=0.85, chunk_type="code", collection="default"),
SearchResult(id="c", content="C", score=0.7, chunk_type="code", collection="default"),
]
fused = search_engine._reciprocal_rank_fusion(semantic, keyword)
# Should have all unique results
ids = [r.id for r in fused]
assert "a" in ids
assert "b" in ids
assert "c" in ids
# B should be ranked higher (appears in both)
b_rank = ids.index("b")
assert b_rank < 2 # Should be in top 2
def test_fusion_respects_weights(self, search_engine):
"""Test that RRF respects semantic/keyword weights."""
from models import SearchResult
# Same results in same order
results = [
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"),
]
# High semantic weight
fused_semantic_heavy = search_engine._reciprocal_rank_fusion(
results, [],
semantic_weight=0.9,
keyword_weight=0.1,
)
# High keyword weight
fused_keyword_heavy = search_engine._reciprocal_rank_fusion(
[], results,
semantic_weight=0.1,
keyword_weight=0.9,
)
# Both should still return the result
assert len(fused_semantic_heavy) == 1
assert len(fused_keyword_heavy) == 1
def test_fusion_normalizes_scores(self, search_engine):
"""Test that RRF normalizes scores to 0-1."""
from models import SearchResult
semantic = [
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"),
SearchResult(id="b", content="B", score=0.8, chunk_type="code", collection="default"),
]
keyword = [
SearchResult(id="c", content="C", score=0.7, chunk_type="code", collection="default"),
]
fused = search_engine._reciprocal_rank_fusion(semantic, keyword)
# All scores should be between 0 and 1
for result in fused:
assert 0 <= result.score <= 1
class TestGlobalSearchEngine:
"""Tests for global search engine."""
def test_get_search_engine_singleton(self):
"""Test that get_search_engine returns singleton."""
from search import get_search_engine, reset_search_engine
reset_search_engine()
engine1 = get_search_engine()
engine2 = get_search_engine()
assert engine1 is engine2
def test_reset_search_engine(self):
"""Test resetting search engine."""
from search import get_search_engine, reset_search_engine
engine1 = get_search_engine()
reset_search_engine()
engine2 = get_search_engine()
assert engine1 is not engine2

View File

@@ -0,0 +1,357 @@
"""Tests for server module and MCP tools."""
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock
import pytest
class TestHealthCheck:
"""Tests for health check endpoint."""
@pytest.mark.asyncio
async def test_health_check_healthy(self):
"""Test health check when healthy."""
import server
# Create a proper async context manager mock
mock_conn = AsyncMock()
mock_conn.fetchval = AsyncMock(return_value=1)
mock_db = MagicMock()
mock_db._pool = MagicMock()
# Make acquire an async context manager
mock_cm = AsyncMock()
mock_cm.__aenter__.return_value = mock_conn
mock_cm.__aexit__.return_value = None
mock_db.acquire.return_value = mock_cm
server._database = mock_db
result = await server.health_check()
assert result["status"] == "healthy"
assert result["service"] == "knowledge-base"
assert result["database"] == "connected"
@pytest.mark.asyncio
async def test_health_check_no_database(self):
"""Test health check without database."""
import server
server._database = None
result = await server.health_check()
assert result["database"] == "not initialized"
class TestSearchKnowledgeTool:
"""Tests for search_knowledge MCP tool."""
@pytest.mark.asyncio
async def test_search_success(self):
"""Test successful search."""
import server
from models import SearchResponse, SearchResult
mock_search = MagicMock()
mock_search.search = AsyncMock(
return_value=SearchResponse(
query="test query",
search_type="hybrid",
results=[
SearchResult(
id="id-1",
content="Test content",
score=0.95,
source_path="/test/file.py",
chunk_type="code",
collection="default",
)
],
total_results=1,
search_time_ms=10.5,
)
)
server._search = mock_search
# Call the wrapped function via .fn
result = await server.search_knowledge.fn(
project_id="proj-123",
agent_id="agent-456",
query="test query",
search_type="hybrid",
collection=None,
limit=10,
threshold=0.7,
file_types=None,
)
assert result["success"] is True
assert len(result["results"]) == 1
assert result["results"][0]["score"] == 0.95
@pytest.mark.asyncio
async def test_search_invalid_type(self):
"""Test search with invalid search type."""
import server
result = await server.search_knowledge.fn(
project_id="proj-123",
agent_id="agent-456",
query="test",
search_type="invalid",
)
assert result["success"] is False
assert "Invalid search type" in result["error"]
@pytest.mark.asyncio
async def test_search_invalid_file_type(self):
"""Test search with invalid file type."""
import server
result = await server.search_knowledge.fn(
project_id="proj-123",
agent_id="agent-456",
query="test",
search_type="hybrid",
collection=None,
limit=10,
threshold=0.7,
file_types=["invalid_type"],
)
assert result["success"] is False
assert "Invalid file type" in result["error"]
class TestIngestContentTool:
"""Tests for ingest_content MCP tool."""
@pytest.mark.asyncio
async def test_ingest_success(self):
"""Test successful ingestion."""
import server
from models import IngestResult
mock_collections = MagicMock()
mock_collections.ingest = AsyncMock(
return_value=IngestResult(
success=True,
chunks_created=3,
embeddings_generated=3,
source_path="/test/file.py",
collection="default",
chunk_ids=["id-1", "id-2", "id-3"],
)
)
server._collections = mock_collections
result = await server.ingest_content.fn(
project_id="proj-123",
agent_id="agent-456",
content="def hello(): pass",
source_path="/test/file.py",
collection="default",
chunk_type="text",
file_type=None,
metadata=None,
)
assert result["success"] is True
assert result["chunks_created"] == 3
assert len(result["chunk_ids"]) == 3
@pytest.mark.asyncio
async def test_ingest_invalid_chunk_type(self):
"""Test ingest with invalid chunk type."""
import server
result = await server.ingest_content.fn(
project_id="proj-123",
agent_id="agent-456",
content="test content",
chunk_type="invalid",
)
assert result["success"] is False
assert "Invalid chunk type" in result["error"]
@pytest.mark.asyncio
async def test_ingest_invalid_file_type(self):
"""Test ingest with invalid file type."""
import server
result = await server.ingest_content.fn(
project_id="proj-123",
agent_id="agent-456",
content="test content",
source_path=None,
collection="default",
chunk_type="text",
file_type="invalid",
metadata=None,
)
assert result["success"] is False
assert "Invalid file type" in result["error"]
class TestDeleteContentTool:
"""Tests for delete_content MCP tool."""
@pytest.mark.asyncio
async def test_delete_success(self):
"""Test successful deletion."""
import server
from models import DeleteResult
mock_collections = MagicMock()
mock_collections.delete = AsyncMock(
return_value=DeleteResult(
success=True,
chunks_deleted=5,
)
)
server._collections = mock_collections
result = await server.delete_content.fn(
project_id="proj-123",
agent_id="agent-456",
source_path="/test/file.py",
collection=None,
chunk_ids=None,
)
assert result["success"] is True
assert result["chunks_deleted"] == 5
class TestListCollectionsTool:
"""Tests for list_collections MCP tool."""
@pytest.mark.asyncio
async def test_list_collections_success(self):
"""Test listing collections."""
import server
from models import CollectionInfo, ListCollectionsResponse
mock_collections = MagicMock()
mock_collections.list_collections = AsyncMock(
return_value=ListCollectionsResponse(
project_id="proj-123",
collections=[
CollectionInfo(
name="collection-1",
project_id="proj-123",
chunk_count=100,
total_tokens=50000,
file_types=["python"],
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
],
total_collections=1,
)
)
server._collections = mock_collections
result = await server.list_collections.fn(
project_id="proj-123",
agent_id="agent-456",
)
assert result["success"] is True
assert result["total_collections"] == 1
assert len(result["collections"]) == 1
class TestGetCollectionStatsTool:
"""Tests for get_collection_stats MCP tool."""
@pytest.mark.asyncio
async def test_get_stats_success(self):
"""Test getting collection stats."""
import server
from models import CollectionStats
mock_collections = MagicMock()
mock_collections.get_collection_stats = AsyncMock(
return_value=CollectionStats(
collection="test-collection",
project_id="proj-123",
chunk_count=100,
unique_sources=10,
total_tokens=50000,
avg_chunk_size=500.0,
chunk_types={"code": 60, "text": 40},
file_types={"python": 50, "javascript": 10},
)
)
server._collections = mock_collections
result = await server.get_collection_stats.fn(
project_id="proj-123",
agent_id="agent-456",
collection="test-collection",
)
assert result["success"] is True
assert result["chunk_count"] == 100
assert result["unique_sources"] == 10
class TestUpdateDocumentTool:
"""Tests for update_document MCP tool."""
@pytest.mark.asyncio
async def test_update_success(self):
"""Test updating a document."""
import server
from models import IngestResult
mock_collections = MagicMock()
mock_collections.update_document = AsyncMock(
return_value=IngestResult(
success=True,
chunks_created=2,
embeddings_generated=2,
source_path="/test/file.py",
collection="default",
chunk_ids=["id-1", "id-2"],
)
)
server._collections = mock_collections
result = await server.update_document.fn(
project_id="proj-123",
agent_id="agent-456",
source_path="/test/file.py",
content="def updated(): pass",
collection="default",
chunk_type="text",
file_type=None,
metadata=None,
)
assert result["success"] is True
assert result["chunks_created"] == 2
@pytest.mark.asyncio
async def test_update_invalid_chunk_type(self):
"""Test update with invalid chunk type."""
import server
result = await server.update_document.fn(
project_id="proj-123",
agent_id="agent-456",
source_path="/test/file.py",
content="test",
chunk_type="invalid",
)
assert result["success"] is False
assert "Invalid chunk type" in result["error"]

2026
mcp-servers/knowledge-base/uv.lock generated Normal file

File diff suppressed because it is too large Load Diff