forked from cardosofelipe/fast-next-template
Improved code formatting, line breaks, and indentation across chunking logic and multiple test modules to enhance code clarity and maintain consistent style. No functional changes made.
423 lines
14 KiB
Python
423 lines
14 KiB
Python
"""
|
|
Embedding generation for Knowledge Base MCP Server.
|
|
|
|
Generates vector embeddings via the LLM Gateway MCP server
|
|
with caching support using Redis.
|
|
"""
|
|
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
|
|
import httpx
|
|
import redis.asyncio as redis
|
|
|
|
from config import Settings, get_settings
|
|
from exceptions import (
|
|
EmbeddingDimensionMismatchError,
|
|
EmbeddingGenerationError,
|
|
ErrorCode,
|
|
KnowledgeBaseError,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class EmbeddingGenerator:
|
|
"""
|
|
Generates embeddings via LLM Gateway.
|
|
|
|
Features:
|
|
- Batched embedding generation
|
|
- Redis caching for deduplication
|
|
- Automatic retry on transient failures
|
|
"""
|
|
|
|
def __init__(self, settings: Settings | None = None) -> None:
|
|
"""Initialize embedding generator."""
|
|
self._settings = settings or get_settings()
|
|
self._redis: redis.Redis | None = None # type: ignore[type-arg]
|
|
self._http_client: httpx.AsyncClient | None = None
|
|
|
|
@property
|
|
def redis_client(self) -> redis.Redis: # type: ignore[type-arg]
|
|
"""Get Redis client, raising if not initialized."""
|
|
if self._redis is None:
|
|
raise KnowledgeBaseError(
|
|
message="Redis client not initialized",
|
|
code=ErrorCode.INTERNAL_ERROR,
|
|
)
|
|
return self._redis
|
|
|
|
@property
|
|
def http_client(self) -> httpx.AsyncClient:
|
|
"""Get HTTP client, raising if not initialized."""
|
|
if self._http_client is None:
|
|
raise KnowledgeBaseError(
|
|
message="HTTP client not initialized",
|
|
code=ErrorCode.INTERNAL_ERROR,
|
|
)
|
|
return self._http_client
|
|
|
|
async def initialize(self) -> None:
|
|
"""Initialize Redis and HTTP clients."""
|
|
try:
|
|
self._redis = redis.from_url(
|
|
self._settings.redis_url,
|
|
encoding="utf-8",
|
|
decode_responses=True,
|
|
)
|
|
# Test connection
|
|
await self._redis.ping() # type: ignore[misc]
|
|
logger.info("Redis connection established for embedding cache")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Redis connection failed, caching disabled: {e}")
|
|
self._redis = None
|
|
|
|
self._http_client = httpx.AsyncClient(
|
|
base_url=self._settings.llm_gateway_url,
|
|
timeout=httpx.Timeout(60.0, connect=10.0),
|
|
)
|
|
logger.info("Embedding generator initialized")
|
|
|
|
async def close(self) -> None:
|
|
"""Close connections."""
|
|
if self._redis:
|
|
await self._redis.close()
|
|
self._redis = None
|
|
|
|
if self._http_client:
|
|
await self._http_client.aclose()
|
|
self._http_client = None
|
|
|
|
logger.info("Embedding generator closed")
|
|
|
|
def _cache_key(self, text: str) -> str:
|
|
"""Generate cache key for a text."""
|
|
text_hash = hashlib.sha256(text.encode()).hexdigest()[:32]
|
|
model = self._settings.embedding_model
|
|
return f"kb:emb:{model}:{text_hash}"
|
|
|
|
async def _get_cached(self, text: str) -> list[float] | None:
|
|
"""Get cached embedding if available."""
|
|
if not self._redis:
|
|
return None
|
|
|
|
try:
|
|
key = self._cache_key(text)
|
|
cached = await self._redis.get(key)
|
|
if cached:
|
|
logger.debug(f"Cache hit for embedding: {key[:20]}...")
|
|
return json.loads(cached)
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Cache read error: {e}")
|
|
return None
|
|
|
|
async def _set_cached(self, text: str, embedding: list[float]) -> None:
|
|
"""Cache an embedding."""
|
|
if not self._redis:
|
|
return
|
|
|
|
try:
|
|
key = self._cache_key(text)
|
|
await self._redis.setex(
|
|
key,
|
|
self._settings.embedding_cache_ttl,
|
|
json.dumps(embedding),
|
|
)
|
|
logger.debug(f"Cached embedding: {key[:20]}...")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Cache write error: {e}")
|
|
|
|
async def _get_cached_batch(
|
|
self, texts: list[str]
|
|
) -> tuple[list[list[float] | None], list[int]]:
|
|
"""
|
|
Get cached embeddings for a batch of texts.
|
|
|
|
Returns:
|
|
Tuple of (embeddings list with None for misses, indices of misses)
|
|
"""
|
|
if not self._redis:
|
|
return [None] * len(texts), list(range(len(texts)))
|
|
|
|
try:
|
|
keys = [self._cache_key(text) for text in texts]
|
|
cached_values = await self._redis.mget(keys)
|
|
|
|
embeddings: list[list[float] | None] = []
|
|
missing_indices: list[int] = []
|
|
|
|
for i, cached in enumerate(cached_values):
|
|
if cached:
|
|
embeddings.append(json.loads(cached))
|
|
else:
|
|
embeddings.append(None)
|
|
missing_indices.append(i)
|
|
|
|
cache_hits = len(texts) - len(missing_indices)
|
|
if cache_hits > 0:
|
|
logger.debug(f"Batch cache hits: {cache_hits}/{len(texts)}")
|
|
|
|
return embeddings, missing_indices
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Batch cache read error: {e}")
|
|
return [None] * len(texts), list(range(len(texts)))
|
|
|
|
async def _set_cached_batch(
|
|
self, texts: list[str], embeddings: list[list[float]]
|
|
) -> None:
|
|
"""Cache a batch of embeddings."""
|
|
if not self._redis or len(texts) != len(embeddings):
|
|
return
|
|
|
|
try:
|
|
pipe = self._redis.pipeline()
|
|
for text, embedding in zip(texts, embeddings, strict=True):
|
|
key = self._cache_key(text)
|
|
pipe.setex(
|
|
key,
|
|
self._settings.embedding_cache_ttl,
|
|
json.dumps(embedding),
|
|
)
|
|
await pipe.execute()
|
|
logger.debug(f"Cached {len(texts)} embeddings in batch")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Batch cache write error: {e}")
|
|
|
|
async def generate(
|
|
self,
|
|
text: str,
|
|
project_id: str = "system",
|
|
agent_id: str = "knowledge-base",
|
|
) -> list[float]:
|
|
"""
|
|
Generate embedding for a single text.
|
|
|
|
Args:
|
|
text: Text to embed
|
|
project_id: Project ID for cost attribution
|
|
agent_id: Agent ID for cost attribution
|
|
|
|
Returns:
|
|
Embedding vector
|
|
"""
|
|
# Check cache first
|
|
cached = await self._get_cached(text)
|
|
if cached:
|
|
return cached
|
|
|
|
# Generate via LLM Gateway
|
|
embeddings = await self._call_llm_gateway([text], project_id, agent_id)
|
|
|
|
if not embeddings:
|
|
raise EmbeddingGenerationError(
|
|
message="No embedding returned from LLM Gateway",
|
|
texts_count=1,
|
|
)
|
|
|
|
embedding = embeddings[0]
|
|
|
|
# Validate dimension
|
|
if len(embedding) != self._settings.embedding_dimension:
|
|
raise EmbeddingDimensionMismatchError(
|
|
expected=self._settings.embedding_dimension,
|
|
actual=len(embedding),
|
|
)
|
|
|
|
# Cache the result
|
|
await self._set_cached(text, embedding)
|
|
|
|
return embedding
|
|
|
|
async def generate_batch(
|
|
self,
|
|
texts: list[str],
|
|
project_id: str = "system",
|
|
agent_id: str = "knowledge-base",
|
|
) -> list[list[float]]:
|
|
"""
|
|
Generate embeddings for multiple texts.
|
|
|
|
Uses caching and batches requests to LLM Gateway.
|
|
|
|
Args:
|
|
texts: List of texts to embed
|
|
project_id: Project ID for cost attribution
|
|
agent_id: Agent ID for cost attribution
|
|
|
|
Returns:
|
|
List of embedding vectors
|
|
"""
|
|
if not texts:
|
|
return []
|
|
|
|
# Check cache for existing embeddings
|
|
cached_embeddings, missing_indices = await self._get_cached_batch(texts)
|
|
|
|
# If all cached, return immediately
|
|
if not missing_indices:
|
|
logger.debug(f"All {len(texts)} embeddings served from cache")
|
|
return [e for e in cached_embeddings if e is not None]
|
|
|
|
# Get texts that need embedding
|
|
texts_to_embed = [texts[i] for i in missing_indices]
|
|
|
|
# Generate embeddings in batches
|
|
new_embeddings: list[list[float]] = []
|
|
batch_size = self._settings.embedding_batch_size
|
|
|
|
for i in range(0, len(texts_to_embed), batch_size):
|
|
batch = texts_to_embed[i : i + batch_size]
|
|
batch_embeddings = await self._call_llm_gateway(batch, project_id, agent_id)
|
|
new_embeddings.extend(batch_embeddings)
|
|
|
|
# Validate dimensions
|
|
for embedding in new_embeddings:
|
|
if len(embedding) != self._settings.embedding_dimension:
|
|
raise EmbeddingDimensionMismatchError(
|
|
expected=self._settings.embedding_dimension,
|
|
actual=len(embedding),
|
|
)
|
|
|
|
# Cache new embeddings
|
|
await self._set_cached_batch(texts_to_embed, new_embeddings)
|
|
|
|
# Combine cached and new embeddings
|
|
result: list[list[float]] = []
|
|
new_idx = 0
|
|
|
|
for i in range(len(texts)):
|
|
if cached_embeddings[i] is not None:
|
|
result.append(cached_embeddings[i]) # type: ignore[arg-type]
|
|
else:
|
|
result.append(new_embeddings[new_idx])
|
|
new_idx += 1
|
|
|
|
logger.info(
|
|
f"Generated {len(new_embeddings)} embeddings, "
|
|
f"{len(texts) - len(missing_indices)} from cache"
|
|
)
|
|
|
|
return result
|
|
|
|
async def _call_llm_gateway(
|
|
self,
|
|
texts: list[str],
|
|
project_id: str,
|
|
agent_id: str,
|
|
) -> list[list[float]]:
|
|
"""
|
|
Call LLM Gateway to generate embeddings.
|
|
|
|
Uses JSON-RPC 2.0 protocol to call the embedding tool.
|
|
"""
|
|
try:
|
|
# JSON-RPC 2.0 request for embedding tool
|
|
request = {
|
|
"jsonrpc": "2.0",
|
|
"method": "tools/call",
|
|
"params": {
|
|
"name": "generate_embeddings",
|
|
"arguments": {
|
|
"project_id": project_id,
|
|
"agent_id": agent_id,
|
|
"texts": texts,
|
|
"model": self._settings.embedding_model,
|
|
},
|
|
},
|
|
"id": 1,
|
|
}
|
|
|
|
response = await self.http_client.post("/mcp", json=request)
|
|
response.raise_for_status()
|
|
|
|
result = response.json()
|
|
|
|
if "error" in result:
|
|
error = result["error"]
|
|
raise EmbeddingGenerationError(
|
|
message=f"LLM Gateway error: {error.get('message', 'Unknown')}",
|
|
texts_count=len(texts),
|
|
details=error.get("data"),
|
|
)
|
|
|
|
# Extract embeddings from response
|
|
content = result.get("result", {}).get("content", [])
|
|
if not content:
|
|
raise EmbeddingGenerationError(
|
|
message="Empty response from LLM Gateway",
|
|
texts_count=len(texts),
|
|
)
|
|
|
|
# Parse the response content
|
|
# LLM Gateway returns embeddings in content[0].text as JSON
|
|
embeddings_data = content[0].get("text", "")
|
|
if isinstance(embeddings_data, str):
|
|
embeddings_data = json.loads(embeddings_data)
|
|
|
|
embeddings = embeddings_data.get("embeddings", [])
|
|
|
|
if len(embeddings) != len(texts):
|
|
raise EmbeddingGenerationError(
|
|
message=f"Embedding count mismatch: expected {len(texts)}, got {len(embeddings)}",
|
|
texts_count=len(texts),
|
|
)
|
|
|
|
return embeddings
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
logger.error(f"LLM Gateway HTTP error: {e}")
|
|
raise EmbeddingGenerationError(
|
|
message=f"LLM Gateway request failed: {e.response.status_code}",
|
|
texts_count=len(texts),
|
|
cause=e,
|
|
)
|
|
except httpx.RequestError as e:
|
|
logger.error(f"LLM Gateway request error: {e}")
|
|
raise EmbeddingGenerationError(
|
|
message=f"Failed to connect to LLM Gateway: {e}",
|
|
texts_count=len(texts),
|
|
cause=e,
|
|
)
|
|
except json.JSONDecodeError as e:
|
|
logger.error(f"Invalid JSON response from LLM Gateway: {e}")
|
|
raise EmbeddingGenerationError(
|
|
message="Invalid response format from LLM Gateway",
|
|
texts_count=len(texts),
|
|
cause=e,
|
|
)
|
|
except EmbeddingGenerationError:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error generating embeddings: {e}")
|
|
raise EmbeddingGenerationError(
|
|
message=f"Unexpected error: {e}",
|
|
texts_count=len(texts),
|
|
cause=e,
|
|
)
|
|
|
|
|
|
# Global embedding generator instance (lazy initialization)
|
|
_embedding_generator: EmbeddingGenerator | None = None
|
|
|
|
|
|
def get_embedding_generator() -> EmbeddingGenerator:
|
|
"""Get the global embedding generator instance."""
|
|
global _embedding_generator
|
|
if _embedding_generator is None:
|
|
_embedding_generator = EmbeddingGenerator()
|
|
return _embedding_generator
|
|
|
|
|
|
def reset_embedding_generator() -> None:
|
|
"""Reset the global embedding generator (for testing)."""
|
|
global _embedding_generator
|
|
_embedding_generator = None
|