feat(context): implement assembly pipeline and compression (#82)

Phase 4 of Context Management Engine - Assembly Pipeline:

- Add TruncationStrategy with end/middle/sentence-aware truncation
- Add TruncationResult dataclass for tracking compression metrics
- Add ContextCompressor for type-specific compression
- Add ContextPipeline orchestrating full assembly workflow:
  - Token counting for all contexts
  - Scoring and ranking via ContextRanker
  - Optional compression when budget threshold exceeded
  - Model-specific formatting (XML for Claude, markdown for OpenAI)
- Add PipelineMetrics for performance tracking
- Update AssembledContext with new fields (model, contexts, metadata)
- Add backward compatibility aliases for renamed fields

Tests: 34 new tests, 223 total context tests passing

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-04 02:32:25 +01:00
parent 0d2005ddcb
commit 6b07e62f00
9 changed files with 1631 additions and 23 deletions

View File

@@ -0,0 +1,502 @@
"""Tests for context assembly pipeline."""
from datetime import UTC, datetime
import pytest
from app.services.context.assembly import ContextPipeline, PipelineMetrics
from app.services.context.budget import BudgetAllocator, TokenBudget
from app.services.context.types import (
AssembledContext,
ContextType,
ConversationContext,
KnowledgeContext,
MessageRole,
SystemContext,
TaskContext,
ToolContext,
)
class TestPipelineMetrics:
"""Tests for PipelineMetrics dataclass."""
def test_creation(self) -> None:
"""Test metrics creation."""
metrics = PipelineMetrics()
assert metrics.total_contexts == 0
assert metrics.selected_contexts == 0
assert metrics.assembly_time_ms == 0.0
def test_to_dict(self) -> None:
"""Test conversion to dictionary."""
metrics = PipelineMetrics(
total_contexts=10,
selected_contexts=8,
excluded_contexts=2,
total_tokens=500,
assembly_time_ms=25.5,
)
metrics.end_time = datetime.now(UTC)
data = metrics.to_dict()
assert data["total_contexts"] == 10
assert data["selected_contexts"] == 8
assert data["excluded_contexts"] == 2
assert data["total_tokens"] == 500
assert data["assembly_time_ms"] == 25.5
assert "start_time" in data
assert "end_time" in data
class TestContextPipeline:
"""Tests for ContextPipeline."""
def test_creation(self) -> None:
"""Test pipeline creation."""
pipeline = ContextPipeline()
assert pipeline._calculator is not None
assert pipeline._scorer is not None
assert pipeline._ranker is not None
assert pipeline._compressor is not None
assert pipeline._allocator is not None
@pytest.mark.asyncio
async def test_assemble_empty_contexts(self) -> None:
"""Test assembling empty context list."""
pipeline = ContextPipeline()
result = await pipeline.assemble(
contexts=[],
query="test query",
model="claude-3-sonnet",
)
assert isinstance(result, AssembledContext)
assert result.context_count == 0
assert result.total_tokens == 0
@pytest.mark.asyncio
async def test_assemble_single_context(self) -> None:
"""Test assembling single context."""
pipeline = ContextPipeline()
contexts = [
SystemContext(
content="You are a helpful assistant.",
source="system",
)
]
result = await pipeline.assemble(
contexts=contexts,
query="help me",
model="claude-3-sonnet",
)
assert result.context_count == 1
assert result.total_tokens > 0
assert "helpful assistant" in result.content
@pytest.mark.asyncio
async def test_assemble_multiple_types(self) -> None:
"""Test assembling multiple context types."""
pipeline = ContextPipeline()
contexts = [
SystemContext(
content="You are a coding assistant.",
source="system",
),
TaskContext(
content="Implement a login feature.",
source="task",
),
KnowledgeContext(
content="Authentication best practices include...",
source="docs/auth.md",
relevance_score=0.8,
),
]
result = await pipeline.assemble(
contexts=contexts,
query="implement login",
model="claude-3-sonnet",
)
assert result.context_count >= 1
assert result.total_tokens > 0
@pytest.mark.asyncio
async def test_assemble_with_custom_budget(self) -> None:
"""Test assembling with custom budget."""
pipeline = ContextPipeline()
budget = TokenBudget(
total=1000,
system=200,
task=200,
knowledge=400,
conversation=100,
tools=50,
response_reserve=50,
)
contexts = [
SystemContext(content="System prompt", source="system"),
TaskContext(content="Task description", source="task"),
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="gpt-4",
custom_budget=budget,
)
assert result.context_count >= 1
@pytest.mark.asyncio
async def test_assemble_with_max_tokens(self) -> None:
"""Test assembling with max_tokens limit."""
pipeline = ContextPipeline()
contexts = [
SystemContext(content="System prompt", source="system"),
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="gpt-4",
max_tokens=5000,
)
assert "budget" in result.metadata
assert result.metadata["budget"]["total"] == 5000
@pytest.mark.asyncio
async def test_assemble_format_output(self) -> None:
"""Test formatted vs unformatted output."""
pipeline = ContextPipeline()
contexts = [
SystemContext(content="System prompt", source="system"),
]
# Formatted (default)
result_formatted = await pipeline.assemble(
contexts=contexts,
query="test",
model="claude-3-sonnet",
format_output=True,
)
# Unformatted
result_raw = await pipeline.assemble(
contexts=contexts,
query="test",
model="claude-3-sonnet",
format_output=False,
)
# Formatted should have XML tags for Claude
assert "<system_instructions>" in result_formatted.content
# Raw should not
assert "<system_instructions>" not in result_raw.content
@pytest.mark.asyncio
async def test_assemble_metrics(self) -> None:
"""Test that metrics are populated."""
pipeline = ContextPipeline()
contexts = [
SystemContext(content="System", source="system"),
TaskContext(content="Task", source="task"),
KnowledgeContext(
content="Knowledge",
source="docs",
relevance_score=0.9,
),
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="claude-3-sonnet",
)
assert "metrics" in result.metadata
metrics = result.metadata["metrics"]
assert metrics["total_contexts"] == 3
assert metrics["assembly_time_ms"] > 0
assert "scoring_time_ms" in metrics
assert "formatting_time_ms" in metrics
@pytest.mark.asyncio
async def test_assemble_with_compression_disabled(self) -> None:
"""Test assembling with compression disabled."""
pipeline = ContextPipeline()
contexts = [
KnowledgeContext(content="A" * 1000, source="docs"),
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="gpt-4",
compress=False,
)
# Should still work, just no compression applied
assert result.context_count >= 0
class TestContextPipelineFormatting:
"""Tests for context formatting."""
@pytest.mark.asyncio
async def test_format_claude_uses_xml(self) -> None:
"""Test that Claude models use XML formatting."""
pipeline = ContextPipeline()
contexts = [
SystemContext(content="System prompt", source="system"),
TaskContext(content="Task", source="task"),
KnowledgeContext(
content="Knowledge",
source="docs",
relevance_score=0.9,
),
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="claude-3-sonnet",
)
# Claude should have XML tags
assert "<system_instructions>" in result.content or result.context_count == 0
@pytest.mark.asyncio
async def test_format_openai_uses_markdown(self) -> None:
"""Test that OpenAI models use markdown formatting."""
pipeline = ContextPipeline()
contexts = [
TaskContext(content="Task description", source="task"),
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="gpt-4",
)
# OpenAI should have markdown headers
if result.context_count > 0 and "Task" in result.content:
assert "## Current Task" in result.content
@pytest.mark.asyncio
async def test_format_knowledge_claude(self) -> None:
"""Test knowledge formatting for Claude."""
pipeline = ContextPipeline()
contexts = [
KnowledgeContext(
content="Document content here",
source="docs/file.md",
relevance_score=0.9,
),
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="claude-3-sonnet",
)
if result.context_count > 0:
assert "<reference_documents>" in result.content
assert "<document" in result.content
@pytest.mark.asyncio
async def test_format_conversation(self) -> None:
"""Test conversation formatting."""
pipeline = ContextPipeline()
contexts = [
ConversationContext(
content="Hello, how are you?",
source="chat",
role=MessageRole.USER,
metadata={"role": "user"},
),
ConversationContext(
content="I'm doing great!",
source="chat",
role=MessageRole.ASSISTANT,
metadata={"role": "assistant"},
),
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="claude-3-sonnet",
)
if result.context_count > 0:
assert "<conversation_history>" in result.content
assert '<message role="user">' in result.content or 'role="user"' in result.content
@pytest.mark.asyncio
async def test_format_tool_results(self) -> None:
"""Test tool result formatting."""
pipeline = ContextPipeline()
contexts = [
ToolContext(
content="Tool output here",
source="tool",
metadata={"tool_name": "search"},
),
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="claude-3-sonnet",
)
if result.context_count > 0:
assert "<tool_results>" in result.content
class TestContextPipelineIntegration:
"""Integration tests for full pipeline."""
@pytest.mark.asyncio
async def test_full_pipeline_workflow(self) -> None:
"""Test complete pipeline workflow."""
pipeline = ContextPipeline()
# Create realistic context mix
contexts = [
SystemContext(
content="You are an expert Python developer.",
source="system",
),
TaskContext(
content="Implement a user authentication system.",
source="task:AUTH-123",
),
KnowledgeContext(
content="JWT tokens provide stateless authentication...",
source="docs/auth/jwt.md",
relevance_score=0.9,
),
KnowledgeContext(
content="OAuth 2.0 is an authorization framework...",
source="docs/auth/oauth.md",
relevance_score=0.7,
),
ConversationContext(
content="Can you help me implement JWT auth?",
source="chat",
role=MessageRole.USER,
metadata={"role": "user"},
),
]
result = await pipeline.assemble(
contexts=contexts,
query="implement JWT authentication",
model="claude-3-sonnet",
)
# Verify result
assert isinstance(result, AssembledContext)
assert result.context_count > 0
assert result.total_tokens > 0
assert result.assembly_time_ms > 0
assert result.model == "claude-3-sonnet"
assert len(result.content) > 0
# Verify metrics
assert "metrics" in result.metadata
assert "query" in result.metadata
assert "budget" in result.metadata
@pytest.mark.asyncio
async def test_context_type_ordering(self) -> None:
"""Test that contexts are ordered by type correctly."""
pipeline = ContextPipeline()
# Add in random order
contexts = [
KnowledgeContext(content="Knowledge", source="docs", relevance_score=0.9),
ToolContext(content="Tool", source="tool", metadata={"tool_name": "test"}),
SystemContext(content="System", source="system"),
ConversationContext(
content="Chat",
source="chat",
role=MessageRole.USER,
metadata={"role": "user"},
),
TaskContext(content="Task", source="task"),
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="claude-3-sonnet",
)
# For Claude, verify order: System -> Task -> Knowledge -> Conversation -> Tool
content = result.content
if result.context_count > 0:
# Find positions (if they exist)
system_pos = content.find("system_instructions")
task_pos = content.find("current_task")
knowledge_pos = content.find("reference_documents")
conversation_pos = content.find("conversation_history")
tool_pos = content.find("tool_results")
# Verify ordering (only check if both exist)
if system_pos >= 0 and task_pos >= 0:
assert system_pos < task_pos
if task_pos >= 0 and knowledge_pos >= 0:
assert task_pos < knowledge_pos
@pytest.mark.asyncio
async def test_excluded_contexts_tracked(self) -> None:
"""Test that excluded contexts are tracked in result."""
pipeline = ContextPipeline()
# Create many contexts to force some exclusions
contexts = [
KnowledgeContext(
content="A" * 500, # Large content
source=f"docs/{i}",
relevance_score=0.1 + (i * 0.05),
)
for i in range(10)
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="gpt-4", # Smaller context window
max_tokens=1000, # Limited budget
)
# Should have excluded some
assert result.excluded_count >= 0
assert result.context_count + result.excluded_count <= len(contexts)

View File

@@ -0,0 +1,214 @@
"""Tests for context compression module."""
import pytest
from app.services.context.compression import (
ContextCompressor,
TruncationResult,
TruncationStrategy,
)
from app.services.context.budget import BudgetAllocator, TokenBudget
from app.services.context.types import (
ContextType,
KnowledgeContext,
SystemContext,
TaskContext,
)
class TestTruncationResult:
"""Tests for TruncationResult dataclass."""
def test_creation(self) -> None:
"""Test basic creation."""
result = TruncationResult(
original_tokens=100,
truncated_tokens=50,
content="Truncated content",
truncated=True,
truncation_ratio=0.5,
)
assert result.original_tokens == 100
assert result.truncated_tokens == 50
assert result.truncated is True
assert result.truncation_ratio == 0.5
def test_tokens_saved(self) -> None:
"""Test tokens_saved property."""
result = TruncationResult(
original_tokens=100,
truncated_tokens=40,
content="Test",
truncated=True,
truncation_ratio=0.6,
)
assert result.tokens_saved == 60
def test_no_truncation(self) -> None:
"""Test when no truncation needed."""
result = TruncationResult(
original_tokens=50,
truncated_tokens=50,
content="Full content",
truncated=False,
truncation_ratio=0.0,
)
assert result.tokens_saved == 0
assert result.truncated is False
class TestTruncationStrategy:
"""Tests for TruncationStrategy."""
def test_creation(self) -> None:
"""Test strategy creation."""
strategy = TruncationStrategy()
assert strategy._preserve_ratio_start == 0.7
assert strategy._min_content_length == 100
def test_creation_with_params(self) -> None:
"""Test strategy creation with custom params."""
strategy = TruncationStrategy(
preserve_ratio_start=0.5,
min_content_length=50,
)
assert strategy._preserve_ratio_start == 0.5
assert strategy._min_content_length == 50
@pytest.mark.asyncio
async def test_truncate_empty_content(self) -> None:
"""Test truncating empty content."""
strategy = TruncationStrategy()
result = await strategy.truncate_to_tokens("", max_tokens=100)
assert result.original_tokens == 0
assert result.truncated_tokens == 0
assert result.content == ""
assert result.truncated is False
@pytest.mark.asyncio
async def test_truncate_content_within_limit(self) -> None:
"""Test content that fits within limit."""
strategy = TruncationStrategy()
content = "Short content"
result = await strategy.truncate_to_tokens(content, max_tokens=100)
assert result.content == content
assert result.truncated is False
@pytest.mark.asyncio
async def test_truncate_end_strategy(self) -> None:
"""Test end truncation strategy."""
strategy = TruncationStrategy()
content = "A" * 1000 # Long content
result = await strategy.truncate_to_tokens(
content, max_tokens=50, strategy="end"
)
assert result.truncated is True
assert len(result.content) < len(content)
assert strategy.TRUNCATION_MARKER in result.content
@pytest.mark.asyncio
async def test_truncate_middle_strategy(self) -> None:
"""Test middle truncation strategy."""
strategy = TruncationStrategy(preserve_ratio_start=0.6)
content = "START " + "A" * 500 + " END"
result = await strategy.truncate_to_tokens(
content, max_tokens=50, strategy="middle"
)
assert result.truncated is True
assert strategy.TRUNCATION_MARKER in result.content
@pytest.mark.asyncio
async def test_truncate_sentence_strategy(self) -> None:
"""Test sentence-aware truncation strategy."""
strategy = TruncationStrategy()
content = "First sentence. Second sentence. Third sentence. Fourth sentence."
result = await strategy.truncate_to_tokens(
content, max_tokens=10, strategy="sentence"
)
assert result.truncated is True
# Should cut at sentence boundary
assert result.content.endswith(".") or strategy.TRUNCATION_MARKER in result.content
class TestContextCompressor:
"""Tests for ContextCompressor."""
def test_creation(self) -> None:
"""Test compressor creation."""
compressor = ContextCompressor()
assert compressor._truncation is not None
@pytest.mark.asyncio
async def test_compress_context_within_limit(self) -> None:
"""Test compressing context that already fits."""
compressor = ContextCompressor()
context = KnowledgeContext(
content="Short content",
source="docs",
)
context.token_count = 5
result = await compressor.compress_context(context, max_tokens=100)
# Should return same context unmodified
assert result.content == "Short content"
assert result.metadata.get("truncated") is not True
@pytest.mark.asyncio
async def test_compress_context_exceeds_limit(self) -> None:
"""Test compressing context that exceeds limit."""
compressor = ContextCompressor()
context = KnowledgeContext(
content="A" * 500,
source="docs",
)
context.token_count = 125 # Approximately 500/4
result = await compressor.compress_context(context, max_tokens=20)
assert result.metadata.get("truncated") is True
assert result.metadata.get("original_tokens") == 125
assert len(result.content) < 500
@pytest.mark.asyncio
async def test_compress_contexts_batch(self) -> None:
"""Test compressing multiple contexts."""
compressor = ContextCompressor()
allocator = BudgetAllocator()
budget = allocator.create_budget(1000)
contexts = [
KnowledgeContext(content="A" * 200, source="docs"),
KnowledgeContext(content="B" * 200, source="docs"),
TaskContext(content="C" * 200, source="task"),
]
result = await compressor.compress_contexts(contexts, budget)
assert len(result) == 3
@pytest.mark.asyncio
async def test_strategy_selection_by_type(self) -> None:
"""Test that correct strategy is selected for each type."""
compressor = ContextCompressor()
assert compressor._get_strategy_for_type(ContextType.SYSTEM) == "end"
assert compressor._get_strategy_for_type(ContextType.TASK) == "end"
assert compressor._get_strategy_for_type(ContextType.KNOWLEDGE) == "sentence"
assert compressor._get_strategy_for_type(ContextType.CONVERSATION) == "end"
assert compressor._get_strategy_for_type(ContextType.TOOL) == "middle"

View File

@@ -426,11 +426,14 @@ class TestAssembledContext:
"""Test basic creation."""
ctx = AssembledContext(
content="Assembled content here",
token_count=500,
contexts_included=5,
total_tokens=500,
context_count=5,
)
assert ctx.content == "Assembled content here"
assert ctx.total_tokens == 500
assert ctx.context_count == 5
# Test backward compatibility aliases
assert ctx.token_count == 500
assert ctx.contexts_included == 5
@@ -438,8 +441,8 @@ class TestAssembledContext:
"""Test budget_utilization property."""
ctx = AssembledContext(
content="test",
token_count=800,
contexts_included=5,
total_tokens=800,
context_count=5,
budget_total=1000,
budget_used=800,
)
@@ -450,8 +453,8 @@ class TestAssembledContext:
"""Test budget_utilization with zero budget."""
ctx = AssembledContext(
content="test",
token_count=0,
contexts_included=0,
total_tokens=0,
context_count=0,
budget_total=0,
budget_used=0,
)
@@ -462,24 +465,26 @@ class TestAssembledContext:
"""Test to_dict method."""
ctx = AssembledContext(
content="test",
token_count=100,
contexts_included=2,
total_tokens=100,
context_count=2,
assembly_time_ms=50.123,
)
data = ctx.to_dict()
assert data["content"] == "test"
assert data["token_count"] == 100
assert data["total_tokens"] == 100
assert data["context_count"] == 2
assert data["assembly_time_ms"] == 50.12 # Rounded
def test_to_json_and_from_json(self) -> None:
"""Test JSON serialization round-trip."""
original = AssembledContext(
content="Test content",
token_count=100,
contexts_included=3,
contexts_excluded=2,
total_tokens=100,
context_count=3,
excluded_count=2,
assembly_time_ms=45.5,
model="claude-3-sonnet",
budget_total=1000,
budget_used=100,
by_type={"system": 20, "knowledge": 80},
@@ -491,8 +496,10 @@ class TestAssembledContext:
restored = AssembledContext.from_json(json_str)
assert restored.content == original.content
assert restored.token_count == original.token_count
assert restored.contexts_included == original.contexts_included
assert restored.total_tokens == original.total_tokens
assert restored.context_count == original.context_count
assert restored.excluded_count == original.excluded_count
assert restored.model == original.model
assert restored.cache_hit == original.cache_hit
assert restored.cache_key == original.cache_key