"""Tests for embedding generation module.""" import json from unittest.mock import AsyncMock, MagicMock, patch import pytest class TestEmbeddingGenerator: """Tests for EmbeddingGenerator class.""" @pytest.fixture def mock_http_response(self): """Create mock HTTP response.""" response = MagicMock() response.status_code = 200 response.raise_for_status = MagicMock() response.json.return_value = { "result": { "content": [{"text": json.dumps({"embeddings": [[0.1] * 1536]})}] } } return response @pytest.mark.asyncio async def test_generate_single_embedding( self, settings, mock_redis, mock_http_response ): """Test generating a single embedding.""" from embeddings import EmbeddingGenerator generator = EmbeddingGenerator(settings=settings) generator._redis = mock_redis # Mock HTTP client mock_client = AsyncMock() mock_client.post.return_value = mock_http_response generator._http_client = mock_client embedding = await generator.generate( text="Hello, world!", project_id="proj-123", agent_id="agent-456", ) assert len(embedding) == 1536 mock_client.post.assert_called_once() @pytest.mark.asyncio async def test_generate_batch_embeddings(self, settings, mock_redis): """Test generating batch embeddings.""" from embeddings import EmbeddingGenerator generator = EmbeddingGenerator(settings=settings) generator._redis = mock_redis # Mock HTTP client with batch response mock_client = AsyncMock() mock_response = MagicMock() mock_response.status_code = 200 mock_response.raise_for_status = MagicMock() mock_response.json.return_value = { "result": { "content": [ { "text": json.dumps( {"embeddings": [[0.1] * 1536, [0.2] * 1536, [0.3] * 1536]} ) } ] } } mock_client.post.return_value = mock_response generator._http_client = mock_client embeddings = await generator.generate_batch( texts=["Text 1", "Text 2", "Text 3"], project_id="proj-123", agent_id="agent-456", ) assert len(embeddings) == 3 assert all(len(e) == 1536 for e in embeddings) @pytest.mark.asyncio async def test_caching(self, settings, mock_redis): """Test embedding caching.""" from embeddings import EmbeddingGenerator generator = EmbeddingGenerator(settings=settings) generator._redis = mock_redis # Pre-populate cache cache_key = generator._cache_key("Hello, world!") await mock_redis.setex(cache_key, 3600, json.dumps([0.5] * 1536)) # Mock HTTP client (should not be called) mock_client = AsyncMock() generator._http_client = mock_client embedding = await generator.generate( text="Hello, world!", project_id="proj-123", agent_id="agent-456", ) # Should return cached embedding assert len(embedding) == 1536 assert embedding[0] == 0.5 mock_client.post.assert_not_called() @pytest.mark.asyncio async def test_cache_miss(self, settings, mock_redis, mock_http_response): """Test embedding cache miss.""" from embeddings import EmbeddingGenerator generator = EmbeddingGenerator(settings=settings) generator._redis = mock_redis mock_client = AsyncMock() mock_client.post.return_value = mock_http_response generator._http_client = mock_client embedding = await generator.generate( text="New text not in cache", project_id="proj-123", agent_id="agent-456", ) assert len(embedding) == 1536 mock_client.post.assert_called_once() def test_cache_key_generation(self, settings): """Test cache key generation.""" from embeddings import EmbeddingGenerator generator = EmbeddingGenerator(settings=settings) key1 = generator._cache_key("Hello") key2 = generator._cache_key("Hello") key3 = generator._cache_key("World") assert key1 == key2 assert key1 != key3 assert key1.startswith("kb:emb:") @pytest.mark.asyncio async def test_dimension_validation(self, settings, mock_redis): """Test embedding dimension validation.""" from embeddings import EmbeddingGenerator from exceptions import EmbeddingDimensionMismatchError generator = EmbeddingGenerator(settings=settings) generator._redis = mock_redis # Mock HTTP client with wrong dimension mock_client = AsyncMock() mock_response = MagicMock() mock_response.status_code = 200 mock_response.raise_for_status = MagicMock() mock_response.json.return_value = { "result": { "content": [ { "text": json.dumps( { "embeddings": [[0.1] * 768] # Wrong dimension } ) } ] } } mock_client.post.return_value = mock_response generator._http_client = mock_client with pytest.raises(EmbeddingDimensionMismatchError): await generator.generate( text="Test text", project_id="proj-123", agent_id="agent-456", ) @pytest.mark.asyncio async def test_empty_batch(self, settings, mock_redis): """Test generating embeddings for empty batch.""" from embeddings import EmbeddingGenerator generator = EmbeddingGenerator(settings=settings) generator._redis = mock_redis embeddings = await generator.generate_batch( texts=[], project_id="proj-123", agent_id="agent-456", ) assert embeddings == [] @pytest.mark.asyncio async def test_initialize_and_close(self, settings): """Test initialize and close methods.""" from embeddings import EmbeddingGenerator generator = EmbeddingGenerator(settings=settings) # Mock successful initialization with patch("embeddings.redis.from_url") as mock_redis_from_url: mock_redis_client = AsyncMock() mock_redis_client.ping = AsyncMock() mock_redis_from_url.return_value = mock_redis_client await generator.initialize() assert generator._http_client is not None await generator.close() assert generator._http_client is None class TestGlobalEmbeddingGenerator: """Tests for global embedding generator.""" def test_get_embedding_generator_singleton(self): """Test that get_embedding_generator returns singleton.""" from embeddings import get_embedding_generator, reset_embedding_generator reset_embedding_generator() gen1 = get_embedding_generator() gen2 = get_embedding_generator() assert gen1 is gen2 def test_reset_embedding_generator(self): """Test resetting embedding generator.""" from embeddings import get_embedding_generator, reset_embedding_generator gen1 = get_embedding_generator() reset_embedding_generator() gen2 = get_embedding_generator() assert gen1 is not gen2