Files
syndarix/mcp-servers/knowledge-base/tests/test_embeddings.py
Felipe Cardoso d0fc7f37ff feat(knowledge-base): implement Knowledge Base MCP Server (#57)
Implements RAG capabilities with pgvector for semantic search:

- Intelligent chunking strategies (code-aware, markdown-aware, text)
- Semantic search with vector similarity (HNSW index)
- Keyword search with PostgreSQL full-text search
- Hybrid search using Reciprocal Rank Fusion (RRF)
- Redis caching for embeddings
- Collection management (ingest, search, delete, stats)
- FastMCP tools: search_knowledge, ingest_content, delete_content,
  list_collections, get_collection_stats, update_document

Testing:
- 128 comprehensive tests covering all components
- 58% code coverage (database integration tests use mocks)
- Passes ruff linting and mypy type checking

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-03 21:33:26 +01:00

246 lines
7.7 KiB
Python

"""Tests for embedding generation module."""
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
class TestEmbeddingGenerator:
"""Tests for EmbeddingGenerator class."""
@pytest.fixture
def mock_http_response(self):
"""Create mock HTTP response."""
response = MagicMock()
response.status_code = 200
response.raise_for_status = MagicMock()
response.json.return_value = {
"result": {
"content": [
{
"text": json.dumps({
"embeddings": [[0.1] * 1536]
})
}
]
}
}
return response
@pytest.mark.asyncio
async def test_generate_single_embedding(self, settings, mock_redis, mock_http_response):
"""Test generating a single embedding."""
from embeddings import EmbeddingGenerator
generator = EmbeddingGenerator(settings=settings)
generator._redis = mock_redis
# Mock HTTP client
mock_client = AsyncMock()
mock_client.post.return_value = mock_http_response
generator._http_client = mock_client
embedding = await generator.generate(
text="Hello, world!",
project_id="proj-123",
agent_id="agent-456",
)
assert len(embedding) == 1536
mock_client.post.assert_called_once()
@pytest.mark.asyncio
async def test_generate_batch_embeddings(self, settings, mock_redis):
"""Test generating batch embeddings."""
from embeddings import EmbeddingGenerator
generator = EmbeddingGenerator(settings=settings)
generator._redis = mock_redis
# Mock HTTP client with batch response
mock_client = AsyncMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.raise_for_status = MagicMock()
mock_response.json.return_value = {
"result": {
"content": [
{
"text": json.dumps({
"embeddings": [[0.1] * 1536, [0.2] * 1536, [0.3] * 1536]
})
}
]
}
}
mock_client.post.return_value = mock_response
generator._http_client = mock_client
embeddings = await generator.generate_batch(
texts=["Text 1", "Text 2", "Text 3"],
project_id="proj-123",
agent_id="agent-456",
)
assert len(embeddings) == 3
assert all(len(e) == 1536 for e in embeddings)
@pytest.mark.asyncio
async def test_caching(self, settings, mock_redis):
"""Test embedding caching."""
from embeddings import EmbeddingGenerator
generator = EmbeddingGenerator(settings=settings)
generator._redis = mock_redis
# Pre-populate cache
cache_key = generator._cache_key("Hello, world!")
await mock_redis.setex(cache_key, 3600, json.dumps([0.5] * 1536))
# Mock HTTP client (should not be called)
mock_client = AsyncMock()
generator._http_client = mock_client
embedding = await generator.generate(
text="Hello, world!",
project_id="proj-123",
agent_id="agent-456",
)
# Should return cached embedding
assert len(embedding) == 1536
assert embedding[0] == 0.5
mock_client.post.assert_not_called()
@pytest.mark.asyncio
async def test_cache_miss(self, settings, mock_redis, mock_http_response):
"""Test embedding cache miss."""
from embeddings import EmbeddingGenerator
generator = EmbeddingGenerator(settings=settings)
generator._redis = mock_redis
mock_client = AsyncMock()
mock_client.post.return_value = mock_http_response
generator._http_client = mock_client
embedding = await generator.generate(
text="New text not in cache",
project_id="proj-123",
agent_id="agent-456",
)
assert len(embedding) == 1536
mock_client.post.assert_called_once()
def test_cache_key_generation(self, settings):
"""Test cache key generation."""
from embeddings import EmbeddingGenerator
generator = EmbeddingGenerator(settings=settings)
key1 = generator._cache_key("Hello")
key2 = generator._cache_key("Hello")
key3 = generator._cache_key("World")
assert key1 == key2
assert key1 != key3
assert key1.startswith("kb:emb:")
@pytest.mark.asyncio
async def test_dimension_validation(self, settings, mock_redis):
"""Test embedding dimension validation."""
from embeddings import EmbeddingGenerator
from exceptions import EmbeddingDimensionMismatchError
generator = EmbeddingGenerator(settings=settings)
generator._redis = mock_redis
# Mock HTTP client with wrong dimension
mock_client = AsyncMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.raise_for_status = MagicMock()
mock_response.json.return_value = {
"result": {
"content": [
{
"text": json.dumps({
"embeddings": [[0.1] * 768] # Wrong dimension
})
}
]
}
}
mock_client.post.return_value = mock_response
generator._http_client = mock_client
with pytest.raises(EmbeddingDimensionMismatchError):
await generator.generate(
text="Test text",
project_id="proj-123",
agent_id="agent-456",
)
@pytest.mark.asyncio
async def test_empty_batch(self, settings, mock_redis):
"""Test generating embeddings for empty batch."""
from embeddings import EmbeddingGenerator
generator = EmbeddingGenerator(settings=settings)
generator._redis = mock_redis
embeddings = await generator.generate_batch(
texts=[],
project_id="proj-123",
agent_id="agent-456",
)
assert embeddings == []
@pytest.mark.asyncio
async def test_initialize_and_close(self, settings):
"""Test initialize and close methods."""
from embeddings import EmbeddingGenerator
generator = EmbeddingGenerator(settings=settings)
# Mock successful initialization
with patch("embeddings.redis.from_url") as mock_redis_from_url:
mock_redis_client = AsyncMock()
mock_redis_client.ping = AsyncMock()
mock_redis_from_url.return_value = mock_redis_client
await generator.initialize()
assert generator._http_client is not None
await generator.close()
assert generator._http_client is None
class TestGlobalEmbeddingGenerator:
"""Tests for global embedding generator."""
def test_get_embedding_generator_singleton(self):
"""Test that get_embedding_generator returns singleton."""
from embeddings import get_embedding_generator, reset_embedding_generator
reset_embedding_generator()
gen1 = get_embedding_generator()
gen2 = get_embedding_generator()
assert gen1 is gen2
def test_reset_embedding_generator(self):
"""Test resetting embedding generator."""
from embeddings import get_embedding_generator, reset_embedding_generator
gen1 = get_embedding_generator()
reset_embedding_generator()
gen2 = get_embedding_generator()
assert gen1 is not gen2