forked from cardosofelipe/fast-next-template
- Added `record_action` in `RateLimiter` for precise tracking of slot consumption post-validation. - Introduced deduplication mechanism for warning alerts in `CostController` to prevent spamming. - Refactored `CostController`'s session and daily budget alert handling for improved clarity. - Implemented test suites for `CostController` and `SafetyGuardian` to validate changes. - Expanded integration testing to cover deduplication, validation, and loop detection edge cases.
437 lines
13 KiB
Python
437 lines
13 KiB
Python
"""Tests for cost controller module."""
|
|
|
|
import pytest
|
|
|
|
from app.services.safety.costs.controller import (
|
|
BudgetTracker,
|
|
CostController,
|
|
)
|
|
from app.services.safety.exceptions import BudgetExceededError
|
|
from app.services.safety.models import (
|
|
ActionMetadata,
|
|
ActionRequest,
|
|
ActionType,
|
|
BudgetScope,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def budget_tracker() -> BudgetTracker:
|
|
"""Create a budget tracker for testing."""
|
|
return BudgetTracker(
|
|
scope=BudgetScope.SESSION,
|
|
scope_id="test-session",
|
|
tokens_limit=1000,
|
|
cost_limit_usd=10.0,
|
|
warning_threshold=0.8,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def cost_controller() -> CostController:
|
|
"""Create a cost controller for testing."""
|
|
return CostController(
|
|
default_session_tokens=1000,
|
|
default_session_cost_usd=10.0,
|
|
default_daily_tokens=5000,
|
|
default_daily_cost_usd=50.0,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_metadata() -> ActionMetadata:
|
|
"""Create sample action metadata."""
|
|
return ActionMetadata(
|
|
agent_id="test-agent",
|
|
session_id="test-session",
|
|
)
|
|
|
|
|
|
def create_action(
|
|
metadata: ActionMetadata,
|
|
estimated_tokens: int = 100,
|
|
estimated_cost: float = 0.01,
|
|
) -> ActionRequest:
|
|
"""Helper to create test actions."""
|
|
return ActionRequest(
|
|
action_type=ActionType.LLM_CALL,
|
|
tool_name="test_tool",
|
|
resource="test-resource",
|
|
arguments={},
|
|
metadata=metadata,
|
|
estimated_cost_tokens=estimated_tokens,
|
|
estimated_cost_usd=estimated_cost,
|
|
)
|
|
|
|
|
|
class TestBudgetTracker:
|
|
"""Tests for BudgetTracker class."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_initial_status(self, budget_tracker: BudgetTracker) -> None:
|
|
"""Test initial budget status is clean."""
|
|
status = await budget_tracker.get_status()
|
|
|
|
assert status.tokens_used == 0
|
|
assert status.cost_used_usd == 0.0
|
|
assert status.tokens_remaining == 1000
|
|
assert status.cost_remaining_usd == 10.0
|
|
assert status.is_warning is False
|
|
assert status.is_exceeded is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_add_usage(self, budget_tracker: BudgetTracker) -> None:
|
|
"""Test adding usage updates counters."""
|
|
await budget_tracker.add_usage(tokens=100, cost_usd=1.0)
|
|
|
|
status = await budget_tracker.get_status()
|
|
assert status.tokens_used == 100
|
|
assert status.cost_used_usd == 1.0
|
|
assert status.tokens_remaining == 900
|
|
assert status.cost_remaining_usd == 9.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_warning_threshold(self, budget_tracker: BudgetTracker) -> None:
|
|
"""Test warning is triggered at threshold."""
|
|
# Add usage to reach 80% of tokens
|
|
await budget_tracker.add_usage(tokens=800, cost_usd=1.0)
|
|
|
|
status = await budget_tracker.get_status()
|
|
assert status.is_warning is True
|
|
assert status.is_exceeded is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_budget_exceeded(self, budget_tracker: BudgetTracker) -> None:
|
|
"""Test budget exceeded detection."""
|
|
# Exceed token limit
|
|
await budget_tracker.add_usage(tokens=1100, cost_usd=1.0)
|
|
|
|
status = await budget_tracker.get_status()
|
|
assert status.is_exceeded is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_budget_allows(self, budget_tracker: BudgetTracker) -> None:
|
|
"""Test check_budget allows within budget."""
|
|
result = await budget_tracker.check_budget(
|
|
estimated_tokens=500,
|
|
estimated_cost_usd=5.0,
|
|
)
|
|
assert result is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_budget_denies(self, budget_tracker: BudgetTracker) -> None:
|
|
"""Test check_budget denies when would exceed."""
|
|
# Use most of the budget
|
|
await budget_tracker.add_usage(tokens=800, cost_usd=8.0)
|
|
|
|
# Check would exceed
|
|
result = await budget_tracker.check_budget(
|
|
estimated_tokens=300,
|
|
estimated_cost_usd=3.0,
|
|
)
|
|
assert result is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reset(self, budget_tracker: BudgetTracker) -> None:
|
|
"""Test manual reset clears counters."""
|
|
await budget_tracker.add_usage(tokens=500, cost_usd=5.0)
|
|
await budget_tracker.reset()
|
|
|
|
status = await budget_tracker.get_status()
|
|
assert status.tokens_used == 0
|
|
assert status.cost_used_usd == 0.0
|
|
|
|
|
|
class TestCostController:
|
|
"""Tests for CostController class."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_budget_success(
|
|
self,
|
|
cost_controller: CostController,
|
|
) -> None:
|
|
"""Test budget check passes with available budget."""
|
|
result = await cost_controller.check_budget(
|
|
agent_id="test-agent",
|
|
session_id="test-session",
|
|
estimated_tokens=100,
|
|
estimated_cost_usd=1.0,
|
|
)
|
|
assert result is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_budget_session_exceeded(
|
|
self,
|
|
cost_controller: CostController,
|
|
) -> None:
|
|
"""Test budget check fails when session budget exceeded."""
|
|
# Use most of session budget
|
|
await cost_controller.record_usage(
|
|
agent_id="test-agent",
|
|
session_id="test-session",
|
|
tokens=900,
|
|
cost_usd=9.0,
|
|
)
|
|
|
|
# Check would exceed
|
|
result = await cost_controller.check_budget(
|
|
agent_id="test-agent",
|
|
session_id="test-session",
|
|
estimated_tokens=200,
|
|
estimated_cost_usd=2.0,
|
|
)
|
|
assert result is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_budget_daily_exceeded(
|
|
self,
|
|
cost_controller: CostController,
|
|
) -> None:
|
|
"""Test budget check fails when daily budget exceeded."""
|
|
# Use most of daily budget
|
|
await cost_controller.record_usage(
|
|
agent_id="test-agent",
|
|
session_id=None,
|
|
tokens=4900,
|
|
cost_usd=49.0,
|
|
)
|
|
|
|
# Check would exceed daily
|
|
result = await cost_controller.check_budget(
|
|
agent_id="test-agent",
|
|
session_id="new-session",
|
|
estimated_tokens=200,
|
|
estimated_cost_usd=2.0,
|
|
)
|
|
assert result is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_action(
|
|
self,
|
|
cost_controller: CostController,
|
|
sample_metadata: ActionMetadata,
|
|
) -> None:
|
|
"""Test checking action budget."""
|
|
action = create_action(
|
|
sample_metadata,
|
|
estimated_tokens=100,
|
|
estimated_cost=0.01,
|
|
)
|
|
|
|
result = await cost_controller.check_action(action)
|
|
assert result is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_require_budget_success(
|
|
self,
|
|
cost_controller: CostController,
|
|
) -> None:
|
|
"""Test require_budget passes when budget available."""
|
|
# Should not raise
|
|
await cost_controller.require_budget(
|
|
agent_id="test-agent",
|
|
session_id="test-session",
|
|
estimated_tokens=100,
|
|
estimated_cost_usd=1.0,
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_require_budget_raises(
|
|
self,
|
|
cost_controller: CostController,
|
|
) -> None:
|
|
"""Test require_budget raises when budget exceeded."""
|
|
# Use all session budget
|
|
await cost_controller.record_usage(
|
|
agent_id="test-agent",
|
|
session_id="test-session",
|
|
tokens=1000,
|
|
cost_usd=10.0,
|
|
)
|
|
|
|
with pytest.raises(BudgetExceededError) as exc_info:
|
|
await cost_controller.require_budget(
|
|
agent_id="test-agent",
|
|
session_id="test-session",
|
|
estimated_tokens=100,
|
|
estimated_cost_usd=1.0,
|
|
)
|
|
|
|
assert "session" in exc_info.value.budget_type.lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_record_usage(
|
|
self,
|
|
cost_controller: CostController,
|
|
) -> None:
|
|
"""Test recording usage updates trackers."""
|
|
await cost_controller.record_usage(
|
|
agent_id="test-agent",
|
|
session_id="test-session",
|
|
tokens=100,
|
|
cost_usd=1.0,
|
|
)
|
|
|
|
# Check session budget was updated
|
|
session_status = await cost_controller.get_status(
|
|
BudgetScope.SESSION, "test-session"
|
|
)
|
|
assert session_status is not None
|
|
assert session_status.tokens_used == 100
|
|
|
|
# Check daily budget was updated
|
|
daily_status = await cost_controller.get_status(BudgetScope.DAILY, "test-agent")
|
|
assert daily_status is not None
|
|
assert daily_status.tokens_used == 100
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_all_statuses(
|
|
self,
|
|
cost_controller: CostController,
|
|
) -> None:
|
|
"""Test getting all budget statuses."""
|
|
# Record some usage
|
|
await cost_controller.record_usage(
|
|
agent_id="agent-1",
|
|
session_id="session-1",
|
|
tokens=100,
|
|
cost_usd=1.0,
|
|
)
|
|
await cost_controller.record_usage(
|
|
agent_id="agent-2",
|
|
session_id="session-2",
|
|
tokens=200,
|
|
cost_usd=2.0,
|
|
)
|
|
|
|
statuses = await cost_controller.get_all_statuses()
|
|
assert len(statuses) >= 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_set_budget(
|
|
self,
|
|
cost_controller: CostController,
|
|
) -> None:
|
|
"""Test setting custom budget."""
|
|
await cost_controller.set_budget(
|
|
scope=BudgetScope.SESSION,
|
|
scope_id="custom-session",
|
|
tokens_limit=5000,
|
|
cost_limit_usd=50.0,
|
|
)
|
|
|
|
status = await cost_controller.get_status(BudgetScope.SESSION, "custom-session")
|
|
assert status is not None
|
|
assert status.tokens_limit == 5000
|
|
assert status.cost_limit_usd == 50.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reset_budget(
|
|
self,
|
|
cost_controller: CostController,
|
|
) -> None:
|
|
"""Test resetting budget."""
|
|
# Record usage
|
|
await cost_controller.record_usage(
|
|
agent_id="test-agent",
|
|
session_id="test-session",
|
|
tokens=500,
|
|
cost_usd=5.0,
|
|
)
|
|
|
|
# Reset session budget
|
|
result = await cost_controller.reset_budget(BudgetScope.SESSION, "test-session")
|
|
assert result is True
|
|
|
|
# Verify reset
|
|
status = await cost_controller.get_status(BudgetScope.SESSION, "test-session")
|
|
assert status is not None
|
|
assert status.tokens_used == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reset_nonexistent_budget(
|
|
self,
|
|
cost_controller: CostController,
|
|
) -> None:
|
|
"""Test resetting non-existent budget returns False."""
|
|
result = await cost_controller.reset_budget(BudgetScope.SESSION, "nonexistent")
|
|
assert result is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_alert_handler(
|
|
self,
|
|
cost_controller: CostController,
|
|
) -> None:
|
|
"""Test alert handler is called at warning threshold."""
|
|
alerts_received = []
|
|
|
|
def alert_handler(alert_type: str, message: str, status):
|
|
alerts_received.append((alert_type, message))
|
|
|
|
cost_controller.add_alert_handler(alert_handler)
|
|
|
|
# Record usage to reach warning threshold (80%)
|
|
await cost_controller.record_usage(
|
|
agent_id="test-agent",
|
|
session_id="test-session",
|
|
tokens=850, # 85% of 1000
|
|
cost_usd=0.0,
|
|
)
|
|
|
|
assert len(alerts_received) > 0
|
|
assert alerts_received[0][0] == "warning"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_remove_alert_handler(
|
|
self,
|
|
cost_controller: CostController,
|
|
) -> None:
|
|
"""Test removing alert handler."""
|
|
alerts_received = []
|
|
|
|
def alert_handler(alert_type: str, message: str, status):
|
|
alerts_received.append((alert_type, message))
|
|
|
|
cost_controller.add_alert_handler(alert_handler)
|
|
cost_controller.remove_alert_handler(alert_handler)
|
|
|
|
# Record usage to reach warning threshold
|
|
await cost_controller.record_usage(
|
|
agent_id="test-agent",
|
|
session_id="test-session",
|
|
tokens=850,
|
|
cost_usd=0.0,
|
|
)
|
|
|
|
assert len(alerts_received) == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_alert_deduplication(
|
|
self,
|
|
cost_controller: CostController,
|
|
) -> None:
|
|
"""Test alerts are only sent once per budget (no spam)."""
|
|
alerts_received = []
|
|
|
|
def alert_handler(alert_type: str, message: str, status):
|
|
alerts_received.append((alert_type, message))
|
|
|
|
cost_controller.add_alert_handler(alert_handler)
|
|
|
|
# Record usage multiple times at warning level
|
|
# Session budget is 1000 with 80% threshold = 800 tokens
|
|
# 10 * 85 = 850 tokens triggers session warning once
|
|
for _ in range(10):
|
|
await cost_controller.record_usage(
|
|
agent_id="test-agent",
|
|
session_id="test-session",
|
|
tokens=85, # Each call adds 85 tokens
|
|
cost_usd=0.0,
|
|
)
|
|
|
|
# Should only receive ONE session warning (daily budget of 5000
|
|
# isn't reached yet). The key point is we don't get 10 alerts!
|
|
assert len(alerts_received) == 1
|
|
assert alerts_received[0][0] == "warning"
|
|
assert "Session" in alerts_received[0][1]
|