Files
fast-next-template/backend/app/services/context/prioritization/ranker.py
Felipe Cardoso 1628eacf2b feat(context): enhance timeout handling, tenant isolation, and budget management
- Added timeout enforcement for token counting, scoring, and compression with detailed error handling.
- Introduced tenant isolation in context caching using project and agent identifiers.
- Enhanced budget management with stricter checks for critical context overspending and buffer limitations.
- Optimized per-context locking with cleanup to prevent memory leaks in concurrent environments.
- Updated default assembly timeout settings for improved performance and reliability.
- Improved XML escaping in Claude adapter for safety against injection attacks.
- Standardized token estimation using model-specific ratios.
2026-01-04 15:52:50 +01:00

331 lines
10 KiB
Python

"""
Context Ranker for Context Management.
Ranks and selects contexts based on scores and budget constraints.
"""
import logging
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from ..budget import TokenBudget, TokenCalculator
from ..config import ContextSettings, get_context_settings
from ..exceptions import BudgetExceededError
from ..scoring.composite import CompositeScorer, ScoredContext
from ..types import BaseContext, ContextPriority
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
@dataclass
class RankingResult:
"""Result of context ranking and selection."""
selected: list[ScoredContext]
excluded: list[ScoredContext]
total_tokens: int
selection_stats: dict[str, Any] = field(default_factory=dict)
@property
def selected_contexts(self) -> list[BaseContext]:
"""Get just the context objects (not scored wrappers)."""
return [s.context for s in self.selected]
class ContextRanker:
"""
Ranks and selects contexts within budget constraints.
Uses greedy selection to maximize total score
while respecting token budgets per context type.
"""
def __init__(
self,
scorer: CompositeScorer | None = None,
calculator: TokenCalculator | None = None,
settings: ContextSettings | None = None,
) -> None:
"""
Initialize context ranker.
Args:
scorer: Composite scorer for scoring contexts
calculator: Token calculator for counting tokens
settings: Context settings (uses global if None)
"""
self._settings = settings or get_context_settings()
self._scorer = scorer or CompositeScorer()
self._calculator = calculator or TokenCalculator()
def set_scorer(self, scorer: CompositeScorer) -> None:
"""Set the scorer."""
self._scorer = scorer
def set_calculator(self, calculator: TokenCalculator) -> None:
"""Set the token calculator."""
self._calculator = calculator
async def rank(
self,
contexts: list[BaseContext],
query: str,
budget: TokenBudget,
model: str | None = None,
ensure_required: bool = True,
**kwargs: Any,
) -> RankingResult:
"""
Rank and select contexts within budget.
Args:
contexts: Contexts to rank
query: Query to rank against
budget: Token budget constraints
model: Model for token counting
ensure_required: If True, always include CRITICAL priority contexts
**kwargs: Additional scoring parameters
Returns:
RankingResult with selected and excluded contexts
"""
if not contexts:
return RankingResult(
selected=[],
excluded=[],
total_tokens=0,
selection_stats={"total_contexts": 0},
)
# 1. Ensure all contexts have token counts
await self._ensure_token_counts(contexts, model)
# 2. Score all contexts
scored_contexts = await self._scorer.score_batch(contexts, query, **kwargs)
# 3. Separate required (CRITICAL priority) from optional
required: list[ScoredContext] = []
optional: list[ScoredContext] = []
if ensure_required:
for sc in scored_contexts:
# CRITICAL priority (150) contexts are always included
if sc.context.priority >= ContextPriority.CRITICAL.value:
required.append(sc)
else:
optional.append(sc)
else:
optional = list(scored_contexts)
# 4. Sort optional by score (highest first)
optional.sort(reverse=True)
# 5. Greedy selection
selected: list[ScoredContext] = []
excluded: list[ScoredContext] = []
total_tokens = 0
# Calculate the usable budget (total minus reserved portions)
usable_budget = budget.total - budget.response_reserve - budget.buffer
# First, try to fit required contexts
for sc in required:
token_count = sc.context.token_count or 0
context_type = sc.context.get_type()
if budget.can_fit(context_type, token_count):
budget.allocate(context_type, token_count)
selected.append(sc)
total_tokens += token_count
else:
# Force-fit CRITICAL contexts if needed, but check total budget first
if total_tokens + token_count > usable_budget:
# Even CRITICAL contexts cannot exceed total model context window
raise BudgetExceededError(
message=(
f"CRITICAL contexts exceed total budget. "
f"Context '{sc.context.source}' ({token_count} tokens) "
f"would exceed usable budget of {usable_budget} tokens."
),
allocated=usable_budget,
requested=total_tokens + token_count,
context_type="CRITICAL_OVERFLOW",
)
budget.allocate(context_type, token_count, force=True)
selected.append(sc)
total_tokens += token_count
logger.warning(
f"Force-fitted CRITICAL context: {sc.context.source} "
f"({token_count} tokens)"
)
# Then, greedily add optional contexts
for sc in optional:
token_count = sc.context.token_count or 0
context_type = sc.context.get_type()
if budget.can_fit(context_type, token_count):
budget.allocate(context_type, token_count)
selected.append(sc)
total_tokens += token_count
else:
excluded.append(sc)
# Build stats
stats = {
"total_contexts": len(contexts),
"required_count": len(required),
"selected_count": len(selected),
"excluded_count": len(excluded),
"total_tokens": total_tokens,
"by_type": self._count_by_type(selected),
}
return RankingResult(
selected=selected,
excluded=excluded,
total_tokens=total_tokens,
selection_stats=stats,
)
async def rank_simple(
self,
contexts: list[BaseContext],
query: str,
max_tokens: int,
model: str | None = None,
**kwargs: Any,
) -> list[BaseContext]:
"""
Simple ranking without budget per type.
Selects top contexts by score until max tokens reached.
Args:
contexts: Contexts to rank
query: Query to rank against
max_tokens: Maximum total tokens
model: Model for token counting
**kwargs: Additional scoring parameters
Returns:
Selected contexts (in score order)
"""
if not contexts:
return []
# Ensure token counts
await self._ensure_token_counts(contexts, model)
# Score all contexts
scored_contexts = await self._scorer.score_batch(contexts, query, **kwargs)
# Sort by score (highest first)
scored_contexts.sort(reverse=True)
# Greedy selection
selected: list[BaseContext] = []
total_tokens = 0
for sc in scored_contexts:
token_count = sc.context.token_count or 0
if total_tokens + token_count <= max_tokens:
selected.append(sc.context)
total_tokens += token_count
return selected
async def _ensure_token_counts(
self,
contexts: list[BaseContext],
model: str | None = None,
) -> None:
"""
Ensure all contexts have token counts.
Counts tokens in parallel for contexts that don't have counts.
Args:
contexts: Contexts to check
model: Model for token counting
"""
import asyncio
# Find contexts needing counts
contexts_needing_counts = [ctx for ctx in contexts if ctx.token_count is None]
if not contexts_needing_counts:
return
# Count all in parallel
tasks = [
self._calculator.count_tokens(ctx.content, model)
for ctx in contexts_needing_counts
]
counts = await asyncio.gather(*tasks)
# Assign counts back
for ctx, count in zip(contexts_needing_counts, counts, strict=True):
ctx.token_count = count
def _count_by_type(
self, scored_contexts: list[ScoredContext]
) -> dict[str, dict[str, int]]:
"""Count selected contexts by type."""
by_type: dict[str, dict[str, int]] = {}
for sc in scored_contexts:
type_name = sc.context.get_type().value
if type_name not in by_type:
by_type[type_name] = {"count": 0, "tokens": 0}
by_type[type_name]["count"] += 1
by_type[type_name]["tokens"] += sc.context.token_count or 0
return by_type
async def rerank_for_diversity(
self,
scored_contexts: list[ScoredContext],
max_per_source: int | None = None,
) -> list[ScoredContext]:
"""
Rerank to ensure source diversity.
Prevents too many items from the same source.
Args:
scored_contexts: Already scored contexts
max_per_source: Maximum items per source (uses settings if None)
Returns:
Reranked contexts
"""
# Use provided value or fall back to settings
effective_max = (
max_per_source
if max_per_source is not None
else self._settings.diversity_max_per_source
)
source_counts: dict[str, int] = {}
result: list[ScoredContext] = []
deferred: list[ScoredContext] = []
for sc in scored_contexts:
source = sc.context.source
current_count = source_counts.get(source, 0)
if current_count < effective_max:
result.append(sc)
source_counts[source] = current_count + 1
else:
deferred.append(sc)
# Add deferred items at the end
result.extend(deferred)
return result