forked from cardosofelipe/fast-next-template
feat(knowledge-base): implement Knowledge Base MCP Server (#57)
Implements RAG capabilities with pgvector for semantic search: - Intelligent chunking strategies (code-aware, markdown-aware, text) - Semantic search with vector similarity (HNSW index) - Keyword search with PostgreSQL full-text search - Hybrid search using Reciprocal Rank Fusion (RRF) - Redis caching for embeddings - Collection management (ingest, search, delete, stats) - FastMCP tools: search_knowledge, ingest_content, delete_content, list_collections, get_collection_stats, update_document Testing: - 128 comprehensive tests covering all components - 58% code coverage (database integration tests use mocks) - Passes ruff linting and mypy type checking 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
1
mcp-servers/knowledge-base/tests/__init__.py
Normal file
1
mcp-servers/knowledge-base/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for Knowledge Base MCP Server."""
|
||||
282
mcp-servers/knowledge-base/tests/conftest.py
Normal file
282
mcp-servers/knowledge-base/tests/conftest.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""
|
||||
Test fixtures for Knowledge Base MCP Server.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
# Set test mode before importing modules
|
||||
os.environ["IS_TEST"] = "true"
|
||||
os.environ["KB_DATABASE_URL"] = "postgresql://test:test@localhost:5432/test"
|
||||
os.environ["KB_REDIS_URL"] = "redis://localhost:6379/0"
|
||||
os.environ["KB_LLM_GATEWAY_URL"] = "http://localhost:8001"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def settings():
|
||||
"""Create test settings."""
|
||||
from config import Settings, reset_settings
|
||||
|
||||
reset_settings()
|
||||
return Settings(
|
||||
host="127.0.0.1",
|
||||
port=8002,
|
||||
debug=True,
|
||||
database_url="postgresql://test:test@localhost:5432/test",
|
||||
redis_url="redis://localhost:6379/0",
|
||||
llm_gateway_url="http://localhost:8001",
|
||||
embedding_dimension=1536,
|
||||
code_chunk_size=500,
|
||||
code_chunk_overlap=50,
|
||||
markdown_chunk_size=800,
|
||||
markdown_chunk_overlap=100,
|
||||
text_chunk_size=400,
|
||||
text_chunk_overlap=50,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_database():
|
||||
"""Create mock database manager."""
|
||||
from database import DatabaseManager
|
||||
|
||||
mock_db = MagicMock(spec=DatabaseManager)
|
||||
mock_db._pool = MagicMock()
|
||||
mock_db.acquire = MagicMock(return_value=AsyncMock())
|
||||
|
||||
# Mock database methods
|
||||
mock_db.initialize = AsyncMock()
|
||||
mock_db.close = AsyncMock()
|
||||
mock_db.store_embedding = AsyncMock(return_value="test-id-123")
|
||||
mock_db.store_embeddings_batch = AsyncMock(return_value=["id-1", "id-2"])
|
||||
mock_db.semantic_search = AsyncMock(return_value=[])
|
||||
mock_db.keyword_search = AsyncMock(return_value=[])
|
||||
mock_db.delete_by_source = AsyncMock(return_value=1)
|
||||
mock_db.delete_collection = AsyncMock(return_value=5)
|
||||
mock_db.delete_by_ids = AsyncMock(return_value=2)
|
||||
mock_db.list_collections = AsyncMock(return_value=[])
|
||||
mock_db.get_collection_stats = AsyncMock()
|
||||
mock_db.cleanup_expired = AsyncMock(return_value=0)
|
||||
|
||||
return mock_db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embeddings():
|
||||
"""Create mock embedding generator."""
|
||||
from embeddings import EmbeddingGenerator
|
||||
|
||||
mock_emb = MagicMock(spec=EmbeddingGenerator)
|
||||
mock_emb.initialize = AsyncMock()
|
||||
mock_emb.close = AsyncMock()
|
||||
|
||||
# Generate fake embeddings (1536 dimensions)
|
||||
def fake_embedding() -> list[float]:
|
||||
return [0.1] * 1536
|
||||
|
||||
mock_emb.generate = AsyncMock(return_value=fake_embedding())
|
||||
mock_emb.generate_batch = AsyncMock(side_effect=lambda texts, **_kwargs: [fake_embedding() for _ in texts])
|
||||
|
||||
return mock_emb
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis():
|
||||
"""Create mock Redis client."""
|
||||
import fakeredis.aioredis
|
||||
|
||||
return fakeredis.aioredis.FakeRedis()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_python_code():
|
||||
"""Sample Python code for chunking tests."""
|
||||
return '''"""Sample module for testing."""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Calculator:
|
||||
"""A simple calculator class."""
|
||||
|
||||
def __init__(self, initial: int = 0) -> None:
|
||||
"""Initialize calculator."""
|
||||
self.value = initial
|
||||
|
||||
def add(self, x: int) -> int:
|
||||
"""Add a value."""
|
||||
self.value += x
|
||||
return self.value
|
||||
|
||||
def subtract(self, x: int) -> int:
|
||||
"""Subtract a value."""
|
||||
self.value -= x
|
||||
return self.value
|
||||
|
||||
|
||||
def helper_function(data: dict[str, Any]) -> str:
|
||||
"""A helper function."""
|
||||
return str(data)
|
||||
|
||||
|
||||
async def async_function() -> None:
|
||||
"""An async function."""
|
||||
pass
|
||||
'''
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_markdown():
|
||||
"""Sample Markdown content for chunking tests."""
|
||||
return '''# Project Documentation
|
||||
|
||||
This is the main documentation for our project.
|
||||
|
||||
## Getting Started
|
||||
|
||||
To get started, follow these steps:
|
||||
|
||||
1. Install dependencies
|
||||
2. Configure settings
|
||||
3. Run the application
|
||||
|
||||
### Prerequisites
|
||||
|
||||
You'll need the following installed:
|
||||
|
||||
- Python 3.12+
|
||||
- PostgreSQL
|
||||
- Redis
|
||||
|
||||
```python
|
||||
# Example code
|
||||
def main():
|
||||
print("Hello, World!")
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### Search Endpoint
|
||||
|
||||
The search endpoint allows you to query the knowledge base.
|
||||
|
||||
**Endpoint:** `POST /api/search`
|
||||
|
||||
**Request:**
|
||||
```json
|
||||
{
|
||||
"query": "your search query",
|
||||
"limit": 10
|
||||
}
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
We welcome contributions! Please see our contributing guide.
|
||||
'''
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_text():
|
||||
"""Sample plain text for chunking tests."""
|
||||
return '''The quick brown fox jumps over the lazy dog. This is a sample text that we use for testing the text chunking functionality. It contains multiple sentences that should be properly split into chunks.
|
||||
|
||||
Each paragraph represents a logical unit of text. The chunker should try to respect paragraph boundaries when possible. This helps maintain context and readability.
|
||||
|
||||
When chunks need to be split mid-paragraph, the chunker should prefer sentence boundaries. This ensures that each chunk contains complete thoughts and is useful for retrieval.
|
||||
|
||||
The final paragraph tests edge cases. What happens with short paragraphs? Do they get merged with adjacent content? Let's find out!
|
||||
'''
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_chunk():
|
||||
"""Sample chunk for testing."""
|
||||
from models import Chunk, ChunkType, FileType
|
||||
|
||||
return Chunk(
|
||||
content="def hello():\n print('Hello')",
|
||||
chunk_type=ChunkType.CODE,
|
||||
file_type=FileType.PYTHON,
|
||||
source_path="/test/hello.py",
|
||||
start_line=1,
|
||||
end_line=2,
|
||||
metadata={"function": "hello"},
|
||||
token_count=15,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_embedding():
|
||||
"""Sample knowledge embedding for testing."""
|
||||
from models import ChunkType, FileType, KnowledgeEmbedding
|
||||
|
||||
return KnowledgeEmbedding(
|
||||
id="test-id-123",
|
||||
project_id="proj-123",
|
||||
collection="default",
|
||||
content="def hello():\n print('Hello')",
|
||||
embedding=[0.1] * 1536,
|
||||
chunk_type=ChunkType.CODE,
|
||||
source_path="/test/hello.py",
|
||||
start_line=1,
|
||||
end_line=2,
|
||||
file_type=FileType.PYTHON,
|
||||
metadata={"function": "hello"},
|
||||
content_hash="abc123",
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_ingest_request():
|
||||
"""Sample ingest request for testing."""
|
||||
from models import ChunkType, FileType, IngestRequest
|
||||
|
||||
return IngestRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
content="def hello():\n print('Hello')",
|
||||
source_path="/test/hello.py",
|
||||
collection="default",
|
||||
chunk_type=ChunkType.CODE,
|
||||
file_type=FileType.PYTHON,
|
||||
metadata={"test": True},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_search_request():
|
||||
"""Sample search request for testing."""
|
||||
from models import SearchRequest, SearchType
|
||||
|
||||
return SearchRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="hello function",
|
||||
search_type=SearchType.HYBRID,
|
||||
collection="default",
|
||||
limit=10,
|
||||
threshold=0.7,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_delete_request():
|
||||
"""Sample delete request for testing."""
|
||||
from models import DeleteRequest
|
||||
|
||||
return DeleteRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
source_path="/test/hello.py",
|
||||
)
|
||||
422
mcp-servers/knowledge-base/tests/test_chunking.py
Normal file
422
mcp-servers/knowledge-base/tests/test_chunking.py
Normal file
@@ -0,0 +1,422 @@
|
||||
"""Tests for chunking module."""
|
||||
|
||||
|
||||
|
||||
class TestBaseChunker:
|
||||
"""Tests for base chunker functionality."""
|
||||
|
||||
def test_count_tokens(self, settings):
|
||||
"""Test token counting."""
|
||||
from chunking.text import TextChunker
|
||||
|
||||
chunker = TextChunker(
|
||||
chunk_size=400,
|
||||
chunk_overlap=50,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
# Simple text should count tokens
|
||||
tokens = chunker.count_tokens("Hello, world!")
|
||||
assert tokens > 0
|
||||
assert tokens < 10 # Should be about 3-4 tokens
|
||||
|
||||
def test_truncate_to_tokens(self, settings):
|
||||
"""Test truncating text to token limit."""
|
||||
from chunking.text import TextChunker
|
||||
|
||||
chunker = TextChunker(
|
||||
chunk_size=400,
|
||||
chunk_overlap=50,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
long_text = "word " * 1000
|
||||
truncated = chunker.truncate_to_tokens(long_text, 10)
|
||||
|
||||
assert chunker.count_tokens(truncated) <= 10
|
||||
|
||||
|
||||
class TestCodeChunker:
|
||||
"""Tests for code chunker."""
|
||||
|
||||
def test_chunk_python_code(self, settings, sample_python_code):
|
||||
"""Test chunking Python code."""
|
||||
from chunking.code import CodeChunker
|
||||
from models import ChunkType, FileType
|
||||
|
||||
chunker = CodeChunker(
|
||||
chunk_size=500,
|
||||
chunk_overlap=50,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
chunks = chunker.chunk(
|
||||
content=sample_python_code,
|
||||
source_path="/test/sample.py",
|
||||
file_type=FileType.PYTHON,
|
||||
)
|
||||
|
||||
assert len(chunks) > 0
|
||||
assert all(c.chunk_type == ChunkType.CODE for c in chunks)
|
||||
assert all(c.file_type == FileType.PYTHON for c in chunks)
|
||||
|
||||
def test_preserves_function_boundaries(self, settings):
|
||||
"""Test that chunker preserves function boundaries."""
|
||||
from chunking.code import CodeChunker
|
||||
from models import FileType
|
||||
|
||||
code = '''def function_one():
|
||||
"""First function."""
|
||||
return 1
|
||||
|
||||
def function_two():
|
||||
"""Second function."""
|
||||
return 2
|
||||
'''
|
||||
|
||||
chunker = CodeChunker(
|
||||
chunk_size=100,
|
||||
chunk_overlap=10,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
chunks = chunker.chunk(
|
||||
content=code,
|
||||
source_path="/test/funcs.py",
|
||||
file_type=FileType.PYTHON,
|
||||
)
|
||||
|
||||
# Each function should ideally be in its own chunk
|
||||
assert len(chunks) >= 1
|
||||
for chunk in chunks:
|
||||
# Check chunks have line numbers
|
||||
assert chunk.start_line is not None
|
||||
assert chunk.end_line is not None
|
||||
assert chunk.start_line <= chunk.end_line
|
||||
|
||||
def test_handles_empty_content(self, settings):
|
||||
"""Test handling empty content."""
|
||||
from chunking.code import CodeChunker
|
||||
|
||||
chunker = CodeChunker(
|
||||
chunk_size=500,
|
||||
chunk_overlap=50,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
chunks = chunker.chunk(content="", source_path="/test/empty.py")
|
||||
|
||||
assert chunks == []
|
||||
|
||||
def test_chunk_type_is_code(self, settings):
|
||||
"""Test that chunk_type property returns CODE."""
|
||||
from chunking.code import CodeChunker
|
||||
from models import ChunkType
|
||||
|
||||
chunker = CodeChunker(
|
||||
chunk_size=500,
|
||||
chunk_overlap=50,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
assert chunker.chunk_type == ChunkType.CODE
|
||||
|
||||
|
||||
class TestMarkdownChunker:
|
||||
"""Tests for markdown chunker."""
|
||||
|
||||
def test_chunk_markdown(self, settings, sample_markdown):
|
||||
"""Test chunking markdown content."""
|
||||
from chunking.markdown import MarkdownChunker
|
||||
from models import ChunkType, FileType
|
||||
|
||||
chunker = MarkdownChunker(
|
||||
chunk_size=800,
|
||||
chunk_overlap=100,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
chunks = chunker.chunk(
|
||||
content=sample_markdown,
|
||||
source_path="/test/docs.md",
|
||||
file_type=FileType.MARKDOWN,
|
||||
)
|
||||
|
||||
assert len(chunks) > 0
|
||||
assert all(c.chunk_type == ChunkType.MARKDOWN for c in chunks)
|
||||
|
||||
def test_respects_heading_hierarchy(self, settings):
|
||||
"""Test that chunker respects heading hierarchy."""
|
||||
from chunking.markdown import MarkdownChunker
|
||||
|
||||
markdown = '''# Main Title
|
||||
|
||||
Introduction paragraph.
|
||||
|
||||
## Section One
|
||||
|
||||
Content for section one.
|
||||
|
||||
### Subsection
|
||||
|
||||
More detailed content.
|
||||
|
||||
## Section Two
|
||||
|
||||
Content for section two.
|
||||
'''
|
||||
|
||||
chunker = MarkdownChunker(
|
||||
chunk_size=200,
|
||||
chunk_overlap=20,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
chunks = chunker.chunk(
|
||||
content=markdown,
|
||||
source_path="/test/docs.md",
|
||||
)
|
||||
|
||||
# Should have multiple chunks based on sections
|
||||
assert len(chunks) >= 1
|
||||
# Metadata should include heading context
|
||||
for chunk in chunks:
|
||||
# Chunks should have content
|
||||
assert len(chunk.content) > 0
|
||||
|
||||
def test_handles_code_blocks(self, settings):
|
||||
"""Test handling of code blocks in markdown."""
|
||||
from chunking.markdown import MarkdownChunker
|
||||
|
||||
markdown = '''# Code Example
|
||||
|
||||
Here's some code:
|
||||
|
||||
```python
|
||||
def hello():
|
||||
print("Hello, World!")
|
||||
```
|
||||
|
||||
End of example.
|
||||
'''
|
||||
|
||||
chunker = MarkdownChunker(
|
||||
chunk_size=500,
|
||||
chunk_overlap=50,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
chunks = chunker.chunk(
|
||||
content=markdown,
|
||||
source_path="/test/code.md",
|
||||
)
|
||||
|
||||
# Code blocks should be preserved
|
||||
assert len(chunks) >= 1
|
||||
full_content = " ".join(c.content for c in chunks)
|
||||
assert "```python" in full_content or "def hello" in full_content
|
||||
|
||||
def test_chunk_type_is_markdown(self, settings):
|
||||
"""Test that chunk_type property returns MARKDOWN."""
|
||||
from chunking.markdown import MarkdownChunker
|
||||
from models import ChunkType
|
||||
|
||||
chunker = MarkdownChunker(
|
||||
chunk_size=800,
|
||||
chunk_overlap=100,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
assert chunker.chunk_type == ChunkType.MARKDOWN
|
||||
|
||||
|
||||
class TestTextChunker:
|
||||
"""Tests for text chunker."""
|
||||
|
||||
def test_chunk_text(self, settings, sample_text):
|
||||
"""Test chunking plain text."""
|
||||
from chunking.text import TextChunker
|
||||
from models import ChunkType
|
||||
|
||||
chunker = TextChunker(
|
||||
chunk_size=400,
|
||||
chunk_overlap=50,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
chunks = chunker.chunk(
|
||||
content=sample_text,
|
||||
source_path="/test/docs.txt",
|
||||
)
|
||||
|
||||
assert len(chunks) > 0
|
||||
assert all(c.chunk_type == ChunkType.TEXT for c in chunks)
|
||||
|
||||
def test_respects_paragraph_boundaries(self, settings):
|
||||
"""Test that chunker respects paragraph boundaries."""
|
||||
from chunking.text import TextChunker
|
||||
|
||||
text = '''First paragraph with some content.
|
||||
|
||||
Second paragraph with different content.
|
||||
|
||||
Third paragraph to test chunking behavior.
|
||||
'''
|
||||
|
||||
chunker = TextChunker(
|
||||
chunk_size=100,
|
||||
chunk_overlap=10,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
chunks = chunker.chunk(
|
||||
content=text,
|
||||
source_path="/test/text.txt",
|
||||
)
|
||||
|
||||
assert len(chunks) >= 1
|
||||
|
||||
def test_handles_single_paragraph(self, settings):
|
||||
"""Test handling of single paragraph that fits in one chunk."""
|
||||
from chunking.text import TextChunker
|
||||
|
||||
text = "This is a short paragraph."
|
||||
|
||||
chunker = TextChunker(
|
||||
chunk_size=400,
|
||||
chunk_overlap=50,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
chunks = chunker.chunk(content=text, source_path="/test/short.txt")
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].content == text
|
||||
|
||||
def test_chunk_type_is_text(self, settings):
|
||||
"""Test that chunk_type property returns TEXT."""
|
||||
from chunking.text import TextChunker
|
||||
from models import ChunkType
|
||||
|
||||
chunker = TextChunker(
|
||||
chunk_size=400,
|
||||
chunk_overlap=50,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
assert chunker.chunk_type == ChunkType.TEXT
|
||||
|
||||
|
||||
class TestChunkerFactory:
|
||||
"""Tests for chunker factory."""
|
||||
|
||||
def test_get_code_chunker(self, settings):
|
||||
"""Test getting code chunker."""
|
||||
from chunking.base import ChunkerFactory
|
||||
from chunking.code import CodeChunker
|
||||
from models import FileType
|
||||
|
||||
factory = ChunkerFactory(settings=settings)
|
||||
chunker = factory.get_chunker(file_type=FileType.PYTHON)
|
||||
|
||||
assert isinstance(chunker, CodeChunker)
|
||||
|
||||
def test_get_markdown_chunker(self, settings):
|
||||
"""Test getting markdown chunker."""
|
||||
from chunking.base import ChunkerFactory
|
||||
from chunking.markdown import MarkdownChunker
|
||||
from models import FileType
|
||||
|
||||
factory = ChunkerFactory(settings=settings)
|
||||
chunker = factory.get_chunker(file_type=FileType.MARKDOWN)
|
||||
|
||||
assert isinstance(chunker, MarkdownChunker)
|
||||
|
||||
def test_get_text_chunker(self, settings):
|
||||
"""Test getting text chunker."""
|
||||
from chunking.base import ChunkerFactory
|
||||
from chunking.text import TextChunker
|
||||
from models import FileType
|
||||
|
||||
factory = ChunkerFactory(settings=settings)
|
||||
chunker = factory.get_chunker(file_type=FileType.TEXT)
|
||||
|
||||
assert isinstance(chunker, TextChunker)
|
||||
|
||||
def test_get_chunker_for_path(self, settings):
|
||||
"""Test getting chunker based on file path."""
|
||||
from chunking.base import ChunkerFactory
|
||||
from chunking.code import CodeChunker
|
||||
from chunking.markdown import MarkdownChunker
|
||||
from models import FileType
|
||||
|
||||
factory = ChunkerFactory(settings=settings)
|
||||
|
||||
chunker, file_type = factory.get_chunker_for_path("/test/file.py")
|
||||
assert isinstance(chunker, CodeChunker)
|
||||
assert file_type == FileType.PYTHON
|
||||
|
||||
chunker, file_type = factory.get_chunker_for_path("/test/docs.md")
|
||||
assert isinstance(chunker, MarkdownChunker)
|
||||
assert file_type == FileType.MARKDOWN
|
||||
|
||||
def test_chunk_content(self, settings, sample_python_code):
|
||||
"""Test chunk_content convenience method."""
|
||||
from chunking.base import ChunkerFactory
|
||||
from models import ChunkType
|
||||
|
||||
factory = ChunkerFactory(settings=settings)
|
||||
|
||||
chunks = factory.chunk_content(
|
||||
content=sample_python_code,
|
||||
source_path="/test/sample.py",
|
||||
)
|
||||
|
||||
assert len(chunks) > 0
|
||||
assert all(c.chunk_type == ChunkType.CODE for c in chunks)
|
||||
|
||||
def test_default_to_text_chunker(self, settings):
|
||||
"""Test defaulting to text chunker."""
|
||||
from chunking.base import ChunkerFactory
|
||||
from chunking.text import TextChunker
|
||||
|
||||
factory = ChunkerFactory(settings=settings)
|
||||
chunker = factory.get_chunker()
|
||||
|
||||
assert isinstance(chunker, TextChunker)
|
||||
|
||||
def test_chunker_caching(self, settings):
|
||||
"""Test that factory caches chunker instances."""
|
||||
from chunking.base import ChunkerFactory
|
||||
from models import FileType
|
||||
|
||||
factory = ChunkerFactory(settings=settings)
|
||||
|
||||
chunker1 = factory.get_chunker(file_type=FileType.PYTHON)
|
||||
chunker2 = factory.get_chunker(file_type=FileType.PYTHON)
|
||||
|
||||
assert chunker1 is chunker2
|
||||
|
||||
|
||||
class TestGlobalChunkerFactory:
|
||||
"""Tests for global chunker factory."""
|
||||
|
||||
def test_get_chunker_factory_singleton(self):
|
||||
"""Test that get_chunker_factory returns singleton."""
|
||||
from chunking.base import get_chunker_factory, reset_chunker_factory
|
||||
|
||||
reset_chunker_factory()
|
||||
factory1 = get_chunker_factory()
|
||||
factory2 = get_chunker_factory()
|
||||
|
||||
assert factory1 is factory2
|
||||
|
||||
def test_reset_chunker_factory(self):
|
||||
"""Test resetting chunker factory."""
|
||||
from chunking.base import get_chunker_factory, reset_chunker_factory
|
||||
|
||||
factory1 = get_chunker_factory()
|
||||
reset_chunker_factory()
|
||||
factory2 = get_chunker_factory()
|
||||
|
||||
assert factory1 is not factory2
|
||||
240
mcp-servers/knowledge-base/tests/test_collection_manager.py
Normal file
240
mcp-servers/knowledge-base/tests/test_collection_manager.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""Tests for collection manager module."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestCollectionManager:
|
||||
"""Tests for CollectionManager class."""
|
||||
|
||||
@pytest.fixture
|
||||
def collection_manager(self, settings, mock_database, mock_embeddings):
|
||||
"""Create collection manager with mocks."""
|
||||
from chunking.base import ChunkerFactory
|
||||
from collection_manager import CollectionManager
|
||||
|
||||
mock_chunker_factory = MagicMock(spec=ChunkerFactory)
|
||||
|
||||
# Mock chunk_content to return chunks
|
||||
from models import Chunk, ChunkType
|
||||
|
||||
mock_chunker_factory.chunk_content.return_value = [
|
||||
Chunk(
|
||||
content="def hello(): pass",
|
||||
chunk_type=ChunkType.CODE,
|
||||
token_count=10,
|
||||
)
|
||||
]
|
||||
|
||||
return CollectionManager(
|
||||
settings=settings,
|
||||
database=mock_database,
|
||||
embeddings=mock_embeddings,
|
||||
chunker_factory=mock_chunker_factory,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_content(self, collection_manager, sample_ingest_request):
|
||||
"""Test content ingestion."""
|
||||
result = await collection_manager.ingest(sample_ingest_request)
|
||||
|
||||
assert result.success is True
|
||||
assert result.chunks_created == 1
|
||||
assert result.embeddings_generated == 1
|
||||
assert len(result.chunk_ids) == 1
|
||||
assert result.collection == "default"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_empty_content(self, collection_manager):
|
||||
"""Test ingesting empty content."""
|
||||
from models import IngestRequest
|
||||
|
||||
# Mock chunker to return empty list
|
||||
collection_manager._chunker_factory.chunk_content.return_value = []
|
||||
|
||||
request = IngestRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
content="",
|
||||
)
|
||||
|
||||
result = await collection_manager.ingest(request)
|
||||
|
||||
assert result.success is True
|
||||
assert result.chunks_created == 0
|
||||
assert result.embeddings_generated == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_error_handling(self, collection_manager, sample_ingest_request):
|
||||
"""Test ingest error handling."""
|
||||
# Make embedding generation fail
|
||||
collection_manager._embeddings.generate_batch.side_effect = Exception("Embedding error")
|
||||
|
||||
result = await collection_manager.ingest(sample_ingest_request)
|
||||
|
||||
assert result.success is False
|
||||
assert "Embedding error" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_by_source(self, collection_manager, sample_delete_request):
|
||||
"""Test deletion by source path."""
|
||||
result = await collection_manager.delete(sample_delete_request)
|
||||
|
||||
assert result.success is True
|
||||
assert result.chunks_deleted == 1 # Mock returns 1
|
||||
collection_manager._database.delete_by_source.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_by_collection(self, collection_manager):
|
||||
"""Test deletion by collection."""
|
||||
from models import DeleteRequest
|
||||
|
||||
request = DeleteRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
collection="to-delete",
|
||||
)
|
||||
|
||||
result = await collection_manager.delete(request)
|
||||
|
||||
assert result.success is True
|
||||
collection_manager._database.delete_collection.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_by_ids(self, collection_manager):
|
||||
"""Test deletion by chunk IDs."""
|
||||
from models import DeleteRequest
|
||||
|
||||
request = DeleteRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
chunk_ids=["id-1", "id-2"],
|
||||
)
|
||||
|
||||
result = await collection_manager.delete(request)
|
||||
|
||||
assert result.success is True
|
||||
collection_manager._database.delete_by_ids.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_no_target(self, collection_manager):
|
||||
"""Test deletion with no target specified."""
|
||||
from models import DeleteRequest
|
||||
|
||||
request = DeleteRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
)
|
||||
|
||||
result = await collection_manager.delete(request)
|
||||
|
||||
assert result.success is False
|
||||
assert "Must specify" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_collections(self, collection_manager):
|
||||
"""Test listing collections."""
|
||||
from models import CollectionInfo
|
||||
|
||||
collection_manager._database.list_collections.return_value = [
|
||||
CollectionInfo(
|
||||
name="collection-1",
|
||||
project_id="proj-123",
|
||||
chunk_count=100,
|
||||
total_tokens=50000,
|
||||
file_types=["python"],
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
),
|
||||
CollectionInfo(
|
||||
name="collection-2",
|
||||
project_id="proj-123",
|
||||
chunk_count=50,
|
||||
total_tokens=25000,
|
||||
file_types=["javascript"],
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
),
|
||||
]
|
||||
|
||||
result = await collection_manager.list_collections("proj-123")
|
||||
|
||||
assert result.project_id == "proj-123"
|
||||
assert result.total_collections == 2
|
||||
assert len(result.collections) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_collection_stats(self, collection_manager):
|
||||
"""Test getting collection statistics."""
|
||||
from models import CollectionStats
|
||||
|
||||
expected_stats = CollectionStats(
|
||||
collection="test-collection",
|
||||
project_id="proj-123",
|
||||
chunk_count=100,
|
||||
unique_sources=10,
|
||||
total_tokens=50000,
|
||||
avg_chunk_size=500.0,
|
||||
chunk_types={"code": 60, "text": 40},
|
||||
file_types={"python": 50, "javascript": 10},
|
||||
)
|
||||
collection_manager._database.get_collection_stats.return_value = expected_stats
|
||||
|
||||
stats = await collection_manager.get_collection_stats("proj-123", "test-collection")
|
||||
|
||||
assert stats.chunk_count == 100
|
||||
assert stats.unique_sources == 10
|
||||
collection_manager._database.get_collection_stats.assert_called_once_with(
|
||||
"proj-123", "test-collection"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_document(self, collection_manager):
|
||||
"""Test updating a document."""
|
||||
result = await collection_manager.update_document(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
source_path="/test/file.py",
|
||||
content="def updated(): pass",
|
||||
collection="default",
|
||||
)
|
||||
|
||||
# Should delete first, then ingest
|
||||
collection_manager._database.delete_by_source.assert_called_once()
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired(self, collection_manager):
|
||||
"""Test cleaning up expired embeddings."""
|
||||
collection_manager._database.cleanup_expired.return_value = 10
|
||||
|
||||
count = await collection_manager.cleanup_expired()
|
||||
|
||||
assert count == 10
|
||||
collection_manager._database.cleanup_expired.assert_called_once()
|
||||
|
||||
|
||||
class TestGlobalCollectionManager:
|
||||
"""Tests for global collection manager."""
|
||||
|
||||
def test_get_collection_manager_singleton(self):
|
||||
"""Test that get_collection_manager returns singleton."""
|
||||
from collection_manager import get_collection_manager, reset_collection_manager
|
||||
|
||||
reset_collection_manager()
|
||||
manager1 = get_collection_manager()
|
||||
manager2 = get_collection_manager()
|
||||
|
||||
assert manager1 is manager2
|
||||
|
||||
def test_reset_collection_manager(self):
|
||||
"""Test resetting collection manager."""
|
||||
from collection_manager import get_collection_manager, reset_collection_manager
|
||||
|
||||
manager1 = get_collection_manager()
|
||||
reset_collection_manager()
|
||||
manager2 = get_collection_manager()
|
||||
|
||||
assert manager1 is not manager2
|
||||
104
mcp-servers/knowledge-base/tests/test_config.py
Normal file
104
mcp-servers/knowledge-base/tests/test_config.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Tests for configuration module."""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class TestSettings:
|
||||
"""Tests for Settings class."""
|
||||
|
||||
def test_default_values(self, settings):
|
||||
"""Test default configuration values."""
|
||||
assert settings.port == 8002
|
||||
assert settings.embedding_dimension == 1536
|
||||
assert settings.code_chunk_size == 500
|
||||
assert settings.search_default_limit == 10
|
||||
|
||||
def test_env_prefix(self):
|
||||
"""Test environment variable prefix."""
|
||||
from config import Settings, reset_settings
|
||||
|
||||
reset_settings()
|
||||
os.environ["KB_PORT"] = "9999"
|
||||
|
||||
settings = Settings()
|
||||
assert settings.port == 9999
|
||||
|
||||
# Cleanup
|
||||
del os.environ["KB_PORT"]
|
||||
reset_settings()
|
||||
|
||||
def test_embedding_settings(self, settings):
|
||||
"""Test embedding-related settings."""
|
||||
assert settings.embedding_model == "text-embedding-3-large"
|
||||
assert settings.embedding_batch_size == 100
|
||||
assert settings.embedding_cache_ttl == 86400
|
||||
|
||||
def test_chunking_settings(self, settings):
|
||||
"""Test chunking-related settings."""
|
||||
assert settings.code_chunk_size == 500
|
||||
assert settings.code_chunk_overlap == 50
|
||||
assert settings.markdown_chunk_size == 800
|
||||
assert settings.markdown_chunk_overlap == 100
|
||||
assert settings.text_chunk_size == 400
|
||||
assert settings.text_chunk_overlap == 50
|
||||
|
||||
def test_search_settings(self, settings):
|
||||
"""Test search-related settings."""
|
||||
assert settings.search_default_limit == 10
|
||||
assert settings.search_max_limit == 100
|
||||
assert settings.semantic_threshold == 0.7
|
||||
assert settings.hybrid_semantic_weight == 0.7
|
||||
assert settings.hybrid_keyword_weight == 0.3
|
||||
|
||||
|
||||
class TestGetSettings:
|
||||
"""Tests for get_settings function."""
|
||||
|
||||
def test_returns_singleton(self):
|
||||
"""Test that get_settings returns singleton."""
|
||||
from config import get_settings, reset_settings
|
||||
|
||||
reset_settings()
|
||||
settings1 = get_settings()
|
||||
settings2 = get_settings()
|
||||
assert settings1 is settings2
|
||||
|
||||
def test_reset_creates_new_instance(self):
|
||||
"""Test that reset_settings clears the singleton."""
|
||||
from config import get_settings, reset_settings
|
||||
|
||||
settings1 = get_settings()
|
||||
reset_settings()
|
||||
settings2 = get_settings()
|
||||
assert settings1 is not settings2
|
||||
|
||||
|
||||
class TestIsTestMode:
|
||||
"""Tests for is_test_mode function."""
|
||||
|
||||
def test_returns_true_when_set(self):
|
||||
"""Test returns True when IS_TEST is set."""
|
||||
from config import is_test_mode
|
||||
|
||||
old_value = os.environ.get("IS_TEST")
|
||||
os.environ["IS_TEST"] = "true"
|
||||
|
||||
assert is_test_mode() is True
|
||||
|
||||
if old_value:
|
||||
os.environ["IS_TEST"] = old_value
|
||||
else:
|
||||
del os.environ["IS_TEST"]
|
||||
|
||||
def test_returns_false_when_not_set(self):
|
||||
"""Test returns False when IS_TEST is not set."""
|
||||
from config import is_test_mode
|
||||
|
||||
old_value = os.environ.get("IS_TEST")
|
||||
if old_value:
|
||||
del os.environ["IS_TEST"]
|
||||
|
||||
assert is_test_mode() is False
|
||||
|
||||
if old_value:
|
||||
os.environ["IS_TEST"] = old_value
|
||||
245
mcp-servers/knowledge-base/tests/test_embeddings.py
Normal file
245
mcp-servers/knowledge-base/tests/test_embeddings.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""Tests for embedding generation module."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestEmbeddingGenerator:
|
||||
"""Tests for EmbeddingGenerator class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_http_response(self):
|
||||
"""Create mock HTTP response."""
|
||||
response = MagicMock()
|
||||
response.status_code = 200
|
||||
response.raise_for_status = MagicMock()
|
||||
response.json.return_value = {
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"text": json.dumps({
|
||||
"embeddings": [[0.1] * 1536]
|
||||
})
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
return response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_single_embedding(self, settings, mock_redis, mock_http_response):
|
||||
"""Test generating a single embedding."""
|
||||
from embeddings import EmbeddingGenerator
|
||||
|
||||
generator = EmbeddingGenerator(settings=settings)
|
||||
generator._redis = mock_redis
|
||||
|
||||
# Mock HTTP client
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_http_response
|
||||
generator._http_client = mock_client
|
||||
|
||||
embedding = await generator.generate(
|
||||
text="Hello, world!",
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
)
|
||||
|
||||
assert len(embedding) == 1536
|
||||
mock_client.post.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_batch_embeddings(self, settings, mock_redis):
|
||||
"""Test generating batch embeddings."""
|
||||
from embeddings import EmbeddingGenerator
|
||||
|
||||
generator = EmbeddingGenerator(settings=settings)
|
||||
generator._redis = mock_redis
|
||||
|
||||
# Mock HTTP client with batch response
|
||||
mock_client = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"text": json.dumps({
|
||||
"embeddings": [[0.1] * 1536, [0.2] * 1536, [0.3] * 1536]
|
||||
})
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
mock_client.post.return_value = mock_response
|
||||
generator._http_client = mock_client
|
||||
|
||||
embeddings = await generator.generate_batch(
|
||||
texts=["Text 1", "Text 2", "Text 3"],
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
)
|
||||
|
||||
assert len(embeddings) == 3
|
||||
assert all(len(e) == 1536 for e in embeddings)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caching(self, settings, mock_redis):
|
||||
"""Test embedding caching."""
|
||||
from embeddings import EmbeddingGenerator
|
||||
|
||||
generator = EmbeddingGenerator(settings=settings)
|
||||
generator._redis = mock_redis
|
||||
|
||||
# Pre-populate cache
|
||||
cache_key = generator._cache_key("Hello, world!")
|
||||
await mock_redis.setex(cache_key, 3600, json.dumps([0.5] * 1536))
|
||||
|
||||
# Mock HTTP client (should not be called)
|
||||
mock_client = AsyncMock()
|
||||
generator._http_client = mock_client
|
||||
|
||||
embedding = await generator.generate(
|
||||
text="Hello, world!",
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
)
|
||||
|
||||
# Should return cached embedding
|
||||
assert len(embedding) == 1536
|
||||
assert embedding[0] == 0.5
|
||||
mock_client.post.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_miss(self, settings, mock_redis, mock_http_response):
|
||||
"""Test embedding cache miss."""
|
||||
from embeddings import EmbeddingGenerator
|
||||
|
||||
generator = EmbeddingGenerator(settings=settings)
|
||||
generator._redis = mock_redis
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_http_response
|
||||
generator._http_client = mock_client
|
||||
|
||||
embedding = await generator.generate(
|
||||
text="New text not in cache",
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
)
|
||||
|
||||
assert len(embedding) == 1536
|
||||
mock_client.post.assert_called_once()
|
||||
|
||||
def test_cache_key_generation(self, settings):
|
||||
"""Test cache key generation."""
|
||||
from embeddings import EmbeddingGenerator
|
||||
|
||||
generator = EmbeddingGenerator(settings=settings)
|
||||
|
||||
key1 = generator._cache_key("Hello")
|
||||
key2 = generator._cache_key("Hello")
|
||||
key3 = generator._cache_key("World")
|
||||
|
||||
assert key1 == key2
|
||||
assert key1 != key3
|
||||
assert key1.startswith("kb:emb:")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dimension_validation(self, settings, mock_redis):
|
||||
"""Test embedding dimension validation."""
|
||||
from embeddings import EmbeddingGenerator
|
||||
from exceptions import EmbeddingDimensionMismatchError
|
||||
|
||||
generator = EmbeddingGenerator(settings=settings)
|
||||
generator._redis = mock_redis
|
||||
|
||||
# Mock HTTP client with wrong dimension
|
||||
mock_client = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"text": json.dumps({
|
||||
"embeddings": [[0.1] * 768] # Wrong dimension
|
||||
})
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
mock_client.post.return_value = mock_response
|
||||
generator._http_client = mock_client
|
||||
|
||||
with pytest.raises(EmbeddingDimensionMismatchError):
|
||||
await generator.generate(
|
||||
text="Test text",
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_batch(self, settings, mock_redis):
|
||||
"""Test generating embeddings for empty batch."""
|
||||
from embeddings import EmbeddingGenerator
|
||||
|
||||
generator = EmbeddingGenerator(settings=settings)
|
||||
generator._redis = mock_redis
|
||||
|
||||
embeddings = await generator.generate_batch(
|
||||
texts=[],
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
)
|
||||
|
||||
assert embeddings == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_and_close(self, settings):
|
||||
"""Test initialize and close methods."""
|
||||
from embeddings import EmbeddingGenerator
|
||||
|
||||
generator = EmbeddingGenerator(settings=settings)
|
||||
|
||||
# Mock successful initialization
|
||||
with patch("embeddings.redis.from_url") as mock_redis_from_url:
|
||||
mock_redis_client = AsyncMock()
|
||||
mock_redis_client.ping = AsyncMock()
|
||||
mock_redis_from_url.return_value = mock_redis_client
|
||||
|
||||
await generator.initialize()
|
||||
|
||||
assert generator._http_client is not None
|
||||
|
||||
await generator.close()
|
||||
|
||||
assert generator._http_client is None
|
||||
|
||||
|
||||
class TestGlobalEmbeddingGenerator:
|
||||
"""Tests for global embedding generator."""
|
||||
|
||||
def test_get_embedding_generator_singleton(self):
|
||||
"""Test that get_embedding_generator returns singleton."""
|
||||
from embeddings import get_embedding_generator, reset_embedding_generator
|
||||
|
||||
reset_embedding_generator()
|
||||
gen1 = get_embedding_generator()
|
||||
gen2 = get_embedding_generator()
|
||||
|
||||
assert gen1 is gen2
|
||||
|
||||
def test_reset_embedding_generator(self):
|
||||
"""Test resetting embedding generator."""
|
||||
from embeddings import get_embedding_generator, reset_embedding_generator
|
||||
|
||||
gen1 = get_embedding_generator()
|
||||
reset_embedding_generator()
|
||||
gen2 = get_embedding_generator()
|
||||
|
||||
assert gen1 is not gen2
|
||||
307
mcp-servers/knowledge-base/tests/test_exceptions.py
Normal file
307
mcp-servers/knowledge-base/tests/test_exceptions.py
Normal file
@@ -0,0 +1,307 @@
|
||||
"""Tests for exception classes."""
|
||||
|
||||
|
||||
|
||||
class TestErrorCode:
|
||||
"""Tests for ErrorCode enum."""
|
||||
|
||||
def test_error_code_values(self):
|
||||
"""Test error code values."""
|
||||
from exceptions import ErrorCode
|
||||
|
||||
assert ErrorCode.UNKNOWN_ERROR.value == "KB_UNKNOWN_ERROR"
|
||||
assert ErrorCode.DATABASE_CONNECTION_ERROR.value == "KB_DATABASE_CONNECTION_ERROR"
|
||||
assert ErrorCode.EMBEDDING_GENERATION_ERROR.value == "KB_EMBEDDING_GENERATION_ERROR"
|
||||
assert ErrorCode.CHUNKING_ERROR.value == "KB_CHUNKING_ERROR"
|
||||
assert ErrorCode.SEARCH_ERROR.value == "KB_SEARCH_ERROR"
|
||||
assert ErrorCode.COLLECTION_NOT_FOUND.value == "KB_COLLECTION_NOT_FOUND"
|
||||
assert ErrorCode.DOCUMENT_NOT_FOUND.value == "KB_DOCUMENT_NOT_FOUND"
|
||||
|
||||
|
||||
class TestKnowledgeBaseError:
|
||||
"""Tests for base exception class."""
|
||||
|
||||
def test_basic_error(self):
|
||||
"""Test basic error creation."""
|
||||
from exceptions import ErrorCode, KnowledgeBaseError
|
||||
|
||||
error = KnowledgeBaseError(
|
||||
message="Something went wrong",
|
||||
code=ErrorCode.UNKNOWN_ERROR,
|
||||
)
|
||||
|
||||
assert error.message == "Something went wrong"
|
||||
assert error.code == ErrorCode.UNKNOWN_ERROR
|
||||
assert error.details == {}
|
||||
assert error.cause is None
|
||||
|
||||
def test_error_with_details(self):
|
||||
"""Test error with details."""
|
||||
from exceptions import ErrorCode, KnowledgeBaseError
|
||||
|
||||
error = KnowledgeBaseError(
|
||||
message="Query failed",
|
||||
code=ErrorCode.DATABASE_QUERY_ERROR,
|
||||
details={"query": "SELECT * FROM table", "error_code": 42},
|
||||
)
|
||||
|
||||
assert error.details["query"] == "SELECT * FROM table"
|
||||
assert error.details["error_code"] == 42
|
||||
|
||||
def test_error_with_cause(self):
|
||||
"""Test error with underlying cause."""
|
||||
from exceptions import ErrorCode, KnowledgeBaseError
|
||||
|
||||
original = ValueError("Original error")
|
||||
error = KnowledgeBaseError(
|
||||
message="Wrapped error",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
cause=original,
|
||||
)
|
||||
|
||||
assert error.cause is original
|
||||
assert isinstance(error.cause, ValueError)
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test to_dict method."""
|
||||
from exceptions import ErrorCode, KnowledgeBaseError
|
||||
|
||||
error = KnowledgeBaseError(
|
||||
message="Test error",
|
||||
code=ErrorCode.INVALID_REQUEST,
|
||||
details={"field": "value"},
|
||||
)
|
||||
|
||||
result = error.to_dict()
|
||||
|
||||
assert result["error"] == "KB_INVALID_REQUEST"
|
||||
assert result["message"] == "Test error"
|
||||
assert result["details"]["field"] == "value"
|
||||
|
||||
def test_str_representation(self):
|
||||
"""Test string representation."""
|
||||
from exceptions import ErrorCode, KnowledgeBaseError
|
||||
|
||||
error = KnowledgeBaseError(
|
||||
message="Test error",
|
||||
code=ErrorCode.INVALID_REQUEST,
|
||||
)
|
||||
|
||||
assert str(error) == "[KB_INVALID_REQUEST] Test error"
|
||||
|
||||
def test_repr_representation(self):
|
||||
"""Test repr representation."""
|
||||
from exceptions import ErrorCode, KnowledgeBaseError
|
||||
|
||||
error = KnowledgeBaseError(
|
||||
message="Test error",
|
||||
code=ErrorCode.INVALID_REQUEST,
|
||||
details={"key": "value"},
|
||||
)
|
||||
|
||||
repr_str = repr(error)
|
||||
assert "KnowledgeBaseError" in repr_str
|
||||
assert "Test error" in repr_str
|
||||
assert "KB_INVALID_REQUEST" in repr_str
|
||||
|
||||
|
||||
class TestDatabaseErrors:
|
||||
"""Tests for database-related exceptions."""
|
||||
|
||||
def test_database_connection_error(self):
|
||||
"""Test database connection error."""
|
||||
from exceptions import DatabaseConnectionError, ErrorCode
|
||||
|
||||
error = DatabaseConnectionError(
|
||||
message="Cannot connect to database",
|
||||
details={"host": "localhost", "port": 5432},
|
||||
)
|
||||
|
||||
assert error.code == ErrorCode.DATABASE_CONNECTION_ERROR
|
||||
assert error.details["host"] == "localhost"
|
||||
|
||||
def test_database_connection_error_default_message(self):
|
||||
"""Test database connection error with default message."""
|
||||
from exceptions import DatabaseConnectionError
|
||||
|
||||
error = DatabaseConnectionError()
|
||||
|
||||
assert error.message == "Failed to connect to database"
|
||||
|
||||
def test_database_query_error(self):
|
||||
"""Test database query error."""
|
||||
from exceptions import DatabaseQueryError, ErrorCode
|
||||
|
||||
error = DatabaseQueryError(
|
||||
message="Query failed",
|
||||
query="SELECT * FROM missing_table",
|
||||
)
|
||||
|
||||
assert error.code == ErrorCode.DATABASE_QUERY_ERROR
|
||||
assert error.details["query"] == "SELECT * FROM missing_table"
|
||||
|
||||
|
||||
class TestEmbeddingErrors:
|
||||
"""Tests for embedding-related exceptions."""
|
||||
|
||||
def test_embedding_generation_error(self):
|
||||
"""Test embedding generation error."""
|
||||
from exceptions import EmbeddingGenerationError, ErrorCode
|
||||
|
||||
error = EmbeddingGenerationError(
|
||||
message="Failed to generate",
|
||||
texts_count=10,
|
||||
)
|
||||
|
||||
assert error.code == ErrorCode.EMBEDDING_GENERATION_ERROR
|
||||
assert error.details["texts_count"] == 10
|
||||
|
||||
def test_embedding_dimension_mismatch(self):
|
||||
"""Test embedding dimension mismatch error."""
|
||||
from exceptions import EmbeddingDimensionMismatchError, ErrorCode
|
||||
|
||||
error = EmbeddingDimensionMismatchError(
|
||||
expected=1536,
|
||||
actual=768,
|
||||
)
|
||||
|
||||
assert error.code == ErrorCode.EMBEDDING_DIMENSION_MISMATCH
|
||||
assert "expected 1536" in error.message
|
||||
assert "got 768" in error.message
|
||||
assert error.details["expected_dimension"] == 1536
|
||||
assert error.details["actual_dimension"] == 768
|
||||
|
||||
|
||||
class TestChunkingErrors:
|
||||
"""Tests for chunking-related exceptions."""
|
||||
|
||||
def test_unsupported_file_type_error(self):
|
||||
"""Test unsupported file type error."""
|
||||
from exceptions import ErrorCode, UnsupportedFileTypeError
|
||||
|
||||
error = UnsupportedFileTypeError(
|
||||
file_type=".xyz",
|
||||
supported_types=[".py", ".js", ".md"],
|
||||
)
|
||||
|
||||
assert error.code == ErrorCode.UNSUPPORTED_FILE_TYPE
|
||||
assert error.details["file_type"] == ".xyz"
|
||||
assert len(error.details["supported_types"]) == 3
|
||||
|
||||
def test_file_too_large_error(self):
|
||||
"""Test file too large error."""
|
||||
from exceptions import ErrorCode, FileTooLargeError
|
||||
|
||||
error = FileTooLargeError(
|
||||
file_size=10_000_000,
|
||||
max_size=1_000_000,
|
||||
)
|
||||
|
||||
assert error.code == ErrorCode.FILE_TOO_LARGE
|
||||
assert error.details["file_size"] == 10_000_000
|
||||
assert error.details["max_size"] == 1_000_000
|
||||
|
||||
def test_encoding_error(self):
|
||||
"""Test encoding error."""
|
||||
from exceptions import EncodingError, ErrorCode
|
||||
|
||||
error = EncodingError(
|
||||
message="Cannot decode file",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
assert error.code == ErrorCode.ENCODING_ERROR
|
||||
assert error.details["encoding"] == "utf-8"
|
||||
|
||||
|
||||
class TestSearchErrors:
|
||||
"""Tests for search-related exceptions."""
|
||||
|
||||
def test_invalid_search_type_error(self):
|
||||
"""Test invalid search type error."""
|
||||
from exceptions import ErrorCode, InvalidSearchTypeError
|
||||
|
||||
error = InvalidSearchTypeError(
|
||||
search_type="invalid",
|
||||
valid_types=["semantic", "keyword", "hybrid"],
|
||||
)
|
||||
|
||||
assert error.code == ErrorCode.INVALID_SEARCH_TYPE
|
||||
assert error.details["search_type"] == "invalid"
|
||||
assert len(error.details["valid_types"]) == 3
|
||||
|
||||
def test_search_timeout_error(self):
|
||||
"""Test search timeout error."""
|
||||
from exceptions import ErrorCode, SearchTimeoutError
|
||||
|
||||
error = SearchTimeoutError(timeout=30.0)
|
||||
|
||||
assert error.code == ErrorCode.SEARCH_TIMEOUT
|
||||
assert error.details["timeout"] == 30.0
|
||||
assert "30" in error.message
|
||||
|
||||
|
||||
class TestCollectionErrors:
|
||||
"""Tests for collection-related exceptions."""
|
||||
|
||||
def test_collection_not_found_error(self):
|
||||
"""Test collection not found error."""
|
||||
from exceptions import CollectionNotFoundError, ErrorCode
|
||||
|
||||
error = CollectionNotFoundError(
|
||||
collection="missing-collection",
|
||||
project_id="proj-123",
|
||||
)
|
||||
|
||||
assert error.code == ErrorCode.COLLECTION_NOT_FOUND
|
||||
assert error.details["collection"] == "missing-collection"
|
||||
assert error.details["project_id"] == "proj-123"
|
||||
|
||||
|
||||
class TestDocumentErrors:
|
||||
"""Tests for document-related exceptions."""
|
||||
|
||||
def test_document_not_found_error(self):
|
||||
"""Test document not found error."""
|
||||
from exceptions import DocumentNotFoundError, ErrorCode
|
||||
|
||||
error = DocumentNotFoundError(
|
||||
source_path="/path/to/file.py",
|
||||
project_id="proj-123",
|
||||
)
|
||||
|
||||
assert error.code == ErrorCode.DOCUMENT_NOT_FOUND
|
||||
assert error.details["source_path"] == "/path/to/file.py"
|
||||
|
||||
def test_invalid_document_error(self):
|
||||
"""Test invalid document error."""
|
||||
from exceptions import ErrorCode, InvalidDocumentError
|
||||
|
||||
error = InvalidDocumentError(
|
||||
message="Empty content",
|
||||
details={"reason": "no content"},
|
||||
)
|
||||
|
||||
assert error.code == ErrorCode.INVALID_DOCUMENT
|
||||
|
||||
|
||||
class TestProjectErrors:
|
||||
"""Tests for project-related exceptions."""
|
||||
|
||||
def test_project_not_found_error(self):
|
||||
"""Test project not found error."""
|
||||
from exceptions import ErrorCode, ProjectNotFoundError
|
||||
|
||||
error = ProjectNotFoundError(project_id="missing-proj")
|
||||
|
||||
assert error.code == ErrorCode.PROJECT_NOT_FOUND
|
||||
assert error.details["project_id"] == "missing-proj"
|
||||
|
||||
def test_project_access_denied_error(self):
|
||||
"""Test project access denied error."""
|
||||
from exceptions import ErrorCode, ProjectAccessDeniedError
|
||||
|
||||
error = ProjectAccessDeniedError(project_id="restricted-proj")
|
||||
|
||||
assert error.code == ErrorCode.PROJECT_ACCESS_DENIED
|
||||
assert "restricted-proj" in error.message
|
||||
347
mcp-servers/knowledge-base/tests/test_models.py
Normal file
347
mcp-servers/knowledge-base/tests/test_models.py
Normal file
@@ -0,0 +1,347 @@
|
||||
"""Tests for data models."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
|
||||
class TestEnums:
|
||||
"""Tests for enum classes."""
|
||||
|
||||
def test_search_type_values(self):
|
||||
"""Test SearchType enum values."""
|
||||
from models import SearchType
|
||||
|
||||
assert SearchType.SEMANTIC.value == "semantic"
|
||||
assert SearchType.KEYWORD.value == "keyword"
|
||||
assert SearchType.HYBRID.value == "hybrid"
|
||||
|
||||
def test_chunk_type_values(self):
|
||||
"""Test ChunkType enum values."""
|
||||
from models import ChunkType
|
||||
|
||||
assert ChunkType.CODE.value == "code"
|
||||
assert ChunkType.MARKDOWN.value == "markdown"
|
||||
assert ChunkType.TEXT.value == "text"
|
||||
assert ChunkType.DOCUMENTATION.value == "documentation"
|
||||
|
||||
def test_file_type_values(self):
|
||||
"""Test FileType enum values."""
|
||||
from models import FileType
|
||||
|
||||
assert FileType.PYTHON.value == "python"
|
||||
assert FileType.JAVASCRIPT.value == "javascript"
|
||||
assert FileType.TYPESCRIPT.value == "typescript"
|
||||
assert FileType.MARKDOWN.value == "markdown"
|
||||
|
||||
|
||||
class TestFileExtensionMap:
|
||||
"""Tests for file extension mapping."""
|
||||
|
||||
def test_python_extensions(self):
|
||||
"""Test Python file extensions."""
|
||||
from models import FILE_EXTENSION_MAP, FileType
|
||||
|
||||
assert FILE_EXTENSION_MAP[".py"] == FileType.PYTHON
|
||||
|
||||
def test_javascript_extensions(self):
|
||||
"""Test JavaScript file extensions."""
|
||||
from models import FILE_EXTENSION_MAP, FileType
|
||||
|
||||
assert FILE_EXTENSION_MAP[".js"] == FileType.JAVASCRIPT
|
||||
assert FILE_EXTENSION_MAP[".jsx"] == FileType.JAVASCRIPT
|
||||
|
||||
def test_typescript_extensions(self):
|
||||
"""Test TypeScript file extensions."""
|
||||
from models import FILE_EXTENSION_MAP, FileType
|
||||
|
||||
assert FILE_EXTENSION_MAP[".ts"] == FileType.TYPESCRIPT
|
||||
assert FILE_EXTENSION_MAP[".tsx"] == FileType.TYPESCRIPT
|
||||
|
||||
def test_markdown_extensions(self):
|
||||
"""Test Markdown file extensions."""
|
||||
from models import FILE_EXTENSION_MAP, FileType
|
||||
|
||||
assert FILE_EXTENSION_MAP[".md"] == FileType.MARKDOWN
|
||||
assert FILE_EXTENSION_MAP[".mdx"] == FileType.MARKDOWN
|
||||
|
||||
|
||||
class TestChunk:
|
||||
"""Tests for Chunk dataclass."""
|
||||
|
||||
def test_chunk_creation(self, sample_chunk):
|
||||
"""Test chunk creation."""
|
||||
from models import ChunkType, FileType
|
||||
|
||||
assert sample_chunk.content == "def hello():\n print('Hello')"
|
||||
assert sample_chunk.chunk_type == ChunkType.CODE
|
||||
assert sample_chunk.file_type == FileType.PYTHON
|
||||
assert sample_chunk.source_path == "/test/hello.py"
|
||||
assert sample_chunk.start_line == 1
|
||||
assert sample_chunk.end_line == 2
|
||||
assert sample_chunk.token_count == 15
|
||||
|
||||
def test_chunk_to_dict(self, sample_chunk):
|
||||
"""Test chunk to_dict method."""
|
||||
result = sample_chunk.to_dict()
|
||||
|
||||
assert result["content"] == "def hello():\n print('Hello')"
|
||||
assert result["chunk_type"] == "code"
|
||||
assert result["file_type"] == "python"
|
||||
assert result["source_path"] == "/test/hello.py"
|
||||
assert result["start_line"] == 1
|
||||
assert result["end_line"] == 2
|
||||
assert result["token_count"] == 15
|
||||
|
||||
|
||||
class TestKnowledgeEmbedding:
|
||||
"""Tests for KnowledgeEmbedding dataclass."""
|
||||
|
||||
def test_embedding_creation(self, sample_embedding):
|
||||
"""Test embedding creation."""
|
||||
assert sample_embedding.id == "test-id-123"
|
||||
assert sample_embedding.project_id == "proj-123"
|
||||
assert sample_embedding.collection == "default"
|
||||
assert len(sample_embedding.embedding) == 1536
|
||||
|
||||
def test_embedding_to_dict(self, sample_embedding):
|
||||
"""Test embedding to_dict method."""
|
||||
result = sample_embedding.to_dict()
|
||||
|
||||
assert result["id"] == "test-id-123"
|
||||
assert result["project_id"] == "proj-123"
|
||||
assert result["collection"] == "default"
|
||||
assert result["chunk_type"] == "code"
|
||||
assert result["file_type"] == "python"
|
||||
assert "embedding" not in result # Embedding excluded for size
|
||||
|
||||
|
||||
class TestIngestRequest:
|
||||
"""Tests for IngestRequest model."""
|
||||
|
||||
def test_ingest_request_creation(self, sample_ingest_request):
|
||||
"""Test ingest request creation."""
|
||||
from models import ChunkType, FileType
|
||||
|
||||
assert sample_ingest_request.project_id == "proj-123"
|
||||
assert sample_ingest_request.agent_id == "agent-456"
|
||||
assert sample_ingest_request.chunk_type == ChunkType.CODE
|
||||
assert sample_ingest_request.file_type == FileType.PYTHON
|
||||
assert sample_ingest_request.collection == "default"
|
||||
|
||||
def test_ingest_request_defaults(self):
|
||||
"""Test ingest request default values."""
|
||||
from models import ChunkType, IngestRequest
|
||||
|
||||
request = IngestRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
content="test content",
|
||||
)
|
||||
|
||||
assert request.collection == "default"
|
||||
assert request.chunk_type == ChunkType.TEXT
|
||||
assert request.file_type is None
|
||||
assert request.metadata == {}
|
||||
|
||||
|
||||
class TestIngestResult:
|
||||
"""Tests for IngestResult model."""
|
||||
|
||||
def test_successful_result(self):
|
||||
"""Test successful ingest result."""
|
||||
from models import IngestResult
|
||||
|
||||
result = IngestResult(
|
||||
success=True,
|
||||
chunks_created=5,
|
||||
embeddings_generated=5,
|
||||
source_path="/test/file.py",
|
||||
collection="default",
|
||||
chunk_ids=["id1", "id2", "id3", "id4", "id5"],
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.chunks_created == 5
|
||||
assert result.error is None
|
||||
|
||||
def test_failed_result(self):
|
||||
"""Test failed ingest result."""
|
||||
from models import IngestResult
|
||||
|
||||
result = IngestResult(
|
||||
success=False,
|
||||
chunks_created=0,
|
||||
embeddings_generated=0,
|
||||
collection="default",
|
||||
chunk_ids=[],
|
||||
error="Something went wrong",
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert result.error == "Something went wrong"
|
||||
|
||||
|
||||
class TestSearchRequest:
|
||||
"""Tests for SearchRequest model."""
|
||||
|
||||
def test_search_request_creation(self, sample_search_request):
|
||||
"""Test search request creation."""
|
||||
from models import SearchType
|
||||
|
||||
assert sample_search_request.project_id == "proj-123"
|
||||
assert sample_search_request.query == "hello function"
|
||||
assert sample_search_request.search_type == SearchType.HYBRID
|
||||
assert sample_search_request.limit == 10
|
||||
assert sample_search_request.threshold == 0.7
|
||||
|
||||
def test_search_request_defaults(self):
|
||||
"""Test search request default values."""
|
||||
from models import SearchRequest, SearchType
|
||||
|
||||
request = SearchRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="test query",
|
||||
)
|
||||
|
||||
assert request.search_type == SearchType.HYBRID
|
||||
assert request.collection is None
|
||||
assert request.limit == 10
|
||||
assert request.threshold == 0.7
|
||||
assert request.file_types is None
|
||||
|
||||
|
||||
class TestSearchResult:
|
||||
"""Tests for SearchResult model."""
|
||||
|
||||
def test_from_embedding(self, sample_embedding):
|
||||
"""Test creating SearchResult from KnowledgeEmbedding."""
|
||||
from models import SearchResult
|
||||
|
||||
result = SearchResult.from_embedding(sample_embedding, 0.95)
|
||||
|
||||
assert result.id == "test-id-123"
|
||||
assert result.content == "def hello():\n print('Hello')"
|
||||
assert result.score == 0.95
|
||||
assert result.source_path == "/test/hello.py"
|
||||
assert result.chunk_type == "code"
|
||||
assert result.file_type == "python"
|
||||
|
||||
|
||||
class TestSearchResponse:
|
||||
"""Tests for SearchResponse model."""
|
||||
|
||||
def test_search_response(self):
|
||||
"""Test search response creation."""
|
||||
from models import SearchResponse, SearchResult
|
||||
|
||||
results = [
|
||||
SearchResult(
|
||||
id="id1",
|
||||
content="test content 1",
|
||||
score=0.95,
|
||||
chunk_type="code",
|
||||
collection="default",
|
||||
),
|
||||
SearchResult(
|
||||
id="id2",
|
||||
content="test content 2",
|
||||
score=0.85,
|
||||
chunk_type="text",
|
||||
collection="default",
|
||||
),
|
||||
]
|
||||
|
||||
response = SearchResponse(
|
||||
query="test query",
|
||||
search_type="hybrid",
|
||||
results=results,
|
||||
total_results=2,
|
||||
search_time_ms=15.5,
|
||||
)
|
||||
|
||||
assert response.query == "test query"
|
||||
assert len(response.results) == 2
|
||||
assert response.search_time_ms == 15.5
|
||||
|
||||
|
||||
class TestDeleteRequest:
|
||||
"""Tests for DeleteRequest model."""
|
||||
|
||||
def test_delete_by_source(self, sample_delete_request):
|
||||
"""Test delete request by source path."""
|
||||
assert sample_delete_request.project_id == "proj-123"
|
||||
assert sample_delete_request.source_path == "/test/hello.py"
|
||||
assert sample_delete_request.collection is None
|
||||
assert sample_delete_request.chunk_ids is None
|
||||
|
||||
def test_delete_by_collection(self):
|
||||
"""Test delete request by collection."""
|
||||
from models import DeleteRequest
|
||||
|
||||
request = DeleteRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
collection="to-delete",
|
||||
)
|
||||
|
||||
assert request.collection == "to-delete"
|
||||
assert request.source_path is None
|
||||
|
||||
def test_delete_by_ids(self):
|
||||
"""Test delete request by chunk IDs."""
|
||||
from models import DeleteRequest
|
||||
|
||||
request = DeleteRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
chunk_ids=["id1", "id2", "id3"],
|
||||
)
|
||||
|
||||
assert len(request.chunk_ids) == 3
|
||||
|
||||
|
||||
class TestCollectionInfo:
|
||||
"""Tests for CollectionInfo model."""
|
||||
|
||||
def test_collection_info(self):
|
||||
"""Test collection info creation."""
|
||||
from models import CollectionInfo
|
||||
|
||||
info = CollectionInfo(
|
||||
name="test-collection",
|
||||
project_id="proj-123",
|
||||
chunk_count=100,
|
||||
total_tokens=50000,
|
||||
file_types=["python", "javascript"],
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
assert info.name == "test-collection"
|
||||
assert info.chunk_count == 100
|
||||
assert len(info.file_types) == 2
|
||||
|
||||
|
||||
class TestCollectionStats:
|
||||
"""Tests for CollectionStats model."""
|
||||
|
||||
def test_collection_stats(self):
|
||||
"""Test collection stats creation."""
|
||||
from models import CollectionStats
|
||||
|
||||
stats = CollectionStats(
|
||||
collection="test-collection",
|
||||
project_id="proj-123",
|
||||
chunk_count=100,
|
||||
unique_sources=10,
|
||||
total_tokens=50000,
|
||||
avg_chunk_size=500.0,
|
||||
chunk_types={"code": 60, "text": 40},
|
||||
file_types={"python": 50, "javascript": 10},
|
||||
oldest_chunk=datetime.now(UTC),
|
||||
newest_chunk=datetime.now(UTC),
|
||||
)
|
||||
|
||||
assert stats.chunk_count == 100
|
||||
assert stats.unique_sources == 10
|
||||
assert stats.chunk_types["code"] == 60
|
||||
295
mcp-servers/knowledge-base/tests/test_search.py
Normal file
295
mcp-servers/knowledge-base/tests/test_search.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""Tests for search module."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestSearchEngine:
|
||||
"""Tests for SearchEngine class."""
|
||||
|
||||
@pytest.fixture
|
||||
def search_engine(self, settings, mock_database, mock_embeddings):
|
||||
"""Create search engine with mocks."""
|
||||
from search import SearchEngine
|
||||
|
||||
engine = SearchEngine(
|
||||
settings=settings,
|
||||
database=mock_database,
|
||||
embeddings=mock_embeddings,
|
||||
)
|
||||
return engine
|
||||
|
||||
@pytest.fixture
|
||||
def sample_db_results(self):
|
||||
"""Create sample database results."""
|
||||
from models import ChunkType, FileType, KnowledgeEmbedding
|
||||
|
||||
return [
|
||||
(
|
||||
KnowledgeEmbedding(
|
||||
id="id-1",
|
||||
project_id="proj-123",
|
||||
collection="default",
|
||||
content="def hello(): pass",
|
||||
embedding=[0.1] * 1536,
|
||||
chunk_type=ChunkType.CODE,
|
||||
source_path="/test/file.py",
|
||||
file_type=FileType.PYTHON,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
),
|
||||
0.95,
|
||||
),
|
||||
(
|
||||
KnowledgeEmbedding(
|
||||
id="id-2",
|
||||
project_id="proj-123",
|
||||
collection="default",
|
||||
content="def world(): pass",
|
||||
embedding=[0.2] * 1536,
|
||||
chunk_type=ChunkType.CODE,
|
||||
source_path="/test/file2.py",
|
||||
file_type=FileType.PYTHON,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
),
|
||||
0.85,
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_semantic_search(self, search_engine, sample_search_request, sample_db_results):
|
||||
"""Test semantic search."""
|
||||
from models import SearchType
|
||||
|
||||
sample_search_request.search_type = SearchType.SEMANTIC
|
||||
search_engine._database.semantic_search.return_value = sample_db_results
|
||||
|
||||
response = await search_engine.search(sample_search_request)
|
||||
|
||||
assert response.search_type == "semantic"
|
||||
assert len(response.results) == 2
|
||||
assert response.results[0].score == 0.95
|
||||
search_engine._database.semantic_search.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keyword_search(self, search_engine, sample_search_request, sample_db_results):
|
||||
"""Test keyword search."""
|
||||
from models import SearchType
|
||||
|
||||
sample_search_request.search_type = SearchType.KEYWORD
|
||||
search_engine._database.keyword_search.return_value = sample_db_results
|
||||
|
||||
response = await search_engine.search(sample_search_request)
|
||||
|
||||
assert response.search_type == "keyword"
|
||||
assert len(response.results) == 2
|
||||
search_engine._database.keyword_search.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hybrid_search(self, search_engine, sample_search_request, sample_db_results):
|
||||
"""Test hybrid search."""
|
||||
from models import SearchType
|
||||
|
||||
sample_search_request.search_type = SearchType.HYBRID
|
||||
|
||||
# Both searches return same results for simplicity
|
||||
search_engine._database.semantic_search.return_value = sample_db_results
|
||||
search_engine._database.keyword_search.return_value = sample_db_results
|
||||
|
||||
response = await search_engine.search(sample_search_request)
|
||||
|
||||
assert response.search_type == "hybrid"
|
||||
# Results should be fused
|
||||
assert len(response.results) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_collection_filter(self, search_engine, sample_search_request, sample_db_results):
|
||||
"""Test search with collection filter."""
|
||||
from models import SearchType
|
||||
|
||||
sample_search_request.search_type = SearchType.SEMANTIC
|
||||
sample_search_request.collection = "specific-collection"
|
||||
search_engine._database.semantic_search.return_value = sample_db_results
|
||||
|
||||
await search_engine.search(sample_search_request)
|
||||
|
||||
# Verify collection was passed to database
|
||||
call_args = search_engine._database.semantic_search.call_args
|
||||
assert call_args.kwargs["collection"] == "specific-collection"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_file_type_filter(self, search_engine, sample_search_request, sample_db_results):
|
||||
"""Test search with file type filter."""
|
||||
from models import FileType, SearchType
|
||||
|
||||
sample_search_request.search_type = SearchType.SEMANTIC
|
||||
sample_search_request.file_types = [FileType.PYTHON]
|
||||
search_engine._database.semantic_search.return_value = sample_db_results
|
||||
|
||||
await search_engine.search(sample_search_request)
|
||||
|
||||
# Verify file types were passed to database
|
||||
call_args = search_engine._database.semantic_search.call_args
|
||||
assert call_args.kwargs["file_types"] == [FileType.PYTHON]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_respects_limit(self, search_engine, sample_search_request, sample_db_results):
|
||||
"""Test that search respects result limit."""
|
||||
from models import SearchType
|
||||
|
||||
sample_search_request.search_type = SearchType.SEMANTIC
|
||||
sample_search_request.limit = 1
|
||||
search_engine._database.semantic_search.return_value = sample_db_results[:1]
|
||||
|
||||
response = await search_engine.search(sample_search_request)
|
||||
|
||||
assert len(response.results) <= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_records_time(self, search_engine, sample_search_request, sample_db_results):
|
||||
"""Test that search records time."""
|
||||
from models import SearchType
|
||||
|
||||
sample_search_request.search_type = SearchType.SEMANTIC
|
||||
search_engine._database.semantic_search.return_value = sample_db_results
|
||||
|
||||
response = await search_engine.search(sample_search_request)
|
||||
|
||||
assert response.search_time_ms > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_search_type(self, search_engine, sample_search_request):
|
||||
"""Test handling invalid search type."""
|
||||
from exceptions import InvalidSearchTypeError
|
||||
|
||||
# Force invalid search type
|
||||
sample_search_request.search_type = "invalid"
|
||||
|
||||
with pytest.raises((InvalidSearchTypeError, ValueError)):
|
||||
await search_engine.search(sample_search_request)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_results(self, search_engine, sample_search_request):
|
||||
"""Test search with no results."""
|
||||
from models import SearchType
|
||||
|
||||
sample_search_request.search_type = SearchType.SEMANTIC
|
||||
search_engine._database.semantic_search.return_value = []
|
||||
|
||||
response = await search_engine.search(sample_search_request)
|
||||
|
||||
assert len(response.results) == 0
|
||||
assert response.total_results == 0
|
||||
|
||||
|
||||
class TestReciprocalRankFusion:
|
||||
"""Tests for reciprocal rank fusion."""
|
||||
|
||||
@pytest.fixture
|
||||
def search_engine(self, settings, mock_database, mock_embeddings):
|
||||
"""Create search engine with mocks."""
|
||||
from search import SearchEngine
|
||||
|
||||
return SearchEngine(
|
||||
settings=settings,
|
||||
database=mock_database,
|
||||
embeddings=mock_embeddings,
|
||||
)
|
||||
|
||||
def test_fusion_combines_results(self, search_engine):
|
||||
"""Test that RRF combines results from both searches."""
|
||||
from models import SearchResult
|
||||
|
||||
semantic = [
|
||||
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"),
|
||||
SearchResult(id="b", content="B", score=0.8, chunk_type="code", collection="default"),
|
||||
]
|
||||
|
||||
keyword = [
|
||||
SearchResult(id="b", content="B", score=0.85, chunk_type="code", collection="default"),
|
||||
SearchResult(id="c", content="C", score=0.7, chunk_type="code", collection="default"),
|
||||
]
|
||||
|
||||
fused = search_engine._reciprocal_rank_fusion(semantic, keyword)
|
||||
|
||||
# Should have all unique results
|
||||
ids = [r.id for r in fused]
|
||||
assert "a" in ids
|
||||
assert "b" in ids
|
||||
assert "c" in ids
|
||||
|
||||
# B should be ranked higher (appears in both)
|
||||
b_rank = ids.index("b")
|
||||
assert b_rank < 2 # Should be in top 2
|
||||
|
||||
def test_fusion_respects_weights(self, search_engine):
|
||||
"""Test that RRF respects semantic/keyword weights."""
|
||||
from models import SearchResult
|
||||
|
||||
# Same results in same order
|
||||
results = [
|
||||
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"),
|
||||
]
|
||||
|
||||
# High semantic weight
|
||||
fused_semantic_heavy = search_engine._reciprocal_rank_fusion(
|
||||
results, [],
|
||||
semantic_weight=0.9,
|
||||
keyword_weight=0.1,
|
||||
)
|
||||
|
||||
# High keyword weight
|
||||
fused_keyword_heavy = search_engine._reciprocal_rank_fusion(
|
||||
[], results,
|
||||
semantic_weight=0.1,
|
||||
keyword_weight=0.9,
|
||||
)
|
||||
|
||||
# Both should still return the result
|
||||
assert len(fused_semantic_heavy) == 1
|
||||
assert len(fused_keyword_heavy) == 1
|
||||
|
||||
def test_fusion_normalizes_scores(self, search_engine):
|
||||
"""Test that RRF normalizes scores to 0-1."""
|
||||
from models import SearchResult
|
||||
|
||||
semantic = [
|
||||
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"),
|
||||
SearchResult(id="b", content="B", score=0.8, chunk_type="code", collection="default"),
|
||||
]
|
||||
|
||||
keyword = [
|
||||
SearchResult(id="c", content="C", score=0.7, chunk_type="code", collection="default"),
|
||||
]
|
||||
|
||||
fused = search_engine._reciprocal_rank_fusion(semantic, keyword)
|
||||
|
||||
# All scores should be between 0 and 1
|
||||
for result in fused:
|
||||
assert 0 <= result.score <= 1
|
||||
|
||||
|
||||
class TestGlobalSearchEngine:
|
||||
"""Tests for global search engine."""
|
||||
|
||||
def test_get_search_engine_singleton(self):
|
||||
"""Test that get_search_engine returns singleton."""
|
||||
from search import get_search_engine, reset_search_engine
|
||||
|
||||
reset_search_engine()
|
||||
engine1 = get_search_engine()
|
||||
engine2 = get_search_engine()
|
||||
|
||||
assert engine1 is engine2
|
||||
|
||||
def test_reset_search_engine(self):
|
||||
"""Test resetting search engine."""
|
||||
from search import get_search_engine, reset_search_engine
|
||||
|
||||
engine1 = get_search_engine()
|
||||
reset_search_engine()
|
||||
engine2 = get_search_engine()
|
||||
|
||||
assert engine1 is not engine2
|
||||
357
mcp-servers/knowledge-base/tests/test_server.py
Normal file
357
mcp-servers/knowledge-base/tests/test_server.py
Normal file
@@ -0,0 +1,357 @@
|
||||
"""Tests for server module and MCP tools."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestHealthCheck:
|
||||
"""Tests for health check endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_healthy(self):
|
||||
"""Test health check when healthy."""
|
||||
import server
|
||||
|
||||
# Create a proper async context manager mock
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.fetchval = AsyncMock(return_value=1)
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db._pool = MagicMock()
|
||||
|
||||
# Make acquire an async context manager
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__.return_value = mock_conn
|
||||
mock_cm.__aexit__.return_value = None
|
||||
mock_db.acquire.return_value = mock_cm
|
||||
|
||||
server._database = mock_db
|
||||
|
||||
result = await server.health_check()
|
||||
|
||||
assert result["status"] == "healthy"
|
||||
assert result["service"] == "knowledge-base"
|
||||
assert result["database"] == "connected"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_no_database(self):
|
||||
"""Test health check without database."""
|
||||
import server
|
||||
|
||||
server._database = None
|
||||
|
||||
result = await server.health_check()
|
||||
|
||||
assert result["database"] == "not initialized"
|
||||
|
||||
|
||||
class TestSearchKnowledgeTool:
|
||||
"""Tests for search_knowledge MCP tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_success(self):
|
||||
"""Test successful search."""
|
||||
import server
|
||||
from models import SearchResponse, SearchResult
|
||||
|
||||
mock_search = MagicMock()
|
||||
mock_search.search = AsyncMock(
|
||||
return_value=SearchResponse(
|
||||
query="test query",
|
||||
search_type="hybrid",
|
||||
results=[
|
||||
SearchResult(
|
||||
id="id-1",
|
||||
content="Test content",
|
||||
score=0.95,
|
||||
source_path="/test/file.py",
|
||||
chunk_type="code",
|
||||
collection="default",
|
||||
)
|
||||
],
|
||||
total_results=1,
|
||||
search_time_ms=10.5,
|
||||
)
|
||||
)
|
||||
server._search = mock_search
|
||||
|
||||
# Call the wrapped function via .fn
|
||||
result = await server.search_knowledge.fn(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="test query",
|
||||
search_type="hybrid",
|
||||
collection=None,
|
||||
limit=10,
|
||||
threshold=0.7,
|
||||
file_types=None,
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert len(result["results"]) == 1
|
||||
assert result["results"][0]["score"] == 0.95
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_invalid_type(self):
|
||||
"""Test search with invalid search type."""
|
||||
import server
|
||||
|
||||
result = await server.search_knowledge.fn(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="test",
|
||||
search_type="invalid",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "Invalid search type" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_invalid_file_type(self):
|
||||
"""Test search with invalid file type."""
|
||||
import server
|
||||
|
||||
result = await server.search_knowledge.fn(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="test",
|
||||
search_type="hybrid",
|
||||
collection=None,
|
||||
limit=10,
|
||||
threshold=0.7,
|
||||
file_types=["invalid_type"],
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "Invalid file type" in result["error"]
|
||||
|
||||
|
||||
class TestIngestContentTool:
|
||||
"""Tests for ingest_content MCP tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_success(self):
|
||||
"""Test successful ingestion."""
|
||||
import server
|
||||
from models import IngestResult
|
||||
|
||||
mock_collections = MagicMock()
|
||||
mock_collections.ingest = AsyncMock(
|
||||
return_value=IngestResult(
|
||||
success=True,
|
||||
chunks_created=3,
|
||||
embeddings_generated=3,
|
||||
source_path="/test/file.py",
|
||||
collection="default",
|
||||
chunk_ids=["id-1", "id-2", "id-3"],
|
||||
)
|
||||
)
|
||||
server._collections = mock_collections
|
||||
|
||||
result = await server.ingest_content.fn(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
content="def hello(): pass",
|
||||
source_path="/test/file.py",
|
||||
collection="default",
|
||||
chunk_type="text",
|
||||
file_type=None,
|
||||
metadata=None,
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["chunks_created"] == 3
|
||||
assert len(result["chunk_ids"]) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_invalid_chunk_type(self):
|
||||
"""Test ingest with invalid chunk type."""
|
||||
import server
|
||||
|
||||
result = await server.ingest_content.fn(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
content="test content",
|
||||
chunk_type="invalid",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "Invalid chunk type" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_invalid_file_type(self):
|
||||
"""Test ingest with invalid file type."""
|
||||
import server
|
||||
|
||||
result = await server.ingest_content.fn(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
content="test content",
|
||||
source_path=None,
|
||||
collection="default",
|
||||
chunk_type="text",
|
||||
file_type="invalid",
|
||||
metadata=None,
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "Invalid file type" in result["error"]
|
||||
|
||||
|
||||
class TestDeleteContentTool:
|
||||
"""Tests for delete_content MCP tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_success(self):
|
||||
"""Test successful deletion."""
|
||||
import server
|
||||
from models import DeleteResult
|
||||
|
||||
mock_collections = MagicMock()
|
||||
mock_collections.delete = AsyncMock(
|
||||
return_value=DeleteResult(
|
||||
success=True,
|
||||
chunks_deleted=5,
|
||||
)
|
||||
)
|
||||
server._collections = mock_collections
|
||||
|
||||
result = await server.delete_content.fn(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
source_path="/test/file.py",
|
||||
collection=None,
|
||||
chunk_ids=None,
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["chunks_deleted"] == 5
|
||||
|
||||
|
||||
class TestListCollectionsTool:
|
||||
"""Tests for list_collections MCP tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_collections_success(self):
|
||||
"""Test listing collections."""
|
||||
import server
|
||||
from models import CollectionInfo, ListCollectionsResponse
|
||||
|
||||
mock_collections = MagicMock()
|
||||
mock_collections.list_collections = AsyncMock(
|
||||
return_value=ListCollectionsResponse(
|
||||
project_id="proj-123",
|
||||
collections=[
|
||||
CollectionInfo(
|
||||
name="collection-1",
|
||||
project_id="proj-123",
|
||||
chunk_count=100,
|
||||
total_tokens=50000,
|
||||
file_types=["python"],
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
],
|
||||
total_collections=1,
|
||||
)
|
||||
)
|
||||
server._collections = mock_collections
|
||||
|
||||
result = await server.list_collections.fn(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["total_collections"] == 1
|
||||
assert len(result["collections"]) == 1
|
||||
|
||||
|
||||
class TestGetCollectionStatsTool:
|
||||
"""Tests for get_collection_stats MCP tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stats_success(self):
|
||||
"""Test getting collection stats."""
|
||||
import server
|
||||
from models import CollectionStats
|
||||
|
||||
mock_collections = MagicMock()
|
||||
mock_collections.get_collection_stats = AsyncMock(
|
||||
return_value=CollectionStats(
|
||||
collection="test-collection",
|
||||
project_id="proj-123",
|
||||
chunk_count=100,
|
||||
unique_sources=10,
|
||||
total_tokens=50000,
|
||||
avg_chunk_size=500.0,
|
||||
chunk_types={"code": 60, "text": 40},
|
||||
file_types={"python": 50, "javascript": 10},
|
||||
)
|
||||
)
|
||||
server._collections = mock_collections
|
||||
|
||||
result = await server.get_collection_stats.fn(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
collection="test-collection",
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["chunk_count"] == 100
|
||||
assert result["unique_sources"] == 10
|
||||
|
||||
|
||||
class TestUpdateDocumentTool:
|
||||
"""Tests for update_document MCP tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_success(self):
|
||||
"""Test updating a document."""
|
||||
import server
|
||||
from models import IngestResult
|
||||
|
||||
mock_collections = MagicMock()
|
||||
mock_collections.update_document = AsyncMock(
|
||||
return_value=IngestResult(
|
||||
success=True,
|
||||
chunks_created=2,
|
||||
embeddings_generated=2,
|
||||
source_path="/test/file.py",
|
||||
collection="default",
|
||||
chunk_ids=["id-1", "id-2"],
|
||||
)
|
||||
)
|
||||
server._collections = mock_collections
|
||||
|
||||
result = await server.update_document.fn(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
source_path="/test/file.py",
|
||||
content="def updated(): pass",
|
||||
collection="default",
|
||||
chunk_type="text",
|
||||
file_type=None,
|
||||
metadata=None,
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["chunks_created"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_invalid_chunk_type(self):
|
||||
"""Test update with invalid chunk type."""
|
||||
import server
|
||||
|
||||
result = await server.update_document.fn(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
source_path="/test/file.py",
|
||||
content="test",
|
||||
chunk_type="invalid",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "Invalid chunk type" in result["error"]
|
||||
Reference in New Issue
Block a user