- Added stricter budget validation in ContextRanker with explicit error handling for invalid configurations. - Introduced `_get_valid_token_count()` helper to validate and safeguard token counts. - Enhanced XML escaping in Claude adapter to prevent injection risks from scores and unhandled content.
375 lines
12 KiB
Python
375 lines
12 KiB
Python
"""
|
|
Context Ranker for Context Management.
|
|
|
|
Ranks and selects contexts based on scores and budget constraints.
|
|
"""
|
|
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
from ..budget import TokenBudget, TokenCalculator
|
|
from ..config import ContextSettings, get_context_settings
|
|
from ..exceptions import BudgetExceededError
|
|
from ..scoring.composite import CompositeScorer, ScoredContext
|
|
from ..types import BaseContext, ContextPriority
|
|
|
|
if TYPE_CHECKING:
|
|
pass
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class RankingResult:
|
|
"""Result of context ranking and selection."""
|
|
|
|
selected: list[ScoredContext]
|
|
excluded: list[ScoredContext]
|
|
total_tokens: int
|
|
selection_stats: dict[str, Any] = field(default_factory=dict)
|
|
|
|
@property
|
|
def selected_contexts(self) -> list[BaseContext]:
|
|
"""Get just the context objects (not scored wrappers)."""
|
|
return [s.context for s in self.selected]
|
|
|
|
|
|
class ContextRanker:
|
|
"""
|
|
Ranks and selects contexts within budget constraints.
|
|
|
|
Uses greedy selection to maximize total score
|
|
while respecting token budgets per context type.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
scorer: CompositeScorer | None = None,
|
|
calculator: TokenCalculator | None = None,
|
|
settings: ContextSettings | None = None,
|
|
) -> None:
|
|
"""
|
|
Initialize context ranker.
|
|
|
|
Args:
|
|
scorer: Composite scorer for scoring contexts
|
|
calculator: Token calculator for counting tokens
|
|
settings: Context settings (uses global if None)
|
|
"""
|
|
self._settings = settings or get_context_settings()
|
|
self._scorer = scorer or CompositeScorer()
|
|
self._calculator = calculator or TokenCalculator()
|
|
|
|
def set_scorer(self, scorer: CompositeScorer) -> None:
|
|
"""Set the scorer."""
|
|
self._scorer = scorer
|
|
|
|
def set_calculator(self, calculator: TokenCalculator) -> None:
|
|
"""Set the token calculator."""
|
|
self._calculator = calculator
|
|
|
|
async def rank(
|
|
self,
|
|
contexts: list[BaseContext],
|
|
query: str,
|
|
budget: TokenBudget,
|
|
model: str | None = None,
|
|
ensure_required: bool = True,
|
|
**kwargs: Any,
|
|
) -> RankingResult:
|
|
"""
|
|
Rank and select contexts within budget.
|
|
|
|
Args:
|
|
contexts: Contexts to rank
|
|
query: Query to rank against
|
|
budget: Token budget constraints
|
|
model: Model for token counting
|
|
ensure_required: If True, always include CRITICAL priority contexts
|
|
**kwargs: Additional scoring parameters
|
|
|
|
Returns:
|
|
RankingResult with selected and excluded contexts
|
|
"""
|
|
if not contexts:
|
|
return RankingResult(
|
|
selected=[],
|
|
excluded=[],
|
|
total_tokens=0,
|
|
selection_stats={"total_contexts": 0},
|
|
)
|
|
|
|
# 1. Ensure all contexts have token counts
|
|
await self._ensure_token_counts(contexts, model)
|
|
|
|
# 2. Score all contexts
|
|
scored_contexts = await self._scorer.score_batch(contexts, query, **kwargs)
|
|
|
|
# 3. Separate required (CRITICAL priority) from optional
|
|
required: list[ScoredContext] = []
|
|
optional: list[ScoredContext] = []
|
|
|
|
if ensure_required:
|
|
for sc in scored_contexts:
|
|
# CRITICAL priority (150) contexts are always included
|
|
if sc.context.priority >= ContextPriority.CRITICAL.value:
|
|
required.append(sc)
|
|
else:
|
|
optional.append(sc)
|
|
else:
|
|
optional = list(scored_contexts)
|
|
|
|
# 4. Sort optional by score (highest first)
|
|
optional.sort(reverse=True)
|
|
|
|
# 5. Greedy selection
|
|
selected: list[ScoredContext] = []
|
|
excluded: list[ScoredContext] = []
|
|
total_tokens = 0
|
|
|
|
# Calculate the usable budget (total minus reserved portions)
|
|
usable_budget = budget.total - budget.response_reserve - budget.buffer
|
|
|
|
# Guard against invalid budget configuration
|
|
if usable_budget <= 0:
|
|
raise BudgetExceededError(
|
|
message=(
|
|
f"Invalid budget configuration: no usable tokens available. "
|
|
f"total={budget.total}, response_reserve={budget.response_reserve}, "
|
|
f"buffer={budget.buffer}"
|
|
),
|
|
allocated=budget.total,
|
|
requested=0,
|
|
context_type="CONFIGURATION_ERROR",
|
|
)
|
|
|
|
# First, try to fit required contexts
|
|
for sc in required:
|
|
token_count = self._get_valid_token_count(sc.context)
|
|
context_type = sc.context.get_type()
|
|
|
|
if budget.can_fit(context_type, token_count):
|
|
budget.allocate(context_type, token_count)
|
|
selected.append(sc)
|
|
total_tokens += token_count
|
|
else:
|
|
# Force-fit CRITICAL contexts if needed, but check total budget first
|
|
if total_tokens + token_count > usable_budget:
|
|
# Even CRITICAL contexts cannot exceed total model context window
|
|
raise BudgetExceededError(
|
|
message=(
|
|
f"CRITICAL contexts exceed total budget. "
|
|
f"Context '{sc.context.source}' ({token_count} tokens) "
|
|
f"would exceed usable budget of {usable_budget} tokens."
|
|
),
|
|
allocated=usable_budget,
|
|
requested=total_tokens + token_count,
|
|
context_type="CRITICAL_OVERFLOW",
|
|
)
|
|
|
|
budget.allocate(context_type, token_count, force=True)
|
|
selected.append(sc)
|
|
total_tokens += token_count
|
|
logger.warning(
|
|
f"Force-fitted CRITICAL context: {sc.context.source} "
|
|
f"({token_count} tokens)"
|
|
)
|
|
|
|
# Then, greedily add optional contexts
|
|
for sc in optional:
|
|
token_count = self._get_valid_token_count(sc.context)
|
|
context_type = sc.context.get_type()
|
|
|
|
if budget.can_fit(context_type, token_count):
|
|
budget.allocate(context_type, token_count)
|
|
selected.append(sc)
|
|
total_tokens += token_count
|
|
else:
|
|
excluded.append(sc)
|
|
|
|
# Build stats
|
|
stats = {
|
|
"total_contexts": len(contexts),
|
|
"required_count": len(required),
|
|
"selected_count": len(selected),
|
|
"excluded_count": len(excluded),
|
|
"total_tokens": total_tokens,
|
|
"by_type": self._count_by_type(selected),
|
|
}
|
|
|
|
return RankingResult(
|
|
selected=selected,
|
|
excluded=excluded,
|
|
total_tokens=total_tokens,
|
|
selection_stats=stats,
|
|
)
|
|
|
|
async def rank_simple(
|
|
self,
|
|
contexts: list[BaseContext],
|
|
query: str,
|
|
max_tokens: int,
|
|
model: str | None = None,
|
|
**kwargs: Any,
|
|
) -> list[BaseContext]:
|
|
"""
|
|
Simple ranking without budget per type.
|
|
|
|
Selects top contexts by score until max tokens reached.
|
|
|
|
Args:
|
|
contexts: Contexts to rank
|
|
query: Query to rank against
|
|
max_tokens: Maximum total tokens
|
|
model: Model for token counting
|
|
**kwargs: Additional scoring parameters
|
|
|
|
Returns:
|
|
Selected contexts (in score order)
|
|
"""
|
|
if not contexts:
|
|
return []
|
|
|
|
# Ensure token counts
|
|
await self._ensure_token_counts(contexts, model)
|
|
|
|
# Score all contexts
|
|
scored_contexts = await self._scorer.score_batch(contexts, query, **kwargs)
|
|
|
|
# Sort by score (highest first)
|
|
scored_contexts.sort(reverse=True)
|
|
|
|
# Greedy selection
|
|
selected: list[BaseContext] = []
|
|
total_tokens = 0
|
|
|
|
for sc in scored_contexts:
|
|
token_count = self._get_valid_token_count(sc.context)
|
|
if total_tokens + token_count <= max_tokens:
|
|
selected.append(sc.context)
|
|
total_tokens += token_count
|
|
|
|
return selected
|
|
|
|
def _get_valid_token_count(self, context: BaseContext) -> int:
|
|
"""
|
|
Get validated token count from a context.
|
|
|
|
Ensures token_count is set (not None) and non-negative to prevent
|
|
budget bypass attacks where:
|
|
- None would be treated as 0 (allowing huge contexts to slip through)
|
|
- Negative values would corrupt budget tracking
|
|
|
|
Args:
|
|
context: Context to get token count from
|
|
|
|
Returns:
|
|
Valid non-negative token count
|
|
|
|
Raises:
|
|
ValueError: If token_count is None or negative
|
|
"""
|
|
if context.token_count is None:
|
|
raise ValueError(
|
|
f"Context '{context.source}' has no token count. "
|
|
"Ensure _ensure_token_counts() is called before ranking."
|
|
)
|
|
if context.token_count < 0:
|
|
raise ValueError(
|
|
f"Context '{context.source}' has invalid negative token count: "
|
|
f"{context.token_count}"
|
|
)
|
|
return context.token_count
|
|
|
|
async def _ensure_token_counts(
|
|
self,
|
|
contexts: list[BaseContext],
|
|
model: str | None = None,
|
|
) -> None:
|
|
"""
|
|
Ensure all contexts have token counts.
|
|
|
|
Counts tokens in parallel for contexts that don't have counts.
|
|
|
|
Args:
|
|
contexts: Contexts to check
|
|
model: Model for token counting
|
|
"""
|
|
import asyncio
|
|
|
|
# Find contexts needing counts
|
|
contexts_needing_counts = [ctx for ctx in contexts if ctx.token_count is None]
|
|
|
|
if not contexts_needing_counts:
|
|
return
|
|
|
|
# Count all in parallel
|
|
tasks = [
|
|
self._calculator.count_tokens(ctx.content, model)
|
|
for ctx in contexts_needing_counts
|
|
]
|
|
counts = await asyncio.gather(*tasks)
|
|
|
|
# Assign counts back
|
|
for ctx, count in zip(contexts_needing_counts, counts, strict=True):
|
|
ctx.token_count = count
|
|
|
|
def _count_by_type(
|
|
self, scored_contexts: list[ScoredContext]
|
|
) -> dict[str, dict[str, int]]:
|
|
"""Count selected contexts by type."""
|
|
by_type: dict[str, dict[str, int]] = {}
|
|
|
|
for sc in scored_contexts:
|
|
type_name = sc.context.get_type().value
|
|
if type_name not in by_type:
|
|
by_type[type_name] = {"count": 0, "tokens": 0}
|
|
by_type[type_name]["count"] += 1
|
|
# Use validated token count (already validated during ranking)
|
|
by_type[type_name]["tokens"] += sc.context.token_count or 0
|
|
|
|
return by_type
|
|
|
|
async def rerank_for_diversity(
|
|
self,
|
|
scored_contexts: list[ScoredContext],
|
|
max_per_source: int | None = None,
|
|
) -> list[ScoredContext]:
|
|
"""
|
|
Rerank to ensure source diversity.
|
|
|
|
Prevents too many items from the same source.
|
|
|
|
Args:
|
|
scored_contexts: Already scored contexts
|
|
max_per_source: Maximum items per source (uses settings if None)
|
|
|
|
Returns:
|
|
Reranked contexts
|
|
"""
|
|
# Use provided value or fall back to settings
|
|
effective_max = (
|
|
max_per_source
|
|
if max_per_source is not None
|
|
else self._settings.diversity_max_per_source
|
|
)
|
|
|
|
source_counts: dict[str, int] = {}
|
|
result: list[ScoredContext] = []
|
|
deferred: list[ScoredContext] = []
|
|
|
|
for sc in scored_contexts:
|
|
source = sc.context.source
|
|
current_count = source_counts.get(source, 0)
|
|
|
|
if current_count < effective_max:
|
|
result.append(sc)
|
|
source_counts[source] = current_count + 1
|
|
else:
|
|
deferred.append(sc)
|
|
|
|
# Add deferred items at the end
|
|
result.extend(deferred)
|
|
return result
|