forked from cardosofelipe/fast-next-template
feat(context): implement assembly pipeline and compression (#82)
Phase 4 of Context Management Engine - Assembly Pipeline: - Add TruncationStrategy with end/middle/sentence-aware truncation - Add TruncationResult dataclass for tracking compression metrics - Add ContextCompressor for type-specific compression - Add ContextPipeline orchestrating full assembly workflow: - Token counting for all contexts - Scoring and ranking via ContextRanker - Optional compression when budget threshold exceeded - Model-specific formatting (XML for Claude, markdown for OpenAI) - Add PipelineMetrics for performance tracking - Update AssembledContext with new fields (model, contexts, metadata) - Add backward compatibility aliases for renamed fields Tests: 34 new tests, 223 total context tests passing 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -63,6 +63,19 @@ from .exceptions import (
|
||||
TokenCountError,
|
||||
)
|
||||
|
||||
# Assembly
|
||||
from .assembly import (
|
||||
ContextPipeline,
|
||||
PipelineMetrics,
|
||||
)
|
||||
|
||||
# Compression
|
||||
from .compression import (
|
||||
ContextCompressor,
|
||||
TruncationResult,
|
||||
TruncationStrategy,
|
||||
)
|
||||
|
||||
# Prioritization
|
||||
from .prioritization import (
|
||||
ContextRanker,
|
||||
@@ -97,10 +110,17 @@ from .types import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Assembly
|
||||
"ContextPipeline",
|
||||
"PipelineMetrics",
|
||||
# Budget Management
|
||||
"BudgetAllocator",
|
||||
"TokenBudget",
|
||||
"TokenCalculator",
|
||||
# Compression
|
||||
"ContextCompressor",
|
||||
"TruncationResult",
|
||||
"TruncationStrategy",
|
||||
# Configuration
|
||||
"ContextSettings",
|
||||
"get_context_settings",
|
||||
|
||||
@@ -3,3 +3,10 @@ Context Assembly Module.
|
||||
|
||||
Provides the assembly pipeline and formatting.
|
||||
"""
|
||||
|
||||
from .pipeline import ContextPipeline, PipelineMetrics
|
||||
|
||||
__all__ = [
|
||||
"ContextPipeline",
|
||||
"PipelineMetrics",
|
||||
]
|
||||
|
||||
432
backend/app/services/context/assembly/pipeline.py
Normal file
432
backend/app/services/context/assembly/pipeline.py
Normal file
@@ -0,0 +1,432 @@
|
||||
"""
|
||||
Context Assembly Pipeline.
|
||||
|
||||
Orchestrates the full context assembly workflow:
|
||||
Gather → Count → Score → Rank → Compress → Format
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ..budget import BudgetAllocator, TokenBudget, TokenCalculator
|
||||
from ..compression.truncation import ContextCompressor
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
from ..exceptions import AssemblyTimeoutError
|
||||
from ..prioritization import ContextRanker
|
||||
from ..scoring import CompositeScorer
|
||||
from ..types import AssembledContext, BaseContext, ContextType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.mcp.client_manager import MCPClientManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineMetrics:
|
||||
"""Metrics from pipeline execution."""
|
||||
|
||||
start_time: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
end_time: datetime | None = None
|
||||
total_contexts: int = 0
|
||||
selected_contexts: int = 0
|
||||
excluded_contexts: int = 0
|
||||
compressed_contexts: int = 0
|
||||
total_tokens: int = 0
|
||||
assembly_time_ms: float = 0.0
|
||||
scoring_time_ms: float = 0.0
|
||||
compression_time_ms: float = 0.0
|
||||
formatting_time_ms: float = 0.0
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"start_time": self.start_time.isoformat(),
|
||||
"end_time": self.end_time.isoformat() if self.end_time else None,
|
||||
"total_contexts": self.total_contexts,
|
||||
"selected_contexts": self.selected_contexts,
|
||||
"excluded_contexts": self.excluded_contexts,
|
||||
"compressed_contexts": self.compressed_contexts,
|
||||
"total_tokens": self.total_tokens,
|
||||
"assembly_time_ms": round(self.assembly_time_ms, 2),
|
||||
"scoring_time_ms": round(self.scoring_time_ms, 2),
|
||||
"compression_time_ms": round(self.compression_time_ms, 2),
|
||||
"formatting_time_ms": round(self.formatting_time_ms, 2),
|
||||
}
|
||||
|
||||
|
||||
class ContextPipeline:
|
||||
"""
|
||||
Context assembly pipeline.
|
||||
|
||||
Orchestrates the full workflow of context assembly:
|
||||
1. Validate and count tokens for all contexts
|
||||
2. Score contexts based on relevance, recency, and priority
|
||||
3. Rank and select contexts within budget
|
||||
4. Compress if needed to fit remaining budget
|
||||
5. Format for the target model
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_manager: "MCPClientManager | None" = None,
|
||||
settings: ContextSettings | None = None,
|
||||
calculator: TokenCalculator | None = None,
|
||||
scorer: CompositeScorer | None = None,
|
||||
ranker: ContextRanker | None = None,
|
||||
compressor: ContextCompressor | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the context pipeline.
|
||||
|
||||
Args:
|
||||
mcp_manager: MCP client manager for LLM Gateway integration
|
||||
settings: Context settings
|
||||
calculator: Token calculator
|
||||
scorer: Context scorer
|
||||
ranker: Context ranker
|
||||
compressor: Context compressor
|
||||
"""
|
||||
self._settings = settings or get_context_settings()
|
||||
self._mcp = mcp_manager
|
||||
|
||||
# Initialize components
|
||||
self._calculator = calculator or TokenCalculator(mcp_manager=mcp_manager)
|
||||
self._scorer = scorer or CompositeScorer(
|
||||
mcp_manager=mcp_manager, settings=self._settings
|
||||
)
|
||||
self._ranker = ranker or ContextRanker(
|
||||
scorer=self._scorer, calculator=self._calculator
|
||||
)
|
||||
self._compressor = compressor or ContextCompressor(
|
||||
calculator=self._calculator
|
||||
)
|
||||
self._allocator = BudgetAllocator(self._settings)
|
||||
|
||||
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||
"""Set MCP manager for all components."""
|
||||
self._mcp = mcp_manager
|
||||
self._calculator.set_mcp_manager(mcp_manager)
|
||||
self._scorer.set_mcp_manager(mcp_manager)
|
||||
|
||||
async def assemble(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
model: str,
|
||||
max_tokens: int | None = None,
|
||||
custom_budget: TokenBudget | None = None,
|
||||
compress: bool = True,
|
||||
format_output: bool = True,
|
||||
timeout_ms: int | None = None,
|
||||
) -> AssembledContext:
|
||||
"""
|
||||
Assemble context for an LLM request.
|
||||
|
||||
This is the main entry point for context assembly.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts to assemble
|
||||
query: Query to optimize for
|
||||
model: Target model name
|
||||
max_tokens: Maximum total tokens (uses model default if None)
|
||||
custom_budget: Optional pre-configured budget
|
||||
compress: Whether to compress oversized contexts
|
||||
format_output: Whether to format the final output
|
||||
timeout_ms: Maximum assembly time in milliseconds
|
||||
|
||||
Returns:
|
||||
AssembledContext with optimized content
|
||||
|
||||
Raises:
|
||||
AssemblyTimeoutError: If assembly exceeds timeout
|
||||
"""
|
||||
timeout = timeout_ms or self._settings.max_assembly_time_ms
|
||||
start = time.perf_counter()
|
||||
metrics = PipelineMetrics(total_contexts=len(contexts))
|
||||
|
||||
try:
|
||||
# Create or use budget
|
||||
if custom_budget:
|
||||
budget = custom_budget
|
||||
elif max_tokens:
|
||||
budget = self._allocator.create_budget(max_tokens)
|
||||
else:
|
||||
budget = self._allocator.create_budget_for_model(model)
|
||||
|
||||
# 1. Count tokens for all contexts
|
||||
await self._ensure_token_counts(contexts, model)
|
||||
|
||||
# Check timeout
|
||||
self._check_timeout(start, timeout, "token counting")
|
||||
|
||||
# 2. Score and rank contexts
|
||||
scoring_start = time.perf_counter()
|
||||
ranking_result = await self._ranker.rank(
|
||||
contexts=contexts,
|
||||
query=query,
|
||||
budget=budget,
|
||||
model=model,
|
||||
)
|
||||
metrics.scoring_time_ms = (time.perf_counter() - scoring_start) * 1000
|
||||
|
||||
selected_contexts = ranking_result.selected_contexts
|
||||
metrics.selected_contexts = len(selected_contexts)
|
||||
metrics.excluded_contexts = len(ranking_result.excluded)
|
||||
|
||||
# Check timeout
|
||||
self._check_timeout(start, timeout, "scoring")
|
||||
|
||||
# 3. Compress if needed and enabled
|
||||
if compress and self._needs_compression(selected_contexts, budget):
|
||||
compression_start = time.perf_counter()
|
||||
selected_contexts = await self._compressor.compress_contexts(
|
||||
selected_contexts, budget, model
|
||||
)
|
||||
metrics.compression_time_ms = (
|
||||
time.perf_counter() - compression_start
|
||||
) * 1000
|
||||
metrics.compressed_contexts = sum(
|
||||
1 for c in selected_contexts if c.metadata.get("truncated", False)
|
||||
)
|
||||
|
||||
# Check timeout
|
||||
self._check_timeout(start, timeout, "compression")
|
||||
|
||||
# 4. Format output
|
||||
formatting_start = time.perf_counter()
|
||||
if format_output:
|
||||
formatted_content = self._format_contexts(selected_contexts, model)
|
||||
else:
|
||||
formatted_content = "\n\n".join(c.content for c in selected_contexts)
|
||||
metrics.formatting_time_ms = (time.perf_counter() - formatting_start) * 1000
|
||||
|
||||
# Calculate final metrics
|
||||
total_tokens = sum(c.token_count or 0 for c in selected_contexts)
|
||||
metrics.total_tokens = total_tokens
|
||||
metrics.assembly_time_ms = (time.perf_counter() - start) * 1000
|
||||
metrics.end_time = datetime.now(UTC)
|
||||
|
||||
return AssembledContext(
|
||||
content=formatted_content,
|
||||
total_tokens=total_tokens,
|
||||
context_count=len(selected_contexts),
|
||||
assembly_time_ms=metrics.assembly_time_ms,
|
||||
model=model,
|
||||
contexts=selected_contexts,
|
||||
excluded_count=metrics.excluded_contexts,
|
||||
metadata={
|
||||
"metrics": metrics.to_dict(),
|
||||
"query": query,
|
||||
"budget": budget.to_dict(),
|
||||
},
|
||||
)
|
||||
|
||||
except AssemblyTimeoutError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Context assembly failed: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def _ensure_token_counts(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
model: str | None = None,
|
||||
) -> None:
|
||||
"""Ensure all contexts have token counts."""
|
||||
tasks = []
|
||||
for context in contexts:
|
||||
if context.token_count is None:
|
||||
tasks.append(self._count_and_set(context, model))
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def _count_and_set(
|
||||
self,
|
||||
context: BaseContext,
|
||||
model: str | None = None,
|
||||
) -> None:
|
||||
"""Count tokens and set on context."""
|
||||
count = await self._calculator.count_tokens(context.content, model)
|
||||
context.token_count = count
|
||||
|
||||
def _needs_compression(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
budget: TokenBudget,
|
||||
) -> bool:
|
||||
"""Check if any contexts exceed their type budget."""
|
||||
# Group by type and check totals
|
||||
by_type: dict[ContextType, int] = {}
|
||||
for context in contexts:
|
||||
ct = context.get_type()
|
||||
by_type[ct] = by_type.get(ct, 0) + (context.token_count or 0)
|
||||
|
||||
for ct, total in by_type.items():
|
||||
if total > budget.get_allocation(ct):
|
||||
return True
|
||||
|
||||
# Also check if utilization exceeds threshold
|
||||
return budget.utilization() > self._settings.compression_threshold
|
||||
|
||||
def _format_contexts(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
model: str,
|
||||
) -> str:
|
||||
"""
|
||||
Format contexts for the target model.
|
||||
|
||||
Groups contexts by type and applies model-specific formatting.
|
||||
"""
|
||||
# Group by type
|
||||
by_type: dict[ContextType, list[BaseContext]] = {}
|
||||
for context in contexts:
|
||||
ct = context.get_type()
|
||||
if ct not in by_type:
|
||||
by_type[ct] = []
|
||||
by_type[ct].append(context)
|
||||
|
||||
# Order types: System -> Task -> Knowledge -> Conversation -> Tool
|
||||
type_order = [
|
||||
ContextType.SYSTEM,
|
||||
ContextType.TASK,
|
||||
ContextType.KNOWLEDGE,
|
||||
ContextType.CONVERSATION,
|
||||
ContextType.TOOL,
|
||||
]
|
||||
|
||||
parts: list[str] = []
|
||||
for ct in type_order:
|
||||
if ct in by_type:
|
||||
formatted = self._format_type(by_type[ct], ct, model)
|
||||
if formatted:
|
||||
parts.append(formatted)
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def _format_type(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
context_type: ContextType,
|
||||
model: str,
|
||||
) -> str:
|
||||
"""Format contexts of a specific type."""
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
# Check if model prefers XML tags (Claude)
|
||||
use_xml = "claude" in model.lower()
|
||||
|
||||
if context_type == ContextType.SYSTEM:
|
||||
return self._format_system(contexts, use_xml)
|
||||
elif context_type == ContextType.TASK:
|
||||
return self._format_task(contexts, use_xml)
|
||||
elif context_type == ContextType.KNOWLEDGE:
|
||||
return self._format_knowledge(contexts, use_xml)
|
||||
elif context_type == ContextType.CONVERSATION:
|
||||
return self._format_conversation(contexts, use_xml)
|
||||
elif context_type == ContextType.TOOL:
|
||||
return self._format_tool(contexts, use_xml)
|
||||
|
||||
return "\n".join(c.content for c in contexts)
|
||||
|
||||
def _format_system(
|
||||
self, contexts: list[BaseContext], use_xml: bool
|
||||
) -> str:
|
||||
"""Format system contexts."""
|
||||
content = "\n\n".join(c.content for c in contexts)
|
||||
if use_xml:
|
||||
return f"<system_instructions>\n{content}\n</system_instructions>"
|
||||
return content
|
||||
|
||||
def _format_task(
|
||||
self, contexts: list[BaseContext], use_xml: bool
|
||||
) -> str:
|
||||
"""Format task contexts."""
|
||||
content = "\n\n".join(c.content for c in contexts)
|
||||
if use_xml:
|
||||
return f"<current_task>\n{content}\n</current_task>"
|
||||
return f"## Current Task\n\n{content}"
|
||||
|
||||
def _format_knowledge(
|
||||
self, contexts: list[BaseContext], use_xml: bool
|
||||
) -> str:
|
||||
"""Format knowledge contexts."""
|
||||
if use_xml:
|
||||
parts = ["<reference_documents>"]
|
||||
for ctx in contexts:
|
||||
parts.append(f'<document source="{ctx.source}">')
|
||||
parts.append(ctx.content)
|
||||
parts.append("</document>")
|
||||
parts.append("</reference_documents>")
|
||||
return "\n".join(parts)
|
||||
else:
|
||||
parts = ["## Reference Documents\n"]
|
||||
for ctx in contexts:
|
||||
parts.append(f"### Source: {ctx.source}\n")
|
||||
parts.append(ctx.content)
|
||||
parts.append("")
|
||||
return "\n".join(parts)
|
||||
|
||||
def _format_conversation(
|
||||
self, contexts: list[BaseContext], use_xml: bool
|
||||
) -> str:
|
||||
"""Format conversation contexts."""
|
||||
if use_xml:
|
||||
parts = ["<conversation_history>"]
|
||||
for ctx in contexts:
|
||||
role = ctx.metadata.get("role", "user")
|
||||
parts.append(f'<message role="{role}">')
|
||||
parts.append(ctx.content)
|
||||
parts.append("</message>")
|
||||
parts.append("</conversation_history>")
|
||||
return "\n".join(parts)
|
||||
else:
|
||||
parts = []
|
||||
for ctx in contexts:
|
||||
role = ctx.metadata.get("role", "user")
|
||||
parts.append(f"**{role.upper()}**: {ctx.content}")
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def _format_tool(
|
||||
self, contexts: list[BaseContext], use_xml: bool
|
||||
) -> str:
|
||||
"""Format tool contexts."""
|
||||
if use_xml:
|
||||
parts = ["<tool_results>"]
|
||||
for ctx in contexts:
|
||||
tool_name = ctx.metadata.get("tool_name", "unknown")
|
||||
parts.append(f'<tool_result name="{tool_name}">')
|
||||
parts.append(ctx.content)
|
||||
parts.append("</tool_result>")
|
||||
parts.append("</tool_results>")
|
||||
return "\n".join(parts)
|
||||
else:
|
||||
parts = ["## Recent Tool Results\n"]
|
||||
for ctx in contexts:
|
||||
tool_name = ctx.metadata.get("tool_name", "unknown")
|
||||
parts.append(f"### Tool: {tool_name}\n")
|
||||
parts.append(f"```\n{ctx.content}\n```")
|
||||
parts.append("")
|
||||
return "\n".join(parts)
|
||||
|
||||
def _check_timeout(
|
||||
self,
|
||||
start: float,
|
||||
timeout_ms: int,
|
||||
phase: str,
|
||||
) -> None:
|
||||
"""Check if timeout exceeded and raise if so."""
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
if elapsed_ms > timeout_ms:
|
||||
raise AssemblyTimeoutError(
|
||||
message=f"Context assembly timed out during {phase}",
|
||||
elapsed_ms=elapsed_ms,
|
||||
timeout_ms=timeout_ms,
|
||||
)
|
||||
@@ -3,3 +3,11 @@ Context Compression Module.
|
||||
|
||||
Provides truncation and compression strategies.
|
||||
"""
|
||||
|
||||
from .truncation import ContextCompressor, TruncationResult, TruncationStrategy
|
||||
|
||||
__all__ = [
|
||||
"ContextCompressor",
|
||||
"TruncationResult",
|
||||
"TruncationStrategy",
|
||||
]
|
||||
|
||||
391
backend/app/services/context/compression/truncation.py
Normal file
391
backend/app/services/context/compression/truncation.py
Normal file
@@ -0,0 +1,391 @@
|
||||
"""
|
||||
Smart Truncation for Context Compression.
|
||||
|
||||
Provides intelligent truncation strategies to reduce context size
|
||||
while preserving the most important information.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..types import BaseContext, ContextType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..budget import TokenBudget, TokenCalculator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TruncationResult:
|
||||
"""Result of truncation operation."""
|
||||
|
||||
original_tokens: int
|
||||
truncated_tokens: int
|
||||
content: str
|
||||
truncated: bool
|
||||
truncation_ratio: float # 0.0 = no truncation, 1.0 = completely removed
|
||||
|
||||
@property
|
||||
def tokens_saved(self) -> int:
|
||||
"""Calculate tokens saved by truncation."""
|
||||
return self.original_tokens - self.truncated_tokens
|
||||
|
||||
|
||||
class TruncationStrategy:
|
||||
"""
|
||||
Smart truncation strategies for context compression.
|
||||
|
||||
Strategies:
|
||||
1. End truncation: Cut from end (for knowledge/docs)
|
||||
2. Middle truncation: Keep start and end (for code)
|
||||
3. Sentence-aware: Truncate at sentence boundaries
|
||||
4. Semantic chunking: Keep most relevant chunks
|
||||
"""
|
||||
|
||||
# Default truncation marker
|
||||
TRUNCATION_MARKER = "\n\n[...content truncated...]\n\n"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
calculator: "TokenCalculator | None" = None,
|
||||
preserve_ratio_start: float = 0.7, # Keep 70% from start by default
|
||||
min_content_length: int = 100, # Minimum characters to keep
|
||||
) -> None:
|
||||
"""
|
||||
Initialize truncation strategy.
|
||||
|
||||
Args:
|
||||
calculator: Token calculator for accurate counting
|
||||
preserve_ratio_start: Ratio of content to keep from start
|
||||
min_content_length: Minimum characters to preserve
|
||||
"""
|
||||
self._calculator = calculator
|
||||
self._preserve_ratio_start = preserve_ratio_start
|
||||
self._min_content_length = min_content_length
|
||||
|
||||
def set_calculator(self, calculator: "TokenCalculator") -> None:
|
||||
"""Set token calculator."""
|
||||
self._calculator = calculator
|
||||
|
||||
async def truncate_to_tokens(
|
||||
self,
|
||||
content: str,
|
||||
max_tokens: int,
|
||||
strategy: str = "end",
|
||||
model: str | None = None,
|
||||
) -> TruncationResult:
|
||||
"""
|
||||
Truncate content to fit within token limit.
|
||||
|
||||
Args:
|
||||
content: Content to truncate
|
||||
max_tokens: Maximum tokens allowed
|
||||
strategy: Truncation strategy ('end', 'middle', 'sentence')
|
||||
model: Model for token counting
|
||||
|
||||
Returns:
|
||||
TruncationResult with truncated content
|
||||
"""
|
||||
if not content:
|
||||
return TruncationResult(
|
||||
original_tokens=0,
|
||||
truncated_tokens=0,
|
||||
content="",
|
||||
truncated=False,
|
||||
truncation_ratio=0.0,
|
||||
)
|
||||
|
||||
# Get original token count
|
||||
original_tokens = await self._count_tokens(content, model)
|
||||
|
||||
if original_tokens <= max_tokens:
|
||||
return TruncationResult(
|
||||
original_tokens=original_tokens,
|
||||
truncated_tokens=original_tokens,
|
||||
content=content,
|
||||
truncated=False,
|
||||
truncation_ratio=0.0,
|
||||
)
|
||||
|
||||
# Apply truncation strategy
|
||||
if strategy == "middle":
|
||||
truncated = await self._truncate_middle(content, max_tokens, model)
|
||||
elif strategy == "sentence":
|
||||
truncated = await self._truncate_sentence(content, max_tokens, model)
|
||||
else: # "end"
|
||||
truncated = await self._truncate_end(content, max_tokens, model)
|
||||
|
||||
truncated_tokens = await self._count_tokens(truncated, model)
|
||||
|
||||
return TruncationResult(
|
||||
original_tokens=original_tokens,
|
||||
truncated_tokens=truncated_tokens,
|
||||
content=truncated,
|
||||
truncated=True,
|
||||
truncation_ratio=1 - (truncated_tokens / original_tokens),
|
||||
)
|
||||
|
||||
async def _truncate_end(
|
||||
self,
|
||||
content: str,
|
||||
max_tokens: int,
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Truncate from end of content.
|
||||
|
||||
Simple but effective for most content types.
|
||||
"""
|
||||
# Binary search for optimal truncation point
|
||||
marker_tokens = await self._count_tokens(self.TRUNCATION_MARKER, model)
|
||||
available_tokens = max_tokens - marker_tokens
|
||||
|
||||
# Estimate characters per token
|
||||
chars_per_token = len(content) / await self._count_tokens(content, model)
|
||||
|
||||
# Start with estimated position
|
||||
estimated_chars = int(available_tokens * chars_per_token)
|
||||
truncated = content[:estimated_chars]
|
||||
|
||||
# Refine with binary search
|
||||
low, high = len(truncated) // 2, len(truncated)
|
||||
best = truncated
|
||||
|
||||
for _ in range(5): # Max 5 iterations
|
||||
mid = (low + high) // 2
|
||||
candidate = content[:mid]
|
||||
tokens = await self._count_tokens(candidate, model)
|
||||
|
||||
if tokens <= available_tokens:
|
||||
best = candidate
|
||||
low = mid + 1
|
||||
else:
|
||||
high = mid - 1
|
||||
|
||||
return best + self.TRUNCATION_MARKER
|
||||
|
||||
async def _truncate_middle(
|
||||
self,
|
||||
content: str,
|
||||
max_tokens: int,
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Truncate from middle, keeping start and end.
|
||||
|
||||
Good for code or content where context at boundaries matters.
|
||||
"""
|
||||
marker_tokens = await self._count_tokens(self.TRUNCATION_MARKER, model)
|
||||
available_tokens = max_tokens - marker_tokens
|
||||
|
||||
# Split between start and end
|
||||
start_tokens = int(available_tokens * self._preserve_ratio_start)
|
||||
end_tokens = available_tokens - start_tokens
|
||||
|
||||
# Get start portion
|
||||
start_content = await self._get_content_for_tokens(
|
||||
content, start_tokens, from_start=True, model=model
|
||||
)
|
||||
|
||||
# Get end portion
|
||||
end_content = await self._get_content_for_tokens(
|
||||
content, end_tokens, from_start=False, model=model
|
||||
)
|
||||
|
||||
return start_content + self.TRUNCATION_MARKER + end_content
|
||||
|
||||
async def _truncate_sentence(
|
||||
self,
|
||||
content: str,
|
||||
max_tokens: int,
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Truncate at sentence boundaries.
|
||||
|
||||
Produces cleaner output by not cutting mid-sentence.
|
||||
"""
|
||||
# Split into sentences
|
||||
sentences = re.split(r"(?<=[.!?])\s+", content)
|
||||
|
||||
result: list[str] = []
|
||||
total_tokens = 0
|
||||
marker_tokens = await self._count_tokens(self.TRUNCATION_MARKER, model)
|
||||
available = max_tokens - marker_tokens
|
||||
|
||||
for sentence in sentences:
|
||||
sentence_tokens = await self._count_tokens(sentence, model)
|
||||
if total_tokens + sentence_tokens <= available:
|
||||
result.append(sentence)
|
||||
total_tokens += sentence_tokens
|
||||
else:
|
||||
break
|
||||
|
||||
if len(result) < len(sentences):
|
||||
return " ".join(result) + self.TRUNCATION_MARKER
|
||||
return " ".join(result)
|
||||
|
||||
async def _get_content_for_tokens(
|
||||
self,
|
||||
content: str,
|
||||
target_tokens: int,
|
||||
from_start: bool = True,
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
"""Get portion of content fitting within token limit."""
|
||||
if target_tokens <= 0:
|
||||
return ""
|
||||
|
||||
current_tokens = await self._count_tokens(content, model)
|
||||
if current_tokens <= target_tokens:
|
||||
return content
|
||||
|
||||
# Estimate characters
|
||||
chars_per_token = len(content) / current_tokens
|
||||
estimated_chars = int(target_tokens * chars_per_token)
|
||||
|
||||
if from_start:
|
||||
return content[:estimated_chars]
|
||||
else:
|
||||
return content[-estimated_chars:]
|
||||
|
||||
async def _count_tokens(self, text: str, model: str | None = None) -> int:
|
||||
"""Count tokens using calculator or estimation."""
|
||||
if self._calculator is not None:
|
||||
return await self._calculator.count_tokens(text, model)
|
||||
|
||||
# Fallback estimation
|
||||
return max(1, len(text) // 4)
|
||||
|
||||
|
||||
class ContextCompressor:
|
||||
"""
|
||||
Compresses contexts to fit within budget constraints.
|
||||
|
||||
Uses truncation strategies to reduce context size while
|
||||
preserving the most important information.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
truncation: TruncationStrategy | None = None,
|
||||
calculator: "TokenCalculator | None" = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize context compressor.
|
||||
|
||||
Args:
|
||||
truncation: Truncation strategy to use
|
||||
calculator: Token calculator for counting
|
||||
"""
|
||||
self._truncation = truncation or TruncationStrategy(calculator)
|
||||
self._calculator = calculator
|
||||
|
||||
if calculator:
|
||||
self._truncation.set_calculator(calculator)
|
||||
|
||||
def set_calculator(self, calculator: "TokenCalculator") -> None:
|
||||
"""Set token calculator."""
|
||||
self._calculator = calculator
|
||||
self._truncation.set_calculator(calculator)
|
||||
|
||||
async def compress_context(
|
||||
self,
|
||||
context: BaseContext,
|
||||
max_tokens: int,
|
||||
model: str | None = None,
|
||||
) -> BaseContext:
|
||||
"""
|
||||
Compress a single context to fit token limit.
|
||||
|
||||
Args:
|
||||
context: Context to compress
|
||||
max_tokens: Maximum tokens allowed
|
||||
model: Model for token counting
|
||||
|
||||
Returns:
|
||||
Compressed context (may be same object if no compression needed)
|
||||
"""
|
||||
current_tokens = context.token_count or await self._count_tokens(
|
||||
context.content, model
|
||||
)
|
||||
|
||||
if current_tokens <= max_tokens:
|
||||
return context
|
||||
|
||||
# Choose strategy based on context type
|
||||
strategy = self._get_strategy_for_type(context.get_type())
|
||||
|
||||
result = await self._truncation.truncate_to_tokens(
|
||||
content=context.content,
|
||||
max_tokens=max_tokens,
|
||||
strategy=strategy,
|
||||
model=model,
|
||||
)
|
||||
|
||||
# Update context with truncated content
|
||||
context.content = result.content
|
||||
context.token_count = result.truncated_tokens
|
||||
context.metadata["truncated"] = True
|
||||
context.metadata["original_tokens"] = result.original_tokens
|
||||
|
||||
return context
|
||||
|
||||
async def compress_contexts(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
budget: "TokenBudget",
|
||||
model: str | None = None,
|
||||
) -> list[BaseContext]:
|
||||
"""
|
||||
Compress multiple contexts to fit within budget.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to potentially compress
|
||||
budget: Token budget constraints
|
||||
model: Model for token counting
|
||||
|
||||
Returns:
|
||||
List of contexts (compressed as needed)
|
||||
"""
|
||||
result: list[BaseContext] = []
|
||||
|
||||
for context in contexts:
|
||||
context_type = context.get_type()
|
||||
remaining = budget.remaining(context_type)
|
||||
current_tokens = context.token_count or await self._count_tokens(
|
||||
context.content, model
|
||||
)
|
||||
|
||||
if current_tokens > remaining:
|
||||
# Need to compress
|
||||
compressed = await self.compress_context(context, remaining, model)
|
||||
result.append(compressed)
|
||||
logger.debug(
|
||||
f"Compressed {context_type.value} context from "
|
||||
f"{current_tokens} to {compressed.token_count} tokens"
|
||||
)
|
||||
else:
|
||||
result.append(context)
|
||||
|
||||
return result
|
||||
|
||||
def _get_strategy_for_type(self, context_type: ContextType) -> str:
|
||||
"""Get optimal truncation strategy for context type."""
|
||||
strategies = {
|
||||
ContextType.SYSTEM: "end", # Keep instructions at start
|
||||
ContextType.TASK: "end", # Keep task description start
|
||||
ContextType.KNOWLEDGE: "sentence", # Clean sentence boundaries
|
||||
ContextType.CONVERSATION: "end", # Keep recent conversation
|
||||
ContextType.TOOL: "middle", # Keep command and result summary
|
||||
}
|
||||
return strategies.get(context_type, "end")
|
||||
|
||||
async def _count_tokens(self, text: str, model: str | None = None) -> int:
|
||||
"""Count tokens using calculator or estimation."""
|
||||
if self._calculator is not None:
|
||||
return await self._calculator.count_tokens(text, model)
|
||||
return max(1, len(text) // 4)
|
||||
@@ -253,12 +253,19 @@ class AssembledContext:
|
||||
|
||||
# Main content
|
||||
content: str
|
||||
token_count: int
|
||||
total_tokens: int
|
||||
|
||||
# Assembly metadata
|
||||
contexts_included: int
|
||||
contexts_excluded: int = 0
|
||||
context_count: int
|
||||
excluded_count: int = 0
|
||||
assembly_time_ms: float = 0.0
|
||||
model: str = ""
|
||||
|
||||
# Included contexts (optional - for inspection)
|
||||
contexts: list["BaseContext"] = field(default_factory=list)
|
||||
|
||||
# Additional metadata from assembly
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Budget tracking
|
||||
budget_total: int = 0
|
||||
@@ -271,6 +278,22 @@ class AssembledContext:
|
||||
cache_hit: bool = False
|
||||
cache_key: str | None = None
|
||||
|
||||
# Aliases for backward compatibility
|
||||
@property
|
||||
def token_count(self) -> int:
|
||||
"""Alias for total_tokens."""
|
||||
return self.total_tokens
|
||||
|
||||
@property
|
||||
def contexts_included(self) -> int:
|
||||
"""Alias for context_count."""
|
||||
return self.context_count
|
||||
|
||||
@property
|
||||
def contexts_excluded(self) -> int:
|
||||
"""Alias for excluded_count."""
|
||||
return self.excluded_count
|
||||
|
||||
@property
|
||||
def budget_utilization(self) -> float:
|
||||
"""Get budget utilization percentage."""
|
||||
@@ -282,10 +305,12 @@ class AssembledContext:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"content": self.content,
|
||||
"token_count": self.token_count,
|
||||
"contexts_included": self.contexts_included,
|
||||
"contexts_excluded": self.contexts_excluded,
|
||||
"total_tokens": self.total_tokens,
|
||||
"context_count": self.context_count,
|
||||
"excluded_count": self.excluded_count,
|
||||
"assembly_time_ms": round(self.assembly_time_ms, 2),
|
||||
"model": self.model,
|
||||
"metadata": self.metadata,
|
||||
"budget_total": self.budget_total,
|
||||
"budget_used": self.budget_used,
|
||||
"budget_utilization": round(self.budget_utilization, 3),
|
||||
@@ -308,10 +333,12 @@ class AssembledContext:
|
||||
data = json.loads(json_str)
|
||||
return cls(
|
||||
content=data["content"],
|
||||
token_count=data["token_count"],
|
||||
contexts_included=data["contexts_included"],
|
||||
contexts_excluded=data.get("contexts_excluded", 0),
|
||||
total_tokens=data["total_tokens"],
|
||||
context_count=data["context_count"],
|
||||
excluded_count=data.get("excluded_count", 0),
|
||||
assembly_time_ms=data.get("assembly_time_ms", 0.0),
|
||||
model=data.get("model", ""),
|
||||
metadata=data.get("metadata", {}),
|
||||
budget_total=data.get("budget_total", 0),
|
||||
budget_used=data.get("budget_used", 0),
|
||||
by_type=data.get("by_type", {}),
|
||||
|
||||
502
backend/tests/services/context/test_assembly.py
Normal file
502
backend/tests/services/context/test_assembly.py
Normal file
@@ -0,0 +1,502 @@
|
||||
"""Tests for context assembly pipeline."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.assembly import ContextPipeline, PipelineMetrics
|
||||
from app.services.context.budget import BudgetAllocator, TokenBudget
|
||||
from app.services.context.types import (
|
||||
AssembledContext,
|
||||
ContextType,
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MessageRole,
|
||||
SystemContext,
|
||||
TaskContext,
|
||||
ToolContext,
|
||||
)
|
||||
|
||||
|
||||
class TestPipelineMetrics:
|
||||
"""Tests for PipelineMetrics dataclass."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test metrics creation."""
|
||||
metrics = PipelineMetrics()
|
||||
|
||||
assert metrics.total_contexts == 0
|
||||
assert metrics.selected_contexts == 0
|
||||
assert metrics.assembly_time_ms == 0.0
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test conversion to dictionary."""
|
||||
metrics = PipelineMetrics(
|
||||
total_contexts=10,
|
||||
selected_contexts=8,
|
||||
excluded_contexts=2,
|
||||
total_tokens=500,
|
||||
assembly_time_ms=25.5,
|
||||
)
|
||||
metrics.end_time = datetime.now(UTC)
|
||||
|
||||
data = metrics.to_dict()
|
||||
|
||||
assert data["total_contexts"] == 10
|
||||
assert data["selected_contexts"] == 8
|
||||
assert data["excluded_contexts"] == 2
|
||||
assert data["total_tokens"] == 500
|
||||
assert data["assembly_time_ms"] == 25.5
|
||||
assert "start_time" in data
|
||||
assert "end_time" in data
|
||||
|
||||
|
||||
class TestContextPipeline:
|
||||
"""Tests for ContextPipeline."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test pipeline creation."""
|
||||
pipeline = ContextPipeline()
|
||||
|
||||
assert pipeline._calculator is not None
|
||||
assert pipeline._scorer is not None
|
||||
assert pipeline._ranker is not None
|
||||
assert pipeline._compressor is not None
|
||||
assert pipeline._allocator is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemble_empty_contexts(self) -> None:
|
||||
"""Test assembling empty context list."""
|
||||
pipeline = ContextPipeline()
|
||||
|
||||
result = await pipeline.assemble(
|
||||
contexts=[],
|
||||
query="test query",
|
||||
model="claude-3-sonnet",
|
||||
)
|
||||
|
||||
assert isinstance(result, AssembledContext)
|
||||
assert result.context_count == 0
|
||||
assert result.total_tokens == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemble_single_context(self) -> None:
|
||||
"""Test assembling single context."""
|
||||
pipeline = ContextPipeline()
|
||||
|
||||
contexts = [
|
||||
SystemContext(
|
||||
content="You are a helpful assistant.",
|
||||
source="system",
|
||||
)
|
||||
]
|
||||
|
||||
result = await pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query="help me",
|
||||
model="claude-3-sonnet",
|
||||
)
|
||||
|
||||
assert result.context_count == 1
|
||||
assert result.total_tokens > 0
|
||||
assert "helpful assistant" in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemble_multiple_types(self) -> None:
|
||||
"""Test assembling multiple context types."""
|
||||
pipeline = ContextPipeline()
|
||||
|
||||
contexts = [
|
||||
SystemContext(
|
||||
content="You are a coding assistant.",
|
||||
source="system",
|
||||
),
|
||||
TaskContext(
|
||||
content="Implement a login feature.",
|
||||
source="task",
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Authentication best practices include...",
|
||||
source="docs/auth.md",
|
||||
relevance_score=0.8,
|
||||
),
|
||||
]
|
||||
|
||||
result = await pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query="implement login",
|
||||
model="claude-3-sonnet",
|
||||
)
|
||||
|
||||
assert result.context_count >= 1
|
||||
assert result.total_tokens > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemble_with_custom_budget(self) -> None:
|
||||
"""Test assembling with custom budget."""
|
||||
pipeline = ContextPipeline()
|
||||
budget = TokenBudget(
|
||||
total=1000,
|
||||
system=200,
|
||||
task=200,
|
||||
knowledge=400,
|
||||
conversation=100,
|
||||
tools=50,
|
||||
response_reserve=50,
|
||||
)
|
||||
|
||||
contexts = [
|
||||
SystemContext(content="System prompt", source="system"),
|
||||
TaskContext(content="Task description", source="task"),
|
||||
]
|
||||
|
||||
result = await pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query="test",
|
||||
model="gpt-4",
|
||||
custom_budget=budget,
|
||||
)
|
||||
|
||||
assert result.context_count >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemble_with_max_tokens(self) -> None:
|
||||
"""Test assembling with max_tokens limit."""
|
||||
pipeline = ContextPipeline()
|
||||
|
||||
contexts = [
|
||||
SystemContext(content="System prompt", source="system"),
|
||||
]
|
||||
|
||||
result = await pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query="test",
|
||||
model="gpt-4",
|
||||
max_tokens=5000,
|
||||
)
|
||||
|
||||
assert "budget" in result.metadata
|
||||
assert result.metadata["budget"]["total"] == 5000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemble_format_output(self) -> None:
|
||||
"""Test formatted vs unformatted output."""
|
||||
pipeline = ContextPipeline()
|
||||
|
||||
contexts = [
|
||||
SystemContext(content="System prompt", source="system"),
|
||||
]
|
||||
|
||||
# Formatted (default)
|
||||
result_formatted = await pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query="test",
|
||||
model="claude-3-sonnet",
|
||||
format_output=True,
|
||||
)
|
||||
|
||||
# Unformatted
|
||||
result_raw = await pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query="test",
|
||||
model="claude-3-sonnet",
|
||||
format_output=False,
|
||||
)
|
||||
|
||||
# Formatted should have XML tags for Claude
|
||||
assert "<system_instructions>" in result_formatted.content
|
||||
# Raw should not
|
||||
assert "<system_instructions>" not in result_raw.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemble_metrics(self) -> None:
|
||||
"""Test that metrics are populated."""
|
||||
pipeline = ContextPipeline()
|
||||
|
||||
contexts = [
|
||||
SystemContext(content="System", source="system"),
|
||||
TaskContext(content="Task", source="task"),
|
||||
KnowledgeContext(
|
||||
content="Knowledge",
|
||||
source="docs",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
]
|
||||
|
||||
result = await pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query="test",
|
||||
model="claude-3-sonnet",
|
||||
)
|
||||
|
||||
assert "metrics" in result.metadata
|
||||
metrics = result.metadata["metrics"]
|
||||
|
||||
assert metrics["total_contexts"] == 3
|
||||
assert metrics["assembly_time_ms"] > 0
|
||||
assert "scoring_time_ms" in metrics
|
||||
assert "formatting_time_ms" in metrics
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemble_with_compression_disabled(self) -> None:
|
||||
"""Test assembling with compression disabled."""
|
||||
pipeline = ContextPipeline()
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(content="A" * 1000, source="docs"),
|
||||
]
|
||||
|
||||
result = await pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query="test",
|
||||
model="gpt-4",
|
||||
compress=False,
|
||||
)
|
||||
|
||||
# Should still work, just no compression applied
|
||||
assert result.context_count >= 0
|
||||
|
||||
|
||||
class TestContextPipelineFormatting:
|
||||
"""Tests for context formatting."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_claude_uses_xml(self) -> None:
|
||||
"""Test that Claude models use XML formatting."""
|
||||
pipeline = ContextPipeline()
|
||||
|
||||
contexts = [
|
||||
SystemContext(content="System prompt", source="system"),
|
||||
TaskContext(content="Task", source="task"),
|
||||
KnowledgeContext(
|
||||
content="Knowledge",
|
||||
source="docs",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
]
|
||||
|
||||
result = await pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query="test",
|
||||
model="claude-3-sonnet",
|
||||
)
|
||||
|
||||
# Claude should have XML tags
|
||||
assert "<system_instructions>" in result.content or result.context_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_openai_uses_markdown(self) -> None:
|
||||
"""Test that OpenAI models use markdown formatting."""
|
||||
pipeline = ContextPipeline()
|
||||
|
||||
contexts = [
|
||||
TaskContext(content="Task description", source="task"),
|
||||
]
|
||||
|
||||
result = await pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query="test",
|
||||
model="gpt-4",
|
||||
)
|
||||
|
||||
# OpenAI should have markdown headers
|
||||
if result.context_count > 0 and "Task" in result.content:
|
||||
assert "## Current Task" in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_knowledge_claude(self) -> None:
|
||||
"""Test knowledge formatting for Claude."""
|
||||
pipeline = ContextPipeline()
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="Document content here",
|
||||
source="docs/file.md",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
]
|
||||
|
||||
result = await pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query="test",
|
||||
model="claude-3-sonnet",
|
||||
)
|
||||
|
||||
if result.context_count > 0:
|
||||
assert "<reference_documents>" in result.content
|
||||
assert "<document" in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_conversation(self) -> None:
|
||||
"""Test conversation formatting."""
|
||||
pipeline = ContextPipeline()
|
||||
|
||||
contexts = [
|
||||
ConversationContext(
|
||||
content="Hello, how are you?",
|
||||
source="chat",
|
||||
role=MessageRole.USER,
|
||||
metadata={"role": "user"},
|
||||
),
|
||||
ConversationContext(
|
||||
content="I'm doing great!",
|
||||
source="chat",
|
||||
role=MessageRole.ASSISTANT,
|
||||
metadata={"role": "assistant"},
|
||||
),
|
||||
]
|
||||
|
||||
result = await pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query="test",
|
||||
model="claude-3-sonnet",
|
||||
)
|
||||
|
||||
if result.context_count > 0:
|
||||
assert "<conversation_history>" in result.content
|
||||
assert '<message role="user">' in result.content or 'role="user"' in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_tool_results(self) -> None:
|
||||
"""Test tool result formatting."""
|
||||
pipeline = ContextPipeline()
|
||||
|
||||
contexts = [
|
||||
ToolContext(
|
||||
content="Tool output here",
|
||||
source="tool",
|
||||
metadata={"tool_name": "search"},
|
||||
),
|
||||
]
|
||||
|
||||
result = await pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query="test",
|
||||
model="claude-3-sonnet",
|
||||
)
|
||||
|
||||
if result.context_count > 0:
|
||||
assert "<tool_results>" in result.content
|
||||
|
||||
|
||||
class TestContextPipelineIntegration:
|
||||
"""Integration tests for full pipeline."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_pipeline_workflow(self) -> None:
|
||||
"""Test complete pipeline workflow."""
|
||||
pipeline = ContextPipeline()
|
||||
|
||||
# Create realistic context mix
|
||||
contexts = [
|
||||
SystemContext(
|
||||
content="You are an expert Python developer.",
|
||||
source="system",
|
||||
),
|
||||
TaskContext(
|
||||
content="Implement a user authentication system.",
|
||||
source="task:AUTH-123",
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="JWT tokens provide stateless authentication...",
|
||||
source="docs/auth/jwt.md",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="OAuth 2.0 is an authorization framework...",
|
||||
source="docs/auth/oauth.md",
|
||||
relevance_score=0.7,
|
||||
),
|
||||
ConversationContext(
|
||||
content="Can you help me implement JWT auth?",
|
||||
source="chat",
|
||||
role=MessageRole.USER,
|
||||
metadata={"role": "user"},
|
||||
),
|
||||
]
|
||||
|
||||
result = await pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query="implement JWT authentication",
|
||||
model="claude-3-sonnet",
|
||||
)
|
||||
|
||||
# Verify result
|
||||
assert isinstance(result, AssembledContext)
|
||||
assert result.context_count > 0
|
||||
assert result.total_tokens > 0
|
||||
assert result.assembly_time_ms > 0
|
||||
assert result.model == "claude-3-sonnet"
|
||||
assert len(result.content) > 0
|
||||
|
||||
# Verify metrics
|
||||
assert "metrics" in result.metadata
|
||||
assert "query" in result.metadata
|
||||
assert "budget" in result.metadata
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_type_ordering(self) -> None:
|
||||
"""Test that contexts are ordered by type correctly."""
|
||||
pipeline = ContextPipeline()
|
||||
|
||||
# Add in random order
|
||||
contexts = [
|
||||
KnowledgeContext(content="Knowledge", source="docs", relevance_score=0.9),
|
||||
ToolContext(content="Tool", source="tool", metadata={"tool_name": "test"}),
|
||||
SystemContext(content="System", source="system"),
|
||||
ConversationContext(
|
||||
content="Chat",
|
||||
source="chat",
|
||||
role=MessageRole.USER,
|
||||
metadata={"role": "user"},
|
||||
),
|
||||
TaskContext(content="Task", source="task"),
|
||||
]
|
||||
|
||||
result = await pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query="test",
|
||||
model="claude-3-sonnet",
|
||||
)
|
||||
|
||||
# For Claude, verify order: System -> Task -> Knowledge -> Conversation -> Tool
|
||||
content = result.content
|
||||
if result.context_count > 0:
|
||||
# Find positions (if they exist)
|
||||
system_pos = content.find("system_instructions")
|
||||
task_pos = content.find("current_task")
|
||||
knowledge_pos = content.find("reference_documents")
|
||||
conversation_pos = content.find("conversation_history")
|
||||
tool_pos = content.find("tool_results")
|
||||
|
||||
# Verify ordering (only check if both exist)
|
||||
if system_pos >= 0 and task_pos >= 0:
|
||||
assert system_pos < task_pos
|
||||
if task_pos >= 0 and knowledge_pos >= 0:
|
||||
assert task_pos < knowledge_pos
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_excluded_contexts_tracked(self) -> None:
|
||||
"""Test that excluded contexts are tracked in result."""
|
||||
pipeline = ContextPipeline()
|
||||
|
||||
# Create many contexts to force some exclusions
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="A" * 500, # Large content
|
||||
source=f"docs/{i}",
|
||||
relevance_score=0.1 + (i * 0.05),
|
||||
)
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
result = await pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query="test",
|
||||
model="gpt-4", # Smaller context window
|
||||
max_tokens=1000, # Limited budget
|
||||
)
|
||||
|
||||
# Should have excluded some
|
||||
assert result.excluded_count >= 0
|
||||
assert result.context_count + result.excluded_count <= len(contexts)
|
||||
214
backend/tests/services/context/test_compression.py
Normal file
214
backend/tests/services/context/test_compression.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""Tests for context compression module."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.compression import (
|
||||
ContextCompressor,
|
||||
TruncationResult,
|
||||
TruncationStrategy,
|
||||
)
|
||||
from app.services.context.budget import BudgetAllocator, TokenBudget
|
||||
from app.services.context.types import (
|
||||
ContextType,
|
||||
KnowledgeContext,
|
||||
SystemContext,
|
||||
TaskContext,
|
||||
)
|
||||
|
||||
|
||||
class TestTruncationResult:
|
||||
"""Tests for TruncationResult dataclass."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test basic creation."""
|
||||
result = TruncationResult(
|
||||
original_tokens=100,
|
||||
truncated_tokens=50,
|
||||
content="Truncated content",
|
||||
truncated=True,
|
||||
truncation_ratio=0.5,
|
||||
)
|
||||
|
||||
assert result.original_tokens == 100
|
||||
assert result.truncated_tokens == 50
|
||||
assert result.truncated is True
|
||||
assert result.truncation_ratio == 0.5
|
||||
|
||||
def test_tokens_saved(self) -> None:
|
||||
"""Test tokens_saved property."""
|
||||
result = TruncationResult(
|
||||
original_tokens=100,
|
||||
truncated_tokens=40,
|
||||
content="Test",
|
||||
truncated=True,
|
||||
truncation_ratio=0.6,
|
||||
)
|
||||
|
||||
assert result.tokens_saved == 60
|
||||
|
||||
def test_no_truncation(self) -> None:
|
||||
"""Test when no truncation needed."""
|
||||
result = TruncationResult(
|
||||
original_tokens=50,
|
||||
truncated_tokens=50,
|
||||
content="Full content",
|
||||
truncated=False,
|
||||
truncation_ratio=0.0,
|
||||
)
|
||||
|
||||
assert result.tokens_saved == 0
|
||||
assert result.truncated is False
|
||||
|
||||
|
||||
class TestTruncationStrategy:
|
||||
"""Tests for TruncationStrategy."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test strategy creation."""
|
||||
strategy = TruncationStrategy()
|
||||
assert strategy._preserve_ratio_start == 0.7
|
||||
assert strategy._min_content_length == 100
|
||||
|
||||
def test_creation_with_params(self) -> None:
|
||||
"""Test strategy creation with custom params."""
|
||||
strategy = TruncationStrategy(
|
||||
preserve_ratio_start=0.5,
|
||||
min_content_length=50,
|
||||
)
|
||||
assert strategy._preserve_ratio_start == 0.5
|
||||
assert strategy._min_content_length == 50
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncate_empty_content(self) -> None:
|
||||
"""Test truncating empty content."""
|
||||
strategy = TruncationStrategy()
|
||||
|
||||
result = await strategy.truncate_to_tokens("", max_tokens=100)
|
||||
|
||||
assert result.original_tokens == 0
|
||||
assert result.truncated_tokens == 0
|
||||
assert result.content == ""
|
||||
assert result.truncated is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncate_content_within_limit(self) -> None:
|
||||
"""Test content that fits within limit."""
|
||||
strategy = TruncationStrategy()
|
||||
content = "Short content"
|
||||
|
||||
result = await strategy.truncate_to_tokens(content, max_tokens=100)
|
||||
|
||||
assert result.content == content
|
||||
assert result.truncated is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncate_end_strategy(self) -> None:
|
||||
"""Test end truncation strategy."""
|
||||
strategy = TruncationStrategy()
|
||||
content = "A" * 1000 # Long content
|
||||
|
||||
result = await strategy.truncate_to_tokens(
|
||||
content, max_tokens=50, strategy="end"
|
||||
)
|
||||
|
||||
assert result.truncated is True
|
||||
assert len(result.content) < len(content)
|
||||
assert strategy.TRUNCATION_MARKER in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncate_middle_strategy(self) -> None:
|
||||
"""Test middle truncation strategy."""
|
||||
strategy = TruncationStrategy(preserve_ratio_start=0.6)
|
||||
content = "START " + "A" * 500 + " END"
|
||||
|
||||
result = await strategy.truncate_to_tokens(
|
||||
content, max_tokens=50, strategy="middle"
|
||||
)
|
||||
|
||||
assert result.truncated is True
|
||||
assert strategy.TRUNCATION_MARKER in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncate_sentence_strategy(self) -> None:
|
||||
"""Test sentence-aware truncation strategy."""
|
||||
strategy = TruncationStrategy()
|
||||
content = "First sentence. Second sentence. Third sentence. Fourth sentence."
|
||||
|
||||
result = await strategy.truncate_to_tokens(
|
||||
content, max_tokens=10, strategy="sentence"
|
||||
)
|
||||
|
||||
assert result.truncated is True
|
||||
# Should cut at sentence boundary
|
||||
assert result.content.endswith(".") or strategy.TRUNCATION_MARKER in result.content
|
||||
|
||||
|
||||
class TestContextCompressor:
|
||||
"""Tests for ContextCompressor."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test compressor creation."""
|
||||
compressor = ContextCompressor()
|
||||
assert compressor._truncation is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compress_context_within_limit(self) -> None:
|
||||
"""Test compressing context that already fits."""
|
||||
compressor = ContextCompressor()
|
||||
|
||||
context = KnowledgeContext(
|
||||
content="Short content",
|
||||
source="docs",
|
||||
)
|
||||
context.token_count = 5
|
||||
|
||||
result = await compressor.compress_context(context, max_tokens=100)
|
||||
|
||||
# Should return same context unmodified
|
||||
assert result.content == "Short content"
|
||||
assert result.metadata.get("truncated") is not True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compress_context_exceeds_limit(self) -> None:
|
||||
"""Test compressing context that exceeds limit."""
|
||||
compressor = ContextCompressor()
|
||||
|
||||
context = KnowledgeContext(
|
||||
content="A" * 500,
|
||||
source="docs",
|
||||
)
|
||||
context.token_count = 125 # Approximately 500/4
|
||||
|
||||
result = await compressor.compress_context(context, max_tokens=20)
|
||||
|
||||
assert result.metadata.get("truncated") is True
|
||||
assert result.metadata.get("original_tokens") == 125
|
||||
assert len(result.content) < 500
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compress_contexts_batch(self) -> None:
|
||||
"""Test compressing multiple contexts."""
|
||||
compressor = ContextCompressor()
|
||||
allocator = BudgetAllocator()
|
||||
budget = allocator.create_budget(1000)
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(content="A" * 200, source="docs"),
|
||||
KnowledgeContext(content="B" * 200, source="docs"),
|
||||
TaskContext(content="C" * 200, source="task"),
|
||||
]
|
||||
|
||||
result = await compressor.compress_contexts(contexts, budget)
|
||||
|
||||
assert len(result) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_strategy_selection_by_type(self) -> None:
|
||||
"""Test that correct strategy is selected for each type."""
|
||||
compressor = ContextCompressor()
|
||||
|
||||
assert compressor._get_strategy_for_type(ContextType.SYSTEM) == "end"
|
||||
assert compressor._get_strategy_for_type(ContextType.TASK) == "end"
|
||||
assert compressor._get_strategy_for_type(ContextType.KNOWLEDGE) == "sentence"
|
||||
assert compressor._get_strategy_for_type(ContextType.CONVERSATION) == "end"
|
||||
assert compressor._get_strategy_for_type(ContextType.TOOL) == "middle"
|
||||
@@ -426,11 +426,14 @@ class TestAssembledContext:
|
||||
"""Test basic creation."""
|
||||
ctx = AssembledContext(
|
||||
content="Assembled content here",
|
||||
token_count=500,
|
||||
contexts_included=5,
|
||||
total_tokens=500,
|
||||
context_count=5,
|
||||
)
|
||||
|
||||
assert ctx.content == "Assembled content here"
|
||||
assert ctx.total_tokens == 500
|
||||
assert ctx.context_count == 5
|
||||
# Test backward compatibility aliases
|
||||
assert ctx.token_count == 500
|
||||
assert ctx.contexts_included == 5
|
||||
|
||||
@@ -438,8 +441,8 @@ class TestAssembledContext:
|
||||
"""Test budget_utilization property."""
|
||||
ctx = AssembledContext(
|
||||
content="test",
|
||||
token_count=800,
|
||||
contexts_included=5,
|
||||
total_tokens=800,
|
||||
context_count=5,
|
||||
budget_total=1000,
|
||||
budget_used=800,
|
||||
)
|
||||
@@ -450,8 +453,8 @@ class TestAssembledContext:
|
||||
"""Test budget_utilization with zero budget."""
|
||||
ctx = AssembledContext(
|
||||
content="test",
|
||||
token_count=0,
|
||||
contexts_included=0,
|
||||
total_tokens=0,
|
||||
context_count=0,
|
||||
budget_total=0,
|
||||
budget_used=0,
|
||||
)
|
||||
@@ -462,24 +465,26 @@ class TestAssembledContext:
|
||||
"""Test to_dict method."""
|
||||
ctx = AssembledContext(
|
||||
content="test",
|
||||
token_count=100,
|
||||
contexts_included=2,
|
||||
total_tokens=100,
|
||||
context_count=2,
|
||||
assembly_time_ms=50.123,
|
||||
)
|
||||
|
||||
data = ctx.to_dict()
|
||||
assert data["content"] == "test"
|
||||
assert data["token_count"] == 100
|
||||
assert data["total_tokens"] == 100
|
||||
assert data["context_count"] == 2
|
||||
assert data["assembly_time_ms"] == 50.12 # Rounded
|
||||
|
||||
def test_to_json_and_from_json(self) -> None:
|
||||
"""Test JSON serialization round-trip."""
|
||||
original = AssembledContext(
|
||||
content="Test content",
|
||||
token_count=100,
|
||||
contexts_included=3,
|
||||
contexts_excluded=2,
|
||||
total_tokens=100,
|
||||
context_count=3,
|
||||
excluded_count=2,
|
||||
assembly_time_ms=45.5,
|
||||
model="claude-3-sonnet",
|
||||
budget_total=1000,
|
||||
budget_used=100,
|
||||
by_type={"system": 20, "knowledge": 80},
|
||||
@@ -491,8 +496,10 @@ class TestAssembledContext:
|
||||
restored = AssembledContext.from_json(json_str)
|
||||
|
||||
assert restored.content == original.content
|
||||
assert restored.token_count == original.token_count
|
||||
assert restored.contexts_included == original.contexts_included
|
||||
assert restored.total_tokens == original.total_tokens
|
||||
assert restored.context_count == original.context_count
|
||||
assert restored.excluded_count == original.excluded_count
|
||||
assert restored.model == original.model
|
||||
assert restored.cache_hit == original.cache_hit
|
||||
assert restored.cache_key == original.cache_key
|
||||
|
||||
|
||||
Reference in New Issue
Block a user