feat(knowledge-base): implement Knowledge Base MCP Server (#57)
Implements RAG capabilities with pgvector for semantic search: - Intelligent chunking strategies (code-aware, markdown-aware, text) - Semantic search with vector similarity (HNSW index) - Keyword search with PostgreSQL full-text search - Hybrid search using Reciprocal Rank Fusion (RRF) - Redis caching for embeddings - Collection management (ingest, search, delete, stats) - FastMCP tools: search_knowledge, ingest_content, delete_content, list_collections, get_collection_stats, update_document Testing: - 128 comprehensive tests covering all components - 58% code coverage (database integration tests use mocks) - Passes ruff linting and mypy type checking 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
245
mcp-servers/knowledge-base/tests/test_embeddings.py
Normal file
245
mcp-servers/knowledge-base/tests/test_embeddings.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""Tests for embedding generation module."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestEmbeddingGenerator:
|
||||
"""Tests for EmbeddingGenerator class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_http_response(self):
|
||||
"""Create mock HTTP response."""
|
||||
response = MagicMock()
|
||||
response.status_code = 200
|
||||
response.raise_for_status = MagicMock()
|
||||
response.json.return_value = {
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"text": json.dumps({
|
||||
"embeddings": [[0.1] * 1536]
|
||||
})
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
return response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_single_embedding(self, settings, mock_redis, mock_http_response):
|
||||
"""Test generating a single embedding."""
|
||||
from embeddings import EmbeddingGenerator
|
||||
|
||||
generator = EmbeddingGenerator(settings=settings)
|
||||
generator._redis = mock_redis
|
||||
|
||||
# Mock HTTP client
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_http_response
|
||||
generator._http_client = mock_client
|
||||
|
||||
embedding = await generator.generate(
|
||||
text="Hello, world!",
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
)
|
||||
|
||||
assert len(embedding) == 1536
|
||||
mock_client.post.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_batch_embeddings(self, settings, mock_redis):
|
||||
"""Test generating batch embeddings."""
|
||||
from embeddings import EmbeddingGenerator
|
||||
|
||||
generator = EmbeddingGenerator(settings=settings)
|
||||
generator._redis = mock_redis
|
||||
|
||||
# Mock HTTP client with batch response
|
||||
mock_client = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"text": json.dumps({
|
||||
"embeddings": [[0.1] * 1536, [0.2] * 1536, [0.3] * 1536]
|
||||
})
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
mock_client.post.return_value = mock_response
|
||||
generator._http_client = mock_client
|
||||
|
||||
embeddings = await generator.generate_batch(
|
||||
texts=["Text 1", "Text 2", "Text 3"],
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
)
|
||||
|
||||
assert len(embeddings) == 3
|
||||
assert all(len(e) == 1536 for e in embeddings)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caching(self, settings, mock_redis):
|
||||
"""Test embedding caching."""
|
||||
from embeddings import EmbeddingGenerator
|
||||
|
||||
generator = EmbeddingGenerator(settings=settings)
|
||||
generator._redis = mock_redis
|
||||
|
||||
# Pre-populate cache
|
||||
cache_key = generator._cache_key("Hello, world!")
|
||||
await mock_redis.setex(cache_key, 3600, json.dumps([0.5] * 1536))
|
||||
|
||||
# Mock HTTP client (should not be called)
|
||||
mock_client = AsyncMock()
|
||||
generator._http_client = mock_client
|
||||
|
||||
embedding = await generator.generate(
|
||||
text="Hello, world!",
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
)
|
||||
|
||||
# Should return cached embedding
|
||||
assert len(embedding) == 1536
|
||||
assert embedding[0] == 0.5
|
||||
mock_client.post.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_miss(self, settings, mock_redis, mock_http_response):
|
||||
"""Test embedding cache miss."""
|
||||
from embeddings import EmbeddingGenerator
|
||||
|
||||
generator = EmbeddingGenerator(settings=settings)
|
||||
generator._redis = mock_redis
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_http_response
|
||||
generator._http_client = mock_client
|
||||
|
||||
embedding = await generator.generate(
|
||||
text="New text not in cache",
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
)
|
||||
|
||||
assert len(embedding) == 1536
|
||||
mock_client.post.assert_called_once()
|
||||
|
||||
def test_cache_key_generation(self, settings):
|
||||
"""Test cache key generation."""
|
||||
from embeddings import EmbeddingGenerator
|
||||
|
||||
generator = EmbeddingGenerator(settings=settings)
|
||||
|
||||
key1 = generator._cache_key("Hello")
|
||||
key2 = generator._cache_key("Hello")
|
||||
key3 = generator._cache_key("World")
|
||||
|
||||
assert key1 == key2
|
||||
assert key1 != key3
|
||||
assert key1.startswith("kb:emb:")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dimension_validation(self, settings, mock_redis):
|
||||
"""Test embedding dimension validation."""
|
||||
from embeddings import EmbeddingGenerator
|
||||
from exceptions import EmbeddingDimensionMismatchError
|
||||
|
||||
generator = EmbeddingGenerator(settings=settings)
|
||||
generator._redis = mock_redis
|
||||
|
||||
# Mock HTTP client with wrong dimension
|
||||
mock_client = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"text": json.dumps({
|
||||
"embeddings": [[0.1] * 768] # Wrong dimension
|
||||
})
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
mock_client.post.return_value = mock_response
|
||||
generator._http_client = mock_client
|
||||
|
||||
with pytest.raises(EmbeddingDimensionMismatchError):
|
||||
await generator.generate(
|
||||
text="Test text",
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_batch(self, settings, mock_redis):
|
||||
"""Test generating embeddings for empty batch."""
|
||||
from embeddings import EmbeddingGenerator
|
||||
|
||||
generator = EmbeddingGenerator(settings=settings)
|
||||
generator._redis = mock_redis
|
||||
|
||||
embeddings = await generator.generate_batch(
|
||||
texts=[],
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
)
|
||||
|
||||
assert embeddings == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_and_close(self, settings):
|
||||
"""Test initialize and close methods."""
|
||||
from embeddings import EmbeddingGenerator
|
||||
|
||||
generator = EmbeddingGenerator(settings=settings)
|
||||
|
||||
# Mock successful initialization
|
||||
with patch("embeddings.redis.from_url") as mock_redis_from_url:
|
||||
mock_redis_client = AsyncMock()
|
||||
mock_redis_client.ping = AsyncMock()
|
||||
mock_redis_from_url.return_value = mock_redis_client
|
||||
|
||||
await generator.initialize()
|
||||
|
||||
assert generator._http_client is not None
|
||||
|
||||
await generator.close()
|
||||
|
||||
assert generator._http_client is None
|
||||
|
||||
|
||||
class TestGlobalEmbeddingGenerator:
|
||||
"""Tests for global embedding generator."""
|
||||
|
||||
def test_get_embedding_generator_singleton(self):
|
||||
"""Test that get_embedding_generator returns singleton."""
|
||||
from embeddings import get_embedding_generator, reset_embedding_generator
|
||||
|
||||
reset_embedding_generator()
|
||||
gen1 = get_embedding_generator()
|
||||
gen2 = get_embedding_generator()
|
||||
|
||||
assert gen1 is gen2
|
||||
|
||||
def test_reset_embedding_generator(self):
|
||||
"""Test resetting embedding generator."""
|
||||
from embeddings import get_embedding_generator, reset_embedding_generator
|
||||
|
||||
gen1 = get_embedding_generator()
|
||||
reset_embedding_generator()
|
||||
gen2 = get_embedding_generator()
|
||||
|
||||
assert gen1 is not gen2
|
||||
Reference in New Issue
Block a user