feat(safety): enhance rate limiting and cost control with alert deduplication and usage tracking

- Added `record_action` in `RateLimiter` for precise tracking of slot consumption post-validation.
- Introduced deduplication mechanism for warning alerts in `CostController` to prevent spamming.
- Refactored `CostController`'s session and daily budget alert handling for improved clarity.
- Implemented test suites for `CostController` and `SafetyGuardian` to validate changes.
- Expanded integration testing to cover deduplication, validation, and loop detection edge cases.
This commit is contained in:
2026-01-03 17:55:34 +01:00
parent 520c06175e
commit caf283bed2
9 changed files with 1782 additions and 92 deletions

View File

@@ -411,7 +411,20 @@ async def shutdown_mcp_client() -> None:
_manager_instance = None
def reset_mcp_client() -> None:
"""Reset the global MCP client manager (for testing)."""
async def reset_mcp_client() -> None:
"""
Reset the global MCP client manager (for testing).
This is an async function to properly acquire the manager lock
and avoid race conditions with get_mcp_client().
"""
global _manager_instance
_manager_instance = None
async with _manager_lock:
if _manager_instance is not None:
# Shutdown gracefully before resetting
try:
await _manager_instance.shutdown()
except Exception: # noqa: S110
pass # Ignore errors during test cleanup
_manager_instance = None

View File

@@ -161,7 +161,7 @@ class MCPConnection:
server_name=self.server_name,
url=self.config.url,
cause=e,
)
) from e
else:
# For STDIO and SSE transports, we'll implement later
raise NotImplementedError(
@@ -297,13 +297,13 @@ class MCPConnection:
server_name=self.server_name,
url=f"{self.config.url}{path}",
cause=e,
)
) from e
except Exception as e:
raise MCPConnectionError(
f"Request failed: {e}",
server_name=self.server_name,
cause=e,
)
) from e
class ConnectionPool:
@@ -322,8 +322,19 @@ class ConnectionPool:
"""
self._connections: dict[str, MCPConnection] = {}
self._lock = asyncio.Lock()
self._per_server_locks: dict[str, asyncio.Lock] = {}
self._max_per_server = max_connections_per_server
def _get_server_lock(self, server_name: str) -> asyncio.Lock:
"""Get or create a lock for a specific server.
Uses setdefault for atomic dict access to prevent race conditions
where two coroutines could create different locks for the same server.
"""
# setdefault is atomic - if key exists, returns existing value
# if key doesn't exist, inserts new value and returns it
return self._per_server_locks.setdefault(server_name, asyncio.Lock())
async def get_connection(
self,
server_name: str,
@@ -332,6 +343,9 @@ class ConnectionPool:
"""
Get or create a connection to a server.
Uses per-server locking to avoid blocking all connections
when establishing a new connection.
Args:
server_name: Name of the server
config: Server configuration
@@ -339,17 +353,33 @@ class ConnectionPool:
Returns:
Active connection
"""
async with self._lock:
if server_name not in self._connections:
connection = MCPConnection(server_name, config)
await connection.connect()
self._connections[server_name] = connection
# Quick check without lock - if connection exists and is connected, return it
if server_name in self._connections:
connection = self._connections[server_name]
if connection.is_connected:
return connection
# Reconnect if not connected
if not connection.is_connected:
# Need to create or reconnect - use per-server lock to avoid blocking others
async with self._lock:
server_lock = self._get_server_lock(server_name)
async with server_lock:
# Double-check after acquiring per-server lock
if server_name in self._connections:
connection = self._connections[server_name]
if connection.is_connected:
return connection
# Connection exists but not connected - reconnect
await connection.connect()
return connection
# Create new connection (outside global lock, under per-server lock)
connection = MCPConnection(server_name, config)
await connection.connect()
# Store connection under global lock
async with self._lock:
self._connections[server_name] = connection
return connection
@@ -374,6 +404,9 @@ class ConnectionPool:
if server_name in self._connections:
await self._connections[server_name].disconnect()
del self._connections[server_name]
# Clean up per-server lock
if server_name in self._per_server_locks:
del self._per_server_locks[server_name]
async def close_all(self) -> None:
"""Close all connections in the pool."""
@@ -385,6 +418,7 @@ class ConnectionPool:
logger.warning("Error closing connection: %s", e)
self._connections.clear()
self._per_server_locks.clear()
logger.info("Closed all MCP connections")
async def health_check_all(self) -> dict[str, bool]:
@@ -394,8 +428,12 @@ class ConnectionPool:
Returns:
Dict mapping server names to health status
"""
# Copy connections under lock to prevent modification during iteration
async with self._lock:
connections_snapshot = dict(self._connections)
results = {}
for name, connection in self._connections.items():
for name, connection in connections_snapshot.items():
results[name] = await connection.health_check()
return results

View File

