forked from cardosofelipe/fast-next-template
test(context): add edge case tests for truncation and scoring concurrency
- Add tests for truncation edge cases, including zero tokens, short content, and marker handling. - Add concurrency tests for scoring to verify per-context locking and handling of multiple contexts.
This commit is contained in:
@@ -212,3 +212,80 @@ class TestContextCompressor:
|
|||||||
assert compressor._get_strategy_for_type(ContextType.KNOWLEDGE) == "sentence"
|
assert compressor._get_strategy_for_type(ContextType.KNOWLEDGE) == "sentence"
|
||||||
assert compressor._get_strategy_for_type(ContextType.CONVERSATION) == "end"
|
assert compressor._get_strategy_for_type(ContextType.CONVERSATION) == "end"
|
||||||
assert compressor._get_strategy_for_type(ContextType.TOOL) == "middle"
|
assert compressor._get_strategy_for_type(ContextType.TOOL) == "middle"
|
||||||
|
|
||||||
|
|
||||||
|
class TestTruncationEdgeCases:
|
||||||
|
"""Tests for edge cases in truncation to prevent regressions."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_truncation_ratio_with_zero_original_tokens(self) -> None:
|
||||||
|
"""Test that truncation ratio handles zero original tokens without division by zero."""
|
||||||
|
strategy = TruncationStrategy()
|
||||||
|
|
||||||
|
# Empty content should not raise ZeroDivisionError
|
||||||
|
result = await strategy.truncate_to_tokens("", max_tokens=100)
|
||||||
|
|
||||||
|
assert result.truncation_ratio == 0.0
|
||||||
|
assert result.original_tokens == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_truncate_end_with_zero_available_tokens(self) -> None:
|
||||||
|
"""Test truncation when marker tokens exceed max_tokens."""
|
||||||
|
strategy = TruncationStrategy()
|
||||||
|
content = "Some content to truncate"
|
||||||
|
|
||||||
|
# max_tokens less than marker tokens should return just marker
|
||||||
|
result = await strategy.truncate_to_tokens(content, max_tokens=1, strategy="end")
|
||||||
|
|
||||||
|
# Should handle gracefully without crashing
|
||||||
|
assert strategy.TRUNCATION_MARKER in result.content or result.content == content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_truncate_with_content_that_has_zero_tokens(self) -> None:
|
||||||
|
"""Test truncation when content estimates to zero tokens."""
|
||||||
|
strategy = TruncationStrategy()
|
||||||
|
|
||||||
|
# Very short content that might estimate to 0 tokens
|
||||||
|
result = await strategy.truncate_to_tokens("a", max_tokens=100)
|
||||||
|
|
||||||
|
# Should not raise ZeroDivisionError
|
||||||
|
assert result.content in ("a", "a" + strategy.TRUNCATION_MARKER)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_content_for_tokens_zero_target(self) -> None:
|
||||||
|
"""Test _get_content_for_tokens with zero target tokens."""
|
||||||
|
strategy = TruncationStrategy()
|
||||||
|
|
||||||
|
result = await strategy._get_content_for_tokens(
|
||||||
|
content="Some content",
|
||||||
|
target_tokens=0,
|
||||||
|
from_start=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sentence_truncation_with_no_sentences(self) -> None:
|
||||||
|
"""Test sentence truncation with content that has no sentence boundaries."""
|
||||||
|
strategy = TruncationStrategy()
|
||||||
|
content = "this is content without any sentence ending punctuation"
|
||||||
|
|
||||||
|
result = await strategy.truncate_to_tokens(
|
||||||
|
content, max_tokens=5, strategy="sentence"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should handle gracefully
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_middle_truncation_very_short_content(self) -> None:
|
||||||
|
"""Test middle truncation with content shorter than preserved portions."""
|
||||||
|
strategy = TruncationStrategy(preserve_ratio_start=0.7)
|
||||||
|
content = "ab" # Very short
|
||||||
|
|
||||||
|
result = await strategy.truncate_to_tokens(
|
||||||
|
content, max_tokens=1, strategy="middle"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should handle gracefully without negative indices
|
||||||
|
assert result is not None
|
||||||
|
|||||||
@@ -616,6 +616,66 @@ class TestCompositeScorer:
|
|||||||
scorer.set_mcp_manager(mock_mcp)
|
scorer.set_mcp_manager(mock_mcp)
|
||||||
assert scorer._relevance_scorer._mcp is mock_mcp
|
assert scorer._relevance_scorer._mcp is mock_mcp
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_scoring_same_context_no_race(self) -> None:
|
||||||
|
"""Test that concurrent scoring of the same context doesn't cause race conditions.
|
||||||
|
|
||||||
|
This verifies that the per-context locking mechanism prevents the same context
|
||||||
|
from being scored multiple times when scored concurrently.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
|
||||||
|
# Create a single context that will be scored multiple times concurrently
|
||||||
|
context = KnowledgeContext(
|
||||||
|
content="Test content for race condition test",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.75,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Score the same context many times in parallel
|
||||||
|
num_concurrent = 50
|
||||||
|
tasks = [scorer.score(context, "test query") for _ in range(num_concurrent)]
|
||||||
|
scores = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# All scores should be identical (the same context scored the same way)
|
||||||
|
assert all(s == scores[0] for s in scores)
|
||||||
|
|
||||||
|
# The context should have its _score cached
|
||||||
|
assert context._score is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_scoring_different_contexts(self) -> None:
|
||||||
|
"""Test that concurrent scoring of different contexts works correctly.
|
||||||
|
|
||||||
|
Different contexts should not interfere with each other during parallel scoring.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
|
||||||
|
# Create many different contexts
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content=f"Test content {i}",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=i / 10,
|
||||||
|
)
|
||||||
|
for i in range(10)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Score all contexts concurrently
|
||||||
|
tasks = [scorer.score(ctx, "test query") for ctx in contexts]
|
||||||
|
scores = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# Each context should have a different score based on its relevance
|
||||||
|
assert len(set(scores)) > 1 # Not all the same
|
||||||
|
|
||||||
|
# All contexts should have cached scores
|
||||||
|
for ctx in contexts:
|
||||||
|
assert ctx._score is not None
|
||||||
|
|
||||||
|
|
||||||
class TestScoredContext:
|
class TestScoredContext:
|
||||||
"""Tests for ScoredContext dataclass."""
|
"""Tests for ScoredContext dataclass."""
|
||||||
|
|||||||
Reference in New Issue
Block a user