Files
fast-next-template/backend/tests/services/context/test_assembly.py
Felipe Cardoso 2bea057fb1 chore(context): refactor for consistency, optimize formatting, and simplify logic
- Cleaned up unnecessary comments in `__all__` definitions for better readability.
- Adjusted indentation and formatting across modules for improved clarity (e.g., long lines, logical grouping).
- Simplified conditional expressions and inline comments for context scoring and ranking.
- Replaced some hard-coded values with type-safe annotations (e.g., `ClassVar`).
- Removed unused imports and ensured consistent usage across test files.
- Updated `test_score_not_cached_on_context` to clarify caching behavior.
- Improved truncation strategy logic and marker handling.
2026-01-04 15:23:14 +01:00

509 lines
15 KiB
Python

"""Tests for context assembly pipeline."""
from datetime import UTC, datetime
import pytest
from app.services.context.assembly import ContextPipeline, PipelineMetrics
from app.services.context.budget import TokenBudget
from app.services.context.types import (
AssembledContext,
ConversationContext,
KnowledgeContext,
MessageRole,
SystemContext,
TaskContext,
ToolContext,
)
class TestPipelineMetrics:
"""Tests for PipelineMetrics dataclass."""
def test_creation(self) -> None:
"""Test metrics creation."""
metrics = PipelineMetrics()
assert metrics.total_contexts == 0
assert metrics.selected_contexts == 0
assert metrics.assembly_time_ms == 0.0
def test_to_dict(self) -> None:
"""Test conversion to dictionary."""
metrics = PipelineMetrics(
total_contexts=10,
selected_contexts=8,
excluded_contexts=2,
total_tokens=500,
assembly_time_ms=25.5,
)
metrics.end_time = datetime.now(UTC)
data = metrics.to_dict()
assert data["total_contexts"] == 10
assert data["selected_contexts"] == 8
assert data["excluded_contexts"] == 2
assert data["total_tokens"] == 500
assert data["assembly_time_ms"] == 25.5
assert "start_time" in data
assert "end_time" in data
class TestContextPipeline:
"""Tests for ContextPipeline."""
def test_creation(self) -> None:
"""Test pipeline creation."""
pipeline = ContextPipeline()
assert pipeline._calculator is not None
assert pipeline._scorer is not None
assert pipeline._ranker is not None
assert pipeline._compressor is not None
assert pipeline._allocator is not None
@pytest.mark.asyncio
async def test_assemble_empty_contexts(self) -> None:
"""Test assembling empty context list."""
pipeline = ContextPipeline()
result = await pipeline.assemble(
contexts=[],
query="test query",
model="claude-3-sonnet",
)
assert isinstance(result, AssembledContext)
assert result.context_count == 0
assert result.total_tokens == 0
@pytest.mark.asyncio
async def test_assemble_single_context(self) -> None:
"""Test assembling single context."""
pipeline = ContextPipeline()
contexts = [
SystemContext(
content="You are a helpful assistant.",
source="system",
)
]
result = await pipeline.assemble(
contexts=contexts,
query="help me",
model="claude-3-sonnet",
)
assert result.context_count == 1
assert result.total_tokens > 0
assert "helpful assistant" in result.content
@pytest.mark.asyncio
async def test_assemble_multiple_types(self) -> None:
"""Test assembling multiple context types."""
pipeline = ContextPipeline()
contexts = [
SystemContext(
content="You are a coding assistant.",
source="system",
),
TaskContext(
content="Implement a login feature.",
source="task",
),
KnowledgeContext(
content="Authentication best practices include...",
source="docs/auth.md",
relevance_score=0.8,
),
]
result = await pipeline.assemble(
contexts=contexts,
query="implement login",
model="claude-3-sonnet",
)
assert result.context_count >= 1
assert result.total_tokens > 0
@pytest.mark.asyncio
async def test_assemble_with_custom_budget(self) -> None:
"""Test assembling with custom budget."""
pipeline = ContextPipeline()
budget = TokenBudget(
total=1000,
system=200,
task=200,
knowledge=400,
conversation=100,
tools=50,
response_reserve=50,
)
contexts = [
SystemContext(content="System prompt", source="system"),
TaskContext(content="Task description", source="task"),
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="gpt-4",
custom_budget=budget,
)
assert result.context_count >= 1
@pytest.mark.asyncio
async def test_assemble_with_max_tokens(self) -> None:
"""Test assembling with max_tokens limit."""
pipeline = ContextPipeline()
contexts = [
SystemContext(content="System prompt", source="system"),
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="gpt-4",
max_tokens=5000,
)
assert "budget" in result.metadata
assert result.metadata["budget"]["total"] == 5000
@pytest.mark.asyncio
async def test_assemble_format_output(self) -> None:
"""Test formatted vs unformatted output."""
pipeline = ContextPipeline()
contexts = [
SystemContext(content="System prompt", source="system"),
]
# Formatted (default)
result_formatted = await pipeline.assemble(
contexts=contexts,
query="test",
model="claude-3-sonnet",
format_output=True,
)
# Unformatted
result_raw = await pipeline.assemble(
contexts=contexts,
query="test",
model="claude-3-sonnet",
format_output=False,
)
# Formatted should have XML tags for Claude
assert "<system_instructions>" in result_formatted.content
# Raw should not
assert "<system_instructions>" not in result_raw.content
@pytest.mark.asyncio
async def test_assemble_metrics(self) -> None:
"""Test that metrics are populated."""
pipeline = ContextPipeline()
contexts = [
SystemContext(content="System", source="system"),
TaskContext(content="Task", source="task"),
KnowledgeContext(
content="Knowledge",
source="docs",
relevance_score=0.9,
),
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="claude-3-sonnet",
)
assert "metrics" in result.metadata
metrics = result.metadata["metrics"]
assert metrics["total_contexts"] == 3
assert metrics["assembly_time_ms"] > 0
assert "scoring_time_ms" in metrics
assert "formatting_time_ms" in metrics
@pytest.mark.asyncio
async def test_assemble_with_compression_disabled(self) -> None:
"""Test assembling with compression disabled."""
pipeline = ContextPipeline()
contexts = [
KnowledgeContext(content="A" * 1000, source="docs"),
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="gpt-4",
compress=False,
)
# Should still work, just no compression applied
assert result.context_count >= 0
class TestContextPipelineFormatting:
"""Tests for context formatting."""
@pytest.mark.asyncio
async def test_format_claude_uses_xml(self) -> None:
"""Test that Claude models use XML formatting."""
pipeline = ContextPipeline()
contexts = [
SystemContext(content="System prompt", source="system"),
TaskContext(content="Task", source="task"),
KnowledgeContext(
content="Knowledge",
source="docs",
relevance_score=0.9,
),
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="claude-3-sonnet",
)
# Claude should have XML tags
assert "<system_instructions>" in result.content or result.context_count == 0
@pytest.mark.asyncio
async def test_format_openai_uses_markdown(self) -> None:
"""Test that OpenAI models use markdown formatting."""
pipeline = ContextPipeline()
contexts = [
TaskContext(content="Task description", source="task"),
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="gpt-4",
)
# OpenAI should have markdown headers
if result.context_count > 0 and "Task" in result.content:
assert "## Current Task" in result.content
@pytest.mark.asyncio
async def test_format_knowledge_claude(self) -> None:
"""Test knowledge formatting for Claude."""
pipeline = ContextPipeline()
contexts = [
KnowledgeContext(
content="Document content here",
source="docs/file.md",
relevance_score=0.9,
),
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="claude-3-sonnet",
)
if result.context_count > 0:
assert "<reference_documents>" in result.content
assert "<document" in result.content
@pytest.mark.asyncio
async def test_format_conversation(self) -> None:
"""Test conversation formatting."""
pipeline = ContextPipeline()
contexts = [
ConversationContext(
content="Hello, how are you?",
source="chat",
role=MessageRole.USER,
metadata={"role": "user"},
),
ConversationContext(
content="I'm doing great!",
source="chat",
role=MessageRole.ASSISTANT,
metadata={"role": "assistant"},
),
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="claude-3-sonnet",
)
if result.context_count > 0:
assert "<conversation_history>" in result.content
assert (
'<message role="user">' in result.content
or 'role="user"' in result.content
)
@pytest.mark.asyncio
async def test_format_tool_results(self) -> None:
"""Test tool result formatting."""
pipeline = ContextPipeline()
contexts = [
ToolContext(
content="Tool output here",
source="tool",
metadata={"tool_name": "search"},
),
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="claude-3-sonnet",
)
if result.context_count > 0:
assert "<tool_results>" in result.content
class TestContextPipelineIntegration:
"""Integration tests for full pipeline."""
@pytest.mark.asyncio
async def test_full_pipeline_workflow(self) -> None:
"""Test complete pipeline workflow."""
pipeline = ContextPipeline()
# Create realistic context mix
contexts = [
SystemContext(
content="You are an expert Python developer.",
source="system",
),
TaskContext(
content="Implement a user authentication system.",
source="task:AUTH-123",
),
KnowledgeContext(
content="JWT tokens provide stateless authentication...",
source="docs/auth/jwt.md",
relevance_score=0.9,
),
KnowledgeContext(
content="OAuth 2.0 is an authorization framework...",
source="docs/auth/oauth.md",
relevance_score=0.7,
),
ConversationContext(
content="Can you help me implement JWT auth?",
source="chat",
role=MessageRole.USER,
metadata={"role": "user"},
),
]
result = await pipeline.assemble(
contexts=contexts,
query="implement JWT authentication",
model="claude-3-sonnet",
)
# Verify result
assert isinstance(result, AssembledContext)
assert result.context_count > 0
assert result.total_tokens > 0
assert result.assembly_time_ms > 0
assert result.model == "claude-3-sonnet"
assert len(result.content) > 0
# Verify metrics
assert "metrics" in result.metadata
assert "query" in result.metadata
assert "budget" in result.metadata
@pytest.mark.asyncio
async def test_context_type_ordering(self) -> None:
"""Test that contexts are ordered by type correctly."""
pipeline = ContextPipeline()
# Add in random order
contexts = [
KnowledgeContext(content="Knowledge", source="docs", relevance_score=0.9),
ToolContext(content="Tool", source="tool", metadata={"tool_name": "test"}),
SystemContext(content="System", source="system"),
ConversationContext(
content="Chat",
source="chat",
role=MessageRole.USER,
metadata={"role": "user"},
),
TaskContext(content="Task", source="task"),
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="claude-3-sonnet",
)
# For Claude, verify order: System -> Task -> Knowledge -> Conversation -> Tool
content = result.content
if result.context_count > 0:
# Find positions (if they exist)
system_pos = content.find("system_instructions")
task_pos = content.find("current_task")
knowledge_pos = content.find("reference_documents")
conversation_pos = content.find("conversation_history")
tool_pos = content.find("tool_results")
# Verify ordering (only check if both exist)
if system_pos >= 0 and task_pos >= 0:
assert system_pos < task_pos
if task_pos >= 0 and knowledge_pos >= 0:
assert task_pos < knowledge_pos
if knowledge_pos >= 0 and conversation_pos >= 0:
assert knowledge_pos < conversation_pos
if conversation_pos >= 0 and tool_pos >= 0:
assert conversation_pos < tool_pos
@pytest.mark.asyncio
async def test_excluded_contexts_tracked(self) -> None:
"""Test that excluded contexts are tracked in result."""
pipeline = ContextPipeline()
# Create many contexts to force some exclusions
contexts = [
KnowledgeContext(
content="A" * 500, # Large content
source=f"docs/{i}",
relevance_score=0.1 + (i * 0.05),
)
for i in range(10)
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="gpt-4", # Smaller context window
max_tokens=1000, # Limited budget
)
# Should have excluded some
assert result.excluded_count >= 0
assert result.context_count + result.excluded_count <= len(contexts)