forked from cardosofelipe/fast-next-template
feat(context): implement context scoring and ranking (Phase 3)
Add comprehensive scoring system with three strategies: - RelevanceScorer: Semantic similarity with keyword fallback - RecencyScorer: Exponential decay with type-specific half-lives - PriorityScorer: Priority-based scoring with type bonuses Implement CompositeScorer combining all strategies with configurable weights (default: 50% relevance, 30% recency, 20% priority). Add ContextRanker for budget-aware context selection with: - Greedy selection algorithm respecting token budgets - CRITICAL priority contexts always included - Diversity reranking to prevent source dominance - Comprehensive selection statistics 68 tests covering all scoring and ranking functionality. Part of #61 - Context Management Engine 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -63,6 +63,22 @@ from .exceptions import (
|
|||||||
TokenCountError,
|
TokenCountError,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Prioritization
|
||||||
|
from .prioritization import (
|
||||||
|
ContextRanker,
|
||||||
|
RankingResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Scoring
|
||||||
|
from .scoring import (
|
||||||
|
BaseScorer,
|
||||||
|
CompositeScorer,
|
||||||
|
PriorityScorer,
|
||||||
|
RecencyScorer,
|
||||||
|
RelevanceScorer,
|
||||||
|
ScoredContext,
|
||||||
|
)
|
||||||
|
|
||||||
# Types
|
# Types
|
||||||
from .types import (
|
from .types import (
|
||||||
AssembledContext,
|
AssembledContext,
|
||||||
@@ -101,6 +117,16 @@ __all__ = [
|
|||||||
"InvalidContextError",
|
"InvalidContextError",
|
||||||
"ScoringError",
|
"ScoringError",
|
||||||
"TokenCountError",
|
"TokenCountError",
|
||||||
|
# Prioritization
|
||||||
|
"ContextRanker",
|
||||||
|
"RankingResult",
|
||||||
|
# Scoring
|
||||||
|
"BaseScorer",
|
||||||
|
"CompositeScorer",
|
||||||
|
"PriorityScorer",
|
||||||
|
"RecencyScorer",
|
||||||
|
"RelevanceScorer",
|
||||||
|
"ScoredContext",
|
||||||
# Types - Base
|
# Types - Base
|
||||||
"AssembledContext",
|
"AssembledContext",
|
||||||
"BaseContext",
|
"BaseContext",
|
||||||
|
|||||||
@@ -3,3 +3,10 @@ Context Prioritization Module.
|
|||||||
|
|
||||||
Provides context ranking and selection.
|
Provides context ranking and selection.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from .ranker import ContextRanker, RankingResult
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ContextRanker",
|
||||||
|
"RankingResult",
|
||||||
|
]
|
||||||
|
|||||||
288
backend/app/services/context/prioritization/ranker.py
Normal file
288
backend/app/services/context/prioritization/ranker.py
Normal file
@@ -0,0 +1,288 @@
|
|||||||
|
"""
|
||||||
|
Context Ranker for Context Management.
|
||||||
|
|
||||||
|
Ranks and selects contexts based on scores and budget constraints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from ..budget import TokenBudget, TokenCalculator
|
||||||
|
from ..scoring.composite import CompositeScorer, ScoredContext
|
||||||
|
from ..types import BaseContext, ContextType
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.services.mcp.client_manager import MCPClientManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RankingResult:
|
||||||
|
"""Result of context ranking and selection."""
|
||||||
|
|
||||||
|
selected: list[ScoredContext]
|
||||||
|
excluded: list[ScoredContext]
|
||||||
|
total_tokens: int
|
||||||
|
selection_stats: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def selected_contexts(self) -> list[BaseContext]:
|
||||||
|
"""Get just the context objects (not scored wrappers)."""
|
||||||
|
return [s.context for s in self.selected]
|
||||||
|
|
||||||
|
|
||||||
|
class ContextRanker:
|
||||||
|
"""
|
||||||
|
Ranks and selects contexts within budget constraints.
|
||||||
|
|
||||||
|
Uses greedy selection to maximize total score
|
||||||
|
while respecting token budgets per context type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
scorer: CompositeScorer | None = None,
|
||||||
|
calculator: TokenCalculator | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize context ranker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scorer: Composite scorer for scoring contexts
|
||||||
|
calculator: Token calculator for counting tokens
|
||||||
|
"""
|
||||||
|
self._scorer = scorer or CompositeScorer()
|
||||||
|
self._calculator = calculator or TokenCalculator()
|
||||||
|
|
||||||
|
def set_scorer(self, scorer: CompositeScorer) -> None:
|
||||||
|
"""Set the scorer."""
|
||||||
|
self._scorer = scorer
|
||||||
|
|
||||||
|
def set_calculator(self, calculator: TokenCalculator) -> None:
|
||||||
|
"""Set the token calculator."""
|
||||||
|
self._calculator = calculator
|
||||||
|
|
||||||
|
async def rank(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
query: str,
|
||||||
|
budget: TokenBudget,
|
||||||
|
model: str | None = None,
|
||||||
|
ensure_required: bool = True,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> RankingResult:
|
||||||
|
"""
|
||||||
|
Rank and select contexts within budget.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: Contexts to rank
|
||||||
|
query: Query to rank against
|
||||||
|
budget: Token budget constraints
|
||||||
|
model: Model for token counting
|
||||||
|
ensure_required: If True, always include CRITICAL priority contexts
|
||||||
|
**kwargs: Additional scoring parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RankingResult with selected and excluded contexts
|
||||||
|
"""
|
||||||
|
if not contexts:
|
||||||
|
return RankingResult(
|
||||||
|
selected=[],
|
||||||
|
excluded=[],
|
||||||
|
total_tokens=0,
|
||||||
|
selection_stats={"total_contexts": 0},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. Ensure all contexts have token counts
|
||||||
|
await self._ensure_token_counts(contexts, model)
|
||||||
|
|
||||||
|
# 2. Score all contexts
|
||||||
|
scored_contexts = await self._scorer.score_batch(contexts, query, **kwargs)
|
||||||
|
|
||||||
|
# 3. Separate required (CRITICAL priority) from optional
|
||||||
|
required: list[ScoredContext] = []
|
||||||
|
optional: list[ScoredContext] = []
|
||||||
|
|
||||||
|
if ensure_required:
|
||||||
|
for sc in scored_contexts:
|
||||||
|
# CRITICAL priority (100) contexts are always included
|
||||||
|
if sc.context.priority >= 100:
|
||||||
|
required.append(sc)
|
||||||
|
else:
|
||||||
|
optional.append(sc)
|
||||||
|
else:
|
||||||
|
optional = list(scored_contexts)
|
||||||
|
|
||||||
|
# 4. Sort optional by score (highest first)
|
||||||
|
optional.sort(reverse=True)
|
||||||
|
|
||||||
|
# 5. Greedy selection
|
||||||
|
selected: list[ScoredContext] = []
|
||||||
|
excluded: list[ScoredContext] = []
|
||||||
|
total_tokens = 0
|
||||||
|
|
||||||
|
# First, try to fit required contexts
|
||||||
|
for sc in required:
|
||||||
|
token_count = sc.context.token_count or 0
|
||||||
|
context_type = sc.context.get_type()
|
||||||
|
|
||||||
|
if budget.can_fit(context_type, token_count):
|
||||||
|
budget.allocate(context_type, token_count)
|
||||||
|
selected.append(sc)
|
||||||
|
total_tokens += token_count
|
||||||
|
else:
|
||||||
|
# Force-fit CRITICAL contexts if needed
|
||||||
|
budget.allocate(context_type, token_count, force=True)
|
||||||
|
selected.append(sc)
|
||||||
|
total_tokens += token_count
|
||||||
|
logger.warning(
|
||||||
|
f"Force-fitted CRITICAL context: {sc.context.source} "
|
||||||
|
f"({token_count} tokens)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Then, greedily add optional contexts
|
||||||
|
for sc in optional:
|
||||||
|
token_count = sc.context.token_count or 0
|
||||||
|
context_type = sc.context.get_type()
|
||||||
|
|
||||||
|
if budget.can_fit(context_type, token_count):
|
||||||
|
budget.allocate(context_type, token_count)
|
||||||
|
selected.append(sc)
|
||||||
|
total_tokens += token_count
|
||||||
|
else:
|
||||||
|
excluded.append(sc)
|
||||||
|
|
||||||
|
# Build stats
|
||||||
|
stats = {
|
||||||
|
"total_contexts": len(contexts),
|
||||||
|
"required_count": len(required),
|
||||||
|
"selected_count": len(selected),
|
||||||
|
"excluded_count": len(excluded),
|
||||||
|
"total_tokens": total_tokens,
|
||||||
|
"by_type": self._count_by_type(selected),
|
||||||
|
}
|
||||||
|
|
||||||
|
return RankingResult(
|
||||||
|
selected=selected,
|
||||||
|
excluded=excluded,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
selection_stats=stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def rank_simple(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
query: str,
|
||||||
|
max_tokens: int,
|
||||||
|
model: str | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> list[BaseContext]:
|
||||||
|
"""
|
||||||
|
Simple ranking without budget per type.
|
||||||
|
|
||||||
|
Selects top contexts by score until max tokens reached.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: Contexts to rank
|
||||||
|
query: Query to rank against
|
||||||
|
max_tokens: Maximum total tokens
|
||||||
|
model: Model for token counting
|
||||||
|
**kwargs: Additional scoring parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Selected contexts (in score order)
|
||||||
|
"""
|
||||||
|
if not contexts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Ensure token counts
|
||||||
|
await self._ensure_token_counts(contexts, model)
|
||||||
|
|
||||||
|
# Score all contexts
|
||||||
|
scored_contexts = await self._scorer.score_batch(contexts, query, **kwargs)
|
||||||
|
|
||||||
|
# Sort by score (highest first)
|
||||||
|
scored_contexts.sort(reverse=True)
|
||||||
|
|
||||||
|
# Greedy selection
|
||||||
|
selected: list[BaseContext] = []
|
||||||
|
total_tokens = 0
|
||||||
|
|
||||||
|
for sc in scored_contexts:
|
||||||
|
token_count = sc.context.token_count or 0
|
||||||
|
if total_tokens + token_count <= max_tokens:
|
||||||
|
selected.append(sc.context)
|
||||||
|
total_tokens += token_count
|
||||||
|
|
||||||
|
return selected
|
||||||
|
|
||||||
|
async def _ensure_token_counts(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
model: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Ensure all contexts have token counts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: Contexts to check
|
||||||
|
model: Model for token counting
|
||||||
|
"""
|
||||||
|
for context in contexts:
|
||||||
|
if context.token_count is None:
|
||||||
|
count = await self._calculator.count_tokens(
|
||||||
|
context.content, model
|
||||||
|
)
|
||||||
|
context.token_count = count
|
||||||
|
|
||||||
|
def _count_by_type(
|
||||||
|
self, scored_contexts: list[ScoredContext]
|
||||||
|
) -> dict[str, dict[str, int]]:
|
||||||
|
"""Count selected contexts by type."""
|
||||||
|
by_type: dict[str, dict[str, int]] = {}
|
||||||
|
|
||||||
|
for sc in scored_contexts:
|
||||||
|
type_name = sc.context.get_type().value
|
||||||
|
if type_name not in by_type:
|
||||||
|
by_type[type_name] = {"count": 0, "tokens": 0}
|
||||||
|
by_type[type_name]["count"] += 1
|
||||||
|
by_type[type_name]["tokens"] += sc.context.token_count or 0
|
||||||
|
|
||||||
|
return by_type
|
||||||
|
|
||||||
|
async def rerank_for_diversity(
|
||||||
|
self,
|
||||||
|
scored_contexts: list[ScoredContext],
|
||||||
|
max_per_source: int = 3,
|
||||||
|
) -> list[ScoredContext]:
|
||||||
|
"""
|
||||||
|
Rerank to ensure source diversity.
|
||||||
|
|
||||||
|
Prevents too many items from the same source.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scored_contexts: Already scored contexts
|
||||||
|
max_per_source: Maximum items per source
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Reranked contexts
|
||||||
|
"""
|
||||||
|
source_counts: dict[str, int] = {}
|
||||||
|
result: list[ScoredContext] = []
|
||||||
|
deferred: list[ScoredContext] = []
|
||||||
|
|
||||||
|
for sc in scored_contexts:
|
||||||
|
source = sc.context.source
|
||||||
|
current_count = source_counts.get(source, 0)
|
||||||
|
|
||||||
|
if current_count < max_per_source:
|
||||||
|
result.append(sc)
|
||||||
|
source_counts[source] = current_count + 1
|
||||||
|
else:
|
||||||
|
deferred.append(sc)
|
||||||
|
|
||||||
|
# Add deferred items at the end
|
||||||
|
result.extend(deferred)
|
||||||
|
return result
|
||||||
@@ -1,5 +1,21 @@
|
|||||||
"""
|
"""
|
||||||
Context Scoring Module.
|
Context Scoring Module.
|
||||||
|
|
||||||
Provides relevance, recency, and composite scoring.
|
Provides scoring strategies for context prioritization.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from .base import BaseScorer, ScorerProtocol
|
||||||
|
from .composite import CompositeScorer, ScoredContext
|
||||||
|
from .priority import PriorityScorer
|
||||||
|
from .recency import RecencyScorer
|
||||||
|
from .relevance import RelevanceScorer
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseScorer",
|
||||||
|
"CompositeScorer",
|
||||||
|
"PriorityScorer",
|
||||||
|
"RecencyScorer",
|
||||||
|
"RelevanceScorer",
|
||||||
|
"ScoredContext",
|
||||||
|
"ScorerProtocol",
|
||||||
|
]
|
||||||
|
|||||||
99
backend/app/services/context/scoring/base.py
Normal file
99
backend/app/services/context/scoring/base.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
"""
|
||||||
|
Base Scorer Protocol and Types.
|
||||||
|
|
||||||
|
Defines the interface for context scoring implementations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from ..types import BaseContext
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.services.mcp.client_manager import MCPClientManager
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class ScorerProtocol(Protocol):
|
||||||
|
"""Protocol for context scorers."""
|
||||||
|
|
||||||
|
async def score(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Score a context item.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to score
|
||||||
|
query: Query to score against
|
||||||
|
**kwargs: Additional scoring parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Score between 0.0 and 1.0
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class BaseScorer(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for context scorers.
|
||||||
|
|
||||||
|
Provides common functionality and interface for
|
||||||
|
different scoring strategies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, weight: float = 1.0) -> None:
|
||||||
|
"""
|
||||||
|
Initialize scorer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weight: Weight for this scorer in composite scoring
|
||||||
|
"""
|
||||||
|
self._weight = weight
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weight(self) -> float:
|
||||||
|
"""Get scorer weight."""
|
||||||
|
return self._weight
|
||||||
|
|
||||||
|
@weight.setter
|
||||||
|
def weight(self, value: float) -> None:
|
||||||
|
"""Set scorer weight."""
|
||||||
|
if not 0.0 <= value <= 1.0:
|
||||||
|
raise ValueError("Weight must be between 0.0 and 1.0")
|
||||||
|
self._weight = value
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def score(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Score a context item.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to score
|
||||||
|
query: Query to score against
|
||||||
|
**kwargs: Additional scoring parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Score between 0.0 and 1.0
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def normalize_score(self, score: float) -> float:
|
||||||
|
"""
|
||||||
|
Normalize score to [0.0, 1.0] range.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
score: Raw score
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized score
|
||||||
|
"""
|
||||||
|
return max(0.0, min(1.0, score))
|
||||||
276
backend/app/services/context/scoring/composite.py
Normal file
276
backend/app/services/context/scoring/composite.py
Normal file
@@ -0,0 +1,276 @@
|
|||||||
|
"""
|
||||||
|
Composite Scorer for Context Management.
|
||||||
|
|
||||||
|
Combines multiple scoring strategies with configurable weights.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from .base import BaseScorer
|
||||||
|
from .priority import PriorityScorer
|
||||||
|
from .recency import RecencyScorer
|
||||||
|
from .relevance import RelevanceScorer
|
||||||
|
from ..config import ContextSettings, get_context_settings
|
||||||
|
from ..types import BaseContext
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.services.mcp.client_manager import MCPClientManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ScoredContext:
|
||||||
|
"""Context with computed scores."""
|
||||||
|
|
||||||
|
context: BaseContext
|
||||||
|
composite_score: float
|
||||||
|
relevance_score: float = 0.0
|
||||||
|
recency_score: float = 0.0
|
||||||
|
priority_score: float = 0.0
|
||||||
|
|
||||||
|
def __lt__(self, other: "ScoredContext") -> bool:
|
||||||
|
"""Enable sorting by composite score."""
|
||||||
|
return self.composite_score < other.composite_score
|
||||||
|
|
||||||
|
def __gt__(self, other: "ScoredContext") -> bool:
|
||||||
|
"""Enable sorting by composite score."""
|
||||||
|
return self.composite_score > other.composite_score
|
||||||
|
|
||||||
|
|
||||||
|
class CompositeScorer:
|
||||||
|
"""
|
||||||
|
Combines multiple scoring strategies.
|
||||||
|
|
||||||
|
Weights:
|
||||||
|
- relevance: How well content matches the query
|
||||||
|
- recency: How recent the content is
|
||||||
|
- priority: Explicit priority assignments
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mcp_manager: "MCPClientManager | None" = None,
|
||||||
|
settings: ContextSettings | None = None,
|
||||||
|
relevance_weight: float | None = None,
|
||||||
|
recency_weight: float | None = None,
|
||||||
|
priority_weight: float | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize composite scorer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_manager: MCP manager for semantic scoring
|
||||||
|
settings: Context settings (uses default if None)
|
||||||
|
relevance_weight: Override relevance weight
|
||||||
|
recency_weight: Override recency weight
|
||||||
|
priority_weight: Override priority weight
|
||||||
|
"""
|
||||||
|
self._settings = settings or get_context_settings()
|
||||||
|
weights = self._settings.get_scoring_weights()
|
||||||
|
|
||||||
|
self._relevance_weight = (
|
||||||
|
relevance_weight if relevance_weight is not None else weights["relevance"]
|
||||||
|
)
|
||||||
|
self._recency_weight = (
|
||||||
|
recency_weight if recency_weight is not None else weights["recency"]
|
||||||
|
)
|
||||||
|
self._priority_weight = (
|
||||||
|
priority_weight if priority_weight is not None else weights["priority"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize scorers
|
||||||
|
self._relevance_scorer = RelevanceScorer(
|
||||||
|
mcp_manager=mcp_manager,
|
||||||
|
weight=self._relevance_weight,
|
||||||
|
)
|
||||||
|
self._recency_scorer = RecencyScorer(weight=self._recency_weight)
|
||||||
|
self._priority_scorer = PriorityScorer(weight=self._priority_weight)
|
||||||
|
|
||||||
|
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||||
|
"""Set MCP manager for semantic scoring."""
|
||||||
|
self._relevance_scorer.set_mcp_manager(mcp_manager)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weights(self) -> dict[str, float]:
|
||||||
|
"""Get current scoring weights."""
|
||||||
|
return {
|
||||||
|
"relevance": self._relevance_weight,
|
||||||
|
"recency": self._recency_weight,
|
||||||
|
"priority": self._priority_weight,
|
||||||
|
}
|
||||||
|
|
||||||
|
def update_weights(
|
||||||
|
self,
|
||||||
|
relevance: float | None = None,
|
||||||
|
recency: float | None = None,
|
||||||
|
priority: float | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Update scoring weights.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
relevance: New relevance weight
|
||||||
|
recency: New recency weight
|
||||||
|
priority: New priority weight
|
||||||
|
"""
|
||||||
|
if relevance is not None:
|
||||||
|
self._relevance_weight = max(0.0, min(1.0, relevance))
|
||||||
|
self._relevance_scorer.weight = self._relevance_weight
|
||||||
|
|
||||||
|
if recency is not None:
|
||||||
|
self._recency_weight = max(0.0, min(1.0, recency))
|
||||||
|
self._recency_scorer.weight = self._recency_weight
|
||||||
|
|
||||||
|
if priority is not None:
|
||||||
|
self._priority_weight = max(0.0, min(1.0, priority))
|
||||||
|
self._priority_scorer.weight = self._priority_weight
|
||||||
|
|
||||||
|
async def score(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Compute composite score for a context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to score
|
||||||
|
query: Query to score against
|
||||||
|
**kwargs: Additional scoring parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Composite score between 0.0 and 1.0
|
||||||
|
"""
|
||||||
|
scored = await self.score_with_details(context, query, **kwargs)
|
||||||
|
return scored.composite_score
|
||||||
|
|
||||||
|
async def score_with_details(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ScoredContext:
|
||||||
|
"""
|
||||||
|
Compute composite score with individual scores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to score
|
||||||
|
query: Query to score against
|
||||||
|
**kwargs: Additional scoring parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ScoredContext with all scores
|
||||||
|
"""
|
||||||
|
# Check if context already has a score
|
||||||
|
if context._score is not None:
|
||||||
|
return ScoredContext(
|
||||||
|
context=context,
|
||||||
|
composite_score=context._score,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute individual scores in parallel
|
||||||
|
relevance_task = self._relevance_scorer.score(context, query, **kwargs)
|
||||||
|
recency_task = self._recency_scorer.score(context, query, **kwargs)
|
||||||
|
priority_task = self._priority_scorer.score(context, query, **kwargs)
|
||||||
|
|
||||||
|
relevance_score, recency_score, priority_score = await asyncio.gather(
|
||||||
|
relevance_task, recency_task, priority_task
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute weighted composite
|
||||||
|
total_weight = (
|
||||||
|
self._relevance_weight + self._recency_weight + self._priority_weight
|
||||||
|
)
|
||||||
|
|
||||||
|
if total_weight > 0:
|
||||||
|
composite = (
|
||||||
|
relevance_score * self._relevance_weight
|
||||||
|
+ recency_score * self._recency_weight
|
||||||
|
+ priority_score * self._priority_weight
|
||||||
|
) / total_weight
|
||||||
|
else:
|
||||||
|
composite = 0.0
|
||||||
|
|
||||||
|
# Cache the score on the context
|
||||||
|
context._score = composite
|
||||||
|
|
||||||
|
return ScoredContext(
|
||||||
|
context=context,
|
||||||
|
composite_score=composite,
|
||||||
|
relevance_score=relevance_score,
|
||||||
|
recency_score=recency_score,
|
||||||
|
priority_score=priority_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def score_batch(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
query: str,
|
||||||
|
parallel: bool = True,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> list[ScoredContext]:
|
||||||
|
"""
|
||||||
|
Score multiple contexts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: Contexts to score
|
||||||
|
query: Query to score against
|
||||||
|
parallel: Whether to score in parallel
|
||||||
|
**kwargs: Additional scoring parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ScoredContext (same order as input)
|
||||||
|
"""
|
||||||
|
if parallel:
|
||||||
|
tasks = [
|
||||||
|
self.score_with_details(ctx, query, **kwargs) for ctx in contexts
|
||||||
|
]
|
||||||
|
return await asyncio.gather(*tasks)
|
||||||
|
else:
|
||||||
|
results = []
|
||||||
|
for ctx in contexts:
|
||||||
|
scored = await self.score_with_details(ctx, query, **kwargs)
|
||||||
|
results.append(scored)
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def rank(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
query: str,
|
||||||
|
limit: int | None = None,
|
||||||
|
min_score: float = 0.0,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> list[ScoredContext]:
|
||||||
|
"""
|
||||||
|
Score and rank contexts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: Contexts to rank
|
||||||
|
query: Query to rank against
|
||||||
|
limit: Maximum number of results
|
||||||
|
min_score: Minimum score threshold
|
||||||
|
**kwargs: Additional scoring parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sorted list of ScoredContext (highest first)
|
||||||
|
"""
|
||||||
|
# Score all contexts
|
||||||
|
scored = await self.score_batch(contexts, query, **kwargs)
|
||||||
|
|
||||||
|
# Filter by minimum score
|
||||||
|
if min_score > 0:
|
||||||
|
scored = [s for s in scored if s.composite_score >= min_score]
|
||||||
|
|
||||||
|
# Sort by score (highest first)
|
||||||
|
scored.sort(reverse=True)
|
||||||
|
|
||||||
|
# Apply limit
|
||||||
|
if limit is not None:
|
||||||
|
scored = scored[:limit]
|
||||||
|
|
||||||
|
return scored
|
||||||
135
backend/app/services/context/scoring/priority.py
Normal file
135
backend/app/services/context/scoring/priority.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
"""
|
||||||
|
Priority Scorer for Context Management.
|
||||||
|
|
||||||
|
Scores context based on assigned priority levels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .base import BaseScorer
|
||||||
|
from ..types import BaseContext, ContextPriority, ContextType
|
||||||
|
|
||||||
|
|
||||||
|
class PriorityScorer(BaseScorer):
|
||||||
|
"""
|
||||||
|
Scores context based on priority levels.
|
||||||
|
|
||||||
|
Converts priority enum values to normalized scores.
|
||||||
|
Also applies type-based priority bonuses.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Default priority bonuses by context type
|
||||||
|
DEFAULT_TYPE_BONUSES: dict[ContextType, float] = {
|
||||||
|
ContextType.SYSTEM: 0.2, # System prompts get a boost
|
||||||
|
ContextType.TASK: 0.15, # Current task is important
|
||||||
|
ContextType.TOOL: 0.1, # Recent tool results matter
|
||||||
|
ContextType.KNOWLEDGE: 0.0, # Knowledge scored by relevance
|
||||||
|
ContextType.CONVERSATION: 0.0, # Conversation scored by recency
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight: float = 1.0,
|
||||||
|
type_bonuses: dict[ContextType, float] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize priority scorer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weight: Scorer weight for composite scoring
|
||||||
|
type_bonuses: Optional context-type priority bonuses
|
||||||
|
"""
|
||||||
|
super().__init__(weight)
|
||||||
|
self._type_bonuses = type_bonuses or self.DEFAULT_TYPE_BONUSES.copy()
|
||||||
|
|
||||||
|
async def score(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Score context based on priority.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to score
|
||||||
|
query: Query (not used for priority, kept for interface)
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Priority score between 0.0 and 1.0
|
||||||
|
"""
|
||||||
|
# Get base priority score
|
||||||
|
priority_value = context.priority
|
||||||
|
base_score = self._priority_to_score(priority_value)
|
||||||
|
|
||||||
|
# Apply type bonus
|
||||||
|
context_type = context.get_type()
|
||||||
|
bonus = self._type_bonuses.get(context_type, 0.0)
|
||||||
|
|
||||||
|
return self.normalize_score(base_score + bonus)
|
||||||
|
|
||||||
|
def _priority_to_score(self, priority: int) -> float:
|
||||||
|
"""
|
||||||
|
Convert priority value to normalized score.
|
||||||
|
|
||||||
|
Priority values (from ContextPriority):
|
||||||
|
- CRITICAL (100) -> 1.0
|
||||||
|
- HIGH (80) -> 0.8
|
||||||
|
- NORMAL (50) -> 0.5
|
||||||
|
- LOW (20) -> 0.2
|
||||||
|
- MINIMAL (0) -> 0.0
|
||||||
|
|
||||||
|
Args:
|
||||||
|
priority: Priority value (0-100)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized score (0.0-1.0)
|
||||||
|
"""
|
||||||
|
# Clamp to valid range
|
||||||
|
clamped = max(0, min(100, priority))
|
||||||
|
return clamped / 100.0
|
||||||
|
|
||||||
|
def get_type_bonus(self, context_type: ContextType) -> float:
|
||||||
|
"""
|
||||||
|
Get priority bonus for a context type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Context type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Bonus value
|
||||||
|
"""
|
||||||
|
return self._type_bonuses.get(context_type, 0.0)
|
||||||
|
|
||||||
|
def set_type_bonus(self, context_type: ContextType, bonus: float) -> None:
|
||||||
|
"""
|
||||||
|
Set priority bonus for a context type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Context type
|
||||||
|
bonus: Bonus value (0.0-1.0)
|
||||||
|
"""
|
||||||
|
if not 0.0 <= bonus <= 1.0:
|
||||||
|
raise ValueError("Bonus must be between 0.0 and 1.0")
|
||||||
|
self._type_bonuses[context_type] = bonus
|
||||||
|
|
||||||
|
async def score_batch(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> list[float]:
|
||||||
|
"""
|
||||||
|
Score multiple contexts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: Contexts to score
|
||||||
|
query: Query (not used)
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of scores (same order as input)
|
||||||
|
"""
|
||||||
|
# Priority scoring is fast, no async needed
|
||||||
|
return [await self.score(ctx, query, **kwargs) for ctx in contexts]
|
||||||
141
backend/app/services/context/scoring/recency.py
Normal file
141
backend/app/services/context/scoring/recency.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
"""
|
||||||
|
Recency Scorer for Context Management.
|
||||||
|
|
||||||
|
Scores context based on how recent it is.
|
||||||
|
More recent content gets higher scores.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .base import BaseScorer
|
||||||
|
from ..types import BaseContext, ContextType
|
||||||
|
|
||||||
|
|
||||||
|
class RecencyScorer(BaseScorer):
|
||||||
|
"""
|
||||||
|
Scores context based on recency.
|
||||||
|
|
||||||
|
Uses exponential decay to score content based on age.
|
||||||
|
More recent content scores higher.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight: float = 1.0,
|
||||||
|
half_life_hours: float = 24.0,
|
||||||
|
type_half_lives: dict[ContextType, float] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize recency scorer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weight: Scorer weight for composite scoring
|
||||||
|
half_life_hours: Default hours until score decays to 0.5
|
||||||
|
type_half_lives: Optional context-type-specific half lives
|
||||||
|
"""
|
||||||
|
super().__init__(weight)
|
||||||
|
self._half_life_hours = half_life_hours
|
||||||
|
self._type_half_lives = type_half_lives or {}
|
||||||
|
|
||||||
|
# Set sensible defaults for context types
|
||||||
|
if ContextType.CONVERSATION not in self._type_half_lives:
|
||||||
|
self._type_half_lives[ContextType.CONVERSATION] = 1.0 # 1 hour
|
||||||
|
if ContextType.TOOL not in self._type_half_lives:
|
||||||
|
self._type_half_lives[ContextType.TOOL] = 0.5 # 30 minutes
|
||||||
|
if ContextType.KNOWLEDGE not in self._type_half_lives:
|
||||||
|
self._type_half_lives[ContextType.KNOWLEDGE] = 168.0 # 1 week
|
||||||
|
if ContextType.SYSTEM not in self._type_half_lives:
|
||||||
|
self._type_half_lives[ContextType.SYSTEM] = 720.0 # 30 days
|
||||||
|
if ContextType.TASK not in self._type_half_lives:
|
||||||
|
self._type_half_lives[ContextType.TASK] = 24.0 # 1 day
|
||||||
|
|
||||||
|
async def score(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Score context based on recency.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to score
|
||||||
|
query: Query (not used for recency, kept for interface)
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
- reference_time: Time to measure recency from (default: now)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Recency score between 0.0 and 1.0
|
||||||
|
"""
|
||||||
|
reference_time = kwargs.get("reference_time")
|
||||||
|
if reference_time is None:
|
||||||
|
reference_time = datetime.now(UTC)
|
||||||
|
elif reference_time.tzinfo is None:
|
||||||
|
reference_time = reference_time.replace(tzinfo=UTC)
|
||||||
|
|
||||||
|
# Ensure context timestamp is timezone-aware
|
||||||
|
context_time = context.timestamp
|
||||||
|
if context_time.tzinfo is None:
|
||||||
|
context_time = context_time.replace(tzinfo=UTC)
|
||||||
|
|
||||||
|
# Calculate age in hours
|
||||||
|
age = reference_time - context_time
|
||||||
|
age_hours = max(0, age.total_seconds() / 3600)
|
||||||
|
|
||||||
|
# Get half-life for this context type
|
||||||
|
context_type = context.get_type()
|
||||||
|
half_life = self._type_half_lives.get(context_type, self._half_life_hours)
|
||||||
|
|
||||||
|
# Exponential decay
|
||||||
|
decay_factor = math.exp(-math.log(2) * age_hours / half_life)
|
||||||
|
|
||||||
|
return self.normalize_score(decay_factor)
|
||||||
|
|
||||||
|
def get_half_life(self, context_type: ContextType) -> float:
|
||||||
|
"""
|
||||||
|
Get half-life for a context type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Context type to get half-life for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Half-life in hours
|
||||||
|
"""
|
||||||
|
return self._type_half_lives.get(context_type, self._half_life_hours)
|
||||||
|
|
||||||
|
def set_half_life(self, context_type: ContextType, hours: float) -> None:
|
||||||
|
"""
|
||||||
|
Set half-life for a context type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Context type to set half-life for
|
||||||
|
hours: Half-life in hours
|
||||||
|
"""
|
||||||
|
if hours <= 0:
|
||||||
|
raise ValueError("Half-life must be positive")
|
||||||
|
self._type_half_lives[context_type] = hours
|
||||||
|
|
||||||
|
async def score_batch(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> list[float]:
|
||||||
|
"""
|
||||||
|
Score multiple contexts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: Contexts to score
|
||||||
|
query: Query (not used)
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of scores (same order as input)
|
||||||
|
"""
|
||||||
|
scores = []
|
||||||
|
for context in contexts:
|
||||||
|
score = await self.score(context, query, **kwargs)
|
||||||
|
scores.append(score)
|
||||||
|
return scores
|
||||||
188
backend/app/services/context/scoring/relevance.py
Normal file
188
backend/app/services/context/scoring/relevance.py
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
"""
|
||||||
|
Relevance Scorer for Context Management.
|
||||||
|
|
||||||
|
Scores context based on semantic similarity to the query.
|
||||||
|
Uses Knowledge Base embeddings when available.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from .base import BaseScorer
|
||||||
|
from ..types import BaseContext, ContextType, KnowledgeContext
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.services.mcp.client_manager import MCPClientManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RelevanceScorer(BaseScorer):
|
||||||
|
"""
|
||||||
|
Scores context based on relevance to query.
|
||||||
|
|
||||||
|
Uses multiple strategies:
|
||||||
|
1. Pre-computed scores (from RAG results)
|
||||||
|
2. MCP-based semantic similarity (via Knowledge Base)
|
||||||
|
3. Keyword matching fallback
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mcp_manager: "MCPClientManager | None" = None,
|
||||||
|
weight: float = 1.0,
|
||||||
|
keyword_fallback_weight: float = 0.5,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize relevance scorer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_manager: MCP manager for Knowledge Base calls
|
||||||
|
weight: Scorer weight for composite scoring
|
||||||
|
keyword_fallback_weight: Max score for keyword-based fallback
|
||||||
|
"""
|
||||||
|
super().__init__(weight)
|
||||||
|
self._mcp = mcp_manager
|
||||||
|
self._keyword_fallback_weight = keyword_fallback_weight
|
||||||
|
|
||||||
|
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||||
|
"""Set MCP manager for semantic scoring."""
|
||||||
|
self._mcp = mcp_manager
|
||||||
|
|
||||||
|
async def score(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Score context relevance to query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to score
|
||||||
|
query: Query to score against
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Relevance score between 0.0 and 1.0
|
||||||
|
"""
|
||||||
|
# 1. Check for pre-computed relevance score
|
||||||
|
if isinstance(context, KnowledgeContext) and context.relevance_score is not None:
|
||||||
|
return self.normalize_score(context.relevance_score)
|
||||||
|
|
||||||
|
# 2. Check metadata for score
|
||||||
|
if "relevance_score" in context.metadata:
|
||||||
|
return self.normalize_score(context.metadata["relevance_score"])
|
||||||
|
|
||||||
|
if "score" in context.metadata:
|
||||||
|
return self.normalize_score(context.metadata["score"])
|
||||||
|
|
||||||
|
# 3. Try MCP-based semantic similarity
|
||||||
|
if self._mcp is not None:
|
||||||
|
try:
|
||||||
|
score = await self._compute_semantic_similarity(context, query)
|
||||||
|
if score is not None:
|
||||||
|
return score
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Semantic scoring failed, using fallback: {e}")
|
||||||
|
|
||||||
|
# 4. Fall back to keyword matching
|
||||||
|
return self._compute_keyword_score(context, query)
|
||||||
|
|
||||||
|
async def _compute_semantic_similarity(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
query: str,
|
||||||
|
) -> float | None:
|
||||||
|
"""
|
||||||
|
Compute semantic similarity using Knowledge Base embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to score
|
||||||
|
query: Query to compare
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Similarity score or None if unavailable
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Use Knowledge Base's search capability to compute similarity
|
||||||
|
result = await self._mcp.call_tool(
|
||||||
|
server="knowledge-base",
|
||||||
|
tool="compute_similarity",
|
||||||
|
args={
|
||||||
|
"text1": query,
|
||||||
|
"text2": context.content[:2000], # Limit content length
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.success and result.data:
|
||||||
|
similarity = result.data.get("similarity")
|
||||||
|
if similarity is not None:
|
||||||
|
return self.normalize_score(float(similarity))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Semantic similarity computation failed: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _compute_keyword_score(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
query: str,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Compute relevance score based on keyword matching.
|
||||||
|
|
||||||
|
Simple but fast fallback when semantic search is unavailable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to score
|
||||||
|
query: Query to match
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Keyword-based relevance score
|
||||||
|
"""
|
||||||
|
if not query or not context.content:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Extract keywords from query
|
||||||
|
query_lower = query.lower()
|
||||||
|
content_lower = context.content.lower()
|
||||||
|
|
||||||
|
# Simple word tokenization
|
||||||
|
query_words = set(re.findall(r"\b\w{3,}\b", query_lower))
|
||||||
|
content_words = set(re.findall(r"\b\w{3,}\b", content_lower))
|
||||||
|
|
||||||
|
if not query_words:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Calculate overlap
|
||||||
|
common_words = query_words & content_words
|
||||||
|
overlap_ratio = len(common_words) / len(query_words)
|
||||||
|
|
||||||
|
# Apply fallback weight ceiling
|
||||||
|
return self.normalize_score(overlap_ratio * self._keyword_fallback_weight)
|
||||||
|
|
||||||
|
async def score_batch(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> list[float]:
|
||||||
|
"""
|
||||||
|
Score multiple contexts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: Contexts to score
|
||||||
|
query: Query to score against
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of scores (same order as input)
|
||||||
|
"""
|
||||||
|
scores = []
|
||||||
|
for context in contexts:
|
||||||
|
score = await self.score(context, query, **kwargs)
|
||||||
|
scores.append(score)
|
||||||
|
return scores
|
||||||
@@ -55,11 +55,9 @@ class TaskContext(BaseContext):
|
|||||||
constraints: list[str] = field(default_factory=list)
|
constraints: list[str] = field(default_factory=list)
|
||||||
parent_task_id: str | None = field(default=None)
|
parent_task_id: str | None = field(default=None)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
# Note: TaskContext should typically have HIGH priority,
|
||||||
"""Set high priority for task context."""
|
# but we don't auto-promote to allow explicit priority setting.
|
||||||
# Task context defaults to high priority
|
# Use TaskContext.create() for default HIGH priority behavior.
|
||||||
if self.priority == ContextPriority.NORMAL.value:
|
|
||||||
self.priority = ContextPriority.HIGH.value
|
|
||||||
|
|
||||||
def get_type(self) -> ContextType:
|
def get_type(self) -> ContextType:
|
||||||
"""Return TASK context type."""
|
"""Return TASK context type."""
|
||||||
|
|||||||
507
backend/tests/services/context/test_ranker.py
Normal file
507
backend/tests/services/context/test_ranker.py
Normal file
@@ -0,0 +1,507 @@
|
|||||||
|
"""Tests for context ranking module."""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.context.budget import BudgetAllocator, TokenBudget
|
||||||
|
from app.services.context.prioritization import ContextRanker, RankingResult
|
||||||
|
from app.services.context.scoring import CompositeScorer, ScoredContext
|
||||||
|
from app.services.context.types import (
|
||||||
|
ContextPriority,
|
||||||
|
ContextType,
|
||||||
|
ConversationContext,
|
||||||
|
KnowledgeContext,
|
||||||
|
MessageRole,
|
||||||
|
SystemContext,
|
||||||
|
TaskContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRankingResult:
|
||||||
|
"""Tests for RankingResult dataclass."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test RankingResult creation."""
|
||||||
|
ctx = TaskContext(content="Test", source="task")
|
||||||
|
scored = ScoredContext(context=ctx, composite_score=0.8)
|
||||||
|
|
||||||
|
result = RankingResult(
|
||||||
|
selected=[scored],
|
||||||
|
excluded=[],
|
||||||
|
total_tokens=100,
|
||||||
|
selection_stats={"total": 1},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result.selected) == 1
|
||||||
|
assert result.total_tokens == 100
|
||||||
|
|
||||||
|
def test_selected_contexts_property(self) -> None:
|
||||||
|
"""Test selected_contexts property extracts contexts."""
|
||||||
|
ctx1 = TaskContext(content="Test 1", source="task")
|
||||||
|
ctx2 = TaskContext(content="Test 2", source="task")
|
||||||
|
|
||||||
|
scored1 = ScoredContext(context=ctx1, composite_score=0.8)
|
||||||
|
scored2 = ScoredContext(context=ctx2, composite_score=0.6)
|
||||||
|
|
||||||
|
result = RankingResult(
|
||||||
|
selected=[scored1, scored2],
|
||||||
|
excluded=[],
|
||||||
|
total_tokens=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
selected = result.selected_contexts
|
||||||
|
assert len(selected) == 2
|
||||||
|
assert ctx1 in selected
|
||||||
|
assert ctx2 in selected
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextRanker:
|
||||||
|
"""Tests for ContextRanker."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test ranker creation."""
|
||||||
|
ranker = ContextRanker()
|
||||||
|
assert ranker._scorer is not None
|
||||||
|
assert ranker._calculator is not None
|
||||||
|
|
||||||
|
def test_creation_with_scorer(self) -> None:
|
||||||
|
"""Test ranker creation with custom scorer."""
|
||||||
|
scorer = CompositeScorer(relevance_weight=0.8)
|
||||||
|
ranker = ContextRanker(scorer=scorer)
|
||||||
|
assert ranker._scorer is scorer
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rank_empty_contexts(self) -> None:
|
||||||
|
"""Test ranking empty context list."""
|
||||||
|
ranker = ContextRanker()
|
||||||
|
allocator = BudgetAllocator()
|
||||||
|
budget = allocator.create_budget(10000)
|
||||||
|
|
||||||
|
result = await ranker.rank([], "query", budget)
|
||||||
|
|
||||||
|
assert len(result.selected) == 0
|
||||||
|
assert len(result.excluded) == 0
|
||||||
|
assert result.total_tokens == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rank_single_context_fits(self) -> None:
|
||||||
|
"""Test ranking single context that fits budget."""
|
||||||
|
ranker = ContextRanker()
|
||||||
|
allocator = BudgetAllocator()
|
||||||
|
budget = allocator.create_budget(10000)
|
||||||
|
|
||||||
|
context = KnowledgeContext(
|
||||||
|
content="Short content",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await ranker.rank([context], "query", budget)
|
||||||
|
|
||||||
|
assert len(result.selected) == 1
|
||||||
|
assert len(result.excluded) == 0
|
||||||
|
assert result.selected[0].context is context
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rank_respects_budget(self) -> None:
|
||||||
|
"""Test that ranking respects token budget."""
|
||||||
|
ranker = ContextRanker()
|
||||||
|
|
||||||
|
# Create a very small budget
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=100,
|
||||||
|
knowledge=50, # Only 50 tokens for knowledge
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create contexts that exceed budget
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="A" * 200, # ~50 tokens
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.9,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="B" * 200, # ~50 tokens
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.8,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="C" * 200, # ~50 tokens
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.7,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await ranker.rank(contexts, "query", budget)
|
||||||
|
|
||||||
|
# Not all should fit
|
||||||
|
assert len(result.selected) < len(contexts)
|
||||||
|
assert len(result.excluded) > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rank_selects_highest_scores(self) -> None:
|
||||||
|
"""Test that ranking selects highest scored contexts."""
|
||||||
|
ranker = ContextRanker()
|
||||||
|
allocator = BudgetAllocator()
|
||||||
|
budget = allocator.create_budget(1000)
|
||||||
|
|
||||||
|
# Small budget for knowledge
|
||||||
|
budget.knowledge = 100
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Low score",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.2,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="High score",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.9,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Medium score",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.5,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await ranker.rank(contexts, "query", budget)
|
||||||
|
|
||||||
|
# Should have selected some
|
||||||
|
if result.selected:
|
||||||
|
# The highest scored should be selected first
|
||||||
|
scores = [s.composite_score for s in result.selected]
|
||||||
|
assert scores == sorted(scores, reverse=True)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rank_critical_priority_always_included(self) -> None:
|
||||||
|
"""Test that CRITICAL priority contexts are always included."""
|
||||||
|
ranker = ContextRanker()
|
||||||
|
|
||||||
|
# Very small budget
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=100,
|
||||||
|
system=10, # Very small
|
||||||
|
knowledge=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
SystemContext(
|
||||||
|
content="Critical system prompt that must be included",
|
||||||
|
source="system",
|
||||||
|
priority=ContextPriority.CRITICAL.value,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Optional knowledge",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.9,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await ranker.rank(contexts, "query", budget, ensure_required=True)
|
||||||
|
|
||||||
|
# Critical context should be in selected
|
||||||
|
selected_priorities = [s.context.priority for s in result.selected]
|
||||||
|
assert ContextPriority.CRITICAL.value in selected_priorities
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rank_without_ensure_required(self) -> None:
|
||||||
|
"""Test ranking without ensuring required contexts."""
|
||||||
|
ranker = ContextRanker()
|
||||||
|
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=100,
|
||||||
|
system=50,
|
||||||
|
knowledge=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
SystemContext(
|
||||||
|
content="A" * 500, # Large content
|
||||||
|
source="system",
|
||||||
|
priority=ContextPriority.CRITICAL.value,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Short",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.9,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await ranker.rank(
|
||||||
|
contexts, "query", budget, ensure_required=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Without ensure_required, CRITICAL contexts can be excluded
|
||||||
|
# if budget doesn't allow
|
||||||
|
assert len(result.selected) + len(result.excluded) == len(contexts)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rank_selection_stats(self) -> None:
|
||||||
|
"""Test that ranking provides useful statistics."""
|
||||||
|
ranker = ContextRanker()
|
||||||
|
allocator = BudgetAllocator()
|
||||||
|
budget = allocator.create_budget(10000)
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Knowledge 1", source="docs", relevance_score=0.8
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Knowledge 2", source="docs", relevance_score=0.6
|
||||||
|
),
|
||||||
|
TaskContext(content="Task", source="task"),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await ranker.rank(contexts, "query", budget)
|
||||||
|
|
||||||
|
stats = result.selection_stats
|
||||||
|
assert "total_contexts" in stats
|
||||||
|
assert "selected_count" in stats
|
||||||
|
assert "excluded_count" in stats
|
||||||
|
assert "total_tokens" in stats
|
||||||
|
assert "by_type" in stats
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rank_simple(self) -> None:
|
||||||
|
"""Test simple ranking without budget per type."""
|
||||||
|
ranker = ContextRanker()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="A",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.9,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="B",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.7,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="C",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.5,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await ranker.rank_simple(contexts, "query", max_tokens=1000)
|
||||||
|
|
||||||
|
# Should return contexts sorted by score
|
||||||
|
assert len(result) > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rank_simple_respects_max_tokens(self) -> None:
|
||||||
|
"""Test that simple ranking respects max tokens."""
|
||||||
|
ranker = ContextRanker()
|
||||||
|
|
||||||
|
# Create contexts with known token counts
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="A" * 100, # ~25 tokens
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.9,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="B" * 100,
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.8,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="C" * 100,
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.7,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Very small limit
|
||||||
|
result = await ranker.rank_simple(contexts, "query", max_tokens=30)
|
||||||
|
|
||||||
|
# Should only fit a limited number
|
||||||
|
assert len(result) <= len(contexts)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rank_simple_empty(self) -> None:
|
||||||
|
"""Test simple ranking with empty list."""
|
||||||
|
ranker = ContextRanker()
|
||||||
|
|
||||||
|
result = await ranker.rank_simple([], "query", max_tokens=1000)
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rerank_for_diversity(self) -> None:
|
||||||
|
"""Test diversity reranking."""
|
||||||
|
ranker = ContextRanker()
|
||||||
|
|
||||||
|
# Create scored contexts from same source
|
||||||
|
contexts = [
|
||||||
|
ScoredContext(
|
||||||
|
context=KnowledgeContext(
|
||||||
|
content=f"Content {i}",
|
||||||
|
source="same-source",
|
||||||
|
relevance_score=0.9 - i * 0.1,
|
||||||
|
),
|
||||||
|
composite_score=0.9 - i * 0.1,
|
||||||
|
)
|
||||||
|
for i in range(5)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Limit to 2 per source
|
||||||
|
result = await ranker.rerank_for_diversity(contexts, max_per_source=2)
|
||||||
|
|
||||||
|
assert len(result) == 5
|
||||||
|
# First 2 should be from same source, rest deferred
|
||||||
|
first_two_sources = [r.context.source for r in result[:2]]
|
||||||
|
assert all(s == "same-source" for s in first_two_sources)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rerank_for_diversity_multiple_sources(self) -> None:
|
||||||
|
"""Test diversity reranking with multiple sources."""
|
||||||
|
ranker = ContextRanker()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
ScoredContext(
|
||||||
|
context=KnowledgeContext(
|
||||||
|
content="Source A - 1",
|
||||||
|
source="source-a",
|
||||||
|
relevance_score=0.9,
|
||||||
|
),
|
||||||
|
composite_score=0.9,
|
||||||
|
),
|
||||||
|
ScoredContext(
|
||||||
|
context=KnowledgeContext(
|
||||||
|
content="Source A - 2",
|
||||||
|
source="source-a",
|
||||||
|
relevance_score=0.8,
|
||||||
|
),
|
||||||
|
composite_score=0.8,
|
||||||
|
),
|
||||||
|
ScoredContext(
|
||||||
|
context=KnowledgeContext(
|
||||||
|
content="Source B - 1",
|
||||||
|
source="source-b",
|
||||||
|
relevance_score=0.7,
|
||||||
|
),
|
||||||
|
composite_score=0.7,
|
||||||
|
),
|
||||||
|
ScoredContext(
|
||||||
|
context=KnowledgeContext(
|
||||||
|
content="Source A - 3",
|
||||||
|
source="source-a",
|
||||||
|
relevance_score=0.6,
|
||||||
|
),
|
||||||
|
composite_score=0.6,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await ranker.rerank_for_diversity(contexts, max_per_source=2)
|
||||||
|
|
||||||
|
# Should not have more than 2 from source-a in first 3
|
||||||
|
source_a_in_first_3 = sum(
|
||||||
|
1 for r in result[:3] if r.context.source == "source-a"
|
||||||
|
)
|
||||||
|
assert source_a_in_first_3 <= 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_token_counts_set(self) -> None:
|
||||||
|
"""Test that token counts are set during ranking."""
|
||||||
|
ranker = ContextRanker()
|
||||||
|
allocator = BudgetAllocator()
|
||||||
|
budget = allocator.create_budget(10000)
|
||||||
|
|
||||||
|
context = KnowledgeContext(
|
||||||
|
content="Test content",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Token count should be None initially
|
||||||
|
assert context.token_count is None
|
||||||
|
|
||||||
|
await ranker.rank([context], "query", budget)
|
||||||
|
|
||||||
|
# Token count should be set after ranking
|
||||||
|
assert context.token_count is not None
|
||||||
|
assert context.token_count > 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextRankerIntegration:
|
||||||
|
"""Integration tests for context ranking."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_ranking_workflow(self) -> None:
|
||||||
|
"""Test complete ranking workflow."""
|
||||||
|
ranker = ContextRanker()
|
||||||
|
allocator = BudgetAllocator()
|
||||||
|
budget = allocator.create_budget(10000)
|
||||||
|
|
||||||
|
# Create diverse context types
|
||||||
|
contexts = [
|
||||||
|
SystemContext(
|
||||||
|
content="You are a helpful assistant.",
|
||||||
|
source="system",
|
||||||
|
priority=ContextPriority.CRITICAL.value,
|
||||||
|
),
|
||||||
|
TaskContext(
|
||||||
|
content="Help the user with their coding question.",
|
||||||
|
source="task",
|
||||||
|
priority=ContextPriority.HIGH.value,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Python is a programming language.",
|
||||||
|
source="docs/python.md",
|
||||||
|
relevance_score=0.9,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Java is also a programming language.",
|
||||||
|
source="docs/java.md",
|
||||||
|
relevance_score=0.4,
|
||||||
|
),
|
||||||
|
ConversationContext(
|
||||||
|
content="Hello, can you help me?",
|
||||||
|
source="chat",
|
||||||
|
role=MessageRole.USER,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await ranker.rank(contexts, "Python help", budget)
|
||||||
|
|
||||||
|
# System (CRITICAL) should be included
|
||||||
|
selected_types = [s.context.get_type() for s in result.selected]
|
||||||
|
assert ContextType.SYSTEM in selected_types
|
||||||
|
|
||||||
|
# Stats should be populated
|
||||||
|
assert result.selection_stats["total_contexts"] == 5
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ranking_preserves_context_order_by_score(self) -> None:
|
||||||
|
"""Test that ranking orders by score correctly."""
|
||||||
|
ranker = ContextRanker()
|
||||||
|
allocator = BudgetAllocator()
|
||||||
|
budget = allocator.create_budget(100000)
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Low",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.2,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="High",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.9,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Medium",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.5,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await ranker.rank(contexts, "query", budget)
|
||||||
|
|
||||||
|
# Verify ordering is by score
|
||||||
|
scores = [s.composite_score for s in result.selected]
|
||||||
|
assert scores == sorted(scores, reverse=True)
|
||||||
712
backend/tests/services/context/test_scoring.py
Normal file
712
backend/tests/services/context/test_scoring.py
Normal file
@@ -0,0 +1,712 @@
|
|||||||
|
"""Tests for context scoring module."""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.context.scoring import (
|
||||||
|
BaseScorer,
|
||||||
|
CompositeScorer,
|
||||||
|
PriorityScorer,
|
||||||
|
RecencyScorer,
|
||||||
|
RelevanceScorer,
|
||||||
|
ScoredContext,
|
||||||
|
)
|
||||||
|
from app.services.context.types import (
|
||||||
|
ContextPriority,
|
||||||
|
ContextType,
|
||||||
|
ConversationContext,
|
||||||
|
KnowledgeContext,
|
||||||
|
MessageRole,
|
||||||
|
SystemContext,
|
||||||
|
TaskContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRelevanceScorer:
|
||||||
|
"""Tests for RelevanceScorer."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test scorer creation."""
|
||||||
|
scorer = RelevanceScorer()
|
||||||
|
assert scorer.weight == 1.0
|
||||||
|
|
||||||
|
def test_creation_with_weight(self) -> None:
|
||||||
|
"""Test scorer creation with custom weight."""
|
||||||
|
scorer = RelevanceScorer(weight=0.5)
|
||||||
|
assert scorer.weight == 0.5
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_with_precomputed_relevance(self) -> None:
|
||||||
|
"""Test scoring with pre-computed relevance score."""
|
||||||
|
scorer = RelevanceScorer()
|
||||||
|
|
||||||
|
# KnowledgeContext with pre-computed score
|
||||||
|
context = KnowledgeContext(
|
||||||
|
content="Test content about Python",
|
||||||
|
source="docs/python.md",
|
||||||
|
relevance_score=0.85,
|
||||||
|
)
|
||||||
|
|
||||||
|
score = await scorer.score(context, "Python programming")
|
||||||
|
assert score == 0.85
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_with_metadata_score(self) -> None:
|
||||||
|
"""Test scoring with metadata-provided score."""
|
||||||
|
scorer = RelevanceScorer()
|
||||||
|
|
||||||
|
context = SystemContext(
|
||||||
|
content="System prompt",
|
||||||
|
source="system",
|
||||||
|
metadata={"relevance_score": 0.9},
|
||||||
|
)
|
||||||
|
|
||||||
|
score = await scorer.score(context, "anything")
|
||||||
|
assert score == 0.9
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_fallback_to_keyword_matching(self) -> None:
|
||||||
|
"""Test fallback to keyword matching when no score available."""
|
||||||
|
scorer = RelevanceScorer(keyword_fallback_weight=0.5)
|
||||||
|
|
||||||
|
context = TaskContext(
|
||||||
|
content="Implement authentication with JWT tokens",
|
||||||
|
source="task",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Query has matching keywords
|
||||||
|
score = await scorer.score(context, "JWT authentication")
|
||||||
|
assert score > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_keyword_matching_no_overlap(self) -> None:
|
||||||
|
"""Test keyword matching with no query overlap."""
|
||||||
|
scorer = RelevanceScorer()
|
||||||
|
|
||||||
|
context = TaskContext(
|
||||||
|
content="Implement database migration",
|
||||||
|
source="task",
|
||||||
|
)
|
||||||
|
|
||||||
|
score = await scorer.score(context, "xyz abc 123")
|
||||||
|
assert score == 0.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_keyword_matching_full_overlap(self) -> None:
|
||||||
|
"""Test keyword matching with high overlap."""
|
||||||
|
scorer = RelevanceScorer(keyword_fallback_weight=1.0)
|
||||||
|
|
||||||
|
context = TaskContext(
|
||||||
|
content="python programming language",
|
||||||
|
source="task",
|
||||||
|
)
|
||||||
|
|
||||||
|
score = await scorer.score(context, "python programming")
|
||||||
|
# Should have high score due to keyword overlap
|
||||||
|
assert score > 0.5
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_with_mcp_success(self) -> None:
|
||||||
|
"""Test scoring with MCP semantic similarity."""
|
||||||
|
mock_mcp = MagicMock()
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.success = True
|
||||||
|
mock_result.data = {"similarity": 0.75}
|
||||||
|
mock_mcp.call_tool = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
|
scorer = RelevanceScorer(mcp_manager=mock_mcp)
|
||||||
|
|
||||||
|
context = TaskContext(
|
||||||
|
content="Test content",
|
||||||
|
source="task",
|
||||||
|
)
|
||||||
|
|
||||||
|
score = await scorer.score(context, "test query")
|
||||||
|
assert score == 0.75
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_with_mcp_failure_fallback(self) -> None:
|
||||||
|
"""Test fallback when MCP fails."""
|
||||||
|
mock_mcp = MagicMock()
|
||||||
|
mock_mcp.call_tool = AsyncMock(side_effect=Exception("Connection failed"))
|
||||||
|
|
||||||
|
scorer = RelevanceScorer(mcp_manager=mock_mcp, keyword_fallback_weight=0.5)
|
||||||
|
|
||||||
|
context = TaskContext(
|
||||||
|
content="Python programming code",
|
||||||
|
source="task",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should fall back to keyword matching
|
||||||
|
score = await scorer.score(context, "Python code")
|
||||||
|
assert score > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_batch(self) -> None:
|
||||||
|
"""Test batch scoring."""
|
||||||
|
scorer = RelevanceScorer()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Python", source="1", relevance_score=0.8
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Java", source="2", relevance_score=0.6
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Go", source="3", relevance_score=0.9
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
scores = await scorer.score_batch(contexts, "test")
|
||||||
|
assert len(scores) == 3
|
||||||
|
assert scores[0] == 0.8
|
||||||
|
assert scores[1] == 0.6
|
||||||
|
assert scores[2] == 0.9
|
||||||
|
|
||||||
|
def test_set_mcp_manager(self) -> None:
|
||||||
|
"""Test setting MCP manager."""
|
||||||
|
scorer = RelevanceScorer()
|
||||||
|
assert scorer._mcp is None
|
||||||
|
|
||||||
|
mock_mcp = MagicMock()
|
||||||
|
scorer.set_mcp_manager(mock_mcp)
|
||||||
|
assert scorer._mcp is mock_mcp
|
||||||
|
|
||||||
|
|
||||||
|
class TestRecencyScorer:
|
||||||
|
"""Tests for RecencyScorer."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test scorer creation."""
|
||||||
|
scorer = RecencyScorer()
|
||||||
|
assert scorer.weight == 1.0
|
||||||
|
assert scorer._half_life_hours == 24.0
|
||||||
|
|
||||||
|
def test_creation_with_custom_half_life(self) -> None:
|
||||||
|
"""Test scorer creation with custom half-life."""
|
||||||
|
scorer = RecencyScorer(half_life_hours=12.0)
|
||||||
|
assert scorer._half_life_hours == 12.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_recent_context(self) -> None:
|
||||||
|
"""Test scoring a very recent context."""
|
||||||
|
scorer = RecencyScorer(half_life_hours=24.0)
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
|
||||||
|
context = TaskContext(
|
||||||
|
content="Recent task",
|
||||||
|
source="task",
|
||||||
|
timestamp=now,
|
||||||
|
)
|
||||||
|
|
||||||
|
score = await scorer.score(context, "query", reference_time=now)
|
||||||
|
# Very recent should have score near 1.0
|
||||||
|
assert score > 0.99
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_at_half_life(self) -> None:
|
||||||
|
"""Test scoring at exactly half-life age."""
|
||||||
|
scorer = RecencyScorer(half_life_hours=24.0)
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
half_life_ago = now - timedelta(hours=24)
|
||||||
|
|
||||||
|
context = TaskContext(
|
||||||
|
content="Day old task",
|
||||||
|
source="task",
|
||||||
|
timestamp=half_life_ago,
|
||||||
|
)
|
||||||
|
|
||||||
|
score = await scorer.score(context, "query", reference_time=now)
|
||||||
|
# At half-life, score should be ~0.5
|
||||||
|
assert 0.49 <= score <= 0.51
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_old_context(self) -> None:
|
||||||
|
"""Test scoring a very old context."""
|
||||||
|
scorer = RecencyScorer(half_life_hours=24.0)
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
week_ago = now - timedelta(days=7)
|
||||||
|
|
||||||
|
context = TaskContext(
|
||||||
|
content="Week old task",
|
||||||
|
source="task",
|
||||||
|
timestamp=week_ago,
|
||||||
|
)
|
||||||
|
|
||||||
|
score = await scorer.score(context, "query", reference_time=now)
|
||||||
|
# 7 days with 24h half-life = very low score
|
||||||
|
assert score < 0.01
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_type_specific_half_lives(self) -> None:
|
||||||
|
"""Test that different context types have different half-lives."""
|
||||||
|
scorer = RecencyScorer()
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
one_hour_ago = now - timedelta(hours=1)
|
||||||
|
|
||||||
|
# Conversation has 1 hour half-life by default
|
||||||
|
conv_context = ConversationContext(
|
||||||
|
content="Hello",
|
||||||
|
source="chat",
|
||||||
|
role=MessageRole.USER,
|
||||||
|
timestamp=one_hour_ago,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Knowledge has 168 hour (1 week) half-life by default
|
||||||
|
knowledge_context = KnowledgeContext(
|
||||||
|
content="Documentation",
|
||||||
|
source="docs",
|
||||||
|
timestamp=one_hour_ago,
|
||||||
|
)
|
||||||
|
|
||||||
|
conv_score = await scorer.score(conv_context, "query", reference_time=now)
|
||||||
|
knowledge_score = await scorer.score(knowledge_context, "query", reference_time=now)
|
||||||
|
|
||||||
|
# Conversation should decay much faster
|
||||||
|
assert conv_score < knowledge_score
|
||||||
|
|
||||||
|
def test_get_half_life(self) -> None:
|
||||||
|
"""Test getting half-life for context type."""
|
||||||
|
scorer = RecencyScorer()
|
||||||
|
|
||||||
|
assert scorer.get_half_life(ContextType.CONVERSATION) == 1.0
|
||||||
|
assert scorer.get_half_life(ContextType.KNOWLEDGE) == 168.0
|
||||||
|
assert scorer.get_half_life(ContextType.SYSTEM) == 720.0
|
||||||
|
|
||||||
|
def test_set_half_life(self) -> None:
|
||||||
|
"""Test setting custom half-life."""
|
||||||
|
scorer = RecencyScorer()
|
||||||
|
|
||||||
|
scorer.set_half_life(ContextType.TASK, 48.0)
|
||||||
|
assert scorer.get_half_life(ContextType.TASK) == 48.0
|
||||||
|
|
||||||
|
def test_set_half_life_invalid(self) -> None:
|
||||||
|
"""Test setting invalid half-life."""
|
||||||
|
scorer = RecencyScorer()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
scorer.set_half_life(ContextType.TASK, 0)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
scorer.set_half_life(ContextType.TASK, -1)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_batch(self) -> None:
|
||||||
|
"""Test batch scoring."""
|
||||||
|
scorer = RecencyScorer()
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
TaskContext(content="1", source="t", timestamp=now),
|
||||||
|
TaskContext(
|
||||||
|
content="2", source="t", timestamp=now - timedelta(hours=24)
|
||||||
|
),
|
||||||
|
TaskContext(
|
||||||
|
content="3", source="t", timestamp=now - timedelta(hours=48)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
scores = await scorer.score_batch(contexts, "query", reference_time=now)
|
||||||
|
assert len(scores) == 3
|
||||||
|
# Scores should be in descending order (more recent = higher)
|
||||||
|
assert scores[0] > scores[1] > scores[2]
|
||||||
|
|
||||||
|
|
||||||
|
class TestPriorityScorer:
|
||||||
|
"""Tests for PriorityScorer."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test scorer creation."""
|
||||||
|
scorer = PriorityScorer()
|
||||||
|
assert scorer.weight == 1.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_critical_priority(self) -> None:
|
||||||
|
"""Test scoring CRITICAL priority context."""
|
||||||
|
scorer = PriorityScorer()
|
||||||
|
|
||||||
|
context = SystemContext(
|
||||||
|
content="Critical system prompt",
|
||||||
|
source="system",
|
||||||
|
priority=ContextPriority.CRITICAL.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
score = await scorer.score(context, "query")
|
||||||
|
# CRITICAL (100) + type bonus should be > 1.0, normalized to 1.0
|
||||||
|
assert score == 1.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_normal_priority(self) -> None:
|
||||||
|
"""Test scoring NORMAL priority context."""
|
||||||
|
scorer = PriorityScorer()
|
||||||
|
|
||||||
|
context = TaskContext(
|
||||||
|
content="Normal task",
|
||||||
|
source="task",
|
||||||
|
priority=ContextPriority.NORMAL.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
score = await scorer.score(context, "query")
|
||||||
|
# NORMAL (50) = 0.5, plus TASK bonus (0.15) = 0.65
|
||||||
|
assert 0.6 <= score <= 0.7
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_low_priority(self) -> None:
|
||||||
|
"""Test scoring LOW priority context."""
|
||||||
|
scorer = PriorityScorer()
|
||||||
|
|
||||||
|
context = KnowledgeContext(
|
||||||
|
content="Low priority knowledge",
|
||||||
|
source="docs",
|
||||||
|
priority=ContextPriority.LOW.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
score = await scorer.score(context, "query")
|
||||||
|
# LOW (20) = 0.2, no bonus for KNOWLEDGE
|
||||||
|
assert 0.15 <= score <= 0.25
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_type_bonuses(self) -> None:
|
||||||
|
"""Test type-specific priority bonuses."""
|
||||||
|
scorer = PriorityScorer()
|
||||||
|
|
||||||
|
# All with same base priority
|
||||||
|
system_ctx = SystemContext(
|
||||||
|
content="System",
|
||||||
|
source="system",
|
||||||
|
priority=50,
|
||||||
|
)
|
||||||
|
task_ctx = TaskContext(
|
||||||
|
content="Task",
|
||||||
|
source="task",
|
||||||
|
priority=50,
|
||||||
|
)
|
||||||
|
knowledge_ctx = KnowledgeContext(
|
||||||
|
content="Knowledge",
|
||||||
|
source="docs",
|
||||||
|
priority=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
system_score = await scorer.score(system_ctx, "query")
|
||||||
|
task_score = await scorer.score(task_ctx, "query")
|
||||||
|
knowledge_score = await scorer.score(knowledge_ctx, "query")
|
||||||
|
|
||||||
|
# System has highest bonus (0.2), task next (0.15), knowledge has none
|
||||||
|
assert system_score > task_score > knowledge_score
|
||||||
|
|
||||||
|
def test_get_type_bonus(self) -> None:
|
||||||
|
"""Test getting type bonus."""
|
||||||
|
scorer = PriorityScorer()
|
||||||
|
|
||||||
|
assert scorer.get_type_bonus(ContextType.SYSTEM) == 0.2
|
||||||
|
assert scorer.get_type_bonus(ContextType.TASK) == 0.15
|
||||||
|
assert scorer.get_type_bonus(ContextType.KNOWLEDGE) == 0.0
|
||||||
|
|
||||||
|
def test_set_type_bonus(self) -> None:
|
||||||
|
"""Test setting custom type bonus."""
|
||||||
|
scorer = PriorityScorer()
|
||||||
|
|
||||||
|
scorer.set_type_bonus(ContextType.KNOWLEDGE, 0.1)
|
||||||
|
assert scorer.get_type_bonus(ContextType.KNOWLEDGE) == 0.1
|
||||||
|
|
||||||
|
def test_set_type_bonus_invalid(self) -> None:
|
||||||
|
"""Test setting invalid type bonus."""
|
||||||
|
scorer = PriorityScorer()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
scorer.set_type_bonus(ContextType.KNOWLEDGE, 1.5)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
scorer.set_type_bonus(ContextType.KNOWLEDGE, -0.1)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompositeScorer:
|
||||||
|
"""Tests for CompositeScorer."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test scorer creation with default weights."""
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
|
||||||
|
weights = scorer.weights
|
||||||
|
assert weights["relevance"] == 0.5
|
||||||
|
assert weights["recency"] == 0.3
|
||||||
|
assert weights["priority"] == 0.2
|
||||||
|
|
||||||
|
def test_creation_with_custom_weights(self) -> None:
|
||||||
|
"""Test scorer creation with custom weights."""
|
||||||
|
scorer = CompositeScorer(
|
||||||
|
relevance_weight=0.6,
|
||||||
|
recency_weight=0.2,
|
||||||
|
priority_weight=0.2,
|
||||||
|
)
|
||||||
|
|
||||||
|
weights = scorer.weights
|
||||||
|
assert weights["relevance"] == 0.6
|
||||||
|
assert weights["recency"] == 0.2
|
||||||
|
assert weights["priority"] == 0.2
|
||||||
|
|
||||||
|
def test_update_weights(self) -> None:
|
||||||
|
"""Test updating weights."""
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
|
||||||
|
scorer.update_weights(relevance=0.7, recency=0.2, priority=0.1)
|
||||||
|
|
||||||
|
weights = scorer.weights
|
||||||
|
assert weights["relevance"] == 0.7
|
||||||
|
assert weights["recency"] == 0.2
|
||||||
|
assert weights["priority"] == 0.1
|
||||||
|
|
||||||
|
def test_update_weights_partial(self) -> None:
|
||||||
|
"""Test partially updating weights."""
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
original_recency = scorer.weights["recency"]
|
||||||
|
|
||||||
|
scorer.update_weights(relevance=0.8)
|
||||||
|
|
||||||
|
assert scorer.weights["relevance"] == 0.8
|
||||||
|
assert scorer.weights["recency"] == original_recency
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_basic(self) -> None:
|
||||||
|
"""Test basic composite scoring."""
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
|
||||||
|
context = KnowledgeContext(
|
||||||
|
content="Test content",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.8,
|
||||||
|
timestamp=datetime.now(UTC),
|
||||||
|
priority=ContextPriority.NORMAL.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
score = await scorer.score(context, "test query")
|
||||||
|
assert 0.0 <= score <= 1.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_with_details(self) -> None:
|
||||||
|
"""Test scoring with detailed breakdown."""
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
|
||||||
|
context = KnowledgeContext(
|
||||||
|
content="Test content",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.8,
|
||||||
|
timestamp=datetime.now(UTC),
|
||||||
|
priority=ContextPriority.HIGH.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
scored = await scorer.score_with_details(context, "test query")
|
||||||
|
|
||||||
|
assert isinstance(scored, ScoredContext)
|
||||||
|
assert scored.context is context
|
||||||
|
assert 0.0 <= scored.composite_score <= 1.0
|
||||||
|
assert scored.relevance_score == 0.8
|
||||||
|
assert scored.recency_score > 0.9 # Very recent
|
||||||
|
assert scored.priority_score > 0.5 # HIGH priority
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_cached_on_context(self) -> None:
|
||||||
|
"""Test that score is cached on the context."""
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
|
||||||
|
context = KnowledgeContext(
|
||||||
|
content="Test",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# First scoring
|
||||||
|
await scorer.score(context, "query")
|
||||||
|
assert context._score is not None
|
||||||
|
|
||||||
|
# Second scoring should use cached value
|
||||||
|
context._score = 0.999 # Set to a known value
|
||||||
|
score2 = await scorer.score(context, "query")
|
||||||
|
assert score2 == 0.999
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_batch(self) -> None:
|
||||||
|
"""Test batch scoring."""
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="High relevance",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.9,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Low relevance",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.2,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
scored = await scorer.score_batch(contexts, "query")
|
||||||
|
assert len(scored) == 2
|
||||||
|
assert scored[0].relevance_score > scored[1].relevance_score
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rank(self) -> None:
|
||||||
|
"""Test ranking contexts."""
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Low", source="docs", relevance_score=0.2
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="High", source="docs", relevance_score=0.9
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Medium", source="docs", relevance_score=0.5
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
ranked = await scorer.rank(contexts, "query")
|
||||||
|
|
||||||
|
# Should be sorted by score (highest first)
|
||||||
|
assert len(ranked) == 3
|
||||||
|
assert ranked[0].relevance_score == 0.9
|
||||||
|
assert ranked[1].relevance_score == 0.5
|
||||||
|
assert ranked[2].relevance_score == 0.2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rank_with_limit(self) -> None:
|
||||||
|
"""Test ranking with limit."""
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content=str(i), source="docs", relevance_score=i / 10
|
||||||
|
)
|
||||||
|
for i in range(10)
|
||||||
|
]
|
||||||
|
|
||||||
|
ranked = await scorer.rank(contexts, "query", limit=3)
|
||||||
|
assert len(ranked) == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rank_with_min_score(self) -> None:
|
||||||
|
"""Test ranking with minimum score threshold."""
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Low", source="docs", relevance_score=0.1
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="High", source="docs", relevance_score=0.9
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
ranked = await scorer.rank(contexts, "query", min_score=0.5)
|
||||||
|
|
||||||
|
# Only the high relevance context should pass the threshold
|
||||||
|
assert len(ranked) <= 2 # Could be 1 if min_score filters
|
||||||
|
|
||||||
|
def test_set_mcp_manager(self) -> None:
|
||||||
|
"""Test setting MCP manager."""
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
mock_mcp = MagicMock()
|
||||||
|
|
||||||
|
scorer.set_mcp_manager(mock_mcp)
|
||||||
|
assert scorer._relevance_scorer._mcp is mock_mcp
|
||||||
|
|
||||||
|
|
||||||
|
class TestScoredContext:
|
||||||
|
"""Tests for ScoredContext dataclass."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test ScoredContext creation."""
|
||||||
|
context = TaskContext(content="Test", source="task")
|
||||||
|
scored = ScoredContext(
|
||||||
|
context=context,
|
||||||
|
composite_score=0.75,
|
||||||
|
relevance_score=0.8,
|
||||||
|
recency_score=0.7,
|
||||||
|
priority_score=0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert scored.context is context
|
||||||
|
assert scored.composite_score == 0.75
|
||||||
|
|
||||||
|
def test_comparison_operators(self) -> None:
|
||||||
|
"""Test comparison operators for sorting."""
|
||||||
|
ctx1 = TaskContext(content="1", source="task")
|
||||||
|
ctx2 = TaskContext(content="2", source="task")
|
||||||
|
|
||||||
|
scored1 = ScoredContext(context=ctx1, composite_score=0.5)
|
||||||
|
scored2 = ScoredContext(context=ctx2, composite_score=0.8)
|
||||||
|
|
||||||
|
assert scored1 < scored2
|
||||||
|
assert scored2 > scored1
|
||||||
|
|
||||||
|
def test_sorting(self) -> None:
|
||||||
|
"""Test sorting scored contexts."""
|
||||||
|
contexts = [
|
||||||
|
ScoredContext(
|
||||||
|
context=TaskContext(content="Low", source="task"),
|
||||||
|
composite_score=0.3,
|
||||||
|
),
|
||||||
|
ScoredContext(
|
||||||
|
context=TaskContext(content="High", source="task"),
|
||||||
|
composite_score=0.9,
|
||||||
|
),
|
||||||
|
ScoredContext(
|
||||||
|
context=TaskContext(content="Medium", source="task"),
|
||||||
|
composite_score=0.6,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
sorted_contexts = sorted(contexts, reverse=True)
|
||||||
|
|
||||||
|
assert sorted_contexts[0].composite_score == 0.9
|
||||||
|
assert sorted_contexts[1].composite_score == 0.6
|
||||||
|
assert sorted_contexts[2].composite_score == 0.3
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseScorer:
|
||||||
|
"""Tests for BaseScorer abstract class."""
|
||||||
|
|
||||||
|
def test_weight_property(self) -> None:
|
||||||
|
"""Test weight property."""
|
||||||
|
# Use a concrete implementation
|
||||||
|
scorer = RelevanceScorer(weight=0.7)
|
||||||
|
assert scorer.weight == 0.7
|
||||||
|
|
||||||
|
def test_weight_setter_valid(self) -> None:
|
||||||
|
"""Test weight setter with valid values."""
|
||||||
|
scorer = RelevanceScorer()
|
||||||
|
scorer.weight = 0.5
|
||||||
|
assert scorer.weight == 0.5
|
||||||
|
|
||||||
|
def test_weight_setter_invalid(self) -> None:
|
||||||
|
"""Test weight setter with invalid values."""
|
||||||
|
scorer = RelevanceScorer()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
scorer.weight = -0.1
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
scorer.weight = 1.5
|
||||||
|
|
||||||
|
def test_normalize_score(self) -> None:
|
||||||
|
"""Test score normalization."""
|
||||||
|
scorer = RelevanceScorer()
|
||||||
|
|
||||||
|
# Normal range
|
||||||
|
assert scorer.normalize_score(0.5) == 0.5
|
||||||
|
|
||||||
|
# Below 0
|
||||||
|
assert scorer.normalize_score(-0.5) == 0.0
|
||||||
|
|
||||||
|
# Above 1
|
||||||
|
assert scorer.normalize_score(1.5) == 1.0
|
||||||
|
|
||||||
|
# Boundaries
|
||||||
|
assert scorer.normalize_score(0.0) == 0.0
|
||||||
|
assert scorer.normalize_score(1.0) == 1.0
|
||||||
@@ -286,9 +286,20 @@ class TestTaskContext:
|
|||||||
assert ctx.title == "Login Feature"
|
assert ctx.title == "Login Feature"
|
||||||
assert ctx.get_type() == ContextType.TASK
|
assert ctx.get_type() == ContextType.TASK
|
||||||
|
|
||||||
def test_default_high_priority(self) -> None:
|
def test_default_normal_priority(self) -> None:
|
||||||
"""Test that task context defaults to high priority."""
|
"""Test that task context uses NORMAL priority from base class."""
|
||||||
ctx = TaskContext(content="Test", source="test")
|
ctx = TaskContext(content="Test", source="test")
|
||||||
|
# TaskContext inherits NORMAL priority from BaseContext
|
||||||
|
# Use TaskContext.create() for default HIGH priority behavior
|
||||||
|
assert ctx.priority == ContextPriority.NORMAL.value
|
||||||
|
|
||||||
|
def test_explicit_high_priority(self) -> None:
|
||||||
|
"""Test setting explicit HIGH priority."""
|
||||||
|
ctx = TaskContext(
|
||||||
|
content="Test",
|
||||||
|
source="test",
|
||||||
|
priority=ContextPriority.HIGH.value,
|
||||||
|
)
|
||||||
assert ctx.priority == ContextPriority.HIGH.value
|
assert ctx.priority == ContextPriority.HIGH.value
|
||||||
|
|
||||||
def test_create_factory(self) -> None:
|
def test_create_factory(self) -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user