feat(safety): enhance rate limiting and cost control with alert deduplication and usage tracking

- Added `record_action` in `RateLimiter` for precise tracking of slot consumption post-validation.
- Introduced deduplication mechanism for warning alerts in `CostController` to prevent spamming.
- Refactored `CostController`'s session and daily budget alert handling for improved clarity.
- Implemented test suites for `CostController` and `SafetyGuardian` to validate changes.
- Expanded integration testing to cover deduplication, validation, and loop detection edge cases.
This commit is contained in:
2026-01-03 17:55:34 +01:00
parent 520c06175e
commit caf283bed2
9 changed files with 1782 additions and 92 deletions

View File

@@ -0,0 +1,436 @@
"""Tests for cost controller module."""
import pytest
from app.services.safety.costs.controller import (
BudgetTracker,
CostController,
)
from app.services.safety.exceptions import BudgetExceededError
from app.services.safety.models import (
ActionMetadata,
ActionRequest,
ActionType,
BudgetScope,
)
@pytest.fixture
def budget_tracker() -> BudgetTracker:
"""Create a budget tracker for testing."""
return BudgetTracker(
scope=BudgetScope.SESSION,
scope_id="test-session",
tokens_limit=1000,
cost_limit_usd=10.0,
warning_threshold=0.8,
)
@pytest.fixture
def cost_controller() -> CostController:
"""Create a cost controller for testing."""
return CostController(
default_session_tokens=1000,
default_session_cost_usd=10.0,
default_daily_tokens=5000,
default_daily_cost_usd=50.0,
)
@pytest.fixture
def sample_metadata() -> ActionMetadata:
"""Create sample action metadata."""
return ActionMetadata(
agent_id="test-agent",
session_id="test-session",
)
def create_action(
metadata: ActionMetadata,
estimated_tokens: int = 100,
estimated_cost: float = 0.01,
) -> ActionRequest:
"""Helper to create test actions."""
return ActionRequest(
action_type=ActionType.LLM_CALL,
tool_name="test_tool",
resource="test-resource",
arguments={},
metadata=metadata,
estimated_cost_tokens=estimated_tokens,
estimated_cost_usd=estimated_cost,
)
class TestBudgetTracker:
"""Tests for BudgetTracker class."""
@pytest.mark.asyncio
async def test_initial_status(self, budget_tracker: BudgetTracker) -> None:
"""Test initial budget status is clean."""
status = await budget_tracker.get_status()
assert status.tokens_used == 0
assert status.cost_used_usd == 0.0
assert status.tokens_remaining == 1000
assert status.cost_remaining_usd == 10.0
assert status.is_warning is False
assert status.is_exceeded is False
@pytest.mark.asyncio
async def test_add_usage(self, budget_tracker: BudgetTracker) -> None:
"""Test adding usage updates counters."""
await budget_tracker.add_usage(tokens=100, cost_usd=1.0)
status = await budget_tracker.get_status()
assert status.tokens_used == 100
assert status.cost_used_usd == 1.0
assert status.tokens_remaining == 900
assert status.cost_remaining_usd == 9.0
@pytest.mark.asyncio
async def test_warning_threshold(self, budget_tracker: BudgetTracker) -> None:
"""Test warning is triggered at threshold."""
# Add usage to reach 80% of tokens
await budget_tracker.add_usage(tokens=800, cost_usd=1.0)
status = await budget_tracker.get_status()
assert status.is_warning is True
assert status.is_exceeded is False
@pytest.mark.asyncio
async def test_budget_exceeded(self, budget_tracker: BudgetTracker) -> None:
"""Test budget exceeded detection."""
# Exceed token limit
await budget_tracker.add_usage(tokens=1100, cost_usd=1.0)
status = await budget_tracker.get_status()
assert status.is_exceeded is True
@pytest.mark.asyncio
async def test_check_budget_allows(self, budget_tracker: BudgetTracker) -> None:
"""Test check_budget allows within budget."""
result = await budget_tracker.check_budget(
estimated_tokens=500,
estimated_cost_usd=5.0,
)
assert result is True
@pytest.mark.asyncio
async def test_check_budget_denies(self, budget_tracker: BudgetTracker) -> None:
"""Test check_budget denies when would exceed."""
# Use most of the budget
await budget_tracker.add_usage(tokens=800, cost_usd=8.0)
# Check would exceed
result = await budget_tracker.check_budget(
estimated_tokens=300,
estimated_cost_usd=3.0,
)
assert result is False
@pytest.mark.asyncio
async def test_reset(self, budget_tracker: BudgetTracker) -> None:
"""Test manual reset clears counters."""
await budget_tracker.add_usage(tokens=500, cost_usd=5.0)
await budget_tracker.reset()
status = await budget_tracker.get_status()
assert status.tokens_used == 0
assert status.cost_used_usd == 0.0
class TestCostController:
"""Tests for CostController class."""
@pytest.mark.asyncio
async def test_check_budget_success(
self,
cost_controller: CostController,
) -> None:
"""Test budget check passes with available budget."""
result = await cost_controller.check_budget(
agent_id="test-agent",
session_id="test-session",
estimated_tokens=100,
estimated_cost_usd=1.0,
)
assert result is True
@pytest.mark.asyncio
async def test_check_budget_session_exceeded(
self,
cost_controller: CostController,
) -> None:
"""Test budget check fails when session budget exceeded."""
# Use most of session budget
await cost_controller.record_usage(
agent_id="test-agent",
session_id="test-session",
tokens=900,
cost_usd=9.0,
)
# Check would exceed
result = await cost_controller.check_budget(
agent_id="test-agent",
session_id="test-session",
estimated_tokens=200,
estimated_cost_usd=2.0,
)
assert result is False
@pytest.mark.asyncio
async def test_check_budget_daily_exceeded(
self,
cost_controller: CostController,
) -> None:
"""Test budget check fails when daily budget exceeded."""
# Use most of daily budget
await cost_controller.record_usage(
agent_id="test-agent",
session_id=None,
tokens=4900,
cost_usd=49.0,
)
# Check would exceed daily
result = await cost_controller.check_budget(
agent_id="test-agent",
session_id="new-session",
estimated_tokens=200,
estimated_cost_usd=2.0,
)
assert result is False
@pytest.mark.asyncio
async def test_check_action(
self,
cost_controller: CostController,
sample_metadata: ActionMetadata,
) -> None:
"""Test checking action budget."""
action = create_action(
sample_metadata,
estimated_tokens=100,
estimated_cost=0.01,
)
result = await cost_controller.check_action(action)
assert result is True
@pytest.mark.asyncio
async def test_require_budget_success(
self,
cost_controller: CostController,
) -> None:
"""Test require_budget passes when budget available."""
# Should not raise
await cost_controller.require_budget(
agent_id="test-agent",
session_id="test-session",
estimated_tokens=100,
estimated_cost_usd=1.0,
)
@pytest.mark.asyncio
async def test_require_budget_raises(
self,
cost_controller: CostController,
) -> None:
"""Test require_budget raises when budget exceeded."""
# Use all session budget
await cost_controller.record_usage(
agent_id="test-agent",
session_id="test-session",
tokens=1000,
cost_usd=10.0,
)
with pytest.raises(BudgetExceededError) as exc_info:
await cost_controller.require_budget(
agent_id="test-agent",
session_id="test-session",
estimated_tokens=100,
estimated_cost_usd=1.0,
)
assert "session" in exc_info.value.budget_type.lower()
@pytest.mark.asyncio
async def test_record_usage(
self,
cost_controller: CostController,
) -> None:
"""Test recording usage updates trackers."""
await cost_controller.record_usage(
agent_id="test-agent",
session_id="test-session",
tokens=100,
cost_usd=1.0,
)
# Check session budget was updated
session_status = await cost_controller.get_status(
BudgetScope.SESSION, "test-session"
)
assert session_status is not None
assert session_status.tokens_used == 100
# Check daily budget was updated
daily_status = await cost_controller.get_status(BudgetScope.DAILY, "test-agent")
assert daily_status is not None
assert daily_status.tokens_used == 100
@pytest.mark.asyncio
async def test_get_all_statuses(
self,
cost_controller: CostController,
) -> None:
"""Test getting all budget statuses."""
# Record some usage
await cost_controller.record_usage(
agent_id="agent-1",
session_id="session-1",
tokens=100,
cost_usd=1.0,
)
await cost_controller.record_usage(
agent_id="agent-2",
session_id="session-2",
tokens=200,
cost_usd=2.0,
)
statuses = await cost_controller.get_all_statuses()
assert len(statuses) >= 2
@pytest.mark.asyncio
async def test_set_budget(
self,
cost_controller: CostController,
) -> None:
"""Test setting custom budget."""
await cost_controller.set_budget(
scope=BudgetScope.SESSION,
scope_id="custom-session",
tokens_limit=5000,
cost_limit_usd=50.0,
)
status = await cost_controller.get_status(BudgetScope.SESSION, "custom-session")
assert status is not None
assert status.tokens_limit == 5000
assert status.cost_limit_usd == 50.0
@pytest.mark.asyncio
async def test_reset_budget(
self,
cost_controller: CostController,
) -> None:
"""Test resetting budget."""
# Record usage
await cost_controller.record_usage(
agent_id="test-agent",
session_id="test-session",
tokens=500,
cost_usd=5.0,
)
# Reset session budget
result = await cost_controller.reset_budget(BudgetScope.SESSION, "test-session")
assert result is True
# Verify reset
status = await cost_controller.get_status(BudgetScope.SESSION, "test-session")
assert status is not None
assert status.tokens_used == 0
@pytest.mark.asyncio
async def test_reset_nonexistent_budget(
self,
cost_controller: CostController,
) -> None:
"""Test resetting non-existent budget returns False."""
result = await cost_controller.reset_budget(BudgetScope.SESSION, "nonexistent")
assert result is False
@pytest.mark.asyncio
async def test_alert_handler(
self,
cost_controller: CostController,
) -> None:
"""Test alert handler is called at warning threshold."""
alerts_received = []
def alert_handler(alert_type: str, message: str, status):
alerts_received.append((alert_type, message))
cost_controller.add_alert_handler(alert_handler)
# Record usage to reach warning threshold (80%)
await cost_controller.record_usage(
agent_id="test-agent",
session_id="test-session",
tokens=850, # 85% of 1000
cost_usd=0.0,
)
assert len(alerts_received) > 0
assert alerts_received[0][0] == "warning"
@pytest.mark.asyncio
async def test_remove_alert_handler(
self,
cost_controller: CostController,
) -> None:
"""Test removing alert handler."""
alerts_received = []
def alert_handler(alert_type: str, message: str, status):
alerts_received.append((alert_type, message))
cost_controller.add_alert_handler(alert_handler)
cost_controller.remove_alert_handler(alert_handler)
# Record usage to reach warning threshold
await cost_controller.record_usage(
agent_id="test-agent",
session_id="test-session",
tokens=850,
cost_usd=0.0,
)
assert len(alerts_received) == 0
@pytest.mark.asyncio
async def test_alert_deduplication(
self,
cost_controller: CostController,
) -> None:
"""Test alerts are only sent once per budget (no spam)."""
alerts_received = []
def alert_handler(alert_type: str, message: str, status):
alerts_received.append((alert_type, message))
cost_controller.add_alert_handler(alert_handler)
# Record usage multiple times at warning level
# Session budget is 1000 with 80% threshold = 800 tokens
# 10 * 85 = 850 tokens triggers session warning once
for _ in range(10):
await cost_controller.record_usage(
agent_id="test-agent",
session_id="test-session",
tokens=85, # Each call adds 85 tokens
cost_usd=0.0,
)
# Should only receive ONE session warning (daily budget of 5000
# isn't reached yet). The key point is we don't get 10 alerts!
assert len(alerts_received) == 1
assert alerts_received[0][0] == "warning"
assert "Session" in alerts_received[0][1]