@@ -185,6 +185,9 @@ class CostController:
# Alert handlers
self._alert_handlers: list[Any] = []
# Track which budgets have had warning alerts sent (to avoid spam)
self._warned_budgets: set[str] = set()
async def get_or_create_tracker(
self,
scope: BudgetScope,
@@ -343,32 +346,44 @@ class CostController:
"""
# Update session budget
if session_id:
session_key = f"session:{session_id}"
session_tracker = await self.get_or_create_tracker(
BudgetScope.SESSION, session_id
)
await session_tracker.add_usage(tokens, cost_usd)
# Check for warning
# Check for warning (only alert once per budget to avoid spam)
status = await session_tracker.get_status()
if status.is_warning and not status.is_exceeded:
await self._send_alert(
"warning",
f"Session {session_id} at {status.tokens_used}/{status.tokens_limit} tokens",
status,
)
if session_key not in self._warned_budgets:
self._warned_budgets.add(session_key)
await self._send_alert(
"warning",
f"Session {session_id} at {status.tokens_used}/{status.tokens_limit} tokens",
status,
)
elif not status.is_warning:
# Clear warning flag if usage dropped below threshold (e.g., after reset)
self._warned_budgets.discard(session_key)
# Update agent daily budget
daily_key = f"daily:{agent_id}"
agent_tracker = await self.get_or_create_tracker(BudgetScope.DAILY, agent_id)
await agent_tracker.add_usage(tokens, cost_usd)
# Check for warning
# Check for warning (only alert once per budget to avoid spam)
status = await agent_tracker.get_status()
if status.is_warning and not status.is_exceeded:
await self._send_alert(
"warning",
f"Agent {agent_id} at {status.tokens_used}/{status.tokens_limit} daily tokens",
status,
)
if daily_key not in self._warned_budgets:
self._warned_budgets.add(daily_key)
await self._send_alert(
"warning",
f"Agent {agent_id} at {status.tokens_used}/{status.tokens_limit} daily tokens",
status,
)
elif not status.is_warning:
# Clear warning flag if usage dropped below threshold (e.g., after reset)
self._warned_budgets.discard(daily_key)
async def get_status(
self,
@@ -388,20 +403,18 @@ class CostController:
key = f"{scope.value}:{scope_id}"
async with self._lock:
tracker = self._trackers.get(key)
if tracker:
return await tracker.get_status()
return None
# Get status while holding lock to prevent TOCTOU race
if tracker:
return await tracker.get_status()
return None
async def get_all_statuses(self) -> list[BudgetStatus]:
"""Get status of all tracked budgets."""
statuses = []
async with self._lock:
trackers = list(self._trackers.values())
for tracker in trackers:
statuses.append(await tracker.get_status())
# Get all statuses while holding lock to prevent TOCTOU race
for tracker in self._trackers.values():
statuses.append(await tracker.get_status())
return statuses
async def set_budget(
@@ -453,11 +466,11 @@ class CostController:
key = f"{scope.value}:{scope_id}"
async with self._lock:
tracker = self._trackers.get(key)
if tracker:
await tracker.reset()
return True
return False
# Reset while holding lock to prevent TOCTOU race
if tracker:
await tracker.reset()
return True
return False
def add_alert_handler(self, handler: Any) -> None:
"""Add an alert handler."""

View File

@@ -15,13 +15,20 @@ from .config import (
get_policy_for_autonomy_level,
get_safety_config,
)
from .costs.controller import CostController
from .exceptions import (
BudgetExceededError,
LoopDetectedError,
RateLimitExceededError,
SafetyError,
)
from .limits.limiter import RateLimiter
from .loops.detector import LoopDetector
from .models import (
ActionRequest,
ActionResult,
AuditEventType,
BudgetScope,
GuardianResult,
SafetyDecision,
SafetyPolicy,
@@ -62,6 +69,9 @@ class SafetyGuardian:
self,
config: SafetyConfig | None = None,
audit_logger: AuditLogger | None = None,
cost_controller: CostController | None = None,
rate_limiter: RateLimiter | None = None,
loop_detector: LoopDetector | None = None,
) -> None:
"""
Initialize the SafetyGuardian.
@@ -69,17 +79,22 @@ class SafetyGuardian:
Args:
config: Optional safety configuration. If None, loads from environment.
audit_logger: Optional audit logger. If None, uses global instance.
cost_controller: Optional cost controller. If None, creates default.
rate_limiter: Optional rate limiter. If None, creates default.
loop_detector: Optional loop detector. If None, creates default.
"""
self._config = config or get_safety_config()
self._audit_logger = audit_logger
self._initialized = False
self._lock = asyncio.Lock()
# Subsystem references (will be initialized lazily)
# Core safety subsystems (always initialized)
self._cost_controller: CostController | None = cost_controller
self._rate_limiter: RateLimiter | None = rate_limiter
self._loop_detector: LoopDetector | None = loop_detector
# Optional subsystems (will be initialized when available)
self._permission_manager: Any = None
self._cost_controller: Any = None
self._rate_limiter: Any = None
self._loop_detector: Any = None
self._hitl_manager: Any = None
self._rollback_manager: Any = None
self._content_filter: Any = None
@@ -95,6 +110,21 @@ class SafetyGuardian:
"""Check if the guardian is initialized."""
return self._initialized
@property
def cost_controller(self) -> CostController | None:
"""Get the cost controller instance."""
return self._cost_controller
@property
def rate_limiter(self) -> RateLimiter | None:
"""Get the rate limiter instance."""
return self._rate_limiter
@property
def loop_detector(self) -> LoopDetector | None:
"""Get the loop detector instance."""
return self._loop_detector
async def initialize(self) -> None:
"""Initialize the SafetyGuardian and all subsystems."""
async with self._lock:
@@ -108,11 +138,23 @@ class SafetyGuardian:
if self._audit_logger is None:
self._audit_logger = await get_audit_logger()
# Initialize subsystems lazily as they're implemented
# For now, we'll import and initialize them when available
# Initialize core safety subsystems
if self._cost_controller is None:
self._cost_controller = CostController()
logger.debug("Initialized CostController")
if self._rate_limiter is None:
self._rate_limiter = RateLimiter()
logger.debug("Initialized RateLimiter")
if self._loop_detector is None:
self._loop_detector = LoopDetector()
logger.debug("Initialized LoopDetector")
self._initialized = True
logger.info("SafetyGuardian initialized")
logger.info(
"SafetyGuardian initialized with CostController, RateLimiter, LoopDetector"
)
async def shutdown(self) -> None:
"""Shutdown the SafetyGuardian and all subsystems."""
@@ -309,13 +351,40 @@ class SafetyGuardian:
# Update cost tracking
if self._cost_controller:
# Track actual cost
pass
try:
# Use explicit None check - 0 is a valid cost value
tokens = (
result.actual_cost_tokens
if result.actual_cost_tokens is not None
else action.estimated_cost_tokens
)
cost_usd = (
result.actual_cost_usd
if result.actual_cost_usd is not None
else action.estimated_cost_usd
)
await self._cost_controller.record_usage(
agent_id=action.metadata.agent_id,
session_id=action.metadata.session_id,
tokens=tokens,
cost_usd=cost_usd,
)
except Exception as e:
logger.warning("Failed to record cost: %s", e)
# Update rate limiter - consume slots for executed actions
if self._rate_limiter:
try:
await self._rate_limiter.record_action(action)
except Exception as e:
logger.warning("Failed to record action in rate limiter: %s", e)
# Update loop detection history
if self._loop_detector:
# Add to action history
pass
try:
await self._loop_detector.record(action)
except Exception as e:
logger.warning("Failed to record action in loop detector: %s", e)
async def rollback(self, checkpoint_id: str) -> bool:
"""
@@ -442,14 +511,80 @@ class SafetyGuardian:
policy: SafetyPolicy,
) -> GuardianResult:
"""Check if action is within budget."""
# TODO: Implement with CostController
# For now, return allow
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Budget check passed (not fully implemented)"],
)
if self._cost_controller is None:
logger.warning("CostController not initialized - skipping budget check")
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Budget check skipped (controller not initialized)"],
)
agent_id = action.metadata.agent_id
session_id = action.metadata.session_id
try:
# Check if we have budget for this action
has_budget = await self._cost_controller.check_budget(
agent_id=agent_id,
session_id=session_id,
estimated_tokens=action.estimated_cost_tokens,
estimated_cost_usd=action.estimated_cost_usd,
)
if not has_budget:
# Get current status for better error message
if session_id:
session_status = await self._cost_controller.get_status(
BudgetScope.SESSION, session_id
)
if session_status and session_status.is_exceeded:
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=[
f"Session budget exceeded: {session_status.tokens_used}"
f"/{session_status.tokens_limit} tokens"
],
)
agent_status = await self._cost_controller.get_status(
BudgetScope.DAILY, agent_id
)
if agent_status and agent_status.is_exceeded:
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=[
f"Daily budget exceeded: {agent_status.tokens_used}"
f"/{agent_status.tokens_limit} tokens"
],
)
# Generic budget exceeded
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=["Budget exceeded"],
)
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Budget check passed"],
)
except BudgetExceededError as e:
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=[str(e)],
)
async def _check_rate_limit(
self,
@@ -457,14 +592,78 @@ class SafetyGuardian:
policy: SafetyPolicy,
) -> GuardianResult:
"""Check if action is within rate limits."""
# TODO: Implement with RateLimiter
# For now, return allow
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Rate limit check passed (not fully implemented)"],
)
if self._rate_limiter is None:
logger.warning("RateLimiter not initialized - skipping rate limit check")
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Rate limit check skipped (limiter not initialized)"],
)
try:
# Check all applicable rate limits for this action
allowed, statuses = await self._rate_limiter.check_action(action)
if not allowed:
# Find the first exceeded limit for the error message
exceeded_status = next(
(s for s in statuses if s.is_limited),
statuses[0] if statuses else None,
)
if exceeded_status:
retry_after = exceeded_status.retry_after_seconds
# Determine if this is a soft limit (delay) or hard limit (deny)
if retry_after > 0 and retry_after <= 5.0:
# Short wait - suggest delay
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DELAY,
reasons=[
f"Rate limit '{exceeded_status.name}' exceeded. "
f"Current: {exceeded_status.current_count}/{exceeded_status.limit}"
],
retry_after_seconds=retry_after,
)
else:
# Hard deny
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=[
f"Rate limit '{exceeded_status.name}' exceeded. "
f"Current: {exceeded_status.current_count}/{exceeded_status.limit}. "
f"Retry after {retry_after:.1f}s"
],
retry_after_seconds=retry_after,
)
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=["Rate limit exceeded"],
)
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Rate limit check passed"],
)
except RateLimitExceededError as e:
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=[str(e)],
retry_after_seconds=e.retry_after_seconds,
)
async def _check_loops(
self,
@@ -472,14 +671,51 @@ class SafetyGuardian:
policy: SafetyPolicy,
) -> GuardianResult:
"""Check for action loops."""
# TODO: Implement with LoopDetector
# For now, return allow
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Loop check passed (not fully implemented)"],
)
if self._loop_detector is None:
logger.warning("LoopDetector not initialized - skipping loop check")
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Loop check skipped (detector not initialized)"],
)
try:
# Check if this action would create a loop
is_loop, loop_type = await self._loop_detector.check(action)
if is_loop:
# Get suggestions for breaking the loop
from .loops.detector import LoopBreaker
suggestions = await LoopBreaker.suggest_alternatives(
action, loop_type or "unknown"
)
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=[
f"Loop detected: {loop_type}",
*suggestions,
],
)
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Loop check passed"],
)
except LoopDetectedError as e:
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=[str(e)],
)
async def _check_hitl(
self,
@@ -610,7 +846,19 @@ async def shutdown_safety_guardian() -> None:
_guardian_instance = None
def reset_safety_guardian() -> None:
"""Reset the SafetyGuardian (for testing)."""
async def reset_safety_guardian() -> None:
"""
Reset the SafetyGuardian (for testing).
This is an async function to properly acquire the guardian lock
and avoid race conditions with get_safety_guardian().
"""
global _guardian_instance
_guardian_instance = None
async with _guardian_lock:
if _guardian_instance is not None:
try:
await _guardian_instance.shutdown()
except Exception: # noqa: S110
pass # Ignore errors during test cleanup
_guardian_instance = None

View File

@@ -223,7 +223,10 @@ class RateLimiter:
action: ActionRequest,
) -> tuple[bool, list[RateLimitStatus]]:
"""
Check all applicable rate limits for an action.
Check all applicable rate limits for an action WITHOUT consuming slots.
Use this during validation to check if action would be allowed.
Call record_action() after successful execution to consume slots.
Args:
action: The action to check
@@ -235,28 +238,53 @@ class RateLimiter:
statuses: list[RateLimitStatus] = []
allowed = True
# Check general actions limit
actions_allowed, actions_status = await self.acquire("actions", agent_id)
# Check general actions limit (read-only)
actions_status = await self.check("actions", agent_id)
statuses.append(actions_status)
if not actions_allowed:
if actions_status.is_limited:
allowed = False
# Check LLM-specific limit for LLM calls
if action.action_type.value == "llm_call":
llm_allowed, llm_status = await self.acquire("llm_calls", agent_id)
llm_status = await self.check("llm_calls", agent_id)
statuses.append(llm_status)
if not llm_allowed:
if llm_status.is_limited:
allowed = False
# Check file ops limit for file operations
if action.action_type.value in {"file_read", "file_write", "file_delete"}:
file_allowed, file_status = await self.acquire("file_ops", agent_id)
file_status = await self.check("file_ops", agent_id)
statuses.append(file_status)
if not file_allowed:
if file_status.is_limited:
allowed = False
return allowed, statuses
async def record_action(
self,
action: ActionRequest,
) -> None:
"""
Record an action by consuming rate limit slots.
Call this AFTER successful execution to properly count the action.
Args:
action: The executed action
"""
agent_id = action.metadata.agent_id
# Consume general actions slot
await self.acquire("actions", agent_id)
# Consume LLM-specific slot for LLM calls
if action.action_type.value == "llm_call":
await self.acquire("llm_calls", agent_id)
# Consume file ops slot for file operations
if action.action_type.value in {"file_read", "file_write", "file_delete"}:
await self.acquire("file_ops", agent_id)
async def require(
self,
limit_name: str,

View File

@@ -20,13 +20,13 @@ from app.services.mcp.routing import ToolInfo, ToolResult
@pytest.fixture
def reset_registry():
async def reset_registry():
"""Reset the singleton registry before and after each test."""
MCPServerRegistry.reset_instance()
reset_mcp_client()
await reset_mcp_client()
yield
MCPServerRegistry.reset_instance()
reset_mcp_client()
await reset_mcp_client()
@pytest.fixture
@@ -388,7 +388,8 @@ class TestModuleLevelFunctions:
mock_shutdown.return_value = None
await shutdown_mcp_client()
def test_reset_mcp_client(self, reset_registry):
@pytest.mark.asyncio
async def test_reset_mcp_client(self, reset_registry):
"""Test resetting the global client."""
reset_mcp_client()
await reset_mcp_client()
# Should not raise

View File

@@ -0,0 +1,436 @@
"""Tests for cost controller module."""
import pytest
from app.services.safety.costs.controller import (
BudgetTracker,
CostController,
)
from app.services.safety.exceptions import BudgetExceededError
from app.services.safety.models import (
ActionMetadata,
ActionRequest,
ActionType,
BudgetScope,
)
@pytest.fixture
def budget_tracker() -> BudgetTracker:
"""Create a budget tracker for testing."""
return BudgetTracker(
scope=BudgetScope.SESSION,
scope_id="test-session",
tokens_limit=1000,
cost_limit_usd=10.0,
warning_threshold=0.8,
)
@pytest.fixture
def cost_controller() -> CostController:
"""Create a cost controller for testing."""
return CostController(
default_session_tokens=1000,
default_session_cost_usd=10.0,
default_daily_tokens=5000,
default_daily_cost_usd=50.0,
)
@pytest.fixture
def sample_metadata() -> ActionMetadata:
"""Create sample action metadata."""
return ActionMetadata(
agent_id="test-agent",
session_id="test-session",
)
def create_action(
metadata: ActionMetadata,
estimated_tokens: int = 100,
estimated_cost: float = 0.01,
) -> ActionRequest:
"""Helper to create test actions."""
return ActionRequest(
action_type=ActionType.LLM_CALL,
tool_name="test_tool",
resource="test-resource",
arguments={},
metadata=metadata,
estimated_cost_tokens=estimated_tokens,
estimated_cost_usd=estimated_cost,
)
class TestBudgetTracker:
"""Tests for BudgetTracker class."""
@pytest.mark.asyncio
async def test_initial_status(self, budget_tracker: BudgetTracker) -> None:
"""Test initial budget status is clean."""
status = await budget_tracker.get_status()
assert status.tokens_used == 0
assert status.cost_used_usd == 0.0
assert status.tokens_remaining == 1000
assert status.cost_remaining_usd == 10.0
assert status.is_warning is False
assert status.is_exceeded is False
@pytest.mark.asyncio
async def test_add_usage(self, budget_tracker: BudgetTracker) -> None:
"""Test adding usage updates counters."""
await budget_tracker.add_usage(tokens=100, cost_usd=1.0)
status = await budget_tracker.get_status()
assert status.tokens_used == 100
assert status.cost_used_usd == 1.0
assert status.tokens_remaining == 900
assert status.cost_remaining_usd == 9.0
@pytest.mark.asyncio
async def test_warning_threshold(self, budget_tracker: BudgetTracker) -> None:
"""Test warning is triggered at threshold."""
# Add usage to reach 80% of tokens
await budget_tracker.add_usage(tokens=800, cost_usd=1.0)
status = await budget_tracker.get_status()
assert status.is_warning is True
assert status.is_exceeded is False
@pytest.mark.asyncio
async def test_budget_exceeded(self, budget_tracker: BudgetTracker) -> None:
"""Test budget exceeded detection."""
# Exceed token limit
await budget_tracker.add_usage(tokens=1100, cost_usd=1.0)
status = await budget_tracker.get_status()
assert status.is_exceeded is True
@pytest.mark.asyncio
async def test_check_budget_allows(self, budget_tracker: BudgetTracker) -> None:
"""Test check_budget allows within budget."""
result = await budget_tracker.check_budget(
estimated_tokens=500,
estimated_cost_usd=5.0,
)
assert result is True
@pytest.mark.asyncio
async def test_check_budget_denies(self, budget_tracker: BudgetTracker) -> None:
"""Test check_budget denies when would exceed."""
# Use most of the budget
await budget_tracker.add_usage(tokens=800, cost_usd=8.0)
# Check would exceed
result = await budget_tracker.check_budget(
estimated_tokens=300,
estimated_cost_usd=3.0,
)
assert result is False
@pytest.mark.asyncio
async def test_reset(self, budget_tracker: BudgetTracker) -> None:
"""Test manual reset clears counters."""
await budget_tracker.add_usage(tokens=500, cost_usd=5.0)
await budget_tracker.reset()
status = await budget_tracker.get_status()
assert status.tokens_used == 0
assert status.cost_used_usd == 0.0
class TestCostController:
"""Tests for CostController class."""
@pytest.mark.asyncio
async def test_check_budget_success(
self,
cost_controller: CostController,
) -> None:
"""Test budget check passes with available budget."""
result = await cost_controller.check_budget(
agent_id="test-agent",
session_id="test-session",
estimated_tokens=100,
estimated_cost_usd=1.0,
)
assert result is True
@pytest.mark.asyncio
async def test_check_budget_session_exceeded(
self,
cost_controller: CostController,
) -> None:
"""Test budget check fails when session budget exceeded."""
# Use most of session budget
await cost_controller.record_usage(
agent_id="test-agent",
session_id="test-session",
tokens=900,
cost_usd=9.0,
)
# Check would exceed
result = await cost_controller.check_budget(
agent_id="test-agent",
session_id="test-session",
estimated_tokens=200,
estimated_cost_usd=2.0,
)
assert result is False
@pytest.mark.asyncio
async def test_check_budget_daily_exceeded(
self,
cost_controller: CostController,
) -> None:
"""Test budget check fails when daily budget exceeded."""
# Use most of daily budget
await cost_controller.record_usage(
agent_id="test-agent",
session_id=None,
tokens=4900,
cost_usd=49.0,
)
# Check would exceed daily
result = await cost_controller.check_budget(
agent_id="test-agent",
session_id="new-session",
estimated_tokens=200,
estimated_cost_usd=2.0,
)
assert result is False
@pytest.mark.asyncio
async def test_check_action(
self,
cost_controller: CostController,
sample_metadata: ActionMetadata,
) -> None:
"""Test checking action budget."""
action = create_action(
sample_metadata,
estimated_tokens=100,
estimated_cost=0.01,
)
result = await cost_controller.check_action(action)
assert result is True
@pytest.mark.asyncio
async def test_require_budget_success(
self,
cost_controller: CostController,
) -> None:
"""Test require_budget passes when budget available."""
# Should not raise
await cost_controller.require_budget(
agent_id="test-agent",
session_id="test-session",
estimated_tokens=100,
estimated_cost_usd=1.0,
)
@pytest.mark.asyncio
async def test_require_budget_raises(
self,
cost_controller: CostController,
) -> None:
"""Test require_budget raises when budget exceeded."""
# Use all session budget
await cost_controller.record_usage(
agent_id="test-agent",
session_id="test-session",
tokens=1000,
cost_usd=10.0,
)
with pytest.raises(BudgetExceededError) as exc_info:
await cost_controller.require_budget(
agent_id="test-agent",
session_id="test-session",
estimated_tokens=100,
estimated_cost_usd=1.0,
)
assert "session" in exc_info.value.budget_type.lower()
@pytest.mark.asyncio
async def test_record_usage(
self,
cost_controller: CostController,
) -> None:
"""Test recording usage updates trackers."""
await cost_controller.record_usage(
agent_id="test-agent",
session_id="test-session",
tokens=100,
cost_usd=1.0,
)
# Check session budget was updated
session_status = await cost_controller.get_status(
BudgetScope.SESSION, "test-session"
)
assert session_status is not None
assert session_status.tokens_used == 100
# Check daily budget was updated
daily_status = await cost_controller.get_status(BudgetScope.DAILY, "test-agent")
assert daily_status is not None
assert daily_status.tokens_used == 100
@pytest.mark.asyncio
async def test_get_all_statuses(
self,
cost_controller: CostController,
) -> None:
"""Test getting all budget statuses."""
# Record some usage
await cost_controller.record_usage(
agent_id="agent-1",
session_id="session-1",
tokens=100,
cost_usd=1.0,
)
await cost_controller.record_usage(
agent_id="agent-2",
session_id="session-2",
tokens=200,
cost_usd=2.0,
)
statuses = await cost_controller.get_all_statuses()
assert len(statuses) >= 2
@pytest.mark.asyncio
async def test_set_budget(
self,
cost_controller: CostController,
) -> None:
"""Test setting custom budget."""
await cost_controller.set_budget(
scope=BudgetScope.SESSION,
scope_id="custom-session",
tokens_limit=5000,
cost_limit_usd=50.0,
)
status = await cost_controller.get_status(BudgetScope.SESSION, "custom-session")
assert status is not None
assert status.tokens_limit == 5000
assert status.cost_limit_usd == 50.0
@pytest.mark.asyncio
async def test_reset_budget(
self,
cost_controller: CostController,
) -> None:
"""Test resetting budget."""
# Record usage
await cost_controller.record_usage(
agent_id="test-agent",
session_id="test-session",
tokens=500,
cost_usd=5.0,
)
# Reset session budget
result = await cost_controller.reset_budget(BudgetScope.SESSION, "test-session")
assert result is True
# Verify reset
status = await cost_controller.get_status(BudgetScope.SESSION, "test-session")
assert status is not None
assert status.tokens_used == 0
@pytest.mark.asyncio
async def test_reset_nonexistent_budget(
self,
cost_controller: CostController,
) -> None:
"""Test resetting non-existent budget returns False."""
result = await cost_controller.reset_budget(BudgetScope.SESSION, "nonexistent")
assert result is False
@pytest.mark.asyncio
async def test_alert_handler(
self,
cost_controller: CostController,
) -> None:
"""Test alert handler is called at warning threshold."""
alerts_received = []
def alert_handler(alert_type: str, message: str, status):
alerts_received.append((alert_type, message))
cost_controller.add_alert_handler(alert_handler)
# Record usage to reach warning threshold (80%)
await cost_controller.record_usage(
agent_id="test-agent",
session_id="test-session",
tokens=850, # 85% of 1000
cost_usd=0.0,
)
assert len(alerts_received) > 0
assert alerts_received[0][0] == "warning"
@pytest.mark.asyncio
async def test_remove_alert_handler(
self,
cost_controller: CostController,
) -> None:
"""Test removing alert handler."""
alerts_received = []
def alert_handler(alert_type: str, message: str, status):
alerts_received.append((alert_type, message))
cost_controller.add_alert_handler(alert_handler)
cost_controller.remove_alert_handler(alert_handler)
# Record usage to reach warning threshold
await cost_controller.record_usage(
agent_id="test-agent",
session_id="test-session",
tokens=850,
cost_usd=0.0,
)
assert len(alerts_received) == 0
@pytest.mark.asyncio
async def test_alert_deduplication(
self,
cost_controller: CostController,
) -> None:
"""Test alerts are only sent once per budget (no spam)."""
alerts_received = []
def alert_handler(alert_type: str, message: str, status):
alerts_received.append((alert_type, message))
cost_controller.add_alert_handler(alert_handler)
# Record usage multiple times at warning level
# Session budget is 1000 with 80% threshold = 800 tokens
# 10 * 85 = 850 tokens triggers session warning once
for _ in range(10):
await cost_controller.record_usage(
agent_id="test-agent",
session_id="test-session",
tokens=85, # Each call adds 85 tokens
cost_usd=0.0,
)
# Should only receive ONE session warning (daily budget of 5000
# isn't reached yet). The key point is we don't get 10 alerts!
assert len(alerts_received) == 1
assert alerts_received[0][0] == "warning"
assert "Session" in alerts_received[0][1]

View File

@@ -0,0 +1,508 @@
"""Tests for SafetyGuardian integration."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from app.services.safety.config import SafetyConfig
from app.services.safety.costs.controller import CostController
from app.services.safety.guardian import (
SafetyGuardian,
get_safety_guardian,
reset_safety_guardian,
shutdown_safety_guardian,
)
from app.services.safety.limits.limiter import RateLimiter
from app.services.safety.loops.detector import LoopDetector
from app.services.safety.models import (
ActionMetadata,
ActionRequest,
ActionResult,
ActionType,
AuditEvent,
AuditEventType,
AutonomyLevel,
BudgetScope,
SafetyDecision,
SafetyPolicy,
)
@pytest_asyncio.fixture
async def reset_guardian():
"""Reset the singleton guardian before and after each test."""
await reset_safety_guardian()
yield
await reset_safety_guardian()
@pytest.fixture
def safety_config() -> SafetyConfig:
"""Create a test safety configuration."""
return SafetyConfig(
enabled=True,
strict_mode=True,
hitl_enabled=False,
auto_checkpoint_destructive=False,
)
@pytest.fixture
def cost_controller() -> CostController:
"""Create a cost controller for testing."""
return CostController(
default_session_tokens=1000,
default_session_cost_usd=10.0,
default_daily_tokens=5000,
default_daily_cost_usd=50.0,
)
@pytest.fixture
def rate_limiter() -> RateLimiter:
"""Create a rate limiter for testing."""
return RateLimiter()
@pytest.fixture
def loop_detector() -> LoopDetector:
"""Create a loop detector for testing."""
return LoopDetector(
history_size=10,
max_exact_repetitions=3,
max_semantic_repetitions=5,
)
def _make_audit_event() -> AuditEvent:
"""Create a mock audit event."""
return AuditEvent(
event_type=AuditEventType.ACTION_REQUESTED,
agent_id="test-agent",
action_id="test-action",
)
@pytest_asyncio.fixture
async def guardian(
safety_config: SafetyConfig,
cost_controller: CostController,
rate_limiter: RateLimiter,
loop_detector: LoopDetector,
) -> SafetyGuardian:
"""Create a SafetyGuardian for testing."""
guardian = SafetyGuardian(
config=safety_config,
cost_controller=cost_controller,
rate_limiter=rate_limiter,
loop_detector=loop_detector,
)
# Patch the audit logger to avoid actual logging
# Return proper AuditEvent objects instead of AsyncMock
guardian._audit_logger = MagicMock()
guardian._audit_logger.log = AsyncMock(return_value=_make_audit_event())
guardian._audit_logger.log_action_request = AsyncMock(
return_value=_make_audit_event()
)
guardian._audit_logger.log_action_executed = AsyncMock(return_value=None)
await guardian.initialize()
return guardian
@pytest.fixture
def sample_metadata() -> ActionMetadata:
"""Create sample action metadata."""
return ActionMetadata(
agent_id="test-agent",
session_id="test-session",
autonomy_level=AutonomyLevel.MILESTONE,
)
def create_action(
metadata: ActionMetadata,
tool_name: str = "test_tool",
action_type: ActionType = ActionType.LLM_CALL,
resource: str = "/tmp/test.txt", # noqa: S108
estimated_tokens: int = 100,
estimated_cost: float = 0.01,
) -> ActionRequest:
"""Helper to create test actions."""
return ActionRequest(
action_type=action_type,
tool_name=tool_name,
resource=resource,
arguments={},
metadata=metadata,
estimated_cost_tokens=estimated_tokens,
estimated_cost_usd=estimated_cost,
)
class TestSafetyGuardianInit:
"""Tests for SafetyGuardian initialization."""
@pytest.mark.asyncio
async def test_init_creates_subsystems(
self,
safety_config: SafetyConfig,
) -> None:
"""Test initialization creates subsystems if not provided."""
with patch(
"app.services.safety.guardian.get_audit_logger",
new_callable=AsyncMock,
):
guardian = SafetyGuardian(config=safety_config)
await guardian.initialize()
assert guardian.cost_controller is not None
assert guardian.rate_limiter is not None
assert guardian.loop_detector is not None
assert guardian.is_initialized is True
@pytest.mark.asyncio
async def test_init_with_provided_subsystems(
self,
safety_config: SafetyConfig,
cost_controller: CostController,
rate_limiter: RateLimiter,
loop_detector: LoopDetector,
) -> None:
"""Test initialization uses provided subsystems."""
guardian = SafetyGuardian(
config=safety_config,
cost_controller=cost_controller,
rate_limiter=rate_limiter,
loop_detector=loop_detector,
)
guardian._audit_logger = MagicMock()
await guardian.initialize()
# Should use the provided instances
assert guardian.cost_controller is cost_controller
assert guardian.rate_limiter is rate_limiter
assert guardian.loop_detector is loop_detector
class TestSafetyGuardianValidation:
"""Tests for SafetyGuardian.validate()."""
@pytest.mark.asyncio
async def test_validate_success(
self,
guardian: SafetyGuardian,
sample_metadata: ActionMetadata,
) -> None:
"""Test successful validation passes all checks."""
action = create_action(sample_metadata)
result = await guardian.validate(action)
assert result.allowed is True
assert result.decision == SafetyDecision.ALLOW
@pytest.mark.asyncio
async def test_validate_disabled_allows_all(
self,
guardian: SafetyGuardian,
sample_metadata: ActionMetadata,
) -> None:
"""Test validation with disabled safety allows all."""
guardian._config.enabled = False
action = create_action(sample_metadata)
result = await guardian.validate(action)
assert result.allowed is True
assert "disabled" in result.reasons[0].lower()
@pytest.mark.asyncio
async def test_validate_budget_exceeded(
self,
guardian: SafetyGuardian,
sample_metadata: ActionMetadata,
) -> None:
"""Test validation fails when budget exceeded."""
# Use up the session budget
await guardian.cost_controller.record_usage(
agent_id=sample_metadata.agent_id,
session_id=sample_metadata.session_id,
tokens=1000,
cost_usd=10.0,
)
action = create_action(sample_metadata, estimated_tokens=100)
result = await guardian.validate(action)
assert result.allowed is False
assert result.decision == SafetyDecision.DENY
assert any("budget" in r.lower() for r in result.reasons)
@pytest.mark.asyncio
async def test_validate_rate_limit_exceeded(
self,
guardian: SafetyGuardian,
sample_metadata: ActionMetadata,
) -> None:
"""Test validation fails when rate limit exceeded."""
# Exhaust rate limits by calling validate many times
for _ in range(100): # More than default limit
action = create_action(sample_metadata)
await guardian.rate_limiter.acquire("actions", sample_metadata.agent_id)
action = create_action(sample_metadata)
result = await guardian.validate(action)
# Should be denied or delayed
assert result.allowed is False
assert result.decision in (SafetyDecision.DENY, SafetyDecision.DELAY)
@pytest.mark.asyncio
async def test_validate_loop_detected(
self,
guardian: SafetyGuardian,
sample_metadata: ActionMetadata,
) -> None:
"""Test validation fails when loop detected."""
action = create_action(sample_metadata)
# Record the same action multiple times (to trigger loop)
for _ in range(3):
await guardian.loop_detector.record(action)
result = await guardian.validate(action)
assert result.allowed is False
assert result.decision == SafetyDecision.DENY
assert any("loop" in r.lower() for r in result.reasons)
@pytest.mark.asyncio
async def test_validate_denied_tool(
self,
guardian: SafetyGuardian,
sample_metadata: ActionMetadata,
) -> None:
"""Test validation fails for denied tools."""
# Create action with tool that matches denied pattern
action = create_action(sample_metadata, tool_name="shell_exec")
# Create policy with denied pattern
policy = SafetyPolicy(
name="test-policy",
allowed_tools=["*"],
denied_tools=["shell_*"],
)
result = await guardian.validate(action, policy=policy)
assert result.allowed is False
assert result.decision == SafetyDecision.DENY
assert any("denied" in r.lower() for r in result.reasons)
@pytest.mark.asyncio
async def test_validate_with_custom_policy(
self,
guardian: SafetyGuardian,
sample_metadata: ActionMetadata,
) -> None:
"""Test validation with custom policy."""
action = create_action(sample_metadata, tool_name="allowed_tool")
policy = SafetyPolicy(
name="test-custom-policy",
allowed_tools=["allowed_*"],
denied_tools=[],
)
result = await guardian.validate(action, policy=policy)
assert result.allowed is True
assert result.decision == SafetyDecision.ALLOW
class TestSafetyGuardianRecording:
"""Tests for SafetyGuardian.record_execution()."""
@pytest.mark.asyncio
async def test_record_execution_updates_cost(
self,
guardian: SafetyGuardian,
sample_metadata: ActionMetadata,
) -> None:
"""Test recording execution updates cost tracker."""
action = create_action(sample_metadata)
action_result = ActionResult(
action_id=action.id,
success=True,
actual_cost_tokens=50,
actual_cost_usd=0.005,
)
await guardian.record_execution(action, action_result)
# Check cost was recorded
status = await guardian.cost_controller.get_status(
BudgetScope.SESSION, sample_metadata.session_id
)
assert status is not None
assert status.tokens_used == 50
@pytest.mark.asyncio
async def test_record_execution_updates_loop_history(
self,
guardian: SafetyGuardian,
sample_metadata: ActionMetadata,
) -> None:
"""Test recording execution updates loop detector history."""
action = create_action(sample_metadata)
action_result = ActionResult(
action_id=action.id,
success=True,
)
await guardian.record_execution(action, action_result)
# Check action was recorded in loop detector
stats = await guardian.loop_detector.get_stats(sample_metadata.agent_id)
assert stats["history_size"] == 1
class TestSafetyGuardianSingleton:
"""Tests for SafetyGuardian singleton functions."""
@pytest.mark.asyncio
async def test_get_safety_guardian_creates_singleton(
self,
reset_guardian,
) -> None:
"""Test get_safety_guardian creates singleton."""
with patch(
"app.services.safety.guardian.get_audit_logger",
new_callable=AsyncMock,
):
guardian1 = await get_safety_guardian()
guardian2 = await get_safety_guardian()
assert guardian1 is guardian2
assert guardian1.is_initialized is True
@pytest.mark.asyncio
async def test_shutdown_safety_guardian(
self,
reset_guardian,
) -> None:
"""Test shutdown cleans up singleton."""
with patch(
"app.services.safety.guardian.get_audit_logger",
new_callable=AsyncMock,
):
guardian = await get_safety_guardian()
assert guardian.is_initialized is True
await shutdown_safety_guardian()
# Singleton should be cleared - next get creates new instance
@pytest.mark.asyncio
async def test_reset_safety_guardian(
self,
reset_guardian,
) -> None:
"""Test reset clears singleton."""
with patch(
"app.services.safety.guardian.get_audit_logger",
new_callable=AsyncMock,
):
guardian1 = await get_safety_guardian()
await reset_safety_guardian()
guardian2 = await get_safety_guardian()
assert guardian1 is not guardian2
class TestPatternMatching:
"""Tests for pattern matching logic."""
def test_exact_match(self) -> None:
"""Test exact pattern matching."""
guardian = SafetyGuardian()
assert guardian._matches_pattern("file_read", "file_read") is True
assert guardian._matches_pattern("file_read", "file_write") is False
def test_wildcard_all(self) -> None:
"""Test wildcard * matches all."""
guardian = SafetyGuardian()
assert guardian._matches_pattern("anything", "*") is True
assert guardian._matches_pattern("", "*") is True
def test_prefix_wildcard(self) -> None:
"""Test prefix wildcard matching."""
guardian = SafetyGuardian()
assert guardian._matches_pattern("test_read", "*_read") is True
assert guardian._matches_pattern("test_write", "*_read") is False
def test_suffix_wildcard(self) -> None:
"""Test suffix wildcard matching."""
guardian = SafetyGuardian()
assert guardian._matches_pattern("file_read", "file_*") is True
assert guardian._matches_pattern("shell_read", "file_*") is False
def test_contains_wildcard(self) -> None:
"""Test contains wildcard matching."""
guardian = SafetyGuardian()
assert guardian._matches_pattern("test_dangerous_action", "*dangerous*") is True
assert guardian._matches_pattern("test_safe_action", "*dangerous*") is False
class TestErrorHandling:
"""Tests for error handling in SafetyGuardian."""
@pytest.mark.asyncio
async def test_strict_mode_fails_on_error(
self,
guardian: SafetyGuardian,
sample_metadata: ActionMetadata,
) -> None:
"""Test strict mode denies on unexpected errors."""
action = create_action(sample_metadata)
# Force an error by breaking the cost controller
original_check = guardian.cost_controller.check_budget
guardian.cost_controller.check_budget = AsyncMock(
side_effect=Exception("Unexpected error")
)
result = await guardian.validate(action)
assert result.allowed is False
assert result.decision == SafetyDecision.DENY
assert any("error" in r.lower() for r in result.reasons)
# Restore
guardian.cost_controller.check_budget = original_check
@pytest.mark.asyncio
async def test_non_strict_mode_allows_on_error(
self,
guardian: SafetyGuardian,
sample_metadata: ActionMetadata,
) -> None:
"""Test non-strict mode allows on unexpected errors."""
guardian._config.strict_mode = False
action = create_action(sample_metadata)
# Force an error by breaking the cost controller
original_check = guardian.cost_controller.check_budget
guardian.cost_controller.check_budget = AsyncMock(
side_effect=Exception("Unexpected error")
)
result = await guardian.validate(action)
assert result.allowed is True
assert result.decision == SafetyDecision.ALLOW
# Restore
guardian.cost_controller.check_budget = original_check
guardian._config.strict_mode = True

