diff --git a/backend/app/services/context/__init__.py b/backend/app/services/context/__init__.py
index 321afa7..5ad719c 100644
--- a/backend/app/services/context/__init__.py
+++ b/backend/app/services/context/__init__.py
@@ -63,6 +63,19 @@ from .exceptions import (
TokenCountError,
)
+# Assembly
+from .assembly import (
+ ContextPipeline,
+ PipelineMetrics,
+)
+
+# Compression
+from .compression import (
+ ContextCompressor,
+ TruncationResult,
+ TruncationStrategy,
+)
+
# Prioritization
from .prioritization import (
ContextRanker,
@@ -97,10 +110,17 @@ from .types import (
)
__all__ = [
+ # Assembly
+ "ContextPipeline",
+ "PipelineMetrics",
# Budget Management
"BudgetAllocator",
"TokenBudget",
"TokenCalculator",
+ # Compression
+ "ContextCompressor",
+ "TruncationResult",
+ "TruncationStrategy",
# Configuration
"ContextSettings",
"get_context_settings",
diff --git a/backend/app/services/context/assembly/__init__.py b/backend/app/services/context/assembly/__init__.py
index ae869ea..3f3a668 100644
--- a/backend/app/services/context/assembly/__init__.py
+++ b/backend/app/services/context/assembly/__init__.py
@@ -3,3 +3,10 @@ Context Assembly Module.
Provides the assembly pipeline and formatting.
"""
+
+from .pipeline import ContextPipeline, PipelineMetrics
+
+__all__ = [
+ "ContextPipeline",
+ "PipelineMetrics",
+]
diff --git a/backend/app/services/context/assembly/pipeline.py b/backend/app/services/context/assembly/pipeline.py
new file mode 100644
index 0000000..2003cec
--- /dev/null
+++ b/backend/app/services/context/assembly/pipeline.py
@@ -0,0 +1,432 @@
+"""
+Context Assembly Pipeline.
+
+Orchestrates the full context assembly workflow:
+Gather → Count → Score → Rank → Compress → Format
+"""
+
+import asyncio
+import logging
+import time
+from dataclasses import dataclass, field
+from datetime import UTC, datetime
+from typing import TYPE_CHECKING, Any
+
+from ..budget import BudgetAllocator, TokenBudget, TokenCalculator
+from ..compression.truncation import ContextCompressor
+from ..config import ContextSettings, get_context_settings
+from ..exceptions import AssemblyTimeoutError
+from ..prioritization import ContextRanker
+from ..scoring import CompositeScorer
+from ..types import AssembledContext, BaseContext, ContextType
+
+if TYPE_CHECKING:
+ from app.services.mcp.client_manager import MCPClientManager
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class PipelineMetrics:
+ """Metrics from pipeline execution."""
+
+ start_time: datetime = field(default_factory=lambda: datetime.now(UTC))
+ end_time: datetime | None = None
+ total_contexts: int = 0
+ selected_contexts: int = 0
+ excluded_contexts: int = 0
+ compressed_contexts: int = 0
+ total_tokens: int = 0
+ assembly_time_ms: float = 0.0
+ scoring_time_ms: float = 0.0
+ compression_time_ms: float = 0.0
+ formatting_time_ms: float = 0.0
+
+ def to_dict(self) -> dict[str, Any]:
+ """Convert to dictionary."""
+ return {
+ "start_time": self.start_time.isoformat(),
+ "end_time": self.end_time.isoformat() if self.end_time else None,
+ "total_contexts": self.total_contexts,
+ "selected_contexts": self.selected_contexts,
+ "excluded_contexts": self.excluded_contexts,
+ "compressed_contexts": self.compressed_contexts,
+ "total_tokens": self.total_tokens,
+ "assembly_time_ms": round(self.assembly_time_ms, 2),
+ "scoring_time_ms": round(self.scoring_time_ms, 2),
+ "compression_time_ms": round(self.compression_time_ms, 2),
+ "formatting_time_ms": round(self.formatting_time_ms, 2),
+ }
+
+
+class ContextPipeline:
+ """
+ Context assembly pipeline.
+
+ Orchestrates the full workflow of context assembly:
+ 1. Validate and count tokens for all contexts
+ 2. Score contexts based on relevance, recency, and priority
+ 3. Rank and select contexts within budget
+ 4. Compress if needed to fit remaining budget
+ 5. Format for the target model
+ """
+
+ def __init__(
+ self,
+ mcp_manager: "MCPClientManager | None" = None,
+ settings: ContextSettings | None = None,
+ calculator: TokenCalculator | None = None,
+ scorer: CompositeScorer | None = None,
+ ranker: ContextRanker | None = None,
+ compressor: ContextCompressor | None = None,
+ ) -> None:
+ """
+ Initialize the context pipeline.
+
+ Args:
+ mcp_manager: MCP client manager for LLM Gateway integration
+ settings: Context settings
+ calculator: Token calculator
+ scorer: Context scorer
+ ranker: Context ranker
+ compressor: Context compressor
+ """
+ self._settings = settings or get_context_settings()
+ self._mcp = mcp_manager
+
+ # Initialize components
+ self._calculator = calculator or TokenCalculator(mcp_manager=mcp_manager)
+ self._scorer = scorer or CompositeScorer(
+ mcp_manager=mcp_manager, settings=self._settings
+ )
+ self._ranker = ranker or ContextRanker(
+ scorer=self._scorer, calculator=self._calculator
+ )
+ self._compressor = compressor or ContextCompressor(
+ calculator=self._calculator
+ )
+ self._allocator = BudgetAllocator(self._settings)
+
+ def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
+ """Set MCP manager for all components."""
+ self._mcp = mcp_manager
+ self._calculator.set_mcp_manager(mcp_manager)
+ self._scorer.set_mcp_manager(mcp_manager)
+
+ async def assemble(
+ self,
+ contexts: list[BaseContext],
+ query: str,
+ model: str,
+ max_tokens: int | None = None,
+ custom_budget: TokenBudget | None = None,
+ compress: bool = True,
+ format_output: bool = True,
+ timeout_ms: int | None = None,
+ ) -> AssembledContext:
+ """
+ Assemble context for an LLM request.
+
+ This is the main entry point for context assembly.
+
+ Args:
+ contexts: List of contexts to assemble
+ query: Query to optimize for
+ model: Target model name
+ max_tokens: Maximum total tokens (uses model default if None)
+ custom_budget: Optional pre-configured budget
+ compress: Whether to compress oversized contexts
+ format_output: Whether to format the final output
+ timeout_ms: Maximum assembly time in milliseconds
+
+ Returns:
+ AssembledContext with optimized content
+
+ Raises:
+ AssemblyTimeoutError: If assembly exceeds timeout
+ """
+ timeout = timeout_ms or self._settings.max_assembly_time_ms
+ start = time.perf_counter()
+ metrics = PipelineMetrics(total_contexts=len(contexts))
+
+ try:
+ # Create or use budget
+ if custom_budget:
+ budget = custom_budget
+ elif max_tokens:
+ budget = self._allocator.create_budget(max_tokens)
+ else:
+ budget = self._allocator.create_budget_for_model(model)
+
+ # 1. Count tokens for all contexts
+ await self._ensure_token_counts(contexts, model)
+
+ # Check timeout
+ self._check_timeout(start, timeout, "token counting")
+
+ # 2. Score and rank contexts
+ scoring_start = time.perf_counter()
+ ranking_result = await self._ranker.rank(
+ contexts=contexts,
+ query=query,
+ budget=budget,
+ model=model,
+ )
+ metrics.scoring_time_ms = (time.perf_counter() - scoring_start) * 1000
+
+ selected_contexts = ranking_result.selected_contexts
+ metrics.selected_contexts = len(selected_contexts)
+ metrics.excluded_contexts = len(ranking_result.excluded)
+
+ # Check timeout
+ self._check_timeout(start, timeout, "scoring")
+
+ # 3. Compress if needed and enabled
+ if compress and self._needs_compression(selected_contexts, budget):
+ compression_start = time.perf_counter()
+ selected_contexts = await self._compressor.compress_contexts(
+ selected_contexts, budget, model
+ )
+ metrics.compression_time_ms = (
+ time.perf_counter() - compression_start
+ ) * 1000
+ metrics.compressed_contexts = sum(
+ 1 for c in selected_contexts if c.metadata.get("truncated", False)
+ )
+
+ # Check timeout
+ self._check_timeout(start, timeout, "compression")
+
+ # 4. Format output
+ formatting_start = time.perf_counter()
+ if format_output:
+ formatted_content = self._format_contexts(selected_contexts, model)
+ else:
+ formatted_content = "\n\n".join(c.content for c in selected_contexts)
+ metrics.formatting_time_ms = (time.perf_counter() - formatting_start) * 1000
+
+ # Calculate final metrics
+ total_tokens = sum(c.token_count or 0 for c in selected_contexts)
+ metrics.total_tokens = total_tokens
+ metrics.assembly_time_ms = (time.perf_counter() - start) * 1000
+ metrics.end_time = datetime.now(UTC)
+
+ return AssembledContext(
+ content=formatted_content,
+ total_tokens=total_tokens,
+ context_count=len(selected_contexts),
+ assembly_time_ms=metrics.assembly_time_ms,
+ model=model,
+ contexts=selected_contexts,
+ excluded_count=metrics.excluded_contexts,
+ metadata={
+ "metrics": metrics.to_dict(),
+ "query": query,
+ "budget": budget.to_dict(),
+ },
+ )
+
+ except AssemblyTimeoutError:
+ raise
+ except Exception as e:
+ logger.error(f"Context assembly failed: {e}", exc_info=True)
+ raise
+
+ async def _ensure_token_counts(
+ self,
+ contexts: list[BaseContext],
+ model: str | None = None,
+ ) -> None:
+ """Ensure all contexts have token counts."""
+ tasks = []
+ for context in contexts:
+ if context.token_count is None:
+ tasks.append(self._count_and_set(context, model))
+
+ if tasks:
+ await asyncio.gather(*tasks)
+
+ async def _count_and_set(
+ self,
+ context: BaseContext,
+ model: str | None = None,
+ ) -> None:
+ """Count tokens and set on context."""
+ count = await self._calculator.count_tokens(context.content, model)
+ context.token_count = count
+
+ def _needs_compression(
+ self,
+ contexts: list[BaseContext],
+ budget: TokenBudget,
+ ) -> bool:
+ """Check if any contexts exceed their type budget."""
+ # Group by type and check totals
+ by_type: dict[ContextType, int] = {}
+ for context in contexts:
+ ct = context.get_type()
+ by_type[ct] = by_type.get(ct, 0) + (context.token_count or 0)
+
+ for ct, total in by_type.items():
+ if total > budget.get_allocation(ct):
+ return True
+
+ # Also check if utilization exceeds threshold
+ return budget.utilization() > self._settings.compression_threshold
+
+ def _format_contexts(
+ self,
+ contexts: list[BaseContext],
+ model: str,
+ ) -> str:
+ """
+ Format contexts for the target model.
+
+ Groups contexts by type and applies model-specific formatting.
+ """
+ # Group by type
+ by_type: dict[ContextType, list[BaseContext]] = {}
+ for context in contexts:
+ ct = context.get_type()
+ if ct not in by_type:
+ by_type[ct] = []
+ by_type[ct].append(context)
+
+ # Order types: System -> Task -> Knowledge -> Conversation -> Tool
+ type_order = [
+ ContextType.SYSTEM,
+ ContextType.TASK,
+ ContextType.KNOWLEDGE,
+ ContextType.CONVERSATION,
+ ContextType.TOOL,
+ ]
+
+ parts: list[str] = []
+ for ct in type_order:
+ if ct in by_type:
+ formatted = self._format_type(by_type[ct], ct, model)
+ if formatted:
+ parts.append(formatted)
+
+ return "\n\n".join(parts)
+
+ def _format_type(
+ self,
+ contexts: list[BaseContext],
+ context_type: ContextType,
+ model: str,
+ ) -> str:
+ """Format contexts of a specific type."""
+ if not contexts:
+ return ""
+
+ # Check if model prefers XML tags (Claude)
+ use_xml = "claude" in model.lower()
+
+ if context_type == ContextType.SYSTEM:
+ return self._format_system(contexts, use_xml)
+ elif context_type == ContextType.TASK:
+ return self._format_task(contexts, use_xml)
+ elif context_type == ContextType.KNOWLEDGE:
+ return self._format_knowledge(contexts, use_xml)
+ elif context_type == ContextType.CONVERSATION:
+ return self._format_conversation(contexts, use_xml)
+ elif context_type == ContextType.TOOL:
+ return self._format_tool(contexts, use_xml)
+
+ return "\n".join(c.content for c in contexts)
+
+ def _format_system(
+ self, contexts: list[BaseContext], use_xml: bool
+ ) -> str:
+ """Format system contexts."""
+ content = "\n\n".join(c.content for c in contexts)
+ if use_xml:
+ return f"\n{content}\n"
+ return content
+
+ def _format_task(
+ self, contexts: list[BaseContext], use_xml: bool
+ ) -> str:
+ """Format task contexts."""
+ content = "\n\n".join(c.content for c in contexts)
+ if use_xml:
+ return f"\n{content}\n"
+ return f"## Current Task\n\n{content}"
+
+ def _format_knowledge(
+ self, contexts: list[BaseContext], use_xml: bool
+ ) -> str:
+ """Format knowledge contexts."""
+ if use_xml:
+ parts = [""]
+ for ctx in contexts:
+ parts.append(f'')
+ parts.append(ctx.content)
+ parts.append("")
+ parts.append("")
+ return "\n".join(parts)
+ else:
+ parts = ["## Reference Documents\n"]
+ for ctx in contexts:
+ parts.append(f"### Source: {ctx.source}\n")
+ parts.append(ctx.content)
+ parts.append("")
+ return "\n".join(parts)
+
+ def _format_conversation(
+ self, contexts: list[BaseContext], use_xml: bool
+ ) -> str:
+ """Format conversation contexts."""
+ if use_xml:
+ parts = [""]
+ for ctx in contexts:
+ role = ctx.metadata.get("role", "user")
+ parts.append(f'')
+ parts.append(ctx.content)
+ parts.append("")
+ parts.append("")
+ return "\n".join(parts)
+ else:
+ parts = []
+ for ctx in contexts:
+ role = ctx.metadata.get("role", "user")
+ parts.append(f"**{role.upper()}**: {ctx.content}")
+ return "\n\n".join(parts)
+
+ def _format_tool(
+ self, contexts: list[BaseContext], use_xml: bool
+ ) -> str:
+ """Format tool contexts."""
+ if use_xml:
+ parts = [""]
+ for ctx in contexts:
+ tool_name = ctx.metadata.get("tool_name", "unknown")
+ parts.append(f'')
+ parts.append(ctx.content)
+ parts.append("")
+ parts.append("")
+ return "\n".join(parts)
+ else:
+ parts = ["## Recent Tool Results\n"]
+ for ctx in contexts:
+ tool_name = ctx.metadata.get("tool_name", "unknown")
+ parts.append(f"### Tool: {tool_name}\n")
+ parts.append(f"```\n{ctx.content}\n```")
+ parts.append("")
+ return "\n".join(parts)
+
+ def _check_timeout(
+ self,
+ start: float,
+ timeout_ms: int,
+ phase: str,
+ ) -> None:
+ """Check if timeout exceeded and raise if so."""
+ elapsed_ms = (time.perf_counter() - start) * 1000
+ if elapsed_ms > timeout_ms:
+ raise AssemblyTimeoutError(
+ message=f"Context assembly timed out during {phase}",
+ elapsed_ms=elapsed_ms,
+ timeout_ms=timeout_ms,
+ )
diff --git a/backend/app/services/context/compression/__init__.py b/backend/app/services/context/compression/__init__.py
index 28cb5e9..02d0dbf 100644
--- a/backend/app/services/context/compression/__init__.py
+++ b/backend/app/services/context/compression/__init__.py
@@ -3,3 +3,11 @@ Context Compression Module.
Provides truncation and compression strategies.
"""
+
+from .truncation import ContextCompressor, TruncationResult, TruncationStrategy
+
+__all__ = [
+ "ContextCompressor",
+ "TruncationResult",
+ "TruncationStrategy",
+]
diff --git a/backend/app/services/context/compression/truncation.py b/backend/app/services/context/compression/truncation.py
new file mode 100644
index 0000000..4c8cf7b
--- /dev/null
+++ b/backend/app/services/context/compression/truncation.py
@@ -0,0 +1,391 @@
+"""
+Smart Truncation for Context Compression.
+
+Provides intelligent truncation strategies to reduce context size
+while preserving the most important information.
+"""
+
+import logging
+import re
+from dataclasses import dataclass
+from typing import TYPE_CHECKING
+
+from ..types import BaseContext, ContextType
+
+if TYPE_CHECKING:
+ from ..budget import TokenBudget, TokenCalculator
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class TruncationResult:
+ """Result of truncation operation."""
+
+ original_tokens: int
+ truncated_tokens: int
+ content: str
+ truncated: bool
+ truncation_ratio: float # 0.0 = no truncation, 1.0 = completely removed
+
+ @property
+ def tokens_saved(self) -> int:
+ """Calculate tokens saved by truncation."""
+ return self.original_tokens - self.truncated_tokens
+
+
+class TruncationStrategy:
+ """
+ Smart truncation strategies for context compression.
+
+ Strategies:
+ 1. End truncation: Cut from end (for knowledge/docs)
+ 2. Middle truncation: Keep start and end (for code)
+ 3. Sentence-aware: Truncate at sentence boundaries
+ 4. Semantic chunking: Keep most relevant chunks
+ """
+
+ # Default truncation marker
+ TRUNCATION_MARKER = "\n\n[...content truncated...]\n\n"
+
+ def __init__(
+ self,
+ calculator: "TokenCalculator | None" = None,
+ preserve_ratio_start: float = 0.7, # Keep 70% from start by default
+ min_content_length: int = 100, # Minimum characters to keep
+ ) -> None:
+ """
+ Initialize truncation strategy.
+
+ Args:
+ calculator: Token calculator for accurate counting
+ preserve_ratio_start: Ratio of content to keep from start
+ min_content_length: Minimum characters to preserve
+ """
+ self._calculator = calculator
+ self._preserve_ratio_start = preserve_ratio_start
+ self._min_content_length = min_content_length
+
+ def set_calculator(self, calculator: "TokenCalculator") -> None:
+ """Set token calculator."""
+ self._calculator = calculator
+
+ async def truncate_to_tokens(
+ self,
+ content: str,
+ max_tokens: int,
+ strategy: str = "end",
+ model: str | None = None,
+ ) -> TruncationResult:
+ """
+ Truncate content to fit within token limit.
+
+ Args:
+ content: Content to truncate
+ max_tokens: Maximum tokens allowed
+ strategy: Truncation strategy ('end', 'middle', 'sentence')
+ model: Model for token counting
+
+ Returns:
+ TruncationResult with truncated content
+ """
+ if not content:
+ return TruncationResult(
+ original_tokens=0,
+ truncated_tokens=0,
+ content="",
+ truncated=False,
+ truncation_ratio=0.0,
+ )
+
+ # Get original token count
+ original_tokens = await self._count_tokens(content, model)
+
+ if original_tokens <= max_tokens:
+ return TruncationResult(
+ original_tokens=original_tokens,
+ truncated_tokens=original_tokens,
+ content=content,
+ truncated=False,
+ truncation_ratio=0.0,
+ )
+
+ # Apply truncation strategy
+ if strategy == "middle":
+ truncated = await self._truncate_middle(content, max_tokens, model)
+ elif strategy == "sentence":
+ truncated = await self._truncate_sentence(content, max_tokens, model)
+ else: # "end"
+ truncated = await self._truncate_end(content, max_tokens, model)
+
+ truncated_tokens = await self._count_tokens(truncated, model)
+
+ return TruncationResult(
+ original_tokens=original_tokens,
+ truncated_tokens=truncated_tokens,
+ content=truncated,
+ truncated=True,
+ truncation_ratio=1 - (truncated_tokens / original_tokens),
+ )
+
+ async def _truncate_end(
+ self,
+ content: str,
+ max_tokens: int,
+ model: str | None = None,
+ ) -> str:
+ """
+ Truncate from end of content.
+
+ Simple but effective for most content types.
+ """
+ # Binary search for optimal truncation point
+ marker_tokens = await self._count_tokens(self.TRUNCATION_MARKER, model)
+ available_tokens = max_tokens - marker_tokens
+
+ # Estimate characters per token
+ chars_per_token = len(content) / await self._count_tokens(content, model)
+
+ # Start with estimated position
+ estimated_chars = int(available_tokens * chars_per_token)
+ truncated = content[:estimated_chars]
+
+ # Refine with binary search
+ low, high = len(truncated) // 2, len(truncated)
+ best = truncated
+
+ for _ in range(5): # Max 5 iterations
+ mid = (low + high) // 2
+ candidate = content[:mid]
+ tokens = await self._count_tokens(candidate, model)
+
+ if tokens <= available_tokens:
+ best = candidate
+ low = mid + 1
+ else:
+ high = mid - 1
+
+ return best + self.TRUNCATION_MARKER
+
+ async def _truncate_middle(
+ self,
+ content: str,
+ max_tokens: int,
+ model: str | None = None,
+ ) -> str:
+ """
+ Truncate from middle, keeping start and end.
+
+ Good for code or content where context at boundaries matters.
+ """
+ marker_tokens = await self._count_tokens(self.TRUNCATION_MARKER, model)
+ available_tokens = max_tokens - marker_tokens
+
+ # Split between start and end
+ start_tokens = int(available_tokens * self._preserve_ratio_start)
+ end_tokens = available_tokens - start_tokens
+
+ # Get start portion
+ start_content = await self._get_content_for_tokens(
+ content, start_tokens, from_start=True, model=model
+ )
+
+ # Get end portion
+ end_content = await self._get_content_for_tokens(
+ content, end_tokens, from_start=False, model=model
+ )
+
+ return start_content + self.TRUNCATION_MARKER + end_content
+
+ async def _truncate_sentence(
+ self,
+ content: str,
+ max_tokens: int,
+ model: str | None = None,
+ ) -> str:
+ """
+ Truncate at sentence boundaries.
+
+ Produces cleaner output by not cutting mid-sentence.
+ """
+ # Split into sentences
+ sentences = re.split(r"(?<=[.!?])\s+", content)
+
+ result: list[str] = []
+ total_tokens = 0
+ marker_tokens = await self._count_tokens(self.TRUNCATION_MARKER, model)
+ available = max_tokens - marker_tokens
+
+ for sentence in sentences:
+ sentence_tokens = await self._count_tokens(sentence, model)
+ if total_tokens + sentence_tokens <= available:
+ result.append(sentence)
+ total_tokens += sentence_tokens
+ else:
+ break
+
+ if len(result) < len(sentences):
+ return " ".join(result) + self.TRUNCATION_MARKER
+ return " ".join(result)
+
+ async def _get_content_for_tokens(
+ self,
+ content: str,
+ target_tokens: int,
+ from_start: bool = True,
+ model: str | None = None,
+ ) -> str:
+ """Get portion of content fitting within token limit."""
+ if target_tokens <= 0:
+ return ""
+
+ current_tokens = await self._count_tokens(content, model)
+ if current_tokens <= target_tokens:
+ return content
+
+ # Estimate characters
+ chars_per_token = len(content) / current_tokens
+ estimated_chars = int(target_tokens * chars_per_token)
+
+ if from_start:
+ return content[:estimated_chars]
+ else:
+ return content[-estimated_chars:]
+
+ async def _count_tokens(self, text: str, model: str | None = None) -> int:
+ """Count tokens using calculator or estimation."""
+ if self._calculator is not None:
+ return await self._calculator.count_tokens(text, model)
+
+ # Fallback estimation
+ return max(1, len(text) // 4)
+
+
+class ContextCompressor:
+ """
+ Compresses contexts to fit within budget constraints.
+
+ Uses truncation strategies to reduce context size while
+ preserving the most important information.
+ """
+
+ def __init__(
+ self,
+ truncation: TruncationStrategy | None = None,
+ calculator: "TokenCalculator | None" = None,
+ ) -> None:
+ """
+ Initialize context compressor.
+
+ Args:
+ truncation: Truncation strategy to use
+ calculator: Token calculator for counting
+ """
+ self._truncation = truncation or TruncationStrategy(calculator)
+ self._calculator = calculator
+
+ if calculator:
+ self._truncation.set_calculator(calculator)
+
+ def set_calculator(self, calculator: "TokenCalculator") -> None:
+ """Set token calculator."""
+ self._calculator = calculator
+ self._truncation.set_calculator(calculator)
+
+ async def compress_context(
+ self,
+ context: BaseContext,
+ max_tokens: int,
+ model: str | None = None,
+ ) -> BaseContext:
+ """
+ Compress a single context to fit token limit.
+
+ Args:
+ context: Context to compress
+ max_tokens: Maximum tokens allowed
+ model: Model for token counting
+
+ Returns:
+ Compressed context (may be same object if no compression needed)
+ """
+ current_tokens = context.token_count or await self._count_tokens(
+ context.content, model
+ )
+
+ if current_tokens <= max_tokens:
+ return context
+
+ # Choose strategy based on context type
+ strategy = self._get_strategy_for_type(context.get_type())
+
+ result = await self._truncation.truncate_to_tokens(
+ content=context.content,
+ max_tokens=max_tokens,
+ strategy=strategy,
+ model=model,
+ )
+
+ # Update context with truncated content
+ context.content = result.content
+ context.token_count = result.truncated_tokens
+ context.metadata["truncated"] = True
+ context.metadata["original_tokens"] = result.original_tokens
+
+ return context
+
+ async def compress_contexts(
+ self,
+ contexts: list[BaseContext],
+ budget: "TokenBudget",
+ model: str | None = None,
+ ) -> list[BaseContext]:
+ """
+ Compress multiple contexts to fit within budget.
+
+ Args:
+ contexts: Contexts to potentially compress
+ budget: Token budget constraints
+ model: Model for token counting
+
+ Returns:
+ List of contexts (compressed as needed)
+ """
+ result: list[BaseContext] = []
+
+ for context in contexts:
+ context_type = context.get_type()
+ remaining = budget.remaining(context_type)
+ current_tokens = context.token_count or await self._count_tokens(
+ context.content, model
+ )
+
+ if current_tokens > remaining:
+ # Need to compress
+ compressed = await self.compress_context(context, remaining, model)
+ result.append(compressed)
+ logger.debug(
+ f"Compressed {context_type.value} context from "
+ f"{current_tokens} to {compressed.token_count} tokens"
+ )
+ else:
+ result.append(context)
+
+ return result
+
+ def _get_strategy_for_type(self, context_type: ContextType) -> str:
+ """Get optimal truncation strategy for context type."""
+ strategies = {
+ ContextType.SYSTEM: "end", # Keep instructions at start
+ ContextType.TASK: "end", # Keep task description start
+ ContextType.KNOWLEDGE: "sentence", # Clean sentence boundaries
+ ContextType.CONVERSATION: "end", # Keep recent conversation
+ ContextType.TOOL: "middle", # Keep command and result summary
+ }
+ return strategies.get(context_type, "end")
+
+ async def _count_tokens(self, text: str, model: str | None = None) -> int:
+ """Count tokens using calculator or estimation."""
+ if self._calculator is not None:
+ return await self._calculator.count_tokens(text, model)
+ return max(1, len(text) // 4)
diff --git a/backend/app/services/context/types/base.py b/backend/app/services/context/types/base.py
index 4b01ed2..6eef658 100644
--- a/backend/app/services/context/types/base.py
+++ b/backend/app/services/context/types/base.py
@@ -253,12 +253,19 @@ class AssembledContext:
# Main content
content: str
- token_count: int
+ total_tokens: int
# Assembly metadata
- contexts_included: int
- contexts_excluded: int = 0
+ context_count: int
+ excluded_count: int = 0
assembly_time_ms: float = 0.0
+ model: str = ""
+
+ # Included contexts (optional - for inspection)
+ contexts: list["BaseContext"] = field(default_factory=list)
+
+ # Additional metadata from assembly
+ metadata: dict[str, Any] = field(default_factory=dict)
# Budget tracking
budget_total: int = 0
@@ -271,6 +278,22 @@ class AssembledContext:
cache_hit: bool = False
cache_key: str | None = None
+ # Aliases for backward compatibility
+ @property
+ def token_count(self) -> int:
+ """Alias for total_tokens."""
+ return self.total_tokens
+
+ @property
+ def contexts_included(self) -> int:
+ """Alias for context_count."""
+ return self.context_count
+
+ @property
+ def contexts_excluded(self) -> int:
+ """Alias for excluded_count."""
+ return self.excluded_count
+
@property
def budget_utilization(self) -> float:
"""Get budget utilization percentage."""
@@ -282,10 +305,12 @@ class AssembledContext:
"""Convert to dictionary."""
return {
"content": self.content,
- "token_count": self.token_count,
- "contexts_included": self.contexts_included,
- "contexts_excluded": self.contexts_excluded,
+ "total_tokens": self.total_tokens,
+ "context_count": self.context_count,
+ "excluded_count": self.excluded_count,
"assembly_time_ms": round(self.assembly_time_ms, 2),
+ "model": self.model,
+ "metadata": self.metadata,
"budget_total": self.budget_total,
"budget_used": self.budget_used,
"budget_utilization": round(self.budget_utilization, 3),
@@ -308,10 +333,12 @@ class AssembledContext:
data = json.loads(json_str)
return cls(
content=data["content"],
- token_count=data["token_count"],
- contexts_included=data["contexts_included"],
- contexts_excluded=data.get("contexts_excluded", 0),
+ total_tokens=data["total_tokens"],
+ context_count=data["context_count"],
+ excluded_count=data.get("excluded_count", 0),
assembly_time_ms=data.get("assembly_time_ms", 0.0),
+ model=data.get("model", ""),
+ metadata=data.get("metadata", {}),
budget_total=data.get("budget_total", 0),
budget_used=data.get("budget_used", 0),
by_type=data.get("by_type", {}),
diff --git a/backend/tests/services/context/test_assembly.py b/backend/tests/services/context/test_assembly.py
new file mode 100644
index 0000000..92f9c7e
--- /dev/null
+++ b/backend/tests/services/context/test_assembly.py
@@ -0,0 +1,502 @@
+"""Tests for context assembly pipeline."""
+
+from datetime import UTC, datetime
+
+import pytest
+
+from app.services.context.assembly import ContextPipeline, PipelineMetrics
+from app.services.context.budget import BudgetAllocator, TokenBudget
+from app.services.context.types import (
+ AssembledContext,
+ ContextType,
+ ConversationContext,
+ KnowledgeContext,
+ MessageRole,
+ SystemContext,
+ TaskContext,
+ ToolContext,
+)
+
+
+class TestPipelineMetrics:
+ """Tests for PipelineMetrics dataclass."""
+
+ def test_creation(self) -> None:
+ """Test metrics creation."""
+ metrics = PipelineMetrics()
+
+ assert metrics.total_contexts == 0
+ assert metrics.selected_contexts == 0
+ assert metrics.assembly_time_ms == 0.0
+
+ def test_to_dict(self) -> None:
+ """Test conversion to dictionary."""
+ metrics = PipelineMetrics(
+ total_contexts=10,
+ selected_contexts=8,
+ excluded_contexts=2,
+ total_tokens=500,
+ assembly_time_ms=25.5,
+ )
+ metrics.end_time = datetime.now(UTC)
+
+ data = metrics.to_dict()
+
+ assert data["total_contexts"] == 10
+ assert data["selected_contexts"] == 8
+ assert data["excluded_contexts"] == 2
+ assert data["total_tokens"] == 500
+ assert data["assembly_time_ms"] == 25.5
+ assert "start_time" in data
+ assert "end_time" in data
+
+
+class TestContextPipeline:
+ """Tests for ContextPipeline."""
+
+ def test_creation(self) -> None:
+ """Test pipeline creation."""
+ pipeline = ContextPipeline()
+
+ assert pipeline._calculator is not None
+ assert pipeline._scorer is not None
+ assert pipeline._ranker is not None
+ assert pipeline._compressor is not None
+ assert pipeline._allocator is not None
+
+ @pytest.mark.asyncio
+ async def test_assemble_empty_contexts(self) -> None:
+ """Test assembling empty context list."""
+ pipeline = ContextPipeline()
+
+ result = await pipeline.assemble(
+ contexts=[],
+ query="test query",
+ model="claude-3-sonnet",
+ )
+
+ assert isinstance(result, AssembledContext)
+ assert result.context_count == 0
+ assert result.total_tokens == 0
+
+ @pytest.mark.asyncio
+ async def test_assemble_single_context(self) -> None:
+ """Test assembling single context."""
+ pipeline = ContextPipeline()
+
+ contexts = [
+ SystemContext(
+ content="You are a helpful assistant.",
+ source="system",
+ )
+ ]
+
+ result = await pipeline.assemble(
+ contexts=contexts,
+ query="help me",
+ model="claude-3-sonnet",
+ )
+
+ assert result.context_count == 1
+ assert result.total_tokens > 0
+ assert "helpful assistant" in result.content
+
+ @pytest.mark.asyncio
+ async def test_assemble_multiple_types(self) -> None:
+ """Test assembling multiple context types."""
+ pipeline = ContextPipeline()
+
+ contexts = [
+ SystemContext(
+ content="You are a coding assistant.",
+ source="system",
+ ),
+ TaskContext(
+ content="Implement a login feature.",
+ source="task",
+ ),
+ KnowledgeContext(
+ content="Authentication best practices include...",
+ source="docs/auth.md",
+ relevance_score=0.8,
+ ),
+ ]
+
+ result = await pipeline.assemble(
+ contexts=contexts,
+ query="implement login",
+ model="claude-3-sonnet",
+ )
+
+ assert result.context_count >= 1
+ assert result.total_tokens > 0
+
+ @pytest.mark.asyncio
+ async def test_assemble_with_custom_budget(self) -> None:
+ """Test assembling with custom budget."""
+ pipeline = ContextPipeline()
+ budget = TokenBudget(
+ total=1000,
+ system=200,
+ task=200,
+ knowledge=400,
+ conversation=100,
+ tools=50,
+ response_reserve=50,
+ )
+
+ contexts = [
+ SystemContext(content="System prompt", source="system"),
+ TaskContext(content="Task description", source="task"),
+ ]
+
+ result = await pipeline.assemble(
+ contexts=contexts,
+ query="test",
+ model="gpt-4",
+ custom_budget=budget,
+ )
+
+ assert result.context_count >= 1
+
+ @pytest.mark.asyncio
+ async def test_assemble_with_max_tokens(self) -> None:
+ """Test assembling with max_tokens limit."""
+ pipeline = ContextPipeline()
+
+ contexts = [
+ SystemContext(content="System prompt", source="system"),
+ ]
+
+ result = await pipeline.assemble(
+ contexts=contexts,
+ query="test",
+ model="gpt-4",
+ max_tokens=5000,
+ )
+
+ assert "budget" in result.metadata
+ assert result.metadata["budget"]["total"] == 5000
+
+ @pytest.mark.asyncio
+ async def test_assemble_format_output(self) -> None:
+ """Test formatted vs unformatted output."""
+ pipeline = ContextPipeline()
+
+ contexts = [
+ SystemContext(content="System prompt", source="system"),
+ ]
+
+ # Formatted (default)
+ result_formatted = await pipeline.assemble(
+ contexts=contexts,
+ query="test",
+ model="claude-3-sonnet",
+ format_output=True,
+ )
+
+ # Unformatted
+ result_raw = await pipeline.assemble(
+ contexts=contexts,
+ query="test",
+ model="claude-3-sonnet",
+ format_output=False,
+ )
+
+ # Formatted should have XML tags for Claude
+ assert "" in result_formatted.content
+ # Raw should not
+ assert "" not in result_raw.content
+
+ @pytest.mark.asyncio
+ async def test_assemble_metrics(self) -> None:
+ """Test that metrics are populated."""
+ pipeline = ContextPipeline()
+
+ contexts = [
+ SystemContext(content="System", source="system"),
+ TaskContext(content="Task", source="task"),
+ KnowledgeContext(
+ content="Knowledge",
+ source="docs",
+ relevance_score=0.9,
+ ),
+ ]
+
+ result = await pipeline.assemble(
+ contexts=contexts,
+ query="test",
+ model="claude-3-sonnet",
+ )
+
+ assert "metrics" in result.metadata
+ metrics = result.metadata["metrics"]
+
+ assert metrics["total_contexts"] == 3
+ assert metrics["assembly_time_ms"] > 0
+ assert "scoring_time_ms" in metrics
+ assert "formatting_time_ms" in metrics
+
+ @pytest.mark.asyncio
+ async def test_assemble_with_compression_disabled(self) -> None:
+ """Test assembling with compression disabled."""
+ pipeline = ContextPipeline()
+
+ contexts = [
+ KnowledgeContext(content="A" * 1000, source="docs"),
+ ]
+
+ result = await pipeline.assemble(
+ contexts=contexts,
+ query="test",
+ model="gpt-4",
+ compress=False,
+ )
+
+ # Should still work, just no compression applied
+ assert result.context_count >= 0
+
+
+class TestContextPipelineFormatting:
+ """Tests for context formatting."""
+
+ @pytest.mark.asyncio
+ async def test_format_claude_uses_xml(self) -> None:
+ """Test that Claude models use XML formatting."""
+ pipeline = ContextPipeline()
+
+ contexts = [
+ SystemContext(content="System prompt", source="system"),
+ TaskContext(content="Task", source="task"),
+ KnowledgeContext(
+ content="Knowledge",
+ source="docs",
+ relevance_score=0.9,
+ ),
+ ]
+
+ result = await pipeline.assemble(
+ contexts=contexts,
+ query="test",
+ model="claude-3-sonnet",
+ )
+
+ # Claude should have XML tags
+ assert "" in result.content or result.context_count == 0
+
+ @pytest.mark.asyncio
+ async def test_format_openai_uses_markdown(self) -> None:
+ """Test that OpenAI models use markdown formatting."""
+ pipeline = ContextPipeline()
+
+ contexts = [
+ TaskContext(content="Task description", source="task"),
+ ]
+
+ result = await pipeline.assemble(
+ contexts=contexts,
+ query="test",
+ model="gpt-4",
+ )
+
+ # OpenAI should have markdown headers
+ if result.context_count > 0 and "Task" in result.content:
+ assert "## Current Task" in result.content
+
+ @pytest.mark.asyncio
+ async def test_format_knowledge_claude(self) -> None:
+ """Test knowledge formatting for Claude."""
+ pipeline = ContextPipeline()
+
+ contexts = [
+ KnowledgeContext(
+ content="Document content here",
+ source="docs/file.md",
+ relevance_score=0.9,
+ ),
+ ]
+
+ result = await pipeline.assemble(
+ contexts=contexts,
+ query="test",
+ model="claude-3-sonnet",
+ )
+
+ if result.context_count > 0:
+ assert "" in result.content
+ assert " None:
+ """Test conversation formatting."""
+ pipeline = ContextPipeline()
+
+ contexts = [
+ ConversationContext(
+ content="Hello, how are you?",
+ source="chat",
+ role=MessageRole.USER,
+ metadata={"role": "user"},
+ ),
+ ConversationContext(
+ content="I'm doing great!",
+ source="chat",
+ role=MessageRole.ASSISTANT,
+ metadata={"role": "assistant"},
+ ),
+ ]
+
+ result = await pipeline.assemble(
+ contexts=contexts,
+ query="test",
+ model="claude-3-sonnet",
+ )
+
+ if result.context_count > 0:
+ assert "" in result.content
+ assert '' in result.content or 'role="user"' in result.content
+
+ @pytest.mark.asyncio
+ async def test_format_tool_results(self) -> None:
+ """Test tool result formatting."""
+ pipeline = ContextPipeline()
+
+ contexts = [
+ ToolContext(
+ content="Tool output here",
+ source="tool",
+ metadata={"tool_name": "search"},
+ ),
+ ]
+
+ result = await pipeline.assemble(
+ contexts=contexts,
+ query="test",
+ model="claude-3-sonnet",
+ )
+
+ if result.context_count > 0:
+ assert "" in result.content
+
+
+class TestContextPipelineIntegration:
+ """Integration tests for full pipeline."""
+
+ @pytest.mark.asyncio
+ async def test_full_pipeline_workflow(self) -> None:
+ """Test complete pipeline workflow."""
+ pipeline = ContextPipeline()
+
+ # Create realistic context mix
+ contexts = [
+ SystemContext(
+ content="You are an expert Python developer.",
+ source="system",
+ ),
+ TaskContext(
+ content="Implement a user authentication system.",
+ source="task:AUTH-123",
+ ),
+ KnowledgeContext(
+ content="JWT tokens provide stateless authentication...",
+ source="docs/auth/jwt.md",
+ relevance_score=0.9,
+ ),
+ KnowledgeContext(
+ content="OAuth 2.0 is an authorization framework...",
+ source="docs/auth/oauth.md",
+ relevance_score=0.7,
+ ),
+ ConversationContext(
+ content="Can you help me implement JWT auth?",
+ source="chat",
+ role=MessageRole.USER,
+ metadata={"role": "user"},
+ ),
+ ]
+
+ result = await pipeline.assemble(
+ contexts=contexts,
+ query="implement JWT authentication",
+ model="claude-3-sonnet",
+ )
+
+ # Verify result
+ assert isinstance(result, AssembledContext)
+ assert result.context_count > 0
+ assert result.total_tokens > 0
+ assert result.assembly_time_ms > 0
+ assert result.model == "claude-3-sonnet"
+ assert len(result.content) > 0
+
+ # Verify metrics
+ assert "metrics" in result.metadata
+ assert "query" in result.metadata
+ assert "budget" in result.metadata
+
+ @pytest.mark.asyncio
+ async def test_context_type_ordering(self) -> None:
+ """Test that contexts are ordered by type correctly."""
+ pipeline = ContextPipeline()
+
+ # Add in random order
+ contexts = [
+ KnowledgeContext(content="Knowledge", source="docs", relevance_score=0.9),
+ ToolContext(content="Tool", source="tool", metadata={"tool_name": "test"}),
+ SystemContext(content="System", source="system"),
+ ConversationContext(
+ content="Chat",
+ source="chat",
+ role=MessageRole.USER,
+ metadata={"role": "user"},
+ ),
+ TaskContext(content="Task", source="task"),
+ ]
+
+ result = await pipeline.assemble(
+ contexts=contexts,
+ query="test",
+ model="claude-3-sonnet",
+ )
+
+ # For Claude, verify order: System -> Task -> Knowledge -> Conversation -> Tool
+ content = result.content
+ if result.context_count > 0:
+ # Find positions (if they exist)
+ system_pos = content.find("system_instructions")
+ task_pos = content.find("current_task")
+ knowledge_pos = content.find("reference_documents")
+ conversation_pos = content.find("conversation_history")
+ tool_pos = content.find("tool_results")
+
+ # Verify ordering (only check if both exist)
+ if system_pos >= 0 and task_pos >= 0:
+ assert system_pos < task_pos
+ if task_pos >= 0 and knowledge_pos >= 0:
+ assert task_pos < knowledge_pos
+
+ @pytest.mark.asyncio
+ async def test_excluded_contexts_tracked(self) -> None:
+ """Test that excluded contexts are tracked in result."""
+ pipeline = ContextPipeline()
+
+ # Create many contexts to force some exclusions
+ contexts = [
+ KnowledgeContext(
+ content="A" * 500, # Large content
+ source=f"docs/{i}",
+ relevance_score=0.1 + (i * 0.05),
+ )
+ for i in range(10)
+ ]
+
+ result = await pipeline.assemble(
+ contexts=contexts,
+ query="test",
+ model="gpt-4", # Smaller context window
+ max_tokens=1000, # Limited budget
+ )
+
+ # Should have excluded some
+ assert result.excluded_count >= 0
+ assert result.context_count + result.excluded_count <= len(contexts)
diff --git a/backend/tests/services/context/test_compression.py b/backend/tests/services/context/test_compression.py
new file mode 100644
index 0000000..b2fde38
--- /dev/null
+++ b/backend/tests/services/context/test_compression.py
@@ -0,0 +1,214 @@
+"""Tests for context compression module."""
+
+import pytest
+
+from app.services.context.compression import (
+ ContextCompressor,
+ TruncationResult,
+ TruncationStrategy,
+)
+from app.services.context.budget import BudgetAllocator, TokenBudget
+from app.services.context.types import (
+ ContextType,
+ KnowledgeContext,
+ SystemContext,
+ TaskContext,
+)
+
+
+class TestTruncationResult:
+ """Tests for TruncationResult dataclass."""
+
+ def test_creation(self) -> None:
+ """Test basic creation."""
+ result = TruncationResult(
+ original_tokens=100,
+ truncated_tokens=50,
+ content="Truncated content",
+ truncated=True,
+ truncation_ratio=0.5,
+ )
+
+ assert result.original_tokens == 100
+ assert result.truncated_tokens == 50
+ assert result.truncated is True
+ assert result.truncation_ratio == 0.5
+
+ def test_tokens_saved(self) -> None:
+ """Test tokens_saved property."""
+ result = TruncationResult(
+ original_tokens=100,
+ truncated_tokens=40,
+ content="Test",
+ truncated=True,
+ truncation_ratio=0.6,
+ )
+
+ assert result.tokens_saved == 60
+
+ def test_no_truncation(self) -> None:
+ """Test when no truncation needed."""
+ result = TruncationResult(
+ original_tokens=50,
+ truncated_tokens=50,
+ content="Full content",
+ truncated=False,
+ truncation_ratio=0.0,
+ )
+
+ assert result.tokens_saved == 0
+ assert result.truncated is False
+
+
+class TestTruncationStrategy:
+ """Tests for TruncationStrategy."""
+
+ def test_creation(self) -> None:
+ """Test strategy creation."""
+ strategy = TruncationStrategy()
+ assert strategy._preserve_ratio_start == 0.7
+ assert strategy._min_content_length == 100
+
+ def test_creation_with_params(self) -> None:
+ """Test strategy creation with custom params."""
+ strategy = TruncationStrategy(
+ preserve_ratio_start=0.5,
+ min_content_length=50,
+ )
+ assert strategy._preserve_ratio_start == 0.5
+ assert strategy._min_content_length == 50
+
+ @pytest.mark.asyncio
+ async def test_truncate_empty_content(self) -> None:
+ """Test truncating empty content."""
+ strategy = TruncationStrategy()
+
+ result = await strategy.truncate_to_tokens("", max_tokens=100)
+
+ assert result.original_tokens == 0
+ assert result.truncated_tokens == 0
+ assert result.content == ""
+ assert result.truncated is False
+
+ @pytest.mark.asyncio
+ async def test_truncate_content_within_limit(self) -> None:
+ """Test content that fits within limit."""
+ strategy = TruncationStrategy()
+ content = "Short content"
+
+ result = await strategy.truncate_to_tokens(content, max_tokens=100)
+
+ assert result.content == content
+ assert result.truncated is False
+
+ @pytest.mark.asyncio
+ async def test_truncate_end_strategy(self) -> None:
+ """Test end truncation strategy."""
+ strategy = TruncationStrategy()
+ content = "A" * 1000 # Long content
+
+ result = await strategy.truncate_to_tokens(
+ content, max_tokens=50, strategy="end"
+ )
+
+ assert result.truncated is True
+ assert len(result.content) < len(content)
+ assert strategy.TRUNCATION_MARKER in result.content
+
+ @pytest.mark.asyncio
+ async def test_truncate_middle_strategy(self) -> None:
+ """Test middle truncation strategy."""
+ strategy = TruncationStrategy(preserve_ratio_start=0.6)
+ content = "START " + "A" * 500 + " END"
+
+ result = await strategy.truncate_to_tokens(
+ content, max_tokens=50, strategy="middle"
+ )
+
+ assert result.truncated is True
+ assert strategy.TRUNCATION_MARKER in result.content
+
+ @pytest.mark.asyncio
+ async def test_truncate_sentence_strategy(self) -> None:
+ """Test sentence-aware truncation strategy."""
+ strategy = TruncationStrategy()
+ content = "First sentence. Second sentence. Third sentence. Fourth sentence."
+
+ result = await strategy.truncate_to_tokens(
+ content, max_tokens=10, strategy="sentence"
+ )
+
+ assert result.truncated is True
+ # Should cut at sentence boundary
+ assert result.content.endswith(".") or strategy.TRUNCATION_MARKER in result.content
+
+
+class TestContextCompressor:
+ """Tests for ContextCompressor."""
+
+ def test_creation(self) -> None:
+ """Test compressor creation."""
+ compressor = ContextCompressor()
+ assert compressor._truncation is not None
+
+ @pytest.mark.asyncio
+ async def test_compress_context_within_limit(self) -> None:
+ """Test compressing context that already fits."""
+ compressor = ContextCompressor()
+
+ context = KnowledgeContext(
+ content="Short content",
+ source="docs",
+ )
+ context.token_count = 5
+
+ result = await compressor.compress_context(context, max_tokens=100)
+
+ # Should return same context unmodified
+ assert result.content == "Short content"
+ assert result.metadata.get("truncated") is not True
+
+ @pytest.mark.asyncio
+ async def test_compress_context_exceeds_limit(self) -> None:
+ """Test compressing context that exceeds limit."""
+ compressor = ContextCompressor()
+
+ context = KnowledgeContext(
+ content="A" * 500,
+ source="docs",
+ )
+ context.token_count = 125 # Approximately 500/4
+
+ result = await compressor.compress_context(context, max_tokens=20)
+
+ assert result.metadata.get("truncated") is True
+ assert result.metadata.get("original_tokens") == 125
+ assert len(result.content) < 500
+
+ @pytest.mark.asyncio
+ async def test_compress_contexts_batch(self) -> None:
+ """Test compressing multiple contexts."""
+ compressor = ContextCompressor()
+ allocator = BudgetAllocator()
+ budget = allocator.create_budget(1000)
+
+ contexts = [
+ KnowledgeContext(content="A" * 200, source="docs"),
+ KnowledgeContext(content="B" * 200, source="docs"),
+ TaskContext(content="C" * 200, source="task"),
+ ]
+
+ result = await compressor.compress_contexts(contexts, budget)
+
+ assert len(result) == 3
+
+ @pytest.mark.asyncio
+ async def test_strategy_selection_by_type(self) -> None:
+ """Test that correct strategy is selected for each type."""
+ compressor = ContextCompressor()
+
+ assert compressor._get_strategy_for_type(ContextType.SYSTEM) == "end"
+ assert compressor._get_strategy_for_type(ContextType.TASK) == "end"
+ assert compressor._get_strategy_for_type(ContextType.KNOWLEDGE) == "sentence"
+ assert compressor._get_strategy_for_type(ContextType.CONVERSATION) == "end"
+ assert compressor._get_strategy_for_type(ContextType.TOOL) == "middle"
diff --git a/backend/tests/services/context/test_types.py b/backend/tests/services/context/test_types.py
index ca36566..2a5743e 100644
--- a/backend/tests/services/context/test_types.py
+++ b/backend/tests/services/context/test_types.py
@@ -426,11 +426,14 @@ class TestAssembledContext:
"""Test basic creation."""
ctx = AssembledContext(
content="Assembled content here",
- token_count=500,
- contexts_included=5,
+ total_tokens=500,
+ context_count=5,
)
assert ctx.content == "Assembled content here"
+ assert ctx.total_tokens == 500
+ assert ctx.context_count == 5
+ # Test backward compatibility aliases
assert ctx.token_count == 500
assert ctx.contexts_included == 5
@@ -438,8 +441,8 @@ class TestAssembledContext:
"""Test budget_utilization property."""
ctx = AssembledContext(
content="test",
- token_count=800,
- contexts_included=5,
+ total_tokens=800,
+ context_count=5,
budget_total=1000,
budget_used=800,
)
@@ -450,8 +453,8 @@ class TestAssembledContext:
"""Test budget_utilization with zero budget."""
ctx = AssembledContext(
content="test",
- token_count=0,
- contexts_included=0,
+ total_tokens=0,
+ context_count=0,
budget_total=0,
budget_used=0,
)
@@ -462,24 +465,26 @@ class TestAssembledContext:
"""Test to_dict method."""
ctx = AssembledContext(
content="test",
- token_count=100,
- contexts_included=2,
+ total_tokens=100,
+ context_count=2,
assembly_time_ms=50.123,
)
data = ctx.to_dict()
assert data["content"] == "test"
- assert data["token_count"] == 100
+ assert data["total_tokens"] == 100
+ assert data["context_count"] == 2
assert data["assembly_time_ms"] == 50.12 # Rounded
def test_to_json_and_from_json(self) -> None:
"""Test JSON serialization round-trip."""
original = AssembledContext(
content="Test content",
- token_count=100,
- contexts_included=3,
- contexts_excluded=2,
+ total_tokens=100,
+ context_count=3,
+ excluded_count=2,
assembly_time_ms=45.5,
+ model="claude-3-sonnet",
budget_total=1000,
budget_used=100,
by_type={"system": 20, "knowledge": 80},
@@ -491,8 +496,10 @@ class TestAssembledContext:
restored = AssembledContext.from_json(json_str)
assert restored.content == original.content
- assert restored.token_count == original.token_count
- assert restored.contexts_included == original.contexts_included
+ assert restored.total_tokens == original.total_tokens
+ assert restored.context_count == original.context_count
+ assert restored.excluded_count == original.excluded_count
+ assert restored.model == original.model
assert restored.cache_hit == original.cache_hit
assert restored.cache_key == original.cache_key