Files
fast-next-template/backend/app/services/context/assembly/pipeline.py
Felipe Cardoso 6b07e62f00 feat(context): implement assembly pipeline and compression (#82)
Phase 4 of Context Management Engine - Assembly Pipeline:

- Add TruncationStrategy with end/middle/sentence-aware truncation
- Add TruncationResult dataclass for tracking compression metrics
- Add ContextCompressor for type-specific compression
- Add ContextPipeline orchestrating full assembly workflow:
  - Token counting for all contexts
  - Scoring and ranking via ContextRanker
  - Optional compression when budget threshold exceeded
  - Model-specific formatting (XML for Claude, markdown for OpenAI)
- Add PipelineMetrics for performance tracking
- Update AssembledContext with new fields (model, contexts, metadata)
- Add backward compatibility aliases for renamed fields

Tests: 34 new tests, 223 total context tests passing

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-04 02:32:25 +01:00

433 lines
15 KiB
Python

"""
Context Assembly Pipeline.
Orchestrates the full context assembly workflow:
Gather → Count → Score → Rank → Compress → Format
"""
import asyncio
import logging
import time
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any
from ..budget import BudgetAllocator, TokenBudget, TokenCalculator
from ..compression.truncation import ContextCompressor
from ..config import ContextSettings, get_context_settings
from ..exceptions import AssemblyTimeoutError
from ..prioritization import ContextRanker
from ..scoring import CompositeScorer
from ..types import AssembledContext, BaseContext, ContextType
if TYPE_CHECKING:
from app.services.mcp.client_manager import MCPClientManager
logger = logging.getLogger(__name__)
@dataclass
class PipelineMetrics:
"""Metrics from pipeline execution."""
start_time: datetime = field(default_factory=lambda: datetime.now(UTC))
end_time: datetime | None = None
total_contexts: int = 0
selected_contexts: int = 0
excluded_contexts: int = 0
compressed_contexts: int = 0
total_tokens: int = 0
assembly_time_ms: float = 0.0
scoring_time_ms: float = 0.0
compression_time_ms: float = 0.0
formatting_time_ms: float = 0.0
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"start_time": self.start_time.isoformat(),
"end_time": self.end_time.isoformat() if self.end_time else None,
"total_contexts": self.total_contexts,
"selected_contexts": self.selected_contexts,
"excluded_contexts": self.excluded_contexts,
"compressed_contexts": self.compressed_contexts,
"total_tokens": self.total_tokens,
"assembly_time_ms": round(self.assembly_time_ms, 2),
"scoring_time_ms": round(self.scoring_time_ms, 2),
"compression_time_ms": round(self.compression_time_ms, 2),
"formatting_time_ms": round(self.formatting_time_ms, 2),
}
class ContextPipeline:
"""
Context assembly pipeline.
Orchestrates the full workflow of context assembly:
1. Validate and count tokens for all contexts
2. Score contexts based on relevance, recency, and priority
3. Rank and select contexts within budget
4. Compress if needed to fit remaining budget
5. Format for the target model
"""
def __init__(
self,
mcp_manager: "MCPClientManager | None" = None,
settings: ContextSettings | None = None,
calculator: TokenCalculator | None = None,
scorer: CompositeScorer | None = None,
ranker: ContextRanker | None = None,
compressor: ContextCompressor | None = None,
) -> None:
"""
Initialize the context pipeline.
Args:
mcp_manager: MCP client manager for LLM Gateway integration
settings: Context settings
calculator: Token calculator
scorer: Context scorer
ranker: Context ranker
compressor: Context compressor
"""
self._settings = settings or get_context_settings()
self._mcp = mcp_manager
# Initialize components
self._calculator = calculator or TokenCalculator(mcp_manager=mcp_manager)
self._scorer = scorer or CompositeScorer(
mcp_manager=mcp_manager, settings=self._settings
)
self._ranker = ranker or ContextRanker(
scorer=self._scorer, calculator=self._calculator
)
self._compressor = compressor or ContextCompressor(
calculator=self._calculator
)
self._allocator = BudgetAllocator(self._settings)
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
"""Set MCP manager for all components."""
self._mcp = mcp_manager
self._calculator.set_mcp_manager(mcp_manager)
self._scorer.set_mcp_manager(mcp_manager)
async def assemble(
self,
contexts: list[BaseContext],
query: str,
model: str,
max_tokens: int | None = None,
custom_budget: TokenBudget | None = None,
compress: bool = True,
format_output: bool = True,
timeout_ms: int | None = None,
) -> AssembledContext:
"""
Assemble context for an LLM request.
This is the main entry point for context assembly.
Args:
contexts: List of contexts to assemble
query: Query to optimize for
model: Target model name
max_tokens: Maximum total tokens (uses model default if None)
custom_budget: Optional pre-configured budget
compress: Whether to compress oversized contexts
format_output: Whether to format the final output
timeout_ms: Maximum assembly time in milliseconds
Returns:
AssembledContext with optimized content
Raises:
AssemblyTimeoutError: If assembly exceeds timeout
"""
timeout = timeout_ms or self._settings.max_assembly_time_ms
start = time.perf_counter()
metrics = PipelineMetrics(total_contexts=len(contexts))
try:
# Create or use budget
if custom_budget:
budget = custom_budget
elif max_tokens:
budget = self._allocator.create_budget(max_tokens)
else:
budget = self._allocator.create_budget_for_model(model)
# 1. Count tokens for all contexts
await self._ensure_token_counts(contexts, model)
# Check timeout
self._check_timeout(start, timeout, "token counting")
# 2. Score and rank contexts
scoring_start = time.perf_counter()
ranking_result = await self._ranker.rank(
contexts=contexts,
query=query,
budget=budget,
model=model,
)
metrics.scoring_time_ms = (time.perf_counter() - scoring_start) * 1000
selected_contexts = ranking_result.selected_contexts
metrics.selected_contexts = len(selected_contexts)
metrics.excluded_contexts = len(ranking_result.excluded)
# Check timeout
self._check_timeout(start, timeout, "scoring")
# 3. Compress if needed and enabled
if compress and self._needs_compression(selected_contexts, budget):
compression_start = time.perf_counter()
selected_contexts = await self._compressor.compress_contexts(
selected_contexts, budget, model
)
metrics.compression_time_ms = (
time.perf_counter() - compression_start
) * 1000
metrics.compressed_contexts = sum(
1 for c in selected_contexts if c.metadata.get("truncated", False)
)
# Check timeout
self._check_timeout(start, timeout, "compression")
# 4. Format output
formatting_start = time.perf_counter()
if format_output:
formatted_content = self._format_contexts(selected_contexts, model)
else:
formatted_content = "\n\n".join(c.content for c in selected_contexts)
metrics.formatting_time_ms = (time.perf_counter() - formatting_start) * 1000
# Calculate final metrics
total_tokens = sum(c.token_count or 0 for c in selected_contexts)
metrics.total_tokens = total_tokens
metrics.assembly_time_ms = (time.perf_counter() - start) * 1000
metrics.end_time = datetime.now(UTC)
return AssembledContext(
content=formatted_content,
total_tokens=total_tokens,
context_count=len(selected_contexts),
assembly_time_ms=metrics.assembly_time_ms,
model=model,
contexts=selected_contexts,
excluded_count=metrics.excluded_contexts,
metadata={
"metrics": metrics.to_dict(),
"query": query,
"budget": budget.to_dict(),
},
)
except AssemblyTimeoutError:
raise
except Exception as e:
logger.error(f"Context assembly failed: {e}", exc_info=True)
raise
async def _ensure_token_counts(
self,
contexts: list[BaseContext],
model: str | None = None,
) -> None:
"""Ensure all contexts have token counts."""
tasks = []
for context in contexts:
if context.token_count is None:
tasks.append(self._count_and_set(context, model))
if tasks:
await asyncio.gather(*tasks)
async def _count_and_set(
self,
context: BaseContext,
model: str | None = None,
) -> None:
"""Count tokens and set on context."""
count = await self._calculator.count_tokens(context.content, model)
context.token_count = count
def _needs_compression(
self,
contexts: list[BaseContext],
budget: TokenBudget,
) -> bool:
"""Check if any contexts exceed their type budget."""
# Group by type and check totals
by_type: dict[ContextType, int] = {}
for context in contexts:
ct = context.get_type()
by_type[ct] = by_type.get(ct, 0) + (context.token_count or 0)
for ct, total in by_type.items():
if total > budget.get_allocation(ct):
return True
# Also check if utilization exceeds threshold
return budget.utilization() > self._settings.compression_threshold
def _format_contexts(
self,
contexts: list[BaseContext],
model: str,
) -> str:
"""
Format contexts for the target model.
Groups contexts by type and applies model-specific formatting.
"""
# Group by type
by_type: dict[ContextType, list[BaseContext]] = {}
for context in contexts:
ct = context.get_type()
if ct not in by_type:
by_type[ct] = []
by_type[ct].append(context)
# Order types: System -> Task -> Knowledge -> Conversation -> Tool
type_order = [
ContextType.SYSTEM,
ContextType.TASK,
ContextType.KNOWLEDGE,
ContextType.CONVERSATION,
ContextType.TOOL,
]
parts: list[str] = []
for ct in type_order:
if ct in by_type:
formatted = self._format_type(by_type[ct], ct, model)
if formatted:
parts.append(formatted)
return "\n\n".join(parts)
def _format_type(
self,
contexts: list[BaseContext],
context_type: ContextType,
model: str,
) -> str:
"""Format contexts of a specific type."""
if not contexts:
return ""
# Check if model prefers XML tags (Claude)
use_xml = "claude" in model.lower()
if context_type == ContextType.SYSTEM:
return self._format_system(contexts, use_xml)
elif context_type == ContextType.TASK:
return self._format_task(contexts, use_xml)
elif context_type == ContextType.KNOWLEDGE:
return self._format_knowledge(contexts, use_xml)
elif context_type == ContextType.CONVERSATION:
return self._format_conversation(contexts, use_xml)
elif context_type == ContextType.TOOL:
return self._format_tool(contexts, use_xml)
return "\n".join(c.content for c in contexts)
def _format_system(
self, contexts: list[BaseContext], use_xml: bool
) -> str:
"""Format system contexts."""
content = "\n\n".join(c.content for c in contexts)
if use_xml:
return f"<system_instructions>\n{content}\n</system_instructions>"
return content
def _format_task(
self, contexts: list[BaseContext], use_xml: bool
) -> str:
"""Format task contexts."""
content = "\n\n".join(c.content for c in contexts)
if use_xml:
return f"<current_task>\n{content}\n</current_task>"
return f"## Current Task\n\n{content}"
def _format_knowledge(
self, contexts: list[BaseContext], use_xml: bool
) -> str:
"""Format knowledge contexts."""
if use_xml:
parts = ["<reference_documents>"]
for ctx in contexts:
parts.append(f'<document source="{ctx.source}">')
parts.append(ctx.content)
parts.append("</document>")
parts.append("</reference_documents>")
return "\n".join(parts)
else:
parts = ["## Reference Documents\n"]
for ctx in contexts:
parts.append(f"### Source: {ctx.source}\n")
parts.append(ctx.content)
parts.append("")
return "\n".join(parts)
def _format_conversation(
self, contexts: list[BaseContext], use_xml: bool
) -> str:
"""Format conversation contexts."""
if use_xml:
parts = ["<conversation_history>"]
for ctx in contexts:
role = ctx.metadata.get("role", "user")
parts.append(f'<message role="{role}">')
parts.append(ctx.content)
parts.append("</message>")
parts.append("</conversation_history>")
return "\n".join(parts)
else:
parts = []
for ctx in contexts:
role = ctx.metadata.get("role", "user")
parts.append(f"**{role.upper()}**: {ctx.content}")
return "\n\n".join(parts)
def _format_tool(
self, contexts: list[BaseContext], use_xml: bool
) -> str:
"""Format tool contexts."""
if use_xml:
parts = ["<tool_results>"]
for ctx in contexts:
tool_name = ctx.metadata.get("tool_name", "unknown")
parts.append(f'<tool_result name="{tool_name}">')
parts.append(ctx.content)
parts.append("</tool_result>")
parts.append("</tool_results>")
return "\n".join(parts)
else:
parts = ["## Recent Tool Results\n"]
for ctx in contexts:
tool_name = ctx.metadata.get("tool_name", "unknown")
parts.append(f"### Tool: {tool_name}\n")
parts.append(f"```\n{ctx.content}\n```")
parts.append("")
return "\n".join(parts)
def _check_timeout(
self,
start: float,
timeout_ms: int,
phase: str,
) -> None:
"""Check if timeout exceeded and raise if so."""
elapsed_ms = (time.perf_counter() - start) * 1000
if elapsed_ms > timeout_ms:
raise AssemblyTimeoutError(
message=f"Context assembly timed out during {phase}",
elapsed_ms=elapsed_ms,
timeout_ms=timeout_ms,
)