- Cleaned up unnecessary comments in `__all__` definitions for better readability. - Adjusted indentation and formatting across modules for improved clarity (e.g., long lines, logical grouping). - Simplified conditional expressions and inline comments for context scoring and ranking. - Replaced some hard-coded values with type-safe annotations (e.g., `ClassVar`). - Removed unused imports and ensured consistent usage across test files. - Updated `test_score_not_cached_on_context` to clarify caching behavior. - Improved truncation strategy logic and marker handling.
221 lines
6.7 KiB
Python
221 lines
6.7 KiB
Python
"""
|
|
Relevance Scorer for Context Management.
|
|
|
|
Scores context based on semantic similarity to the query.
|
|
Uses Knowledge Base embeddings when available.
|
|
"""
|
|
|
|
import logging
|
|
import re
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
from ..config import ContextSettings, get_context_settings
|
|
from ..types import BaseContext, KnowledgeContext
|
|
from .base import BaseScorer
|
|
|
|
if TYPE_CHECKING:
|
|
from app.services.mcp.client_manager import MCPClientManager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RelevanceScorer(BaseScorer):
|
|
"""
|
|
Scores context based on relevance to query.
|
|
|
|
Uses multiple strategies:
|
|
1. Pre-computed scores (from RAG results)
|
|
2. MCP-based semantic similarity (via Knowledge Base)
|
|
3. Keyword matching fallback
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
mcp_manager: "MCPClientManager | None" = None,
|
|
weight: float = 1.0,
|
|
keyword_fallback_weight: float | None = None,
|
|
semantic_max_chars: int | None = None,
|
|
settings: ContextSettings | None = None,
|
|
) -> None:
|
|
"""
|
|
Initialize relevance scorer.
|
|
|
|
Args:
|
|
mcp_manager: MCP manager for Knowledge Base calls
|
|
weight: Scorer weight for composite scoring
|
|
keyword_fallback_weight: Max score for keyword-based fallback (overrides settings)
|
|
semantic_max_chars: Max content length for semantic similarity (overrides settings)
|
|
settings: Context settings (uses global if None)
|
|
"""
|
|
super().__init__(weight)
|
|
self._settings = settings or get_context_settings()
|
|
self._mcp = mcp_manager
|
|
|
|
# Use provided values or fall back to settings
|
|
self._keyword_fallback_weight = (
|
|
keyword_fallback_weight
|
|
if keyword_fallback_weight is not None
|
|
else self._settings.relevance_keyword_fallback_weight
|
|
)
|
|
self._semantic_max_chars = (
|
|
semantic_max_chars
|
|
if semantic_max_chars is not None
|
|
else self._settings.relevance_semantic_max_chars
|
|
)
|
|
|
|
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
|
"""Set MCP manager for semantic scoring."""
|
|
self._mcp = mcp_manager
|
|
|
|
async def score(
|
|
self,
|
|
context: BaseContext,
|
|
query: str,
|
|
**kwargs: Any,
|
|
) -> float:
|
|
"""
|
|
Score context relevance to query.
|
|
|
|
Args:
|
|
context: Context to score
|
|
query: Query to score against
|
|
**kwargs: Additional parameters
|
|
|
|
Returns:
|
|
Relevance score between 0.0 and 1.0
|
|
"""
|
|
# 1. Check for pre-computed relevance score
|
|
if (
|
|
isinstance(context, KnowledgeContext)
|
|
and context.relevance_score is not None
|
|
):
|
|
return self.normalize_score(context.relevance_score)
|
|
|
|
# 2. Check metadata for score
|
|
if "relevance_score" in context.metadata:
|
|
return self.normalize_score(context.metadata["relevance_score"])
|
|
|
|
if "score" in context.metadata:
|
|
return self.normalize_score(context.metadata["score"])
|
|
|
|
# 3. Try MCP-based semantic similarity (if compute_similarity tool is available)
|
|
# Note: This requires the knowledge-base MCP server to implement compute_similarity
|
|
if self._mcp is not None:
|
|
try:
|
|
score = await self._compute_semantic_similarity(context, query)
|
|
if score is not None:
|
|
return score
|
|
except Exception as e:
|
|
# Log at debug level since this is expected if compute_similarity
|
|
# tool is not implemented in the Knowledge Base server
|
|
logger.debug(
|
|
f"Semantic scoring unavailable, using keyword fallback: {e}"
|
|
)
|
|
|
|
# 4. Fall back to keyword matching
|
|
return self._compute_keyword_score(context, query)
|
|
|
|
async def _compute_semantic_similarity(
|
|
self,
|
|
context: BaseContext,
|
|
query: str,
|
|
) -> float | None:
|
|
"""
|
|
Compute semantic similarity using Knowledge Base embeddings.
|
|
|
|
Args:
|
|
context: Context to score
|
|
query: Query to compare
|
|
|
|
Returns:
|
|
Similarity score or None if unavailable
|
|
"""
|
|
if self._mcp is None:
|
|
return None
|
|
|
|
try:
|
|
# Use Knowledge Base's search capability to compute similarity
|
|
result = await self._mcp.call_tool(
|
|
server="knowledge-base",
|
|
tool="compute_similarity",
|
|
args={
|
|
"text1": query,
|
|
"text2": context.content[
|
|
: self._semantic_max_chars
|
|
], # Limit content length
|
|
},
|
|
)
|
|
|
|
if result.success and isinstance(result.data, dict):
|
|
similarity = result.data.get("similarity")
|
|
if similarity is not None:
|
|
return self.normalize_score(float(similarity))
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Semantic similarity computation failed: {e}")
|
|
|
|
return None
|
|
|
|
def _compute_keyword_score(
|
|
self,
|
|
context: BaseContext,
|
|
query: str,
|
|
) -> float:
|
|
"""
|
|
Compute relevance score based on keyword matching.
|
|
|
|
Simple but fast fallback when semantic search is unavailable.
|
|
|
|
Args:
|
|
context: Context to score
|
|
query: Query to match
|
|
|
|
Returns:
|
|
Keyword-based relevance score
|
|
"""
|
|
if not query or not context.content:
|
|
return 0.0
|
|
|
|
# Extract keywords from query
|
|
query_lower = query.lower()
|
|
content_lower = context.content.lower()
|
|
|
|
# Simple word tokenization
|
|
query_words = set(re.findall(r"\b\w{3,}\b", query_lower))
|
|
content_words = set(re.findall(r"\b\w{3,}\b", content_lower))
|
|
|
|
if not query_words:
|
|
return 0.0
|
|
|
|
# Calculate overlap
|
|
common_words = query_words & content_words
|
|
overlap_ratio = len(common_words) / len(query_words)
|
|
|
|
# Apply fallback weight ceiling
|
|
return self.normalize_score(overlap_ratio * self._keyword_fallback_weight)
|
|
|
|
async def score_batch(
|
|
self,
|
|
contexts: list[BaseContext],
|
|
query: str,
|
|
**kwargs: Any,
|
|
) -> list[float]:
|
|
"""
|
|
Score multiple contexts in parallel.
|
|
|
|
Args:
|
|
contexts: Contexts to score
|
|
query: Query to score against
|
|
**kwargs: Additional parameters
|
|
|
|
Returns:
|
|
List of scores (same order as input)
|
|
"""
|
|
import asyncio
|
|
|
|
if not contexts:
|
|
return []
|
|
|
|
tasks = [self.score(context, query, **kwargs) for context in contexts]
|
|
return await asyncio.gather(*tasks)
|