feat(context): implement context scoring and ranking (Phase 3)
Add comprehensive scoring system with three strategies: - RelevanceScorer: Semantic similarity with keyword fallback - RecencyScorer: Exponential decay with type-specific half-lives - PriorityScorer: Priority-based scoring with type bonuses Implement CompositeScorer combining all strategies with configurable weights (default: 50% relevance, 30% recency, 20% priority). Add ContextRanker for budget-aware context selection with: - Greedy selection algorithm respecting token budgets - CRITICAL priority contexts always included - Diversity reranking to prevent source dominance - Comprehensive selection statistics 68 tests covering all scoring and ranking functionality. Part of #61 - Context Management Engine 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
507
backend/tests/services/context/test_ranker.py
Normal file
507
backend/tests/services/context/test_ranker.py
Normal file
@@ -0,0 +1,507 @@
|
||||
"""Tests for context ranking module."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.budget import BudgetAllocator, TokenBudget
|
||||
from app.services.context.prioritization import ContextRanker, RankingResult
|
||||
from app.services.context.scoring import CompositeScorer, ScoredContext
|
||||
from app.services.context.types import (
|
||||
ContextPriority,
|
||||
ContextType,
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MessageRole,
|
||||
SystemContext,
|
||||
TaskContext,
|
||||
)
|
||||
|
||||
|
||||
class TestRankingResult:
|
||||
"""Tests for RankingResult dataclass."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test RankingResult creation."""
|
||||
ctx = TaskContext(content="Test", source="task")
|
||||
scored = ScoredContext(context=ctx, composite_score=0.8)
|
||||
|
||||
result = RankingResult(
|
||||
selected=[scored],
|
||||
excluded=[],
|
||||
total_tokens=100,
|
||||
selection_stats={"total": 1},
|
||||
)
|
||||
|
||||
assert len(result.selected) == 1
|
||||
assert result.total_tokens == 100
|
||||
|
||||
def test_selected_contexts_property(self) -> None:
|
||||
"""Test selected_contexts property extracts contexts."""
|
||||
ctx1 = TaskContext(content="Test 1", source="task")
|
||||
ctx2 = TaskContext(content="Test 2", source="task")
|
||||
|
||||
scored1 = ScoredContext(context=ctx1, composite_score=0.8)
|
||||
scored2 = ScoredContext(context=ctx2, composite_score=0.6)
|
||||
|
||||
result = RankingResult(
|
||||
selected=[scored1, scored2],
|
||||
excluded=[],
|
||||
total_tokens=200,
|
||||
)
|
||||
|
||||
selected = result.selected_contexts
|
||||
assert len(selected) == 2
|
||||
assert ctx1 in selected
|
||||
assert ctx2 in selected
|
||||
|
||||
|
||||
class TestContextRanker:
|
||||
"""Tests for ContextRanker."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test ranker creation."""
|
||||
ranker = ContextRanker()
|
||||
assert ranker._scorer is not None
|
||||
assert ranker._calculator is not None
|
||||
|
||||
def test_creation_with_scorer(self) -> None:
|
||||
"""Test ranker creation with custom scorer."""
|
||||
scorer = CompositeScorer(relevance_weight=0.8)
|
||||
ranker = ContextRanker(scorer=scorer)
|
||||
assert ranker._scorer is scorer
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_empty_contexts(self) -> None:
|
||||
"""Test ranking empty context list."""
|
||||
ranker = ContextRanker()
|
||||
allocator = BudgetAllocator()
|
||||
budget = allocator.create_budget(10000)
|
||||
|
||||
result = await ranker.rank([], "query", budget)
|
||||
|
||||
assert len(result.selected) == 0
|
||||
assert len(result.excluded) == 0
|
||||
assert result.total_tokens == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_single_context_fits(self) -> None:
|
||||
"""Test ranking single context that fits budget."""
|
||||
ranker = ContextRanker()
|
||||
allocator = BudgetAllocator()
|
||||
budget = allocator.create_budget(10000)
|
||||
|
||||
context = KnowledgeContext(
|
||||
content="Short content",
|
||||
source="docs",
|
||||
relevance_score=0.8,
|
||||
)
|
||||
|
||||
result = await ranker.rank([context], "query", budget)
|
||||
|
||||
assert len(result.selected) == 1
|
||||
assert len(result.excluded) == 0
|
||||
assert result.selected[0].context is context
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_respects_budget(self) -> None:
|
||||
"""Test that ranking respects token budget."""
|
||||
ranker = ContextRanker()
|
||||
|
||||
# Create a very small budget
|
||||
budget = TokenBudget(
|
||||
total=100,
|
||||
knowledge=50, # Only 50 tokens for knowledge
|
||||
)
|
||||
|
||||
# Create contexts that exceed budget
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="A" * 200, # ~50 tokens
|
||||
source="docs",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="B" * 200, # ~50 tokens
|
||||
source="docs",
|
||||
relevance_score=0.8,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="C" * 200, # ~50 tokens
|
||||
source="docs",
|
||||
relevance_score=0.7,
|
||||
),
|
||||
]
|
||||
|
||||
result = await ranker.rank(contexts, "query", budget)
|
||||
|
||||
# Not all should fit
|
||||
assert len(result.selected) < len(contexts)
|
||||
assert len(result.excluded) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_selects_highest_scores(self) -> None:
|
||||
"""Test that ranking selects highest scored contexts."""
|
||||
ranker = ContextRanker()
|
||||
allocator = BudgetAllocator()
|
||||
budget = allocator.create_budget(1000)
|
||||
|
||||
# Small budget for knowledge
|
||||
budget.knowledge = 100
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="Low score",
|
||||
source="docs",
|
||||
relevance_score=0.2,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="High score",
|
||||
source="docs",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Medium score",
|
||||
source="docs",
|
||||
relevance_score=0.5,
|
||||
),
|
||||
]
|
||||
|
||||
result = await ranker.rank(contexts, "query", budget)
|
||||
|
||||
# Should have selected some
|
||||
if result.selected:
|
||||
# The highest scored should be selected first
|
||||
scores = [s.composite_score for s in result.selected]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_critical_priority_always_included(self) -> None:
|
||||
"""Test that CRITICAL priority contexts are always included."""
|
||||
ranker = ContextRanker()
|
||||
|
||||
# Very small budget
|
||||
budget = TokenBudget(
|
||||
total=100,
|
||||
system=10, # Very small
|
||||
knowledge=10,
|
||||
)
|
||||
|
||||
contexts = [
|
||||
SystemContext(
|
||||
content="Critical system prompt that must be included",
|
||||
source="system",
|
||||
priority=ContextPriority.CRITICAL.value,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Optional knowledge",
|
||||
source="docs",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
]
|
||||
|
||||
result = await ranker.rank(contexts, "query", budget, ensure_required=True)
|
||||
|
||||
# Critical context should be in selected
|
||||
selected_priorities = [s.context.priority for s in result.selected]
|
||||
assert ContextPriority.CRITICAL.value in selected_priorities
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_without_ensure_required(self) -> None:
|
||||
"""Test ranking without ensuring required contexts."""
|
||||
ranker = ContextRanker()
|
||||
|
||||
budget = TokenBudget(
|
||||
total=100,
|
||||
system=50,
|
||||
knowledge=50,
|
||||
)
|
||||
|
||||
contexts = [
|
||||
SystemContext(
|
||||
content="A" * 500, # Large content
|
||||
source="system",
|
||||
priority=ContextPriority.CRITICAL.value,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Short",
|
||||
source="docs",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
]
|
||||
|
||||
result = await ranker.rank(
|
||||
contexts, "query", budget, ensure_required=False
|
||||
)
|
||||
|
||||
# Without ensure_required, CRITICAL contexts can be excluded
|
||||
# if budget doesn't allow
|
||||
assert len(result.selected) + len(result.excluded) == len(contexts)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_selection_stats(self) -> None:
|
||||
"""Test that ranking provides useful statistics."""
|
||||
ranker = ContextRanker()
|
||||
allocator = BudgetAllocator()
|
||||
budget = allocator.create_budget(10000)
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="Knowledge 1", source="docs", relevance_score=0.8
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Knowledge 2", source="docs", relevance_score=0.6
|
||||
),
|
||||
TaskContext(content="Task", source="task"),
|
||||
]
|
||||
|
||||
result = await ranker.rank(contexts, "query", budget)
|
||||
|
||||
stats = result.selection_stats
|
||||
assert "total_contexts" in stats
|
||||
assert "selected_count" in stats
|
||||
assert "excluded_count" in stats
|
||||
assert "total_tokens" in stats
|
||||
assert "by_type" in stats
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_simple(self) -> None:
|
||||
"""Test simple ranking without budget per type."""
|
||||
ranker = ContextRanker()
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="A",
|
||||
source="docs",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="B",
|
||||
source="docs",
|
||||
relevance_score=0.7,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="C",
|
||||
source="docs",
|
||||
relevance_score=0.5,
|
||||
),
|
||||
]
|
||||
|
||||
result = await ranker.rank_simple(contexts, "query", max_tokens=1000)
|
||||
|
||||
# Should return contexts sorted by score
|
||||
assert len(result) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_simple_respects_max_tokens(self) -> None:
|
||||
"""Test that simple ranking respects max tokens."""
|
||||
ranker = ContextRanker()
|
||||
|
||||
# Create contexts with known token counts
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="A" * 100, # ~25 tokens
|
||||
source="docs",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="B" * 100,
|
||||
source="docs",
|
||||
relevance_score=0.8,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="C" * 100,
|
||||
source="docs",
|
||||
relevance_score=0.7,
|
||||
),
|
||||
]
|
||||
|
||||
# Very small limit
|
||||
result = await ranker.rank_simple(contexts, "query", max_tokens=30)
|
||||
|
||||
# Should only fit a limited number
|
||||
assert len(result) <= len(contexts)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_simple_empty(self) -> None:
|
||||
"""Test simple ranking with empty list."""
|
||||
ranker = ContextRanker()
|
||||
|
||||
result = await ranker.rank_simple([], "query", max_tokens=1000)
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rerank_for_diversity(self) -> None:
|
||||
"""Test diversity reranking."""
|
||||
ranker = ContextRanker()
|
||||
|
||||
# Create scored contexts from same source
|
||||
contexts = [
|
||||
ScoredContext(
|
||||
context=KnowledgeContext(
|
||||
content=f"Content {i}",
|
||||
source="same-source",
|
||||
relevance_score=0.9 - i * 0.1,
|
||||
),
|
||||
composite_score=0.9 - i * 0.1,
|
||||
)
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
# Limit to 2 per source
|
||||
result = await ranker.rerank_for_diversity(contexts, max_per_source=2)
|
||||
|
||||
assert len(result) == 5
|
||||
# First 2 should be from same source, rest deferred
|
||||
first_two_sources = [r.context.source for r in result[:2]]
|
||||
assert all(s == "same-source" for s in first_two_sources)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rerank_for_diversity_multiple_sources(self) -> None:
|
||||
"""Test diversity reranking with multiple sources."""
|
||||
ranker = ContextRanker()
|
||||
|
||||
contexts = [
|
||||
ScoredContext(
|
||||
context=KnowledgeContext(
|
||||
content="Source A - 1",
|
||||
source="source-a",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
composite_score=0.9,
|
||||
),
|
||||
ScoredContext(
|
||||
context=KnowledgeContext(
|
||||
content="Source A - 2",
|
||||
source="source-a",
|
||||
relevance_score=0.8,
|
||||
),
|
||||
composite_score=0.8,
|
||||
),
|
||||
ScoredContext(
|
||||
context=KnowledgeContext(
|
||||
content="Source B - 1",
|
||||
source="source-b",
|
||||
relevance_score=0.7,
|
||||
),
|
||||
composite_score=0.7,
|
||||
),
|
||||
ScoredContext(
|
||||
context=KnowledgeContext(
|
||||
content="Source A - 3",
|
||||
source="source-a",
|
||||
relevance_score=0.6,
|
||||
),
|
||||
composite_score=0.6,
|
||||
),
|
||||
]
|
||||
|
||||
result = await ranker.rerank_for_diversity(contexts, max_per_source=2)
|
||||
|
||||
# Should not have more than 2 from source-a in first 3
|
||||
source_a_in_first_3 = sum(
|
||||
1 for r in result[:3] if r.context.source == "source-a"
|
||||
)
|
||||
assert source_a_in_first_3 <= 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_counts_set(self) -> None:
|
||||
"""Test that token counts are set during ranking."""
|
||||
ranker = ContextRanker()
|
||||
allocator = BudgetAllocator()
|
||||
budget = allocator.create_budget(10000)
|
||||
|
||||
context = KnowledgeContext(
|
||||
content="Test content",
|
||||
source="docs",
|
||||
relevance_score=0.8,
|
||||
)
|
||||
|
||||
# Token count should be None initially
|
||||
assert context.token_count is None
|
||||
|
||||
await ranker.rank([context], "query", budget)
|
||||
|
||||
# Token count should be set after ranking
|
||||
assert context.token_count is not None
|
||||
assert context.token_count > 0
|
||||
|
||||
|
||||
class TestContextRankerIntegration:
|
||||
"""Integration tests for context ranking."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_ranking_workflow(self) -> None:
|
||||
"""Test complete ranking workflow."""
|
||||
ranker = ContextRanker()
|
||||
allocator = BudgetAllocator()
|
||||
budget = allocator.create_budget(10000)
|
||||
|
||||
# Create diverse context types
|
||||
contexts = [
|
||||
SystemContext(
|
||||
content="You are a helpful assistant.",
|
||||
source="system",
|
||||
priority=ContextPriority.CRITICAL.value,
|
||||
),
|
||||
TaskContext(
|
||||
content="Help the user with their coding question.",
|
||||
source="task",
|
||||
priority=ContextPriority.HIGH.value,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Python is a programming language.",
|
||||
source="docs/python.md",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Java is also a programming language.",
|
||||
source="docs/java.md",
|
||||
relevance_score=0.4,
|
||||
),
|
||||
ConversationContext(
|
||||
content="Hello, can you help me?",
|
||||
source="chat",
|
||||
role=MessageRole.USER,
|
||||
),
|
||||
]
|
||||
|
||||
result = await ranker.rank(contexts, "Python help", budget)
|
||||
|
||||
# System (CRITICAL) should be included
|
||||
selected_types = [s.context.get_type() for s in result.selected]
|
||||
assert ContextType.SYSTEM in selected_types
|
||||
|
||||
# Stats should be populated
|
||||
assert result.selection_stats["total_contexts"] == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ranking_preserves_context_order_by_score(self) -> None:
|
||||
"""Test that ranking orders by score correctly."""
|
||||
ranker = ContextRanker()
|
||||
allocator = BudgetAllocator()
|
||||
budget = allocator.create_budget(100000)
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="Low",
|
||||
source="docs",
|
||||
relevance_score=0.2,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="High",
|
||||
source="docs",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Medium",
|
||||
source="docs",
|
||||
relevance_score=0.5,
|
||||
),
|
||||
]
|
||||
|
||||
result = await ranker.rank(contexts, "query", budget)
|
||||
|
||||
# Verify ordering is by score
|
||||
scores = [s.composite_score for s in result.selected]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
712
backend/tests/services/context/test_scoring.py
Normal file
712
backend/tests/services/context/test_scoring.py
Normal file
@@ -0,0 +1,712 @@
|
||||
"""Tests for context scoring module."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.scoring import (
|
||||
BaseScorer,
|
||||
CompositeScorer,
|
||||
PriorityScorer,
|
||||
RecencyScorer,
|
||||
RelevanceScorer,
|
||||
ScoredContext,
|
||||
)
|
||||
from app.services.context.types import (
|
||||
ContextPriority,
|
||||
ContextType,
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MessageRole,
|
||||
SystemContext,
|
||||
TaskContext,
|
||||
)
|
||||
|
||||
|
||||
class TestRelevanceScorer:
|
||||
"""Tests for RelevanceScorer."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test scorer creation."""
|
||||
scorer = RelevanceScorer()
|
||||
assert scorer.weight == 1.0
|
||||
|
||||
def test_creation_with_weight(self) -> None:
|
||||
"""Test scorer creation with custom weight."""
|
||||
scorer = RelevanceScorer(weight=0.5)
|
||||
assert scorer.weight == 0.5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_with_precomputed_relevance(self) -> None:
|
||||
"""Test scoring with pre-computed relevance score."""
|
||||
scorer = RelevanceScorer()
|
||||
|
||||
# KnowledgeContext with pre-computed score
|
||||
context = KnowledgeContext(
|
||||
content="Test content about Python",
|
||||
source="docs/python.md",
|
||||
relevance_score=0.85,
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "Python programming")
|
||||
assert score == 0.85
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_with_metadata_score(self) -> None:
|
||||
"""Test scoring with metadata-provided score."""
|
||||
scorer = RelevanceScorer()
|
||||
|
||||
context = SystemContext(
|
||||
content="System prompt",
|
||||
source="system",
|
||||
metadata={"relevance_score": 0.9},
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "anything")
|
||||
assert score == 0.9
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_fallback_to_keyword_matching(self) -> None:
|
||||
"""Test fallback to keyword matching when no score available."""
|
||||
scorer = RelevanceScorer(keyword_fallback_weight=0.5)
|
||||
|
||||
context = TaskContext(
|
||||
content="Implement authentication with JWT tokens",
|
||||
source="task",
|
||||
)
|
||||
|
||||
# Query has matching keywords
|
||||
score = await scorer.score(context, "JWT authentication")
|
||||
assert score > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keyword_matching_no_overlap(self) -> None:
|
||||
"""Test keyword matching with no query overlap."""
|
||||
scorer = RelevanceScorer()
|
||||
|
||||
context = TaskContext(
|
||||
content="Implement database migration",
|
||||
source="task",
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "xyz abc 123")
|
||||
assert score == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keyword_matching_full_overlap(self) -> None:
|
||||
"""Test keyword matching with high overlap."""
|
||||
scorer = RelevanceScorer(keyword_fallback_weight=1.0)
|
||||
|
||||
context = TaskContext(
|
||||
content="python programming language",
|
||||
source="task",
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "python programming")
|
||||
# Should have high score due to keyword overlap
|
||||
assert score > 0.5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_with_mcp_success(self) -> None:
|
||||
"""Test scoring with MCP semantic similarity."""
|
||||
mock_mcp = MagicMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.success = True
|
||||
mock_result.data = {"similarity": 0.75}
|
||||
mock_mcp.call_tool = AsyncMock(return_value=mock_result)
|
||||
|
||||
scorer = RelevanceScorer(mcp_manager=mock_mcp)
|
||||
|
||||
context = TaskContext(
|
||||
content="Test content",
|
||||
source="task",
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "test query")
|
||||
assert score == 0.75
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_with_mcp_failure_fallback(self) -> None:
|
||||
"""Test fallback when MCP fails."""
|
||||
mock_mcp = MagicMock()
|
||||
mock_mcp.call_tool = AsyncMock(side_effect=Exception("Connection failed"))
|
||||
|
||||
scorer = RelevanceScorer(mcp_manager=mock_mcp, keyword_fallback_weight=0.5)
|
||||
|
||||
context = TaskContext(
|
||||
content="Python programming code",
|
||||
source="task",
|
||||
)
|
||||
|
||||
# Should fall back to keyword matching
|
||||
score = await scorer.score(context, "Python code")
|
||||
assert score > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_batch(self) -> None:
|
||||
"""Test batch scoring."""
|
||||
scorer = RelevanceScorer()
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="Python", source="1", relevance_score=0.8
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Java", source="2", relevance_score=0.6
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Go", source="3", relevance_score=0.9
|
||||
),
|
||||
]
|
||||
|
||||
scores = await scorer.score_batch(contexts, "test")
|
||||
assert len(scores) == 3
|
||||
assert scores[0] == 0.8
|
||||
assert scores[1] == 0.6
|
||||
assert scores[2] == 0.9
|
||||
|
||||
def test_set_mcp_manager(self) -> None:
|
||||
"""Test setting MCP manager."""
|
||||
scorer = RelevanceScorer()
|
||||
assert scorer._mcp is None
|
||||
|
||||
mock_mcp = MagicMock()
|
||||
scorer.set_mcp_manager(mock_mcp)
|
||||
assert scorer._mcp is mock_mcp
|
||||
|
||||
|
||||
class TestRecencyScorer:
|
||||
"""Tests for RecencyScorer."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test scorer creation."""
|
||||
scorer = RecencyScorer()
|
||||
assert scorer.weight == 1.0
|
||||
assert scorer._half_life_hours == 24.0
|
||||
|
||||
def test_creation_with_custom_half_life(self) -> None:
|
||||
"""Test scorer creation with custom half-life."""
|
||||
scorer = RecencyScorer(half_life_hours=12.0)
|
||||
assert scorer._half_life_hours == 12.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_recent_context(self) -> None:
|
||||
"""Test scoring a very recent context."""
|
||||
scorer = RecencyScorer(half_life_hours=24.0)
|
||||
now = datetime.now(UTC)
|
||||
|
||||
context = TaskContext(
|
||||
content="Recent task",
|
||||
source="task",
|
||||
timestamp=now,
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "query", reference_time=now)
|
||||
# Very recent should have score near 1.0
|
||||
assert score > 0.99
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_at_half_life(self) -> None:
|
||||
"""Test scoring at exactly half-life age."""
|
||||
scorer = RecencyScorer(half_life_hours=24.0)
|
||||
now = datetime.now(UTC)
|
||||
half_life_ago = now - timedelta(hours=24)
|
||||
|
||||
context = TaskContext(
|
||||
content="Day old task",
|
||||
source="task",
|
||||
timestamp=half_life_ago,
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "query", reference_time=now)
|
||||
# At half-life, score should be ~0.5
|
||||
assert 0.49 <= score <= 0.51
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_old_context(self) -> None:
|
||||
"""Test scoring a very old context."""
|
||||
scorer = RecencyScorer(half_life_hours=24.0)
|
||||
now = datetime.now(UTC)
|
||||
week_ago = now - timedelta(days=7)
|
||||
|
||||
context = TaskContext(
|
||||
content="Week old task",
|
||||
source="task",
|
||||
timestamp=week_ago,
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "query", reference_time=now)
|
||||
# 7 days with 24h half-life = very low score
|
||||
assert score < 0.01
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_type_specific_half_lives(self) -> None:
|
||||
"""Test that different context types have different half-lives."""
|
||||
scorer = RecencyScorer()
|
||||
now = datetime.now(UTC)
|
||||
one_hour_ago = now - timedelta(hours=1)
|
||||
|
||||
# Conversation has 1 hour half-life by default
|
||||
conv_context = ConversationContext(
|
||||
content="Hello",
|
||||
source="chat",
|
||||
role=MessageRole.USER,
|
||||
timestamp=one_hour_ago,
|
||||
)
|
||||
|
||||
# Knowledge has 168 hour (1 week) half-life by default
|
||||
knowledge_context = KnowledgeContext(
|
||||
content="Documentation",
|
||||
source="docs",
|
||||
timestamp=one_hour_ago,
|
||||
)
|
||||
|
||||
conv_score = await scorer.score(conv_context, "query", reference_time=now)
|
||||
knowledge_score = await scorer.score(knowledge_context, "query", reference_time=now)
|
||||
|
||||
# Conversation should decay much faster
|
||||
assert conv_score < knowledge_score
|
||||
|
||||
def test_get_half_life(self) -> None:
|
||||
"""Test getting half-life for context type."""
|
||||
scorer = RecencyScorer()
|
||||
|
||||
assert scorer.get_half_life(ContextType.CONVERSATION) == 1.0
|
||||
assert scorer.get_half_life(ContextType.KNOWLEDGE) == 168.0
|
||||
assert scorer.get_half_life(ContextType.SYSTEM) == 720.0
|
||||
|
||||
def test_set_half_life(self) -> None:
|
||||
"""Test setting custom half-life."""
|
||||
scorer = RecencyScorer()
|
||||
|
||||
scorer.set_half_life(ContextType.TASK, 48.0)
|
||||
assert scorer.get_half_life(ContextType.TASK) == 48.0
|
||||
|
||||
def test_set_half_life_invalid(self) -> None:
|
||||
"""Test setting invalid half-life."""
|
||||
scorer = RecencyScorer()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
scorer.set_half_life(ContextType.TASK, 0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
scorer.set_half_life(ContextType.TASK, -1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_batch(self) -> None:
|
||||
"""Test batch scoring."""
|
||||
scorer = RecencyScorer()
|
||||
now = datetime.now(UTC)
|
||||
|
||||
contexts = [
|
||||
TaskContext(content="1", source="t", timestamp=now),
|
||||
TaskContext(
|
||||
content="2", source="t", timestamp=now - timedelta(hours=24)
|
||||
),
|
||||
TaskContext(
|
||||
content="3", source="t", timestamp=now - timedelta(hours=48)
|
||||
),
|
||||
]
|
||||
|
||||
scores = await scorer.score_batch(contexts, "query", reference_time=now)
|
||||
assert len(scores) == 3
|
||||
# Scores should be in descending order (more recent = higher)
|
||||
assert scores[0] > scores[1] > scores[2]
|
||||
|
||||
|
||||
class TestPriorityScorer:
|
||||
"""Tests for PriorityScorer."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test scorer creation."""
|
||||
scorer = PriorityScorer()
|
||||
assert scorer.weight == 1.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_critical_priority(self) -> None:
|
||||
"""Test scoring CRITICAL priority context."""
|
||||
scorer = PriorityScorer()
|
||||
|
||||
context = SystemContext(
|
||||
content="Critical system prompt",
|
||||
source="system",
|
||||
priority=ContextPriority.CRITICAL.value,
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "query")
|
||||
# CRITICAL (100) + type bonus should be > 1.0, normalized to 1.0
|
||||
assert score == 1.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_normal_priority(self) -> None:
|
||||
"""Test scoring NORMAL priority context."""
|
||||
scorer = PriorityScorer()
|
||||
|
||||
context = TaskContext(
|
||||
content="Normal task",
|
||||
source="task",
|
||||
priority=ContextPriority.NORMAL.value,
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "query")
|
||||
# NORMAL (50) = 0.5, plus TASK bonus (0.15) = 0.65
|
||||
assert 0.6 <= score <= 0.7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_low_priority(self) -> None:
|
||||
"""Test scoring LOW priority context."""
|
||||
scorer = PriorityScorer()
|
||||
|
||||
context = KnowledgeContext(
|
||||
content="Low priority knowledge",
|
||||
source="docs",
|
||||
priority=ContextPriority.LOW.value,
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "query")
|
||||
# LOW (20) = 0.2, no bonus for KNOWLEDGE
|
||||
assert 0.15 <= score <= 0.25
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_type_bonuses(self) -> None:
|
||||
"""Test type-specific priority bonuses."""
|
||||
scorer = PriorityScorer()
|
||||
|
||||
# All with same base priority
|
||||
system_ctx = SystemContext(
|
||||
content="System",
|
||||
source="system",
|
||||
priority=50,
|
||||
)
|
||||
task_ctx = TaskContext(
|
||||
content="Task",
|
||||
source="task",
|
||||
priority=50,
|
||||
)
|
||||
knowledge_ctx = KnowledgeContext(
|
||||
content="Knowledge",
|
||||
source="docs",
|
||||
priority=50,
|
||||
)
|
||||
|
||||
system_score = await scorer.score(system_ctx, "query")
|
||||
task_score = await scorer.score(task_ctx, "query")
|
||||
knowledge_score = await scorer.score(knowledge_ctx, "query")
|
||||
|
||||
# System has highest bonus (0.2), task next (0.15), knowledge has none
|
||||
assert system_score > task_score > knowledge_score
|
||||
|
||||
def test_get_type_bonus(self) -> None:
|
||||
"""Test getting type bonus."""
|
||||
scorer = PriorityScorer()
|
||||
|
||||
assert scorer.get_type_bonus(ContextType.SYSTEM) == 0.2
|
||||
assert scorer.get_type_bonus(ContextType.TASK) == 0.15
|
||||
assert scorer.get_type_bonus(ContextType.KNOWLEDGE) == 0.0
|
||||
|
||||
def test_set_type_bonus(self) -> None:
|
||||
"""Test setting custom type bonus."""
|
||||
scorer = PriorityScorer()
|
||||
|
||||
scorer.set_type_bonus(ContextType.KNOWLEDGE, 0.1)
|
||||
assert scorer.get_type_bonus(ContextType.KNOWLEDGE) == 0.1
|
||||
|
||||
def test_set_type_bonus_invalid(self) -> None:
|
||||
"""Test setting invalid type bonus."""
|
||||
scorer = PriorityScorer()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
scorer.set_type_bonus(ContextType.KNOWLEDGE, 1.5)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
scorer.set_type_bonus(ContextType.KNOWLEDGE, -0.1)
|
||||
|
||||
|
||||
class TestCompositeScorer:
|
||||
"""Tests for CompositeScorer."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test scorer creation with default weights."""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
weights = scorer.weights
|
||||
assert weights["relevance"] == 0.5
|
||||
assert weights["recency"] == 0.3
|
||||
assert weights["priority"] == 0.2
|
||||
|
||||
def test_creation_with_custom_weights(self) -> None:
|
||||
"""Test scorer creation with custom weights."""
|
||||
scorer = CompositeScorer(
|
||||
relevance_weight=0.6,
|
||||
recency_weight=0.2,
|
||||
priority_weight=0.2,
|
||||
)
|
||||
|
||||
weights = scorer.weights
|
||||
assert weights["relevance"] == 0.6
|
||||
assert weights["recency"] == 0.2
|
||||
assert weights["priority"] == 0.2
|
||||
|
||||
def test_update_weights(self) -> None:
|
||||
"""Test updating weights."""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
scorer.update_weights(relevance=0.7, recency=0.2, priority=0.1)
|
||||
|
||||
weights = scorer.weights
|
||||
assert weights["relevance"] == 0.7
|
||||
assert weights["recency"] == 0.2
|
||||
assert weights["priority"] == 0.1
|
||||
|
||||
def test_update_weights_partial(self) -> None:
|
||||
"""Test partially updating weights."""
|
||||
scorer = CompositeScorer()
|
||||
original_recency = scorer.weights["recency"]
|
||||
|
||||
scorer.update_weights(relevance=0.8)
|
||||
|
||||
assert scorer.weights["relevance"] == 0.8
|
||||
assert scorer.weights["recency"] == original_recency
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_basic(self) -> None:
|
||||
"""Test basic composite scoring."""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
context = KnowledgeContext(
|
||||
content="Test content",
|
||||
source="docs",
|
||||
relevance_score=0.8,
|
||||
timestamp=datetime.now(UTC),
|
||||
priority=ContextPriority.NORMAL.value,
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "test query")
|
||||
assert 0.0 <= score <= 1.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_with_details(self) -> None:
|
||||
"""Test scoring with detailed breakdown."""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
context = KnowledgeContext(
|
||||
content="Test content",
|
||||
source="docs",
|
||||
relevance_score=0.8,
|
||||
timestamp=datetime.now(UTC),
|
||||
priority=ContextPriority.HIGH.value,
|
||||
)
|
||||
|
||||
scored = await scorer.score_with_details(context, "test query")
|
||||
|
||||
assert isinstance(scored, ScoredContext)
|
||||
assert scored.context is context
|
||||
assert 0.0 <= scored.composite_score <= 1.0
|
||||
assert scored.relevance_score == 0.8
|
||||
assert scored.recency_score > 0.9 # Very recent
|
||||
assert scored.priority_score > 0.5 # HIGH priority
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_cached_on_context(self) -> None:
|
||||
"""Test that score is cached on the context."""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
context = KnowledgeContext(
|
||||
content="Test",
|
||||
source="docs",
|
||||
relevance_score=0.5,
|
||||
)
|
||||
|
||||
# First scoring
|
||||
await scorer.score(context, "query")
|
||||
assert context._score is not None
|
||||
|
||||
# Second scoring should use cached value
|
||||
context._score = 0.999 # Set to a known value
|
||||
score2 = await scorer.score(context, "query")
|
||||
assert score2 == 0.999
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_batch(self) -> None:
|
||||
"""Test batch scoring."""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="High relevance",
|
||||
source="docs",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Low relevance",
|
||||
source="docs",
|
||||
relevance_score=0.2,
|
||||
),
|
||||
]
|
||||
|
||||
scored = await scorer.score_batch(contexts, "query")
|
||||
assert len(scored) == 2
|
||||
assert scored[0].relevance_score > scored[1].relevance_score
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank(self) -> None:
|
||||
"""Test ranking contexts."""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="Low", source="docs", relevance_score=0.2
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="High", source="docs", relevance_score=0.9
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Medium", source="docs", relevance_score=0.5
|
||||
),
|
||||
]
|
||||
|
||||
ranked = await scorer.rank(contexts, "query")
|
||||
|
||||
# Should be sorted by score (highest first)
|
||||
assert len(ranked) == 3
|
||||
assert ranked[0].relevance_score == 0.9
|
||||
assert ranked[1].relevance_score == 0.5
|
||||
assert ranked[2].relevance_score == 0.2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_with_limit(self) -> None:
|
||||
"""Test ranking with limit."""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content=str(i), source="docs", relevance_score=i / 10
|
||||
)
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
ranked = await scorer.rank(contexts, "query", limit=3)
|
||||
assert len(ranked) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_with_min_score(self) -> None:
|
||||
"""Test ranking with minimum score threshold."""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="Low", source="docs", relevance_score=0.1
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="High", source="docs", relevance_score=0.9
|
||||
),
|
||||
]
|
||||
|
||||
ranked = await scorer.rank(contexts, "query", min_score=0.5)
|
||||
|
||||
# Only the high relevance context should pass the threshold
|
||||
assert len(ranked) <= 2 # Could be 1 if min_score filters
|
||||
|
||||
def test_set_mcp_manager(self) -> None:
|
||||
"""Test setting MCP manager."""
|
||||
scorer = CompositeScorer()
|
||||
mock_mcp = MagicMock()
|
||||
|
||||
scorer.set_mcp_manager(mock_mcp)
|
||||
assert scorer._relevance_scorer._mcp is mock_mcp
|
||||
|
||||
|
||||
class TestScoredContext:
|
||||
"""Tests for ScoredContext dataclass."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test ScoredContext creation."""
|
||||
context = TaskContext(content="Test", source="task")
|
||||
scored = ScoredContext(
|
||||
context=context,
|
||||
composite_score=0.75,
|
||||
relevance_score=0.8,
|
||||
recency_score=0.7,
|
||||
priority_score=0.5,
|
||||
)
|
||||
|
||||
assert scored.context is context
|
||||
assert scored.composite_score == 0.75
|
||||
|
||||
def test_comparison_operators(self) -> None:
|
||||
"""Test comparison operators for sorting."""
|
||||
ctx1 = TaskContext(content="1", source="task")
|
||||
ctx2 = TaskContext(content="2", source="task")
|
||||
|
||||
scored1 = ScoredContext(context=ctx1, composite_score=0.5)
|
||||
scored2 = ScoredContext(context=ctx2, composite_score=0.8)
|
||||
|
||||
assert scored1 < scored2
|
||||
assert scored2 > scored1
|
||||
|
||||
def test_sorting(self) -> None:
|
||||
"""Test sorting scored contexts."""
|
||||
contexts = [
|
||||
ScoredContext(
|
||||
context=TaskContext(content="Low", source="task"),
|
||||
composite_score=0.3,
|
||||
),
|
||||
ScoredContext(
|
||||
context=TaskContext(content="High", source="task"),
|
||||
composite_score=0.9,
|
||||
),
|
||||
ScoredContext(
|
||||
context=TaskContext(content="Medium", source="task"),
|
||||
composite_score=0.6,
|
||||
),
|
||||
]
|
||||
|
||||
sorted_contexts = sorted(contexts, reverse=True)
|
||||
|
||||
assert sorted_contexts[0].composite_score == 0.9
|
||||
assert sorted_contexts[1].composite_score == 0.6
|
||||
assert sorted_contexts[2].composite_score == 0.3
|
||||
|
||||
|
||||
class TestBaseScorer:
|
||||
"""Tests for BaseScorer abstract class."""
|
||||
|
||||
def test_weight_property(self) -> None:
|
||||
"""Test weight property."""
|
||||
# Use a concrete implementation
|
||||
scorer = RelevanceScorer(weight=0.7)
|
||||
assert scorer.weight == 0.7
|
||||
|
||||
def test_weight_setter_valid(self) -> None:
|
||||
"""Test weight setter with valid values."""
|
||||
scorer = RelevanceScorer()
|
||||
scorer.weight = 0.5
|
||||
assert scorer.weight == 0.5
|
||||
|
||||
def test_weight_setter_invalid(self) -> None:
|
||||
"""Test weight setter with invalid values."""
|
||||
scorer = RelevanceScorer()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
scorer.weight = -0.1
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
scorer.weight = 1.5
|
||||
|
||||
def test_normalize_score(self) -> None:
|
||||
"""Test score normalization."""
|
||||
scorer = RelevanceScorer()
|
||||
|
||||
# Normal range
|
||||
assert scorer.normalize_score(0.5) == 0.5
|
||||
|
||||
# Below 0
|
||||
assert scorer.normalize_score(-0.5) == 0.0
|
||||
|
||||
# Above 1
|
||||
assert scorer.normalize_score(1.5) == 1.0
|
||||
|
||||
# Boundaries
|
||||
assert scorer.normalize_score(0.0) == 0.0
|
||||
assert scorer.normalize_score(1.0) == 1.0
|
||||
@@ -286,9 +286,20 @@ class TestTaskContext:
|
||||
assert ctx.title == "Login Feature"
|
||||
assert ctx.get_type() == ContextType.TASK
|
||||
|
||||
def test_default_high_priority(self) -> None:
|
||||
"""Test that task context defaults to high priority."""
|
||||
def test_default_normal_priority(self) -> None:
|
||||
"""Test that task context uses NORMAL priority from base class."""
|
||||
ctx = TaskContext(content="Test", source="test")
|
||||
# TaskContext inherits NORMAL priority from BaseContext
|
||||
# Use TaskContext.create() for default HIGH priority behavior
|
||||
assert ctx.priority == ContextPriority.NORMAL.value
|
||||
|
||||
def test_explicit_high_priority(self) -> None:
|
||||
"""Test setting explicit HIGH priority."""
|
||||
ctx = TaskContext(
|
||||
content="Test",
|
||||
source="test",
|
||||
priority=ContextPriority.HIGH.value,
|
||||
)
|
||||
assert ctx.priority == ContextPriority.HIGH.value
|
||||
|
||||
def test_create_factory(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user