- Cleaned up unnecessary comments in `__all__` definitions for better readability. - Adjusted indentation and formatting across modules for improved clarity (e.g., long lines, logical grouping). - Simplified conditional expressions and inline comments for context scoring and ranking. - Replaced some hard-coded values with type-safe annotations (e.g., `ClassVar`). - Removed unused imports and ensured consistent usage across test files. - Updated `test_score_not_cached_on_context` to clarify caching behavior. - Improved truncation strategy logic and marker handling.
509 lines
15 KiB
Python
509 lines
15 KiB
Python
"""Tests for context assembly pipeline."""
|
|
|
|
from datetime import UTC, datetime
|
|
|
|
import pytest
|
|
|
|
from app.services.context.assembly import ContextPipeline, PipelineMetrics
|
|
from app.services.context.budget import TokenBudget
|
|
from app.services.context.types import (
|
|
AssembledContext,
|
|
ConversationContext,
|
|
KnowledgeContext,
|
|
MessageRole,
|
|
SystemContext,
|
|
TaskContext,
|
|
ToolContext,
|
|
)
|
|
|
|
|
|
class TestPipelineMetrics:
|
|
"""Tests for PipelineMetrics dataclass."""
|
|
|
|
def test_creation(self) -> None:
|
|
"""Test metrics creation."""
|
|
metrics = PipelineMetrics()
|
|
|
|
assert metrics.total_contexts == 0
|
|
assert metrics.selected_contexts == 0
|
|
assert metrics.assembly_time_ms == 0.0
|
|
|
|
def test_to_dict(self) -> None:
|
|
"""Test conversion to dictionary."""
|
|
metrics = PipelineMetrics(
|
|
total_contexts=10,
|
|
selected_contexts=8,
|
|
excluded_contexts=2,
|
|
total_tokens=500,
|
|
assembly_time_ms=25.5,
|
|
)
|
|
metrics.end_time = datetime.now(UTC)
|
|
|
|
data = metrics.to_dict()
|
|
|
|
assert data["total_contexts"] == 10
|
|
assert data["selected_contexts"] == 8
|
|
assert data["excluded_contexts"] == 2
|
|
assert data["total_tokens"] == 500
|
|
assert data["assembly_time_ms"] == 25.5
|
|
assert "start_time" in data
|
|
assert "end_time" in data
|
|
|
|
|
|
class TestContextPipeline:
|
|
"""Tests for ContextPipeline."""
|
|
|
|
def test_creation(self) -> None:
|
|
"""Test pipeline creation."""
|
|
pipeline = ContextPipeline()
|
|
|
|
assert pipeline._calculator is not None
|
|
assert pipeline._scorer is not None
|
|
assert pipeline._ranker is not None
|
|
assert pipeline._compressor is not None
|
|
assert pipeline._allocator is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_assemble_empty_contexts(self) -> None:
|
|
"""Test assembling empty context list."""
|
|
pipeline = ContextPipeline()
|
|
|
|
result = await pipeline.assemble(
|
|
contexts=[],
|
|
query="test query",
|
|
model="claude-3-sonnet",
|
|
)
|
|
|
|
assert isinstance(result, AssembledContext)
|
|
assert result.context_count == 0
|
|
assert result.total_tokens == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_assemble_single_context(self) -> None:
|
|
"""Test assembling single context."""
|
|
pipeline = ContextPipeline()
|
|
|
|
contexts = [
|
|
SystemContext(
|
|
content="You are a helpful assistant.",
|
|
source="system",
|
|
)
|
|
]
|
|
|
|
result = await pipeline.assemble(
|
|
contexts=contexts,
|
|
query="help me",
|
|
model="claude-3-sonnet",
|
|
)
|
|
|
|
assert result.context_count == 1
|
|
assert result.total_tokens > 0
|
|
assert "helpful assistant" in result.content
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_assemble_multiple_types(self) -> None:
|
|
"""Test assembling multiple context types."""
|
|
pipeline = ContextPipeline()
|
|
|
|
contexts = [
|
|
SystemContext(
|
|
content="You are a coding assistant.",
|
|
source="system",
|
|
),
|
|
TaskContext(
|
|
content="Implement a login feature.",
|
|
source="task",
|
|
),
|
|
KnowledgeContext(
|
|
content="Authentication best practices include...",
|
|
source="docs/auth.md",
|
|
relevance_score=0.8,
|
|
),
|
|
]
|
|
|
|
result = await pipeline.assemble(
|
|
contexts=contexts,
|
|
query="implement login",
|
|
model="claude-3-sonnet",
|
|
)
|
|
|
|
assert result.context_count >= 1
|
|
assert result.total_tokens > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_assemble_with_custom_budget(self) -> None:
|
|
"""Test assembling with custom budget."""
|
|
pipeline = ContextPipeline()
|
|
budget = TokenBudget(
|
|
total=1000,
|
|
system=200,
|
|
task=200,
|
|
knowledge=400,
|
|
conversation=100,
|
|
tools=50,
|
|
response_reserve=50,
|
|
)
|
|
|
|
contexts = [
|
|
SystemContext(content="System prompt", source="system"),
|
|
TaskContext(content="Task description", source="task"),
|
|
]
|
|
|
|
result = await pipeline.assemble(
|
|
contexts=contexts,
|
|
query="test",
|
|
model="gpt-4",
|
|
custom_budget=budget,
|
|
)
|
|
|
|
assert result.context_count >= 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_assemble_with_max_tokens(self) -> None:
|
|
"""Test assembling with max_tokens limit."""
|
|
pipeline = ContextPipeline()
|
|
|
|
contexts = [
|
|
SystemContext(content="System prompt", source="system"),
|
|
]
|
|
|
|
result = await pipeline.assemble(
|
|
contexts=contexts,
|
|
query="test",
|
|
model="gpt-4",
|
|
max_tokens=5000,
|
|
)
|
|
|
|
assert "budget" in result.metadata
|
|
assert result.metadata["budget"]["total"] == 5000
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_assemble_format_output(self) -> None:
|
|
"""Test formatted vs unformatted output."""
|
|
pipeline = ContextPipeline()
|
|
|
|
contexts = [
|
|
SystemContext(content="System prompt", source="system"),
|
|
]
|
|
|
|
# Formatted (default)
|
|
result_formatted = await pipeline.assemble(
|
|
contexts=contexts,
|
|
query="test",
|
|
model="claude-3-sonnet",
|
|
format_output=True,
|
|
)
|
|
|
|
# Unformatted
|
|
result_raw = await pipeline.assemble(
|
|
contexts=contexts,
|
|
query="test",
|
|
model="claude-3-sonnet",
|
|
format_output=False,
|
|
)
|
|
|
|
# Formatted should have XML tags for Claude
|
|
assert "<system_instructions>" in result_formatted.content
|
|
# Raw should not
|
|
assert "<system_instructions>" not in result_raw.content
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_assemble_metrics(self) -> None:
|
|
"""Test that metrics are populated."""
|
|
pipeline = ContextPipeline()
|
|
|
|
contexts = [
|
|
SystemContext(content="System", source="system"),
|
|
TaskContext(content="Task", source="task"),
|
|
KnowledgeContext(
|
|
content="Knowledge",
|
|
source="docs",
|
|
relevance_score=0.9,
|
|
),
|
|
]
|
|
|
|
result = await pipeline.assemble(
|
|
contexts=contexts,
|
|
query="test",
|
|
model="claude-3-sonnet",
|
|
)
|
|
|
|
assert "metrics" in result.metadata
|
|
metrics = result.metadata["metrics"]
|
|
|
|
assert metrics["total_contexts"] == 3
|
|
assert metrics["assembly_time_ms"] > 0
|
|
assert "scoring_time_ms" in metrics
|
|
assert "formatting_time_ms" in metrics
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_assemble_with_compression_disabled(self) -> None:
|
|
"""Test assembling with compression disabled."""
|
|
pipeline = ContextPipeline()
|
|
|
|
contexts = [
|
|
KnowledgeContext(content="A" * 1000, source="docs"),
|
|
]
|
|
|
|
result = await pipeline.assemble(
|
|
contexts=contexts,
|
|
query="test",
|
|
model="gpt-4",
|
|
compress=False,
|
|
)
|
|
|
|
# Should still work, just no compression applied
|
|
assert result.context_count >= 0
|
|
|
|
|
|
class TestContextPipelineFormatting:
|
|
"""Tests for context formatting."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_format_claude_uses_xml(self) -> None:
|
|
"""Test that Claude models use XML formatting."""
|
|
pipeline = ContextPipeline()
|
|
|
|
contexts = [
|
|
SystemContext(content="System prompt", source="system"),
|
|
TaskContext(content="Task", source="task"),
|
|
KnowledgeContext(
|
|
content="Knowledge",
|
|
source="docs",
|
|
relevance_score=0.9,
|
|
),
|
|
]
|
|
|
|
result = await pipeline.assemble(
|
|
contexts=contexts,
|
|
query="test",
|
|
model="claude-3-sonnet",
|
|
)
|
|
|
|
# Claude should have XML tags
|
|
assert "<system_instructions>" in result.content or result.context_count == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_format_openai_uses_markdown(self) -> None:
|
|
"""Test that OpenAI models use markdown formatting."""
|
|
pipeline = ContextPipeline()
|
|
|
|
contexts = [
|
|
TaskContext(content="Task description", source="task"),
|
|
]
|
|
|
|
result = await pipeline.assemble(
|
|
contexts=contexts,
|
|
query="test",
|
|
model="gpt-4",
|
|
)
|
|
|
|
# OpenAI should have markdown headers
|
|
if result.context_count > 0 and "Task" in result.content:
|
|
assert "## Current Task" in result.content
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_format_knowledge_claude(self) -> None:
|
|
"""Test knowledge formatting for Claude."""
|
|
pipeline = ContextPipeline()
|
|
|
|
contexts = [
|
|
KnowledgeContext(
|
|
content="Document content here",
|
|
source="docs/file.md",
|
|
relevance_score=0.9,
|
|
),
|
|
]
|
|
|
|
result = await pipeline.assemble(
|
|
contexts=contexts,
|
|
query="test",
|
|
model="claude-3-sonnet",
|
|
)
|
|
|
|
if result.context_count > 0:
|
|
assert "<reference_documents>" in result.content
|
|
assert "<document" in result.content
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_format_conversation(self) -> None:
|
|
"""Test conversation formatting."""
|
|
pipeline = ContextPipeline()
|
|
|
|
contexts = [
|
|
ConversationContext(
|
|
content="Hello, how are you?",
|
|
source="chat",
|
|
role=MessageRole.USER,
|
|
metadata={"role": "user"},
|
|
),
|
|
ConversationContext(
|
|
content="I'm doing great!",
|
|
source="chat",
|
|
role=MessageRole.ASSISTANT,
|
|
metadata={"role": "assistant"},
|
|
),
|
|
]
|
|
|
|
result = await pipeline.assemble(
|
|
contexts=contexts,
|
|
query="test",
|
|
model="claude-3-sonnet",
|
|
)
|
|
|
|
if result.context_count > 0:
|
|
assert "<conversation_history>" in result.content
|
|
assert (
|
|
'<message role="user">' in result.content
|
|
or 'role="user"' in result.content
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_format_tool_results(self) -> None:
|
|
"""Test tool result formatting."""
|
|
pipeline = ContextPipeline()
|
|
|
|
contexts = [
|
|
ToolContext(
|
|
content="Tool output here",
|
|
source="tool",
|
|
metadata={"tool_name": "search"},
|
|
),
|
|
]
|
|
|
|
result = await pipeline.assemble(
|
|
contexts=contexts,
|
|
query="test",
|
|
model="claude-3-sonnet",
|
|
)
|
|
|
|
if result.context_count > 0:
|
|
assert "<tool_results>" in result.content
|
|
|
|
|
|
class TestContextPipelineIntegration:
|
|
"""Integration tests for full pipeline."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_full_pipeline_workflow(self) -> None:
|
|
"""Test complete pipeline workflow."""
|
|
pipeline = ContextPipeline()
|
|
|
|
# Create realistic context mix
|
|
contexts = [
|
|
SystemContext(
|
|
content="You are an expert Python developer.",
|
|
source="system",
|
|
),
|
|
TaskContext(
|
|
content="Implement a user authentication system.",
|
|
source="task:AUTH-123",
|
|
),
|
|
KnowledgeContext(
|
|
content="JWT tokens provide stateless authentication...",
|
|
source="docs/auth/jwt.md",
|
|
relevance_score=0.9,
|
|
),
|
|
KnowledgeContext(
|
|
content="OAuth 2.0 is an authorization framework...",
|
|
source="docs/auth/oauth.md",
|
|
relevance_score=0.7,
|
|
),
|
|
ConversationContext(
|
|
content="Can you help me implement JWT auth?",
|
|
source="chat",
|
|
role=MessageRole.USER,
|
|
metadata={"role": "user"},
|
|
),
|
|
]
|
|
|
|
result = await pipeline.assemble(
|
|
contexts=contexts,
|
|
query="implement JWT authentication",
|
|
model="claude-3-sonnet",
|
|
)
|
|
|
|
# Verify result
|
|
assert isinstance(result, AssembledContext)
|
|
assert result.context_count > 0
|
|
assert result.total_tokens > 0
|
|
assert result.assembly_time_ms > 0
|
|
assert result.model == "claude-3-sonnet"
|
|
assert len(result.content) > 0
|
|
|
|
# Verify metrics
|
|
assert "metrics" in result.metadata
|
|
assert "query" in result.metadata
|
|
assert "budget" in result.metadata
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_context_type_ordering(self) -> None:
|
|
"""Test that contexts are ordered by type correctly."""
|
|
pipeline = ContextPipeline()
|
|
|
|
# Add in random order
|
|
contexts = [
|
|
KnowledgeContext(content="Knowledge", source="docs", relevance_score=0.9),
|
|
ToolContext(content="Tool", source="tool", metadata={"tool_name": "test"}),
|
|
SystemContext(content="System", source="system"),
|
|
ConversationContext(
|
|
content="Chat",
|
|
source="chat",
|
|
role=MessageRole.USER,
|
|
metadata={"role": "user"},
|
|
),
|
|
TaskContext(content="Task", source="task"),
|
|
]
|
|
|
|
result = await pipeline.assemble(
|
|
contexts=contexts,
|
|
query="test",
|
|
model="claude-3-sonnet",
|
|
)
|
|
|
|
# For Claude, verify order: System -> Task -> Knowledge -> Conversation -> Tool
|
|
content = result.content
|
|
if result.context_count > 0:
|
|
# Find positions (if they exist)
|
|
system_pos = content.find("system_instructions")
|
|
task_pos = content.find("current_task")
|
|
knowledge_pos = content.find("reference_documents")
|
|
conversation_pos = content.find("conversation_history")
|
|
tool_pos = content.find("tool_results")
|
|
|
|
# Verify ordering (only check if both exist)
|
|
if system_pos >= 0 and task_pos >= 0:
|
|
assert system_pos < task_pos
|
|
if task_pos >= 0 and knowledge_pos >= 0:
|
|
assert task_pos < knowledge_pos
|
|
if knowledge_pos >= 0 and conversation_pos >= 0:
|
|
assert knowledge_pos < conversation_pos
|
|
if conversation_pos >= 0 and tool_pos >= 0:
|
|
assert conversation_pos < tool_pos
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_excluded_contexts_tracked(self) -> None:
|
|
"""Test that excluded contexts are tracked in result."""
|
|
pipeline = ContextPipeline()
|
|
|
|
# Create many contexts to force some exclusions
|
|
contexts = [
|
|
KnowledgeContext(
|
|
content="A" * 500, # Large content
|
|
source=f"docs/{i}",
|
|
relevance_score=0.1 + (i * 0.05),
|
|
)
|
|
for i in range(10)
|
|
]
|
|
|
|
result = await pipeline.assemble(
|
|
contexts=contexts,
|
|
query="test",
|
|
model="gpt-4", # Smaller context window
|
|
max_tokens=1000, # Limited budget
|
|
)
|
|
|
|
# Should have excluded some
|
|
assert result.excluded_count >= 0
|
|
assert result.context_count + result.excluded_count <= len(contexts)
|