View File

@@ -0,0 +1,508 @@
"""Tests for SafetyGuardian integration."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from app.services.safety.config import SafetyConfig
from app.services.safety.costs.controller import CostController
from app.services.safety.guardian import (
SafetyGuardian,
get_safety_guardian,
reset_safety_guardian,
shutdown_safety_guardian,
)
from app.services.safety.limits.limiter import RateLimiter
from app.services.safety.loops.detector import LoopDetector
from app.services.safety.models import (
ActionMetadata,
ActionRequest,
ActionResult,
ActionType,
AuditEvent,
AuditEventType,
AutonomyLevel,
BudgetScope,
SafetyDecision,
SafetyPolicy,
)
@pytest_asyncio.fixture
async def reset_guardian():
"""Reset the singleton guardian before and after each test."""
await reset_safety_guardian()
yield
await reset_safety_guardian()
@pytest.fixture
def safety_config() -> SafetyConfig:
"""Create a test safety configuration."""
return SafetyConfig(
enabled=True,
strict_mode=True,
hitl_enabled=False,
auto_checkpoint_destructive=False,
)
@pytest.fixture
def cost_controller() -> CostController:
"""Create a cost controller for testing."""
return CostController(
default_session_tokens=1000,
default_session_cost_usd=10.0,
default_daily_tokens=5000,
default_daily_cost_usd=50.0,
)
@pytest.fixture
def rate_limiter() -> RateLimiter:
"""Create a rate limiter for testing."""
return RateLimiter()
@pytest.fixture
def loop_detector() -> LoopDetector:
"""Create a loop detector for testing."""
return LoopDetector(
history_size=10,
max_exact_repetitions=3,
max_semantic_repetitions=5,
)
def _make_audit_event() -> AuditEvent:
"""Create a mock audit event."""
return AuditEvent(
event_type=AuditEventType.ACTION_REQUESTED,
agent_id="test-agent",
action_id="test-action",
)
@pytest_asyncio.fixture
async def guardian(
safety_config: SafetyConfig,
cost_controller: CostController,
rate_limiter: RateLimiter,
loop_detector: LoopDetector,
) -> SafetyGuardian:
"""Create a SafetyGuardian for testing."""
guardian = SafetyGuardian(
config=safety_config,
cost_controller=cost_controller,
rate_limiter=rate_limiter,
loop_detector=loop_detector,
)
# Patch the audit logger to avoid actual logging
# Return proper AuditEvent objects instead of AsyncMock
guardian._audit_logger = MagicMock()
guardian._audit_logger.log = AsyncMock(return_value=_make_audit_event())
guardian._audit_logger.log_action_request = AsyncMock(
return_value=_make_audit_event()
)
guardian._audit_logger.log_action_executed = AsyncMock(return_value=None)
await guardian.initialize()
return guardian
@pytest.fixture
def sample_metadata() -> ActionMetadata:
"""Create sample action metadata."""
return ActionMetadata(
agent_id="test-agent",
session_id="test-session",
autonomy_level=AutonomyLevel.MILESTONE,
)
def create_action(
metadata: ActionMetadata,
tool_name: str = "test_tool",
action_type: ActionType = ActionType.LLM_CALL,
resource: str = "/tmp/test.txt", # noqa: S108
estimated_tokens: int = 100,
estimated_cost: float = 0.01,
) -> ActionRequest:
"""Helper to create test actions."""
return ActionRequest(
action_type=action_type,
tool_name=tool_name,
resource=resource,
arguments={},
metadata=metadata,
estimated_cost_tokens=estimated_tokens,
estimated_cost_usd=estimated_cost,
)
class TestSafetyGuardianInit:
"""Tests for SafetyGuardian initialization."""
@pytest.mark.asyncio
async def test_init_creates_subsystems(
self,
safety_config: SafetyConfig,
) -> None:
"""Test initialization creates subsystems if not provided."""
with patch(
"app.services.safety.guardian.get_audit_logger",
new_callable=AsyncMock,
):
guardian = SafetyGuardian(config=safety_config)
await guardian.initialize()
assert guardian.cost_controller is not None
assert guardian.rate_limiter is not None
assert guardian.loop_detector is not None
assert guardian.is_initialized is True
@pytest.mark.asyncio
async def test_init_with_provided_subsystems(
self,
safety_config: SafetyConfig,
cost_controller: CostController,
rate_limiter: RateLimiter,
loop_detector: LoopDetector,
) -> None:
"""Test initialization uses provided subsystems."""
guardian = SafetyGuardian(
config=safety_config,
cost_controller=cost_controller,
rate_limiter=rate_limiter,
loop_detector=loop_detector,
)
guardian._audit_logger = MagicMock()
await guardian.initialize()
# Should use the provided instances
assert guardian.cost_controller is cost_controller
assert guardian.rate_limiter is rate_limiter
assert guardian.loop_detector is loop_detector
class TestSafetyGuardianValidation:
"""Tests for SafetyGuardian.validate()."""
@pytest.mark.asyncio
async def test_validate_success(
self,
guardian: SafetyGuardian,
sample_metadata: ActionMetadata,
) -> None:
"""Test successful validation passes all checks."""
action = create_action(sample_metadata)
result = await guardian.validate(action)
assert result.allowed is True
assert result.decision == SafetyDecision.ALLOW
@pytest.mark.asyncio
async def test_validate_disabled_allows_all(
self,
guardian: SafetyGuardian,
sample_metadata: ActionMetadata,
) -> None:
"""Test validation with disabled safety allows all."""
guardian._config.enabled = False
action = create_action(sample_metadata)
result = await guardian.validate(action)
assert result.allowed is True
assert "disabled" in result.reasons[0].lower()
@pytest.mark.asyncio
async def test_validate_budget_exceeded(
self,
guardian: SafetyGuardian,
sample_metadata: ActionMetadata,
) -> None:
"""Test validation fails when budget exceeded."""
# Use up the session budget
await guardian.cost_controller.record_usage(
agent_id=sample_metadata.agent_id,
session_id=sample_metadata.session_id,
tokens=1000,
cost_usd=10.0,
)
action = create_action(sample_metadata, estimated_tokens=100)
result = await guardian.validate(action)
assert result.allowed is False
assert result.decision == SafetyDecision.DENY
assert any("budget" in r.lower() for r in result.reasons)
@pytest.mark.asyncio
async def test_validate_rate_limit_exceeded(
self,
guardian: SafetyGuardian,
sample_metadata: ActionMetadata,
) -> None:
"""Test validation fails when rate limit exceeded."""
# Exhaust rate limits by calling validate many times
for _ in range(100): # More than default limit
action = create_action(sample_metadata)
await guardian.rate_limiter.acquire("actions", sample_metadata.agent_id)
action = create_action(sample_metadata)
result = await guardian.validate(action)
# Should be denied or delayed
assert result.allowed is False
assert result.decision in (SafetyDecision.DENY, SafetyDecision.DELAY)
@pytest.mark.asyncio
async def test_validate_loop_detected(
self,
guardian: SafetyGuardian,
sample_metadata: ActionMetadata,
) -> None:
"""Test validation fails when loop detected."""
action = create_action(sample_metadata)
# Record the same action multiple times (to trigger loop)
for _ in range(3):
await guardian.loop_detector.record(action)
result = await guardian.validate(action)
assert result.allowed is False
assert result.decision == SafetyDecision.DENY
assert any("loop" in r.lower() for r in result.reasons)
@pytest.mark.asyncio
async def test_validate_denied_tool(
self,
guardian: SafetyGuardian,
sample_metadata: ActionMetadata,
) -> None:
"""Test validation fails for denied tools."""
# Create action with tool that matches denied pattern
action = create_action(sample_metadata, tool_name="shell_exec")
# Create policy with denied pattern
policy = SafetyPolicy(
name="test-policy",
allowed_tools=["*"],
denied_tools=["shell_*"],
)
result = await guardian.validate(action, policy=policy)
assert result.allowed is False
assert result.decision == SafetyDecision.DENY
assert any("denied" in r.lower() for r in result.reasons)
@pytest.mark.asyncio
async def test_validate_with_custom_policy(
self,
guardian: SafetyGuardian,
sample_metadata: ActionMetadata,
) -> None:
"""Test validation with custom policy."""
action = create_action(sample_metadata, tool_name="allowed_tool")
policy = SafetyPolicy(
name="test-custom-policy",
allowed_tools=["allowed_*"],
denied_tools=[],
)
result = await guardian.validate(action, policy=policy)
assert result.allowed is True
assert result.decision == SafetyDecision.ALLOW
class TestSafetyGuardianRecording:
"""Tests for SafetyGuardian.record_execution()."""
@pytest.mark.asyncio
async def test_record_execution_updates_cost(
self,
guardian: SafetyGuardian,
sample_metadata: ActionMetadata,
) -> None:
"""Test recording execution updates cost tracker."""
action = create_action(sample_metadata)
action_result = ActionResult(
action_id=action.id,
success=True,
actual_cost_tokens=50,
actual_cost_usd=0.005,
)
await guardian.record_execution(action, action_result)
# Check cost was recorded
status = await guardian.cost_controller.get_status(
BudgetScope.SESSION, sample_metadata.session_id
)
assert status is not None
assert status.tokens_used == 50
@pytest.mark.asyncio
async def test_record_execution_updates_loop_history(
self,
guardian: SafetyGuardian,
sample_metadata: ActionMetadata,
) -> None:
"""Test recording execution updates loop detector history."""
action = create_action(sample_metadata)
action_result = ActionResult(
action_id=action.id,
success=True,
)
await guardian.record_execution(action, action_result)
# Check action was recorded in loop detector
stats = await guardian.loop_detector.get_stats(sample_metadata.agent_id)
assert stats["history_size"] == 1
class TestSafetyGuardianSingleton:
"""Tests for SafetyGuardian singleton functions."""
@pytest.mark.asyncio
async def test_get_safety_guardian_creates_singleton(
self,
reset_guardian,
) -> None:
"""Test get_safety_guardian creates singleton."""
with patch(
"app.services.safety.guardian.get_audit_logger",
new_callable=AsyncMock,
):
guardian1 = await get_safety_guardian()
guardian2 = await get_safety_guardian()
assert guardian1 is guardian2
assert guardian1.is_initialized is True
@pytest.mark.asyncio
async def test_shutdown_safety_guardian(
self,
reset_guardian,
) -> None:
"""Test shutdown cleans up singleton."""
with patch(
"app.services.safety.guardian.get_audit_logger",
new_callable=AsyncMock,
):
guardian = await get_safety_guardian()
assert guardian.is_initialized is True
await shutdown_safety_guardian()
# Singleton should be cleared - next get creates new instance
@pytest.mark.asyncio
async def test_reset_safety_guardian(
self,
reset_guardian,
) -> None:
"""Test reset clears singleton."""
with patch(
"app.services.safety.guardian.get_audit_logger",
new_callable=AsyncMock,
):
guardian1 = await get_safety_guardian()
await reset_safety_guardian()
guardian2 = await get_safety_guardian()
assert guardian1 is not guardian2
class TestPatternMatching:
"""Tests for pattern matching logic."""
def test_exact_match(self) -> None:
"""Test exact pattern matching."""
guardian = SafetyGuardian()
assert guardian._matches_pattern("file_read", "file_read") is True
assert guardian._matches_pattern("file_read", "file_write") is False
def test_wildcard_all(self) -> None:
"""Test wildcard * matches all."""
guardian = SafetyGuardian()
assert guardian._matches_pattern("anything", "*") is True
assert guardian._matches_pattern("", "*") is True
def test_prefix_wildcard(self) -> None:
"""Test prefix wildcard matching."""
guardian = SafetyGuardian()
assert guardian._matches_pattern("test_read", "*_read") is True
assert guardian._matches_pattern("test_write", "*_read") is False
def test_suffix_wildcard(self) -> None:
"""Test suffix wildcard matching."""
guardian = SafetyGuardian()
assert guardian._matches_pattern("file_read", "file_*") is True
assert guardian._matches_pattern("shell_read", "file_*") is False
def test_contains_wildcard(self) -> None:
"""Test contains wildcard matching."""
guardian = SafetyGuardian()
assert guardian._matches_pattern("test_dangerous_action", "*dangerous*") is True
assert guardian._matches_pattern("test_safe_action", "*dangerous*") is False
class TestErrorHandling:
"""Tests for error handling in SafetyGuardian."""
@pytest.mark.asyncio
async def test_strict_mode_fails_on_error(
self,
guardian: SafetyGuardian,
sample_metadata: ActionMetadata,
) -> None:
"""Test strict mode denies on unexpected errors."""
action = create_action(sample_metadata)
# Force an error by breaking the cost controller
original_check = guardian.cost_controller.check_budget
guardian.cost_controller.check_budget = AsyncMock(
side_effect=Exception("Unexpected error")
)
result = await guardian.validate(action)
assert result.allowed is False
assert result.decision == SafetyDecision.DENY
assert any("error" in r.lower() for r in result.reasons)
# Restore
guardian.cost_controller.check_budget = original_check
@pytest.mark.asyncio
async def test_non_strict_mode_allows_on_error(
self,
guardian: SafetyGuardian,
sample_metadata: ActionMetadata,
) -> None:
"""Test non-strict mode allows on unexpected errors."""
guardian._config.strict_mode = False
action = create_action(sample_metadata)
# Force an error by breaking the cost controller
original_check = guardian.cost_controller.check_budget
guardian.cost_controller.check_budget = AsyncMock(
side_effect=Exception("Unexpected error")
)
result = await guardian.validate(action)
assert result.allowed is True
assert result.decision == SafetyDecision.ALLOW
# Restore
guardian.cost_controller.check_budget = original_check
guardian._config.strict_mode = True

View File

@@ -0,0 +1,405 @@
"""Tests for rate limiter module."""
import pytest
from app.services.safety.exceptions import RateLimitExceededError
from app.services.safety.limits.limiter import (
RateLimiter,
SlidingWindowCounter,
)
from app.services.safety.models import (
ActionMetadata,
ActionRequest,
ActionType,
RateLimitConfig,
)
@pytest.fixture
def sliding_counter() -> SlidingWindowCounter:
"""Create a sliding window counter for testing."""
return SlidingWindowCounter(
limit=5,
window_seconds=60,
burst_limit=3,
)
@pytest.fixture
def rate_limiter() -> RateLimiter:
"""Create a rate limiter for testing."""
limiter = RateLimiter()
# Configure a test limit
limiter.configure(
RateLimitConfig(
name="test_limit",
limit=5,
window_seconds=60,
burst_limit=3,
)
)
return limiter
@pytest.fixture
def sample_metadata() -> ActionMetadata:
"""Create sample action metadata."""
return ActionMetadata(
agent_id="test-agent",
session_id="test-session",
)
def create_action(
metadata: ActionMetadata,
action_type: ActionType = ActionType.LLM_CALL,
) -> ActionRequest:
"""Helper to create test actions."""
return ActionRequest(
action_type=action_type,
tool_name="test_tool",
resource="test-resource",
arguments={},
metadata=metadata,
)
class TestSlidingWindowCounter:
"""Tests for SlidingWindowCounter class."""
@pytest.mark.asyncio
async def test_first_acquire_allowed(
self,
sliding_counter: SlidingWindowCounter,
) -> None:
"""Test first acquire is always allowed."""
allowed, retry_after = await sliding_counter.try_acquire()
assert allowed is True
assert retry_after == 0.0
@pytest.mark.asyncio
async def test_burst_limit(
self,
sliding_counter: SlidingWindowCounter,
) -> None:
"""Test burst limit is enforced."""
# Acquire up to burst limit (3)
for _ in range(3):
allowed, _ = await sliding_counter.try_acquire()
assert allowed is True
# Next should be denied (burst exceeded)
allowed, retry_after = await sliding_counter.try_acquire()
assert allowed is False
assert retry_after > 0
@pytest.mark.asyncio
async def test_get_status(
self,
sliding_counter: SlidingWindowCounter,
) -> None:
"""Test getting counter status."""
# Make some requests
await sliding_counter.try_acquire()
await sliding_counter.try_acquire()
current, remaining, reset_in = await sliding_counter.get_status()
assert current == 2
assert remaining == 3 # 5 - 2
assert reset_in >= 0
class TestRateLimiter:
"""Tests for RateLimiter class."""
@pytest.mark.asyncio
async def test_check_status(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test checking rate limit status."""
status = await rate_limiter.check("test_limit", "test-key")
assert status.name == "test_limit"
assert status.current_count == 0
assert status.limit == 5
assert status.remaining == 5
assert status.is_limited is False
@pytest.mark.asyncio
async def test_acquire_success(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test successful acquire."""
allowed, status = await rate_limiter.acquire("test_limit", "test-key")
assert allowed is True
assert status.current_count == 1
assert status.remaining == 4
@pytest.mark.asyncio
async def test_acquire_burst_exceeded(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test acquire fails when burst exceeded."""
# Acquire up to burst limit
for _ in range(3):
allowed, _ = await rate_limiter.acquire("test_limit", "test-key")
assert allowed is True
# Next should fail
allowed, status = await rate_limiter.acquire("test_limit", "test-key")
assert allowed is False
assert status.is_limited is True
assert status.retry_after_seconds > 0
@pytest.mark.asyncio
async def test_require_success(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test require passes when not limited."""
# Should not raise
await rate_limiter.require("test_limit", "test-key")
@pytest.mark.asyncio
async def test_require_raises(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test require raises when limited."""
# Use up burst limit
for _ in range(3):
await rate_limiter.acquire("test_limit", "test-key")
with pytest.raises(RateLimitExceededError) as exc_info:
await rate_limiter.require("test_limit", "test-key")
assert exc_info.value.limit_type == "test_limit"
assert exc_info.value.retry_after_seconds > 0
@pytest.mark.asyncio
async def test_check_action_allowed(
self,
rate_limiter: RateLimiter,
sample_metadata: ActionMetadata,
) -> None:
"""Test checking action is allowed."""
action = create_action(sample_metadata)
allowed, statuses = await rate_limiter.check_action(action)
assert allowed is True
assert len(statuses) >= 1 # At least "actions" limit
@pytest.mark.asyncio
async def test_check_action_llm_limits(
self,
rate_limiter: RateLimiter,
sample_metadata: ActionMetadata,
) -> None:
"""Test LLM actions check LLM-specific limits."""
action = create_action(sample_metadata, action_type=ActionType.LLM_CALL)
allowed, statuses = await rate_limiter.check_action(action)
assert allowed is True
# Should have checked both "actions" and "llm_calls"
limit_names = [s.name for s in statuses]
assert "actions" in limit_names
assert "llm_calls" in limit_names
@pytest.mark.asyncio
async def test_check_action_file_limits(
self,
rate_limiter: RateLimiter,
sample_metadata: ActionMetadata,
) -> None:
"""Test file actions check file-specific limits."""
action = create_action(sample_metadata, action_type=ActionType.FILE_READ)
allowed, statuses = await rate_limiter.check_action(action)
assert allowed is True
# Should have checked both "actions" and "file_ops"
limit_names = [s.name for s in statuses]
assert "actions" in limit_names
assert "file_ops" in limit_names
@pytest.mark.asyncio
async def test_check_action_does_not_consume_slot(
self,
rate_limiter: RateLimiter,
sample_metadata: ActionMetadata,
) -> None:
"""Test check_action only checks without consuming slots."""
action = create_action(sample_metadata)
# Check multiple times - should never consume
for _ in range(10):
allowed, _ = await rate_limiter.check_action(action)
assert allowed is True
# Verify no slots were consumed
status = await rate_limiter.check("actions", sample_metadata.agent_id)
assert status.current_count == 0
@pytest.mark.asyncio
async def test_record_action_consumes_slot(
self,
rate_limiter: RateLimiter,
sample_metadata: ActionMetadata,
) -> None:
"""Test record_action consumes rate limit slots."""
action = create_action(sample_metadata)
# Record the action
await rate_limiter.record_action(action)
# Verify slot was consumed
status = await rate_limiter.check("actions", sample_metadata.agent_id)
assert status.current_count == 1
@pytest.mark.asyncio
async def test_record_action_consumes_type_specific_slots(
self,
rate_limiter: RateLimiter,
sample_metadata: ActionMetadata,
) -> None:
"""Test record_action consumes type-specific slots."""
# LLM action
llm_action = create_action(sample_metadata, action_type=ActionType.LLM_CALL)
await rate_limiter.record_action(llm_action)
statuses = await rate_limiter.get_all_statuses(sample_metadata.agent_id)
assert statuses["actions"].current_count == 1
assert statuses["llm_calls"].current_count == 1
assert statuses["file_ops"].current_count == 0
# File action
file_action = create_action(sample_metadata, action_type=ActionType.FILE_READ)
await rate_limiter.record_action(file_action)
statuses = await rate_limiter.get_all_statuses(sample_metadata.agent_id)
assert statuses["actions"].current_count == 2
assert statuses["llm_calls"].current_count == 1
assert statuses["file_ops"].current_count == 1
@pytest.mark.asyncio
async def test_get_all_statuses(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test getting all rate limit statuses."""
# Make some requests
await rate_limiter.acquire("actions", "test-key")
await rate_limiter.acquire("llm_calls", "test-key")
statuses = await rate_limiter.get_all_statuses("test-key")
assert "actions" in statuses
assert "llm_calls" in statuses
assert "file_ops" in statuses
assert statuses["actions"].current_count >= 1
assert statuses["llm_calls"].current_count >= 1
@pytest.mark.asyncio
async def test_reset_single(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test resetting a single rate limit."""
# Make some requests
await rate_limiter.acquire("test_limit", "test-key")
await rate_limiter.acquire("test_limit", "test-key")
# Reset
result = await rate_limiter.reset("test_limit", "test-key")
assert result is True
# Check it's reset
status = await rate_limiter.check("test_limit", "test-key")
assert status.current_count == 0
@pytest.mark.asyncio
async def test_reset_nonexistent(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test resetting non-existent limit returns False."""
result = await rate_limiter.reset("nonexistent", "test-key")
assert result is False
@pytest.mark.asyncio
async def test_reset_all(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test resetting all rate limits for a key."""
# Make requests across multiple limits
await rate_limiter.acquire("actions", "test-key")
await rate_limiter.acquire("llm_calls", "test-key")
await rate_limiter.acquire("file_ops", "test-key")
# Reset all
count = await rate_limiter.reset_all("test-key")
assert count >= 3
# Check they're reset
statuses = await rate_limiter.get_all_statuses("test-key")
for status in statuses.values():
assert status.current_count == 0
@pytest.mark.asyncio
async def test_per_key_isolation(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test rate limits are isolated per key."""
# Use up burst limit for key-1
for _ in range(3):
await rate_limiter.acquire("test_limit", "key-1")
# key-1 should be limited
allowed1, _ = await rate_limiter.acquire("test_limit", "key-1")
assert allowed1 is False
# key-2 should still be allowed
allowed2, _ = await rate_limiter.acquire("test_limit", "key-2")
assert allowed2 is True
@pytest.mark.asyncio
async def test_configure_custom_limit(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test configuring custom rate limits."""
rate_limiter.configure(
RateLimitConfig(
name="custom",
limit=100,
window_seconds=120,
burst_limit=50,
)
)
status = await rate_limiter.check("custom", "test-key")
assert status.limit == 100
assert status.window_seconds == 120
@pytest.mark.asyncio
async def test_default_limit_fallback(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test fallback to default limit for unknown limit names."""
# Request limit that doesn't exist
status = await rate_limiter.check("unknown_limit", "test-key")
# Should use default (60/60s)
assert status.limit == 60
assert status.window_seconds == 60