- Added timeout enforcement for token counting, scoring, and compression with detailed error handling. - Introduced tenant isolation in context caching using project and agent identifiers. - Enhanced budget management with stricter checks for critical context overspending and buffer limitations. - Optimized per-context locking with cleanup to prevent memory leaks in concurrent environments. - Updated default assembly timeout settings for improved performance and reliability. - Improved XML escaping in Claude adapter for safety against injection attacks. - Standardized token estimation using model-specific ratios.
454 lines
14 KiB
Python
454 lines
14 KiB
Python
"""
|
|
Smart Truncation for Context Compression.
|
|
|
|
Provides intelligent truncation strategies to reduce context size
|
|
while preserving the most important information.
|
|
"""
|
|
|
|
import logging
|
|
import re
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING
|
|
|
|
from ..config import ContextSettings, get_context_settings
|
|
from ..types import BaseContext, ContextType
|
|
|
|
if TYPE_CHECKING:
|
|
from ..budget import TokenBudget, TokenCalculator
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _estimate_tokens(text: str, model: str | None = None) -> int:
|
|
"""
|
|
Estimate token count using model-specific character ratios.
|
|
|
|
Module-level function for reuse across classes. Uses the same ratios
|
|
as TokenCalculator for consistency.
|
|
|
|
Args:
|
|
text: Text to estimate tokens for
|
|
model: Optional model name for model-specific ratios
|
|
|
|
Returns:
|
|
Estimated token count (minimum 1)
|
|
"""
|
|
# Model-specific character ratios (chars per token)
|
|
model_ratios = {
|
|
"claude": 3.5,
|
|
"gpt-4": 4.0,
|
|
"gpt-3.5": 4.0,
|
|
"gemini": 4.0,
|
|
}
|
|
default_ratio = 4.0
|
|
|
|
ratio = default_ratio
|
|
if model:
|
|
model_lower = model.lower()
|
|
for model_prefix, model_ratio in model_ratios.items():
|
|
if model_prefix in model_lower:
|
|
ratio = model_ratio
|
|
break
|
|
|
|
return max(1, int(len(text) / ratio))
|
|
|
|
|
|
@dataclass
|
|
class TruncationResult:
|
|
"""Result of truncation operation."""
|
|
|
|
original_tokens: int
|
|
truncated_tokens: int
|
|
content: str
|
|
truncated: bool
|
|
truncation_ratio: float # 0.0 = no truncation, 1.0 = completely removed
|
|
|
|
@property
|
|
def tokens_saved(self) -> int:
|
|
"""Calculate tokens saved by truncation."""
|
|
return self.original_tokens - self.truncated_tokens
|
|
|
|
|
|
class TruncationStrategy:
|
|
"""
|
|
Smart truncation strategies for context compression.
|
|
|
|
Strategies:
|
|
1. End truncation: Cut from end (for knowledge/docs)
|
|
2. Middle truncation: Keep start and end (for code)
|
|
3. Sentence-aware: Truncate at sentence boundaries
|
|
4. Semantic chunking: Keep most relevant chunks
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
calculator: "TokenCalculator | None" = None,
|
|
preserve_ratio_start: float | None = None,
|
|
min_content_length: int | None = None,
|
|
settings: ContextSettings | None = None,
|
|
) -> None:
|
|
"""
|
|
Initialize truncation strategy.
|
|
|
|
Args:
|
|
calculator: Token calculator for accurate counting
|
|
preserve_ratio_start: Ratio of content to keep from start (overrides settings)
|
|
min_content_length: Minimum characters to preserve (overrides settings)
|
|
settings: Context settings (uses global if None)
|
|
"""
|
|
self._settings = settings or get_context_settings()
|
|
self._calculator = calculator
|
|
|
|
# Use provided values or fall back to settings
|
|
self._preserve_ratio_start = (
|
|
preserve_ratio_start
|
|
if preserve_ratio_start is not None
|
|
else self._settings.truncation_preserve_ratio
|
|
)
|
|
self._min_content_length = (
|
|
min_content_length
|
|
if min_content_length is not None
|
|
else self._settings.truncation_min_content_length
|
|
)
|
|
|
|
@property
|
|
def truncation_marker(self) -> str:
|
|
"""Get truncation marker from settings."""
|
|
return self._settings.truncation_marker
|
|
|
|
def set_calculator(self, calculator: "TokenCalculator") -> None:
|
|
"""Set token calculator."""
|
|
self._calculator = calculator
|
|
|
|
async def truncate_to_tokens(
|
|
self,
|
|
content: str,
|
|
max_tokens: int,
|
|
strategy: str = "end",
|
|
model: str | None = None,
|
|
) -> TruncationResult:
|
|
"""
|
|
Truncate content to fit within token limit.
|
|
|
|
Args:
|
|
content: Content to truncate
|
|
max_tokens: Maximum tokens allowed
|
|
strategy: Truncation strategy ('end', 'middle', 'sentence')
|
|
model: Model for token counting
|
|
|
|
Returns:
|
|
TruncationResult with truncated content
|
|
"""
|
|
if not content:
|
|
return TruncationResult(
|
|
original_tokens=0,
|
|
truncated_tokens=0,
|
|
content="",
|
|
truncated=False,
|
|
truncation_ratio=0.0,
|
|
)
|
|
|
|
# Get original token count
|
|
original_tokens = await self._count_tokens(content, model)
|
|
|
|
if original_tokens <= max_tokens:
|
|
return TruncationResult(
|
|
original_tokens=original_tokens,
|
|
truncated_tokens=original_tokens,
|
|
content=content,
|
|
truncated=False,
|
|
truncation_ratio=0.0,
|
|
)
|
|
|
|
# Apply truncation strategy
|
|
if strategy == "middle":
|
|
truncated = await self._truncate_middle(content, max_tokens, model)
|
|
elif strategy == "sentence":
|
|
truncated = await self._truncate_sentence(content, max_tokens, model)
|
|
else: # "end"
|
|
truncated = await self._truncate_end(content, max_tokens, model)
|
|
|
|
truncated_tokens = await self._count_tokens(truncated, model)
|
|
|
|
return TruncationResult(
|
|
original_tokens=original_tokens,
|
|
truncated_tokens=truncated_tokens,
|
|
content=truncated,
|
|
truncated=True,
|
|
truncation_ratio=0.0
|
|
if original_tokens == 0
|
|
else 1 - (truncated_tokens / original_tokens),
|
|
)
|
|
|
|
async def _truncate_end(
|
|
self,
|
|
content: str,
|
|
max_tokens: int,
|
|
model: str | None = None,
|
|
) -> str:
|
|
"""
|
|
Truncate from end of content.
|
|
|
|
Simple but effective for most content types.
|
|
"""
|
|
# Binary search for optimal truncation point
|
|
marker_tokens = await self._count_tokens(self.truncation_marker, model)
|
|
available_tokens = max(0, max_tokens - marker_tokens)
|
|
|
|
# Edge case: if no tokens available for content, return just the marker
|
|
if available_tokens <= 0:
|
|
return self.truncation_marker
|
|
|
|
# Estimate characters per token (guard against division by zero)
|
|
content_tokens = await self._count_tokens(content, model)
|
|
if content_tokens == 0:
|
|
return content + self.truncation_marker
|
|
chars_per_token = len(content) / content_tokens
|
|
|
|
# Start with estimated position
|
|
estimated_chars = int(available_tokens * chars_per_token)
|
|
truncated = content[:estimated_chars]
|
|
|
|
# Refine with binary search
|
|
low, high = len(truncated) // 2, len(truncated)
|
|
best = truncated
|
|
|
|
for _ in range(5): # Max 5 iterations
|
|
mid = (low + high) // 2
|
|
candidate = content[:mid]
|
|
tokens = await self._count_tokens(candidate, model)
|
|
|
|
if tokens <= available_tokens:
|
|
best = candidate
|
|
low = mid + 1
|
|
else:
|
|
high = mid - 1
|
|
|
|
return best + self.truncation_marker
|
|
|
|
async def _truncate_middle(
|
|
self,
|
|
content: str,
|
|
max_tokens: int,
|
|
model: str | None = None,
|
|
) -> str:
|
|
"""
|
|
Truncate from middle, keeping start and end.
|
|
|
|
Good for code or content where context at boundaries matters.
|
|
"""
|
|
marker_tokens = await self._count_tokens(self.truncation_marker, model)
|
|
available_tokens = max_tokens - marker_tokens
|
|
|
|
# Split between start and end
|
|
start_tokens = int(available_tokens * self._preserve_ratio_start)
|
|
end_tokens = available_tokens - start_tokens
|
|
|
|
# Get start portion
|
|
start_content = await self._get_content_for_tokens(
|
|
content, start_tokens, from_start=True, model=model
|
|
)
|
|
|
|
# Get end portion
|
|
end_content = await self._get_content_for_tokens(
|
|
content, end_tokens, from_start=False, model=model
|
|
)
|
|
|
|
return start_content + self.truncation_marker + end_content
|
|
|
|
async def _truncate_sentence(
|
|
self,
|
|
content: str,
|
|
max_tokens: int,
|
|
model: str | None = None,
|
|
) -> str:
|
|
"""
|
|
Truncate at sentence boundaries.
|
|
|
|
Produces cleaner output by not cutting mid-sentence.
|
|
"""
|
|
# Split into sentences
|
|
sentences = re.split(r"(?<=[.!?])\s+", content)
|
|
|
|
result: list[str] = []
|
|
total_tokens = 0
|
|
marker_tokens = await self._count_tokens(self.truncation_marker, model)
|
|
available = max_tokens - marker_tokens
|
|
|
|
for sentence in sentences:
|
|
sentence_tokens = await self._count_tokens(sentence, model)
|
|
if total_tokens + sentence_tokens <= available:
|
|
result.append(sentence)
|
|
total_tokens += sentence_tokens
|
|
else:
|
|
break
|
|
|
|
if len(result) < len(sentences):
|
|
return " ".join(result) + self.truncation_marker
|
|
return " ".join(result)
|
|
|
|
async def _get_content_for_tokens(
|
|
self,
|
|
content: str,
|
|
target_tokens: int,
|
|
from_start: bool = True,
|
|
model: str | None = None,
|
|
) -> str:
|
|
"""Get portion of content fitting within token limit."""
|
|
if target_tokens <= 0:
|
|
return ""
|
|
|
|
current_tokens = await self._count_tokens(content, model)
|
|
if current_tokens <= target_tokens:
|
|
return content
|
|
|
|
# Estimate characters (guard against division by zero)
|
|
if current_tokens == 0:
|
|
return content
|
|
chars_per_token = len(content) / current_tokens
|
|
estimated_chars = int(target_tokens * chars_per_token)
|
|
|
|
if from_start:
|
|
return content[:estimated_chars]
|
|
else:
|
|
return content[-estimated_chars:]
|
|
|
|
async def _count_tokens(self, text: str, model: str | None = None) -> int:
|
|
"""Count tokens using calculator or estimation."""
|
|
if self._calculator is not None:
|
|
return await self._calculator.count_tokens(text, model)
|
|
|
|
# Fallback estimation with model-specific ratios
|
|
return _estimate_tokens(text, model)
|
|
|
|
|
|
class ContextCompressor:
|
|
"""
|
|
Compresses contexts to fit within budget constraints.
|
|
|
|
Uses truncation strategies to reduce context size while
|
|
preserving the most important information.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
truncation: TruncationStrategy | None = None,
|
|
calculator: "TokenCalculator | None" = None,
|
|
) -> None:
|
|
"""
|
|
Initialize context compressor.
|
|
|
|
Args:
|
|
truncation: Truncation strategy to use
|
|
calculator: Token calculator for counting
|
|
"""
|
|
self._truncation = truncation or TruncationStrategy(calculator)
|
|
self._calculator = calculator
|
|
|
|
if calculator:
|
|
self._truncation.set_calculator(calculator)
|
|
|
|
def set_calculator(self, calculator: "TokenCalculator") -> None:
|
|
"""Set token calculator."""
|
|
self._calculator = calculator
|
|
self._truncation.set_calculator(calculator)
|
|
|
|
async def compress_context(
|
|
self,
|
|
context: BaseContext,
|
|
max_tokens: int,
|
|
model: str | None = None,
|
|
) -> BaseContext:
|
|
"""
|
|
Compress a single context to fit token limit.
|
|
|
|
Args:
|
|
context: Context to compress
|
|
max_tokens: Maximum tokens allowed
|
|
model: Model for token counting
|
|
|
|
Returns:
|
|
Compressed context (may be same object if no compression needed)
|
|
"""
|
|
current_tokens = context.token_count or await self._count_tokens(
|
|
context.content, model
|
|
)
|
|
|
|
if current_tokens <= max_tokens:
|
|
return context
|
|
|
|
# Choose strategy based on context type
|
|
strategy = self._get_strategy_for_type(context.get_type())
|
|
|
|
result = await self._truncation.truncate_to_tokens(
|
|
content=context.content,
|
|
max_tokens=max_tokens,
|
|
strategy=strategy,
|
|
model=model,
|
|
)
|
|
|
|
# Update context with truncated content
|
|
context.content = result.content
|
|
context.token_count = result.truncated_tokens
|
|
context.metadata["truncated"] = True
|
|
context.metadata["original_tokens"] = result.original_tokens
|
|
|
|
return context
|
|
|
|
async def compress_contexts(
|
|
self,
|
|
contexts: list[BaseContext],
|
|
budget: "TokenBudget",
|
|
model: str | None = None,
|
|
) -> list[BaseContext]:
|
|
"""
|
|
Compress multiple contexts to fit within budget.
|
|
|
|
Args:
|
|
contexts: Contexts to potentially compress
|
|
budget: Token budget constraints
|
|
model: Model for token counting
|
|
|
|
Returns:
|
|
List of contexts (compressed as needed)
|
|
"""
|
|
result: list[BaseContext] = []
|
|
|
|
for context in contexts:
|
|
context_type = context.get_type()
|
|
remaining = budget.remaining(context_type)
|
|
current_tokens = context.token_count or await self._count_tokens(
|
|
context.content, model
|
|
)
|
|
|
|
if current_tokens > remaining:
|
|
# Need to compress
|
|
compressed = await self.compress_context(context, remaining, model)
|
|
result.append(compressed)
|
|
logger.debug(
|
|
f"Compressed {context_type.value} context from "
|
|
f"{current_tokens} to {compressed.token_count} tokens"
|
|
)
|
|
else:
|
|
result.append(context)
|
|
|
|
return result
|
|
|
|
def _get_strategy_for_type(self, context_type: ContextType) -> str:
|
|
"""Get optimal truncation strategy for context type."""
|
|
strategies = {
|
|
ContextType.SYSTEM: "end", # Keep instructions at start
|
|
ContextType.TASK: "end", # Keep task description start
|
|
ContextType.KNOWLEDGE: "sentence", # Clean sentence boundaries
|
|
ContextType.CONVERSATION: "end", # Keep recent conversation
|
|
ContextType.TOOL: "middle", # Keep command and result summary
|
|
}
|
|
return strategies.get(context_type, "end")
|
|
|
|
async def _count_tokens(self, text: str, model: str | None = None) -> int:
|
|
"""Count tokens using calculator or estimation."""
|
|
if self._calculator is not None:
|
|
return await self._calculator.count_tokens(text, model)
|
|
# Use model-specific estimation for consistency
|
|
return _estimate_tokens(text, model)
|