feat(safety): enhance rate limiting and cost control with alert deduplication and usage tracking
- Added `record_action` in `RateLimiter` for precise tracking of slot consumption post-validation. - Introduced deduplication mechanism for warning alerts in `CostController` to prevent spamming. - Refactored `CostController`'s session and daily budget alert handling for improved clarity. - Implemented test suites for `CostController` and `SafetyGuardian` to validate changes. - Expanded integration testing to cover deduplication, validation, and loop detection edge cases.
This commit is contained in:
405
backend/tests/services/safety/test_limits.py
Normal file
405
backend/tests/services/safety/test_limits.py
Normal file
@@ -0,0 +1,405 @@
|
||||
"""Tests for rate limiter module."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.safety.exceptions import RateLimitExceededError
|
||||
from app.services.safety.limits.limiter import (
|
||||
RateLimiter,
|
||||
SlidingWindowCounter,
|
||||
)
|
||||
from app.services.safety.models import (
|
||||
ActionMetadata,
|
||||
ActionRequest,
|
||||
ActionType,
|
||||
RateLimitConfig,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sliding_counter() -> SlidingWindowCounter:
|
||||
"""Create a sliding window counter for testing."""
|
||||
return SlidingWindowCounter(
|
||||
limit=5,
|
||||
window_seconds=60,
|
||||
burst_limit=3,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rate_limiter() -> RateLimiter:
|
||||
"""Create a rate limiter for testing."""
|
||||
limiter = RateLimiter()
|
||||
# Configure a test limit
|
||||
limiter.configure(
|
||||
RateLimitConfig(
|
||||
name="test_limit",
|
||||
limit=5,
|
||||
window_seconds=60,
|
||||
burst_limit=3,
|
||||
)
|
||||
)
|
||||
return limiter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_metadata() -> ActionMetadata:
|
||||
"""Create sample action metadata."""
|
||||
return ActionMetadata(
|
||||
agent_id="test-agent",
|
||||
session_id="test-session",
|
||||
)
|
||||
|
||||
|
||||
def create_action(
|
||||
metadata: ActionMetadata,
|
||||
action_type: ActionType = ActionType.LLM_CALL,
|
||||
) -> ActionRequest:
|
||||
"""Helper to create test actions."""
|
||||
return ActionRequest(
|
||||
action_type=action_type,
|
||||
tool_name="test_tool",
|
||||
resource="test-resource",
|
||||
arguments={},
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
class TestSlidingWindowCounter:
|
||||
"""Tests for SlidingWindowCounter class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_acquire_allowed(
|
||||
self,
|
||||
sliding_counter: SlidingWindowCounter,
|
||||
) -> None:
|
||||
"""Test first acquire is always allowed."""
|
||||
allowed, retry_after = await sliding_counter.try_acquire()
|
||||
|
||||
assert allowed is True
|
||||
assert retry_after == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_burst_limit(
|
||||
self,
|
||||
sliding_counter: SlidingWindowCounter,
|
||||
) -> None:
|
||||
"""Test burst limit is enforced."""
|
||||
# Acquire up to burst limit (3)
|
||||
for _ in range(3):
|
||||
allowed, _ = await sliding_counter.try_acquire()
|
||||
assert allowed is True
|
||||
|
||||
# Next should be denied (burst exceeded)
|
||||
allowed, retry_after = await sliding_counter.try_acquire()
|
||||
assert allowed is False
|
||||
assert retry_after > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_status(
|
||||
self,
|
||||
sliding_counter: SlidingWindowCounter,
|
||||
) -> None:
|
||||
"""Test getting counter status."""
|
||||
# Make some requests
|
||||
await sliding_counter.try_acquire()
|
||||
await sliding_counter.try_acquire()
|
||||
|
||||
current, remaining, reset_in = await sliding_counter.get_status()
|
||||
|
||||
assert current == 2
|
||||
assert remaining == 3 # 5 - 2
|
||||
assert reset_in >= 0
|
||||
|
||||
|
||||
class TestRateLimiter:
|
||||
"""Tests for RateLimiter class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_status(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
) -> None:
|
||||
"""Test checking rate limit status."""
|
||||
status = await rate_limiter.check("test_limit", "test-key")
|
||||
|
||||
assert status.name == "test_limit"
|
||||
assert status.current_count == 0
|
||||
assert status.limit == 5
|
||||
assert status.remaining == 5
|
||||
assert status.is_limited is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_success(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
) -> None:
|
||||
"""Test successful acquire."""
|
||||
allowed, status = await rate_limiter.acquire("test_limit", "test-key")
|
||||
|
||||
assert allowed is True
|
||||
assert status.current_count == 1
|
||||
assert status.remaining == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_burst_exceeded(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
) -> None:
|
||||
"""Test acquire fails when burst exceeded."""
|
||||
# Acquire up to burst limit
|
||||
for _ in range(3):
|
||||
allowed, _ = await rate_limiter.acquire("test_limit", "test-key")
|
||||
assert allowed is True
|
||||
|
||||
# Next should fail
|
||||
allowed, status = await rate_limiter.acquire("test_limit", "test-key")
|
||||
assert allowed is False
|
||||
assert status.is_limited is True
|
||||
assert status.retry_after_seconds > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_success(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
) -> None:
|
||||
"""Test require passes when not limited."""
|
||||
# Should not raise
|
||||
await rate_limiter.require("test_limit", "test-key")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_raises(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
) -> None:
|
||||
"""Test require raises when limited."""
|
||||
# Use up burst limit
|
||||
for _ in range(3):
|
||||
await rate_limiter.acquire("test_limit", "test-key")
|
||||
|
||||
with pytest.raises(RateLimitExceededError) as exc_info:
|
||||
await rate_limiter.require("test_limit", "test-key")
|
||||
|
||||
assert exc_info.value.limit_type == "test_limit"
|
||||
assert exc_info.value.retry_after_seconds > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_action_allowed(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test checking action is allowed."""
|
||||
action = create_action(sample_metadata)
|
||||
|
||||
allowed, statuses = await rate_limiter.check_action(action)
|
||||
|
||||
assert allowed is True
|
||||
assert len(statuses) >= 1 # At least "actions" limit
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_action_llm_limits(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test LLM actions check LLM-specific limits."""
|
||||
action = create_action(sample_metadata, action_type=ActionType.LLM_CALL)
|
||||
|
||||
allowed, statuses = await rate_limiter.check_action(action)
|
||||
|
||||
assert allowed is True
|
||||
# Should have checked both "actions" and "llm_calls"
|
||||
limit_names = [s.name for s in statuses]
|
||||
assert "actions" in limit_names
|
||||
assert "llm_calls" in limit_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_action_file_limits(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test file actions check file-specific limits."""
|
||||
action = create_action(sample_metadata, action_type=ActionType.FILE_READ)
|
||||
|
||||
allowed, statuses = await rate_limiter.check_action(action)
|
||||
|
||||
assert allowed is True
|
||||
# Should have checked both "actions" and "file_ops"
|
||||
limit_names = [s.name for s in statuses]
|
||||
assert "actions" in limit_names
|
||||
assert "file_ops" in limit_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_action_does_not_consume_slot(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test check_action only checks without consuming slots."""
|
||||
action = create_action(sample_metadata)
|
||||
|
||||
# Check multiple times - should never consume
|
||||
for _ in range(10):
|
||||
allowed, _ = await rate_limiter.check_action(action)
|
||||
assert allowed is True
|
||||
|
||||
# Verify no slots were consumed
|
||||
status = await rate_limiter.check("actions", sample_metadata.agent_id)
|
||||
assert status.current_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_action_consumes_slot(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test record_action consumes rate limit slots."""
|
||||
action = create_action(sample_metadata)
|
||||
|
||||
# Record the action
|
||||
await rate_limiter.record_action(action)
|
||||
|
||||
# Verify slot was consumed
|
||||
status = await rate_limiter.check("actions", sample_metadata.agent_id)
|
||||
assert status.current_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_action_consumes_type_specific_slots(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test record_action consumes type-specific slots."""
|
||||
# LLM action
|
||||
llm_action = create_action(sample_metadata, action_type=ActionType.LLM_CALL)
|
||||
await rate_limiter.record_action(llm_action)
|
||||
|
||||
statuses = await rate_limiter.get_all_statuses(sample_metadata.agent_id)
|
||||
assert statuses["actions"].current_count == 1
|
||||
assert statuses["llm_calls"].current_count == 1
|
||||
assert statuses["file_ops"].current_count == 0
|
||||
|
||||
# File action
|
||||
file_action = create_action(sample_metadata, action_type=ActionType.FILE_READ)
|
||||
await rate_limiter.record_action(file_action)
|
||||
|
||||
statuses = await rate_limiter.get_all_statuses(sample_metadata.agent_id)
|
||||
assert statuses["actions"].current_count == 2
|
||||
assert statuses["llm_calls"].current_count == 1
|
||||
assert statuses["file_ops"].current_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_statuses(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
) -> None:
|
||||
"""Test getting all rate limit statuses."""
|
||||
# Make some requests
|
||||
await rate_limiter.acquire("actions", "test-key")
|
||||
await rate_limiter.acquire("llm_calls", "test-key")
|
||||
|
||||
statuses = await rate_limiter.get_all_statuses("test-key")
|
||||
|
||||
assert "actions" in statuses
|
||||
assert "llm_calls" in statuses
|
||||
assert "file_ops" in statuses
|
||||
assert statuses["actions"].current_count >= 1
|
||||
assert statuses["llm_calls"].current_count >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_single(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
) -> None:
|
||||
"""Test resetting a single rate limit."""
|
||||
# Make some requests
|
||||
await rate_limiter.acquire("test_limit", "test-key")
|
||||
await rate_limiter.acquire("test_limit", "test-key")
|
||||
|
||||
# Reset
|
||||
result = await rate_limiter.reset("test_limit", "test-key")
|
||||
assert result is True
|
||||
|
||||
# Check it's reset
|
||||
status = await rate_limiter.check("test_limit", "test-key")
|
||||
assert status.current_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_nonexistent(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
) -> None:
|
||||
"""Test resetting non-existent limit returns False."""
|
||||
result = await rate_limiter.reset("nonexistent", "test-key")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_all(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
) -> None:
|
||||
"""Test resetting all rate limits for a key."""
|
||||
# Make requests across multiple limits
|
||||
await rate_limiter.acquire("actions", "test-key")
|
||||
await rate_limiter.acquire("llm_calls", "test-key")
|
||||
await rate_limiter.acquire("file_ops", "test-key")
|
||||
|
||||
# Reset all
|
||||
count = await rate_limiter.reset_all("test-key")
|
||||
assert count >= 3
|
||||
|
||||
# Check they're reset
|
||||
statuses = await rate_limiter.get_all_statuses("test-key")
|
||||
for status in statuses.values():
|
||||
assert status.current_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_per_key_isolation(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
) -> None:
|
||||
"""Test rate limits are isolated per key."""
|
||||
# Use up burst limit for key-1
|
||||
for _ in range(3):
|
||||
await rate_limiter.acquire("test_limit", "key-1")
|
||||
|
||||
# key-1 should be limited
|
||||
allowed1, _ = await rate_limiter.acquire("test_limit", "key-1")
|
||||
assert allowed1 is False
|
||||
|
||||
# key-2 should still be allowed
|
||||
allowed2, _ = await rate_limiter.acquire("test_limit", "key-2")
|
||||
assert allowed2 is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_configure_custom_limit(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
) -> None:
|
||||
"""Test configuring custom rate limits."""
|
||||
rate_limiter.configure(
|
||||
RateLimitConfig(
|
||||
name="custom",
|
||||
limit=100,
|
||||
window_seconds=120,
|
||||
burst_limit=50,
|
||||
)
|
||||
)
|
||||
|
||||
status = await rate_limiter.check("custom", "test-key")
|
||||
assert status.limit == 100
|
||||
assert status.window_seconds == 120
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_limit_fallback(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
) -> None:
|
||||
"""Test fallback to default limit for unknown limit names."""
|
||||
# Request limit that doesn't exist
|
||||
status = await rate_limiter.check("unknown_limit", "test-key")
|
||||
|
||||
# Should use default (60/60s)
|
||||
assert status.limit == 60
|
||||
assert status.window_seconds == 60
|
||||
Reference in New Issue
Block a user