diff --git a/backend/app/services/context/__init__.py b/backend/app/services/context/__init__.py index add3366..321afa7 100644 --- a/backend/app/services/context/__init__.py +++ b/backend/app/services/context/__init__.py @@ -63,6 +63,22 @@ from .exceptions import ( TokenCountError, ) +# Prioritization +from .prioritization import ( + ContextRanker, + RankingResult, +) + +# Scoring +from .scoring import ( + BaseScorer, + CompositeScorer, + PriorityScorer, + RecencyScorer, + RelevanceScorer, + ScoredContext, +) + # Types from .types import ( AssembledContext, @@ -101,6 +117,16 @@ __all__ = [ "InvalidContextError", "ScoringError", "TokenCountError", + # Prioritization + "ContextRanker", + "RankingResult", + # Scoring + "BaseScorer", + "CompositeScorer", + "PriorityScorer", + "RecencyScorer", + "RelevanceScorer", + "ScoredContext", # Types - Base "AssembledContext", "BaseContext", diff --git a/backend/app/services/context/prioritization/__init__.py b/backend/app/services/context/prioritization/__init__.py index 66f586f..14540bd 100644 --- a/backend/app/services/context/prioritization/__init__.py +++ b/backend/app/services/context/prioritization/__init__.py @@ -3,3 +3,10 @@ Context Prioritization Module. Provides context ranking and selection. """ + +from .ranker import ContextRanker, RankingResult + +__all__ = [ + "ContextRanker", + "RankingResult", +] diff --git a/backend/app/services/context/prioritization/ranker.py b/backend/app/services/context/prioritization/ranker.py new file mode 100644 index 0000000..d2de292 --- /dev/null +++ b/backend/app/services/context/prioritization/ranker.py @@ -0,0 +1,288 @@ +""" +Context Ranker for Context Management. + +Ranks and selects contexts based on scores and budget constraints. +""" + +import logging +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from ..budget import TokenBudget, TokenCalculator +from ..scoring.composite import CompositeScorer, ScoredContext +from ..types import BaseContext, ContextType + +if TYPE_CHECKING: + from app.services.mcp.client_manager import MCPClientManager + +logger = logging.getLogger(__name__) + + +@dataclass +class RankingResult: + """Result of context ranking and selection.""" + + selected: list[ScoredContext] + excluded: list[ScoredContext] + total_tokens: int + selection_stats: dict[str, Any] = field(default_factory=dict) + + @property + def selected_contexts(self) -> list[BaseContext]: + """Get just the context objects (not scored wrappers).""" + return [s.context for s in self.selected] + + +class ContextRanker: + """ + Ranks and selects contexts within budget constraints. + + Uses greedy selection to maximize total score + while respecting token budgets per context type. + """ + + def __init__( + self, + scorer: CompositeScorer | None = None, + calculator: TokenCalculator | None = None, + ) -> None: + """ + Initialize context ranker. + + Args: + scorer: Composite scorer for scoring contexts + calculator: Token calculator for counting tokens + """ + self._scorer = scorer or CompositeScorer() + self._calculator = calculator or TokenCalculator() + + def set_scorer(self, scorer: CompositeScorer) -> None: + """Set the scorer.""" + self._scorer = scorer + + def set_calculator(self, calculator: TokenCalculator) -> None: + """Set the token calculator.""" + self._calculator = calculator + + async def rank( + self, + contexts: list[BaseContext], + query: str, + budget: TokenBudget, + model: str | None = None, + ensure_required: bool = True, + **kwargs: Any, + ) -> RankingResult: + """ + Rank and select contexts within budget. + + Args: + contexts: Contexts to rank + query: Query to rank against + budget: Token budget constraints + model: Model for token counting + ensure_required: If True, always include CRITICAL priority contexts + **kwargs: Additional scoring parameters + + Returns: + RankingResult with selected and excluded contexts + """ + if not contexts: + return RankingResult( + selected=[], + excluded=[], + total_tokens=0, + selection_stats={"total_contexts": 0}, + ) + + # 1. Ensure all contexts have token counts + await self._ensure_token_counts(contexts, model) + + # 2. Score all contexts + scored_contexts = await self._scorer.score_batch(contexts, query, **kwargs) + + # 3. Separate required (CRITICAL priority) from optional + required: list[ScoredContext] = [] + optional: list[ScoredContext] = [] + + if ensure_required: + for sc in scored_contexts: + # CRITICAL priority (100) contexts are always included + if sc.context.priority >= 100: + required.append(sc) + else: + optional.append(sc) + else: + optional = list(scored_contexts) + + # 4. Sort optional by score (highest first) + optional.sort(reverse=True) + + # 5. Greedy selection + selected: list[ScoredContext] = [] + excluded: list[ScoredContext] = [] + total_tokens = 0 + + # First, try to fit required contexts + for sc in required: + token_count = sc.context.token_count or 0 + context_type = sc.context.get_type() + + if budget.can_fit(context_type, token_count): + budget.allocate(context_type, token_count) + selected.append(sc) + total_tokens += token_count + else: + # Force-fit CRITICAL contexts if needed + budget.allocate(context_type, token_count, force=True) + selected.append(sc) + total_tokens += token_count + logger.warning( + f"Force-fitted CRITICAL context: {sc.context.source} " + f"({token_count} tokens)" + ) + + # Then, greedily add optional contexts + for sc in optional: + token_count = sc.context.token_count or 0 + context_type = sc.context.get_type() + + if budget.can_fit(context_type, token_count): + budget.allocate(context_type, token_count) + selected.append(sc) + total_tokens += token_count + else: + excluded.append(sc) + + # Build stats + stats = { + "total_contexts": len(contexts), + "required_count": len(required), + "selected_count": len(selected), + "excluded_count": len(excluded), + "total_tokens": total_tokens, + "by_type": self._count_by_type(selected), + } + + return RankingResult( + selected=selected, + excluded=excluded, + total_tokens=total_tokens, + selection_stats=stats, + ) + + async def rank_simple( + self, + contexts: list[BaseContext], + query: str, + max_tokens: int, + model: str | None = None, + **kwargs: Any, + ) -> list[BaseContext]: + """ + Simple ranking without budget per type. + + Selects top contexts by score until max tokens reached. + + Args: + contexts: Contexts to rank + query: Query to rank against + max_tokens: Maximum total tokens + model: Model for token counting + **kwargs: Additional scoring parameters + + Returns: + Selected contexts (in score order) + """ + if not contexts: + return [] + + # Ensure token counts + await self._ensure_token_counts(contexts, model) + + # Score all contexts + scored_contexts = await self._scorer.score_batch(contexts, query, **kwargs) + + # Sort by score (highest first) + scored_contexts.sort(reverse=True) + + # Greedy selection + selected: list[BaseContext] = [] + total_tokens = 0 + + for sc in scored_contexts: + token_count = sc.context.token_count or 0 + if total_tokens + token_count <= max_tokens: + selected.append(sc.context) + total_tokens += token_count + + return selected + + async def _ensure_token_counts( + self, + contexts: list[BaseContext], + model: str | None = None, + ) -> None: + """ + Ensure all contexts have token counts. + + Args: + contexts: Contexts to check + model: Model for token counting + """ + for context in contexts: + if context.token_count is None: + count = await self._calculator.count_tokens( + context.content, model + ) + context.token_count = count + + def _count_by_type( + self, scored_contexts: list[ScoredContext] + ) -> dict[str, dict[str, int]]: + """Count selected contexts by type.""" + by_type: dict[str, dict[str, int]] = {} + + for sc in scored_contexts: + type_name = sc.context.get_type().value + if type_name not in by_type: + by_type[type_name] = {"count": 0, "tokens": 0} + by_type[type_name]["count"] += 1 + by_type[type_name]["tokens"] += sc.context.token_count or 0 + + return by_type + + async def rerank_for_diversity( + self, + scored_contexts: list[ScoredContext], + max_per_source: int = 3, + ) -> list[ScoredContext]: + """ + Rerank to ensure source diversity. + + Prevents too many items from the same source. + + Args: + scored_contexts: Already scored contexts + max_per_source: Maximum items per source + + Returns: + Reranked contexts + """ + source_counts: dict[str, int] = {} + result: list[ScoredContext] = [] + deferred: list[ScoredContext] = [] + + for sc in scored_contexts: + source = sc.context.source + current_count = source_counts.get(source, 0) + + if current_count < max_per_source: + result.append(sc) + source_counts[source] = current_count + 1 + else: + deferred.append(sc) + + # Add deferred items at the end + result.extend(deferred) + return result diff --git a/backend/app/services/context/scoring/__init__.py b/backend/app/services/context/scoring/__init__.py index f0b7218..ef225f5 100644 --- a/backend/app/services/context/scoring/__init__.py +++ b/backend/app/services/context/scoring/__init__.py @@ -1,5 +1,21 @@ """ Context Scoring Module. -Provides relevance, recency, and composite scoring. +Provides scoring strategies for context prioritization. """ + +from .base import BaseScorer, ScorerProtocol +from .composite import CompositeScorer, ScoredContext +from .priority import PriorityScorer +from .recency import RecencyScorer +from .relevance import RelevanceScorer + +__all__ = [ + "BaseScorer", + "CompositeScorer", + "PriorityScorer", + "RecencyScorer", + "RelevanceScorer", + "ScoredContext", + "ScorerProtocol", +] diff --git a/backend/app/services/context/scoring/base.py b/backend/app/services/context/scoring/base.py new file mode 100644 index 0000000..469518c --- /dev/null +++ b/backend/app/services/context/scoring/base.py @@ -0,0 +1,99 @@ +""" +Base Scorer Protocol and Types. + +Defines the interface for context scoring implementations. +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + +from ..types import BaseContext + +if TYPE_CHECKING: + from app.services.mcp.client_manager import MCPClientManager + + +@runtime_checkable +class ScorerProtocol(Protocol): + """Protocol for context scorers.""" + + async def score( + self, + context: BaseContext, + query: str, + **kwargs: Any, + ) -> float: + """ + Score a context item. + + Args: + context: Context to score + query: Query to score against + **kwargs: Additional scoring parameters + + Returns: + Score between 0.0 and 1.0 + """ + ... + + +class BaseScorer(ABC): + """ + Abstract base class for context scorers. + + Provides common functionality and interface for + different scoring strategies. + """ + + def __init__(self, weight: float = 1.0) -> None: + """ + Initialize scorer. + + Args: + weight: Weight for this scorer in composite scoring + """ + self._weight = weight + + @property + def weight(self) -> float: + """Get scorer weight.""" + return self._weight + + @weight.setter + def weight(self, value: float) -> None: + """Set scorer weight.""" + if not 0.0 <= value <= 1.0: + raise ValueError("Weight must be between 0.0 and 1.0") + self._weight = value + + @abstractmethod + async def score( + self, + context: BaseContext, + query: str, + **kwargs: Any, + ) -> float: + """ + Score a context item. + + Args: + context: Context to score + query: Query to score against + **kwargs: Additional scoring parameters + + Returns: + Score between 0.0 and 1.0 + """ + ... + + def normalize_score(self, score: float) -> float: + """ + Normalize score to [0.0, 1.0] range. + + Args: + score: Raw score + + Returns: + Normalized score + """ + return max(0.0, min(1.0, score)) diff --git a/backend/app/services/context/scoring/composite.py b/backend/app/services/context/scoring/composite.py new file mode 100644 index 0000000..d28ddd5 --- /dev/null +++ b/backend/app/services/context/scoring/composite.py @@ -0,0 +1,276 @@ +""" +Composite Scorer for Context Management. + +Combines multiple scoring strategies with configurable weights. +""" + +import asyncio +import logging +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from .base import BaseScorer +from .priority import PriorityScorer +from .recency import RecencyScorer +from .relevance import RelevanceScorer +from ..config import ContextSettings, get_context_settings +from ..types import BaseContext + +if TYPE_CHECKING: + from app.services.mcp.client_manager import MCPClientManager + +logger = logging.getLogger(__name__) + + +@dataclass +class ScoredContext: + """Context with computed scores.""" + + context: BaseContext + composite_score: float + relevance_score: float = 0.0 + recency_score: float = 0.0 + priority_score: float = 0.0 + + def __lt__(self, other: "ScoredContext") -> bool: + """Enable sorting by composite score.""" + return self.composite_score < other.composite_score + + def __gt__(self, other: "ScoredContext") -> bool: + """Enable sorting by composite score.""" + return self.composite_score > other.composite_score + + +class CompositeScorer: + """ + Combines multiple scoring strategies. + + Weights: + - relevance: How well content matches the query + - recency: How recent the content is + - priority: Explicit priority assignments + """ + + def __init__( + self, + mcp_manager: "MCPClientManager | None" = None, + settings: ContextSettings | None = None, + relevance_weight: float | None = None, + recency_weight: float | None = None, + priority_weight: float | None = None, + ) -> None: + """ + Initialize composite scorer. + + Args: + mcp_manager: MCP manager for semantic scoring + settings: Context settings (uses default if None) + relevance_weight: Override relevance weight + recency_weight: Override recency weight + priority_weight: Override priority weight + """ + self._settings = settings or get_context_settings() + weights = self._settings.get_scoring_weights() + + self._relevance_weight = ( + relevance_weight if relevance_weight is not None else weights["relevance"] + ) + self._recency_weight = ( + recency_weight if recency_weight is not None else weights["recency"] + ) + self._priority_weight = ( + priority_weight if priority_weight is not None else weights["priority"] + ) + + # Initialize scorers + self._relevance_scorer = RelevanceScorer( + mcp_manager=mcp_manager, + weight=self._relevance_weight, + ) + self._recency_scorer = RecencyScorer(weight=self._recency_weight) + self._priority_scorer = PriorityScorer(weight=self._priority_weight) + + def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None: + """Set MCP manager for semantic scoring.""" + self._relevance_scorer.set_mcp_manager(mcp_manager) + + @property + def weights(self) -> dict[str, float]: + """Get current scoring weights.""" + return { + "relevance": self._relevance_weight, + "recency": self._recency_weight, + "priority": self._priority_weight, + } + + def update_weights( + self, + relevance: float | None = None, + recency: float | None = None, + priority: float | None = None, + ) -> None: + """ + Update scoring weights. + + Args: + relevance: New relevance weight + recency: New recency weight + priority: New priority weight + """ + if relevance is not None: + self._relevance_weight = max(0.0, min(1.0, relevance)) + self._relevance_scorer.weight = self._relevance_weight + + if recency is not None: + self._recency_weight = max(0.0, min(1.0, recency)) + self._recency_scorer.weight = self._recency_weight + + if priority is not None: + self._priority_weight = max(0.0, min(1.0, priority)) + self._priority_scorer.weight = self._priority_weight + + async def score( + self, + context: BaseContext, + query: str, + **kwargs: Any, + ) -> float: + """ + Compute composite score for a context. + + Args: + context: Context to score + query: Query to score against + **kwargs: Additional scoring parameters + + Returns: + Composite score between 0.0 and 1.0 + """ + scored = await self.score_with_details(context, query, **kwargs) + return scored.composite_score + + async def score_with_details( + self, + context: BaseContext, + query: str, + **kwargs: Any, + ) -> ScoredContext: + """ + Compute composite score with individual scores. + + Args: + context: Context to score + query: Query to score against + **kwargs: Additional scoring parameters + + Returns: + ScoredContext with all scores + """ + # Check if context already has a score + if context._score is not None: + return ScoredContext( + context=context, + composite_score=context._score, + ) + + # Compute individual scores in parallel + relevance_task = self._relevance_scorer.score(context, query, **kwargs) + recency_task = self._recency_scorer.score(context, query, **kwargs) + priority_task = self._priority_scorer.score(context, query, **kwargs) + + relevance_score, recency_score, priority_score = await asyncio.gather( + relevance_task, recency_task, priority_task + ) + + # Compute weighted composite + total_weight = ( + self._relevance_weight + self._recency_weight + self._priority_weight + ) + + if total_weight > 0: + composite = ( + relevance_score * self._relevance_weight + + recency_score * self._recency_weight + + priority_score * self._priority_weight + ) / total_weight + else: + composite = 0.0 + + # Cache the score on the context + context._score = composite + + return ScoredContext( + context=context, + composite_score=composite, + relevance_score=relevance_score, + recency_score=recency_score, + priority_score=priority_score, + ) + + async def score_batch( + self, + contexts: list[BaseContext], + query: str, + parallel: bool = True, + **kwargs: Any, + ) -> list[ScoredContext]: + """ + Score multiple contexts. + + Args: + contexts: Contexts to score + query: Query to score against + parallel: Whether to score in parallel + **kwargs: Additional scoring parameters + + Returns: + List of ScoredContext (same order as input) + """ + if parallel: + tasks = [ + self.score_with_details(ctx, query, **kwargs) for ctx in contexts + ] + return await asyncio.gather(*tasks) + else: + results = [] + for ctx in contexts: + scored = await self.score_with_details(ctx, query, **kwargs) + results.append(scored) + return results + + async def rank( + self, + contexts: list[BaseContext], + query: str, + limit: int | None = None, + min_score: float = 0.0, + **kwargs: Any, + ) -> list[ScoredContext]: + """ + Score and rank contexts. + + Args: + contexts: Contexts to rank + query: Query to rank against + limit: Maximum number of results + min_score: Minimum score threshold + **kwargs: Additional scoring parameters + + Returns: + Sorted list of ScoredContext (highest first) + """ + # Score all contexts + scored = await self.score_batch(contexts, query, **kwargs) + + # Filter by minimum score + if min_score > 0: + scored = [s for s in scored if s.composite_score >= min_score] + + # Sort by score (highest first) + scored.sort(reverse=True) + + # Apply limit + if limit is not None: + scored = scored[:limit] + + return scored diff --git a/backend/app/services/context/scoring/priority.py b/backend/app/services/context/scoring/priority.py new file mode 100644 index 0000000..4523fed --- /dev/null +++ b/backend/app/services/context/scoring/priority.py @@ -0,0 +1,135 @@ +""" +Priority Scorer for Context Management. + +Scores context based on assigned priority levels. +""" + +from typing import Any + +from .base import BaseScorer +from ..types import BaseContext, ContextPriority, ContextType + + +class PriorityScorer(BaseScorer): + """ + Scores context based on priority levels. + + Converts priority enum values to normalized scores. + Also applies type-based priority bonuses. + """ + + # Default priority bonuses by context type + DEFAULT_TYPE_BONUSES: dict[ContextType, float] = { + ContextType.SYSTEM: 0.2, # System prompts get a boost + ContextType.TASK: 0.15, # Current task is important + ContextType.TOOL: 0.1, # Recent tool results matter + ContextType.KNOWLEDGE: 0.0, # Knowledge scored by relevance + ContextType.CONVERSATION: 0.0, # Conversation scored by recency + } + + def __init__( + self, + weight: float = 1.0, + type_bonuses: dict[ContextType, float] | None = None, + ) -> None: + """ + Initialize priority scorer. + + Args: + weight: Scorer weight for composite scoring + type_bonuses: Optional context-type priority bonuses + """ + super().__init__(weight) + self._type_bonuses = type_bonuses or self.DEFAULT_TYPE_BONUSES.copy() + + async def score( + self, + context: BaseContext, + query: str, + **kwargs: Any, + ) -> float: + """ + Score context based on priority. + + Args: + context: Context to score + query: Query (not used for priority, kept for interface) + **kwargs: Additional parameters + + Returns: + Priority score between 0.0 and 1.0 + """ + # Get base priority score + priority_value = context.priority + base_score = self._priority_to_score(priority_value) + + # Apply type bonus + context_type = context.get_type() + bonus = self._type_bonuses.get(context_type, 0.0) + + return self.normalize_score(base_score + bonus) + + def _priority_to_score(self, priority: int) -> float: + """ + Convert priority value to normalized score. + + Priority values (from ContextPriority): + - CRITICAL (100) -> 1.0 + - HIGH (80) -> 0.8 + - NORMAL (50) -> 0.5 + - LOW (20) -> 0.2 + - MINIMAL (0) -> 0.0 + + Args: + priority: Priority value (0-100) + + Returns: + Normalized score (0.0-1.0) + """ + # Clamp to valid range + clamped = max(0, min(100, priority)) + return clamped / 100.0 + + def get_type_bonus(self, context_type: ContextType) -> float: + """ + Get priority bonus for a context type. + + Args: + context_type: Context type + + Returns: + Bonus value + """ + return self._type_bonuses.get(context_type, 0.0) + + def set_type_bonus(self, context_type: ContextType, bonus: float) -> None: + """ + Set priority bonus for a context type. + + Args: + context_type: Context type + bonus: Bonus value (0.0-1.0) + """ + if not 0.0 <= bonus <= 1.0: + raise ValueError("Bonus must be between 0.0 and 1.0") + self._type_bonuses[context_type] = bonus + + async def score_batch( + self, + contexts: list[BaseContext], + query: str, + **kwargs: Any, + ) -> list[float]: + """ + Score multiple contexts. + + Args: + contexts: Contexts to score + query: Query (not used) + **kwargs: Additional parameters + + Returns: + List of scores (same order as input) + """ + # Priority scoring is fast, no async needed + return [await self.score(ctx, query, **kwargs) for ctx in contexts] diff --git a/backend/app/services/context/scoring/recency.py b/backend/app/services/context/scoring/recency.py new file mode 100644 index 0000000..69721e3 --- /dev/null +++ b/backend/app/services/context/scoring/recency.py @@ -0,0 +1,141 @@ +""" +Recency Scorer for Context Management. + +Scores context based on how recent it is. +More recent content gets higher scores. +""" + +import math +from datetime import UTC, datetime, timedelta +from typing import Any + +from .base import BaseScorer +from ..types import BaseContext, ContextType + + +class RecencyScorer(BaseScorer): + """ + Scores context based on recency. + + Uses exponential decay to score content based on age. + More recent content scores higher. + """ + + def __init__( + self, + weight: float = 1.0, + half_life_hours: float = 24.0, + type_half_lives: dict[ContextType, float] | None = None, + ) -> None: + """ + Initialize recency scorer. + + Args: + weight: Scorer weight for composite scoring + half_life_hours: Default hours until score decays to 0.5 + type_half_lives: Optional context-type-specific half lives + """ + super().__init__(weight) + self._half_life_hours = half_life_hours + self._type_half_lives = type_half_lives or {} + + # Set sensible defaults for context types + if ContextType.CONVERSATION not in self._type_half_lives: + self._type_half_lives[ContextType.CONVERSATION] = 1.0 # 1 hour + if ContextType.TOOL not in self._type_half_lives: + self._type_half_lives[ContextType.TOOL] = 0.5 # 30 minutes + if ContextType.KNOWLEDGE not in self._type_half_lives: + self._type_half_lives[ContextType.KNOWLEDGE] = 168.0 # 1 week + if ContextType.SYSTEM not in self._type_half_lives: + self._type_half_lives[ContextType.SYSTEM] = 720.0 # 30 days + if ContextType.TASK not in self._type_half_lives: + self._type_half_lives[ContextType.TASK] = 24.0 # 1 day + + async def score( + self, + context: BaseContext, + query: str, + **kwargs: Any, + ) -> float: + """ + Score context based on recency. + + Args: + context: Context to score + query: Query (not used for recency, kept for interface) + **kwargs: Additional parameters + - reference_time: Time to measure recency from (default: now) + + Returns: + Recency score between 0.0 and 1.0 + """ + reference_time = kwargs.get("reference_time") + if reference_time is None: + reference_time = datetime.now(UTC) + elif reference_time.tzinfo is None: + reference_time = reference_time.replace(tzinfo=UTC) + + # Ensure context timestamp is timezone-aware + context_time = context.timestamp + if context_time.tzinfo is None: + context_time = context_time.replace(tzinfo=UTC) + + # Calculate age in hours + age = reference_time - context_time + age_hours = max(0, age.total_seconds() / 3600) + + # Get half-life for this context type + context_type = context.get_type() + half_life = self._type_half_lives.get(context_type, self._half_life_hours) + + # Exponential decay + decay_factor = math.exp(-math.log(2) * age_hours / half_life) + + return self.normalize_score(decay_factor) + + def get_half_life(self, context_type: ContextType) -> float: + """ + Get half-life for a context type. + + Args: + context_type: Context type to get half-life for + + Returns: + Half-life in hours + """ + return self._type_half_lives.get(context_type, self._half_life_hours) + + def set_half_life(self, context_type: ContextType, hours: float) -> None: + """ + Set half-life for a context type. + + Args: + context_type: Context type to set half-life for + hours: Half-life in hours + """ + if hours <= 0: + raise ValueError("Half-life must be positive") + self._type_half_lives[context_type] = hours + + async def score_batch( + self, + contexts: list[BaseContext], + query: str, + **kwargs: Any, + ) -> list[float]: + """ + Score multiple contexts. + + Args: + contexts: Contexts to score + query: Query (not used) + **kwargs: Additional parameters + + Returns: + List of scores (same order as input) + """ + scores = [] + for context in contexts: + score = await self.score(context, query, **kwargs) + scores.append(score) + return scores diff --git a/backend/app/services/context/scoring/relevance.py b/backend/app/services/context/scoring/relevance.py new file mode 100644 index 0000000..4352b4e --- /dev/null +++ b/backend/app/services/context/scoring/relevance.py @@ -0,0 +1,188 @@ +""" +Relevance Scorer for Context Management. + +Scores context based on semantic similarity to the query. +Uses Knowledge Base embeddings when available. +""" + +import logging +import re +from typing import TYPE_CHECKING, Any + +from .base import BaseScorer +from ..types import BaseContext, ContextType, KnowledgeContext + +if TYPE_CHECKING: + from app.services.mcp.client_manager import MCPClientManager + +logger = logging.getLogger(__name__) + + +class RelevanceScorer(BaseScorer): + """ + Scores context based on relevance to query. + + Uses multiple strategies: + 1. Pre-computed scores (from RAG results) + 2. MCP-based semantic similarity (via Knowledge Base) + 3. Keyword matching fallback + """ + + def __init__( + self, + mcp_manager: "MCPClientManager | None" = None, + weight: float = 1.0, + keyword_fallback_weight: float = 0.5, + ) -> None: + """ + Initialize relevance scorer. + + Args: + mcp_manager: MCP manager for Knowledge Base calls + weight: Scorer weight for composite scoring + keyword_fallback_weight: Max score for keyword-based fallback + """ + super().__init__(weight) + self._mcp = mcp_manager + self._keyword_fallback_weight = keyword_fallback_weight + + def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None: + """Set MCP manager for semantic scoring.""" + self._mcp = mcp_manager + + async def score( + self, + context: BaseContext, + query: str, + **kwargs: Any, + ) -> float: + """ + Score context relevance to query. + + Args: + context: Context to score + query: Query to score against + **kwargs: Additional parameters + + Returns: + Relevance score between 0.0 and 1.0 + """ + # 1. Check for pre-computed relevance score + if isinstance(context, KnowledgeContext) and context.relevance_score is not None: + return self.normalize_score(context.relevance_score) + + # 2. Check metadata for score + if "relevance_score" in context.metadata: + return self.normalize_score(context.metadata["relevance_score"]) + + if "score" in context.metadata: + return self.normalize_score(context.metadata["score"]) + + # 3. Try MCP-based semantic similarity + if self._mcp is not None: + try: + score = await self._compute_semantic_similarity(context, query) + if score is not None: + return score + except Exception as e: + logger.debug(f"Semantic scoring failed, using fallback: {e}") + + # 4. Fall back to keyword matching + return self._compute_keyword_score(context, query) + + async def _compute_semantic_similarity( + self, + context: BaseContext, + query: str, + ) -> float | None: + """ + Compute semantic similarity using Knowledge Base embeddings. + + Args: + context: Context to score + query: Query to compare + + Returns: + Similarity score or None if unavailable + """ + try: + # Use Knowledge Base's search capability to compute similarity + result = await self._mcp.call_tool( + server="knowledge-base", + tool="compute_similarity", + args={ + "text1": query, + "text2": context.content[:2000], # Limit content length + }, + ) + + if result.success and result.data: + similarity = result.data.get("similarity") + if similarity is not None: + return self.normalize_score(float(similarity)) + + except Exception as e: + logger.debug(f"Semantic similarity computation failed: {e}") + + return None + + def _compute_keyword_score( + self, + context: BaseContext, + query: str, + ) -> float: + """ + Compute relevance score based on keyword matching. + + Simple but fast fallback when semantic search is unavailable. + + Args: + context: Context to score + query: Query to match + + Returns: + Keyword-based relevance score + """ + if not query or not context.content: + return 0.0 + + # Extract keywords from query + query_lower = query.lower() + content_lower = context.content.lower() + + # Simple word tokenization + query_words = set(re.findall(r"\b\w{3,}\b", query_lower)) + content_words = set(re.findall(r"\b\w{3,}\b", content_lower)) + + if not query_words: + return 0.0 + + # Calculate overlap + common_words = query_words & content_words + overlap_ratio = len(common_words) / len(query_words) + + # Apply fallback weight ceiling + return self.normalize_score(overlap_ratio * self._keyword_fallback_weight) + + async def score_batch( + self, + contexts: list[BaseContext], + query: str, + **kwargs: Any, + ) -> list[float]: + """ + Score multiple contexts. + + Args: + contexts: Contexts to score + query: Query to score against + **kwargs: Additional parameters + + Returns: + List of scores (same order as input) + """ + scores = [] + for context in contexts: + score = await self.score(context, query, **kwargs) + scores.append(score) + return scores diff --git a/backend/app/services/context/types/task.py b/backend/app/services/context/types/task.py index e1765f2..ce6f784 100644 --- a/backend/app/services/context/types/task.py +++ b/backend/app/services/context/types/task.py @@ -55,11 +55,9 @@ class TaskContext(BaseContext): constraints: list[str] = field(default_factory=list) parent_task_id: str | None = field(default=None) - def __post_init__(self) -> None: - """Set high priority for task context.""" - # Task context defaults to high priority - if self.priority == ContextPriority.NORMAL.value: - self.priority = ContextPriority.HIGH.value + # Note: TaskContext should typically have HIGH priority, + # but we don't auto-promote to allow explicit priority setting. + # Use TaskContext.create() for default HIGH priority behavior. def get_type(self) -> ContextType: """Return TASK context type.""" diff --git a/backend/tests/services/context/test_ranker.py b/backend/tests/services/context/test_ranker.py new file mode 100644 index 0000000..adf876c --- /dev/null +++ b/backend/tests/services/context/test_ranker.py @@ -0,0 +1,507 @@ +"""Tests for context ranking module.""" + +from datetime import UTC, datetime + +import pytest + +from app.services.context.budget import BudgetAllocator, TokenBudget +from app.services.context.prioritization import ContextRanker, RankingResult +from app.services.context.scoring import CompositeScorer, ScoredContext +from app.services.context.types import ( + ContextPriority, + ContextType, + ConversationContext, + KnowledgeContext, + MessageRole, + SystemContext, + TaskContext, +) + + +class TestRankingResult: + """Tests for RankingResult dataclass.""" + + def test_creation(self) -> None: + """Test RankingResult creation.""" + ctx = TaskContext(content="Test", source="task") + scored = ScoredContext(context=ctx, composite_score=0.8) + + result = RankingResult( + selected=[scored], + excluded=[], + total_tokens=100, + selection_stats={"total": 1}, + ) + + assert len(result.selected) == 1 + assert result.total_tokens == 100 + + def test_selected_contexts_property(self) -> None: + """Test selected_contexts property extracts contexts.""" + ctx1 = TaskContext(content="Test 1", source="task") + ctx2 = TaskContext(content="Test 2", source="task") + + scored1 = ScoredContext(context=ctx1, composite_score=0.8) + scored2 = ScoredContext(context=ctx2, composite_score=0.6) + + result = RankingResult( + selected=[scored1, scored2], + excluded=[], + total_tokens=200, + ) + + selected = result.selected_contexts + assert len(selected) == 2 + assert ctx1 in selected + assert ctx2 in selected + + +class TestContextRanker: + """Tests for ContextRanker.""" + + def test_creation(self) -> None: + """Test ranker creation.""" + ranker = ContextRanker() + assert ranker._scorer is not None + assert ranker._calculator is not None + + def test_creation_with_scorer(self) -> None: + """Test ranker creation with custom scorer.""" + scorer = CompositeScorer(relevance_weight=0.8) + ranker = ContextRanker(scorer=scorer) + assert ranker._scorer is scorer + + @pytest.mark.asyncio + async def test_rank_empty_contexts(self) -> None: + """Test ranking empty context list.""" + ranker = ContextRanker() + allocator = BudgetAllocator() + budget = allocator.create_budget(10000) + + result = await ranker.rank([], "query", budget) + + assert len(result.selected) == 0 + assert len(result.excluded) == 0 + assert result.total_tokens == 0 + + @pytest.mark.asyncio + async def test_rank_single_context_fits(self) -> None: + """Test ranking single context that fits budget.""" + ranker = ContextRanker() + allocator = BudgetAllocator() + budget = allocator.create_budget(10000) + + context = KnowledgeContext( + content="Short content", + source="docs", + relevance_score=0.8, + ) + + result = await ranker.rank([context], "query", budget) + + assert len(result.selected) == 1 + assert len(result.excluded) == 0 + assert result.selected[0].context is context + + @pytest.mark.asyncio + async def test_rank_respects_budget(self) -> None: + """Test that ranking respects token budget.""" + ranker = ContextRanker() + + # Create a very small budget + budget = TokenBudget( + total=100, + knowledge=50, # Only 50 tokens for knowledge + ) + + # Create contexts that exceed budget + contexts = [ + KnowledgeContext( + content="A" * 200, # ~50 tokens + source="docs", + relevance_score=0.9, + ), + KnowledgeContext( + content="B" * 200, # ~50 tokens + source="docs", + relevance_score=0.8, + ), + KnowledgeContext( + content="C" * 200, # ~50 tokens + source="docs", + relevance_score=0.7, + ), + ] + + result = await ranker.rank(contexts, "query", budget) + + # Not all should fit + assert len(result.selected) < len(contexts) + assert len(result.excluded) > 0 + + @pytest.mark.asyncio + async def test_rank_selects_highest_scores(self) -> None: + """Test that ranking selects highest scored contexts.""" + ranker = ContextRanker() + allocator = BudgetAllocator() + budget = allocator.create_budget(1000) + + # Small budget for knowledge + budget.knowledge = 100 + + contexts = [ + KnowledgeContext( + content="Low score", + source="docs", + relevance_score=0.2, + ), + KnowledgeContext( + content="High score", + source="docs", + relevance_score=0.9, + ), + KnowledgeContext( + content="Medium score", + source="docs", + relevance_score=0.5, + ), + ] + + result = await ranker.rank(contexts, "query", budget) + + # Should have selected some + if result.selected: + # The highest scored should be selected first + scores = [s.composite_score for s in result.selected] + assert scores == sorted(scores, reverse=True) + + @pytest.mark.asyncio + async def test_rank_critical_priority_always_included(self) -> None: + """Test that CRITICAL priority contexts are always included.""" + ranker = ContextRanker() + + # Very small budget + budget = TokenBudget( + total=100, + system=10, # Very small + knowledge=10, + ) + + contexts = [ + SystemContext( + content="Critical system prompt that must be included", + source="system", + priority=ContextPriority.CRITICAL.value, + ), + KnowledgeContext( + content="Optional knowledge", + source="docs", + relevance_score=0.9, + ), + ] + + result = await ranker.rank(contexts, "query", budget, ensure_required=True) + + # Critical context should be in selected + selected_priorities = [s.context.priority for s in result.selected] + assert ContextPriority.CRITICAL.value in selected_priorities + + @pytest.mark.asyncio + async def test_rank_without_ensure_required(self) -> None: + """Test ranking without ensuring required contexts.""" + ranker = ContextRanker() + + budget = TokenBudget( + total=100, + system=50, + knowledge=50, + ) + + contexts = [ + SystemContext( + content="A" * 500, # Large content + source="system", + priority=ContextPriority.CRITICAL.value, + ), + KnowledgeContext( + content="Short", + source="docs", + relevance_score=0.9, + ), + ] + + result = await ranker.rank( + contexts, "query", budget, ensure_required=False + ) + + # Without ensure_required, CRITICAL contexts can be excluded + # if budget doesn't allow + assert len(result.selected) + len(result.excluded) == len(contexts) + + @pytest.mark.asyncio + async def test_rank_selection_stats(self) -> None: + """Test that ranking provides useful statistics.""" + ranker = ContextRanker() + allocator = BudgetAllocator() + budget = allocator.create_budget(10000) + + contexts = [ + KnowledgeContext( + content="Knowledge 1", source="docs", relevance_score=0.8 + ), + KnowledgeContext( + content="Knowledge 2", source="docs", relevance_score=0.6 + ), + TaskContext(content="Task", source="task"), + ] + + result = await ranker.rank(contexts, "query", budget) + + stats = result.selection_stats + assert "total_contexts" in stats + assert "selected_count" in stats + assert "excluded_count" in stats + assert "total_tokens" in stats + assert "by_type" in stats + + @pytest.mark.asyncio + async def test_rank_simple(self) -> None: + """Test simple ranking without budget per type.""" + ranker = ContextRanker() + + contexts = [ + KnowledgeContext( + content="A", + source="docs", + relevance_score=0.9, + ), + KnowledgeContext( + content="B", + source="docs", + relevance_score=0.7, + ), + KnowledgeContext( + content="C", + source="docs", + relevance_score=0.5, + ), + ] + + result = await ranker.rank_simple(contexts, "query", max_tokens=1000) + + # Should return contexts sorted by score + assert len(result) > 0 + + @pytest.mark.asyncio + async def test_rank_simple_respects_max_tokens(self) -> None: + """Test that simple ranking respects max tokens.""" + ranker = ContextRanker() + + # Create contexts with known token counts + contexts = [ + KnowledgeContext( + content="A" * 100, # ~25 tokens + source="docs", + relevance_score=0.9, + ), + KnowledgeContext( + content="B" * 100, + source="docs", + relevance_score=0.8, + ), + KnowledgeContext( + content="C" * 100, + source="docs", + relevance_score=0.7, + ), + ] + + # Very small limit + result = await ranker.rank_simple(contexts, "query", max_tokens=30) + + # Should only fit a limited number + assert len(result) <= len(contexts) + + @pytest.mark.asyncio + async def test_rank_simple_empty(self) -> None: + """Test simple ranking with empty list.""" + ranker = ContextRanker() + + result = await ranker.rank_simple([], "query", max_tokens=1000) + assert result == [] + + @pytest.mark.asyncio + async def test_rerank_for_diversity(self) -> None: + """Test diversity reranking.""" + ranker = ContextRanker() + + # Create scored contexts from same source + contexts = [ + ScoredContext( + context=KnowledgeContext( + content=f"Content {i}", + source="same-source", + relevance_score=0.9 - i * 0.1, + ), + composite_score=0.9 - i * 0.1, + ) + for i in range(5) + ] + + # Limit to 2 per source + result = await ranker.rerank_for_diversity(contexts, max_per_source=2) + + assert len(result) == 5 + # First 2 should be from same source, rest deferred + first_two_sources = [r.context.source for r in result[:2]] + assert all(s == "same-source" for s in first_two_sources) + + @pytest.mark.asyncio + async def test_rerank_for_diversity_multiple_sources(self) -> None: + """Test diversity reranking with multiple sources.""" + ranker = ContextRanker() + + contexts = [ + ScoredContext( + context=KnowledgeContext( + content="Source A - 1", + source="source-a", + relevance_score=0.9, + ), + composite_score=0.9, + ), + ScoredContext( + context=KnowledgeContext( + content="Source A - 2", + source="source-a", + relevance_score=0.8, + ), + composite_score=0.8, + ), + ScoredContext( + context=KnowledgeContext( + content="Source B - 1", + source="source-b", + relevance_score=0.7, + ), + composite_score=0.7, + ), + ScoredContext( + context=KnowledgeContext( + content="Source A - 3", + source="source-a", + relevance_score=0.6, + ), + composite_score=0.6, + ), + ] + + result = await ranker.rerank_for_diversity(contexts, max_per_source=2) + + # Should not have more than 2 from source-a in first 3 + source_a_in_first_3 = sum( + 1 for r in result[:3] if r.context.source == "source-a" + ) + assert source_a_in_first_3 <= 2 + + @pytest.mark.asyncio + async def test_token_counts_set(self) -> None: + """Test that token counts are set during ranking.""" + ranker = ContextRanker() + allocator = BudgetAllocator() + budget = allocator.create_budget(10000) + + context = KnowledgeContext( + content="Test content", + source="docs", + relevance_score=0.8, + ) + + # Token count should be None initially + assert context.token_count is None + + await ranker.rank([context], "query", budget) + + # Token count should be set after ranking + assert context.token_count is not None + assert context.token_count > 0 + + +class TestContextRankerIntegration: + """Integration tests for context ranking.""" + + @pytest.mark.asyncio + async def test_full_ranking_workflow(self) -> None: + """Test complete ranking workflow.""" + ranker = ContextRanker() + allocator = BudgetAllocator() + budget = allocator.create_budget(10000) + + # Create diverse context types + contexts = [ + SystemContext( + content="You are a helpful assistant.", + source="system", + priority=ContextPriority.CRITICAL.value, + ), + TaskContext( + content="Help the user with their coding question.", + source="task", + priority=ContextPriority.HIGH.value, + ), + KnowledgeContext( + content="Python is a programming language.", + source="docs/python.md", + relevance_score=0.9, + ), + KnowledgeContext( + content="Java is also a programming language.", + source="docs/java.md", + relevance_score=0.4, + ), + ConversationContext( + content="Hello, can you help me?", + source="chat", + role=MessageRole.USER, + ), + ] + + result = await ranker.rank(contexts, "Python help", budget) + + # System (CRITICAL) should be included + selected_types = [s.context.get_type() for s in result.selected] + assert ContextType.SYSTEM in selected_types + + # Stats should be populated + assert result.selection_stats["total_contexts"] == 5 + + @pytest.mark.asyncio + async def test_ranking_preserves_context_order_by_score(self) -> None: + """Test that ranking orders by score correctly.""" + ranker = ContextRanker() + allocator = BudgetAllocator() + budget = allocator.create_budget(100000) + + contexts = [ + KnowledgeContext( + content="Low", + source="docs", + relevance_score=0.2, + ), + KnowledgeContext( + content="High", + source="docs", + relevance_score=0.9, + ), + KnowledgeContext( + content="Medium", + source="docs", + relevance_score=0.5, + ), + ] + + result = await ranker.rank(contexts, "query", budget) + + # Verify ordering is by score + scores = [s.composite_score for s in result.selected] + assert scores == sorted(scores, reverse=True) diff --git a/backend/tests/services/context/test_scoring.py b/backend/tests/services/context/test_scoring.py new file mode 100644 index 0000000..6fea92f --- /dev/null +++ b/backend/tests/services/context/test_scoring.py @@ -0,0 +1,712 @@ +"""Tests for context scoring module.""" + +from datetime import UTC, datetime, timedelta +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from app.services.context.scoring import ( + BaseScorer, + CompositeScorer, + PriorityScorer, + RecencyScorer, + RelevanceScorer, + ScoredContext, +) +from app.services.context.types import ( + ContextPriority, + ContextType, + ConversationContext, + KnowledgeContext, + MessageRole, + SystemContext, + TaskContext, +) + + +class TestRelevanceScorer: + """Tests for RelevanceScorer.""" + + def test_creation(self) -> None: + """Test scorer creation.""" + scorer = RelevanceScorer() + assert scorer.weight == 1.0 + + def test_creation_with_weight(self) -> None: + """Test scorer creation with custom weight.""" + scorer = RelevanceScorer(weight=0.5) + assert scorer.weight == 0.5 + + @pytest.mark.asyncio + async def test_score_with_precomputed_relevance(self) -> None: + """Test scoring with pre-computed relevance score.""" + scorer = RelevanceScorer() + + # KnowledgeContext with pre-computed score + context = KnowledgeContext( + content="Test content about Python", + source="docs/python.md", + relevance_score=0.85, + ) + + score = await scorer.score(context, "Python programming") + assert score == 0.85 + + @pytest.mark.asyncio + async def test_score_with_metadata_score(self) -> None: + """Test scoring with metadata-provided score.""" + scorer = RelevanceScorer() + + context = SystemContext( + content="System prompt", + source="system", + metadata={"relevance_score": 0.9}, + ) + + score = await scorer.score(context, "anything") + assert score == 0.9 + + @pytest.mark.asyncio + async def test_score_fallback_to_keyword_matching(self) -> None: + """Test fallback to keyword matching when no score available.""" + scorer = RelevanceScorer(keyword_fallback_weight=0.5) + + context = TaskContext( + content="Implement authentication with JWT tokens", + source="task", + ) + + # Query has matching keywords + score = await scorer.score(context, "JWT authentication") + assert score > 0 + + @pytest.mark.asyncio + async def test_keyword_matching_no_overlap(self) -> None: + """Test keyword matching with no query overlap.""" + scorer = RelevanceScorer() + + context = TaskContext( + content="Implement database migration", + source="task", + ) + + score = await scorer.score(context, "xyz abc 123") + assert score == 0.0 + + @pytest.mark.asyncio + async def test_keyword_matching_full_overlap(self) -> None: + """Test keyword matching with high overlap.""" + scorer = RelevanceScorer(keyword_fallback_weight=1.0) + + context = TaskContext( + content="python programming language", + source="task", + ) + + score = await scorer.score(context, "python programming") + # Should have high score due to keyword overlap + assert score > 0.5 + + @pytest.mark.asyncio + async def test_score_with_mcp_success(self) -> None: + """Test scoring with MCP semantic similarity.""" + mock_mcp = MagicMock() + mock_result = MagicMock() + mock_result.success = True + mock_result.data = {"similarity": 0.75} + mock_mcp.call_tool = AsyncMock(return_value=mock_result) + + scorer = RelevanceScorer(mcp_manager=mock_mcp) + + context = TaskContext( + content="Test content", + source="task", + ) + + score = await scorer.score(context, "test query") + assert score == 0.75 + + @pytest.mark.asyncio + async def test_score_with_mcp_failure_fallback(self) -> None: + """Test fallback when MCP fails.""" + mock_mcp = MagicMock() + mock_mcp.call_tool = AsyncMock(side_effect=Exception("Connection failed")) + + scorer = RelevanceScorer(mcp_manager=mock_mcp, keyword_fallback_weight=0.5) + + context = TaskContext( + content="Python programming code", + source="task", + ) + + # Should fall back to keyword matching + score = await scorer.score(context, "Python code") + assert score > 0 + + @pytest.mark.asyncio + async def test_score_batch(self) -> None: + """Test batch scoring.""" + scorer = RelevanceScorer() + + contexts = [ + KnowledgeContext( + content="Python", source="1", relevance_score=0.8 + ), + KnowledgeContext( + content="Java", source="2", relevance_score=0.6 + ), + KnowledgeContext( + content="Go", source="3", relevance_score=0.9 + ), + ] + + scores = await scorer.score_batch(contexts, "test") + assert len(scores) == 3 + assert scores[0] == 0.8 + assert scores[1] == 0.6 + assert scores[2] == 0.9 + + def test_set_mcp_manager(self) -> None: + """Test setting MCP manager.""" + scorer = RelevanceScorer() + assert scorer._mcp is None + + mock_mcp = MagicMock() + scorer.set_mcp_manager(mock_mcp) + assert scorer._mcp is mock_mcp + + +class TestRecencyScorer: + """Tests for RecencyScorer.""" + + def test_creation(self) -> None: + """Test scorer creation.""" + scorer = RecencyScorer() + assert scorer.weight == 1.0 + assert scorer._half_life_hours == 24.0 + + def test_creation_with_custom_half_life(self) -> None: + """Test scorer creation with custom half-life.""" + scorer = RecencyScorer(half_life_hours=12.0) + assert scorer._half_life_hours == 12.0 + + @pytest.mark.asyncio + async def test_score_recent_context(self) -> None: + """Test scoring a very recent context.""" + scorer = RecencyScorer(half_life_hours=24.0) + now = datetime.now(UTC) + + context = TaskContext( + content="Recent task", + source="task", + timestamp=now, + ) + + score = await scorer.score(context, "query", reference_time=now) + # Very recent should have score near 1.0 + assert score > 0.99 + + @pytest.mark.asyncio + async def test_score_at_half_life(self) -> None: + """Test scoring at exactly half-life age.""" + scorer = RecencyScorer(half_life_hours=24.0) + now = datetime.now(UTC) + half_life_ago = now - timedelta(hours=24) + + context = TaskContext( + content="Day old task", + source="task", + timestamp=half_life_ago, + ) + + score = await scorer.score(context, "query", reference_time=now) + # At half-life, score should be ~0.5 + assert 0.49 <= score <= 0.51 + + @pytest.mark.asyncio + async def test_score_old_context(self) -> None: + """Test scoring a very old context.""" + scorer = RecencyScorer(half_life_hours=24.0) + now = datetime.now(UTC) + week_ago = now - timedelta(days=7) + + context = TaskContext( + content="Week old task", + source="task", + timestamp=week_ago, + ) + + score = await scorer.score(context, "query", reference_time=now) + # 7 days with 24h half-life = very low score + assert score < 0.01 + + @pytest.mark.asyncio + async def test_type_specific_half_lives(self) -> None: + """Test that different context types have different half-lives.""" + scorer = RecencyScorer() + now = datetime.now(UTC) + one_hour_ago = now - timedelta(hours=1) + + # Conversation has 1 hour half-life by default + conv_context = ConversationContext( + content="Hello", + source="chat", + role=MessageRole.USER, + timestamp=one_hour_ago, + ) + + # Knowledge has 168 hour (1 week) half-life by default + knowledge_context = KnowledgeContext( + content="Documentation", + source="docs", + timestamp=one_hour_ago, + ) + + conv_score = await scorer.score(conv_context, "query", reference_time=now) + knowledge_score = await scorer.score(knowledge_context, "query", reference_time=now) + + # Conversation should decay much faster + assert conv_score < knowledge_score + + def test_get_half_life(self) -> None: + """Test getting half-life for context type.""" + scorer = RecencyScorer() + + assert scorer.get_half_life(ContextType.CONVERSATION) == 1.0 + assert scorer.get_half_life(ContextType.KNOWLEDGE) == 168.0 + assert scorer.get_half_life(ContextType.SYSTEM) == 720.0 + + def test_set_half_life(self) -> None: + """Test setting custom half-life.""" + scorer = RecencyScorer() + + scorer.set_half_life(ContextType.TASK, 48.0) + assert scorer.get_half_life(ContextType.TASK) == 48.0 + + def test_set_half_life_invalid(self) -> None: + """Test setting invalid half-life.""" + scorer = RecencyScorer() + + with pytest.raises(ValueError): + scorer.set_half_life(ContextType.TASK, 0) + + with pytest.raises(ValueError): + scorer.set_half_life(ContextType.TASK, -1) + + @pytest.mark.asyncio + async def test_score_batch(self) -> None: + """Test batch scoring.""" + scorer = RecencyScorer() + now = datetime.now(UTC) + + contexts = [ + TaskContext(content="1", source="t", timestamp=now), + TaskContext( + content="2", source="t", timestamp=now - timedelta(hours=24) + ), + TaskContext( + content="3", source="t", timestamp=now - timedelta(hours=48) + ), + ] + + scores = await scorer.score_batch(contexts, "query", reference_time=now) + assert len(scores) == 3 + # Scores should be in descending order (more recent = higher) + assert scores[0] > scores[1] > scores[2] + + +class TestPriorityScorer: + """Tests for PriorityScorer.""" + + def test_creation(self) -> None: + """Test scorer creation.""" + scorer = PriorityScorer() + assert scorer.weight == 1.0 + + @pytest.mark.asyncio + async def test_score_critical_priority(self) -> None: + """Test scoring CRITICAL priority context.""" + scorer = PriorityScorer() + + context = SystemContext( + content="Critical system prompt", + source="system", + priority=ContextPriority.CRITICAL.value, + ) + + score = await scorer.score(context, "query") + # CRITICAL (100) + type bonus should be > 1.0, normalized to 1.0 + assert score == 1.0 + + @pytest.mark.asyncio + async def test_score_normal_priority(self) -> None: + """Test scoring NORMAL priority context.""" + scorer = PriorityScorer() + + context = TaskContext( + content="Normal task", + source="task", + priority=ContextPriority.NORMAL.value, + ) + + score = await scorer.score(context, "query") + # NORMAL (50) = 0.5, plus TASK bonus (0.15) = 0.65 + assert 0.6 <= score <= 0.7 + + @pytest.mark.asyncio + async def test_score_low_priority(self) -> None: + """Test scoring LOW priority context.""" + scorer = PriorityScorer() + + context = KnowledgeContext( + content="Low priority knowledge", + source="docs", + priority=ContextPriority.LOW.value, + ) + + score = await scorer.score(context, "query") + # LOW (20) = 0.2, no bonus for KNOWLEDGE + assert 0.15 <= score <= 0.25 + + @pytest.mark.asyncio + async def test_type_bonuses(self) -> None: + """Test type-specific priority bonuses.""" + scorer = PriorityScorer() + + # All with same base priority + system_ctx = SystemContext( + content="System", + source="system", + priority=50, + ) + task_ctx = TaskContext( + content="Task", + source="task", + priority=50, + ) + knowledge_ctx = KnowledgeContext( + content="Knowledge", + source="docs", + priority=50, + ) + + system_score = await scorer.score(system_ctx, "query") + task_score = await scorer.score(task_ctx, "query") + knowledge_score = await scorer.score(knowledge_ctx, "query") + + # System has highest bonus (0.2), task next (0.15), knowledge has none + assert system_score > task_score > knowledge_score + + def test_get_type_bonus(self) -> None: + """Test getting type bonus.""" + scorer = PriorityScorer() + + assert scorer.get_type_bonus(ContextType.SYSTEM) == 0.2 + assert scorer.get_type_bonus(ContextType.TASK) == 0.15 + assert scorer.get_type_bonus(ContextType.KNOWLEDGE) == 0.0 + + def test_set_type_bonus(self) -> None: + """Test setting custom type bonus.""" + scorer = PriorityScorer() + + scorer.set_type_bonus(ContextType.KNOWLEDGE, 0.1) + assert scorer.get_type_bonus(ContextType.KNOWLEDGE) == 0.1 + + def test_set_type_bonus_invalid(self) -> None: + """Test setting invalid type bonus.""" + scorer = PriorityScorer() + + with pytest.raises(ValueError): + scorer.set_type_bonus(ContextType.KNOWLEDGE, 1.5) + + with pytest.raises(ValueError): + scorer.set_type_bonus(ContextType.KNOWLEDGE, -0.1) + + +class TestCompositeScorer: + """Tests for CompositeScorer.""" + + def test_creation(self) -> None: + """Test scorer creation with default weights.""" + scorer = CompositeScorer() + + weights = scorer.weights + assert weights["relevance"] == 0.5 + assert weights["recency"] == 0.3 + assert weights["priority"] == 0.2 + + def test_creation_with_custom_weights(self) -> None: + """Test scorer creation with custom weights.""" + scorer = CompositeScorer( + relevance_weight=0.6, + recency_weight=0.2, + priority_weight=0.2, + ) + + weights = scorer.weights + assert weights["relevance"] == 0.6 + assert weights["recency"] == 0.2 + assert weights["priority"] == 0.2 + + def test_update_weights(self) -> None: + """Test updating weights.""" + scorer = CompositeScorer() + + scorer.update_weights(relevance=0.7, recency=0.2, priority=0.1) + + weights = scorer.weights + assert weights["relevance"] == 0.7 + assert weights["recency"] == 0.2 + assert weights["priority"] == 0.1 + + def test_update_weights_partial(self) -> None: + """Test partially updating weights.""" + scorer = CompositeScorer() + original_recency = scorer.weights["recency"] + + scorer.update_weights(relevance=0.8) + + assert scorer.weights["relevance"] == 0.8 + assert scorer.weights["recency"] == original_recency + + @pytest.mark.asyncio + async def test_score_basic(self) -> None: + """Test basic composite scoring.""" + scorer = CompositeScorer() + + context = KnowledgeContext( + content="Test content", + source="docs", + relevance_score=0.8, + timestamp=datetime.now(UTC), + priority=ContextPriority.NORMAL.value, + ) + + score = await scorer.score(context, "test query") + assert 0.0 <= score <= 1.0 + + @pytest.mark.asyncio + async def test_score_with_details(self) -> None: + """Test scoring with detailed breakdown.""" + scorer = CompositeScorer() + + context = KnowledgeContext( + content="Test content", + source="docs", + relevance_score=0.8, + timestamp=datetime.now(UTC), + priority=ContextPriority.HIGH.value, + ) + + scored = await scorer.score_with_details(context, "test query") + + assert isinstance(scored, ScoredContext) + assert scored.context is context + assert 0.0 <= scored.composite_score <= 1.0 + assert scored.relevance_score == 0.8 + assert scored.recency_score > 0.9 # Very recent + assert scored.priority_score > 0.5 # HIGH priority + + @pytest.mark.asyncio + async def test_score_cached_on_context(self) -> None: + """Test that score is cached on the context.""" + scorer = CompositeScorer() + + context = KnowledgeContext( + content="Test", + source="docs", + relevance_score=0.5, + ) + + # First scoring + await scorer.score(context, "query") + assert context._score is not None + + # Second scoring should use cached value + context._score = 0.999 # Set to a known value + score2 = await scorer.score(context, "query") + assert score2 == 0.999 + + @pytest.mark.asyncio + async def test_score_batch(self) -> None: + """Test batch scoring.""" + scorer = CompositeScorer() + + contexts = [ + KnowledgeContext( + content="High relevance", + source="docs", + relevance_score=0.9, + ), + KnowledgeContext( + content="Low relevance", + source="docs", + relevance_score=0.2, + ), + ] + + scored = await scorer.score_batch(contexts, "query") + assert len(scored) == 2 + assert scored[0].relevance_score > scored[1].relevance_score + + @pytest.mark.asyncio + async def test_rank(self) -> None: + """Test ranking contexts.""" + scorer = CompositeScorer() + + contexts = [ + KnowledgeContext( + content="Low", source="docs", relevance_score=0.2 + ), + KnowledgeContext( + content="High", source="docs", relevance_score=0.9 + ), + KnowledgeContext( + content="Medium", source="docs", relevance_score=0.5 + ), + ] + + ranked = await scorer.rank(contexts, "query") + + # Should be sorted by score (highest first) + assert len(ranked) == 3 + assert ranked[0].relevance_score == 0.9 + assert ranked[1].relevance_score == 0.5 + assert ranked[2].relevance_score == 0.2 + + @pytest.mark.asyncio + async def test_rank_with_limit(self) -> None: + """Test ranking with limit.""" + scorer = CompositeScorer() + + contexts = [ + KnowledgeContext( + content=str(i), source="docs", relevance_score=i / 10 + ) + for i in range(10) + ] + + ranked = await scorer.rank(contexts, "query", limit=3) + assert len(ranked) == 3 + + @pytest.mark.asyncio + async def test_rank_with_min_score(self) -> None: + """Test ranking with minimum score threshold.""" + scorer = CompositeScorer() + + contexts = [ + KnowledgeContext( + content="Low", source="docs", relevance_score=0.1 + ), + KnowledgeContext( + content="High", source="docs", relevance_score=0.9 + ), + ] + + ranked = await scorer.rank(contexts, "query", min_score=0.5) + + # Only the high relevance context should pass the threshold + assert len(ranked) <= 2 # Could be 1 if min_score filters + + def test_set_mcp_manager(self) -> None: + """Test setting MCP manager.""" + scorer = CompositeScorer() + mock_mcp = MagicMock() + + scorer.set_mcp_manager(mock_mcp) + assert scorer._relevance_scorer._mcp is mock_mcp + + +class TestScoredContext: + """Tests for ScoredContext dataclass.""" + + def test_creation(self) -> None: + """Test ScoredContext creation.""" + context = TaskContext(content="Test", source="task") + scored = ScoredContext( + context=context, + composite_score=0.75, + relevance_score=0.8, + recency_score=0.7, + priority_score=0.5, + ) + + assert scored.context is context + assert scored.composite_score == 0.75 + + def test_comparison_operators(self) -> None: + """Test comparison operators for sorting.""" + ctx1 = TaskContext(content="1", source="task") + ctx2 = TaskContext(content="2", source="task") + + scored1 = ScoredContext(context=ctx1, composite_score=0.5) + scored2 = ScoredContext(context=ctx2, composite_score=0.8) + + assert scored1 < scored2 + assert scored2 > scored1 + + def test_sorting(self) -> None: + """Test sorting scored contexts.""" + contexts = [ + ScoredContext( + context=TaskContext(content="Low", source="task"), + composite_score=0.3, + ), + ScoredContext( + context=TaskContext(content="High", source="task"), + composite_score=0.9, + ), + ScoredContext( + context=TaskContext(content="Medium", source="task"), + composite_score=0.6, + ), + ] + + sorted_contexts = sorted(contexts, reverse=True) + + assert sorted_contexts[0].composite_score == 0.9 + assert sorted_contexts[1].composite_score == 0.6 + assert sorted_contexts[2].composite_score == 0.3 + + +class TestBaseScorer: + """Tests for BaseScorer abstract class.""" + + def test_weight_property(self) -> None: + """Test weight property.""" + # Use a concrete implementation + scorer = RelevanceScorer(weight=0.7) + assert scorer.weight == 0.7 + + def test_weight_setter_valid(self) -> None: + """Test weight setter with valid values.""" + scorer = RelevanceScorer() + scorer.weight = 0.5 + assert scorer.weight == 0.5 + + def test_weight_setter_invalid(self) -> None: + """Test weight setter with invalid values.""" + scorer = RelevanceScorer() + + with pytest.raises(ValueError): + scorer.weight = -0.1 + + with pytest.raises(ValueError): + scorer.weight = 1.5 + + def test_normalize_score(self) -> None: + """Test score normalization.""" + scorer = RelevanceScorer() + + # Normal range + assert scorer.normalize_score(0.5) == 0.5 + + # Below 0 + assert scorer.normalize_score(-0.5) == 0.0 + + # Above 1 + assert scorer.normalize_score(1.5) == 1.0 + + # Boundaries + assert scorer.normalize_score(0.0) == 0.0 + assert scorer.normalize_score(1.0) == 1.0 diff --git a/backend/tests/services/context/test_types.py b/backend/tests/services/context/test_types.py index 0db53cc..ca36566 100644 --- a/backend/tests/services/context/test_types.py +++ b/backend/tests/services/context/test_types.py @@ -286,9 +286,20 @@ class TestTaskContext: assert ctx.title == "Login Feature" assert ctx.get_type() == ContextType.TASK - def test_default_high_priority(self) -> None: - """Test that task context defaults to high priority.""" + def test_default_normal_priority(self) -> None: + """Test that task context uses NORMAL priority from base class.""" ctx = TaskContext(content="Test", source="test") + # TaskContext inherits NORMAL priority from BaseContext + # Use TaskContext.create() for default HIGH priority behavior + assert ctx.priority == ContextPriority.NORMAL.value + + def test_explicit_high_priority(self) -> None: + """Test setting explicit HIGH priority.""" + ctx = TaskContext( + content="Test", + source="test", + priority=ContextPriority.HIGH.value, + ) assert ctx.priority == ContextPriority.HIGH.value def test_create_factory(self) -> None: