forked from cardosofelipe/fast-next-template
feat(context): implement context scoring and ranking (Phase 3)
Add comprehensive scoring system with three strategies: - RelevanceScorer: Semantic similarity with keyword fallback - RecencyScorer: Exponential decay with type-specific half-lives - PriorityScorer: Priority-based scoring with type bonuses Implement CompositeScorer combining all strategies with configurable weights (default: 50% relevance, 30% recency, 20% priority). Add ContextRanker for budget-aware context selection with: - Greedy selection algorithm respecting token budgets - CRITICAL priority contexts always included - Diversity reranking to prevent source dominance - Comprehensive selection statistics 68 tests covering all scoring and ranking functionality. Part of #61 - Context Management Engine 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -63,6 +63,22 @@ from .exceptions import (
|
||||
TokenCountError,
|
||||
)
|
||||
|
||||
# Prioritization
|
||||
from .prioritization import (
|
||||
ContextRanker,
|
||||
RankingResult,
|
||||
)
|
||||
|
||||
# Scoring
|
||||
from .scoring import (
|
||||
BaseScorer,
|
||||
CompositeScorer,
|
||||
PriorityScorer,
|
||||
RecencyScorer,
|
||||
RelevanceScorer,
|
||||
ScoredContext,
|
||||
)
|
||||
|
||||
# Types
|
||||
from .types import (
|
||||
AssembledContext,
|
||||
@@ -101,6 +117,16 @@ __all__ = [
|
||||
"InvalidContextError",
|
||||
"ScoringError",
|
||||
"TokenCountError",
|
||||
# Prioritization
|
||||
"ContextRanker",
|
||||
"RankingResult",
|
||||
# Scoring
|
||||
"BaseScorer",
|
||||
"CompositeScorer",
|
||||
"PriorityScorer",
|
||||
"RecencyScorer",
|
||||
"RelevanceScorer",
|
||||
"ScoredContext",
|
||||
# Types - Base
|
||||
"AssembledContext",
|
||||
"BaseContext",
|
||||
|
||||
@@ -3,3 +3,10 @@ Context Prioritization Module.
|
||||
|
||||
Provides context ranking and selection.
|
||||
"""
|
||||
|
||||
from .ranker import ContextRanker, RankingResult
|
||||
|
||||
__all__ = [
|
||||
"ContextRanker",
|
||||
"RankingResult",
|
||||
]
|
||||
|
||||
288
backend/app/services/context/prioritization/ranker.py
Normal file
288
backend/app/services/context/prioritization/ranker.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
Context Ranker for Context Management.
|
||||
|
||||
Ranks and selects contexts based on scores and budget constraints.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ..budget import TokenBudget, TokenCalculator
|
||||
from ..scoring.composite import CompositeScorer, ScoredContext
|
||||
from ..types import BaseContext, ContextType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.mcp.client_manager import MCPClientManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RankingResult:
|
||||
"""Result of context ranking and selection."""
|
||||
|
||||
selected: list[ScoredContext]
|
||||
excluded: list[ScoredContext]
|
||||
total_tokens: int
|
||||
selection_stats: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def selected_contexts(self) -> list[BaseContext]:
|
||||
"""Get just the context objects (not scored wrappers)."""
|
||||
return [s.context for s in self.selected]
|
||||
|
||||
|
||||
class ContextRanker:
|
||||
"""
|
||||
Ranks and selects contexts within budget constraints.
|
||||
|
||||
Uses greedy selection to maximize total score
|
||||
while respecting token budgets per context type.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scorer: CompositeScorer | None = None,
|
||||
calculator: TokenCalculator | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize context ranker.
|
||||
|
||||
Args:
|
||||
scorer: Composite scorer for scoring contexts
|
||||
calculator: Token calculator for counting tokens
|
||||
"""
|
||||
self._scorer = scorer or CompositeScorer()
|
||||
self._calculator = calculator or TokenCalculator()
|
||||
|
||||
def set_scorer(self, scorer: CompositeScorer) -> None:
|
||||
"""Set the scorer."""
|
||||
self._scorer = scorer
|
||||
|
||||
def set_calculator(self, calculator: TokenCalculator) -> None:
|
||||
"""Set the token calculator."""
|
||||
self._calculator = calculator
|
||||
|
||||
async def rank(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
budget: TokenBudget,
|
||||
model: str | None = None,
|
||||
ensure_required: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> RankingResult:
|
||||
"""
|
||||
Rank and select contexts within budget.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to rank
|
||||
query: Query to rank against
|
||||
budget: Token budget constraints
|
||||
model: Model for token counting
|
||||
ensure_required: If True, always include CRITICAL priority contexts
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
RankingResult with selected and excluded contexts
|
||||
"""
|
||||
if not contexts:
|
||||
return RankingResult(
|
||||
selected=[],
|
||||
excluded=[],
|
||||
total_tokens=0,
|
||||
selection_stats={"total_contexts": 0},
|
||||
)
|
||||
|
||||
# 1. Ensure all contexts have token counts
|
||||
await self._ensure_token_counts(contexts, model)
|
||||
|
||||
# 2. Score all contexts
|
||||
scored_contexts = await self._scorer.score_batch(contexts, query, **kwargs)
|
||||
|
||||
# 3. Separate required (CRITICAL priority) from optional
|
||||
required: list[ScoredContext] = []
|
||||
optional: list[ScoredContext] = []
|
||||
|
||||
if ensure_required:
|
||||
for sc in scored_contexts:
|
||||
# CRITICAL priority (100) contexts are always included
|
||||
if sc.context.priority >= 100:
|
||||
required.append(sc)
|
||||
else:
|
||||
optional.append(sc)
|
||||
else:
|
||||
optional = list(scored_contexts)
|
||||
|
||||
# 4. Sort optional by score (highest first)
|
||||
optional.sort(reverse=True)
|
||||
|
||||
# 5. Greedy selection
|
||||
selected: list[ScoredContext] = []
|
||||
excluded: list[ScoredContext] = []
|
||||
total_tokens = 0
|
||||
|
||||
# First, try to fit required contexts
|
||||
for sc in required:
|
||||
token_count = sc.context.token_count or 0
|
||||
context_type = sc.context.get_type()
|
||||
|
||||
if budget.can_fit(context_type, token_count):
|
||||
budget.allocate(context_type, token_count)
|
||||
selected.append(sc)
|
||||
total_tokens += token_count
|
||||
else:
|
||||
# Force-fit CRITICAL contexts if needed
|
||||
budget.allocate(context_type, token_count, force=True)
|
||||
selected.append(sc)
|
||||
total_tokens += token_count
|
||||
logger.warning(
|
||||
f"Force-fitted CRITICAL context: {sc.context.source} "
|
||||
f"({token_count} tokens)"
|
||||
)
|
||||
|
||||
# Then, greedily add optional contexts
|
||||
for sc in optional:
|
||||
token_count = sc.context.token_count or 0
|
||||
context_type = sc.context.get_type()
|
||||
|
||||
if budget.can_fit(context_type, token_count):
|
||||
budget.allocate(context_type, token_count)
|
||||
selected.append(sc)
|
||||
total_tokens += token_count
|
||||
else:
|
||||
excluded.append(sc)
|
||||
|
||||
# Build stats
|
||||
stats = {
|
||||
"total_contexts": len(contexts),
|
||||
"required_count": len(required),
|
||||
"selected_count": len(selected),
|
||||
"excluded_count": len(excluded),
|
||||
"total_tokens": total_tokens,
|
||||
"by_type": self._count_by_type(selected),
|
||||
}
|
||||
|
||||
return RankingResult(
|
||||
selected=selected,
|
||||
excluded=excluded,
|
||||
total_tokens=total_tokens,
|
||||
selection_stats=stats,
|
||||
)
|
||||
|
||||
async def rank_simple(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
max_tokens: int,
|
||||
model: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[BaseContext]:
|
||||
"""
|
||||
Simple ranking without budget per type.
|
||||
|
||||
Selects top contexts by score until max tokens reached.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to rank
|
||||
query: Query to rank against
|
||||
max_tokens: Maximum total tokens
|
||||
model: Model for token counting
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
Selected contexts (in score order)
|
||||
"""
|
||||
if not contexts:
|
||||
return []
|
||||
|
||||
# Ensure token counts
|
||||
await self._ensure_token_counts(contexts, model)
|
||||
|
||||
# Score all contexts
|
||||
scored_contexts = await self._scorer.score_batch(contexts, query, **kwargs)
|
||||
|
||||
# Sort by score (highest first)
|
||||
scored_contexts.sort(reverse=True)
|
||||
|
||||
# Greedy selection
|
||||
selected: list[BaseContext] = []
|
||||
total_tokens = 0
|
||||
|
||||
for sc in scored_contexts:
|
||||
token_count = sc.context.token_count or 0
|
||||
if total_tokens + token_count <= max_tokens:
|
||||
selected.append(sc.context)
|
||||
total_tokens += token_count
|
||||
|
||||
return selected
|
||||
|
||||
async def _ensure_token_counts(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
model: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Ensure all contexts have token counts.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to check
|
||||
model: Model for token counting
|
||||
"""
|
||||
for context in contexts:
|
||||
if context.token_count is None:
|
||||
count = await self._calculator.count_tokens(
|
||||
context.content, model
|
||||
)
|
||||
context.token_count = count
|
||||
|
||||
def _count_by_type(
|
||||
self, scored_contexts: list[ScoredContext]
|
||||
) -> dict[str, dict[str, int]]:
|
||||
"""Count selected contexts by type."""
|
||||
by_type: dict[str, dict[str, int]] = {}
|
||||
|
||||
for sc in scored_contexts:
|
||||
type_name = sc.context.get_type().value
|
||||
if type_name not in by_type:
|
||||
by_type[type_name] = {"count": 0, "tokens": 0}
|
||||
by_type[type_name]["count"] += 1
|
||||
by_type[type_name]["tokens"] += sc.context.token_count or 0
|
||||
|
||||
return by_type
|
||||
|
||||
async def rerank_for_diversity(
|
||||
self,
|
||||
scored_contexts: list[ScoredContext],
|
||||
max_per_source: int = 3,
|
||||
) -> list[ScoredContext]:
|
||||
"""
|
||||
Rerank to ensure source diversity.
|
||||
|
||||
Prevents too many items from the same source.
|
||||
|
||||
Args:
|
||||
scored_contexts: Already scored contexts
|
||||
max_per_source: Maximum items per source
|
||||
|
||||
Returns:
|
||||
Reranked contexts
|
||||
"""
|
||||
source_counts: dict[str, int] = {}
|
||||
result: list[ScoredContext] = []
|
||||
deferred: list[ScoredContext] = []
|
||||
|
||||
for sc in scored_contexts:
|
||||
source = sc.context.source
|
||||
current_count = source_counts.get(source, 0)
|
||||
|
||||
if current_count < max_per_source:
|
||||
result.append(sc)
|
||||
source_counts[source] = current_count + 1
|
||||
else:
|
||||
deferred.append(sc)
|
||||
|
||||
# Add deferred items at the end
|
||||
result.extend(deferred)
|
||||
return result
|
||||
@@ -1,5 +1,21 @@
|
||||
"""
|
||||
Context Scoring Module.
|
||||
|
||||
Provides relevance, recency, and composite scoring.
|
||||
Provides scoring strategies for context prioritization.
|
||||
"""
|
||||
|
||||
from .base import BaseScorer, ScorerProtocol
|
||||
from .composite import CompositeScorer, ScoredContext
|
||||
from .priority import PriorityScorer
|
||||
from .recency import RecencyScorer
|
||||
from .relevance import RelevanceScorer
|
||||
|
||||
__all__ = [
|
||||
"BaseScorer",
|
||||
"CompositeScorer",
|
||||
"PriorityScorer",
|
||||
"RecencyScorer",
|
||||
"RelevanceScorer",
|
||||
"ScoredContext",
|
||||
"ScorerProtocol",
|
||||
]
|
||||
|
||||
99
backend/app/services/context/scoring/base.py
Normal file
99
backend/app/services/context/scoring/base.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Base Scorer Protocol and Types.
|
||||
|
||||
Defines the interface for context scoring implementations.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
||||
|
||||
from ..types import BaseContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.mcp.client_manager import MCPClientManager
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ScorerProtocol(Protocol):
|
||||
"""Protocol for context scorers."""
|
||||
|
||||
async def score(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> float:
|
||||
"""
|
||||
Score a context item.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query to score against
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
Score between 0.0 and 1.0
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class BaseScorer(ABC):
|
||||
"""
|
||||
Abstract base class for context scorers.
|
||||
|
||||
Provides common functionality and interface for
|
||||
different scoring strategies.
|
||||
"""
|
||||
|
||||
def __init__(self, weight: float = 1.0) -> None:
|
||||
"""
|
||||
Initialize scorer.
|
||||
|
||||
Args:
|
||||
weight: Weight for this scorer in composite scoring
|
||||
"""
|
||||
self._weight = weight
|
||||
|
||||
@property
|
||||
def weight(self) -> float:
|
||||
"""Get scorer weight."""
|
||||
return self._weight
|
||||
|
||||
@weight.setter
|
||||
def weight(self, value: float) -> None:
|
||||
"""Set scorer weight."""
|
||||
if not 0.0 <= value <= 1.0:
|
||||
raise ValueError("Weight must be between 0.0 and 1.0")
|
||||
self._weight = value
|
||||
|
||||
@abstractmethod
|
||||
async def score(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> float:
|
||||
"""
|
||||
Score a context item.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query to score against
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
Score between 0.0 and 1.0
|
||||
"""
|
||||
...
|
||||
|
||||
def normalize_score(self, score: float) -> float:
|
||||
"""
|
||||
Normalize score to [0.0, 1.0] range.
|
||||
|
||||
Args:
|
||||
score: Raw score
|
||||
|
||||
Returns:
|
||||
Normalized score
|
||||
"""
|
||||
return max(0.0, min(1.0, score))
|
||||
276
backend/app/services/context/scoring/composite.py
Normal file
276
backend/app/services/context/scoring/composite.py
Normal file
@@ -0,0 +1,276 @@
|
||||
"""
|
||||
Composite Scorer for Context Management.
|
||||
|
||||
Combines multiple scoring strategies with configurable weights.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from .base import BaseScorer
|
||||
from .priority import PriorityScorer
|
||||
from .recency import RecencyScorer
|
||||
from .relevance import RelevanceScorer
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
from ..types import BaseContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.mcp.client_manager import MCPClientManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScoredContext:
|
||||
"""Context with computed scores."""
|
||||
|
||||
context: BaseContext
|
||||
composite_score: float
|
||||
relevance_score: float = 0.0
|
||||
recency_score: float = 0.0
|
||||
priority_score: float = 0.0
|
||||
|
||||
def __lt__(self, other: "ScoredContext") -> bool:
|
||||
"""Enable sorting by composite score."""
|
||||
return self.composite_score < other.composite_score
|
||||
|
||||
def __gt__(self, other: "ScoredContext") -> bool:
|
||||
"""Enable sorting by composite score."""
|
||||
return self.composite_score > other.composite_score
|
||||
|
||||
|
||||
class CompositeScorer:
|
||||
"""
|
||||
Combines multiple scoring strategies.
|
||||
|
||||
Weights:
|
||||
- relevance: How well content matches the query
|
||||
- recency: How recent the content is
|
||||
- priority: Explicit priority assignments
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_manager: "MCPClientManager | None" = None,
|
||||
settings: ContextSettings | None = None,
|
||||
relevance_weight: float | None = None,
|
||||
recency_weight: float | None = None,
|
||||
priority_weight: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize composite scorer.
|
||||
|
||||
Args:
|
||||
mcp_manager: MCP manager for semantic scoring
|
||||
settings: Context settings (uses default if None)
|
||||
relevance_weight: Override relevance weight
|
||||
recency_weight: Override recency weight
|
||||
priority_weight: Override priority weight
|
||||
"""
|
||||
self._settings = settings or get_context_settings()
|
||||
weights = self._settings.get_scoring_weights()
|
||||
|
||||
self._relevance_weight = (
|
||||
relevance_weight if relevance_weight is not None else weights["relevance"]
|
||||
)
|
||||
self._recency_weight = (
|
||||
recency_weight if recency_weight is not None else weights["recency"]
|
||||
)
|
||||
self._priority_weight = (
|
||||
priority_weight if priority_weight is not None else weights["priority"]
|
||||
)
|
||||
|
||||
# Initialize scorers
|
||||
self._relevance_scorer = RelevanceScorer(
|
||||
mcp_manager=mcp_manager,
|
||||
weight=self._relevance_weight,
|
||||
)
|
||||
self._recency_scorer = RecencyScorer(weight=self._recency_weight)
|
||||
self._priority_scorer = PriorityScorer(weight=self._priority_weight)
|
||||
|
||||
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||
"""Set MCP manager for semantic scoring."""
|
||||
self._relevance_scorer.set_mcp_manager(mcp_manager)
|
||||
|
||||
@property
|
||||
def weights(self) -> dict[str, float]:
|
||||
"""Get current scoring weights."""
|
||||
return {
|
||||
"relevance": self._relevance_weight,
|
||||
"recency": self._recency_weight,
|
||||
"priority": self._priority_weight,
|
||||
}
|
||||
|
||||
def update_weights(
|
||||
self,
|
||||
relevance: float | None = None,
|
||||
recency: float | None = None,
|
||||
priority: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Update scoring weights.
|
||||
|
||||
Args:
|
||||
relevance: New relevance weight
|
||||
recency: New recency weight
|
||||
priority: New priority weight
|
||||
"""
|
||||
if relevance is not None:
|
||||
self._relevance_weight = max(0.0, min(1.0, relevance))
|
||||
self._relevance_scorer.weight = self._relevance_weight
|
||||
|
||||
if recency is not None:
|
||||
self._recency_weight = max(0.0, min(1.0, recency))
|
||||
self._recency_scorer.weight = self._recency_weight
|
||||
|
||||
if priority is not None:
|
||||
self._priority_weight = max(0.0, min(1.0, priority))
|
||||
self._priority_scorer.weight = self._priority_weight
|
||||
|
||||
async def score(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> float:
|
||||
"""
|
||||
Compute composite score for a context.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query to score against
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
Composite score between 0.0 and 1.0
|
||||
"""
|
||||
scored = await self.score_with_details(context, query, **kwargs)
|
||||
return scored.composite_score
|
||||
|
||||
async def score_with_details(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> ScoredContext:
|
||||
"""
|
||||
Compute composite score with individual scores.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query to score against
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
ScoredContext with all scores
|
||||
"""
|
||||
# Check if context already has a score
|
||||
if context._score is not None:
|
||||
return ScoredContext(
|
||||
context=context,
|
||||
composite_score=context._score,
|
||||
)
|
||||
|
||||
# Compute individual scores in parallel
|
||||
relevance_task = self._relevance_scorer.score(context, query, **kwargs)
|
||||
recency_task = self._recency_scorer.score(context, query, **kwargs)
|
||||
priority_task = self._priority_scorer.score(context, query, **kwargs)
|
||||
|
||||
relevance_score, recency_score, priority_score = await asyncio.gather(
|
||||
relevance_task, recency_task, priority_task
|
||||
)
|
||||
|
||||
# Compute weighted composite
|
||||
total_weight = (
|
||||
self._relevance_weight + self._recency_weight + self._priority_weight
|
||||
)
|
||||
|
||||
if total_weight > 0:
|
||||
composite = (
|
||||
relevance_score * self._relevance_weight
|
||||
+ recency_score * self._recency_weight
|
||||
+ priority_score * self._priority_weight
|
||||
) / total_weight
|
||||
else:
|
||||
composite = 0.0
|
||||
|
||||
# Cache the score on the context
|
||||
context._score = composite
|
||||
|
||||
return ScoredContext(
|
||||
context=context,
|
||||
composite_score=composite,
|
||||
relevance_score=relevance_score,
|
||||
recency_score=recency_score,
|
||||
priority_score=priority_score,
|
||||
)
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
parallel: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> list[ScoredContext]:
|
||||
"""
|
||||
Score multiple contexts.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to score
|
||||
query: Query to score against
|
||||
parallel: Whether to score in parallel
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
List of ScoredContext (same order as input)
|
||||
"""
|
||||
if parallel:
|
||||
tasks = [
|
||||
self.score_with_details(ctx, query, **kwargs) for ctx in contexts
|
||||
]
|
||||
return await asyncio.gather(*tasks)
|
||||
else:
|
||||
results = []
|
||||
for ctx in contexts:
|
||||
scored = await self.score_with_details(ctx, query, **kwargs)
|
||||
results.append(scored)
|
||||
return results
|
||||
|
||||
async def rank(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
limit: int | None = None,
|
||||
min_score: float = 0.0,
|
||||
**kwargs: Any,
|
||||
) -> list[ScoredContext]:
|
||||
"""
|
||||
Score and rank contexts.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to rank
|
||||
query: Query to rank against
|
||||
limit: Maximum number of results
|
||||
min_score: Minimum score threshold
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
Sorted list of ScoredContext (highest first)
|
||||
"""
|
||||
# Score all contexts
|
||||
scored = await self.score_batch(contexts, query, **kwargs)
|
||||
|
||||
# Filter by minimum score
|
||||
if min_score > 0:
|
||||
scored = [s for s in scored if s.composite_score >= min_score]
|
||||
|
||||
# Sort by score (highest first)
|
||||
scored.sort(reverse=True)
|
||||
|
||||
# Apply limit
|
||||
if limit is not None:
|
||||
scored = scored[:limit]
|
||||
|
||||
return scored
|
||||
135
backend/app/services/context/scoring/priority.py
Normal file
135
backend/app/services/context/scoring/priority.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
Priority Scorer for Context Management.
|
||||
|
||||
Scores context based on assigned priority levels.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseScorer
|
||||
from ..types import BaseContext, ContextPriority, ContextType
|
||||
|
||||
|
||||
class PriorityScorer(BaseScorer):
|
||||
"""
|
||||
Scores context based on priority levels.
|
||||
|
||||
Converts priority enum values to normalized scores.
|
||||
Also applies type-based priority bonuses.
|
||||
"""
|
||||
|
||||
# Default priority bonuses by context type
|
||||
DEFAULT_TYPE_BONUSES: dict[ContextType, float] = {
|
||||
ContextType.SYSTEM: 0.2, # System prompts get a boost
|
||||
ContextType.TASK: 0.15, # Current task is important
|
||||
ContextType.TOOL: 0.1, # Recent tool results matter
|
||||
ContextType.KNOWLEDGE: 0.0, # Knowledge scored by relevance
|
||||
ContextType.CONVERSATION: 0.0, # Conversation scored by recency
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight: float = 1.0,
|
||||
type_bonuses: dict[ContextType, float] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize priority scorer.
|
||||
|
||||
Args:
|
||||
weight: Scorer weight for composite scoring
|
||||
type_bonuses: Optional context-type priority bonuses
|
||||
"""
|
||||
super().__init__(weight)
|
||||
self._type_bonuses = type_bonuses or self.DEFAULT_TYPE_BONUSES.copy()
|
||||
|
||||
async def score(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> float:
|
||||
"""
|
||||
Score context based on priority.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query (not used for priority, kept for interface)
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Priority score between 0.0 and 1.0
|
||||
"""
|
||||
# Get base priority score
|
||||
priority_value = context.priority
|
||||
base_score = self._priority_to_score(priority_value)
|
||||
|
||||
# Apply type bonus
|
||||
context_type = context.get_type()
|
||||
bonus = self._type_bonuses.get(context_type, 0.0)
|
||||
|
||||
return self.normalize_score(base_score + bonus)
|
||||
|
||||
def _priority_to_score(self, priority: int) -> float:
|
||||
"""
|
||||
Convert priority value to normalized score.
|
||||
|
||||
Priority values (from ContextPriority):
|
||||
- CRITICAL (100) -> 1.0
|
||||
- HIGH (80) -> 0.8
|
||||
- NORMAL (50) -> 0.5
|
||||
- LOW (20) -> 0.2
|
||||
- MINIMAL (0) -> 0.0
|
||||
|
||||
Args:
|
||||
priority: Priority value (0-100)
|
||||
|
||||
Returns:
|
||||
Normalized score (0.0-1.0)
|
||||
"""
|
||||
# Clamp to valid range
|
||||
clamped = max(0, min(100, priority))
|
||||
return clamped / 100.0
|
||||
|
||||
def get_type_bonus(self, context_type: ContextType) -> float:
|
||||
"""
|
||||
Get priority bonus for a context type.
|
||||
|
||||
Args:
|
||||
context_type: Context type
|
||||
|
||||
Returns:
|
||||
Bonus value
|
||||
"""
|
||||
return self._type_bonuses.get(context_type, 0.0)
|
||||
|
||||
def set_type_bonus(self, context_type: ContextType, bonus: float) -> None:
|
||||
"""
|
||||
Set priority bonus for a context type.
|
||||
|
||||
Args:
|
||||
context_type: Context type
|
||||
bonus: Bonus value (0.0-1.0)
|
||||
"""
|
||||
if not 0.0 <= bonus <= 1.0:
|
||||
raise ValueError("Bonus must be between 0.0 and 1.0")
|
||||
self._type_bonuses[context_type] = bonus
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> list[float]:
|
||||
"""
|
||||
Score multiple contexts.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to score
|
||||
query: Query (not used)
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
List of scores (same order as input)
|
||||
"""
|
||||
# Priority scoring is fast, no async needed
|
||||
return [await self.score(ctx, query, **kwargs) for ctx in contexts]
|
||||
141
backend/app/services/context/scoring/recency.py
Normal file
141
backend/app/services/context/scoring/recency.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
Recency Scorer for Context Management.
|
||||
|
||||
Scores context based on how recent it is.
|
||||
More recent content gets higher scores.
|
||||
"""
|
||||
|
||||
import math
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseScorer
|
||||
from ..types import BaseContext, ContextType
|
||||
|
||||
|
||||
class RecencyScorer(BaseScorer):
|
||||
"""
|
||||
Scores context based on recency.
|
||||
|
||||
Uses exponential decay to score content based on age.
|
||||
More recent content scores higher.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight: float = 1.0,
|
||||
half_life_hours: float = 24.0,
|
||||
type_half_lives: dict[ContextType, float] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize recency scorer.
|
||||
|
||||
Args:
|
||||
weight: Scorer weight for composite scoring
|
||||
half_life_hours: Default hours until score decays to 0.5
|
||||
type_half_lives: Optional context-type-specific half lives
|
||||
"""
|
||||
super().__init__(weight)
|
||||
self._half_life_hours = half_life_hours
|
||||
self._type_half_lives = type_half_lives or {}
|
||||
|
||||
# Set sensible defaults for context types
|
||||
if ContextType.CONVERSATION not in self._type_half_lives:
|
||||
self._type_half_lives[ContextType.CONVERSATION] = 1.0 # 1 hour
|
||||
if ContextType.TOOL not in self._type_half_lives:
|
||||
self._type_half_lives[ContextType.TOOL] = 0.5 # 30 minutes
|
||||
if ContextType.KNOWLEDGE not in self._type_half_lives:
|
||||
self._type_half_lives[ContextType.KNOWLEDGE] = 168.0 # 1 week
|
||||
if ContextType.SYSTEM not in self._type_half_lives:
|
||||
self._type_half_lives[ContextType.SYSTEM] = 720.0 # 30 days
|
||||
if ContextType.TASK not in self._type_half_lives:
|
||||
self._type_half_lives[ContextType.TASK] = 24.0 # 1 day
|
||||
|
||||
async def score(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> float:
|
||||
"""
|
||||
Score context based on recency.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query (not used for recency, kept for interface)
|
||||
**kwargs: Additional parameters
|
||||
- reference_time: Time to measure recency from (default: now)
|
||||
|
||||
Returns:
|
||||
Recency score between 0.0 and 1.0
|
||||
"""
|
||||
reference_time = kwargs.get("reference_time")
|
||||
if reference_time is None:
|
||||
reference_time = datetime.now(UTC)
|
||||
elif reference_time.tzinfo is None:
|
||||
reference_time = reference_time.replace(tzinfo=UTC)
|
||||
|
||||
# Ensure context timestamp is timezone-aware
|
||||
context_time = context.timestamp
|
||||
if context_time.tzinfo is None:
|
||||
context_time = context_time.replace(tzinfo=UTC)
|
||||
|
||||
# Calculate age in hours
|
||||
age = reference_time - context_time
|
||||
age_hours = max(0, age.total_seconds() / 3600)
|
||||
|
||||
# Get half-life for this context type
|
||||
context_type = context.get_type()
|
||||
half_life = self._type_half_lives.get(context_type, self._half_life_hours)
|
||||
|
||||
# Exponential decay
|
||||
decay_factor = math.exp(-math.log(2) * age_hours / half_life)
|
||||
|
||||
return self.normalize_score(decay_factor)
|
||||
|
||||
def get_half_life(self, context_type: ContextType) -> float:
|
||||
"""
|
||||
Get half-life for a context type.
|
||||
|
||||
Args:
|
||||
context_type: Context type to get half-life for
|
||||
|
||||
Returns:
|
||||
Half-life in hours
|
||||
"""
|
||||
return self._type_half_lives.get(context_type, self._half_life_hours)
|
||||
|
||||
def set_half_life(self, context_type: ContextType, hours: float) -> None:
|
||||
"""
|
||||
Set half-life for a context type.
|
||||
|
||||
Args:
|
||||
context_type: Context type to set half-life for
|
||||
hours: Half-life in hours
|
||||
"""
|
||||
if hours <= 0:
|
||||
raise ValueError("Half-life must be positive")
|
||||
self._type_half_lives[context_type] = hours
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> list[float]:
|
||||
"""
|
||||
Score multiple contexts.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to score
|
||||
query: Query (not used)
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
List of scores (same order as input)
|
||||
"""
|
||||
scores = []
|
||||
for context in contexts:
|
||||
score = await self.score(context, query, **kwargs)
|
||||
scores.append(score)
|
||||
return scores
|
||||
188
backend/app/services/context/scoring/relevance.py
Normal file
188
backend/app/services/context/scoring/relevance.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
Relevance Scorer for Context Management.
|
||||
|
||||
Scores context based on semantic similarity to the query.
|
||||
Uses Knowledge Base embeddings when available.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from .base import BaseScorer
|
||||
from ..types import BaseContext, ContextType, KnowledgeContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.mcp.client_manager import MCPClientManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RelevanceScorer(BaseScorer):
|
||||
"""
|
||||
Scores context based on relevance to query.
|
||||
|
||||
Uses multiple strategies:
|
||||
1. Pre-computed scores (from RAG results)
|
||||
2. MCP-based semantic similarity (via Knowledge Base)
|
||||
3. Keyword matching fallback
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_manager: "MCPClientManager | None" = None,
|
||||
weight: float = 1.0,
|
||||
keyword_fallback_weight: float = 0.5,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize relevance scorer.
|
||||
|
||||
Args:
|
||||
mcp_manager: MCP manager for Knowledge Base calls
|
||||
weight: Scorer weight for composite scoring
|
||||
keyword_fallback_weight: Max score for keyword-based fallback
|
||||
"""
|
||||
super().__init__(weight)
|
||||
self._mcp = mcp_manager
|
||||
self._keyword_fallback_weight = keyword_fallback_weight
|
||||
|
||||
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||
"""Set MCP manager for semantic scoring."""
|
||||
self._mcp = mcp_manager
|
||||
|
||||
async def score(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> float:
|
||||
"""
|
||||
Score context relevance to query.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query to score against
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Relevance score between 0.0 and 1.0
|
||||
"""
|
||||
# 1. Check for pre-computed relevance score
|
||||
if isinstance(context, KnowledgeContext) and context.relevance_score is not None:
|
||||
return self.normalize_score(context.relevance_score)
|
||||
|
||||
# 2. Check metadata for score
|
||||
if "relevance_score" in context.metadata:
|
||||
return self.normalize_score(context.metadata["relevance_score"])
|
||||
|
||||
if "score" in context.metadata:
|
||||
return self.normalize_score(context.metadata["score"])
|
||||
|
||||
# 3. Try MCP-based semantic similarity
|
||||
if self._mcp is not None:
|
||||
try:
|
||||
score = await self._compute_semantic_similarity(context, query)
|
||||
if score is not None:
|
||||
return score
|
||||
except Exception as e:
|
||||
logger.debug(f"Semantic scoring failed, using fallback: {e}")
|
||||
|
||||
# 4. Fall back to keyword matching
|
||||
return self._compute_keyword_score(context, query)
|
||||
|
||||
async def _compute_semantic_similarity(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
) -> float | None:
|
||||
"""
|
||||
Compute semantic similarity using Knowledge Base embeddings.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query to compare
|
||||
|
||||
Returns:
|
||||
Similarity score or None if unavailable
|
||||
"""
|
||||
try:
|
||||
# Use Knowledge Base's search capability to compute similarity
|
||||
result = await self._mcp.call_tool(
|
||||
server="knowledge-base",
|
||||
tool="compute_similarity",
|
||||
args={
|
||||
"text1": query,
|
||||
"text2": context.content[:2000], # Limit content length
|
||||
},
|
||||
)
|
||||
|
||||
if result.success and result.data:
|
||||
similarity = result.data.get("similarity")
|
||||
if similarity is not None:
|
||||
return self.normalize_score(float(similarity))
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Semantic similarity computation failed: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _compute_keyword_score(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
) -> float:
|
||||
"""
|
||||
Compute relevance score based on keyword matching.
|
||||
|
||||
Simple but fast fallback when semantic search is unavailable.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query to match
|
||||
|
||||
Returns:
|
||||
Keyword-based relevance score
|
||||
"""
|
||||
if not query or not context.content:
|
||||
return 0.0
|
||||
|
||||
# Extract keywords from query
|
||||
query_lower = query.lower()
|
||||
content_lower = context.content.lower()
|
||||
|
||||
# Simple word tokenization
|
||||
query_words = set(re.findall(r"\b\w{3,}\b", query_lower))
|
||||
content_words = set(re.findall(r"\b\w{3,}\b", content_lower))
|
||||
|
||||
if not query_words:
|
||||
return 0.0
|
||||
|
||||
# Calculate overlap
|
||||
common_words = query_words & content_words
|
||||
overlap_ratio = len(common_words) / len(query_words)
|
||||
|
||||
# Apply fallback weight ceiling
|
||||
return self.normalize_score(overlap_ratio * self._keyword_fallback_weight)
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> list[float]:
|
||||
"""
|
||||
Score multiple contexts.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to score
|
||||
query: Query to score against
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
List of scores (same order as input)
|
||||
"""
|
||||
scores = []
|
||||
for context in contexts:
|
||||
score = await self.score(context, query, **kwargs)
|
||||
scores.append(score)
|
||||
return scores
|
||||
@@ -55,11 +55,9 @@ class TaskContext(BaseContext):
|
||||
constraints: list[str] = field(default_factory=list)
|
||||
parent_task_id: str | None = field(default=None)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Set high priority for task context."""
|
||||
# Task context defaults to high priority
|
||||
if self.priority == ContextPriority.NORMAL.value:
|
||||
self.priority = ContextPriority.HIGH.value
|
||||
# Note: TaskContext should typically have HIGH priority,
|
||||
# but we don't auto-promote to allow explicit priority setting.
|
||||
# Use TaskContext.create() for default HIGH priority behavior.
|
||||
|
||||
def get_type(self) -> ContextType:
|
||||
"""Return TASK context type."""
|
||||
|
||||
507
backend/tests/services/context/test_ranker.py
Normal file
507
backend/tests/services/context/test_ranker.py
Normal file
@@ -0,0 +1,507 @@
|
||||
"""Tests for context ranking module."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.budget import BudgetAllocator, TokenBudget
|
||||
from app.services.context.prioritization import ContextRanker, RankingResult
|
||||
from app.services.context.scoring import CompositeScorer, ScoredContext
|
||||
from app.services.context.types import (
|
||||
ContextPriority,
|
||||
ContextType,
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MessageRole,
|
||||
SystemContext,
|
||||
TaskContext,
|
||||
)
|
||||
|
||||
|
||||
class TestRankingResult:
|
||||
"""Tests for RankingResult dataclass."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test RankingResult creation."""
|
||||
ctx = TaskContext(content="Test", source="task")
|
||||
scored = ScoredContext(context=ctx, composite_score=0.8)
|
||||
|
||||
result = RankingResult(
|
||||
selected=[scored],
|
||||
excluded=[],
|
||||
total_tokens=100,
|
||||
selection_stats={"total": 1},
|
||||
)
|
||||
|
||||
assert len(result.selected) == 1
|
||||
assert result.total_tokens == 100
|
||||
|
||||
def test_selected_contexts_property(self) -> None:
|
||||
"""Test selected_contexts property extracts contexts."""
|
||||
ctx1 = TaskContext(content="Test 1", source="task")
|
||||
ctx2 = TaskContext(content="Test 2", source="task")
|
||||
|
||||
scored1 = ScoredContext(context=ctx1, composite_score=0.8)
|
||||
scored2 = ScoredContext(context=ctx2, composite_score=0.6)
|
||||
|
||||
result = RankingResult(
|
||||
selected=[scored1, scored2],
|
||||
excluded=[],
|
||||
total_tokens=200,
|
||||
)
|
||||
|
||||
selected = result.selected_contexts
|
||||
assert len(selected) == 2
|
||||
assert ctx1 in selected
|
||||
assert ctx2 in selected
|
||||
|
||||
|
||||
class TestContextRanker:
|
||||
"""Tests for ContextRanker."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test ranker creation."""
|
||||
ranker = ContextRanker()
|
||||
assert ranker._scorer is not None
|
||||
assert ranker._calculator is not None
|
||||
|
||||
def test_creation_with_scorer(self) -> None:
|
||||
"""Test ranker creation with custom scorer."""
|
||||
scorer = CompositeScorer(relevance_weight=0.8)
|
||||
ranker = ContextRanker(scorer=scorer)
|
||||
assert ranker._scorer is scorer
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_empty_contexts(self) -> None:
|
||||
"""Test ranking empty context list."""
|
||||
ranker = ContextRanker()
|
||||
allocator = BudgetAllocator()
|
||||
budget = allocator.create_budget(10000)
|
||||
|
||||
result = await ranker.rank([], "query", budget)
|
||||
|
||||
assert len(result.selected) == 0
|
||||
assert len(result.excluded) == 0
|
||||
assert result.total_tokens == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_single_context_fits(self) -> None:
|
||||
"""Test ranking single context that fits budget."""
|
||||
ranker = ContextRanker()
|
||||
allocator = BudgetAllocator()
|
||||
budget = allocator.create_budget(10000)
|
||||
|
||||
context = KnowledgeContext(
|
||||
content="Short content",
|
||||
source="docs",
|
||||
relevance_score=0.8,
|
||||
)
|
||||
|
||||
result = await ranker.rank([context], "query", budget)
|
||||
|
||||
assert len(result.selected) == 1
|
||||
assert len(result.excluded) == 0
|
||||
assert result.selected[0].context is context
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_respects_budget(self) -> None:
|
||||
"""Test that ranking respects token budget."""
|
||||
ranker = ContextRanker()
|
||||
|
||||
# Create a very small budget
|
||||
budget = TokenBudget(
|
||||
total=100,
|
||||
knowledge=50, # Only 50 tokens for knowledge
|
||||
)
|
||||
|
||||
# Create contexts that exceed budget
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="A" * 200, # ~50 tokens
|
||||
source="docs",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="B" * 200, # ~50 tokens
|
||||
source="docs",
|
||||
relevance_score=0.8,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="C" * 200, # ~50 tokens
|
||||
source="docs",
|
||||
relevance_score=0.7,
|
||||
),
|
||||
]
|
||||
|
||||
result = await ranker.rank(contexts, "query", budget)
|
||||
|
||||
# Not all should fit
|
||||
assert len(result.selected) < len(contexts)
|
||||
assert len(result.excluded) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_selects_highest_scores(self) -> None:
|
||||
"""Test that ranking selects highest scored contexts."""
|
||||
ranker = ContextRanker()
|
||||
allocator = BudgetAllocator()
|
||||
budget = allocator.create_budget(1000)
|
||||
|
||||
# Small budget for knowledge
|
||||
budget.knowledge = 100
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="Low score",
|
||||
source="docs",
|
||||
relevance_score=0.2,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="High score",
|
||||
source="docs",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Medium score",
|
||||
source="docs",
|
||||
relevance_score=0.5,
|
||||
),
|
||||
]
|
||||
|
||||
result = await ranker.rank(contexts, "query", budget)
|
||||
|
||||
# Should have selected some
|
||||
if result.selected:
|
||||
# The highest scored should be selected first
|
||||
scores = [s.composite_score for s in result.selected]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_critical_priority_always_included(self) -> None:
|
||||
"""Test that CRITICAL priority contexts are always included."""
|
||||
ranker = ContextRanker()
|
||||
|
||||
# Very small budget
|
||||
budget = TokenBudget(
|
||||
total=100,
|
||||
system=10, # Very small
|
||||
knowledge=10,
|
||||
)
|
||||
|
||||
contexts = [
|
||||
SystemContext(
|
||||
content="Critical system prompt that must be included",
|
||||
source="system",
|
||||
priority=ContextPriority.CRITICAL.value,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Optional knowledge",
|
||||
source="docs",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
]
|
||||
|
||||
result = await ranker.rank(contexts, "query", budget, ensure_required=True)
|
||||
|
||||
# Critical context should be in selected
|
||||
selected_priorities = [s.context.priority for s in result.selected]
|
||||
assert ContextPriority.CRITICAL.value in selected_priorities
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_without_ensure_required(self) -> None:
|
||||
"""Test ranking without ensuring required contexts."""
|
||||
ranker = ContextRanker()
|
||||
|
||||
budget = TokenBudget(
|
||||
total=100,
|
||||
system=50,
|
||||
knowledge=50,
|
||||
)
|
||||
|
||||
contexts = [
|
||||
SystemContext(
|
||||
content="A" * 500, # Large content
|
||||
source="system",
|
||||
priority=ContextPriority.CRITICAL.value,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Short",
|
||||
source="docs",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
]
|
||||
|
||||
result = await ranker.rank(
|
||||
contexts, "query", budget, ensure_required=False
|
||||
)
|
||||
|
||||
# Without ensure_required, CRITICAL contexts can be excluded
|
||||
# if budget doesn't allow
|
||||
assert len(result.selected) + len(result.excluded) == len(contexts)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_selection_stats(self) -> None:
|
||||
"""Test that ranking provides useful statistics."""
|
||||
ranker = ContextRanker()
|
||||
allocator = BudgetAllocator()
|
||||
budget = allocator.create_budget(10000)
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="Knowledge 1", source="docs", relevance_score=0.8
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Knowledge 2", source="docs", relevance_score=0.6
|
||||
),
|
||||
TaskContext(content="Task", source="task"),
|
||||
]
|
||||
|
||||
result = await ranker.rank(contexts, "query", budget)
|
||||
|
||||
stats = result.selection_stats
|
||||
assert "total_contexts" in stats
|
||||
assert "selected_count" in stats
|
||||
assert "excluded_count" in stats
|
||||
assert "total_tokens" in stats
|
||||
assert "by_type" in stats
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_simple(self) -> None:
|
||||
"""Test simple ranking without budget per type."""
|
||||
ranker = ContextRanker()
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="A",
|
||||
source="docs",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="B",
|
||||
source="docs",
|
||||
relevance_score=0.7,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="C",
|
||||
source="docs",
|
||||
relevance_score=0.5,
|
||||
),
|
||||
]
|
||||
|
||||
result = await ranker.rank_simple(contexts, "query", max_tokens=1000)
|
||||
|
||||
# Should return contexts sorted by score
|
||||
assert len(result) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_simple_respects_max_tokens(self) -> None:
|
||||
"""Test that simple ranking respects max tokens."""
|
||||
ranker = ContextRanker()
|
||||
|
||||
# Create contexts with known token counts
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="A" * 100, # ~25 tokens
|
||||
source="docs",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="B" * 100,
|
||||
source="docs",
|
||||
relevance_score=0.8,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="C" * 100,
|
||||
source="docs",
|
||||
relevance_score=0.7,
|
||||
),
|
||||
]
|
||||
|
||||
# Very small limit
|
||||
result = await ranker.rank_simple(contexts, "query", max_tokens=30)
|
||||
|
||||
# Should only fit a limited number
|
||||
assert len(result) <= len(contexts)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_simple_empty(self) -> None:
|
||||
"""Test simple ranking with empty list."""
|
||||
ranker = ContextRanker()
|
||||
|
||||
result = await ranker.rank_simple([], "query", max_tokens=1000)
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rerank_for_diversity(self) -> None:
|
||||
"""Test diversity reranking."""
|
||||
ranker = ContextRanker()
|
||||
|
||||
# Create scored contexts from same source
|
||||
contexts = [
|
||||
ScoredContext(
|
||||
context=KnowledgeContext(
|
||||
content=f"Content {i}",
|
||||
source="same-source",
|
||||
relevance_score=0.9 - i * 0.1,
|
||||
),
|
||||
composite_score=0.9 - i * 0.1,
|
||||
)
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
# Limit to 2 per source
|
||||
result = await ranker.rerank_for_diversity(contexts, max_per_source=2)
|
||||
|
||||
assert len(result) == 5
|
||||
# First 2 should be from same source, rest deferred
|
||||
first_two_sources = [r.context.source for r in result[:2]]
|
||||
assert all(s == "same-source" for s in first_two_sources)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rerank_for_diversity_multiple_sources(self) -> None:
|
||||
"""Test diversity reranking with multiple sources."""
|
||||
ranker = ContextRanker()
|
||||
|
||||
contexts = [
|
||||
ScoredContext(
|
||||
context=KnowledgeContext(
|
||||
content="Source A - 1",
|
||||
source="source-a",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
composite_score=0.9,
|
||||
),
|
||||
ScoredContext(
|
||||
context=KnowledgeContext(
|
||||
content="Source A - 2",
|
||||
source="source-a",
|
||||
relevance_score=0.8,
|
||||
),
|
||||
composite_score=0.8,
|
||||
),
|
||||
ScoredContext(
|
||||
context=KnowledgeContext(
|
||||
content="Source B - 1",
|
||||
source="source-b",
|
||||
relevance_score=0.7,
|
||||
),
|
||||
composite_score=0.7,
|
||||
),
|
||||
ScoredContext(
|
||||
context=KnowledgeContext(
|
||||
content="Source A - 3",
|
||||
source="source-a",
|
||||
relevance_score=0.6,
|
||||
),
|
||||
composite_score=0.6,
|
||||
),
|
||||
]
|
||||
|
||||
result = await ranker.rerank_for_diversity(contexts, max_per_source=2)
|
||||
|
||||
# Should not have more than 2 from source-a in first 3
|
||||
source_a_in_first_3 = sum(
|
||||
1 for r in result[:3] if r.context.source == "source-a"
|
||||
)
|
||||
assert source_a_in_first_3 <= 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_counts_set(self) -> None:
|
||||
"""Test that token counts are set during ranking."""
|
||||
ranker = ContextRanker()
|
||||
allocator = BudgetAllocator()
|
||||
budget = allocator.create_budget(10000)
|
||||
|
||||
context = KnowledgeContext(
|
||||
content="Test content",
|
||||
source="docs",
|
||||
relevance_score=0.8,
|
||||
)
|
||||
|
||||
# Token count should be None initially
|
||||
assert context.token_count is None
|
||||
|
||||
await ranker.rank([context], "query", budget)
|
||||
|
||||
# Token count should be set after ranking
|
||||
assert context.token_count is not None
|
||||
assert context.token_count > 0
|
||||
|
||||
|
||||
class TestContextRankerIntegration:
|
||||
"""Integration tests for context ranking."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_ranking_workflow(self) -> None:
|
||||
"""Test complete ranking workflow."""
|
||||
ranker = ContextRanker()
|
||||
allocator = BudgetAllocator()
|
||||
budget = allocator.create_budget(10000)
|
||||
|
||||
# Create diverse context types
|
||||
contexts = [
|
||||
SystemContext(
|
||||
content="You are a helpful assistant.",
|
||||
source="system",
|
||||
priority=ContextPriority.CRITICAL.value,
|
||||
),
|
||||
TaskContext(
|
||||
content="Help the user with their coding question.",
|
||||
source="task",
|
||||
priority=ContextPriority.HIGH.value,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Python is a programming language.",
|
||||
source="docs/python.md",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Java is also a programming language.",
|
||||
source="docs/java.md",
|
||||
relevance_score=0.4,
|
||||
),
|
||||
ConversationContext(
|
||||
content="Hello, can you help me?",
|
||||
source="chat",
|
||||
role=MessageRole.USER,
|
||||
),
|
||||
]
|
||||
|
||||
result = await ranker.rank(contexts, "Python help", budget)
|
||||
|
||||
# System (CRITICAL) should be included
|
||||
selected_types = [s.context.get_type() for s in result.selected]
|
||||
assert ContextType.SYSTEM in selected_types
|
||||
|
||||
# Stats should be populated
|
||||
assert result.selection_stats["total_contexts"] == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ranking_preserves_context_order_by_score(self) -> None:
|
||||
"""Test that ranking orders by score correctly."""
|
||||
ranker = ContextRanker()
|
||||
allocator = BudgetAllocator()
|
||||
budget = allocator.create_budget(100000)
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="Low",
|
||||
source="docs",
|
||||
relevance_score=0.2,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="High",
|
||||
source="docs",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Medium",
|
||||
source="docs",
|
||||
relevance_score=0.5,
|
||||
),
|
||||
]
|
||||
|
||||
result = await ranker.rank(contexts, "query", budget)
|
||||
|
||||
# Verify ordering is by score
|
||||
scores = [s.composite_score for s in result.selected]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
712
backend/tests/services/context/test_scoring.py
Normal file
712
backend/tests/services/context/test_scoring.py
Normal file
@@ -0,0 +1,712 @@
|
||||
"""Tests for context scoring module."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.scoring import (
|
||||
BaseScorer,
|
||||
CompositeScorer,
|
||||
PriorityScorer,
|
||||
RecencyScorer,
|
||||
RelevanceScorer,
|
||||
ScoredContext,
|
||||
)
|
||||
from app.services.context.types import (
|
||||
ContextPriority,
|
||||
ContextType,
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MessageRole,
|
||||
SystemContext,
|
||||
TaskContext,
|
||||
)
|
||||
|
||||
|
||||
class TestRelevanceScorer:
|
||||
"""Tests for RelevanceScorer."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test scorer creation."""
|
||||
scorer = RelevanceScorer()
|
||||
assert scorer.weight == 1.0
|
||||
|
||||
def test_creation_with_weight(self) -> None:
|
||||
"""Test scorer creation with custom weight."""
|
||||
scorer = RelevanceScorer(weight=0.5)
|
||||
assert scorer.weight == 0.5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_with_precomputed_relevance(self) -> None:
|
||||
"""Test scoring with pre-computed relevance score."""
|
||||
scorer = RelevanceScorer()
|
||||
|
||||
# KnowledgeContext with pre-computed score
|
||||
context = KnowledgeContext(
|
||||
content="Test content about Python",
|
||||
source="docs/python.md",
|
||||
relevance_score=0.85,
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "Python programming")
|
||||
assert score == 0.85
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_with_metadata_score(self) -> None:
|
||||
"""Test scoring with metadata-provided score."""
|
||||
scorer = RelevanceScorer()
|
||||
|
||||
context = SystemContext(
|
||||
content="System prompt",
|
||||
source="system",
|
||||
metadata={"relevance_score": 0.9},
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "anything")
|
||||
assert score == 0.9
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_fallback_to_keyword_matching(self) -> None:
|
||||
"""Test fallback to keyword matching when no score available."""
|
||||
scorer = RelevanceScorer(keyword_fallback_weight=0.5)
|
||||
|
||||
context = TaskContext(
|
||||
content="Implement authentication with JWT tokens",
|
||||
source="task",
|
||||
)
|
||||
|
||||
# Query has matching keywords
|
||||
score = await scorer.score(context, "JWT authentication")
|
||||
assert score > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keyword_matching_no_overlap(self) -> None:
|
||||
"""Test keyword matching with no query overlap."""
|
||||
scorer = RelevanceScorer()
|
||||
|
||||
context = TaskContext(
|
||||
content="Implement database migration",
|
||||
source="task",
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "xyz abc 123")
|
||||
assert score == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keyword_matching_full_overlap(self) -> None:
|
||||
"""Test keyword matching with high overlap."""
|
||||
scorer = RelevanceScorer(keyword_fallback_weight=1.0)
|
||||
|
||||
context = TaskContext(
|
||||
content="python programming language",
|
||||
source="task",
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "python programming")
|
||||
# Should have high score due to keyword overlap
|
||||
assert score > 0.5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_with_mcp_success(self) -> None:
|
||||
"""Test scoring with MCP semantic similarity."""
|
||||
mock_mcp = MagicMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.success = True
|
||||
mock_result.data = {"similarity": 0.75}
|
||||
mock_mcp.call_tool = AsyncMock(return_value=mock_result)
|
||||
|
||||
scorer = RelevanceScorer(mcp_manager=mock_mcp)
|
||||
|
||||
context = TaskContext(
|
||||
content="Test content",
|
||||
source="task",
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "test query")
|
||||
assert score == 0.75
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_with_mcp_failure_fallback(self) -> None:
|
||||
"""Test fallback when MCP fails."""
|
||||
mock_mcp = MagicMock()
|
||||
mock_mcp.call_tool = AsyncMock(side_effect=Exception("Connection failed"))
|
||||
|
||||
scorer = RelevanceScorer(mcp_manager=mock_mcp, keyword_fallback_weight=0.5)
|
||||
|
||||
context = TaskContext(
|
||||
content="Python programming code",
|
||||
source="task",
|
||||
)
|
||||
|
||||
# Should fall back to keyword matching
|
||||
score = await scorer.score(context, "Python code")
|
||||
assert score > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_batch(self) -> None:
|
||||
"""Test batch scoring."""
|
||||
scorer = RelevanceScorer()
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="Python", source="1", relevance_score=0.8
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Java", source="2", relevance_score=0.6
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Go", source="3", relevance_score=0.9
|
||||
),
|
||||
]
|
||||
|
||||
scores = await scorer.score_batch(contexts, "test")
|
||||
assert len(scores) == 3
|
||||
assert scores[0] == 0.8
|
||||
assert scores[1] == 0.6
|
||||
assert scores[2] == 0.9
|
||||
|
||||
def test_set_mcp_manager(self) -> None:
|
||||
"""Test setting MCP manager."""
|
||||
scorer = RelevanceScorer()
|
||||
assert scorer._mcp is None
|
||||
|
||||
mock_mcp = MagicMock()
|
||||
scorer.set_mcp_manager(mock_mcp)
|
||||
assert scorer._mcp is mock_mcp
|
||||
|
||||
|
||||
class TestRecencyScorer:
|
||||
"""Tests for RecencyScorer."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test scorer creation."""
|
||||
scorer = RecencyScorer()
|
||||
assert scorer.weight == 1.0
|
||||
assert scorer._half_life_hours == 24.0
|
||||
|
||||
def test_creation_with_custom_half_life(self) -> None:
|
||||
"""Test scorer creation with custom half-life."""
|
||||
scorer = RecencyScorer(half_life_hours=12.0)
|
||||
assert scorer._half_life_hours == 12.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_recent_context(self) -> None:
|
||||
"""Test scoring a very recent context."""
|
||||
scorer = RecencyScorer(half_life_hours=24.0)
|
||||
now = datetime.now(UTC)
|
||||
|
||||
context = TaskContext(
|
||||
content="Recent task",
|
||||
source="task",
|
||||
timestamp=now,
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "query", reference_time=now)
|
||||
# Very recent should have score near 1.0
|
||||
assert score > 0.99
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_at_half_life(self) -> None:
|
||||
"""Test scoring at exactly half-life age."""
|
||||
scorer = RecencyScorer(half_life_hours=24.0)
|
||||
now = datetime.now(UTC)
|
||||
half_life_ago = now - timedelta(hours=24)
|
||||
|
||||
context = TaskContext(
|
||||
content="Day old task",
|
||||
source="task",
|
||||
timestamp=half_life_ago,
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "query", reference_time=now)
|
||||
# At half-life, score should be ~0.5
|
||||
assert 0.49 <= score <= 0.51
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_old_context(self) -> None:
|
||||
"""Test scoring a very old context."""
|
||||
scorer = RecencyScorer(half_life_hours=24.0)
|
||||
now = datetime.now(UTC)
|
||||
week_ago = now - timedelta(days=7)
|
||||
|
||||
context = TaskContext(
|
||||
content="Week old task",
|
||||
source="task",
|
||||
timestamp=week_ago,
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "query", reference_time=now)
|
||||
# 7 days with 24h half-life = very low score
|
||||
assert score < 0.01
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_type_specific_half_lives(self) -> None:
|
||||
"""Test that different context types have different half-lives."""
|
||||
scorer = RecencyScorer()
|
||||
now = datetime.now(UTC)
|
||||
one_hour_ago = now - timedelta(hours=1)
|
||||
|
||||
# Conversation has 1 hour half-life by default
|
||||
conv_context = ConversationContext(
|
||||
content="Hello",
|
||||
source="chat",
|
||||
role=MessageRole.USER,
|
||||
timestamp=one_hour_ago,
|
||||
)
|
||||
|
||||
# Knowledge has 168 hour (1 week) half-life by default
|
||||
knowledge_context = KnowledgeContext(
|
||||
content="Documentation",
|
||||
source="docs",
|
||||
timestamp=one_hour_ago,
|
||||
)
|
||||
|
||||
conv_score = await scorer.score(conv_context, "query", reference_time=now)
|
||||
knowledge_score = await scorer.score(knowledge_context, "query", reference_time=now)
|
||||
|
||||
# Conversation should decay much faster
|
||||
assert conv_score < knowledge_score
|
||||
|
||||
def test_get_half_life(self) -> None:
|
||||
"""Test getting half-life for context type."""
|
||||
scorer = RecencyScorer()
|
||||
|
||||
assert scorer.get_half_life(ContextType.CONVERSATION) == 1.0
|
||||
assert scorer.get_half_life(ContextType.KNOWLEDGE) == 168.0
|
||||
assert scorer.get_half_life(ContextType.SYSTEM) == 720.0
|
||||
|
||||
def test_set_half_life(self) -> None:
|
||||
"""Test setting custom half-life."""
|
||||
scorer = RecencyScorer()
|
||||
|
||||
scorer.set_half_life(ContextType.TASK, 48.0)
|
||||
assert scorer.get_half_life(ContextType.TASK) == 48.0
|
||||
|
||||
def test_set_half_life_invalid(self) -> None:
|
||||
"""Test setting invalid half-life."""
|
||||
scorer = RecencyScorer()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
scorer.set_half_life(ContextType.TASK, 0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
scorer.set_half_life(ContextType.TASK, -1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_batch(self) -> None:
|
||||
"""Test batch scoring."""
|
||||
scorer = RecencyScorer()
|
||||
now = datetime.now(UTC)
|
||||
|
||||
contexts = [
|
||||
TaskContext(content="1", source="t", timestamp=now),
|
||||
TaskContext(
|
||||
content="2", source="t", timestamp=now - timedelta(hours=24)
|
||||
),
|
||||
TaskContext(
|
||||
content="3", source="t", timestamp=now - timedelta(hours=48)
|
||||
),
|
||||
]
|
||||
|
||||
scores = await scorer.score_batch(contexts, "query", reference_time=now)
|
||||
assert len(scores) == 3
|
||||
# Scores should be in descending order (more recent = higher)
|
||||
assert scores[0] > scores[1] > scores[2]
|
||||
|
||||
|
||||
class TestPriorityScorer:
|
||||
"""Tests for PriorityScorer."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test scorer creation."""
|
||||
scorer = PriorityScorer()
|
||||
assert scorer.weight == 1.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_critical_priority(self) -> None:
|
||||
"""Test scoring CRITICAL priority context."""
|
||||
scorer = PriorityScorer()
|
||||
|
||||
context = SystemContext(
|
||||
content="Critical system prompt",
|
||||
source="system",
|
||||
priority=ContextPriority.CRITICAL.value,
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "query")
|
||||
# CRITICAL (100) + type bonus should be > 1.0, normalized to 1.0
|
||||
assert score == 1.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_normal_priority(self) -> None:
|
||||
"""Test scoring NORMAL priority context."""
|
||||
scorer = PriorityScorer()
|
||||
|
||||
context = TaskContext(
|
||||
content="Normal task",
|
||||
source="task",
|
||||
priority=ContextPriority.NORMAL.value,
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "query")
|
||||
# NORMAL (50) = 0.5, plus TASK bonus (0.15) = 0.65
|
||||
assert 0.6 <= score <= 0.7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_low_priority(self) -> None:
|
||||
"""Test scoring LOW priority context."""
|
||||
scorer = PriorityScorer()
|
||||
|
||||
context = KnowledgeContext(
|
||||
content="Low priority knowledge",
|
||||
source="docs",
|
||||
priority=ContextPriority.LOW.value,
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "query")
|
||||
# LOW (20) = 0.2, no bonus for KNOWLEDGE
|
||||
assert 0.15 <= score <= 0.25
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_type_bonuses(self) -> None:
|
||||
"""Test type-specific priority bonuses."""
|
||||
scorer = PriorityScorer()
|
||||
|
||||
# All with same base priority
|
||||
system_ctx = SystemContext(
|
||||
content="System",
|
||||
source="system",
|
||||
priority=50,
|
||||
)
|
||||
task_ctx = TaskContext(
|
||||
content="Task",
|
||||
source="task",
|
||||
priority=50,
|
||||
)
|
||||
knowledge_ctx = KnowledgeContext(
|
||||
content="Knowledge",
|
||||
source="docs",
|
||||
priority=50,
|
||||
)
|
||||
|
||||
system_score = await scorer.score(system_ctx, "query")
|
||||
task_score = await scorer.score(task_ctx, "query")
|
||||
knowledge_score = await scorer.score(knowledge_ctx, "query")
|
||||
|
||||
# System has highest bonus (0.2), task next (0.15), knowledge has none
|
||||
assert system_score > task_score > knowledge_score
|
||||
|
||||
def test_get_type_bonus(self) -> None:
|
||||
"""Test getting type bonus."""
|
||||
scorer = PriorityScorer()
|
||||
|
||||
assert scorer.get_type_bonus(ContextType.SYSTEM) == 0.2
|
||||
assert scorer.get_type_bonus(ContextType.TASK) == 0.15
|
||||
assert scorer.get_type_bonus(ContextType.KNOWLEDGE) == 0.0
|
||||
|
||||
def test_set_type_bonus(self) -> None:
|
||||
"""Test setting custom type bonus."""
|
||||
scorer = PriorityScorer()
|
||||
|
||||
scorer.set_type_bonus(ContextType.KNOWLEDGE, 0.1)
|
||||
assert scorer.get_type_bonus(ContextType.KNOWLEDGE) == 0.1
|
||||
|
||||
def test_set_type_bonus_invalid(self) -> None:
|
||||
"""Test setting invalid type bonus."""
|
||||
scorer = PriorityScorer()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
scorer.set_type_bonus(ContextType.KNOWLEDGE, 1.5)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
scorer.set_type_bonus(ContextType.KNOWLEDGE, -0.1)
|
||||
|
||||
|
||||
class TestCompositeScorer:
|
||||
"""Tests for CompositeScorer."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test scorer creation with default weights."""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
weights = scorer.weights
|
||||
assert weights["relevance"] == 0.5
|
||||
assert weights["recency"] == 0.3
|
||||
assert weights["priority"] == 0.2
|
||||
|
||||
def test_creation_with_custom_weights(self) -> None:
|
||||
"""Test scorer creation with custom weights."""
|
||||
scorer = CompositeScorer(
|
||||
relevance_weight=0.6,
|
||||
recency_weight=0.2,
|
||||
priority_weight=0.2,
|
||||
)
|
||||
|
||||
weights = scorer.weights
|
||||
assert weights["relevance"] == 0.6
|
||||
assert weights["recency"] == 0.2
|
||||
assert weights["priority"] == 0.2
|
||||
|
||||
def test_update_weights(self) -> None:
|
||||
"""Test updating weights."""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
scorer.update_weights(relevance=0.7, recency=0.2, priority=0.1)
|
||||
|
||||
weights = scorer.weights
|
||||
assert weights["relevance"] == 0.7
|
||||
assert weights["recency"] == 0.2
|
||||
assert weights["priority"] == 0.1
|
||||
|
||||
def test_update_weights_partial(self) -> None:
|
||||
"""Test partially updating weights."""
|
||||
scorer = CompositeScorer()
|
||||
original_recency = scorer.weights["recency"]
|
||||
|
||||
scorer.update_weights(relevance=0.8)
|
||||
|
||||
assert scorer.weights["relevance"] == 0.8
|
||||
assert scorer.weights["recency"] == original_recency
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_basic(self) -> None:
|
||||
"""Test basic composite scoring."""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
context = KnowledgeContext(
|
||||
content="Test content",
|
||||
source="docs",
|
||||
relevance_score=0.8,
|
||||
timestamp=datetime.now(UTC),
|
||||
priority=ContextPriority.NORMAL.value,
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "test query")
|
||||
assert 0.0 <= score <= 1.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_with_details(self) -> None:
|
||||
"""Test scoring with detailed breakdown."""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
context = KnowledgeContext(
|
||||
content="Test content",
|
||||
source="docs",
|
||||
relevance_score=0.8,
|
||||
timestamp=datetime.now(UTC),
|
||||
priority=ContextPriority.HIGH.value,
|
||||
)
|
||||
|
||||
scored = await scorer.score_with_details(context, "test query")
|
||||
|
||||
assert isinstance(scored, ScoredContext)
|
||||
assert scored.context is context
|
||||
assert 0.0 <= scored.composite_score <= 1.0
|
||||
assert scored.relevance_score == 0.8
|
||||
assert scored.recency_score > 0.9 # Very recent
|
||||
assert scored.priority_score > 0.5 # HIGH priority
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_cached_on_context(self) -> None:
|
||||
"""Test that score is cached on the context."""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
context = KnowledgeContext(
|
||||
content="Test",
|
||||
source="docs",
|
||||
relevance_score=0.5,
|
||||
)
|
||||
|
||||
# First scoring
|
||||
await scorer.score(context, "query")
|
||||
assert context._score is not None
|
||||
|
||||
# Second scoring should use cached value
|
||||
context._score = 0.999 # Set to a known value
|
||||
score2 = await scorer.score(context, "query")
|
||||
assert score2 == 0.999
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_batch(self) -> None:
|
||||
"""Test batch scoring."""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="High relevance",
|
||||
source="docs",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Low relevance",
|
||||
source="docs",
|
||||
relevance_score=0.2,
|
||||
),
|
||||
]
|
||||
|
||||
scored = await scorer.score_batch(contexts, "query")
|
||||
assert len(scored) == 2
|
||||
assert scored[0].relevance_score > scored[1].relevance_score
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank(self) -> None:
|
||||
"""Test ranking contexts."""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="Low", source="docs", relevance_score=0.2
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="High", source="docs", relevance_score=0.9
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Medium", source="docs", relevance_score=0.5
|
||||
),
|
||||
]
|
||||
|
||||
ranked = await scorer.rank(contexts, "query")
|
||||
|
||||
# Should be sorted by score (highest first)
|
||||
assert len(ranked) == 3
|
||||
assert ranked[0].relevance_score == 0.9
|
||||
assert ranked[1].relevance_score == 0.5
|
||||
assert ranked[2].relevance_score == 0.2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_with_limit(self) -> None:
|
||||
"""Test ranking with limit."""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content=str(i), source="docs", relevance_score=i / 10
|
||||
)
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
ranked = await scorer.rank(contexts, "query", limit=3)
|
||||
assert len(ranked) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_with_min_score(self) -> None:
|
||||
"""Test ranking with minimum score threshold."""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="Low", source="docs", relevance_score=0.1
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="High", source="docs", relevance_score=0.9
|
||||
),
|
||||
]
|
||||
|
||||
ranked = await scorer.rank(contexts, "query", min_score=0.5)
|
||||
|
||||
# Only the high relevance context should pass the threshold
|
||||
assert len(ranked) <= 2 # Could be 1 if min_score filters
|
||||
|
||||
def test_set_mcp_manager(self) -> None:
|
||||
"""Test setting MCP manager."""
|
||||
scorer = CompositeScorer()
|
||||
mock_mcp = MagicMock()
|
||||
|
||||
scorer.set_mcp_manager(mock_mcp)
|
||||
assert scorer._relevance_scorer._mcp is mock_mcp
|
||||
|
||||
|
||||
class TestScoredContext:
|
||||
"""Tests for ScoredContext dataclass."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test ScoredContext creation."""
|
||||
context = TaskContext(content="Test", source="task")
|
||||
scored = ScoredContext(
|
||||
context=context,
|
||||
composite_score=0.75,
|
||||
relevance_score=0.8,
|
||||
recency_score=0.7,
|
||||
priority_score=0.5,
|
||||
)
|
||||
|
||||
assert scored.context is context
|
||||
assert scored.composite_score == 0.75
|
||||
|
||||
def test_comparison_operators(self) -> None:
|
||||
"""Test comparison operators for sorting."""
|
||||
ctx1 = TaskContext(content="1", source="task")
|
||||
ctx2 = TaskContext(content="2", source="task")
|
||||
|
||||
scored1 = ScoredContext(context=ctx1, composite_score=0.5)
|
||||
scored2 = ScoredContext(context=ctx2, composite_score=0.8)
|
||||
|
||||
assert scored1 < scored2
|
||||
assert scored2 > scored1
|
||||
|
||||
def test_sorting(self) -> None:
|
||||
"""Test sorting scored contexts."""
|
||||
contexts = [
|
||||
ScoredContext(
|
||||
context=TaskContext(content="Low", source="task"),
|
||||
composite_score=0.3,
|
||||
),
|
||||
ScoredContext(
|
||||
context=TaskContext(content="High", source="task"),
|
||||
composite_score=0.9,
|
||||
),
|
||||
ScoredContext(
|
||||
context=TaskContext(content="Medium", source="task"),
|
||||
composite_score=0.6,
|
||||
),
|
||||
]
|
||||
|
||||
sorted_contexts = sorted(contexts, reverse=True)
|
||||
|
||||
assert sorted_contexts[0].composite_score == 0.9
|
||||
assert sorted_contexts[1].composite_score == 0.6
|
||||
assert sorted_contexts[2].composite_score == 0.3
|
||||
|
||||
|
||||
class TestBaseScorer:
|
||||
"""Tests for BaseScorer abstract class."""
|
||||
|
||||
def test_weight_property(self) -> None:
|
||||
"""Test weight property."""
|
||||
# Use a concrete implementation
|
||||
scorer = RelevanceScorer(weight=0.7)
|
||||
assert scorer.weight == 0.7
|
||||
|
||||
def test_weight_setter_valid(self) -> None:
|
||||
"""Test weight setter with valid values."""
|
||||
scorer = RelevanceScorer()
|
||||
scorer.weight = 0.5
|
||||
assert scorer.weight == 0.5
|
||||
|
||||
def test_weight_setter_invalid(self) -> None:
|
||||
"""Test weight setter with invalid values."""
|
||||
scorer = RelevanceScorer()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
scorer.weight = -0.1
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
scorer.weight = 1.5
|
||||
|
||||
def test_normalize_score(self) -> None:
|
||||
"""Test score normalization."""
|
||||
scorer = RelevanceScorer()
|
||||
|
||||
# Normal range
|
||||
assert scorer.normalize_score(0.5) == 0.5
|
||||
|
||||
# Below 0
|
||||
assert scorer.normalize_score(-0.5) == 0.0
|
||||
|
||||
# Above 1
|
||||
assert scorer.normalize_score(1.5) == 1.0
|
||||
|
||||
# Boundaries
|
||||
assert scorer.normalize_score(0.0) == 0.0
|
||||
assert scorer.normalize_score(1.0) == 1.0
|
||||
@@ -286,9 +286,20 @@ class TestTaskContext:
|
||||
assert ctx.title == "Login Feature"
|
||||
assert ctx.get_type() == ContextType.TASK
|
||||
|
||||
def test_default_high_priority(self) -> None:
|
||||
"""Test that task context defaults to high priority."""
|
||||
def test_default_normal_priority(self) -> None:
|
||||
"""Test that task context uses NORMAL priority from base class."""
|
||||
ctx = TaskContext(content="Test", source="test")
|
||||
# TaskContext inherits NORMAL priority from BaseContext
|
||||
# Use TaskContext.create() for default HIGH priority behavior
|
||||
assert ctx.priority == ContextPriority.NORMAL.value
|
||||
|
||||
def test_explicit_high_priority(self) -> None:
|
||||
"""Test setting explicit HIGH priority."""
|
||||
ctx = TaskContext(
|
||||
content="Test",
|
||||
source="test",
|
||||
priority=ContextPriority.HIGH.value,
|
||||
)
|
||||
assert ctx.priority == ContextPriority.HIGH.value
|
||||
|
||||
def test_create_factory(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user