"""Tests for context ranking module.""" from datetime import UTC, datetime import pytest from app.services.context.budget import BudgetAllocator, TokenBudget from app.services.context.prioritization import ContextRanker, RankingResult from app.services.context.scoring import CompositeScorer, ScoredContext from app.services.context.types import ( ContextPriority, ContextType, ConversationContext, KnowledgeContext, MessageRole, SystemContext, TaskContext, ) class TestRankingResult: """Tests for RankingResult dataclass.""" def test_creation(self) -> None: """Test RankingResult creation.""" ctx = TaskContext(content="Test", source="task") scored = ScoredContext(context=ctx, composite_score=0.8) result = RankingResult( selected=[scored], excluded=[], total_tokens=100, selection_stats={"total": 1}, ) assert len(result.selected) == 1 assert result.total_tokens == 100 def test_selected_contexts_property(self) -> None: """Test selected_contexts property extracts contexts.""" ctx1 = TaskContext(content="Test 1", source="task") ctx2 = TaskContext(content="Test 2", source="task") scored1 = ScoredContext(context=ctx1, composite_score=0.8) scored2 = ScoredContext(context=ctx2, composite_score=0.6) result = RankingResult( selected=[scored1, scored2], excluded=[], total_tokens=200, ) selected = result.selected_contexts assert len(selected) == 2 assert ctx1 in selected assert ctx2 in selected class TestContextRanker: """Tests for ContextRanker.""" def test_creation(self) -> None: """Test ranker creation.""" ranker = ContextRanker() assert ranker._scorer is not None assert ranker._calculator is not None def test_creation_with_scorer(self) -> None: """Test ranker creation with custom scorer.""" scorer = CompositeScorer(relevance_weight=0.8) ranker = ContextRanker(scorer=scorer) assert ranker._scorer is scorer @pytest.mark.asyncio async def test_rank_empty_contexts(self) -> None: """Test ranking empty context list.""" ranker = ContextRanker() allocator = BudgetAllocator() budget = allocator.create_budget(10000) result = await ranker.rank([], "query", budget) assert len(result.selected) == 0 assert len(result.excluded) == 0 assert result.total_tokens == 0 @pytest.mark.asyncio async def test_rank_single_context_fits(self) -> None: """Test ranking single context that fits budget.""" ranker = ContextRanker() allocator = BudgetAllocator() budget = allocator.create_budget(10000) context = KnowledgeContext( content="Short content", source="docs", relevance_score=0.8, ) result = await ranker.rank([context], "query", budget) assert len(result.selected) == 1 assert len(result.excluded) == 0 assert result.selected[0].context is context @pytest.mark.asyncio async def test_rank_respects_budget(self) -> None: """Test that ranking respects token budget.""" ranker = ContextRanker() # Create a very small budget budget = TokenBudget( total=100, knowledge=50, # Only 50 tokens for knowledge ) # Create contexts that exceed budget contexts = [ KnowledgeContext( content="A" * 200, # ~50 tokens source="docs", relevance_score=0.9, ), KnowledgeContext( content="B" * 200, # ~50 tokens source="docs", relevance_score=0.8, ), KnowledgeContext( content="C" * 200, # ~50 tokens source="docs", relevance_score=0.7, ), ] result = await ranker.rank(contexts, "query", budget) # Not all should fit assert len(result.selected) < len(contexts) assert len(result.excluded) > 0 @pytest.mark.asyncio async def test_rank_selects_highest_scores(self) -> None: """Test that ranking selects highest scored contexts.""" ranker = ContextRanker() allocator = BudgetAllocator() budget = allocator.create_budget(1000) # Small budget for knowledge budget.knowledge = 100 contexts = [ KnowledgeContext( content="Low score", source="docs", relevance_score=0.2, ), KnowledgeContext( content="High score", source="docs", relevance_score=0.9, ), KnowledgeContext( content="Medium score", source="docs", relevance_score=0.5, ), ] result = await ranker.rank(contexts, "query", budget) # Should have selected some if result.selected: # The highest scored should be selected first scores = [s.composite_score for s in result.selected] assert scores == sorted(scores, reverse=True) @pytest.mark.asyncio async def test_rank_critical_priority_always_included(self) -> None: """Test that CRITICAL priority contexts are always included.""" ranker = ContextRanker() # Very small budget budget = TokenBudget( total=100, system=10, # Very small knowledge=10, ) contexts = [ SystemContext( content="Critical system prompt that must be included", source="system", priority=ContextPriority.CRITICAL.value, ), KnowledgeContext( content="Optional knowledge", source="docs", relevance_score=0.9, ), ] result = await ranker.rank(contexts, "query", budget, ensure_required=True) # Critical context should be in selected selected_priorities = [s.context.priority for s in result.selected] assert ContextPriority.CRITICAL.value in selected_priorities @pytest.mark.asyncio async def test_rank_without_ensure_required(self) -> None: """Test ranking without ensuring required contexts.""" ranker = ContextRanker() budget = TokenBudget( total=100, system=50, knowledge=50, ) contexts = [ SystemContext( content="A" * 500, # Large content source="system", priority=ContextPriority.CRITICAL.value, ), KnowledgeContext( content="Short", source="docs", relevance_score=0.9, ), ] result = await ranker.rank( contexts, "query", budget, ensure_required=False ) # Without ensure_required, CRITICAL contexts can be excluded # if budget doesn't allow assert len(result.selected) + len(result.excluded) == len(contexts) @pytest.mark.asyncio async def test_rank_selection_stats(self) -> None: """Test that ranking provides useful statistics.""" ranker = ContextRanker() allocator = BudgetAllocator() budget = allocator.create_budget(10000) contexts = [ KnowledgeContext( content="Knowledge 1", source="docs", relevance_score=0.8 ), KnowledgeContext( content="Knowledge 2", source="docs", relevance_score=0.6 ), TaskContext(content="Task", source="task"), ] result = await ranker.rank(contexts, "query", budget) stats = result.selection_stats assert "total_contexts" in stats assert "selected_count" in stats assert "excluded_count" in stats assert "total_tokens" in stats assert "by_type" in stats @pytest.mark.asyncio async def test_rank_simple(self) -> None: """Test simple ranking without budget per type.""" ranker = ContextRanker() contexts = [ KnowledgeContext( content="A", source="docs", relevance_score=0.9, ), KnowledgeContext( content="B", source="docs", relevance_score=0.7, ), KnowledgeContext( content="C", source="docs", relevance_score=0.5, ), ] result = await ranker.rank_simple(contexts, "query", max_tokens=1000) # Should return contexts sorted by score assert len(result) > 0 @pytest.mark.asyncio async def test_rank_simple_respects_max_tokens(self) -> None: """Test that simple ranking respects max tokens.""" ranker = ContextRanker() # Create contexts with known token counts contexts = [ KnowledgeContext( content="A" * 100, # ~25 tokens source="docs", relevance_score=0.9, ), KnowledgeContext( content="B" * 100, source="docs", relevance_score=0.8, ), KnowledgeContext( content="C" * 100, source="docs", relevance_score=0.7, ), ] # Very small limit result = await ranker.rank_simple(contexts, "query", max_tokens=30) # Should only fit a limited number assert len(result) <= len(contexts) @pytest.mark.asyncio async def test_rank_simple_empty(self) -> None: """Test simple ranking with empty list.""" ranker = ContextRanker() result = await ranker.rank_simple([], "query", max_tokens=1000) assert result == [] @pytest.mark.asyncio async def test_rerank_for_diversity(self) -> None: """Test diversity reranking.""" ranker = ContextRanker() # Create scored contexts from same source contexts = [ ScoredContext( context=KnowledgeContext( content=f"Content {i}", source="same-source", relevance_score=0.9 - i * 0.1, ), composite_score=0.9 - i * 0.1, ) for i in range(5) ] # Limit to 2 per source result = await ranker.rerank_for_diversity(contexts, max_per_source=2) assert len(result) == 5 # First 2 should be from same source, rest deferred first_two_sources = [r.context.source for r in result[:2]] assert all(s == "same-source" for s in first_two_sources) @pytest.mark.asyncio async def test_rerank_for_diversity_multiple_sources(self) -> None: """Test diversity reranking with multiple sources.""" ranker = ContextRanker() contexts = [ ScoredContext( context=KnowledgeContext( content="Source A - 1", source="source-a", relevance_score=0.9, ), composite_score=0.9, ), ScoredContext( context=KnowledgeContext( content="Source A - 2", source="source-a", relevance_score=0.8, ), composite_score=0.8, ), ScoredContext( context=KnowledgeContext( content="Source B - 1", source="source-b", relevance_score=0.7, ), composite_score=0.7, ), ScoredContext( context=KnowledgeContext( content="Source A - 3", source="source-a", relevance_score=0.6, ), composite_score=0.6, ), ] result = await ranker.rerank_for_diversity(contexts, max_per_source=2) # Should not have more than 2 from source-a in first 3 source_a_in_first_3 = sum( 1 for r in result[:3] if r.context.source == "source-a" ) assert source_a_in_first_3 <= 2 @pytest.mark.asyncio async def test_token_counts_set(self) -> None: """Test that token counts are set during ranking.""" ranker = ContextRanker() allocator = BudgetAllocator() budget = allocator.create_budget(10000) context = KnowledgeContext( content="Test content", source="docs", relevance_score=0.8, ) # Token count should be None initially assert context.token_count is None await ranker.rank([context], "query", budget) # Token count should be set after ranking assert context.token_count is not None assert context.token_count > 0 class TestContextRankerIntegration: """Integration tests for context ranking.""" @pytest.mark.asyncio async def test_full_ranking_workflow(self) -> None: """Test complete ranking workflow.""" ranker = ContextRanker() allocator = BudgetAllocator() budget = allocator.create_budget(10000) # Create diverse context types contexts = [ SystemContext( content="You are a helpful assistant.", source="system", priority=ContextPriority.CRITICAL.value, ), TaskContext( content="Help the user with their coding question.", source="task", priority=ContextPriority.HIGH.value, ), KnowledgeContext( content="Python is a programming language.", source="docs/python.md", relevance_score=0.9, ), KnowledgeContext( content="Java is also a programming language.", source="docs/java.md", relevance_score=0.4, ), ConversationContext( content="Hello, can you help me?", source="chat", role=MessageRole.USER, ), ] result = await ranker.rank(contexts, "Python help", budget) # System (CRITICAL) should be included selected_types = [s.context.get_type() for s in result.selected] assert ContextType.SYSTEM in selected_types # Stats should be populated assert result.selection_stats["total_contexts"] == 5 @pytest.mark.asyncio async def test_ranking_preserves_context_order_by_score(self) -> None: """Test that ranking orders by score correctly.""" ranker = ContextRanker() allocator = BudgetAllocator() budget = allocator.create_budget(100000) contexts = [ KnowledgeContext( content="Low", source="docs", relevance_score=0.2, ), KnowledgeContext( content="High", source="docs", relevance_score=0.9, ), KnowledgeContext( content="Medium", source="docs", relevance_score=0.5, ), ] result = await ranker.rank(contexts, "query", budget) # Verify ordering is by score scores = [s.composite_score for s in result.selected] assert scores == sorted(scores, reverse=True)