feat(context): enhance timeout handling, tenant isolation, and budget management
- Added timeout enforcement for token counting, scoring, and compression with detailed error handling. - Introduced tenant isolation in context caching using project and agent identifiers. - Enhanced budget management with stricter checks for critical context overspending and buffer limitations. - Optimized per-context locking with cleanup to prevent memory leaks in concurrent environments. - Updated default assembly timeout settings for improved performance and reliability. - Improved XML escaping in Claude adapter for safety against injection attacks. - Standardized token estimation using model-specific ratios.
This commit is contained in:
@@ -12,6 +12,7 @@ from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ..adapters import get_adapter
|
||||
from ..budget import BudgetAllocator, TokenBudget, TokenCalculator
|
||||
from ..compression.truncation import ContextCompressor
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
@@ -156,20 +157,42 @@ class ContextPipeline:
|
||||
else:
|
||||
budget = self._allocator.create_budget_for_model(model)
|
||||
|
||||
# 1. Count tokens for all contexts
|
||||
await self._ensure_token_counts(contexts, model)
|
||||
# 1. Count tokens for all contexts (with timeout enforcement)
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._ensure_token_counts(contexts, model),
|
||||
timeout=self._remaining_timeout(start, timeout),
|
||||
)
|
||||
except TimeoutError:
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
raise AssemblyTimeoutError(
|
||||
message="Context assembly timed out during token counting",
|
||||
elapsed_ms=elapsed_ms,
|
||||
timeout_ms=timeout,
|
||||
)
|
||||
|
||||
# Check timeout
|
||||
# Check timeout (handles edge case where operation finished just at limit)
|
||||
self._check_timeout(start, timeout, "token counting")
|
||||
|
||||
# 2. Score and rank contexts
|
||||
# 2. Score and rank contexts (with timeout enforcement)
|
||||
scoring_start = time.perf_counter()
|
||||
ranking_result = await self._ranker.rank(
|
||||
contexts=contexts,
|
||||
query=query,
|
||||
budget=budget,
|
||||
model=model,
|
||||
)
|
||||
try:
|
||||
ranking_result = await asyncio.wait_for(
|
||||
self._ranker.rank(
|
||||
contexts=contexts,
|
||||
query=query,
|
||||
budget=budget,
|
||||
model=model,
|
||||
),
|
||||
timeout=self._remaining_timeout(start, timeout),
|
||||
)
|
||||
except TimeoutError:
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
raise AssemblyTimeoutError(
|
||||
message="Context assembly timed out during scoring/ranking",
|
||||
elapsed_ms=elapsed_ms,
|
||||
timeout_ms=timeout,
|
||||
)
|
||||
metrics.scoring_time_ms = (time.perf_counter() - scoring_start) * 1000
|
||||
|
||||
selected_contexts = ranking_result.selected_contexts
|
||||
@@ -179,12 +202,23 @@ class ContextPipeline:
|
||||
# Check timeout
|
||||
self._check_timeout(start, timeout, "scoring")
|
||||
|
||||
# 3. Compress if needed and enabled
|
||||
# 3. Compress if needed and enabled (with timeout enforcement)
|
||||
if compress and self._needs_compression(selected_contexts, budget):
|
||||
compression_start = time.perf_counter()
|
||||
selected_contexts = await self._compressor.compress_contexts(
|
||||
selected_contexts, budget, model
|
||||
)
|
||||
try:
|
||||
selected_contexts = await asyncio.wait_for(
|
||||
self._compressor.compress_contexts(
|
||||
selected_contexts, budget, model
|
||||
),
|
||||
timeout=self._remaining_timeout(start, timeout),
|
||||
)
|
||||
except TimeoutError:
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
raise AssemblyTimeoutError(
|
||||
message="Context assembly timed out during compression",
|
||||
elapsed_ms=elapsed_ms,
|
||||
timeout_ms=timeout,
|
||||
)
|
||||
metrics.compression_time_ms = (
|
||||
time.perf_counter() - compression_start
|
||||
) * 1000
|
||||
@@ -280,129 +314,18 @@ class ContextPipeline:
|
||||
"""
|
||||
Format contexts for the target model.
|
||||
|
||||
Groups contexts by type and applies model-specific formatting.
|
||||
Uses model-specific adapters (ClaudeAdapter, OpenAIAdapter, etc.)
|
||||
to format contexts optimally for each model family.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to format
|
||||
model: Target model name
|
||||
|
||||
Returns:
|
||||
Formatted context string
|
||||
"""
|
||||
# Group by type
|
||||
by_type: dict[ContextType, list[BaseContext]] = {}
|
||||
for context in contexts:
|
||||
ct = context.get_type()
|
||||
if ct not in by_type:
|
||||
by_type[ct] = []
|
||||
by_type[ct].append(context)
|
||||
|
||||
# Order types: System -> Task -> Knowledge -> Conversation -> Tool
|
||||
type_order = [
|
||||
ContextType.SYSTEM,
|
||||
ContextType.TASK,
|
||||
ContextType.KNOWLEDGE,
|
||||
ContextType.CONVERSATION,
|
||||
ContextType.TOOL,
|
||||
]
|
||||
|
||||
parts: list[str] = []
|
||||
for ct in type_order:
|
||||
if ct in by_type:
|
||||
formatted = self._format_type(by_type[ct], ct, model)
|
||||
if formatted:
|
||||
parts.append(formatted)
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def _format_type(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
context_type: ContextType,
|
||||
model: str,
|
||||
) -> str:
|
||||
"""Format contexts of a specific type."""
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
# Check if model prefers XML tags (Claude)
|
||||
use_xml = "claude" in model.lower()
|
||||
|
||||
if context_type == ContextType.SYSTEM:
|
||||
return self._format_system(contexts, use_xml)
|
||||
elif context_type == ContextType.TASK:
|
||||
return self._format_task(contexts, use_xml)
|
||||
elif context_type == ContextType.KNOWLEDGE:
|
||||
return self._format_knowledge(contexts, use_xml)
|
||||
elif context_type == ContextType.CONVERSATION:
|
||||
return self._format_conversation(contexts, use_xml)
|
||||
elif context_type == ContextType.TOOL:
|
||||
return self._format_tool(contexts, use_xml)
|
||||
|
||||
return "\n".join(c.content for c in contexts)
|
||||
|
||||
def _format_system(self, contexts: list[BaseContext], use_xml: bool) -> str:
|
||||
"""Format system contexts."""
|
||||
content = "\n\n".join(c.content for c in contexts)
|
||||
if use_xml:
|
||||
return f"<system_instructions>\n{content}\n</system_instructions>"
|
||||
return content
|
||||
|
||||
def _format_task(self, contexts: list[BaseContext], use_xml: bool) -> str:
|
||||
"""Format task contexts."""
|
||||
content = "\n\n".join(c.content for c in contexts)
|
||||
if use_xml:
|
||||
return f"<current_task>\n{content}\n</current_task>"
|
||||
return f"## Current Task\n\n{content}"
|
||||
|
||||
def _format_knowledge(self, contexts: list[BaseContext], use_xml: bool) -> str:
|
||||
"""Format knowledge contexts."""
|
||||
if use_xml:
|
||||
parts = ["<reference_documents>"]
|
||||
for ctx in contexts:
|
||||
parts.append(f'<document source="{ctx.source}">')
|
||||
parts.append(ctx.content)
|
||||
parts.append("</document>")
|
||||
parts.append("</reference_documents>")
|
||||
return "\n".join(parts)
|
||||
else:
|
||||
parts = ["## Reference Documents\n"]
|
||||
for ctx in contexts:
|
||||
parts.append(f"### Source: {ctx.source}\n")
|
||||
parts.append(ctx.content)
|
||||
parts.append("")
|
||||
return "\n".join(parts)
|
||||
|
||||
def _format_conversation(self, contexts: list[BaseContext], use_xml: bool) -> str:
|
||||
"""Format conversation contexts."""
|
||||
if use_xml:
|
||||
parts = ["<conversation_history>"]
|
||||
for ctx in contexts:
|
||||
role = ctx.metadata.get("role", "user")
|
||||
parts.append(f'<message role="{role}">')
|
||||
parts.append(ctx.content)
|
||||
parts.append("</message>")
|
||||
parts.append("</conversation_history>")
|
||||
return "\n".join(parts)
|
||||
else:
|
||||
parts = []
|
||||
for ctx in contexts:
|
||||
role = ctx.metadata.get("role", "user")
|
||||
parts.append(f"**{role.upper()}**: {ctx.content}")
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def _format_tool(self, contexts: list[BaseContext], use_xml: bool) -> str:
|
||||
"""Format tool contexts."""
|
||||
if use_xml:
|
||||
parts = ["<tool_results>"]
|
||||
for ctx in contexts:
|
||||
tool_name = ctx.metadata.get("tool_name", "unknown")
|
||||
parts.append(f'<tool_result name="{tool_name}">')
|
||||
parts.append(ctx.content)
|
||||
parts.append("</tool_result>")
|
||||
parts.append("</tool_results>")
|
||||
return "\n".join(parts)
|
||||
else:
|
||||
parts = ["## Recent Tool Results\n"]
|
||||
for ctx in contexts:
|
||||
tool_name = ctx.metadata.get("tool_name", "unknown")
|
||||
parts.append(f"### Tool: {tool_name}\n")
|
||||
parts.append(f"```\n{ctx.content}\n```")
|
||||
parts.append("")
|
||||
return "\n".join(parts)
|
||||
adapter = get_adapter(model)
|
||||
return adapter.format(contexts)
|
||||
|
||||
def _check_timeout(
|
||||
self,
|
||||
@@ -412,9 +335,28 @@ class ContextPipeline:
|
||||
) -> None:
|
||||
"""Check if timeout exceeded and raise if so."""
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
if elapsed_ms > timeout_ms:
|
||||
if elapsed_ms >= timeout_ms:
|
||||
raise AssemblyTimeoutError(
|
||||
message=f"Context assembly timed out during {phase}",
|
||||
elapsed_ms=elapsed_ms,
|
||||
timeout_ms=timeout_ms,
|
||||
)
|
||||
|
||||
def _remaining_timeout(self, start: float, timeout_ms: int) -> float:
|
||||
"""
|
||||
Calculate remaining timeout in seconds for asyncio.wait_for.
|
||||
|
||||
Returns at least a small positive value to avoid immediate timeout
|
||||
edge cases with wait_for.
|
||||
|
||||
Args:
|
||||
start: Start time from time.perf_counter()
|
||||
timeout_ms: Total timeout in milliseconds
|
||||
|
||||
Returns:
|
||||
Remaining timeout in seconds (minimum 0.001)
|
||||
"""
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
remaining_ms = timeout_ms - elapsed_ms
|
||||
# Return at least 1ms to avoid zero/negative timeout edge cases
|
||||
return max(remaining_ms / 1000.0, 0.001)
|
||||
|
||||
Reference in New Issue
Block a user