Files
fast-next-template/backend/tests/services/context/test_ranker.py
Felipe Cardoso 2bea057fb1 chore(context): refactor for consistency, optimize formatting, and simplify logic
- Cleaned up unnecessary comments in `__all__` definitions for better readability.
- Adjusted indentation and formatting across modules for improved clarity (e.g., long lines, logical grouping).
- Simplified conditional expressions and inline comments for context scoring and ranking.
- Replaced some hard-coded values with type-safe annotations (e.g., `ClassVar`).
- Removed unused imports and ensured consistent usage across test files.
- Updated `test_score_not_cached_on_context` to clarify caching behavior.
- Improved truncation strategy logic and marker handling.
2026-01-04 15:23:14 +01:00

500 lines
16 KiB
Python

"""Tests for context ranking module."""
import pytest
from app.services.context.budget import BudgetAllocator, TokenBudget
from app.services.context.prioritization import ContextRanker, RankingResult
from app.services.context.scoring import CompositeScorer, ScoredContext
from app.services.context.types import (
ContextPriority,
ContextType,
ConversationContext,
KnowledgeContext,
MessageRole,
SystemContext,
TaskContext,
)
class TestRankingResult:
"""Tests for RankingResult dataclass."""
def test_creation(self) -> None:
"""Test RankingResult creation."""
ctx = TaskContext(content="Test", source="task")
scored = ScoredContext(context=ctx, composite_score=0.8)
result = RankingResult(
selected=[scored],
excluded=[],
total_tokens=100,
selection_stats={"total": 1},
)
assert len(result.selected) == 1
assert result.total_tokens == 100
def test_selected_contexts_property(self) -> None:
"""Test selected_contexts property extracts contexts."""
ctx1 = TaskContext(content="Test 1", source="task")
ctx2 = TaskContext(content="Test 2", source="task")
scored1 = ScoredContext(context=ctx1, composite_score=0.8)
scored2 = ScoredContext(context=ctx2, composite_score=0.6)
result = RankingResult(
selected=[scored1, scored2],
excluded=[],
total_tokens=200,
)
selected = result.selected_contexts
assert len(selected) == 2
assert ctx1 in selected
assert ctx2 in selected
class TestContextRanker:
"""Tests for ContextRanker."""
def test_creation(self) -> None:
"""Test ranker creation."""
ranker = ContextRanker()
assert ranker._scorer is not None
assert ranker._calculator is not None
def test_creation_with_scorer(self) -> None:
"""Test ranker creation with custom scorer."""
scorer = CompositeScorer(relevance_weight=0.8)
ranker = ContextRanker(scorer=scorer)
assert ranker._scorer is scorer
@pytest.mark.asyncio
async def test_rank_empty_contexts(self) -> None:
"""Test ranking empty context list."""
ranker = ContextRanker()
allocator = BudgetAllocator()
budget = allocator.create_budget(10000)
result = await ranker.rank([], "query", budget)
assert len(result.selected) == 0
assert len(result.excluded) == 0
assert result.total_tokens == 0
@pytest.mark.asyncio
async def test_rank_single_context_fits(self) -> None:
"""Test ranking single context that fits budget."""
ranker = ContextRanker()
allocator = BudgetAllocator()
budget = allocator.create_budget(10000)
context = KnowledgeContext(
content="Short content",
source="docs",
relevance_score=0.8,
)
result = await ranker.rank([context], "query", budget)
assert len(result.selected) == 1
assert len(result.excluded) == 0
assert result.selected[0].context is context
@pytest.mark.asyncio
async def test_rank_respects_budget(self) -> None:
"""Test that ranking respects token budget."""
ranker = ContextRanker()
# Create a very small budget
budget = TokenBudget(
total=100,
knowledge=50, # Only 50 tokens for knowledge
)
# Create contexts that exceed budget
contexts = [
KnowledgeContext(
content="A" * 200, # ~50 tokens
source="docs",
relevance_score=0.9,
),
KnowledgeContext(
content="B" * 200, # ~50 tokens
source="docs",
relevance_score=0.8,
),
KnowledgeContext(
content="C" * 200, # ~50 tokens
source="docs",
relevance_score=0.7,
),
]
result = await ranker.rank(contexts, "query", budget)
# Not all should fit
assert len(result.selected) < len(contexts)
assert len(result.excluded) > 0
@pytest.mark.asyncio
async def test_rank_selects_highest_scores(self) -> None:
"""Test that ranking selects highest scored contexts."""
ranker = ContextRanker()
allocator = BudgetAllocator()
budget = allocator.create_budget(1000)
# Small budget for knowledge
budget.knowledge = 100
contexts = [
KnowledgeContext(
content="Low score",
source="docs",
relevance_score=0.2,
),
KnowledgeContext(
content="High score",
source="docs",
relevance_score=0.9,
),
KnowledgeContext(
content="Medium score",
source="docs",
relevance_score=0.5,
),
]
result = await ranker.rank(contexts, "query", budget)
# Should have selected some
if result.selected:
# The highest scored should be selected first
scores = [s.composite_score for s in result.selected]
assert scores == sorted(scores, reverse=True)
@pytest.mark.asyncio
async def test_rank_critical_priority_always_included(self) -> None:
"""Test that CRITICAL priority contexts are always included."""
ranker = ContextRanker()
# Very small budget
budget = TokenBudget(
total=100,
system=10, # Very small
knowledge=10,
)
contexts = [
SystemContext(
content="Critical system prompt that must be included",
source="system",
priority=ContextPriority.CRITICAL.value,
),
KnowledgeContext(
content="Optional knowledge",
source="docs",
relevance_score=0.9,
),
]
result = await ranker.rank(contexts, "query", budget, ensure_required=True)
# Critical context should be in selected
selected_priorities = [s.context.priority for s in result.selected]
assert ContextPriority.CRITICAL.value in selected_priorities
@pytest.mark.asyncio
async def test_rank_without_ensure_required(self) -> None:
"""Test ranking without ensuring required contexts."""
ranker = ContextRanker()
budget = TokenBudget(
total=100,
system=50,
knowledge=50,
)
contexts = [
SystemContext(
content="A" * 500, # Large content
source="system",
priority=ContextPriority.CRITICAL.value,
),
KnowledgeContext(
content="Short",
source="docs",
relevance_score=0.9,
),
]
result = await ranker.rank(contexts, "query", budget, ensure_required=False)
# Without ensure_required, CRITICAL contexts can be excluded
# if budget doesn't allow
assert len(result.selected) + len(result.excluded) == len(contexts)
@pytest.mark.asyncio
async def test_rank_selection_stats(self) -> None:
"""Test that ranking provides useful statistics."""
ranker = ContextRanker()
allocator = BudgetAllocator()
budget = allocator.create_budget(10000)
contexts = [
KnowledgeContext(content="Knowledge 1", source="docs", relevance_score=0.8),
KnowledgeContext(content="Knowledge 2", source="docs", relevance_score=0.6),
TaskContext(content="Task", source="task"),
]
result = await ranker.rank(contexts, "query", budget)
stats = result.selection_stats
assert "total_contexts" in stats
assert "selected_count" in stats
assert "excluded_count" in stats
assert "total_tokens" in stats
assert "by_type" in stats
@pytest.mark.asyncio
async def test_rank_simple(self) -> None:
"""Test simple ranking without budget per type."""
ranker = ContextRanker()
contexts = [
KnowledgeContext(
content="A",
source="docs",
relevance_score=0.9,
),
KnowledgeContext(
content="B",
source="docs",
relevance_score=0.7,
),
KnowledgeContext(
content="C",
source="docs",
relevance_score=0.5,
),
]
result = await ranker.rank_simple(contexts, "query", max_tokens=1000)
# Should return contexts sorted by score
assert len(result) > 0
@pytest.mark.asyncio
async def test_rank_simple_respects_max_tokens(self) -> None:
"""Test that simple ranking respects max tokens."""
ranker = ContextRanker()
# Create contexts with known token counts
contexts = [
KnowledgeContext(
content="A" * 100, # ~25 tokens
source="docs",
relevance_score=0.9,
),
KnowledgeContext(
content="B" * 100,
source="docs",
relevance_score=0.8,
),
KnowledgeContext(
content="C" * 100,
source="docs",
relevance_score=0.7,
),
]
# Very small limit
result = await ranker.rank_simple(contexts, "query", max_tokens=30)
# Should only fit a limited number
assert len(result) <= len(contexts)
@pytest.mark.asyncio
async def test_rank_simple_empty(self) -> None:
"""Test simple ranking with empty list."""
ranker = ContextRanker()
result = await ranker.rank_simple([], "query", max_tokens=1000)
assert result == []
@pytest.mark.asyncio
async def test_rerank_for_diversity(self) -> None:
"""Test diversity reranking."""
ranker = ContextRanker()
# Create scored contexts from same source
contexts = [
ScoredContext(
context=KnowledgeContext(
content=f"Content {i}",
source="same-source",
relevance_score=0.9 - i * 0.1,
),
composite_score=0.9 - i * 0.1,
)
for i in range(5)
]
# Limit to 2 per source
result = await ranker.rerank_for_diversity(contexts, max_per_source=2)
assert len(result) == 5
# First 2 should be from same source, rest deferred
first_two_sources = [r.context.source for r in result[:2]]
assert all(s == "same-source" for s in first_two_sources)
@pytest.mark.asyncio
async def test_rerank_for_diversity_multiple_sources(self) -> None:
"""Test diversity reranking with multiple sources."""
ranker = ContextRanker()
contexts = [
ScoredContext(
context=KnowledgeContext(
content="Source A - 1",
source="source-a",
relevance_score=0.9,
),
composite_score=0.9,
),
ScoredContext(
context=KnowledgeContext(
content="Source A - 2",
source="source-a",
relevance_score=0.8,
),
composite_score=0.8,
),
ScoredContext(
context=KnowledgeContext(
content="Source B - 1",
source="source-b",
relevance_score=0.7,
),
composite_score=0.7,
),
ScoredContext(
context=KnowledgeContext(
content="Source A - 3",
source="source-a",
relevance_score=0.6,
),
composite_score=0.6,
),
]
result = await ranker.rerank_for_diversity(contexts, max_per_source=2)
# Should not have more than 2 from source-a in first 3
source_a_in_first_3 = sum(
1 for r in result[:3] if r.context.source == "source-a"
)
assert source_a_in_first_3 <= 2
@pytest.mark.asyncio
async def test_token_counts_set(self) -> None:
"""Test that token counts are set during ranking."""
ranker = ContextRanker()
allocator = BudgetAllocator()
budget = allocator.create_budget(10000)
context = KnowledgeContext(
content="Test content",
source="docs",
relevance_score=0.8,
)
# Token count should be None initially
assert context.token_count is None
await ranker.rank([context], "query", budget)
# Token count should be set after ranking
assert context.token_count is not None
assert context.token_count > 0
class TestContextRankerIntegration:
"""Integration tests for context ranking."""
@pytest.mark.asyncio
async def test_full_ranking_workflow(self) -> None:
"""Test complete ranking workflow."""
ranker = ContextRanker()
allocator = BudgetAllocator()
budget = allocator.create_budget(10000)
# Create diverse context types
contexts = [
SystemContext(
content="You are a helpful assistant.",
source="system",
priority=ContextPriority.CRITICAL.value,
),
TaskContext(
content="Help the user with their coding question.",
source="task",
priority=ContextPriority.HIGH.value,
),
KnowledgeContext(
content="Python is a programming language.",
source="docs/python.md",
relevance_score=0.9,
),
KnowledgeContext(
content="Java is also a programming language.",
source="docs/java.md",
relevance_score=0.4,
),
ConversationContext(
content="Hello, can you help me?",
source="chat",
role=MessageRole.USER,
),
]
result = await ranker.rank(contexts, "Python help", budget)
# System (CRITICAL) should be included
selected_types = [s.context.get_type() for s in result.selected]
assert ContextType.SYSTEM in selected_types
# Stats should be populated
assert result.selection_stats["total_contexts"] == 5
@pytest.mark.asyncio
async def test_ranking_preserves_context_order_by_score(self) -> None:
"""Test that ranking orders by score correctly."""
ranker = ContextRanker()
allocator = BudgetAllocator()
budget = allocator.create_budget(100000)
contexts = [
KnowledgeContext(
content="Low",
source="docs",
relevance_score=0.2,
),
KnowledgeContext(
content="High",
source="docs",
relevance_score=0.9,
),
KnowledgeContext(
content="Medium",
source="docs",
relevance_score=0.5,
),
]
result = await ranker.rank(contexts, "query", budget)
# Verify ordering is by score
scores = [s.composite_score for s in result.selected]
assert scores == sorted(scores, reverse=True)