"""Tests for context assembly pipeline.""" from datetime import UTC, datetime import pytest from app.services.context.assembly import ContextPipeline, PipelineMetrics from app.services.context.budget import BudgetAllocator, TokenBudget from app.services.context.types import ( AssembledContext, ContextType, ConversationContext, KnowledgeContext, MessageRole, SystemContext, TaskContext, ToolContext, ) class TestPipelineMetrics: """Tests for PipelineMetrics dataclass.""" def test_creation(self) -> None: """Test metrics creation.""" metrics = PipelineMetrics() assert metrics.total_contexts == 0 assert metrics.selected_contexts == 0 assert metrics.assembly_time_ms == 0.0 def test_to_dict(self) -> None: """Test conversion to dictionary.""" metrics = PipelineMetrics( total_contexts=10, selected_contexts=8, excluded_contexts=2, total_tokens=500, assembly_time_ms=25.5, ) metrics.end_time = datetime.now(UTC) data = metrics.to_dict() assert data["total_contexts"] == 10 assert data["selected_contexts"] == 8 assert data["excluded_contexts"] == 2 assert data["total_tokens"] == 500 assert data["assembly_time_ms"] == 25.5 assert "start_time" in data assert "end_time" in data class TestContextPipeline: """Tests for ContextPipeline.""" def test_creation(self) -> None: """Test pipeline creation.""" pipeline = ContextPipeline() assert pipeline._calculator is not None assert pipeline._scorer is not None assert pipeline._ranker is not None assert pipeline._compressor is not None assert pipeline._allocator is not None @pytest.mark.asyncio async def test_assemble_empty_contexts(self) -> None: """Test assembling empty context list.""" pipeline = ContextPipeline() result = await pipeline.assemble( contexts=[], query="test query", model="claude-3-sonnet", ) assert isinstance(result, AssembledContext) assert result.context_count == 0 assert result.total_tokens == 0 @pytest.mark.asyncio async def test_assemble_single_context(self) -> None: """Test assembling single context.""" pipeline = ContextPipeline() contexts = [ SystemContext( content="You are a helpful assistant.", source="system", ) ] result = await pipeline.assemble( contexts=contexts, query="help me", model="claude-3-sonnet", ) assert result.context_count == 1 assert result.total_tokens > 0 assert "helpful assistant" in result.content @pytest.mark.asyncio async def test_assemble_multiple_types(self) -> None: """Test assembling multiple context types.""" pipeline = ContextPipeline() contexts = [ SystemContext( content="You are a coding assistant.", source="system", ), TaskContext( content="Implement a login feature.", source="task", ), KnowledgeContext( content="Authentication best practices include...", source="docs/auth.md", relevance_score=0.8, ), ] result = await pipeline.assemble( contexts=contexts, query="implement login", model="claude-3-sonnet", ) assert result.context_count >= 1 assert result.total_tokens > 0 @pytest.mark.asyncio async def test_assemble_with_custom_budget(self) -> None: """Test assembling with custom budget.""" pipeline = ContextPipeline() budget = TokenBudget( total=1000, system=200, task=200, knowledge=400, conversation=100, tools=50, response_reserve=50, ) contexts = [ SystemContext(content="System prompt", source="system"), TaskContext(content="Task description", source="task"), ] result = await pipeline.assemble( contexts=contexts, query="test", model="gpt-4", custom_budget=budget, ) assert result.context_count >= 1 @pytest.mark.asyncio async def test_assemble_with_max_tokens(self) -> None: """Test assembling with max_tokens limit.""" pipeline = ContextPipeline() contexts = [ SystemContext(content="System prompt", source="system"), ] result = await pipeline.assemble( contexts=contexts, query="test", model="gpt-4", max_tokens=5000, ) assert "budget" in result.metadata assert result.metadata["budget"]["total"] == 5000 @pytest.mark.asyncio async def test_assemble_format_output(self) -> None: """Test formatted vs unformatted output.""" pipeline = ContextPipeline() contexts = [ SystemContext(content="System prompt", source="system"), ] # Formatted (default) result_formatted = await pipeline.assemble( contexts=contexts, query="test", model="claude-3-sonnet", format_output=True, ) # Unformatted result_raw = await pipeline.assemble( contexts=contexts, query="test", model="claude-3-sonnet", format_output=False, ) # Formatted should have XML tags for Claude assert "" in result_formatted.content # Raw should not assert "" not in result_raw.content @pytest.mark.asyncio async def test_assemble_metrics(self) -> None: """Test that metrics are populated.""" pipeline = ContextPipeline() contexts = [ SystemContext(content="System", source="system"), TaskContext(content="Task", source="task"), KnowledgeContext( content="Knowledge", source="docs", relevance_score=0.9, ), ] result = await pipeline.assemble( contexts=contexts, query="test", model="claude-3-sonnet", ) assert "metrics" in result.metadata metrics = result.metadata["metrics"] assert metrics["total_contexts"] == 3 assert metrics["assembly_time_ms"] > 0 assert "scoring_time_ms" in metrics assert "formatting_time_ms" in metrics @pytest.mark.asyncio async def test_assemble_with_compression_disabled(self) -> None: """Test assembling with compression disabled.""" pipeline = ContextPipeline() contexts = [ KnowledgeContext(content="A" * 1000, source="docs"), ] result = await pipeline.assemble( contexts=contexts, query="test", model="gpt-4", compress=False, ) # Should still work, just no compression applied assert result.context_count >= 0 class TestContextPipelineFormatting: """Tests for context formatting.""" @pytest.mark.asyncio async def test_format_claude_uses_xml(self) -> None: """Test that Claude models use XML formatting.""" pipeline = ContextPipeline() contexts = [ SystemContext(content="System prompt", source="system"), TaskContext(content="Task", source="task"), KnowledgeContext( content="Knowledge", source="docs", relevance_score=0.9, ), ] result = await pipeline.assemble( contexts=contexts, query="test", model="claude-3-sonnet", ) # Claude should have XML tags assert "" in result.content or result.context_count == 0 @pytest.mark.asyncio async def test_format_openai_uses_markdown(self) -> None: """Test that OpenAI models use markdown formatting.""" pipeline = ContextPipeline() contexts = [ TaskContext(content="Task description", source="task"), ] result = await pipeline.assemble( contexts=contexts, query="test", model="gpt-4", ) # OpenAI should have markdown headers if result.context_count > 0 and "Task" in result.content: assert "## Current Task" in result.content @pytest.mark.asyncio async def test_format_knowledge_claude(self) -> None: """Test knowledge formatting for Claude.""" pipeline = ContextPipeline() contexts = [ KnowledgeContext( content="Document content here", source="docs/file.md", relevance_score=0.9, ), ] result = await pipeline.assemble( contexts=contexts, query="test", model="claude-3-sonnet", ) if result.context_count > 0: assert "" in result.content assert " None: """Test conversation formatting.""" pipeline = ContextPipeline() contexts = [ ConversationContext( content="Hello, how are you?", source="chat", role=MessageRole.USER, metadata={"role": "user"}, ), ConversationContext( content="I'm doing great!", source="chat", role=MessageRole.ASSISTANT, metadata={"role": "assistant"}, ), ] result = await pipeline.assemble( contexts=contexts, query="test", model="claude-3-sonnet", ) if result.context_count > 0: assert "" in result.content assert '' in result.content or 'role="user"' in result.content @pytest.mark.asyncio async def test_format_tool_results(self) -> None: """Test tool result formatting.""" pipeline = ContextPipeline() contexts = [ ToolContext( content="Tool output here", source="tool", metadata={"tool_name": "search"}, ), ] result = await pipeline.assemble( contexts=contexts, query="test", model="claude-3-sonnet", ) if result.context_count > 0: assert "" in result.content class TestContextPipelineIntegration: """Integration tests for full pipeline.""" @pytest.mark.asyncio async def test_full_pipeline_workflow(self) -> None: """Test complete pipeline workflow.""" pipeline = ContextPipeline() # Create realistic context mix contexts = [ SystemContext( content="You are an expert Python developer.", source="system", ), TaskContext( content="Implement a user authentication system.", source="task:AUTH-123", ), KnowledgeContext( content="JWT tokens provide stateless authentication...", source="docs/auth/jwt.md", relevance_score=0.9, ), KnowledgeContext( content="OAuth 2.0 is an authorization framework...", source="docs/auth/oauth.md", relevance_score=0.7, ), ConversationContext( content="Can you help me implement JWT auth?", source="chat", role=MessageRole.USER, metadata={"role": "user"}, ), ] result = await pipeline.assemble( contexts=contexts, query="implement JWT authentication", model="claude-3-sonnet", ) # Verify result assert isinstance(result, AssembledContext) assert result.context_count > 0 assert result.total_tokens > 0 assert result.assembly_time_ms > 0 assert result.model == "claude-3-sonnet" assert len(result.content) > 0 # Verify metrics assert "metrics" in result.metadata assert "query" in result.metadata assert "budget" in result.metadata @pytest.mark.asyncio async def test_context_type_ordering(self) -> None: """Test that contexts are ordered by type correctly.""" pipeline = ContextPipeline() # Add in random order contexts = [ KnowledgeContext(content="Knowledge", source="docs", relevance_score=0.9), ToolContext(content="Tool", source="tool", metadata={"tool_name": "test"}), SystemContext(content="System", source="system"), ConversationContext( content="Chat", source="chat", role=MessageRole.USER, metadata={"role": "user"}, ), TaskContext(content="Task", source="task"), ] result = await pipeline.assemble( contexts=contexts, query="test", model="claude-3-sonnet", ) # For Claude, verify order: System -> Task -> Knowledge -> Conversation -> Tool content = result.content if result.context_count > 0: # Find positions (if they exist) system_pos = content.find("system_instructions") task_pos = content.find("current_task") knowledge_pos = content.find("reference_documents") conversation_pos = content.find("conversation_history") tool_pos = content.find("tool_results") # Verify ordering (only check if both exist) if system_pos >= 0 and task_pos >= 0: assert system_pos < task_pos if task_pos >= 0 and knowledge_pos >= 0: assert task_pos < knowledge_pos @pytest.mark.asyncio async def test_excluded_contexts_tracked(self) -> None: """Test that excluded contexts are tracked in result.""" pipeline = ContextPipeline() # Create many contexts to force some exclusions contexts = [ KnowledgeContext( content="A" * 500, # Large content source=f"docs/{i}", relevance_score=0.1 + (i * 0.05), ) for i in range(10) ] result = await pipeline.assemble( contexts=contexts, query="test", model="gpt-4", # Smaller context window max_tokens=1000, # Limited budget ) # Should have excluded some assert result.excluded_count >= 0 assert result.context_count + result.excluded_count <= len(contexts)