forked from cardosofelipe/fast-next-template
Compare commits
7 Commits
18d717e996
...
f6194b3e19
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f6194b3e19 | ||
|
|
6bb376a336 | ||
|
|
cd7a9ccbdf | ||
|
|
953af52d0e | ||
|
|
e6e98d4ed1 | ||
|
|
ca5f5e3383 | ||
|
|
d0fc7f37ff |
31
CLAUDE.md
31
CLAUDE.md
@@ -83,6 +83,37 @@ docs/
|
|||||||
3. **Testing Required**: All code must be tested, aim for >90% coverage
|
3. **Testing Required**: All code must be tested, aim for >90% coverage
|
||||||
4. **Code Review**: Must pass multi-agent review before merge
|
4. **Code Review**: Must pass multi-agent review before merge
|
||||||
5. **No Direct Commits**: Never commit directly to `main` or `dev`
|
5. **No Direct Commits**: Never commit directly to `main` or `dev`
|
||||||
|
6. **Stack Verification**: ALWAYS run the full stack before considering work done (see below)
|
||||||
|
|
||||||
|
### CRITICAL: Stack Verification Before Merge
|
||||||
|
|
||||||
|
**This is NON-NEGOTIABLE. A feature with 100% test coverage that crashes on startup is WORTHLESS.**
|
||||||
|
|
||||||
|
Before considering ANY issue complete:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 1. Start the dev stack
|
||||||
|
make dev
|
||||||
|
|
||||||
|
# 2. Wait for backend to be healthy, check logs
|
||||||
|
docker compose -f docker-compose.dev.yml logs backend --tail=100
|
||||||
|
|
||||||
|
# 3. Start frontend
|
||||||
|
cd frontend && npm run dev
|
||||||
|
|
||||||
|
# 4. Verify both are running without errors
|
||||||
|
```
|
||||||
|
|
||||||
|
**The issue is NOT done if:**
|
||||||
|
- Backend crashes on startup (import errors, missing dependencies)
|
||||||
|
- Frontend fails to compile or render
|
||||||
|
- Health checks fail
|
||||||
|
- Any error appears in logs
|
||||||
|
|
||||||
|
**Why this matters:**
|
||||||
|
- Tests run in isolation and may pass despite broken imports
|
||||||
|
- Docker builds cache layers and may hide dependency issues
|
||||||
|
- A single `ModuleNotFoundError` renders all test coverage meaningless
|
||||||
|
|
||||||
### Common Commands
|
### Common Commands
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,10 @@ ENV PYTHONDONTWRITEBYTECODE=1 \
|
|||||||
PYTHONPATH=/app \
|
PYTHONPATH=/app \
|
||||||
UV_COMPILE_BYTECODE=1 \
|
UV_COMPILE_BYTECODE=1 \
|
||||||
UV_LINK_MODE=copy \
|
UV_LINK_MODE=copy \
|
||||||
UV_NO_CACHE=1
|
UV_NO_CACHE=1 \
|
||||||
|
UV_PROJECT_ENVIRONMENT=/opt/venv \
|
||||||
|
VIRTUAL_ENV=/opt/venv \
|
||||||
|
PATH="/opt/venv/bin:$PATH"
|
||||||
|
|
||||||
# Install system dependencies and uv
|
# Install system dependencies and uv
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
@@ -20,7 +23,7 @@ RUN apt-get update && \
|
|||||||
# Copy dependency files
|
# Copy dependency files
|
||||||
COPY pyproject.toml uv.lock ./
|
COPY pyproject.toml uv.lock ./
|
||||||
|
|
||||||
# Install dependencies using uv (development mode with dev dependencies)
|
# Install dependencies using uv into /opt/venv (outside /app to survive bind mounts)
|
||||||
RUN uv sync --extra dev --frozen
|
RUN uv sync --extra dev --frozen
|
||||||
|
|
||||||
# Copy application code
|
# Copy application code
|
||||||
@@ -45,7 +48,10 @@ ENV PYTHONDONTWRITEBYTECODE=1 \
|
|||||||
PYTHONPATH=/app \
|
PYTHONPATH=/app \
|
||||||
UV_COMPILE_BYTECODE=1 \
|
UV_COMPILE_BYTECODE=1 \
|
||||||
UV_LINK_MODE=copy \
|
UV_LINK_MODE=copy \
|
||||||
UV_NO_CACHE=1
|
UV_NO_CACHE=1 \
|
||||||
|
UV_PROJECT_ENVIRONMENT=/opt/venv \
|
||||||
|
VIRTUAL_ENV=/opt/venv \
|
||||||
|
PATH="/opt/venv/bin:$PATH"
|
||||||
|
|
||||||
# Install system dependencies and uv
|
# Install system dependencies and uv
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
@@ -58,7 +64,7 @@ RUN apt-get update && \
|
|||||||
# Copy dependency files
|
# Copy dependency files
|
||||||
COPY pyproject.toml uv.lock ./
|
COPY pyproject.toml uv.lock ./
|
||||||
|
|
||||||
# Install only production dependencies using uv (no dev dependencies)
|
# Install only production dependencies using uv into /opt/venv
|
||||||
RUN uv sync --frozen --no-dev
|
RUN uv sync --frozen --no-dev
|
||||||
|
|
||||||
# Copy application code
|
# Copy application code
|
||||||
@@ -67,7 +73,7 @@ COPY entrypoint.sh /usr/local/bin/
|
|||||||
RUN chmod +x /usr/local/bin/entrypoint.sh
|
RUN chmod +x /usr/local/bin/entrypoint.sh
|
||||||
|
|
||||||
# Set ownership to non-root user
|
# Set ownership to non-root user
|
||||||
RUN chown -R appuser:appuser /app
|
RUN chown -R appuser:appuser /app /opt/venv
|
||||||
|
|
||||||
# Switch to non-root user
|
# Switch to non-root user
|
||||||
USER appuser
|
USER appuser
|
||||||
@@ -77,4 +83,4 @@ HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
|
|||||||
CMD curl -f http://localhost:8000/health || exit 1
|
CMD curl -f http://localhost:8000/health || exit 1
|
||||||
|
|
||||||
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
||||||
CMD ["uv", "run", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
# Ensure the project's virtualenv binaries are on PATH so commands like
|
# Ensure the virtualenv binaries are on PATH. Dependencies are installed
|
||||||
# 'uvicorn' work even when not prefixed by 'uv run'. This matches how uv
|
# to /opt/venv (not /app/.venv) to survive bind mounts in development.
|
||||||
# installs the env into /app/.venv in our containers.
|
if [ -d "/opt/venv/bin" ]; then
|
||||||
if [ -d "/app/.venv/bin" ]; then
|
export PATH="/opt/venv/bin:$PATH"
|
||||||
export PATH="/app/.venv/bin:$PATH"
|
export VIRTUAL_ENV="/opt/venv"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Only the backend service should run migrations and init_db
|
# Only the backend service should run migrations and init_db
|
||||||
|
|||||||
@@ -40,8 +40,7 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- ./backend:/app
|
- ./backend:/app
|
||||||
- ./uploads:/app/uploads
|
- ./uploads:/app/uploads
|
||||||
# Exclude local .venv from bind mount to use container's .venv
|
# Note: venv is at /opt/venv (not /app/.venv) so bind mount doesn't affect it
|
||||||
- /app/.venv
|
|
||||||
ports:
|
ports:
|
||||||
- "8000:8000"
|
- "8000:8000"
|
||||||
env_file:
|
env_file:
|
||||||
@@ -76,7 +75,6 @@ services:
|
|||||||
target: development
|
target: development
|
||||||
volumes:
|
volumes:
|
||||||
- ./backend:/app
|
- ./backend:/app
|
||||||
- /app/.venv
|
|
||||||
env_file:
|
env_file:
|
||||||
- .env
|
- .env
|
||||||
environment:
|
environment:
|
||||||
@@ -99,7 +97,6 @@ services:
|
|||||||
target: development
|
target: development
|
||||||
volumes:
|
volumes:
|
||||||
- ./backend:/app
|
- ./backend:/app
|
||||||
- /app/.venv
|
|
||||||
env_file:
|
env_file:
|
||||||
- .env
|
- .env
|
||||||
environment:
|
environment:
|
||||||
@@ -122,7 +119,6 @@ services:
|
|||||||
target: development
|
target: development
|
||||||
volumes:
|
volumes:
|
||||||
- ./backend:/app
|
- ./backend:/app
|
||||||
- /app/.venv
|
|
||||||
env_file:
|
env_file:
|
||||||
- .env
|
- .env
|
||||||
environment:
|
environment:
|
||||||
@@ -145,7 +141,6 @@ services:
|
|||||||
target: development
|
target: development
|
||||||
volumes:
|
volumes:
|
||||||
- ./backend:/app
|
- ./backend:/app
|
||||||
- /app/.venv
|
|
||||||
env_file:
|
env_file:
|
||||||
- .env
|
- .env
|
||||||
environment:
|
environment:
|
||||||
|
|||||||
@@ -214,9 +214,9 @@ test(frontend): add unit tests for ProjectDashboard
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Phase 2+ Implementation Workflow
|
## Rigorous Implementation Workflow
|
||||||
|
|
||||||
**For complex infrastructure issues (Phase 2 MCP, core systems), follow this rigorous process:**
|
**This workflow applies to ALL feature implementations. Follow this process rigorously:**
|
||||||
|
|
||||||
### 1. Branch Setup
|
### 1. Branch Setup
|
||||||
```bash
|
```bash
|
||||||
@@ -257,12 +257,42 @@ Before closing an issue, perform deep review from multiple angles:
|
|||||||
|
|
||||||
**No stone unturned. No sloppy results. No unreviewed work.**
|
**No stone unturned. No sloppy results. No unreviewed work.**
|
||||||
|
|
||||||
### 5. Final Validation
|
### 5. Stack Verification (CRITICAL - NON-NEGOTIABLE)
|
||||||
|
|
||||||
|
**ALWAYS run the full stack and verify it boots correctly before considering ANY work done.**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Start the full development stack
|
||||||
|
make dev
|
||||||
|
|
||||||
|
# Check backend logs for startup errors
|
||||||
|
docker compose -f docker-compose.dev.yml logs backend --tail=100
|
||||||
|
|
||||||
|
# Start frontend separately
|
||||||
|
cd frontend && npm run dev
|
||||||
|
|
||||||
|
# Check frontend console for errors
|
||||||
|
```
|
||||||
|
|
||||||
|
**A feature is NOT complete if:**
|
||||||
|
- The stack doesn't boot
|
||||||
|
- There are import errors in logs
|
||||||
|
- Health checks fail
|
||||||
|
- Any component crashes on startup
|
||||||
|
|
||||||
|
**This rule exists because:**
|
||||||
|
- Tests can pass but the application won't start (import errors, missing deps)
|
||||||
|
- 90% test coverage is worthless if the app crashes on boot
|
||||||
|
- Docker builds can mask local issues
|
||||||
|
|
||||||
|
### 6. Final Validation Checklist
|
||||||
- [ ] All tests pass (unit, integration, E2E)
|
- [ ] All tests pass (unit, integration, E2E)
|
||||||
- [ ] Type checking passes
|
- [ ] Type checking passes
|
||||||
- [ ] Linting passes
|
- [ ] Linting passes
|
||||||
|
- [ ] **Stack boots successfully** (backend + frontend)
|
||||||
|
- [ ] **Logs show no errors**
|
||||||
|
- [ ] Coverage meets threshold (>90% backend, >90% frontend)
|
||||||
- [ ] Documentation updated
|
- [ ] Documentation updated
|
||||||
- [ ] Coverage meets threshold
|
|
||||||
- [ ] Issue checklist 100% complete
|
- [ ] Issue checklist 100% complete
|
||||||
- [ ] Multi-agent review passed
|
- [ ] Multi-agent review passed
|
||||||
|
|
||||||
|
|||||||
31
mcp-servers/knowledge-base/Dockerfile
Normal file
31
mcp-servers/knowledge-base/Dockerfile
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
FROM python:3.12-slim
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Install system dependencies
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
build-essential \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install uv for fast package installation
|
||||||
|
COPY --from=ghcr.io/astral-sh/uv:latest /uv /bin/uv
|
||||||
|
|
||||||
|
# Copy project files
|
||||||
|
COPY pyproject.toml ./
|
||||||
|
COPY *.py ./
|
||||||
|
COPY chunking/ ./chunking/
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
RUN uv pip install --system --no-cache .
|
||||||
|
|
||||||
|
# Create non-root user
|
||||||
|
RUN useradd --create-home --shell /bin/bash appuser
|
||||||
|
USER appuser
|
||||||
|
|
||||||
|
# Health check
|
||||||
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||||
|
CMD python -c "import httpx; httpx.get('http://localhost:8002/health').raise_for_status()"
|
||||||
|
|
||||||
|
EXPOSE 8002
|
||||||
|
|
||||||
|
CMD ["python", "server.py"]
|
||||||
19
mcp-servers/knowledge-base/chunking/__init__.py
Normal file
19
mcp-servers/knowledge-base/chunking/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
"""
|
||||||
|
Chunking module for Knowledge Base MCP Server.
|
||||||
|
|
||||||
|
Provides intelligent content chunking for different file types
|
||||||
|
with overlap and context preservation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from chunking.base import BaseChunker, ChunkerFactory
|
||||||
|
from chunking.code import CodeChunker
|
||||||
|
from chunking.markdown import MarkdownChunker
|
||||||
|
from chunking.text import TextChunker
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseChunker",
|
||||||
|
"ChunkerFactory",
|
||||||
|
"CodeChunker",
|
||||||
|
"MarkdownChunker",
|
||||||
|
"TextChunker",
|
||||||
|
]
|
||||||
281
mcp-servers/knowledge-base/chunking/base.py
Normal file
281
mcp-servers/knowledge-base/chunking/base.py
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
"""
|
||||||
|
Base chunker implementation.
|
||||||
|
|
||||||
|
Provides abstract interface and common utilities for content chunking.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
from config import Settings, get_settings
|
||||||
|
from exceptions import ChunkingError
|
||||||
|
from models import FILE_EXTENSION_MAP, Chunk, ChunkType, FileType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseChunker(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for content chunkers.
|
||||||
|
|
||||||
|
Subclasses implement specific chunking strategies for
|
||||||
|
different content types (code, markdown, text).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
chunk_size: int,
|
||||||
|
chunk_overlap: int,
|
||||||
|
settings: Settings | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize chunker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_size: Target tokens per chunk
|
||||||
|
chunk_overlap: Token overlap between chunks
|
||||||
|
settings: Application settings
|
||||||
|
"""
|
||||||
|
self._settings = settings or get_settings()
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
self.chunk_overlap = chunk_overlap
|
||||||
|
|
||||||
|
# Use cl100k_base encoding (GPT-4/text-embedding-3)
|
||||||
|
self._tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
|
def count_tokens(self, text: str) -> int:
|
||||||
|
"""Count tokens in text."""
|
||||||
|
return len(self._tokenizer.encode(text))
|
||||||
|
|
||||||
|
def truncate_to_tokens(self, text: str, max_tokens: int) -> str:
|
||||||
|
"""Truncate text to max tokens."""
|
||||||
|
tokens = self._tokenizer.encode(text)
|
||||||
|
if len(tokens) <= max_tokens:
|
||||||
|
return text
|
||||||
|
return self._tokenizer.decode(tokens[:max_tokens])
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def chunk(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
source_path: str | None = None,
|
||||||
|
file_type: FileType | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> list[Chunk]:
|
||||||
|
"""
|
||||||
|
Split content into chunks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content to chunk
|
||||||
|
source_path: Source file path for reference
|
||||||
|
file_type: File type for specialized handling
|
||||||
|
metadata: Additional metadata to include
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Chunk objects
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def chunk_type(self) -> ChunkType:
|
||||||
|
"""Get the chunk type this chunker produces."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _create_chunk(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
source_path: str | None = None,
|
||||||
|
start_line: int | None = None,
|
||||||
|
end_line: int | None = None,
|
||||||
|
file_type: FileType | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> Chunk:
|
||||||
|
"""Create a chunk with token count."""
|
||||||
|
token_count = self.count_tokens(content)
|
||||||
|
return Chunk(
|
||||||
|
content=content,
|
||||||
|
chunk_type=self.chunk_type,
|
||||||
|
file_type=file_type,
|
||||||
|
source_path=source_path,
|
||||||
|
start_line=start_line,
|
||||||
|
end_line=end_line,
|
||||||
|
metadata=metadata or {},
|
||||||
|
token_count=token_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkerFactory:
|
||||||
|
"""
|
||||||
|
Factory for creating appropriate chunkers.
|
||||||
|
|
||||||
|
Selects the best chunker based on file type or content.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, settings: Settings | None = None) -> None:
|
||||||
|
"""Initialize factory."""
|
||||||
|
self._settings = settings or get_settings()
|
||||||
|
self._chunkers: dict[str, BaseChunker] = {}
|
||||||
|
|
||||||
|
def _get_code_chunker(self) -> "BaseChunker":
|
||||||
|
"""Get or create code chunker."""
|
||||||
|
from chunking.code import CodeChunker
|
||||||
|
|
||||||
|
if "code" not in self._chunkers:
|
||||||
|
self._chunkers["code"] = CodeChunker(
|
||||||
|
chunk_size=self._settings.code_chunk_size,
|
||||||
|
chunk_overlap=self._settings.code_chunk_overlap,
|
||||||
|
settings=self._settings,
|
||||||
|
)
|
||||||
|
return self._chunkers["code"]
|
||||||
|
|
||||||
|
def _get_markdown_chunker(self) -> "BaseChunker":
|
||||||
|
"""Get or create markdown chunker."""
|
||||||
|
from chunking.markdown import MarkdownChunker
|
||||||
|
|
||||||
|
if "markdown" not in self._chunkers:
|
||||||
|
self._chunkers["markdown"] = MarkdownChunker(
|
||||||
|
chunk_size=self._settings.markdown_chunk_size,
|
||||||
|
chunk_overlap=self._settings.markdown_chunk_overlap,
|
||||||
|
settings=self._settings,
|
||||||
|
)
|
||||||
|
return self._chunkers["markdown"]
|
||||||
|
|
||||||
|
def _get_text_chunker(self) -> "BaseChunker":
|
||||||
|
"""Get or create text chunker."""
|
||||||
|
from chunking.text import TextChunker
|
||||||
|
|
||||||
|
if "text" not in self._chunkers:
|
||||||
|
self._chunkers["text"] = TextChunker(
|
||||||
|
chunk_size=self._settings.text_chunk_size,
|
||||||
|
chunk_overlap=self._settings.text_chunk_overlap,
|
||||||
|
settings=self._settings,
|
||||||
|
)
|
||||||
|
return self._chunkers["text"]
|
||||||
|
|
||||||
|
def get_chunker(
|
||||||
|
self,
|
||||||
|
file_type: FileType | None = None,
|
||||||
|
chunk_type: ChunkType | None = None,
|
||||||
|
) -> BaseChunker:
|
||||||
|
"""
|
||||||
|
Get appropriate chunker for content type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_type: File type to chunk
|
||||||
|
chunk_type: Explicit chunk type to use
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Appropriate chunker instance
|
||||||
|
"""
|
||||||
|
# If explicit chunk type specified, use it
|
||||||
|
if chunk_type:
|
||||||
|
if chunk_type == ChunkType.CODE:
|
||||||
|
return self._get_code_chunker()
|
||||||
|
elif chunk_type == ChunkType.MARKDOWN:
|
||||||
|
return self._get_markdown_chunker()
|
||||||
|
else:
|
||||||
|
return self._get_text_chunker()
|
||||||
|
|
||||||
|
# Otherwise, infer from file type
|
||||||
|
if file_type:
|
||||||
|
if file_type == FileType.MARKDOWN:
|
||||||
|
return self._get_markdown_chunker()
|
||||||
|
elif file_type in (FileType.TEXT, FileType.JSON, FileType.YAML, FileType.TOML):
|
||||||
|
return self._get_text_chunker()
|
||||||
|
else:
|
||||||
|
# Code files
|
||||||
|
return self._get_code_chunker()
|
||||||
|
|
||||||
|
# Default to text chunker
|
||||||
|
return self._get_text_chunker()
|
||||||
|
|
||||||
|
def get_chunker_for_path(self, source_path: str) -> tuple[BaseChunker, FileType | None]:
|
||||||
|
"""
|
||||||
|
Get chunker based on file path extension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_path: File path to chunk
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (chunker, file_type)
|
||||||
|
"""
|
||||||
|
# Extract extension
|
||||||
|
ext = ""
|
||||||
|
if "." in source_path:
|
||||||
|
ext = "." + source_path.rsplit(".", 1)[-1].lower()
|
||||||
|
|
||||||
|
file_type = FILE_EXTENSION_MAP.get(ext)
|
||||||
|
chunker = self.get_chunker(file_type=file_type)
|
||||||
|
|
||||||
|
return chunker, file_type
|
||||||
|
|
||||||
|
def chunk_content(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
source_path: str | None = None,
|
||||||
|
file_type: FileType | None = None,
|
||||||
|
chunk_type: ChunkType | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> list[Chunk]:
|
||||||
|
"""
|
||||||
|
Chunk content using appropriate strategy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content to chunk
|
||||||
|
source_path: Source file path
|
||||||
|
file_type: File type
|
||||||
|
chunk_type: Explicit chunk type
|
||||||
|
metadata: Additional metadata
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of chunks
|
||||||
|
"""
|
||||||
|
# If we have a source path but no file type, infer it
|
||||||
|
if source_path and not file_type:
|
||||||
|
chunker, file_type = self.get_chunker_for_path(source_path)
|
||||||
|
else:
|
||||||
|
chunker = self.get_chunker(file_type=file_type, chunk_type=chunk_type)
|
||||||
|
|
||||||
|
try:
|
||||||
|
chunks = chunker.chunk(
|
||||||
|
content=content,
|
||||||
|
source_path=source_path,
|
||||||
|
file_type=file_type,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Chunked content into {len(chunks)} chunks "
|
||||||
|
f"(type={chunker.chunk_type.value})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Chunking error: {e}")
|
||||||
|
raise ChunkingError(
|
||||||
|
message=f"Failed to chunk content: {e}",
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Global chunker factory instance
|
||||||
|
_chunker_factory: ChunkerFactory | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_chunker_factory() -> ChunkerFactory:
|
||||||
|
"""Get the global chunker factory instance."""
|
||||||
|
global _chunker_factory
|
||||||
|
if _chunker_factory is None:
|
||||||
|
_chunker_factory = ChunkerFactory()
|
||||||
|
return _chunker_factory
|
||||||
|
|
||||||
|
|
||||||
|
def reset_chunker_factory() -> None:
|
||||||
|
"""Reset the global chunker factory (for testing)."""
|
||||||
|
global _chunker_factory
|
||||||
|
_chunker_factory = None
|
||||||
410
mcp-servers/knowledge-base/chunking/code.py
Normal file
410
mcp-servers/knowledge-base/chunking/code.py
Normal file
@@ -0,0 +1,410 @@
|
|||||||
|
"""
|
||||||
|
Code-aware chunking implementation.
|
||||||
|
|
||||||
|
Provides intelligent chunking for source code that respects
|
||||||
|
function/class boundaries and preserves context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from chunking.base import BaseChunker
|
||||||
|
from config import Settings
|
||||||
|
from models import Chunk, ChunkType, FileType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Language-specific patterns for detecting function/class definitions
|
||||||
|
LANGUAGE_PATTERNS: dict[FileType, dict[str, re.Pattern[str]]] = {
|
||||||
|
FileType.PYTHON: {
|
||||||
|
"function": re.compile(r"^(\s*)(async\s+)?def\s+\w+", re.MULTILINE),
|
||||||
|
"class": re.compile(r"^(\s*)class\s+\w+", re.MULTILINE),
|
||||||
|
"decorator": re.compile(r"^(\s*)@\w+", re.MULTILINE),
|
||||||
|
},
|
||||||
|
FileType.JAVASCRIPT: {
|
||||||
|
"function": re.compile(
|
||||||
|
r"^(\s*)(export\s+)?(async\s+)?function\s+\w+|"
|
||||||
|
r"^(\s*)(export\s+)?(const|let|var)\s+\w+\s*=\s*(async\s+)?\(",
|
||||||
|
re.MULTILINE,
|
||||||
|
),
|
||||||
|
"class": re.compile(r"^(\s*)(export\s+)?class\s+\w+", re.MULTILINE),
|
||||||
|
"arrow": re.compile(
|
||||||
|
r"^(\s*)(export\s+)?(const|let|var)\s+\w+\s*=\s*(async\s+)?(\([^)]*\)|[^=])\s*=>",
|
||||||
|
re.MULTILINE,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
FileType.TYPESCRIPT: {
|
||||||
|
"function": re.compile(
|
||||||
|
r"^(\s*)(export\s+)?(async\s+)?function\s+\w+|"
|
||||||
|
r"^(\s*)(export\s+)?(const|let|var)\s+\w+\s*[:<]",
|
||||||
|
re.MULTILINE,
|
||||||
|
),
|
||||||
|
"class": re.compile(r"^(\s*)(export\s+)?class\s+\w+", re.MULTILINE),
|
||||||
|
"interface": re.compile(r"^(\s*)(export\s+)?interface\s+\w+", re.MULTILINE),
|
||||||
|
"type": re.compile(r"^(\s*)(export\s+)?type\s+\w+", re.MULTILINE),
|
||||||
|
},
|
||||||
|
FileType.GO: {
|
||||||
|
"function": re.compile(r"^func\s+(\([^)]+\)\s+)?\w+", re.MULTILINE),
|
||||||
|
"struct": re.compile(r"^type\s+\w+\s+struct", re.MULTILINE),
|
||||||
|
"interface": re.compile(r"^type\s+\w+\s+interface", re.MULTILINE),
|
||||||
|
},
|
||||||
|
FileType.RUST: {
|
||||||
|
"function": re.compile(r"^(\s*)(pub\s+)?(async\s+)?fn\s+\w+", re.MULTILINE),
|
||||||
|
"struct": re.compile(r"^(\s*)(pub\s+)?struct\s+\w+", re.MULTILINE),
|
||||||
|
"impl": re.compile(r"^(\s*)impl\s+", re.MULTILINE),
|
||||||
|
"trait": re.compile(r"^(\s*)(pub\s+)?trait\s+\w+", re.MULTILINE),
|
||||||
|
},
|
||||||
|
FileType.JAVA: {
|
||||||
|
"method": re.compile(
|
||||||
|
r"^(\s*)(public|private|protected)?\s*(static)?\s*\w+\s+\w+\s*\(",
|
||||||
|
re.MULTILINE,
|
||||||
|
),
|
||||||
|
"class": re.compile(
|
||||||
|
r"^(\s*)(public|private|protected)?\s*(abstract)?\s*class\s+\w+",
|
||||||
|
re.MULTILINE,
|
||||||
|
),
|
||||||
|
"interface": re.compile(
|
||||||
|
r"^(\s*)(public|private|protected)?\s*interface\s+\w+",
|
||||||
|
re.MULTILINE,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class CodeChunker(BaseChunker):
|
||||||
|
"""
|
||||||
|
Code-aware chunker that respects logical boundaries.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Detects function/class boundaries
|
||||||
|
- Preserves decorator/annotation context
|
||||||
|
- Handles nested structures
|
||||||
|
- Falls back to line-based chunking when needed
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
chunk_size: int,
|
||||||
|
chunk_overlap: int,
|
||||||
|
settings: Settings | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize code chunker."""
|
||||||
|
super().__init__(chunk_size, chunk_overlap, settings)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chunk_type(self) -> ChunkType:
|
||||||
|
"""Get chunk type."""
|
||||||
|
return ChunkType.CODE
|
||||||
|
|
||||||
|
def chunk(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
source_path: str | None = None,
|
||||||
|
file_type: FileType | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> list[Chunk]:
|
||||||
|
"""
|
||||||
|
Chunk code content.
|
||||||
|
|
||||||
|
Tries to respect function/class boundaries, falling back
|
||||||
|
to line-based chunking if needed.
|
||||||
|
"""
|
||||||
|
if not content.strip():
|
||||||
|
return []
|
||||||
|
|
||||||
|
metadata = metadata or {}
|
||||||
|
lines = content.splitlines(keepends=True)
|
||||||
|
|
||||||
|
# Try language-aware chunking if we have patterns
|
||||||
|
if file_type and file_type in LANGUAGE_PATTERNS:
|
||||||
|
chunks = self._chunk_by_structure(
|
||||||
|
content, lines, file_type, source_path, metadata
|
||||||
|
)
|
||||||
|
if chunks:
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
# Fall back to line-based chunking
|
||||||
|
return self._chunk_by_lines(lines, source_path, file_type, metadata)
|
||||||
|
|
||||||
|
def _chunk_by_structure(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
lines: list[str],
|
||||||
|
file_type: FileType,
|
||||||
|
source_path: str | None,
|
||||||
|
metadata: dict[str, Any],
|
||||||
|
) -> list[Chunk]:
|
||||||
|
"""
|
||||||
|
Chunk by detecting code structure (functions, classes).
|
||||||
|
|
||||||
|
Returns empty list if structure detection isn't useful.
|
||||||
|
"""
|
||||||
|
patterns = LANGUAGE_PATTERNS.get(file_type, {})
|
||||||
|
if not patterns:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Find all structure boundaries
|
||||||
|
boundaries: list[tuple[int, str]] = [] # (line_number, type)
|
||||||
|
|
||||||
|
for struct_type, pattern in patterns.items():
|
||||||
|
for match in pattern.finditer(content):
|
||||||
|
# Convert character position to line number
|
||||||
|
line_num = content[:match.start()].count("\n")
|
||||||
|
boundaries.append((line_num, struct_type))
|
||||||
|
|
||||||
|
if not boundaries:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Sort boundaries by line number
|
||||||
|
boundaries.sort(key=lambda x: x[0])
|
||||||
|
|
||||||
|
# If we have very few boundaries, line-based might be better
|
||||||
|
if len(boundaries) < 3 and len(lines) > 50:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Create chunks based on boundaries
|
||||||
|
chunks: list[Chunk] = []
|
||||||
|
current_start = 0
|
||||||
|
|
||||||
|
for _i, (line_num, struct_type) in enumerate(boundaries):
|
||||||
|
# Check if we need to create a chunk before this boundary
|
||||||
|
if line_num > current_start:
|
||||||
|
# Include any preceding comments/decorators
|
||||||
|
actual_start = self._find_context_start(lines, line_num)
|
||||||
|
if actual_start < current_start:
|
||||||
|
actual_start = current_start
|
||||||
|
|
||||||
|
chunk_lines = lines[current_start:line_num]
|
||||||
|
chunk_content = "".join(chunk_lines)
|
||||||
|
|
||||||
|
if chunk_content.strip():
|
||||||
|
token_count = self.count_tokens(chunk_content)
|
||||||
|
|
||||||
|
# If chunk is too large, split it
|
||||||
|
if token_count > self.chunk_size * 1.5:
|
||||||
|
sub_chunks = self._split_large_chunk(
|
||||||
|
chunk_lines, current_start, source_path, file_type, metadata
|
||||||
|
)
|
||||||
|
chunks.extend(sub_chunks)
|
||||||
|
elif token_count > 0:
|
||||||
|
chunks.append(
|
||||||
|
self._create_chunk(
|
||||||
|
content=chunk_content.rstrip(),
|
||||||
|
source_path=source_path,
|
||||||
|
start_line=current_start + 1,
|
||||||
|
end_line=line_num,
|
||||||
|
file_type=file_type,
|
||||||
|
metadata={**metadata, "structure_type": struct_type},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
current_start = line_num
|
||||||
|
|
||||||
|
# Handle remaining content
|
||||||
|
if current_start < len(lines):
|
||||||
|
chunk_lines = lines[current_start:]
|
||||||
|
chunk_content = "".join(chunk_lines)
|
||||||
|
|
||||||
|
if chunk_content.strip():
|
||||||
|
token_count = self.count_tokens(chunk_content)
|
||||||
|
|
||||||
|
if token_count > self.chunk_size * 1.5:
|
||||||
|
sub_chunks = self._split_large_chunk(
|
||||||
|
chunk_lines, current_start, source_path, file_type, metadata
|
||||||
|
)
|
||||||
|
chunks.extend(sub_chunks)
|
||||||
|
else:
|
||||||
|
chunks.append(
|
||||||
|
self._create_chunk(
|
||||||
|
content=chunk_content.rstrip(),
|
||||||
|
source_path=source_path,
|
||||||
|
start_line=current_start + 1,
|
||||||
|
end_line=len(lines),
|
||||||
|
file_type=file_type,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def _find_context_start(self, lines: list[str], line_num: int) -> int:
|
||||||
|
"""Find the start of context (decorators, comments) before a line."""
|
||||||
|
start = line_num
|
||||||
|
|
||||||
|
# Look backwards for decorators/comments
|
||||||
|
for i in range(line_num - 1, max(0, line_num - 10), -1):
|
||||||
|
line = lines[i].strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
if line.startswith(("#", "//", "/*", "*", "@", "'")):
|
||||||
|
start = i
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
return start
|
||||||
|
|
||||||
|
def _split_large_chunk(
|
||||||
|
self,
|
||||||
|
chunk_lines: list[str],
|
||||||
|
base_line: int,
|
||||||
|
source_path: str | None,
|
||||||
|
file_type: FileType | None,
|
||||||
|
metadata: dict[str, Any],
|
||||||
|
) -> list[Chunk]:
|
||||||
|
"""Split a large chunk into smaller pieces with overlap."""
|
||||||
|
chunks: list[Chunk] = []
|
||||||
|
current_lines: list[str] = []
|
||||||
|
current_tokens = 0
|
||||||
|
chunk_start = 0
|
||||||
|
|
||||||
|
for i, line in enumerate(chunk_lines):
|
||||||
|
line_tokens = self.count_tokens(line)
|
||||||
|
|
||||||
|
if current_tokens + line_tokens > self.chunk_size and current_lines:
|
||||||
|
# Create chunk
|
||||||
|
chunk_content = "".join(current_lines).rstrip()
|
||||||
|
chunks.append(
|
||||||
|
self._create_chunk(
|
||||||
|
content=chunk_content,
|
||||||
|
source_path=source_path,
|
||||||
|
start_line=base_line + chunk_start + 1,
|
||||||
|
end_line=base_line + i,
|
||||||
|
file_type=file_type,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate overlap
|
||||||
|
overlap_tokens = 0
|
||||||
|
overlap_lines: list[str] = []
|
||||||
|
for j in range(len(current_lines) - 1, -1, -1):
|
||||||
|
overlap_tokens += self.count_tokens(current_lines[j])
|
||||||
|
if overlap_tokens >= self.chunk_overlap:
|
||||||
|
overlap_lines = current_lines[j:]
|
||||||
|
break
|
||||||
|
|
||||||
|
current_lines = overlap_lines
|
||||||
|
current_tokens = sum(self.count_tokens(line) for line in current_lines)
|
||||||
|
chunk_start = i - len(overlap_lines)
|
||||||
|
|
||||||
|
current_lines.append(line)
|
||||||
|
current_tokens += line_tokens
|
||||||
|
|
||||||
|
# Final chunk
|
||||||
|
if current_lines:
|
||||||
|
chunk_content = "".join(current_lines).rstrip()
|
||||||
|
if chunk_content.strip():
|
||||||
|
chunks.append(
|
||||||
|
self._create_chunk(
|
||||||
|
content=chunk_content,
|
||||||
|
source_path=source_path,
|
||||||
|
start_line=base_line + chunk_start + 1,
|
||||||
|
end_line=base_line + len(chunk_lines),
|
||||||
|
file_type=file_type,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def _chunk_by_lines(
|
||||||
|
self,
|
||||||
|
lines: list[str],
|
||||||
|
source_path: str | None,
|
||||||
|
file_type: FileType | None,
|
||||||
|
metadata: dict[str, Any],
|
||||||
|
) -> list[Chunk]:
|
||||||
|
"""Chunk by lines with overlap."""
|
||||||
|
chunks: list[Chunk] = []
|
||||||
|
current_lines: list[str] = []
|
||||||
|
current_tokens = 0
|
||||||
|
chunk_start = 0
|
||||||
|
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
line_tokens = self.count_tokens(line)
|
||||||
|
|
||||||
|
# If this line alone exceeds chunk size, handle specially
|
||||||
|
if line_tokens > self.chunk_size:
|
||||||
|
# Flush current chunk
|
||||||
|
if current_lines:
|
||||||
|
chunk_content = "".join(current_lines).rstrip()
|
||||||
|
if chunk_content.strip():
|
||||||
|
chunks.append(
|
||||||
|
self._create_chunk(
|
||||||
|
content=chunk_content,
|
||||||
|
source_path=source_path,
|
||||||
|
start_line=chunk_start + 1,
|
||||||
|
end_line=i,
|
||||||
|
file_type=file_type,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
current_lines = []
|
||||||
|
current_tokens = 0
|
||||||
|
chunk_start = i
|
||||||
|
|
||||||
|
# Truncate and add long line
|
||||||
|
truncated = self.truncate_to_tokens(line, self.chunk_size)
|
||||||
|
chunks.append(
|
||||||
|
self._create_chunk(
|
||||||
|
content=truncated.rstrip(),
|
||||||
|
source_path=source_path,
|
||||||
|
start_line=i + 1,
|
||||||
|
end_line=i + 1,
|
||||||
|
file_type=file_type,
|
||||||
|
metadata={**metadata, "truncated": True},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
chunk_start = i + 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if current_tokens + line_tokens > self.chunk_size and current_lines:
|
||||||
|
# Create chunk
|
||||||
|
chunk_content = "".join(current_lines).rstrip()
|
||||||
|
if chunk_content.strip():
|
||||||
|
chunks.append(
|
||||||
|
self._create_chunk(
|
||||||
|
content=chunk_content,
|
||||||
|
source_path=source_path,
|
||||||
|
start_line=chunk_start + 1,
|
||||||
|
end_line=i,
|
||||||
|
file_type=file_type,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate overlap
|
||||||
|
overlap_tokens = 0
|
||||||
|
overlap_lines: list[str] = []
|
||||||
|
for j in range(len(current_lines) - 1, -1, -1):
|
||||||
|
line_tok = self.count_tokens(current_lines[j])
|
||||||
|
if overlap_tokens + line_tok > self.chunk_overlap:
|
||||||
|
break
|
||||||
|
overlap_lines.insert(0, current_lines[j])
|
||||||
|
overlap_tokens += line_tok
|
||||||
|
|
||||||
|
current_lines = overlap_lines
|
||||||
|
current_tokens = overlap_tokens
|
||||||
|
chunk_start = i - len(overlap_lines)
|
||||||
|
|
||||||
|
current_lines.append(line)
|
||||||
|
current_tokens += line_tokens
|
||||||
|
|
||||||
|
# Final chunk
|
||||||
|
if current_lines:
|
||||||
|
chunk_content = "".join(current_lines).rstrip()
|
||||||
|
if chunk_content.strip():
|
||||||
|
chunks.append(
|
||||||
|
self._create_chunk(
|
||||||
|
content=chunk_content,
|
||||||
|
source_path=source_path,
|
||||||
|
start_line=chunk_start + 1,
|
||||||
|
end_line=len(lines),
|
||||||
|
file_type=file_type,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return chunks
|
||||||
483
mcp-servers/knowledge-base/chunking/markdown.py
Normal file
483
mcp-servers/knowledge-base/chunking/markdown.py
Normal file
@@ -0,0 +1,483 @@
|
|||||||
|
"""
|
||||||
|
Markdown-aware chunking implementation.
|
||||||
|
|
||||||
|
Provides intelligent chunking for markdown content that respects
|
||||||
|
heading hierarchy and preserves document structure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from chunking.base import BaseChunker
|
||||||
|
from config import Settings
|
||||||
|
from models import Chunk, ChunkType, FileType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Patterns for markdown elements
|
||||||
|
HEADING_PATTERN = re.compile(r"^(#{1,6})\s+(.+)$", re.MULTILINE)
|
||||||
|
CODE_BLOCK_PATTERN = re.compile(r"^```", re.MULTILINE)
|
||||||
|
HR_PATTERN = re.compile(r"^(-{3,}|_{3,}|\*{3,})$", re.MULTILINE)
|
||||||
|
|
||||||
|
|
||||||
|
class MarkdownChunker(BaseChunker):
|
||||||
|
"""
|
||||||
|
Markdown-aware chunker that respects document structure.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Respects heading hierarchy
|
||||||
|
- Preserves heading context in chunks
|
||||||
|
- Handles code blocks as units
|
||||||
|
- Maintains list continuity where possible
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
chunk_size: int,
|
||||||
|
chunk_overlap: int,
|
||||||
|
settings: Settings | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize markdown chunker."""
|
||||||
|
super().__init__(chunk_size, chunk_overlap, settings)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chunk_type(self) -> ChunkType:
|
||||||
|
"""Get chunk type."""
|
||||||
|
return ChunkType.MARKDOWN
|
||||||
|
|
||||||
|
def chunk(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
source_path: str | None = None,
|
||||||
|
file_type: FileType | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> list[Chunk]:
|
||||||
|
"""
|
||||||
|
Chunk markdown content.
|
||||||
|
|
||||||
|
Splits on heading boundaries and preserves heading context.
|
||||||
|
"""
|
||||||
|
if not content.strip():
|
||||||
|
return []
|
||||||
|
|
||||||
|
metadata = metadata or {}
|
||||||
|
file_type = file_type or FileType.MARKDOWN
|
||||||
|
|
||||||
|
# Split content into sections by headings
|
||||||
|
sections = self._split_by_headings(content)
|
||||||
|
|
||||||
|
if not sections:
|
||||||
|
# No headings, chunk as plain text
|
||||||
|
return self._chunk_text_block(
|
||||||
|
content, source_path, file_type, metadata, []
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks: list[Chunk] = []
|
||||||
|
heading_stack: list[tuple[int, str]] = [] # (level, text)
|
||||||
|
|
||||||
|
for section in sections:
|
||||||
|
heading_level = section.get("level", 0)
|
||||||
|
heading_text = section.get("heading", "")
|
||||||
|
section_content = section.get("content", "")
|
||||||
|
start_line = section.get("start_line", 1)
|
||||||
|
end_line = section.get("end_line", 1)
|
||||||
|
|
||||||
|
# Update heading stack
|
||||||
|
if heading_level > 0:
|
||||||
|
# Pop headings of equal or higher level
|
||||||
|
while heading_stack and heading_stack[-1][0] >= heading_level:
|
||||||
|
heading_stack.pop()
|
||||||
|
heading_stack.append((heading_level, heading_text))
|
||||||
|
|
||||||
|
# Build heading context prefix
|
||||||
|
heading_context = " > ".join(h[1] for h in heading_stack)
|
||||||
|
|
||||||
|
section_chunks = self._chunk_section(
|
||||||
|
content=section_content,
|
||||||
|
heading_context=heading_context,
|
||||||
|
heading_level=heading_level,
|
||||||
|
heading_text=heading_text,
|
||||||
|
start_line=start_line,
|
||||||
|
end_line=end_line,
|
||||||
|
source_path=source_path,
|
||||||
|
file_type=file_type,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
chunks.extend(section_chunks)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def _split_by_headings(self, content: str) -> list[dict[str, Any]]:
|
||||||
|
"""Split content into sections by headings."""
|
||||||
|
sections: list[dict[str, Any]] = []
|
||||||
|
lines = content.split("\n")
|
||||||
|
|
||||||
|
current_section: dict[str, Any] = {
|
||||||
|
"level": 0,
|
||||||
|
"heading": "",
|
||||||
|
"content": "",
|
||||||
|
"start_line": 1,
|
||||||
|
"end_line": 1,
|
||||||
|
}
|
||||||
|
current_lines: list[str] = []
|
||||||
|
in_code_block = False
|
||||||
|
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
# Track code blocks
|
||||||
|
if line.strip().startswith("```"):
|
||||||
|
in_code_block = not in_code_block
|
||||||
|
current_lines.append(line)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Skip heading detection in code blocks
|
||||||
|
if in_code_block:
|
||||||
|
current_lines.append(line)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check for heading
|
||||||
|
heading_match = HEADING_PATTERN.match(line)
|
||||||
|
if heading_match:
|
||||||
|
# Save previous section
|
||||||
|
if current_lines:
|
||||||
|
current_section["content"] = "\n".join(current_lines)
|
||||||
|
current_section["end_line"] = i
|
||||||
|
if current_section["content"].strip():
|
||||||
|
sections.append(current_section)
|
||||||
|
|
||||||
|
# Start new section
|
||||||
|
level = len(heading_match.group(1))
|
||||||
|
heading_text = heading_match.group(2).strip()
|
||||||
|
current_section = {
|
||||||
|
"level": level,
|
||||||
|
"heading": heading_text,
|
||||||
|
"content": "",
|
||||||
|
"start_line": i + 1,
|
||||||
|
"end_line": i + 1,
|
||||||
|
}
|
||||||
|
current_lines = [line]
|
||||||
|
else:
|
||||||
|
current_lines.append(line)
|
||||||
|
|
||||||
|
# Save final section
|
||||||
|
if current_lines:
|
||||||
|
current_section["content"] = "\n".join(current_lines)
|
||||||
|
current_section["end_line"] = len(lines)
|
||||||
|
if current_section["content"].strip():
|
||||||
|
sections.append(current_section)
|
||||||
|
|
||||||
|
return sections
|
||||||
|
|
||||||
|
def _chunk_section(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
heading_context: str,
|
||||||
|
heading_level: int,
|
||||||
|
heading_text: str,
|
||||||
|
start_line: int,
|
||||||
|
end_line: int,
|
||||||
|
source_path: str | None,
|
||||||
|
file_type: FileType,
|
||||||
|
metadata: dict[str, Any],
|
||||||
|
) -> list[Chunk]:
|
||||||
|
"""Chunk a single section of markdown."""
|
||||||
|
if not content.strip():
|
||||||
|
return []
|
||||||
|
|
||||||
|
token_count = self.count_tokens(content)
|
||||||
|
|
||||||
|
# If section fits in one chunk, return as-is
|
||||||
|
if token_count <= self.chunk_size:
|
||||||
|
section_metadata = {
|
||||||
|
**metadata,
|
||||||
|
"heading_context": heading_context,
|
||||||
|
"heading_level": heading_level,
|
||||||
|
"heading_text": heading_text,
|
||||||
|
}
|
||||||
|
return [
|
||||||
|
self._create_chunk(
|
||||||
|
content=content.strip(),
|
||||||
|
source_path=source_path,
|
||||||
|
start_line=start_line,
|
||||||
|
end_line=end_line,
|
||||||
|
file_type=file_type,
|
||||||
|
metadata=section_metadata,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Need to split - try to split on paragraphs first
|
||||||
|
return self._chunk_text_block(
|
||||||
|
content,
|
||||||
|
source_path,
|
||||||
|
file_type,
|
||||||
|
{
|
||||||
|
**metadata,
|
||||||
|
"heading_context": heading_context,
|
||||||
|
"heading_level": heading_level,
|
||||||
|
"heading_text": heading_text,
|
||||||
|
},
|
||||||
|
_heading_stack=[(heading_level, heading_text)] if heading_text else [],
|
||||||
|
base_line=start_line,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _chunk_text_block(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
source_path: str | None,
|
||||||
|
file_type: FileType,
|
||||||
|
metadata: dict[str, Any],
|
||||||
|
_heading_stack: list[tuple[int, str]],
|
||||||
|
base_line: int = 1,
|
||||||
|
) -> list[Chunk]:
|
||||||
|
"""Chunk a block of text by paragraphs."""
|
||||||
|
# Split into paragraphs (separated by blank lines)
|
||||||
|
paragraphs = self._split_into_paragraphs(content)
|
||||||
|
|
||||||
|
if not paragraphs:
|
||||||
|
return []
|
||||||
|
|
||||||
|
chunks: list[Chunk] = []
|
||||||
|
current_content: list[str] = []
|
||||||
|
current_tokens = 0
|
||||||
|
chunk_start_line = base_line
|
||||||
|
|
||||||
|
for para_info in paragraphs:
|
||||||
|
para_content = para_info["content"]
|
||||||
|
para_tokens = para_info["tokens"]
|
||||||
|
para_start = para_info["start_line"]
|
||||||
|
|
||||||
|
# Handle very large paragraphs
|
||||||
|
if para_tokens > self.chunk_size:
|
||||||
|
# Flush current content
|
||||||
|
if current_content:
|
||||||
|
chunk_text = "\n\n".join(current_content)
|
||||||
|
chunks.append(
|
||||||
|
self._create_chunk(
|
||||||
|
content=chunk_text.strip(),
|
||||||
|
source_path=source_path,
|
||||||
|
start_line=chunk_start_line,
|
||||||
|
end_line=base_line + para_start - 1,
|
||||||
|
file_type=file_type,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
current_content = []
|
||||||
|
current_tokens = 0
|
||||||
|
|
||||||
|
# Split large paragraph by sentences/lines
|
||||||
|
sub_chunks = self._split_large_paragraph(
|
||||||
|
para_content,
|
||||||
|
source_path,
|
||||||
|
file_type,
|
||||||
|
metadata,
|
||||||
|
base_line + para_start,
|
||||||
|
)
|
||||||
|
chunks.extend(sub_chunks)
|
||||||
|
chunk_start_line = base_line + para_info["end_line"] + 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if adding this paragraph exceeds limit
|
||||||
|
if current_tokens + para_tokens > self.chunk_size and current_content:
|
||||||
|
# Create chunk
|
||||||
|
chunk_text = "\n\n".join(current_content)
|
||||||
|
chunks.append(
|
||||||
|
self._create_chunk(
|
||||||
|
content=chunk_text.strip(),
|
||||||
|
source_path=source_path,
|
||||||
|
start_line=chunk_start_line,
|
||||||
|
end_line=base_line + para_start - 1,
|
||||||
|
file_type=file_type,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Overlap: include last paragraph if it fits
|
||||||
|
if current_content and self.count_tokens(current_content[-1]) <= self.chunk_overlap:
|
||||||
|
current_content = [current_content[-1]]
|
||||||
|
current_tokens = self.count_tokens(current_content[-1])
|
||||||
|
else:
|
||||||
|
current_content = []
|
||||||
|
current_tokens = 0
|
||||||
|
|
||||||
|
chunk_start_line = base_line + para_start
|
||||||
|
|
||||||
|
current_content.append(para_content)
|
||||||
|
current_tokens += para_tokens
|
||||||
|
|
||||||
|
# Final chunk
|
||||||
|
if current_content:
|
||||||
|
chunk_text = "\n\n".join(current_content)
|
||||||
|
end_line_num = base_line + (paragraphs[-1]["end_line"] if paragraphs else 0)
|
||||||
|
chunks.append(
|
||||||
|
self._create_chunk(
|
||||||
|
content=chunk_text.strip(),
|
||||||
|
source_path=source_path,
|
||||||
|
start_line=chunk_start_line,
|
||||||
|
end_line=end_line_num,
|
||||||
|
file_type=file_type,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def _split_into_paragraphs(self, content: str) -> list[dict[str, Any]]:
|
||||||
|
"""Split content into paragraphs with metadata."""
|
||||||
|
paragraphs: list[dict[str, Any]] = []
|
||||||
|
lines = content.split("\n")
|
||||||
|
|
||||||
|
current_para: list[str] = []
|
||||||
|
para_start = 0
|
||||||
|
in_code_block = False
|
||||||
|
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
# Track code blocks (keep them as single units)
|
||||||
|
if line.strip().startswith("```"):
|
||||||
|
if in_code_block:
|
||||||
|
# End of code block
|
||||||
|
current_para.append(line)
|
||||||
|
in_code_block = False
|
||||||
|
else:
|
||||||
|
# Start of code block - save previous paragraph
|
||||||
|
if current_para and any(p.strip() for p in current_para):
|
||||||
|
para_content = "\n".join(current_para)
|
||||||
|
paragraphs.append({
|
||||||
|
"content": para_content,
|
||||||
|
"tokens": self.count_tokens(para_content),
|
||||||
|
"start_line": para_start,
|
||||||
|
"end_line": i - 1,
|
||||||
|
})
|
||||||
|
current_para = [line]
|
||||||
|
para_start = i
|
||||||
|
in_code_block = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
if in_code_block:
|
||||||
|
current_para.append(line)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Empty line indicates paragraph break
|
||||||
|
if not line.strip():
|
||||||
|
if current_para and any(p.strip() for p in current_para):
|
||||||
|
para_content = "\n".join(current_para)
|
||||||
|
paragraphs.append({
|
||||||
|
"content": para_content,
|
||||||
|
"tokens": self.count_tokens(para_content),
|
||||||
|
"start_line": para_start,
|
||||||
|
"end_line": i - 1,
|
||||||
|
})
|
||||||
|
current_para = []
|
||||||
|
para_start = i + 1
|
||||||
|
else:
|
||||||
|
if not current_para:
|
||||||
|
para_start = i
|
||||||
|
current_para.append(line)
|
||||||
|
|
||||||
|
# Final paragraph
|
||||||
|
if current_para and any(p.strip() for p in current_para):
|
||||||
|
para_content = "\n".join(current_para)
|
||||||
|
paragraphs.append({
|
||||||
|
"content": para_content,
|
||||||
|
"tokens": self.count_tokens(para_content),
|
||||||
|
"start_line": para_start,
|
||||||
|
"end_line": len(lines) - 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
return paragraphs
|
||||||
|
|
||||||
|
def _split_large_paragraph(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
source_path: str | None,
|
||||||
|
file_type: FileType,
|
||||||
|
metadata: dict[str, Any],
|
||||||
|
base_line: int,
|
||||||
|
) -> list[Chunk]:
|
||||||
|
"""Split a large paragraph into smaller chunks."""
|
||||||
|
# Try splitting by sentences
|
||||||
|
sentences = self._split_into_sentences(content)
|
||||||
|
|
||||||
|
chunks: list[Chunk] = []
|
||||||
|
current_content: list[str] = []
|
||||||
|
current_tokens = 0
|
||||||
|
|
||||||
|
for sentence in sentences:
|
||||||
|
sentence_tokens = self.count_tokens(sentence)
|
||||||
|
|
||||||
|
# If single sentence is too large, truncate
|
||||||
|
if sentence_tokens > self.chunk_size:
|
||||||
|
if current_content:
|
||||||
|
chunk_text = " ".join(current_content)
|
||||||
|
chunks.append(
|
||||||
|
self._create_chunk(
|
||||||
|
content=chunk_text.strip(),
|
||||||
|
source_path=source_path,
|
||||||
|
start_line=base_line,
|
||||||
|
end_line=base_line,
|
||||||
|
file_type=file_type,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
current_content = []
|
||||||
|
current_tokens = 0
|
||||||
|
|
||||||
|
truncated = self.truncate_to_tokens(sentence, self.chunk_size)
|
||||||
|
chunks.append(
|
||||||
|
self._create_chunk(
|
||||||
|
content=truncated.strip(),
|
||||||
|
source_path=source_path,
|
||||||
|
start_line=base_line,
|
||||||
|
end_line=base_line,
|
||||||
|
file_type=file_type,
|
||||||
|
metadata={**metadata, "truncated": True},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if current_tokens + sentence_tokens > self.chunk_size and current_content:
|
||||||
|
chunk_text = " ".join(current_content)
|
||||||
|
chunks.append(
|
||||||
|
self._create_chunk(
|
||||||
|
content=chunk_text.strip(),
|
||||||
|
source_path=source_path,
|
||||||
|
start_line=base_line,
|
||||||
|
end_line=base_line,
|
||||||
|
file_type=file_type,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Overlap with last sentence
|
||||||
|
if current_content and self.count_tokens(current_content[-1]) <= self.chunk_overlap:
|
||||||
|
current_content = [current_content[-1]]
|
||||||
|
current_tokens = self.count_tokens(current_content[-1])
|
||||||
|
else:
|
||||||
|
current_content = []
|
||||||
|
current_tokens = 0
|
||||||
|
|
||||||
|
current_content.append(sentence)
|
||||||
|
current_tokens += sentence_tokens
|
||||||
|
|
||||||
|
# Final chunk
|
||||||
|
if current_content:
|
||||||
|
chunk_text = " ".join(current_content)
|
||||||
|
chunks.append(
|
||||||
|
self._create_chunk(
|
||||||
|
content=chunk_text.strip(),
|
||||||
|
source_path=source_path,
|
||||||
|
start_line=base_line,
|
||||||
|
end_line=base_line,
|
||||||
|
file_type=file_type,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def _split_into_sentences(self, text: str) -> list[str]:
|
||||||
|
"""Split text into sentences."""
|
||||||
|
# Simple sentence splitting on common terminators
|
||||||
|
# More sophisticated splitting could use nltk or spacy
|
||||||
|
sentence_endings = re.compile(r"(?<=[.!?])\s+")
|
||||||
|
sentences = sentence_endings.split(text)
|
||||||
|
return [s.strip() for s in sentences if s.strip()]
|
||||||
389
mcp-servers/knowledge-base/chunking/text.py
Normal file
389
mcp-servers/knowledge-base/chunking/text.py
Normal file
@@ -0,0 +1,389 @@
|
|||||||
|
"""
|
||||||
|
Plain text chunking implementation.
|
||||||
|
|
||||||
|
Provides simple text chunking with paragraph and sentence
|
||||||
|
boundary detection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from chunking.base import BaseChunker
|
||||||
|
from config import Settings
|
||||||
|
from models import Chunk, ChunkType, FileType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TextChunker(BaseChunker):
|
||||||
|
"""
|
||||||
|
Plain text chunker with paragraph awareness.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Splits on paragraph boundaries
|
||||||
|
- Falls back to sentence/word boundaries
|
||||||
|
- Configurable overlap for context preservation
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
chunk_size: int,
|
||||||
|
chunk_overlap: int,
|
||||||
|
settings: Settings | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize text chunker."""
|
||||||
|
super().__init__(chunk_size, chunk_overlap, settings)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chunk_type(self) -> ChunkType:
|
||||||
|
"""Get chunk type."""
|
||||||
|
return ChunkType.TEXT
|
||||||
|
|
||||||
|
def chunk(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
source_path: str | None = None,
|
||||||
|
file_type: FileType | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> list[Chunk]:
|
||||||
|
"""
|
||||||
|
Chunk plain text content.
|
||||||
|
|
||||||
|
Tries paragraph boundaries first, then sentences.
|
||||||
|
"""
|
||||||
|
if not content.strip():
|
||||||
|
return []
|
||||||
|
|
||||||
|
metadata = metadata or {}
|
||||||
|
|
||||||
|
# Check if content fits in a single chunk
|
||||||
|
total_tokens = self.count_tokens(content)
|
||||||
|
if total_tokens <= self.chunk_size:
|
||||||
|
return [
|
||||||
|
self._create_chunk(
|
||||||
|
content=content.strip(),
|
||||||
|
source_path=source_path,
|
||||||
|
start_line=1,
|
||||||
|
end_line=content.count("\n") + 1,
|
||||||
|
file_type=file_type,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Try paragraph-based chunking
|
||||||
|
paragraphs = self._split_paragraphs(content)
|
||||||
|
if len(paragraphs) > 1:
|
||||||
|
return self._chunk_by_paragraphs(
|
||||||
|
paragraphs, source_path, file_type, metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fall back to sentence-based chunking
|
||||||
|
return self._chunk_by_sentences(
|
||||||
|
content, source_path, file_type, metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
def _split_paragraphs(self, content: str) -> list[dict[str, Any]]:
|
||||||
|
"""Split content into paragraphs."""
|
||||||
|
paragraphs: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
# Split on double newlines (paragraph boundaries)
|
||||||
|
raw_paras = re.split(r"\n\s*\n", content)
|
||||||
|
|
||||||
|
line_num = 1
|
||||||
|
for para in raw_paras:
|
||||||
|
para = para.strip()
|
||||||
|
if not para:
|
||||||
|
continue
|
||||||
|
|
||||||
|
para_lines = para.count("\n") + 1
|
||||||
|
paragraphs.append({
|
||||||
|
"content": para,
|
||||||
|
"tokens": self.count_tokens(para),
|
||||||
|
"start_line": line_num,
|
||||||
|
"end_line": line_num + para_lines - 1,
|
||||||
|
})
|
||||||
|
line_num += para_lines + 1 # +1 for blank line between paragraphs
|
||||||
|
|
||||||
|
return paragraphs
|
||||||
|
|
||||||
|
def _chunk_by_paragraphs(
|
||||||
|
self,
|
||||||
|
paragraphs: list[dict[str, Any]],
|
||||||
|
source_path: str | None,
|
||||||
|
file_type: FileType | None,
|
||||||
|
metadata: dict[str, Any],
|
||||||
|
) -> list[Chunk]:
|
||||||
|
"""Chunk by combining paragraphs up to size limit."""
|
||||||
|
chunks: list[Chunk] = []
|
||||||
|
current_paras: list[str] = []
|
||||||
|
current_tokens = 0
|
||||||
|
chunk_start = paragraphs[0]["start_line"] if paragraphs else 1
|
||||||
|
chunk_end = chunk_start
|
||||||
|
|
||||||
|
for para in paragraphs:
|
||||||
|
para_content = para["content"]
|
||||||
|
para_tokens = para["tokens"]
|
||||||
|
|
||||||
|
# Handle paragraphs larger than chunk size
|
||||||
|
if para_tokens > self.chunk_size:
|
||||||
|
# Flush current content
|
||||||
|
if current_paras:
|
||||||
|
chunk_text = "\n\n".join(current_paras)
|
||||||
|
chunks.append(
|
||||||
|
self._create_chunk(
|
||||||
|
content=chunk_text,
|
||||||
|
source_path=source_path,
|
||||||
|
start_line=chunk_start,
|
||||||
|
end_line=chunk_end,
|
||||||
|
file_type=file_type,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
current_paras = []
|
||||||
|
current_tokens = 0
|
||||||
|
|
||||||
|
# Split large paragraph
|
||||||
|
sub_chunks = self._split_large_text(
|
||||||
|
para_content,
|
||||||
|
source_path,
|
||||||
|
file_type,
|
||||||
|
metadata,
|
||||||
|
para["start_line"],
|
||||||
|
)
|
||||||
|
chunks.extend(sub_chunks)
|
||||||
|
chunk_start = para["end_line"] + 1
|
||||||
|
chunk_end = chunk_start
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if adding paragraph exceeds limit
|
||||||
|
if current_tokens + para_tokens > self.chunk_size and current_paras:
|
||||||
|
chunk_text = "\n\n".join(current_paras)
|
||||||
|
chunks.append(
|
||||||
|
self._create_chunk(
|
||||||
|
content=chunk_text,
|
||||||
|
source_path=source_path,
|
||||||
|
start_line=chunk_start,
|
||||||
|
end_line=chunk_end,
|
||||||
|
file_type=file_type,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Overlap: keep last paragraph if small enough
|
||||||
|
overlap_para = None
|
||||||
|
if current_paras and self.count_tokens(current_paras[-1]) <= self.chunk_overlap:
|
||||||
|
overlap_para = current_paras[-1]
|
||||||
|
|
||||||
|
current_paras = [overlap_para] if overlap_para else []
|
||||||
|
current_tokens = self.count_tokens(overlap_para) if overlap_para else 0
|
||||||
|
chunk_start = para["start_line"]
|
||||||
|
|
||||||
|
current_paras.append(para_content)
|
||||||
|
current_tokens += para_tokens
|
||||||
|
chunk_end = para["end_line"]
|
||||||
|
|
||||||
|
# Final chunk
|
||||||
|
if current_paras:
|
||||||
|
chunk_text = "\n\n".join(current_paras)
|
||||||
|
chunks.append(
|
||||||
|
self._create_chunk(
|
||||||
|
content=chunk_text,
|
||||||
|
source_path=source_path,
|
||||||
|
start_line=chunk_start,
|
||||||
|
end_line=chunk_end,
|
||||||
|
file_type=file_type,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def _chunk_by_sentences(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
source_path: str | None,
|
||||||
|
file_type: FileType | None,
|
||||||
|
metadata: dict[str, Any],
|
||||||
|
) -> list[Chunk]:
|
||||||
|
"""Chunk by sentences."""
|
||||||
|
sentences = self._split_sentences(content)
|
||||||
|
|
||||||
|
if not sentences:
|
||||||
|
return []
|
||||||
|
|
||||||
|
chunks: list[Chunk] = []
|
||||||
|
current_sentences: list[str] = []
|
||||||
|
current_tokens = 0
|
||||||
|
|
||||||
|
for sentence in sentences:
|
||||||
|
sentence_tokens = self.count_tokens(sentence)
|
||||||
|
|
||||||
|
# Handle sentences larger than chunk size
|
||||||
|
if sentence_tokens > self.chunk_size:
|
||||||
|
if current_sentences:
|
||||||
|
chunk_text = " ".join(current_sentences)
|
||||||
|
chunks.append(
|
||||||
|
self._create_chunk(
|
||||||
|
content=chunk_text,
|
||||||
|
source_path=source_path,
|
||||||
|
start_line=1,
|
||||||
|
end_line=1,
|
||||||
|
file_type=file_type,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
current_sentences = []
|
||||||
|
current_tokens = 0
|
||||||
|
|
||||||
|
# Truncate large sentence
|
||||||
|
truncated = self.truncate_to_tokens(sentence, self.chunk_size)
|
||||||
|
chunks.append(
|
||||||
|
self._create_chunk(
|
||||||
|
content=truncated,
|
||||||
|
source_path=source_path,
|
||||||
|
start_line=1,
|
||||||
|
end_line=1,
|
||||||
|
file_type=file_type,
|
||||||
|
metadata={**metadata, "truncated": True},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if adding sentence exceeds limit
|
||||||
|
if current_tokens + sentence_tokens > self.chunk_size and current_sentences:
|
||||||
|
chunk_text = " ".join(current_sentences)
|
||||||
|
chunks.append(
|
||||||
|
self._create_chunk(
|
||||||
|
content=chunk_text,
|
||||||
|
source_path=source_path,
|
||||||
|
start_line=1,
|
||||||
|
end_line=1,
|
||||||
|
file_type=file_type,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Overlap: keep last sentence if small enough
|
||||||
|
overlap = None
|
||||||
|
if current_sentences and self.count_tokens(current_sentences[-1]) <= self.chunk_overlap:
|
||||||
|
overlap = current_sentences[-1]
|
||||||
|
|
||||||
|
current_sentences = [overlap] if overlap else []
|
||||||
|
current_tokens = self.count_tokens(overlap) if overlap else 0
|
||||||
|
|
||||||
|
current_sentences.append(sentence)
|
||||||
|
current_tokens += sentence_tokens
|
||||||
|
|
||||||
|
# Final chunk
|
||||||
|
if current_sentences:
|
||||||
|
chunk_text = " ".join(current_sentences)
|
||||||
|
chunks.append(
|
||||||
|
self._create_chunk(
|
||||||
|
content=chunk_text,
|
||||||
|
source_path=source_path,
|
||||||
|
start_line=1,
|
||||||
|
end_line=content.count("\n") + 1,
|
||||||
|
file_type=file_type,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def _split_sentences(self, text: str) -> list[str]:
|
||||||
|
"""Split text into sentences."""
|
||||||
|
# Handle common sentence endings
|
||||||
|
# This is a simple approach - production might use nltk or spacy
|
||||||
|
sentence_pattern = re.compile(
|
||||||
|
r"(?<=[.!?])\s+(?=[A-Z])|" # Standard sentence ending
|
||||||
|
r"(?<=[.!?])\s*$|" # End of text
|
||||||
|
r"(?<=\n)\s*(?=\S)" # Newlines as boundaries
|
||||||
|
)
|
||||||
|
|
||||||
|
sentences = sentence_pattern.split(text)
|
||||||
|
return [s.strip() for s in sentences if s.strip()]
|
||||||
|
|
||||||
|
def _split_large_text(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
source_path: str | None,
|
||||||
|
file_type: FileType | None,
|
||||||
|
metadata: dict[str, Any],
|
||||||
|
base_line: int,
|
||||||
|
) -> list[Chunk]:
|
||||||
|
"""Split text that exceeds chunk size."""
|
||||||
|
# First try sentences
|
||||||
|
sentences = self._split_sentences(text)
|
||||||
|
|
||||||
|
if len(sentences) > 1:
|
||||||
|
return self._chunk_by_sentences(
|
||||||
|
text, source_path, file_type, metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fall back to word-based splitting
|
||||||
|
return self._chunk_by_words(
|
||||||
|
text, source_path, file_type, metadata, base_line
|
||||||
|
)
|
||||||
|
|
||||||
|
def _chunk_by_words(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
source_path: str | None,
|
||||||
|
file_type: FileType | None,
|
||||||
|
metadata: dict[str, Any],
|
||||||
|
base_line: int,
|
||||||
|
) -> list[Chunk]:
|
||||||
|
"""Last resort: chunk by words."""
|
||||||
|
words = text.split()
|
||||||
|
chunks: list[Chunk] = []
|
||||||
|
current_words: list[str] = []
|
||||||
|
current_tokens = 0
|
||||||
|
|
||||||
|
for word in words:
|
||||||
|
word_tokens = self.count_tokens(word + " ")
|
||||||
|
|
||||||
|
if current_tokens + word_tokens > self.chunk_size and current_words:
|
||||||
|
chunk_text = " ".join(current_words)
|
||||||
|
chunks.append(
|
||||||
|
self._create_chunk(
|
||||||
|
content=chunk_text,
|
||||||
|
source_path=source_path,
|
||||||
|
start_line=base_line,
|
||||||
|
end_line=base_line,
|
||||||
|
file_type=file_type,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Word overlap
|
||||||
|
overlap_count = 0
|
||||||
|
overlap_words: list[str] = []
|
||||||
|
for w in reversed(current_words):
|
||||||
|
w_tokens = self.count_tokens(w + " ")
|
||||||
|
if overlap_count + w_tokens > self.chunk_overlap:
|
||||||
|
break
|
||||||
|
overlap_words.insert(0, w)
|
||||||
|
overlap_count += w_tokens
|
||||||
|
|
||||||
|
current_words = overlap_words
|
||||||
|
current_tokens = overlap_count
|
||||||
|
|
||||||
|
current_words.append(word)
|
||||||
|
current_tokens += word_tokens
|
||||||
|
|
||||||
|
# Final chunk
|
||||||
|
if current_words:
|
||||||
|
chunk_text = " ".join(current_words)
|
||||||
|
chunks.append(
|
||||||
|
self._create_chunk(
|
||||||
|
content=chunk_text,
|
||||||
|
source_path=source_path,
|
||||||
|
start_line=base_line,
|
||||||
|
end_line=base_line,
|
||||||
|
file_type=file_type,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return chunks
|
||||||
382
mcp-servers/knowledge-base/collection_manager.py
Normal file
382
mcp-servers/knowledge-base/collection_manager.py
Normal file
@@ -0,0 +1,382 @@
|
|||||||
|
"""
|
||||||
|
Collection management for Knowledge Base MCP Server.
|
||||||
|
|
||||||
|
Provides operations for managing document collections including
|
||||||
|
ingestion, deletion, and statistics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from chunking.base import ChunkerFactory, get_chunker_factory
|
||||||
|
from config import Settings, get_settings
|
||||||
|
from database import DatabaseManager, get_database_manager
|
||||||
|
from embeddings import EmbeddingGenerator, get_embedding_generator
|
||||||
|
from models import (
|
||||||
|
ChunkType,
|
||||||
|
CollectionStats,
|
||||||
|
DeleteRequest,
|
||||||
|
DeleteResult,
|
||||||
|
FileType,
|
||||||
|
IngestRequest,
|
||||||
|
IngestResult,
|
||||||
|
ListCollectionsResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CollectionManager:
|
||||||
|
"""
|
||||||
|
Manages knowledge base collections.
|
||||||
|
|
||||||
|
Handles document ingestion, chunking, embedding generation,
|
||||||
|
and collection operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
settings: Settings | None = None,
|
||||||
|
database: DatabaseManager | None = None,
|
||||||
|
embeddings: EmbeddingGenerator | None = None,
|
||||||
|
chunker_factory: ChunkerFactory | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize collection manager."""
|
||||||
|
self._settings = settings or get_settings()
|
||||||
|
self._database = database
|
||||||
|
self._embeddings = embeddings
|
||||||
|
self._chunker_factory = chunker_factory
|
||||||
|
|
||||||
|
@property
|
||||||
|
def database(self) -> DatabaseManager:
|
||||||
|
"""Get database manager."""
|
||||||
|
if self._database is None:
|
||||||
|
self._database = get_database_manager()
|
||||||
|
return self._database
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embeddings(self) -> EmbeddingGenerator:
|
||||||
|
"""Get embedding generator."""
|
||||||
|
if self._embeddings is None:
|
||||||
|
self._embeddings = get_embedding_generator()
|
||||||
|
return self._embeddings
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chunker_factory(self) -> ChunkerFactory:
|
||||||
|
"""Get chunker factory."""
|
||||||
|
if self._chunker_factory is None:
|
||||||
|
self._chunker_factory = get_chunker_factory()
|
||||||
|
return self._chunker_factory
|
||||||
|
|
||||||
|
async def ingest(self, request: IngestRequest) -> IngestResult:
|
||||||
|
"""
|
||||||
|
Ingest content into the knowledge base.
|
||||||
|
|
||||||
|
Chunks the content, generates embeddings, and stores them.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Ingest request with content and options
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Ingest result with created chunk IDs
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Chunk the content
|
||||||
|
chunks = self.chunker_factory.chunk_content(
|
||||||
|
content=request.content,
|
||||||
|
source_path=request.source_path,
|
||||||
|
file_type=request.file_type,
|
||||||
|
chunk_type=request.chunk_type,
|
||||||
|
metadata=request.metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not chunks:
|
||||||
|
return IngestResult(
|
||||||
|
success=True,
|
||||||
|
chunks_created=0,
|
||||||
|
embeddings_generated=0,
|
||||||
|
source_path=request.source_path,
|
||||||
|
collection=request.collection,
|
||||||
|
chunk_ids=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract chunk contents for embedding
|
||||||
|
chunk_texts = [chunk.content for chunk in chunks]
|
||||||
|
|
||||||
|
# Generate embeddings
|
||||||
|
embeddings_list = await self.embeddings.generate_batch(
|
||||||
|
texts=chunk_texts,
|
||||||
|
project_id=request.project_id,
|
||||||
|
agent_id=request.agent_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store embeddings
|
||||||
|
chunk_ids: list[str] = []
|
||||||
|
for chunk, embedding in zip(chunks, embeddings_list, strict=True):
|
||||||
|
# Build metadata with chunk info
|
||||||
|
chunk_metadata = {
|
||||||
|
**request.metadata,
|
||||||
|
**chunk.metadata,
|
||||||
|
"token_count": chunk.token_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
chunk_id = await self.database.store_embedding(
|
||||||
|
project_id=request.project_id,
|
||||||
|
collection=request.collection,
|
||||||
|
content=chunk.content,
|
||||||
|
embedding=embedding,
|
||||||
|
chunk_type=chunk.chunk_type,
|
||||||
|
source_path=chunk.source_path or request.source_path,
|
||||||
|
start_line=chunk.start_line,
|
||||||
|
end_line=chunk.end_line,
|
||||||
|
file_type=chunk.file_type or request.file_type,
|
||||||
|
metadata=chunk_metadata,
|
||||||
|
)
|
||||||
|
chunk_ids.append(chunk_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Ingested {len(chunks)} chunks into collection '{request.collection}' "
|
||||||
|
f"for project {request.project_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return IngestResult(
|
||||||
|
success=True,
|
||||||
|
chunks_created=len(chunks),
|
||||||
|
embeddings_generated=len(embeddings_list),
|
||||||
|
source_path=request.source_path,
|
||||||
|
collection=request.collection,
|
||||||
|
chunk_ids=chunk_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Ingest error: {e}")
|
||||||
|
return IngestResult(
|
||||||
|
success=False,
|
||||||
|
chunks_created=0,
|
||||||
|
embeddings_generated=0,
|
||||||
|
source_path=request.source_path,
|
||||||
|
collection=request.collection,
|
||||||
|
chunk_ids=[],
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete(self, request: DeleteRequest) -> DeleteResult:
|
||||||
|
"""
|
||||||
|
Delete content from the knowledge base.
|
||||||
|
|
||||||
|
Supports deletion by source path, collection, or chunk IDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Delete request with target specification
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Delete result with count of deleted chunks
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
deleted_count = 0
|
||||||
|
|
||||||
|
if request.chunk_ids:
|
||||||
|
# Delete specific chunks
|
||||||
|
deleted_count = await self.database.delete_by_ids(
|
||||||
|
project_id=request.project_id,
|
||||||
|
chunk_ids=request.chunk_ids,
|
||||||
|
)
|
||||||
|
elif request.source_path:
|
||||||
|
# Delete by source path
|
||||||
|
deleted_count = await self.database.delete_by_source(
|
||||||
|
project_id=request.project_id,
|
||||||
|
source_path=request.source_path,
|
||||||
|
collection=request.collection,
|
||||||
|
)
|
||||||
|
elif request.collection:
|
||||||
|
# Delete entire collection
|
||||||
|
deleted_count = await self.database.delete_collection(
|
||||||
|
project_id=request.project_id,
|
||||||
|
collection=request.collection,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return DeleteResult(
|
||||||
|
success=False,
|
||||||
|
chunks_deleted=0,
|
||||||
|
error="Must specify chunk_ids, source_path, or collection",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Deleted {deleted_count} chunks for project {request.project_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return DeleteResult(
|
||||||
|
success=True,
|
||||||
|
chunks_deleted=deleted_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Delete error: {e}")
|
||||||
|
return DeleteResult(
|
||||||
|
success=False,
|
||||||
|
chunks_deleted=0,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_collections(self, project_id: str) -> ListCollectionsResponse:
|
||||||
|
"""
|
||||||
|
List all collections for a project.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of collection info
|
||||||
|
"""
|
||||||
|
collections = await self.database.list_collections(project_id)
|
||||||
|
|
||||||
|
return ListCollectionsResponse(
|
||||||
|
project_id=project_id,
|
||||||
|
collections=collections,
|
||||||
|
total_collections=len(collections),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_collection_stats(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
collection: str,
|
||||||
|
) -> CollectionStats:
|
||||||
|
"""
|
||||||
|
Get statistics for a collection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project ID
|
||||||
|
collection: Collection name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Collection statistics
|
||||||
|
"""
|
||||||
|
return await self.database.get_collection_stats(project_id, collection)
|
||||||
|
|
||||||
|
async def update_document(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
agent_id: str,
|
||||||
|
source_path: str,
|
||||||
|
content: str,
|
||||||
|
collection: str = "default",
|
||||||
|
chunk_type: ChunkType = ChunkType.TEXT,
|
||||||
|
file_type: FileType | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> IngestResult:
|
||||||
|
"""
|
||||||
|
Update a document by atomically replacing existing chunks.
|
||||||
|
|
||||||
|
Uses a database transaction to delete existing chunks and insert new ones
|
||||||
|
atomically, preventing race conditions during concurrent updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project ID
|
||||||
|
agent_id: Agent ID
|
||||||
|
source_path: Source file path
|
||||||
|
content: New content
|
||||||
|
collection: Collection name
|
||||||
|
chunk_type: Type of content
|
||||||
|
file_type: File type for code chunking
|
||||||
|
metadata: Additional metadata
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Ingest result
|
||||||
|
"""
|
||||||
|
request_metadata = metadata or {}
|
||||||
|
|
||||||
|
# Chunk the content
|
||||||
|
chunks = self.chunker_factory.chunk_content(
|
||||||
|
content=content,
|
||||||
|
source_path=source_path,
|
||||||
|
file_type=file_type,
|
||||||
|
chunk_type=chunk_type,
|
||||||
|
metadata=request_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not chunks:
|
||||||
|
# No chunks = delete existing and return empty result
|
||||||
|
await self.database.delete_by_source(
|
||||||
|
project_id=project_id,
|
||||||
|
source_path=source_path,
|
||||||
|
collection=collection,
|
||||||
|
)
|
||||||
|
return IngestResult(
|
||||||
|
success=True,
|
||||||
|
chunks_created=0,
|
||||||
|
embeddings_generated=0,
|
||||||
|
source_path=source_path,
|
||||||
|
collection=collection,
|
||||||
|
chunk_ids=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate embeddings for new chunks
|
||||||
|
chunk_texts = [chunk.content for chunk in chunks]
|
||||||
|
embeddings_list = await self.embeddings.generate_batch(
|
||||||
|
texts=chunk_texts,
|
||||||
|
project_id=project_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build embeddings data for transactional replace
|
||||||
|
embeddings_data = []
|
||||||
|
for chunk, embedding in zip(chunks, embeddings_list, strict=True):
|
||||||
|
chunk_metadata = {
|
||||||
|
**request_metadata,
|
||||||
|
**chunk.metadata,
|
||||||
|
"token_count": chunk.token_count,
|
||||||
|
"source_path": chunk.source_path or source_path,
|
||||||
|
"start_line": chunk.start_line,
|
||||||
|
"end_line": chunk.end_line,
|
||||||
|
"file_type": (chunk.file_type or file_type).value if (chunk.file_type or file_type) else None,
|
||||||
|
}
|
||||||
|
embeddings_data.append((
|
||||||
|
chunk.content,
|
||||||
|
embedding,
|
||||||
|
chunk.chunk_type,
|
||||||
|
chunk_metadata,
|
||||||
|
))
|
||||||
|
|
||||||
|
# Atomically replace old embeddings with new ones
|
||||||
|
_, chunk_ids = await self.database.replace_source_embeddings(
|
||||||
|
project_id=project_id,
|
||||||
|
source_path=source_path,
|
||||||
|
collection=collection,
|
||||||
|
embeddings=embeddings_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
return IngestResult(
|
||||||
|
success=True,
|
||||||
|
chunks_created=len(chunk_ids),
|
||||||
|
embeddings_generated=len(embeddings_list),
|
||||||
|
source_path=source_path,
|
||||||
|
collection=collection,
|
||||||
|
chunk_ids=chunk_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def cleanup_expired(self) -> int:
|
||||||
|
"""
|
||||||
|
Remove expired embeddings from all collections.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of embeddings removed
|
||||||
|
"""
|
||||||
|
return await self.database.cleanup_expired()
|
||||||
|
|
||||||
|
|
||||||
|
# Global collection manager instance (lazy initialization)
|
||||||
|
_collection_manager: CollectionManager | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_collection_manager() -> CollectionManager:
|
||||||
|
"""Get the global collection manager instance."""
|
||||||
|
global _collection_manager
|
||||||
|
if _collection_manager is None:
|
||||||
|
_collection_manager = CollectionManager()
|
||||||
|
return _collection_manager
|
||||||
|
|
||||||
|
|
||||||
|
def reset_collection_manager() -> None:
|
||||||
|
"""Reset the global collection manager (for testing)."""
|
||||||
|
global _collection_manager
|
||||||
|
_collection_manager = None
|
||||||
152
mcp-servers/knowledge-base/config.py
Normal file
152
mcp-servers/knowledge-base/config.py
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
"""
|
||||||
|
Configuration for Knowledge Base MCP Server.
|
||||||
|
|
||||||
|
Uses pydantic-settings for environment variable loading.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
"""Application settings loaded from environment."""
|
||||||
|
|
||||||
|
# Server settings
|
||||||
|
host: str = Field(default="0.0.0.0", description="Server host")
|
||||||
|
port: int = Field(default=8002, description="Server port")
|
||||||
|
debug: bool = Field(default=False, description="Debug mode")
|
||||||
|
|
||||||
|
# Database settings
|
||||||
|
database_url: str = Field(
|
||||||
|
default="postgresql://postgres:postgres@localhost:5432/syndarix",
|
||||||
|
description="PostgreSQL connection URL with pgvector extension",
|
||||||
|
)
|
||||||
|
database_pool_size: int = Field(default=10, description="Connection pool size")
|
||||||
|
database_pool_max_overflow: int = Field(
|
||||||
|
default=20, description="Max overflow connections"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Redis settings
|
||||||
|
redis_url: str = Field(
|
||||||
|
default="redis://localhost:6379/0",
|
||||||
|
description="Redis connection URL",
|
||||||
|
)
|
||||||
|
|
||||||
|
# LLM Gateway settings (for embeddings)
|
||||||
|
llm_gateway_url: str = Field(
|
||||||
|
default="http://localhost:8001",
|
||||||
|
description="LLM Gateway MCP server URL",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Embedding settings
|
||||||
|
embedding_model: str = Field(
|
||||||
|
default="text-embedding-3-large",
|
||||||
|
description="Default embedding model",
|
||||||
|
)
|
||||||
|
embedding_dimension: int = Field(
|
||||||
|
default=1536,
|
||||||
|
description="Embedding vector dimension",
|
||||||
|
)
|
||||||
|
embedding_batch_size: int = Field(
|
||||||
|
default=100,
|
||||||
|
description="Max texts per embedding batch",
|
||||||
|
)
|
||||||
|
embedding_cache_ttl: int = Field(
|
||||||
|
default=86400,
|
||||||
|
description="Embedding cache TTL in seconds (24 hours)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Chunking settings
|
||||||
|
code_chunk_size: int = Field(
|
||||||
|
default=500,
|
||||||
|
description="Target tokens per code chunk",
|
||||||
|
)
|
||||||
|
code_chunk_overlap: int = Field(
|
||||||
|
default=50,
|
||||||
|
description="Token overlap between code chunks",
|
||||||
|
)
|
||||||
|
markdown_chunk_size: int = Field(
|
||||||
|
default=800,
|
||||||
|
description="Target tokens per markdown chunk",
|
||||||
|
)
|
||||||
|
markdown_chunk_overlap: int = Field(
|
||||||
|
default=100,
|
||||||
|
description="Token overlap between markdown chunks",
|
||||||
|
)
|
||||||
|
text_chunk_size: int = Field(
|
||||||
|
default=400,
|
||||||
|
description="Target tokens per text chunk",
|
||||||
|
)
|
||||||
|
text_chunk_overlap: int = Field(
|
||||||
|
default=50,
|
||||||
|
description="Token overlap between text chunks",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search settings
|
||||||
|
search_default_limit: int = Field(
|
||||||
|
default=10,
|
||||||
|
description="Default number of search results",
|
||||||
|
)
|
||||||
|
search_max_limit: int = Field(
|
||||||
|
default=100,
|
||||||
|
description="Maximum number of search results",
|
||||||
|
)
|
||||||
|
semantic_threshold: float = Field(
|
||||||
|
default=0.7,
|
||||||
|
description="Minimum similarity score for semantic search",
|
||||||
|
)
|
||||||
|
hybrid_semantic_weight: float = Field(
|
||||||
|
default=0.7,
|
||||||
|
description="Weight for semantic results in hybrid search",
|
||||||
|
)
|
||||||
|
hybrid_keyword_weight: float = Field(
|
||||||
|
default=0.3,
|
||||||
|
description="Weight for keyword results in hybrid search",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Storage settings
|
||||||
|
embedding_ttl_days: int = Field(
|
||||||
|
default=30,
|
||||||
|
description="TTL for embedding records in days (0 = no expiry)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Content size limits (DoS prevention)
|
||||||
|
max_document_size: int = Field(
|
||||||
|
default=10 * 1024 * 1024, # 10 MB
|
||||||
|
description="Maximum size of a single document in bytes",
|
||||||
|
)
|
||||||
|
max_batch_size: int = Field(
|
||||||
|
default=100,
|
||||||
|
description="Maximum number of documents in a batch operation",
|
||||||
|
)
|
||||||
|
max_batch_total_size: int = Field(
|
||||||
|
default=50 * 1024 * 1024, # 50 MB
|
||||||
|
description="Maximum total size of all documents in a batch",
|
||||||
|
)
|
||||||
|
|
||||||
|
model_config = {"env_prefix": "KB_", "env_file": ".env", "extra": "ignore"}
|
||||||
|
|
||||||
|
|
||||||
|
# Global settings instance (lazy initialization)
|
||||||
|
_settings: Settings | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_settings() -> Settings:
|
||||||
|
"""Get the global settings instance."""
|
||||||
|
global _settings
|
||||||
|
if _settings is None:
|
||||||
|
_settings = Settings()
|
||||||
|
return _settings
|
||||||
|
|
||||||
|
|
||||||
|
def reset_settings() -> None:
|
||||||
|
"""Reset the global settings (for testing)."""
|
||||||
|
global _settings
|
||||||
|
_settings = None
|
||||||
|
|
||||||
|
|
||||||
|
def is_test_mode() -> bool:
|
||||||
|
"""Check if running in test mode."""
|
||||||
|
return os.getenv("IS_TEST", "").lower() in ("true", "1", "yes")
|
||||||
870
mcp-servers/knowledge-base/database.py
Normal file
870
mcp-servers/knowledge-base/database.py
Normal file
@@ -0,0 +1,870 @@
|
|||||||
|
"""
|
||||||
|
Database management for Knowledge Base MCP Server.
|
||||||
|
|
||||||
|
Handles PostgreSQL connections with pgvector extension for
|
||||||
|
vector similarity search operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import asyncpg
|
||||||
|
from pgvector.asyncpg import register_vector
|
||||||
|
|
||||||
|
from config import Settings, get_settings
|
||||||
|
from exceptions import (
|
||||||
|
CollectionNotFoundError,
|
||||||
|
DatabaseConnectionError,
|
||||||
|
DatabaseQueryError,
|
||||||
|
ErrorCode,
|
||||||
|
KnowledgeBaseError,
|
||||||
|
)
|
||||||
|
from models import (
|
||||||
|
ChunkType,
|
||||||
|
CollectionInfo,
|
||||||
|
CollectionStats,
|
||||||
|
FileType,
|
||||||
|
KnowledgeEmbedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseManager:
|
||||||
|
"""
|
||||||
|
Manages PostgreSQL connections and vector operations.
|
||||||
|
|
||||||
|
Uses asyncpg for async operations and pgvector for
|
||||||
|
vector similarity search.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, settings: Settings | None = None) -> None:
|
||||||
|
"""Initialize database manager."""
|
||||||
|
self._settings = settings or get_settings()
|
||||||
|
self._pool: asyncpg.Pool | None = None # type: ignore[type-arg]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pool(self) -> asyncpg.Pool: # type: ignore[type-arg]
|
||||||
|
"""Get connection pool, raising if not initialized."""
|
||||||
|
if self._pool is None:
|
||||||
|
raise DatabaseConnectionError("Database pool not initialized")
|
||||||
|
return self._pool
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""Initialize connection pool and create schema."""
|
||||||
|
try:
|
||||||
|
self._pool = await asyncpg.create_pool(
|
||||||
|
self._settings.database_url,
|
||||||
|
min_size=2,
|
||||||
|
max_size=self._settings.database_pool_size,
|
||||||
|
max_inactive_connection_lifetime=300,
|
||||||
|
init=self._init_connection,
|
||||||
|
)
|
||||||
|
logger.info("Database pool created successfully")
|
||||||
|
|
||||||
|
# Create schema
|
||||||
|
await self._create_schema()
|
||||||
|
logger.info("Database schema initialized")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize database: {e}")
|
||||||
|
raise DatabaseConnectionError(
|
||||||
|
message=f"Failed to initialize database: {e}",
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _init_connection(self, conn: asyncpg.Connection) -> None: # type: ignore[type-arg]
|
||||||
|
"""Initialize a connection with pgvector support."""
|
||||||
|
await register_vector(conn)
|
||||||
|
|
||||||
|
async def _create_schema(self) -> None:
|
||||||
|
"""Create database schema if not exists."""
|
||||||
|
async with self.pool.acquire() as conn:
|
||||||
|
# Enable pgvector extension
|
||||||
|
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||||
|
|
||||||
|
# Create main embeddings table
|
||||||
|
await conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS knowledge_embeddings (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
project_id VARCHAR(255) NOT NULL,
|
||||||
|
collection VARCHAR(255) NOT NULL DEFAULT 'default',
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
embedding vector(1536),
|
||||||
|
chunk_type VARCHAR(50) NOT NULL,
|
||||||
|
source_path TEXT,
|
||||||
|
start_line INTEGER,
|
||||||
|
end_line INTEGER,
|
||||||
|
file_type VARCHAR(50),
|
||||||
|
metadata JSONB DEFAULT '{}',
|
||||||
|
content_hash VARCHAR(64),
|
||||||
|
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||||
|
updated_at TIMESTAMPTZ DEFAULT NOW(),
|
||||||
|
expires_at TIMESTAMPTZ
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create indexes for common queries
|
||||||
|
await conn.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_embeddings_project_collection
|
||||||
|
ON knowledge_embeddings(project_id, collection)
|
||||||
|
""")
|
||||||
|
|
||||||
|
await conn.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_embeddings_source_path
|
||||||
|
ON knowledge_embeddings(project_id, source_path)
|
||||||
|
""")
|
||||||
|
|
||||||
|
await conn.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_embeddings_content_hash
|
||||||
|
ON knowledge_embeddings(project_id, content_hash)
|
||||||
|
""")
|
||||||
|
|
||||||
|
await conn.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_embeddings_chunk_type
|
||||||
|
ON knowledge_embeddings(project_id, chunk_type)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create HNSW index for vector similarity search
|
||||||
|
# This dramatically improves search performance
|
||||||
|
await conn.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_embeddings_vector_hnsw
|
||||||
|
ON knowledge_embeddings
|
||||||
|
USING hnsw (embedding vector_cosine_ops)
|
||||||
|
WITH (m = 16, ef_construction = 64)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create GIN index for full-text search
|
||||||
|
await conn.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_embeddings_content_fts
|
||||||
|
ON knowledge_embeddings
|
||||||
|
USING gin(to_tsvector('english', content))
|
||||||
|
""")
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the connection pool."""
|
||||||
|
if self._pool:
|
||||||
|
await self._pool.close()
|
||||||
|
self._pool = None
|
||||||
|
logger.info("Database pool closed")
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def acquire(self) -> Any:
|
||||||
|
"""Acquire a connection from the pool."""
|
||||||
|
async with self.pool.acquire() as conn:
|
||||||
|
yield conn
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compute_content_hash(content: str) -> str:
|
||||||
|
"""Compute SHA-256 hash of content for deduplication."""
|
||||||
|
return hashlib.sha256(content.encode()).hexdigest()
|
||||||
|
|
||||||
|
async def store_embedding(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
collection: str,
|
||||||
|
content: str,
|
||||||
|
embedding: list[float],
|
||||||
|
chunk_type: ChunkType,
|
||||||
|
source_path: str | None = None,
|
||||||
|
start_line: int | None = None,
|
||||||
|
end_line: int | None = None,
|
||||||
|
file_type: FileType | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Store an embedding in the database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The ID of the stored embedding.
|
||||||
|
"""
|
||||||
|
content_hash = self.compute_content_hash(content)
|
||||||
|
metadata = metadata or {}
|
||||||
|
|
||||||
|
# Calculate expiration if TTL is set
|
||||||
|
expires_at = None
|
||||||
|
if self._settings.embedding_ttl_days > 0:
|
||||||
|
expires_at = datetime.now(UTC) + timedelta(
|
||||||
|
days=self._settings.embedding_ttl_days
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with self.acquire() as conn:
|
||||||
|
# Check for duplicate content
|
||||||
|
existing = await conn.fetchval(
|
||||||
|
"""
|
||||||
|
SELECT id FROM knowledge_embeddings
|
||||||
|
WHERE project_id = $1 AND collection = $2 AND content_hash = $3
|
||||||
|
""",
|
||||||
|
project_id,
|
||||||
|
collection,
|
||||||
|
content_hash,
|
||||||
|
)
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
# Update existing embedding
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
UPDATE knowledge_embeddings
|
||||||
|
SET embedding = $1, updated_at = NOW(), expires_at = $2,
|
||||||
|
metadata = $3, source_path = $4, start_line = $5,
|
||||||
|
end_line = $6, file_type = $7
|
||||||
|
WHERE id = $8
|
||||||
|
""",
|
||||||
|
embedding,
|
||||||
|
expires_at,
|
||||||
|
metadata,
|
||||||
|
source_path,
|
||||||
|
start_line,
|
||||||
|
end_line,
|
||||||
|
file_type.value if file_type else None,
|
||||||
|
existing,
|
||||||
|
)
|
||||||
|
logger.debug(f"Updated existing embedding: {existing}")
|
||||||
|
return str(existing)
|
||||||
|
|
||||||
|
# Insert new embedding
|
||||||
|
embedding_id = await conn.fetchval(
|
||||||
|
"""
|
||||||
|
INSERT INTO knowledge_embeddings
|
||||||
|
(project_id, collection, content, embedding, chunk_type,
|
||||||
|
source_path, start_line, end_line, file_type, metadata,
|
||||||
|
content_hash, expires_at)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
|
||||||
|
RETURNING id
|
||||||
|
""",
|
||||||
|
project_id,
|
||||||
|
collection,
|
||||||
|
content,
|
||||||
|
embedding,
|
||||||
|
chunk_type.value,
|
||||||
|
source_path,
|
||||||
|
start_line,
|
||||||
|
end_line,
|
||||||
|
file_type.value if file_type else None,
|
||||||
|
metadata,
|
||||||
|
content_hash,
|
||||||
|
expires_at,
|
||||||
|
)
|
||||||
|
logger.debug(f"Stored new embedding: {embedding_id}")
|
||||||
|
return str(embedding_id)
|
||||||
|
|
||||||
|
except asyncpg.PostgresError as e:
|
||||||
|
logger.error(f"Database error storing embedding: {e}")
|
||||||
|
raise DatabaseQueryError(
|
||||||
|
message=f"Failed to store embedding: {e}",
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def store_embeddings_batch(
|
||||||
|
self,
|
||||||
|
embeddings: list[tuple[str, str, str, list[float], ChunkType, dict[str, Any]]],
|
||||||
|
) -> list[str]:
|
||||||
|
"""
|
||||||
|
Store multiple embeddings in a batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings: List of (project_id, collection, content, embedding, chunk_type, metadata)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of created embedding IDs.
|
||||||
|
"""
|
||||||
|
if not embeddings:
|
||||||
|
return []
|
||||||
|
|
||||||
|
ids = []
|
||||||
|
expires_at = None
|
||||||
|
if self._settings.embedding_ttl_days > 0:
|
||||||
|
expires_at = datetime.now(UTC) + timedelta(
|
||||||
|
days=self._settings.embedding_ttl_days
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with self.acquire() as conn:
|
||||||
|
# Wrap in transaction for all-or-nothing batch semantics
|
||||||
|
async with conn.transaction():
|
||||||
|
for project_id, collection, content, embedding, chunk_type, metadata in embeddings:
|
||||||
|
content_hash = self.compute_content_hash(content)
|
||||||
|
source_path = metadata.get("source_path")
|
||||||
|
start_line = metadata.get("start_line")
|
||||||
|
end_line = metadata.get("end_line")
|
||||||
|
file_type = metadata.get("file_type")
|
||||||
|
|
||||||
|
embedding_id = await conn.fetchval(
|
||||||
|
"""
|
||||||
|
INSERT INTO knowledge_embeddings
|
||||||
|
(project_id, collection, content, embedding, chunk_type,
|
||||||
|
source_path, start_line, end_line, file_type, metadata,
|
||||||
|
content_hash, expires_at)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
|
||||||
|
ON CONFLICT DO NOTHING
|
||||||
|
RETURNING id
|
||||||
|
""",
|
||||||
|
project_id,
|
||||||
|
collection,
|
||||||
|
content,
|
||||||
|
embedding,
|
||||||
|
chunk_type.value,
|
||||||
|
source_path,
|
||||||
|
start_line,
|
||||||
|
end_line,
|
||||||
|
file_type,
|
||||||
|
metadata,
|
||||||
|
content_hash,
|
||||||
|
expires_at,
|
||||||
|
)
|
||||||
|
if embedding_id:
|
||||||
|
ids.append(str(embedding_id))
|
||||||
|
|
||||||
|
logger.info(f"Stored {len(ids)} embeddings in batch")
|
||||||
|
return ids
|
||||||
|
|
||||||
|
except asyncpg.PostgresError as e:
|
||||||
|
logger.error(f"Database error in batch store: {e}")
|
||||||
|
raise DatabaseQueryError(
|
||||||
|
message=f"Failed to store embeddings batch: {e}",
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def semantic_search(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
query_embedding: list[float],
|
||||||
|
collection: str | None = None,
|
||||||
|
limit: int = 10,
|
||||||
|
threshold: float = 0.7,
|
||||||
|
file_types: list[FileType] | None = None,
|
||||||
|
) -> list[tuple[KnowledgeEmbedding, float]]:
|
||||||
|
"""
|
||||||
|
Perform semantic (vector) search.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (embedding, similarity_score) tuples.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with self.acquire() as conn:
|
||||||
|
# Build query with optional filters using CTE to filter by similarity
|
||||||
|
# We use a CTE to compute similarity once, then filter in outer query
|
||||||
|
inner_query = """
|
||||||
|
SELECT
|
||||||
|
id, project_id, collection, content, embedding,
|
||||||
|
chunk_type, source_path, start_line, end_line,
|
||||||
|
file_type, metadata, content_hash, created_at,
|
||||||
|
updated_at, expires_at,
|
||||||
|
1 - (embedding <=> $1) as similarity
|
||||||
|
FROM knowledge_embeddings
|
||||||
|
WHERE project_id = $2
|
||||||
|
AND (expires_at IS NULL OR expires_at > NOW())
|
||||||
|
"""
|
||||||
|
params: list[Any] = [query_embedding, project_id]
|
||||||
|
param_idx = 3
|
||||||
|
|
||||||
|
if collection:
|
||||||
|
inner_query += f" AND collection = ${param_idx}"
|
||||||
|
params.append(collection)
|
||||||
|
param_idx += 1
|
||||||
|
|
||||||
|
if file_types:
|
||||||
|
file_type_values = [ft.value for ft in file_types]
|
||||||
|
inner_query += f" AND file_type = ANY(${param_idx})"
|
||||||
|
params.append(file_type_values)
|
||||||
|
param_idx += 1
|
||||||
|
|
||||||
|
# Wrap in CTE and filter by threshold in outer query
|
||||||
|
query = f"""
|
||||||
|
WITH scored AS ({inner_query})
|
||||||
|
SELECT * FROM scored
|
||||||
|
WHERE similarity >= ${param_idx}
|
||||||
|
ORDER BY similarity DESC
|
||||||
|
LIMIT ${param_idx + 1}
|
||||||
|
"""
|
||||||
|
params.extend([threshold, limit])
|
||||||
|
|
||||||
|
rows = await conn.fetch(query, *params)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for row in rows:
|
||||||
|
embedding = KnowledgeEmbedding(
|
||||||
|
id=str(row["id"]),
|
||||||
|
project_id=row["project_id"],
|
||||||
|
collection=row["collection"],
|
||||||
|
content=row["content"],
|
||||||
|
embedding=list(row["embedding"]),
|
||||||
|
chunk_type=ChunkType(row["chunk_type"]),
|
||||||
|
source_path=row["source_path"],
|
||||||
|
start_line=row["start_line"],
|
||||||
|
end_line=row["end_line"],
|
||||||
|
file_type=FileType(row["file_type"]) if row["file_type"] else None,
|
||||||
|
metadata=row["metadata"] or {},
|
||||||
|
content_hash=row["content_hash"],
|
||||||
|
created_at=row["created_at"],
|
||||||
|
updated_at=row["updated_at"],
|
||||||
|
expires_at=row["expires_at"],
|
||||||
|
)
|
||||||
|
results.append((embedding, float(row["similarity"])))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
except asyncpg.PostgresError as e:
|
||||||
|
logger.error(f"Semantic search error: {e}")
|
||||||
|
raise DatabaseQueryError(
|
||||||
|
message=f"Semantic search failed: {e}",
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def keyword_search(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
query: str,
|
||||||
|
collection: str | None = None,
|
||||||
|
limit: int = 10,
|
||||||
|
file_types: list[FileType] | None = None,
|
||||||
|
) -> list[tuple[KnowledgeEmbedding, float]]:
|
||||||
|
"""
|
||||||
|
Perform full-text keyword search.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (embedding, relevance_score) tuples.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with self.acquire() as conn:
|
||||||
|
# Build query with optional filters
|
||||||
|
sql = """
|
||||||
|
SELECT
|
||||||
|
id, project_id, collection, content, embedding,
|
||||||
|
chunk_type, source_path, start_line, end_line,
|
||||||
|
file_type, metadata, content_hash, created_at,
|
||||||
|
updated_at, expires_at,
|
||||||
|
ts_rank(to_tsvector('english', content),
|
||||||
|
plainto_tsquery('english', $1)) as relevance
|
||||||
|
FROM knowledge_embeddings
|
||||||
|
WHERE project_id = $2
|
||||||
|
AND to_tsvector('english', content) @@ plainto_tsquery('english', $1)
|
||||||
|
AND (expires_at IS NULL OR expires_at > NOW())
|
||||||
|
"""
|
||||||
|
params: list[Any] = [query, project_id]
|
||||||
|
param_idx = 3
|
||||||
|
|
||||||
|
if collection:
|
||||||
|
sql += f" AND collection = ${param_idx}"
|
||||||
|
params.append(collection)
|
||||||
|
param_idx += 1
|
||||||
|
|
||||||
|
if file_types:
|
||||||
|
file_type_values = [ft.value for ft in file_types]
|
||||||
|
sql += f" AND file_type = ANY(${param_idx})"
|
||||||
|
params.append(file_type_values)
|
||||||
|
param_idx += 1
|
||||||
|
|
||||||
|
sql += f" ORDER BY relevance DESC LIMIT ${param_idx}"
|
||||||
|
params.append(limit)
|
||||||
|
|
||||||
|
rows = await conn.fetch(sql, *params)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for row in rows:
|
||||||
|
embedding = KnowledgeEmbedding(
|
||||||
|
id=str(row["id"]),
|
||||||
|
project_id=row["project_id"],
|
||||||
|
collection=row["collection"],
|
||||||
|
content=row["content"],
|
||||||
|
embedding=list(row["embedding"]) if row["embedding"] else [],
|
||||||
|
chunk_type=ChunkType(row["chunk_type"]),
|
||||||
|
source_path=row["source_path"],
|
||||||
|
start_line=row["start_line"],
|
||||||
|
end_line=row["end_line"],
|
||||||
|
file_type=FileType(row["file_type"]) if row["file_type"] else None,
|
||||||
|
metadata=row["metadata"] or {},
|
||||||
|
content_hash=row["content_hash"],
|
||||||
|
created_at=row["created_at"],
|
||||||
|
updated_at=row["updated_at"],
|
||||||
|
expires_at=row["expires_at"],
|
||||||
|
)
|
||||||
|
# Normalize relevance to 0-1 scale (approximate)
|
||||||
|
normalized_score = min(1.0, float(row["relevance"]))
|
||||||
|
results.append((embedding, normalized_score))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
except asyncpg.PostgresError as e:
|
||||||
|
logger.error(f"Keyword search error: {e}")
|
||||||
|
raise DatabaseQueryError(
|
||||||
|
message=f"Keyword search failed: {e}",
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete_by_source(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
source_path: str,
|
||||||
|
collection: str | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""Delete all embeddings for a source path."""
|
||||||
|
try:
|
||||||
|
async with self.acquire() as conn:
|
||||||
|
if collection:
|
||||||
|
result = await conn.execute(
|
||||||
|
"""
|
||||||
|
DELETE FROM knowledge_embeddings
|
||||||
|
WHERE project_id = $1 AND source_path = $2 AND collection = $3
|
||||||
|
""",
|
||||||
|
project_id,
|
||||||
|
source_path,
|
||||||
|
collection,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = await conn.execute(
|
||||||
|
"""
|
||||||
|
DELETE FROM knowledge_embeddings
|
||||||
|
WHERE project_id = $1 AND source_path = $2
|
||||||
|
""",
|
||||||
|
project_id,
|
||||||
|
source_path,
|
||||||
|
)
|
||||||
|
# Parse "DELETE N" result
|
||||||
|
count = int(result.split()[-1])
|
||||||
|
logger.info(f"Deleted {count} embeddings for source: {source_path}")
|
||||||
|
return count
|
||||||
|
|
||||||
|
except asyncpg.PostgresError as e:
|
||||||
|
logger.error(f"Delete error: {e}")
|
||||||
|
raise DatabaseQueryError(
|
||||||
|
message=f"Failed to delete by source: {e}",
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def replace_source_embeddings(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
source_path: str,
|
||||||
|
collection: str,
|
||||||
|
embeddings: list[tuple[str, list[float], ChunkType, dict[str, Any]]],
|
||||||
|
) -> tuple[int, list[str]]:
|
||||||
|
"""
|
||||||
|
Atomically replace all embeddings for a source path.
|
||||||
|
|
||||||
|
Deletes existing embeddings and inserts new ones in a single transaction,
|
||||||
|
preventing race conditions during document updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project ID
|
||||||
|
source_path: Source file path being updated
|
||||||
|
collection: Collection name
|
||||||
|
embeddings: List of (content, embedding, chunk_type, metadata)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (deleted_count, new_embedding_ids)
|
||||||
|
"""
|
||||||
|
expires_at = None
|
||||||
|
if self._settings.embedding_ttl_days > 0:
|
||||||
|
expires_at = datetime.now(UTC) + timedelta(
|
||||||
|
days=self._settings.embedding_ttl_days
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with self.acquire() as conn:
|
||||||
|
# Use transaction for atomic replace
|
||||||
|
async with conn.transaction():
|
||||||
|
# First, delete existing embeddings for this source
|
||||||
|
delete_result = await conn.execute(
|
||||||
|
"""
|
||||||
|
DELETE FROM knowledge_embeddings
|
||||||
|
WHERE project_id = $1 AND source_path = $2 AND collection = $3
|
||||||
|
""",
|
||||||
|
project_id,
|
||||||
|
source_path,
|
||||||
|
collection,
|
||||||
|
)
|
||||||
|
deleted_count = int(delete_result.split()[-1])
|
||||||
|
|
||||||
|
# Then insert new embeddings
|
||||||
|
new_ids = []
|
||||||
|
for content, embedding, chunk_type, metadata in embeddings:
|
||||||
|
content_hash = self.compute_content_hash(content)
|
||||||
|
start_line = metadata.get("start_line")
|
||||||
|
end_line = metadata.get("end_line")
|
||||||
|
file_type = metadata.get("file_type")
|
||||||
|
|
||||||
|
embedding_id = await conn.fetchval(
|
||||||
|
"""
|
||||||
|
INSERT INTO knowledge_embeddings
|
||||||
|
(project_id, collection, content, embedding, chunk_type,
|
||||||
|
source_path, start_line, end_line, file_type, metadata,
|
||||||
|
content_hash, expires_at)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
|
||||||
|
RETURNING id
|
||||||
|
""",
|
||||||
|
project_id,
|
||||||
|
collection,
|
||||||
|
content,
|
||||||
|
embedding,
|
||||||
|
chunk_type.value,
|
||||||
|
source_path,
|
||||||
|
start_line,
|
||||||
|
end_line,
|
||||||
|
file_type,
|
||||||
|
metadata,
|
||||||
|
content_hash,
|
||||||
|
expires_at,
|
||||||
|
)
|
||||||
|
if embedding_id:
|
||||||
|
new_ids.append(str(embedding_id))
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Replaced source {source_path}: deleted {deleted_count}, "
|
||||||
|
f"inserted {len(new_ids)} embeddings"
|
||||||
|
)
|
||||||
|
return deleted_count, new_ids
|
||||||
|
|
||||||
|
except asyncpg.PostgresError as e:
|
||||||
|
logger.error(f"Replace source error: {e}")
|
||||||
|
raise DatabaseQueryError(
|
||||||
|
message=f"Failed to replace source embeddings: {e}",
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete_collection(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
collection: str,
|
||||||
|
) -> int:
|
||||||
|
"""Delete an entire collection."""
|
||||||
|
try:
|
||||||
|
async with self.acquire() as conn:
|
||||||
|
result = await conn.execute(
|
||||||
|
"""
|
||||||
|
DELETE FROM knowledge_embeddings
|
||||||
|
WHERE project_id = $1 AND collection = $2
|
||||||
|
""",
|
||||||
|
project_id,
|
||||||
|
collection,
|
||||||
|
)
|
||||||
|
count = int(result.split()[-1])
|
||||||
|
logger.info(f"Deleted collection {collection}: {count} embeddings")
|
||||||
|
return count
|
||||||
|
|
||||||
|
except asyncpg.PostgresError as e:
|
||||||
|
logger.error(f"Delete collection error: {e}")
|
||||||
|
raise DatabaseQueryError(
|
||||||
|
message=f"Failed to delete collection: {e}",
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete_by_ids(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
chunk_ids: list[str],
|
||||||
|
) -> int:
|
||||||
|
"""Delete specific embeddings by ID."""
|
||||||
|
if not chunk_ids:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Convert string IDs to UUIDs
|
||||||
|
uuids = [uuid.UUID(cid) for cid in chunk_ids]
|
||||||
|
|
||||||
|
async with self.acquire() as conn:
|
||||||
|
result = await conn.execute(
|
||||||
|
"""
|
||||||
|
DELETE FROM knowledge_embeddings
|
||||||
|
WHERE project_id = $1 AND id = ANY($2)
|
||||||
|
""",
|
||||||
|
project_id,
|
||||||
|
uuids,
|
||||||
|
)
|
||||||
|
count = int(result.split()[-1])
|
||||||
|
logger.info(f"Deleted {count} embeddings by ID")
|
||||||
|
return count
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
raise KnowledgeBaseError(
|
||||||
|
message=f"Invalid chunk ID format: {e}",
|
||||||
|
code=ErrorCode.INVALID_REQUEST,
|
||||||
|
)
|
||||||
|
except asyncpg.PostgresError as e:
|
||||||
|
logger.error(f"Delete by IDs error: {e}")
|
||||||
|
raise DatabaseQueryError(
|
||||||
|
message=f"Failed to delete by IDs: {e}",
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_collections(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
) -> list[CollectionInfo]:
|
||||||
|
"""List all collections for a project."""
|
||||||
|
try:
|
||||||
|
async with self.acquire() as conn:
|
||||||
|
rows = await conn.fetch(
|
||||||
|
"""
|
||||||
|
SELECT
|
||||||
|
collection,
|
||||||
|
COUNT(*) as chunk_count,
|
||||||
|
COALESCE(SUM((metadata->>'token_count')::int), 0) as total_tokens,
|
||||||
|
ARRAY_AGG(DISTINCT file_type) FILTER (WHERE file_type IS NOT NULL) as file_types,
|
||||||
|
MIN(created_at) as created_at,
|
||||||
|
MAX(updated_at) as updated_at
|
||||||
|
FROM knowledge_embeddings
|
||||||
|
WHERE project_id = $1
|
||||||
|
AND (expires_at IS NULL OR expires_at > NOW())
|
||||||
|
GROUP BY collection
|
||||||
|
ORDER BY collection
|
||||||
|
""",
|
||||||
|
project_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
CollectionInfo(
|
||||||
|
name=row["collection"],
|
||||||
|
project_id=project_id,
|
||||||
|
chunk_count=row["chunk_count"],
|
||||||
|
total_tokens=row["total_tokens"] or 0,
|
||||||
|
file_types=row["file_types"] or [],
|
||||||
|
created_at=row["created_at"],
|
||||||
|
updated_at=row["updated_at"],
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
except asyncpg.PostgresError as e:
|
||||||
|
logger.error(f"List collections error: {e}")
|
||||||
|
raise DatabaseQueryError(
|
||||||
|
message=f"Failed to list collections: {e}",
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_collection_stats(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
collection: str,
|
||||||
|
) -> CollectionStats:
|
||||||
|
"""Get detailed statistics for a collection."""
|
||||||
|
try:
|
||||||
|
async with self.acquire() as conn:
|
||||||
|
# Check if collection exists
|
||||||
|
exists = await conn.fetchval(
|
||||||
|
"""
|
||||||
|
SELECT EXISTS(
|
||||||
|
SELECT 1 FROM knowledge_embeddings
|
||||||
|
WHERE project_id = $1 AND collection = $2
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
project_id,
|
||||||
|
collection,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not exists:
|
||||||
|
raise CollectionNotFoundError(collection, project_id)
|
||||||
|
|
||||||
|
# Get stats
|
||||||
|
row = await conn.fetchrow(
|
||||||
|
"""
|
||||||
|
SELECT
|
||||||
|
COUNT(*) as chunk_count,
|
||||||
|
COUNT(DISTINCT source_path) as unique_sources,
|
||||||
|
COALESCE(SUM((metadata->>'token_count')::int), 0) as total_tokens,
|
||||||
|
COALESCE(AVG(LENGTH(content)), 0) as avg_chunk_size,
|
||||||
|
MIN(created_at) as oldest_chunk,
|
||||||
|
MAX(created_at) as newest_chunk
|
||||||
|
FROM knowledge_embeddings
|
||||||
|
WHERE project_id = $1 AND collection = $2
|
||||||
|
AND (expires_at IS NULL OR expires_at > NOW())
|
||||||
|
""",
|
||||||
|
project_id,
|
||||||
|
collection,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get chunk type breakdown
|
||||||
|
chunk_rows = await conn.fetch(
|
||||||
|
"""
|
||||||
|
SELECT chunk_type, COUNT(*) as count
|
||||||
|
FROM knowledge_embeddings
|
||||||
|
WHERE project_id = $1 AND collection = $2
|
||||||
|
AND (expires_at IS NULL OR expires_at > NOW())
|
||||||
|
GROUP BY chunk_type
|
||||||
|
""",
|
||||||
|
project_id,
|
||||||
|
collection,
|
||||||
|
)
|
||||||
|
chunk_types = {r["chunk_type"]: r["count"] for r in chunk_rows}
|
||||||
|
|
||||||
|
# Get file type breakdown
|
||||||
|
file_rows = await conn.fetch(
|
||||||
|
"""
|
||||||
|
SELECT file_type, COUNT(*) as count
|
||||||
|
FROM knowledge_embeddings
|
||||||
|
WHERE project_id = $1 AND collection = $2
|
||||||
|
AND file_type IS NOT NULL
|
||||||
|
AND (expires_at IS NULL OR expires_at > NOW())
|
||||||
|
GROUP BY file_type
|
||||||
|
""",
|
||||||
|
project_id,
|
||||||
|
collection,
|
||||||
|
)
|
||||||
|
file_types = {r["file_type"]: r["count"] for r in file_rows}
|
||||||
|
|
||||||
|
return CollectionStats(
|
||||||
|
collection=collection,
|
||||||
|
project_id=project_id,
|
||||||
|
chunk_count=row["chunk_count"],
|
||||||
|
unique_sources=row["unique_sources"],
|
||||||
|
total_tokens=row["total_tokens"] or 0,
|
||||||
|
avg_chunk_size=float(row["avg_chunk_size"] or 0),
|
||||||
|
chunk_types=chunk_types,
|
||||||
|
file_types=file_types,
|
||||||
|
oldest_chunk=row["oldest_chunk"],
|
||||||
|
newest_chunk=row["newest_chunk"],
|
||||||
|
)
|
||||||
|
|
||||||
|
except CollectionNotFoundError:
|
||||||
|
raise
|
||||||
|
except asyncpg.PostgresError as e:
|
||||||
|
logger.error(f"Get collection stats error: {e}")
|
||||||
|
raise DatabaseQueryError(
|
||||||
|
message=f"Failed to get collection stats: {e}",
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def cleanup_expired(self) -> int:
|
||||||
|
"""Remove expired embeddings."""
|
||||||
|
try:
|
||||||
|
async with self.acquire() as conn:
|
||||||
|
result = await conn.execute(
|
||||||
|
"""
|
||||||
|
DELETE FROM knowledge_embeddings
|
||||||
|
WHERE expires_at IS NOT NULL AND expires_at < NOW()
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
count = int(result.split()[-1])
|
||||||
|
if count > 0:
|
||||||
|
logger.info(f"Cleaned up {count} expired embeddings")
|
||||||
|
return count
|
||||||
|
|
||||||
|
except asyncpg.PostgresError as e:
|
||||||
|
logger.error(f"Cleanup error: {e}")
|
||||||
|
raise DatabaseQueryError(
|
||||||
|
message=f"Failed to cleanup expired: {e}",
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Global database manager instance (lazy initialization)
|
||||||
|
_db_manager: DatabaseManager | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_database_manager() -> DatabaseManager:
|
||||||
|
"""Get the global database manager instance."""
|
||||||
|
global _db_manager
|
||||||
|
if _db_manager is None:
|
||||||
|
_db_manager = DatabaseManager()
|
||||||
|
return _db_manager
|
||||||
|
|
||||||
|
|
||||||
|
def reset_database_manager() -> None:
|
||||||
|
"""Reset the global database manager (for testing)."""
|
||||||
|
global _db_manager
|
||||||
|
_db_manager = None
|
||||||
426
mcp-servers/knowledge-base/embeddings.py
Normal file
426
mcp-servers/knowledge-base/embeddings.py
Normal file
@@ -0,0 +1,426 @@
|
|||||||
|
"""
|
||||||
|
Embedding generation for Knowledge Base MCP Server.
|
||||||
|
|
||||||
|
Generates vector embeddings via the LLM Gateway MCP server
|
||||||
|
with caching support using Redis.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import redis.asyncio as redis
|
||||||
|
|
||||||
|
from config import Settings, get_settings
|
||||||
|
from exceptions import (
|
||||||
|
EmbeddingDimensionMismatchError,
|
||||||
|
EmbeddingGenerationError,
|
||||||
|
ErrorCode,
|
||||||
|
KnowledgeBaseError,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingGenerator:
|
||||||
|
"""
|
||||||
|
Generates embeddings via LLM Gateway.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Batched embedding generation
|
||||||
|
- Redis caching for deduplication
|
||||||
|
- Automatic retry on transient failures
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, settings: Settings | None = None) -> None:
|
||||||
|
"""Initialize embedding generator."""
|
||||||
|
self._settings = settings or get_settings()
|
||||||
|
self._redis: redis.Redis | None = None # type: ignore[type-arg]
|
||||||
|
self._http_client: httpx.AsyncClient | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def redis_client(self) -> redis.Redis: # type: ignore[type-arg]
|
||||||
|
"""Get Redis client, raising if not initialized."""
|
||||||
|
if self._redis is None:
|
||||||
|
raise KnowledgeBaseError(
|
||||||
|
message="Redis client not initialized",
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
)
|
||||||
|
return self._redis
|
||||||
|
|
||||||
|
@property
|
||||||
|
def http_client(self) -> httpx.AsyncClient:
|
||||||
|
"""Get HTTP client, raising if not initialized."""
|
||||||
|
if self._http_client is None:
|
||||||
|
raise KnowledgeBaseError(
|
||||||
|
message="HTTP client not initialized",
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
)
|
||||||
|
return self._http_client
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""Initialize Redis and HTTP clients."""
|
||||||
|
try:
|
||||||
|
self._redis = redis.from_url(
|
||||||
|
self._settings.redis_url,
|
||||||
|
encoding="utf-8",
|
||||||
|
decode_responses=True,
|
||||||
|
)
|
||||||
|
# Test connection
|
||||||
|
await self._redis.ping() # type: ignore[misc]
|
||||||
|
logger.info("Redis connection established for embedding cache")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Redis connection failed, caching disabled: {e}")
|
||||||
|
self._redis = None
|
||||||
|
|
||||||
|
self._http_client = httpx.AsyncClient(
|
||||||
|
base_url=self._settings.llm_gateway_url,
|
||||||
|
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||||
|
)
|
||||||
|
logger.info("Embedding generator initialized")
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close connections."""
|
||||||
|
if self._redis:
|
||||||
|
await self._redis.close()
|
||||||
|
self._redis = None
|
||||||
|
|
||||||
|
if self._http_client:
|
||||||
|
await self._http_client.aclose()
|
||||||
|
self._http_client = None
|
||||||
|
|
||||||
|
logger.info("Embedding generator closed")
|
||||||
|
|
||||||
|
def _cache_key(self, text: str) -> str:
|
||||||
|
"""Generate cache key for a text."""
|
||||||
|
text_hash = hashlib.sha256(text.encode()).hexdigest()[:32]
|
||||||
|
model = self._settings.embedding_model
|
||||||
|
return f"kb:emb:{model}:{text_hash}"
|
||||||
|
|
||||||
|
async def _get_cached(self, text: str) -> list[float] | None:
|
||||||
|
"""Get cached embedding if available."""
|
||||||
|
if not self._redis:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
key = self._cache_key(text)
|
||||||
|
cached = await self._redis.get(key)
|
||||||
|
if cached:
|
||||||
|
logger.debug(f"Cache hit for embedding: {key[:20]}...")
|
||||||
|
return json.loads(cached)
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Cache read error: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _set_cached(self, text: str, embedding: list[float]) -> None:
|
||||||
|
"""Cache an embedding."""
|
||||||
|
if not self._redis:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
key = self._cache_key(text)
|
||||||
|
await self._redis.setex(
|
||||||
|
key,
|
||||||
|
self._settings.embedding_cache_ttl,
|
||||||
|
json.dumps(embedding),
|
||||||
|
)
|
||||||
|
logger.debug(f"Cached embedding: {key[:20]}...")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Cache write error: {e}")
|
||||||
|
|
||||||
|
async def _get_cached_batch(
|
||||||
|
self, texts: list[str]
|
||||||
|
) -> tuple[list[list[float] | None], list[int]]:
|
||||||
|
"""
|
||||||
|
Get cached embeddings for a batch of texts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (embeddings list with None for misses, indices of misses)
|
||||||
|
"""
|
||||||
|
if not self._redis:
|
||||||
|
return [None] * len(texts), list(range(len(texts)))
|
||||||
|
|
||||||
|
try:
|
||||||
|
keys = [self._cache_key(text) for text in texts]
|
||||||
|
cached_values = await self._redis.mget(keys)
|
||||||
|
|
||||||
|
embeddings: list[list[float] | None] = []
|
||||||
|
missing_indices: list[int] = []
|
||||||
|
|
||||||
|
for i, cached in enumerate(cached_values):
|
||||||
|
if cached:
|
||||||
|
embeddings.append(json.loads(cached))
|
||||||
|
else:
|
||||||
|
embeddings.append(None)
|
||||||
|
missing_indices.append(i)
|
||||||
|
|
||||||
|
cache_hits = len(texts) - len(missing_indices)
|
||||||
|
if cache_hits > 0:
|
||||||
|
logger.debug(f"Batch cache hits: {cache_hits}/{len(texts)}")
|
||||||
|
|
||||||
|
return embeddings, missing_indices
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Batch cache read error: {e}")
|
||||||
|
return [None] * len(texts), list(range(len(texts)))
|
||||||
|
|
||||||
|
async def _set_cached_batch(
|
||||||
|
self, texts: list[str], embeddings: list[list[float]]
|
||||||
|
) -> None:
|
||||||
|
"""Cache a batch of embeddings."""
|
||||||
|
if not self._redis or len(texts) != len(embeddings):
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
pipe = self._redis.pipeline()
|
||||||
|
for text, embedding in zip(texts, embeddings, strict=True):
|
||||||
|
key = self._cache_key(text)
|
||||||
|
pipe.setex(
|
||||||
|
key,
|
||||||
|
self._settings.embedding_cache_ttl,
|
||||||
|
json.dumps(embedding),
|
||||||
|
)
|
||||||
|
await pipe.execute()
|
||||||
|
logger.debug(f"Cached {len(texts)} embeddings in batch")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Batch cache write error: {e}")
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
project_id: str = "system",
|
||||||
|
agent_id: str = "knowledge-base",
|
||||||
|
) -> list[float]:
|
||||||
|
"""
|
||||||
|
Generate embedding for a single text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to embed
|
||||||
|
project_id: Project ID for cost attribution
|
||||||
|
agent_id: Agent ID for cost attribution
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embedding vector
|
||||||
|
"""
|
||||||
|
# Check cache first
|
||||||
|
cached = await self._get_cached(text)
|
||||||
|
if cached:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
# Generate via LLM Gateway
|
||||||
|
embeddings = await self._call_llm_gateway(
|
||||||
|
[text], project_id, agent_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not embeddings:
|
||||||
|
raise EmbeddingGenerationError(
|
||||||
|
message="No embedding returned from LLM Gateway",
|
||||||
|
texts_count=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding = embeddings[0]
|
||||||
|
|
||||||
|
# Validate dimension
|
||||||
|
if len(embedding) != self._settings.embedding_dimension:
|
||||||
|
raise EmbeddingDimensionMismatchError(
|
||||||
|
expected=self._settings.embedding_dimension,
|
||||||
|
actual=len(embedding),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache the result
|
||||||
|
await self._set_cached(text, embedding)
|
||||||
|
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
async def generate_batch(
|
||||||
|
self,
|
||||||
|
texts: list[str],
|
||||||
|
project_id: str = "system",
|
||||||
|
agent_id: str = "knowledge-base",
|
||||||
|
) -> list[list[float]]:
|
||||||
|
"""
|
||||||
|
Generate embeddings for multiple texts.
|
||||||
|
|
||||||
|
Uses caching and batches requests to LLM Gateway.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of texts to embed
|
||||||
|
project_id: Project ID for cost attribution
|
||||||
|
agent_id: Agent ID for cost attribution
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of embedding vectors
|
||||||
|
"""
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Check cache for existing embeddings
|
||||||
|
cached_embeddings, missing_indices = await self._get_cached_batch(texts)
|
||||||
|
|
||||||
|
# If all cached, return immediately
|
||||||
|
if not missing_indices:
|
||||||
|
logger.debug(f"All {len(texts)} embeddings served from cache")
|
||||||
|
return [e for e in cached_embeddings if e is not None]
|
||||||
|
|
||||||
|
# Get texts that need embedding
|
||||||
|
texts_to_embed = [texts[i] for i in missing_indices]
|
||||||
|
|
||||||
|
# Generate embeddings in batches
|
||||||
|
new_embeddings: list[list[float]] = []
|
||||||
|
batch_size = self._settings.embedding_batch_size
|
||||||
|
|
||||||
|
for i in range(0, len(texts_to_embed), batch_size):
|
||||||
|
batch = texts_to_embed[i : i + batch_size]
|
||||||
|
batch_embeddings = await self._call_llm_gateway(
|
||||||
|
batch, project_id, agent_id
|
||||||
|
)
|
||||||
|
new_embeddings.extend(batch_embeddings)
|
||||||
|
|
||||||
|
# Validate dimensions
|
||||||
|
for embedding in new_embeddings:
|
||||||
|
if len(embedding) != self._settings.embedding_dimension:
|
||||||
|
raise EmbeddingDimensionMismatchError(
|
||||||
|
expected=self._settings.embedding_dimension,
|
||||||
|
actual=len(embedding),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache new embeddings
|
||||||
|
await self._set_cached_batch(texts_to_embed, new_embeddings)
|
||||||
|
|
||||||
|
# Combine cached and new embeddings
|
||||||
|
result: list[list[float]] = []
|
||||||
|
new_idx = 0
|
||||||
|
|
||||||
|
for i in range(len(texts)):
|
||||||
|
if cached_embeddings[i] is not None:
|
||||||
|
result.append(cached_embeddings[i]) # type: ignore[arg-type]
|
||||||
|
else:
|
||||||
|
result.append(new_embeddings[new_idx])
|
||||||
|
new_idx += 1
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Generated {len(new_embeddings)} embeddings, "
|
||||||
|
f"{len(texts) - len(missing_indices)} from cache"
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def _call_llm_gateway(
|
||||||
|
self,
|
||||||
|
texts: list[str],
|
||||||
|
project_id: str,
|
||||||
|
agent_id: str,
|
||||||
|
) -> list[list[float]]:
|
||||||
|
"""
|
||||||
|
Call LLM Gateway to generate embeddings.
|
||||||
|
|
||||||
|
Uses JSON-RPC 2.0 protocol to call the embedding tool.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# JSON-RPC 2.0 request for embedding tool
|
||||||
|
request = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": "tools/call",
|
||||||
|
"params": {
|
||||||
|
"name": "generate_embeddings",
|
||||||
|
"arguments": {
|
||||||
|
"project_id": project_id,
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"texts": texts,
|
||||||
|
"model": self._settings.embedding_model,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"id": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await self.http_client.post("/mcp", json=request)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
if "error" in result:
|
||||||
|
error = result["error"]
|
||||||
|
raise EmbeddingGenerationError(
|
||||||
|
message=f"LLM Gateway error: {error.get('message', 'Unknown')}",
|
||||||
|
texts_count=len(texts),
|
||||||
|
details=error.get("data"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract embeddings from response
|
||||||
|
content = result.get("result", {}).get("content", [])
|
||||||
|
if not content:
|
||||||
|
raise EmbeddingGenerationError(
|
||||||
|
message="Empty response from LLM Gateway",
|
||||||
|
texts_count=len(texts),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse the response content
|
||||||
|
# LLM Gateway returns embeddings in content[0].text as JSON
|
||||||
|
embeddings_data = content[0].get("text", "")
|
||||||
|
if isinstance(embeddings_data, str):
|
||||||
|
embeddings_data = json.loads(embeddings_data)
|
||||||
|
|
||||||
|
embeddings = embeddings_data.get("embeddings", [])
|
||||||
|
|
||||||
|
if len(embeddings) != len(texts):
|
||||||
|
raise EmbeddingGenerationError(
|
||||||
|
message=f"Embedding count mismatch: expected {len(texts)}, got {len(embeddings)}",
|
||||||
|
texts_count=len(texts),
|
||||||
|
)
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
logger.error(f"LLM Gateway HTTP error: {e}")
|
||||||
|
raise EmbeddingGenerationError(
|
||||||
|
message=f"LLM Gateway request failed: {e.response.status_code}",
|
||||||
|
texts_count=len(texts),
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.error(f"LLM Gateway request error: {e}")
|
||||||
|
raise EmbeddingGenerationError(
|
||||||
|
message=f"Failed to connect to LLM Gateway: {e}",
|
||||||
|
texts_count=len(texts),
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"Invalid JSON response from LLM Gateway: {e}")
|
||||||
|
raise EmbeddingGenerationError(
|
||||||
|
message="Invalid response format from LLM Gateway",
|
||||||
|
texts_count=len(texts),
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
except EmbeddingGenerationError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error generating embeddings: {e}")
|
||||||
|
raise EmbeddingGenerationError(
|
||||||
|
message=f"Unexpected error: {e}",
|
||||||
|
texts_count=len(texts),
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Global embedding generator instance (lazy initialization)
|
||||||
|
_embedding_generator: EmbeddingGenerator | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedding_generator() -> EmbeddingGenerator:
|
||||||
|
"""Get the global embedding generator instance."""
|
||||||
|
global _embedding_generator
|
||||||
|
if _embedding_generator is None:
|
||||||
|
_embedding_generator = EmbeddingGenerator()
|
||||||
|
return _embedding_generator
|
||||||
|
|
||||||
|
|
||||||
|
def reset_embedding_generator() -> None:
|
||||||
|
"""Reset the global embedding generator (for testing)."""
|
||||||
|
global _embedding_generator
|
||||||
|
_embedding_generator = None
|
||||||
409
mcp-servers/knowledge-base/exceptions.py
Normal file
409
mcp-servers/knowledge-base/exceptions.py
Normal file
@@ -0,0 +1,409 @@
|
|||||||
|
"""
|
||||||
|
Custom exceptions for Knowledge Base MCP Server.
|
||||||
|
|
||||||
|
Provides structured error handling with error codes and details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorCode(str, Enum):
|
||||||
|
"""Error codes for Knowledge Base operations."""
|
||||||
|
|
||||||
|
# General errors
|
||||||
|
UNKNOWN_ERROR = "KB_UNKNOWN_ERROR"
|
||||||
|
INVALID_REQUEST = "KB_INVALID_REQUEST"
|
||||||
|
INTERNAL_ERROR = "KB_INTERNAL_ERROR"
|
||||||
|
|
||||||
|
# Database errors
|
||||||
|
DATABASE_CONNECTION_ERROR = "KB_DATABASE_CONNECTION_ERROR"
|
||||||
|
DATABASE_QUERY_ERROR = "KB_DATABASE_QUERY_ERROR"
|
||||||
|
DATABASE_INTEGRITY_ERROR = "KB_DATABASE_INTEGRITY_ERROR"
|
||||||
|
|
||||||
|
# Embedding errors
|
||||||
|
EMBEDDING_GENERATION_ERROR = "KB_EMBEDDING_GENERATION_ERROR"
|
||||||
|
EMBEDDING_DIMENSION_MISMATCH = "KB_EMBEDDING_DIMENSION_MISMATCH"
|
||||||
|
EMBEDDING_RATE_LIMIT = "KB_EMBEDDING_RATE_LIMIT"
|
||||||
|
|
||||||
|
# Chunking errors
|
||||||
|
CHUNKING_ERROR = "KB_CHUNKING_ERROR"
|
||||||
|
UNSUPPORTED_FILE_TYPE = "KB_UNSUPPORTED_FILE_TYPE"
|
||||||
|
FILE_TOO_LARGE = "KB_FILE_TOO_LARGE"
|
||||||
|
ENCODING_ERROR = "KB_ENCODING_ERROR"
|
||||||
|
|
||||||
|
# Search errors
|
||||||
|
SEARCH_ERROR = "KB_SEARCH_ERROR"
|
||||||
|
INVALID_SEARCH_TYPE = "KB_INVALID_SEARCH_TYPE"
|
||||||
|
SEARCH_TIMEOUT = "KB_SEARCH_TIMEOUT"
|
||||||
|
|
||||||
|
# Collection errors
|
||||||
|
COLLECTION_NOT_FOUND = "KB_COLLECTION_NOT_FOUND"
|
||||||
|
COLLECTION_ALREADY_EXISTS = "KB_COLLECTION_ALREADY_EXISTS"
|
||||||
|
|
||||||
|
# Document errors
|
||||||
|
DOCUMENT_NOT_FOUND = "KB_DOCUMENT_NOT_FOUND"
|
||||||
|
DOCUMENT_ALREADY_EXISTS = "KB_DOCUMENT_ALREADY_EXISTS"
|
||||||
|
INVALID_DOCUMENT = "KB_INVALID_DOCUMENT"
|
||||||
|
|
||||||
|
# Project errors
|
||||||
|
PROJECT_NOT_FOUND = "KB_PROJECT_NOT_FOUND"
|
||||||
|
PROJECT_ACCESS_DENIED = "KB_PROJECT_ACCESS_DENIED"
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeBaseError(Exception):
|
||||||
|
"""
|
||||||
|
Base exception for Knowledge Base errors.
|
||||||
|
|
||||||
|
All custom exceptions inherit from this class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
code: ErrorCode = ErrorCode.UNKNOWN_ERROR,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
cause: Exception | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize Knowledge Base error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Human-readable error message
|
||||||
|
code: Error code for programmatic handling
|
||||||
|
details: Additional error details
|
||||||
|
cause: Original exception that caused this error
|
||||||
|
"""
|
||||||
|
super().__init__(message)
|
||||||
|
self.message = message
|
||||||
|
self.code = code
|
||||||
|
self.details = details or {}
|
||||||
|
self.cause = cause
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert error to dictionary for JSON response."""
|
||||||
|
result: dict[str, Any] = {
|
||||||
|
"error": self.code.value,
|
||||||
|
"message": self.message,
|
||||||
|
}
|
||||||
|
if self.details:
|
||||||
|
result["details"] = self.details
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""String representation."""
|
||||||
|
return f"[{self.code.value}] {self.message}"
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""Detailed representation."""
|
||||||
|
return (
|
||||||
|
f"{self.__class__.__name__}("
|
||||||
|
f"message={self.message!r}, "
|
||||||
|
f"code={self.code.value!r}, "
|
||||||
|
f"details={self.details!r})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Database Errors
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseError(KnowledgeBaseError):
|
||||||
|
"""Base class for database-related errors."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
code: ErrorCode = ErrorCode.DATABASE_QUERY_ERROR,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
cause: Exception | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message, code, details, cause)
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseConnectionError(DatabaseError):
|
||||||
|
"""Failed to connect to the database."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Failed to connect to database",
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
cause: Exception | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message, ErrorCode.DATABASE_CONNECTION_ERROR, details, cause)
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseQueryError(DatabaseError):
|
||||||
|
"""Database query failed."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
query: str | None = None,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
cause: Exception | None = None,
|
||||||
|
) -> None:
|
||||||
|
details = details or {}
|
||||||
|
if query:
|
||||||
|
details["query"] = query
|
||||||
|
super().__init__(message, ErrorCode.DATABASE_QUERY_ERROR, details, cause)
|
||||||
|
|
||||||
|
|
||||||
|
# Embedding Errors
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingError(KnowledgeBaseError):
|
||||||
|
"""Base class for embedding-related errors."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
code: ErrorCode = ErrorCode.EMBEDDING_GENERATION_ERROR,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
cause: Exception | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message, code, details, cause)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingGenerationError(EmbeddingError):
|
||||||
|
"""Failed to generate embeddings."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Failed to generate embeddings",
|
||||||
|
texts_count: int | None = None,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
cause: Exception | None = None,
|
||||||
|
) -> None:
|
||||||
|
details = details or {}
|
||||||
|
if texts_count is not None:
|
||||||
|
details["texts_count"] = texts_count
|
||||||
|
super().__init__(message, ErrorCode.EMBEDDING_GENERATION_ERROR, details, cause)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingDimensionMismatchError(EmbeddingError):
|
||||||
|
"""Embedding dimension doesn't match expected dimension."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
expected: int,
|
||||||
|
actual: int,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
details = details or {}
|
||||||
|
details["expected_dimension"] = expected
|
||||||
|
details["actual_dimension"] = actual
|
||||||
|
message = f"Embedding dimension mismatch: expected {expected}, got {actual}"
|
||||||
|
super().__init__(message, ErrorCode.EMBEDDING_DIMENSION_MISMATCH, details)
|
||||||
|
|
||||||
|
|
||||||
|
# Chunking Errors
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkingError(KnowledgeBaseError):
|
||||||
|
"""Base class for chunking-related errors."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
code: ErrorCode = ErrorCode.CHUNKING_ERROR,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
cause: Exception | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message, code, details, cause)
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportedFileTypeError(ChunkingError):
|
||||||
|
"""File type is not supported for chunking."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
file_type: str,
|
||||||
|
supported_types: list[str] | None = None,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
details = details or {}
|
||||||
|
details["file_type"] = file_type
|
||||||
|
if supported_types:
|
||||||
|
details["supported_types"] = supported_types
|
||||||
|
message = f"Unsupported file type: {file_type}"
|
||||||
|
super().__init__(message, ErrorCode.UNSUPPORTED_FILE_TYPE, details)
|
||||||
|
|
||||||
|
|
||||||
|
class FileTooLargeError(ChunkingError):
|
||||||
|
"""File exceeds maximum allowed size."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
file_size: int,
|
||||||
|
max_size: int,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
details = details or {}
|
||||||
|
details["file_size"] = file_size
|
||||||
|
details["max_size"] = max_size
|
||||||
|
message = f"File too large: {file_size} bytes exceeds limit of {max_size} bytes"
|
||||||
|
super().__init__(message, ErrorCode.FILE_TOO_LARGE, details)
|
||||||
|
|
||||||
|
|
||||||
|
class EncodingError(ChunkingError):
|
||||||
|
"""Failed to decode file content."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Failed to decode file content",
|
||||||
|
encoding: str | None = None,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
cause: Exception | None = None,
|
||||||
|
) -> None:
|
||||||
|
details = details or {}
|
||||||
|
if encoding:
|
||||||
|
details["encoding"] = encoding
|
||||||
|
super().__init__(message, ErrorCode.ENCODING_ERROR, details, cause)
|
||||||
|
|
||||||
|
|
||||||
|
# Search Errors
|
||||||
|
|
||||||
|
|
||||||
|
class SearchError(KnowledgeBaseError):
|
||||||
|
"""Base class for search-related errors."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
code: ErrorCode = ErrorCode.SEARCH_ERROR,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
cause: Exception | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message, code, details, cause)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidSearchTypeError(SearchError):
|
||||||
|
"""Invalid search type specified."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
search_type: str,
|
||||||
|
valid_types: list[str] | None = None,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
details = details or {}
|
||||||
|
details["search_type"] = search_type
|
||||||
|
if valid_types:
|
||||||
|
details["valid_types"] = valid_types
|
||||||
|
message = f"Invalid search type: {search_type}"
|
||||||
|
super().__init__(message, ErrorCode.INVALID_SEARCH_TYPE, details)
|
||||||
|
|
||||||
|
|
||||||
|
class SearchTimeoutError(SearchError):
|
||||||
|
"""Search operation timed out."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
timeout: float,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
details = details or {}
|
||||||
|
details["timeout"] = timeout
|
||||||
|
message = f"Search timed out after {timeout} seconds"
|
||||||
|
super().__init__(message, ErrorCode.SEARCH_TIMEOUT, details)
|
||||||
|
|
||||||
|
|
||||||
|
# Collection Errors
|
||||||
|
|
||||||
|
|
||||||
|
class CollectionError(KnowledgeBaseError):
|
||||||
|
"""Base class for collection-related errors."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CollectionNotFoundError(CollectionError):
|
||||||
|
"""Collection does not exist."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
collection: str,
|
||||||
|
project_id: str | None = None,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
details = details or {}
|
||||||
|
details["collection"] = collection
|
||||||
|
if project_id:
|
||||||
|
details["project_id"] = project_id
|
||||||
|
message = f"Collection not found: {collection}"
|
||||||
|
super().__init__(message, ErrorCode.COLLECTION_NOT_FOUND, details)
|
||||||
|
|
||||||
|
|
||||||
|
# Document Errors
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentError(KnowledgeBaseError):
|
||||||
|
"""Base class for document-related errors."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentNotFoundError(DocumentError):
|
||||||
|
"""Document does not exist."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
source_path: str,
|
||||||
|
project_id: str | None = None,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
details = details or {}
|
||||||
|
details["source_path"] = source_path
|
||||||
|
if project_id:
|
||||||
|
details["project_id"] = project_id
|
||||||
|
message = f"Document not found: {source_path}"
|
||||||
|
super().__init__(message, ErrorCode.DOCUMENT_NOT_FOUND, details)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidDocumentError(DocumentError):
|
||||||
|
"""Document content is invalid."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Invalid document content",
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
cause: Exception | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message, ErrorCode.INVALID_DOCUMENT, details, cause)
|
||||||
|
|
||||||
|
|
||||||
|
# Project Errors
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectError(KnowledgeBaseError):
|
||||||
|
"""Base class for project-related errors."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectNotFoundError(ProjectError):
|
||||||
|
"""Project does not exist."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
details = details or {}
|
||||||
|
details["project_id"] = project_id
|
||||||
|
message = f"Project not found: {project_id}"
|
||||||
|
super().__init__(message, ErrorCode.PROJECT_NOT_FOUND, details)
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectAccessDeniedError(ProjectError):
|
||||||
|
"""Access to project is denied."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
details = details or {}
|
||||||
|
details["project_id"] = project_id
|
||||||
|
message = f"Access denied to project: {project_id}"
|
||||||
|
super().__init__(message, ErrorCode.PROJECT_ACCESS_DENIED, details)
|
||||||
321
mcp-servers/knowledge-base/models.py
Normal file
321
mcp-servers/knowledge-base/models.py
Normal file
@@ -0,0 +1,321 @@
|
|||||||
|
"""
|
||||||
|
Data models for Knowledge Base MCP Server.
|
||||||
|
|
||||||
|
Defines database models, Pydantic schemas, and data structures
|
||||||
|
for RAG operations with pgvector.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class SearchType(str, Enum):
|
||||||
|
"""Types of search supported."""
|
||||||
|
|
||||||
|
SEMANTIC = "semantic" # Vector similarity search
|
||||||
|
KEYWORD = "keyword" # Full-text search
|
||||||
|
HYBRID = "hybrid" # Combined semantic + keyword
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkType(str, Enum):
|
||||||
|
"""Types of content chunks."""
|
||||||
|
|
||||||
|
CODE = "code"
|
||||||
|
MARKDOWN = "markdown"
|
||||||
|
TEXT = "text"
|
||||||
|
DOCUMENTATION = "documentation"
|
||||||
|
|
||||||
|
|
||||||
|
class FileType(str, Enum):
|
||||||
|
"""Supported file types for chunking."""
|
||||||
|
|
||||||
|
PYTHON = "python"
|
||||||
|
JAVASCRIPT = "javascript"
|
||||||
|
TYPESCRIPT = "typescript"
|
||||||
|
GO = "go"
|
||||||
|
RUST = "rust"
|
||||||
|
JAVA = "java"
|
||||||
|
MARKDOWN = "markdown"
|
||||||
|
TEXT = "text"
|
||||||
|
JSON = "json"
|
||||||
|
YAML = "yaml"
|
||||||
|
TOML = "toml"
|
||||||
|
|
||||||
|
|
||||||
|
# File extension to FileType mapping
|
||||||
|
FILE_EXTENSION_MAP: dict[str, FileType] = {
|
||||||
|
".py": FileType.PYTHON,
|
||||||
|
".js": FileType.JAVASCRIPT,
|
||||||
|
".jsx": FileType.JAVASCRIPT,
|
||||||
|
".ts": FileType.TYPESCRIPT,
|
||||||
|
".tsx": FileType.TYPESCRIPT,
|
||||||
|
".go": FileType.GO,
|
||||||
|
".rs": FileType.RUST,
|
||||||
|
".java": FileType.JAVA,
|
||||||
|
".md": FileType.MARKDOWN,
|
||||||
|
".mdx": FileType.MARKDOWN,
|
||||||
|
".txt": FileType.TEXT,
|
||||||
|
".json": FileType.JSON,
|
||||||
|
".yaml": FileType.YAML,
|
||||||
|
".yml": FileType.YAML,
|
||||||
|
".toml": FileType.TOML,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Chunk:
|
||||||
|
"""A chunk of content ready for embedding."""
|
||||||
|
|
||||||
|
content: str
|
||||||
|
chunk_type: ChunkType
|
||||||
|
file_type: FileType | None = None
|
||||||
|
source_path: str | None = None
|
||||||
|
start_line: int | None = None
|
||||||
|
end_line: int | None = None
|
||||||
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
token_count: int = 0
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
"content": self.content,
|
||||||
|
"chunk_type": self.chunk_type.value,
|
||||||
|
"file_type": self.file_type.value if self.file_type else None,
|
||||||
|
"source_path": self.source_path,
|
||||||
|
"start_line": self.start_line,
|
||||||
|
"end_line": self.end_line,
|
||||||
|
"metadata": self.metadata,
|
||||||
|
"token_count": self.token_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class KnowledgeEmbedding:
|
||||||
|
"""
|
||||||
|
A knowledge embedding stored in the database.
|
||||||
|
|
||||||
|
Represents a chunk of content with its vector embedding.
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
project_id: str
|
||||||
|
collection: str
|
||||||
|
content: str
|
||||||
|
embedding: list[float]
|
||||||
|
chunk_type: ChunkType
|
||||||
|
source_path: str | None = None
|
||||||
|
start_line: int | None = None
|
||||||
|
end_line: int | None = None
|
||||||
|
file_type: FileType | None = None
|
||||||
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
content_hash: str | None = None
|
||||||
|
created_at: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||||
|
updated_at: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||||
|
expires_at: datetime | None = None
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary (excluding embedding for size)."""
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"project_id": self.project_id,
|
||||||
|
"collection": self.collection,
|
||||||
|
"content": self.content,
|
||||||
|
"chunk_type": self.chunk_type.value,
|
||||||
|
"source_path": self.source_path,
|
||||||
|
"start_line": self.start_line,
|
||||||
|
"end_line": self.end_line,
|
||||||
|
"file_type": self.file_type.value if self.file_type else None,
|
||||||
|
"metadata": self.metadata,
|
||||||
|
"content_hash": self.content_hash,
|
||||||
|
"created_at": self.created_at.isoformat(),
|
||||||
|
"updated_at": self.updated_at.isoformat(),
|
||||||
|
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Pydantic Request/Response Models
|
||||||
|
|
||||||
|
|
||||||
|
class IngestRequest(BaseModel):
|
||||||
|
"""Request to ingest content into the knowledge base."""
|
||||||
|
|
||||||
|
project_id: str = Field(..., description="Project ID for scoping")
|
||||||
|
agent_id: str = Field(..., description="Agent ID making the request")
|
||||||
|
content: str = Field(..., description="Content to ingest")
|
||||||
|
source_path: str | None = Field(
|
||||||
|
default=None, description="Source file path for reference"
|
||||||
|
)
|
||||||
|
collection: str = Field(
|
||||||
|
default="default", description="Collection to store in"
|
||||||
|
)
|
||||||
|
chunk_type: ChunkType = Field(
|
||||||
|
default=ChunkType.TEXT, description="Type of content"
|
||||||
|
)
|
||||||
|
file_type: FileType | None = Field(
|
||||||
|
default=None, description="File type for code chunking"
|
||||||
|
)
|
||||||
|
metadata: dict[str, Any] = Field(
|
||||||
|
default_factory=dict, description="Additional metadata"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class IngestResult(BaseModel):
|
||||||
|
"""Result of an ingest operation."""
|
||||||
|
|
||||||
|
success: bool = Field(..., description="Whether ingest succeeded")
|
||||||
|
chunks_created: int = Field(default=0, description="Number of chunks created")
|
||||||
|
embeddings_generated: int = Field(
|
||||||
|
default=0, description="Number of embeddings generated"
|
||||||
|
)
|
||||||
|
source_path: str | None = Field(default=None, description="Source path ingested")
|
||||||
|
collection: str = Field(default="default", description="Collection stored in")
|
||||||
|
chunk_ids: list[str] = Field(
|
||||||
|
default_factory=list, description="IDs of created chunks"
|
||||||
|
)
|
||||||
|
error: str | None = Field(default=None, description="Error message if failed")
|
||||||
|
|
||||||
|
|
||||||
|
class SearchRequest(BaseModel):
|
||||||
|
"""Request to search the knowledge base."""
|
||||||
|
|
||||||
|
project_id: str = Field(..., description="Project ID for scoping")
|
||||||
|
agent_id: str = Field(..., description="Agent ID making the request")
|
||||||
|
query: str = Field(..., description="Search query")
|
||||||
|
search_type: SearchType = Field(
|
||||||
|
default=SearchType.HYBRID, description="Type of search"
|
||||||
|
)
|
||||||
|
collection: str | None = Field(
|
||||||
|
default=None, description="Collection to search (None = all)"
|
||||||
|
)
|
||||||
|
limit: int = Field(default=10, ge=1, le=100, description="Max results")
|
||||||
|
threshold: float = Field(
|
||||||
|
default=0.7, ge=0.0, le=1.0, description="Minimum similarity score"
|
||||||
|
)
|
||||||
|
file_types: list[FileType] | None = Field(
|
||||||
|
default=None, description="Filter by file types"
|
||||||
|
)
|
||||||
|
include_metadata: bool = Field(
|
||||||
|
default=True, description="Include metadata in results"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SearchResult(BaseModel):
|
||||||
|
"""A single search result."""
|
||||||
|
|
||||||
|
id: str = Field(..., description="Chunk ID")
|
||||||
|
content: str = Field(..., description="Chunk content")
|
||||||
|
score: float = Field(..., description="Relevance score (0-1)")
|
||||||
|
source_path: str | None = Field(default=None, description="Source file path")
|
||||||
|
start_line: int | None = Field(default=None, description="Start line in source")
|
||||||
|
end_line: int | None = Field(default=None, description="End line in source")
|
||||||
|
chunk_type: str = Field(..., description="Type of chunk")
|
||||||
|
file_type: str | None = Field(default=None, description="File type")
|
||||||
|
collection: str = Field(..., description="Collection name")
|
||||||
|
metadata: dict[str, Any] = Field(
|
||||||
|
default_factory=dict, description="Additional metadata"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_embedding(
|
||||||
|
cls, embedding: KnowledgeEmbedding, score: float
|
||||||
|
) -> "SearchResult":
|
||||||
|
"""Create SearchResult from KnowledgeEmbedding."""
|
||||||
|
return cls(
|
||||||
|
id=embedding.id,
|
||||||
|
content=embedding.content,
|
||||||
|
score=score,
|
||||||
|
source_path=embedding.source_path,
|
||||||
|
start_line=embedding.start_line,
|
||||||
|
end_line=embedding.end_line,
|
||||||
|
chunk_type=embedding.chunk_type.value,
|
||||||
|
file_type=embedding.file_type.value if embedding.file_type else None,
|
||||||
|
collection=embedding.collection,
|
||||||
|
metadata=embedding.metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SearchResponse(BaseModel):
|
||||||
|
"""Response from a search operation."""
|
||||||
|
|
||||||
|
query: str = Field(..., description="Original query")
|
||||||
|
search_type: str = Field(..., description="Type of search performed")
|
||||||
|
results: list[SearchResult] = Field(
|
||||||
|
default_factory=list, description="Search results"
|
||||||
|
)
|
||||||
|
total_results: int = Field(default=0, description="Total results found")
|
||||||
|
search_time_ms: float = Field(default=0.0, description="Search time in ms")
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteRequest(BaseModel):
|
||||||
|
"""Request to delete from the knowledge base."""
|
||||||
|
|
||||||
|
project_id: str = Field(..., description="Project ID for scoping")
|
||||||
|
agent_id: str = Field(..., description="Agent ID making the request")
|
||||||
|
source_path: str | None = Field(
|
||||||
|
default=None, description="Delete by source path"
|
||||||
|
)
|
||||||
|
collection: str | None = Field(
|
||||||
|
default=None, description="Delete entire collection"
|
||||||
|
)
|
||||||
|
chunk_ids: list[str] | None = Field(
|
||||||
|
default=None, description="Delete specific chunks"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteResult(BaseModel):
|
||||||
|
"""Result of a delete operation."""
|
||||||
|
|
||||||
|
success: bool = Field(..., description="Whether delete succeeded")
|
||||||
|
chunks_deleted: int = Field(default=0, description="Number of chunks deleted")
|
||||||
|
error: str | None = Field(default=None, description="Error message if failed")
|
||||||
|
|
||||||
|
|
||||||
|
class CollectionInfo(BaseModel):
|
||||||
|
"""Information about a collection."""
|
||||||
|
|
||||||
|
name: str = Field(..., description="Collection name")
|
||||||
|
project_id: str = Field(..., description="Project ID")
|
||||||
|
chunk_count: int = Field(default=0, description="Number of chunks")
|
||||||
|
total_tokens: int = Field(default=0, description="Total tokens stored")
|
||||||
|
file_types: list[str] = Field(
|
||||||
|
default_factory=list, description="File types in collection"
|
||||||
|
)
|
||||||
|
created_at: datetime = Field(..., description="Creation time")
|
||||||
|
updated_at: datetime = Field(..., description="Last update time")
|
||||||
|
|
||||||
|
|
||||||
|
class ListCollectionsResponse(BaseModel):
|
||||||
|
"""Response for listing collections."""
|
||||||
|
|
||||||
|
project_id: str = Field(..., description="Project ID")
|
||||||
|
collections: list[CollectionInfo] = Field(
|
||||||
|
default_factory=list, description="Collections in project"
|
||||||
|
)
|
||||||
|
total_collections: int = Field(default=0, description="Total count")
|
||||||
|
|
||||||
|
|
||||||
|
class CollectionStats(BaseModel):
|
||||||
|
"""Statistics for a collection."""
|
||||||
|
|
||||||
|
collection: str = Field(..., description="Collection name")
|
||||||
|
project_id: str = Field(..., description="Project ID")
|
||||||
|
chunk_count: int = Field(default=0, description="Number of chunks")
|
||||||
|
unique_sources: int = Field(default=0, description="Unique source files")
|
||||||
|
total_tokens: int = Field(default=0, description="Total tokens")
|
||||||
|
avg_chunk_size: float = Field(default=0.0, description="Average chunk size")
|
||||||
|
chunk_types: dict[str, int] = Field(
|
||||||
|
default_factory=dict, description="Count by chunk type"
|
||||||
|
)
|
||||||
|
file_types: dict[str, int] = Field(
|
||||||
|
default_factory=dict, description="Count by file type"
|
||||||
|
)
|
||||||
|
oldest_chunk: datetime | None = Field(
|
||||||
|
default=None, description="Oldest chunk timestamp"
|
||||||
|
)
|
||||||
|
newest_chunk: datetime | None = Field(
|
||||||
|
default=None, description="Newest chunk timestamp"
|
||||||
|
)
|
||||||
@@ -4,21 +4,101 @@ version = "0.1.0"
|
|||||||
description = "Syndarix Knowledge Base MCP Server - RAG with pgvector for semantic search"
|
description = "Syndarix Knowledge Base MCP Server - RAG with pgvector for semantic search"
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"fastmcp>=0.1.0",
|
"fastmcp>=2.0.0",
|
||||||
"asyncpg>=0.29.0",
|
"asyncpg>=0.29.0",
|
||||||
"pgvector>=0.3.0",
|
"pgvector>=0.3.0",
|
||||||
"redis>=5.0.0",
|
"redis>=5.0.0",
|
||||||
"pydantic>=2.0.0",
|
"pydantic>=2.0.0",
|
||||||
"pydantic-settings>=2.0.0",
|
"pydantic-settings>=2.0.0",
|
||||||
|
"tiktoken>=0.7.0",
|
||||||
|
"httpx>=0.27.0",
|
||||||
|
"uvicorn>=0.30.0",
|
||||||
|
"fastapi>=0.115.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
dev = [
|
dev = [
|
||||||
"pytest>=8.0.0",
|
"pytest>=8.0.0",
|
||||||
"pytest-asyncio>=0.23.0",
|
"pytest-asyncio>=0.24.0",
|
||||||
|
"pytest-cov>=5.0.0",
|
||||||
|
"fakeredis>=2.25.0",
|
||||||
"ruff>=0.8.0",
|
"ruff>=0.8.0",
|
||||||
|
"mypy>=1.11.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
knowledge-base = "server:main"
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["hatchling"]
|
||||||
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.wheel]
|
||||||
|
packages = ["."]
|
||||||
|
exclude = ["tests/", "*.md", "Dockerfile"]
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.sdist]
|
||||||
|
include = ["*.py", "pyproject.toml"]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
target-version = "py312"
|
target-version = "py312"
|
||||||
line-length = 88
|
line-length = 88
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
select = [
|
||||||
|
"E", # pycodestyle errors
|
||||||
|
"W", # pycodestyle warnings
|
||||||
|
"F", # pyflakes
|
||||||
|
"I", # isort
|
||||||
|
"B", # flake8-bugbear
|
||||||
|
"C4", # flake8-comprehensions
|
||||||
|
"UP", # pyupgrade
|
||||||
|
"ARG", # flake8-unused-arguments
|
||||||
|
"SIM", # flake8-simplify
|
||||||
|
]
|
||||||
|
ignore = [
|
||||||
|
"E501", # line too long (handled by formatter)
|
||||||
|
"B008", # do not perform function calls in argument defaults
|
||||||
|
"B904", # raise from in except (too noisy)
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.ruff.lint.isort]
|
||||||
|
known-first-party = ["config", "models", "exceptions", "database", "embeddings", "chunking", "search", "collections"]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
asyncio_default_fixture_loop_scope = "function"
|
||||||
|
testpaths = ["tests"]
|
||||||
|
addopts = "-v --tb=short"
|
||||||
|
filterwarnings = [
|
||||||
|
"ignore::DeprecationWarning",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.coverage.run]
|
||||||
|
source = ["."]
|
||||||
|
omit = ["tests/*", "conftest.py"]
|
||||||
|
branch = true
|
||||||
|
|
||||||
|
[tool.coverage.report]
|
||||||
|
exclude_lines = [
|
||||||
|
"pragma: no cover",
|
||||||
|
"def __repr__",
|
||||||
|
"raise NotImplementedError",
|
||||||
|
"if TYPE_CHECKING:",
|
||||||
|
"if __name__ == .__main__.:",
|
||||||
|
]
|
||||||
|
fail_under = 55
|
||||||
|
show_missing = true
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
python_version = "3.12"
|
||||||
|
warn_return_any = false
|
||||||
|
warn_unused_ignores = false
|
||||||
|
disallow_untyped_defs = true
|
||||||
|
ignore_missing_imports = true
|
||||||
|
plugins = ["pydantic.mypy"]
|
||||||
|
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
module = "tests.*"
|
||||||
|
disallow_untyped_defs = false
|
||||||
|
ignore_errors = true
|
||||||
|
|||||||
289
mcp-servers/knowledge-base/search.py
Normal file
289
mcp-servers/knowledge-base/search.py
Normal file
@@ -0,0 +1,289 @@
|
|||||||
|
"""
|
||||||
|
Search implementations for Knowledge Base MCP Server.
|
||||||
|
|
||||||
|
Provides semantic (vector), keyword (full-text), and hybrid search
|
||||||
|
capabilities over the knowledge base.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
from config import Settings, get_settings
|
||||||
|
from database import DatabaseManager, get_database_manager
|
||||||
|
from embeddings import EmbeddingGenerator, get_embedding_generator
|
||||||
|
from exceptions import InvalidSearchTypeError, SearchError
|
||||||
|
from models import (
|
||||||
|
SearchRequest,
|
||||||
|
SearchResponse,
|
||||||
|
SearchResult,
|
||||||
|
SearchType,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SearchEngine:
|
||||||
|
"""
|
||||||
|
Unified search engine supporting multiple search types.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Semantic search using vector similarity
|
||||||
|
- Keyword search using full-text search
|
||||||
|
- Hybrid search combining both approaches
|
||||||
|
- Configurable result fusion and weighting
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
settings: Settings | None = None,
|
||||||
|
database: DatabaseManager | None = None,
|
||||||
|
embeddings: EmbeddingGenerator | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize search engine."""
|
||||||
|
self._settings = settings or get_settings()
|
||||||
|
self._database = database
|
||||||
|
self._embeddings = embeddings
|
||||||
|
|
||||||
|
@property
|
||||||
|
def database(self) -> DatabaseManager:
|
||||||
|
"""Get database manager."""
|
||||||
|
if self._database is None:
|
||||||
|
self._database = get_database_manager()
|
||||||
|
return self._database
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embeddings(self) -> EmbeddingGenerator:
|
||||||
|
"""Get embedding generator."""
|
||||||
|
if self._embeddings is None:
|
||||||
|
self._embeddings = get_embedding_generator()
|
||||||
|
return self._embeddings
|
||||||
|
|
||||||
|
async def search(self, request: SearchRequest) -> SearchResponse:
|
||||||
|
"""
|
||||||
|
Execute a search request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Search request with query and options
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Search response with results
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if request.search_type == SearchType.SEMANTIC:
|
||||||
|
results = await self._semantic_search(request)
|
||||||
|
elif request.search_type == SearchType.KEYWORD:
|
||||||
|
results = await self._keyword_search(request)
|
||||||
|
elif request.search_type == SearchType.HYBRID:
|
||||||
|
results = await self._hybrid_search(request)
|
||||||
|
else:
|
||||||
|
raise InvalidSearchTypeError(
|
||||||
|
search_type=request.search_type,
|
||||||
|
valid_types=[t.value for t in SearchType],
|
||||||
|
)
|
||||||
|
|
||||||
|
search_time_ms = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Search completed: type={request.search_type.value}, "
|
||||||
|
f"results={len(results)}, time={search_time_ms:.1f}ms"
|
||||||
|
)
|
||||||
|
|
||||||
|
return SearchResponse(
|
||||||
|
query=request.query,
|
||||||
|
search_type=request.search_type.value,
|
||||||
|
results=results,
|
||||||
|
total_results=len(results),
|
||||||
|
search_time_ms=search_time_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
except InvalidSearchTypeError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Search error: {e}")
|
||||||
|
raise SearchError(
|
||||||
|
message=f"Search failed: {e}",
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _semantic_search(self, request: SearchRequest) -> list[SearchResult]:
|
||||||
|
"""Execute semantic (vector) search."""
|
||||||
|
# Generate embedding for query
|
||||||
|
query_embedding = await self.embeddings.generate(
|
||||||
|
text=request.query,
|
||||||
|
project_id=request.project_id,
|
||||||
|
agent_id=request.agent_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search database
|
||||||
|
results = await self.database.semantic_search(
|
||||||
|
project_id=request.project_id,
|
||||||
|
query_embedding=query_embedding,
|
||||||
|
collection=request.collection,
|
||||||
|
limit=request.limit,
|
||||||
|
threshold=request.threshold,
|
||||||
|
file_types=request.file_types,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to SearchResult
|
||||||
|
return [
|
||||||
|
SearchResult.from_embedding(embedding, score)
|
||||||
|
for embedding, score in results
|
||||||
|
]
|
||||||
|
|
||||||
|
async def _keyword_search(self, request: SearchRequest) -> list[SearchResult]:
|
||||||
|
"""Execute keyword (full-text) search."""
|
||||||
|
results = await self.database.keyword_search(
|
||||||
|
project_id=request.project_id,
|
||||||
|
query=request.query,
|
||||||
|
collection=request.collection,
|
||||||
|
limit=request.limit,
|
||||||
|
file_types=request.file_types,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter by threshold (keyword search scores are normalized)
|
||||||
|
filtered = [
|
||||||
|
(emb, score) for emb, score in results
|
||||||
|
if score >= request.threshold
|
||||||
|
]
|
||||||
|
|
||||||
|
return [
|
||||||
|
SearchResult.from_embedding(embedding, score)
|
||||||
|
for embedding, score in filtered
|
||||||
|
]
|
||||||
|
|
||||||
|
async def _hybrid_search(self, request: SearchRequest) -> list[SearchResult]:
|
||||||
|
"""
|
||||||
|
Execute hybrid search combining semantic and keyword.
|
||||||
|
|
||||||
|
Uses Reciprocal Rank Fusion (RRF) for result combination.
|
||||||
|
Executes both searches concurrently for better performance.
|
||||||
|
"""
|
||||||
|
# Execute both searches with higher limits for fusion
|
||||||
|
fusion_limit = min(request.limit * 2, 100)
|
||||||
|
|
||||||
|
# Create modified request for sub-searches
|
||||||
|
semantic_request = SearchRequest(
|
||||||
|
project_id=request.project_id,
|
||||||
|
agent_id=request.agent_id,
|
||||||
|
query=request.query,
|
||||||
|
search_type=SearchType.SEMANTIC,
|
||||||
|
collection=request.collection,
|
||||||
|
limit=fusion_limit,
|
||||||
|
threshold=request.threshold * 0.8, # Lower threshold for fusion
|
||||||
|
file_types=request.file_types,
|
||||||
|
include_metadata=request.include_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
keyword_request = SearchRequest(
|
||||||
|
project_id=request.project_id,
|
||||||
|
agent_id=request.agent_id,
|
||||||
|
query=request.query,
|
||||||
|
search_type=SearchType.KEYWORD,
|
||||||
|
collection=request.collection,
|
||||||
|
limit=fusion_limit,
|
||||||
|
threshold=0.0, # No threshold for keyword, we'll filter after fusion
|
||||||
|
file_types=request.file_types,
|
||||||
|
include_metadata=request.include_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute searches concurrently for better performance
|
||||||
|
semantic_results, keyword_results = await asyncio.gather(
|
||||||
|
self._semantic_search(semantic_request),
|
||||||
|
self._keyword_search(keyword_request),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fuse results using RRF
|
||||||
|
fused = self._reciprocal_rank_fusion(
|
||||||
|
semantic_results=semantic_results,
|
||||||
|
keyword_results=keyword_results,
|
||||||
|
semantic_weight=self._settings.hybrid_semantic_weight,
|
||||||
|
keyword_weight=self._settings.hybrid_keyword_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter by threshold and limit
|
||||||
|
filtered = [
|
||||||
|
result for result in fused
|
||||||
|
if result.score >= request.threshold
|
||||||
|
][:request.limit]
|
||||||
|
|
||||||
|
return filtered
|
||||||
|
|
||||||
|
def _reciprocal_rank_fusion(
|
||||||
|
self,
|
||||||
|
semantic_results: list[SearchResult],
|
||||||
|
keyword_results: list[SearchResult],
|
||||||
|
semantic_weight: float = 0.7,
|
||||||
|
keyword_weight: float = 0.3,
|
||||||
|
k: int = 60, # RRF constant
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
"""
|
||||||
|
Combine results using Reciprocal Rank Fusion.
|
||||||
|
|
||||||
|
RRF score = sum(weight / (k + rank)) for each result list.
|
||||||
|
"""
|
||||||
|
# Calculate RRF scores
|
||||||
|
scores: dict[str, float] = {}
|
||||||
|
results_by_id: dict[str, SearchResult] = {}
|
||||||
|
|
||||||
|
# Process semantic results
|
||||||
|
for rank, result in enumerate(semantic_results, start=1):
|
||||||
|
rrf_score = semantic_weight / (k + rank)
|
||||||
|
scores[result.id] = scores.get(result.id, 0) + rrf_score
|
||||||
|
results_by_id[result.id] = result
|
||||||
|
|
||||||
|
# Process keyword results
|
||||||
|
for rank, result in enumerate(keyword_results, start=1):
|
||||||
|
rrf_score = keyword_weight / (k + rank)
|
||||||
|
scores[result.id] = scores.get(result.id, 0) + rrf_score
|
||||||
|
if result.id not in results_by_id:
|
||||||
|
results_by_id[result.id] = result
|
||||||
|
|
||||||
|
# Sort by combined score
|
||||||
|
sorted_ids = sorted(scores.keys(), key=lambda x: scores[x], reverse=True)
|
||||||
|
|
||||||
|
# Normalize scores to 0-1 range
|
||||||
|
max_score = max(scores.values()) if scores else 1.0
|
||||||
|
|
||||||
|
# Create final results with normalized scores
|
||||||
|
final_results: list[SearchResult] = []
|
||||||
|
for result_id in sorted_ids:
|
||||||
|
result = results_by_id[result_id]
|
||||||
|
normalized_score = scores[result_id] / max_score
|
||||||
|
# Create new result with updated score
|
||||||
|
final_results.append(
|
||||||
|
SearchResult(
|
||||||
|
id=result.id,
|
||||||
|
content=result.content,
|
||||||
|
score=normalized_score,
|
||||||
|
source_path=result.source_path,
|
||||||
|
start_line=result.start_line,
|
||||||
|
end_line=result.end_line,
|
||||||
|
chunk_type=result.chunk_type,
|
||||||
|
file_type=result.file_type,
|
||||||
|
collection=result.collection,
|
||||||
|
metadata=result.metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return final_results
|
||||||
|
|
||||||
|
|
||||||
|
# Global search engine instance (lazy initialization)
|
||||||
|
_search_engine: SearchEngine | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_search_engine() -> SearchEngine:
|
||||||
|
"""Get the global search engine instance."""
|
||||||
|
global _search_engine
|
||||||
|
if _search_engine is None:
|
||||||
|
_search_engine = SearchEngine()
|
||||||
|
return _search_engine
|
||||||
|
|
||||||
|
|
||||||
|
def reset_search_engine() -> None:
|
||||||
|
"""Reset the global search engine (for testing)."""
|
||||||
|
global _search_engine
|
||||||
|
_search_engine = None
|
||||||
File diff suppressed because it is too large
Load Diff
1
mcp-servers/knowledge-base/tests/__init__.py
Normal file
1
mcp-servers/knowledge-base/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Tests for Knowledge Base MCP Server."""
|
||||||
283
mcp-servers/knowledge-base/tests/conftest.py
Normal file
283
mcp-servers/knowledge-base/tests/conftest.py
Normal file
@@ -0,0 +1,283 @@
|
|||||||
|
"""
|
||||||
|
Test fixtures for Knowledge Base MCP Server.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
# Set test mode before importing modules
|
||||||
|
os.environ["IS_TEST"] = "true"
|
||||||
|
os.environ["KB_DATABASE_URL"] = "postgresql://test:test@localhost:5432/test"
|
||||||
|
os.environ["KB_REDIS_URL"] = "redis://localhost:6379/0"
|
||||||
|
os.environ["KB_LLM_GATEWAY_URL"] = "http://localhost:8001"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def settings():
|
||||||
|
"""Create test settings."""
|
||||||
|
from config import Settings, reset_settings
|
||||||
|
|
||||||
|
reset_settings()
|
||||||
|
return Settings(
|
||||||
|
host="127.0.0.1",
|
||||||
|
port=8002,
|
||||||
|
debug=True,
|
||||||
|
database_url="postgresql://test:test@localhost:5432/test",
|
||||||
|
redis_url="redis://localhost:6379/0",
|
||||||
|
llm_gateway_url="http://localhost:8001",
|
||||||
|
embedding_dimension=1536,
|
||||||
|
code_chunk_size=500,
|
||||||
|
code_chunk_overlap=50,
|
||||||
|
markdown_chunk_size=800,
|
||||||
|
markdown_chunk_overlap=100,
|
||||||
|
text_chunk_size=400,
|
||||||
|
text_chunk_overlap=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_database():
|
||||||
|
"""Create mock database manager."""
|
||||||
|
from database import DatabaseManager
|
||||||
|
|
||||||
|
mock_db = MagicMock(spec=DatabaseManager)
|
||||||
|
mock_db._pool = MagicMock()
|
||||||
|
mock_db.acquire = MagicMock(return_value=AsyncMock())
|
||||||
|
|
||||||
|
# Mock database methods
|
||||||
|
mock_db.initialize = AsyncMock()
|
||||||
|
mock_db.close = AsyncMock()
|
||||||
|
mock_db.store_embedding = AsyncMock(return_value="test-id-123")
|
||||||
|
mock_db.store_embeddings_batch = AsyncMock(return_value=["id-1", "id-2"])
|
||||||
|
mock_db.semantic_search = AsyncMock(return_value=[])
|
||||||
|
mock_db.keyword_search = AsyncMock(return_value=[])
|
||||||
|
mock_db.delete_by_source = AsyncMock(return_value=1)
|
||||||
|
mock_db.delete_collection = AsyncMock(return_value=5)
|
||||||
|
mock_db.delete_by_ids = AsyncMock(return_value=2)
|
||||||
|
mock_db.replace_source_embeddings = AsyncMock(return_value=(1, ["new-id-1"]))
|
||||||
|
mock_db.list_collections = AsyncMock(return_value=[])
|
||||||
|
mock_db.get_collection_stats = AsyncMock()
|
||||||
|
mock_db.cleanup_expired = AsyncMock(return_value=0)
|
||||||
|
|
||||||
|
return mock_db
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_embeddings():
|
||||||
|
"""Create mock embedding generator."""
|
||||||
|
from embeddings import EmbeddingGenerator
|
||||||
|
|
||||||
|
mock_emb = MagicMock(spec=EmbeddingGenerator)
|
||||||
|
mock_emb.initialize = AsyncMock()
|
||||||
|
mock_emb.close = AsyncMock()
|
||||||
|
|
||||||
|
# Generate fake embeddings (1536 dimensions)
|
||||||
|
def fake_embedding() -> list[float]:
|
||||||
|
return [0.1] * 1536
|
||||||
|
|
||||||
|
mock_emb.generate = AsyncMock(return_value=fake_embedding())
|
||||||
|
mock_emb.generate_batch = AsyncMock(side_effect=lambda texts, **_kwargs: [fake_embedding() for _ in texts])
|
||||||
|
|
||||||
|
return mock_emb
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_redis():
|
||||||
|
"""Create mock Redis client."""
|
||||||
|
import fakeredis.aioredis
|
||||||
|
|
||||||
|
return fakeredis.aioredis.FakeRedis()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_python_code():
|
||||||
|
"""Sample Python code for chunking tests."""
|
||||||
|
return '''"""Sample module for testing."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class Calculator:
|
||||||
|
"""A simple calculator class."""
|
||||||
|
|
||||||
|
def __init__(self, initial: int = 0) -> None:
|
||||||
|
"""Initialize calculator."""
|
||||||
|
self.value = initial
|
||||||
|
|
||||||
|
def add(self, x: int) -> int:
|
||||||
|
"""Add a value."""
|
||||||
|
self.value += x
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
def subtract(self, x: int) -> int:
|
||||||
|
"""Subtract a value."""
|
||||||
|
self.value -= x
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
def helper_function(data: dict[str, Any]) -> str:
|
||||||
|
"""A helper function."""
|
||||||
|
return str(data)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_function() -> None:
|
||||||
|
"""An async function."""
|
||||||
|
pass
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_markdown():
|
||||||
|
"""Sample Markdown content for chunking tests."""
|
||||||
|
return '''# Project Documentation
|
||||||
|
|
||||||
|
This is the main documentation for our project.
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
|
||||||
|
To get started, follow these steps:
|
||||||
|
|
||||||
|
1. Install dependencies
|
||||||
|
2. Configure settings
|
||||||
|
3. Run the application
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
You'll need the following installed:
|
||||||
|
|
||||||
|
- Python 3.12+
|
||||||
|
- PostgreSQL
|
||||||
|
- Redis
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Example code
|
||||||
|
def main():
|
||||||
|
print("Hello, World!")
|
||||||
|
```
|
||||||
|
|
||||||
|
## API Reference
|
||||||
|
|
||||||
|
### Search Endpoint
|
||||||
|
|
||||||
|
The search endpoint allows you to query the knowledge base.
|
||||||
|
|
||||||
|
**Endpoint:** `POST /api/search`
|
||||||
|
|
||||||
|
**Request:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"query": "your search query",
|
||||||
|
"limit": 10
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
We welcome contributions! Please see our contributing guide.
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_text():
|
||||||
|
"""Sample plain text for chunking tests."""
|
||||||
|
return '''The quick brown fox jumps over the lazy dog. This is a sample text that we use for testing the text chunking functionality. It contains multiple sentences that should be properly split into chunks.
|
||||||
|
|
||||||
|
Each paragraph represents a logical unit of text. The chunker should try to respect paragraph boundaries when possible. This helps maintain context and readability.
|
||||||
|
|
||||||
|
When chunks need to be split mid-paragraph, the chunker should prefer sentence boundaries. This ensures that each chunk contains complete thoughts and is useful for retrieval.
|
||||||
|
|
||||||
|
The final paragraph tests edge cases. What happens with short paragraphs? Do they get merged with adjacent content? Let's find out!
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_chunk():
|
||||||
|
"""Sample chunk for testing."""
|
||||||
|
from models import Chunk, ChunkType, FileType
|
||||||
|
|
||||||
|
return Chunk(
|
||||||
|
content="def hello():\n print('Hello')",
|
||||||
|
chunk_type=ChunkType.CODE,
|
||||||
|
file_type=FileType.PYTHON,
|
||||||
|
source_path="/test/hello.py",
|
||||||
|
start_line=1,
|
||||||
|
end_line=2,
|
||||||
|
metadata={"function": "hello"},
|
||||||
|
token_count=15,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_embedding():
|
||||||
|
"""Sample knowledge embedding for testing."""
|
||||||
|
from models import ChunkType, FileType, KnowledgeEmbedding
|
||||||
|
|
||||||
|
return KnowledgeEmbedding(
|
||||||
|
id="test-id-123",
|
||||||
|
project_id="proj-123",
|
||||||
|
collection="default",
|
||||||
|
content="def hello():\n print('Hello')",
|
||||||
|
embedding=[0.1] * 1536,
|
||||||
|
chunk_type=ChunkType.CODE,
|
||||||
|
source_path="/test/hello.py",
|
||||||
|
start_line=1,
|
||||||
|
end_line=2,
|
||||||
|
file_type=FileType.PYTHON,
|
||||||
|
metadata={"function": "hello"},
|
||||||
|
content_hash="abc123",
|
||||||
|
created_at=datetime.now(UTC),
|
||||||
|
updated_at=datetime.now(UTC),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_ingest_request():
|
||||||
|
"""Sample ingest request for testing."""
|
||||||
|
from models import ChunkType, FileType, IngestRequest
|
||||||
|
|
||||||
|
return IngestRequest(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
content="def hello():\n print('Hello')",
|
||||||
|
source_path="/test/hello.py",
|
||||||
|
collection="default",
|
||||||
|
chunk_type=ChunkType.CODE,
|
||||||
|
file_type=FileType.PYTHON,
|
||||||
|
metadata={"test": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_search_request():
|
||||||
|
"""Sample search request for testing."""
|
||||||
|
from models import SearchRequest, SearchType
|
||||||
|
|
||||||
|
return SearchRequest(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="hello function",
|
||||||
|
search_type=SearchType.HYBRID,
|
||||||
|
collection="default",
|
||||||
|
limit=10,
|
||||||
|
threshold=0.7,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_delete_request():
|
||||||
|
"""Sample delete request for testing."""
|
||||||
|
from models import DeleteRequest
|
||||||
|
|
||||||
|
return DeleteRequest(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
source_path="/test/hello.py",
|
||||||
|
)
|
||||||
422
mcp-servers/knowledge-base/tests/test_chunking.py
Normal file
422
mcp-servers/knowledge-base/tests/test_chunking.py
Normal file
@@ -0,0 +1,422 @@
|
|||||||
|
"""Tests for chunking module."""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseChunker:
|
||||||
|
"""Tests for base chunker functionality."""
|
||||||
|
|
||||||
|
def test_count_tokens(self, settings):
|
||||||
|
"""Test token counting."""
|
||||||
|
from chunking.text import TextChunker
|
||||||
|
|
||||||
|
chunker = TextChunker(
|
||||||
|
chunk_size=400,
|
||||||
|
chunk_overlap=50,
|
||||||
|
settings=settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Simple text should count tokens
|
||||||
|
tokens = chunker.count_tokens("Hello, world!")
|
||||||
|
assert tokens > 0
|
||||||
|
assert tokens < 10 # Should be about 3-4 tokens
|
||||||
|
|
||||||
|
def test_truncate_to_tokens(self, settings):
|
||||||
|
"""Test truncating text to token limit."""
|
||||||
|
from chunking.text import TextChunker
|
||||||
|
|
||||||
|
chunker = TextChunker(
|
||||||
|
chunk_size=400,
|
||||||
|
chunk_overlap=50,
|
||||||
|
settings=settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
long_text = "word " * 1000
|
||||||
|
truncated = chunker.truncate_to_tokens(long_text, 10)
|
||||||
|
|
||||||
|
assert chunker.count_tokens(truncated) <= 10
|
||||||
|
|
||||||
|
|
||||||
|
class TestCodeChunker:
|
||||||
|
"""Tests for code chunker."""
|
||||||
|
|
||||||
|
def test_chunk_python_code(self, settings, sample_python_code):
|
||||||
|
"""Test chunking Python code."""
|
||||||
|
from chunking.code import CodeChunker
|
||||||
|
from models import ChunkType, FileType
|
||||||
|
|
||||||
|
chunker = CodeChunker(
|
||||||
|
chunk_size=500,
|
||||||
|
chunk_overlap=50,
|
||||||
|
settings=settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = chunker.chunk(
|
||||||
|
content=sample_python_code,
|
||||||
|
source_path="/test/sample.py",
|
||||||
|
file_type=FileType.PYTHON,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(chunks) > 0
|
||||||
|
assert all(c.chunk_type == ChunkType.CODE for c in chunks)
|
||||||
|
assert all(c.file_type == FileType.PYTHON for c in chunks)
|
||||||
|
|
||||||
|
def test_preserves_function_boundaries(self, settings):
|
||||||
|
"""Test that chunker preserves function boundaries."""
|
||||||
|
from chunking.code import CodeChunker
|
||||||
|
from models import FileType
|
||||||
|
|
||||||
|
code = '''def function_one():
|
||||||
|
"""First function."""
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def function_two():
|
||||||
|
"""Second function."""
|
||||||
|
return 2
|
||||||
|
'''
|
||||||
|
|
||||||
|
chunker = CodeChunker(
|
||||||
|
chunk_size=100,
|
||||||
|
chunk_overlap=10,
|
||||||
|
settings=settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = chunker.chunk(
|
||||||
|
content=code,
|
||||||
|
source_path="/test/funcs.py",
|
||||||
|
file_type=FileType.PYTHON,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Each function should ideally be in its own chunk
|
||||||
|
assert len(chunks) >= 1
|
||||||
|
for chunk in chunks:
|
||||||
|
# Check chunks have line numbers
|
||||||
|
assert chunk.start_line is not None
|
||||||
|
assert chunk.end_line is not None
|
||||||
|
assert chunk.start_line <= chunk.end_line
|
||||||
|
|
||||||
|
def test_handles_empty_content(self, settings):
|
||||||
|
"""Test handling empty content."""
|
||||||
|
from chunking.code import CodeChunker
|
||||||
|
|
||||||
|
chunker = CodeChunker(
|
||||||
|
chunk_size=500,
|
||||||
|
chunk_overlap=50,
|
||||||
|
settings=settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = chunker.chunk(content="", source_path="/test/empty.py")
|
||||||
|
|
||||||
|
assert chunks == []
|
||||||
|
|
||||||
|
def test_chunk_type_is_code(self, settings):
|
||||||
|
"""Test that chunk_type property returns CODE."""
|
||||||
|
from chunking.code import CodeChunker
|
||||||
|
from models import ChunkType
|
||||||
|
|
||||||
|
chunker = CodeChunker(
|
||||||
|
chunk_size=500,
|
||||||
|
chunk_overlap=50,
|
||||||
|
settings=settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert chunker.chunk_type == ChunkType.CODE
|
||||||
|
|
||||||
|
|
||||||
|
class TestMarkdownChunker:
|
||||||
|
"""Tests for markdown chunker."""
|
||||||
|
|
||||||
|
def test_chunk_markdown(self, settings, sample_markdown):
|
||||||
|
"""Test chunking markdown content."""
|
||||||
|
from chunking.markdown import MarkdownChunker
|
||||||
|
from models import ChunkType, FileType
|
||||||
|
|
||||||
|
chunker = MarkdownChunker(
|
||||||
|
chunk_size=800,
|
||||||
|
chunk_overlap=100,
|
||||||
|
settings=settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = chunker.chunk(
|
||||||
|
content=sample_markdown,
|
||||||
|
source_path="/test/docs.md",
|
||||||
|
file_type=FileType.MARKDOWN,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(chunks) > 0
|
||||||
|
assert all(c.chunk_type == ChunkType.MARKDOWN for c in chunks)
|
||||||
|
|
||||||
|
def test_respects_heading_hierarchy(self, settings):
|
||||||
|
"""Test that chunker respects heading hierarchy."""
|
||||||
|
from chunking.markdown import MarkdownChunker
|
||||||
|
|
||||||
|
markdown = '''# Main Title
|
||||||
|
|
||||||
|
Introduction paragraph.
|
||||||
|
|
||||||
|
## Section One
|
||||||
|
|
||||||
|
Content for section one.
|
||||||
|
|
||||||
|
### Subsection
|
||||||
|
|
||||||
|
More detailed content.
|
||||||
|
|
||||||
|
## Section Two
|
||||||
|
|
||||||
|
Content for section two.
|
||||||
|
'''
|
||||||
|
|
||||||
|
chunker = MarkdownChunker(
|
||||||
|
chunk_size=200,
|
||||||
|
chunk_overlap=20,
|
||||||
|
settings=settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = chunker.chunk(
|
||||||
|
content=markdown,
|
||||||
|
source_path="/test/docs.md",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should have multiple chunks based on sections
|
||||||
|
assert len(chunks) >= 1
|
||||||
|
# Metadata should include heading context
|
||||||
|
for chunk in chunks:
|
||||||
|
# Chunks should have content
|
||||||
|
assert len(chunk.content) > 0
|
||||||
|
|
||||||
|
def test_handles_code_blocks(self, settings):
|
||||||
|
"""Test handling of code blocks in markdown."""
|
||||||
|
from chunking.markdown import MarkdownChunker
|
||||||
|
|
||||||
|
markdown = '''# Code Example
|
||||||
|
|
||||||
|
Here's some code:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def hello():
|
||||||
|
print("Hello, World!")
|
||||||
|
```
|
||||||
|
|
||||||
|
End of example.
|
||||||
|
'''
|
||||||
|
|
||||||
|
chunker = MarkdownChunker(
|
||||||
|
chunk_size=500,
|
||||||
|
chunk_overlap=50,
|
||||||
|
settings=settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = chunker.chunk(
|
||||||
|
content=markdown,
|
||||||
|
source_path="/test/code.md",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Code blocks should be preserved
|
||||||
|
assert len(chunks) >= 1
|
||||||
|
full_content = " ".join(c.content for c in chunks)
|
||||||
|
assert "```python" in full_content or "def hello" in full_content
|
||||||
|
|
||||||
|
def test_chunk_type_is_markdown(self, settings):
|
||||||
|
"""Test that chunk_type property returns MARKDOWN."""
|
||||||
|
from chunking.markdown import MarkdownChunker
|
||||||
|
from models import ChunkType
|
||||||
|
|
||||||
|
chunker = MarkdownChunker(
|
||||||
|
chunk_size=800,
|
||||||
|
chunk_overlap=100,
|
||||||
|
settings=settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert chunker.chunk_type == ChunkType.MARKDOWN
|
||||||
|
|
||||||
|
|
||||||
|
class TestTextChunker:
|
||||||
|
"""Tests for text chunker."""
|
||||||
|
|
||||||
|
def test_chunk_text(self, settings, sample_text):
|
||||||
|
"""Test chunking plain text."""
|
||||||
|
from chunking.text import TextChunker
|
||||||
|
from models import ChunkType
|
||||||
|
|
||||||
|
chunker = TextChunker(
|
||||||
|
chunk_size=400,
|
||||||
|
chunk_overlap=50,
|
||||||
|
settings=settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = chunker.chunk(
|
||||||
|
content=sample_text,
|
||||||
|
source_path="/test/docs.txt",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(chunks) > 0
|
||||||
|
assert all(c.chunk_type == ChunkType.TEXT for c in chunks)
|
||||||
|
|
||||||
|
def test_respects_paragraph_boundaries(self, settings):
|
||||||
|
"""Test that chunker respects paragraph boundaries."""
|
||||||
|
from chunking.text import TextChunker
|
||||||
|
|
||||||
|
text = '''First paragraph with some content.
|
||||||
|
|
||||||
|
Second paragraph with different content.
|
||||||
|
|
||||||
|
Third paragraph to test chunking behavior.
|
||||||
|
'''
|
||||||
|
|
||||||
|
chunker = TextChunker(
|
||||||
|
chunk_size=100,
|
||||||
|
chunk_overlap=10,
|
||||||
|
settings=settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = chunker.chunk(
|
||||||
|
content=text,
|
||||||
|
source_path="/test/text.txt",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(chunks) >= 1
|
||||||
|
|
||||||
|
def test_handles_single_paragraph(self, settings):
|
||||||
|
"""Test handling of single paragraph that fits in one chunk."""
|
||||||
|
from chunking.text import TextChunker
|
||||||
|
|
||||||
|
text = "This is a short paragraph."
|
||||||
|
|
||||||
|
chunker = TextChunker(
|
||||||
|
chunk_size=400,
|
||||||
|
chunk_overlap=50,
|
||||||
|
settings=settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = chunker.chunk(content=text, source_path="/test/short.txt")
|
||||||
|
|
||||||
|
assert len(chunks) == 1
|
||||||
|
assert chunks[0].content == text
|
||||||
|
|
||||||
|
def test_chunk_type_is_text(self, settings):
|
||||||
|
"""Test that chunk_type property returns TEXT."""
|
||||||
|
from chunking.text import TextChunker
|
||||||
|
from models import ChunkType
|
||||||
|
|
||||||
|
chunker = TextChunker(
|
||||||
|
chunk_size=400,
|
||||||
|
chunk_overlap=50,
|
||||||
|
settings=settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert chunker.chunk_type == ChunkType.TEXT
|
||||||
|
|
||||||
|
|
||||||
|
class TestChunkerFactory:
|
||||||
|
"""Tests for chunker factory."""
|
||||||
|
|
||||||
|
def test_get_code_chunker(self, settings):
|
||||||
|
"""Test getting code chunker."""
|
||||||
|
from chunking.base import ChunkerFactory
|
||||||
|
from chunking.code import CodeChunker
|
||||||
|
from models import FileType
|
||||||
|
|
||||||
|
factory = ChunkerFactory(settings=settings)
|
||||||
|
chunker = factory.get_chunker(file_type=FileType.PYTHON)
|
||||||
|
|
||||||
|
assert isinstance(chunker, CodeChunker)
|
||||||
|
|
||||||
|
def test_get_markdown_chunker(self, settings):
|
||||||
|
"""Test getting markdown chunker."""
|
||||||
|
from chunking.base import ChunkerFactory
|
||||||
|
from chunking.markdown import MarkdownChunker
|
||||||
|
from models import FileType
|
||||||
|
|
||||||
|
factory = ChunkerFactory(settings=settings)
|
||||||
|
chunker = factory.get_chunker(file_type=FileType.MARKDOWN)
|
||||||
|
|
||||||
|
assert isinstance(chunker, MarkdownChunker)
|
||||||
|
|
||||||
|
def test_get_text_chunker(self, settings):
|
||||||
|
"""Test getting text chunker."""
|
||||||
|
from chunking.base import ChunkerFactory
|
||||||
|
from chunking.text import TextChunker
|
||||||
|
from models import FileType
|
||||||
|
|
||||||
|
factory = ChunkerFactory(settings=settings)
|
||||||
|
chunker = factory.get_chunker(file_type=FileType.TEXT)
|
||||||
|
|
||||||
|
assert isinstance(chunker, TextChunker)
|
||||||
|
|
||||||
|
def test_get_chunker_for_path(self, settings):
|
||||||
|
"""Test getting chunker based on file path."""
|
||||||
|
from chunking.base import ChunkerFactory
|
||||||
|
from chunking.code import CodeChunker
|
||||||
|
from chunking.markdown import MarkdownChunker
|
||||||
|
from models import FileType
|
||||||
|
|
||||||
|
factory = ChunkerFactory(settings=settings)
|
||||||
|
|
||||||
|
chunker, file_type = factory.get_chunker_for_path("/test/file.py")
|
||||||
|
assert isinstance(chunker, CodeChunker)
|
||||||
|
assert file_type == FileType.PYTHON
|
||||||
|
|
||||||
|
chunker, file_type = factory.get_chunker_for_path("/test/docs.md")
|
||||||
|
assert isinstance(chunker, MarkdownChunker)
|
||||||
|
assert file_type == FileType.MARKDOWN
|
||||||
|
|
||||||
|
def test_chunk_content(self, settings, sample_python_code):
|
||||||
|
"""Test chunk_content convenience method."""
|
||||||
|
from chunking.base import ChunkerFactory
|
||||||
|
from models import ChunkType
|
||||||
|
|
||||||
|
factory = ChunkerFactory(settings=settings)
|
||||||
|
|
||||||
|
chunks = factory.chunk_content(
|
||||||
|
content=sample_python_code,
|
||||||
|
source_path="/test/sample.py",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(chunks) > 0
|
||||||
|
assert all(c.chunk_type == ChunkType.CODE for c in chunks)
|
||||||
|
|
||||||
|
def test_default_to_text_chunker(self, settings):
|
||||||
|
"""Test defaulting to text chunker."""
|
||||||
|
from chunking.base import ChunkerFactory
|
||||||
|
from chunking.text import TextChunker
|
||||||
|
|
||||||
|
factory = ChunkerFactory(settings=settings)
|
||||||
|
chunker = factory.get_chunker()
|
||||||
|
|
||||||
|
assert isinstance(chunker, TextChunker)
|
||||||
|
|
||||||
|
def test_chunker_caching(self, settings):
|
||||||
|
"""Test that factory caches chunker instances."""
|
||||||
|
from chunking.base import ChunkerFactory
|
||||||
|
from models import FileType
|
||||||
|
|
||||||
|
factory = ChunkerFactory(settings=settings)
|
||||||
|
|
||||||
|
chunker1 = factory.get_chunker(file_type=FileType.PYTHON)
|
||||||
|
chunker2 = factory.get_chunker(file_type=FileType.PYTHON)
|
||||||
|
|
||||||
|
assert chunker1 is chunker2
|
||||||
|
|
||||||
|
|
||||||
|
class TestGlobalChunkerFactory:
|
||||||
|
"""Tests for global chunker factory."""
|
||||||
|
|
||||||
|
def test_get_chunker_factory_singleton(self):
|
||||||
|
"""Test that get_chunker_factory returns singleton."""
|
||||||
|
from chunking.base import get_chunker_factory, reset_chunker_factory
|
||||||
|
|
||||||
|
reset_chunker_factory()
|
||||||
|
factory1 = get_chunker_factory()
|
||||||
|
factory2 = get_chunker_factory()
|
||||||
|
|
||||||
|
assert factory1 is factory2
|
||||||
|
|
||||||
|
def test_reset_chunker_factory(self):
|
||||||
|
"""Test resetting chunker factory."""
|
||||||
|
from chunking.base import get_chunker_factory, reset_chunker_factory
|
||||||
|
|
||||||
|
factory1 = get_chunker_factory()
|
||||||
|
reset_chunker_factory()
|
||||||
|
factory2 = get_chunker_factory()
|
||||||
|
|
||||||
|
assert factory1 is not factory2
|
||||||
241
mcp-servers/knowledge-base/tests/test_collection_manager.py
Normal file
241
mcp-servers/knowledge-base/tests/test_collection_manager.py
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
"""Tests for collection manager module."""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
class TestCollectionManager:
|
||||||
|
"""Tests for CollectionManager class."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def collection_manager(self, settings, mock_database, mock_embeddings):
|
||||||
|
"""Create collection manager with mocks."""
|
||||||
|
from chunking.base import ChunkerFactory
|
||||||
|
from collection_manager import CollectionManager
|
||||||
|
|
||||||
|
mock_chunker_factory = MagicMock(spec=ChunkerFactory)
|
||||||
|
|
||||||
|
# Mock chunk_content to return chunks
|
||||||
|
from models import Chunk, ChunkType
|
||||||
|
|
||||||
|
mock_chunker_factory.chunk_content.return_value = [
|
||||||
|
Chunk(
|
||||||
|
content="def hello(): pass",
|
||||||
|
chunk_type=ChunkType.CODE,
|
||||||
|
token_count=10,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
return CollectionManager(
|
||||||
|
settings=settings,
|
||||||
|
database=mock_database,
|
||||||
|
embeddings=mock_embeddings,
|
||||||
|
chunker_factory=mock_chunker_factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ingest_content(self, collection_manager, sample_ingest_request):
|
||||||
|
"""Test content ingestion."""
|
||||||
|
result = await collection_manager.ingest(sample_ingest_request)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.chunks_created == 1
|
||||||
|
assert result.embeddings_generated == 1
|
||||||
|
assert len(result.chunk_ids) == 1
|
||||||
|
assert result.collection == "default"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ingest_empty_content(self, collection_manager):
|
||||||
|
"""Test ingesting empty content."""
|
||||||
|
from models import IngestRequest
|
||||||
|
|
||||||
|
# Mock chunker to return empty list
|
||||||
|
collection_manager._chunker_factory.chunk_content.return_value = []
|
||||||
|
|
||||||
|
request = IngestRequest(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
content="",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await collection_manager.ingest(request)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.chunks_created == 0
|
||||||
|
assert result.embeddings_generated == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ingest_error_handling(self, collection_manager, sample_ingest_request):
|
||||||
|
"""Test ingest error handling."""
|
||||||
|
# Make embedding generation fail
|
||||||
|
collection_manager._embeddings.generate_batch.side_effect = Exception("Embedding error")
|
||||||
|
|
||||||
|
result = await collection_manager.ingest(sample_ingest_request)
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "Embedding error" in result.error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_by_source(self, collection_manager, sample_delete_request):
|
||||||
|
"""Test deletion by source path."""
|
||||||
|
result = await collection_manager.delete(sample_delete_request)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.chunks_deleted == 1 # Mock returns 1
|
||||||
|
collection_manager._database.delete_by_source.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_by_collection(self, collection_manager):
|
||||||
|
"""Test deletion by collection."""
|
||||||
|
from models import DeleteRequest
|
||||||
|
|
||||||
|
request = DeleteRequest(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
collection="to-delete",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await collection_manager.delete(request)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
collection_manager._database.delete_collection.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_by_ids(self, collection_manager):
|
||||||
|
"""Test deletion by chunk IDs."""
|
||||||
|
from models import DeleteRequest
|
||||||
|
|
||||||
|
request = DeleteRequest(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
chunk_ids=["id-1", "id-2"],
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await collection_manager.delete(request)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
collection_manager._database.delete_by_ids.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_no_target(self, collection_manager):
|
||||||
|
"""Test deletion with no target specified."""
|
||||||
|
from models import DeleteRequest
|
||||||
|
|
||||||
|
request = DeleteRequest(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await collection_manager.delete(request)
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "Must specify" in result.error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_collections(self, collection_manager):
|
||||||
|
"""Test listing collections."""
|
||||||
|
from models import CollectionInfo
|
||||||
|
|
||||||
|
collection_manager._database.list_collections.return_value = [
|
||||||
|
CollectionInfo(
|
||||||
|
name="collection-1",
|
||||||
|
project_id="proj-123",
|
||||||
|
chunk_count=100,
|
||||||
|
total_tokens=50000,
|
||||||
|
file_types=["python"],
|
||||||
|
created_at=datetime.now(UTC),
|
||||||
|
updated_at=datetime.now(UTC),
|
||||||
|
),
|
||||||
|
CollectionInfo(
|
||||||
|
name="collection-2",
|
||||||
|
project_id="proj-123",
|
||||||
|
chunk_count=50,
|
||||||
|
total_tokens=25000,
|
||||||
|
file_types=["javascript"],
|
||||||
|
created_at=datetime.now(UTC),
|
||||||
|
updated_at=datetime.now(UTC),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await collection_manager.list_collections("proj-123")
|
||||||
|
|
||||||
|
assert result.project_id == "proj-123"
|
||||||
|
assert result.total_collections == 2
|
||||||
|
assert len(result.collections) == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_collection_stats(self, collection_manager):
|
||||||
|
"""Test getting collection statistics."""
|
||||||
|
from models import CollectionStats
|
||||||
|
|
||||||
|
expected_stats = CollectionStats(
|
||||||
|
collection="test-collection",
|
||||||
|
project_id="proj-123",
|
||||||
|
chunk_count=100,
|
||||||
|
unique_sources=10,
|
||||||
|
total_tokens=50000,
|
||||||
|
avg_chunk_size=500.0,
|
||||||
|
chunk_types={"code": 60, "text": 40},
|
||||||
|
file_types={"python": 50, "javascript": 10},
|
||||||
|
)
|
||||||
|
collection_manager._database.get_collection_stats.return_value = expected_stats
|
||||||
|
|
||||||
|
stats = await collection_manager.get_collection_stats("proj-123", "test-collection")
|
||||||
|
|
||||||
|
assert stats.chunk_count == 100
|
||||||
|
assert stats.unique_sources == 10
|
||||||
|
collection_manager._database.get_collection_stats.assert_called_once_with(
|
||||||
|
"proj-123", "test-collection"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_document(self, collection_manager):
|
||||||
|
"""Test updating a document with atomic replace."""
|
||||||
|
result = await collection_manager.update_document(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
source_path="/test/file.py",
|
||||||
|
content="def updated(): pass",
|
||||||
|
collection="default",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should use atomic replace (delete + insert in transaction)
|
||||||
|
collection_manager._database.replace_source_embeddings.assert_called_once()
|
||||||
|
assert result.success is True
|
||||||
|
assert len(result.chunk_ids) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cleanup_expired(self, collection_manager):
|
||||||
|
"""Test cleaning up expired embeddings."""
|
||||||
|
collection_manager._database.cleanup_expired.return_value = 10
|
||||||
|
|
||||||
|
count = await collection_manager.cleanup_expired()
|
||||||
|
|
||||||
|
assert count == 10
|
||||||
|
collection_manager._database.cleanup_expired.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
class TestGlobalCollectionManager:
|
||||||
|
"""Tests for global collection manager."""
|
||||||
|
|
||||||
|
def test_get_collection_manager_singleton(self):
|
||||||
|
"""Test that get_collection_manager returns singleton."""
|
||||||
|
from collection_manager import get_collection_manager, reset_collection_manager
|
||||||
|
|
||||||
|
reset_collection_manager()
|
||||||
|
manager1 = get_collection_manager()
|
||||||
|
manager2 = get_collection_manager()
|
||||||
|
|
||||||
|
assert manager1 is manager2
|
||||||
|
|
||||||
|
def test_reset_collection_manager(self):
|
||||||
|
"""Test resetting collection manager."""
|
||||||
|
from collection_manager import get_collection_manager, reset_collection_manager
|
||||||
|
|
||||||
|
manager1 = get_collection_manager()
|
||||||
|
reset_collection_manager()
|
||||||
|
manager2 = get_collection_manager()
|
||||||
|
|
||||||
|
assert manager1 is not manager2
|
||||||
104
mcp-servers/knowledge-base/tests/test_config.py
Normal file
104
mcp-servers/knowledge-base/tests/test_config.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
"""Tests for configuration module."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class TestSettings:
|
||||||
|
"""Tests for Settings class."""
|
||||||
|
|
||||||
|
def test_default_values(self, settings):
|
||||||
|
"""Test default configuration values."""
|
||||||
|
assert settings.port == 8002
|
||||||
|
assert settings.embedding_dimension == 1536
|
||||||
|
assert settings.code_chunk_size == 500
|
||||||
|
assert settings.search_default_limit == 10
|
||||||
|
|
||||||
|
def test_env_prefix(self):
|
||||||
|
"""Test environment variable prefix."""
|
||||||
|
from config import Settings, reset_settings
|
||||||
|
|
||||||
|
reset_settings()
|
||||||
|
os.environ["KB_PORT"] = "9999"
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.port == 9999
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
del os.environ["KB_PORT"]
|
||||||
|
reset_settings()
|
||||||
|
|
||||||
|
def test_embedding_settings(self, settings):
|
||||||
|
"""Test embedding-related settings."""
|
||||||
|
assert settings.embedding_model == "text-embedding-3-large"
|
||||||
|
assert settings.embedding_batch_size == 100
|
||||||
|
assert settings.embedding_cache_ttl == 86400
|
||||||
|
|
||||||
|
def test_chunking_settings(self, settings):
|
||||||
|
"""Test chunking-related settings."""
|
||||||
|
assert settings.code_chunk_size == 500
|
||||||
|
assert settings.code_chunk_overlap == 50
|
||||||
|
assert settings.markdown_chunk_size == 800
|
||||||
|
assert settings.markdown_chunk_overlap == 100
|
||||||
|
assert settings.text_chunk_size == 400
|
||||||
|
assert settings.text_chunk_overlap == 50
|
||||||
|
|
||||||
|
def test_search_settings(self, settings):
|
||||||
|
"""Test search-related settings."""
|
||||||
|
assert settings.search_default_limit == 10
|
||||||
|
assert settings.search_max_limit == 100
|
||||||
|
assert settings.semantic_threshold == 0.7
|
||||||
|
assert settings.hybrid_semantic_weight == 0.7
|
||||||
|
assert settings.hybrid_keyword_weight == 0.3
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetSettings:
|
||||||
|
"""Tests for get_settings function."""
|
||||||
|
|
||||||
|
def test_returns_singleton(self):
|
||||||
|
"""Test that get_settings returns singleton."""
|
||||||
|
from config import get_settings, reset_settings
|
||||||
|
|
||||||
|
reset_settings()
|
||||||
|
settings1 = get_settings()
|
||||||
|
settings2 = get_settings()
|
||||||
|
assert settings1 is settings2
|
||||||
|
|
||||||
|
def test_reset_creates_new_instance(self):
|
||||||
|
"""Test that reset_settings clears the singleton."""
|
||||||
|
from config import get_settings, reset_settings
|
||||||
|
|
||||||
|
settings1 = get_settings()
|
||||||
|
reset_settings()
|
||||||
|
settings2 = get_settings()
|
||||||
|
assert settings1 is not settings2
|
||||||
|
|
||||||
|
|
||||||
|
class TestIsTestMode:
|
||||||
|
"""Tests for is_test_mode function."""
|
||||||
|
|
||||||
|
def test_returns_true_when_set(self):
|
||||||
|
"""Test returns True when IS_TEST is set."""
|
||||||
|
from config import is_test_mode
|
||||||
|
|
||||||
|
old_value = os.environ.get("IS_TEST")
|
||||||
|
os.environ["IS_TEST"] = "true"
|
||||||
|
|
||||||
|
assert is_test_mode() is True
|
||||||
|
|
||||||
|
if old_value:
|
||||||
|
os.environ["IS_TEST"] = old_value
|
||||||
|
else:
|
||||||
|
del os.environ["IS_TEST"]
|
||||||
|
|
||||||
|
def test_returns_false_when_not_set(self):
|
||||||
|
"""Test returns False when IS_TEST is not set."""
|
||||||
|
from config import is_test_mode
|
||||||
|
|
||||||
|
old_value = os.environ.get("IS_TEST")
|
||||||
|
if old_value:
|
||||||
|
del os.environ["IS_TEST"]
|
||||||
|
|
||||||
|
assert is_test_mode() is False
|
||||||
|
|
||||||
|
if old_value:
|
||||||
|
os.environ["IS_TEST"] = old_value
|
||||||
245
mcp-servers/knowledge-base/tests/test_embeddings.py
Normal file
245
mcp-servers/knowledge-base/tests/test_embeddings.py
Normal file
@@ -0,0 +1,245 @@
|
|||||||
|
"""Tests for embedding generation module."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbeddingGenerator:
|
||||||
|
"""Tests for EmbeddingGenerator class."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_http_response(self):
|
||||||
|
"""Create mock HTTP response."""
|
||||||
|
response = MagicMock()
|
||||||
|
response.status_code = 200
|
||||||
|
response.raise_for_status = MagicMock()
|
||||||
|
response.json.return_value = {
|
||||||
|
"result": {
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"text": json.dumps({
|
||||||
|
"embeddings": [[0.1] * 1536]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return response
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_single_embedding(self, settings, mock_redis, mock_http_response):
|
||||||
|
"""Test generating a single embedding."""
|
||||||
|
from embeddings import EmbeddingGenerator
|
||||||
|
|
||||||
|
generator = EmbeddingGenerator(settings=settings)
|
||||||
|
generator._redis = mock_redis
|
||||||
|
|
||||||
|
# Mock HTTP client
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = mock_http_response
|
||||||
|
generator._http_client = mock_client
|
||||||
|
|
||||||
|
embedding = await generator.generate(
|
||||||
|
text="Hello, world!",
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(embedding) == 1536
|
||||||
|
mock_client.post.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_batch_embeddings(self, settings, mock_redis):
|
||||||
|
"""Test generating batch embeddings."""
|
||||||
|
from embeddings import EmbeddingGenerator
|
||||||
|
|
||||||
|
generator = EmbeddingGenerator(settings=settings)
|
||||||
|
generator._redis = mock_redis
|
||||||
|
|
||||||
|
# Mock HTTP client with batch response
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"result": {
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"text": json.dumps({
|
||||||
|
"embeddings": [[0.1] * 1536, [0.2] * 1536, [0.3] * 1536]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
generator._http_client = mock_client
|
||||||
|
|
||||||
|
embeddings = await generator.generate_batch(
|
||||||
|
texts=["Text 1", "Text 2", "Text 3"],
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(embeddings) == 3
|
||||||
|
assert all(len(e) == 1536 for e in embeddings)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_caching(self, settings, mock_redis):
|
||||||
|
"""Test embedding caching."""
|
||||||
|
from embeddings import EmbeddingGenerator
|
||||||
|
|
||||||
|
generator = EmbeddingGenerator(settings=settings)
|
||||||
|
generator._redis = mock_redis
|
||||||
|
|
||||||
|
# Pre-populate cache
|
||||||
|
cache_key = generator._cache_key("Hello, world!")
|
||||||
|
await mock_redis.setex(cache_key, 3600, json.dumps([0.5] * 1536))
|
||||||
|
|
||||||
|
# Mock HTTP client (should not be called)
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
generator._http_client = mock_client
|
||||||
|
|
||||||
|
embedding = await generator.generate(
|
||||||
|
text="Hello, world!",
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return cached embedding
|
||||||
|
assert len(embedding) == 1536
|
||||||
|
assert embedding[0] == 0.5
|
||||||
|
mock_client.post.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_miss(self, settings, mock_redis, mock_http_response):
|
||||||
|
"""Test embedding cache miss."""
|
||||||
|
from embeddings import EmbeddingGenerator
|
||||||
|
|
||||||
|
generator = EmbeddingGenerator(settings=settings)
|
||||||
|
generator._redis = mock_redis
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = mock_http_response
|
||||||
|
generator._http_client = mock_client
|
||||||
|
|
||||||
|
embedding = await generator.generate(
|
||||||
|
text="New text not in cache",
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(embedding) == 1536
|
||||||
|
mock_client.post.assert_called_once()
|
||||||
|
|
||||||
|
def test_cache_key_generation(self, settings):
|
||||||
|
"""Test cache key generation."""
|
||||||
|
from embeddings import EmbeddingGenerator
|
||||||
|
|
||||||
|
generator = EmbeddingGenerator(settings=settings)
|
||||||
|
|
||||||
|
key1 = generator._cache_key("Hello")
|
||||||
|
key2 = generator._cache_key("Hello")
|
||||||
|
key3 = generator._cache_key("World")
|
||||||
|
|
||||||
|
assert key1 == key2
|
||||||
|
assert key1 != key3
|
||||||
|
assert key1.startswith("kb:emb:")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dimension_validation(self, settings, mock_redis):
|
||||||
|
"""Test embedding dimension validation."""
|
||||||
|
from embeddings import EmbeddingGenerator
|
||||||
|
from exceptions import EmbeddingDimensionMismatchError
|
||||||
|
|
||||||
|
generator = EmbeddingGenerator(settings=settings)
|
||||||
|
generator._redis = mock_redis
|
||||||
|
|
||||||
|
# Mock HTTP client with wrong dimension
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"result": {
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"text": json.dumps({
|
||||||
|
"embeddings": [[0.1] * 768] # Wrong dimension
|
||||||
|
})
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
generator._http_client = mock_client
|
||||||
|
|
||||||
|
with pytest.raises(EmbeddingDimensionMismatchError):
|
||||||
|
await generator.generate(
|
||||||
|
text="Test text",
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_batch(self, settings, mock_redis):
|
||||||
|
"""Test generating embeddings for empty batch."""
|
||||||
|
from embeddings import EmbeddingGenerator
|
||||||
|
|
||||||
|
generator = EmbeddingGenerator(settings=settings)
|
||||||
|
generator._redis = mock_redis
|
||||||
|
|
||||||
|
embeddings = await generator.generate_batch(
|
||||||
|
texts=[],
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert embeddings == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_initialize_and_close(self, settings):
|
||||||
|
"""Test initialize and close methods."""
|
||||||
|
from embeddings import EmbeddingGenerator
|
||||||
|
|
||||||
|
generator = EmbeddingGenerator(settings=settings)
|
||||||
|
|
||||||
|
# Mock successful initialization
|
||||||
|
with patch("embeddings.redis.from_url") as mock_redis_from_url:
|
||||||
|
mock_redis_client = AsyncMock()
|
||||||
|
mock_redis_client.ping = AsyncMock()
|
||||||
|
mock_redis_from_url.return_value = mock_redis_client
|
||||||
|
|
||||||
|
await generator.initialize()
|
||||||
|
|
||||||
|
assert generator._http_client is not None
|
||||||
|
|
||||||
|
await generator.close()
|
||||||
|
|
||||||
|
assert generator._http_client is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestGlobalEmbeddingGenerator:
|
||||||
|
"""Tests for global embedding generator."""
|
||||||
|
|
||||||
|
def test_get_embedding_generator_singleton(self):
|
||||||
|
"""Test that get_embedding_generator returns singleton."""
|
||||||
|
from embeddings import get_embedding_generator, reset_embedding_generator
|
||||||
|
|
||||||
|
reset_embedding_generator()
|
||||||
|
gen1 = get_embedding_generator()
|
||||||
|
gen2 = get_embedding_generator()
|
||||||
|
|
||||||
|
assert gen1 is gen2
|
||||||
|
|
||||||
|
def test_reset_embedding_generator(self):
|
||||||
|
"""Test resetting embedding generator."""
|
||||||
|
from embeddings import get_embedding_generator, reset_embedding_generator
|
||||||
|
|
||||||
|
gen1 = get_embedding_generator()
|
||||||
|
reset_embedding_generator()
|
||||||
|
gen2 = get_embedding_generator()
|
||||||
|
|
||||||
|
assert gen1 is not gen2
|
||||||
307
mcp-servers/knowledge-base/tests/test_exceptions.py
Normal file
307
mcp-servers/knowledge-base/tests/test_exceptions.py
Normal file
@@ -0,0 +1,307 @@
|
|||||||
|
"""Tests for exception classes."""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TestErrorCode:
|
||||||
|
"""Tests for ErrorCode enum."""
|
||||||
|
|
||||||
|
def test_error_code_values(self):
|
||||||
|
"""Test error code values."""
|
||||||
|
from exceptions import ErrorCode
|
||||||
|
|
||||||
|
assert ErrorCode.UNKNOWN_ERROR.value == "KB_UNKNOWN_ERROR"
|
||||||
|
assert ErrorCode.DATABASE_CONNECTION_ERROR.value == "KB_DATABASE_CONNECTION_ERROR"
|
||||||
|
assert ErrorCode.EMBEDDING_GENERATION_ERROR.value == "KB_EMBEDDING_GENERATION_ERROR"
|
||||||
|
assert ErrorCode.CHUNKING_ERROR.value == "KB_CHUNKING_ERROR"
|
||||||
|
assert ErrorCode.SEARCH_ERROR.value == "KB_SEARCH_ERROR"
|
||||||
|
assert ErrorCode.COLLECTION_NOT_FOUND.value == "KB_COLLECTION_NOT_FOUND"
|
||||||
|
assert ErrorCode.DOCUMENT_NOT_FOUND.value == "KB_DOCUMENT_NOT_FOUND"
|
||||||
|
|
||||||
|
|
||||||
|
class TestKnowledgeBaseError:
|
||||||
|
"""Tests for base exception class."""
|
||||||
|
|
||||||
|
def test_basic_error(self):
|
||||||
|
"""Test basic error creation."""
|
||||||
|
from exceptions import ErrorCode, KnowledgeBaseError
|
||||||
|
|
||||||
|
error = KnowledgeBaseError(
|
||||||
|
message="Something went wrong",
|
||||||
|
code=ErrorCode.UNKNOWN_ERROR,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert error.message == "Something went wrong"
|
||||||
|
assert error.code == ErrorCode.UNKNOWN_ERROR
|
||||||
|
assert error.details == {}
|
||||||
|
assert error.cause is None
|
||||||
|
|
||||||
|
def test_error_with_details(self):
|
||||||
|
"""Test error with details."""
|
||||||
|
from exceptions import ErrorCode, KnowledgeBaseError
|
||||||
|
|
||||||
|
error = KnowledgeBaseError(
|
||||||
|
message="Query failed",
|
||||||
|
code=ErrorCode.DATABASE_QUERY_ERROR,
|
||||||
|
details={"query": "SELECT * FROM table", "error_code": 42},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert error.details["query"] == "SELECT * FROM table"
|
||||||
|
assert error.details["error_code"] == 42
|
||||||
|
|
||||||
|
def test_error_with_cause(self):
|
||||||
|
"""Test error with underlying cause."""
|
||||||
|
from exceptions import ErrorCode, KnowledgeBaseError
|
||||||
|
|
||||||
|
original = ValueError("Original error")
|
||||||
|
error = KnowledgeBaseError(
|
||||||
|
message="Wrapped error",
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
cause=original,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert error.cause is original
|
||||||
|
assert isinstance(error.cause, ValueError)
|
||||||
|
|
||||||
|
def test_to_dict(self):
|
||||||
|
"""Test to_dict method."""
|
||||||
|
from exceptions import ErrorCode, KnowledgeBaseError
|
||||||
|
|
||||||
|
error = KnowledgeBaseError(
|
||||||
|
message="Test error",
|
||||||
|
code=ErrorCode.INVALID_REQUEST,
|
||||||
|
details={"field": "value"},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = error.to_dict()
|
||||||
|
|
||||||
|
assert result["error"] == "KB_INVALID_REQUEST"
|
||||||
|
assert result["message"] == "Test error"
|
||||||
|
assert result["details"]["field"] == "value"
|
||||||
|
|
||||||
|
def test_str_representation(self):
|
||||||
|
"""Test string representation."""
|
||||||
|
from exceptions import ErrorCode, KnowledgeBaseError
|
||||||
|
|
||||||
|
error = KnowledgeBaseError(
|
||||||
|
message="Test error",
|
||||||
|
code=ErrorCode.INVALID_REQUEST,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert str(error) == "[KB_INVALID_REQUEST] Test error"
|
||||||
|
|
||||||
|
def test_repr_representation(self):
|
||||||
|
"""Test repr representation."""
|
||||||
|
from exceptions import ErrorCode, KnowledgeBaseError
|
||||||
|
|
||||||
|
error = KnowledgeBaseError(
|
||||||
|
message="Test error",
|
||||||
|
code=ErrorCode.INVALID_REQUEST,
|
||||||
|
details={"key": "value"},
|
||||||
|
)
|
||||||
|
|
||||||
|
repr_str = repr(error)
|
||||||
|
assert "KnowledgeBaseError" in repr_str
|
||||||
|
assert "Test error" in repr_str
|
||||||
|
assert "KB_INVALID_REQUEST" in repr_str
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatabaseErrors:
|
||||||
|
"""Tests for database-related exceptions."""
|
||||||
|
|
||||||
|
def test_database_connection_error(self):
|
||||||
|
"""Test database connection error."""
|
||||||
|
from exceptions import DatabaseConnectionError, ErrorCode
|
||||||
|
|
||||||
|
error = DatabaseConnectionError(
|
||||||
|
message="Cannot connect to database",
|
||||||
|
details={"host": "localhost", "port": 5432},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert error.code == ErrorCode.DATABASE_CONNECTION_ERROR
|
||||||
|
assert error.details["host"] == "localhost"
|
||||||
|
|
||||||
|
def test_database_connection_error_default_message(self):
|
||||||
|
"""Test database connection error with default message."""
|
||||||
|
from exceptions import DatabaseConnectionError
|
||||||
|
|
||||||
|
error = DatabaseConnectionError()
|
||||||
|
|
||||||
|
assert error.message == "Failed to connect to database"
|
||||||
|
|
||||||
|
def test_database_query_error(self):
|
||||||
|
"""Test database query error."""
|
||||||
|
from exceptions import DatabaseQueryError, ErrorCode
|
||||||
|
|
||||||
|
error = DatabaseQueryError(
|
||||||
|
message="Query failed",
|
||||||
|
query="SELECT * FROM missing_table",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert error.code == ErrorCode.DATABASE_QUERY_ERROR
|
||||||
|
assert error.details["query"] == "SELECT * FROM missing_table"
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbeddingErrors:
|
||||||
|
"""Tests for embedding-related exceptions."""
|
||||||
|
|
||||||
|
def test_embedding_generation_error(self):
|
||||||
|
"""Test embedding generation error."""
|
||||||
|
from exceptions import EmbeddingGenerationError, ErrorCode
|
||||||
|
|
||||||
|
error = EmbeddingGenerationError(
|
||||||
|
message="Failed to generate",
|
||||||
|
texts_count=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert error.code == ErrorCode.EMBEDDING_GENERATION_ERROR
|
||||||
|
assert error.details["texts_count"] == 10
|
||||||
|
|
||||||
|
def test_embedding_dimension_mismatch(self):
|
||||||
|
"""Test embedding dimension mismatch error."""
|
||||||
|
from exceptions import EmbeddingDimensionMismatchError, ErrorCode
|
||||||
|
|
||||||
|
error = EmbeddingDimensionMismatchError(
|
||||||
|
expected=1536,
|
||||||
|
actual=768,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert error.code == ErrorCode.EMBEDDING_DIMENSION_MISMATCH
|
||||||
|
assert "expected 1536" in error.message
|
||||||
|
assert "got 768" in error.message
|
||||||
|
assert error.details["expected_dimension"] == 1536
|
||||||
|
assert error.details["actual_dimension"] == 768
|
||||||
|
|
||||||
|
|
||||||
|
class TestChunkingErrors:
|
||||||
|
"""Tests for chunking-related exceptions."""
|
||||||
|
|
||||||
|
def test_unsupported_file_type_error(self):
|
||||||
|
"""Test unsupported file type error."""
|
||||||
|
from exceptions import ErrorCode, UnsupportedFileTypeError
|
||||||
|
|
||||||
|
error = UnsupportedFileTypeError(
|
||||||
|
file_type=".xyz",
|
||||||
|
supported_types=[".py", ".js", ".md"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert error.code == ErrorCode.UNSUPPORTED_FILE_TYPE
|
||||||
|
assert error.details["file_type"] == ".xyz"
|
||||||
|
assert len(error.details["supported_types"]) == 3
|
||||||
|
|
||||||
|
def test_file_too_large_error(self):
|
||||||
|
"""Test file too large error."""
|
||||||
|
from exceptions import ErrorCode, FileTooLargeError
|
||||||
|
|
||||||
|
error = FileTooLargeError(
|
||||||
|
file_size=10_000_000,
|
||||||
|
max_size=1_000_000,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert error.code == ErrorCode.FILE_TOO_LARGE
|
||||||
|
assert error.details["file_size"] == 10_000_000
|
||||||
|
assert error.details["max_size"] == 1_000_000
|
||||||
|
|
||||||
|
def test_encoding_error(self):
|
||||||
|
"""Test encoding error."""
|
||||||
|
from exceptions import EncodingError, ErrorCode
|
||||||
|
|
||||||
|
error = EncodingError(
|
||||||
|
message="Cannot decode file",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert error.code == ErrorCode.ENCODING_ERROR
|
||||||
|
assert error.details["encoding"] == "utf-8"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSearchErrors:
|
||||||
|
"""Tests for search-related exceptions."""
|
||||||
|
|
||||||
|
def test_invalid_search_type_error(self):
|
||||||
|
"""Test invalid search type error."""
|
||||||
|
from exceptions import ErrorCode, InvalidSearchTypeError
|
||||||
|
|
||||||
|
error = InvalidSearchTypeError(
|
||||||
|
search_type="invalid",
|
||||||
|
valid_types=["semantic", "keyword", "hybrid"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert error.code == ErrorCode.INVALID_SEARCH_TYPE
|
||||||
|
assert error.details["search_type"] == "invalid"
|
||||||
|
assert len(error.details["valid_types"]) == 3
|
||||||
|
|
||||||
|
def test_search_timeout_error(self):
|
||||||
|
"""Test search timeout error."""
|
||||||
|
from exceptions import ErrorCode, SearchTimeoutError
|
||||||
|
|
||||||
|
error = SearchTimeoutError(timeout=30.0)
|
||||||
|
|
||||||
|
assert error.code == ErrorCode.SEARCH_TIMEOUT
|
||||||
|
assert error.details["timeout"] == 30.0
|
||||||
|
assert "30" in error.message
|
||||||
|
|
||||||
|
|
||||||
|
class TestCollectionErrors:
|
||||||
|
"""Tests for collection-related exceptions."""
|
||||||
|
|
||||||
|
def test_collection_not_found_error(self):
|
||||||
|
"""Test collection not found error."""
|
||||||
|
from exceptions import CollectionNotFoundError, ErrorCode
|
||||||
|
|
||||||
|
error = CollectionNotFoundError(
|
||||||
|
collection="missing-collection",
|
||||||
|
project_id="proj-123",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert error.code == ErrorCode.COLLECTION_NOT_FOUND
|
||||||
|
assert error.details["collection"] == "missing-collection"
|
||||||
|
assert error.details["project_id"] == "proj-123"
|
||||||
|
|
||||||
|
|
||||||
|
class TestDocumentErrors:
|
||||||
|
"""Tests for document-related exceptions."""
|
||||||
|
|
||||||
|
def test_document_not_found_error(self):
|
||||||
|
"""Test document not found error."""
|
||||||
|
from exceptions import DocumentNotFoundError, ErrorCode
|
||||||
|
|
||||||
|
error = DocumentNotFoundError(
|
||||||
|
source_path="/path/to/file.py",
|
||||||
|
project_id="proj-123",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert error.code == ErrorCode.DOCUMENT_NOT_FOUND
|
||||||
|
assert error.details["source_path"] == "/path/to/file.py"
|
||||||
|
|
||||||
|
def test_invalid_document_error(self):
|
||||||
|
"""Test invalid document error."""
|
||||||
|
from exceptions import ErrorCode, InvalidDocumentError
|
||||||
|
|
||||||
|
error = InvalidDocumentError(
|
||||||
|
message="Empty content",
|
||||||
|
details={"reason": "no content"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert error.code == ErrorCode.INVALID_DOCUMENT
|
||||||
|
|
||||||
|
|
||||||
|
class TestProjectErrors:
|
||||||
|
"""Tests for project-related exceptions."""
|
||||||
|
|
||||||
|
def test_project_not_found_error(self):
|
||||||
|
"""Test project not found error."""
|
||||||
|
from exceptions import ErrorCode, ProjectNotFoundError
|
||||||
|
|
||||||
|
error = ProjectNotFoundError(project_id="missing-proj")
|
||||||
|
|
||||||
|
assert error.code == ErrorCode.PROJECT_NOT_FOUND
|
||||||
|
assert error.details["project_id"] == "missing-proj"
|
||||||
|
|
||||||
|
def test_project_access_denied_error(self):
|
||||||
|
"""Test project access denied error."""
|
||||||
|
from exceptions import ErrorCode, ProjectAccessDeniedError
|
||||||
|
|
||||||
|
error = ProjectAccessDeniedError(project_id="restricted-proj")
|
||||||
|
|
||||||
|
assert error.code == ErrorCode.PROJECT_ACCESS_DENIED
|
||||||
|
assert "restricted-proj" in error.message
|
||||||
347
mcp-servers/knowledge-base/tests/test_models.py
Normal file
347
mcp-servers/knowledge-base/tests/test_models.py
Normal file
@@ -0,0 +1,347 @@
|
|||||||
|
"""Tests for data models."""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnums:
|
||||||
|
"""Tests for enum classes."""
|
||||||
|
|
||||||
|
def test_search_type_values(self):
|
||||||
|
"""Test SearchType enum values."""
|
||||||
|
from models import SearchType
|
||||||
|
|
||||||
|
assert SearchType.SEMANTIC.value == "semantic"
|
||||||
|
assert SearchType.KEYWORD.value == "keyword"
|
||||||
|
assert SearchType.HYBRID.value == "hybrid"
|
||||||
|
|
||||||
|
def test_chunk_type_values(self):
|
||||||
|
"""Test ChunkType enum values."""
|
||||||
|
from models import ChunkType
|
||||||
|
|
||||||
|
assert ChunkType.CODE.value == "code"
|
||||||
|
assert ChunkType.MARKDOWN.value == "markdown"
|
||||||
|
assert ChunkType.TEXT.value == "text"
|
||||||
|
assert ChunkType.DOCUMENTATION.value == "documentation"
|
||||||
|
|
||||||
|
def test_file_type_values(self):
|
||||||
|
"""Test FileType enum values."""
|
||||||
|
from models import FileType
|
||||||
|
|
||||||
|
assert FileType.PYTHON.value == "python"
|
||||||
|
assert FileType.JAVASCRIPT.value == "javascript"
|
||||||
|
assert FileType.TYPESCRIPT.value == "typescript"
|
||||||
|
assert FileType.MARKDOWN.value == "markdown"
|
||||||
|
|
||||||
|
|
||||||
|
class TestFileExtensionMap:
|
||||||
|
"""Tests for file extension mapping."""
|
||||||
|
|
||||||
|
def test_python_extensions(self):
|
||||||
|
"""Test Python file extensions."""
|
||||||
|
from models import FILE_EXTENSION_MAP, FileType
|
||||||
|
|
||||||
|
assert FILE_EXTENSION_MAP[".py"] == FileType.PYTHON
|
||||||
|
|
||||||
|
def test_javascript_extensions(self):
|
||||||
|
"""Test JavaScript file extensions."""
|
||||||
|
from models import FILE_EXTENSION_MAP, FileType
|
||||||
|
|
||||||
|
assert FILE_EXTENSION_MAP[".js"] == FileType.JAVASCRIPT
|
||||||
|
assert FILE_EXTENSION_MAP[".jsx"] == FileType.JAVASCRIPT
|
||||||
|
|
||||||
|
def test_typescript_extensions(self):
|
||||||
|
"""Test TypeScript file extensions."""
|
||||||
|
from models import FILE_EXTENSION_MAP, FileType
|
||||||
|
|
||||||
|
assert FILE_EXTENSION_MAP[".ts"] == FileType.TYPESCRIPT
|
||||||
|
assert FILE_EXTENSION_MAP[".tsx"] == FileType.TYPESCRIPT
|
||||||
|
|
||||||
|
def test_markdown_extensions(self):
|
||||||
|
"""Test Markdown file extensions."""
|
||||||
|
from models import FILE_EXTENSION_MAP, FileType
|
||||||
|
|
||||||
|
assert FILE_EXTENSION_MAP[".md"] == FileType.MARKDOWN
|
||||||
|
assert FILE_EXTENSION_MAP[".mdx"] == FileType.MARKDOWN
|
||||||
|
|
||||||
|
|
||||||
|
class TestChunk:
|
||||||
|
"""Tests for Chunk dataclass."""
|
||||||
|
|
||||||
|
def test_chunk_creation(self, sample_chunk):
|
||||||
|
"""Test chunk creation."""
|
||||||
|
from models import ChunkType, FileType
|
||||||
|
|
||||||
|
assert sample_chunk.content == "def hello():\n print('Hello')"
|
||||||
|
assert sample_chunk.chunk_type == ChunkType.CODE
|
||||||
|
assert sample_chunk.file_type == FileType.PYTHON
|
||||||
|
assert sample_chunk.source_path == "/test/hello.py"
|
||||||
|
assert sample_chunk.start_line == 1
|
||||||
|
assert sample_chunk.end_line == 2
|
||||||
|
assert sample_chunk.token_count == 15
|
||||||
|
|
||||||
|
def test_chunk_to_dict(self, sample_chunk):
|
||||||
|
"""Test chunk to_dict method."""
|
||||||
|
result = sample_chunk.to_dict()
|
||||||
|
|
||||||
|
assert result["content"] == "def hello():\n print('Hello')"
|
||||||
|
assert result["chunk_type"] == "code"
|
||||||
|
assert result["file_type"] == "python"
|
||||||
|
assert result["source_path"] == "/test/hello.py"
|
||||||
|
assert result["start_line"] == 1
|
||||||
|
assert result["end_line"] == 2
|
||||||
|
assert result["token_count"] == 15
|
||||||
|
|
||||||
|
|
||||||
|
class TestKnowledgeEmbedding:
|
||||||
|
"""Tests for KnowledgeEmbedding dataclass."""
|
||||||
|
|
||||||
|
def test_embedding_creation(self, sample_embedding):
|
||||||
|
"""Test embedding creation."""
|
||||||
|
assert sample_embedding.id == "test-id-123"
|
||||||
|
assert sample_embedding.project_id == "proj-123"
|
||||||
|
assert sample_embedding.collection == "default"
|
||||||
|
assert len(sample_embedding.embedding) == 1536
|
||||||
|
|
||||||
|
def test_embedding_to_dict(self, sample_embedding):
|
||||||
|
"""Test embedding to_dict method."""
|
||||||
|
result = sample_embedding.to_dict()
|
||||||
|
|
||||||
|
assert result["id"] == "test-id-123"
|
||||||
|
assert result["project_id"] == "proj-123"
|
||||||
|
assert result["collection"] == "default"
|
||||||
|
assert result["chunk_type"] == "code"
|
||||||
|
assert result["file_type"] == "python"
|
||||||
|
assert "embedding" not in result # Embedding excluded for size
|
||||||
|
|
||||||
|
|
||||||
|
class TestIngestRequest:
|
||||||
|
"""Tests for IngestRequest model."""
|
||||||
|
|
||||||
|
def test_ingest_request_creation(self, sample_ingest_request):
|
||||||
|
"""Test ingest request creation."""
|
||||||
|
from models import ChunkType, FileType
|
||||||
|
|
||||||
|
assert sample_ingest_request.project_id == "proj-123"
|
||||||
|
assert sample_ingest_request.agent_id == "agent-456"
|
||||||
|
assert sample_ingest_request.chunk_type == ChunkType.CODE
|
||||||
|
assert sample_ingest_request.file_type == FileType.PYTHON
|
||||||
|
assert sample_ingest_request.collection == "default"
|
||||||
|
|
||||||
|
def test_ingest_request_defaults(self):
|
||||||
|
"""Test ingest request default values."""
|
||||||
|
from models import ChunkType, IngestRequest
|
||||||
|
|
||||||
|
request = IngestRequest(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
content="test content",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert request.collection == "default"
|
||||||
|
assert request.chunk_type == ChunkType.TEXT
|
||||||
|
assert request.file_type is None
|
||||||
|
assert request.metadata == {}
|
||||||
|
|
||||||
|
|
||||||
|
class TestIngestResult:
|
||||||
|
"""Tests for IngestResult model."""
|
||||||
|
|
||||||
|
def test_successful_result(self):
|
||||||
|
"""Test successful ingest result."""
|
||||||
|
from models import IngestResult
|
||||||
|
|
||||||
|
result = IngestResult(
|
||||||
|
success=True,
|
||||||
|
chunks_created=5,
|
||||||
|
embeddings_generated=5,
|
||||||
|
source_path="/test/file.py",
|
||||||
|
collection="default",
|
||||||
|
chunk_ids=["id1", "id2", "id3", "id4", "id5"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.chunks_created == 5
|
||||||
|
assert result.error is None
|
||||||
|
|
||||||
|
def test_failed_result(self):
|
||||||
|
"""Test failed ingest result."""
|
||||||
|
from models import IngestResult
|
||||||
|
|
||||||
|
result = IngestResult(
|
||||||
|
success=False,
|
||||||
|
chunks_created=0,
|
||||||
|
embeddings_generated=0,
|
||||||
|
collection="default",
|
||||||
|
chunk_ids=[],
|
||||||
|
error="Something went wrong",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert result.error == "Something went wrong"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSearchRequest:
|
||||||
|
"""Tests for SearchRequest model."""
|
||||||
|
|
||||||
|
def test_search_request_creation(self, sample_search_request):
|
||||||
|
"""Test search request creation."""
|
||||||
|
from models import SearchType
|
||||||
|
|
||||||
|
assert sample_search_request.project_id == "proj-123"
|
||||||
|
assert sample_search_request.query == "hello function"
|
||||||
|
assert sample_search_request.search_type == SearchType.HYBRID
|
||||||
|
assert sample_search_request.limit == 10
|
||||||
|
assert sample_search_request.threshold == 0.7
|
||||||
|
|
||||||
|
def test_search_request_defaults(self):
|
||||||
|
"""Test search request default values."""
|
||||||
|
from models import SearchRequest, SearchType
|
||||||
|
|
||||||
|
request = SearchRequest(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="test query",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert request.search_type == SearchType.HYBRID
|
||||||
|
assert request.collection is None
|
||||||
|
assert request.limit == 10
|
||||||
|
assert request.threshold == 0.7
|
||||||
|
assert request.file_types is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestSearchResult:
|
||||||
|
"""Tests for SearchResult model."""
|
||||||
|
|
||||||
|
def test_from_embedding(self, sample_embedding):
|
||||||
|
"""Test creating SearchResult from KnowledgeEmbedding."""
|
||||||
|
from models import SearchResult
|
||||||
|
|
||||||
|
result = SearchResult.from_embedding(sample_embedding, 0.95)
|
||||||
|
|
||||||
|
assert result.id == "test-id-123"
|
||||||
|
assert result.content == "def hello():\n print('Hello')"
|
||||||
|
assert result.score == 0.95
|
||||||
|
assert result.source_path == "/test/hello.py"
|
||||||
|
assert result.chunk_type == "code"
|
||||||
|
assert result.file_type == "python"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSearchResponse:
|
||||||
|
"""Tests for SearchResponse model."""
|
||||||
|
|
||||||
|
def test_search_response(self):
|
||||||
|
"""Test search response creation."""
|
||||||
|
from models import SearchResponse, SearchResult
|
||||||
|
|
||||||
|
results = [
|
||||||
|
SearchResult(
|
||||||
|
id="id1",
|
||||||
|
content="test content 1",
|
||||||
|
score=0.95,
|
||||||
|
chunk_type="code",
|
||||||
|
collection="default",
|
||||||
|
),
|
||||||
|
SearchResult(
|
||||||
|
id="id2",
|
||||||
|
content="test content 2",
|
||||||
|
score=0.85,
|
||||||
|
chunk_type="text",
|
||||||
|
collection="default",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
response = SearchResponse(
|
||||||
|
query="test query",
|
||||||
|
search_type="hybrid",
|
||||||
|
results=results,
|
||||||
|
total_results=2,
|
||||||
|
search_time_ms=15.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.query == "test query"
|
||||||
|
assert len(response.results) == 2
|
||||||
|
assert response.search_time_ms == 15.5
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeleteRequest:
|
||||||
|
"""Tests for DeleteRequest model."""
|
||||||
|
|
||||||
|
def test_delete_by_source(self, sample_delete_request):
|
||||||
|
"""Test delete request by source path."""
|
||||||
|
assert sample_delete_request.project_id == "proj-123"
|
||||||
|
assert sample_delete_request.source_path == "/test/hello.py"
|
||||||
|
assert sample_delete_request.collection is None
|
||||||
|
assert sample_delete_request.chunk_ids is None
|
||||||
|
|
||||||
|
def test_delete_by_collection(self):
|
||||||
|
"""Test delete request by collection."""
|
||||||
|
from models import DeleteRequest
|
||||||
|
|
||||||
|
request = DeleteRequest(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
collection="to-delete",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert request.collection == "to-delete"
|
||||||
|
assert request.source_path is None
|
||||||
|
|
||||||
|
def test_delete_by_ids(self):
|
||||||
|
"""Test delete request by chunk IDs."""
|
||||||
|
from models import DeleteRequest
|
||||||
|
|
||||||
|
request = DeleteRequest(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
chunk_ids=["id1", "id2", "id3"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(request.chunk_ids) == 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestCollectionInfo:
|
||||||
|
"""Tests for CollectionInfo model."""
|
||||||
|
|
||||||
|
def test_collection_info(self):
|
||||||
|
"""Test collection info creation."""
|
||||||
|
from models import CollectionInfo
|
||||||
|
|
||||||
|
info = CollectionInfo(
|
||||||
|
name="test-collection",
|
||||||
|
project_id="proj-123",
|
||||||
|
chunk_count=100,
|
||||||
|
total_tokens=50000,
|
||||||
|
file_types=["python", "javascript"],
|
||||||
|
created_at=datetime.now(UTC),
|
||||||
|
updated_at=datetime.now(UTC),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert info.name == "test-collection"
|
||||||
|
assert info.chunk_count == 100
|
||||||
|
assert len(info.file_types) == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestCollectionStats:
|
||||||
|
"""Tests for CollectionStats model."""
|
||||||
|
|
||||||
|
def test_collection_stats(self):
|
||||||
|
"""Test collection stats creation."""
|
||||||
|
from models import CollectionStats
|
||||||
|
|
||||||
|
stats = CollectionStats(
|
||||||
|
collection="test-collection",
|
||||||
|
project_id="proj-123",
|
||||||
|
chunk_count=100,
|
||||||
|
unique_sources=10,
|
||||||
|
total_tokens=50000,
|
||||||
|
avg_chunk_size=500.0,
|
||||||
|
chunk_types={"code": 60, "text": 40},
|
||||||
|
file_types={"python": 50, "javascript": 10},
|
||||||
|
oldest_chunk=datetime.now(UTC),
|
||||||
|
newest_chunk=datetime.now(UTC),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert stats.chunk_count == 100
|
||||||
|
assert stats.unique_sources == 10
|
||||||
|
assert stats.chunk_types["code"] == 60
|
||||||
295
mcp-servers/knowledge-base/tests/test_search.py
Normal file
295
mcp-servers/knowledge-base/tests/test_search.py
Normal file
@@ -0,0 +1,295 @@
|
|||||||
|
"""Tests for search module."""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
class TestSearchEngine:
|
||||||
|
"""Tests for SearchEngine class."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def search_engine(self, settings, mock_database, mock_embeddings):
|
||||||
|
"""Create search engine with mocks."""
|
||||||
|
from search import SearchEngine
|
||||||
|
|
||||||
|
engine = SearchEngine(
|
||||||
|
settings=settings,
|
||||||
|
database=mock_database,
|
||||||
|
embeddings=mock_embeddings,
|
||||||
|
)
|
||||||
|
return engine
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_db_results(self):
|
||||||
|
"""Create sample database results."""
|
||||||
|
from models import ChunkType, FileType, KnowledgeEmbedding
|
||||||
|
|
||||||
|
return [
|
||||||
|
(
|
||||||
|
KnowledgeEmbedding(
|
||||||
|
id="id-1",
|
||||||
|
project_id="proj-123",
|
||||||
|
collection="default",
|
||||||
|
content="def hello(): pass",
|
||||||
|
embedding=[0.1] * 1536,
|
||||||
|
chunk_type=ChunkType.CODE,
|
||||||
|
source_path="/test/file.py",
|
||||||
|
file_type=FileType.PYTHON,
|
||||||
|
created_at=datetime.now(UTC),
|
||||||
|
updated_at=datetime.now(UTC),
|
||||||
|
),
|
||||||
|
0.95,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
KnowledgeEmbedding(
|
||||||
|
id="id-2",
|
||||||
|
project_id="proj-123",
|
||||||
|
collection="default",
|
||||||
|
content="def world(): pass",
|
||||||
|
embedding=[0.2] * 1536,
|
||||||
|
chunk_type=ChunkType.CODE,
|
||||||
|
source_path="/test/file2.py",
|
||||||
|
file_type=FileType.PYTHON,
|
||||||
|
created_at=datetime.now(UTC),
|
||||||
|
updated_at=datetime.now(UTC),
|
||||||
|
),
|
||||||
|
0.85,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_semantic_search(self, search_engine, sample_search_request, sample_db_results):
|
||||||
|
"""Test semantic search."""
|
||||||
|
from models import SearchType
|
||||||
|
|
||||||
|
sample_search_request.search_type = SearchType.SEMANTIC
|
||||||
|
search_engine._database.semantic_search.return_value = sample_db_results
|
||||||
|
|
||||||
|
response = await search_engine.search(sample_search_request)
|
||||||
|
|
||||||
|
assert response.search_type == "semantic"
|
||||||
|
assert len(response.results) == 2
|
||||||
|
assert response.results[0].score == 0.95
|
||||||
|
search_engine._database.semantic_search.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_keyword_search(self, search_engine, sample_search_request, sample_db_results):
|
||||||
|
"""Test keyword search."""
|
||||||
|
from models import SearchType
|
||||||
|
|
||||||
|
sample_search_request.search_type = SearchType.KEYWORD
|
||||||
|
search_engine._database.keyword_search.return_value = sample_db_results
|
||||||
|
|
||||||
|
response = await search_engine.search(sample_search_request)
|
||||||
|
|
||||||
|
assert response.search_type == "keyword"
|
||||||
|
assert len(response.results) == 2
|
||||||
|
search_engine._database.keyword_search.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_hybrid_search(self, search_engine, sample_search_request, sample_db_results):
|
||||||
|
"""Test hybrid search."""
|
||||||
|
from models import SearchType
|
||||||
|
|
||||||
|
sample_search_request.search_type = SearchType.HYBRID
|
||||||
|
|
||||||
|
# Both searches return same results for simplicity
|
||||||
|
search_engine._database.semantic_search.return_value = sample_db_results
|
||||||
|
search_engine._database.keyword_search.return_value = sample_db_results
|
||||||
|
|
||||||
|
response = await search_engine.search(sample_search_request)
|
||||||
|
|
||||||
|
assert response.search_type == "hybrid"
|
||||||
|
# Results should be fused
|
||||||
|
assert len(response.results) >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_with_collection_filter(self, search_engine, sample_search_request, sample_db_results):
|
||||||
|
"""Test search with collection filter."""
|
||||||
|
from models import SearchType
|
||||||
|
|
||||||
|
sample_search_request.search_type = SearchType.SEMANTIC
|
||||||
|
sample_search_request.collection = "specific-collection"
|
||||||
|
search_engine._database.semantic_search.return_value = sample_db_results
|
||||||
|
|
||||||
|
await search_engine.search(sample_search_request)
|
||||||
|
|
||||||
|
# Verify collection was passed to database
|
||||||
|
call_args = search_engine._database.semantic_search.call_args
|
||||||
|
assert call_args.kwargs["collection"] == "specific-collection"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_with_file_type_filter(self, search_engine, sample_search_request, sample_db_results):
|
||||||
|
"""Test search with file type filter."""
|
||||||
|
from models import FileType, SearchType
|
||||||
|
|
||||||
|
sample_search_request.search_type = SearchType.SEMANTIC
|
||||||
|
sample_search_request.file_types = [FileType.PYTHON]
|
||||||
|
search_engine._database.semantic_search.return_value = sample_db_results
|
||||||
|
|
||||||
|
await search_engine.search(sample_search_request)
|
||||||
|
|
||||||
|
# Verify file types were passed to database
|
||||||
|
call_args = search_engine._database.semantic_search.call_args
|
||||||
|
assert call_args.kwargs["file_types"] == [FileType.PYTHON]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_respects_limit(self, search_engine, sample_search_request, sample_db_results):
|
||||||
|
"""Test that search respects result limit."""
|
||||||
|
from models import SearchType
|
||||||
|
|
||||||
|
sample_search_request.search_type = SearchType.SEMANTIC
|
||||||
|
sample_search_request.limit = 1
|
||||||
|
search_engine._database.semantic_search.return_value = sample_db_results[:1]
|
||||||
|
|
||||||
|
response = await search_engine.search(sample_search_request)
|
||||||
|
|
||||||
|
assert len(response.results) <= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_records_time(self, search_engine, sample_search_request, sample_db_results):
|
||||||
|
"""Test that search records time."""
|
||||||
|
from models import SearchType
|
||||||
|
|
||||||
|
sample_search_request.search_type = SearchType.SEMANTIC
|
||||||
|
search_engine._database.semantic_search.return_value = sample_db_results
|
||||||
|
|
||||||
|
response = await search_engine.search(sample_search_request)
|
||||||
|
|
||||||
|
assert response.search_time_ms > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_search_type(self, search_engine, sample_search_request):
|
||||||
|
"""Test handling invalid search type."""
|
||||||
|
from exceptions import InvalidSearchTypeError
|
||||||
|
|
||||||
|
# Force invalid search type
|
||||||
|
sample_search_request.search_type = "invalid"
|
||||||
|
|
||||||
|
with pytest.raises((InvalidSearchTypeError, ValueError)):
|
||||||
|
await search_engine.search(sample_search_request)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_results(self, search_engine, sample_search_request):
|
||||||
|
"""Test search with no results."""
|
||||||
|
from models import SearchType
|
||||||
|
|
||||||
|
sample_search_request.search_type = SearchType.SEMANTIC
|
||||||
|
search_engine._database.semantic_search.return_value = []
|
||||||
|
|
||||||
|
response = await search_engine.search(sample_search_request)
|
||||||
|
|
||||||
|
assert len(response.results) == 0
|
||||||
|
assert response.total_results == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestReciprocalRankFusion:
|
||||||
|
"""Tests for reciprocal rank fusion."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def search_engine(self, settings, mock_database, mock_embeddings):
|
||||||
|
"""Create search engine with mocks."""
|
||||||
|
from search import SearchEngine
|
||||||
|
|
||||||
|
return SearchEngine(
|
||||||
|
settings=settings,
|
||||||
|
database=mock_database,
|
||||||
|
embeddings=mock_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_fusion_combines_results(self, search_engine):
|
||||||
|
"""Test that RRF combines results from both searches."""
|
||||||
|
from models import SearchResult
|
||||||
|
|
||||||
|
semantic = [
|
||||||
|
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"),
|
||||||
|
SearchResult(id="b", content="B", score=0.8, chunk_type="code", collection="default"),
|
||||||
|
]
|
||||||
|
|
||||||
|
keyword = [
|
||||||
|
SearchResult(id="b", content="B", score=0.85, chunk_type="code", collection="default"),
|
||||||
|
SearchResult(id="c", content="C", score=0.7, chunk_type="code", collection="default"),
|
||||||
|
]
|
||||||
|
|
||||||
|
fused = search_engine._reciprocal_rank_fusion(semantic, keyword)
|
||||||
|
|
||||||
|
# Should have all unique results
|
||||||
|
ids = [r.id for r in fused]
|
||||||
|
assert "a" in ids
|
||||||
|
assert "b" in ids
|
||||||
|
assert "c" in ids
|
||||||
|
|
||||||
|
# B should be ranked higher (appears in both)
|
||||||
|
b_rank = ids.index("b")
|
||||||
|
assert b_rank < 2 # Should be in top 2
|
||||||
|
|
||||||
|
def test_fusion_respects_weights(self, search_engine):
|
||||||
|
"""Test that RRF respects semantic/keyword weights."""
|
||||||
|
from models import SearchResult
|
||||||
|
|
||||||
|
# Same results in same order
|
||||||
|
results = [
|
||||||
|
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# High semantic weight
|
||||||
|
fused_semantic_heavy = search_engine._reciprocal_rank_fusion(
|
||||||
|
results, [],
|
||||||
|
semantic_weight=0.9,
|
||||||
|
keyword_weight=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# High keyword weight
|
||||||
|
fused_keyword_heavy = search_engine._reciprocal_rank_fusion(
|
||||||
|
[], results,
|
||||||
|
semantic_weight=0.1,
|
||||||
|
keyword_weight=0.9,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Both should still return the result
|
||||||
|
assert len(fused_semantic_heavy) == 1
|
||||||
|
assert len(fused_keyword_heavy) == 1
|
||||||
|
|
||||||
|
def test_fusion_normalizes_scores(self, search_engine):
|
||||||
|
"""Test that RRF normalizes scores to 0-1."""
|
||||||
|
from models import SearchResult
|
||||||
|
|
||||||
|
semantic = [
|
||||||
|
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"),
|
||||||
|
SearchResult(id="b", content="B", score=0.8, chunk_type="code", collection="default"),
|
||||||
|
]
|
||||||
|
|
||||||
|
keyword = [
|
||||||
|
SearchResult(id="c", content="C", score=0.7, chunk_type="code", collection="default"),
|
||||||
|
]
|
||||||
|
|
||||||
|
fused = search_engine._reciprocal_rank_fusion(semantic, keyword)
|
||||||
|
|
||||||
|
# All scores should be between 0 and 1
|
||||||
|
for result in fused:
|
||||||
|
assert 0 <= result.score <= 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestGlobalSearchEngine:
|
||||||
|
"""Tests for global search engine."""
|
||||||
|
|
||||||
|
def test_get_search_engine_singleton(self):
|
||||||
|
"""Test that get_search_engine returns singleton."""
|
||||||
|
from search import get_search_engine, reset_search_engine
|
||||||
|
|
||||||
|
reset_search_engine()
|
||||||
|
engine1 = get_search_engine()
|
||||||
|
engine2 = get_search_engine()
|
||||||
|
|
||||||
|
assert engine1 is engine2
|
||||||
|
|
||||||
|
def test_reset_search_engine(self):
|
||||||
|
"""Test resetting search engine."""
|
||||||
|
from search import get_search_engine, reset_search_engine
|
||||||
|
|
||||||
|
engine1 = get_search_engine()
|
||||||
|
reset_search_engine()
|
||||||
|
engine2 = get_search_engine()
|
||||||
|
|
||||||
|
assert engine1 is not engine2
|
||||||
655
mcp-servers/knowledge-base/tests/test_server.py
Normal file
655
mcp-servers/knowledge-base/tests/test_server.py
Normal file
@@ -0,0 +1,655 @@
|
|||||||
|
"""Tests for server module and MCP tools."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
|
||||||
|
class TestHealthCheck:
|
||||||
|
"""Tests for health check endpoint."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_check_healthy(self):
|
||||||
|
"""Test health check when all dependencies are connected."""
|
||||||
|
import server
|
||||||
|
|
||||||
|
# Create a proper async context manager mock for database
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_conn.fetchval = AsyncMock(return_value=1)
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db._pool = MagicMock()
|
||||||
|
|
||||||
|
# Make acquire an async context manager
|
||||||
|
mock_cm = AsyncMock()
|
||||||
|
mock_cm.__aenter__.return_value = mock_conn
|
||||||
|
mock_cm.__aexit__.return_value = None
|
||||||
|
mock_db.acquire.return_value = mock_cm
|
||||||
|
|
||||||
|
# Mock Redis
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.ping = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
# Mock HTTP client for LLM Gateway
|
||||||
|
mock_http_response = AsyncMock()
|
||||||
|
mock_http_response.status_code = 200
|
||||||
|
mock_http_client = AsyncMock()
|
||||||
|
mock_http_client.get = AsyncMock(return_value=mock_http_response)
|
||||||
|
|
||||||
|
# Mock embeddings with Redis and HTTP client
|
||||||
|
mock_embeddings = MagicMock()
|
||||||
|
mock_embeddings._redis = mock_redis
|
||||||
|
mock_embeddings._http_client = mock_http_client
|
||||||
|
|
||||||
|
server._database = mock_db
|
||||||
|
server._embeddings = mock_embeddings
|
||||||
|
|
||||||
|
result = await server.health_check()
|
||||||
|
|
||||||
|
assert result["status"] == "healthy"
|
||||||
|
assert result["service"] == "knowledge-base"
|
||||||
|
assert result["dependencies"]["database"] == "connected"
|
||||||
|
assert result["dependencies"]["redis"] == "connected"
|
||||||
|
assert result["dependencies"]["llm_gateway"] == "connected"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_check_no_database(self):
|
||||||
|
"""Test health check without database - should be unhealthy."""
|
||||||
|
import server
|
||||||
|
|
||||||
|
server._database = None
|
||||||
|
server._embeddings = None
|
||||||
|
|
||||||
|
result = await server.health_check()
|
||||||
|
|
||||||
|
assert result["status"] == "unhealthy"
|
||||||
|
assert result["dependencies"]["database"] == "not initialized"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_check_degraded(self):
|
||||||
|
"""Test health check with database but no Redis - should be degraded."""
|
||||||
|
import server
|
||||||
|
|
||||||
|
# Create a proper async context manager mock for database
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_conn.fetchval = AsyncMock(return_value=1)
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db._pool = MagicMock()
|
||||||
|
|
||||||
|
mock_cm = AsyncMock()
|
||||||
|
mock_cm.__aenter__.return_value = mock_conn
|
||||||
|
mock_cm.__aexit__.return_value = None
|
||||||
|
mock_db.acquire.return_value = mock_cm
|
||||||
|
|
||||||
|
# Mock embeddings without Redis
|
||||||
|
mock_embeddings = MagicMock()
|
||||||
|
mock_embeddings._redis = None
|
||||||
|
mock_embeddings._http_client = None
|
||||||
|
|
||||||
|
server._database = mock_db
|
||||||
|
server._embeddings = mock_embeddings
|
||||||
|
|
||||||
|
result = await server.health_check()
|
||||||
|
|
||||||
|
assert result["status"] == "degraded"
|
||||||
|
assert result["dependencies"]["database"] == "connected"
|
||||||
|
assert result["dependencies"]["redis"] == "not initialized"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSearchKnowledgeTool:
|
||||||
|
"""Tests for search_knowledge MCP tool."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_success(self):
|
||||||
|
"""Test successful search."""
|
||||||
|
import server
|
||||||
|
from models import SearchResponse, SearchResult
|
||||||
|
|
||||||
|
mock_search = MagicMock()
|
||||||
|
mock_search.search = AsyncMock(
|
||||||
|
return_value=SearchResponse(
|
||||||
|
query="test query",
|
||||||
|
search_type="hybrid",
|
||||||
|
results=[
|
||||||
|
SearchResult(
|
||||||
|
id="id-1",
|
||||||
|
content="Test content",
|
||||||
|
score=0.95,
|
||||||
|
source_path="/test/file.py",
|
||||||
|
chunk_type="code",
|
||||||
|
collection="default",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
total_results=1,
|
||||||
|
search_time_ms=10.5,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
server._search = mock_search
|
||||||
|
|
||||||
|
# Call the wrapped function via .fn
|
||||||
|
result = await server.search_knowledge.fn(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="test query",
|
||||||
|
search_type="hybrid",
|
||||||
|
collection=None,
|
||||||
|
limit=10,
|
||||||
|
threshold=0.7,
|
||||||
|
file_types=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert len(result["results"]) == 1
|
||||||
|
assert result["results"][0]["score"] == 0.95
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_invalid_type(self):
|
||||||
|
"""Test search with invalid search type."""
|
||||||
|
import server
|
||||||
|
|
||||||
|
result = await server.search_knowledge.fn(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="test",
|
||||||
|
search_type="invalid",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
assert "Invalid search type" in result["error"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_invalid_file_type(self):
|
||||||
|
"""Test search with invalid file type."""
|
||||||
|
import server
|
||||||
|
|
||||||
|
result = await server.search_knowledge.fn(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="test",
|
||||||
|
search_type="hybrid",
|
||||||
|
collection=None,
|
||||||
|
limit=10,
|
||||||
|
threshold=0.7,
|
||||||
|
file_types=["invalid_type"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
assert "Invalid file type" in result["error"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestIngestContentTool:
|
||||||
|
"""Tests for ingest_content MCP tool."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ingest_success(self):
|
||||||
|
"""Test successful ingestion."""
|
||||||
|
import server
|
||||||
|
from models import IngestResult
|
||||||
|
|
||||||
|
mock_collections = MagicMock()
|
||||||
|
mock_collections.ingest = AsyncMock(
|
||||||
|
return_value=IngestResult(
|
||||||
|
success=True,
|
||||||
|
chunks_created=3,
|
||||||
|
embeddings_generated=3,
|
||||||
|
source_path="/test/file.py",
|
||||||
|
collection="default",
|
||||||
|
chunk_ids=["id-1", "id-2", "id-3"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
server._collections = mock_collections
|
||||||
|
|
||||||
|
result = await server.ingest_content.fn(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
content="def hello(): pass",
|
||||||
|
source_path="/test/file.py",
|
||||||
|
collection="default",
|
||||||
|
chunk_type="text",
|
||||||
|
file_type=None,
|
||||||
|
metadata=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["chunks_created"] == 3
|
||||||
|
assert len(result["chunk_ids"]) == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ingest_invalid_chunk_type(self):
|
||||||
|
"""Test ingest with invalid chunk type."""
|
||||||
|
import server
|
||||||
|
|
||||||
|
result = await server.ingest_content.fn(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
content="test content",
|
||||||
|
chunk_type="invalid",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
assert "Invalid chunk type" in result["error"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ingest_invalid_file_type(self):
|
||||||
|
"""Test ingest with invalid file type."""
|
||||||
|
import server
|
||||||
|
|
||||||
|
result = await server.ingest_content.fn(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
content="test content",
|
||||||
|
source_path=None,
|
||||||
|
collection="default",
|
||||||
|
chunk_type="text",
|
||||||
|
file_type="invalid",
|
||||||
|
metadata=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
assert "Invalid file type" in result["error"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeleteContentTool:
|
||||||
|
"""Tests for delete_content MCP tool."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_success(self):
|
||||||
|
"""Test successful deletion."""
|
||||||
|
import server
|
||||||
|
from models import DeleteResult
|
||||||
|
|
||||||
|
mock_collections = MagicMock()
|
||||||
|
mock_collections.delete = AsyncMock(
|
||||||
|
return_value=DeleteResult(
|
||||||
|
success=True,
|
||||||
|
chunks_deleted=5,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
server._collections = mock_collections
|
||||||
|
|
||||||
|
result = await server.delete_content.fn(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
source_path="/test/file.py",
|
||||||
|
collection=None,
|
||||||
|
chunk_ids=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["chunks_deleted"] == 5
|
||||||
|
|
||||||
|
|
||||||
|
class TestListCollectionsTool:
|
||||||
|
"""Tests for list_collections MCP tool."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_collections_success(self):
|
||||||
|
"""Test listing collections."""
|
||||||
|
import server
|
||||||
|
from models import CollectionInfo, ListCollectionsResponse
|
||||||
|
|
||||||
|
mock_collections = MagicMock()
|
||||||
|
mock_collections.list_collections = AsyncMock(
|
||||||
|
return_value=ListCollectionsResponse(
|
||||||
|
project_id="proj-123",
|
||||||
|
collections=[
|
||||||
|
CollectionInfo(
|
||||||
|
name="collection-1",
|
||||||
|
project_id="proj-123",
|
||||||
|
chunk_count=100,
|
||||||
|
total_tokens=50000,
|
||||||
|
file_types=["python"],
|
||||||
|
created_at=datetime.now(UTC),
|
||||||
|
updated_at=datetime.now(UTC),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
total_collections=1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
server._collections = mock_collections
|
||||||
|
|
||||||
|
result = await server.list_collections.fn(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["total_collections"] == 1
|
||||||
|
assert len(result["collections"]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetCollectionStatsTool:
|
||||||
|
"""Tests for get_collection_stats MCP tool."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_stats_success(self):
|
||||||
|
"""Test getting collection stats."""
|
||||||
|
import server
|
||||||
|
from models import CollectionStats
|
||||||
|
|
||||||
|
mock_collections = MagicMock()
|
||||||
|
mock_collections.get_collection_stats = AsyncMock(
|
||||||
|
return_value=CollectionStats(
|
||||||
|
collection="test-collection",
|
||||||
|
project_id="proj-123",
|
||||||
|
chunk_count=100,
|
||||||
|
unique_sources=10,
|
||||||
|
total_tokens=50000,
|
||||||
|
avg_chunk_size=500.0,
|
||||||
|
chunk_types={"code": 60, "text": 40},
|
||||||
|
file_types={"python": 50, "javascript": 10},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
server._collections = mock_collections
|
||||||
|
|
||||||
|
result = await server.get_collection_stats.fn(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
collection="test-collection",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["chunk_count"] == 100
|
||||||
|
assert result["unique_sources"] == 10
|
||||||
|
|
||||||
|
|
||||||
|
class TestUpdateDocumentTool:
|
||||||
|
"""Tests for update_document MCP tool."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_success(self):
|
||||||
|
"""Test updating a document."""
|
||||||
|
import server
|
||||||
|
from models import IngestResult
|
||||||
|
|
||||||
|
mock_collections = MagicMock()
|
||||||
|
mock_collections.update_document = AsyncMock(
|
||||||
|
return_value=IngestResult(
|
||||||
|
success=True,
|
||||||
|
chunks_created=2,
|
||||||
|
embeddings_generated=2,
|
||||||
|
source_path="/test/file.py",
|
||||||
|
collection="default",
|
||||||
|
chunk_ids=["id-1", "id-2"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
server._collections = mock_collections
|
||||||
|
|
||||||
|
result = await server.update_document.fn(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
source_path="/test/file.py",
|
||||||
|
content="def updated(): pass",
|
||||||
|
collection="default",
|
||||||
|
chunk_type="text",
|
||||||
|
file_type=None,
|
||||||
|
metadata=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["chunks_created"] == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_invalid_chunk_type(self):
|
||||||
|
"""Test update with invalid chunk type."""
|
||||||
|
import server
|
||||||
|
|
||||||
|
result = await server.update_document.fn(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
source_path="/test/file.py",
|
||||||
|
content="test",
|
||||||
|
chunk_type="invalid",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
assert "Invalid chunk type" in result["error"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPToolsEndpoint:
|
||||||
|
"""Tests for /mcp/tools endpoint."""
|
||||||
|
|
||||||
|
def test_list_mcp_tools(self):
|
||||||
|
"""Test listing available MCP tools."""
|
||||||
|
import server
|
||||||
|
|
||||||
|
client = TestClient(server.app)
|
||||||
|
response = client.get("/mcp/tools")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "tools" in data
|
||||||
|
assert len(data["tools"]) == 6 # 6 tools registered
|
||||||
|
|
||||||
|
tool_names = [t["name"] for t in data["tools"]]
|
||||||
|
assert "search_knowledge" in tool_names
|
||||||
|
assert "ingest_content" in tool_names
|
||||||
|
assert "delete_content" in tool_names
|
||||||
|
assert "list_collections" in tool_names
|
||||||
|
assert "get_collection_stats" in tool_names
|
||||||
|
assert "update_document" in tool_names
|
||||||
|
|
||||||
|
def test_tool_has_schema(self):
|
||||||
|
"""Test that each tool has input schema."""
|
||||||
|
import server
|
||||||
|
|
||||||
|
client = TestClient(server.app)
|
||||||
|
response = client.get("/mcp/tools")
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
for tool in data["tools"]:
|
||||||
|
assert "inputSchema" in tool
|
||||||
|
assert "type" in tool["inputSchema"]
|
||||||
|
assert tool["inputSchema"]["type"] == "object"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPRPCEndpoint:
|
||||||
|
"""Tests for /mcp JSON-RPC endpoint."""
|
||||||
|
|
||||||
|
def test_valid_jsonrpc_request(self):
|
||||||
|
"""Test valid JSON-RPC request."""
|
||||||
|
import server
|
||||||
|
from models import SearchResponse, SearchResult
|
||||||
|
|
||||||
|
mock_search = MagicMock()
|
||||||
|
mock_search.search = AsyncMock(
|
||||||
|
return_value=SearchResponse(
|
||||||
|
query="test",
|
||||||
|
search_type="hybrid",
|
||||||
|
results=[
|
||||||
|
SearchResult(
|
||||||
|
id="id-1",
|
||||||
|
content="Test",
|
||||||
|
score=0.9,
|
||||||
|
source_path="/test.py",
|
||||||
|
chunk_type="code",
|
||||||
|
collection="default",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
total_results=1,
|
||||||
|
search_time_ms=5.0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
server._search = mock_search
|
||||||
|
|
||||||
|
client = TestClient(server.app)
|
||||||
|
response = client.post(
|
||||||
|
"/mcp",
|
||||||
|
json={
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": "search_knowledge",
|
||||||
|
"params": {
|
||||||
|
"project_id": "proj-123",
|
||||||
|
"agent_id": "agent-456",
|
||||||
|
"query": "test",
|
||||||
|
},
|
||||||
|
"id": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["jsonrpc"] == "2.0"
|
||||||
|
assert data["id"] == 1
|
||||||
|
assert "result" in data
|
||||||
|
assert data["result"]["success"] is True
|
||||||
|
|
||||||
|
def test_invalid_jsonrpc_version(self):
|
||||||
|
"""Test request with invalid JSON-RPC version."""
|
||||||
|
import server
|
||||||
|
|
||||||
|
client = TestClient(server.app)
|
||||||
|
response = client.post(
|
||||||
|
"/mcp",
|
||||||
|
json={
|
||||||
|
"jsonrpc": "1.0",
|
||||||
|
"method": "search_knowledge",
|
||||||
|
"params": {},
|
||||||
|
"id": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
data = response.json()
|
||||||
|
assert data["error"]["code"] == -32600
|
||||||
|
assert "jsonrpc must be '2.0'" in data["error"]["message"]
|
||||||
|
|
||||||
|
def test_missing_method(self):
|
||||||
|
"""Test request without method."""
|
||||||
|
import server
|
||||||
|
|
||||||
|
client = TestClient(server.app)
|
||||||
|
response = client.post(
|
||||||
|
"/mcp",
|
||||||
|
json={
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"params": {},
|
||||||
|
"id": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
data = response.json()
|
||||||
|
assert data["error"]["code"] == -32600
|
||||||
|
assert "method is required" in data["error"]["message"]
|
||||||
|
|
||||||
|
def test_unknown_method(self):
|
||||||
|
"""Test request with unknown method."""
|
||||||
|
import server
|
||||||
|
|
||||||
|
client = TestClient(server.app)
|
||||||
|
response = client.post(
|
||||||
|
"/mcp",
|
||||||
|
json={
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": "unknown_method",
|
||||||
|
"params": {},
|
||||||
|
"id": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
data = response.json()
|
||||||
|
assert data["error"]["code"] == -32601
|
||||||
|
assert "Method not found" in data["error"]["message"]
|
||||||
|
|
||||||
|
def test_invalid_params(self):
|
||||||
|
"""Test request with invalid params."""
|
||||||
|
import server
|
||||||
|
|
||||||
|
client = TestClient(server.app)
|
||||||
|
response = client.post(
|
||||||
|
"/mcp",
|
||||||
|
json={
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": "search_knowledge",
|
||||||
|
"params": {"invalid_param": "value"}, # Missing required params
|
||||||
|
"id": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
data = response.json()
|
||||||
|
assert data["error"]["code"] == -32602
|
||||||
|
|
||||||
|
|
||||||
|
class TestContentSizeLimits:
|
||||||
|
"""Tests for content size validation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ingest_rejects_oversized_content(self):
|
||||||
|
"""Test that ingest rejects content exceeding size limit."""
|
||||||
|
import server
|
||||||
|
from config import get_settings
|
||||||
|
|
||||||
|
settings = get_settings()
|
||||||
|
# Create content larger than max size
|
||||||
|
oversized_content = "x" * (settings.max_document_size + 1)
|
||||||
|
|
||||||
|
result = await server.ingest_content.fn(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
content=oversized_content,
|
||||||
|
chunk_type="text",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
assert "exceeds maximum" in result["error"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_rejects_oversized_content(self):
|
||||||
|
"""Test that update rejects content exceeding size limit."""
|
||||||
|
import server
|
||||||
|
from config import get_settings
|
||||||
|
|
||||||
|
settings = get_settings()
|
||||||
|
oversized_content = "x" * (settings.max_document_size + 1)
|
||||||
|
|
||||||
|
result = await server.update_document.fn(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
source_path="/test.py",
|
||||||
|
content=oversized_content,
|
||||||
|
chunk_type="text",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
assert "exceeds maximum" in result["error"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ingest_accepts_valid_size_content(self):
|
||||||
|
"""Test that ingest accepts content within size limit."""
|
||||||
|
import server
|
||||||
|
from models import IngestResult
|
||||||
|
|
||||||
|
mock_collections = MagicMock()
|
||||||
|
mock_collections.ingest = AsyncMock(
|
||||||
|
return_value=IngestResult(
|
||||||
|
success=True,
|
||||||
|
chunks_created=1,
|
||||||
|
embeddings_generated=1,
|
||||||
|
source_path="/test.py",
|
||||||
|
collection="default",
|
||||||
|
chunk_ids=["id-1"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
server._collections = mock_collections
|
||||||
|
|
||||||
|
# Small content that's within limits
|
||||||
|
# Pass all parameters to avoid Field default resolution issues
|
||||||
|
result = await server.ingest_content.fn(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
content="def hello(): pass",
|
||||||
|
source_path="/test.py",
|
||||||
|
collection="default",
|
||||||
|
chunk_type="text",
|
||||||
|
file_type=None,
|
||||||
|
metadata=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
2026
mcp-servers/knowledge-base/uv.lock
generated
Normal file
2026
mcp-servers/knowledge-base/uv.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user