diff --git a/backend/tests/services/context/test_compression.py b/backend/tests/services/context/test_compression.py index b2fde38..c37ca10 100644 --- a/backend/tests/services/context/test_compression.py +++ b/backend/tests/services/context/test_compression.py @@ -212,3 +212,80 @@ class TestContextCompressor: assert compressor._get_strategy_for_type(ContextType.KNOWLEDGE) == "sentence" assert compressor._get_strategy_for_type(ContextType.CONVERSATION) == "end" assert compressor._get_strategy_for_type(ContextType.TOOL) == "middle" + + +class TestTruncationEdgeCases: + """Tests for edge cases in truncation to prevent regressions.""" + + @pytest.mark.asyncio + async def test_truncation_ratio_with_zero_original_tokens(self) -> None: + """Test that truncation ratio handles zero original tokens without division by zero.""" + strategy = TruncationStrategy() + + # Empty content should not raise ZeroDivisionError + result = await strategy.truncate_to_tokens("", max_tokens=100) + + assert result.truncation_ratio == 0.0 + assert result.original_tokens == 0 + + @pytest.mark.asyncio + async def test_truncate_end_with_zero_available_tokens(self) -> None: + """Test truncation when marker tokens exceed max_tokens.""" + strategy = TruncationStrategy() + content = "Some content to truncate" + + # max_tokens less than marker tokens should return just marker + result = await strategy.truncate_to_tokens(content, max_tokens=1, strategy="end") + + # Should handle gracefully without crashing + assert strategy.TRUNCATION_MARKER in result.content or result.content == content + + @pytest.mark.asyncio + async def test_truncate_with_content_that_has_zero_tokens(self) -> None: + """Test truncation when content estimates to zero tokens.""" + strategy = TruncationStrategy() + + # Very short content that might estimate to 0 tokens + result = await strategy.truncate_to_tokens("a", max_tokens=100) + + # Should not raise ZeroDivisionError + assert result.content in ("a", "a" + strategy.TRUNCATION_MARKER) + + @pytest.mark.asyncio + async def test_get_content_for_tokens_zero_target(self) -> None: + """Test _get_content_for_tokens with zero target tokens.""" + strategy = TruncationStrategy() + + result = await strategy._get_content_for_tokens( + content="Some content", + target_tokens=0, + from_start=True, + ) + + assert result == "" + + @pytest.mark.asyncio + async def test_sentence_truncation_with_no_sentences(self) -> None: + """Test sentence truncation with content that has no sentence boundaries.""" + strategy = TruncationStrategy() + content = "this is content without any sentence ending punctuation" + + result = await strategy.truncate_to_tokens( + content, max_tokens=5, strategy="sentence" + ) + + # Should handle gracefully + assert result is not None + + @pytest.mark.asyncio + async def test_middle_truncation_very_short_content(self) -> None: + """Test middle truncation with content shorter than preserved portions.""" + strategy = TruncationStrategy(preserve_ratio_start=0.7) + content = "ab" # Very short + + result = await strategy.truncate_to_tokens( + content, max_tokens=1, strategy="middle" + ) + + # Should handle gracefully without negative indices + assert result is not None diff --git a/backend/tests/services/context/test_scoring.py b/backend/tests/services/context/test_scoring.py index 6fea92f..1feeea6 100644 --- a/backend/tests/services/context/test_scoring.py +++ b/backend/tests/services/context/test_scoring.py @@ -616,6 +616,66 @@ class TestCompositeScorer: scorer.set_mcp_manager(mock_mcp) assert scorer._relevance_scorer._mcp is mock_mcp + @pytest.mark.asyncio + async def test_concurrent_scoring_same_context_no_race(self) -> None: + """Test that concurrent scoring of the same context doesn't cause race conditions. + + This verifies that the per-context locking mechanism prevents the same context + from being scored multiple times when scored concurrently. + """ + import asyncio + + scorer = CompositeScorer() + + # Create a single context that will be scored multiple times concurrently + context = KnowledgeContext( + content="Test content for race condition test", + source="docs", + relevance_score=0.75, + ) + + # Score the same context many times in parallel + num_concurrent = 50 + tasks = [scorer.score(context, "test query") for _ in range(num_concurrent)] + scores = await asyncio.gather(*tasks) + + # All scores should be identical (the same context scored the same way) + assert all(s == scores[0] for s in scores) + + # The context should have its _score cached + assert context._score is not None + + @pytest.mark.asyncio + async def test_concurrent_scoring_different_contexts(self) -> None: + """Test that concurrent scoring of different contexts works correctly. + + Different contexts should not interfere with each other during parallel scoring. + """ + import asyncio + + scorer = CompositeScorer() + + # Create many different contexts + contexts = [ + KnowledgeContext( + content=f"Test content {i}", + source="docs", + relevance_score=i / 10, + ) + for i in range(10) + ] + + # Score all contexts concurrently + tasks = [scorer.score(ctx, "test query") for ctx in contexts] + scores = await asyncio.gather(*tasks) + + # Each context should have a different score based on its relevance + assert len(set(scores)) > 1 # Not all the same + + # All contexts should have cached scores + for ctx in contexts: + assert ctx._score is not None + class TestScoredContext: """Tests for ScoredContext dataclass."""