forked from cardosofelipe/fast-next-template
feat(context): implement context scoring and ranking (Phase 3)
Add comprehensive scoring system with three strategies: - RelevanceScorer: Semantic similarity with keyword fallback - RecencyScorer: Exponential decay with type-specific half-lives - PriorityScorer: Priority-based scoring with type bonuses Implement CompositeScorer combining all strategies with configurable weights (default: 50% relevance, 30% recency, 20% priority). Add ContextRanker for budget-aware context selection with: - Greedy selection algorithm respecting token budgets - CRITICAL priority contexts always included - Diversity reranking to prevent source dominance - Comprehensive selection statistics 68 tests covering all scoring and ranking functionality. Part of #61 - Context Management Engine 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -3,3 +3,10 @@ Context Prioritization Module.
|
||||
|
||||
Provides context ranking and selection.
|
||||
"""
|
||||
|
||||
from .ranker import ContextRanker, RankingResult
|
||||
|
||||
__all__ = [
|
||||
"ContextRanker",
|
||||
"RankingResult",
|
||||
]
|
||||
|
||||
288
backend/app/services/context/prioritization/ranker.py
Normal file
288
backend/app/services/context/prioritization/ranker.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
Context Ranker for Context Management.
|
||||
|
||||
Ranks and selects contexts based on scores and budget constraints.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ..budget import TokenBudget, TokenCalculator
|
||||
from ..scoring.composite import CompositeScorer, ScoredContext
|
||||
from ..types import BaseContext, ContextType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.mcp.client_manager import MCPClientManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RankingResult:
|
||||
"""Result of context ranking and selection."""
|
||||
|
||||
selected: list[ScoredContext]
|
||||
excluded: list[ScoredContext]
|
||||
total_tokens: int
|
||||
selection_stats: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def selected_contexts(self) -> list[BaseContext]:
|
||||
"""Get just the context objects (not scored wrappers)."""
|
||||
return [s.context for s in self.selected]
|
||||
|
||||
|
||||
class ContextRanker:
|
||||
"""
|
||||
Ranks and selects contexts within budget constraints.
|
||||
|
||||
Uses greedy selection to maximize total score
|
||||
while respecting token budgets per context type.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scorer: CompositeScorer | None = None,
|
||||
calculator: TokenCalculator | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize context ranker.
|
||||
|
||||
Args:
|
||||
scorer: Composite scorer for scoring contexts
|
||||
calculator: Token calculator for counting tokens
|
||||
"""
|
||||
self._scorer = scorer or CompositeScorer()
|
||||
self._calculator = calculator or TokenCalculator()
|
||||
|
||||
def set_scorer(self, scorer: CompositeScorer) -> None:
|
||||
"""Set the scorer."""
|
||||
self._scorer = scorer
|
||||
|
||||
def set_calculator(self, calculator: TokenCalculator) -> None:
|
||||
"""Set the token calculator."""
|
||||
self._calculator = calculator
|
||||
|
||||
async def rank(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
budget: TokenBudget,
|
||||
model: str | None = None,
|
||||
ensure_required: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> RankingResult:
|
||||
"""
|
||||
Rank and select contexts within budget.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to rank
|
||||
query: Query to rank against
|
||||
budget: Token budget constraints
|
||||
model: Model for token counting
|
||||
ensure_required: If True, always include CRITICAL priority contexts
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
RankingResult with selected and excluded contexts
|
||||
"""
|
||||
if not contexts:
|
||||
return RankingResult(
|
||||
selected=[],
|
||||
excluded=[],
|
||||
total_tokens=0,
|
||||
selection_stats={"total_contexts": 0},
|
||||
)
|
||||
|
||||
# 1. Ensure all contexts have token counts
|
||||
await self._ensure_token_counts(contexts, model)
|
||||
|
||||
# 2. Score all contexts
|
||||
scored_contexts = await self._scorer.score_batch(contexts, query, **kwargs)
|
||||
|
||||
# 3. Separate required (CRITICAL priority) from optional
|
||||
required: list[ScoredContext] = []
|
||||
optional: list[ScoredContext] = []
|
||||
|
||||
if ensure_required:
|
||||
for sc in scored_contexts:
|
||||
# CRITICAL priority (100) contexts are always included
|
||||
if sc.context.priority >= 100:
|
||||
required.append(sc)
|
||||
else:
|
||||
optional.append(sc)
|
||||
else:
|
||||
optional = list(scored_contexts)
|
||||
|
||||
# 4. Sort optional by score (highest first)
|
||||
optional.sort(reverse=True)
|
||||
|
||||
# 5. Greedy selection
|
||||
selected: list[ScoredContext] = []
|
||||
excluded: list[ScoredContext] = []
|
||||
total_tokens = 0
|
||||
|
||||
# First, try to fit required contexts
|
||||
for sc in required:
|
||||
token_count = sc.context.token_count or 0
|
||||
context_type = sc.context.get_type()
|
||||
|
||||
if budget.can_fit(context_type, token_count):
|
||||
budget.allocate(context_type, token_count)
|
||||
selected.append(sc)
|
||||
total_tokens += token_count
|
||||
else:
|
||||
# Force-fit CRITICAL contexts if needed
|
||||
budget.allocate(context_type, token_count, force=True)
|
||||
selected.append(sc)
|
||||
total_tokens += token_count
|
||||
logger.warning(
|
||||
f"Force-fitted CRITICAL context: {sc.context.source} "
|
||||
f"({token_count} tokens)"
|
||||
)
|
||||
|
||||
# Then, greedily add optional contexts
|
||||
for sc in optional:
|
||||
token_count = sc.context.token_count or 0
|
||||
context_type = sc.context.get_type()
|
||||
|
||||
if budget.can_fit(context_type, token_count):
|
||||
budget.allocate(context_type, token_count)
|
||||
selected.append(sc)
|
||||
total_tokens += token_count
|
||||
else:
|
||||
excluded.append(sc)
|
||||
|
||||
# Build stats
|
||||
stats = {
|
||||
"total_contexts": len(contexts),
|
||||
"required_count": len(required),
|
||||
"selected_count": len(selected),
|
||||
"excluded_count": len(excluded),
|
||||
"total_tokens": total_tokens,
|
||||
"by_type": self._count_by_type(selected),
|
||||
}
|
||||
|
||||
return RankingResult(
|
||||
selected=selected,
|
||||
excluded=excluded,
|
||||
total_tokens=total_tokens,
|
||||
selection_stats=stats,
|
||||
)
|
||||
|
||||
async def rank_simple(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
max_tokens: int,
|
||||
model: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[BaseContext]:
|
||||
"""
|
||||
Simple ranking without budget per type.
|
||||
|
||||
Selects top contexts by score until max tokens reached.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to rank
|
||||
query: Query to rank against
|
||||
max_tokens: Maximum total tokens
|
||||
model: Model for token counting
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
Selected contexts (in score order)
|
||||
"""
|
||||
if not contexts:
|
||||
return []
|
||||
|
||||
# Ensure token counts
|
||||
await self._ensure_token_counts(contexts, model)
|
||||
|
||||
# Score all contexts
|
||||
scored_contexts = await self._scorer.score_batch(contexts, query, **kwargs)
|
||||
|
||||
# Sort by score (highest first)
|
||||
scored_contexts.sort(reverse=True)
|
||||
|
||||
# Greedy selection
|
||||
selected: list[BaseContext] = []
|
||||
total_tokens = 0
|
||||
|
||||
for sc in scored_contexts:
|
||||
token_count = sc.context.token_count or 0
|
||||
if total_tokens + token_count <= max_tokens:
|
||||
selected.append(sc.context)
|
||||
total_tokens += token_count
|
||||
|
||||
return selected
|
||||
|
||||
async def _ensure_token_counts(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
model: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Ensure all contexts have token counts.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to check
|
||||
model: Model for token counting
|
||||
"""
|
||||
for context in contexts:
|
||||
if context.token_count is None:
|
||||
count = await self._calculator.count_tokens(
|
||||
context.content, model
|
||||
)
|
||||
context.token_count = count
|
||||
|
||||
def _count_by_type(
|
||||
self, scored_contexts: list[ScoredContext]
|
||||
) -> dict[str, dict[str, int]]:
|
||||
"""Count selected contexts by type."""
|
||||
by_type: dict[str, dict[str, int]] = {}
|
||||
|
||||
for sc in scored_contexts:
|
||||
type_name = sc.context.get_type().value
|
||||
if type_name not in by_type:
|
||||
by_type[type_name] = {"count": 0, "tokens": 0}
|
||||
by_type[type_name]["count"] += 1
|
||||
by_type[type_name]["tokens"] += sc.context.token_count or 0
|
||||
|
||||
return by_type
|
||||
|
||||
async def rerank_for_diversity(
|
||||
self,
|
||||
scored_contexts: list[ScoredContext],
|
||||
max_per_source: int = 3,
|
||||
) -> list[ScoredContext]:
|
||||
"""
|
||||
Rerank to ensure source diversity.
|
||||
|
||||
Prevents too many items from the same source.
|
||||
|
||||
Args:
|
||||
scored_contexts: Already scored contexts
|
||||
max_per_source: Maximum items per source
|
||||
|
||||
Returns:
|
||||
Reranked contexts
|
||||
"""
|
||||
source_counts: dict[str, int] = {}
|
||||
result: list[ScoredContext] = []
|
||||
deferred: list[ScoredContext] = []
|
||||
|
||||
for sc in scored_contexts:
|
||||
source = sc.context.source
|
||||
current_count = source_counts.get(source, 0)
|
||||
|
||||
if current_count < max_per_source:
|
||||
result.append(sc)
|
||||
source_counts[source] = current_count + 1
|
||||
else:
|
||||
deferred.append(sc)
|
||||
|
||||
# Add deferred items at the end
|
||||
result.extend(deferred)
|
||||
return result
|
||||
Reference in New Issue
Block a user