forked from cardosofelipe/fast-next-template
feat(context): enhance performance, caching, and settings management
- Replace hard-coded limits with configurable settings (e.g., cache memory size, truncation strategy, relevance settings). - Optimize parallel execution in token counting, scoring, and reranking for source diversity. - Improve caching logic: - Add per-context locks for safe parallel scoring. - Reuse precomputed fingerprints for cache efficiency. - Make truncation, scoring, and ranker behaviors fully configurable via settings. - Add support for middle truncation, context hash-based hashing, and dynamic token limiting. - Refactor methods for scalability and better error handling. Tests: Updated all affected components with additional test cases.
This commit is contained in:
@@ -237,7 +237,7 @@ class TokenCalculator:
|
||||
"""
|
||||
Count tokens for multiple texts.
|
||||
|
||||
Efficient batch counting with caching.
|
||||
Efficient batch counting with caching and parallel execution.
|
||||
|
||||
Args:
|
||||
texts: List of texts to count
|
||||
@@ -246,13 +246,14 @@ class TokenCalculator:
|
||||
Returns:
|
||||
List of token counts (same order as input)
|
||||
"""
|
||||
results: list[int] = []
|
||||
import asyncio
|
||||
|
||||
for text in texts:
|
||||
count = await self.count_tokens(text, model)
|
||||
results.append(count)
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
return results
|
||||
# Execute all token counts in parallel for better performance
|
||||
tasks = [self.count_tokens(text, model) for text in texts]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the token count cache."""
|
||||
|
||||
@@ -54,7 +54,7 @@ class ContextCache:
|
||||
|
||||
# In-memory fallback cache when Redis unavailable
|
||||
self._memory_cache: dict[str, tuple[str, float]] = {}
|
||||
self._max_memory_items = 1000
|
||||
self._max_memory_items = self._settings.cache_memory_max_items
|
||||
|
||||
def set_redis(self, redis: "Redis") -> None:
|
||||
"""Set Redis connection."""
|
||||
@@ -100,7 +100,7 @@ class ContextCache:
|
||||
Compute a fingerprint for a context assembly request.
|
||||
|
||||
The fingerprint is based on:
|
||||
- Context content and metadata
|
||||
- Context content hash and metadata (not full content for performance)
|
||||
- Query string
|
||||
- Target model
|
||||
|
||||
@@ -112,12 +112,13 @@ class ContextCache:
|
||||
Returns:
|
||||
32-character hex fingerprint
|
||||
"""
|
||||
# Build a deterministic representation
|
||||
# Build a deterministic representation using content hashes for performance
|
||||
# This avoids JSON serializing potentially large content strings
|
||||
context_data = []
|
||||
for ctx in contexts:
|
||||
context_data.append({
|
||||
"type": ctx.get_type().value,
|
||||
"content": ctx.content,
|
||||
"content_hash": self._hash_content(ctx.content), # Hash instead of full content
|
||||
"source": ctx.source,
|
||||
"priority": ctx.priority, # Already an int
|
||||
})
|
||||
|
||||
@@ -10,6 +10,7 @@ import re
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
from ..types import BaseContext, ContextType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -45,26 +46,41 @@ class TruncationStrategy:
|
||||
4. Semantic chunking: Keep most relevant chunks
|
||||
"""
|
||||
|
||||
# Default truncation marker
|
||||
TRUNCATION_MARKER = "\n\n[...content truncated...]\n\n"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
calculator: "TokenCalculator | None" = None,
|
||||
preserve_ratio_start: float = 0.7, # Keep 70% from start by default
|
||||
min_content_length: int = 100, # Minimum characters to keep
|
||||
preserve_ratio_start: float | None = None,
|
||||
min_content_length: int | None = None,
|
||||
settings: ContextSettings | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize truncation strategy.
|
||||
|
||||
Args:
|
||||
calculator: Token calculator for accurate counting
|
||||
preserve_ratio_start: Ratio of content to keep from start
|
||||
min_content_length: Minimum characters to preserve
|
||||
preserve_ratio_start: Ratio of content to keep from start (overrides settings)
|
||||
min_content_length: Minimum characters to preserve (overrides settings)
|
||||
settings: Context settings (uses global if None)
|
||||
"""
|
||||
self._settings = settings or get_context_settings()
|
||||
self._calculator = calculator
|
||||
self._preserve_ratio_start = preserve_ratio_start
|
||||
self._min_content_length = min_content_length
|
||||
|
||||
# Use provided values or fall back to settings
|
||||
self._preserve_ratio_start = (
|
||||
preserve_ratio_start
|
||||
if preserve_ratio_start is not None
|
||||
else self._settings.truncation_preserve_ratio
|
||||
)
|
||||
self._min_content_length = (
|
||||
min_content_length
|
||||
if min_content_length is not None
|
||||
else self._settings.truncation_min_content_length
|
||||
)
|
||||
|
||||
@property
|
||||
def TRUNCATION_MARKER(self) -> str:
|
||||
"""Get truncation marker from settings."""
|
||||
return self._settings.truncation_marker
|
||||
|
||||
def set_calculator(self, calculator: "TokenCalculator") -> None:
|
||||
"""Set token calculator."""
|
||||
@@ -125,7 +141,7 @@ class TruncationStrategy:
|
||||
truncated_tokens=truncated_tokens,
|
||||
content=truncated,
|
||||
truncated=True,
|
||||
truncation_ratio=1 - (truncated_tokens / original_tokens),
|
||||
truncation_ratio=0.0 if original_tokens == 0 else 1 - (truncated_tokens / original_tokens),
|
||||
)
|
||||
|
||||
async def _truncate_end(
|
||||
@@ -141,10 +157,17 @@ class TruncationStrategy:
|
||||
"""
|
||||
# Binary search for optimal truncation point
|
||||
marker_tokens = await self._count_tokens(self.TRUNCATION_MARKER, model)
|
||||
available_tokens = max_tokens - marker_tokens
|
||||
available_tokens = max(0, max_tokens - marker_tokens)
|
||||
|
||||
# Estimate characters per token
|
||||
chars_per_token = len(content) / await self._count_tokens(content, model)
|
||||
# Edge case: if no tokens available for content, return just the marker
|
||||
if available_tokens <= 0:
|
||||
return self.TRUNCATION_MARKER
|
||||
|
||||
# Estimate characters per token (guard against division by zero)
|
||||
content_tokens = await self._count_tokens(content, model)
|
||||
if content_tokens == 0:
|
||||
return content + self.TRUNCATION_MARKER
|
||||
chars_per_token = len(content) / content_tokens
|
||||
|
||||
# Start with estimated position
|
||||
estimated_chars = int(available_tokens * chars_per_token)
|
||||
@@ -243,7 +266,9 @@ class TruncationStrategy:
|
||||
if current_tokens <= target_tokens:
|
||||
return content
|
||||
|
||||
# Estimate characters
|
||||
# Estimate characters (guard against division by zero)
|
||||
if current_tokens == 0:
|
||||
return content
|
||||
chars_per_token = len(content) / current_tokens
|
||||
estimated_chars = int(target_tokens * chars_per_token)
|
||||
|
||||
|
||||
@@ -104,9 +104,21 @@ class ContextSettings(BaseSettings):
|
||||
le=1.0,
|
||||
description="Compress when budget usage exceeds this percentage",
|
||||
)
|
||||
truncation_suffix: str = Field(
|
||||
default="... [truncated]",
|
||||
description="Suffix to add when truncating content",
|
||||
truncation_marker: str = Field(
|
||||
default="\n\n[...content truncated...]\n\n",
|
||||
description="Marker text to insert where content was truncated",
|
||||
)
|
||||
truncation_preserve_ratio: float = Field(
|
||||
default=0.7,
|
||||
ge=0.1,
|
||||
le=0.9,
|
||||
description="Ratio of content to preserve from start in middle truncation (0.7 = 70% start, 30% end)",
|
||||
)
|
||||
truncation_min_content_length: int = Field(
|
||||
default=100,
|
||||
ge=10,
|
||||
le=1000,
|
||||
description="Minimum content length in characters before truncation applies",
|
||||
)
|
||||
summary_model_group: str = Field(
|
||||
default="fast",
|
||||
@@ -128,6 +140,12 @@ class ContextSettings(BaseSettings):
|
||||
default="ctx",
|
||||
description="Redis key prefix for context cache",
|
||||
)
|
||||
cache_memory_max_items: int = Field(
|
||||
default=1000,
|
||||
ge=100,
|
||||
le=100000,
|
||||
description="Maximum items in memory fallback cache when Redis unavailable",
|
||||
)
|
||||
|
||||
# Performance settings
|
||||
max_assembly_time_ms: int = Field(
|
||||
@@ -165,6 +183,28 @@ class ContextSettings(BaseSettings):
|
||||
description="Minimum relevance score for knowledge",
|
||||
)
|
||||
|
||||
# Relevance scoring settings
|
||||
relevance_keyword_fallback_weight: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Maximum score for keyword-based fallback scoring (when semantic unavailable)",
|
||||
)
|
||||
relevance_semantic_max_chars: int = Field(
|
||||
default=2000,
|
||||
ge=100,
|
||||
le=10000,
|
||||
description="Maximum content length in chars for semantic similarity computation",
|
||||
)
|
||||
|
||||
# Diversity/ranking settings
|
||||
diversity_max_per_source: int = Field(
|
||||
default=3,
|
||||
ge=1,
|
||||
le=20,
|
||||
description="Maximum contexts from the same source in diversity reranking",
|
||||
)
|
||||
|
||||
# Conversation history settings
|
||||
conversation_max_turns: int = Field(
|
||||
default=20,
|
||||
@@ -253,11 +293,15 @@ class ContextSettings(BaseSettings):
|
||||
"compression": {
|
||||
"threshold": self.compression_threshold,
|
||||
"summary_model_group": self.summary_model_group,
|
||||
"truncation_marker": self.truncation_marker,
|
||||
"truncation_preserve_ratio": self.truncation_preserve_ratio,
|
||||
"truncation_min_content_length": self.truncation_min_content_length,
|
||||
},
|
||||
"cache": {
|
||||
"enabled": self.cache_enabled,
|
||||
"ttl_seconds": self.cache_ttl_seconds,
|
||||
"prefix": self.cache_prefix,
|
||||
"memory_max_items": self.cache_memory_max_items,
|
||||
},
|
||||
"performance": {
|
||||
"max_assembly_time_ms": self.max_assembly_time_ms,
|
||||
@@ -269,6 +313,13 @@ class ContextSettings(BaseSettings):
|
||||
"max_results": self.knowledge_max_results,
|
||||
"min_score": self.knowledge_min_score,
|
||||
},
|
||||
"relevance": {
|
||||
"keyword_fallback_weight": self.relevance_keyword_fallback_weight,
|
||||
"semantic_max_chars": self.relevance_semantic_max_chars,
|
||||
},
|
||||
"diversity": {
|
||||
"max_per_source": self.diversity_max_per_source,
|
||||
},
|
||||
"conversation": {
|
||||
"max_turns": self.conversation_max_turns,
|
||||
"recent_priority": self.conversation_recent_priority,
|
||||
|
||||
@@ -214,6 +214,7 @@ class ContextEngine:
|
||||
contexts.extend(custom_contexts)
|
||||
|
||||
# Check cache if enabled
|
||||
fingerprint: str | None = None
|
||||
if use_cache and self._cache.is_enabled:
|
||||
fingerprint = self._cache.compute_fingerprint(contexts, query, model)
|
||||
cached = await self._cache.get_assembled(fingerprint)
|
||||
@@ -232,9 +233,8 @@ class ContextEngine:
|
||||
format_output=format_output,
|
||||
)
|
||||
|
||||
# Cache result if enabled
|
||||
if use_cache and self._cache.is_enabled:
|
||||
fingerprint = self._cache.compute_fingerprint(contexts, query, model)
|
||||
# Cache result if enabled (reuse fingerprint computed above)
|
||||
if use_cache and self._cache.is_enabled and fingerprint is not None:
|
||||
await self._cache.set_assembled(fingerprint, result)
|
||||
|
||||
return result
|
||||
@@ -275,7 +275,8 @@ class ContextEngine:
|
||||
)
|
||||
|
||||
contexts = []
|
||||
for chunk in result.data.get("results", []):
|
||||
results = result.data.get("results", []) if isinstance(result.data, dict) else []
|
||||
for chunk in results:
|
||||
contexts.append(
|
||||
KnowledgeContext(
|
||||
content=chunk.get("content", ""),
|
||||
|
||||
@@ -9,6 +9,7 @@ from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ..budget import TokenBudget, TokenCalculator
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
from ..scoring.composite import CompositeScorer, ScoredContext
|
||||
from ..types import BaseContext
|
||||
|
||||
@@ -45,6 +46,7 @@ class ContextRanker:
|
||||
self,
|
||||
scorer: CompositeScorer | None = None,
|
||||
calculator: TokenCalculator | None = None,
|
||||
settings: ContextSettings | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize context ranker.
|
||||
@@ -52,7 +54,9 @@ class ContextRanker:
|
||||
Args:
|
||||
scorer: Composite scorer for scoring contexts
|
||||
calculator: Token calculator for counting tokens
|
||||
settings: Context settings (uses global if None)
|
||||
"""
|
||||
self._settings = settings or get_context_settings()
|
||||
self._scorer = scorer or CompositeScorer()
|
||||
self._calculator = calculator or TokenCalculator()
|
||||
|
||||
@@ -226,16 +230,32 @@ class ContextRanker:
|
||||
"""
|
||||
Ensure all contexts have token counts.
|
||||
|
||||
Counts tokens in parallel for contexts that don't have counts.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to check
|
||||
model: Model for token counting
|
||||
"""
|
||||
for context in contexts:
|
||||
if context.token_count is None:
|
||||
count = await self._calculator.count_tokens(
|
||||
context.content, model
|
||||
)
|
||||
context.token_count = count
|
||||
import asyncio
|
||||
|
||||
# Find contexts needing counts
|
||||
contexts_needing_counts = [
|
||||
ctx for ctx in contexts if ctx.token_count is None
|
||||
]
|
||||
|
||||
if not contexts_needing_counts:
|
||||
return
|
||||
|
||||
# Count all in parallel
|
||||
tasks = [
|
||||
self._calculator.count_tokens(ctx.content, model)
|
||||
for ctx in contexts_needing_counts
|
||||
]
|
||||
counts = await asyncio.gather(*tasks)
|
||||
|
||||
# Assign counts back
|
||||
for ctx, count in zip(contexts_needing_counts, counts):
|
||||
ctx.token_count = count
|
||||
|
||||
def _count_by_type(
|
||||
self, scored_contexts: list[ScoredContext]
|
||||
@@ -255,7 +275,7 @@ class ContextRanker:
|
||||
async def rerank_for_diversity(
|
||||
self,
|
||||
scored_contexts: list[ScoredContext],
|
||||
max_per_source: int = 3,
|
||||
max_per_source: int | None = None,
|
||||
) -> list[ScoredContext]:
|
||||
"""
|
||||
Rerank to ensure source diversity.
|
||||
@@ -264,11 +284,18 @@ class ContextRanker:
|
||||
|
||||
Args:
|
||||
scored_contexts: Already scored contexts
|
||||
max_per_source: Maximum items per source
|
||||
max_per_source: Maximum items per source (uses settings if None)
|
||||
|
||||
Returns:
|
||||
Reranked contexts
|
||||
"""
|
||||
# Use provided value or fall back to settings
|
||||
effective_max = (
|
||||
max_per_source
|
||||
if max_per_source is not None
|
||||
else self._settings.diversity_max_per_source
|
||||
)
|
||||
|
||||
source_counts: dict[str, int] = {}
|
||||
result: list[ScoredContext] = []
|
||||
deferred: list[ScoredContext] = []
|
||||
@@ -277,7 +304,7 @@ class ContextRanker:
|
||||
source = sc.context.source
|
||||
current_count = source_counts.get(source, 0)
|
||||
|
||||
if current_count < max_per_source:
|
||||
if current_count < effective_max:
|
||||
result.append(sc)
|
||||
source_counts[source] = current_count + 1
|
||||
else:
|
||||
|
||||
@@ -8,6 +8,7 @@ import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
from ..types import BaseContext
|
||||
@@ -89,6 +90,11 @@ class CompositeScorer:
|
||||
self._recency_scorer = RecencyScorer(weight=self._recency_weight)
|
||||
self._priority_scorer = PriorityScorer(weight=self._priority_weight)
|
||||
|
||||
# Per-context locks to prevent race conditions during parallel scoring
|
||||
# Uses WeakValueDictionary so locks are garbage collected when not in use
|
||||
self._context_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
|
||||
self._locks_lock = asyncio.Lock() # Lock to protect _context_locks access
|
||||
|
||||
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||
"""Set MCP manager for semantic scoring."""
|
||||
self._relevance_scorer.set_mcp_manager(mcp_manager)
|
||||
@@ -128,6 +134,38 @@ class CompositeScorer:
|
||||
self._priority_weight = max(0.0, min(1.0, priority))
|
||||
self._priority_scorer.weight = self._priority_weight
|
||||
|
||||
async def _get_context_lock(self, context_id: str) -> asyncio.Lock:
|
||||
"""
|
||||
Get or create a lock for a specific context.
|
||||
|
||||
Thread-safe access to per-context locks prevents race conditions
|
||||
when the same context is scored concurrently.
|
||||
|
||||
Args:
|
||||
context_id: The context ID to get a lock for
|
||||
|
||||
Returns:
|
||||
asyncio.Lock for the context
|
||||
"""
|
||||
# Fast path: check if lock exists without acquiring main lock
|
||||
if context_id in self._context_locks:
|
||||
lock = self._context_locks.get(context_id)
|
||||
if lock is not None:
|
||||
return lock
|
||||
|
||||
# Slow path: create lock while holding main lock
|
||||
async with self._locks_lock:
|
||||
# Double-check after acquiring lock
|
||||
if context_id in self._context_locks:
|
||||
lock = self._context_locks.get(context_id)
|
||||
if lock is not None:
|
||||
return lock
|
||||
|
||||
# Create new lock
|
||||
new_lock = asyncio.Lock()
|
||||
self._context_locks[context_id] = new_lock
|
||||
return new_lock
|
||||
|
||||
async def score(
|
||||
self,
|
||||
context: BaseContext,
|
||||
@@ -157,6 +195,9 @@ class CompositeScorer:
|
||||
"""
|
||||
Compute composite score with individual scores.
|
||||
|
||||
Uses per-context locking to prevent race conditions when the same
|
||||
context is scored concurrently in parallel scoring operations.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query to score against
|
||||
@@ -165,46 +206,50 @@ class CompositeScorer:
|
||||
Returns:
|
||||
ScoredContext with all scores
|
||||
"""
|
||||
# Check if context already has a score
|
||||
if context._score is not None:
|
||||
return ScoredContext(
|
||||
context=context,
|
||||
composite_score=context._score,
|
||||
# Get lock for this specific context to prevent race conditions
|
||||
context_lock = await self._get_context_lock(context.id)
|
||||
|
||||
async with context_lock:
|
||||
# Check if context already has a score (inside lock to prevent races)
|
||||
if context._score is not None:
|
||||
return ScoredContext(
|
||||
context=context,
|
||||
composite_score=context._score,
|
||||
)
|
||||
|
||||
# Compute individual scores in parallel
|
||||
relevance_task = self._relevance_scorer.score(context, query, **kwargs)
|
||||
recency_task = self._recency_scorer.score(context, query, **kwargs)
|
||||
priority_task = self._priority_scorer.score(context, query, **kwargs)
|
||||
|
||||
relevance_score, recency_score, priority_score = await asyncio.gather(
|
||||
relevance_task, recency_task, priority_task
|
||||
)
|
||||
|
||||
# Compute individual scores in parallel
|
||||
relevance_task = self._relevance_scorer.score(context, query, **kwargs)
|
||||
recency_task = self._recency_scorer.score(context, query, **kwargs)
|
||||
priority_task = self._priority_scorer.score(context, query, **kwargs)
|
||||
# Compute weighted composite
|
||||
total_weight = (
|
||||
self._relevance_weight + self._recency_weight + self._priority_weight
|
||||
)
|
||||
|
||||
relevance_score, recency_score, priority_score = await asyncio.gather(
|
||||
relevance_task, recency_task, priority_task
|
||||
)
|
||||
if total_weight > 0:
|
||||
composite = (
|
||||
relevance_score * self._relevance_weight
|
||||
+ recency_score * self._recency_weight
|
||||
+ priority_score * self._priority_weight
|
||||
) / total_weight
|
||||
else:
|
||||
composite = 0.0
|
||||
|
||||
# Compute weighted composite
|
||||
total_weight = (
|
||||
self._relevance_weight + self._recency_weight + self._priority_weight
|
||||
)
|
||||
# Cache the score on the context (now safe - inside lock)
|
||||
context._score = composite
|
||||
|
||||
if total_weight > 0:
|
||||
composite = (
|
||||
relevance_score * self._relevance_weight
|
||||
+ recency_score * self._recency_weight
|
||||
+ priority_score * self._priority_weight
|
||||
) / total_weight
|
||||
else:
|
||||
composite = 0.0
|
||||
|
||||
# Cache the score on the context
|
||||
context._score = composite
|
||||
|
||||
return ScoredContext(
|
||||
context=context,
|
||||
composite_score=composite,
|
||||
relevance_score=relevance_score,
|
||||
recency_score=recency_score,
|
||||
priority_score=priority_score,
|
||||
)
|
||||
return ScoredContext(
|
||||
context=context,
|
||||
composite_score=composite,
|
||||
relevance_score=relevance_score,
|
||||
recency_score=recency_score,
|
||||
priority_score=priority_score,
|
||||
)
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
|
||||
@@ -9,6 +9,7 @@ import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
from ..types import BaseContext, KnowledgeContext
|
||||
from .base import BaseScorer
|
||||
|
||||
@@ -32,7 +33,9 @@ class RelevanceScorer(BaseScorer):
|
||||
self,
|
||||
mcp_manager: "MCPClientManager | None" = None,
|
||||
weight: float = 1.0,
|
||||
keyword_fallback_weight: float = 0.5,
|
||||
keyword_fallback_weight: float | None = None,
|
||||
semantic_max_chars: int | None = None,
|
||||
settings: ContextSettings | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize relevance scorer.
|
||||
@@ -40,11 +43,25 @@ class RelevanceScorer(BaseScorer):
|
||||
Args:
|
||||
mcp_manager: MCP manager for Knowledge Base calls
|
||||
weight: Scorer weight for composite scoring
|
||||
keyword_fallback_weight: Max score for keyword-based fallback
|
||||
keyword_fallback_weight: Max score for keyword-based fallback (overrides settings)
|
||||
semantic_max_chars: Max content length for semantic similarity (overrides settings)
|
||||
settings: Context settings (uses global if None)
|
||||
"""
|
||||
super().__init__(weight)
|
||||
self._settings = settings or get_context_settings()
|
||||
self._mcp = mcp_manager
|
||||
self._keyword_fallback_weight = keyword_fallback_weight
|
||||
|
||||
# Use provided values or fall back to settings
|
||||
self._keyword_fallback_weight = (
|
||||
keyword_fallback_weight
|
||||
if keyword_fallback_weight is not None
|
||||
else self._settings.relevance_keyword_fallback_weight
|
||||
)
|
||||
self._semantic_max_chars = (
|
||||
semantic_max_chars
|
||||
if semantic_max_chars is not None
|
||||
else self._settings.relevance_semantic_max_chars
|
||||
)
|
||||
|
||||
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||
"""Set MCP manager for semantic scoring."""
|
||||
@@ -112,11 +129,11 @@ class RelevanceScorer(BaseScorer):
|
||||
tool="compute_similarity",
|
||||
args={
|
||||
"text1": query,
|
||||
"text2": context.content[:2000], # Limit content length
|
||||
"text2": context.content[: self._semantic_max_chars], # Limit content length
|
||||
},
|
||||
)
|
||||
|
||||
if result.success and result.data:
|
||||
if result.success and isinstance(result.data, dict):
|
||||
similarity = result.data.get("similarity")
|
||||
if similarity is not None:
|
||||
return self.normalize_score(float(similarity))
|
||||
@@ -171,7 +188,7 @@ class RelevanceScorer(BaseScorer):
|
||||
**kwargs: Any,
|
||||
) -> list[float]:
|
||||
"""
|
||||
Score multiple contexts.
|
||||
Score multiple contexts in parallel.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to score
|
||||
@@ -181,8 +198,10 @@ class RelevanceScorer(BaseScorer):
|
||||
Returns:
|
||||
List of scores (same order as input)
|
||||
"""
|
||||
scores = []
|
||||
for context in contexts:
|
||||
score = await self.score(context, query, **kwargs)
|
||||
scores.append(score)
|
||||
return scores
|
||||
import asyncio
|
||||
|
||||
if not contexts:
|
||||
return []
|
||||
|
||||
tasks = [self.score(context, query, **kwargs) for context in contexts]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
Reference in New Issue
Block a user