View File

@@ -0,0 +1,405 @@
"""Tests for rate limiter module."""
import pytest
from app.services.safety.exceptions import RateLimitExceededError
from app.services.safety.limits.limiter import (
RateLimiter,
SlidingWindowCounter,
)
from app.services.safety.models import (
ActionMetadata,
ActionRequest,
ActionType,
RateLimitConfig,
)
@pytest.fixture
def sliding_counter() -> SlidingWindowCounter:
"""Create a sliding window counter for testing."""
return SlidingWindowCounter(
limit=5,
window_seconds=60,
burst_limit=3,
)
@pytest.fixture
def rate_limiter() -> RateLimiter:
"""Create a rate limiter for testing."""
limiter = RateLimiter()
# Configure a test limit
limiter.configure(
RateLimitConfig(
name="test_limit",
limit=5,
window_seconds=60,
burst_limit=3,
)
)
return limiter
@pytest.fixture
def sample_metadata() -> ActionMetadata:
"""Create sample action metadata."""
return ActionMetadata(
agent_id="test-agent",
session_id="test-session",
)
def create_action(
metadata: ActionMetadata,
action_type: ActionType = ActionType.LLM_CALL,
) -> ActionRequest:
"""Helper to create test actions."""
return ActionRequest(
action_type=action_type,
tool_name="test_tool",
resource="test-resource",
arguments={},
metadata=metadata,
)
class TestSlidingWindowCounter:
"""Tests for SlidingWindowCounter class."""
@pytest.mark.asyncio
async def test_first_acquire_allowed(
self,
sliding_counter: SlidingWindowCounter,
) -> None:
"""Test first acquire is always allowed."""
allowed, retry_after = await sliding_counter.try_acquire()
assert allowed is True
assert retry_after == 0.0
@pytest.mark.asyncio
async def test_burst_limit(
self,
sliding_counter: SlidingWindowCounter,
) -> None:
"""Test burst limit is enforced."""
# Acquire up to burst limit (3)
for _ in range(3):
allowed, _ = await sliding_counter.try_acquire()
assert allowed is True
# Next should be denied (burst exceeded)
allowed, retry_after = await sliding_counter.try_acquire()
assert allowed is False
assert retry_after > 0
@pytest.mark.asyncio
async def test_get_status(
self,
sliding_counter: SlidingWindowCounter,
) -> None:
"""Test getting counter status."""
# Make some requests
await sliding_counter.try_acquire()
await sliding_counter.try_acquire()
current, remaining, reset_in = await sliding_counter.get_status()
assert current == 2
assert remaining == 3 # 5 - 2
assert reset_in >= 0
class TestRateLimiter:
"""Tests for RateLimiter class."""
@pytest.mark.asyncio
async def test_check_status(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test checking rate limit status."""
status = await rate_limiter.check("test_limit", "test-key")
assert status.name == "test_limit"
assert status.current_count == 0
assert status.limit == 5
assert status.remaining == 5
assert status.is_limited is False
@pytest.mark.asyncio
async def test_acquire_success(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test successful acquire."""
allowed, status = await rate_limiter.acquire("test_limit", "test-key")
assert allowed is True
assert status.current_count == 1
assert status.remaining == 4
@pytest.mark.asyncio
async def test_acquire_burst_exceeded(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test acquire fails when burst exceeded."""
# Acquire up to burst limit
for _ in range(3):
allowed, _ = await rate_limiter.acquire("test_limit", "test-key")
assert allowed is True
# Next should fail
allowed, status = await rate_limiter.acquire("test_limit", "test-key")
assert allowed is False
assert status.is_limited is True
assert status.retry_after_seconds > 0
@pytest.mark.asyncio
async def test_require_success(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test require passes when not limited."""
# Should not raise
await rate_limiter.require("test_limit", "test-key")
@pytest.mark.asyncio
async def test_require_raises(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test require raises when limited."""
# Use up burst limit
for _ in range(3):
await rate_limiter.acquire("test_limit", "test-key")
with pytest.raises(RateLimitExceededError) as exc_info:
await rate_limiter.require("test_limit", "test-key")
assert exc_info.value.limit_type == "test_limit"
assert exc_info.value.retry_after_seconds > 0
@pytest.mark.asyncio
async def test_check_action_allowed(
self,
rate_limiter: RateLimiter,
sample_metadata: ActionMetadata,
) -> None:
"""Test checking action is allowed."""
action = create_action(sample_metadata)
allowed, statuses = await rate_limiter.check_action(action)
assert allowed is True
assert len(statuses) >= 1 # At least "actions" limit
@pytest.mark.asyncio
async def test_check_action_llm_limits(
self,
rate_limiter: RateLimiter,
sample_metadata: ActionMetadata,
) -> None:
"""Test LLM actions check LLM-specific limits."""
action = create_action(sample_metadata, action_type=ActionType.LLM_CALL)
allowed, statuses = await rate_limiter.check_action(action)
assert allowed is True
# Should have checked both "actions" and "llm_calls"
limit_names = [s.name for s in statuses]
assert "actions" in limit_names
assert "llm_calls" in limit_names
@pytest.mark.asyncio
async def test_check_action_file_limits(
self,
rate_limiter: RateLimiter,
sample_metadata: ActionMetadata,
) -> None:
"""Test file actions check file-specific limits."""
action = create_action(sample_metadata, action_type=ActionType.FILE_READ)
allowed, statuses = await rate_limiter.check_action(action)
assert allowed is True
# Should have checked both "actions" and "file_ops"
limit_names = [s.name for s in statuses]
assert "actions" in limit_names
assert "file_ops" in limit_names
@pytest.mark.asyncio
async def test_check_action_does_not_consume_slot(
self,
rate_limiter: RateLimiter,
sample_metadata: ActionMetadata,
) -> None:
"""Test check_action only checks without consuming slots."""
action = create_action(sample_metadata)
# Check multiple times - should never consume
for _ in range(10):
allowed, _ = await rate_limiter.check_action(action)
assert allowed is True
# Verify no slots were consumed
status = await rate_limiter.check("actions", sample_metadata.agent_id)
assert status.current_count == 0
@pytest.mark.asyncio
async def test_record_action_consumes_slot(
self,
rate_limiter: RateLimiter,
sample_metadata: ActionMetadata,
) -> None:
"""Test record_action consumes rate limit slots."""
action = create_action(sample_metadata)
# Record the action
await rate_limiter.record_action(action)
# Verify slot was consumed
status = await rate_limiter.check("actions", sample_metadata.agent_id)
assert status.current_count == 1
@pytest.mark.asyncio
async def test_record_action_consumes_type_specific_slots(
self,
rate_limiter: RateLimiter,
sample_metadata: ActionMetadata,
) -> None:
"""Test record_action consumes type-specific slots."""
# LLM action
llm_action = create_action(sample_metadata, action_type=ActionType.LLM_CALL)
await rate_limiter.record_action(llm_action)
statuses = await rate_limiter.get_all_statuses(sample_metadata.agent_id)
assert statuses["actions"].current_count == 1
assert statuses["llm_calls"].current_count == 1
assert statuses["file_ops"].current_count == 0
# File action
file_action = create_action(sample_metadata, action_type=ActionType.FILE_READ)
await rate_limiter.record_action(file_action)
statuses = await rate_limiter.get_all_statuses(sample_metadata.agent_id)
assert statuses["actions"].current_count == 2
assert statuses["llm_calls"].current_count == 1
assert statuses["file_ops"].current_count == 1
@pytest.mark.asyncio
async def test_get_all_statuses(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test getting all rate limit statuses."""
# Make some requests
await rate_limiter.acquire("actions", "test-key")
await rate_limiter.acquire("llm_calls", "test-key")
statuses = await rate_limiter.get_all_statuses("test-key")
assert "actions" in statuses
assert "llm_calls" in statuses
assert "file_ops" in statuses
assert statuses["actions"].current_count >= 1
assert statuses["llm_calls"].current_count >= 1
@pytest.mark.asyncio
async def test_reset_single(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test resetting a single rate limit."""
# Make some requests
await rate_limiter.acquire("test_limit", "test-key")
await rate_limiter.acquire("test_limit", "test-key")
# Reset
result = await rate_limiter.reset("test_limit", "test-key")
assert result is True
# Check it's reset
status = await rate_limiter.check("test_limit", "test-key")
assert status.current_count == 0
@pytest.mark.asyncio
async def test_reset_nonexistent(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test resetting non-existent limit returns False."""
result = await rate_limiter.reset("nonexistent", "test-key")
assert result is False
@pytest.mark.asyncio
async def test_reset_all(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test resetting all rate limits for a key."""
# Make requests across multiple limits
await rate_limiter.acquire("actions", "test-key")
await rate_limiter.acquire("llm_calls", "test-key")
await rate_limiter.acquire("file_ops", "test-key")
# Reset all
count = await rate_limiter.reset_all("test-key")
assert count >= 3
# Check they're reset
statuses = await rate_limiter.get_all_statuses("test-key")
for status in statuses.values():
assert status.current_count == 0
@pytest.mark.asyncio
async def test_per_key_isolation(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test rate limits are isolated per key."""
# Use up burst limit for key-1
for _ in range(3):
await rate_limiter.acquire("test_limit", "key-1")
# key-1 should be limited
allowed1, _ = await rate_limiter.acquire("test_limit", "key-1")
assert allowed1 is False
# key-2 should still be allowed
allowed2, _ = await rate_limiter.acquire("test_limit", "key-2")
assert allowed2 is True
@pytest.mark.asyncio
async def test_configure_custom_limit(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test configuring custom rate limits."""
rate_limiter.configure(
RateLimitConfig(
name="custom",
limit=100,
window_seconds=120,
burst_limit=50,
)
)
status = await rate_limiter.check("custom", "test-key")
assert status.limit == 100
assert status.window_seconds == 120
@pytest.mark.asyncio
async def test_default_limit_fallback(
self,
rate_limiter: RateLimiter,
) -> None:
"""Test fallback to default limit for unknown limit names."""
# Request limit that doesn't exist
status = await rate_limiter.check("unknown_limit", "test-key")
# Should use default (60/60s)
assert status.limit == 60
assert status.window_seconds == 60