feat(knowledge-base): implement Knowledge Base MCP Server (#57)

Implements RAG capabilities with pgvector for semantic search:

- Intelligent chunking strategies (code-aware, markdown-aware, text)
- Semantic search with vector similarity (HNSW index)
- Keyword search with PostgreSQL full-text search
- Hybrid search using Reciprocal Rank Fusion (RRF)
- Redis caching for embeddings
- Collection management (ingest, search, delete, stats)
- FastMCP tools: search_knowledge, ingest_content, delete_content,
  list_collections, get_collection_stats, update_document

Testing:
- 128 comprehensive tests covering all components
- 58% code coverage (database integration tests use mocks)
- Passes ruff linting and mypy type checking

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-03 21:33:26 +01:00
parent 18d717e996
commit d0fc7f37ff
26 changed files with 9530 additions and 120 deletions

View File

@@ -0,0 +1 @@
"""Tests for Knowledge Base MCP Server."""

View File

@@ -0,0 +1,282 @@
"""
Test fixtures for Knowledge Base MCP Server.
"""
import os
import sys
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock
import pytest
# Add parent directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Set test mode before importing modules
os.environ["IS_TEST"] = "true"
os.environ["KB_DATABASE_URL"] = "postgresql://test:test@localhost:5432/test"
os.environ["KB_REDIS_URL"] = "redis://localhost:6379/0"
os.environ["KB_LLM_GATEWAY_URL"] = "http://localhost:8001"
@pytest.fixture
def settings():
"""Create test settings."""
from config import Settings, reset_settings
reset_settings()
return Settings(
host="127.0.0.1",
port=8002,
debug=True,
database_url="postgresql://test:test@localhost:5432/test",
redis_url="redis://localhost:6379/0",
llm_gateway_url="http://localhost:8001",
embedding_dimension=1536,
code_chunk_size=500,
code_chunk_overlap=50,
markdown_chunk_size=800,
markdown_chunk_overlap=100,
text_chunk_size=400,
text_chunk_overlap=50,
)
@pytest.fixture
def mock_database():
"""Create mock database manager."""
from database import DatabaseManager
mock_db = MagicMock(spec=DatabaseManager)
mock_db._pool = MagicMock()
mock_db.acquire = MagicMock(return_value=AsyncMock())
# Mock database methods
mock_db.initialize = AsyncMock()
mock_db.close = AsyncMock()
mock_db.store_embedding = AsyncMock(return_value="test-id-123")
mock_db.store_embeddings_batch = AsyncMock(return_value=["id-1", "id-2"])
mock_db.semantic_search = AsyncMock(return_value=[])
mock_db.keyword_search = AsyncMock(return_value=[])
mock_db.delete_by_source = AsyncMock(return_value=1)
mock_db.delete_collection = AsyncMock(return_value=5)
mock_db.delete_by_ids = AsyncMock(return_value=2)
mock_db.list_collections = AsyncMock(return_value=[])
mock_db.get_collection_stats = AsyncMock()
mock_db.cleanup_expired = AsyncMock(return_value=0)
return mock_db
@pytest.fixture
def mock_embeddings():
"""Create mock embedding generator."""
from embeddings import EmbeddingGenerator
mock_emb = MagicMock(spec=EmbeddingGenerator)
mock_emb.initialize = AsyncMock()
mock_emb.close = AsyncMock()
# Generate fake embeddings (1536 dimensions)
def fake_embedding() -> list[float]:
return [0.1] * 1536
mock_emb.generate = AsyncMock(return_value=fake_embedding())
mock_emb.generate_batch = AsyncMock(side_effect=lambda texts, **_kwargs: [fake_embedding() for _ in texts])
return mock_emb
@pytest.fixture
def mock_redis():
"""Create mock Redis client."""
import fakeredis.aioredis
return fakeredis.aioredis.FakeRedis()
@pytest.fixture
def sample_python_code():
"""Sample Python code for chunking tests."""
return '''"""Sample module for testing."""
import os
from typing import Any
class Calculator:
"""A simple calculator class."""
def __init__(self, initial: int = 0) -> None:
"""Initialize calculator."""
self.value = initial
def add(self, x: int) -> int:
"""Add a value."""
self.value += x
return self.value
def subtract(self, x: int) -> int:
"""Subtract a value."""
self.value -= x
return self.value
def helper_function(data: dict[str, Any]) -> str:
"""A helper function."""
return str(data)
async def async_function() -> None:
"""An async function."""
pass
'''
@pytest.fixture
def sample_markdown():
"""Sample Markdown content for chunking tests."""
return '''# Project Documentation
This is the main documentation for our project.
## Getting Started
To get started, follow these steps:
1. Install dependencies
2. Configure settings
3. Run the application
### Prerequisites
You'll need the following installed:
- Python 3.12+
- PostgreSQL
- Redis
```python
# Example code
def main():
print("Hello, World!")
```
## API Reference
### Search Endpoint
The search endpoint allows you to query the knowledge base.
**Endpoint:** `POST /api/search`
**Request:**
```json
{
"query": "your search query",
"limit": 10
}
```
## Contributing
We welcome contributions! Please see our contributing guide.
'''
@pytest.fixture
def sample_text():
"""Sample plain text for chunking tests."""
return '''The quick brown fox jumps over the lazy dog. This is a sample text that we use for testing the text chunking functionality. It contains multiple sentences that should be properly split into chunks.
Each paragraph represents a logical unit of text. The chunker should try to respect paragraph boundaries when possible. This helps maintain context and readability.
When chunks need to be split mid-paragraph, the chunker should prefer sentence boundaries. This ensures that each chunk contains complete thoughts and is useful for retrieval.
The final paragraph tests edge cases. What happens with short paragraphs? Do they get merged with adjacent content? Let's find out!
'''
@pytest.fixture
def sample_chunk():
"""Sample chunk for testing."""
from models import Chunk, ChunkType, FileType
return Chunk(
content="def hello():\n print('Hello')",
chunk_type=ChunkType.CODE,
file_type=FileType.PYTHON,
source_path="/test/hello.py",
start_line=1,
end_line=2,
metadata={"function": "hello"},
token_count=15,
)
@pytest.fixture
def sample_embedding():
"""Sample knowledge embedding for testing."""
from models import ChunkType, FileType, KnowledgeEmbedding
return KnowledgeEmbedding(
id="test-id-123",
project_id="proj-123",
collection="default",
content="def hello():\n print('Hello')",
embedding=[0.1] * 1536,
chunk_type=ChunkType.CODE,
source_path="/test/hello.py",
start_line=1,
end_line=2,
file_type=FileType.PYTHON,
metadata={"function": "hello"},
content_hash="abc123",
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
@pytest.fixture
def sample_ingest_request():
"""Sample ingest request for testing."""
from models import ChunkType, FileType, IngestRequest
return IngestRequest(
project_id="proj-123",
agent_id="agent-456",
content="def hello():\n print('Hello')",
source_path="/test/hello.py",
collection="default",
chunk_type=ChunkType.CODE,
file_type=FileType.PYTHON,
metadata={"test": True},
)
@pytest.fixture
def sample_search_request():
"""Sample search request for testing."""
from models import SearchRequest, SearchType
return SearchRequest(
project_id="proj-123",
agent_id="agent-456",
query="hello function",
search_type=SearchType.HYBRID,
collection="default",
limit=10,
threshold=0.7,
)
@pytest.fixture
def sample_delete_request():
"""Sample delete request for testing."""
from models import DeleteRequest
return DeleteRequest(
project_id="proj-123",
agent_id="agent-456",
source_path="/test/hello.py",
)

View File

@@ -0,0 +1,422 @@
"""Tests for chunking module."""
class TestBaseChunker:
"""Tests for base chunker functionality."""
def test_count_tokens(self, settings):
"""Test token counting."""
from chunking.text import TextChunker
chunker = TextChunker(
chunk_size=400,
chunk_overlap=50,
settings=settings,
)
# Simple text should count tokens
tokens = chunker.count_tokens("Hello, world!")
assert tokens > 0
assert tokens < 10 # Should be about 3-4 tokens
def test_truncate_to_tokens(self, settings):
"""Test truncating text to token limit."""
from chunking.text import TextChunker
chunker = TextChunker(
chunk_size=400,
chunk_overlap=50,
settings=settings,
)
long_text = "word " * 1000
truncated = chunker.truncate_to_tokens(long_text, 10)
assert chunker.count_tokens(truncated) <= 10
class TestCodeChunker:
"""Tests for code chunker."""
def test_chunk_python_code(self, settings, sample_python_code):
"""Test chunking Python code."""
from chunking.code import CodeChunker
from models import ChunkType, FileType
chunker = CodeChunker(
chunk_size=500,
chunk_overlap=50,
settings=settings,
)
chunks = chunker.chunk(
content=sample_python_code,
source_path="/test/sample.py",
file_type=FileType.PYTHON,
)
assert len(chunks) > 0
assert all(c.chunk_type == ChunkType.CODE for c in chunks)
assert all(c.file_type == FileType.PYTHON for c in chunks)
def test_preserves_function_boundaries(self, settings):
"""Test that chunker preserves function boundaries."""
from chunking.code import CodeChunker
from models import FileType
code = '''def function_one():
"""First function."""
return 1
def function_two():
"""Second function."""
return 2
'''
chunker = CodeChunker(
chunk_size=100,
chunk_overlap=10,
settings=settings,
)
chunks = chunker.chunk(
content=code,
source_path="/test/funcs.py",
file_type=FileType.PYTHON,
)
# Each function should ideally be in its own chunk
assert len(chunks) >= 1
for chunk in chunks:
# Check chunks have line numbers
assert chunk.start_line is not None
assert chunk.end_line is not None
assert chunk.start_line <= chunk.end_line
def test_handles_empty_content(self, settings):
"""Test handling empty content."""
from chunking.code import CodeChunker
chunker = CodeChunker(
chunk_size=500,
chunk_overlap=50,
settings=settings,
)
chunks = chunker.chunk(content="", source_path="/test/empty.py")
assert chunks == []
def test_chunk_type_is_code(self, settings):
"""Test that chunk_type property returns CODE."""
from chunking.code import CodeChunker
from models import ChunkType
chunker = CodeChunker(
chunk_size=500,
chunk_overlap=50,
settings=settings,
)
assert chunker.chunk_type == ChunkType.CODE
class TestMarkdownChunker:
"""Tests for markdown chunker."""
def test_chunk_markdown(self, settings, sample_markdown):
"""Test chunking markdown content."""
from chunking.markdown import MarkdownChunker
from models import ChunkType, FileType
chunker = MarkdownChunker(
chunk_size=800,
chunk_overlap=100,
settings=settings,
)
chunks = chunker.chunk(
content=sample_markdown,
source_path="/test/docs.md",
file_type=FileType.MARKDOWN,
)
assert len(chunks) > 0
assert all(c.chunk_type == ChunkType.MARKDOWN for c in chunks)
def test_respects_heading_hierarchy(self, settings):
"""Test that chunker respects heading hierarchy."""
from chunking.markdown import MarkdownChunker
markdown = '''# Main Title
Introduction paragraph.
## Section One
Content for section one.
### Subsection
More detailed content.
## Section Two
Content for section two.
'''
chunker = MarkdownChunker(
chunk_size=200,
chunk_overlap=20,
settings=settings,
)
chunks = chunker.chunk(
content=markdown,
source_path="/test/docs.md",
)
# Should have multiple chunks based on sections
assert len(chunks) >= 1
# Metadata should include heading context
for chunk in chunks:
# Chunks should have content
assert len(chunk.content) > 0
def test_handles_code_blocks(self, settings):
"""Test handling of code blocks in markdown."""
from chunking.markdown import MarkdownChunker
markdown = '''# Code Example
Here's some code:
```python
def hello():
print("Hello, World!")
```
End of example.
'''
chunker = MarkdownChunker(
chunk_size=500,
chunk_overlap=50,
settings=settings,
)
chunks = chunker.chunk(
content=markdown,
source_path="/test/code.md",
)
# Code blocks should be preserved
assert len(chunks) >= 1
full_content = " ".join(c.content for c in chunks)
assert "```python" in full_content or "def hello" in full_content
def test_chunk_type_is_markdown(self, settings):
"""Test that chunk_type property returns MARKDOWN."""
from chunking.markdown import MarkdownChunker
from models import ChunkType
chunker = MarkdownChunker(
chunk_size=800,
chunk_overlap=100,
settings=settings,
)
assert chunker.chunk_type == ChunkType.MARKDOWN
class TestTextChunker:
"""Tests for text chunker."""
def test_chunk_text(self, settings, sample_text):
"""Test chunking plain text."""
from chunking.text import TextChunker
from models import ChunkType
chunker = TextChunker(
chunk_size=400,
chunk_overlap=50,
settings=settings,
)
chunks = chunker.chunk(
content=sample_text,
source_path="/test/docs.txt",
)
assert len(chunks) > 0
assert all(c.chunk_type == ChunkType.TEXT for c in chunks)
def test_respects_paragraph_boundaries(self, settings):
"""Test that chunker respects paragraph boundaries."""
from chunking.text import TextChunker
text = '''First paragraph with some content.
Second paragraph with different content.
Third paragraph to test chunking behavior.
'''
chunker = TextChunker(
chunk_size=100,
chunk_overlap=10,
settings=settings,
)
chunks = chunker.chunk(
content=text,
source_path="/test/text.txt",
)
assert len(chunks) >= 1
def test_handles_single_paragraph(self, settings):
"""Test handling of single paragraph that fits in one chunk."""
from chunking.text import TextChunker
text = "This is a short paragraph."
chunker = TextChunker(
chunk_size=400,
chunk_overlap=50,
settings=settings,
)
chunks = chunker.chunk(content=text, source_path="/test/short.txt")
assert len(chunks) == 1
assert chunks[0].content == text
def test_chunk_type_is_text(self, settings):
"""Test that chunk_type property returns TEXT."""
from chunking.text import TextChunker
from models import ChunkType
chunker = TextChunker(
chunk_size=400,
chunk_overlap=50,
settings=settings,
)
assert chunker.chunk_type == ChunkType.TEXT
class TestChunkerFactory:
"""Tests for chunker factory."""
def test_get_code_chunker(self, settings):
"""Test getting code chunker."""
from chunking.base import ChunkerFactory
from chunking.code import CodeChunker
from models import FileType
factory = ChunkerFactory(settings=settings)
chunker = factory.get_chunker(file_type=FileType.PYTHON)
assert isinstance(chunker, CodeChunker)
def test_get_markdown_chunker(self, settings):
"""Test getting markdown chunker."""
from chunking.base import ChunkerFactory
from chunking.markdown import MarkdownChunker
from models import FileType
factory = ChunkerFactory(settings=settings)
chunker = factory.get_chunker(file_type=FileType.MARKDOWN)
assert isinstance(chunker, MarkdownChunker)
def test_get_text_chunker(self, settings):
"""Test getting text chunker."""
from chunking.base import ChunkerFactory
from chunking.text import TextChunker
from models import FileType
factory = ChunkerFactory(settings=settings)
chunker = factory.get_chunker(file_type=FileType.TEXT)
assert isinstance(chunker, TextChunker)
def test_get_chunker_for_path(self, settings):
"""Test getting chunker based on file path."""
from chunking.base import ChunkerFactory
from chunking.code import CodeChunker
from chunking.markdown import MarkdownChunker
from models import FileType
factory = ChunkerFactory(settings=settings)
chunker, file_type = factory.get_chunker_for_path("/test/file.py")
assert isinstance(chunker, CodeChunker)
assert file_type == FileType.PYTHON
chunker, file_type = factory.get_chunker_for_path("/test/docs.md")
assert isinstance(chunker, MarkdownChunker)
assert file_type == FileType.MARKDOWN
def test_chunk_content(self, settings, sample_python_code):
"""Test chunk_content convenience method."""
from chunking.base import ChunkerFactory
from models import ChunkType
factory = ChunkerFactory(settings=settings)
chunks = factory.chunk_content(
content=sample_python_code,
source_path="/test/sample.py",
)
assert len(chunks) > 0
assert all(c.chunk_type == ChunkType.CODE for c in chunks)
def test_default_to_text_chunker(self, settings):
"""Test defaulting to text chunker."""
from chunking.base import ChunkerFactory
from chunking.text import TextChunker
factory = ChunkerFactory(settings=settings)
chunker = factory.get_chunker()
assert isinstance(chunker, TextChunker)
def test_chunker_caching(self, settings):
"""Test that factory caches chunker instances."""
from chunking.base import ChunkerFactory
from models import FileType
factory = ChunkerFactory(settings=settings)
chunker1 = factory.get_chunker(file_type=FileType.PYTHON)
chunker2 = factory.get_chunker(file_type=FileType.PYTHON)
assert chunker1 is chunker2
class TestGlobalChunkerFactory:
"""Tests for global chunker factory."""
def test_get_chunker_factory_singleton(self):
"""Test that get_chunker_factory returns singleton."""
from chunking.base import get_chunker_factory, reset_chunker_factory
reset_chunker_factory()
factory1 = get_chunker_factory()
factory2 = get_chunker_factory()
assert factory1 is factory2
def test_reset_chunker_factory(self):
"""Test resetting chunker factory."""
from chunking.base import get_chunker_factory, reset_chunker_factory
factory1 = get_chunker_factory()
reset_chunker_factory()
factory2 = get_chunker_factory()
assert factory1 is not factory2

View File

@@ -0,0 +1,240 @@
"""Tests for collection manager module."""
from datetime import UTC, datetime
from unittest.mock import MagicMock
import pytest
class TestCollectionManager:
"""Tests for CollectionManager class."""
@pytest.fixture
def collection_manager(self, settings, mock_database, mock_embeddings):
"""Create collection manager with mocks."""
from chunking.base import ChunkerFactory
from collection_manager import CollectionManager
mock_chunker_factory = MagicMock(spec=ChunkerFactory)
# Mock chunk_content to return chunks
from models import Chunk, ChunkType
mock_chunker_factory.chunk_content.return_value = [
Chunk(
content="def hello(): pass",
chunk_type=ChunkType.CODE,
token_count=10,
)
]
return CollectionManager(
settings=settings,
database=mock_database,
embeddings=mock_embeddings,
chunker_factory=mock_chunker_factory,
)
@pytest.mark.asyncio
async def test_ingest_content(self, collection_manager, sample_ingest_request):
"""Test content ingestion."""
result = await collection_manager.ingest(sample_ingest_request)
assert result.success is True
assert result.chunks_created == 1
assert result.embeddings_generated == 1
assert len(result.chunk_ids) == 1
assert result.collection == "default"
@pytest.mark.asyncio
async def test_ingest_empty_content(self, collection_manager):
"""Test ingesting empty content."""
from models import IngestRequest
# Mock chunker to return empty list
collection_manager._chunker_factory.chunk_content.return_value = []
request = IngestRequest(
project_id="proj-123",
agent_id="agent-456",
content="",
)
result = await collection_manager.ingest(request)
assert result.success is True
assert result.chunks_created == 0
assert result.embeddings_generated == 0
@pytest.mark.asyncio
async def test_ingest_error_handling(self, collection_manager, sample_ingest_request):
"""Test ingest error handling."""
# Make embedding generation fail
collection_manager._embeddings.generate_batch.side_effect = Exception("Embedding error")
result = await collection_manager.ingest(sample_ingest_request)
assert result.success is False
assert "Embedding error" in result.error
@pytest.mark.asyncio
async def test_delete_by_source(self, collection_manager, sample_delete_request):
"""Test deletion by source path."""
result = await collection_manager.delete(sample_delete_request)
assert result.success is True
assert result.chunks_deleted == 1 # Mock returns 1
collection_manager._database.delete_by_source.assert_called_once()
@pytest.mark.asyncio
async def test_delete_by_collection(self, collection_manager):
"""Test deletion by collection."""
from models import DeleteRequest
request = DeleteRequest(
project_id="proj-123",
agent_id="agent-456",
collection="to-delete",
)
result = await collection_manager.delete(request)
assert result.success is True
collection_manager._database.delete_collection.assert_called_once()
@pytest.mark.asyncio
async def test_delete_by_ids(self, collection_manager):
"""Test deletion by chunk IDs."""
from models import DeleteRequest
request = DeleteRequest(
project_id="proj-123",
agent_id="agent-456",
chunk_ids=["id-1", "id-2"],
)
result = await collection_manager.delete(request)
assert result.success is True
collection_manager._database.delete_by_ids.assert_called_once()
@pytest.mark.asyncio
async def test_delete_no_target(self, collection_manager):
"""Test deletion with no target specified."""
from models import DeleteRequest
request = DeleteRequest(
project_id="proj-123",
agent_id="agent-456",
)
result = await collection_manager.delete(request)
assert result.success is False
assert "Must specify" in result.error
@pytest.mark.asyncio
async def test_list_collections(self, collection_manager):
"""Test listing collections."""
from models import CollectionInfo
collection_manager._database.list_collections.return_value = [
CollectionInfo(
name="collection-1",
project_id="proj-123",
chunk_count=100,
total_tokens=50000,
file_types=["python"],
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
),
CollectionInfo(
name="collection-2",
project_id="proj-123",
chunk_count=50,
total_tokens=25000,
file_types=["javascript"],
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
),
]
result = await collection_manager.list_collections("proj-123")
assert result.project_id == "proj-123"
assert result.total_collections == 2
assert len(result.collections) == 2
@pytest.mark.asyncio
async def test_get_collection_stats(self, collection_manager):
"""Test getting collection statistics."""
from models import CollectionStats
expected_stats = CollectionStats(
collection="test-collection",
project_id="proj-123",
chunk_count=100,
unique_sources=10,
total_tokens=50000,
avg_chunk_size=500.0,
chunk_types={"code": 60, "text": 40},
file_types={"python": 50, "javascript": 10},
)
collection_manager._database.get_collection_stats.return_value = expected_stats
stats = await collection_manager.get_collection_stats("proj-123", "test-collection")
assert stats.chunk_count == 100
assert stats.unique_sources == 10
collection_manager._database.get_collection_stats.assert_called_once_with(
"proj-123", "test-collection"
)
@pytest.mark.asyncio
async def test_update_document(self, collection_manager):
"""Test updating a document."""
result = await collection_manager.update_document(
project_id="proj-123",
agent_id="agent-456",
source_path="/test/file.py",
content="def updated(): pass",
collection="default",
)
# Should delete first, then ingest
collection_manager._database.delete_by_source.assert_called_once()
assert result.success is True
@pytest.mark.asyncio
async def test_cleanup_expired(self, collection_manager):
"""Test cleaning up expired embeddings."""
collection_manager._database.cleanup_expired.return_value = 10
count = await collection_manager.cleanup_expired()
assert count == 10
collection_manager._database.cleanup_expired.assert_called_once()
class TestGlobalCollectionManager:
"""Tests for global collection manager."""
def test_get_collection_manager_singleton(self):
"""Test that get_collection_manager returns singleton."""
from collection_manager import get_collection_manager, reset_collection_manager
reset_collection_manager()
manager1 = get_collection_manager()
manager2 = get_collection_manager()
assert manager1 is manager2
def test_reset_collection_manager(self):
"""Test resetting collection manager."""
from collection_manager import get_collection_manager, reset_collection_manager
manager1 = get_collection_manager()
reset_collection_manager()
manager2 = get_collection_manager()
assert manager1 is not manager2

View File

@@ -0,0 +1,104 @@
"""Tests for configuration module."""
import os
class TestSettings:
"""Tests for Settings class."""
def test_default_values(self, settings):
"""Test default configuration values."""
assert settings.port == 8002
assert settings.embedding_dimension == 1536
assert settings.code_chunk_size == 500
assert settings.search_default_limit == 10
def test_env_prefix(self):
"""Test environment variable prefix."""
from config import Settings, reset_settings
reset_settings()
os.environ["KB_PORT"] = "9999"
settings = Settings()
assert settings.port == 9999
# Cleanup
del os.environ["KB_PORT"]
reset_settings()
def test_embedding_settings(self, settings):
"""Test embedding-related settings."""
assert settings.embedding_model == "text-embedding-3-large"
assert settings.embedding_batch_size == 100
assert settings.embedding_cache_ttl == 86400
def test_chunking_settings(self, settings):
"""Test chunking-related settings."""
assert settings.code_chunk_size == 500
assert settings.code_chunk_overlap == 50
assert settings.markdown_chunk_size == 800
assert settings.markdown_chunk_overlap == 100
assert settings.text_chunk_size == 400
assert settings.text_chunk_overlap == 50
def test_search_settings(self, settings):
"""Test search-related settings."""
assert settings.search_default_limit == 10
assert settings.search_max_limit == 100
assert settings.semantic_threshold == 0.7
assert settings.hybrid_semantic_weight == 0.7
assert settings.hybrid_keyword_weight == 0.3
class TestGetSettings:
"""Tests for get_settings function."""
def test_returns_singleton(self):
"""Test that get_settings returns singleton."""
from config import get_settings, reset_settings
reset_settings()
settings1 = get_settings()
settings2 = get_settings()
assert settings1 is settings2
def test_reset_creates_new_instance(self):
"""Test that reset_settings clears the singleton."""
from config import get_settings, reset_settings
settings1 = get_settings()
reset_settings()
settings2 = get_settings()
assert settings1 is not settings2
class TestIsTestMode:
"""Tests for is_test_mode function."""
def test_returns_true_when_set(self):
"""Test returns True when IS_TEST is set."""
from config import is_test_mode
old_value = os.environ.get("IS_TEST")
os.environ["IS_TEST"] = "true"
assert is_test_mode() is True
if old_value:
os.environ["IS_TEST"] = old_value
else:
del os.environ["IS_TEST"]
def test_returns_false_when_not_set(self):
"""Test returns False when IS_TEST is not set."""
from config import is_test_mode
old_value = os.environ.get("IS_TEST")
if old_value:
del os.environ["IS_TEST"]
assert is_test_mode() is False
if old_value:
os.environ["IS_TEST"] = old_value

View File

@@ -0,0 +1,245 @@
"""Tests for embedding generation module."""
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
class TestEmbeddingGenerator:
"""Tests for EmbeddingGenerator class."""
@pytest.fixture
def mock_http_response(self):
"""Create mock HTTP response."""
response = MagicMock()
response.status_code = 200
response.raise_for_status = MagicMock()
response.json.return_value = {
"result": {
"content": [
{
"text": json.dumps({
"embeddings": [[0.1] * 1536]
})
}
]
}
}
return response
@pytest.mark.asyncio
async def test_generate_single_embedding(self, settings, mock_redis, mock_http_response):
"""Test generating a single embedding."""
from embeddings import EmbeddingGenerator
generator = EmbeddingGenerator(settings=settings)
generator._redis = mock_redis
# Mock HTTP client
mock_client = AsyncMock()
mock_client.post.return_value = mock_http_response
generator._http_client = mock_client
embedding = await generator.generate(
text="Hello, world!",
project_id="proj-123",
agent_id="agent-456",
)
assert len(embedding) == 1536
mock_client.post.assert_called_once()
@pytest.mark.asyncio
async def test_generate_batch_embeddings(self, settings, mock_redis):
"""Test generating batch embeddings."""
from embeddings import EmbeddingGenerator
generator = EmbeddingGenerator(settings=settings)
generator._redis = mock_redis
# Mock HTTP client with batch response
mock_client = AsyncMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.raise_for_status = MagicMock()
mock_response.json.return_value = {
"result": {
"content": [
{
"text": json.dumps({
"embeddings": [[0.1] * 1536, [0.2] * 1536, [0.3] * 1536]
})
}
]
}
}
mock_client.post.return_value = mock_response
generator._http_client = mock_client
embeddings = await generator.generate_batch(
texts=["Text 1", "Text 2", "Text 3"],
project_id="proj-123",
agent_id="agent-456",
)
assert len(embeddings) == 3
assert all(len(e) == 1536 for e in embeddings)
@pytest.mark.asyncio
async def test_caching(self, settings, mock_redis):
"""Test embedding caching."""
from embeddings import EmbeddingGenerator
generator = EmbeddingGenerator(settings=settings)
generator._redis = mock_redis
# Pre-populate cache
cache_key = generator._cache_key("Hello, world!")
await mock_redis.setex(cache_key, 3600, json.dumps([0.5] * 1536))
# Mock HTTP client (should not be called)
mock_client = AsyncMock()
generator._http_client = mock_client
embedding = await generator.generate(
text="Hello, world!",
project_id="proj-123",
agent_id="agent-456",
)
# Should return cached embedding
assert len(embedding) == 1536
assert embedding[0] == 0.5
mock_client.post.assert_not_called()
@pytest.mark.asyncio
async def test_cache_miss(self, settings, mock_redis, mock_http_response):
"""Test embedding cache miss."""
from embeddings import EmbeddingGenerator
generator = EmbeddingGenerator(settings=settings)
generator._redis = mock_redis
mock_client = AsyncMock()
mock_client.post.return_value = mock_http_response
generator._http_client = mock_client
embedding = await generator.generate(
text="New text not in cache",
project_id="proj-123",
agent_id="agent-456",
)
assert len(embedding) == 1536
mock_client.post.assert_called_once()
def test_cache_key_generation(self, settings):
"""Test cache key generation."""
from embeddings import EmbeddingGenerator
generator = EmbeddingGenerator(settings=settings)
key1 = generator._cache_key("Hello")
key2 = generator._cache_key("Hello")
key3 = generator._cache_key("World")
assert key1 == key2
assert key1 != key3
assert key1.startswith("kb:emb:")
@pytest.mark.asyncio
async def test_dimension_validation(self, settings, mock_redis):
"""Test embedding dimension validation."""
from embeddings import EmbeddingGenerator
from exceptions import EmbeddingDimensionMismatchError
generator = EmbeddingGenerator(settings=settings)
generator._redis = mock_redis
# Mock HTTP client with wrong dimension
mock_client = AsyncMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.raise_for_status = MagicMock()
mock_response.json.return_value = {
"result": {
"content": [
{
"text": json.dumps({
"embeddings": [[0.1] * 768] # Wrong dimension
})
}
]
}
}
mock_client.post.return_value = mock_response
generator._http_client = mock_client
with pytest.raises(EmbeddingDimensionMismatchError):
await generator.generate(
text="Test text",
project_id="proj-123",
agent_id="agent-456",
)
@pytest.mark.asyncio
async def test_empty_batch(self, settings, mock_redis):
"""Test generating embeddings for empty batch."""
from embeddings import EmbeddingGenerator
generator = EmbeddingGenerator(settings=settings)
generator._redis = mock_redis
embeddings = await generator.generate_batch(
texts=[],
project_id="proj-123",
agent_id="agent-456",
)
assert embeddings == []
@pytest.mark.asyncio
async def test_initialize_and_close(self, settings):
"""Test initialize and close methods."""
from embeddings import EmbeddingGenerator
generator = EmbeddingGenerator(settings=settings)
# Mock successful initialization
with patch("embeddings.redis.from_url") as mock_redis_from_url:
mock_redis_client = AsyncMock()
mock_redis_client.ping = AsyncMock()
mock_redis_from_url.return_value = mock_redis_client
await generator.initialize()
assert generator._http_client is not None
await generator.close()
assert generator._http_client is None
class TestGlobalEmbeddingGenerator:
"""Tests for global embedding generator."""
def test_get_embedding_generator_singleton(self):
"""Test that get_embedding_generator returns singleton."""
from embeddings import get_embedding_generator, reset_embedding_generator
reset_embedding_generator()
gen1 = get_embedding_generator()
gen2 = get_embedding_generator()
assert gen1 is gen2
def test_reset_embedding_generator(self):
"""Test resetting embedding generator."""
from embeddings import get_embedding_generator, reset_embedding_generator
gen1 = get_embedding_generator()
reset_embedding_generator()
gen2 = get_embedding_generator()
assert gen1 is not gen2

View File

@@ -0,0 +1,307 @@
"""Tests for exception classes."""
class TestErrorCode:
"""Tests for ErrorCode enum."""
def test_error_code_values(self):
"""Test error code values."""
from exceptions import ErrorCode
assert ErrorCode.UNKNOWN_ERROR.value == "KB_UNKNOWN_ERROR"
assert ErrorCode.DATABASE_CONNECTION_ERROR.value == "KB_DATABASE_CONNECTION_ERROR"
assert ErrorCode.EMBEDDING_GENERATION_ERROR.value == "KB_EMBEDDING_GENERATION_ERROR"
assert ErrorCode.CHUNKING_ERROR.value == "KB_CHUNKING_ERROR"
assert ErrorCode.SEARCH_ERROR.value == "KB_SEARCH_ERROR"
assert ErrorCode.COLLECTION_NOT_FOUND.value == "KB_COLLECTION_NOT_FOUND"
assert ErrorCode.DOCUMENT_NOT_FOUND.value == "KB_DOCUMENT_NOT_FOUND"
class TestKnowledgeBaseError:
"""Tests for base exception class."""
def test_basic_error(self):
"""Test basic error creation."""
from exceptions import ErrorCode, KnowledgeBaseError
error = KnowledgeBaseError(
message="Something went wrong",
code=ErrorCode.UNKNOWN_ERROR,
)
assert error.message == "Something went wrong"
assert error.code == ErrorCode.UNKNOWN_ERROR
assert error.details == {}
assert error.cause is None
def test_error_with_details(self):
"""Test error with details."""
from exceptions import ErrorCode, KnowledgeBaseError
error = KnowledgeBaseError(
message="Query failed",
code=ErrorCode.DATABASE_QUERY_ERROR,
details={"query": "SELECT * FROM table", "error_code": 42},
)
assert error.details["query"] == "SELECT * FROM table"
assert error.details["error_code"] == 42
def test_error_with_cause(self):
"""Test error with underlying cause."""
from exceptions import ErrorCode, KnowledgeBaseError
original = ValueError("Original error")
error = KnowledgeBaseError(
message="Wrapped error",
code=ErrorCode.INTERNAL_ERROR,
cause=original,
)
assert error.cause is original
assert isinstance(error.cause, ValueError)
def test_to_dict(self):
"""Test to_dict method."""
from exceptions import ErrorCode, KnowledgeBaseError
error = KnowledgeBaseError(
message="Test error",
code=ErrorCode.INVALID_REQUEST,
details={"field": "value"},
)
result = error.to_dict()
assert result["error"] == "KB_INVALID_REQUEST"
assert result["message"] == "Test error"
assert result["details"]["field"] == "value"
def test_str_representation(self):
"""Test string representation."""
from exceptions import ErrorCode, KnowledgeBaseError
error = KnowledgeBaseError(
message="Test error",
code=ErrorCode.INVALID_REQUEST,
)
assert str(error) == "[KB_INVALID_REQUEST] Test error"
def test_repr_representation(self):
"""Test repr representation."""
from exceptions import ErrorCode, KnowledgeBaseError
error = KnowledgeBaseError(
message="Test error",
code=ErrorCode.INVALID_REQUEST,
details={"key": "value"},
)
repr_str = repr(error)
assert "KnowledgeBaseError" in repr_str
assert "Test error" in repr_str
assert "KB_INVALID_REQUEST" in repr_str
class TestDatabaseErrors:
"""Tests for database-related exceptions."""
def test_database_connection_error(self):
"""Test database connection error."""
from exceptions import DatabaseConnectionError, ErrorCode
error = DatabaseConnectionError(
message="Cannot connect to database",
details={"host": "localhost", "port": 5432},
)
assert error.code == ErrorCode.DATABASE_CONNECTION_ERROR
assert error.details["host"] == "localhost"
def test_database_connection_error_default_message(self):
"""Test database connection error with default message."""
from exceptions import DatabaseConnectionError
error = DatabaseConnectionError()
assert error.message == "Failed to connect to database"
def test_database_query_error(self):
"""Test database query error."""
from exceptions import DatabaseQueryError, ErrorCode
error = DatabaseQueryError(
message="Query failed",
query="SELECT * FROM missing_table",
)
assert error.code == ErrorCode.DATABASE_QUERY_ERROR
assert error.details["query"] == "SELECT * FROM missing_table"
class TestEmbeddingErrors:
"""Tests for embedding-related exceptions."""
def test_embedding_generation_error(self):
"""Test embedding generation error."""
from exceptions import EmbeddingGenerationError, ErrorCode
error = EmbeddingGenerationError(
message="Failed to generate",
texts_count=10,
)
assert error.code == ErrorCode.EMBEDDING_GENERATION_ERROR
assert error.details["texts_count"] == 10
def test_embedding_dimension_mismatch(self):
"""Test embedding dimension mismatch error."""
from exceptions import EmbeddingDimensionMismatchError, ErrorCode
error = EmbeddingDimensionMismatchError(
expected=1536,
actual=768,
)
assert error.code == ErrorCode.EMBEDDING_DIMENSION_MISMATCH
assert "expected 1536" in error.message
assert "got 768" in error.message
assert error.details["expected_dimension"] == 1536
assert error.details["actual_dimension"] == 768
class TestChunkingErrors:
"""Tests for chunking-related exceptions."""
def test_unsupported_file_type_error(self):
"""Test unsupported file type error."""
from exceptions import ErrorCode, UnsupportedFileTypeError
error = UnsupportedFileTypeError(
file_type=".xyz",
supported_types=[".py", ".js", ".md"],
)
assert error.code == ErrorCode.UNSUPPORTED_FILE_TYPE
assert error.details["file_type"] == ".xyz"
assert len(error.details["supported_types"]) == 3
def test_file_too_large_error(self):
"""Test file too large error."""
from exceptions import ErrorCode, FileTooLargeError
error = FileTooLargeError(
file_size=10_000_000,
max_size=1_000_000,
)
assert error.code == ErrorCode.FILE_TOO_LARGE
assert error.details["file_size"] == 10_000_000
assert error.details["max_size"] == 1_000_000
def test_encoding_error(self):
"""Test encoding error."""
from exceptions import EncodingError, ErrorCode
error = EncodingError(
message="Cannot decode file",
encoding="utf-8",
)
assert error.code == ErrorCode.ENCODING_ERROR
assert error.details["encoding"] == "utf-8"
class TestSearchErrors:
"""Tests for search-related exceptions."""
def test_invalid_search_type_error(self):
"""Test invalid search type error."""
from exceptions import ErrorCode, InvalidSearchTypeError
error = InvalidSearchTypeError(
search_type="invalid",
valid_types=["semantic", "keyword", "hybrid"],
)
assert error.code == ErrorCode.INVALID_SEARCH_TYPE
assert error.details["search_type"] == "invalid"
assert len(error.details["valid_types"]) == 3
def test_search_timeout_error(self):
"""Test search timeout error."""
from exceptions import ErrorCode, SearchTimeoutError
error = SearchTimeoutError(timeout=30.0)
assert error.code == ErrorCode.SEARCH_TIMEOUT
assert error.details["timeout"] == 30.0
assert "30" in error.message
class TestCollectionErrors:
"""Tests for collection-related exceptions."""
def test_collection_not_found_error(self):
"""Test collection not found error."""
from exceptions import CollectionNotFoundError, ErrorCode
error = CollectionNotFoundError(
collection="missing-collection",
project_id="proj-123",
)
assert error.code == ErrorCode.COLLECTION_NOT_FOUND
assert error.details["collection"] == "missing-collection"
assert error.details["project_id"] == "proj-123"
class TestDocumentErrors:
"""Tests for document-related exceptions."""
def test_document_not_found_error(self):
"""Test document not found error."""
from exceptions import DocumentNotFoundError, ErrorCode
error = DocumentNotFoundError(
source_path="/path/to/file.py",
project_id="proj-123",
)
assert error.code == ErrorCode.DOCUMENT_NOT_FOUND
assert error.details["source_path"] == "/path/to/file.py"
def test_invalid_document_error(self):
"""Test invalid document error."""
from exceptions import ErrorCode, InvalidDocumentError
error = InvalidDocumentError(
message="Empty content",
details={"reason": "no content"},
)
assert error.code == ErrorCode.INVALID_DOCUMENT
class TestProjectErrors:
"""Tests for project-related exceptions."""
def test_project_not_found_error(self):
"""Test project not found error."""
from exceptions import ErrorCode, ProjectNotFoundError
error = ProjectNotFoundError(project_id="missing-proj")
assert error.code == ErrorCode.PROJECT_NOT_FOUND
assert error.details["project_id"] == "missing-proj"
def test_project_access_denied_error(self):
"""Test project access denied error."""
from exceptions import ErrorCode, ProjectAccessDeniedError
error = ProjectAccessDeniedError(project_id="restricted-proj")
assert error.code == ErrorCode.PROJECT_ACCESS_DENIED
assert "restricted-proj" in error.message

View File

@@ -0,0 +1,347 @@
"""Tests for data models."""
from datetime import UTC, datetime
class TestEnums:
"""Tests for enum classes."""
def test_search_type_values(self):
"""Test SearchType enum values."""
from models import SearchType
assert SearchType.SEMANTIC.value == "semantic"
assert SearchType.KEYWORD.value == "keyword"
assert SearchType.HYBRID.value == "hybrid"
def test_chunk_type_values(self):
"""Test ChunkType enum values."""
from models import ChunkType
assert ChunkType.CODE.value == "code"
assert ChunkType.MARKDOWN.value == "markdown"
assert ChunkType.TEXT.value == "text"
assert ChunkType.DOCUMENTATION.value == "documentation"
def test_file_type_values(self):
"""Test FileType enum values."""
from models import FileType
assert FileType.PYTHON.value == "python"
assert FileType.JAVASCRIPT.value == "javascript"
assert FileType.TYPESCRIPT.value == "typescript"
assert FileType.MARKDOWN.value == "markdown"
class TestFileExtensionMap:
"""Tests for file extension mapping."""
def test_python_extensions(self):
"""Test Python file extensions."""
from models import FILE_EXTENSION_MAP, FileType
assert FILE_EXTENSION_MAP[".py"] == FileType.PYTHON
def test_javascript_extensions(self):
"""Test JavaScript file extensions."""
from models import FILE_EXTENSION_MAP, FileType
assert FILE_EXTENSION_MAP[".js"] == FileType.JAVASCRIPT
assert FILE_EXTENSION_MAP[".jsx"] == FileType.JAVASCRIPT
def test_typescript_extensions(self):
"""Test TypeScript file extensions."""
from models import FILE_EXTENSION_MAP, FileType
assert FILE_EXTENSION_MAP[".ts"] == FileType.TYPESCRIPT
assert FILE_EXTENSION_MAP[".tsx"] == FileType.TYPESCRIPT
def test_markdown_extensions(self):
"""Test Markdown file extensions."""
from models import FILE_EXTENSION_MAP, FileType
assert FILE_EXTENSION_MAP[".md"] == FileType.MARKDOWN
assert FILE_EXTENSION_MAP[".mdx"] == FileType.MARKDOWN
class TestChunk:
"""Tests for Chunk dataclass."""
def test_chunk_creation(self, sample_chunk):
"""Test chunk creation."""
from models import ChunkType, FileType
assert sample_chunk.content == "def hello():\n print('Hello')"
assert sample_chunk.chunk_type == ChunkType.CODE
assert sample_chunk.file_type == FileType.PYTHON
assert sample_chunk.source_path == "/test/hello.py"
assert sample_chunk.start_line == 1
assert sample_chunk.end_line == 2
assert sample_chunk.token_count == 15
def test_chunk_to_dict(self, sample_chunk):
"""Test chunk to_dict method."""
result = sample_chunk.to_dict()
assert result["content"] == "def hello():\n print('Hello')"
assert result["chunk_type"] == "code"
assert result["file_type"] == "python"
assert result["source_path"] == "/test/hello.py"
assert result["start_line"] == 1
assert result["end_line"] == 2
assert result["token_count"] == 15
class TestKnowledgeEmbedding:
"""Tests for KnowledgeEmbedding dataclass."""
def test_embedding_creation(self, sample_embedding):
"""Test embedding creation."""
assert sample_embedding.id == "test-id-123"
assert sample_embedding.project_id == "proj-123"
assert sample_embedding.collection == "default"
assert len(sample_embedding.embedding) == 1536
def test_embedding_to_dict(self, sample_embedding):
"""Test embedding to_dict method."""
result = sample_embedding.to_dict()
assert result["id"] == "test-id-123"
assert result["project_id"] == "proj-123"
assert result["collection"] == "default"
assert result["chunk_type"] == "code"
assert result["file_type"] == "python"
assert "embedding" not in result # Embedding excluded for size
class TestIngestRequest:
"""Tests for IngestRequest model."""
def test_ingest_request_creation(self, sample_ingest_request):
"""Test ingest request creation."""
from models import ChunkType, FileType
assert sample_ingest_request.project_id == "proj-123"
assert sample_ingest_request.agent_id == "agent-456"
assert sample_ingest_request.chunk_type == ChunkType.CODE
assert sample_ingest_request.file_type == FileType.PYTHON
assert sample_ingest_request.collection == "default"
def test_ingest_request_defaults(self):
"""Test ingest request default values."""
from models import ChunkType, IngestRequest
request = IngestRequest(
project_id="proj-123",
agent_id="agent-456",
content="test content",
)
assert request.collection == "default"
assert request.chunk_type == ChunkType.TEXT
assert request.file_type is None
assert request.metadata == {}
class TestIngestResult:
"""Tests for IngestResult model."""
def test_successful_result(self):
"""Test successful ingest result."""
from models import IngestResult
result = IngestResult(
success=True,
chunks_created=5,
embeddings_generated=5,
source_path="/test/file.py",
collection="default",
chunk_ids=["id1", "id2", "id3", "id4", "id5"],
)
assert result.success is True
assert result.chunks_created == 5
assert result.error is None
def test_failed_result(self):
"""Test failed ingest result."""
from models import IngestResult
result = IngestResult(
success=False,
chunks_created=0,
embeddings_generated=0,
collection="default",
chunk_ids=[],
error="Something went wrong",
)
assert result.success is False
assert result.error == "Something went wrong"
class TestSearchRequest:
"""Tests for SearchRequest model."""
def test_search_request_creation(self, sample_search_request):
"""Test search request creation."""
from models import SearchType
assert sample_search_request.project_id == "proj-123"
assert sample_search_request.query == "hello function"
assert sample_search_request.search_type == SearchType.HYBRID
assert sample_search_request.limit == 10
assert sample_search_request.threshold == 0.7
def test_search_request_defaults(self):
"""Test search request default values."""
from models import SearchRequest, SearchType
request = SearchRequest(
project_id="proj-123",
agent_id="agent-456",
query="test query",
)
assert request.search_type == SearchType.HYBRID
assert request.collection is None
assert request.limit == 10
assert request.threshold == 0.7
assert request.file_types is None
class TestSearchResult:
"""Tests for SearchResult model."""
def test_from_embedding(self, sample_embedding):
"""Test creating SearchResult from KnowledgeEmbedding."""
from models import SearchResult
result = SearchResult.from_embedding(sample_embedding, 0.95)
assert result.id == "test-id-123"
assert result.content == "def hello():\n print('Hello')"
assert result.score == 0.95
assert result.source_path == "/test/hello.py"
assert result.chunk_type == "code"
assert result.file_type == "python"
class TestSearchResponse:
"""Tests for SearchResponse model."""
def test_search_response(self):
"""Test search response creation."""
from models import SearchResponse, SearchResult
results = [
SearchResult(
id="id1",
content="test content 1",
score=0.95,
chunk_type="code",
collection="default",
),
SearchResult(
id="id2",
content="test content 2",
score=0.85,
chunk_type="text",
collection="default",
),
]
response = SearchResponse(
query="test query",
search_type="hybrid",
results=results,
total_results=2,
search_time_ms=15.5,
)
assert response.query == "test query"
assert len(response.results) == 2
assert response.search_time_ms == 15.5
class TestDeleteRequest:
"""Tests for DeleteRequest model."""
def test_delete_by_source(self, sample_delete_request):
"""Test delete request by source path."""
assert sample_delete_request.project_id == "proj-123"
assert sample_delete_request.source_path == "/test/hello.py"
assert sample_delete_request.collection is None
assert sample_delete_request.chunk_ids is None
def test_delete_by_collection(self):
"""Test delete request by collection."""
from models import DeleteRequest
request = DeleteRequest(
project_id="proj-123",
agent_id="agent-456",
collection="to-delete",
)
assert request.collection == "to-delete"
assert request.source_path is None
def test_delete_by_ids(self):
"""Test delete request by chunk IDs."""
from models import DeleteRequest
request = DeleteRequest(
project_id="proj-123",
agent_id="agent-456",
chunk_ids=["id1", "id2", "id3"],
)
assert len(request.chunk_ids) == 3
class TestCollectionInfo:
"""Tests for CollectionInfo model."""
def test_collection_info(self):
"""Test collection info creation."""
from models import CollectionInfo
info = CollectionInfo(
name="test-collection",
project_id="proj-123",
chunk_count=100,
total_tokens=50000,
file_types=["python", "javascript"],
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
assert info.name == "test-collection"
assert info.chunk_count == 100
assert len(info.file_types) == 2
class TestCollectionStats:
"""Tests for CollectionStats model."""
def test_collection_stats(self):
"""Test collection stats creation."""
from models import CollectionStats
stats = CollectionStats(
collection="test-collection",
project_id="proj-123",
chunk_count=100,
unique_sources=10,
total_tokens=50000,
avg_chunk_size=500.0,
chunk_types={"code": 60, "text": 40},
file_types={"python": 50, "javascript": 10},
oldest_chunk=datetime.now(UTC),
newest_chunk=datetime.now(UTC),
)
assert stats.chunk_count == 100
assert stats.unique_sources == 10
assert stats.chunk_types["code"] == 60

View File

@@ -0,0 +1,295 @@
"""Tests for search module."""
from datetime import UTC, datetime
import pytest
class TestSearchEngine:
"""Tests for SearchEngine class."""
@pytest.fixture
def search_engine(self, settings, mock_database, mock_embeddings):
"""Create search engine with mocks."""
from search import SearchEngine
engine = SearchEngine(
settings=settings,
database=mock_database,
embeddings=mock_embeddings,
)
return engine
@pytest.fixture
def sample_db_results(self):
"""Create sample database results."""
from models import ChunkType, FileType, KnowledgeEmbedding
return [
(
KnowledgeEmbedding(
id="id-1",
project_id="proj-123",
collection="default",
content="def hello(): pass",
embedding=[0.1] * 1536,
chunk_type=ChunkType.CODE,
source_path="/test/file.py",
file_type=FileType.PYTHON,
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
),
0.95,
),
(
KnowledgeEmbedding(
id="id-2",
project_id="proj-123",
collection="default",
content="def world(): pass",
embedding=[0.2] * 1536,
chunk_type=ChunkType.CODE,
source_path="/test/file2.py",
file_type=FileType.PYTHON,
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
),
0.85,
),
]
@pytest.mark.asyncio
async def test_semantic_search(self, search_engine, sample_search_request, sample_db_results):
"""Test semantic search."""
from models import SearchType
sample_search_request.search_type = SearchType.SEMANTIC
search_engine._database.semantic_search.return_value = sample_db_results
response = await search_engine.search(sample_search_request)
assert response.search_type == "semantic"
assert len(response.results) == 2
assert response.results[0].score == 0.95
search_engine._database.semantic_search.assert_called_once()
@pytest.mark.asyncio
async def test_keyword_search(self, search_engine, sample_search_request, sample_db_results):
"""Test keyword search."""
from models import SearchType
sample_search_request.search_type = SearchType.KEYWORD
search_engine._database.keyword_search.return_value = sample_db_results
response = await search_engine.search(sample_search_request)
assert response.search_type == "keyword"
assert len(response.results) == 2
search_engine._database.keyword_search.assert_called_once()
@pytest.mark.asyncio
async def test_hybrid_search(self, search_engine, sample_search_request, sample_db_results):
"""Test hybrid search."""
from models import SearchType
sample_search_request.search_type = SearchType.HYBRID
# Both searches return same results for simplicity
search_engine._database.semantic_search.return_value = sample_db_results
search_engine._database.keyword_search.return_value = sample_db_results
response = await search_engine.search(sample_search_request)
assert response.search_type == "hybrid"
# Results should be fused
assert len(response.results) >= 1
@pytest.mark.asyncio
async def test_search_with_collection_filter(self, search_engine, sample_search_request, sample_db_results):
"""Test search with collection filter."""
from models import SearchType
sample_search_request.search_type = SearchType.SEMANTIC
sample_search_request.collection = "specific-collection"
search_engine._database.semantic_search.return_value = sample_db_results
await search_engine.search(sample_search_request)
# Verify collection was passed to database
call_args = search_engine._database.semantic_search.call_args
assert call_args.kwargs["collection"] == "specific-collection"
@pytest.mark.asyncio
async def test_search_with_file_type_filter(self, search_engine, sample_search_request, sample_db_results):
"""Test search with file type filter."""
from models import FileType, SearchType
sample_search_request.search_type = SearchType.SEMANTIC
sample_search_request.file_types = [FileType.PYTHON]
search_engine._database.semantic_search.return_value = sample_db_results
await search_engine.search(sample_search_request)
# Verify file types were passed to database
call_args = search_engine._database.semantic_search.call_args
assert call_args.kwargs["file_types"] == [FileType.PYTHON]
@pytest.mark.asyncio
async def test_search_respects_limit(self, search_engine, sample_search_request, sample_db_results):
"""Test that search respects result limit."""
from models import SearchType
sample_search_request.search_type = SearchType.SEMANTIC
sample_search_request.limit = 1
search_engine._database.semantic_search.return_value = sample_db_results[:1]
response = await search_engine.search(sample_search_request)
assert len(response.results) <= 1
@pytest.mark.asyncio
async def test_search_records_time(self, search_engine, sample_search_request, sample_db_results):
"""Test that search records time."""
from models import SearchType
sample_search_request.search_type = SearchType.SEMANTIC
search_engine._database.semantic_search.return_value = sample_db_results
response = await search_engine.search(sample_search_request)
assert response.search_time_ms > 0
@pytest.mark.asyncio
async def test_invalid_search_type(self, search_engine, sample_search_request):
"""Test handling invalid search type."""
from exceptions import InvalidSearchTypeError
# Force invalid search type
sample_search_request.search_type = "invalid"
with pytest.raises((InvalidSearchTypeError, ValueError)):
await search_engine.search(sample_search_request)
@pytest.mark.asyncio
async def test_empty_results(self, search_engine, sample_search_request):
"""Test search with no results."""
from models import SearchType
sample_search_request.search_type = SearchType.SEMANTIC
search_engine._database.semantic_search.return_value = []
response = await search_engine.search(sample_search_request)
assert len(response.results) == 0
assert response.total_results == 0
class TestReciprocalRankFusion:
"""Tests for reciprocal rank fusion."""
@pytest.fixture
def search_engine(self, settings, mock_database, mock_embeddings):
"""Create search engine with mocks."""
from search import SearchEngine
return SearchEngine(
settings=settings,
database=mock_database,
embeddings=mock_embeddings,
)
def test_fusion_combines_results(self, search_engine):
"""Test that RRF combines results from both searches."""
from models import SearchResult
semantic = [
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"),
SearchResult(id="b", content="B", score=0.8, chunk_type="code", collection="default"),
]
keyword = [
SearchResult(id="b", content="B", score=0.85, chunk_type="code", collection="default"),
SearchResult(id="c", content="C", score=0.7, chunk_type="code", collection="default"),
]
fused = search_engine._reciprocal_rank_fusion(semantic, keyword)
# Should have all unique results
ids = [r.id for r in fused]
assert "a" in ids
assert "b" in ids
assert "c" in ids
# B should be ranked higher (appears in both)
b_rank = ids.index("b")
assert b_rank < 2 # Should be in top 2
def test_fusion_respects_weights(self, search_engine):
"""Test that RRF respects semantic/keyword weights."""
from models import SearchResult
# Same results in same order
results = [
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"),
]
# High semantic weight
fused_semantic_heavy = search_engine._reciprocal_rank_fusion(
results, [],
semantic_weight=0.9,
keyword_weight=0.1,
)
# High keyword weight
fused_keyword_heavy = search_engine._reciprocal_rank_fusion(
[], results,
semantic_weight=0.1,
keyword_weight=0.9,
)
# Both should still return the result
assert len(fused_semantic_heavy) == 1
assert len(fused_keyword_heavy) == 1
def test_fusion_normalizes_scores(self, search_engine):
"""Test that RRF normalizes scores to 0-1."""
from models import SearchResult
semantic = [
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"),
SearchResult(id="b", content="B", score=0.8, chunk_type="code", collection="default"),
]
keyword = [
SearchResult(id="c", content="C", score=0.7, chunk_type="code", collection="default"),
]
fused = search_engine._reciprocal_rank_fusion(semantic, keyword)
# All scores should be between 0 and 1
for result in fused:
assert 0 <= result.score <= 1
class TestGlobalSearchEngine:
"""Tests for global search engine."""
def test_get_search_engine_singleton(self):
"""Test that get_search_engine returns singleton."""
from search import get_search_engine, reset_search_engine
reset_search_engine()
engine1 = get_search_engine()
engine2 = get_search_engine()
assert engine1 is engine2
def test_reset_search_engine(self):
"""Test resetting search engine."""
from search import get_search_engine, reset_search_engine
engine1 = get_search_engine()
reset_search_engine()
engine2 = get_search_engine()
assert engine1 is not engine2

View File

@@ -0,0 +1,357 @@
"""Tests for server module and MCP tools."""
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock
import pytest
class TestHealthCheck:
"""Tests for health check endpoint."""
@pytest.mark.asyncio
async def test_health_check_healthy(self):
"""Test health check when healthy."""
import server
# Create a proper async context manager mock
mock_conn = AsyncMock()
mock_conn.fetchval = AsyncMock(return_value=1)
mock_db = MagicMock()
mock_db._pool = MagicMock()
# Make acquire an async context manager
mock_cm = AsyncMock()
mock_cm.__aenter__.return_value = mock_conn
mock_cm.__aexit__.return_value = None
mock_db.acquire.return_value = mock_cm
server._database = mock_db
result = await server.health_check()
assert result["status"] == "healthy"
assert result["service"] == "knowledge-base"
assert result["database"] == "connected"
@pytest.mark.asyncio
async def test_health_check_no_database(self):
"""Test health check without database."""
import server
server._database = None
result = await server.health_check()
assert result["database"] == "not initialized"
class TestSearchKnowledgeTool:
"""Tests for search_knowledge MCP tool."""
@pytest.mark.asyncio
async def test_search_success(self):
"""Test successful search."""
import server
from models import SearchResponse, SearchResult
mock_search = MagicMock()
mock_search.search = AsyncMock(
return_value=SearchResponse(
query="test query",
search_type="hybrid",
results=[
SearchResult(
id="id-1",
content="Test content",
score=0.95,
source_path="/test/file.py",
chunk_type="code",
collection="default",
)
],
total_results=1,
search_time_ms=10.5,
)
)
server._search = mock_search
# Call the wrapped function via .fn
result = await server.search_knowledge.fn(
project_id="proj-123",
agent_id="agent-456",
query="test query",
search_type="hybrid",
collection=None,
limit=10,
threshold=0.7,
file_types=None,
)
assert result["success"] is True
assert len(result["results"]) == 1
assert result["results"][0]["score"] == 0.95
@pytest.mark.asyncio
async def test_search_invalid_type(self):
"""Test search with invalid search type."""
import server
result = await server.search_knowledge.fn(
project_id="proj-123",
agent_id="agent-456",
query="test",
search_type="invalid",
)
assert result["success"] is False
assert "Invalid search type" in result["error"]
@pytest.mark.asyncio
async def test_search_invalid_file_type(self):
"""Test search with invalid file type."""
import server
result = await server.search_knowledge.fn(
project_id="proj-123",
agent_id="agent-456",
query="test",
search_type="hybrid",
collection=None,
limit=10,
threshold=0.7,
file_types=["invalid_type"],
)
assert result["success"] is False
assert "Invalid file type" in result["error"]
class TestIngestContentTool:
"""Tests for ingest_content MCP tool."""
@pytest.mark.asyncio
async def test_ingest_success(self):
"""Test successful ingestion."""
import server
from models import IngestResult
mock_collections = MagicMock()
mock_collections.ingest = AsyncMock(
return_value=IngestResult(
success=True,
chunks_created=3,
embeddings_generated=3,
source_path="/test/file.py",
collection="default",
chunk_ids=["id-1", "id-2", "id-3"],
)
)
server._collections = mock_collections
result = await server.ingest_content.fn(
project_id="proj-123",
agent_id="agent-456",
content="def hello(): pass",
source_path="/test/file.py",
collection="default",
chunk_type="text",
file_type=None,
metadata=None,
)
assert result["success"] is True
assert result["chunks_created"] == 3
assert len(result["chunk_ids"]) == 3
@pytest.mark.asyncio
async def test_ingest_invalid_chunk_type(self):
"""Test ingest with invalid chunk type."""
import server
result = await server.ingest_content.fn(
project_id="proj-123",
agent_id="agent-456",
content="test content",
chunk_type="invalid",
)
assert result["success"] is False
assert "Invalid chunk type" in result["error"]
@pytest.mark.asyncio
async def test_ingest_invalid_file_type(self):
"""Test ingest with invalid file type."""
import server
result = await server.ingest_content.fn(
project_id="proj-123",
agent_id="agent-456",
content="test content",
source_path=None,
collection="default",
chunk_type="text",
file_type="invalid",
metadata=None,
)
assert result["success"] is False
assert "Invalid file type" in result["error"]
class TestDeleteContentTool:
"""Tests for delete_content MCP tool."""
@pytest.mark.asyncio
async def test_delete_success(self):
"""Test successful deletion."""
import server
from models import DeleteResult
mock_collections = MagicMock()
mock_collections.delete = AsyncMock(
return_value=DeleteResult(
success=True,
chunks_deleted=5,
)
)
server._collections = mock_collections
result = await server.delete_content.fn(
project_id="proj-123",
agent_id="agent-456",
source_path="/test/file.py",
collection=None,
chunk_ids=None,
)
assert result["success"] is True
assert result["chunks_deleted"] == 5
class TestListCollectionsTool:
"""Tests for list_collections MCP tool."""
@pytest.mark.asyncio
async def test_list_collections_success(self):
"""Test listing collections."""
import server
from models import CollectionInfo, ListCollectionsResponse
mock_collections = MagicMock()
mock_collections.list_collections = AsyncMock(
return_value=ListCollectionsResponse(
project_id="proj-123",
collections=[
CollectionInfo(
name="collection-1",
project_id="proj-123",
chunk_count=100,
total_tokens=50000,
file_types=["python"],
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
],
total_collections=1,
)
)
server._collections = mock_collections
result = await server.list_collections.fn(
project_id="proj-123",
agent_id="agent-456",
)
assert result["success"] is True
assert result["total_collections"] == 1
assert len(result["collections"]) == 1
class TestGetCollectionStatsTool:
"""Tests for get_collection_stats MCP tool."""
@pytest.mark.asyncio
async def test_get_stats_success(self):
"""Test getting collection stats."""
import server
from models import CollectionStats
mock_collections = MagicMock()
mock_collections.get_collection_stats = AsyncMock(
return_value=CollectionStats(
collection="test-collection",
project_id="proj-123",
chunk_count=100,
unique_sources=10,
total_tokens=50000,
avg_chunk_size=500.0,
chunk_types={"code": 60, "text": 40},
file_types={"python": 50, "javascript": 10},
)
)
server._collections = mock_collections
result = await server.get_collection_stats.fn(
project_id="proj-123",
agent_id="agent-456",
collection="test-collection",
)
assert result["success"] is True
assert result["chunk_count"] == 100
assert result["unique_sources"] == 10
class TestUpdateDocumentTool:
"""Tests for update_document MCP tool."""
@pytest.mark.asyncio
async def test_update_success(self):
"""Test updating a document."""
import server
from models import IngestResult
mock_collections = MagicMock()
mock_collections.update_document = AsyncMock(
return_value=IngestResult(
success=True,
chunks_created=2,
embeddings_generated=2,
source_path="/test/file.py",
collection="default",
chunk_ids=["id-1", "id-2"],
)
)
server._collections = mock_collections
result = await server.update_document.fn(
project_id="proj-123",
agent_id="agent-456",
source_path="/test/file.py",
content="def updated(): pass",
collection="default",
chunk_type="text",
file_type=None,
metadata=None,
)
assert result["success"] is True
assert result["chunks_created"] == 2
@pytest.mark.asyncio
async def test_update_invalid_chunk_type(self):
"""Test update with invalid chunk type."""
import server
result = await server.update_document.fn(
project_id="proj-123",
agent_id="agent-456",
source_path="/test/file.py",
content="test",
chunk_type="invalid",
)
assert result["success"] is False
assert "Invalid chunk type" in result["error"]