forked from cardosofelipe/fast-next-template
Add comprehensive scoring system with three strategies: - RelevanceScorer: Semantic similarity with keyword fallback - RecencyScorer: Exponential decay with type-specific half-lives - PriorityScorer: Priority-based scoring with type bonuses Implement CompositeScorer combining all strategies with configurable weights (default: 50% relevance, 30% recency, 20% priority). Add ContextRanker for budget-aware context selection with: - Greedy selection algorithm respecting token budgets - CRITICAL priority contexts always included - Diversity reranking to prevent source dominance - Comprehensive selection statistics 68 tests covering all scoring and ranking functionality. Part of #61 - Context Management Engine 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
508 lines
16 KiB
Python
508 lines
16 KiB
Python
"""Tests for context ranking module."""
|
|
|
|
from datetime import UTC, datetime
|
|
|
|
import pytest
|
|
|
|
from app.services.context.budget import BudgetAllocator, TokenBudget
|
|
from app.services.context.prioritization import ContextRanker, RankingResult
|
|
from app.services.context.scoring import CompositeScorer, ScoredContext
|
|
from app.services.context.types import (
|
|
ContextPriority,
|
|
ContextType,
|
|
ConversationContext,
|
|
KnowledgeContext,
|
|
MessageRole,
|
|
SystemContext,
|
|
TaskContext,
|
|
)
|
|
|
|
|
|
class TestRankingResult:
|
|
"""Tests for RankingResult dataclass."""
|
|
|
|
def test_creation(self) -> None:
|
|
"""Test RankingResult creation."""
|
|
ctx = TaskContext(content="Test", source="task")
|
|
scored = ScoredContext(context=ctx, composite_score=0.8)
|
|
|
|
result = RankingResult(
|
|
selected=[scored],
|
|
excluded=[],
|
|
total_tokens=100,
|
|
selection_stats={"total": 1},
|
|
)
|
|
|
|
assert len(result.selected) == 1
|
|
assert result.total_tokens == 100
|
|
|
|
def test_selected_contexts_property(self) -> None:
|
|
"""Test selected_contexts property extracts contexts."""
|
|
ctx1 = TaskContext(content="Test 1", source="task")
|
|
ctx2 = TaskContext(content="Test 2", source="task")
|
|
|
|
scored1 = ScoredContext(context=ctx1, composite_score=0.8)
|
|
scored2 = ScoredContext(context=ctx2, composite_score=0.6)
|
|
|
|
result = RankingResult(
|
|
selected=[scored1, scored2],
|
|
excluded=[],
|
|
total_tokens=200,
|
|
)
|
|
|
|
selected = result.selected_contexts
|
|
assert len(selected) == 2
|
|
assert ctx1 in selected
|
|
assert ctx2 in selected
|
|
|
|
|
|
class TestContextRanker:
|
|
"""Tests for ContextRanker."""
|
|
|
|
def test_creation(self) -> None:
|
|
"""Test ranker creation."""
|
|
ranker = ContextRanker()
|
|
assert ranker._scorer is not None
|
|
assert ranker._calculator is not None
|
|
|
|
def test_creation_with_scorer(self) -> None:
|
|
"""Test ranker creation with custom scorer."""
|
|
scorer = CompositeScorer(relevance_weight=0.8)
|
|
ranker = ContextRanker(scorer=scorer)
|
|
assert ranker._scorer is scorer
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rank_empty_contexts(self) -> None:
|
|
"""Test ranking empty context list."""
|
|
ranker = ContextRanker()
|
|
allocator = BudgetAllocator()
|
|
budget = allocator.create_budget(10000)
|
|
|
|
result = await ranker.rank([], "query", budget)
|
|
|
|
assert len(result.selected) == 0
|
|
assert len(result.excluded) == 0
|
|
assert result.total_tokens == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rank_single_context_fits(self) -> None:
|
|
"""Test ranking single context that fits budget."""
|
|
ranker = ContextRanker()
|
|
allocator = BudgetAllocator()
|
|
budget = allocator.create_budget(10000)
|
|
|
|
context = KnowledgeContext(
|
|
content="Short content",
|
|
source="docs",
|
|
relevance_score=0.8,
|
|
)
|
|
|
|
result = await ranker.rank([context], "query", budget)
|
|
|
|
assert len(result.selected) == 1
|
|
assert len(result.excluded) == 0
|
|
assert result.selected[0].context is context
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rank_respects_budget(self) -> None:
|
|
"""Test that ranking respects token budget."""
|
|
ranker = ContextRanker()
|
|
|
|
# Create a very small budget
|
|
budget = TokenBudget(
|
|
total=100,
|
|
knowledge=50, # Only 50 tokens for knowledge
|
|
)
|
|
|
|
# Create contexts that exceed budget
|
|
contexts = [
|
|
KnowledgeContext(
|
|
content="A" * 200, # ~50 tokens
|
|
source="docs",
|
|
relevance_score=0.9,
|
|
),
|
|
KnowledgeContext(
|
|
content="B" * 200, # ~50 tokens
|
|
source="docs",
|
|
relevance_score=0.8,
|
|
),
|
|
KnowledgeContext(
|
|
content="C" * 200, # ~50 tokens
|
|
source="docs",
|
|
relevance_score=0.7,
|
|
),
|
|
]
|
|
|
|
result = await ranker.rank(contexts, "query", budget)
|
|
|
|
# Not all should fit
|
|
assert len(result.selected) < len(contexts)
|
|
assert len(result.excluded) > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rank_selects_highest_scores(self) -> None:
|
|
"""Test that ranking selects highest scored contexts."""
|
|
ranker = ContextRanker()
|
|
allocator = BudgetAllocator()
|
|
budget = allocator.create_budget(1000)
|
|
|
|
# Small budget for knowledge
|
|
budget.knowledge = 100
|
|
|
|
contexts = [
|
|
KnowledgeContext(
|
|
content="Low score",
|
|
source="docs",
|
|
relevance_score=0.2,
|
|
),
|
|
KnowledgeContext(
|
|
content="High score",
|
|
source="docs",
|
|
relevance_score=0.9,
|
|
),
|
|
KnowledgeContext(
|
|
content="Medium score",
|
|
source="docs",
|
|
relevance_score=0.5,
|
|
),
|
|
]
|
|
|
|
result = await ranker.rank(contexts, "query", budget)
|
|
|
|
# Should have selected some
|
|
if result.selected:
|
|
# The highest scored should be selected first
|
|
scores = [s.composite_score for s in result.selected]
|
|
assert scores == sorted(scores, reverse=True)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rank_critical_priority_always_included(self) -> None:
|
|
"""Test that CRITICAL priority contexts are always included."""
|
|
ranker = ContextRanker()
|
|
|
|
# Very small budget
|
|
budget = TokenBudget(
|
|
total=100,
|
|
system=10, # Very small
|
|
knowledge=10,
|
|
)
|
|
|
|
contexts = [
|
|
SystemContext(
|
|
content="Critical system prompt that must be included",
|
|
source="system",
|
|
priority=ContextPriority.CRITICAL.value,
|
|
),
|
|
KnowledgeContext(
|
|
content="Optional knowledge",
|
|
source="docs",
|
|
relevance_score=0.9,
|
|
),
|
|
]
|
|
|
|
result = await ranker.rank(contexts, "query", budget, ensure_required=True)
|
|
|
|
# Critical context should be in selected
|
|
selected_priorities = [s.context.priority for s in result.selected]
|
|
assert ContextPriority.CRITICAL.value in selected_priorities
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rank_without_ensure_required(self) -> None:
|
|
"""Test ranking without ensuring required contexts."""
|
|
ranker = ContextRanker()
|
|
|
|
budget = TokenBudget(
|
|
total=100,
|
|
system=50,
|
|
knowledge=50,
|
|
)
|
|
|
|
contexts = [
|
|
SystemContext(
|
|
content="A" * 500, # Large content
|
|
source="system",
|
|
priority=ContextPriority.CRITICAL.value,
|
|
),
|
|
KnowledgeContext(
|
|
content="Short",
|
|
source="docs",
|
|
relevance_score=0.9,
|
|
),
|
|
]
|
|
|
|
result = await ranker.rank(
|
|
contexts, "query", budget, ensure_required=False
|
|
)
|
|
|
|
# Without ensure_required, CRITICAL contexts can be excluded
|
|
# if budget doesn't allow
|
|
assert len(result.selected) + len(result.excluded) == len(contexts)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rank_selection_stats(self) -> None:
|
|
"""Test that ranking provides useful statistics."""
|
|
ranker = ContextRanker()
|
|
allocator = BudgetAllocator()
|
|
budget = allocator.create_budget(10000)
|
|
|
|
contexts = [
|
|
KnowledgeContext(
|
|
content="Knowledge 1", source="docs", relevance_score=0.8
|
|
),
|
|
KnowledgeContext(
|
|
content="Knowledge 2", source="docs", relevance_score=0.6
|
|
),
|
|
TaskContext(content="Task", source="task"),
|
|
]
|
|
|
|
result = await ranker.rank(contexts, "query", budget)
|
|
|
|
stats = result.selection_stats
|
|
assert "total_contexts" in stats
|
|
assert "selected_count" in stats
|
|
assert "excluded_count" in stats
|
|
assert "total_tokens" in stats
|
|
assert "by_type" in stats
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rank_simple(self) -> None:
|
|
"""Test simple ranking without budget per type."""
|
|
ranker = ContextRanker()
|
|
|
|
contexts = [
|
|
KnowledgeContext(
|
|
content="A",
|
|
source="docs",
|
|
relevance_score=0.9,
|
|
),
|
|
KnowledgeContext(
|
|
content="B",
|
|
source="docs",
|
|
relevance_score=0.7,
|
|
),
|
|
KnowledgeContext(
|
|
content="C",
|
|
source="docs",
|
|
relevance_score=0.5,
|
|
),
|
|
]
|
|
|
|
result = await ranker.rank_simple(contexts, "query", max_tokens=1000)
|
|
|
|
# Should return contexts sorted by score
|
|
assert len(result) > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rank_simple_respects_max_tokens(self) -> None:
|
|
"""Test that simple ranking respects max tokens."""
|
|
ranker = ContextRanker()
|
|
|
|
# Create contexts with known token counts
|
|
contexts = [
|
|
KnowledgeContext(
|
|
content="A" * 100, # ~25 tokens
|
|
source="docs",
|
|
relevance_score=0.9,
|
|
),
|
|
KnowledgeContext(
|
|
content="B" * 100,
|
|
source="docs",
|
|
relevance_score=0.8,
|
|
),
|
|
KnowledgeContext(
|
|
content="C" * 100,
|
|
source="docs",
|
|
relevance_score=0.7,
|
|
),
|
|
]
|
|
|
|
# Very small limit
|
|
result = await ranker.rank_simple(contexts, "query", max_tokens=30)
|
|
|
|
# Should only fit a limited number
|
|
assert len(result) <= len(contexts)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rank_simple_empty(self) -> None:
|
|
"""Test simple ranking with empty list."""
|
|
ranker = ContextRanker()
|
|
|
|
result = await ranker.rank_simple([], "query", max_tokens=1000)
|
|
assert result == []
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rerank_for_diversity(self) -> None:
|
|
"""Test diversity reranking."""
|
|
ranker = ContextRanker()
|
|
|
|
# Create scored contexts from same source
|
|
contexts = [
|
|
ScoredContext(
|
|
context=KnowledgeContext(
|
|
content=f"Content {i}",
|
|
source="same-source",
|
|
relevance_score=0.9 - i * 0.1,
|
|
),
|
|
composite_score=0.9 - i * 0.1,
|
|
)
|
|
for i in range(5)
|
|
]
|
|
|
|
# Limit to 2 per source
|
|
result = await ranker.rerank_for_diversity(contexts, max_per_source=2)
|
|
|
|
assert len(result) == 5
|
|
# First 2 should be from same source, rest deferred
|
|
first_two_sources = [r.context.source for r in result[:2]]
|
|
assert all(s == "same-source" for s in first_two_sources)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rerank_for_diversity_multiple_sources(self) -> None:
|
|
"""Test diversity reranking with multiple sources."""
|
|
ranker = ContextRanker()
|
|
|
|
contexts = [
|
|
ScoredContext(
|
|
context=KnowledgeContext(
|
|
content="Source A - 1",
|
|
source="source-a",
|
|
relevance_score=0.9,
|
|
),
|
|
composite_score=0.9,
|
|
),
|
|
ScoredContext(
|
|
context=KnowledgeContext(
|
|
content="Source A - 2",
|
|
source="source-a",
|
|
relevance_score=0.8,
|
|
),
|
|
composite_score=0.8,
|
|
),
|
|
ScoredContext(
|
|
context=KnowledgeContext(
|
|
content="Source B - 1",
|
|
source="source-b",
|
|
relevance_score=0.7,
|
|
),
|
|
composite_score=0.7,
|
|
),
|
|
ScoredContext(
|
|
context=KnowledgeContext(
|
|
content="Source A - 3",
|
|
source="source-a",
|
|
relevance_score=0.6,
|
|
),
|
|
composite_score=0.6,
|
|
),
|
|
]
|
|
|
|
result = await ranker.rerank_for_diversity(contexts, max_per_source=2)
|
|
|
|
# Should not have more than 2 from source-a in first 3
|
|
source_a_in_first_3 = sum(
|
|
1 for r in result[:3] if r.context.source == "source-a"
|
|
)
|
|
assert source_a_in_first_3 <= 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_token_counts_set(self) -> None:
|
|
"""Test that token counts are set during ranking."""
|
|
ranker = ContextRanker()
|
|
allocator = BudgetAllocator()
|
|
budget = allocator.create_budget(10000)
|
|
|
|
context = KnowledgeContext(
|
|
content="Test content",
|
|
source="docs",
|
|
relevance_score=0.8,
|
|
)
|
|
|
|
# Token count should be None initially
|
|
assert context.token_count is None
|
|
|
|
await ranker.rank([context], "query", budget)
|
|
|
|
# Token count should be set after ranking
|
|
assert context.token_count is not None
|
|
assert context.token_count > 0
|
|
|
|
|
|
class TestContextRankerIntegration:
|
|
"""Integration tests for context ranking."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_full_ranking_workflow(self) -> None:
|
|
"""Test complete ranking workflow."""
|
|
ranker = ContextRanker()
|
|
allocator = BudgetAllocator()
|
|
budget = allocator.create_budget(10000)
|
|
|
|
# Create diverse context types
|
|
contexts = [
|
|
SystemContext(
|
|
content="You are a helpful assistant.",
|
|
source="system",
|
|
priority=ContextPriority.CRITICAL.value,
|
|
),
|
|
TaskContext(
|
|
content="Help the user with their coding question.",
|
|
source="task",
|
|
priority=ContextPriority.HIGH.value,
|
|
),
|
|
KnowledgeContext(
|
|
content="Python is a programming language.",
|
|
source="docs/python.md",
|
|
relevance_score=0.9,
|
|
),
|
|
KnowledgeContext(
|
|
content="Java is also a programming language.",
|
|
source="docs/java.md",
|
|
relevance_score=0.4,
|
|
),
|
|
ConversationContext(
|
|
content="Hello, can you help me?",
|
|
source="chat",
|
|
role=MessageRole.USER,
|
|
),
|
|
]
|
|
|
|
result = await ranker.rank(contexts, "Python help", budget)
|
|
|
|
# System (CRITICAL) should be included
|
|
selected_types = [s.context.get_type() for s in result.selected]
|
|
assert ContextType.SYSTEM in selected_types
|
|
|
|
# Stats should be populated
|
|
assert result.selection_stats["total_contexts"] == 5
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ranking_preserves_context_order_by_score(self) -> None:
|
|
"""Test that ranking orders by score correctly."""
|
|
ranker = ContextRanker()
|
|
allocator = BudgetAllocator()
|
|
budget = allocator.create_budget(100000)
|
|
|
|
contexts = [
|
|
KnowledgeContext(
|
|
content="Low",
|
|
source="docs",
|
|
relevance_score=0.2,
|
|
),
|
|
KnowledgeContext(
|
|
content="High",
|
|
source="docs",
|
|
relevance_score=0.9,
|
|
),
|
|
KnowledgeContext(
|
|
content="Medium",
|
|
source="docs",
|
|
relevance_score=0.5,
|
|
),
|
|
]
|
|
|
|
result = await ranker.rank(contexts, "query", budget)
|
|
|
|
# Verify ordering is by score
|
|
scores = [s.composite_score for s in result.selected]
|
|
assert scores == sorted(scores, reverse=True)
|