forked from cardosofelipe/fast-next-template
feat(context): enhance performance, caching, and settings management
- Replace hard-coded limits with configurable settings (e.g., cache memory size, truncation strategy, relevance settings). - Optimize parallel execution in token counting, scoring, and reranking for source diversity. - Improve caching logic: - Add per-context locks for safe parallel scoring. - Reuse precomputed fingerprints for cache efficiency. - Make truncation, scoring, and ranker behaviors fully configurable via settings. - Add support for middle truncation, context hash-based hashing, and dynamic token limiting. - Refactor methods for scalability and better error handling. Tests: Updated all affected components with additional test cases.
This commit is contained in:
@@ -237,7 +237,7 @@ class TokenCalculator:
|
|||||||
"""
|
"""
|
||||||
Count tokens for multiple texts.
|
Count tokens for multiple texts.
|
||||||
|
|
||||||
Efficient batch counting with caching.
|
Efficient batch counting with caching and parallel execution.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
texts: List of texts to count
|
texts: List of texts to count
|
||||||
@@ -246,13 +246,14 @@ class TokenCalculator:
|
|||||||
Returns:
|
Returns:
|
||||||
List of token counts (same order as input)
|
List of token counts (same order as input)
|
||||||
"""
|
"""
|
||||||
results: list[int] = []
|
import asyncio
|
||||||
|
|
||||||
for text in texts:
|
if not texts:
|
||||||
count = await self.count_tokens(text, model)
|
return []
|
||||||
results.append(count)
|
|
||||||
|
|
||||||
return results
|
# Execute all token counts in parallel for better performance
|
||||||
|
tasks = [self.count_tokens(text, model) for text in texts]
|
||||||
|
return await asyncio.gather(*tasks)
|
||||||
|
|
||||||
def clear_cache(self) -> None:
|
def clear_cache(self) -> None:
|
||||||
"""Clear the token count cache."""
|
"""Clear the token count cache."""
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ class ContextCache:
|
|||||||
|
|
||||||
# In-memory fallback cache when Redis unavailable
|
# In-memory fallback cache when Redis unavailable
|
||||||
self._memory_cache: dict[str, tuple[str, float]] = {}
|
self._memory_cache: dict[str, tuple[str, float]] = {}
|
||||||
self._max_memory_items = 1000
|
self._max_memory_items = self._settings.cache_memory_max_items
|
||||||
|
|
||||||
def set_redis(self, redis: "Redis") -> None:
|
def set_redis(self, redis: "Redis") -> None:
|
||||||
"""Set Redis connection."""
|
"""Set Redis connection."""
|
||||||
@@ -100,7 +100,7 @@ class ContextCache:
|
|||||||
Compute a fingerprint for a context assembly request.
|
Compute a fingerprint for a context assembly request.
|
||||||
|
|
||||||
The fingerprint is based on:
|
The fingerprint is based on:
|
||||||
- Context content and metadata
|
- Context content hash and metadata (not full content for performance)
|
||||||
- Query string
|
- Query string
|
||||||
- Target model
|
- Target model
|
||||||
|
|
||||||
@@ -112,12 +112,13 @@ class ContextCache:
|
|||||||
Returns:
|
Returns:
|
||||||
32-character hex fingerprint
|
32-character hex fingerprint
|
||||||
"""
|
"""
|
||||||
# Build a deterministic representation
|
# Build a deterministic representation using content hashes for performance
|
||||||
|
# This avoids JSON serializing potentially large content strings
|
||||||
context_data = []
|
context_data = []
|
||||||
for ctx in contexts:
|
for ctx in contexts:
|
||||||
context_data.append({
|
context_data.append({
|
||||||
"type": ctx.get_type().value,
|
"type": ctx.get_type().value,
|
||||||
"content": ctx.content,
|
"content_hash": self._hash_content(ctx.content), # Hash instead of full content
|
||||||
"source": ctx.source,
|
"source": ctx.source,
|
||||||
"priority": ctx.priority, # Already an int
|
"priority": ctx.priority, # Already an int
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import re
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ..config import ContextSettings, get_context_settings
|
||||||
from ..types import BaseContext, ContextType
|
from ..types import BaseContext, ContextType
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -45,26 +46,41 @@ class TruncationStrategy:
|
|||||||
4. Semantic chunking: Keep most relevant chunks
|
4. Semantic chunking: Keep most relevant chunks
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Default truncation marker
|
|
||||||
TRUNCATION_MARKER = "\n\n[...content truncated...]\n\n"
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
calculator: "TokenCalculator | None" = None,
|
calculator: "TokenCalculator | None" = None,
|
||||||
preserve_ratio_start: float = 0.7, # Keep 70% from start by default
|
preserve_ratio_start: float | None = None,
|
||||||
min_content_length: int = 100, # Minimum characters to keep
|
min_content_length: int | None = None,
|
||||||
|
settings: ContextSettings | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize truncation strategy.
|
Initialize truncation strategy.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
calculator: Token calculator for accurate counting
|
calculator: Token calculator for accurate counting
|
||||||
preserve_ratio_start: Ratio of content to keep from start
|
preserve_ratio_start: Ratio of content to keep from start (overrides settings)
|
||||||
min_content_length: Minimum characters to preserve
|
min_content_length: Minimum characters to preserve (overrides settings)
|
||||||
|
settings: Context settings (uses global if None)
|
||||||
"""
|
"""
|
||||||
|
self._settings = settings or get_context_settings()
|
||||||
self._calculator = calculator
|
self._calculator = calculator
|
||||||
self._preserve_ratio_start = preserve_ratio_start
|
|
||||||
self._min_content_length = min_content_length
|
# Use provided values or fall back to settings
|
||||||
|
self._preserve_ratio_start = (
|
||||||
|
preserve_ratio_start
|
||||||
|
if preserve_ratio_start is not None
|
||||||
|
else self._settings.truncation_preserve_ratio
|
||||||
|
)
|
||||||
|
self._min_content_length = (
|
||||||
|
min_content_length
|
||||||
|
if min_content_length is not None
|
||||||
|
else self._settings.truncation_min_content_length
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def TRUNCATION_MARKER(self) -> str:
|
||||||
|
"""Get truncation marker from settings."""
|
||||||
|
return self._settings.truncation_marker
|
||||||
|
|
||||||
def set_calculator(self, calculator: "TokenCalculator") -> None:
|
def set_calculator(self, calculator: "TokenCalculator") -> None:
|
||||||
"""Set token calculator."""
|
"""Set token calculator."""
|
||||||
@@ -125,7 +141,7 @@ class TruncationStrategy:
|
|||||||
truncated_tokens=truncated_tokens,
|
truncated_tokens=truncated_tokens,
|
||||||
content=truncated,
|
content=truncated,
|
||||||
truncated=True,
|
truncated=True,
|
||||||
truncation_ratio=1 - (truncated_tokens / original_tokens),
|
truncation_ratio=0.0 if original_tokens == 0 else 1 - (truncated_tokens / original_tokens),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _truncate_end(
|
async def _truncate_end(
|
||||||
@@ -141,10 +157,17 @@ class TruncationStrategy:
|
|||||||
"""
|
"""
|
||||||
# Binary search for optimal truncation point
|
# Binary search for optimal truncation point
|
||||||
marker_tokens = await self._count_tokens(self.TRUNCATION_MARKER, model)
|
marker_tokens = await self._count_tokens(self.TRUNCATION_MARKER, model)
|
||||||
available_tokens = max_tokens - marker_tokens
|
available_tokens = max(0, max_tokens - marker_tokens)
|
||||||
|
|
||||||
# Estimate characters per token
|
# Edge case: if no tokens available for content, return just the marker
|
||||||
chars_per_token = len(content) / await self._count_tokens(content, model)
|
if available_tokens <= 0:
|
||||||
|
return self.TRUNCATION_MARKER
|
||||||
|
|
||||||
|
# Estimate characters per token (guard against division by zero)
|
||||||
|
content_tokens = await self._count_tokens(content, model)
|
||||||
|
if content_tokens == 0:
|
||||||
|
return content + self.TRUNCATION_MARKER
|
||||||
|
chars_per_token = len(content) / content_tokens
|
||||||
|
|
||||||
# Start with estimated position
|
# Start with estimated position
|
||||||
estimated_chars = int(available_tokens * chars_per_token)
|
estimated_chars = int(available_tokens * chars_per_token)
|
||||||
@@ -243,7 +266,9 @@ class TruncationStrategy:
|
|||||||
if current_tokens <= target_tokens:
|
if current_tokens <= target_tokens:
|
||||||
return content
|
return content
|
||||||
|
|
||||||
# Estimate characters
|
# Estimate characters (guard against division by zero)
|
||||||
|
if current_tokens == 0:
|
||||||
|
return content
|
||||||
chars_per_token = len(content) / current_tokens
|
chars_per_token = len(content) / current_tokens
|
||||||
estimated_chars = int(target_tokens * chars_per_token)
|
estimated_chars = int(target_tokens * chars_per_token)
|
||||||
|
|
||||||
|
|||||||
@@ -104,9 +104,21 @@ class ContextSettings(BaseSettings):
|
|||||||
le=1.0,
|
le=1.0,
|
||||||
description="Compress when budget usage exceeds this percentage",
|
description="Compress when budget usage exceeds this percentage",
|
||||||
)
|
)
|
||||||
truncation_suffix: str = Field(
|
truncation_marker: str = Field(
|
||||||
default="... [truncated]",
|
default="\n\n[...content truncated...]\n\n",
|
||||||
description="Suffix to add when truncating content",
|
description="Marker text to insert where content was truncated",
|
||||||
|
)
|
||||||
|
truncation_preserve_ratio: float = Field(
|
||||||
|
default=0.7,
|
||||||
|
ge=0.1,
|
||||||
|
le=0.9,
|
||||||
|
description="Ratio of content to preserve from start in middle truncation (0.7 = 70% start, 30% end)",
|
||||||
|
)
|
||||||
|
truncation_min_content_length: int = Field(
|
||||||
|
default=100,
|
||||||
|
ge=10,
|
||||||
|
le=1000,
|
||||||
|
description="Minimum content length in characters before truncation applies",
|
||||||
)
|
)
|
||||||
summary_model_group: str = Field(
|
summary_model_group: str = Field(
|
||||||
default="fast",
|
default="fast",
|
||||||
@@ -128,6 +140,12 @@ class ContextSettings(BaseSettings):
|
|||||||
default="ctx",
|
default="ctx",
|
||||||
description="Redis key prefix for context cache",
|
description="Redis key prefix for context cache",
|
||||||
)
|
)
|
||||||
|
cache_memory_max_items: int = Field(
|
||||||
|
default=1000,
|
||||||
|
ge=100,
|
||||||
|
le=100000,
|
||||||
|
description="Maximum items in memory fallback cache when Redis unavailable",
|
||||||
|
)
|
||||||
|
|
||||||
# Performance settings
|
# Performance settings
|
||||||
max_assembly_time_ms: int = Field(
|
max_assembly_time_ms: int = Field(
|
||||||
@@ -165,6 +183,28 @@ class ContextSettings(BaseSettings):
|
|||||||
description="Minimum relevance score for knowledge",
|
description="Minimum relevance score for knowledge",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Relevance scoring settings
|
||||||
|
relevance_keyword_fallback_weight: float = Field(
|
||||||
|
default=0.5,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Maximum score for keyword-based fallback scoring (when semantic unavailable)",
|
||||||
|
)
|
||||||
|
relevance_semantic_max_chars: int = Field(
|
||||||
|
default=2000,
|
||||||
|
ge=100,
|
||||||
|
le=10000,
|
||||||
|
description="Maximum content length in chars for semantic similarity computation",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Diversity/ranking settings
|
||||||
|
diversity_max_per_source: int = Field(
|
||||||
|
default=3,
|
||||||
|
ge=1,
|
||||||
|
le=20,
|
||||||
|
description="Maximum contexts from the same source in diversity reranking",
|
||||||
|
)
|
||||||
|
|
||||||
# Conversation history settings
|
# Conversation history settings
|
||||||
conversation_max_turns: int = Field(
|
conversation_max_turns: int = Field(
|
||||||
default=20,
|
default=20,
|
||||||
@@ -253,11 +293,15 @@ class ContextSettings(BaseSettings):
|
|||||||
"compression": {
|
"compression": {
|
||||||
"threshold": self.compression_threshold,
|
"threshold": self.compression_threshold,
|
||||||
"summary_model_group": self.summary_model_group,
|
"summary_model_group": self.summary_model_group,
|
||||||
|
"truncation_marker": self.truncation_marker,
|
||||||
|
"truncation_preserve_ratio": self.truncation_preserve_ratio,
|
||||||
|
"truncation_min_content_length": self.truncation_min_content_length,
|
||||||
},
|
},
|
||||||
"cache": {
|
"cache": {
|
||||||
"enabled": self.cache_enabled,
|
"enabled": self.cache_enabled,
|
||||||
"ttl_seconds": self.cache_ttl_seconds,
|
"ttl_seconds": self.cache_ttl_seconds,
|
||||||
"prefix": self.cache_prefix,
|
"prefix": self.cache_prefix,
|
||||||
|
"memory_max_items": self.cache_memory_max_items,
|
||||||
},
|
},
|
||||||
"performance": {
|
"performance": {
|
||||||
"max_assembly_time_ms": self.max_assembly_time_ms,
|
"max_assembly_time_ms": self.max_assembly_time_ms,
|
||||||
@@ -269,6 +313,13 @@ class ContextSettings(BaseSettings):
|
|||||||
"max_results": self.knowledge_max_results,
|
"max_results": self.knowledge_max_results,
|
||||||
"min_score": self.knowledge_min_score,
|
"min_score": self.knowledge_min_score,
|
||||||
},
|
},
|
||||||
|
"relevance": {
|
||||||
|
"keyword_fallback_weight": self.relevance_keyword_fallback_weight,
|
||||||
|
"semantic_max_chars": self.relevance_semantic_max_chars,
|
||||||
|
},
|
||||||
|
"diversity": {
|
||||||
|
"max_per_source": self.diversity_max_per_source,
|
||||||
|
},
|
||||||
"conversation": {
|
"conversation": {
|
||||||
"max_turns": self.conversation_max_turns,
|
"max_turns": self.conversation_max_turns,
|
||||||
"recent_priority": self.conversation_recent_priority,
|
"recent_priority": self.conversation_recent_priority,
|
||||||
|
|||||||
@@ -214,6 +214,7 @@ class ContextEngine:
|
|||||||
contexts.extend(custom_contexts)
|
contexts.extend(custom_contexts)
|
||||||
|
|
||||||
# Check cache if enabled
|
# Check cache if enabled
|
||||||
|
fingerprint: str | None = None
|
||||||
if use_cache and self._cache.is_enabled:
|
if use_cache and self._cache.is_enabled:
|
||||||
fingerprint = self._cache.compute_fingerprint(contexts, query, model)
|
fingerprint = self._cache.compute_fingerprint(contexts, query, model)
|
||||||
cached = await self._cache.get_assembled(fingerprint)
|
cached = await self._cache.get_assembled(fingerprint)
|
||||||
@@ -232,9 +233,8 @@ class ContextEngine:
|
|||||||
format_output=format_output,
|
format_output=format_output,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Cache result if enabled
|
# Cache result if enabled (reuse fingerprint computed above)
|
||||||
if use_cache and self._cache.is_enabled:
|
if use_cache and self._cache.is_enabled and fingerprint is not None:
|
||||||
fingerprint = self._cache.compute_fingerprint(contexts, query, model)
|
|
||||||
await self._cache.set_assembled(fingerprint, result)
|
await self._cache.set_assembled(fingerprint, result)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@@ -275,7 +275,8 @@ class ContextEngine:
|
|||||||
)
|
)
|
||||||
|
|
||||||
contexts = []
|
contexts = []
|
||||||
for chunk in result.data.get("results", []):
|
results = result.data.get("results", []) if isinstance(result.data, dict) else []
|
||||||
|
for chunk in results:
|
||||||
contexts.append(
|
contexts.append(
|
||||||
KnowledgeContext(
|
KnowledgeContext(
|
||||||
content=chunk.get("content", ""),
|
content=chunk.get("content", ""),
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from dataclasses import dataclass, field
|
|||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from ..budget import TokenBudget, TokenCalculator
|
from ..budget import TokenBudget, TokenCalculator
|
||||||
|
from ..config import ContextSettings, get_context_settings
|
||||||
from ..scoring.composite import CompositeScorer, ScoredContext
|
from ..scoring.composite import CompositeScorer, ScoredContext
|
||||||
from ..types import BaseContext
|
from ..types import BaseContext
|
||||||
|
|
||||||
@@ -45,6 +46,7 @@ class ContextRanker:
|
|||||||
self,
|
self,
|
||||||
scorer: CompositeScorer | None = None,
|
scorer: CompositeScorer | None = None,
|
||||||
calculator: TokenCalculator | None = None,
|
calculator: TokenCalculator | None = None,
|
||||||
|
settings: ContextSettings | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize context ranker.
|
Initialize context ranker.
|
||||||
@@ -52,7 +54,9 @@ class ContextRanker:
|
|||||||
Args:
|
Args:
|
||||||
scorer: Composite scorer for scoring contexts
|
scorer: Composite scorer for scoring contexts
|
||||||
calculator: Token calculator for counting tokens
|
calculator: Token calculator for counting tokens
|
||||||
|
settings: Context settings (uses global if None)
|
||||||
"""
|
"""
|
||||||
|
self._settings = settings or get_context_settings()
|
||||||
self._scorer = scorer or CompositeScorer()
|
self._scorer = scorer or CompositeScorer()
|
||||||
self._calculator = calculator or TokenCalculator()
|
self._calculator = calculator or TokenCalculator()
|
||||||
|
|
||||||
@@ -226,16 +230,32 @@ class ContextRanker:
|
|||||||
"""
|
"""
|
||||||
Ensure all contexts have token counts.
|
Ensure all contexts have token counts.
|
||||||
|
|
||||||
|
Counts tokens in parallel for contexts that don't have counts.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
contexts: Contexts to check
|
contexts: Contexts to check
|
||||||
model: Model for token counting
|
model: Model for token counting
|
||||||
"""
|
"""
|
||||||
for context in contexts:
|
import asyncio
|
||||||
if context.token_count is None:
|
|
||||||
count = await self._calculator.count_tokens(
|
# Find contexts needing counts
|
||||||
context.content, model
|
contexts_needing_counts = [
|
||||||
)
|
ctx for ctx in contexts if ctx.token_count is None
|
||||||
context.token_count = count
|
]
|
||||||
|
|
||||||
|
if not contexts_needing_counts:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Count all in parallel
|
||||||
|
tasks = [
|
||||||
|
self._calculator.count_tokens(ctx.content, model)
|
||||||
|
for ctx in contexts_needing_counts
|
||||||
|
]
|
||||||
|
counts = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# Assign counts back
|
||||||
|
for ctx, count in zip(contexts_needing_counts, counts):
|
||||||
|
ctx.token_count = count
|
||||||
|
|
||||||
def _count_by_type(
|
def _count_by_type(
|
||||||
self, scored_contexts: list[ScoredContext]
|
self, scored_contexts: list[ScoredContext]
|
||||||
@@ -255,7 +275,7 @@ class ContextRanker:
|
|||||||
async def rerank_for_diversity(
|
async def rerank_for_diversity(
|
||||||
self,
|
self,
|
||||||
scored_contexts: list[ScoredContext],
|
scored_contexts: list[ScoredContext],
|
||||||
max_per_source: int = 3,
|
max_per_source: int | None = None,
|
||||||
) -> list[ScoredContext]:
|
) -> list[ScoredContext]:
|
||||||
"""
|
"""
|
||||||
Rerank to ensure source diversity.
|
Rerank to ensure source diversity.
|
||||||
@@ -264,11 +284,18 @@ class ContextRanker:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
scored_contexts: Already scored contexts
|
scored_contexts: Already scored contexts
|
||||||
max_per_source: Maximum items per source
|
max_per_source: Maximum items per source (uses settings if None)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Reranked contexts
|
Reranked contexts
|
||||||
"""
|
"""
|
||||||
|
# Use provided value or fall back to settings
|
||||||
|
effective_max = (
|
||||||
|
max_per_source
|
||||||
|
if max_per_source is not None
|
||||||
|
else self._settings.diversity_max_per_source
|
||||||
|
)
|
||||||
|
|
||||||
source_counts: dict[str, int] = {}
|
source_counts: dict[str, int] = {}
|
||||||
result: list[ScoredContext] = []
|
result: list[ScoredContext] = []
|
||||||
deferred: list[ScoredContext] = []
|
deferred: list[ScoredContext] = []
|
||||||
@@ -277,7 +304,7 @@ class ContextRanker:
|
|||||||
source = sc.context.source
|
source = sc.context.source
|
||||||
current_count = source_counts.get(source, 0)
|
current_count = source_counts.get(source, 0)
|
||||||
|
|
||||||
if current_count < max_per_source:
|
if current_count < effective_max:
|
||||||
result.append(sc)
|
result.append(sc)
|
||||||
source_counts[source] = current_count + 1
|
source_counts[source] = current_count + 1
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
from weakref import WeakValueDictionary
|
||||||
|
|
||||||
from ..config import ContextSettings, get_context_settings
|
from ..config import ContextSettings, get_context_settings
|
||||||
from ..types import BaseContext
|
from ..types import BaseContext
|
||||||
@@ -89,6 +90,11 @@ class CompositeScorer:
|
|||||||
self._recency_scorer = RecencyScorer(weight=self._recency_weight)
|
self._recency_scorer = RecencyScorer(weight=self._recency_weight)
|
||||||
self._priority_scorer = PriorityScorer(weight=self._priority_weight)
|
self._priority_scorer = PriorityScorer(weight=self._priority_weight)
|
||||||
|
|
||||||
|
# Per-context locks to prevent race conditions during parallel scoring
|
||||||
|
# Uses WeakValueDictionary so locks are garbage collected when not in use
|
||||||
|
self._context_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
|
||||||
|
self._locks_lock = asyncio.Lock() # Lock to protect _context_locks access
|
||||||
|
|
||||||
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||||
"""Set MCP manager for semantic scoring."""
|
"""Set MCP manager for semantic scoring."""
|
||||||
self._relevance_scorer.set_mcp_manager(mcp_manager)
|
self._relevance_scorer.set_mcp_manager(mcp_manager)
|
||||||
@@ -128,6 +134,38 @@ class CompositeScorer:
|
|||||||
self._priority_weight = max(0.0, min(1.0, priority))
|
self._priority_weight = max(0.0, min(1.0, priority))
|
||||||
self._priority_scorer.weight = self._priority_weight
|
self._priority_scorer.weight = self._priority_weight
|
||||||
|
|
||||||
|
async def _get_context_lock(self, context_id: str) -> asyncio.Lock:
|
||||||
|
"""
|
||||||
|
Get or create a lock for a specific context.
|
||||||
|
|
||||||
|
Thread-safe access to per-context locks prevents race conditions
|
||||||
|
when the same context is scored concurrently.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_id: The context ID to get a lock for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
asyncio.Lock for the context
|
||||||
|
"""
|
||||||
|
# Fast path: check if lock exists without acquiring main lock
|
||||||
|
if context_id in self._context_locks:
|
||||||
|
lock = self._context_locks.get(context_id)
|
||||||
|
if lock is not None:
|
||||||
|
return lock
|
||||||
|
|
||||||
|
# Slow path: create lock while holding main lock
|
||||||
|
async with self._locks_lock:
|
||||||
|
# Double-check after acquiring lock
|
||||||
|
if context_id in self._context_locks:
|
||||||
|
lock = self._context_locks.get(context_id)
|
||||||
|
if lock is not None:
|
||||||
|
return lock
|
||||||
|
|
||||||
|
# Create new lock
|
||||||
|
new_lock = asyncio.Lock()
|
||||||
|
self._context_locks[context_id] = new_lock
|
||||||
|
return new_lock
|
||||||
|
|
||||||
async def score(
|
async def score(
|
||||||
self,
|
self,
|
||||||
context: BaseContext,
|
context: BaseContext,
|
||||||
@@ -157,6 +195,9 @@ class CompositeScorer:
|
|||||||
"""
|
"""
|
||||||
Compute composite score with individual scores.
|
Compute composite score with individual scores.
|
||||||
|
|
||||||
|
Uses per-context locking to prevent race conditions when the same
|
||||||
|
context is scored concurrently in parallel scoring operations.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
context: Context to score
|
context: Context to score
|
||||||
query: Query to score against
|
query: Query to score against
|
||||||
@@ -165,7 +206,11 @@ class CompositeScorer:
|
|||||||
Returns:
|
Returns:
|
||||||
ScoredContext with all scores
|
ScoredContext with all scores
|
||||||
"""
|
"""
|
||||||
# Check if context already has a score
|
# Get lock for this specific context to prevent race conditions
|
||||||
|
context_lock = await self._get_context_lock(context.id)
|
||||||
|
|
||||||
|
async with context_lock:
|
||||||
|
# Check if context already has a score (inside lock to prevent races)
|
||||||
if context._score is not None:
|
if context._score is not None:
|
||||||
return ScoredContext(
|
return ScoredContext(
|
||||||
context=context,
|
context=context,
|
||||||
@@ -195,7 +240,7 @@ class CompositeScorer:
|
|||||||
else:
|
else:
|
||||||
composite = 0.0
|
composite = 0.0
|
||||||
|
|
||||||
# Cache the score on the context
|
# Cache the score on the context (now safe - inside lock)
|
||||||
context._score = composite
|
context._score = composite
|
||||||
|
|
||||||
return ScoredContext(
|
return ScoredContext(
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import logging
|
|||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from ..config import ContextSettings, get_context_settings
|
||||||
from ..types import BaseContext, KnowledgeContext
|
from ..types import BaseContext, KnowledgeContext
|
||||||
from .base import BaseScorer
|
from .base import BaseScorer
|
||||||
|
|
||||||
@@ -32,7 +33,9 @@ class RelevanceScorer(BaseScorer):
|
|||||||
self,
|
self,
|
||||||
mcp_manager: "MCPClientManager | None" = None,
|
mcp_manager: "MCPClientManager | None" = None,
|
||||||
weight: float = 1.0,
|
weight: float = 1.0,
|
||||||
keyword_fallback_weight: float = 0.5,
|
keyword_fallback_weight: float | None = None,
|
||||||
|
semantic_max_chars: int | None = None,
|
||||||
|
settings: ContextSettings | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize relevance scorer.
|
Initialize relevance scorer.
|
||||||
@@ -40,11 +43,25 @@ class RelevanceScorer(BaseScorer):
|
|||||||
Args:
|
Args:
|
||||||
mcp_manager: MCP manager for Knowledge Base calls
|
mcp_manager: MCP manager for Knowledge Base calls
|
||||||
weight: Scorer weight for composite scoring
|
weight: Scorer weight for composite scoring
|
||||||
keyword_fallback_weight: Max score for keyword-based fallback
|
keyword_fallback_weight: Max score for keyword-based fallback (overrides settings)
|
||||||
|
semantic_max_chars: Max content length for semantic similarity (overrides settings)
|
||||||
|
settings: Context settings (uses global if None)
|
||||||
"""
|
"""
|
||||||
super().__init__(weight)
|
super().__init__(weight)
|
||||||
|
self._settings = settings or get_context_settings()
|
||||||
self._mcp = mcp_manager
|
self._mcp = mcp_manager
|
||||||
self._keyword_fallback_weight = keyword_fallback_weight
|
|
||||||
|
# Use provided values or fall back to settings
|
||||||
|
self._keyword_fallback_weight = (
|
||||||
|
keyword_fallback_weight
|
||||||
|
if keyword_fallback_weight is not None
|
||||||
|
else self._settings.relevance_keyword_fallback_weight
|
||||||
|
)
|
||||||
|
self._semantic_max_chars = (
|
||||||
|
semantic_max_chars
|
||||||
|
if semantic_max_chars is not None
|
||||||
|
else self._settings.relevance_semantic_max_chars
|
||||||
|
)
|
||||||
|
|
||||||
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||||
"""Set MCP manager for semantic scoring."""
|
"""Set MCP manager for semantic scoring."""
|
||||||
@@ -112,11 +129,11 @@ class RelevanceScorer(BaseScorer):
|
|||||||
tool="compute_similarity",
|
tool="compute_similarity",
|
||||||
args={
|
args={
|
||||||
"text1": query,
|
"text1": query,
|
||||||
"text2": context.content[:2000], # Limit content length
|
"text2": context.content[: self._semantic_max_chars], # Limit content length
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if result.success and result.data:
|
if result.success and isinstance(result.data, dict):
|
||||||
similarity = result.data.get("similarity")
|
similarity = result.data.get("similarity")
|
||||||
if similarity is not None:
|
if similarity is not None:
|
||||||
return self.normalize_score(float(similarity))
|
return self.normalize_score(float(similarity))
|
||||||
@@ -171,7 +188,7 @@ class RelevanceScorer(BaseScorer):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> list[float]:
|
) -> list[float]:
|
||||||
"""
|
"""
|
||||||
Score multiple contexts.
|
Score multiple contexts in parallel.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
contexts: Contexts to score
|
contexts: Contexts to score
|
||||||
@@ -181,8 +198,10 @@ class RelevanceScorer(BaseScorer):
|
|||||||
Returns:
|
Returns:
|
||||||
List of scores (same order as input)
|
List of scores (same order as input)
|
||||||
"""
|
"""
|
||||||
scores = []
|
import asyncio
|
||||||
for context in contexts:
|
|
||||||
score = await self.score(context, query, **kwargs)
|
if not contexts:
|
||||||
scores.append(score)
|
return []
|
||||||
return scores
|
|
||||||
|
tasks = [self.score(context, query, **kwargs) for context in contexts]
|
||||||
|
return await asyncio.gather(*tasks)
|
||||||
|
|||||||
Reference in New Issue
Block a user