forked from cardosofelipe/fast-next-template
Implements RAG capabilities with pgvector for semantic search: - Intelligent chunking strategies (code-aware, markdown-aware, text) - Semantic search with vector similarity (HNSW index) - Keyword search with PostgreSQL full-text search - Hybrid search using Reciprocal Rank Fusion (RRF) - Redis caching for embeddings - Collection management (ingest, search, delete, stats) - FastMCP tools: search_knowledge, ingest_content, delete_content, list_collections, get_collection_stats, update_document Testing: - 128 comprehensive tests covering all components - 58% code coverage (database integration tests use mocks) - Passes ruff linting and mypy type checking 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
427 lines
14 KiB
Python
427 lines
14 KiB
Python
"""
|
|
Embedding generation for Knowledge Base MCP Server.
|
|
|
|
Generates vector embeddings via the LLM Gateway MCP server
|
|
with caching support using Redis.
|
|
"""
|
|
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
|
|
import httpx
|
|
import redis.asyncio as redis
|
|
|
|
from config import Settings, get_settings
|
|
from exceptions import (
|
|
EmbeddingDimensionMismatchError,
|
|
EmbeddingGenerationError,
|
|
ErrorCode,
|
|
KnowledgeBaseError,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class EmbeddingGenerator:
|
|
"""
|
|
Generates embeddings via LLM Gateway.
|
|
|
|
Features:
|
|
- Batched embedding generation
|
|
- Redis caching for deduplication
|
|
- Automatic retry on transient failures
|
|
"""
|
|
|
|
def __init__(self, settings: Settings | None = None) -> None:
|
|
"""Initialize embedding generator."""
|
|
self._settings = settings or get_settings()
|
|
self._redis: redis.Redis | None = None # type: ignore[type-arg]
|
|
self._http_client: httpx.AsyncClient | None = None
|
|
|
|
@property
|
|
def redis_client(self) -> redis.Redis: # type: ignore[type-arg]
|
|
"""Get Redis client, raising if not initialized."""
|
|
if self._redis is None:
|
|
raise KnowledgeBaseError(
|
|
message="Redis client not initialized",
|
|
code=ErrorCode.INTERNAL_ERROR,
|
|
)
|
|
return self._redis
|
|
|
|
@property
|
|
def http_client(self) -> httpx.AsyncClient:
|
|
"""Get HTTP client, raising if not initialized."""
|
|
if self._http_client is None:
|
|
raise KnowledgeBaseError(
|
|
message="HTTP client not initialized",
|
|
code=ErrorCode.INTERNAL_ERROR,
|
|
)
|
|
return self._http_client
|
|
|
|
async def initialize(self) -> None:
|
|
"""Initialize Redis and HTTP clients."""
|
|
try:
|
|
self._redis = redis.from_url(
|
|
self._settings.redis_url,
|
|
encoding="utf-8",
|
|
decode_responses=True,
|
|
)
|
|
# Test connection
|
|
await self._redis.ping() # type: ignore[misc]
|
|
logger.info("Redis connection established for embedding cache")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Redis connection failed, caching disabled: {e}")
|
|
self._redis = None
|
|
|
|
self._http_client = httpx.AsyncClient(
|
|
base_url=self._settings.llm_gateway_url,
|
|
timeout=httpx.Timeout(60.0, connect=10.0),
|
|
)
|
|
logger.info("Embedding generator initialized")
|
|
|
|
async def close(self) -> None:
|
|
"""Close connections."""
|
|
if self._redis:
|
|
await self._redis.close()
|
|
self._redis = None
|
|
|
|
if self._http_client:
|
|
await self._http_client.aclose()
|
|
self._http_client = None
|
|
|
|
logger.info("Embedding generator closed")
|
|
|
|
def _cache_key(self, text: str) -> str:
|
|
"""Generate cache key for a text."""
|
|
text_hash = hashlib.sha256(text.encode()).hexdigest()[:32]
|
|
model = self._settings.embedding_model
|
|
return f"kb:emb:{model}:{text_hash}"
|
|
|
|
async def _get_cached(self, text: str) -> list[float] | None:
|
|
"""Get cached embedding if available."""
|
|
if not self._redis:
|
|
return None
|
|
|
|
try:
|
|
key = self._cache_key(text)
|
|
cached = await self._redis.get(key)
|
|
if cached:
|
|
logger.debug(f"Cache hit for embedding: {key[:20]}...")
|
|
return json.loads(cached)
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Cache read error: {e}")
|
|
return None
|
|
|
|
async def _set_cached(self, text: str, embedding: list[float]) -> None:
|
|
"""Cache an embedding."""
|
|
if not self._redis:
|
|
return
|
|
|
|
try:
|
|
key = self._cache_key(text)
|
|
await self._redis.setex(
|
|
key,
|
|
self._settings.embedding_cache_ttl,
|
|
json.dumps(embedding),
|
|
)
|
|
logger.debug(f"Cached embedding: {key[:20]}...")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Cache write error: {e}")
|
|
|
|
async def _get_cached_batch(
|
|
self, texts: list[str]
|
|
) -> tuple[list[list[float] | None], list[int]]:
|
|
"""
|
|
Get cached embeddings for a batch of texts.
|
|
|
|
Returns:
|
|
Tuple of (embeddings list with None for misses, indices of misses)
|
|
"""
|
|
if not self._redis:
|
|
return [None] * len(texts), list(range(len(texts)))
|
|
|
|
try:
|
|
keys = [self._cache_key(text) for text in texts]
|
|
cached_values = await self._redis.mget(keys)
|
|
|
|
embeddings: list[list[float] | None] = []
|
|
missing_indices: list[int] = []
|
|
|
|
for i, cached in enumerate(cached_values):
|
|
if cached:
|
|
embeddings.append(json.loads(cached))
|
|
else:
|
|
embeddings.append(None)
|
|
missing_indices.append(i)
|
|
|
|
cache_hits = len(texts) - len(missing_indices)
|
|
if cache_hits > 0:
|
|
logger.debug(f"Batch cache hits: {cache_hits}/{len(texts)}")
|
|
|
|
return embeddings, missing_indices
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Batch cache read error: {e}")
|
|
return [None] * len(texts), list(range(len(texts)))
|
|
|
|
async def _set_cached_batch(
|
|
self, texts: list[str], embeddings: list[list[float]]
|
|
) -> None:
|
|
"""Cache a batch of embeddings."""
|
|
if not self._redis or len(texts) != len(embeddings):
|
|
return
|
|
|
|
try:
|
|
pipe = self._redis.pipeline()
|
|
for text, embedding in zip(texts, embeddings, strict=True):
|
|
key = self._cache_key(text)
|
|
pipe.setex(
|
|
key,
|
|
self._settings.embedding_cache_ttl,
|
|
json.dumps(embedding),
|
|
)
|
|
await pipe.execute()
|
|
logger.debug(f"Cached {len(texts)} embeddings in batch")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Batch cache write error: {e}")
|
|
|
|
async def generate(
|
|
self,
|
|
text: str,
|
|
project_id: str = "system",
|
|
agent_id: str = "knowledge-base",
|
|
) -> list[float]:
|
|
"""
|
|
Generate embedding for a single text.
|
|
|
|
Args:
|
|
text: Text to embed
|
|
project_id: Project ID for cost attribution
|
|
agent_id: Agent ID for cost attribution
|
|
|
|
Returns:
|
|
Embedding vector
|
|
"""
|
|
# Check cache first
|
|
cached = await self._get_cached(text)
|
|
if cached:
|
|
return cached
|
|
|
|
# Generate via LLM Gateway
|
|
embeddings = await self._call_llm_gateway(
|
|
[text], project_id, agent_id
|
|
)
|
|
|
|
if not embeddings:
|
|
raise EmbeddingGenerationError(
|
|
message="No embedding returned from LLM Gateway",
|
|
texts_count=1,
|
|
)
|
|
|
|
embedding = embeddings[0]
|
|
|
|
# Validate dimension
|
|
if len(embedding) != self._settings.embedding_dimension:
|
|
raise EmbeddingDimensionMismatchError(
|
|
expected=self._settings.embedding_dimension,
|
|
actual=len(embedding),
|
|
)
|
|
|
|
# Cache the result
|
|
await self._set_cached(text, embedding)
|
|
|
|
return embedding
|
|
|
|
async def generate_batch(
|
|
self,
|
|
texts: list[str],
|
|
project_id: str = "system",
|
|
agent_id: str = "knowledge-base",
|
|
) -> list[list[float]]:
|
|
"""
|
|
Generate embeddings for multiple texts.
|
|
|
|
Uses caching and batches requests to LLM Gateway.
|
|
|
|
Args:
|
|
texts: List of texts to embed
|
|
project_id: Project ID for cost attribution
|
|
agent_id: Agent ID for cost attribution
|
|
|
|
Returns:
|
|
List of embedding vectors
|
|
"""
|
|
if not texts:
|
|
return []
|
|
|
|
# Check cache for existing embeddings
|
|
cached_embeddings, missing_indices = await self._get_cached_batch(texts)
|
|
|
|
# If all cached, return immediately
|
|
if not missing_indices:
|
|
logger.debug(f"All {len(texts)} embeddings served from cache")
|
|
return [e for e in cached_embeddings if e is not None]
|
|
|
|
# Get texts that need embedding
|
|
texts_to_embed = [texts[i] for i in missing_indices]
|
|
|
|
# Generate embeddings in batches
|
|
new_embeddings: list[list[float]] = []
|
|
batch_size = self._settings.embedding_batch_size
|
|
|
|
for i in range(0, len(texts_to_embed), batch_size):
|
|
batch = texts_to_embed[i : i + batch_size]
|
|
batch_embeddings = await self._call_llm_gateway(
|
|
batch, project_id, agent_id
|
|
)
|
|
new_embeddings.extend(batch_embeddings)
|
|
|
|
# Validate dimensions
|
|
for embedding in new_embeddings:
|
|
if len(embedding) != self._settings.embedding_dimension:
|
|
raise EmbeddingDimensionMismatchError(
|
|
expected=self._settings.embedding_dimension,
|
|
actual=len(embedding),
|
|
)
|
|
|
|
# Cache new embeddings
|
|
await self._set_cached_batch(texts_to_embed, new_embeddings)
|
|
|
|
# Combine cached and new embeddings
|
|
result: list[list[float]] = []
|
|
new_idx = 0
|
|
|
|
for i in range(len(texts)):
|
|
if cached_embeddings[i] is not None:
|
|
result.append(cached_embeddings[i]) # type: ignore[arg-type]
|
|
else:
|
|
result.append(new_embeddings[new_idx])
|
|
new_idx += 1
|
|
|
|
logger.info(
|
|
f"Generated {len(new_embeddings)} embeddings, "
|
|
f"{len(texts) - len(missing_indices)} from cache"
|
|
)
|
|
|
|
return result
|
|
|
|
async def _call_llm_gateway(
|
|
self,
|
|
texts: list[str],
|
|
project_id: str,
|
|
agent_id: str,
|
|
) -> list[list[float]]:
|
|
"""
|
|
Call LLM Gateway to generate embeddings.
|
|
|
|
Uses JSON-RPC 2.0 protocol to call the embedding tool.
|
|
"""
|
|
try:
|
|
# JSON-RPC 2.0 request for embedding tool
|
|
request = {
|
|
"jsonrpc": "2.0",
|
|
"method": "tools/call",
|
|
"params": {
|
|
"name": "generate_embeddings",
|
|
"arguments": {
|
|
"project_id": project_id,
|
|
"agent_id": agent_id,
|
|
"texts": texts,
|
|
"model": self._settings.embedding_model,
|
|
},
|
|
},
|
|
"id": 1,
|
|
}
|
|
|
|
response = await self.http_client.post("/mcp", json=request)
|
|
response.raise_for_status()
|
|
|
|
result = response.json()
|
|
|
|
if "error" in result:
|
|
error = result["error"]
|
|
raise EmbeddingGenerationError(
|
|
message=f"LLM Gateway error: {error.get('message', 'Unknown')}",
|
|
texts_count=len(texts),
|
|
details=error.get("data"),
|
|
)
|
|
|
|
# Extract embeddings from response
|
|
content = result.get("result", {}).get("content", [])
|
|
if not content:
|
|
raise EmbeddingGenerationError(
|
|
message="Empty response from LLM Gateway",
|
|
texts_count=len(texts),
|
|
)
|
|
|
|
# Parse the response content
|
|
# LLM Gateway returns embeddings in content[0].text as JSON
|
|
embeddings_data = content[0].get("text", "")
|
|
if isinstance(embeddings_data, str):
|
|
embeddings_data = json.loads(embeddings_data)
|
|
|
|
embeddings = embeddings_data.get("embeddings", [])
|
|
|
|
if len(embeddings) != len(texts):
|
|
raise EmbeddingGenerationError(
|
|
message=f"Embedding count mismatch: expected {len(texts)}, got {len(embeddings)}",
|
|
texts_count=len(texts),
|
|
)
|
|
|
|
return embeddings
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
logger.error(f"LLM Gateway HTTP error: {e}")
|
|
raise EmbeddingGenerationError(
|
|
message=f"LLM Gateway request failed: {e.response.status_code}",
|
|
texts_count=len(texts),
|
|
cause=e,
|
|
)
|
|
except httpx.RequestError as e:
|
|
logger.error(f"LLM Gateway request error: {e}")
|
|
raise EmbeddingGenerationError(
|
|
message=f"Failed to connect to LLM Gateway: {e}",
|
|
texts_count=len(texts),
|
|
cause=e,
|
|
)
|
|
except json.JSONDecodeError as e:
|
|
logger.error(f"Invalid JSON response from LLM Gateway: {e}")
|
|
raise EmbeddingGenerationError(
|
|
message="Invalid response format from LLM Gateway",
|
|
texts_count=len(texts),
|
|
cause=e,
|
|
)
|
|
except EmbeddingGenerationError:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error generating embeddings: {e}")
|
|
raise EmbeddingGenerationError(
|
|
message=f"Unexpected error: {e}",
|
|
texts_count=len(texts),
|
|
cause=e,
|
|
)
|
|
|
|
|
|
# Global embedding generator instance (lazy initialization)
|
|
_embedding_generator: EmbeddingGenerator | None = None
|
|
|
|
|
|
def get_embedding_generator() -> EmbeddingGenerator:
|
|
"""Get the global embedding generator instance."""
|
|
global _embedding_generator
|
|
if _embedding_generator is None:
|
|
_embedding_generator = EmbeddingGenerator()
|
|
return _embedding_generator
|
|
|
|
|
|
def reset_embedding_generator() -> None:
|
|
"""Reset the global embedding generator (for testing)."""
|
|
global _embedding_generator
|
|
_embedding_generator = None
|