forked from cardosofelipe/fast-next-template
feat(context): implement assembly pipeline and compression (#82)
Phase 4 of Context Management Engine - Assembly Pipeline: - Add TruncationStrategy with end/middle/sentence-aware truncation - Add TruncationResult dataclass for tracking compression metrics - Add ContextCompressor for type-specific compression - Add ContextPipeline orchestrating full assembly workflow: - Token counting for all contexts - Scoring and ranking via ContextRanker - Optional compression when budget threshold exceeded - Model-specific formatting (XML for Claude, markdown for OpenAI) - Add PipelineMetrics for performance tracking - Update AssembledContext with new fields (model, contexts, metadata) - Add backward compatibility aliases for renamed fields Tests: 34 new tests, 223 total context tests passing 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -63,6 +63,19 @@ from .exceptions import (
|
|||||||
TokenCountError,
|
TokenCountError,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Assembly
|
||||||
|
from .assembly import (
|
||||||
|
ContextPipeline,
|
||||||
|
PipelineMetrics,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compression
|
||||||
|
from .compression import (
|
||||||
|
ContextCompressor,
|
||||||
|
TruncationResult,
|
||||||
|
TruncationStrategy,
|
||||||
|
)
|
||||||
|
|
||||||
# Prioritization
|
# Prioritization
|
||||||
from .prioritization import (
|
from .prioritization import (
|
||||||
ContextRanker,
|
ContextRanker,
|
||||||
@@ -97,10 +110,17 @@ from .types import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
# Assembly
|
||||||
|
"ContextPipeline",
|
||||||
|
"PipelineMetrics",
|
||||||
# Budget Management
|
# Budget Management
|
||||||
"BudgetAllocator",
|
"BudgetAllocator",
|
||||||
"TokenBudget",
|
"TokenBudget",
|
||||||
"TokenCalculator",
|
"TokenCalculator",
|
||||||
|
# Compression
|
||||||
|
"ContextCompressor",
|
||||||
|
"TruncationResult",
|
||||||
|
"TruncationStrategy",
|
||||||
# Configuration
|
# Configuration
|
||||||
"ContextSettings",
|
"ContextSettings",
|
||||||
"get_context_settings",
|
"get_context_settings",
|
||||||
|
|||||||
@@ -3,3 +3,10 @@ Context Assembly Module.
|
|||||||
|
|
||||||
Provides the assembly pipeline and formatting.
|
Provides the assembly pipeline and formatting.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from .pipeline import ContextPipeline, PipelineMetrics
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ContextPipeline",
|
||||||
|
"PipelineMetrics",
|
||||||
|
]
|
||||||
|
|||||||
432
backend/app/services/context/assembly/pipeline.py
Normal file
432
backend/app/services/context/assembly/pipeline.py
Normal file
@@ -0,0 +1,432 @@
|
|||||||
|
"""
|
||||||
|
Context Assembly Pipeline.
|
||||||
|
|
||||||
|
Orchestrates the full context assembly workflow:
|
||||||
|
Gather → Count → Score → Rank → Compress → Format
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from ..budget import BudgetAllocator, TokenBudget, TokenCalculator
|
||||||
|
from ..compression.truncation import ContextCompressor
|
||||||
|
from ..config import ContextSettings, get_context_settings
|
||||||
|
from ..exceptions import AssemblyTimeoutError
|
||||||
|
from ..prioritization import ContextRanker
|
||||||
|
from ..scoring import CompositeScorer
|
||||||
|
from ..types import AssembledContext, BaseContext, ContextType
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.services.mcp.client_manager import MCPClientManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PipelineMetrics:
|
||||||
|
"""Metrics from pipeline execution."""
|
||||||
|
|
||||||
|
start_time: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||||
|
end_time: datetime | None = None
|
||||||
|
total_contexts: int = 0
|
||||||
|
selected_contexts: int = 0
|
||||||
|
excluded_contexts: int = 0
|
||||||
|
compressed_contexts: int = 0
|
||||||
|
total_tokens: int = 0
|
||||||
|
assembly_time_ms: float = 0.0
|
||||||
|
scoring_time_ms: float = 0.0
|
||||||
|
compression_time_ms: float = 0.0
|
||||||
|
formatting_time_ms: float = 0.0
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
"start_time": self.start_time.isoformat(),
|
||||||
|
"end_time": self.end_time.isoformat() if self.end_time else None,
|
||||||
|
"total_contexts": self.total_contexts,
|
||||||
|
"selected_contexts": self.selected_contexts,
|
||||||
|
"excluded_contexts": self.excluded_contexts,
|
||||||
|
"compressed_contexts": self.compressed_contexts,
|
||||||
|
"total_tokens": self.total_tokens,
|
||||||
|
"assembly_time_ms": round(self.assembly_time_ms, 2),
|
||||||
|
"scoring_time_ms": round(self.scoring_time_ms, 2),
|
||||||
|
"compression_time_ms": round(self.compression_time_ms, 2),
|
||||||
|
"formatting_time_ms": round(self.formatting_time_ms, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ContextPipeline:
|
||||||
|
"""
|
||||||
|
Context assembly pipeline.
|
||||||
|
|
||||||
|
Orchestrates the full workflow of context assembly:
|
||||||
|
1. Validate and count tokens for all contexts
|
||||||
|
2. Score contexts based on relevance, recency, and priority
|
||||||
|
3. Rank and select contexts within budget
|
||||||
|
4. Compress if needed to fit remaining budget
|
||||||
|
5. Format for the target model
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mcp_manager: "MCPClientManager | None" = None,
|
||||||
|
settings: ContextSettings | None = None,
|
||||||
|
calculator: TokenCalculator | None = None,
|
||||||
|
scorer: CompositeScorer | None = None,
|
||||||
|
ranker: ContextRanker | None = None,
|
||||||
|
compressor: ContextCompressor | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the context pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_manager: MCP client manager for LLM Gateway integration
|
||||||
|
settings: Context settings
|
||||||
|
calculator: Token calculator
|
||||||
|
scorer: Context scorer
|
||||||
|
ranker: Context ranker
|
||||||
|
compressor: Context compressor
|
||||||
|
"""
|
||||||
|
self._settings = settings or get_context_settings()
|
||||||
|
self._mcp = mcp_manager
|
||||||
|
|
||||||
|
# Initialize components
|
||||||
|
self._calculator = calculator or TokenCalculator(mcp_manager=mcp_manager)
|
||||||
|
self._scorer = scorer or CompositeScorer(
|
||||||
|
mcp_manager=mcp_manager, settings=self._settings
|
||||||
|
)
|
||||||
|
self._ranker = ranker or ContextRanker(
|
||||||
|
scorer=self._scorer, calculator=self._calculator
|
||||||
|
)
|
||||||
|
self._compressor = compressor or ContextCompressor(
|
||||||
|
calculator=self._calculator
|
||||||
|
)
|
||||||
|
self._allocator = BudgetAllocator(self._settings)
|
||||||
|
|
||||||
|
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||||
|
"""Set MCP manager for all components."""
|
||||||
|
self._mcp = mcp_manager
|
||||||
|
self._calculator.set_mcp_manager(mcp_manager)
|
||||||
|
self._scorer.set_mcp_manager(mcp_manager)
|
||||||
|
|
||||||
|
async def assemble(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
query: str,
|
||||||
|
model: str,
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
custom_budget: TokenBudget | None = None,
|
||||||
|
compress: bool = True,
|
||||||
|
format_output: bool = True,
|
||||||
|
timeout_ms: int | None = None,
|
||||||
|
) -> AssembledContext:
|
||||||
|
"""
|
||||||
|
Assemble context for an LLM request.
|
||||||
|
|
||||||
|
This is the main entry point for context assembly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: List of contexts to assemble
|
||||||
|
query: Query to optimize for
|
||||||
|
model: Target model name
|
||||||
|
max_tokens: Maximum total tokens (uses model default if None)
|
||||||
|
custom_budget: Optional pre-configured budget
|
||||||
|
compress: Whether to compress oversized contexts
|
||||||
|
format_output: Whether to format the final output
|
||||||
|
timeout_ms: Maximum assembly time in milliseconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AssembledContext with optimized content
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssemblyTimeoutError: If assembly exceeds timeout
|
||||||
|
"""
|
||||||
|
timeout = timeout_ms or self._settings.max_assembly_time_ms
|
||||||
|
start = time.perf_counter()
|
||||||
|
metrics = PipelineMetrics(total_contexts=len(contexts))
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create or use budget
|
||||||
|
if custom_budget:
|
||||||
|
budget = custom_budget
|
||||||
|
elif max_tokens:
|
||||||
|
budget = self._allocator.create_budget(max_tokens)
|
||||||
|
else:
|
||||||
|
budget = self._allocator.create_budget_for_model(model)
|
||||||
|
|
||||||
|
# 1. Count tokens for all contexts
|
||||||
|
await self._ensure_token_counts(contexts, model)
|
||||||
|
|
||||||
|
# Check timeout
|
||||||
|
self._check_timeout(start, timeout, "token counting")
|
||||||
|
|
||||||
|
# 2. Score and rank contexts
|
||||||
|
scoring_start = time.perf_counter()
|
||||||
|
ranking_result = await self._ranker.rank(
|
||||||
|
contexts=contexts,
|
||||||
|
query=query,
|
||||||
|
budget=budget,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
metrics.scoring_time_ms = (time.perf_counter() - scoring_start) * 1000
|
||||||
|
|
||||||
|
selected_contexts = ranking_result.selected_contexts
|
||||||
|
metrics.selected_contexts = len(selected_contexts)
|
||||||
|
metrics.excluded_contexts = len(ranking_result.excluded)
|
||||||
|
|
||||||
|
# Check timeout
|
||||||
|
self._check_timeout(start, timeout, "scoring")
|
||||||
|
|
||||||
|
# 3. Compress if needed and enabled
|
||||||
|
if compress and self._needs_compression(selected_contexts, budget):
|
||||||
|
compression_start = time.perf_counter()
|
||||||
|
selected_contexts = await self._compressor.compress_contexts(
|
||||||
|
selected_contexts, budget, model
|
||||||
|
)
|
||||||
|
metrics.compression_time_ms = (
|
||||||
|
time.perf_counter() - compression_start
|
||||||
|
) * 1000
|
||||||
|
metrics.compressed_contexts = sum(
|
||||||
|
1 for c in selected_contexts if c.metadata.get("truncated", False)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check timeout
|
||||||
|
self._check_timeout(start, timeout, "compression")
|
||||||
|
|
||||||
|
# 4. Format output
|
||||||
|
formatting_start = time.perf_counter()
|
||||||
|
if format_output:
|
||||||
|
formatted_content = self._format_contexts(selected_contexts, model)
|
||||||
|
else:
|
||||||
|
formatted_content = "\n\n".join(c.content for c in selected_contexts)
|
||||||
|
metrics.formatting_time_ms = (time.perf_counter() - formatting_start) * 1000
|
||||||
|
|
||||||
|
# Calculate final metrics
|
||||||
|
total_tokens = sum(c.token_count or 0 for c in selected_contexts)
|
||||||
|
metrics.total_tokens = total_tokens
|
||||||
|
metrics.assembly_time_ms = (time.perf_counter() - start) * 1000
|
||||||
|
metrics.end_time = datetime.now(UTC)
|
||||||
|
|
||||||
|
return AssembledContext(
|
||||||
|
content=formatted_content,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
context_count=len(selected_contexts),
|
||||||
|
assembly_time_ms=metrics.assembly_time_ms,
|
||||||
|
model=model,
|
||||||
|
contexts=selected_contexts,
|
||||||
|
excluded_count=metrics.excluded_contexts,
|
||||||
|
metadata={
|
||||||
|
"metrics": metrics.to_dict(),
|
||||||
|
"query": query,
|
||||||
|
"budget": budget.to_dict(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except AssemblyTimeoutError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Context assembly failed: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _ensure_token_counts(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
model: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Ensure all contexts have token counts."""
|
||||||
|
tasks = []
|
||||||
|
for context in contexts:
|
||||||
|
if context.token_count is None:
|
||||||
|
tasks.append(self._count_and_set(context, model))
|
||||||
|
|
||||||
|
if tasks:
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
async def _count_and_set(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Count tokens and set on context."""
|
||||||
|
count = await self._calculator.count_tokens(context.content, model)
|
||||||
|
context.token_count = count
|
||||||
|
|
||||||
|
def _needs_compression(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
budget: TokenBudget,
|
||||||
|
) -> bool:
|
||||||
|
"""Check if any contexts exceed their type budget."""
|
||||||
|
# Group by type and check totals
|
||||||
|
by_type: dict[ContextType, int] = {}
|
||||||
|
for context in contexts:
|
||||||
|
ct = context.get_type()
|
||||||
|
by_type[ct] = by_type.get(ct, 0) + (context.token_count or 0)
|
||||||
|
|
||||||
|
for ct, total in by_type.items():
|
||||||
|
if total > budget.get_allocation(ct):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Also check if utilization exceeds threshold
|
||||||
|
return budget.utilization() > self._settings.compression_threshold
|
||||||
|
|
||||||
|
def _format_contexts(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
model: str,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Format contexts for the target model.
|
||||||
|
|
||||||
|
Groups contexts by type and applies model-specific formatting.
|
||||||
|
"""
|
||||||
|
# Group by type
|
||||||
|
by_type: dict[ContextType, list[BaseContext]] = {}
|
||||||
|
for context in contexts:
|
||||||
|
ct = context.get_type()
|
||||||
|
if ct not in by_type:
|
||||||
|
by_type[ct] = []
|
||||||
|
by_type[ct].append(context)
|
||||||
|
|
||||||
|
# Order types: System -> Task -> Knowledge -> Conversation -> Tool
|
||||||
|
type_order = [
|
||||||
|
ContextType.SYSTEM,
|
||||||
|
ContextType.TASK,
|
||||||
|
ContextType.KNOWLEDGE,
|
||||||
|
ContextType.CONVERSATION,
|
||||||
|
ContextType.TOOL,
|
||||||
|
]
|
||||||
|
|
||||||
|
parts: list[str] = []
|
||||||
|
for ct in type_order:
|
||||||
|
if ct in by_type:
|
||||||
|
formatted = self._format_type(by_type[ct], ct, model)
|
||||||
|
if formatted:
|
||||||
|
parts.append(formatted)
|
||||||
|
|
||||||
|
return "\n\n".join(parts)
|
||||||
|
|
||||||
|
def _format_type(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
context_type: ContextType,
|
||||||
|
model: str,
|
||||||
|
) -> str:
|
||||||
|
"""Format contexts of a specific type."""
|
||||||
|
if not contexts:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Check if model prefers XML tags (Claude)
|
||||||
|
use_xml = "claude" in model.lower()
|
||||||
|
|
||||||
|
if context_type == ContextType.SYSTEM:
|
||||||
|
return self._format_system(contexts, use_xml)
|
||||||
|
elif context_type == ContextType.TASK:
|
||||||
|
return self._format_task(contexts, use_xml)
|
||||||
|
elif context_type == ContextType.KNOWLEDGE:
|
||||||
|
return self._format_knowledge(contexts, use_xml)
|
||||||
|
elif context_type == ContextType.CONVERSATION:
|
||||||
|
return self._format_conversation(contexts, use_xml)
|
||||||
|
elif context_type == ContextType.TOOL:
|
||||||
|
return self._format_tool(contexts, use_xml)
|
||||||
|
|
||||||
|
return "\n".join(c.content for c in contexts)
|
||||||
|
|
||||||
|
def _format_system(
|
||||||
|
self, contexts: list[BaseContext], use_xml: bool
|
||||||
|
) -> str:
|
||||||
|
"""Format system contexts."""
|
||||||
|
content = "\n\n".join(c.content for c in contexts)
|
||||||
|
if use_xml:
|
||||||
|
return f"<system_instructions>\n{content}\n</system_instructions>"
|
||||||
|
return content
|
||||||
|
|
||||||
|
def _format_task(
|
||||||
|
self, contexts: list[BaseContext], use_xml: bool
|
||||||
|
) -> str:
|
||||||
|
"""Format task contexts."""
|
||||||
|
content = "\n\n".join(c.content for c in contexts)
|
||||||
|
if use_xml:
|
||||||
|
return f"<current_task>\n{content}\n</current_task>"
|
||||||
|
return f"## Current Task\n\n{content}"
|
||||||
|
|
||||||
|
def _format_knowledge(
|
||||||
|
self, contexts: list[BaseContext], use_xml: bool
|
||||||
|
) -> str:
|
||||||
|
"""Format knowledge contexts."""
|
||||||
|
if use_xml:
|
||||||
|
parts = ["<reference_documents>"]
|
||||||
|
for ctx in contexts:
|
||||||
|
parts.append(f'<document source="{ctx.source}">')
|
||||||
|
parts.append(ctx.content)
|
||||||
|
parts.append("</document>")
|
||||||
|
parts.append("</reference_documents>")
|
||||||
|
return "\n".join(parts)
|
||||||
|
else:
|
||||||
|
parts = ["## Reference Documents\n"]
|
||||||
|
for ctx in contexts:
|
||||||
|
parts.append(f"### Source: {ctx.source}\n")
|
||||||
|
parts.append(ctx.content)
|
||||||
|
parts.append("")
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
def _format_conversation(
|
||||||
|
self, contexts: list[BaseContext], use_xml: bool
|
||||||
|
) -> str:
|
||||||
|
"""Format conversation contexts."""
|
||||||
|
if use_xml:
|
||||||
|
parts = ["<conversation_history>"]
|
||||||
|
for ctx in contexts:
|
||||||
|
role = ctx.metadata.get("role", "user")
|
||||||
|
parts.append(f'<message role="{role}">')
|
||||||
|
parts.append(ctx.content)
|
||||||
|
parts.append("</message>")
|
||||||
|
parts.append("</conversation_history>")
|
||||||
|
return "\n".join(parts)
|
||||||
|
else:
|
||||||
|
parts = []
|
||||||
|
for ctx in contexts:
|
||||||
|
role = ctx.metadata.get("role", "user")
|
||||||
|
parts.append(f"**{role.upper()}**: {ctx.content}")
|
||||||
|
return "\n\n".join(parts)
|
||||||
|
|
||||||
|
def _format_tool(
|
||||||
|
self, contexts: list[BaseContext], use_xml: bool
|
||||||
|
) -> str:
|
||||||
|
"""Format tool contexts."""
|
||||||
|
if use_xml:
|
||||||
|
parts = ["<tool_results>"]
|
||||||
|
for ctx in contexts:
|
||||||
|
tool_name = ctx.metadata.get("tool_name", "unknown")
|
||||||
|
parts.append(f'<tool_result name="{tool_name}">')
|
||||||
|
parts.append(ctx.content)
|
||||||
|
parts.append("</tool_result>")
|
||||||
|
parts.append("</tool_results>")
|
||||||
|
return "\n".join(parts)
|
||||||
|
else:
|
||||||
|
parts = ["## Recent Tool Results\n"]
|
||||||
|
for ctx in contexts:
|
||||||
|
tool_name = ctx.metadata.get("tool_name", "unknown")
|
||||||
|
parts.append(f"### Tool: {tool_name}\n")
|
||||||
|
parts.append(f"```\n{ctx.content}\n```")
|
||||||
|
parts.append("")
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
def _check_timeout(
|
||||||
|
self,
|
||||||
|
start: float,
|
||||||
|
timeout_ms: int,
|
||||||
|
phase: str,
|
||||||
|
) -> None:
|
||||||
|
"""Check if timeout exceeded and raise if so."""
|
||||||
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||||
|
if elapsed_ms > timeout_ms:
|
||||||
|
raise AssemblyTimeoutError(
|
||||||
|
message=f"Context assembly timed out during {phase}",
|
||||||
|
elapsed_ms=elapsed_ms,
|
||||||
|
timeout_ms=timeout_ms,
|
||||||
|
)
|
||||||
@@ -3,3 +3,11 @@ Context Compression Module.
|
|||||||
|
|
||||||
Provides truncation and compression strategies.
|
Provides truncation and compression strategies.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from .truncation import ContextCompressor, TruncationResult, TruncationStrategy
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ContextCompressor",
|
||||||
|
"TruncationResult",
|
||||||
|
"TruncationStrategy",
|
||||||
|
]
|
||||||
|
|||||||
391
backend/app/services/context/compression/truncation.py
Normal file
391
backend/app/services/context/compression/truncation.py
Normal file
@@ -0,0 +1,391 @@
|
|||||||
|
"""
|
||||||
|
Smart Truncation for Context Compression.
|
||||||
|
|
||||||
|
Provides intelligent truncation strategies to reduce context size
|
||||||
|
while preserving the most important information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ..types import BaseContext, ContextType
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..budget import TokenBudget, TokenCalculator
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TruncationResult:
|
||||||
|
"""Result of truncation operation."""
|
||||||
|
|
||||||
|
original_tokens: int
|
||||||
|
truncated_tokens: int
|
||||||
|
content: str
|
||||||
|
truncated: bool
|
||||||
|
truncation_ratio: float # 0.0 = no truncation, 1.0 = completely removed
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tokens_saved(self) -> int:
|
||||||
|
"""Calculate tokens saved by truncation."""
|
||||||
|
return self.original_tokens - self.truncated_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class TruncationStrategy:
|
||||||
|
"""
|
||||||
|
Smart truncation strategies for context compression.
|
||||||
|
|
||||||
|
Strategies:
|
||||||
|
1. End truncation: Cut from end (for knowledge/docs)
|
||||||
|
2. Middle truncation: Keep start and end (for code)
|
||||||
|
3. Sentence-aware: Truncate at sentence boundaries
|
||||||
|
4. Semantic chunking: Keep most relevant chunks
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Default truncation marker
|
||||||
|
TRUNCATION_MARKER = "\n\n[...content truncated...]\n\n"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
calculator: "TokenCalculator | None" = None,
|
||||||
|
preserve_ratio_start: float = 0.7, # Keep 70% from start by default
|
||||||
|
min_content_length: int = 100, # Minimum characters to keep
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize truncation strategy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
calculator: Token calculator for accurate counting
|
||||||
|
preserve_ratio_start: Ratio of content to keep from start
|
||||||
|
min_content_length: Minimum characters to preserve
|
||||||
|
"""
|
||||||
|
self._calculator = calculator
|
||||||
|
self._preserve_ratio_start = preserve_ratio_start
|
||||||
|
self._min_content_length = min_content_length
|
||||||
|
|
||||||
|
def set_calculator(self, calculator: "TokenCalculator") -> None:
|
||||||
|
"""Set token calculator."""
|
||||||
|
self._calculator = calculator
|
||||||
|
|
||||||
|
async def truncate_to_tokens(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
max_tokens: int,
|
||||||
|
strategy: str = "end",
|
||||||
|
model: str | None = None,
|
||||||
|
) -> TruncationResult:
|
||||||
|
"""
|
||||||
|
Truncate content to fit within token limit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content to truncate
|
||||||
|
max_tokens: Maximum tokens allowed
|
||||||
|
strategy: Truncation strategy ('end', 'middle', 'sentence')
|
||||||
|
model: Model for token counting
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TruncationResult with truncated content
|
||||||
|
"""
|
||||||
|
if not content:
|
||||||
|
return TruncationResult(
|
||||||
|
original_tokens=0,
|
||||||
|
truncated_tokens=0,
|
||||||
|
content="",
|
||||||
|
truncated=False,
|
||||||
|
truncation_ratio=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get original token count
|
||||||
|
original_tokens = await self._count_tokens(content, model)
|
||||||
|
|
||||||
|
if original_tokens <= max_tokens:
|
||||||
|
return TruncationResult(
|
||||||
|
original_tokens=original_tokens,
|
||||||
|
truncated_tokens=original_tokens,
|
||||||
|
content=content,
|
||||||
|
truncated=False,
|
||||||
|
truncation_ratio=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply truncation strategy
|
||||||
|
if strategy == "middle":
|
||||||
|
truncated = await self._truncate_middle(content, max_tokens, model)
|
||||||
|
elif strategy == "sentence":
|
||||||
|
truncated = await self._truncate_sentence(content, max_tokens, model)
|
||||||
|
else: # "end"
|
||||||
|
truncated = await self._truncate_end(content, max_tokens, model)
|
||||||
|
|
||||||
|
truncated_tokens = await self._count_tokens(truncated, model)
|
||||||
|
|
||||||
|
return TruncationResult(
|
||||||
|
original_tokens=original_tokens,
|
||||||
|
truncated_tokens=truncated_tokens,
|
||||||
|
content=truncated,
|
||||||
|
truncated=True,
|
||||||
|
truncation_ratio=1 - (truncated_tokens / original_tokens),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _truncate_end(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
max_tokens: int,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Truncate from end of content.
|
||||||
|
|
||||||
|
Simple but effective for most content types.
|
||||||
|
"""
|
||||||
|
# Binary search for optimal truncation point
|
||||||
|
marker_tokens = await self._count_tokens(self.TRUNCATION_MARKER, model)
|
||||||
|
available_tokens = max_tokens - marker_tokens
|
||||||
|
|
||||||
|
# Estimate characters per token
|
||||||
|
chars_per_token = len(content) / await self._count_tokens(content, model)
|
||||||
|
|
||||||
|
# Start with estimated position
|
||||||
|
estimated_chars = int(available_tokens * chars_per_token)
|
||||||
|
truncated = content[:estimated_chars]
|
||||||
|
|
||||||
|
# Refine with binary search
|
||||||
|
low, high = len(truncated) // 2, len(truncated)
|
||||||
|
best = truncated
|
||||||
|
|
||||||
|
for _ in range(5): # Max 5 iterations
|
||||||
|
mid = (low + high) // 2
|
||||||
|
candidate = content[:mid]
|
||||||
|
tokens = await self._count_tokens(candidate, model)
|
||||||
|
|
||||||
|
if tokens <= available_tokens:
|
||||||
|
best = candidate
|
||||||
|
low = mid + 1
|
||||||
|
else:
|
||||||
|
high = mid - 1
|
||||||
|
|
||||||
|
return best + self.TRUNCATION_MARKER
|
||||||
|
|
||||||
|
async def _truncate_middle(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
max_tokens: int,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Truncate from middle, keeping start and end.
|
||||||
|
|
||||||
|
Good for code or content where context at boundaries matters.
|
||||||
|
"""
|
||||||
|
marker_tokens = await self._count_tokens(self.TRUNCATION_MARKER, model)
|
||||||
|
available_tokens = max_tokens - marker_tokens
|
||||||
|
|
||||||
|
# Split between start and end
|
||||||
|
start_tokens = int(available_tokens * self._preserve_ratio_start)
|
||||||
|
end_tokens = available_tokens - start_tokens
|
||||||
|
|
||||||
|
# Get start portion
|
||||||
|
start_content = await self._get_content_for_tokens(
|
||||||
|
content, start_tokens, from_start=True, model=model
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get end portion
|
||||||
|
end_content = await self._get_content_for_tokens(
|
||||||
|
content, end_tokens, from_start=False, model=model
|
||||||
|
)
|
||||||
|
|
||||||
|
return start_content + self.TRUNCATION_MARKER + end_content
|
||||||
|
|
||||||
|
async def _truncate_sentence(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
max_tokens: int,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Truncate at sentence boundaries.
|
||||||
|
|
||||||
|
Produces cleaner output by not cutting mid-sentence.
|
||||||
|
"""
|
||||||
|
# Split into sentences
|
||||||
|
sentences = re.split(r"(?<=[.!?])\s+", content)
|
||||||
|
|
||||||
|
result: list[str] = []
|
||||||
|
total_tokens = 0
|
||||||
|
marker_tokens = await self._count_tokens(self.TRUNCATION_MARKER, model)
|
||||||
|
available = max_tokens - marker_tokens
|
||||||
|
|
||||||
|
for sentence in sentences:
|
||||||
|
sentence_tokens = await self._count_tokens(sentence, model)
|
||||||
|
if total_tokens + sentence_tokens <= available:
|
||||||
|
result.append(sentence)
|
||||||
|
total_tokens += sentence_tokens
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
if len(result) < len(sentences):
|
||||||
|
return " ".join(result) + self.TRUNCATION_MARKER
|
||||||
|
return " ".join(result)
|
||||||
|
|
||||||
|
async def _get_content_for_tokens(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
target_tokens: int,
|
||||||
|
from_start: bool = True,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Get portion of content fitting within token limit."""
|
||||||
|
if target_tokens <= 0:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
current_tokens = await self._count_tokens(content, model)
|
||||||
|
if current_tokens <= target_tokens:
|
||||||
|
return content
|
||||||
|
|
||||||
|
# Estimate characters
|
||||||
|
chars_per_token = len(content) / current_tokens
|
||||||
|
estimated_chars = int(target_tokens * chars_per_token)
|
||||||
|
|
||||||
|
if from_start:
|
||||||
|
return content[:estimated_chars]
|
||||||
|
else:
|
||||||
|
return content[-estimated_chars:]
|
||||||
|
|
||||||
|
async def _count_tokens(self, text: str, model: str | None = None) -> int:
|
||||||
|
"""Count tokens using calculator or estimation."""
|
||||||
|
if self._calculator is not None:
|
||||||
|
return await self._calculator.count_tokens(text, model)
|
||||||
|
|
||||||
|
# Fallback estimation
|
||||||
|
return max(1, len(text) // 4)
|
||||||
|
|
||||||
|
|
||||||
|
class ContextCompressor:
|
||||||
|
"""
|
||||||
|
Compresses contexts to fit within budget constraints.
|
||||||
|
|
||||||
|
Uses truncation strategies to reduce context size while
|
||||||
|
preserving the most important information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
truncation: TruncationStrategy | None = None,
|
||||||
|
calculator: "TokenCalculator | None" = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize context compressor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
truncation: Truncation strategy to use
|
||||||
|
calculator: Token calculator for counting
|
||||||
|
"""
|
||||||
|
self._truncation = truncation or TruncationStrategy(calculator)
|
||||||
|
self._calculator = calculator
|
||||||
|
|
||||||
|
if calculator:
|
||||||
|
self._truncation.set_calculator(calculator)
|
||||||
|
|
||||||
|
def set_calculator(self, calculator: "TokenCalculator") -> None:
|
||||||
|
"""Set token calculator."""
|
||||||
|
self._calculator = calculator
|
||||||
|
self._truncation.set_calculator(calculator)
|
||||||
|
|
||||||
|
async def compress_context(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
max_tokens: int,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> BaseContext:
|
||||||
|
"""
|
||||||
|
Compress a single context to fit token limit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to compress
|
||||||
|
max_tokens: Maximum tokens allowed
|
||||||
|
model: Model for token counting
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Compressed context (may be same object if no compression needed)
|
||||||
|
"""
|
||||||
|
current_tokens = context.token_count or await self._count_tokens(
|
||||||
|
context.content, model
|
||||||
|
)
|
||||||
|
|
||||||
|
if current_tokens <= max_tokens:
|
||||||
|
return context
|
||||||
|
|
||||||
|
# Choose strategy based on context type
|
||||||
|
strategy = self._get_strategy_for_type(context.get_type())
|
||||||
|
|
||||||
|
result = await self._truncation.truncate_to_tokens(
|
||||||
|
content=context.content,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
strategy=strategy,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update context with truncated content
|
||||||
|
context.content = result.content
|
||||||
|
context.token_count = result.truncated_tokens
|
||||||
|
context.metadata["truncated"] = True
|
||||||
|
context.metadata["original_tokens"] = result.original_tokens
|
||||||
|
|
||||||
|
return context
|
||||||
|
|
||||||
|
async def compress_contexts(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
budget: "TokenBudget",
|
||||||
|
model: str | None = None,
|
||||||
|
) -> list[BaseContext]:
|
||||||
|
"""
|
||||||
|
Compress multiple contexts to fit within budget.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: Contexts to potentially compress
|
||||||
|
budget: Token budget constraints
|
||||||
|
model: Model for token counting
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of contexts (compressed as needed)
|
||||||
|
"""
|
||||||
|
result: list[BaseContext] = []
|
||||||
|
|
||||||
|
for context in contexts:
|
||||||
|
context_type = context.get_type()
|
||||||
|
remaining = budget.remaining(context_type)
|
||||||
|
current_tokens = context.token_count or await self._count_tokens(
|
||||||
|
context.content, model
|
||||||
|
)
|
||||||
|
|
||||||
|
if current_tokens > remaining:
|
||||||
|
# Need to compress
|
||||||
|
compressed = await self.compress_context(context, remaining, model)
|
||||||
|
result.append(compressed)
|
||||||
|
logger.debug(
|
||||||
|
f"Compressed {context_type.value} context from "
|
||||||
|
f"{current_tokens} to {compressed.token_count} tokens"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result.append(context)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _get_strategy_for_type(self, context_type: ContextType) -> str:
|
||||||
|
"""Get optimal truncation strategy for context type."""
|
||||||
|
strategies = {
|
||||||
|
ContextType.SYSTEM: "end", # Keep instructions at start
|
||||||
|
ContextType.TASK: "end", # Keep task description start
|
||||||
|
ContextType.KNOWLEDGE: "sentence", # Clean sentence boundaries
|
||||||
|
ContextType.CONVERSATION: "end", # Keep recent conversation
|
||||||
|
ContextType.TOOL: "middle", # Keep command and result summary
|
||||||
|
}
|
||||||
|
return strategies.get(context_type, "end")
|
||||||
|
|
||||||
|
async def _count_tokens(self, text: str, model: str | None = None) -> int:
|
||||||
|
"""Count tokens using calculator or estimation."""
|
||||||
|
if self._calculator is not None:
|
||||||
|
return await self._calculator.count_tokens(text, model)
|
||||||
|
return max(1, len(text) // 4)
|
||||||
@@ -253,12 +253,19 @@ class AssembledContext:
|
|||||||
|
|
||||||
# Main content
|
# Main content
|
||||||
content: str
|
content: str
|
||||||
token_count: int
|
total_tokens: int
|
||||||
|
|
||||||
# Assembly metadata
|
# Assembly metadata
|
||||||
contexts_included: int
|
context_count: int
|
||||||
contexts_excluded: int = 0
|
excluded_count: int = 0
|
||||||
assembly_time_ms: float = 0.0
|
assembly_time_ms: float = 0.0
|
||||||
|
model: str = ""
|
||||||
|
|
||||||
|
# Included contexts (optional - for inspection)
|
||||||
|
contexts: list["BaseContext"] = field(default_factory=list)
|
||||||
|
|
||||||
|
# Additional metadata from assembly
|
||||||
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
# Budget tracking
|
# Budget tracking
|
||||||
budget_total: int = 0
|
budget_total: int = 0
|
||||||
@@ -271,6 +278,22 @@ class AssembledContext:
|
|||||||
cache_hit: bool = False
|
cache_hit: bool = False
|
||||||
cache_key: str | None = None
|
cache_key: str | None = None
|
||||||
|
|
||||||
|
# Aliases for backward compatibility
|
||||||
|
@property
|
||||||
|
def token_count(self) -> int:
|
||||||
|
"""Alias for total_tokens."""
|
||||||
|
return self.total_tokens
|
||||||
|
|
||||||
|
@property
|
||||||
|
def contexts_included(self) -> int:
|
||||||
|
"""Alias for context_count."""
|
||||||
|
return self.context_count
|
||||||
|
|
||||||
|
@property
|
||||||
|
def contexts_excluded(self) -> int:
|
||||||
|
"""Alias for excluded_count."""
|
||||||
|
return self.excluded_count
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def budget_utilization(self) -> float:
|
def budget_utilization(self) -> float:
|
||||||
"""Get budget utilization percentage."""
|
"""Get budget utilization percentage."""
|
||||||
@@ -282,10 +305,12 @@ class AssembledContext:
|
|||||||
"""Convert to dictionary."""
|
"""Convert to dictionary."""
|
||||||
return {
|
return {
|
||||||
"content": self.content,
|
"content": self.content,
|
||||||
"token_count": self.token_count,
|
"total_tokens": self.total_tokens,
|
||||||
"contexts_included": self.contexts_included,
|
"context_count": self.context_count,
|
||||||
"contexts_excluded": self.contexts_excluded,
|
"excluded_count": self.excluded_count,
|
||||||
"assembly_time_ms": round(self.assembly_time_ms, 2),
|
"assembly_time_ms": round(self.assembly_time_ms, 2),
|
||||||
|
"model": self.model,
|
||||||
|
"metadata": self.metadata,
|
||||||
"budget_total": self.budget_total,
|
"budget_total": self.budget_total,
|
||||||
"budget_used": self.budget_used,
|
"budget_used": self.budget_used,
|
||||||
"budget_utilization": round(self.budget_utilization, 3),
|
"budget_utilization": round(self.budget_utilization, 3),
|
||||||
@@ -308,10 +333,12 @@ class AssembledContext:
|
|||||||
data = json.loads(json_str)
|
data = json.loads(json_str)
|
||||||
return cls(
|
return cls(
|
||||||
content=data["content"],
|
content=data["content"],
|
||||||
token_count=data["token_count"],
|
total_tokens=data["total_tokens"],
|
||||||
contexts_included=data["contexts_included"],
|
context_count=data["context_count"],
|
||||||
contexts_excluded=data.get("contexts_excluded", 0),
|
excluded_count=data.get("excluded_count", 0),
|
||||||
assembly_time_ms=data.get("assembly_time_ms", 0.0),
|
assembly_time_ms=data.get("assembly_time_ms", 0.0),
|
||||||
|
model=data.get("model", ""),
|
||||||
|
metadata=data.get("metadata", {}),
|
||||||
budget_total=data.get("budget_total", 0),
|
budget_total=data.get("budget_total", 0),
|
||||||
budget_used=data.get("budget_used", 0),
|
budget_used=data.get("budget_used", 0),
|
||||||
by_type=data.get("by_type", {}),
|
by_type=data.get("by_type", {}),
|
||||||
|
|||||||
502
backend/tests/services/context/test_assembly.py
Normal file
502
backend/tests/services/context/test_assembly.py
Normal file
@@ -0,0 +1,502 @@
|
|||||||
|
"""Tests for context assembly pipeline."""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.context.assembly import ContextPipeline, PipelineMetrics
|
||||||
|
from app.services.context.budget import BudgetAllocator, TokenBudget
|
||||||
|
from app.services.context.types import (
|
||||||
|
AssembledContext,
|
||||||
|
ContextType,
|
||||||
|
ConversationContext,
|
||||||
|
KnowledgeContext,
|
||||||
|
MessageRole,
|
||||||
|
SystemContext,
|
||||||
|
TaskContext,
|
||||||
|
ToolContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPipelineMetrics:
|
||||||
|
"""Tests for PipelineMetrics dataclass."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test metrics creation."""
|
||||||
|
metrics = PipelineMetrics()
|
||||||
|
|
||||||
|
assert metrics.total_contexts == 0
|
||||||
|
assert metrics.selected_contexts == 0
|
||||||
|
assert metrics.assembly_time_ms == 0.0
|
||||||
|
|
||||||
|
def test_to_dict(self) -> None:
|
||||||
|
"""Test conversion to dictionary."""
|
||||||
|
metrics = PipelineMetrics(
|
||||||
|
total_contexts=10,
|
||||||
|
selected_contexts=8,
|
||||||
|
excluded_contexts=2,
|
||||||
|
total_tokens=500,
|
||||||
|
assembly_time_ms=25.5,
|
||||||
|
)
|
||||||
|
metrics.end_time = datetime.now(UTC)
|
||||||
|
|
||||||
|
data = metrics.to_dict()
|
||||||
|
|
||||||
|
assert data["total_contexts"] == 10
|
||||||
|
assert data["selected_contexts"] == 8
|
||||||
|
assert data["excluded_contexts"] == 2
|
||||||
|
assert data["total_tokens"] == 500
|
||||||
|
assert data["assembly_time_ms"] == 25.5
|
||||||
|
assert "start_time" in data
|
||||||
|
assert "end_time" in data
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextPipeline:
|
||||||
|
"""Tests for ContextPipeline."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test pipeline creation."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
|
||||||
|
assert pipeline._calculator is not None
|
||||||
|
assert pipeline._scorer is not None
|
||||||
|
assert pipeline._ranker is not None
|
||||||
|
assert pipeline._compressor is not None
|
||||||
|
assert pipeline._allocator is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assemble_empty_contexts(self) -> None:
|
||||||
|
"""Test assembling empty context list."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
|
||||||
|
result = await pipeline.assemble(
|
||||||
|
contexts=[],
|
||||||
|
query="test query",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, AssembledContext)
|
||||||
|
assert result.context_count == 0
|
||||||
|
assert result.total_tokens == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assemble_single_context(self) -> None:
|
||||||
|
"""Test assembling single context."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
SystemContext(
|
||||||
|
content="You are a helpful assistant.",
|
||||||
|
source="system",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query="help me",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.context_count == 1
|
||||||
|
assert result.total_tokens > 0
|
||||||
|
assert "helpful assistant" in result.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assemble_multiple_types(self) -> None:
|
||||||
|
"""Test assembling multiple context types."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
SystemContext(
|
||||||
|
content="You are a coding assistant.",
|
||||||
|
source="system",
|
||||||
|
),
|
||||||
|
TaskContext(
|
||||||
|
content="Implement a login feature.",
|
||||||
|
source="task",
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Authentication best practices include...",
|
||||||
|
source="docs/auth.md",
|
||||||
|
relevance_score=0.8,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query="implement login",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.context_count >= 1
|
||||||
|
assert result.total_tokens > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assemble_with_custom_budget(self) -> None:
|
||||||
|
"""Test assembling with custom budget."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=1000,
|
||||||
|
system=200,
|
||||||
|
task=200,
|
||||||
|
knowledge=400,
|
||||||
|
conversation=100,
|
||||||
|
tools=50,
|
||||||
|
response_reserve=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
SystemContext(content="System prompt", source="system"),
|
||||||
|
TaskContext(content="Task description", source="task"),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query="test",
|
||||||
|
model="gpt-4",
|
||||||
|
custom_budget=budget,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.context_count >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assemble_with_max_tokens(self) -> None:
|
||||||
|
"""Test assembling with max_tokens limit."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
SystemContext(content="System prompt", source="system"),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query="test",
|
||||||
|
model="gpt-4",
|
||||||
|
max_tokens=5000,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "budget" in result.metadata
|
||||||
|
assert result.metadata["budget"]["total"] == 5000
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assemble_format_output(self) -> None:
|
||||||
|
"""Test formatted vs unformatted output."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
SystemContext(content="System prompt", source="system"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Formatted (default)
|
||||||
|
result_formatted = await pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query="test",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
format_output=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Unformatted
|
||||||
|
result_raw = await pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query="test",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
format_output=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Formatted should have XML tags for Claude
|
||||||
|
assert "<system_instructions>" in result_formatted.content
|
||||||
|
# Raw should not
|
||||||
|
assert "<system_instructions>" not in result_raw.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assemble_metrics(self) -> None:
|
||||||
|
"""Test that metrics are populated."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
SystemContext(content="System", source="system"),
|
||||||
|
TaskContext(content="Task", source="task"),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Knowledge",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.9,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query="test",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "metrics" in result.metadata
|
||||||
|
metrics = result.metadata["metrics"]
|
||||||
|
|
||||||
|
assert metrics["total_contexts"] == 3
|
||||||
|
assert metrics["assembly_time_ms"] > 0
|
||||||
|
assert "scoring_time_ms" in metrics
|
||||||
|
assert "formatting_time_ms" in metrics
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assemble_with_compression_disabled(self) -> None:
|
||||||
|
"""Test assembling with compression disabled."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(content="A" * 1000, source="docs"),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query="test",
|
||||||
|
model="gpt-4",
|
||||||
|
compress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should still work, just no compression applied
|
||||||
|
assert result.context_count >= 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextPipelineFormatting:
|
||||||
|
"""Tests for context formatting."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_format_claude_uses_xml(self) -> None:
|
||||||
|
"""Test that Claude models use XML formatting."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
SystemContext(content="System prompt", source="system"),
|
||||||
|
TaskContext(content="Task", source="task"),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Knowledge",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.9,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query="test",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Claude should have XML tags
|
||||||
|
assert "<system_instructions>" in result.content or result.context_count == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_format_openai_uses_markdown(self) -> None:
|
||||||
|
"""Test that OpenAI models use markdown formatting."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
TaskContext(content="Task description", source="task"),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query="test",
|
||||||
|
model="gpt-4",
|
||||||
|
)
|
||||||
|
|
||||||
|
# OpenAI should have markdown headers
|
||||||
|
if result.context_count > 0 and "Task" in result.content:
|
||||||
|
assert "## Current Task" in result.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_format_knowledge_claude(self) -> None:
|
||||||
|
"""Test knowledge formatting for Claude."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Document content here",
|
||||||
|
source="docs/file.md",
|
||||||
|
relevance_score=0.9,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query="test",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.context_count > 0:
|
||||||
|
assert "<reference_documents>" in result.content
|
||||||
|
assert "<document" in result.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_format_conversation(self) -> None:
|
||||||
|
"""Test conversation formatting."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
ConversationContext(
|
||||||
|
content="Hello, how are you?",
|
||||||
|
source="chat",
|
||||||
|
role=MessageRole.USER,
|
||||||
|
metadata={"role": "user"},
|
||||||
|
),
|
||||||
|
ConversationContext(
|
||||||
|
content="I'm doing great!",
|
||||||
|
source="chat",
|
||||||
|
role=MessageRole.ASSISTANT,
|
||||||
|
metadata={"role": "assistant"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query="test",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.context_count > 0:
|
||||||
|
assert "<conversation_history>" in result.content
|
||||||
|
assert '<message role="user">' in result.content or 'role="user"' in result.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_format_tool_results(self) -> None:
|
||||||
|
"""Test tool result formatting."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
ToolContext(
|
||||||
|
content="Tool output here",
|
||||||
|
source="tool",
|
||||||
|
metadata={"tool_name": "search"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query="test",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.context_count > 0:
|
||||||
|
assert "<tool_results>" in result.content
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextPipelineIntegration:
|
||||||
|
"""Integration tests for full pipeline."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_pipeline_workflow(self) -> None:
|
||||||
|
"""Test complete pipeline workflow."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
|
||||||
|
# Create realistic context mix
|
||||||
|
contexts = [
|
||||||
|
SystemContext(
|
||||||
|
content="You are an expert Python developer.",
|
||||||
|
source="system",
|
||||||
|
),
|
||||||
|
TaskContext(
|
||||||
|
content="Implement a user authentication system.",
|
||||||
|
source="task:AUTH-123",
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="JWT tokens provide stateless authentication...",
|
||||||
|
source="docs/auth/jwt.md",
|
||||||
|
relevance_score=0.9,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="OAuth 2.0 is an authorization framework...",
|
||||||
|
source="docs/auth/oauth.md",
|
||||||
|
relevance_score=0.7,
|
||||||
|
),
|
||||||
|
ConversationContext(
|
||||||
|
content="Can you help me implement JWT auth?",
|
||||||
|
source="chat",
|
||||||
|
role=MessageRole.USER,
|
||||||
|
metadata={"role": "user"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query="implement JWT authentication",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify result
|
||||||
|
assert isinstance(result, AssembledContext)
|
||||||
|
assert result.context_count > 0
|
||||||
|
assert result.total_tokens > 0
|
||||||
|
assert result.assembly_time_ms > 0
|
||||||
|
assert result.model == "claude-3-sonnet"
|
||||||
|
assert len(result.content) > 0
|
||||||
|
|
||||||
|
# Verify metrics
|
||||||
|
assert "metrics" in result.metadata
|
||||||
|
assert "query" in result.metadata
|
||||||
|
assert "budget" in result.metadata
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_type_ordering(self) -> None:
|
||||||
|
"""Test that contexts are ordered by type correctly."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
|
||||||
|
# Add in random order
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(content="Knowledge", source="docs", relevance_score=0.9),
|
||||||
|
ToolContext(content="Tool", source="tool", metadata={"tool_name": "test"}),
|
||||||
|
SystemContext(content="System", source="system"),
|
||||||
|
ConversationContext(
|
||||||
|
content="Chat",
|
||||||
|
source="chat",
|
||||||
|
role=MessageRole.USER,
|
||||||
|
metadata={"role": "user"},
|
||||||
|
),
|
||||||
|
TaskContext(content="Task", source="task"),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query="test",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
)
|
||||||
|
|
||||||
|
# For Claude, verify order: System -> Task -> Knowledge -> Conversation -> Tool
|
||||||
|
content = result.content
|
||||||
|
if result.context_count > 0:
|
||||||
|
# Find positions (if they exist)
|
||||||
|
system_pos = content.find("system_instructions")
|
||||||
|
task_pos = content.find("current_task")
|
||||||
|
knowledge_pos = content.find("reference_documents")
|
||||||
|
conversation_pos = content.find("conversation_history")
|
||||||
|
tool_pos = content.find("tool_results")
|
||||||
|
|
||||||
|
# Verify ordering (only check if both exist)
|
||||||
|
if system_pos >= 0 and task_pos >= 0:
|
||||||
|
assert system_pos < task_pos
|
||||||
|
if task_pos >= 0 and knowledge_pos >= 0:
|
||||||
|
assert task_pos < knowledge_pos
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_excluded_contexts_tracked(self) -> None:
|
||||||
|
"""Test that excluded contexts are tracked in result."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
|
||||||
|
# Create many contexts to force some exclusions
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="A" * 500, # Large content
|
||||||
|
source=f"docs/{i}",
|
||||||
|
relevance_score=0.1 + (i * 0.05),
|
||||||
|
)
|
||||||
|
for i in range(10)
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query="test",
|
||||||
|
model="gpt-4", # Smaller context window
|
||||||
|
max_tokens=1000, # Limited budget
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should have excluded some
|
||||||
|
assert result.excluded_count >= 0
|
||||||
|
assert result.context_count + result.excluded_count <= len(contexts)
|
||||||
214
backend/tests/services/context/test_compression.py
Normal file
214
backend/tests/services/context/test_compression.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
"""Tests for context compression module."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.context.compression import (
|
||||||
|
ContextCompressor,
|
||||||
|
TruncationResult,
|
||||||
|
TruncationStrategy,
|
||||||
|
)
|
||||||
|
from app.services.context.budget import BudgetAllocator, TokenBudget
|
||||||
|
from app.services.context.types import (
|
||||||
|
ContextType,
|
||||||
|
KnowledgeContext,
|
||||||
|
SystemContext,
|
||||||
|
TaskContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTruncationResult:
|
||||||
|
"""Tests for TruncationResult dataclass."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test basic creation."""
|
||||||
|
result = TruncationResult(
|
||||||
|
original_tokens=100,
|
||||||
|
truncated_tokens=50,
|
||||||
|
content="Truncated content",
|
||||||
|
truncated=True,
|
||||||
|
truncation_ratio=0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.original_tokens == 100
|
||||||
|
assert result.truncated_tokens == 50
|
||||||
|
assert result.truncated is True
|
||||||
|
assert result.truncation_ratio == 0.5
|
||||||
|
|
||||||
|
def test_tokens_saved(self) -> None:
|
||||||
|
"""Test tokens_saved property."""
|
||||||
|
result = TruncationResult(
|
||||||
|
original_tokens=100,
|
||||||
|
truncated_tokens=40,
|
||||||
|
content="Test",
|
||||||
|
truncated=True,
|
||||||
|
truncation_ratio=0.6,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.tokens_saved == 60
|
||||||
|
|
||||||
|
def test_no_truncation(self) -> None:
|
||||||
|
"""Test when no truncation needed."""
|
||||||
|
result = TruncationResult(
|
||||||
|
original_tokens=50,
|
||||||
|
truncated_tokens=50,
|
||||||
|
content="Full content",
|
||||||
|
truncated=False,
|
||||||
|
truncation_ratio=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.tokens_saved == 0
|
||||||
|
assert result.truncated is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestTruncationStrategy:
|
||||||
|
"""Tests for TruncationStrategy."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test strategy creation."""
|
||||||
|
strategy = TruncationStrategy()
|
||||||
|
assert strategy._preserve_ratio_start == 0.7
|
||||||
|
assert strategy._min_content_length == 100
|
||||||
|
|
||||||
|
def test_creation_with_params(self) -> None:
|
||||||
|
"""Test strategy creation with custom params."""
|
||||||
|
strategy = TruncationStrategy(
|
||||||
|
preserve_ratio_start=0.5,
|
||||||
|
min_content_length=50,
|
||||||
|
)
|
||||||
|
assert strategy._preserve_ratio_start == 0.5
|
||||||
|
assert strategy._min_content_length == 50
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_truncate_empty_content(self) -> None:
|
||||||
|
"""Test truncating empty content."""
|
||||||
|
strategy = TruncationStrategy()
|
||||||
|
|
||||||
|
result = await strategy.truncate_to_tokens("", max_tokens=100)
|
||||||
|
|
||||||
|
assert result.original_tokens == 0
|
||||||
|
assert result.truncated_tokens == 0
|
||||||
|
assert result.content == ""
|
||||||
|
assert result.truncated is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_truncate_content_within_limit(self) -> None:
|
||||||
|
"""Test content that fits within limit."""
|
||||||
|
strategy = TruncationStrategy()
|
||||||
|
content = "Short content"
|
||||||
|
|
||||||
|
result = await strategy.truncate_to_tokens(content, max_tokens=100)
|
||||||
|
|
||||||
|
assert result.content == content
|
||||||
|
assert result.truncated is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_truncate_end_strategy(self) -> None:
|
||||||
|
"""Test end truncation strategy."""
|
||||||
|
strategy = TruncationStrategy()
|
||||||
|
content = "A" * 1000 # Long content
|
||||||
|
|
||||||
|
result = await strategy.truncate_to_tokens(
|
||||||
|
content, max_tokens=50, strategy="end"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.truncated is True
|
||||||
|
assert len(result.content) < len(content)
|
||||||
|
assert strategy.TRUNCATION_MARKER in result.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_truncate_middle_strategy(self) -> None:
|
||||||
|
"""Test middle truncation strategy."""
|
||||||
|
strategy = TruncationStrategy(preserve_ratio_start=0.6)
|
||||||
|
content = "START " + "A" * 500 + " END"
|
||||||
|
|
||||||
|
result = await strategy.truncate_to_tokens(
|
||||||
|
content, max_tokens=50, strategy="middle"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.truncated is True
|
||||||
|
assert strategy.TRUNCATION_MARKER in result.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_truncate_sentence_strategy(self) -> None:
|
||||||
|
"""Test sentence-aware truncation strategy."""
|
||||||
|
strategy = TruncationStrategy()
|
||||||
|
content = "First sentence. Second sentence. Third sentence. Fourth sentence."
|
||||||
|
|
||||||
|
result = await strategy.truncate_to_tokens(
|
||||||
|
content, max_tokens=10, strategy="sentence"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.truncated is True
|
||||||
|
# Should cut at sentence boundary
|
||||||
|
assert result.content.endswith(".") or strategy.TRUNCATION_MARKER in result.content
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextCompressor:
|
||||||
|
"""Tests for ContextCompressor."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test compressor creation."""
|
||||||
|
compressor = ContextCompressor()
|
||||||
|
assert compressor._truncation is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_compress_context_within_limit(self) -> None:
|
||||||
|
"""Test compressing context that already fits."""
|
||||||
|
compressor = ContextCompressor()
|
||||||
|
|
||||||
|
context = KnowledgeContext(
|
||||||
|
content="Short content",
|
||||||
|
source="docs",
|
||||||
|
)
|
||||||
|
context.token_count = 5
|
||||||
|
|
||||||
|
result = await compressor.compress_context(context, max_tokens=100)
|
||||||
|
|
||||||
|
# Should return same context unmodified
|
||||||
|
assert result.content == "Short content"
|
||||||
|
assert result.metadata.get("truncated") is not True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_compress_context_exceeds_limit(self) -> None:
|
||||||
|
"""Test compressing context that exceeds limit."""
|
||||||
|
compressor = ContextCompressor()
|
||||||
|
|
||||||
|
context = KnowledgeContext(
|
||||||
|
content="A" * 500,
|
||||||
|
source="docs",
|
||||||
|
)
|
||||||
|
context.token_count = 125 # Approximately 500/4
|
||||||
|
|
||||||
|
result = await compressor.compress_context(context, max_tokens=20)
|
||||||
|
|
||||||
|
assert result.metadata.get("truncated") is True
|
||||||
|
assert result.metadata.get("original_tokens") == 125
|
||||||
|
assert len(result.content) < 500
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_compress_contexts_batch(self) -> None:
|
||||||
|
"""Test compressing multiple contexts."""
|
||||||
|
compressor = ContextCompressor()
|
||||||
|
allocator = BudgetAllocator()
|
||||||
|
budget = allocator.create_budget(1000)
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(content="A" * 200, source="docs"),
|
||||||
|
KnowledgeContext(content="B" * 200, source="docs"),
|
||||||
|
TaskContext(content="C" * 200, source="task"),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await compressor.compress_contexts(contexts, budget)
|
||||||
|
|
||||||
|
assert len(result) == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_strategy_selection_by_type(self) -> None:
|
||||||
|
"""Test that correct strategy is selected for each type."""
|
||||||
|
compressor = ContextCompressor()
|
||||||
|
|
||||||
|
assert compressor._get_strategy_for_type(ContextType.SYSTEM) == "end"
|
||||||
|
assert compressor._get_strategy_for_type(ContextType.TASK) == "end"
|
||||||
|
assert compressor._get_strategy_for_type(ContextType.KNOWLEDGE) == "sentence"
|
||||||
|
assert compressor._get_strategy_for_type(ContextType.CONVERSATION) == "end"
|
||||||
|
assert compressor._get_strategy_for_type(ContextType.TOOL) == "middle"
|
||||||
@@ -426,11 +426,14 @@ class TestAssembledContext:
|
|||||||
"""Test basic creation."""
|
"""Test basic creation."""
|
||||||
ctx = AssembledContext(
|
ctx = AssembledContext(
|
||||||
content="Assembled content here",
|
content="Assembled content here",
|
||||||
token_count=500,
|
total_tokens=500,
|
||||||
contexts_included=5,
|
context_count=5,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert ctx.content == "Assembled content here"
|
assert ctx.content == "Assembled content here"
|
||||||
|
assert ctx.total_tokens == 500
|
||||||
|
assert ctx.context_count == 5
|
||||||
|
# Test backward compatibility aliases
|
||||||
assert ctx.token_count == 500
|
assert ctx.token_count == 500
|
||||||
assert ctx.contexts_included == 5
|
assert ctx.contexts_included == 5
|
||||||
|
|
||||||
@@ -438,8 +441,8 @@ class TestAssembledContext:
|
|||||||
"""Test budget_utilization property."""
|
"""Test budget_utilization property."""
|
||||||
ctx = AssembledContext(
|
ctx = AssembledContext(
|
||||||
content="test",
|
content="test",
|
||||||
token_count=800,
|
total_tokens=800,
|
||||||
contexts_included=5,
|
context_count=5,
|
||||||
budget_total=1000,
|
budget_total=1000,
|
||||||
budget_used=800,
|
budget_used=800,
|
||||||
)
|
)
|
||||||
@@ -450,8 +453,8 @@ class TestAssembledContext:
|
|||||||
"""Test budget_utilization with zero budget."""
|
"""Test budget_utilization with zero budget."""
|
||||||
ctx = AssembledContext(
|
ctx = AssembledContext(
|
||||||
content="test",
|
content="test",
|
||||||
token_count=0,
|
total_tokens=0,
|
||||||
contexts_included=0,
|
context_count=0,
|
||||||
budget_total=0,
|
budget_total=0,
|
||||||
budget_used=0,
|
budget_used=0,
|
||||||
)
|
)
|
||||||
@@ -462,24 +465,26 @@ class TestAssembledContext:
|
|||||||
"""Test to_dict method."""
|
"""Test to_dict method."""
|
||||||
ctx = AssembledContext(
|
ctx = AssembledContext(
|
||||||
content="test",
|
content="test",
|
||||||
token_count=100,
|
total_tokens=100,
|
||||||
contexts_included=2,
|
context_count=2,
|
||||||
assembly_time_ms=50.123,
|
assembly_time_ms=50.123,
|
||||||
)
|
)
|
||||||
|
|
||||||
data = ctx.to_dict()
|
data = ctx.to_dict()
|
||||||
assert data["content"] == "test"
|
assert data["content"] == "test"
|
||||||
assert data["token_count"] == 100
|
assert data["total_tokens"] == 100
|
||||||
|
assert data["context_count"] == 2
|
||||||
assert data["assembly_time_ms"] == 50.12 # Rounded
|
assert data["assembly_time_ms"] == 50.12 # Rounded
|
||||||
|
|
||||||
def test_to_json_and_from_json(self) -> None:
|
def test_to_json_and_from_json(self) -> None:
|
||||||
"""Test JSON serialization round-trip."""
|
"""Test JSON serialization round-trip."""
|
||||||
original = AssembledContext(
|
original = AssembledContext(
|
||||||
content="Test content",
|
content="Test content",
|
||||||
token_count=100,
|
total_tokens=100,
|
||||||
contexts_included=3,
|
context_count=3,
|
||||||
contexts_excluded=2,
|
excluded_count=2,
|
||||||
assembly_time_ms=45.5,
|
assembly_time_ms=45.5,
|
||||||
|
model="claude-3-sonnet",
|
||||||
budget_total=1000,
|
budget_total=1000,
|
||||||
budget_used=100,
|
budget_used=100,
|
||||||
by_type={"system": 20, "knowledge": 80},
|
by_type={"system": 20, "knowledge": 80},
|
||||||
@@ -491,8 +496,10 @@ class TestAssembledContext:
|
|||||||
restored = AssembledContext.from_json(json_str)
|
restored = AssembledContext.from_json(json_str)
|
||||||
|
|
||||||
assert restored.content == original.content
|
assert restored.content == original.content
|
||||||
assert restored.token_count == original.token_count
|
assert restored.total_tokens == original.total_tokens
|
||||||
assert restored.contexts_included == original.contexts_included
|
assert restored.context_count == original.context_count
|
||||||
|
assert restored.excluded_count == original.excluded_count
|
||||||
|
assert restored.model == original.model
|
||||||
assert restored.cache_hit == original.cache_hit
|
assert restored.cache_hit == original.cache_hit
|
||||||
assert restored.cache_key == original.cache_key
|
assert restored.cache_key == original.cache_key
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user