Files
fast-next-template/backend/app/services/context/prioritization/ranker.py
Felipe Cardoso 758052dcff feat(context): improve budget validation and XML safety in ranking and Claude adapter
- Added stricter budget validation in ContextRanker with explicit error handling for invalid configurations.
- Introduced `_get_valid_token_count()` helper to validate and safeguard token counts.
- Enhanced XML escaping in Claude adapter to prevent injection risks from scores and unhandled content.
2026-01-04 16:02:18 +01:00

375 lines
12 KiB
Python

"""
Context Ranker for Context Management.
Ranks and selects contexts based on scores and budget constraints.
"""
import logging
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from ..budget import TokenBudget, TokenCalculator
from ..config import ContextSettings, get_context_settings
from ..exceptions import BudgetExceededError
from ..scoring.composite import CompositeScorer, ScoredContext
from ..types import BaseContext, ContextPriority
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
@dataclass
class RankingResult:
"""Result of context ranking and selection."""
selected: list[ScoredContext]
excluded: list[ScoredContext]
total_tokens: int
selection_stats: dict[str, Any] = field(default_factory=dict)
@property
def selected_contexts(self) -> list[BaseContext]:
"""Get just the context objects (not scored wrappers)."""
return [s.context for s in self.selected]
class ContextRanker:
"""
Ranks and selects contexts within budget constraints.
Uses greedy selection to maximize total score
while respecting token budgets per context type.
"""
def __init__(
self,
scorer: CompositeScorer | None = None,
calculator: TokenCalculator | None = None,
settings: ContextSettings | None = None,
) -> None:
"""
Initialize context ranker.
Args:
scorer: Composite scorer for scoring contexts
calculator: Token calculator for counting tokens
settings: Context settings (uses global if None)
"""
self._settings = settings or get_context_settings()
self._scorer = scorer or CompositeScorer()
self._calculator = calculator or TokenCalculator()
def set_scorer(self, scorer: CompositeScorer) -> None:
"""Set the scorer."""
self._scorer = scorer
def set_calculator(self, calculator: TokenCalculator) -> None:
"""Set the token calculator."""
self._calculator = calculator
async def rank(
self,
contexts: list[BaseContext],
query: str,
budget: TokenBudget,
model: str | None = None,
ensure_required: bool = True,
**kwargs: Any,
) -> RankingResult:
"""
Rank and select contexts within budget.
Args:
contexts: Contexts to rank
query: Query to rank against
budget: Token budget constraints
model: Model for token counting
ensure_required: If True, always include CRITICAL priority contexts
**kwargs: Additional scoring parameters
Returns:
RankingResult with selected and excluded contexts
"""
if not contexts:
return RankingResult(
selected=[],
excluded=[],
total_tokens=0,
selection_stats={"total_contexts": 0},
)
# 1. Ensure all contexts have token counts
await self._ensure_token_counts(contexts, model)
# 2. Score all contexts
scored_contexts = await self._scorer.score_batch(contexts, query, **kwargs)
# 3. Separate required (CRITICAL priority) from optional
required: list[ScoredContext] = []
optional: list[ScoredContext] = []
if ensure_required:
for sc in scored_contexts:
# CRITICAL priority (150) contexts are always included
if sc.context.priority >= ContextPriority.CRITICAL.value:
required.append(sc)
else:
optional.append(sc)
else:
optional = list(scored_contexts)
# 4. Sort optional by score (highest first)
optional.sort(reverse=True)
# 5. Greedy selection
selected: list[ScoredContext] = []
excluded: list[ScoredContext] = []
total_tokens = 0
# Calculate the usable budget (total minus reserved portions)
usable_budget = budget.total - budget.response_reserve - budget.buffer
# Guard against invalid budget configuration
if usable_budget <= 0:
raise BudgetExceededError(
message=(
f"Invalid budget configuration: no usable tokens available. "
f"total={budget.total}, response_reserve={budget.response_reserve}, "
f"buffer={budget.buffer}"
),
allocated=budget.total,
requested=0,
context_type="CONFIGURATION_ERROR",
)
# First, try to fit required contexts
for sc in required:
token_count = self._get_valid_token_count(sc.context)
context_type = sc.context.get_type()
if budget.can_fit(context_type, token_count):
budget.allocate(context_type, token_count)
selected.append(sc)
total_tokens += token_count
else:
# Force-fit CRITICAL contexts if needed, but check total budget first
if total_tokens + token_count > usable_budget:
# Even CRITICAL contexts cannot exceed total model context window
raise BudgetExceededError(
message=(
f"CRITICAL contexts exceed total budget. "
f"Context '{sc.context.source}' ({token_count} tokens) "
f"would exceed usable budget of {usable_budget} tokens."
),
allocated=usable_budget,
requested=total_tokens + token_count,
context_type="CRITICAL_OVERFLOW",
)
budget.allocate(context_type, token_count, force=True)
selected.append(sc)
total_tokens += token_count
logger.warning(
f"Force-fitted CRITICAL context: {sc.context.source} "
f"({token_count} tokens)"
)
# Then, greedily add optional contexts
for sc in optional:
token_count = self._get_valid_token_count(sc.context)
context_type = sc.context.get_type()
if budget.can_fit(context_type, token_count):
budget.allocate(context_type, token_count)
selected.append(sc)
total_tokens += token_count
else:
excluded.append(sc)
# Build stats
stats = {
"total_contexts": len(contexts),
"required_count": len(required),
"selected_count": len(selected),
"excluded_count": len(excluded),
"total_tokens": total_tokens,
"by_type": self._count_by_type(selected),
}
return RankingResult(
selected=selected,
excluded=excluded,
total_tokens=total_tokens,
selection_stats=stats,
)
async def rank_simple(
self,
contexts: list[BaseContext],
query: str,
max_tokens: int,
model: str | None = None,
**kwargs: Any,
) -> list[BaseContext]:
"""
Simple ranking without budget per type.
Selects top contexts by score until max tokens reached.
Args:
contexts: Contexts to rank
query: Query to rank against
max_tokens: Maximum total tokens
model: Model for token counting
**kwargs: Additional scoring parameters
Returns:
Selected contexts (in score order)
"""
if not contexts:
return []
# Ensure token counts
await self._ensure_token_counts(contexts, model)
# Score all contexts
scored_contexts = await self._scorer.score_batch(contexts, query, **kwargs)
# Sort by score (highest first)
scored_contexts.sort(reverse=True)
# Greedy selection
selected: list[BaseContext] = []
total_tokens = 0
for sc in scored_contexts:
token_count = self._get_valid_token_count(sc.context)
if total_tokens + token_count <= max_tokens:
selected.append(sc.context)
total_tokens += token_count
return selected
def _get_valid_token_count(self, context: BaseContext) -> int:
"""
Get validated token count from a context.
Ensures token_count is set (not None) and non-negative to prevent
budget bypass attacks where:
- None would be treated as 0 (allowing huge contexts to slip through)
- Negative values would corrupt budget tracking
Args:
context: Context to get token count from
Returns:
Valid non-negative token count
Raises:
ValueError: If token_count is None or negative
"""
if context.token_count is None:
raise ValueError(
f"Context '{context.source}' has no token count. "
"Ensure _ensure_token_counts() is called before ranking."
)
if context.token_count < 0:
raise ValueError(
f"Context '{context.source}' has invalid negative token count: "
f"{context.token_count}"
)
return context.token_count
async def _ensure_token_counts(
self,
contexts: list[BaseContext],
model: str | None = None,
) -> None:
"""
Ensure all contexts have token counts.
Counts tokens in parallel for contexts that don't have counts.
Args:
contexts: Contexts to check
model: Model for token counting
"""
import asyncio
# Find contexts needing counts
contexts_needing_counts = [ctx for ctx in contexts if ctx.token_count is None]
if not contexts_needing_counts:
return
# Count all in parallel
tasks = [
self._calculator.count_tokens(ctx.content, model)
for ctx in contexts_needing_counts
]
counts = await asyncio.gather(*tasks)
# Assign counts back
for ctx, count in zip(contexts_needing_counts, counts, strict=True):
ctx.token_count = count
def _count_by_type(
self, scored_contexts: list[ScoredContext]
) -> dict[str, dict[str, int]]:
"""Count selected contexts by type."""
by_type: dict[str, dict[str, int]] = {}
for sc in scored_contexts:
type_name = sc.context.get_type().value
if type_name not in by_type:
by_type[type_name] = {"count": 0, "tokens": 0}
by_type[type_name]["count"] += 1
# Use validated token count (already validated during ranking)
by_type[type_name]["tokens"] += sc.context.token_count or 0
return by_type
async def rerank_for_diversity(
self,
scored_contexts: list[ScoredContext],
max_per_source: int | None = None,
) -> list[ScoredContext]:
"""
Rerank to ensure source diversity.
Prevents too many items from the same source.
Args:
scored_contexts: Already scored contexts
max_per_source: Maximum items per source (uses settings if None)
Returns:
Reranked contexts
"""
# Use provided value or fall back to settings
effective_max = (
max_per_source
if max_per_source is not None
else self._settings.diversity_max_per_source
)
source_counts: dict[str, int] = {}
result: list[ScoredContext] = []
deferred: list[ScoredContext] = []
for sc in scored_contexts:
source = sc.context.source
current_count = source_counts.get(source, 0)
if current_count < effective_max:
result.append(sc)
source_counts[source] = current_count + 1
else:
deferred.append(sc)
# Add deferred items at the end
result.extend(deferred)
return result