feat(context): enhance performance, caching, and settings management

- Replace hard-coded limits with configurable settings (e.g., cache memory size, truncation strategy, relevance settings).
- Optimize parallel execution in token counting, scoring, and reranking for source diversity.
- Improve caching logic:
  - Add per-context locks for safe parallel scoring.
  - Reuse precomputed fingerprints for cache efficiency.
- Make truncation, scoring, and ranker behaviors fully configurable via settings.
- Add support for middle truncation, context hash-based hashing, and dynamic token limiting.
- Refactor methods for scalability and better error handling.

Tests: Updated all affected components with additional test cases.
This commit is contained in:
2026-01-04 12:37:58 +01:00
parent 6c7b72f130
commit 96e6400bd8
8 changed files with 256 additions and 86 deletions

View File

@@ -237,7 +237,7 @@ class TokenCalculator:
"""
Count tokens for multiple texts.
Efficient batch counting with caching.
Efficient batch counting with caching and parallel execution.
Args:
texts: List of texts to count
@@ -246,13 +246,14 @@ class TokenCalculator:
Returns:
List of token counts (same order as input)
"""
results: list[int] = []
import asyncio
for text in texts:
count = await self.count_tokens(text, model)
results.append(count)
if not texts:
return []
return results
# Execute all token counts in parallel for better performance
tasks = [self.count_tokens(text, model) for text in texts]
return await asyncio.gather(*tasks)
def clear_cache(self) -> None:
"""Clear the token count cache."""

View File

