Files
syndarix/mcp-servers/knowledge-base/embeddings.py
Felipe Cardoso d0fc7f37ff feat(knowledge-base): implement Knowledge Base MCP Server (#57)
Implements RAG capabilities with pgvector for semantic search:

- Intelligent chunking strategies (code-aware, markdown-aware, text)
- Semantic search with vector similarity (HNSW index)
- Keyword search with PostgreSQL full-text search
- Hybrid search using Reciprocal Rank Fusion (RRF)
- Redis caching for embeddings
- Collection management (ingest, search, delete, stats)
- FastMCP tools: search_knowledge, ingest_content, delete_content,
  list_collections, get_collection_stats, update_document

Testing:
- 128 comprehensive tests covering all components
- 58% code coverage (database integration tests use mocks)
- Passes ruff linting and mypy type checking

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-03 21:33:26 +01:00

427 lines
14 KiB
Python

"""
Embedding generation for Knowledge Base MCP Server.
Generates vector embeddings via the LLM Gateway MCP server
with caching support using Redis.
"""
import hashlib
import json
import logging
import httpx
import redis.asyncio as redis
from config import Settings, get_settings
from exceptions import (
EmbeddingDimensionMismatchError,
EmbeddingGenerationError,
ErrorCode,
KnowledgeBaseError,
)
logger = logging.getLogger(__name__)
class EmbeddingGenerator:
"""
Generates embeddings via LLM Gateway.
Features:
- Batched embedding generation
- Redis caching for deduplication
- Automatic retry on transient failures
"""
def __init__(self, settings: Settings | None = None) -> None:
"""Initialize embedding generator."""
self._settings = settings or get_settings()
self._redis: redis.Redis | None = None # type: ignore[type-arg]
self._http_client: httpx.AsyncClient | None = None
@property
def redis_client(self) -> redis.Redis: # type: ignore[type-arg]
"""Get Redis client, raising if not initialized."""
if self._redis is None:
raise KnowledgeBaseError(
message="Redis client not initialized",
code=ErrorCode.INTERNAL_ERROR,
)
return self._redis
@property
def http_client(self) -> httpx.AsyncClient:
"""Get HTTP client, raising if not initialized."""
if self._http_client is None:
raise KnowledgeBaseError(
message="HTTP client not initialized",
code=ErrorCode.INTERNAL_ERROR,
)
return self._http_client
async def initialize(self) -> None:
"""Initialize Redis and HTTP clients."""
try:
self._redis = redis.from_url(
self._settings.redis_url,
encoding="utf-8",
decode_responses=True,
)
# Test connection
await self._redis.ping() # type: ignore[misc]
logger.info("Redis connection established for embedding cache")
except Exception as e:
logger.warning(f"Redis connection failed, caching disabled: {e}")
self._redis = None
self._http_client = httpx.AsyncClient(
base_url=self._settings.llm_gateway_url,
timeout=httpx.Timeout(60.0, connect=10.0),
)
logger.info("Embedding generator initialized")
async def close(self) -> None:
"""Close connections."""
if self._redis:
await self._redis.close()
self._redis = None
if self._http_client:
await self._http_client.aclose()
self._http_client = None
logger.info("Embedding generator closed")
def _cache_key(self, text: str) -> str:
"""Generate cache key for a text."""
text_hash = hashlib.sha256(text.encode()).hexdigest()[:32]
model = self._settings.embedding_model
return f"kb:emb:{model}:{text_hash}"
async def _get_cached(self, text: str) -> list[float] | None:
"""Get cached embedding if available."""
if not self._redis:
return None
try:
key = self._cache_key(text)
cached = await self._redis.get(key)
if cached:
logger.debug(f"Cache hit for embedding: {key[:20]}...")
return json.loads(cached)
return None
except Exception as e:
logger.warning(f"Cache read error: {e}")
return None
async def _set_cached(self, text: str, embedding: list[float]) -> None:
"""Cache an embedding."""
if not self._redis:
return
try:
key = self._cache_key(text)
await self._redis.setex(
key,
self._settings.embedding_cache_ttl,
json.dumps(embedding),
)
logger.debug(f"Cached embedding: {key[:20]}...")
except Exception as e:
logger.warning(f"Cache write error: {e}")
async def _get_cached_batch(
self, texts: list[str]
) -> tuple[list[list[float] | None], list[int]]:
"""
Get cached embeddings for a batch of texts.
Returns:
Tuple of (embeddings list with None for misses, indices of misses)
"""
if not self._redis:
return [None] * len(texts), list(range(len(texts)))
try:
keys = [self._cache_key(text) for text in texts]
cached_values = await self._redis.mget(keys)
embeddings: list[list[float] | None] = []
missing_indices: list[int] = []
for i, cached in enumerate(cached_values):
if cached:
embeddings.append(json.loads(cached))
else:
embeddings.append(None)
missing_indices.append(i)
cache_hits = len(texts) - len(missing_indices)
if cache_hits > 0:
logger.debug(f"Batch cache hits: {cache_hits}/{len(texts)}")
return embeddings, missing_indices
except Exception as e:
logger.warning(f"Batch cache read error: {e}")
return [None] * len(texts), list(range(len(texts)))
async def _set_cached_batch(
self, texts: list[str], embeddings: list[list[float]]
) -> None:
"""Cache a batch of embeddings."""
if not self._redis or len(texts) != len(embeddings):
return
try:
pipe = self._redis.pipeline()
for text, embedding in zip(texts, embeddings, strict=True):
key = self._cache_key(text)
pipe.setex(
key,
self._settings.embedding_cache_ttl,
json.dumps(embedding),
)
await pipe.execute()
logger.debug(f"Cached {len(texts)} embeddings in batch")
except Exception as e:
logger.warning(f"Batch cache write error: {e}")
async def generate(
self,
text: str,
project_id: str = "system",
agent_id: str = "knowledge-base",
) -> list[float]:
"""
Generate embedding for a single text.
Args:
text: Text to embed
project_id: Project ID for cost attribution
agent_id: Agent ID for cost attribution
Returns:
Embedding vector
"""
# Check cache first
cached = await self._get_cached(text)
if cached:
return cached
# Generate via LLM Gateway
embeddings = await self._call_llm_gateway(
[text], project_id, agent_id
)
if not embeddings:
raise EmbeddingGenerationError(
message="No embedding returned from LLM Gateway",
texts_count=1,
)
embedding = embeddings[0]
# Validate dimension
if len(embedding) != self._settings.embedding_dimension:
raise EmbeddingDimensionMismatchError(
expected=self._settings.embedding_dimension,
actual=len(embedding),
)
# Cache the result
await self._set_cached(text, embedding)
return embedding
async def generate_batch(
self,
texts: list[str],
project_id: str = "system",
agent_id: str = "knowledge-base",
) -> list[list[float]]:
"""
Generate embeddings for multiple texts.
Uses caching and batches requests to LLM Gateway.
Args:
texts: List of texts to embed
project_id: Project ID for cost attribution
agent_id: Agent ID for cost attribution
Returns:
List of embedding vectors
"""
if not texts:
return []
# Check cache for existing embeddings
cached_embeddings, missing_indices = await self._get_cached_batch(texts)
# If all cached, return immediately
if not missing_indices:
logger.debug(f"All {len(texts)} embeddings served from cache")
return [e for e in cached_embeddings if e is not None]
# Get texts that need embedding
texts_to_embed = [texts[i] for i in missing_indices]
# Generate embeddings in batches
new_embeddings: list[list[float]] = []
batch_size = self._settings.embedding_batch_size
for i in range(0, len(texts_to_embed), batch_size):
batch = texts_to_embed[i : i + batch_size]
batch_embeddings = await self._call_llm_gateway(
batch, project_id, agent_id
)
new_embeddings.extend(batch_embeddings)
# Validate dimensions
for embedding in new_embeddings:
if len(embedding) != self._settings.embedding_dimension:
raise EmbeddingDimensionMismatchError(
expected=self._settings.embedding_dimension,
actual=len(embedding),
)
# Cache new embeddings
await self._set_cached_batch(texts_to_embed, new_embeddings)
# Combine cached and new embeddings
result: list[list[float]] = []
new_idx = 0
for i in range(len(texts)):
if cached_embeddings[i] is not None:
result.append(cached_embeddings[i]) # type: ignore[arg-type]
else:
result.append(new_embeddings[new_idx])
new_idx += 1
logger.info(
f"Generated {len(new_embeddings)} embeddings, "
f"{len(texts) - len(missing_indices)} from cache"
)
return result
async def _call_llm_gateway(
self,
texts: list[str],
project_id: str,
agent_id: str,
) -> list[list[float]]:
"""
Call LLM Gateway to generate embeddings.
Uses JSON-RPC 2.0 protocol to call the embedding tool.
"""
try:
# JSON-RPC 2.0 request for embedding tool
request = {
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "generate_embeddings",
"arguments": {
"project_id": project_id,
"agent_id": agent_id,
"texts": texts,
"model": self._settings.embedding_model,
},
},
"id": 1,
}
response = await self.http_client.post("/mcp", json=request)
response.raise_for_status()
result = response.json()
if "error" in result:
error = result["error"]
raise EmbeddingGenerationError(
message=f"LLM Gateway error: {error.get('message', 'Unknown')}",
texts_count=len(texts),
details=error.get("data"),
)
# Extract embeddings from response
content = result.get("result", {}).get("content", [])
if not content:
raise EmbeddingGenerationError(
message="Empty response from LLM Gateway",
texts_count=len(texts),
)
# Parse the response content
# LLM Gateway returns embeddings in content[0].text as JSON
embeddings_data = content[0].get("text", "")
if isinstance(embeddings_data, str):
embeddings_data = json.loads(embeddings_data)
embeddings = embeddings_data.get("embeddings", [])
if len(embeddings) != len(texts):
raise EmbeddingGenerationError(
message=f"Embedding count mismatch: expected {len(texts)}, got {len(embeddings)}",
texts_count=len(texts),
)
return embeddings
except httpx.HTTPStatusError as e:
logger.error(f"LLM Gateway HTTP error: {e}")
raise EmbeddingGenerationError(
message=f"LLM Gateway request failed: {e.response.status_code}",
texts_count=len(texts),
cause=e,
)
except httpx.RequestError as e:
logger.error(f"LLM Gateway request error: {e}")
raise EmbeddingGenerationError(
message=f"Failed to connect to LLM Gateway: {e}",
texts_count=len(texts),
cause=e,
)
except json.JSONDecodeError as e:
logger.error(f"Invalid JSON response from LLM Gateway: {e}")
raise EmbeddingGenerationError(
message="Invalid response format from LLM Gateway",
texts_count=len(texts),
cause=e,
)
except EmbeddingGenerationError:
raise
except Exception as e:
logger.error(f"Unexpected error generating embeddings: {e}")
raise EmbeddingGenerationError(
message=f"Unexpected error: {e}",
texts_count=len(texts),
cause=e,
)
# Global embedding generator instance (lazy initialization)
_embedding_generator: EmbeddingGenerator | None = None
def get_embedding_generator() -> EmbeddingGenerator:
"""Get the global embedding generator instance."""
global _embedding_generator
if _embedding_generator is None:
_embedding_generator = EmbeddingGenerator()
return _embedding_generator
def reset_embedding_generator() -> None:
"""Reset the global embedding generator (for testing)."""
global _embedding_generator
_embedding_generator = None