""" Composite Scorer for Context Management. Combines multiple scoring strategies with configurable weights. """ import asyncio import logging import time from dataclasses import dataclass from typing import TYPE_CHECKING, Any from ..config import ContextSettings, get_context_settings from ..types import BaseContext from .priority import PriorityScorer from .recency import RecencyScorer from .relevance import RelevanceScorer if TYPE_CHECKING: from app.services.mcp.client_manager import MCPClientManager logger = logging.getLogger(__name__) @dataclass class ScoredContext: """Context with computed scores.""" context: BaseContext composite_score: float relevance_score: float = 0.0 recency_score: float = 0.0 priority_score: float = 0.0 def __lt__(self, other: "ScoredContext") -> bool: """Enable sorting by composite score.""" return self.composite_score < other.composite_score def __gt__(self, other: "ScoredContext") -> bool: """Enable sorting by composite score.""" return self.composite_score > other.composite_score class CompositeScorer: """ Combines multiple scoring strategies. Weights: - relevance: How well content matches the query - recency: How recent the content is - priority: Explicit priority assignments """ def __init__( self, mcp_manager: "MCPClientManager | None" = None, settings: ContextSettings | None = None, relevance_weight: float | None = None, recency_weight: float | None = None, priority_weight: float | None = None, ) -> None: """ Initialize composite scorer. Args: mcp_manager: MCP manager for semantic scoring settings: Context settings (uses default if None) relevance_weight: Override relevance weight recency_weight: Override recency weight priority_weight: Override priority weight """ self._settings = settings or get_context_settings() weights = self._settings.get_scoring_weights() self._relevance_weight = ( relevance_weight if relevance_weight is not None else weights["relevance"] ) self._recency_weight = ( recency_weight if recency_weight is not None else weights["recency"] ) self._priority_weight = ( priority_weight if priority_weight is not None else weights["priority"] ) # Initialize scorers self._relevance_scorer = RelevanceScorer( mcp_manager=mcp_manager, weight=self._relevance_weight, ) self._recency_scorer = RecencyScorer(weight=self._recency_weight) self._priority_scorer = PriorityScorer(weight=self._priority_weight) # Per-context locks to prevent race conditions during parallel scoring # Uses dict with (lock, last_used_time) tuples for cleanup self._context_locks: dict[str, tuple[asyncio.Lock, float]] = {} self._locks_lock = asyncio.Lock() # Lock to protect _context_locks access self._max_locks = 1000 # Maximum locks to keep (prevent memory growth) self._lock_ttl = 60.0 # Seconds before a lock can be cleaned up def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None: """Set MCP manager for semantic scoring.""" self._relevance_scorer.set_mcp_manager(mcp_manager) @property def weights(self) -> dict[str, float]: """Get current scoring weights.""" return { "relevance": self._relevance_weight, "recency": self._recency_weight, "priority": self._priority_weight, } def update_weights( self, relevance: float | None = None, recency: float | None = None, priority: float | None = None, ) -> None: """ Update scoring weights. Args: relevance: New relevance weight recency: New recency weight priority: New priority weight """ if relevance is not None: self._relevance_weight = max(0.0, min(1.0, relevance)) self._relevance_scorer.weight = self._relevance_weight if recency is not None: self._recency_weight = max(0.0, min(1.0, recency)) self._recency_scorer.weight = self._recency_weight if priority is not None: self._priority_weight = max(0.0, min(1.0, priority)) self._priority_scorer.weight = self._priority_weight async def _get_context_lock(self, context_id: str) -> asyncio.Lock: """ Get or create a lock for a specific context. Thread-safe access to per-context locks prevents race conditions when the same context is scored concurrently. Includes automatic cleanup of old locks to prevent memory growth. Args: context_id: The context ID to get a lock for Returns: asyncio.Lock for the context """ now = time.time() # Fast path: check if lock exists without acquiring main lock # NOTE: We only READ here - no writes to avoid race conditions # with cleanup. The timestamp will be updated in the slow path # if the lock is still valid. lock_entry = self._context_locks.get(context_id) if lock_entry is not None: lock, _ = lock_entry # Return the lock but defer timestamp update to avoid race # The lock is still valid; timestamp update is best-effort return lock # Slow path: create lock or update timestamp while holding main lock async with self._locks_lock: # Double-check after acquiring lock - entry may have been # created by another coroutine or deleted by cleanup lock_entry = self._context_locks.get(context_id) if lock_entry is not None: lock, _ = lock_entry # Safe to update timestamp here since we hold the lock self._context_locks[context_id] = (lock, now) return lock # Cleanup old locks if we have too many if len(self._context_locks) >= self._max_locks: self._cleanup_old_locks(now) # Create new lock new_lock = asyncio.Lock() self._context_locks[context_id] = (new_lock, now) return new_lock def _cleanup_old_locks(self, now: float) -> None: """ Remove old locks that haven't been used recently. Called while holding _locks_lock. Removes locks older than _lock_ttl, but only if they're not currently held. Args: now: Current timestamp for age calculation """ cutoff = now - self._lock_ttl to_remove = [] for context_id, (lock, last_used) in self._context_locks.items(): # Only remove if old AND not currently held if last_used < cutoff and not lock.locked(): to_remove.append(context_id) # Remove oldest 50% if still over limit after TTL filtering if len(self._context_locks) - len(to_remove) >= self._max_locks: # Sort by last used time and mark oldest for removal sorted_entries = sorted( self._context_locks.items(), key=lambda x: x[1][1], # Sort by last_used time ) # Remove oldest 50% that aren't locked target_remove = len(self._context_locks) // 2 for context_id, (lock, _) in sorted_entries: if len(to_remove) >= target_remove: break if context_id not in to_remove and not lock.locked(): to_remove.append(context_id) for context_id in to_remove: del self._context_locks[context_id] if to_remove: logger.debug(f"Cleaned up {len(to_remove)} context locks") async def score( self, context: BaseContext, query: str, **kwargs: Any, ) -> float: """ Compute composite score for a context. Args: context: Context to score query: Query to score against **kwargs: Additional scoring parameters Returns: Composite score between 0.0 and 1.0 """ scored = await self.score_with_details(context, query, **kwargs) return scored.composite_score async def score_with_details( self, context: BaseContext, query: str, **kwargs: Any, ) -> ScoredContext: """ Compute composite score with individual scores. Uses per-context locking to prevent race conditions when the same context is scored concurrently in parallel scoring operations. Args: context: Context to score query: Query to score against **kwargs: Additional scoring parameters Returns: ScoredContext with all scores """ # Get lock for this specific context to prevent race conditions # within concurrent scoring operations for the same query context_lock = await self._get_context_lock(context.id) async with context_lock: # Compute individual scores in parallel # Note: We do NOT cache scores on the context because scores are # query-dependent. Caching without considering the query would # return incorrect scores for different queries. relevance_task = self._relevance_scorer.score(context, query, **kwargs) recency_task = self._recency_scorer.score(context, query, **kwargs) priority_task = self._priority_scorer.score(context, query, **kwargs) relevance_score, recency_score, priority_score = await asyncio.gather( relevance_task, recency_task, priority_task ) # Compute weighted composite total_weight = ( self._relevance_weight + self._recency_weight + self._priority_weight ) if total_weight > 0: composite = ( relevance_score * self._relevance_weight + recency_score * self._recency_weight + priority_score * self._priority_weight ) / total_weight else: composite = 0.0 return ScoredContext( context=context, composite_score=composite, relevance_score=relevance_score, recency_score=recency_score, priority_score=priority_score, ) async def score_batch( self, contexts: list[BaseContext], query: str, parallel: bool = True, **kwargs: Any, ) -> list[ScoredContext]: """ Score multiple contexts. Args: contexts: Contexts to score query: Query to score against parallel: Whether to score in parallel **kwargs: Additional scoring parameters Returns: List of ScoredContext (same order as input) """ if parallel: tasks = [self.score_with_details(ctx, query, **kwargs) for ctx in contexts] return await asyncio.gather(*tasks) else: results = [] for ctx in contexts: scored = await self.score_with_details(ctx, query, **kwargs) results.append(scored) return results async def rank( self, contexts: list[BaseContext], query: str, limit: int | None = None, min_score: float = 0.0, **kwargs: Any, ) -> list[ScoredContext]: """ Score and rank contexts. Args: contexts: Contexts to rank query: Query to rank against limit: Maximum number of results min_score: Minimum score threshold **kwargs: Additional scoring parameters Returns: Sorted list of ScoredContext (highest first) """ # Score all contexts scored = await self.score_batch(contexts, query, **kwargs) # Filter by minimum score if min_score > 0: scored = [s for s in scored if s.composite_score >= min_score] # Sort by score (highest first) scored.sort(reverse=True) # Apply limit if limit is not None: scored = scored[:limit] return scored