""" Context Ranker for Context Management. Ranks and selects contexts based on scores and budget constraints. """ import logging from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any from ..budget import TokenBudget, TokenCalculator from ..config import ContextSettings, get_context_settings from ..exceptions import BudgetExceededError from ..scoring.composite import CompositeScorer, ScoredContext from ..types import BaseContext, ContextPriority if TYPE_CHECKING: pass logger = logging.getLogger(__name__) @dataclass class RankingResult: """Result of context ranking and selection.""" selected: list[ScoredContext] excluded: list[ScoredContext] total_tokens: int selection_stats: dict[str, Any] = field(default_factory=dict) @property def selected_contexts(self) -> list[BaseContext]: """Get just the context objects (not scored wrappers).""" return [s.context for s in self.selected] class ContextRanker: """ Ranks and selects contexts within budget constraints. Uses greedy selection to maximize total score while respecting token budgets per context type. """ def __init__( self, scorer: CompositeScorer | None = None, calculator: TokenCalculator | None = None, settings: ContextSettings | None = None, ) -> None: """ Initialize context ranker. Args: scorer: Composite scorer for scoring contexts calculator: Token calculator for counting tokens settings: Context settings (uses global if None) """ self._settings = settings or get_context_settings() self._scorer = scorer or CompositeScorer() self._calculator = calculator or TokenCalculator() def set_scorer(self, scorer: CompositeScorer) -> None: """Set the scorer.""" self._scorer = scorer def set_calculator(self, calculator: TokenCalculator) -> None: """Set the token calculator.""" self._calculator = calculator async def rank( self, contexts: list[BaseContext], query: str, budget: TokenBudget, model: str | None = None, ensure_required: bool = True, **kwargs: Any, ) -> RankingResult: """ Rank and select contexts within budget. Args: contexts: Contexts to rank query: Query to rank against budget: Token budget constraints model: Model for token counting ensure_required: If True, always include CRITICAL priority contexts **kwargs: Additional scoring parameters Returns: RankingResult with selected and excluded contexts """ if not contexts: return RankingResult( selected=[], excluded=[], total_tokens=0, selection_stats={"total_contexts": 0}, ) # 1. Ensure all contexts have token counts await self._ensure_token_counts(contexts, model) # 2. Score all contexts scored_contexts = await self._scorer.score_batch(contexts, query, **kwargs) # 3. Separate required (CRITICAL priority) from optional required: list[ScoredContext] = [] optional: list[ScoredContext] = [] if ensure_required: for sc in scored_contexts: # CRITICAL priority (150) contexts are always included if sc.context.priority >= ContextPriority.CRITICAL.value: required.append(sc) else: optional.append(sc) else: optional = list(scored_contexts) # 4. Sort optional by score (highest first) optional.sort(reverse=True) # 5. Greedy selection selected: list[ScoredContext] = [] excluded: list[ScoredContext] = [] total_tokens = 0 # Calculate the usable budget (total minus reserved portions) usable_budget = budget.total - budget.response_reserve - budget.buffer # Guard against invalid budget configuration if usable_budget <= 0: raise BudgetExceededError( message=( f"Invalid budget configuration: no usable tokens available. " f"total={budget.total}, response_reserve={budget.response_reserve}, " f"buffer={budget.buffer}" ), allocated=budget.total, requested=0, context_type="CONFIGURATION_ERROR", ) # First, try to fit required contexts for sc in required: token_count = self._get_valid_token_count(sc.context) context_type = sc.context.get_type() if budget.can_fit(context_type, token_count): budget.allocate(context_type, token_count) selected.append(sc) total_tokens += token_count else: # Force-fit CRITICAL contexts if needed, but check total budget first if total_tokens + token_count > usable_budget: # Even CRITICAL contexts cannot exceed total model context window raise BudgetExceededError( message=( f"CRITICAL contexts exceed total budget. " f"Context '{sc.context.source}' ({token_count} tokens) " f"would exceed usable budget of {usable_budget} tokens." ), allocated=usable_budget, requested=total_tokens + token_count, context_type="CRITICAL_OVERFLOW", ) budget.allocate(context_type, token_count, force=True) selected.append(sc) total_tokens += token_count logger.warning( f"Force-fitted CRITICAL context: {sc.context.source} " f"({token_count} tokens)" ) # Then, greedily add optional contexts for sc in optional: token_count = self._get_valid_token_count(sc.context) context_type = sc.context.get_type() if budget.can_fit(context_type, token_count): budget.allocate(context_type, token_count) selected.append(sc) total_tokens += token_count else: excluded.append(sc) # Build stats stats = { "total_contexts": len(contexts), "required_count": len(required), "selected_count": len(selected), "excluded_count": len(excluded), "total_tokens": total_tokens, "by_type": self._count_by_type(selected), } return RankingResult( selected=selected, excluded=excluded, total_tokens=total_tokens, selection_stats=stats, ) async def rank_simple( self, contexts: list[BaseContext], query: str, max_tokens: int, model: str | None = None, **kwargs: Any, ) -> list[BaseContext]: """ Simple ranking without budget per type. Selects top contexts by score until max tokens reached. Args: contexts: Contexts to rank query: Query to rank against max_tokens: Maximum total tokens model: Model for token counting **kwargs: Additional scoring parameters Returns: Selected contexts (in score order) """ if not contexts: return [] # Ensure token counts await self._ensure_token_counts(contexts, model) # Score all contexts scored_contexts = await self._scorer.score_batch(contexts, query, **kwargs) # Sort by score (highest first) scored_contexts.sort(reverse=True) # Greedy selection selected: list[BaseContext] = [] total_tokens = 0 for sc in scored_contexts: token_count = self._get_valid_token_count(sc.context) if total_tokens + token_count <= max_tokens: selected.append(sc.context) total_tokens += token_count return selected def _get_valid_token_count(self, context: BaseContext) -> int: """ Get validated token count from a context. Ensures token_count is set (not None) and non-negative to prevent budget bypass attacks where: - None would be treated as 0 (allowing huge contexts to slip through) - Negative values would corrupt budget tracking Args: context: Context to get token count from Returns: Valid non-negative token count Raises: ValueError: If token_count is None or negative """ if context.token_count is None: raise ValueError( f"Context '{context.source}' has no token count. " "Ensure _ensure_token_counts() is called before ranking." ) if context.token_count < 0: raise ValueError( f"Context '{context.source}' has invalid negative token count: " f"{context.token_count}" ) return context.token_count async def _ensure_token_counts( self, contexts: list[BaseContext], model: str | None = None, ) -> None: """ Ensure all contexts have token counts. Counts tokens in parallel for contexts that don't have counts. Args: contexts: Contexts to check model: Model for token counting """ import asyncio # Find contexts needing counts contexts_needing_counts = [ctx for ctx in contexts if ctx.token_count is None] if not contexts_needing_counts: return # Count all in parallel tasks = [ self._calculator.count_tokens(ctx.content, model) for ctx in contexts_needing_counts ] counts = await asyncio.gather(*tasks) # Assign counts back for ctx, count in zip(contexts_needing_counts, counts, strict=True): ctx.token_count = count def _count_by_type( self, scored_contexts: list[ScoredContext] ) -> dict[str, dict[str, int]]: """Count selected contexts by type.""" by_type: dict[str, dict[str, int]] = {} for sc in scored_contexts: type_name = sc.context.get_type().value if type_name not in by_type: by_type[type_name] = {"count": 0, "tokens": 0} by_type[type_name]["count"] += 1 # Use validated token count (already validated during ranking) by_type[type_name]["tokens"] += sc.context.token_count or 0 return by_type async def rerank_for_diversity( self, scored_contexts: list[ScoredContext], max_per_source: int | None = None, ) -> list[ScoredContext]: """ Rerank to ensure source diversity. Prevents too many items from the same source. Args: scored_contexts: Already scored contexts max_per_source: Maximum items per source (uses settings if None) Returns: Reranked contexts """ # Use provided value or fall back to settings effective_max = ( max_per_source if max_per_source is not None else self._settings.diversity_max_per_source ) source_counts: dict[str, int] = {} result: list[ScoredContext] = [] deferred: list[ScoredContext] = [] for sc in scored_contexts: source = sc.context.source current_count = source_counts.get(source, 0) if current_count < effective_max: result.append(sc) source_counts[source] = current_count + 1 else: deferred.append(sc) # Add deferred items at the end result.extend(deferred) return result