forked from cardosofelipe/fast-next-template
Implements RAG capabilities with pgvector for semantic search: - Intelligent chunking strategies (code-aware, markdown-aware, text) - Semantic search with vector similarity (HNSW index) - Keyword search with PostgreSQL full-text search - Hybrid search using Reciprocal Rank Fusion (RRF) - Redis caching for embeddings - Collection management (ingest, search, delete, stats) - FastMCP tools: search_knowledge, ingest_content, delete_content, list_collections, get_collection_stats, update_document Testing: - 128 comprehensive tests covering all components - 58% code coverage (database integration tests use mocks) - Passes ruff linting and mypy type checking 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
246 lines
7.7 KiB
Python
246 lines
7.7 KiB
Python
"""Tests for embedding generation module."""
|
|
|
|
import json
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
|
|
class TestEmbeddingGenerator:
|
|
"""Tests for EmbeddingGenerator class."""
|
|
|
|
@pytest.fixture
|
|
def mock_http_response(self):
|
|
"""Create mock HTTP response."""
|
|
response = MagicMock()
|
|
response.status_code = 200
|
|
response.raise_for_status = MagicMock()
|
|
response.json.return_value = {
|
|
"result": {
|
|
"content": [
|
|
{
|
|
"text": json.dumps({
|
|
"embeddings": [[0.1] * 1536]
|
|
})
|
|
}
|
|
]
|
|
}
|
|
}
|
|
return response
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_single_embedding(self, settings, mock_redis, mock_http_response):
|
|
"""Test generating a single embedding."""
|
|
from embeddings import EmbeddingGenerator
|
|
|
|
generator = EmbeddingGenerator(settings=settings)
|
|
generator._redis = mock_redis
|
|
|
|
# Mock HTTP client
|
|
mock_client = AsyncMock()
|
|
mock_client.post.return_value = mock_http_response
|
|
generator._http_client = mock_client
|
|
|
|
embedding = await generator.generate(
|
|
text="Hello, world!",
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
)
|
|
|
|
assert len(embedding) == 1536
|
|
mock_client.post.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_batch_embeddings(self, settings, mock_redis):
|
|
"""Test generating batch embeddings."""
|
|
from embeddings import EmbeddingGenerator
|
|
|
|
generator = EmbeddingGenerator(settings=settings)
|
|
generator._redis = mock_redis
|
|
|
|
# Mock HTTP client with batch response
|
|
mock_client = AsyncMock()
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.raise_for_status = MagicMock()
|
|
mock_response.json.return_value = {
|
|
"result": {
|
|
"content": [
|
|
{
|
|
"text": json.dumps({
|
|
"embeddings": [[0.1] * 1536, [0.2] * 1536, [0.3] * 1536]
|
|
})
|
|
}
|
|
]
|
|
}
|
|
}
|
|
mock_client.post.return_value = mock_response
|
|
generator._http_client = mock_client
|
|
|
|
embeddings = await generator.generate_batch(
|
|
texts=["Text 1", "Text 2", "Text 3"],
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
)
|
|
|
|
assert len(embeddings) == 3
|
|
assert all(len(e) == 1536 for e in embeddings)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_caching(self, settings, mock_redis):
|
|
"""Test embedding caching."""
|
|
from embeddings import EmbeddingGenerator
|
|
|
|
generator = EmbeddingGenerator(settings=settings)
|
|
generator._redis = mock_redis
|
|
|
|
# Pre-populate cache
|
|
cache_key = generator._cache_key("Hello, world!")
|
|
await mock_redis.setex(cache_key, 3600, json.dumps([0.5] * 1536))
|
|
|
|
# Mock HTTP client (should not be called)
|
|
mock_client = AsyncMock()
|
|
generator._http_client = mock_client
|
|
|
|
embedding = await generator.generate(
|
|
text="Hello, world!",
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
)
|
|
|
|
# Should return cached embedding
|
|
assert len(embedding) == 1536
|
|
assert embedding[0] == 0.5
|
|
mock_client.post.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_miss(self, settings, mock_redis, mock_http_response):
|
|
"""Test embedding cache miss."""
|
|
from embeddings import EmbeddingGenerator
|
|
|
|
generator = EmbeddingGenerator(settings=settings)
|
|
generator._redis = mock_redis
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.post.return_value = mock_http_response
|
|
generator._http_client = mock_client
|
|
|
|
embedding = await generator.generate(
|
|
text="New text not in cache",
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
)
|
|
|
|
assert len(embedding) == 1536
|
|
mock_client.post.assert_called_once()
|
|
|
|
def test_cache_key_generation(self, settings):
|
|
"""Test cache key generation."""
|
|
from embeddings import EmbeddingGenerator
|
|
|
|
generator = EmbeddingGenerator(settings=settings)
|
|
|
|
key1 = generator._cache_key("Hello")
|
|
key2 = generator._cache_key("Hello")
|
|
key3 = generator._cache_key("World")
|
|
|
|
assert key1 == key2
|
|
assert key1 != key3
|
|
assert key1.startswith("kb:emb:")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_dimension_validation(self, settings, mock_redis):
|
|
"""Test embedding dimension validation."""
|
|
from embeddings import EmbeddingGenerator
|
|
from exceptions import EmbeddingDimensionMismatchError
|
|
|
|
generator = EmbeddingGenerator(settings=settings)
|
|
generator._redis = mock_redis
|
|
|
|
# Mock HTTP client with wrong dimension
|
|
mock_client = AsyncMock()
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.raise_for_status = MagicMock()
|
|
mock_response.json.return_value = {
|
|
"result": {
|
|
"content": [
|
|
{
|
|
"text": json.dumps({
|
|
"embeddings": [[0.1] * 768] # Wrong dimension
|
|
})
|
|
}
|
|
]
|
|
}
|
|
}
|
|
mock_client.post.return_value = mock_response
|
|
generator._http_client = mock_client
|
|
|
|
with pytest.raises(EmbeddingDimensionMismatchError):
|
|
await generator.generate(
|
|
text="Test text",
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_empty_batch(self, settings, mock_redis):
|
|
"""Test generating embeddings for empty batch."""
|
|
from embeddings import EmbeddingGenerator
|
|
|
|
generator = EmbeddingGenerator(settings=settings)
|
|
generator._redis = mock_redis
|
|
|
|
embeddings = await generator.generate_batch(
|
|
texts=[],
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
)
|
|
|
|
assert embeddings == []
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_initialize_and_close(self, settings):
|
|
"""Test initialize and close methods."""
|
|
from embeddings import EmbeddingGenerator
|
|
|
|
generator = EmbeddingGenerator(settings=settings)
|
|
|
|
# Mock successful initialization
|
|
with patch("embeddings.redis.from_url") as mock_redis_from_url:
|
|
mock_redis_client = AsyncMock()
|
|
mock_redis_client.ping = AsyncMock()
|
|
mock_redis_from_url.return_value = mock_redis_client
|
|
|
|
await generator.initialize()
|
|
|
|
assert generator._http_client is not None
|
|
|
|
await generator.close()
|
|
|
|
assert generator._http_client is None
|
|
|
|
|
|
class TestGlobalEmbeddingGenerator:
|
|
"""Tests for global embedding generator."""
|
|
|
|
def test_get_embedding_generator_singleton(self):
|
|
"""Test that get_embedding_generator returns singleton."""
|
|
from embeddings import get_embedding_generator, reset_embedding_generator
|
|
|
|
reset_embedding_generator()
|
|
gen1 = get_embedding_generator()
|
|
gen2 = get_embedding_generator()
|
|
|
|
assert gen1 is gen2
|
|
|
|
def test_reset_embedding_generator(self):
|
|
"""Test resetting embedding generator."""
|
|
from embeddings import get_embedding_generator, reset_embedding_generator
|
|
|
|
gen1 = get_embedding_generator()
|
|
reset_embedding_generator()
|
|
gen2 = get_embedding_generator()
|
|
|
|
assert gen1 is not gen2
|