Files
syndarix/mcp-servers/llm-gateway/tests/test_cost_tracking.py
Felipe Cardoso f482559e15 fix(llm-gateway): improve type safety and datetime consistency
- Add type annotations for mypy compliance
- Use UTC-aware datetimes consistently (datetime.now(UTC))
- Add type: ignore comments for LiteLLM incomplete stubs
- Fix import ordering and formatting
- Update pyproject.toml mypy configuration

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-03 20:56:05 +01:00

406 lines
13 KiB
Python

"""
Tests for cost_tracking module.
"""
import asyncio
import fakeredis.aioredis
import pytest
from config import Settings
from cost_tracking import (
CostTracker,
calculate_cost,
close_cost_tracker,
get_cost_tracker,
reset_cost_tracker,
)
@pytest.fixture
def tracker_settings() -> Settings:
"""Settings for cost tracker tests."""
return Settings(
redis_url="redis://localhost:6379/0",
redis_prefix="test_llm_gateway",
cost_tracking_enabled=True,
cost_alert_threshold=100.0,
default_budget_limit=1000.0,
)
@pytest.fixture
def fake_redis() -> fakeredis.aioredis.FakeRedis:
"""Create fake Redis for testing."""
return fakeredis.aioredis.FakeRedis(decode_responses=True)
@pytest.fixture
def tracker(
fake_redis: fakeredis.aioredis.FakeRedis,
tracker_settings: Settings,
) -> CostTracker:
"""Create cost tracker with fake Redis."""
return CostTracker(redis_client=fake_redis, settings=tracker_settings)
class TestCalculateCost:
"""Tests for calculate_cost function."""
def test_calculate_cost_known_model(self) -> None:
"""Test calculating cost for known model."""
# claude-opus-4: $15/1M input, $75/1M output
cost = calculate_cost(
model="claude-opus-4",
prompt_tokens=1000,
completion_tokens=500,
)
# 1000/1M * 15 + 500/1M * 75 = 0.015 + 0.0375 = 0.0525
assert cost == pytest.approx(0.0525, rel=0.001)
def test_calculate_cost_unknown_model(self) -> None:
"""Test calculating cost for unknown model."""
cost = calculate_cost(
model="unknown-model",
prompt_tokens=1000,
completion_tokens=500,
)
assert cost == 0.0
def test_calculate_cost_zero_tokens(self) -> None:
"""Test calculating cost with zero tokens."""
cost = calculate_cost(
model="claude-opus-4",
prompt_tokens=0,
completion_tokens=0,
)
assert cost == 0.0
def test_calculate_cost_large_token_counts(self) -> None:
"""Test calculating cost with large token counts."""
cost = calculate_cost(
model="claude-opus-4",
prompt_tokens=1_000_000,
completion_tokens=500_000,
)
# 1M * 15/1M + 500K * 75/1M = 15 + 37.5 = 52.5
assert cost == pytest.approx(52.5, rel=0.001)
class TestCostTracker:
"""Tests for CostTracker class."""
def test_record_usage(self, tracker: CostTracker) -> None:
"""Test recording usage."""
asyncio.run(
tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=100,
completion_tokens=50,
cost_usd=0.01,
)
)
# Verify by getting usage report
report = asyncio.run(tracker.get_project_usage("proj-123", period="day"))
assert report.total_requests == 1
assert report.total_cost_usd == pytest.approx(0.01, rel=0.01)
def test_record_usage_disabled(self, tracker_settings: Settings) -> None:
"""Test recording is skipped when disabled."""
settings = Settings(
**{
**tracker_settings.model_dump(),
"cost_tracking_enabled": False,
}
)
fake_redis = fakeredis.aioredis.FakeRedis(decode_responses=True)
disabled_tracker = CostTracker(redis_client=fake_redis, settings=settings)
# This should not raise and should not record
asyncio.run(
disabled_tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=100,
completion_tokens=50,
cost_usd=0.01,
)
)
# Usage should be empty
report = asyncio.run(
disabled_tracker.get_project_usage("proj-123", period="day")
)
assert report.total_requests == 0
def test_record_usage_with_session(self, tracker: CostTracker) -> None:
"""Test recording usage with session ID."""
asyncio.run(
tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=100,
completion_tokens=50,
cost_usd=0.01,
session_id="session-789",
)
)
# Verify session usage
session_usage = asyncio.run(tracker.get_session_usage("session-789"))
assert session_usage["session_id"] == "session-789"
assert session_usage["total_cost_usd"] == pytest.approx(0.01, rel=0.01)
def test_get_project_usage_empty(self, tracker: CostTracker) -> None:
"""Test getting usage for project with no data."""
report = asyncio.run(
tracker.get_project_usage("nonexistent-project", period="day")
)
assert report.entity_id == "nonexistent-project"
assert report.entity_type == "project"
assert report.total_requests == 0
assert report.total_cost_usd == 0.0
def test_get_project_usage_multiple_models(self, tracker: CostTracker) -> None:
"""Test usage tracking across multiple models."""
# Record usage for different models
asyncio.run(
tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=100,
completion_tokens=50,
cost_usd=0.01,
)
)
asyncio.run(
tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="gpt-4.1",
prompt_tokens=200,
completion_tokens=100,
cost_usd=0.02,
)
)
report = asyncio.run(tracker.get_project_usage("proj-123", period="day"))
assert report.total_requests == 2
assert len(report.by_model) == 2
assert "claude-opus-4" in report.by_model
assert "gpt-4.1" in report.by_model
def test_get_agent_usage(self, tracker: CostTracker) -> None:
"""Test getting agent usage."""
asyncio.run(
tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=100,
completion_tokens=50,
cost_usd=0.01,
)
)
report = asyncio.run(tracker.get_agent_usage("agent-456", period="day"))
assert report.entity_id == "agent-456"
assert report.entity_type == "agent"
assert report.total_requests == 1
def test_usage_periods(self, tracker: CostTracker) -> None:
"""Test different usage periods."""
asyncio.run(
tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=100,
completion_tokens=50,
cost_usd=0.01,
)
)
# Check different periods
hour_report = asyncio.run(tracker.get_project_usage("proj-123", period="hour"))
day_report = asyncio.run(tracker.get_project_usage("proj-123", period="day"))
month_report = asyncio.run(
tracker.get_project_usage("proj-123", period="month")
)
assert hour_report.period == "hour"
assert day_report.period == "day"
assert month_report.period == "month"
# All should have the same data
assert hour_report.total_requests == 1
assert day_report.total_requests == 1
assert month_report.total_requests == 1
def test_check_budget_within(self, tracker: CostTracker) -> None:
"""Test budget check when within limit."""
# Record some usage
asyncio.run(
tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=1000,
completion_tokens=500,
cost_usd=50.0,
)
)
within, current, limit = asyncio.run(
tracker.check_budget("proj-123", budget_limit=100.0)
)
assert within is True
assert current == pytest.approx(50.0, rel=0.01)
assert limit == 100.0
def test_check_budget_exceeded(self, tracker: CostTracker) -> None:
"""Test budget check when exceeded."""
asyncio.run(
tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=1000,
completion_tokens=500,
cost_usd=150.0,
)
)
within, current, limit = asyncio.run(
tracker.check_budget("proj-123", budget_limit=100.0)
)
assert within is False
assert current >= limit
def test_check_budget_default_limit(self, tracker: CostTracker) -> None:
"""Test budget check with default limit."""
within, current, limit = asyncio.run(tracker.check_budget("proj-123"))
assert limit == 1000.0 # Default from settings
def test_estimate_request_cost_known_model(self, tracker: CostTracker) -> None:
"""Test estimating cost for known model."""
cost = asyncio.run(
tracker.estimate_request_cost(
model="claude-opus-4",
prompt_tokens=1000,
max_completion_tokens=500,
)
)
# 1000/1M * 15 + 500/1M * 75 = 0.015 + 0.0375 = 0.0525
assert cost == pytest.approx(0.0525, rel=0.01)
def test_estimate_request_cost_unknown_model(self, tracker: CostTracker) -> None:
"""Test estimating cost for unknown model."""
cost = asyncio.run(
tracker.estimate_request_cost(
model="unknown-model",
prompt_tokens=1000,
max_completion_tokens=500,
)
)
# Uses fallback estimate
assert cost > 0
def test_should_alert_below_threshold(self, tracker: CostTracker) -> None:
"""Test alert check when below threshold."""
asyncio.run(
tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=1000,
completion_tokens=500,
cost_usd=50.0,
)
)
should_alert, current = asyncio.run(
tracker.should_alert("proj-123", threshold=100.0)
)
assert should_alert is False
assert current == pytest.approx(50.0, rel=0.01)
def test_should_alert_above_threshold(self, tracker: CostTracker) -> None:
"""Test alert check when above threshold."""
asyncio.run(
tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=1000,
completion_tokens=500,
cost_usd=150.0,
)
)
should_alert, current = asyncio.run(
tracker.should_alert("proj-123", threshold=100.0)
)
assert should_alert is True
def test_close(self, tracker: CostTracker) -> None:
"""Test closing tracker."""
asyncio.run(tracker.close())
assert tracker._redis is None
class TestGlobalTracker:
"""Tests for global tracker functions."""
def test_get_cost_tracker(self) -> None:
"""Test getting global tracker."""
reset_cost_tracker()
tracker = get_cost_tracker()
assert isinstance(tracker, CostTracker)
def test_get_cost_tracker_singleton(self) -> None:
"""Test tracker is singleton."""
reset_cost_tracker()
tracker1 = get_cost_tracker()
tracker2 = get_cost_tracker()
assert tracker1 is tracker2
def test_reset_cost_tracker(self) -> None:
"""Test resetting global tracker."""
reset_cost_tracker()
tracker1 = get_cost_tracker()
reset_cost_tracker()
tracker2 = get_cost_tracker()
assert tracker1 is not tracker2
def test_close_cost_tracker(self) -> None:
"""Test closing global tracker."""
reset_cost_tracker()
_ = get_cost_tracker()
asyncio.run(close_cost_tracker())
# Getting again should create a new one
tracker2 = get_cost_tracker()
assert tracker2 is not None