forked from cardosofelipe/fast-next-template
- Cleaned up unnecessary comments in `__all__` definitions for better readability. - Adjusted indentation and formatting across modules for improved clarity (e.g., long lines, logical grouping). - Simplified conditional expressions and inline comments for context scoring and ranking. - Replaced some hard-coded values with type-safe annotations (e.g., `ClassVar`). - Removed unused imports and ensured consistent usage across test files. - Updated `test_score_not_cached_on_context` to clarify caching behavior. - Improved truncation strategy logic and marker handling.
421 lines
15 KiB
Python
421 lines
15 KiB
Python
"""
|
|
Context Assembly Pipeline.
|
|
|
|
Orchestrates the full context assembly workflow:
|
|
Gather → Count → Score → Rank → Compress → Format
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from datetime import UTC, datetime
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
from ..budget import BudgetAllocator, TokenBudget, TokenCalculator
|
|
from ..compression.truncation import ContextCompressor
|
|
from ..config import ContextSettings, get_context_settings
|
|
from ..exceptions import AssemblyTimeoutError
|
|
from ..prioritization import ContextRanker
|
|
from ..scoring import CompositeScorer
|
|
from ..types import AssembledContext, BaseContext, ContextType
|
|
|
|
if TYPE_CHECKING:
|
|
from app.services.mcp.client_manager import MCPClientManager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class PipelineMetrics:
|
|
"""Metrics from pipeline execution."""
|
|
|
|
start_time: datetime = field(default_factory=lambda: datetime.now(UTC))
|
|
end_time: datetime | None = None
|
|
total_contexts: int = 0
|
|
selected_contexts: int = 0
|
|
excluded_contexts: int = 0
|
|
compressed_contexts: int = 0
|
|
total_tokens: int = 0
|
|
assembly_time_ms: float = 0.0
|
|
scoring_time_ms: float = 0.0
|
|
compression_time_ms: float = 0.0
|
|
formatting_time_ms: float = 0.0
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
"""Convert to dictionary."""
|
|
return {
|
|
"start_time": self.start_time.isoformat(),
|
|
"end_time": self.end_time.isoformat() if self.end_time else None,
|
|
"total_contexts": self.total_contexts,
|
|
"selected_contexts": self.selected_contexts,
|
|
"excluded_contexts": self.excluded_contexts,
|
|
"compressed_contexts": self.compressed_contexts,
|
|
"total_tokens": self.total_tokens,
|
|
"assembly_time_ms": round(self.assembly_time_ms, 2),
|
|
"scoring_time_ms": round(self.scoring_time_ms, 2),
|
|
"compression_time_ms": round(self.compression_time_ms, 2),
|
|
"formatting_time_ms": round(self.formatting_time_ms, 2),
|
|
}
|
|
|
|
|
|
class ContextPipeline:
|
|
"""
|
|
Context assembly pipeline.
|
|
|
|
Orchestrates the full workflow of context assembly:
|
|
1. Validate and count tokens for all contexts
|
|
2. Score contexts based on relevance, recency, and priority
|
|
3. Rank and select contexts within budget
|
|
4. Compress if needed to fit remaining budget
|
|
5. Format for the target model
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
mcp_manager: "MCPClientManager | None" = None,
|
|
settings: ContextSettings | None = None,
|
|
calculator: TokenCalculator | None = None,
|
|
scorer: CompositeScorer | None = None,
|
|
ranker: ContextRanker | None = None,
|
|
compressor: ContextCompressor | None = None,
|
|
) -> None:
|
|
"""
|
|
Initialize the context pipeline.
|
|
|
|
Args:
|
|
mcp_manager: MCP client manager for LLM Gateway integration
|
|
settings: Context settings
|
|
calculator: Token calculator
|
|
scorer: Context scorer
|
|
ranker: Context ranker
|
|
compressor: Context compressor
|
|
"""
|
|
self._settings = settings or get_context_settings()
|
|
self._mcp = mcp_manager
|
|
|
|
# Initialize components
|
|
self._calculator = calculator or TokenCalculator(mcp_manager=mcp_manager)
|
|
self._scorer = scorer or CompositeScorer(
|
|
mcp_manager=mcp_manager, settings=self._settings
|
|
)
|
|
self._ranker = ranker or ContextRanker(
|
|
scorer=self._scorer, calculator=self._calculator
|
|
)
|
|
self._compressor = compressor or ContextCompressor(calculator=self._calculator)
|
|
self._allocator = BudgetAllocator(self._settings)
|
|
|
|
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
|
"""Set MCP manager for all components."""
|
|
self._mcp = mcp_manager
|
|
self._calculator.set_mcp_manager(mcp_manager)
|
|
self._scorer.set_mcp_manager(mcp_manager)
|
|
|
|
async def assemble(
|
|
self,
|
|
contexts: list[BaseContext],
|
|
query: str,
|
|
model: str,
|
|
max_tokens: int | None = None,
|
|
custom_budget: TokenBudget | None = None,
|
|
compress: bool = True,
|
|
format_output: bool = True,
|
|
timeout_ms: int | None = None,
|
|
) -> AssembledContext:
|
|
"""
|
|
Assemble context for an LLM request.
|
|
|
|
This is the main entry point for context assembly.
|
|
|
|
Args:
|
|
contexts: List of contexts to assemble
|
|
query: Query to optimize for
|
|
model: Target model name
|
|
max_tokens: Maximum total tokens (uses model default if None)
|
|
custom_budget: Optional pre-configured budget
|
|
compress: Whether to compress oversized contexts
|
|
format_output: Whether to format the final output
|
|
timeout_ms: Maximum assembly time in milliseconds
|
|
|
|
Returns:
|
|
AssembledContext with optimized content
|
|
|
|
Raises:
|
|
AssemblyTimeoutError: If assembly exceeds timeout
|
|
"""
|
|
timeout = timeout_ms or self._settings.max_assembly_time_ms
|
|
start = time.perf_counter()
|
|
metrics = PipelineMetrics(total_contexts=len(contexts))
|
|
|
|
try:
|
|
# Create or use budget
|
|
if custom_budget:
|
|
budget = custom_budget
|
|
elif max_tokens:
|
|
budget = self._allocator.create_budget(max_tokens)
|
|
else:
|
|
budget = self._allocator.create_budget_for_model(model)
|
|
|
|
# 1. Count tokens for all contexts
|
|
await self._ensure_token_counts(contexts, model)
|
|
|
|
# Check timeout
|
|
self._check_timeout(start, timeout, "token counting")
|
|
|
|
# 2. Score and rank contexts
|
|
scoring_start = time.perf_counter()
|
|
ranking_result = await self._ranker.rank(
|
|
contexts=contexts,
|
|
query=query,
|
|
budget=budget,
|
|
model=model,
|
|
)
|
|
metrics.scoring_time_ms = (time.perf_counter() - scoring_start) * 1000
|
|
|
|
selected_contexts = ranking_result.selected_contexts
|
|
metrics.selected_contexts = len(selected_contexts)
|
|
metrics.excluded_contexts = len(ranking_result.excluded)
|
|
|
|
# Check timeout
|
|
self._check_timeout(start, timeout, "scoring")
|
|
|
|
# 3. Compress if needed and enabled
|
|
if compress and self._needs_compression(selected_contexts, budget):
|
|
compression_start = time.perf_counter()
|
|
selected_contexts = await self._compressor.compress_contexts(
|
|
selected_contexts, budget, model
|
|
)
|
|
metrics.compression_time_ms = (
|
|
time.perf_counter() - compression_start
|
|
) * 1000
|
|
metrics.compressed_contexts = sum(
|
|
1 for c in selected_contexts if c.metadata.get("truncated", False)
|
|
)
|
|
|
|
# Check timeout
|
|
self._check_timeout(start, timeout, "compression")
|
|
|
|
# 4. Format output
|
|
formatting_start = time.perf_counter()
|
|
if format_output:
|
|
formatted_content = self._format_contexts(selected_contexts, model)
|
|
else:
|
|
formatted_content = "\n\n".join(c.content for c in selected_contexts)
|
|
metrics.formatting_time_ms = (time.perf_counter() - formatting_start) * 1000
|
|
|
|
# Calculate final metrics
|
|
total_tokens = sum(c.token_count or 0 for c in selected_contexts)
|
|
metrics.total_tokens = total_tokens
|
|
metrics.assembly_time_ms = (time.perf_counter() - start) * 1000
|
|
metrics.end_time = datetime.now(UTC)
|
|
|
|
return AssembledContext(
|
|
content=formatted_content,
|
|
total_tokens=total_tokens,
|
|
context_count=len(selected_contexts),
|
|
assembly_time_ms=metrics.assembly_time_ms,
|
|
model=model,
|
|
contexts=selected_contexts,
|
|
excluded_count=metrics.excluded_contexts,
|
|
metadata={
|
|
"metrics": metrics.to_dict(),
|
|
"query": query,
|
|
"budget": budget.to_dict(),
|
|
},
|
|
)
|
|
|
|
except AssemblyTimeoutError:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Context assembly failed: {e}", exc_info=True)
|
|
raise
|
|
|
|
async def _ensure_token_counts(
|
|
self,
|
|
contexts: list[BaseContext],
|
|
model: str | None = None,
|
|
) -> None:
|
|
"""Ensure all contexts have token counts."""
|
|
tasks = []
|
|
for context in contexts:
|
|
if context.token_count is None:
|
|
tasks.append(self._count_and_set(context, model))
|
|
|
|
if tasks:
|
|
await asyncio.gather(*tasks)
|
|
|
|
async def _count_and_set(
|
|
self,
|
|
context: BaseContext,
|
|
model: str | None = None,
|
|
) -> None:
|
|
"""Count tokens and set on context."""
|
|
count = await self._calculator.count_tokens(context.content, model)
|
|
context.token_count = count
|
|
|
|
def _needs_compression(
|
|
self,
|
|
contexts: list[BaseContext],
|
|
budget: TokenBudget,
|
|
) -> bool:
|
|
"""Check if any contexts exceed their type budget."""
|
|
# Group by type and check totals
|
|
by_type: dict[ContextType, int] = {}
|
|
for context in contexts:
|
|
ct = context.get_type()
|
|
by_type[ct] = by_type.get(ct, 0) + (context.token_count or 0)
|
|
|
|
for ct, total in by_type.items():
|
|
if total > budget.get_allocation(ct):
|
|
return True
|
|
|
|
# Also check if utilization exceeds threshold
|
|
return budget.utilization() > self._settings.compression_threshold
|
|
|
|
def _format_contexts(
|
|
self,
|
|
contexts: list[BaseContext],
|
|
model: str,
|
|
) -> str:
|
|
"""
|
|
Format contexts for the target model.
|
|
|
|
Groups contexts by type and applies model-specific formatting.
|
|
"""
|
|
# Group by type
|
|
by_type: dict[ContextType, list[BaseContext]] = {}
|
|
for context in contexts:
|
|
ct = context.get_type()
|
|
if ct not in by_type:
|
|
by_type[ct] = []
|
|
by_type[ct].append(context)
|
|
|
|
# Order types: System -> Task -> Knowledge -> Conversation -> Tool
|
|
type_order = [
|
|
ContextType.SYSTEM,
|
|
ContextType.TASK,
|
|
ContextType.KNOWLEDGE,
|
|
ContextType.CONVERSATION,
|
|
ContextType.TOOL,
|
|
]
|
|
|
|
parts: list[str] = []
|
|
for ct in type_order:
|
|
if ct in by_type:
|
|
formatted = self._format_type(by_type[ct], ct, model)
|
|
if formatted:
|
|
parts.append(formatted)
|
|
|
|
return "\n\n".join(parts)
|
|
|
|
def _format_type(
|
|
self,
|
|
contexts: list[BaseContext],
|
|
context_type: ContextType,
|
|
model: str,
|
|
) -> str:
|
|
"""Format contexts of a specific type."""
|
|
if not contexts:
|
|
return ""
|
|
|
|
# Check if model prefers XML tags (Claude)
|
|
use_xml = "claude" in model.lower()
|
|
|
|
if context_type == ContextType.SYSTEM:
|
|
return self._format_system(contexts, use_xml)
|
|
elif context_type == ContextType.TASK:
|
|
return self._format_task(contexts, use_xml)
|
|
elif context_type == ContextType.KNOWLEDGE:
|
|
return self._format_knowledge(contexts, use_xml)
|
|
elif context_type == ContextType.CONVERSATION:
|
|
return self._format_conversation(contexts, use_xml)
|
|
elif context_type == ContextType.TOOL:
|
|
return self._format_tool(contexts, use_xml)
|
|
|
|
return "\n".join(c.content for c in contexts)
|
|
|
|
def _format_system(self, contexts: list[BaseContext], use_xml: bool) -> str:
|
|
"""Format system contexts."""
|
|
content = "\n\n".join(c.content for c in contexts)
|
|
if use_xml:
|
|
return f"<system_instructions>\n{content}\n</system_instructions>"
|
|
return content
|
|
|
|
def _format_task(self, contexts: list[BaseContext], use_xml: bool) -> str:
|
|
"""Format task contexts."""
|
|
content = "\n\n".join(c.content for c in contexts)
|
|
if use_xml:
|
|
return f"<current_task>\n{content}\n</current_task>"
|
|
return f"## Current Task\n\n{content}"
|
|
|
|
def _format_knowledge(self, contexts: list[BaseContext], use_xml: bool) -> str:
|
|
"""Format knowledge contexts."""
|
|
if use_xml:
|
|
parts = ["<reference_documents>"]
|
|
for ctx in contexts:
|
|
parts.append(f'<document source="{ctx.source}">')
|
|
parts.append(ctx.content)
|
|
parts.append("</document>")
|
|
parts.append("</reference_documents>")
|
|
return "\n".join(parts)
|
|
else:
|
|
parts = ["## Reference Documents\n"]
|
|
for ctx in contexts:
|
|
parts.append(f"### Source: {ctx.source}\n")
|
|
parts.append(ctx.content)
|
|
parts.append("")
|
|
return "\n".join(parts)
|
|
|
|
def _format_conversation(self, contexts: list[BaseContext], use_xml: bool) -> str:
|
|
"""Format conversation contexts."""
|
|
if use_xml:
|
|
parts = ["<conversation_history>"]
|
|
for ctx in contexts:
|
|
role = ctx.metadata.get("role", "user")
|
|
parts.append(f'<message role="{role}">')
|
|
parts.append(ctx.content)
|
|
parts.append("</message>")
|
|
parts.append("</conversation_history>")
|
|
return "\n".join(parts)
|
|
else:
|
|
parts = []
|
|
for ctx in contexts:
|
|
role = ctx.metadata.get("role", "user")
|
|
parts.append(f"**{role.upper()}**: {ctx.content}")
|
|
return "\n\n".join(parts)
|
|
|
|
def _format_tool(self, contexts: list[BaseContext], use_xml: bool) -> str:
|
|
"""Format tool contexts."""
|
|
if use_xml:
|
|
parts = ["<tool_results>"]
|
|
for ctx in contexts:
|
|
tool_name = ctx.metadata.get("tool_name", "unknown")
|
|
parts.append(f'<tool_result name="{tool_name}">')
|
|
parts.append(ctx.content)
|
|
parts.append("</tool_result>")
|
|
parts.append("</tool_results>")
|
|
return "\n".join(parts)
|
|
else:
|
|
parts = ["## Recent Tool Results\n"]
|
|
for ctx in contexts:
|
|
tool_name = ctx.metadata.get("tool_name", "unknown")
|
|
parts.append(f"### Tool: {tool_name}\n")
|
|
parts.append(f"```\n{ctx.content}\n```")
|
|
parts.append("")
|
|
return "\n".join(parts)
|
|
|
|
def _check_timeout(
|
|
self,
|
|
start: float,
|
|
timeout_ms: int,
|
|
phase: str,
|
|
) -> None:
|
|
"""Check if timeout exceeded and raise if so."""
|
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
|
if elapsed_ms > timeout_ms:
|
|
raise AssemblyTimeoutError(
|
|
message=f"Context assembly timed out during {phase}",
|
|
elapsed_ms=elapsed_ms,
|
|
timeout_ms=timeout_ms,
|
|
)
|