Files
syndarix/mcp-servers/knowledge-base/tests/test_embeddings.py
Felipe Cardoso 51404216ae refactor(knowledge-base mcp server): adjust formatting for consistency and readability
Improved code formatting, line breaks, and indentation across chunking logic and multiple test modules to enhance code clarity and maintain consistent style. No functional changes made.
2026-01-06 17:20:31 +01:00

244 lines
7.6 KiB
Python

"""Tests for embedding generation module."""
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
class TestEmbeddingGenerator:
"""Tests for EmbeddingGenerator class."""
@pytest.fixture
def mock_http_response(self):
"""Create mock HTTP response."""
response = MagicMock()
response.status_code = 200
response.raise_for_status = MagicMock()
response.json.return_value = {
"result": {
"content": [{"text": json.dumps({"embeddings": [[0.1] * 1536]})}]
}
}
return response
@pytest.mark.asyncio
async def test_generate_single_embedding(
self, settings, mock_redis, mock_http_response
):
"""Test generating a single embedding."""
from embeddings import EmbeddingGenerator
generator = EmbeddingGenerator(settings=settings)
generator._redis = mock_redis
# Mock HTTP client
mock_client = AsyncMock()
mock_client.post.return_value = mock_http_response
generator._http_client = mock_client
embedding = await generator.generate(
text="Hello, world!",
project_id="proj-123",
agent_id="agent-456",
)
assert len(embedding) == 1536
mock_client.post.assert_called_once()
@pytest.mark.asyncio
async def test_generate_batch_embeddings(self, settings, mock_redis):
"""Test generating batch embeddings."""
from embeddings import EmbeddingGenerator
generator = EmbeddingGenerator(settings=settings)
generator._redis = mock_redis
# Mock HTTP client with batch response
mock_client = AsyncMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.raise_for_status = MagicMock()
mock_response.json.return_value = {
"result": {
"content": [
{
"text": json.dumps(
{"embeddings": [[0.1] * 1536, [0.2] * 1536, [0.3] * 1536]}
)
}
]
}
}
mock_client.post.return_value = mock_response
generator._http_client = mock_client
embeddings = await generator.generate_batch(
texts=["Text 1", "Text 2", "Text 3"],
project_id="proj-123",
agent_id="agent-456",
)
assert len(embeddings) == 3
assert all(len(e) == 1536 for e in embeddings)
@pytest.mark.asyncio
async def test_caching(self, settings, mock_redis):
"""Test embedding caching."""
from embeddings import EmbeddingGenerator
generator = EmbeddingGenerator(settings=settings)
generator._redis = mock_redis
# Pre-populate cache
cache_key = generator._cache_key("Hello, world!")
await mock_redis.setex(cache_key, 3600, json.dumps([0.5] * 1536))
# Mock HTTP client (should not be called)
mock_client = AsyncMock()
generator._http_client = mock_client
embedding = await generator.generate(
text="Hello, world!",
project_id="proj-123",
agent_id="agent-456",
)
# Should return cached embedding
assert len(embedding) == 1536
assert embedding[0] == 0.5
mock_client.post.assert_not_called()
@pytest.mark.asyncio
async def test_cache_miss(self, settings, mock_redis, mock_http_response):
"""Test embedding cache miss."""
from embeddings import EmbeddingGenerator
generator = EmbeddingGenerator(settings=settings)
generator._redis = mock_redis
mock_client = AsyncMock()
mock_client.post.return_value = mock_http_response
generator._http_client = mock_client
embedding = await generator.generate(
text="New text not in cache",
project_id="proj-123",
agent_id="agent-456",
)
assert len(embedding) == 1536
mock_client.post.assert_called_once()
def test_cache_key_generation(self, settings):
"""Test cache key generation."""
from embeddings import EmbeddingGenerator
generator = EmbeddingGenerator(settings=settings)
key1 = generator._cache_key("Hello")
key2 = generator._cache_key("Hello")
key3 = generator._cache_key("World")
assert key1 == key2
assert key1 != key3
assert key1.startswith("kb:emb:")
@pytest.mark.asyncio
async def test_dimension_validation(self, settings, mock_redis):
"""Test embedding dimension validation."""
from embeddings import EmbeddingGenerator
from exceptions import EmbeddingDimensionMismatchError
generator = EmbeddingGenerator(settings=settings)
generator._redis = mock_redis
# Mock HTTP client with wrong dimension
mock_client = AsyncMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.raise_for_status = MagicMock()
mock_response.json.return_value = {
"result": {
"content": [
{
"text": json.dumps(
{
"embeddings": [[0.1] * 768] # Wrong dimension
}
)
}
]
}
}
mock_client.post.return_value = mock_response
generator._http_client = mock_client
with pytest.raises(EmbeddingDimensionMismatchError):
await generator.generate(
text="Test text",
project_id="proj-123",
agent_id="agent-456",
)
@pytest.mark.asyncio
async def test_empty_batch(self, settings, mock_redis):
"""Test generating embeddings for empty batch."""
from embeddings import EmbeddingGenerator
generator = EmbeddingGenerator(settings=settings)
generator._redis = mock_redis
embeddings = await generator.generate_batch(
texts=[],
project_id="proj-123",
agent_id="agent-456",
)
assert embeddings == []
@pytest.mark.asyncio
async def test_initialize_and_close(self, settings):
"""Test initialize and close methods."""
from embeddings import EmbeddingGenerator
generator = EmbeddingGenerator(settings=settings)
# Mock successful initialization
with patch("embeddings.redis.from_url") as mock_redis_from_url:
mock_redis_client = AsyncMock()
mock_redis_client.ping = AsyncMock()
mock_redis_from_url.return_value = mock_redis_client
await generator.initialize()
assert generator._http_client is not None
await generator.close()
assert generator._http_client is None
class TestGlobalEmbeddingGenerator:
"""Tests for global embedding generator."""
def test_get_embedding_generator_singleton(self):
"""Test that get_embedding_generator returns singleton."""
from embeddings import get_embedding_generator, reset_embedding_generator
reset_embedding_generator()
gen1 = get_embedding_generator()
gen2 = get_embedding_generator()
assert gen1 is gen2
def test_reset_embedding_generator(self):
"""Test resetting embedding generator."""
from embeddings import get_embedding_generator, reset_embedding_generator
gen1 = get_embedding_generator()
reset_embedding_generator()
gen2 = get_embedding_generator()
assert gen1 is not gen2