feat(knowledge-base): implement Knowledge Base MCP Server (#57)
Implements RAG capabilities with pgvector for semantic search: - Intelligent chunking strategies (code-aware, markdown-aware, text) - Semantic search with vector similarity (HNSW index) - Keyword search with PostgreSQL full-text search - Hybrid search using Reciprocal Rank Fusion (RRF) - Redis caching for embeddings - Collection management (ingest, search, delete, stats) - FastMCP tools: search_knowledge, ingest_content, delete_content, list_collections, get_collection_stats, update_document Testing: - 128 comprehensive tests covering all components - 58% code coverage (database integration tests use mocks) - Passes ruff linting and mypy type checking 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
426
mcp-servers/knowledge-base/embeddings.py
Normal file
426
mcp-servers/knowledge-base/embeddings.py
Normal file
@@ -0,0 +1,426 @@
|
||||
"""
|
||||
Embedding generation for Knowledge Base MCP Server.
|
||||
|
||||
Generates vector embeddings via the LLM Gateway MCP server
|
||||
with caching support using Redis.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
import redis.asyncio as redis
|
||||
|
||||
from config import Settings, get_settings
|
||||
from exceptions import (
|
||||
EmbeddingDimensionMismatchError,
|
||||
EmbeddingGenerationError,
|
||||
ErrorCode,
|
||||
KnowledgeBaseError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingGenerator:
|
||||
"""
|
||||
Generates embeddings via LLM Gateway.
|
||||
|
||||
Features:
|
||||
- Batched embedding generation
|
||||
- Redis caching for deduplication
|
||||
- Automatic retry on transient failures
|
||||
"""
|
||||
|
||||
def __init__(self, settings: Settings | None = None) -> None:
|
||||
"""Initialize embedding generator."""
|
||||
self._settings = settings or get_settings()
|
||||
self._redis: redis.Redis | None = None # type: ignore[type-arg]
|
||||
self._http_client: httpx.AsyncClient | None = None
|
||||
|
||||
@property
|
||||
def redis_client(self) -> redis.Redis: # type: ignore[type-arg]
|
||||
"""Get Redis client, raising if not initialized."""
|
||||
if self._redis is None:
|
||||
raise KnowledgeBaseError(
|
||||
message="Redis client not initialized",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
)
|
||||
return self._redis
|
||||
|
||||
@property
|
||||
def http_client(self) -> httpx.AsyncClient:
|
||||
"""Get HTTP client, raising if not initialized."""
|
||||
if self._http_client is None:
|
||||
raise KnowledgeBaseError(
|
||||
message="HTTP client not initialized",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
)
|
||||
return self._http_client
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize Redis and HTTP clients."""
|
||||
try:
|
||||
self._redis = redis.from_url(
|
||||
self._settings.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
# Test connection
|
||||
await self._redis.ping() # type: ignore[misc]
|
||||
logger.info("Redis connection established for embedding cache")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis connection failed, caching disabled: {e}")
|
||||
self._redis = None
|
||||
|
||||
self._http_client = httpx.AsyncClient(
|
||||
base_url=self._settings.llm_gateway_url,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
)
|
||||
logger.info("Embedding generator initialized")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close connections."""
|
||||
if self._redis:
|
||||
await self._redis.close()
|
||||
self._redis = None
|
||||
|
||||
if self._http_client:
|
||||
await self._http_client.aclose()
|
||||
self._http_client = None
|
||||
|
||||
logger.info("Embedding generator closed")
|
||||
|
||||
def _cache_key(self, text: str) -> str:
|
||||
"""Generate cache key for a text."""
|
||||
text_hash = hashlib.sha256(text.encode()).hexdigest()[:32]
|
||||
model = self._settings.embedding_model
|
||||
return f"kb:emb:{model}:{text_hash}"
|
||||
|
||||
async def _get_cached(self, text: str) -> list[float] | None:
|
||||
"""Get cached embedding if available."""
|
||||
if not self._redis:
|
||||
return None
|
||||
|
||||
try:
|
||||
key = self._cache_key(text)
|
||||
cached = await self._redis.get(key)
|
||||
if cached:
|
||||
logger.debug(f"Cache hit for embedding: {key[:20]}...")
|
||||
return json.loads(cached)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache read error: {e}")
|
||||
return None
|
||||
|
||||
async def _set_cached(self, text: str, embedding: list[float]) -> None:
|
||||
"""Cache an embedding."""
|
||||
if not self._redis:
|
||||
return
|
||||
|
||||
try:
|
||||
key = self._cache_key(text)
|
||||
await self._redis.setex(
|
||||
key,
|
||||
self._settings.embedding_cache_ttl,
|
||||
json.dumps(embedding),
|
||||
)
|
||||
logger.debug(f"Cached embedding: {key[:20]}...")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache write error: {e}")
|
||||
|
||||
async def _get_cached_batch(
|
||||
self, texts: list[str]
|
||||
) -> tuple[list[list[float] | None], list[int]]:
|
||||
"""
|
||||
Get cached embeddings for a batch of texts.
|
||||
|
||||
Returns:
|
||||
Tuple of (embeddings list with None for misses, indices of misses)
|
||||
"""
|
||||
if not self._redis:
|
||||
return [None] * len(texts), list(range(len(texts)))
|
||||
|
||||
try:
|
||||
keys = [self._cache_key(text) for text in texts]
|
||||
cached_values = await self._redis.mget(keys)
|
||||
|
||||
embeddings: list[list[float] | None] = []
|
||||
missing_indices: list[int] = []
|
||||
|
||||
for i, cached in enumerate(cached_values):
|
||||
if cached:
|
||||
embeddings.append(json.loads(cached))
|
||||
else:
|
||||
embeddings.append(None)
|
||||
missing_indices.append(i)
|
||||
|
||||
cache_hits = len(texts) - len(missing_indices)
|
||||
if cache_hits > 0:
|
||||
logger.debug(f"Batch cache hits: {cache_hits}/{len(texts)}")
|
||||
|
||||
return embeddings, missing_indices
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Batch cache read error: {e}")
|
||||
return [None] * len(texts), list(range(len(texts)))
|
||||
|
||||
async def _set_cached_batch(
|
||||
self, texts: list[str], embeddings: list[list[float]]
|
||||
) -> None:
|
||||
"""Cache a batch of embeddings."""
|
||||
if not self._redis or len(texts) != len(embeddings):
|
||||
return
|
||||
|
||||
try:
|
||||
pipe = self._redis.pipeline()
|
||||
for text, embedding in zip(texts, embeddings, strict=True):
|
||||
key = self._cache_key(text)
|
||||
pipe.setex(
|
||||
key,
|
||||
self._settings.embedding_cache_ttl,
|
||||
json.dumps(embedding),
|
||||
)
|
||||
await pipe.execute()
|
||||
logger.debug(f"Cached {len(texts)} embeddings in batch")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Batch cache write error: {e}")
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
text: str,
|
||||
project_id: str = "system",
|
||||
agent_id: str = "knowledge-base",
|
||||
) -> list[float]:
|
||||
"""
|
||||
Generate embedding for a single text.
|
||||
|
||||
Args:
|
||||
text: Text to embed
|
||||
project_id: Project ID for cost attribution
|
||||
agent_id: Agent ID for cost attribution
|
||||
|
||||
Returns:
|
||||
Embedding vector
|
||||
"""
|
||||
# Check cache first
|
||||
cached = await self._get_cached(text)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
# Generate via LLM Gateway
|
||||
embeddings = await self._call_llm_gateway(
|
||||
[text], project_id, agent_id
|
||||
)
|
||||
|
||||
if not embeddings:
|
||||
raise EmbeddingGenerationError(
|
||||
message="No embedding returned from LLM Gateway",
|
||||
texts_count=1,
|
||||
)
|
||||
|
||||
embedding = embeddings[0]
|
||||
|
||||
# Validate dimension
|
||||
if len(embedding) != self._settings.embedding_dimension:
|
||||
raise EmbeddingDimensionMismatchError(
|
||||
expected=self._settings.embedding_dimension,
|
||||
actual=len(embedding),
|
||||
)
|
||||
|
||||
# Cache the result
|
||||
await self._set_cached(text, embedding)
|
||||
|
||||
return embedding
|
||||
|
||||
async def generate_batch(
|
||||
self,
|
||||
texts: list[str],
|
||||
project_id: str = "system",
|
||||
agent_id: str = "knowledge-base",
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
Generate embeddings for multiple texts.
|
||||
|
||||
Uses caching and batches requests to LLM Gateway.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed
|
||||
project_id: Project ID for cost attribution
|
||||
agent_id: Agent ID for cost attribution
|
||||
|
||||
Returns:
|
||||
List of embedding vectors
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
# Check cache for existing embeddings
|
||||
cached_embeddings, missing_indices = await self._get_cached_batch(texts)
|
||||
|
||||
# If all cached, return immediately
|
||||
if not missing_indices:
|
||||
logger.debug(f"All {len(texts)} embeddings served from cache")
|
||||
return [e for e in cached_embeddings if e is not None]
|
||||
|
||||
# Get texts that need embedding
|
||||
texts_to_embed = [texts[i] for i in missing_indices]
|
||||
|
||||
# Generate embeddings in batches
|
||||
new_embeddings: list[list[float]] = []
|
||||
batch_size = self._settings.embedding_batch_size
|
||||
|
||||
for i in range(0, len(texts_to_embed), batch_size):
|
||||
batch = texts_to_embed[i : i + batch_size]
|
||||
batch_embeddings = await self._call_llm_gateway(
|
||||
batch, project_id, agent_id
|
||||
)
|
||||
new_embeddings.extend(batch_embeddings)
|
||||
|
||||
# Validate dimensions
|
||||
for embedding in new_embeddings:
|
||||
if len(embedding) != self._settings.embedding_dimension:
|
||||
raise EmbeddingDimensionMismatchError(
|
||||
expected=self._settings.embedding_dimension,
|
||||
actual=len(embedding),
|
||||
)
|
||||
|
||||
# Cache new embeddings
|
||||
await self._set_cached_batch(texts_to_embed, new_embeddings)
|
||||
|
||||
# Combine cached and new embeddings
|
||||
result: list[list[float]] = []
|
||||
new_idx = 0
|
||||
|
||||
for i in range(len(texts)):
|
||||
if cached_embeddings[i] is not None:
|
||||
result.append(cached_embeddings[i]) # type: ignore[arg-type]
|
||||
else:
|
||||
result.append(new_embeddings[new_idx])
|
||||
new_idx += 1
|
||||
|
||||
logger.info(
|
||||
f"Generated {len(new_embeddings)} embeddings, "
|
||||
f"{len(texts) - len(missing_indices)} from cache"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def _call_llm_gateway(
|
||||
self,
|
||||
texts: list[str],
|
||||
project_id: str,
|
||||
agent_id: str,
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
Call LLM Gateway to generate embeddings.
|
||||
|
||||
Uses JSON-RPC 2.0 protocol to call the embedding tool.
|
||||
"""
|
||||
try:
|
||||
# JSON-RPC 2.0 request for embedding tool
|
||||
request = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "generate_embeddings",
|
||||
"arguments": {
|
||||
"project_id": project_id,
|
||||
"agent_id": agent_id,
|
||||
"texts": texts,
|
||||
"model": self._settings.embedding_model,
|
||||
},
|
||||
},
|
||||
"id": 1,
|
||||
}
|
||||
|
||||
response = await self.http_client.post("/mcp", json=request)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
|
||||
if "error" in result:
|
||||
error = result["error"]
|
||||
raise EmbeddingGenerationError(
|
||||
message=f"LLM Gateway error: {error.get('message', 'Unknown')}",
|
||||
texts_count=len(texts),
|
||||
details=error.get("data"),
|
||||
)
|
||||
|
||||
# Extract embeddings from response
|
||||
content = result.get("result", {}).get("content", [])
|
||||
if not content:
|
||||
raise EmbeddingGenerationError(
|
||||
message="Empty response from LLM Gateway",
|
||||
texts_count=len(texts),
|
||||
)
|
||||
|
||||
# Parse the response content
|
||||
# LLM Gateway returns embeddings in content[0].text as JSON
|
||||
embeddings_data = content[0].get("text", "")
|
||||
if isinstance(embeddings_data, str):
|
||||
embeddings_data = json.loads(embeddings_data)
|
||||
|
||||
embeddings = embeddings_data.get("embeddings", [])
|
||||
|
||||
if len(embeddings) != len(texts):
|
||||
raise EmbeddingGenerationError(
|
||||
message=f"Embedding count mismatch: expected {len(texts)}, got {len(embeddings)}",
|
||||
texts_count=len(texts),
|
||||
)
|
||||
|
||||
return embeddings
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"LLM Gateway HTTP error: {e}")
|
||||
raise EmbeddingGenerationError(
|
||||
message=f"LLM Gateway request failed: {e.response.status_code}",
|
||||
texts_count=len(texts),
|
||||
cause=e,
|
||||
)
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"LLM Gateway request error: {e}")
|
||||
raise EmbeddingGenerationError(
|
||||
message=f"Failed to connect to LLM Gateway: {e}",
|
||||
texts_count=len(texts),
|
||||
cause=e,
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Invalid JSON response from LLM Gateway: {e}")
|
||||
raise EmbeddingGenerationError(
|
||||
message="Invalid response format from LLM Gateway",
|
||||
texts_count=len(texts),
|
||||
cause=e,
|
||||
)
|
||||
except EmbeddingGenerationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error generating embeddings: {e}")
|
||||
raise EmbeddingGenerationError(
|
||||
message=f"Unexpected error: {e}",
|
||||
texts_count=len(texts),
|
||||
cause=e,
|
||||
)
|
||||
|
||||
|
||||
# Global embedding generator instance (lazy initialization)
|
||||
_embedding_generator: EmbeddingGenerator | None = None
|
||||
|
||||
|
||||
def get_embedding_generator() -> EmbeddingGenerator:
|
||||
"""Get the global embedding generator instance."""
|
||||
global _embedding_generator
|
||||
if _embedding_generator is None:
|
||||
_embedding_generator = EmbeddingGenerator()
|
||||
return _embedding_generator
|
||||
|
||||
|
||||
def reset_embedding_generator() -> None:
|
||||
"""Reset the global embedding generator (for testing)."""
|
||||
global _embedding_generator
|
||||
_embedding_generator = None
|
||||
Reference in New Issue
Block a user