forked from cardosofelipe/fast-next-template
Improved code formatting, line breaks, and indentation across chunking logic and multiple test modules to enhance code clarity and maintain consistent style. No functional changes made.
244 lines
7.6 KiB
Python
244 lines
7.6 KiB
Python
"""Tests for embedding generation module."""
|
|
|
|
import json
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
|
|
class TestEmbeddingGenerator:
|
|
"""Tests for EmbeddingGenerator class."""
|
|
|
|
@pytest.fixture
|
|
def mock_http_response(self):
|
|
"""Create mock HTTP response."""
|
|
response = MagicMock()
|
|
response.status_code = 200
|
|
response.raise_for_status = MagicMock()
|
|
response.json.return_value = {
|
|
"result": {
|
|
"content": [{"text": json.dumps({"embeddings": [[0.1] * 1536]})}]
|
|
}
|
|
}
|
|
return response
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_single_embedding(
|
|
self, settings, mock_redis, mock_http_response
|
|
):
|
|
"""Test generating a single embedding."""
|
|
from embeddings import EmbeddingGenerator
|
|
|
|
generator = EmbeddingGenerator(settings=settings)
|
|
generator._redis = mock_redis
|
|
|
|
# Mock HTTP client
|
|
mock_client = AsyncMock()
|
|
mock_client.post.return_value = mock_http_response
|
|
generator._http_client = mock_client
|
|
|
|
embedding = await generator.generate(
|
|
text="Hello, world!",
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
)
|
|
|
|
assert len(embedding) == 1536
|
|
mock_client.post.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_batch_embeddings(self, settings, mock_redis):
|
|
"""Test generating batch embeddings."""
|
|
from embeddings import EmbeddingGenerator
|
|
|
|
generator = EmbeddingGenerator(settings=settings)
|
|
generator._redis = mock_redis
|
|
|
|
# Mock HTTP client with batch response
|
|
mock_client = AsyncMock()
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.raise_for_status = MagicMock()
|
|
mock_response.json.return_value = {
|
|
"result": {
|
|
"content": [
|
|
{
|
|
"text": json.dumps(
|
|
{"embeddings": [[0.1] * 1536, [0.2] * 1536, [0.3] * 1536]}
|
|
)
|
|
}
|
|
]
|
|
}
|
|
}
|
|
mock_client.post.return_value = mock_response
|
|
generator._http_client = mock_client
|
|
|
|
embeddings = await generator.generate_batch(
|
|
texts=["Text 1", "Text 2", "Text 3"],
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
)
|
|
|
|
assert len(embeddings) == 3
|
|
assert all(len(e) == 1536 for e in embeddings)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_caching(self, settings, mock_redis):
|
|
"""Test embedding caching."""
|
|
from embeddings import EmbeddingGenerator
|
|
|
|
generator = EmbeddingGenerator(settings=settings)
|
|
generator._redis = mock_redis
|
|
|
|
# Pre-populate cache
|
|
cache_key = generator._cache_key("Hello, world!")
|
|
await mock_redis.setex(cache_key, 3600, json.dumps([0.5] * 1536))
|
|
|
|
# Mock HTTP client (should not be called)
|
|
mock_client = AsyncMock()
|
|
generator._http_client = mock_client
|
|
|
|
embedding = await generator.generate(
|
|
text="Hello, world!",
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
)
|
|
|
|
# Should return cached embedding
|
|
assert len(embedding) == 1536
|
|
assert embedding[0] == 0.5
|
|
mock_client.post.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_miss(self, settings, mock_redis, mock_http_response):
|
|
"""Test embedding cache miss."""
|
|
from embeddings import EmbeddingGenerator
|
|
|
|
generator = EmbeddingGenerator(settings=settings)
|
|
generator._redis = mock_redis
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.post.return_value = mock_http_response
|
|
generator._http_client = mock_client
|
|
|
|
embedding = await generator.generate(
|
|
text="New text not in cache",
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
)
|
|
|
|
assert len(embedding) == 1536
|
|
mock_client.post.assert_called_once()
|
|
|
|
def test_cache_key_generation(self, settings):
|
|
"""Test cache key generation."""
|
|
from embeddings import EmbeddingGenerator
|
|
|
|
generator = EmbeddingGenerator(settings=settings)
|
|
|
|
key1 = generator._cache_key("Hello")
|
|
key2 = generator._cache_key("Hello")
|
|
key3 = generator._cache_key("World")
|
|
|
|
assert key1 == key2
|
|
assert key1 != key3
|
|
assert key1.startswith("kb:emb:")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_dimension_validation(self, settings, mock_redis):
|
|
"""Test embedding dimension validation."""
|
|
from embeddings import EmbeddingGenerator
|
|
from exceptions import EmbeddingDimensionMismatchError
|
|
|
|
generator = EmbeddingGenerator(settings=settings)
|
|
generator._redis = mock_redis
|
|
|
|
# Mock HTTP client with wrong dimension
|
|
mock_client = AsyncMock()
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.raise_for_status = MagicMock()
|
|
mock_response.json.return_value = {
|
|
"result": {
|
|
"content": [
|
|
{
|
|
"text": json.dumps(
|
|
{
|
|
"embeddings": [[0.1] * 768] # Wrong dimension
|
|
}
|
|
)
|
|
}
|
|
]
|
|
}
|
|
}
|
|
mock_client.post.return_value = mock_response
|
|
generator._http_client = mock_client
|
|
|
|
with pytest.raises(EmbeddingDimensionMismatchError):
|
|
await generator.generate(
|
|
text="Test text",
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_empty_batch(self, settings, mock_redis):
|
|
"""Test generating embeddings for empty batch."""
|
|
from embeddings import EmbeddingGenerator
|
|
|
|
generator = EmbeddingGenerator(settings=settings)
|
|
generator._redis = mock_redis
|
|
|
|
embeddings = await generator.generate_batch(
|
|
texts=[],
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
)
|
|
|
|
assert embeddings == []
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_initialize_and_close(self, settings):
|
|
"""Test initialize and close methods."""
|
|
from embeddings import EmbeddingGenerator
|
|
|
|
generator = EmbeddingGenerator(settings=settings)
|
|
|
|
# Mock successful initialization
|
|
with patch("embeddings.redis.from_url") as mock_redis_from_url:
|
|
mock_redis_client = AsyncMock()
|
|
mock_redis_client.ping = AsyncMock()
|
|
mock_redis_from_url.return_value = mock_redis_client
|
|
|
|
await generator.initialize()
|
|
|
|
assert generator._http_client is not None
|
|
|
|
await generator.close()
|
|
|
|
assert generator._http_client is None
|
|
|
|
|
|
class TestGlobalEmbeddingGenerator:
|
|
"""Tests for global embedding generator."""
|
|
|
|
def test_get_embedding_generator_singleton(self):
|
|
"""Test that get_embedding_generator returns singleton."""
|
|
from embeddings import get_embedding_generator, reset_embedding_generator
|
|
|
|
reset_embedding_generator()
|
|
gen1 = get_embedding_generator()
|
|
gen2 = get_embedding_generator()
|
|
|
|
assert gen1 is gen2
|
|
|
|
def test_reset_embedding_generator(self):
|
|
"""Test resetting embedding generator."""
|
|
from embeddings import get_embedding_generator, reset_embedding_generator
|
|
|
|
gen1 = get_embedding_generator()
|
|
reset_embedding_generator()
|
|
gen2 = get_embedding_generator()
|
|
|
|
assert gen1 is not gen2
|