test(context): add edge case tests for truncation and scoring concurrency
- Add tests for truncation edge cases, including zero tokens, short content, and marker handling. - Add concurrency tests for scoring to verify per-context locking and handling of multiple contexts.
This commit is contained in:
@@ -616,6 +616,66 @@ class TestCompositeScorer:
|
||||
scorer.set_mcp_manager(mock_mcp)
|
||||
assert scorer._relevance_scorer._mcp is mock_mcp
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_scoring_same_context_no_race(self) -> None:
|
||||
"""Test that concurrent scoring of the same context doesn't cause race conditions.
|
||||
|
||||
This verifies that the per-context locking mechanism prevents the same context
|
||||
from being scored multiple times when scored concurrently.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
scorer = CompositeScorer()
|
||||
|
||||
# Create a single context that will be scored multiple times concurrently
|
||||
context = KnowledgeContext(
|
||||
content="Test content for race condition test",
|
||||
source="docs",
|
||||
relevance_score=0.75,
|
||||
)
|
||||
|
||||
# Score the same context many times in parallel
|
||||
num_concurrent = 50
|
||||
tasks = [scorer.score(context, "test query") for _ in range(num_concurrent)]
|
||||
scores = await asyncio.gather(*tasks)
|
||||
|
||||
# All scores should be identical (the same context scored the same way)
|
||||
assert all(s == scores[0] for s in scores)
|
||||
|
||||
# The context should have its _score cached
|
||||
assert context._score is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_scoring_different_contexts(self) -> None:
|
||||
"""Test that concurrent scoring of different contexts works correctly.
|
||||
|
||||
Different contexts should not interfere with each other during parallel scoring.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
scorer = CompositeScorer()
|
||||
|
||||
# Create many different contexts
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content=f"Test content {i}",
|
||||
source="docs",
|
||||
relevance_score=i / 10,
|
||||
)
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
# Score all contexts concurrently
|
||||
tasks = [scorer.score(ctx, "test query") for ctx in contexts]
|
||||
scores = await asyncio.gather(*tasks)
|
||||
|
||||
# Each context should have a different score based on its relevance
|
||||
assert len(set(scores)) > 1 # Not all the same
|
||||
|
||||
# All contexts should have cached scores
|
||||
for ctx in contexts:
|
||||
assert ctx._score is not None
|
||||
|
||||
|
||||
class TestScoredContext:
|
||||
"""Tests for ScoredContext dataclass."""
|
||||
|
||||
Reference in New Issue
Block a user