@@ -54,7 +54,7 @@ class ContextCache:
# In-memory fallback cache when Redis unavailable
self._memory_cache: dict[str, tuple[str, float]] = {}
self._max_memory_items = 1000
self._max_memory_items = self._settings.cache_memory_max_items
def set_redis(self, redis: "Redis") -> None:
"""Set Redis connection."""
@@ -100,7 +100,7 @@ class ContextCache:
Compute a fingerprint for a context assembly request.
The fingerprint is based on:
- Context content and metadata
- Context content hash and metadata (not full content for performance)
- Query string
- Target model
@@ -112,12 +112,13 @@ class ContextCache:
Returns:
32-character hex fingerprint
"""
# Build a deterministic representation
# Build a deterministic representation using content hashes for performance
# This avoids JSON serializing potentially large content strings
context_data = []
for ctx in contexts:
context_data.append({
"type": ctx.get_type().value,
"content": ctx.content,
"content_hash": self._hash_content(ctx.content), # Hash instead of full content
"source": ctx.source,
"priority": ctx.priority, # Already an int
})

View File

@@ -10,6 +10,7 @@ import re
from dataclasses import dataclass
from typing import TYPE_CHECKING
from ..config import ContextSettings, get_context_settings
from ..types import BaseContext, ContextType
if TYPE_CHECKING:
@@ -45,26 +46,41 @@ class TruncationStrategy:
4. Semantic chunking: Keep most relevant chunks
"""
# Default truncation marker
TRUNCATION_MARKER = "\n\n[...content truncated...]\n\n"
def __init__(
self,
calculator: "TokenCalculator | None" = None,
preserve_ratio_start: float = 0.7, # Keep 70% from start by default
min_content_length: int = 100, # Minimum characters to keep
preserve_ratio_start: float | None = None,
min_content_length: int | None = None,
settings: ContextSettings | None = None,
) -> None:
"""
Initialize truncation strategy.
Args:
calculator: Token calculator for accurate counting
preserve_ratio_start: Ratio of content to keep from start
min_content_length: Minimum characters to preserve
preserve_ratio_start: Ratio of content to keep from start (overrides settings)
min_content_length: Minimum characters to preserve (overrides settings)
settings: Context settings (uses global if None)
"""
self._settings = settings or get_context_settings()
self._calculator = calculator
self._preserve_ratio_start = preserve_ratio_start
self._min_content_length = min_content_length
# Use provided values or fall back to settings
self._preserve_ratio_start = (
preserve_ratio_start
if preserve_ratio_start is not None
else self._settings.truncation_preserve_ratio
)
self._min_content_length = (
min_content_length
if min_content_length is not None
else self._settings.truncation_min_content_length
)
@property
def TRUNCATION_MARKER(self) -> str:
"""Get truncation marker from settings."""
return self._settings.truncation_marker
def set_calculator(self, calculator: "TokenCalculator") -> None:
"""Set token calculator."""
@@ -125,7 +141,7 @@ class TruncationStrategy:
truncated_tokens=truncated_tokens,
content=truncated,
truncated=True,
truncation_ratio=1 - (truncated_tokens / original_tokens),
truncation_ratio=0.0 if original_tokens == 0 else 1 - (truncated_tokens / original_tokens),
)
async def _truncate_end(
@@ -141,10 +157,17 @@ class TruncationStrategy:
"""
# Binary search for optimal truncation point
marker_tokens = await self._count_tokens(self.TRUNCATION_MARKER, model)
available_tokens = max_tokens - marker_tokens
available_tokens = max(0, max_tokens - marker_tokens)
# Estimate characters per token
chars_per_token = len(content) / await self._count_tokens(content, model)
# Edge case: if no tokens available for content, return just the marker
if available_tokens <= 0:
return self.TRUNCATION_MARKER
# Estimate characters per token (guard against division by zero)
content_tokens = await self._count_tokens(content, model)
if content_tokens == 0:
return content + self.TRUNCATION_MARKER
chars_per_token = len(content) / content_tokens
# Start with estimated position
estimated_chars = int(available_tokens * chars_per_token)
@@ -243,7 +266,9 @@ class TruncationStrategy:
if current_tokens <= target_tokens:
return content
# Estimate characters
# Estimate characters (guard against division by zero)
if current_tokens == 0:
return content
chars_per_token = len(content) / current_tokens
estimated_chars = int(target_tokens * chars_per_token)

View File

@@ -104,9 +104,21 @@ class ContextSettings(BaseSettings):
le=1.0,
description="Compress when budget usage exceeds this percentage",
)
truncation_suffix: str = Field(
default="... [truncated]",
description="Suffix to add when truncating content",
truncation_marker: str = Field(
default="\n\n[...content truncated...]\n\n",
description="Marker text to insert where content was truncated",
)
truncation_preserve_ratio: float = Field(
default=0.7,
ge=0.1,
le=0.9,
description="Ratio of content to preserve from start in middle truncation (0.7 = 70% start, 30% end)",
)
truncation_min_content_length: int = Field(
default=100,
ge=10,
le=1000,
description="Minimum content length in characters before truncation applies",
)
summary_model_group: str = Field(
default="fast",
@@ -128,6 +140,12 @@ class ContextSettings(BaseSettings):
default="ctx",
description="Redis key prefix for context cache",
)
cache_memory_max_items: int = Field(
default=1000,
ge=100,
le=100000,
description="Maximum items in memory fallback cache when Redis unavailable",
)
# Performance settings
max_assembly_time_ms: int = Field(
@@ -165,6 +183,28 @@ class ContextSettings(BaseSettings):
description="Minimum relevance score for knowledge",
)
# Relevance scoring settings
relevance_keyword_fallback_weight: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Maximum score for keyword-based fallback scoring (when semantic unavailable)",
)
relevance_semantic_max_chars: int = Field(
default=2000,
ge=100,
le=10000,
description="Maximum content length in chars for semantic similarity computation",
)
# Diversity/ranking settings
diversity_max_per_source: int = Field(
default=3,
ge=1,
le=20,
description="Maximum contexts from the same source in diversity reranking",
)
# Conversation history settings
conversation_max_turns: int = Field(
default=20,
@@ -253,11 +293,15 @@ class ContextSettings(BaseSettings):
"compression": {
"threshold": self.compression_threshold,
"summary_model_group": self.summary_model_group,
"truncation_marker": self.truncation_marker,
"truncation_preserve_ratio": self.truncation_preserve_ratio,
"truncation_min_content_length": self.truncation_min_content_length,
},
"cache": {
"enabled": self.cache_enabled,
"ttl_seconds": self.cache_ttl_seconds,
"prefix": self.cache_prefix,
"memory_max_items": self.cache_memory_max_items,
},
"performance": {
"max_assembly_time_ms": self.max_assembly_time_ms,
@@ -269,6 +313,13 @@ class ContextSettings(BaseSettings):
"max_results": self.knowledge_max_results,
"min_score": self.knowledge_min_score,
},
"relevance": {
"keyword_fallback_weight": self.relevance_keyword_fallback_weight,
"semantic_max_chars": self.relevance_semantic_max_chars,
},
"diversity": {
"max_per_source": self.diversity_max_per_source,
},
"conversation": {
"max_turns": self.conversation_max_turns,
"recent_priority": self.conversation_recent_priority,

View File

@@ -214,6 +214,7 @@ class ContextEngine:
contexts.extend(custom_contexts)
# Check cache if enabled
fingerprint: str | None = None
if use_cache and self._cache.is_enabled:
fingerprint = self._cache.compute_fingerprint(contexts, query, model)
cached = await self._cache.get_assembled(fingerprint)
@@ -232,9 +233,8 @@ class ContextEngine:
format_output=format_output,
)
# Cache result if enabled
if use_cache and self._cache.is_enabled:
fingerprint = self._cache.compute_fingerprint(contexts, query, model)
# Cache result if enabled (reuse fingerprint computed above)
if use_cache and self._cache.is_enabled and fingerprint is not None:
await self._cache.set_assembled(fingerprint, result)
return result
@@ -275,7 +275,8 @@ class ContextEngine:
)
contexts = []
for chunk in result.data.get("results", []):
results = result.data.get("results", []) if isinstance(result.data, dict) else []
for chunk in results:
contexts.append(
KnowledgeContext(
content=chunk.get("content", ""),

View File

@@ -9,6 +9,7 @@ from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from ..budget import TokenBudget, TokenCalculator
from ..config import ContextSettings, get_context_settings
from ..scoring.composite import CompositeScorer, ScoredContext
from ..types import BaseContext
@@ -45,6 +46,7 @@ class ContextRanker:
self,
scorer: CompositeScorer | None = None,
calculator: TokenCalculator | None = None,
settings: ContextSettings | None = None,
) -> None:
"""
Initialize context ranker.
@@ -52,7 +54,9 @@ class ContextRanker:
Args:
scorer: Composite scorer for scoring contexts
calculator: Token calculator for counting tokens
settings: Context settings (uses global if None)
"""
self._settings = settings or get_context_settings()
self._scorer = scorer or CompositeScorer()
self._calculator = calculator or TokenCalculator()
@@ -226,16 +230,32 @@ class ContextRanker:
"""
Ensure all contexts have token counts.
Counts tokens in parallel for contexts that don't have counts.
Args:
contexts: Contexts to check
model: Model for token counting
"""
for context in contexts:
if context.token_count is None:
count = await self._calculator.count_tokens(
context.content, model
)
context.token_count = count
import asyncio
# Find contexts needing counts
contexts_needing_counts = [
ctx for ctx in contexts if ctx.token_count is None
]
if not contexts_needing_counts:
return
# Count all in parallel
tasks = [
self._calculator.count_tokens(ctx.content, model)
for ctx in contexts_needing_counts
]
counts = await asyncio.gather(*tasks)
# Assign counts back
for ctx, count in zip(contexts_needing_counts, counts):
ctx.token_count = count
def _count_by_type(
self, scored_contexts: list[ScoredContext]
@@ -255,7 +275,7 @@ class ContextRanker:
async def rerank_for_diversity(
self,
scored_contexts: list[ScoredContext],
max_per_source: int = 3,
max_per_source: int | None = None,
) -> list[ScoredContext]:
"""
Rerank to ensure source diversity.
@@ -264,11 +284,18 @@ class ContextRanker:
Args:
scored_contexts: Already scored contexts
max_per_source: Maximum items per source
max_per_source: Maximum items per source (uses settings if None)
Returns:
Reranked contexts
"""
# Use provided value or fall back to settings
effective_max = (
max_per_source
if max_per_source is not None
else self._settings.diversity_max_per_source
)
source_counts: dict[str, int] = {}
result: list[ScoredContext] = []
deferred: list[ScoredContext] = []
@@ -277,7 +304,7 @@ class ContextRanker:
source = sc.context.source
current_count = source_counts.get(source, 0)
if current_count < max_per_source:
if current_count < effective_max:
result.append(sc)
source_counts[source] = current_count + 1
else:

View File

@@ -8,6 +8,7 @@ import asyncio
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from weakref import WeakValueDictionary
from ..config import ContextSettings, get_context_settings
from ..types import BaseContext
@@ -89,6 +90,11 @@ class CompositeScorer:
self._recency_scorer = RecencyScorer(weight=self._recency_weight)
self._priority_scorer = PriorityScorer(weight=self._priority_weight)
# Per-context locks to prevent race conditions during parallel scoring
# Uses WeakValueDictionary so locks are garbage collected when not in use
self._context_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
self._locks_lock = asyncio.Lock() # Lock to protect _context_locks access
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
"""Set MCP manager for semantic scoring."""
self._relevance_scorer.set_mcp_manager(mcp_manager)
@@ -128,6 +134,38 @@ class CompositeScorer:
self._priority_weight = max(0.0, min(1.0, priority))
self._priority_scorer.weight = self._priority_weight
async def _get_context_lock(self, context_id: str) -> asyncio.Lock:
"""
Get or create a lock for a specific context.
Thread-safe access to per-context locks prevents race conditions
when the same context is scored concurrently.
Args:
context_id: The context ID to get a lock for
Returns:
asyncio.Lock for the context
"""
# Fast path: check if lock exists without acquiring main lock
if context_id in self._context_locks:
lock = self._context_locks.get(context_id)
if lock is not None:
return lock
# Slow path: create lock while holding main lock
async with self._locks_lock:
# Double-check after acquiring lock
if context_id in self._context_locks:
lock = self._context_locks.get(context_id)
if lock is not None:
return lock
# Create new lock
new_lock = asyncio.Lock()
self._context_locks[context_id] = new_lock
return new_lock
async def score(
self,
context: BaseContext,
@@ -157,6 +195,9 @@ class CompositeScorer:
"""
Compute composite score with individual scores.
Uses per-context locking to prevent race conditions when the same
context is scored concurrently in parallel scoring operations.
Args:
context: Context to score
query: Query to score against
@@ -165,46 +206,50 @@ class CompositeScorer:
Returns:
ScoredContext with all scores
"""
# Check if context already has a score
if context._score is not None:
return ScoredContext(
context=context,
composite_score=context._score,
# Get lock for this specific context to prevent race conditions
context_lock = await self._get_context_lock(context.id)
async with context_lock:
# Check if context already has a score (inside lock to prevent races)
if context._score is not None:
return ScoredContext(
context=context,
composite_score=context._score,
)
# Compute individual scores in parallel
relevance_task = self._relevance_scorer.score(context, query, **kwargs)
recency_task = self._recency_scorer.score(context, query, **kwargs)
priority_task = self._priority_scorer.score(context, query, **kwargs)
relevance_score, recency_score, priority_score = await asyncio.gather(
relevance_task, recency_task, priority_task
)
# Compute individual scores in parallel
relevance_task = self._relevance_scorer.score(context, query, **kwargs)
recency_task = self._recency_scorer.score(context, query, **kwargs)
priority_task = self._priority_scorer.score(context, query, **kwargs)
# Compute weighted composite
total_weight = (
self._relevance_weight + self._recency_weight + self._priority_weight
)
relevance_score, recency_score, priority_score = await asyncio.gather(
relevance_task, recency_task, priority_task
)
if total_weight > 0:
composite = (
relevance_score * self._relevance_weight
+ recency_score * self._recency_weight
+ priority_score * self._priority_weight
) / total_weight
else:
composite = 0.0
# Compute weighted composite
total_weight = (
self._relevance_weight + self._recency_weight + self._priority_weight
)
# Cache the score on the context (now safe - inside lock)
context._score = composite
if total_weight > 0:
composite = (
relevance_score * self._relevance_weight
+ recency_score * self._recency_weight
+ priority_score * self._priority_weight
) / total_weight
else:
composite = 0.0
# Cache the score on the context
context._score = composite
return ScoredContext(
context=context,
composite_score=composite,
relevance_score=relevance_score,
recency_score=recency_score,
priority_score=priority_score,
)
return ScoredContext(
context=context,
composite_score=composite,
relevance_score=relevance_score,
recency_score=recency_score,
priority_score=priority_score,
)
async def score_batch(
self,

View File

@@ -9,6 +9,7 @@ import logging
import re
from typing import TYPE_CHECKING, Any
from ..config import ContextSettings, get_context_settings
from ..types import BaseContext, KnowledgeContext
from .base import BaseScorer
@@ -32,7 +33,9 @@ class RelevanceScorer(BaseScorer):
self,
mcp_manager: "MCPClientManager | None" = None,
weight: float = 1.0,
keyword_fallback_weight: float = 0.5,
keyword_fallback_weight: float | None = None,
semantic_max_chars: int | None = None,
settings: ContextSettings | None = None,
) -> None:
"""
Initialize relevance scorer.
@@ -40,11 +43,25 @@ class RelevanceScorer(BaseScorer):
Args:
mcp_manager: MCP manager for Knowledge Base calls
weight: Scorer weight for composite scoring
keyword_fallback_weight: Max score for keyword-based fallback
keyword_fallback_weight: Max score for keyword-based fallback (overrides settings)
semantic_max_chars: Max content length for semantic similarity (overrides settings)
settings: Context settings (uses global if None)
"""
super().__init__(weight)
self._settings = settings or get_context_settings()
self._mcp = mcp_manager
self._keyword_fallback_weight = keyword_fallback_weight
# Use provided values or fall back to settings
self._keyword_fallback_weight = (
keyword_fallback_weight
if keyword_fallback_weight is not None
else self._settings.relevance_keyword_fallback_weight
)
self._semantic_max_chars = (
semantic_max_chars
if semantic_max_chars is not None
else self._settings.relevance_semantic_max_chars
)
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
"""Set MCP manager for semantic scoring."""
@@ -112,11 +129,11 @@ class RelevanceScorer(BaseScorer):
tool="compute_similarity",
args={
"text1": query,
"text2": context.content[:2000], # Limit content length
"text2": context.content[: self._semantic_max_chars], # Limit content length
},
)
if result.success and result.data:
if result.success and isinstance(result.data, dict):
similarity = result.data.get("similarity")
if similarity is not None:
return self.normalize_score(float(similarity))
@@ -171,7 +188,7 @@ class RelevanceScorer(BaseScorer):
**kwargs: Any,
) -> list[float]:
"""
Score multiple contexts.
Score multiple contexts in parallel.
Args:
contexts: Contexts to score
@@ -181,8 +198,10 @@ class RelevanceScorer(BaseScorer):
Returns:
List of scores (same order as input)
"""
scores = []
for context in contexts:
score = await self.score(context, query, **kwargs)
scores.append(score)
return scores
import asyncio
if not contexts:
return []
tasks = [self.score(context, query, **kwargs) for context in contexts]
return await asyncio.gather(*tasks)