feat(safety): enhance rate limiting and cost control with alert deduplication and usage tracking

- Added `record_action` in `RateLimiter` for precise tracking of slot consumption post-validation.
- Introduced deduplication mechanism for warning alerts in `CostController` to prevent spamming.
- Refactored `CostController`'s session and daily budget alert handling for improved clarity.
- Implemented test suites for `CostController` and `SafetyGuardian` to validate changes.
- Expanded integration testing to cover deduplication, validation, and loop detection edge cases.
This commit is contained in:
2026-01-03 17:55:34 +01:00
parent 520c06175e
commit caf283bed2
9 changed files with 1782 additions and 92 deletions

View File

@@ -0,0 +1,405 @@
"""Tests for rate limiter module."""
import pytest
from app.services.safety.exceptions import RateLimitExceededError
from app.services.safety.limits.limiter import (
RateLimiter,
SlidingWindowCounter,
)
from app.services.safety.models import (
ActionMetadata,
ActionRequest,
ActionType,
RateLimitConfig,
)
@pytest.fixture
def sliding_counter() -> SlidingWindowCounter:
"""Create a sliding window counter for testing."""
return SlidingWindowCounter(
limit=5,
window_seconds=60,
burst_limit=3,
)
@pytest.fixture
def rate_limiter() -> RateLimiter:
"""Create a rate limiter for testing."""
limiter = RateLimiter()
# Configure a test limit
limiter.configure(
RateLimitConfig(
name="test_limit",
limit=5,
window_seconds=60,
burst_limit=3,
)
)
return limiter
@pytest.fixture
def sample_metadata() -> ActionMetadata:
"""Create sample action metadata."""
return ActionMetadata(
agent_id="test-agent",
session_id="test-session",
)
def create_action(
metadata: ActionMetadata,
action_type: ActionType = ActionType.LLM_CALL,
) -> ActionRequest:
"""Helper to create test actions."""
return ActionRequest(
action_type=action_type,
tool_name="test_tool",
resource="test-resource",
arguments={},
metadata=metadata,
)
class TestSlidingWindowCounter:
"""Tests for SlidingWindowCounter class."""
@pytest.mark.asyncio
async def test_first_acquire_allowed(
self,
sliding_counter: SlidingWindowCounter,
) -> None:
"""Test first acquire is always allowed."""
allowed, retry_after = await sliding_counter.try_acquire()
assert allowed is True
assert retry_after == 0.0
@pytest.mark.asyncio
async def test_burst_limit(
self,
sliding_counter: SlidingWindowCounter,
) -> None:
"""Test burst limit is enforced."""
# Acquire up to burst limit (3)
for _ in range(3):
allowed, _ = await sliding_counter.try_acquire()
assert allowed is True
# Next should be denied (burst exceeded)
allowed, retry_after = await sliding_counter.try_acquire()
assert allowed is False
assert retry_after > 0
@pytest.mark.asyncio
async def test_get_status(
self,
sliding_counter: SlidingWindowCounter,
) -> None:
"""Test getting counter status."""
# Make some requests
await sliding_counter.try_acquire()
await sliding_counter.try_acquire()
current, remaining, reset_in = await sliding_counter.get_status()
assert current == 2
assert remaining == 3 # 5 - 2
assert reset_in >= 0
class TestRateLimiter:
"""Tests for RateLimiter class."""
@pytest.mark.asyncio
async def test_check_status(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test checking rate limit status."""
status = await rate_limiter.check("test_limit", "test-key")
assert status.name == "test_limit"
assert status.current_count == 0
assert status.limit == 5
assert status.remaining == 5
assert status.is_limited is False
@pytest.mark.asyncio
async def test_acquire_success(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test successful acquire."""
allowed, status = await rate_limiter.acquire("test_limit", "test-key")
assert allowed is True
assert status.current_count == 1
assert status.remaining == 4
@pytest.mark.asyncio
async def test_acquire_burst_exceeded(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test acquire fails when burst exceeded."""
# Acquire up to burst limit
for _ in range(3):
allowed, _ = await rate_limiter.acquire("test_limit", "test-key")
assert allowed is True
# Next should fail
allowed, status = await rate_limiter.acquire("test_limit", "test-key")
assert allowed is False
assert status.is_limited is True
assert status.retry_after_seconds > 0
@pytest.mark.asyncio
async def test_require_success(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test require passes when not limited."""
# Should not raise
await rate_limiter.require("test_limit", "test-key")
@pytest.mark.asyncio
async def test_require_raises(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test require raises when limited."""
# Use up burst limit
for _ in range(3):
await rate_limiter.acquire("test_limit", "test-key")
with pytest.raises(RateLimitExceededError) as exc_info:
await rate_limiter.require("test_limit", "test-key")
assert exc_info.value.limit_type == "test_limit"
assert exc_info.value.retry_after_seconds > 0
@pytest.mark.asyncio
async def test_check_action_allowed(
self,
rate_limiter: RateLimiter,
sample_metadata: ActionMetadata,
) -> None:
"""Test checking action is allowed."""
action = create_action(sample_metadata)
allowed, statuses = await rate_limiter.check_action(action)
assert allowed is True
assert len(statuses) >= 1 # At least "actions" limit
@pytest.mark.asyncio
async def test_check_action_llm_limits(
self,
rate_limiter: RateLimiter,
sample_metadata: ActionMetadata,
) -> None:
"""Test LLM actions check LLM-specific limits."""
action = create_action(sample_metadata, action_type=ActionType.LLM_CALL)
allowed, statuses = await rate_limiter.check_action(action)
assert allowed is True
# Should have checked both "actions" and "llm_calls"
limit_names = [s.name for s in statuses]
assert "actions" in limit_names
assert "llm_calls" in limit_names
@pytest.mark.asyncio
async def test_check_action_file_limits(
self,
rate_limiter: RateLimiter,
sample_metadata: ActionMetadata,
) -> None:
"""Test file actions check file-specific limits."""
action = create_action(sample_metadata, action_type=ActionType.FILE_READ)
allowed, statuses = await rate_limiter.check_action(action)
assert allowed is True
# Should have checked both "actions" and "file_ops"
limit_names = [s.name for s in statuses]
assert "actions" in limit_names
assert "file_ops" in limit_names
@pytest.mark.asyncio
async def test_check_action_does_not_consume_slot(
self,
rate_limiter: RateLimiter,
sample_metadata: ActionMetadata,
) -> None:
"""Test check_action only checks without consuming slots."""
action = create_action(sample_metadata)
# Check multiple times - should never consume
for _ in range(10):
allowed, _ = await rate_limiter.check_action(action)
assert allowed is True
# Verify no slots were consumed
status = await rate_limiter.check("actions", sample_metadata.agent_id)
assert status.current_count == 0
@pytest.mark.asyncio
async def test_record_action_consumes_slot(
self,
rate_limiter: RateLimiter,
sample_metadata: ActionMetadata,
) -> None:
"""Test record_action consumes rate limit slots."""
action = create_action(sample_metadata)
# Record the action
await rate_limiter.record_action(action)
# Verify slot was consumed
status = await rate_limiter.check("actions", sample_metadata.agent_id)
assert status.current_count == 1
@pytest.mark.asyncio
async def test_record_action_consumes_type_specific_slots(
self,
rate_limiter: RateLimiter,
sample_metadata: ActionMetadata,
) -> None:
"""Test record_action consumes type-specific slots."""
# LLM action
llm_action = create_action(sample_metadata, action_type=ActionType.LLM_CALL)
await rate_limiter.record_action(llm_action)
statuses = await rate_limiter.get_all_statuses(sample_metadata.agent_id)
assert statuses["actions"].current_count == 1
assert statuses["llm_calls"].current_count == 1
assert statuses["file_ops"].current_count == 0
# File action
file_action = create_action(sample_metadata, action_type=ActionType.FILE_READ)
await rate_limiter.record_action(file_action)
statuses = await rate_limiter.get_all_statuses(sample_metadata.agent_id)
assert statuses["actions"].current_count == 2
assert statuses["llm_calls"].current_count == 1
assert statuses["file_ops"].current_count == 1
@pytest.mark.asyncio
async def test_get_all_statuses(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test getting all rate limit statuses."""
# Make some requests
await rate_limiter.acquire("actions", "test-key")
await rate_limiter.acquire("llm_calls", "test-key")
statuses = await rate_limiter.get_all_statuses("test-key")
assert "actions" in statuses
assert "llm_calls" in statuses
assert "file_ops" in statuses
assert statuses["actions"].current_count >= 1
assert statuses["llm_calls"].current_count >= 1
@pytest.mark.asyncio
async def test_reset_single(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test resetting a single rate limit."""
# Make some requests
await rate_limiter.acquire("test_limit", "test-key")
await rate_limiter.acquire("test_limit", "test-key")
# Reset
result = await rate_limiter.reset("test_limit", "test-key")
assert result is True
# Check it's reset
status = await rate_limiter.check("test_limit", "test-key")
assert status.current_count == 0
@pytest.mark.asyncio
async def test_reset_nonexistent(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test resetting non-existent limit returns False."""
result = await rate_limiter.reset("nonexistent", "test-key")
assert result is False
@pytest.mark.asyncio
async def test_reset_all(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test resetting all rate limits for a key."""
# Make requests across multiple limits
await rate_limiter.acquire("actions", "test-key")
await rate_limiter.acquire("llm_calls", "test-key")
await rate_limiter.acquire("file_ops", "test-key")
# Reset all
count = await rate_limiter.reset_all("test-key")
assert count >= 3
# Check they're reset
statuses = await rate_limiter.get_all_statuses("test-key")
for status in statuses.values():
assert status.current_count == 0
@pytest.mark.asyncio
async def test_per_key_isolation(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test rate limits are isolated per key."""
# Use up burst limit for key-1
for _ in range(3):
await rate_limiter.acquire("test_limit", "key-1")
# key-1 should be limited
allowed1, _ = await rate_limiter.acquire("test_limit", "key-1")
assert allowed1 is False
# key-2 should still be allowed
allowed2, _ = await rate_limiter.acquire("test_limit", "key-2")
assert allowed2 is True
@pytest.mark.asyncio
async def test_configure_custom_limit(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test configuring custom rate limits."""
rate_limiter.configure(
RateLimitConfig(
name="custom",
limit=100,
window_seconds=120,
burst_limit=50,
)
)
status = await rate_limiter.check("custom", "test-key")
assert status.limit == 100
assert status.window_seconds == 120
@pytest.mark.asyncio
async def test_default_limit_fallback(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test fallback to default limit for unknown limit names."""
# Request limit that doesn't exist
status = await rate_limiter.check("unknown_limit", "test-key")
# Should use default (60/60s)
assert status.limit == 60
assert status.window_seconds == 60