feat(context): enhance timeout handling, tenant isolation, and budget management

- Added timeout enforcement for token counting, scoring, and compression with detailed error handling.
- Introduced tenant isolation in context caching using project and agent identifiers.
- Enhanced budget management with stricter checks for critical context overspending and buffer limitations.
- Optimized per-context locking with cleanup to prevent memory leaks in concurrent environments.
- Updated default assembly timeout settings for improved performance and reliability.
- Improved XML escaping in Claude adapter for safety against injection attacks.
- Standardized token estimation using model-specific ratios.
This commit is contained in:
2026-01-04 15:52:50 +01:00
parent 2bea057fb1
commit 1628eacf2b
10 changed files with 271 additions and 175 deletions

View File

@@ -12,6 +12,7 @@ from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any
from ..adapters import get_adapter
from ..budget import BudgetAllocator, TokenBudget, TokenCalculator
from ..compression.truncation import ContextCompressor
from ..config import ContextSettings, get_context_settings
@@ -156,20 +157,42 @@ class ContextPipeline:
else:
budget = self._allocator.create_budget_for_model(model)
# 1. Count tokens for all contexts
await self._ensure_token_counts(contexts, model)
# 1. Count tokens for all contexts (with timeout enforcement)
try:
await asyncio.wait_for(
self._ensure_token_counts(contexts, model),
timeout=self._remaining_timeout(start, timeout),
)
except TimeoutError:
elapsed_ms = (time.perf_counter() - start) * 1000
raise AssemblyTimeoutError(
message="Context assembly timed out during token counting",
elapsed_ms=elapsed_ms,
timeout_ms=timeout,
)
# Check timeout
# Check timeout (handles edge case where operation finished just at limit)
self._check_timeout(start, timeout, "token counting")
# 2. Score and rank contexts
# 2. Score and rank contexts (with timeout enforcement)
scoring_start = time.perf_counter()
ranking_result = await self._ranker.rank(
contexts=contexts,
query=query,
budget=budget,
model=model,
)
try:
ranking_result = await asyncio.wait_for(
self._ranker.rank(
contexts=contexts,
query=query,
budget=budget,
model=model,
),
timeout=self._remaining_timeout(start, timeout),
)
except TimeoutError:
elapsed_ms = (time.perf_counter() - start) * 1000
raise AssemblyTimeoutError(
message="Context assembly timed out during scoring/ranking",
elapsed_ms=elapsed_ms,
timeout_ms=timeout,
)
metrics.scoring_time_ms = (time.perf_counter() - scoring_start) * 1000
selected_contexts = ranking_result.selected_contexts
@@ -179,12 +202,23 @@ class ContextPipeline:
# Check timeout
self._check_timeout(start, timeout, "scoring")
# 3. Compress if needed and enabled
# 3. Compress if needed and enabled (with timeout enforcement)
if compress and self._needs_compression(selected_contexts, budget):
compression_start = time.perf_counter()
selected_contexts = await self._compressor.compress_contexts(
selected_contexts, budget, model
)
try:
selected_contexts = await asyncio.wait_for(
self._compressor.compress_contexts(
selected_contexts, budget, model
),
timeout=self._remaining_timeout(start, timeout),
)
except TimeoutError:
elapsed_ms = (time.perf_counter() - start) * 1000
raise AssemblyTimeoutError(
message="Context assembly timed out during compression",
elapsed_ms=elapsed_ms,
timeout_ms=timeout,
)
metrics.compression_time_ms = (
time.perf_counter() - compression_start
) * 1000
@@ -280,129 +314,18 @@ class ContextPipeline:
"""
Format contexts for the target model.
Groups contexts by type and applies model-specific formatting.
Uses model-specific adapters (ClaudeAdapter, OpenAIAdapter, etc.)
to format contexts optimally for each model family.
Args:
contexts: Contexts to format
model: Target model name
Returns:
Formatted context string
"""
# Group by type
by_type: dict[ContextType, list[BaseContext]] = {}
for context in contexts:
ct = context.get_type()
if ct not in by_type:
by_type[ct] = []
by_type[ct].append(context)
# Order types: System -> Task -> Knowledge -> Conversation -> Tool
type_order = [
ContextType.SYSTEM,
ContextType.TASK,
ContextType.KNOWLEDGE,
ContextType.CONVERSATION,
ContextType.TOOL,
]
parts: list[str] = []
for ct in type_order:
if ct in by_type:
formatted = self._format_type(by_type[ct], ct, model)
if formatted:
parts.append(formatted)
return "\n\n".join(parts)
def _format_type(
self,
contexts: list[BaseContext],
context_type: ContextType,
model: str,
) -> str:
"""Format contexts of a specific type."""
if not contexts:
return ""
# Check if model prefers XML tags (Claude)
use_xml = "claude" in model.lower()
if context_type == ContextType.SYSTEM:
return self._format_system(contexts, use_xml)
elif context_type == ContextType.TASK:
return self._format_task(contexts, use_xml)
elif context_type == ContextType.KNOWLEDGE:
return self._format_knowledge(contexts, use_xml)
elif context_type == ContextType.CONVERSATION:
return self._format_conversation(contexts, use_xml)
elif context_type == ContextType.TOOL:
return self._format_tool(contexts, use_xml)
return "\n".join(c.content for c in contexts)
def _format_system(self, contexts: list[BaseContext], use_xml: bool) -> str:
"""Format system contexts."""
content = "\n\n".join(c.content for c in contexts)
if use_xml:
return f"<system_instructions>\n{content}\n</system_instructions>"
return content
def _format_task(self, contexts: list[BaseContext], use_xml: bool) -> str:
"""Format task contexts."""
content = "\n\n".join(c.content for c in contexts)
if use_xml:
return f"<current_task>\n{content}\n</current_task>"
return f"## Current Task\n\n{content}"
def _format_knowledge(self, contexts: list[BaseContext], use_xml: bool) -> str:
"""Format knowledge contexts."""
if use_xml:
parts = ["<reference_documents>"]
for ctx in contexts:
parts.append(f'<document source="{ctx.source}">')
parts.append(ctx.content)
parts.append("</document>")
parts.append("</reference_documents>")
return "\n".join(parts)
else:
parts = ["## Reference Documents\n"]
for ctx in contexts:
parts.append(f"### Source: {ctx.source}\n")
parts.append(ctx.content)
parts.append("")
return "\n".join(parts)
def _format_conversation(self, contexts: list[BaseContext], use_xml: bool) -> str:
"""Format conversation contexts."""
if use_xml:
parts = ["<conversation_history>"]
for ctx in contexts:
role = ctx.metadata.get("role", "user")
parts.append(f'<message role="{role}">')
parts.append(ctx.content)
parts.append("</message>")
parts.append("</conversation_history>")
return "\n".join(parts)
else:
parts = []
for ctx in contexts:
role = ctx.metadata.get("role", "user")
parts.append(f"**{role.upper()}**: {ctx.content}")
return "\n\n".join(parts)
def _format_tool(self, contexts: list[BaseContext], use_xml: bool) -> str:
"""Format tool contexts."""
if use_xml:
parts = ["<tool_results>"]
for ctx in contexts:
tool_name = ctx.metadata.get("tool_name", "unknown")
parts.append(f'<tool_result name="{tool_name}">')
parts.append(ctx.content)
parts.append("</tool_result>")
parts.append("</tool_results>")
return "\n".join(parts)
else:
parts = ["## Recent Tool Results\n"]
for ctx in contexts:
tool_name = ctx.metadata.get("tool_name", "unknown")
parts.append(f"### Tool: {tool_name}\n")
parts.append(f"```\n{ctx.content}\n```")
parts.append("")
return "\n".join(parts)
adapter = get_adapter(model)
return adapter.format(contexts)
def _check_timeout(
self,
@@ -412,9 +335,28 @@ class ContextPipeline:
) -> None:
"""Check if timeout exceeded and raise if so."""
elapsed_ms = (time.perf_counter() - start) * 1000
if elapsed_ms > timeout_ms:
if elapsed_ms >= timeout_ms:
raise AssemblyTimeoutError(
message=f"Context assembly timed out during {phase}",
elapsed_ms=elapsed_ms,
timeout_ms=timeout_ms,
)
def _remaining_timeout(self, start: float, timeout_ms: int) -> float:
"""
Calculate remaining timeout in seconds for asyncio.wait_for.
Returns at least a small positive value to avoid immediate timeout
edge cases with wait_for.
Args:
start: Start time from time.perf_counter()
timeout_ms: Total timeout in milliseconds
Returns:
Remaining timeout in seconds (minimum 0.001)
"""
elapsed_ms = (time.perf_counter() - start) * 1000
remaining_ms = timeout_ms - elapsed_ms
# Return at least 1ms to avoid zero/negative timeout edge cases
return max(remaining_ms / 1000.0, 0.001)