Files
fast-next-template/backend/app/services/context/prioritization/ranker.py
Felipe Cardoso 0d2005ddcb feat(context): implement context scoring and ranking (Phase 3)
Add comprehensive scoring system with three strategies:
- RelevanceScorer: Semantic similarity with keyword fallback
- RecencyScorer: Exponential decay with type-specific half-lives
- PriorityScorer: Priority-based scoring with type bonuses

Implement CompositeScorer combining all strategies with configurable
weights (default: 50% relevance, 30% recency, 20% priority).

Add ContextRanker for budget-aware context selection with:
- Greedy selection algorithm respecting token budgets
- CRITICAL priority contexts always included
- Diversity reranking to prevent source dominance
- Comprehensive selection statistics

68 tests covering all scoring and ranking functionality.

Part of #61 - Context Management Engine

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-04 02:24:06 +01:00

289 lines
8.6 KiB
Python

"""
Context Ranker for Context Management.
Ranks and selects contexts based on scores and budget constraints.
"""
import logging
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from ..budget import TokenBudget, TokenCalculator
from ..scoring.composite import CompositeScorer, ScoredContext
from ..types import BaseContext, ContextType
if TYPE_CHECKING:
from app.services.mcp.client_manager import MCPClientManager
logger = logging.getLogger(__name__)
@dataclass
class RankingResult:
"""Result of context ranking and selection."""
selected: list[ScoredContext]
excluded: list[ScoredContext]
total_tokens: int
selection_stats: dict[str, Any] = field(default_factory=dict)
@property
def selected_contexts(self) -> list[BaseContext]:
"""Get just the context objects (not scored wrappers)."""
return [s.context for s in self.selected]
class ContextRanker:
"""
Ranks and selects contexts within budget constraints.
Uses greedy selection to maximize total score
while respecting token budgets per context type.
"""
def __init__(
self,
scorer: CompositeScorer | None = None,
calculator: TokenCalculator | None = None,
) -> None:
"""
Initialize context ranker.
Args:
scorer: Composite scorer for scoring contexts
calculator: Token calculator for counting tokens
"""
self._scorer = scorer or CompositeScorer()
self._calculator = calculator or TokenCalculator()
def set_scorer(self, scorer: CompositeScorer) -> None:
"""Set the scorer."""
self._scorer = scorer
def set_calculator(self, calculator: TokenCalculator) -> None:
"""Set the token calculator."""
self._calculator = calculator
async def rank(
self,
contexts: list[BaseContext],
query: str,
budget: TokenBudget,
model: str | None = None,
ensure_required: bool = True,
**kwargs: Any,
) -> RankingResult:
"""
Rank and select contexts within budget.
Args:
contexts: Contexts to rank
query: Query to rank against
budget: Token budget constraints
model: Model for token counting
ensure_required: If True, always include CRITICAL priority contexts
**kwargs: Additional scoring parameters
Returns:
RankingResult with selected and excluded contexts
"""
if not contexts:
return RankingResult(
selected=[],
excluded=[],
total_tokens=0,
selection_stats={"total_contexts": 0},
)
# 1. Ensure all contexts have token counts
await self._ensure_token_counts(contexts, model)
# 2. Score all contexts
scored_contexts = await self._scorer.score_batch(contexts, query, **kwargs)
# 3. Separate required (CRITICAL priority) from optional
required: list[ScoredContext] = []
optional: list[ScoredContext] = []
if ensure_required:
for sc in scored_contexts:
# CRITICAL priority (100) contexts are always included
if sc.context.priority >= 100:
required.append(sc)
else:
optional.append(sc)
else:
optional = list(scored_contexts)
# 4. Sort optional by score (highest first)
optional.sort(reverse=True)
# 5. Greedy selection
selected: list[ScoredContext] = []
excluded: list[ScoredContext] = []
total_tokens = 0
# First, try to fit required contexts
for sc in required:
token_count = sc.context.token_count or 0
context_type = sc.context.get_type()
if budget.can_fit(context_type, token_count):
budget.allocate(context_type, token_count)
selected.append(sc)
total_tokens += token_count
else:
# Force-fit CRITICAL contexts if needed
budget.allocate(context_type, token_count, force=True)
selected.append(sc)
total_tokens += token_count
logger.warning(
f"Force-fitted CRITICAL context: {sc.context.source} "
f"({token_count} tokens)"
)
# Then, greedily add optional contexts
for sc in optional:
token_count = sc.context.token_count or 0
context_type = sc.context.get_type()
if budget.can_fit(context_type, token_count):
budget.allocate(context_type, token_count)
selected.append(sc)
total_tokens += token_count
else:
excluded.append(sc)
# Build stats
stats = {
"total_contexts": len(contexts),
"required_count": len(required),
"selected_count": len(selected),
"excluded_count": len(excluded),
"total_tokens": total_tokens,
"by_type": self._count_by_type(selected),
}
return RankingResult(
selected=selected,
excluded=excluded,
total_tokens=total_tokens,
selection_stats=stats,
)
async def rank_simple(
self,
contexts: list[BaseContext],
query: str,
max_tokens: int,
model: str | None = None,
**kwargs: Any,
) -> list[BaseContext]:
"""
Simple ranking without budget per type.
Selects top contexts by score until max tokens reached.
Args:
contexts: Contexts to rank
query: Query to rank against
max_tokens: Maximum total tokens
model: Model for token counting
**kwargs: Additional scoring parameters
Returns:
Selected contexts (in score order)
"""
if not contexts:
return []
# Ensure token counts
await self._ensure_token_counts(contexts, model)
# Score all contexts
scored_contexts = await self._scorer.score_batch(contexts, query, **kwargs)
# Sort by score (highest first)
scored_contexts.sort(reverse=True)
# Greedy selection
selected: list[BaseContext] = []
total_tokens = 0
for sc in scored_contexts:
token_count = sc.context.token_count or 0
if total_tokens + token_count <= max_tokens:
selected.append(sc.context)
total_tokens += token_count
return selected
async def _ensure_token_counts(
self,
contexts: list[BaseContext],
model: str | None = None,
) -> None:
"""
Ensure all contexts have token counts.
Args:
contexts: Contexts to check
model: Model for token counting
"""
for context in contexts:
if context.token_count is None:
count = await self._calculator.count_tokens(
context.content, model
)
context.token_count = count
def _count_by_type(
self, scored_contexts: list[ScoredContext]
) -> dict[str, dict[str, int]]:
"""Count selected contexts by type."""
by_type: dict[str, dict[str, int]] = {}
for sc in scored_contexts:
type_name = sc.context.get_type().value
if type_name not in by_type:
by_type[type_name] = {"count": 0, "tokens": 0}
by_type[type_name]["count"] += 1
by_type[type_name]["tokens"] += sc.context.token_count or 0
return by_type
async def rerank_for_diversity(
self,
scored_contexts: list[ScoredContext],
max_per_source: int = 3,
) -> list[ScoredContext]:
"""
Rerank to ensure source diversity.
Prevents too many items from the same source.
Args:
scored_contexts: Already scored contexts
max_per_source: Maximum items per source
Returns:
Reranked contexts
"""
source_counts: dict[str, int] = {}
result: list[ScoredContext] = []
deferred: list[ScoredContext] = []
for sc in scored_contexts:
source = sc.context.source
current_count = source_counts.get(source, 0)
if current_count < max_per_source:
result.append(sc)
source_counts[source] = current_count + 1
else:
deferred.append(sc)
# Add deferred items at the end
result.extend(deferred)
return result