"""Tests for context compression module.""" import pytest from app.services.context.compression import ( ContextCompressor, TruncationResult, TruncationStrategy, ) from app.services.context.budget import BudgetAllocator, TokenBudget from app.services.context.types import ( ContextType, KnowledgeContext, SystemContext, TaskContext, ) class TestTruncationResult: """Tests for TruncationResult dataclass.""" def test_creation(self) -> None: """Test basic creation.""" result = TruncationResult( original_tokens=100, truncated_tokens=50, content="Truncated content", truncated=True, truncation_ratio=0.5, ) assert result.original_tokens == 100 assert result.truncated_tokens == 50 assert result.truncated is True assert result.truncation_ratio == 0.5 def test_tokens_saved(self) -> None: """Test tokens_saved property.""" result = TruncationResult( original_tokens=100, truncated_tokens=40, content="Test", truncated=True, truncation_ratio=0.6, ) assert result.tokens_saved == 60 def test_no_truncation(self) -> None: """Test when no truncation needed.""" result = TruncationResult( original_tokens=50, truncated_tokens=50, content="Full content", truncated=False, truncation_ratio=0.0, ) assert result.tokens_saved == 0 assert result.truncated is False class TestTruncationStrategy: """Tests for TruncationStrategy.""" def test_creation(self) -> None: """Test strategy creation.""" strategy = TruncationStrategy() assert strategy._preserve_ratio_start == 0.7 assert strategy._min_content_length == 100 def test_creation_with_params(self) -> None: """Test strategy creation with custom params.""" strategy = TruncationStrategy( preserve_ratio_start=0.5, min_content_length=50, ) assert strategy._preserve_ratio_start == 0.5 assert strategy._min_content_length == 50 @pytest.mark.asyncio async def test_truncate_empty_content(self) -> None: """Test truncating empty content.""" strategy = TruncationStrategy() result = await strategy.truncate_to_tokens("", max_tokens=100) assert result.original_tokens == 0 assert result.truncated_tokens == 0 assert result.content == "" assert result.truncated is False @pytest.mark.asyncio async def test_truncate_content_within_limit(self) -> None: """Test content that fits within limit.""" strategy = TruncationStrategy() content = "Short content" result = await strategy.truncate_to_tokens(content, max_tokens=100) assert result.content == content assert result.truncated is False @pytest.mark.asyncio async def test_truncate_end_strategy(self) -> None: """Test end truncation strategy.""" strategy = TruncationStrategy() content = "A" * 1000 # Long content result = await strategy.truncate_to_tokens( content, max_tokens=50, strategy="end" ) assert result.truncated is True assert len(result.content) < len(content) assert strategy.TRUNCATION_MARKER in result.content @pytest.mark.asyncio async def test_truncate_middle_strategy(self) -> None: """Test middle truncation strategy.""" strategy = TruncationStrategy(preserve_ratio_start=0.6) content = "START " + "A" * 500 + " END" result = await strategy.truncate_to_tokens( content, max_tokens=50, strategy="middle" ) assert result.truncated is True assert strategy.TRUNCATION_MARKER in result.content @pytest.mark.asyncio async def test_truncate_sentence_strategy(self) -> None: """Test sentence-aware truncation strategy.""" strategy = TruncationStrategy() content = "First sentence. Second sentence. Third sentence. Fourth sentence." result = await strategy.truncate_to_tokens( content, max_tokens=10, strategy="sentence" ) assert result.truncated is True # Should cut at sentence boundary assert result.content.endswith(".") or strategy.TRUNCATION_MARKER in result.content class TestContextCompressor: """Tests for ContextCompressor.""" def test_creation(self) -> None: """Test compressor creation.""" compressor = ContextCompressor() assert compressor._truncation is not None @pytest.mark.asyncio async def test_compress_context_within_limit(self) -> None: """Test compressing context that already fits.""" compressor = ContextCompressor() context = KnowledgeContext( content="Short content", source="docs", ) context.token_count = 5 result = await compressor.compress_context(context, max_tokens=100) # Should return same context unmodified assert result.content == "Short content" assert result.metadata.get("truncated") is not True @pytest.mark.asyncio async def test_compress_context_exceeds_limit(self) -> None: """Test compressing context that exceeds limit.""" compressor = ContextCompressor() context = KnowledgeContext( content="A" * 500, source="docs", ) context.token_count = 125 # Approximately 500/4 result = await compressor.compress_context(context, max_tokens=20) assert result.metadata.get("truncated") is True assert result.metadata.get("original_tokens") == 125 assert len(result.content) < 500 @pytest.mark.asyncio async def test_compress_contexts_batch(self) -> None: """Test compressing multiple contexts.""" compressor = ContextCompressor() allocator = BudgetAllocator() budget = allocator.create_budget(1000) contexts = [ KnowledgeContext(content="A" * 200, source="docs"), KnowledgeContext(content="B" * 200, source="docs"), TaskContext(content="C" * 200, source="task"), ] result = await compressor.compress_contexts(contexts, budget) assert len(result) == 3 @pytest.mark.asyncio async def test_strategy_selection_by_type(self) -> None: """Test that correct strategy is selected for each type.""" compressor = ContextCompressor() assert compressor._get_strategy_for_type(ContextType.SYSTEM) == "end" assert compressor._get_strategy_for_type(ContextType.TASK) == "end" assert compressor._get_strategy_for_type(ContextType.KNOWLEDGE) == "sentence" assert compressor._get_strategy_for_type(ContextType.CONVERSATION) == "end" assert compressor._get_strategy_for_type(ContextType.TOOL) == "middle"