feat(context): implement assembly pipeline and compression (#82)
Phase 4 of Context Management Engine - Assembly Pipeline: - Add TruncationStrategy with end/middle/sentence-aware truncation - Add TruncationResult dataclass for tracking compression metrics - Add ContextCompressor for type-specific compression - Add ContextPipeline orchestrating full assembly workflow: - Token counting for all contexts - Scoring and ranking via ContextRanker - Optional compression when budget threshold exceeded - Model-specific formatting (XML for Claude, markdown for OpenAI) - Add PipelineMetrics for performance tracking - Update AssembledContext with new fields (model, contexts, metadata) - Add backward compatibility aliases for renamed fields Tests: 34 new tests, 223 total context tests passing 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
391
backend/app/services/context/compression/truncation.py
Normal file
391
backend/app/services/context/compression/truncation.py
Normal file
@@ -0,0 +1,391 @@
|
||||
"""
|
||||
Smart Truncation for Context Compression.
|
||||
|
||||
Provides intelligent truncation strategies to reduce context size
|
||||
while preserving the most important information.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..types import BaseContext, ContextType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..budget import TokenBudget, TokenCalculator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TruncationResult:
|
||||
"""Result of truncation operation."""
|
||||
|
||||
original_tokens: int
|
||||
truncated_tokens: int
|
||||
content: str
|
||||
truncated: bool
|
||||
truncation_ratio: float # 0.0 = no truncation, 1.0 = completely removed
|
||||
|
||||
@property
|
||||
def tokens_saved(self) -> int:
|
||||
"""Calculate tokens saved by truncation."""
|
||||
return self.original_tokens - self.truncated_tokens
|
||||
|
||||
|
||||
class TruncationStrategy:
|
||||
"""
|
||||
Smart truncation strategies for context compression.
|
||||
|
||||
Strategies:
|
||||
1. End truncation: Cut from end (for knowledge/docs)
|
||||
2. Middle truncation: Keep start and end (for code)
|
||||
3. Sentence-aware: Truncate at sentence boundaries
|
||||
4. Semantic chunking: Keep most relevant chunks
|
||||
"""
|
||||
|
||||
# Default truncation marker
|
||||
TRUNCATION_MARKER = "\n\n[...content truncated...]\n\n"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
calculator: "TokenCalculator | None" = None,
|
||||
preserve_ratio_start: float = 0.7, # Keep 70% from start by default
|
||||
min_content_length: int = 100, # Minimum characters to keep
|
||||
) -> None:
|
||||
"""
|
||||
Initialize truncation strategy.
|
||||
|
||||
Args:
|
||||
calculator: Token calculator for accurate counting
|
||||
preserve_ratio_start: Ratio of content to keep from start
|
||||
min_content_length: Minimum characters to preserve
|
||||
"""
|
||||
self._calculator = calculator
|
||||
self._preserve_ratio_start = preserve_ratio_start
|
||||
self._min_content_length = min_content_length
|
||||
|
||||
def set_calculator(self, calculator: "TokenCalculator") -> None:
|
||||
"""Set token calculator."""
|
||||
self._calculator = calculator
|
||||
|
||||
async def truncate_to_tokens(
|
||||
self,
|
||||
content: str,
|
||||
max_tokens: int,
|
||||
strategy: str = "end",
|
||||
model: str | None = None,
|
||||
) -> TruncationResult:
|
||||
"""
|
||||
Truncate content to fit within token limit.
|
||||
|
||||
Args:
|
||||
content: Content to truncate
|
||||
max_tokens: Maximum tokens allowed
|
||||
strategy: Truncation strategy ('end', 'middle', 'sentence')
|
||||
model: Model for token counting
|
||||
|
||||
Returns:
|
||||
TruncationResult with truncated content
|
||||
"""
|
||||
if not content:
|
||||
return TruncationResult(
|
||||
original_tokens=0,
|
||||
truncated_tokens=0,
|
||||
content="",
|
||||
truncated=False,
|
||||
truncation_ratio=0.0,
|
||||
)
|
||||
|
||||
# Get original token count
|
||||
original_tokens = await self._count_tokens(content, model)
|
||||
|
||||
if original_tokens <= max_tokens:
|
||||
return TruncationResult(
|
||||
original_tokens=original_tokens,
|
||||
truncated_tokens=original_tokens,
|
||||
content=content,
|
||||
truncated=False,
|
||||
truncation_ratio=0.0,
|
||||
)
|
||||
|
||||
# Apply truncation strategy
|
||||
if strategy == "middle":
|
||||
truncated = await self._truncate_middle(content, max_tokens, model)
|
||||
elif strategy == "sentence":
|
||||
truncated = await self._truncate_sentence(content, max_tokens, model)
|
||||
else: # "end"
|
||||
truncated = await self._truncate_end(content, max_tokens, model)
|
||||
|
||||
truncated_tokens = await self._count_tokens(truncated, model)
|
||||
|
||||
return TruncationResult(
|
||||
original_tokens=original_tokens,
|
||||
truncated_tokens=truncated_tokens,
|
||||
content=truncated,
|
||||
truncated=True,
|
||||
truncation_ratio=1 - (truncated_tokens / original_tokens),
|
||||
)
|
||||
|
||||
async def _truncate_end(
|
||||
self,
|
||||
content: str,
|
||||
max_tokens: int,
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Truncate from end of content.
|
||||
|
||||
Simple but effective for most content types.
|
||||
"""
|
||||
# Binary search for optimal truncation point
|
||||
marker_tokens = await self._count_tokens(self.TRUNCATION_MARKER, model)
|
||||
available_tokens = max_tokens - marker_tokens
|
||||
|
||||
# Estimate characters per token
|
||||
chars_per_token = len(content) / await self._count_tokens(content, model)
|
||||
|
||||
# Start with estimated position
|
||||
estimated_chars = int(available_tokens * chars_per_token)
|
||||
truncated = content[:estimated_chars]
|
||||
|
||||
# Refine with binary search
|
||||
low, high = len(truncated) // 2, len(truncated)
|
||||
best = truncated
|
||||
|
||||
for _ in range(5): # Max 5 iterations
|
||||
mid = (low + high) // 2
|
||||
candidate = content[:mid]
|
||||
tokens = await self._count_tokens(candidate, model)
|
||||
|
||||
if tokens <= available_tokens:
|
||||
best = candidate
|
||||
low = mid + 1
|
||||
else:
|
||||
high = mid - 1
|
||||
|
||||
return best + self.TRUNCATION_MARKER
|
||||
|
||||
async def _truncate_middle(
|
||||
self,
|
||||
content: str,
|
||||
max_tokens: int,
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Truncate from middle, keeping start and end.
|
||||
|
||||
Good for code or content where context at boundaries matters.
|
||||
"""
|
||||
marker_tokens = await self._count_tokens(self.TRUNCATION_MARKER, model)
|
||||
available_tokens = max_tokens - marker_tokens
|
||||
|
||||
# Split between start and end
|
||||
start_tokens = int(available_tokens * self._preserve_ratio_start)
|
||||
end_tokens = available_tokens - start_tokens
|
||||
|
||||
# Get start portion
|
||||
start_content = await self._get_content_for_tokens(
|
||||
content, start_tokens, from_start=True, model=model
|
||||
)
|
||||
|
||||
# Get end portion
|
||||
end_content = await self._get_content_for_tokens(
|
||||
content, end_tokens, from_start=False, model=model
|
||||
)
|
||||
|
||||
return start_content + self.TRUNCATION_MARKER + end_content
|
||||
|
||||
async def _truncate_sentence(
|
||||
self,
|
||||
content: str,
|
||||
max_tokens: int,
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Truncate at sentence boundaries.
|
||||
|
||||
Produces cleaner output by not cutting mid-sentence.
|
||||
"""
|
||||
# Split into sentences
|
||||
sentences = re.split(r"(?<=[.!?])\s+", content)
|
||||
|
||||
result: list[str] = []
|
||||
total_tokens = 0
|
||||
marker_tokens = await self._count_tokens(self.TRUNCATION_MARKER, model)
|
||||
available = max_tokens - marker_tokens
|
||||
|
||||
for sentence in sentences:
|
||||
sentence_tokens = await self._count_tokens(sentence, model)
|
||||
if total_tokens + sentence_tokens <= available:
|
||||
result.append(sentence)
|
||||
total_tokens += sentence_tokens
|
||||
else:
|
||||
break
|
||||
|
||||
if len(result) < len(sentences):
|
||||
return " ".join(result) + self.TRUNCATION_MARKER
|
||||
return " ".join(result)
|
||||
|
||||
async def _get_content_for_tokens(
|
||||
self,
|
||||
content: str,
|
||||
target_tokens: int,
|
||||
from_start: bool = True,
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
"""Get portion of content fitting within token limit."""
|
||||
if target_tokens <= 0:
|
||||
return ""
|
||||
|
||||
current_tokens = await self._count_tokens(content, model)
|
||||
if current_tokens <= target_tokens:
|
||||
return content
|
||||
|
||||
# Estimate characters
|
||||
chars_per_token = len(content) / current_tokens
|
||||
estimated_chars = int(target_tokens * chars_per_token)
|
||||
|
||||
if from_start:
|
||||
return content[:estimated_chars]
|
||||
else:
|
||||
return content[-estimated_chars:]
|
||||
|
||||
async def _count_tokens(self, text: str, model: str | None = None) -> int:
|
||||
"""Count tokens using calculator or estimation."""
|
||||
if self._calculator is not None:
|
||||
return await self._calculator.count_tokens(text, model)
|
||||
|
||||
# Fallback estimation
|
||||
return max(1, len(text) // 4)
|
||||
|
||||
|
||||
class ContextCompressor:
|
||||
"""
|
||||
Compresses contexts to fit within budget constraints.
|
||||
|
||||
Uses truncation strategies to reduce context size while
|
||||
preserving the most important information.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
truncation: TruncationStrategy | None = None,
|
||||
calculator: "TokenCalculator | None" = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize context compressor.
|
||||
|
||||
Args:
|
||||
truncation: Truncation strategy to use
|
||||
calculator: Token calculator for counting
|
||||
"""
|
||||
self._truncation = truncation or TruncationStrategy(calculator)
|
||||
self._calculator = calculator
|
||||
|
||||
if calculator:
|
||||
self._truncation.set_calculator(calculator)
|
||||
|
||||
def set_calculator(self, calculator: "TokenCalculator") -> None:
|
||||
"""Set token calculator."""
|
||||
self._calculator = calculator
|
||||
self._truncation.set_calculator(calculator)
|
||||
|
||||
async def compress_context(
|
||||
self,
|
||||
context: BaseContext,
|
||||
max_tokens: int,
|
||||
model: str | None = None,
|
||||
) -> BaseContext:
|
||||
"""
|
||||
Compress a single context to fit token limit.
|
||||
|
||||
Args:
|
||||
context: Context to compress
|
||||
max_tokens: Maximum tokens allowed
|
||||
model: Model for token counting
|
||||
|
||||
Returns:
|
||||
Compressed context (may be same object if no compression needed)
|
||||
"""
|
||||
current_tokens = context.token_count or await self._count_tokens(
|
||||
context.content, model
|
||||
)
|
||||
|
||||
if current_tokens <= max_tokens:
|
||||
return context
|
||||
|
||||
# Choose strategy based on context type
|
||||
strategy = self._get_strategy_for_type(context.get_type())
|
||||
|
||||
result = await self._truncation.truncate_to_tokens(
|
||||
content=context.content,
|
||||
max_tokens=max_tokens,
|
||||
strategy=strategy,
|
||||
model=model,
|
||||
)
|
||||
|
||||
# Update context with truncated content
|
||||
context.content = result.content
|
||||
context.token_count = result.truncated_tokens
|
||||
context.metadata["truncated"] = True
|
||||
context.metadata["original_tokens"] = result.original_tokens
|
||||
|
||||
return context
|
||||
|
||||
async def compress_contexts(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
budget: "TokenBudget",
|
||||
model: str | None = None,
|
||||
) -> list[BaseContext]:
|
||||
"""
|
||||
Compress multiple contexts to fit within budget.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to potentially compress
|
||||
budget: Token budget constraints
|
||||
model: Model for token counting
|
||||
|
||||
Returns:
|
||||
List of contexts (compressed as needed)
|
||||
"""
|
||||
result: list[BaseContext] = []
|
||||
|
||||
for context in contexts:
|
||||
context_type = context.get_type()
|
||||
remaining = budget.remaining(context_type)
|
||||
current_tokens = context.token_count or await self._count_tokens(
|
||||
context.content, model
|
||||
)
|
||||
|
||||
if current_tokens > remaining:
|
||||
# Need to compress
|
||||
compressed = await self.compress_context(context, remaining, model)
|
||||
result.append(compressed)
|
||||
logger.debug(
|
||||
f"Compressed {context_type.value} context from "
|
||||
f"{current_tokens} to {compressed.token_count} tokens"
|
||||
)
|
||||
else:
|
||||
result.append(context)
|
||||
|
||||
return result
|
||||
|
||||
def _get_strategy_for_type(self, context_type: ContextType) -> str:
|
||||
"""Get optimal truncation strategy for context type."""
|
||||
strategies = {
|
||||
ContextType.SYSTEM: "end", # Keep instructions at start
|
||||
ContextType.TASK: "end", # Keep task description start
|
||||
ContextType.KNOWLEDGE: "sentence", # Clean sentence boundaries
|
||||
ContextType.CONVERSATION: "end", # Keep recent conversation
|
||||
ContextType.TOOL: "middle", # Keep command and result summary
|
||||
}
|
||||
return strategies.get(context_type, "end")
|
||||
|
||||
async def _count_tokens(self, text: str, model: str | None = None) -> int:
|
||||
"""Count tokens using calculator or estimation."""
|
||||
if self._calculator is not None:
|
||||
return await self._calculator.count_tokens(text, model)
|
||||
return max(1, len(text) // 4)
|
||||
Reference in New Issue
Block a user