- Cleaned up unnecessary comments in `__all__` definitions for better readability. - Adjusted indentation and formatting across modules for improved clarity (e.g., long lines, logical grouping). - Simplified conditional expressions and inline comments for context scoring and ranking. - Replaced some hard-coded values with type-safe annotations (e.g., `ClassVar`). - Removed unused imports and ensured consistent usage across test files. - Updated `test_score_not_cached_on_context` to clarify caching behavior. - Improved truncation strategy logic and marker handling.
295 lines
9.8 KiB
Python
295 lines
9.8 KiB
Python
"""Tests for context compression module."""
|
|
|
|
import pytest
|
|
|
|
from app.services.context.budget import BudgetAllocator
|
|
from app.services.context.compression import (
|
|
ContextCompressor,
|
|
TruncationResult,
|
|
TruncationStrategy,
|
|
)
|
|
from app.services.context.types import (
|
|
ContextType,
|
|
KnowledgeContext,
|
|
TaskContext,
|
|
)
|
|
|
|
|
|
class TestTruncationResult:
|
|
"""Tests for TruncationResult dataclass."""
|
|
|
|
def test_creation(self) -> None:
|
|
"""Test basic creation."""
|
|
result = TruncationResult(
|
|
original_tokens=100,
|
|
truncated_tokens=50,
|
|
content="Truncated content",
|
|
truncated=True,
|
|
truncation_ratio=0.5,
|
|
)
|
|
|
|
assert result.original_tokens == 100
|
|
assert result.truncated_tokens == 50
|
|
assert result.truncated is True
|
|
assert result.truncation_ratio == 0.5
|
|
|
|
def test_tokens_saved(self) -> None:
|
|
"""Test tokens_saved property."""
|
|
result = TruncationResult(
|
|
original_tokens=100,
|
|
truncated_tokens=40,
|
|
content="Test",
|
|
truncated=True,
|
|
truncation_ratio=0.6,
|
|
)
|
|
|
|
assert result.tokens_saved == 60
|
|
|
|
def test_no_truncation(self) -> None:
|
|
"""Test when no truncation needed."""
|
|
result = TruncationResult(
|
|
original_tokens=50,
|
|
truncated_tokens=50,
|
|
content="Full content",
|
|
truncated=False,
|
|
truncation_ratio=0.0,
|
|
)
|
|
|
|
assert result.tokens_saved == 0
|
|
assert result.truncated is False
|
|
|
|
|
|
class TestTruncationStrategy:
|
|
"""Tests for TruncationStrategy."""
|
|
|
|
def test_creation(self) -> None:
|
|
"""Test strategy creation."""
|
|
strategy = TruncationStrategy()
|
|
assert strategy._preserve_ratio_start == 0.7
|
|
assert strategy._min_content_length == 100
|
|
|
|
def test_creation_with_params(self) -> None:
|
|
"""Test strategy creation with custom params."""
|
|
strategy = TruncationStrategy(
|
|
preserve_ratio_start=0.5,
|
|
min_content_length=50,
|
|
)
|
|
assert strategy._preserve_ratio_start == 0.5
|
|
assert strategy._min_content_length == 50
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_truncate_empty_content(self) -> None:
|
|
"""Test truncating empty content."""
|
|
strategy = TruncationStrategy()
|
|
|
|
result = await strategy.truncate_to_tokens("", max_tokens=100)
|
|
|
|
assert result.original_tokens == 0
|
|
assert result.truncated_tokens == 0
|
|
assert result.content == ""
|
|
assert result.truncated is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_truncate_content_within_limit(self) -> None:
|
|
"""Test content that fits within limit."""
|
|
strategy = TruncationStrategy()
|
|
content = "Short content"
|
|
|
|
result = await strategy.truncate_to_tokens(content, max_tokens=100)
|
|
|
|
assert result.content == content
|
|
assert result.truncated is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_truncate_end_strategy(self) -> None:
|
|
"""Test end truncation strategy."""
|
|
strategy = TruncationStrategy()
|
|
content = "A" * 1000 # Long content
|
|
|
|
result = await strategy.truncate_to_tokens(
|
|
content, max_tokens=50, strategy="end"
|
|
)
|
|
|
|
assert result.truncated is True
|
|
assert len(result.content) < len(content)
|
|
assert strategy.truncation_marker in result.content
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_truncate_middle_strategy(self) -> None:
|
|
"""Test middle truncation strategy."""
|
|
strategy = TruncationStrategy(preserve_ratio_start=0.6)
|
|
content = "START " + "A" * 500 + " END"
|
|
|
|
result = await strategy.truncate_to_tokens(
|
|
content, max_tokens=50, strategy="middle"
|
|
)
|
|
|
|
assert result.truncated is True
|
|
assert strategy.truncation_marker in result.content
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_truncate_sentence_strategy(self) -> None:
|
|
"""Test sentence-aware truncation strategy."""
|
|
strategy = TruncationStrategy()
|
|
content = "First sentence. Second sentence. Third sentence. Fourth sentence."
|
|
|
|
result = await strategy.truncate_to_tokens(
|
|
content, max_tokens=10, strategy="sentence"
|
|
)
|
|
|
|
assert result.truncated is True
|
|
# Should cut at sentence boundary
|
|
assert (
|
|
result.content.endswith(".") or strategy.truncation_marker in result.content
|
|
)
|
|
|
|
|
|
class TestContextCompressor:
|
|
"""Tests for ContextCompressor."""
|
|
|
|
def test_creation(self) -> None:
|
|
"""Test compressor creation."""
|
|
compressor = ContextCompressor()
|
|
assert compressor._truncation is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_compress_context_within_limit(self) -> None:
|
|
"""Test compressing context that already fits."""
|
|
compressor = ContextCompressor()
|
|
|
|
context = KnowledgeContext(
|
|
content="Short content",
|
|
source="docs",
|
|
)
|
|
context.token_count = 5
|
|
|
|
result = await compressor.compress_context(context, max_tokens=100)
|
|
|
|
# Should return same context unmodified
|
|
assert result.content == "Short content"
|
|
assert result.metadata.get("truncated") is not True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_compress_context_exceeds_limit(self) -> None:
|
|
"""Test compressing context that exceeds limit."""
|
|
compressor = ContextCompressor()
|
|
|
|
context = KnowledgeContext(
|
|
content="A" * 500,
|
|
source="docs",
|
|
)
|
|
context.token_count = 125 # Approximately 500/4
|
|
|
|
result = await compressor.compress_context(context, max_tokens=20)
|
|
|
|
assert result.metadata.get("truncated") is True
|
|
assert result.metadata.get("original_tokens") == 125
|
|
assert len(result.content) < 500
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_compress_contexts_batch(self) -> None:
|
|
"""Test compressing multiple contexts."""
|
|
compressor = ContextCompressor()
|
|
allocator = BudgetAllocator()
|
|
budget = allocator.create_budget(1000)
|
|
|
|
contexts = [
|
|
KnowledgeContext(content="A" * 200, source="docs"),
|
|
KnowledgeContext(content="B" * 200, source="docs"),
|
|
TaskContext(content="C" * 200, source="task"),
|
|
]
|
|
|
|
result = await compressor.compress_contexts(contexts, budget)
|
|
|
|
assert len(result) == 3
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_strategy_selection_by_type(self) -> None:
|
|
"""Test that correct strategy is selected for each type."""
|
|
compressor = ContextCompressor()
|
|
|
|
assert compressor._get_strategy_for_type(ContextType.SYSTEM) == "end"
|
|
assert compressor._get_strategy_for_type(ContextType.TASK) == "end"
|
|
assert compressor._get_strategy_for_type(ContextType.KNOWLEDGE) == "sentence"
|
|
assert compressor._get_strategy_for_type(ContextType.CONVERSATION) == "end"
|
|
assert compressor._get_strategy_for_type(ContextType.TOOL) == "middle"
|
|
|
|
|
|
class TestTruncationEdgeCases:
|
|
"""Tests for edge cases in truncation to prevent regressions."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_truncation_ratio_with_zero_original_tokens(self) -> None:
|
|
"""Test that truncation ratio handles zero original tokens without division by zero."""
|
|
strategy = TruncationStrategy()
|
|
|
|
# Empty content should not raise ZeroDivisionError
|
|
result = await strategy.truncate_to_tokens("", max_tokens=100)
|
|
|
|
assert result.truncation_ratio == 0.0
|
|
assert result.original_tokens == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_truncate_end_with_zero_available_tokens(self) -> None:
|
|
"""Test truncation when marker tokens exceed max_tokens."""
|
|
strategy = TruncationStrategy()
|
|
content = "Some content to truncate"
|
|
|
|
# max_tokens less than marker tokens should return just marker
|
|
result = await strategy.truncate_to_tokens(
|
|
content, max_tokens=1, strategy="end"
|
|
)
|
|
|
|
# Should handle gracefully without crashing
|
|
assert strategy.truncation_marker in result.content or result.content == content
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_truncate_with_content_that_has_zero_tokens(self) -> None:
|
|
"""Test truncation when content estimates to zero tokens."""
|
|
strategy = TruncationStrategy()
|
|
|
|
# Very short content that might estimate to 0 tokens
|
|
result = await strategy.truncate_to_tokens("a", max_tokens=100)
|
|
|
|
# Should not raise ZeroDivisionError
|
|
assert result.content in ("a", "a" + strategy.truncation_marker)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_content_for_tokens_zero_target(self) -> None:
|
|
"""Test _get_content_for_tokens with zero target tokens."""
|
|
strategy = TruncationStrategy()
|
|
|
|
result = await strategy._get_content_for_tokens(
|
|
content="Some content",
|
|
target_tokens=0,
|
|
from_start=True,
|
|
)
|
|
|
|
assert result == ""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sentence_truncation_with_no_sentences(self) -> None:
|
|
"""Test sentence truncation with content that has no sentence boundaries."""
|
|
strategy = TruncationStrategy()
|
|
content = "this is content without any sentence ending punctuation"
|
|
|
|
result = await strategy.truncate_to_tokens(
|
|
content, max_tokens=5, strategy="sentence"
|
|
)
|
|
|
|
# Should handle gracefully
|
|
assert result is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_middle_truncation_very_short_content(self) -> None:
|
|
"""Test middle truncation with content shorter than preserved portions."""
|
|
strategy = TruncationStrategy(preserve_ratio_start=0.7)
|
|
content = "ab" # Very short
|
|
|
|
result = await strategy.truncate_to_tokens(
|
|
content, max_tokens=1, strategy="middle"
|
|
)
|
|
|
|
# Should handle gracefully without negative indices
|
|
assert result is not None
|