""" Context Assembly Pipeline. Orchestrates the full context assembly workflow: Gather → Count → Score → Rank → Compress → Format """ import asyncio import logging import time from dataclasses import dataclass, field from datetime import UTC, datetime from typing import TYPE_CHECKING, Any from ..budget import BudgetAllocator, TokenBudget, TokenCalculator from ..compression.truncation import ContextCompressor from ..config import ContextSettings, get_context_settings from ..exceptions import AssemblyTimeoutError from ..prioritization import ContextRanker from ..scoring import CompositeScorer from ..types import AssembledContext, BaseContext, ContextType if TYPE_CHECKING: from app.services.mcp.client_manager import MCPClientManager logger = logging.getLogger(__name__) @dataclass class PipelineMetrics: """Metrics from pipeline execution.""" start_time: datetime = field(default_factory=lambda: datetime.now(UTC)) end_time: datetime | None = None total_contexts: int = 0 selected_contexts: int = 0 excluded_contexts: int = 0 compressed_contexts: int = 0 total_tokens: int = 0 assembly_time_ms: float = 0.0 scoring_time_ms: float = 0.0 compression_time_ms: float = 0.0 formatting_time_ms: float = 0.0 def to_dict(self) -> dict[str, Any]: """Convert to dictionary.""" return { "start_time": self.start_time.isoformat(), "end_time": self.end_time.isoformat() if self.end_time else None, "total_contexts": self.total_contexts, "selected_contexts": self.selected_contexts, "excluded_contexts": self.excluded_contexts, "compressed_contexts": self.compressed_contexts, "total_tokens": self.total_tokens, "assembly_time_ms": round(self.assembly_time_ms, 2), "scoring_time_ms": round(self.scoring_time_ms, 2), "compression_time_ms": round(self.compression_time_ms, 2), "formatting_time_ms": round(self.formatting_time_ms, 2), } class ContextPipeline: """ Context assembly pipeline. Orchestrates the full workflow of context assembly: 1. Validate and count tokens for all contexts 2. Score contexts based on relevance, recency, and priority 3. Rank and select contexts within budget 4. Compress if needed to fit remaining budget 5. Format for the target model """ def __init__( self, mcp_manager: "MCPClientManager | None" = None, settings: ContextSettings | None = None, calculator: TokenCalculator | None = None, scorer: CompositeScorer | None = None, ranker: ContextRanker | None = None, compressor: ContextCompressor | None = None, ) -> None: """ Initialize the context pipeline. Args: mcp_manager: MCP client manager for LLM Gateway integration settings: Context settings calculator: Token calculator scorer: Context scorer ranker: Context ranker compressor: Context compressor """ self._settings = settings or get_context_settings() self._mcp = mcp_manager # Initialize components self._calculator = calculator or TokenCalculator(mcp_manager=mcp_manager) self._scorer = scorer or CompositeScorer( mcp_manager=mcp_manager, settings=self._settings ) self._ranker = ranker or ContextRanker( scorer=self._scorer, calculator=self._calculator ) self._compressor = compressor or ContextCompressor( calculator=self._calculator ) self._allocator = BudgetAllocator(self._settings) def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None: """Set MCP manager for all components.""" self._mcp = mcp_manager self._calculator.set_mcp_manager(mcp_manager) self._scorer.set_mcp_manager(mcp_manager) async def assemble( self, contexts: list[BaseContext], query: str, model: str, max_tokens: int | None = None, custom_budget: TokenBudget | None = None, compress: bool = True, format_output: bool = True, timeout_ms: int | None = None, ) -> AssembledContext: """ Assemble context for an LLM request. This is the main entry point for context assembly. Args: contexts: List of contexts to assemble query: Query to optimize for model: Target model name max_tokens: Maximum total tokens (uses model default if None) custom_budget: Optional pre-configured budget compress: Whether to compress oversized contexts format_output: Whether to format the final output timeout_ms: Maximum assembly time in milliseconds Returns: AssembledContext with optimized content Raises: AssemblyTimeoutError: If assembly exceeds timeout """ timeout = timeout_ms or self._settings.max_assembly_time_ms start = time.perf_counter() metrics = PipelineMetrics(total_contexts=len(contexts)) try: # Create or use budget if custom_budget: budget = custom_budget elif max_tokens: budget = self._allocator.create_budget(max_tokens) else: budget = self._allocator.create_budget_for_model(model) # 1. Count tokens for all contexts await self._ensure_token_counts(contexts, model) # Check timeout self._check_timeout(start, timeout, "token counting") # 2. Score and rank contexts scoring_start = time.perf_counter() ranking_result = await self._ranker.rank( contexts=contexts, query=query, budget=budget, model=model, ) metrics.scoring_time_ms = (time.perf_counter() - scoring_start) * 1000 selected_contexts = ranking_result.selected_contexts metrics.selected_contexts = len(selected_contexts) metrics.excluded_contexts = len(ranking_result.excluded) # Check timeout self._check_timeout(start, timeout, "scoring") # 3. Compress if needed and enabled if compress and self._needs_compression(selected_contexts, budget): compression_start = time.perf_counter() selected_contexts = await self._compressor.compress_contexts( selected_contexts, budget, model ) metrics.compression_time_ms = ( time.perf_counter() - compression_start ) * 1000 metrics.compressed_contexts = sum( 1 for c in selected_contexts if c.metadata.get("truncated", False) ) # Check timeout self._check_timeout(start, timeout, "compression") # 4. Format output formatting_start = time.perf_counter() if format_output: formatted_content = self._format_contexts(selected_contexts, model) else: formatted_content = "\n\n".join(c.content for c in selected_contexts) metrics.formatting_time_ms = (time.perf_counter() - formatting_start) * 1000 # Calculate final metrics total_tokens = sum(c.token_count or 0 for c in selected_contexts) metrics.total_tokens = total_tokens metrics.assembly_time_ms = (time.perf_counter() - start) * 1000 metrics.end_time = datetime.now(UTC) return AssembledContext( content=formatted_content, total_tokens=total_tokens, context_count=len(selected_contexts), assembly_time_ms=metrics.assembly_time_ms, model=model, contexts=selected_contexts, excluded_count=metrics.excluded_contexts, metadata={ "metrics": metrics.to_dict(), "query": query, "budget": budget.to_dict(), }, ) except AssemblyTimeoutError: raise except Exception as e: logger.error(f"Context assembly failed: {e}", exc_info=True) raise async def _ensure_token_counts( self, contexts: list[BaseContext], model: str | None = None, ) -> None: """Ensure all contexts have token counts.""" tasks = [] for context in contexts: if context.token_count is None: tasks.append(self._count_and_set(context, model)) if tasks: await asyncio.gather(*tasks) async def _count_and_set( self, context: BaseContext, model: str | None = None, ) -> None: """Count tokens and set on context.""" count = await self._calculator.count_tokens(context.content, model) context.token_count = count def _needs_compression( self, contexts: list[BaseContext], budget: TokenBudget, ) -> bool: """Check if any contexts exceed their type budget.""" # Group by type and check totals by_type: dict[ContextType, int] = {} for context in contexts: ct = context.get_type() by_type[ct] = by_type.get(ct, 0) + (context.token_count or 0) for ct, total in by_type.items(): if total > budget.get_allocation(ct): return True # Also check if utilization exceeds threshold return budget.utilization() > self._settings.compression_threshold def _format_contexts( self, contexts: list[BaseContext], model: str, ) -> str: """ Format contexts for the target model. Groups contexts by type and applies model-specific formatting. """ # Group by type by_type: dict[ContextType, list[BaseContext]] = {} for context in contexts: ct = context.get_type() if ct not in by_type: by_type[ct] = [] by_type[ct].append(context) # Order types: System -> Task -> Knowledge -> Conversation -> Tool type_order = [ ContextType.SYSTEM, ContextType.TASK, ContextType.KNOWLEDGE, ContextType.CONVERSATION, ContextType.TOOL, ] parts: list[str] = [] for ct in type_order: if ct in by_type: formatted = self._format_type(by_type[ct], ct, model) if formatted: parts.append(formatted) return "\n\n".join(parts) def _format_type( self, contexts: list[BaseContext], context_type: ContextType, model: str, ) -> str: """Format contexts of a specific type.""" if not contexts: return "" # Check if model prefers XML tags (Claude) use_xml = "claude" in model.lower() if context_type == ContextType.SYSTEM: return self._format_system(contexts, use_xml) elif context_type == ContextType.TASK: return self._format_task(contexts, use_xml) elif context_type == ContextType.KNOWLEDGE: return self._format_knowledge(contexts, use_xml) elif context_type == ContextType.CONVERSATION: return self._format_conversation(contexts, use_xml) elif context_type == ContextType.TOOL: return self._format_tool(contexts, use_xml) return "\n".join(c.content for c in contexts) def _format_system( self, contexts: list[BaseContext], use_xml: bool ) -> str: """Format system contexts.""" content = "\n\n".join(c.content for c in contexts) if use_xml: return f"\n{content}\n" return content def _format_task( self, contexts: list[BaseContext], use_xml: bool ) -> str: """Format task contexts.""" content = "\n\n".join(c.content for c in contexts) if use_xml: return f"\n{content}\n" return f"## Current Task\n\n{content}" def _format_knowledge( self, contexts: list[BaseContext], use_xml: bool ) -> str: """Format knowledge contexts.""" if use_xml: parts = [""] for ctx in contexts: parts.append(f'') parts.append(ctx.content) parts.append("") parts.append("") return "\n".join(parts) else: parts = ["## Reference Documents\n"] for ctx in contexts: parts.append(f"### Source: {ctx.source}\n") parts.append(ctx.content) parts.append("") return "\n".join(parts) def _format_conversation( self, contexts: list[BaseContext], use_xml: bool ) -> str: """Format conversation contexts.""" if use_xml: parts = [""] for ctx in contexts: role = ctx.metadata.get("role", "user") parts.append(f'') parts.append(ctx.content) parts.append("") parts.append("") return "\n".join(parts) else: parts = [] for ctx in contexts: role = ctx.metadata.get("role", "user") parts.append(f"**{role.upper()}**: {ctx.content}") return "\n\n".join(parts) def _format_tool( self, contexts: list[BaseContext], use_xml: bool ) -> str: """Format tool contexts.""" if use_xml: parts = [""] for ctx in contexts: tool_name = ctx.metadata.get("tool_name", "unknown") parts.append(f'') parts.append(ctx.content) parts.append("") parts.append("") return "\n".join(parts) else: parts = ["## Recent Tool Results\n"] for ctx in contexts: tool_name = ctx.metadata.get("tool_name", "unknown") parts.append(f"### Tool: {tool_name}\n") parts.append(f"```\n{ctx.content}\n```") parts.append("") return "\n".join(parts) def _check_timeout( self, start: float, timeout_ms: int, phase: str, ) -> None: """Check if timeout exceeded and raise if so.""" elapsed_ms = (time.perf_counter() - start) * 1000 if elapsed_ms > timeout_ms: raise AssemblyTimeoutError( message=f"Context assembly timed out during {phase}", elapsed_ms=elapsed_ms, timeout_ms=timeout_ms, )