"""Tests for cost controller module.""" import pytest from app.services.safety.costs.controller import ( BudgetTracker, CostController, ) from app.services.safety.exceptions import BudgetExceededError from app.services.safety.models import ( ActionMetadata, ActionRequest, ActionType, BudgetScope, ) @pytest.fixture def budget_tracker() -> BudgetTracker: """Create a budget tracker for testing.""" return BudgetTracker( scope=BudgetScope.SESSION, scope_id="test-session", tokens_limit=1000, cost_limit_usd=10.0, warning_threshold=0.8, ) @pytest.fixture def cost_controller() -> CostController: """Create a cost controller for testing.""" return CostController( default_session_tokens=1000, default_session_cost_usd=10.0, default_daily_tokens=5000, default_daily_cost_usd=50.0, ) @pytest.fixture def sample_metadata() -> ActionMetadata: """Create sample action metadata.""" return ActionMetadata( agent_id="test-agent", session_id="test-session", ) def create_action( metadata: ActionMetadata, estimated_tokens: int = 100, estimated_cost: float = 0.01, ) -> ActionRequest: """Helper to create test actions.""" return ActionRequest( action_type=ActionType.LLM_CALL, tool_name="test_tool", resource="test-resource", arguments={}, metadata=metadata, estimated_cost_tokens=estimated_tokens, estimated_cost_usd=estimated_cost, ) class TestBudgetTracker: """Tests for BudgetTracker class.""" @pytest.mark.asyncio async def test_initial_status(self, budget_tracker: BudgetTracker) -> None: """Test initial budget status is clean.""" status = await budget_tracker.get_status() assert status.tokens_used == 0 assert status.cost_used_usd == 0.0 assert status.tokens_remaining == 1000 assert status.cost_remaining_usd == 10.0 assert status.is_warning is False assert status.is_exceeded is False @pytest.mark.asyncio async def test_add_usage(self, budget_tracker: BudgetTracker) -> None: """Test adding usage updates counters.""" await budget_tracker.add_usage(tokens=100, cost_usd=1.0) status = await budget_tracker.get_status() assert status.tokens_used == 100 assert status.cost_used_usd == 1.0 assert status.tokens_remaining == 900 assert status.cost_remaining_usd == 9.0 @pytest.mark.asyncio async def test_warning_threshold(self, budget_tracker: BudgetTracker) -> None: """Test warning is triggered at threshold.""" # Add usage to reach 80% of tokens await budget_tracker.add_usage(tokens=800, cost_usd=1.0) status = await budget_tracker.get_status() assert status.is_warning is True assert status.is_exceeded is False @pytest.mark.asyncio async def test_budget_exceeded(self, budget_tracker: BudgetTracker) -> None: """Test budget exceeded detection.""" # Exceed token limit await budget_tracker.add_usage(tokens=1100, cost_usd=1.0) status = await budget_tracker.get_status() assert status.is_exceeded is True @pytest.mark.asyncio async def test_check_budget_allows(self, budget_tracker: BudgetTracker) -> None: """Test check_budget allows within budget.""" result = await budget_tracker.check_budget( estimated_tokens=500, estimated_cost_usd=5.0, ) assert result is True @pytest.mark.asyncio async def test_check_budget_denies(self, budget_tracker: BudgetTracker) -> None: """Test check_budget denies when would exceed.""" # Use most of the budget await budget_tracker.add_usage(tokens=800, cost_usd=8.0) # Check would exceed result = await budget_tracker.check_budget( estimated_tokens=300, estimated_cost_usd=3.0, ) assert result is False @pytest.mark.asyncio async def test_reset(self, budget_tracker: BudgetTracker) -> None: """Test manual reset clears counters.""" await budget_tracker.add_usage(tokens=500, cost_usd=5.0) await budget_tracker.reset() status = await budget_tracker.get_status() assert status.tokens_used == 0 assert status.cost_used_usd == 0.0 class TestCostController: """Tests for CostController class.""" @pytest.mark.asyncio async def test_check_budget_success( self, cost_controller: CostController, ) -> None: """Test budget check passes with available budget.""" result = await cost_controller.check_budget( agent_id="test-agent", session_id="test-session", estimated_tokens=100, estimated_cost_usd=1.0, ) assert result is True @pytest.mark.asyncio async def test_check_budget_session_exceeded( self, cost_controller: CostController, ) -> None: """Test budget check fails when session budget exceeded.""" # Use most of session budget await cost_controller.record_usage( agent_id="test-agent", session_id="test-session", tokens=900, cost_usd=9.0, ) # Check would exceed result = await cost_controller.check_budget( agent_id="test-agent", session_id="test-session", estimated_tokens=200, estimated_cost_usd=2.0, ) assert result is False @pytest.mark.asyncio async def test_check_budget_daily_exceeded( self, cost_controller: CostController, ) -> None: """Test budget check fails when daily budget exceeded.""" # Use most of daily budget await cost_controller.record_usage( agent_id="test-agent", session_id=None, tokens=4900, cost_usd=49.0, ) # Check would exceed daily result = await cost_controller.check_budget( agent_id="test-agent", session_id="new-session", estimated_tokens=200, estimated_cost_usd=2.0, ) assert result is False @pytest.mark.asyncio async def test_check_action( self, cost_controller: CostController, sample_metadata: ActionMetadata, ) -> None: """Test checking action budget.""" action = create_action( sample_metadata, estimated_tokens=100, estimated_cost=0.01, ) result = await cost_controller.check_action(action) assert result is True @pytest.mark.asyncio async def test_require_budget_success( self, cost_controller: CostController, ) -> None: """Test require_budget passes when budget available.""" # Should not raise await cost_controller.require_budget( agent_id="test-agent", session_id="test-session", estimated_tokens=100, estimated_cost_usd=1.0, ) @pytest.mark.asyncio async def test_require_budget_raises( self, cost_controller: CostController, ) -> None: """Test require_budget raises when budget exceeded.""" # Use all session budget await cost_controller.record_usage( agent_id="test-agent", session_id="test-session", tokens=1000, cost_usd=10.0, ) with pytest.raises(BudgetExceededError) as exc_info: await cost_controller.require_budget( agent_id="test-agent", session_id="test-session", estimated_tokens=100, estimated_cost_usd=1.0, ) assert "session" in exc_info.value.budget_type.lower() @pytest.mark.asyncio async def test_record_usage( self, cost_controller: CostController, ) -> None: """Test recording usage updates trackers.""" await cost_controller.record_usage( agent_id="test-agent", session_id="test-session", tokens=100, cost_usd=1.0, ) # Check session budget was updated session_status = await cost_controller.get_status( BudgetScope.SESSION, "test-session" ) assert session_status is not None assert session_status.tokens_used == 100 # Check daily budget was updated daily_status = await cost_controller.get_status(BudgetScope.DAILY, "test-agent") assert daily_status is not None assert daily_status.tokens_used == 100 @pytest.mark.asyncio async def test_get_all_statuses( self, cost_controller: CostController, ) -> None: """Test getting all budget statuses.""" # Record some usage await cost_controller.record_usage( agent_id="agent-1", session_id="session-1", tokens=100, cost_usd=1.0, ) await cost_controller.record_usage( agent_id="agent-2", session_id="session-2", tokens=200, cost_usd=2.0, ) statuses = await cost_controller.get_all_statuses() assert len(statuses) >= 2 @pytest.mark.asyncio async def test_set_budget( self, cost_controller: CostController, ) -> None: """Test setting custom budget.""" await cost_controller.set_budget( scope=BudgetScope.SESSION, scope_id="custom-session", tokens_limit=5000, cost_limit_usd=50.0, ) status = await cost_controller.get_status(BudgetScope.SESSION, "custom-session") assert status is not None assert status.tokens_limit == 5000 assert status.cost_limit_usd == 50.0 @pytest.mark.asyncio async def test_reset_budget( self, cost_controller: CostController, ) -> None: """Test resetting budget.""" # Record usage await cost_controller.record_usage( agent_id="test-agent", session_id="test-session", tokens=500, cost_usd=5.0, ) # Reset session budget result = await cost_controller.reset_budget(BudgetScope.SESSION, "test-session") assert result is True # Verify reset status = await cost_controller.get_status(BudgetScope.SESSION, "test-session") assert status is not None assert status.tokens_used == 0 @pytest.mark.asyncio async def test_reset_nonexistent_budget( self, cost_controller: CostController, ) -> None: """Test resetting non-existent budget returns False.""" result = await cost_controller.reset_budget(BudgetScope.SESSION, "nonexistent") assert result is False @pytest.mark.asyncio async def test_alert_handler( self, cost_controller: CostController, ) -> None: """Test alert handler is called at warning threshold.""" alerts_received = [] def alert_handler(alert_type: str, message: str, status): alerts_received.append((alert_type, message)) cost_controller.add_alert_handler(alert_handler) # Record usage to reach warning threshold (80%) await cost_controller.record_usage( agent_id="test-agent", session_id="test-session", tokens=850, # 85% of 1000 cost_usd=0.0, ) assert len(alerts_received) > 0 assert alerts_received[0][0] == "warning" @pytest.mark.asyncio async def test_remove_alert_handler( self, cost_controller: CostController, ) -> None: """Test removing alert handler.""" alerts_received = [] def alert_handler(alert_type: str, message: str, status): alerts_received.append((alert_type, message)) cost_controller.add_alert_handler(alert_handler) cost_controller.remove_alert_handler(alert_handler) # Record usage to reach warning threshold await cost_controller.record_usage( agent_id="test-agent", session_id="test-session", tokens=850, cost_usd=0.0, ) assert len(alerts_received) == 0 @pytest.mark.asyncio async def test_alert_deduplication( self, cost_controller: CostController, ) -> None: """Test alerts are only sent once per budget (no spam).""" alerts_received = [] def alert_handler(alert_type: str, message: str, status): alerts_received.append((alert_type, message)) cost_controller.add_alert_handler(alert_handler) # Record usage multiple times at warning level # Session budget is 1000 with 80% threshold = 800 tokens # 10 * 85 = 850 tokens triggers session warning once for _ in range(10): await cost_controller.record_usage( agent_id="test-agent", session_id="test-session", tokens=85, # Each call adds 85 tokens cost_usd=0.0, ) # Should only receive ONE session warning (daily budget of 5000 # isn't reached yet). The key point is we don't get 10 alerts! assert len(alerts_received) == 1 assert alerts_received[0][0] == "warning" assert "Session" in alerts_received[0][1]