"""Tests for token budget management.""" from unittest.mock import AsyncMock, MagicMock import pytest from app.services.context.budget import ( BudgetAllocator, TokenBudget, TokenCalculator, ) from app.services.context.config import ContextSettings from app.services.context.exceptions import BudgetExceededError from app.services.context.types import ContextType class TestTokenBudget: """Tests for TokenBudget dataclass.""" def test_creation(self) -> None: """Test basic budget creation.""" budget = TokenBudget(total=10000) assert budget.total == 10000 assert budget.system == 0 assert budget.total_used() == 0 def test_creation_with_allocations(self) -> None: """Test budget creation with allocations.""" budget = TokenBudget( total=10000, system=500, task=1000, knowledge=4000, conversation=2000, tools=500, response_reserve=1500, buffer=500, ) assert budget.system == 500 assert budget.knowledge == 4000 assert budget.response_reserve == 1500 def test_get_allocation(self) -> None: """Test getting allocation for a type.""" budget = TokenBudget( total=10000, system=500, knowledge=4000, ) assert budget.get_allocation(ContextType.SYSTEM) == 500 assert budget.get_allocation(ContextType.KNOWLEDGE) == 4000 assert budget.get_allocation("system") == 500 def test_remaining(self) -> None: """Test remaining budget calculation.""" budget = TokenBudget( total=10000, system=500, knowledge=4000, ) # Initially full assert budget.remaining(ContextType.SYSTEM) == 500 assert budget.remaining(ContextType.KNOWLEDGE) == 4000 # After allocation budget.allocate(ContextType.SYSTEM, 200) assert budget.remaining(ContextType.SYSTEM) == 300 def test_can_fit(self) -> None: """Test can_fit check.""" budget = TokenBudget( total=10000, system=500, knowledge=4000, ) assert budget.can_fit(ContextType.SYSTEM, 500) is True assert budget.can_fit(ContextType.SYSTEM, 501) is False assert budget.can_fit(ContextType.KNOWLEDGE, 4000) is True def test_allocate_success(self) -> None: """Test successful allocation.""" budget = TokenBudget( total=10000, system=500, ) result = budget.allocate(ContextType.SYSTEM, 200) assert result is True assert budget.get_used(ContextType.SYSTEM) == 200 assert budget.remaining(ContextType.SYSTEM) == 300 def test_allocate_exceeds_budget(self) -> None: """Test allocation exceeding budget.""" budget = TokenBudget( total=10000, system=500, ) with pytest.raises(BudgetExceededError) as exc_info: budget.allocate(ContextType.SYSTEM, 600) assert exc_info.value.allocated == 500 assert exc_info.value.requested == 600 def test_allocate_force(self) -> None: """Test forced allocation exceeding budget.""" budget = TokenBudget( total=10000, system=500, ) # Force should allow exceeding result = budget.allocate(ContextType.SYSTEM, 600, force=True) assert result is True assert budget.get_used(ContextType.SYSTEM) == 600 def test_deallocate(self) -> None: """Test deallocation.""" budget = TokenBudget( total=10000, system=500, ) budget.allocate(ContextType.SYSTEM, 300) assert budget.get_used(ContextType.SYSTEM) == 300 budget.deallocate(ContextType.SYSTEM, 100) assert budget.get_used(ContextType.SYSTEM) == 200 def test_deallocate_below_zero(self) -> None: """Test deallocation doesn't go below zero.""" budget = TokenBudget( total=10000, system=500, ) budget.allocate(ContextType.SYSTEM, 100) budget.deallocate(ContextType.SYSTEM, 200) assert budget.get_used(ContextType.SYSTEM) == 0 def test_total_remaining(self) -> None: """Test total remaining calculation.""" budget = TokenBudget( total=10000, system=500, knowledge=4000, response_reserve=1500, buffer=500, ) # Usable = total - response_reserve - buffer = 10000 - 1500 - 500 = 8000 assert budget.total_remaining() == 8000 # After allocation budget.allocate(ContextType.SYSTEM, 200) assert budget.total_remaining() == 7800 def test_utilization(self) -> None: """Test utilization calculation.""" budget = TokenBudget( total=10000, system=500, response_reserve=1500, buffer=500, ) # No usage = 0% assert budget.utilization(ContextType.SYSTEM) == 0.0 # Half used = 50% budget.allocate(ContextType.SYSTEM, 250) assert budget.utilization(ContextType.SYSTEM) == 0.5 # Total utilization assert budget.utilization() == 250 / 8000 # 250 / (10000 - 1500 - 500) def test_reset(self) -> None: """Test reset clears usage.""" budget = TokenBudget( total=10000, system=500, ) budget.allocate(ContextType.SYSTEM, 300) assert budget.get_used(ContextType.SYSTEM) == 300 budget.reset() assert budget.get_used(ContextType.SYSTEM) == 0 assert budget.total_used() == 0 def test_to_dict(self) -> None: """Test to_dict conversion.""" budget = TokenBudget( total=10000, system=500, task=1000, knowledge=4000, ) budget.allocate(ContextType.SYSTEM, 200) data = budget.to_dict() assert data["total"] == 10000 assert data["allocations"]["system"] == 500 assert data["used"]["system"] == 200 assert data["remaining"]["system"] == 300 class TestBudgetAllocator: """Tests for BudgetAllocator.""" def test_create_budget(self) -> None: """Test budget creation with default allocations.""" allocator = BudgetAllocator() budget = allocator.create_budget(100000) assert budget.total == 100000 assert budget.system == 5000 # 5% assert budget.task == 10000 # 10% assert budget.knowledge == 40000 # 40% assert budget.conversation == 20000 # 20% assert budget.tools == 5000 # 5% assert budget.response_reserve == 15000 # 15% assert budget.buffer == 5000 # 5% def test_create_budget_custom_allocations(self) -> None: """Test budget creation with custom allocations.""" allocator = BudgetAllocator() budget = allocator.create_budget( 100000, custom_allocations={ "system": 0.10, "task": 0.10, "knowledge": 0.30, "conversation": 0.25, "tools": 0.05, "response": 0.15, "buffer": 0.05, }, ) assert budget.system == 10000 # 10% assert budget.knowledge == 30000 # 30% def test_create_budget_for_model(self) -> None: """Test budget creation for specific model.""" allocator = BudgetAllocator() # Claude models have 200k context budget = allocator.create_budget_for_model("claude-3-sonnet") assert budget.total == 200000 # GPT-4 has 8k context budget = allocator.create_budget_for_model("gpt-4") assert budget.total == 8192 # GPT-4-turbo has 128k context budget = allocator.create_budget_for_model("gpt-4-turbo") assert budget.total == 128000 def test_get_model_context_size(self) -> None: """Test model context size lookup.""" allocator = BudgetAllocator() # Known models assert allocator.get_model_context_size("claude-3-opus") == 200000 assert allocator.get_model_context_size("gpt-4") == 8192 assert allocator.get_model_context_size("gemini-1.5-pro") == 2000000 # Unknown model gets default assert allocator.get_model_context_size("unknown-model") == 8192 def test_adjust_budget(self) -> None: """Test budget adjustment.""" allocator = BudgetAllocator() budget = allocator.create_budget(10000) original_system = budget.system original_buffer = budget.buffer # Increase system by taking from buffer budget = allocator.adjust_budget(budget, ContextType.SYSTEM, 200) assert budget.system == original_system + 200 assert budget.buffer == original_buffer - 200 def test_adjust_budget_limited_by_buffer(self) -> None: """Test that adjustment is limited by buffer size.""" allocator = BudgetAllocator() budget = allocator.create_budget(10000) original_buffer = budget.buffer # Try to increase more than buffer allows budget = allocator.adjust_budget(budget, ContextType.SYSTEM, 10000) # Should only increase by buffer amount assert budget.buffer == 0 assert budget.system <= original_buffer + budget.system def test_rebalance_budget(self) -> None: """Test budget rebalancing.""" allocator = BudgetAllocator() budget = allocator.create_budget(10000) # Use most of knowledge budget budget.allocate(ContextType.KNOWLEDGE, 3500) # Rebalance prioritizing knowledge budget = allocator.rebalance_budget( budget, prioritize=[ContextType.KNOWLEDGE], ) # Knowledge should have gotten more tokens # (This is a fuzzy test - just check it runs) assert budget is not None class TestTokenCalculator: """Tests for TokenCalculator.""" def test_estimate_tokens(self) -> None: """Test token estimation.""" calc = TokenCalculator() # Empty string assert calc.estimate_tokens("") == 0 # Short text (~4 chars per token) text = "This is a test message" estimate = calc.estimate_tokens(text) assert 4 <= estimate <= 8 def test_estimate_tokens_model_specific(self) -> None: """Test model-specific estimation ratios.""" calc = TokenCalculator() text = "a" * 100 # Claude uses 3.5 chars per token claude_estimate = calc.estimate_tokens(text, "claude-3-sonnet") # GPT uses 4.0 chars per token gpt_estimate = calc.estimate_tokens(text, "gpt-4") # Claude should estimate more tokens (smaller ratio) assert claude_estimate >= gpt_estimate @pytest.mark.asyncio async def test_count_tokens_no_mcp(self) -> None: """Test token counting without MCP (fallback to estimation).""" calc = TokenCalculator() text = "This is a test" count = await calc.count_tokens(text) # Should use estimation assert count > 0 @pytest.mark.asyncio async def test_count_tokens_with_mcp_success(self) -> None: """Test token counting with MCP integration.""" # Mock MCP manager mock_mcp = MagicMock() mock_result = MagicMock() mock_result.success = True mock_result.data = {"token_count": 42} mock_mcp.call_tool = AsyncMock(return_value=mock_result) calc = TokenCalculator(mcp_manager=mock_mcp) count = await calc.count_tokens("test text") assert count == 42 mock_mcp.call_tool.assert_called_once() @pytest.mark.asyncio async def test_count_tokens_with_mcp_failure(self) -> None: """Test fallback when MCP fails.""" # Mock MCP manager that fails mock_mcp = MagicMock() mock_mcp.call_tool = AsyncMock(side_effect=Exception("Connection failed")) calc = TokenCalculator(mcp_manager=mock_mcp) count = await calc.count_tokens("test text") # Should fall back to estimation assert count > 0 @pytest.mark.asyncio async def test_count_tokens_caching(self) -> None: """Test that token counts are cached.""" mock_mcp = MagicMock() mock_result = MagicMock() mock_result.success = True mock_result.data = {"token_count": 42} mock_mcp.call_tool = AsyncMock(return_value=mock_result) calc = TokenCalculator(mcp_manager=mock_mcp) # First call count1 = await calc.count_tokens("test text") # Second call (should use cache) count2 = await calc.count_tokens("test text") assert count1 == count2 == 42 # MCP should only be called once assert mock_mcp.call_tool.call_count == 1 @pytest.mark.asyncio async def test_count_tokens_batch(self) -> None: """Test batch token counting.""" calc = TokenCalculator() texts = ["Hello", "World", "Test message here"] counts = await calc.count_tokens_batch(texts) assert len(counts) == 3 assert all(c > 0 for c in counts) def test_cache_stats(self) -> None: """Test cache statistics.""" calc = TokenCalculator() stats = calc.get_cache_stats() assert stats["enabled"] is True assert stats["size"] == 0 assert stats["hits"] == 0 assert stats["misses"] == 0 @pytest.mark.asyncio async def test_cache_hit_rate(self) -> None: """Test cache hit rate tracking.""" calc = TokenCalculator() # Make some calls await calc.count_tokens("text1") await calc.count_tokens("text2") await calc.count_tokens("text1") # Cache hit stats = calc.get_cache_stats() assert stats["hits"] == 1 assert stats["misses"] == 2 def test_clear_cache(self) -> None: """Test cache clearing.""" calc = TokenCalculator() calc._cache["test"] = 100 calc._cache_hits = 5 calc.clear_cache() assert len(calc._cache) == 0 assert calc._cache_hits == 0 def test_set_mcp_manager(self) -> None: """Test setting MCP manager after initialization.""" calc = TokenCalculator() assert calc._mcp is None mock_mcp = MagicMock() calc.set_mcp_manager(mock_mcp) assert calc._mcp is mock_mcp @pytest.mark.asyncio async def test_parse_token_count_formats(self) -> None: """Test parsing different token count response formats.""" calc = TokenCalculator() # Dict with token_count assert calc._parse_token_count({"token_count": 42}) == 42 # Dict with tokens assert calc._parse_token_count({"tokens": 42}) == 42 # Dict with count assert calc._parse_token_count({"count": 42}) == 42 # Direct int assert calc._parse_token_count(42) == 42 # JSON string assert calc._parse_token_count('{"token_count": 42}') == 42 # Invalid assert calc._parse_token_count("invalid") is None class TestBudgetIntegration: """Integration tests for budget management.""" @pytest.mark.asyncio async def test_full_budget_workflow(self) -> None: """Test complete budget allocation workflow.""" # Create settings and allocator settings = ContextSettings() allocator = BudgetAllocator(settings) # Create budget for Claude budget = allocator.create_budget_for_model("claude-3-sonnet") assert budget.total == 200000 # Create calculator (without MCP for test) calc = TokenCalculator() # Simulate context allocation system_text = "You are a helpful assistant." * 10 system_tokens = await calc.count_tokens(system_text) # Allocate assert budget.can_fit(ContextType.SYSTEM, system_tokens) budget.allocate(ContextType.SYSTEM, system_tokens) # Check state assert budget.get_used(ContextType.SYSTEM) == system_tokens assert budget.remaining(ContextType.SYSTEM) == budget.system - system_tokens @pytest.mark.asyncio async def test_budget_overflow_handling(self) -> None: """Test handling budget overflow.""" allocator = BudgetAllocator() budget = allocator.create_budget(1000) # Small budget # Try to allocate more than available with pytest.raises(BudgetExceededError): budget.allocate(ContextType.KNOWLEDGE, 500) # Force allocation should work budget.allocate(ContextType.KNOWLEDGE, 500, force=True) assert budget.get_used(ContextType.KNOWLEDGE) == 500