- Added timeout enforcement for token counting, scoring, and compression with detailed error handling. - Introduced tenant isolation in context caching using project and agent identifiers. - Enhanced budget management with stricter checks for critical context overspending and buffer limitations. - Optimized per-context locking with cleanup to prevent memory leaks in concurrent environments. - Updated default assembly timeout settings for improved performance and reliability. - Improved XML escaping in Claude adapter for safety against injection attacks. - Standardized token estimation using model-specific ratios.
363 lines
13 KiB
Python
363 lines
13 KiB
Python
"""
|
|
Context Assembly Pipeline.
|
|
|
|
Orchestrates the full context assembly workflow:
|
|
Gather → Count → Score → Rank → Compress → Format
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from datetime import UTC, datetime
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
from ..adapters import get_adapter
|
|
from ..budget import BudgetAllocator, TokenBudget, TokenCalculator
|
|
from ..compression.truncation import ContextCompressor
|
|
from ..config import ContextSettings, get_context_settings
|
|
from ..exceptions import AssemblyTimeoutError
|
|
from ..prioritization import ContextRanker
|
|
from ..scoring import CompositeScorer
|
|
from ..types import AssembledContext, BaseContext, ContextType
|
|
|
|
if TYPE_CHECKING:
|
|
from app.services.mcp.client_manager import MCPClientManager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class PipelineMetrics:
|
|
"""Metrics from pipeline execution."""
|
|
|
|
start_time: datetime = field(default_factory=lambda: datetime.now(UTC))
|
|
end_time: datetime | None = None
|
|
total_contexts: int = 0
|
|
selected_contexts: int = 0
|
|
excluded_contexts: int = 0
|
|
compressed_contexts: int = 0
|
|
total_tokens: int = 0
|
|
assembly_time_ms: float = 0.0
|
|
scoring_time_ms: float = 0.0
|
|
compression_time_ms: float = 0.0
|
|
formatting_time_ms: float = 0.0
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
"""Convert to dictionary."""
|
|
return {
|
|
"start_time": self.start_time.isoformat(),
|
|
"end_time": self.end_time.isoformat() if self.end_time else None,
|
|
"total_contexts": self.total_contexts,
|
|
"selected_contexts": self.selected_contexts,
|
|
"excluded_contexts": self.excluded_contexts,
|
|
"compressed_contexts": self.compressed_contexts,
|
|
"total_tokens": self.total_tokens,
|
|
"assembly_time_ms": round(self.assembly_time_ms, 2),
|
|
"scoring_time_ms": round(self.scoring_time_ms, 2),
|
|
"compression_time_ms": round(self.compression_time_ms, 2),
|
|
"formatting_time_ms": round(self.formatting_time_ms, 2),
|
|
}
|
|
|
|
|
|
class ContextPipeline:
|
|
"""
|
|
Context assembly pipeline.
|
|
|
|
Orchestrates the full workflow of context assembly:
|
|
1. Validate and count tokens for all contexts
|
|
2. Score contexts based on relevance, recency, and priority
|
|
3. Rank and select contexts within budget
|
|
4. Compress if needed to fit remaining budget
|
|
5. Format for the target model
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
mcp_manager: "MCPClientManager | None" = None,
|
|
settings: ContextSettings | None = None,
|
|
calculator: TokenCalculator | None = None,
|
|
scorer: CompositeScorer | None = None,
|
|
ranker: ContextRanker | None = None,
|
|
compressor: ContextCompressor | None = None,
|
|
) -> None:
|
|
"""
|
|
Initialize the context pipeline.
|
|
|
|
Args:
|
|
mcp_manager: MCP client manager for LLM Gateway integration
|
|
settings: Context settings
|
|
calculator: Token calculator
|
|
scorer: Context scorer
|
|
ranker: Context ranker
|
|
compressor: Context compressor
|
|
"""
|
|
self._settings = settings or get_context_settings()
|
|
self._mcp = mcp_manager
|
|
|
|
# Initialize components
|
|
self._calculator = calculator or TokenCalculator(mcp_manager=mcp_manager)
|
|
self._scorer = scorer or CompositeScorer(
|
|
mcp_manager=mcp_manager, settings=self._settings
|
|
)
|
|
self._ranker = ranker or ContextRanker(
|
|
scorer=self._scorer, calculator=self._calculator
|
|
)
|
|
self._compressor = compressor or ContextCompressor(calculator=self._calculator)
|
|
self._allocator = BudgetAllocator(self._settings)
|
|
|
|
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
|
"""Set MCP manager for all components."""
|
|
self._mcp = mcp_manager
|
|
self._calculator.set_mcp_manager(mcp_manager)
|
|
self._scorer.set_mcp_manager(mcp_manager)
|
|
|
|
async def assemble(
|
|
self,
|
|
contexts: list[BaseContext],
|
|
query: str,
|
|
model: str,
|
|
max_tokens: int | None = None,
|
|
custom_budget: TokenBudget | None = None,
|
|
compress: bool = True,
|
|
format_output: bool = True,
|
|
timeout_ms: int | None = None,
|
|
) -> AssembledContext:
|
|
"""
|
|
Assemble context for an LLM request.
|
|
|
|
This is the main entry point for context assembly.
|
|
|
|
Args:
|
|
contexts: List of contexts to assemble
|
|
query: Query to optimize for
|
|
model: Target model name
|
|
max_tokens: Maximum total tokens (uses model default if None)
|
|
custom_budget: Optional pre-configured budget
|
|
compress: Whether to compress oversized contexts
|
|
format_output: Whether to format the final output
|
|
timeout_ms: Maximum assembly time in milliseconds
|
|
|
|
Returns:
|
|
AssembledContext with optimized content
|
|
|
|
Raises:
|
|
AssemblyTimeoutError: If assembly exceeds timeout
|
|
"""
|
|
timeout = timeout_ms or self._settings.max_assembly_time_ms
|
|
start = time.perf_counter()
|
|
metrics = PipelineMetrics(total_contexts=len(contexts))
|
|
|
|
try:
|
|
# Create or use budget
|
|
if custom_budget:
|
|
budget = custom_budget
|
|
elif max_tokens:
|
|
budget = self._allocator.create_budget(max_tokens)
|
|
else:
|
|
budget = self._allocator.create_budget_for_model(model)
|
|
|
|
# 1. Count tokens for all contexts (with timeout enforcement)
|
|
try:
|
|
await asyncio.wait_for(
|
|
self._ensure_token_counts(contexts, model),
|
|
timeout=self._remaining_timeout(start, timeout),
|
|
)
|
|
except TimeoutError:
|
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
|
raise AssemblyTimeoutError(
|
|
message="Context assembly timed out during token counting",
|
|
elapsed_ms=elapsed_ms,
|
|
timeout_ms=timeout,
|
|
)
|
|
|
|
# Check timeout (handles edge case where operation finished just at limit)
|
|
self._check_timeout(start, timeout, "token counting")
|
|
|
|
# 2. Score and rank contexts (with timeout enforcement)
|
|
scoring_start = time.perf_counter()
|
|
try:
|
|
ranking_result = await asyncio.wait_for(
|
|
self._ranker.rank(
|
|
contexts=contexts,
|
|
query=query,
|
|
budget=budget,
|
|
model=model,
|
|
),
|
|
timeout=self._remaining_timeout(start, timeout),
|
|
)
|
|
except TimeoutError:
|
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
|
raise AssemblyTimeoutError(
|
|
message="Context assembly timed out during scoring/ranking",
|
|
elapsed_ms=elapsed_ms,
|
|
timeout_ms=timeout,
|
|
)
|
|
metrics.scoring_time_ms = (time.perf_counter() - scoring_start) * 1000
|
|
|
|
selected_contexts = ranking_result.selected_contexts
|
|
metrics.selected_contexts = len(selected_contexts)
|
|
metrics.excluded_contexts = len(ranking_result.excluded)
|
|
|
|
# Check timeout
|
|
self._check_timeout(start, timeout, "scoring")
|
|
|
|
# 3. Compress if needed and enabled (with timeout enforcement)
|
|
if compress and self._needs_compression(selected_contexts, budget):
|
|
compression_start = time.perf_counter()
|
|
try:
|
|
selected_contexts = await asyncio.wait_for(
|
|
self._compressor.compress_contexts(
|
|
selected_contexts, budget, model
|
|
),
|
|
timeout=self._remaining_timeout(start, timeout),
|
|
)
|
|
except TimeoutError:
|
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
|
raise AssemblyTimeoutError(
|
|
message="Context assembly timed out during compression",
|
|
elapsed_ms=elapsed_ms,
|
|
timeout_ms=timeout,
|
|
)
|
|
metrics.compression_time_ms = (
|
|
time.perf_counter() - compression_start
|
|
) * 1000
|
|
metrics.compressed_contexts = sum(
|
|
1 for c in selected_contexts if c.metadata.get("truncated", False)
|
|
)
|
|
|
|
# Check timeout
|
|
self._check_timeout(start, timeout, "compression")
|
|
|
|
# 4. Format output
|
|
formatting_start = time.perf_counter()
|
|
if format_output:
|
|
formatted_content = self._format_contexts(selected_contexts, model)
|
|
else:
|
|
formatted_content = "\n\n".join(c.content for c in selected_contexts)
|
|
metrics.formatting_time_ms = (time.perf_counter() - formatting_start) * 1000
|
|
|
|
# Calculate final metrics
|
|
total_tokens = sum(c.token_count or 0 for c in selected_contexts)
|
|
metrics.total_tokens = total_tokens
|
|
metrics.assembly_time_ms = (time.perf_counter() - start) * 1000
|
|
metrics.end_time = datetime.now(UTC)
|
|
|
|
return AssembledContext(
|
|
content=formatted_content,
|
|
total_tokens=total_tokens,
|
|
context_count=len(selected_contexts),
|
|
assembly_time_ms=metrics.assembly_time_ms,
|
|
model=model,
|
|
contexts=selected_contexts,
|
|
excluded_count=metrics.excluded_contexts,
|
|
metadata={
|
|
"metrics": metrics.to_dict(),
|
|
"query": query,
|
|
"budget": budget.to_dict(),
|
|
},
|
|
)
|
|
|
|
except AssemblyTimeoutError:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Context assembly failed: {e}", exc_info=True)
|
|
raise
|
|
|
|
async def _ensure_token_counts(
|
|
self,
|
|
contexts: list[BaseContext],
|
|
model: str | None = None,
|
|
) -> None:
|
|
"""Ensure all contexts have token counts."""
|
|
tasks = []
|
|
for context in contexts:
|
|
if context.token_count is None:
|
|
tasks.append(self._count_and_set(context, model))
|
|
|
|
if tasks:
|
|
await asyncio.gather(*tasks)
|
|
|
|
async def _count_and_set(
|
|
self,
|
|
context: BaseContext,
|
|
model: str | None = None,
|
|
) -> None:
|
|
"""Count tokens and set on context."""
|
|
count = await self._calculator.count_tokens(context.content, model)
|
|
context.token_count = count
|
|
|
|
def _needs_compression(
|
|
self,
|
|
contexts: list[BaseContext],
|
|
budget: TokenBudget,
|
|
) -> bool:
|
|
"""Check if any contexts exceed their type budget."""
|
|
# Group by type and check totals
|
|
by_type: dict[ContextType, int] = {}
|
|
for context in contexts:
|
|
ct = context.get_type()
|
|
by_type[ct] = by_type.get(ct, 0) + (context.token_count or 0)
|
|
|
|
for ct, total in by_type.items():
|
|
if total > budget.get_allocation(ct):
|
|
return True
|
|
|
|
# Also check if utilization exceeds threshold
|
|
return budget.utilization() > self._settings.compression_threshold
|
|
|
|
def _format_contexts(
|
|
self,
|
|
contexts: list[BaseContext],
|
|
model: str,
|
|
) -> str:
|
|
"""
|
|
Format contexts for the target model.
|
|
|
|
Uses model-specific adapters (ClaudeAdapter, OpenAIAdapter, etc.)
|
|
to format contexts optimally for each model family.
|
|
|
|
Args:
|
|
contexts: Contexts to format
|
|
model: Target model name
|
|
|
|
Returns:
|
|
Formatted context string
|
|
"""
|
|
adapter = get_adapter(model)
|
|
return adapter.format(contexts)
|
|
|
|
def _check_timeout(
|
|
self,
|
|
start: float,
|
|
timeout_ms: int,
|
|
phase: str,
|
|
) -> None:
|
|
"""Check if timeout exceeded and raise if so."""
|
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
|
if elapsed_ms >= timeout_ms:
|
|
raise AssemblyTimeoutError(
|
|
message=f"Context assembly timed out during {phase}",
|
|
elapsed_ms=elapsed_ms,
|
|
timeout_ms=timeout_ms,
|
|
)
|
|
|
|
def _remaining_timeout(self, start: float, timeout_ms: int) -> float:
|
|
"""
|
|
Calculate remaining timeout in seconds for asyncio.wait_for.
|
|
|
|
Returns at least a small positive value to avoid immediate timeout
|
|
edge cases with wait_for.
|
|
|
|
Args:
|
|
start: Start time from time.perf_counter()
|
|
timeout_ms: Total timeout in milliseconds
|
|
|
|
Returns:
|
|
Remaining timeout in seconds (minimum 0.001)
|
|
"""
|
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
|
remaining_ms = timeout_ms - elapsed_ms
|
|
# Return at least 1ms to avoid zero/negative timeout edge cases
|
|
return max(remaining_ms / 1000.0, 0.001)
|