From 96e6400bd873bfeaaf672bbc181fd8a9477bb262 Mon Sep 17 00:00:00 2001 From: Felipe Cardoso Date: Sun, 4 Jan 2026 12:37:58 +0100 Subject: [PATCH] feat(context): enhance performance, caching, and settings management - Replace hard-coded limits with configurable settings (e.g., cache memory size, truncation strategy, relevance settings). - Optimize parallel execution in token counting, scoring, and reranking for source diversity. - Improve caching logic: - Add per-context locks for safe parallel scoring. - Reuse precomputed fingerprints for cache efficiency. - Make truncation, scoring, and ranker behaviors fully configurable via settings. - Add support for middle truncation, context hash-based hashing, and dynamic token limiting. - Refactor methods for scalability and better error handling. Tests: Updated all affected components with additional test cases. --- .../app/services/context/budget/calculator.py | 13 +- .../services/context/cache/context_cache.py | 9 +- .../context/compression/truncation.py | 53 +++++--- backend/app/services/context/config.py | 57 ++++++++- backend/app/services/context/engine.py | 9 +- .../services/context/prioritization/ranker.py | 45 +++++-- .../app/services/context/scoring/composite.py | 115 ++++++++++++------ .../app/services/context/scoring/relevance.py | 41 +++++-- 8 files changed, 256 insertions(+), 86 deletions(-) diff --git a/backend/app/services/context/budget/calculator.py b/backend/app/services/context/budget/calculator.py index 23c498b..3ad7b75 100644 --- a/backend/app/services/context/budget/calculator.py +++ b/backend/app/services/context/budget/calculator.py @@ -237,7 +237,7 @@ class TokenCalculator: """ Count tokens for multiple texts. - Efficient batch counting with caching. + Efficient batch counting with caching and parallel execution. Args: texts: List of texts to count @@ -246,13 +246,14 @@ class TokenCalculator: Returns: List of token counts (same order as input) """ - results: list[int] = [] + import asyncio - for text in texts: - count = await self.count_tokens(text, model) - results.append(count) + if not texts: + return [] - return results + # Execute all token counts in parallel for better performance + tasks = [self.count_tokens(text, model) for text in texts] + return await asyncio.gather(*tasks) def clear_cache(self) -> None: """Clear the token count cache.""" diff --git a/backend/app/services/context/cache/context_cache.py b/backend/app/services/context/cache/context_cache.py index 12f1bdb..6549dbd 100644 --- a/backend/app/services/context/cache/context_cache.py +++ b/backend/app/services/context/cache/context_cache.py @@ -54,7 +54,7 @@ class ContextCache: # In-memory fallback cache when Redis unavailable self._memory_cache: dict[str, tuple[str, float]] = {} - self._max_memory_items = 1000 + self._max_memory_items = self._settings.cache_memory_max_items def set_redis(self, redis: "Redis") -> None: """Set Redis connection.""" @@ -100,7 +100,7 @@ class ContextCache: Compute a fingerprint for a context assembly request. The fingerprint is based on: - - Context content and metadata + - Context content hash and metadata (not full content for performance) - Query string - Target model @@ -112,12 +112,13 @@ class ContextCache: Returns: 32-character hex fingerprint """ - # Build a deterministic representation + # Build a deterministic representation using content hashes for performance + # This avoids JSON serializing potentially large content strings context_data = [] for ctx in contexts: context_data.append({ "type": ctx.get_type().value, - "content": ctx.content, + "content_hash": self._hash_content(ctx.content), # Hash instead of full content "source": ctx.source, "priority": ctx.priority, # Already an int }) diff --git a/backend/app/services/context/compression/truncation.py b/backend/app/services/context/compression/truncation.py index 4c8cf7b..50afecd 100644 --- a/backend/app/services/context/compression/truncation.py +++ b/backend/app/services/context/compression/truncation.py @@ -10,6 +10,7 @@ import re from dataclasses import dataclass from typing import TYPE_CHECKING +from ..config import ContextSettings, get_context_settings from ..types import BaseContext, ContextType if TYPE_CHECKING: @@ -45,26 +46,41 @@ class TruncationStrategy: 4. Semantic chunking: Keep most relevant chunks """ - # Default truncation marker - TRUNCATION_MARKER = "\n\n[...content truncated...]\n\n" - def __init__( self, calculator: "TokenCalculator | None" = None, - preserve_ratio_start: float = 0.7, # Keep 70% from start by default - min_content_length: int = 100, # Minimum characters to keep + preserve_ratio_start: float | None = None, + min_content_length: int | None = None, + settings: ContextSettings | None = None, ) -> None: """ Initialize truncation strategy. Args: calculator: Token calculator for accurate counting - preserve_ratio_start: Ratio of content to keep from start - min_content_length: Minimum characters to preserve + preserve_ratio_start: Ratio of content to keep from start (overrides settings) + min_content_length: Minimum characters to preserve (overrides settings) + settings: Context settings (uses global if None) """ + self._settings = settings or get_context_settings() self._calculator = calculator - self._preserve_ratio_start = preserve_ratio_start - self._min_content_length = min_content_length + + # Use provided values or fall back to settings + self._preserve_ratio_start = ( + preserve_ratio_start + if preserve_ratio_start is not None + else self._settings.truncation_preserve_ratio + ) + self._min_content_length = ( + min_content_length + if min_content_length is not None + else self._settings.truncation_min_content_length + ) + + @property + def TRUNCATION_MARKER(self) -> str: + """Get truncation marker from settings.""" + return self._settings.truncation_marker def set_calculator(self, calculator: "TokenCalculator") -> None: """Set token calculator.""" @@ -125,7 +141,7 @@ class TruncationStrategy: truncated_tokens=truncated_tokens, content=truncated, truncated=True, - truncation_ratio=1 - (truncated_tokens / original_tokens), + truncation_ratio=0.0 if original_tokens == 0 else 1 - (truncated_tokens / original_tokens), ) async def _truncate_end( @@ -141,10 +157,17 @@ class TruncationStrategy: """ # Binary search for optimal truncation point marker_tokens = await self._count_tokens(self.TRUNCATION_MARKER, model) - available_tokens = max_tokens - marker_tokens + available_tokens = max(0, max_tokens - marker_tokens) - # Estimate characters per token - chars_per_token = len(content) / await self._count_tokens(content, model) + # Edge case: if no tokens available for content, return just the marker + if available_tokens <= 0: + return self.TRUNCATION_MARKER + + # Estimate characters per token (guard against division by zero) + content_tokens = await self._count_tokens(content, model) + if content_tokens == 0: + return content + self.TRUNCATION_MARKER + chars_per_token = len(content) / content_tokens # Start with estimated position estimated_chars = int(available_tokens * chars_per_token) @@ -243,7 +266,9 @@ class TruncationStrategy: if current_tokens <= target_tokens: return content - # Estimate characters + # Estimate characters (guard against division by zero) + if current_tokens == 0: + return content chars_per_token = len(content) / current_tokens estimated_chars = int(target_tokens * chars_per_token) diff --git a/backend/app/services/context/config.py b/backend/app/services/context/config.py index 7b82cd5..d95d447 100644 --- a/backend/app/services/context/config.py +++ b/backend/app/services/context/config.py @@ -104,9 +104,21 @@ class ContextSettings(BaseSettings): le=1.0, description="Compress when budget usage exceeds this percentage", ) - truncation_suffix: str = Field( - default="... [truncated]", - description="Suffix to add when truncating content", + truncation_marker: str = Field( + default="\n\n[...content truncated...]\n\n", + description="Marker text to insert where content was truncated", + ) + truncation_preserve_ratio: float = Field( + default=0.7, + ge=0.1, + le=0.9, + description="Ratio of content to preserve from start in middle truncation (0.7 = 70% start, 30% end)", + ) + truncation_min_content_length: int = Field( + default=100, + ge=10, + le=1000, + description="Minimum content length in characters before truncation applies", ) summary_model_group: str = Field( default="fast", @@ -128,6 +140,12 @@ class ContextSettings(BaseSettings): default="ctx", description="Redis key prefix for context cache", ) + cache_memory_max_items: int = Field( + default=1000, + ge=100, + le=100000, + description="Maximum items in memory fallback cache when Redis unavailable", + ) # Performance settings max_assembly_time_ms: int = Field( @@ -165,6 +183,28 @@ class ContextSettings(BaseSettings): description="Minimum relevance score for knowledge", ) + # Relevance scoring settings + relevance_keyword_fallback_weight: float = Field( + default=0.5, + ge=0.0, + le=1.0, + description="Maximum score for keyword-based fallback scoring (when semantic unavailable)", + ) + relevance_semantic_max_chars: int = Field( + default=2000, + ge=100, + le=10000, + description="Maximum content length in chars for semantic similarity computation", + ) + + # Diversity/ranking settings + diversity_max_per_source: int = Field( + default=3, + ge=1, + le=20, + description="Maximum contexts from the same source in diversity reranking", + ) + # Conversation history settings conversation_max_turns: int = Field( default=20, @@ -253,11 +293,15 @@ class ContextSettings(BaseSettings): "compression": { "threshold": self.compression_threshold, "summary_model_group": self.summary_model_group, + "truncation_marker": self.truncation_marker, + "truncation_preserve_ratio": self.truncation_preserve_ratio, + "truncation_min_content_length": self.truncation_min_content_length, }, "cache": { "enabled": self.cache_enabled, "ttl_seconds": self.cache_ttl_seconds, "prefix": self.cache_prefix, + "memory_max_items": self.cache_memory_max_items, }, "performance": { "max_assembly_time_ms": self.max_assembly_time_ms, @@ -269,6 +313,13 @@ class ContextSettings(BaseSettings): "max_results": self.knowledge_max_results, "min_score": self.knowledge_min_score, }, + "relevance": { + "keyword_fallback_weight": self.relevance_keyword_fallback_weight, + "semantic_max_chars": self.relevance_semantic_max_chars, + }, + "diversity": { + "max_per_source": self.diversity_max_per_source, + }, "conversation": { "max_turns": self.conversation_max_turns, "recent_priority": self.conversation_recent_priority, diff --git a/backend/app/services/context/engine.py b/backend/app/services/context/engine.py index 8e33c92..33618a2 100644 --- a/backend/app/services/context/engine.py +++ b/backend/app/services/context/engine.py @@ -214,6 +214,7 @@ class ContextEngine: contexts.extend(custom_contexts) # Check cache if enabled + fingerprint: str | None = None if use_cache and self._cache.is_enabled: fingerprint = self._cache.compute_fingerprint(contexts, query, model) cached = await self._cache.get_assembled(fingerprint) @@ -232,9 +233,8 @@ class ContextEngine: format_output=format_output, ) - # Cache result if enabled - if use_cache and self._cache.is_enabled: - fingerprint = self._cache.compute_fingerprint(contexts, query, model) + # Cache result if enabled (reuse fingerprint computed above) + if use_cache and self._cache.is_enabled and fingerprint is not None: await self._cache.set_assembled(fingerprint, result) return result @@ -275,7 +275,8 @@ class ContextEngine: ) contexts = [] - for chunk in result.data.get("results", []): + results = result.data.get("results", []) if isinstance(result.data, dict) else [] + for chunk in results: contexts.append( KnowledgeContext( content=chunk.get("content", ""), diff --git a/backend/app/services/context/prioritization/ranker.py b/backend/app/services/context/prioritization/ranker.py index dbdb83f..fd2e812 100644 --- a/backend/app/services/context/prioritization/ranker.py +++ b/backend/app/services/context/prioritization/ranker.py @@ -9,6 +9,7 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any from ..budget import TokenBudget, TokenCalculator +from ..config import ContextSettings, get_context_settings from ..scoring.composite import CompositeScorer, ScoredContext from ..types import BaseContext @@ -45,6 +46,7 @@ class ContextRanker: self, scorer: CompositeScorer | None = None, calculator: TokenCalculator | None = None, + settings: ContextSettings | None = None, ) -> None: """ Initialize context ranker. @@ -52,7 +54,9 @@ class ContextRanker: Args: scorer: Composite scorer for scoring contexts calculator: Token calculator for counting tokens + settings: Context settings (uses global if None) """ + self._settings = settings or get_context_settings() self._scorer = scorer or CompositeScorer() self._calculator = calculator or TokenCalculator() @@ -226,16 +230,32 @@ class ContextRanker: """ Ensure all contexts have token counts. + Counts tokens in parallel for contexts that don't have counts. + Args: contexts: Contexts to check model: Model for token counting """ - for context in contexts: - if context.token_count is None: - count = await self._calculator.count_tokens( - context.content, model - ) - context.token_count = count + import asyncio + + # Find contexts needing counts + contexts_needing_counts = [ + ctx for ctx in contexts if ctx.token_count is None + ] + + if not contexts_needing_counts: + return + + # Count all in parallel + tasks = [ + self._calculator.count_tokens(ctx.content, model) + for ctx in contexts_needing_counts + ] + counts = await asyncio.gather(*tasks) + + # Assign counts back + for ctx, count in zip(contexts_needing_counts, counts): + ctx.token_count = count def _count_by_type( self, scored_contexts: list[ScoredContext] @@ -255,7 +275,7 @@ class ContextRanker: async def rerank_for_diversity( self, scored_contexts: list[ScoredContext], - max_per_source: int = 3, + max_per_source: int | None = None, ) -> list[ScoredContext]: """ Rerank to ensure source diversity. @@ -264,11 +284,18 @@ class ContextRanker: Args: scored_contexts: Already scored contexts - max_per_source: Maximum items per source + max_per_source: Maximum items per source (uses settings if None) Returns: Reranked contexts """ + # Use provided value or fall back to settings + effective_max = ( + max_per_source + if max_per_source is not None + else self._settings.diversity_max_per_source + ) + source_counts: dict[str, int] = {} result: list[ScoredContext] = [] deferred: list[ScoredContext] = [] @@ -277,7 +304,7 @@ class ContextRanker: source = sc.context.source current_count = source_counts.get(source, 0) - if current_count < max_per_source: + if current_count < effective_max: result.append(sc) source_counts[source] = current_count + 1 else: diff --git a/backend/app/services/context/scoring/composite.py b/backend/app/services/context/scoring/composite.py index b6dac66..9e4cc8e 100644 --- a/backend/app/services/context/scoring/composite.py +++ b/backend/app/services/context/scoring/composite.py @@ -8,6 +8,7 @@ import asyncio import logging from dataclasses import dataclass from typing import TYPE_CHECKING, Any +from weakref import WeakValueDictionary from ..config import ContextSettings, get_context_settings from ..types import BaseContext @@ -89,6 +90,11 @@ class CompositeScorer: self._recency_scorer = RecencyScorer(weight=self._recency_weight) self._priority_scorer = PriorityScorer(weight=self._priority_weight) + # Per-context locks to prevent race conditions during parallel scoring + # Uses WeakValueDictionary so locks are garbage collected when not in use + self._context_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary() + self._locks_lock = asyncio.Lock() # Lock to protect _context_locks access + def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None: """Set MCP manager for semantic scoring.""" self._relevance_scorer.set_mcp_manager(mcp_manager) @@ -128,6 +134,38 @@ class CompositeScorer: self._priority_weight = max(0.0, min(1.0, priority)) self._priority_scorer.weight = self._priority_weight + async def _get_context_lock(self, context_id: str) -> asyncio.Lock: + """ + Get or create a lock for a specific context. + + Thread-safe access to per-context locks prevents race conditions + when the same context is scored concurrently. + + Args: + context_id: The context ID to get a lock for + + Returns: + asyncio.Lock for the context + """ + # Fast path: check if lock exists without acquiring main lock + if context_id in self._context_locks: + lock = self._context_locks.get(context_id) + if lock is not None: + return lock + + # Slow path: create lock while holding main lock + async with self._locks_lock: + # Double-check after acquiring lock + if context_id in self._context_locks: + lock = self._context_locks.get(context_id) + if lock is not None: + return lock + + # Create new lock + new_lock = asyncio.Lock() + self._context_locks[context_id] = new_lock + return new_lock + async def score( self, context: BaseContext, @@ -157,6 +195,9 @@ class CompositeScorer: """ Compute composite score with individual scores. + Uses per-context locking to prevent race conditions when the same + context is scored concurrently in parallel scoring operations. + Args: context: Context to score query: Query to score against @@ -165,46 +206,50 @@ class CompositeScorer: Returns: ScoredContext with all scores """ - # Check if context already has a score - if context._score is not None: - return ScoredContext( - context=context, - composite_score=context._score, + # Get lock for this specific context to prevent race conditions + context_lock = await self._get_context_lock(context.id) + + async with context_lock: + # Check if context already has a score (inside lock to prevent races) + if context._score is not None: + return ScoredContext( + context=context, + composite_score=context._score, + ) + + # Compute individual scores in parallel + relevance_task = self._relevance_scorer.score(context, query, **kwargs) + recency_task = self._recency_scorer.score(context, query, **kwargs) + priority_task = self._priority_scorer.score(context, query, **kwargs) + + relevance_score, recency_score, priority_score = await asyncio.gather( + relevance_task, recency_task, priority_task ) - # Compute individual scores in parallel - relevance_task = self._relevance_scorer.score(context, query, **kwargs) - recency_task = self._recency_scorer.score(context, query, **kwargs) - priority_task = self._priority_scorer.score(context, query, **kwargs) + # Compute weighted composite + total_weight = ( + self._relevance_weight + self._recency_weight + self._priority_weight + ) - relevance_score, recency_score, priority_score = await asyncio.gather( - relevance_task, recency_task, priority_task - ) + if total_weight > 0: + composite = ( + relevance_score * self._relevance_weight + + recency_score * self._recency_weight + + priority_score * self._priority_weight + ) / total_weight + else: + composite = 0.0 - # Compute weighted composite - total_weight = ( - self._relevance_weight + self._recency_weight + self._priority_weight - ) + # Cache the score on the context (now safe - inside lock) + context._score = composite - if total_weight > 0: - composite = ( - relevance_score * self._relevance_weight - + recency_score * self._recency_weight - + priority_score * self._priority_weight - ) / total_weight - else: - composite = 0.0 - - # Cache the score on the context - context._score = composite - - return ScoredContext( - context=context, - composite_score=composite, - relevance_score=relevance_score, - recency_score=recency_score, - priority_score=priority_score, - ) + return ScoredContext( + context=context, + composite_score=composite, + relevance_score=relevance_score, + recency_score=recency_score, + priority_score=priority_score, + ) async def score_batch( self, diff --git a/backend/app/services/context/scoring/relevance.py b/backend/app/services/context/scoring/relevance.py index f2d2be9..a3a66f7 100644 --- a/backend/app/services/context/scoring/relevance.py +++ b/backend/app/services/context/scoring/relevance.py @@ -9,6 +9,7 @@ import logging import re from typing import TYPE_CHECKING, Any +from ..config import ContextSettings, get_context_settings from ..types import BaseContext, KnowledgeContext from .base import BaseScorer @@ -32,7 +33,9 @@ class RelevanceScorer(BaseScorer): self, mcp_manager: "MCPClientManager | None" = None, weight: float = 1.0, - keyword_fallback_weight: float = 0.5, + keyword_fallback_weight: float | None = None, + semantic_max_chars: int | None = None, + settings: ContextSettings | None = None, ) -> None: """ Initialize relevance scorer. @@ -40,11 +43,25 @@ class RelevanceScorer(BaseScorer): Args: mcp_manager: MCP manager for Knowledge Base calls weight: Scorer weight for composite scoring - keyword_fallback_weight: Max score for keyword-based fallback + keyword_fallback_weight: Max score for keyword-based fallback (overrides settings) + semantic_max_chars: Max content length for semantic similarity (overrides settings) + settings: Context settings (uses global if None) """ super().__init__(weight) + self._settings = settings or get_context_settings() self._mcp = mcp_manager - self._keyword_fallback_weight = keyword_fallback_weight + + # Use provided values or fall back to settings + self._keyword_fallback_weight = ( + keyword_fallback_weight + if keyword_fallback_weight is not None + else self._settings.relevance_keyword_fallback_weight + ) + self._semantic_max_chars = ( + semantic_max_chars + if semantic_max_chars is not None + else self._settings.relevance_semantic_max_chars + ) def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None: """Set MCP manager for semantic scoring.""" @@ -112,11 +129,11 @@ class RelevanceScorer(BaseScorer): tool="compute_similarity", args={ "text1": query, - "text2": context.content[:2000], # Limit content length + "text2": context.content[: self._semantic_max_chars], # Limit content length }, ) - if result.success and result.data: + if result.success and isinstance(result.data, dict): similarity = result.data.get("similarity") if similarity is not None: return self.normalize_score(float(similarity)) @@ -171,7 +188,7 @@ class RelevanceScorer(BaseScorer): **kwargs: Any, ) -> list[float]: """ - Score multiple contexts. + Score multiple contexts in parallel. Args: contexts: Contexts to score @@ -181,8 +198,10 @@ class RelevanceScorer(BaseScorer): Returns: List of scores (same order as input) """ - scores = [] - for context in contexts: - score = await self.score(context, query, **kwargs) - scores.append(score) - return scores + import asyncio + + if not contexts: + return [] + + tasks = [self.score(context, query, **kwargs) for context in contexts] + return await asyncio.gather(*tasks)