forked from cardosofelipe/fast-next-template
feat(safety): enhance rate limiting and cost control with alert deduplication and usage tracking
- Added `record_action` in `RateLimiter` for precise tracking of slot consumption post-validation. - Introduced deduplication mechanism for warning alerts in `CostController` to prevent spamming. - Refactored `CostController`'s session and daily budget alert handling for improved clarity. - Implemented test suites for `CostController` and `SafetyGuardian` to validate changes. - Expanded integration testing to cover deduplication, validation, and loop detection edge cases.
This commit is contained in:
@@ -411,7 +411,20 @@ async def shutdown_mcp_client() -> None:
|
||||
_manager_instance = None
|
||||
|
||||
|
||||
def reset_mcp_client() -> None:
|
||||
"""Reset the global MCP client manager (for testing)."""
|
||||
async def reset_mcp_client() -> None:
|
||||
"""
|
||||
Reset the global MCP client manager (for testing).
|
||||
|
||||
This is an async function to properly acquire the manager lock
|
||||
and avoid race conditions with get_mcp_client().
|
||||
"""
|
||||
global _manager_instance
|
||||
_manager_instance = None
|
||||
|
||||
async with _manager_lock:
|
||||
if _manager_instance is not None:
|
||||
# Shutdown gracefully before resetting
|
||||
try:
|
||||
await _manager_instance.shutdown()
|
||||
except Exception: # noqa: S110
|
||||
pass # Ignore errors during test cleanup
|
||||
_manager_instance = None
|
||||
|
||||
@@ -161,7 +161,7 @@ class MCPConnection:
|
||||
server_name=self.server_name,
|
||||
url=self.config.url,
|
||||
cause=e,
|
||||
)
|
||||
) from e
|
||||
else:
|
||||
# For STDIO and SSE transports, we'll implement later
|
||||
raise NotImplementedError(
|
||||
@@ -297,13 +297,13 @@ class MCPConnection:
|
||||
server_name=self.server_name,
|
||||
url=f"{self.config.url}{path}",
|
||||
cause=e,
|
||||
)
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise MCPConnectionError(
|
||||
f"Request failed: {e}",
|
||||
server_name=self.server_name,
|
||||
cause=e,
|
||||
)
|
||||
) from e
|
||||
|
||||
|
||||
class ConnectionPool:
|
||||
@@ -322,8 +322,19 @@ class ConnectionPool:
|
||||
"""
|
||||
self._connections: dict[str, MCPConnection] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._per_server_locks: dict[str, asyncio.Lock] = {}
|
||||
self._max_per_server = max_connections_per_server
|
||||
|
||||
def _get_server_lock(self, server_name: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a specific server.
|
||||
|
||||
Uses setdefault for atomic dict access to prevent race conditions
|
||||
where two coroutines could create different locks for the same server.
|
||||
"""
|
||||
# setdefault is atomic - if key exists, returns existing value
|
||||
# if key doesn't exist, inserts new value and returns it
|
||||
return self._per_server_locks.setdefault(server_name, asyncio.Lock())
|
||||
|
||||
async def get_connection(
|
||||
self,
|
||||
server_name: str,
|
||||
@@ -332,6 +343,9 @@ class ConnectionPool:
|
||||
"""
|
||||
Get or create a connection to a server.
|
||||
|
||||
Uses per-server locking to avoid blocking all connections
|
||||
when establishing a new connection.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server
|
||||
config: Server configuration
|
||||
@@ -339,17 +353,33 @@ class ConnectionPool:
|
||||
Returns:
|
||||
Active connection
|
||||
"""
|
||||
async with self._lock:
|
||||
if server_name not in self._connections:
|
||||
connection = MCPConnection(server_name, config)
|
||||
await connection.connect()
|
||||
self._connections[server_name] = connection
|
||||
|
||||
# Quick check without lock - if connection exists and is connected, return it
|
||||
if server_name in self._connections:
|
||||
connection = self._connections[server_name]
|
||||
if connection.is_connected:
|
||||
return connection
|
||||
|
||||
# Reconnect if not connected
|
||||
if not connection.is_connected:
|
||||
# Need to create or reconnect - use per-server lock to avoid blocking others
|
||||
async with self._lock:
|
||||
server_lock = self._get_server_lock(server_name)
|
||||
|
||||
async with server_lock:
|
||||
# Double-check after acquiring per-server lock
|
||||
if server_name in self._connections:
|
||||
connection = self._connections[server_name]
|
||||
if connection.is_connected:
|
||||
return connection
|
||||
# Connection exists but not connected - reconnect
|
||||
await connection.connect()
|
||||
return connection
|
||||
|
||||
# Create new connection (outside global lock, under per-server lock)
|
||||
connection = MCPConnection(server_name, config)
|
||||
await connection.connect()
|
||||
|
||||
# Store connection under global lock
|
||||
async with self._lock:
|
||||
self._connections[server_name] = connection
|
||||
|
||||
return connection
|
||||
|
||||
@@ -374,6 +404,9 @@ class ConnectionPool:
|
||||
if server_name in self._connections:
|
||||
await self._connections[server_name].disconnect()
|
||||
del self._connections[server_name]
|
||||
# Clean up per-server lock
|
||||
if server_name in self._per_server_locks:
|
||||
del self._per_server_locks[server_name]
|
||||
|
||||
async def close_all(self) -> None:
|
||||
"""Close all connections in the pool."""
|
||||
@@ -385,6 +418,7 @@ class ConnectionPool:
|
||||
logger.warning("Error closing connection: %s", e)
|
||||
|
||||
self._connections.clear()
|
||||
self._per_server_locks.clear()
|
||||
logger.info("Closed all MCP connections")
|
||||
|
||||
async def health_check_all(self) -> dict[str, bool]:
|
||||
@@ -394,8 +428,12 @@ class ConnectionPool:
|
||||
Returns:
|
||||
Dict mapping server names to health status
|
||||
"""
|
||||
# Copy connections under lock to prevent modification during iteration
|
||||
async with self._lock:
|
||||
connections_snapshot = dict(self._connections)
|
||||
|
||||
results = {}
|
||||
for name, connection in self._connections.items():
|
||||
for name, connection in connections_snapshot.items():
|
||||
results[name] = await connection.health_check()
|
||||
return results
|
||||
|
||||
|
||||
@@ -185,6 +185,9 @@ class CostController:
|
||||
# Alert handlers
|
||||
self._alert_handlers: list[Any] = []
|
||||
|
||||
# Track which budgets have had warning alerts sent (to avoid spam)
|
||||
self._warned_budgets: set[str] = set()
|
||||
|
||||
async def get_or_create_tracker(
|
||||
self,
|
||||
scope: BudgetScope,
|
||||
@@ -343,32 +346,44 @@ class CostController:
|
||||
"""
|
||||
# Update session budget
|
||||
if session_id:
|
||||
session_key = f"session:{session_id}"
|
||||
session_tracker = await self.get_or_create_tracker(
|
||||
BudgetScope.SESSION, session_id
|
||||
)
|
||||
await session_tracker.add_usage(tokens, cost_usd)
|
||||
|
||||
# Check for warning
|
||||
# Check for warning (only alert once per budget to avoid spam)
|
||||
status = await session_tracker.get_status()
|
||||
if status.is_warning and not status.is_exceeded:
|
||||
await self._send_alert(
|
||||
"warning",
|
||||
f"Session {session_id} at {status.tokens_used}/{status.tokens_limit} tokens",
|
||||
status,
|
||||
)
|
||||
if session_key not in self._warned_budgets:
|
||||
self._warned_budgets.add(session_key)
|
||||
await self._send_alert(
|
||||
"warning",
|
||||
f"Session {session_id} at {status.tokens_used}/{status.tokens_limit} tokens",
|
||||
status,
|
||||
)
|
||||
elif not status.is_warning:
|
||||
# Clear warning flag if usage dropped below threshold (e.g., after reset)
|
||||
self._warned_budgets.discard(session_key)
|
||||
|
||||
# Update agent daily budget
|
||||
daily_key = f"daily:{agent_id}"
|
||||
agent_tracker = await self.get_or_create_tracker(BudgetScope.DAILY, agent_id)
|
||||
await agent_tracker.add_usage(tokens, cost_usd)
|
||||
|
||||
# Check for warning
|
||||
# Check for warning (only alert once per budget to avoid spam)
|
||||
status = await agent_tracker.get_status()
|
||||
if status.is_warning and not status.is_exceeded:
|
||||
await self._send_alert(
|
||||
"warning",
|
||||
f"Agent {agent_id} at {status.tokens_used}/{status.tokens_limit} daily tokens",
|
||||
status,
|
||||
)
|
||||
if daily_key not in self._warned_budgets:
|
||||
self._warned_budgets.add(daily_key)
|
||||
await self._send_alert(
|
||||
"warning",
|
||||
f"Agent {agent_id} at {status.tokens_used}/{status.tokens_limit} daily tokens",
|
||||
status,
|
||||
)
|
||||
elif not status.is_warning:
|
||||
# Clear warning flag if usage dropped below threshold (e.g., after reset)
|
||||
self._warned_budgets.discard(daily_key)
|
||||
|
||||
async def get_status(
|
||||
self,
|
||||
@@ -388,20 +403,18 @@ class CostController:
|
||||
key = f"{scope.value}:{scope_id}"
|
||||
async with self._lock:
|
||||
tracker = self._trackers.get(key)
|
||||
|
||||
if tracker:
|
||||
return await tracker.get_status()
|
||||
return None
|
||||
# Get status while holding lock to prevent TOCTOU race
|
||||
if tracker:
|
||||
return await tracker.get_status()
|
||||
return None
|
||||
|
||||
async def get_all_statuses(self) -> list[BudgetStatus]:
|
||||
"""Get status of all tracked budgets."""
|
||||
statuses = []
|
||||
async with self._lock:
|
||||
trackers = list(self._trackers.values())
|
||||
|
||||
for tracker in trackers:
|
||||
statuses.append(await tracker.get_status())
|
||||
|
||||
# Get all statuses while holding lock to prevent TOCTOU race
|
||||
for tracker in self._trackers.values():
|
||||
statuses.append(await tracker.get_status())
|
||||
return statuses
|
||||
|
||||
async def set_budget(
|
||||
@@ -453,11 +466,11 @@ class CostController:
|
||||
key = f"{scope.value}:{scope_id}"
|
||||
async with self._lock:
|
||||
tracker = self._trackers.get(key)
|
||||
|
||||
if tracker:
|
||||
await tracker.reset()
|
||||
return True
|
||||
return False
|
||||
# Reset while holding lock to prevent TOCTOU race
|
||||
if tracker:
|
||||
await tracker.reset()
|
||||
return True
|
||||
return False
|
||||
|
||||
def add_alert_handler(self, handler: Any) -> None:
|
||||
"""Add an alert handler."""
|
||||
|
||||
@@ -15,13 +15,20 @@ from .config import (
|
||||
get_policy_for_autonomy_level,
|
||||
get_safety_config,
|
||||
)
|
||||
from .costs.controller import CostController
|
||||
from .exceptions import (
|
||||
BudgetExceededError,
|
||||
LoopDetectedError,
|
||||
RateLimitExceededError,
|
||||
SafetyError,
|
||||
)
|
||||
from .limits.limiter import RateLimiter
|
||||
from .loops.detector import LoopDetector
|
||||
from .models import (
|
||||
ActionRequest,
|
||||
ActionResult,
|
||||
AuditEventType,
|
||||
BudgetScope,
|
||||
GuardianResult,
|
||||
SafetyDecision,
|
||||
SafetyPolicy,
|
||||
@@ -62,6 +69,9 @@ class SafetyGuardian:
|
||||
self,
|
||||
config: SafetyConfig | None = None,
|
||||
audit_logger: AuditLogger | None = None,
|
||||
cost_controller: CostController | None = None,
|
||||
rate_limiter: RateLimiter | None = None,
|
||||
loop_detector: LoopDetector | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the SafetyGuardian.
|
||||
@@ -69,17 +79,22 @@ class SafetyGuardian:
|
||||
Args:
|
||||
config: Optional safety configuration. If None, loads from environment.
|
||||
audit_logger: Optional audit logger. If None, uses global instance.
|
||||
cost_controller: Optional cost controller. If None, creates default.
|
||||
rate_limiter: Optional rate limiter. If None, creates default.
|
||||
loop_detector: Optional loop detector. If None, creates default.
|
||||
"""
|
||||
self._config = config or get_safety_config()
|
||||
self._audit_logger = audit_logger
|
||||
self._initialized = False
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Subsystem references (will be initialized lazily)
|
||||
# Core safety subsystems (always initialized)
|
||||
self._cost_controller: CostController | None = cost_controller
|
||||
self._rate_limiter: RateLimiter | None = rate_limiter
|
||||
self._loop_detector: LoopDetector | None = loop_detector
|
||||
|
||||
# Optional subsystems (will be initialized when available)
|
||||
self._permission_manager: Any = None
|
||||
self._cost_controller: Any = None
|
||||
self._rate_limiter: Any = None
|
||||
self._loop_detector: Any = None
|
||||
self._hitl_manager: Any = None
|
||||
self._rollback_manager: Any = None
|
||||
self._content_filter: Any = None
|
||||
@@ -95,6 +110,21 @@ class SafetyGuardian:
|
||||
"""Check if the guardian is initialized."""
|
||||
return self._initialized
|
||||
|
||||
@property
|
||||
def cost_controller(self) -> CostController | None:
|
||||
"""Get the cost controller instance."""
|
||||
return self._cost_controller
|
||||
|
||||
@property
|
||||
def rate_limiter(self) -> RateLimiter | None:
|
||||
"""Get the rate limiter instance."""
|
||||
return self._rate_limiter
|
||||
|
||||
@property
|
||||
def loop_detector(self) -> LoopDetector | None:
|
||||
"""Get the loop detector instance."""
|
||||
return self._loop_detector
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the SafetyGuardian and all subsystems."""
|
||||
async with self._lock:
|
||||
@@ -108,11 +138,23 @@ class SafetyGuardian:
|
||||
if self._audit_logger is None:
|
||||
self._audit_logger = await get_audit_logger()
|
||||
|
||||
# Initialize subsystems lazily as they're implemented
|
||||
# For now, we'll import and initialize them when available
|
||||
# Initialize core safety subsystems
|
||||
if self._cost_controller is None:
|
||||
self._cost_controller = CostController()
|
||||
logger.debug("Initialized CostController")
|
||||
|
||||
if self._rate_limiter is None:
|
||||
self._rate_limiter = RateLimiter()
|
||||
logger.debug("Initialized RateLimiter")
|
||||
|
||||
if self._loop_detector is None:
|
||||
self._loop_detector = LoopDetector()
|
||||
logger.debug("Initialized LoopDetector")
|
||||
|
||||
self._initialized = True
|
||||
logger.info("SafetyGuardian initialized")
|
||||
logger.info(
|
||||
"SafetyGuardian initialized with CostController, RateLimiter, LoopDetector"
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Shutdown the SafetyGuardian and all subsystems."""
|
||||
@@ -309,13 +351,40 @@ class SafetyGuardian:
|
||||
|
||||
# Update cost tracking
|
||||
if self._cost_controller:
|
||||
# Track actual cost
|
||||
pass
|
||||
try:
|
||||
# Use explicit None check - 0 is a valid cost value
|
||||
tokens = (
|
||||
result.actual_cost_tokens
|
||||
if result.actual_cost_tokens is not None
|
||||
else action.estimated_cost_tokens
|
||||
)
|
||||
cost_usd = (
|
||||
result.actual_cost_usd
|
||||
if result.actual_cost_usd is not None
|
||||
else action.estimated_cost_usd
|
||||
)
|
||||
await self._cost_controller.record_usage(
|
||||
agent_id=action.metadata.agent_id,
|
||||
session_id=action.metadata.session_id,
|
||||
tokens=tokens,
|
||||
cost_usd=cost_usd,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to record cost: %s", e)
|
||||
|
||||
# Update rate limiter - consume slots for executed actions
|
||||
if self._rate_limiter:
|
||||
try:
|
||||
await self._rate_limiter.record_action(action)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to record action in rate limiter: %s", e)
|
||||
|
||||
# Update loop detection history
|
||||
if self._loop_detector:
|
||||
# Add to action history
|
||||
pass
|
||||
try:
|
||||
await self._loop_detector.record(action)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to record action in loop detector: %s", e)
|
||||
|
||||
async def rollback(self, checkpoint_id: str) -> bool:
|
||||
"""
|
||||
@@ -442,14 +511,80 @@ class SafetyGuardian:
|
||||
policy: SafetyPolicy,
|
||||
) -> GuardianResult:
|
||||
"""Check if action is within budget."""
|
||||
# TODO: Implement with CostController
|
||||
# For now, return allow
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["Budget check passed (not fully implemented)"],
|
||||
)
|
||||
if self._cost_controller is None:
|
||||
logger.warning("CostController not initialized - skipping budget check")
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["Budget check skipped (controller not initialized)"],
|
||||
)
|
||||
|
||||
agent_id = action.metadata.agent_id
|
||||
session_id = action.metadata.session_id
|
||||
|
||||
try:
|
||||
# Check if we have budget for this action
|
||||
has_budget = await self._cost_controller.check_budget(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
estimated_tokens=action.estimated_cost_tokens,
|
||||
estimated_cost_usd=action.estimated_cost_usd,
|
||||
)
|
||||
|
||||
if not has_budget:
|
||||
# Get current status for better error message
|
||||
if session_id:
|
||||
session_status = await self._cost_controller.get_status(
|
||||
BudgetScope.SESSION, session_id
|
||||
)
|
||||
if session_status and session_status.is_exceeded:
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DENY,
|
||||
reasons=[
|
||||
f"Session budget exceeded: {session_status.tokens_used}"
|
||||
f"/{session_status.tokens_limit} tokens"
|
||||
],
|
||||
)
|
||||
|
||||
agent_status = await self._cost_controller.get_status(
|
||||
BudgetScope.DAILY, agent_id
|
||||
)
|
||||
if agent_status and agent_status.is_exceeded:
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DENY,
|
||||
reasons=[
|
||||
f"Daily budget exceeded: {agent_status.tokens_used}"
|
||||
f"/{agent_status.tokens_limit} tokens"
|
||||
],
|
||||
)
|
||||
|
||||
# Generic budget exceeded
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DENY,
|
||||
reasons=["Budget exceeded"],
|
||||
)
|
||||
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["Budget check passed"],
|
||||
)
|
||||
|
||||
except BudgetExceededError as e:
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DENY,
|
||||
reasons=[str(e)],
|
||||
)
|
||||
|
||||
async def _check_rate_limit(
|
||||
self,
|
||||
@@ -457,14 +592,78 @@ class SafetyGuardian:
|
||||
policy: SafetyPolicy,
|
||||
) -> GuardianResult:
|
||||
"""Check if action is within rate limits."""
|
||||
# TODO: Implement with RateLimiter
|
||||
# For now, return allow
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["Rate limit check passed (not fully implemented)"],
|
||||
)
|
||||
if self._rate_limiter is None:
|
||||
logger.warning("RateLimiter not initialized - skipping rate limit check")
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["Rate limit check skipped (limiter not initialized)"],
|
||||
)
|
||||
|
||||
try:
|
||||
# Check all applicable rate limits for this action
|
||||
allowed, statuses = await self._rate_limiter.check_action(action)
|
||||
|
||||
if not allowed:
|
||||
# Find the first exceeded limit for the error message
|
||||
exceeded_status = next(
|
||||
(s for s in statuses if s.is_limited),
|
||||
statuses[0] if statuses else None,
|
||||
)
|
||||
|
||||
if exceeded_status:
|
||||
retry_after = exceeded_status.retry_after_seconds
|
||||
|
||||
# Determine if this is a soft limit (delay) or hard limit (deny)
|
||||
if retry_after > 0 and retry_after <= 5.0:
|
||||
# Short wait - suggest delay
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DELAY,
|
||||
reasons=[
|
||||
f"Rate limit '{exceeded_status.name}' exceeded. "
|
||||
f"Current: {exceeded_status.current_count}/{exceeded_status.limit}"
|
||||
],
|
||||
retry_after_seconds=retry_after,
|
||||
)
|
||||
else:
|
||||
# Hard deny
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DENY,
|
||||
reasons=[
|
||||
f"Rate limit '{exceeded_status.name}' exceeded. "
|
||||
f"Current: {exceeded_status.current_count}/{exceeded_status.limit}. "
|
||||
f"Retry after {retry_after:.1f}s"
|
||||
],
|
||||
retry_after_seconds=retry_after,
|
||||
)
|
||||
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DENY,
|
||||
reasons=["Rate limit exceeded"],
|
||||
)
|
||||
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["Rate limit check passed"],
|
||||
)
|
||||
|
||||
except RateLimitExceededError as e:
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DENY,
|
||||
reasons=[str(e)],
|
||||
retry_after_seconds=e.retry_after_seconds,
|
||||
)
|
||||
|
||||
async def _check_loops(
|
||||
self,
|
||||
@@ -472,14 +671,51 @@ class SafetyGuardian:
|
||||
policy: SafetyPolicy,
|
||||
) -> GuardianResult:
|
||||
"""Check for action loops."""
|
||||
# TODO: Implement with LoopDetector
|
||||
# For now, return allow
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["Loop check passed (not fully implemented)"],
|
||||
)
|
||||
if self._loop_detector is None:
|
||||
logger.warning("LoopDetector not initialized - skipping loop check")
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["Loop check skipped (detector not initialized)"],
|
||||
)
|
||||
|
||||
try:
|
||||
# Check if this action would create a loop
|
||||
is_loop, loop_type = await self._loop_detector.check(action)
|
||||
|
||||
if is_loop:
|
||||
# Get suggestions for breaking the loop
|
||||
from .loops.detector import LoopBreaker
|
||||
|
||||
suggestions = await LoopBreaker.suggest_alternatives(
|
||||
action, loop_type or "unknown"
|
||||
)
|
||||
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DENY,
|
||||
reasons=[
|
||||
f"Loop detected: {loop_type}",
|
||||
*suggestions,
|
||||
],
|
||||
)
|
||||
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["Loop check passed"],
|
||||
)
|
||||
|
||||
except LoopDetectedError as e:
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DENY,
|
||||
reasons=[str(e)],
|
||||
)
|
||||
|
||||
async def _check_hitl(
|
||||
self,
|
||||
@@ -610,7 +846,19 @@ async def shutdown_safety_guardian() -> None:
|
||||
_guardian_instance = None
|
||||
|
||||
|
||||
def reset_safety_guardian() -> None:
|
||||
"""Reset the SafetyGuardian (for testing)."""
|
||||
async def reset_safety_guardian() -> None:
|
||||
"""
|
||||
Reset the SafetyGuardian (for testing).
|
||||
|
||||
This is an async function to properly acquire the guardian lock
|
||||
and avoid race conditions with get_safety_guardian().
|
||||
"""
|
||||
global _guardian_instance
|
||||
_guardian_instance = None
|
||||
|
||||
async with _guardian_lock:
|
||||
if _guardian_instance is not None:
|
||||
try:
|
||||
await _guardian_instance.shutdown()
|
||||
except Exception: # noqa: S110
|
||||
pass # Ignore errors during test cleanup
|
||||
_guardian_instance = None
|
||||
|
||||
@@ -223,7 +223,10 @@ class RateLimiter:
|
||||
action: ActionRequest,
|
||||
) -> tuple[bool, list[RateLimitStatus]]:
|
||||
"""
|
||||
Check all applicable rate limits for an action.
|
||||
Check all applicable rate limits for an action WITHOUT consuming slots.
|
||||
|
||||
Use this during validation to check if action would be allowed.
|
||||
Call record_action() after successful execution to consume slots.
|
||||
|
||||
Args:
|
||||
action: The action to check
|
||||
@@ -235,28 +238,53 @@ class RateLimiter:
|
||||
statuses: list[RateLimitStatus] = []
|
||||
allowed = True
|
||||
|
||||
# Check general actions limit
|
||||
actions_allowed, actions_status = await self.acquire("actions", agent_id)
|
||||
# Check general actions limit (read-only)
|
||||
actions_status = await self.check("actions", agent_id)
|
||||
statuses.append(actions_status)
|
||||
if not actions_allowed:
|
||||
if actions_status.is_limited:
|
||||
allowed = False
|
||||
|
||||
# Check LLM-specific limit for LLM calls
|
||||
if action.action_type.value == "llm_call":
|
||||
llm_allowed, llm_status = await self.acquire("llm_calls", agent_id)
|
||||
llm_status = await self.check("llm_calls", agent_id)
|
||||
statuses.append(llm_status)
|
||||
if not llm_allowed:
|
||||
if llm_status.is_limited:
|
||||
allowed = False
|
||||
|
||||
# Check file ops limit for file operations
|
||||
if action.action_type.value in {"file_read", "file_write", "file_delete"}:
|
||||
file_allowed, file_status = await self.acquire("file_ops", agent_id)
|
||||
file_status = await self.check("file_ops", agent_id)
|
||||
statuses.append(file_status)
|
||||
if not file_allowed:
|
||||
if file_status.is_limited:
|
||||
allowed = False
|
||||
|
||||
return allowed, statuses
|
||||
|
||||
async def record_action(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Record an action by consuming rate limit slots.
|
||||
|
||||
Call this AFTER successful execution to properly count the action.
|
||||
|
||||
Args:
|
||||
action: The executed action
|
||||
"""
|
||||
agent_id = action.metadata.agent_id
|
||||
|
||||
# Consume general actions slot
|
||||
await self.acquire("actions", agent_id)
|
||||
|
||||
# Consume LLM-specific slot for LLM calls
|
||||
if action.action_type.value == "llm_call":
|
||||
await self.acquire("llm_calls", agent_id)
|
||||
|
||||
# Consume file ops slot for file operations
|
||||
if action.action_type.value in {"file_read", "file_write", "file_delete"}:
|
||||
await self.acquire("file_ops", agent_id)
|
||||
|
||||
async def require(
|
||||
self,
|
||||
limit_name: str,
|
||||
|
||||
@@ -20,13 +20,13 @@ from app.services.mcp.routing import ToolInfo, ToolResult
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_registry():
|
||||
async def reset_registry():
|
||||
"""Reset the singleton registry before and after each test."""
|
||||
MCPServerRegistry.reset_instance()
|
||||
reset_mcp_client()
|
||||
await reset_mcp_client()
|
||||
yield
|
||||
MCPServerRegistry.reset_instance()
|
||||
reset_mcp_client()
|
||||
await reset_mcp_client()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -388,7 +388,8 @@ class TestModuleLevelFunctions:
|
||||
mock_shutdown.return_value = None
|
||||
await shutdown_mcp_client()
|
||||
|
||||
def test_reset_mcp_client(self, reset_registry):
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_mcp_client(self, reset_registry):
|
||||
"""Test resetting the global client."""
|
||||
reset_mcp_client()
|
||||
await reset_mcp_client()
|
||||
# Should not raise
|
||||
|
||||
436
backend/tests/services/safety/test_costs.py
Normal file
436
backend/tests/services/safety/test_costs.py
Normal file
@@ -0,0 +1,436 @@
|
||||
"""Tests for cost controller module."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.safety.costs.controller import (
|
||||
BudgetTracker,
|
||||
CostController,
|
||||
)
|
||||
from app.services.safety.exceptions import BudgetExceededError
|
||||
from app.services.safety.models import (
|
||||
ActionMetadata,
|
||||
ActionRequest,
|
||||
ActionType,
|
||||
BudgetScope,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def budget_tracker() -> BudgetTracker:
|
||||
"""Create a budget tracker for testing."""
|
||||
return BudgetTracker(
|
||||
scope=BudgetScope.SESSION,
|
||||
scope_id="test-session",
|
||||
tokens_limit=1000,
|
||||
cost_limit_usd=10.0,
|
||||
warning_threshold=0.8,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cost_controller() -> CostController:
|
||||
"""Create a cost controller for testing."""
|
||||
return CostController(
|
||||
default_session_tokens=1000,
|
||||
default_session_cost_usd=10.0,
|
||||
default_daily_tokens=5000,
|
||||
default_daily_cost_usd=50.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_metadata() -> ActionMetadata:
|
||||
"""Create sample action metadata."""
|
||||
return ActionMetadata(
|
||||
agent_id="test-agent",
|
||||
session_id="test-session",
|
||||
)
|
||||
|
||||
|
||||
def create_action(
|
||||
metadata: ActionMetadata,
|
||||
estimated_tokens: int = 100,
|
||||
estimated_cost: float = 0.01,
|
||||
) -> ActionRequest:
|
||||
"""Helper to create test actions."""
|
||||
return ActionRequest(
|
||||
action_type=ActionType.LLM_CALL,
|
||||
tool_name="test_tool",
|
||||
resource="test-resource",
|
||||
arguments={},
|
||||
metadata=metadata,
|
||||
estimated_cost_tokens=estimated_tokens,
|
||||
estimated_cost_usd=estimated_cost,
|
||||
)
|
||||
|
||||
|
||||
class TestBudgetTracker:
|
||||
"""Tests for BudgetTracker class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initial_status(self, budget_tracker: BudgetTracker) -> None:
|
||||
"""Test initial budget status is clean."""
|
||||
status = await budget_tracker.get_status()
|
||||
|
||||
assert status.tokens_used == 0
|
||||
assert status.cost_used_usd == 0.0
|
||||
assert status.tokens_remaining == 1000
|
||||
assert status.cost_remaining_usd == 10.0
|
||||
assert status.is_warning is False
|
||||
assert status.is_exceeded is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_usage(self, budget_tracker: BudgetTracker) -> None:
|
||||
"""Test adding usage updates counters."""
|
||||
await budget_tracker.add_usage(tokens=100, cost_usd=1.0)
|
||||
|
||||
status = await budget_tracker.get_status()
|
||||
assert status.tokens_used == 100
|
||||
assert status.cost_used_usd == 1.0
|
||||
assert status.tokens_remaining == 900
|
||||
assert status.cost_remaining_usd == 9.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_warning_threshold(self, budget_tracker: BudgetTracker) -> None:
|
||||
"""Test warning is triggered at threshold."""
|
||||
# Add usage to reach 80% of tokens
|
||||
await budget_tracker.add_usage(tokens=800, cost_usd=1.0)
|
||||
|
||||
status = await budget_tracker.get_status()
|
||||
assert status.is_warning is True
|
||||
assert status.is_exceeded is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_budget_exceeded(self, budget_tracker: BudgetTracker) -> None:
|
||||
"""Test budget exceeded detection."""
|
||||
# Exceed token limit
|
||||
await budget_tracker.add_usage(tokens=1100, cost_usd=1.0)
|
||||
|
||||
status = await budget_tracker.get_status()
|
||||
assert status.is_exceeded is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_budget_allows(self, budget_tracker: BudgetTracker) -> None:
|
||||
"""Test check_budget allows within budget."""
|
||||
result = await budget_tracker.check_budget(
|
||||
estimated_tokens=500,
|
||||
estimated_cost_usd=5.0,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_budget_denies(self, budget_tracker: BudgetTracker) -> None:
|
||||
"""Test check_budget denies when would exceed."""
|
||||
# Use most of the budget
|
||||
await budget_tracker.add_usage(tokens=800, cost_usd=8.0)
|
||||
|
||||
# Check would exceed
|
||||
result = await budget_tracker.check_budget(
|
||||
estimated_tokens=300,
|
||||
estimated_cost_usd=3.0,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset(self, budget_tracker: BudgetTracker) -> None:
|
||||
"""Test manual reset clears counters."""
|
||||
await budget_tracker.add_usage(tokens=500, cost_usd=5.0)
|
||||
await budget_tracker.reset()
|
||||
|
||||
status = await budget_tracker.get_status()
|
||||
assert status.tokens_used == 0
|
||||
assert status.cost_used_usd == 0.0
|
||||
|
||||
|
||||
class TestCostController:
|
||||
"""Tests for CostController class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_budget_success(
|
||||
self,
|
||||
cost_controller: CostController,
|
||||
) -> None:
|
||||
"""Test budget check passes with available budget."""
|
||||
result = await cost_controller.check_budget(
|
||||
agent_id="test-agent",
|
||||
session_id="test-session",
|
||||
estimated_tokens=100,
|
||||
estimated_cost_usd=1.0,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_budget_session_exceeded(
|
||||
self,
|
||||
cost_controller: CostController,
|
||||
) -> None:
|
||||
"""Test budget check fails when session budget exceeded."""
|
||||
# Use most of session budget
|
||||
await cost_controller.record_usage(
|
||||
agent_id="test-agent",
|
||||
session_id="test-session",
|
||||
tokens=900,
|
||||
cost_usd=9.0,
|
||||
)
|
||||
|
||||
# Check would exceed
|
||||
result = await cost_controller.check_budget(
|
||||
agent_id="test-agent",
|
||||
session_id="test-session",
|
||||
estimated_tokens=200,
|
||||
estimated_cost_usd=2.0,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_budget_daily_exceeded(
|
||||
self,
|
||||
cost_controller: CostController,
|
||||
) -> None:
|
||||
"""Test budget check fails when daily budget exceeded."""
|
||||
# Use most of daily budget
|
||||
await cost_controller.record_usage(
|
||||
agent_id="test-agent",
|
||||
session_id=None,
|
||||
tokens=4900,
|
||||
cost_usd=49.0,
|
||||
)
|
||||
|
||||
# Check would exceed daily
|
||||
result = await cost_controller.check_budget(
|
||||
agent_id="test-agent",
|
||||
session_id="new-session",
|
||||
estimated_tokens=200,
|
||||
estimated_cost_usd=2.0,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_action(
|
||||
self,
|
||||
cost_controller: CostController,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test checking action budget."""
|
||||
action = create_action(
|
||||
sample_metadata,
|
||||
estimated_tokens=100,
|
||||
estimated_cost=0.01,
|
||||
)
|
||||
|
||||
result = await cost_controller.check_action(action)
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_budget_success(
|
||||
self,
|
||||
cost_controller: CostController,
|
||||
) -> None:
|
||||
"""Test require_budget passes when budget available."""
|
||||
# Should not raise
|
||||
await cost_controller.require_budget(
|
||||
agent_id="test-agent",
|
||||
session_id="test-session",
|
||||
estimated_tokens=100,
|
||||
estimated_cost_usd=1.0,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_budget_raises(
|
||||
self,
|
||||
cost_controller: CostController,
|
||||
) -> None:
|
||||
"""Test require_budget raises when budget exceeded."""
|
||||
# Use all session budget
|
||||
await cost_controller.record_usage(
|
||||
agent_id="test-agent",
|
||||
session_id="test-session",
|
||||
tokens=1000,
|
||||
cost_usd=10.0,
|
||||
)
|
||||
|
||||
with pytest.raises(BudgetExceededError) as exc_info:
|
||||
await cost_controller.require_budget(
|
||||
agent_id="test-agent",
|
||||
session_id="test-session",
|
||||
estimated_tokens=100,
|
||||
estimated_cost_usd=1.0,
|
||||
)
|
||||
|
||||
assert "session" in exc_info.value.budget_type.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_usage(
|
||||
self,
|
||||
cost_controller: CostController,
|
||||
) -> None:
|
||||
"""Test recording usage updates trackers."""
|
||||
await cost_controller.record_usage(
|
||||
agent_id="test-agent",
|
||||
session_id="test-session",
|
||||
tokens=100,
|
||||
cost_usd=1.0,
|
||||
)
|
||||
|
||||
# Check session budget was updated
|
||||
session_status = await cost_controller.get_status(
|
||||
BudgetScope.SESSION, "test-session"
|
||||
)
|
||||
assert session_status is not None
|
||||
assert session_status.tokens_used == 100
|
||||
|
||||
# Check daily budget was updated
|
||||
daily_status = await cost_controller.get_status(BudgetScope.DAILY, "test-agent")
|
||||
assert daily_status is not None
|
||||
assert daily_status.tokens_used == 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_statuses(
|
||||
self,
|
||||
cost_controller: CostController,
|
||||
) -> None:
|
||||
"""Test getting all budget statuses."""
|
||||
# Record some usage
|
||||
await cost_controller.record_usage(
|
||||
agent_id="agent-1",
|
||||
session_id="session-1",
|
||||
tokens=100,
|
||||
cost_usd=1.0,
|
||||
)
|
||||
await cost_controller.record_usage(
|
||||
agent_id="agent-2",
|
||||
session_id="session-2",
|
||||
tokens=200,
|
||||
cost_usd=2.0,
|
||||
)
|
||||
|
||||
statuses = await cost_controller.get_all_statuses()
|
||||
assert len(statuses) >= 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_budget(
|
||||
self,
|
||||
cost_controller: CostController,
|
||||
) -> None:
|
||||
"""Test setting custom budget."""
|
||||
await cost_controller.set_budget(
|
||||
scope=BudgetScope.SESSION,
|
||||
scope_id="custom-session",
|
||||
tokens_limit=5000,
|
||||
cost_limit_usd=50.0,
|
||||
)
|
||||
|
||||
status = await cost_controller.get_status(BudgetScope.SESSION, "custom-session")
|
||||
assert status is not None
|
||||
assert status.tokens_limit == 5000
|
||||
assert status.cost_limit_usd == 50.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_budget(
|
||||
self,
|
||||
cost_controller: CostController,
|
||||
) -> None:
|
||||
"""Test resetting budget."""
|
||||
# Record usage
|
||||
await cost_controller.record_usage(
|
||||
agent_id="test-agent",
|
||||
session_id="test-session",
|
||||
tokens=500,
|
||||
cost_usd=5.0,
|
||||
)
|
||||
|
||||
# Reset session budget
|
||||
result = await cost_controller.reset_budget(BudgetScope.SESSION, "test-session")
|
||||
assert result is True
|
||||
|
||||
# Verify reset
|
||||
status = await cost_controller.get_status(BudgetScope.SESSION, "test-session")
|
||||
assert status is not None
|
||||
assert status.tokens_used == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_nonexistent_budget(
|
||||
self,
|
||||
cost_controller: CostController,
|
||||
) -> None:
|
||||
"""Test resetting non-existent budget returns False."""
|
||||
result = await cost_controller.reset_budget(BudgetScope.SESSION, "nonexistent")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_alert_handler(
|
||||
self,
|
||||
cost_controller: CostController,
|
||||
) -> None:
|
||||
"""Test alert handler is called at warning threshold."""
|
||||
alerts_received = []
|
||||
|
||||
def alert_handler(alert_type: str, message: str, status):
|
||||
alerts_received.append((alert_type, message))
|
||||
|
||||
cost_controller.add_alert_handler(alert_handler)
|
||||
|
||||
# Record usage to reach warning threshold (80%)
|
||||
await cost_controller.record_usage(
|
||||
agent_id="test-agent",
|
||||
session_id="test-session",
|
||||
tokens=850, # 85% of 1000
|
||||
cost_usd=0.0,
|
||||
)
|
||||
|
||||
assert len(alerts_received) > 0
|
||||
assert alerts_received[0][0] == "warning"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_alert_handler(
|
||||
self,
|
||||
cost_controller: CostController,
|
||||
) -> None:
|
||||
"""Test removing alert handler."""
|
||||
alerts_received = []
|
||||
|
||||
def alert_handler(alert_type: str, message: str, status):
|
||||
alerts_received.append((alert_type, message))
|
||||
|
||||
cost_controller.add_alert_handler(alert_handler)
|
||||
cost_controller.remove_alert_handler(alert_handler)
|
||||
|
||||
# Record usage to reach warning threshold
|
||||
await cost_controller.record_usage(
|
||||
agent_id="test-agent",
|
||||
session_id="test-session",
|
||||
tokens=850,
|
||||
cost_usd=0.0,
|
||||
)
|
||||
|
||||
assert len(alerts_received) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_alert_deduplication(
|
||||
self,
|
||||
cost_controller: CostController,
|
||||
) -> None:
|
||||
"""Test alerts are only sent once per budget (no spam)."""
|
||||
alerts_received = []
|
||||
|
||||
def alert_handler(alert_type: str, message: str, status):
|
||||
alerts_received.append((alert_type, message))
|
||||
|
||||
cost_controller.add_alert_handler(alert_handler)
|
||||
|
||||
# Record usage multiple times at warning level
|
||||
# Session budget is 1000 with 80% threshold = 800 tokens
|
||||
# 10 * 85 = 850 tokens triggers session warning once
|
||||
for _ in range(10):
|
||||
await cost_controller.record_usage(
|
||||
agent_id="test-agent",
|
||||
session_id="test-session",
|
||||
tokens=85, # Each call adds 85 tokens
|
||||
cost_usd=0.0,
|
||||
)
|
||||
|
||||
# Should only receive ONE session warning (daily budget of 5000
|
||||
# isn't reached yet). The key point is we don't get 10 alerts!
|
||||
assert len(alerts_received) == 1
|
||||
assert alerts_received[0][0] == "warning"
|
||||
assert "Session" in alerts_received[0][1]
|
||||
508
backend/tests/services/safety/test_guardian.py
Normal file
508
backend/tests/services/safety/test_guardian.py
Normal file
@@ -0,0 +1,508 @@
|
||||
"""Tests for SafetyGuardian integration."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from app.services.safety.config import SafetyConfig
|
||||
from app.services.safety.costs.controller import CostController
|
||||
from app.services.safety.guardian import (
|
||||
SafetyGuardian,
|
||||
get_safety_guardian,
|
||||
reset_safety_guardian,
|
||||
shutdown_safety_guardian,
|
||||
)
|
||||
from app.services.safety.limits.limiter import RateLimiter
|
||||
from app.services.safety.loops.detector import LoopDetector
|
||||
from app.services.safety.models import (
|
||||
ActionMetadata,
|
||||
ActionRequest,
|
||||
ActionResult,
|
||||
ActionType,
|
||||
AuditEvent,
|
||||
AuditEventType,
|
||||
AutonomyLevel,
|
||||
BudgetScope,
|
||||
SafetyDecision,
|
||||
SafetyPolicy,
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def reset_guardian():
|
||||
"""Reset the singleton guardian before and after each test."""
|
||||
await reset_safety_guardian()
|
||||
yield
|
||||
await reset_safety_guardian()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def safety_config() -> SafetyConfig:
|
||||
"""Create a test safety configuration."""
|
||||
return SafetyConfig(
|
||||
enabled=True,
|
||||
strict_mode=True,
|
||||
hitl_enabled=False,
|
||||
auto_checkpoint_destructive=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cost_controller() -> CostController:
|
||||
"""Create a cost controller for testing."""
|
||||
return CostController(
|
||||
default_session_tokens=1000,
|
||||
default_session_cost_usd=10.0,
|
||||
default_daily_tokens=5000,
|
||||
default_daily_cost_usd=50.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rate_limiter() -> RateLimiter:
|
||||
"""Create a rate limiter for testing."""
|
||||
return RateLimiter()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def loop_detector() -> LoopDetector:
|
||||
"""Create a loop detector for testing."""
|
||||
return LoopDetector(
|
||||
history_size=10,
|
||||
max_exact_repetitions=3,
|
||||
max_semantic_repetitions=5,
|
||||
)
|
||||
|
||||
|
||||
def _make_audit_event() -> AuditEvent:
|
||||
"""Create a mock audit event."""
|
||||
return AuditEvent(
|
||||
event_type=AuditEventType.ACTION_REQUESTED,
|
||||
agent_id="test-agent",
|
||||
action_id="test-action",
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def guardian(
|
||||
safety_config: SafetyConfig,
|
||||
cost_controller: CostController,
|
||||
rate_limiter: RateLimiter,
|
||||
loop_detector: LoopDetector,
|
||||
) -> SafetyGuardian:
|
||||
"""Create a SafetyGuardian for testing."""
|
||||
guardian = SafetyGuardian(
|
||||
config=safety_config,
|
||||
cost_controller=cost_controller,
|
||||
rate_limiter=rate_limiter,
|
||||
loop_detector=loop_detector,
|
||||
)
|
||||
# Patch the audit logger to avoid actual logging
|
||||
# Return proper AuditEvent objects instead of AsyncMock
|
||||
guardian._audit_logger = MagicMock()
|
||||
guardian._audit_logger.log = AsyncMock(return_value=_make_audit_event())
|
||||
guardian._audit_logger.log_action_request = AsyncMock(
|
||||
return_value=_make_audit_event()
|
||||
)
|
||||
guardian._audit_logger.log_action_executed = AsyncMock(return_value=None)
|
||||
await guardian.initialize()
|
||||
return guardian
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_metadata() -> ActionMetadata:
|
||||
"""Create sample action metadata."""
|
||||
return ActionMetadata(
|
||||
agent_id="test-agent",
|
||||
session_id="test-session",
|
||||
autonomy_level=AutonomyLevel.MILESTONE,
|
||||
)
|
||||
|
||||
|
||||
def create_action(
|
||||
metadata: ActionMetadata,
|
||||
tool_name: str = "test_tool",
|
||||
action_type: ActionType = ActionType.LLM_CALL,
|
||||
resource: str = "/tmp/test.txt", # noqa: S108
|
||||
estimated_tokens: int = 100,
|
||||
estimated_cost: float = 0.01,
|
||||
) -> ActionRequest:
|
||||
"""Helper to create test actions."""
|
||||
return ActionRequest(
|
||||
action_type=action_type,
|
||||
tool_name=tool_name,
|
||||
resource=resource,
|
||||
arguments={},
|
||||
metadata=metadata,
|
||||
estimated_cost_tokens=estimated_tokens,
|
||||
estimated_cost_usd=estimated_cost,
|
||||
)
|
||||
|
||||
|
||||
class TestSafetyGuardianInit:
|
||||
"""Tests for SafetyGuardian initialization."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_creates_subsystems(
|
||||
self,
|
||||
safety_config: SafetyConfig,
|
||||
) -> None:
|
||||
"""Test initialization creates subsystems if not provided."""
|
||||
with patch(
|
||||
"app.services.safety.guardian.get_audit_logger",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
guardian = SafetyGuardian(config=safety_config)
|
||||
await guardian.initialize()
|
||||
|
||||
assert guardian.cost_controller is not None
|
||||
assert guardian.rate_limiter is not None
|
||||
assert guardian.loop_detector is not None
|
||||
assert guardian.is_initialized is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_with_provided_subsystems(
|
||||
self,
|
||||
safety_config: SafetyConfig,
|
||||
cost_controller: CostController,
|
||||
rate_limiter: RateLimiter,
|
||||
loop_detector: LoopDetector,
|
||||
) -> None:
|
||||
"""Test initialization uses provided subsystems."""
|
||||
guardian = SafetyGuardian(
|
||||
config=safety_config,
|
||||
cost_controller=cost_controller,
|
||||
rate_limiter=rate_limiter,
|
||||
loop_detector=loop_detector,
|
||||
)
|
||||
guardian._audit_logger = MagicMock()
|
||||
await guardian.initialize()
|
||||
|
||||
# Should use the provided instances
|
||||
assert guardian.cost_controller is cost_controller
|
||||
assert guardian.rate_limiter is rate_limiter
|
||||
assert guardian.loop_detector is loop_detector
|
||||
|
||||
|
||||
class TestSafetyGuardianValidation:
|
||||
"""Tests for SafetyGuardian.validate()."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_success(
|
||||
self,
|
||||
guardian: SafetyGuardian,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test successful validation passes all checks."""
|
||||
action = create_action(sample_metadata)
|
||||
|
||||
result = await guardian.validate(action)
|
||||
|
||||
assert result.allowed is True
|
||||
assert result.decision == SafetyDecision.ALLOW
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_disabled_allows_all(
|
||||
self,
|
||||
guardian: SafetyGuardian,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test validation with disabled safety allows all."""
|
||||
guardian._config.enabled = False
|
||||
action = create_action(sample_metadata)
|
||||
|
||||
result = await guardian.validate(action)
|
||||
|
||||
assert result.allowed is True
|
||||
assert "disabled" in result.reasons[0].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_budget_exceeded(
|
||||
self,
|
||||
guardian: SafetyGuardian,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test validation fails when budget exceeded."""
|
||||
# Use up the session budget
|
||||
await guardian.cost_controller.record_usage(
|
||||
agent_id=sample_metadata.agent_id,
|
||||
session_id=sample_metadata.session_id,
|
||||
tokens=1000,
|
||||
cost_usd=10.0,
|
||||
)
|
||||
|
||||
action = create_action(sample_metadata, estimated_tokens=100)
|
||||
result = await guardian.validate(action)
|
||||
|
||||
assert result.allowed is False
|
||||
assert result.decision == SafetyDecision.DENY
|
||||
assert any("budget" in r.lower() for r in result.reasons)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_rate_limit_exceeded(
|
||||
self,
|
||||
guardian: SafetyGuardian,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test validation fails when rate limit exceeded."""
|
||||
# Exhaust rate limits by calling validate many times
|
||||
for _ in range(100): # More than default limit
|
||||
action = create_action(sample_metadata)
|
||||
await guardian.rate_limiter.acquire("actions", sample_metadata.agent_id)
|
||||
|
||||
action = create_action(sample_metadata)
|
||||
result = await guardian.validate(action)
|
||||
|
||||
# Should be denied or delayed
|
||||
assert result.allowed is False
|
||||
assert result.decision in (SafetyDecision.DENY, SafetyDecision.DELAY)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_loop_detected(
|
||||
self,
|
||||
guardian: SafetyGuardian,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test validation fails when loop detected."""
|
||||
action = create_action(sample_metadata)
|
||||
|
||||
# Record the same action multiple times (to trigger loop)
|
||||
for _ in range(3):
|
||||
await guardian.loop_detector.record(action)
|
||||
|
||||
result = await guardian.validate(action)
|
||||
|
||||
assert result.allowed is False
|
||||
assert result.decision == SafetyDecision.DENY
|
||||
assert any("loop" in r.lower() for r in result.reasons)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_denied_tool(
|
||||
self,
|
||||
guardian: SafetyGuardian,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test validation fails for denied tools."""
|
||||
# Create action with tool that matches denied pattern
|
||||
action = create_action(sample_metadata, tool_name="shell_exec")
|
||||
|
||||
# Create policy with denied pattern
|
||||
policy = SafetyPolicy(
|
||||
name="test-policy",
|
||||
allowed_tools=["*"],
|
||||
denied_tools=["shell_*"],
|
||||
)
|
||||
|
||||
result = await guardian.validate(action, policy=policy)
|
||||
|
||||
assert result.allowed is False
|
||||
assert result.decision == SafetyDecision.DENY
|
||||
assert any("denied" in r.lower() for r in result.reasons)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_with_custom_policy(
|
||||
self,
|
||||
guardian: SafetyGuardian,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test validation with custom policy."""
|
||||
action = create_action(sample_metadata, tool_name="allowed_tool")
|
||||
|
||||
policy = SafetyPolicy(
|
||||
name="test-custom-policy",
|
||||
allowed_tools=["allowed_*"],
|
||||
denied_tools=[],
|
||||
)
|
||||
|
||||
result = await guardian.validate(action, policy=policy)
|
||||
|
||||
assert result.allowed is True
|
||||
assert result.decision == SafetyDecision.ALLOW
|
||||
|
||||
|
||||
class TestSafetyGuardianRecording:
|
||||
"""Tests for SafetyGuardian.record_execution()."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_execution_updates_cost(
|
||||
self,
|
||||
guardian: SafetyGuardian,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test recording execution updates cost tracker."""
|
||||
action = create_action(sample_metadata)
|
||||
action_result = ActionResult(
|
||||
action_id=action.id,
|
||||
success=True,
|
||||
actual_cost_tokens=50,
|
||||
actual_cost_usd=0.005,
|
||||
)
|
||||
|
||||
await guardian.record_execution(action, action_result)
|
||||
|
||||
# Check cost was recorded
|
||||
status = await guardian.cost_controller.get_status(
|
||||
BudgetScope.SESSION, sample_metadata.session_id
|
||||
)
|
||||
assert status is not None
|
||||
assert status.tokens_used == 50
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_execution_updates_loop_history(
|
||||
self,
|
||||
guardian: SafetyGuardian,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test recording execution updates loop detector history."""
|
||||
action = create_action(sample_metadata)
|
||||
action_result = ActionResult(
|
||||
action_id=action.id,
|
||||
success=True,
|
||||
)
|
||||
|
||||
await guardian.record_execution(action, action_result)
|
||||
|
||||
# Check action was recorded in loop detector
|
||||
stats = await guardian.loop_detector.get_stats(sample_metadata.agent_id)
|
||||
assert stats["history_size"] == 1
|
||||
|
||||
|
||||
class TestSafetyGuardianSingleton:
|
||||
"""Tests for SafetyGuardian singleton functions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_safety_guardian_creates_singleton(
|
||||
self,
|
||||
reset_guardian,
|
||||
) -> None:
|
||||
"""Test get_safety_guardian creates singleton."""
|
||||
with patch(
|
||||
"app.services.safety.guardian.get_audit_logger",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
guardian1 = await get_safety_guardian()
|
||||
guardian2 = await get_safety_guardian()
|
||||
|
||||
assert guardian1 is guardian2
|
||||
assert guardian1.is_initialized is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown_safety_guardian(
|
||||
self,
|
||||
reset_guardian,
|
||||
) -> None:
|
||||
"""Test shutdown cleans up singleton."""
|
||||
with patch(
|
||||
"app.services.safety.guardian.get_audit_logger",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
guardian = await get_safety_guardian()
|
||||
assert guardian.is_initialized is True
|
||||
|
||||
await shutdown_safety_guardian()
|
||||
# Singleton should be cleared - next get creates new instance
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_safety_guardian(
|
||||
self,
|
||||
reset_guardian,
|
||||
) -> None:
|
||||
"""Test reset clears singleton."""
|
||||
with patch(
|
||||
"app.services.safety.guardian.get_audit_logger",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
guardian1 = await get_safety_guardian()
|
||||
|
||||
await reset_safety_guardian()
|
||||
|
||||
guardian2 = await get_safety_guardian()
|
||||
assert guardian1 is not guardian2
|
||||
|
||||
|
||||
class TestPatternMatching:
|
||||
"""Tests for pattern matching logic."""
|
||||
|
||||
def test_exact_match(self) -> None:
|
||||
"""Test exact pattern matching."""
|
||||
guardian = SafetyGuardian()
|
||||
assert guardian._matches_pattern("file_read", "file_read") is True
|
||||
assert guardian._matches_pattern("file_read", "file_write") is False
|
||||
|
||||
def test_wildcard_all(self) -> None:
|
||||
"""Test wildcard * matches all."""
|
||||
guardian = SafetyGuardian()
|
||||
assert guardian._matches_pattern("anything", "*") is True
|
||||
assert guardian._matches_pattern("", "*") is True
|
||||
|
||||
def test_prefix_wildcard(self) -> None:
|
||||
"""Test prefix wildcard matching."""
|
||||
guardian = SafetyGuardian()
|
||||
assert guardian._matches_pattern("test_read", "*_read") is True
|
||||
assert guardian._matches_pattern("test_write", "*_read") is False
|
||||
|
||||
def test_suffix_wildcard(self) -> None:
|
||||
"""Test suffix wildcard matching."""
|
||||
guardian = SafetyGuardian()
|
||||
assert guardian._matches_pattern("file_read", "file_*") is True
|
||||
assert guardian._matches_pattern("shell_read", "file_*") is False
|
||||
|
||||
def test_contains_wildcard(self) -> None:
|
||||
"""Test contains wildcard matching."""
|
||||
guardian = SafetyGuardian()
|
||||
assert guardian._matches_pattern("test_dangerous_action", "*dangerous*") is True
|
||||
assert guardian._matches_pattern("test_safe_action", "*dangerous*") is False
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Tests for error handling in SafetyGuardian."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_strict_mode_fails_on_error(
|
||||
self,
|
||||
guardian: SafetyGuardian,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test strict mode denies on unexpected errors."""
|
||||
action = create_action(sample_metadata)
|
||||
|
||||
# Force an error by breaking the cost controller
|
||||
original_check = guardian.cost_controller.check_budget
|
||||
guardian.cost_controller.check_budget = AsyncMock(
|
||||
side_effect=Exception("Unexpected error")
|
||||
)
|
||||
|
||||
result = await guardian.validate(action)
|
||||
|
||||
assert result.allowed is False
|
||||
assert result.decision == SafetyDecision.DENY
|
||||
assert any("error" in r.lower() for r in result.reasons)
|
||||
|
||||
# Restore
|
||||
guardian.cost_controller.check_budget = original_check
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_strict_mode_allows_on_error(
|
||||
self,
|
||||
guardian: SafetyGuardian,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test non-strict mode allows on unexpected errors."""
|
||||
guardian._config.strict_mode = False
|
||||
action = create_action(sample_metadata)
|
||||
|
||||
# Force an error by breaking the cost controller
|
||||
original_check = guardian.cost_controller.check_budget
|
||||
guardian.cost_controller.check_budget = AsyncMock(
|
||||
side_effect=Exception("Unexpected error")
|
||||
)
|
||||
|
||||
result = await guardian.validate(action)
|
||||
|
||||
assert result.allowed is True
|
||||
assert result.decision == SafetyDecision.ALLOW
|
||||
|
||||
# Restore
|
||||
guardian.cost_controller.check_budget = original_check
|
||||
guardian._config.strict_mode = True
|
||||
405
backend/tests/services/safety/test_limits.py
Normal file
405
backend/tests/services/safety/test_limits.py
Normal file
@@ -0,0 +1,405 @@
|
||||
"""Tests for rate limiter module."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.safety.exceptions import RateLimitExceededError
|
||||
from app.services.safety.limits.limiter import (
|
||||
RateLimiter,
|
||||
SlidingWindowCounter,
|
||||
)
|
||||
from app.services.safety.models import (
|
||||
ActionMetadata,
|
||||
ActionRequest,
|
||||
ActionType,
|
||||
RateLimitConfig,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sliding_counter() -> SlidingWindowCounter:
|
||||
"""Create a sliding window counter for testing."""
|
||||
return SlidingWindowCounter(
|
||||
limit=5,
|
||||
window_seconds=60,
|
||||
burst_limit=3,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rate_limiter() -> RateLimiter:
|
||||
"""Create a rate limiter for testing."""
|
||||
limiter = RateLimiter()
|
||||
# Configure a test limit
|
||||
limiter.configure(
|
||||
RateLimitConfig(
|
||||
name="test_limit",
|
||||
limit=5,
|
||||
window_seconds=60,
|
||||
burst_limit=3,
|
||||
)
|
||||
)
|
||||
return limiter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_metadata() -> ActionMetadata:
|
||||
"""Create sample action metadata."""
|
||||
return ActionMetadata(
|
||||
agent_id="test-agent",
|
||||
session_id="test-session",
|
||||
)
|
||||
|
||||
|
||||
def create_action(
|
||||
metadata: ActionMetadata,
|
||||
action_type: ActionType = ActionType.LLM_CALL,
|
||||
) -> ActionRequest:
|
||||
"""Helper to create test actions."""
|
||||
return ActionRequest(
|
||||
action_type=action_type,
|
||||
tool_name="test_tool",
|
||||
resource="test-resource",
|
||||
arguments={},
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
class TestSlidingWindowCounter:
|
||||
"""Tests for SlidingWindowCounter class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_acquire_allowed(
|
||||
self,
|
||||
sliding_counter: SlidingWindowCounter,
|
||||
) -> None:
|
||||
"""Test first acquire is always allowed."""
|
||||
allowed, retry_after = await sliding_counter.try_acquire()
|
||||
|
||||
assert allowed is True
|
||||
assert retry_after == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_burst_limit(
|
||||
self,
|
||||
sliding_counter: SlidingWindowCounter,
|
||||
) -> None:
|
||||
"""Test burst limit is enforced."""
|
||||
# Acquire up to burst limit (3)
|
||||
for _ in range(3):
|
||||
allowed, _ = await sliding_counter.try_acquire()
|
||||
assert allowed is True
|
||||
|
||||
# Next should be denied (burst exceeded)
|
||||
allowed, retry_after = await sliding_counter.try_acquire()
|
||||
assert allowed is False
|
||||
assert retry_after > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_status(
|
||||
self,
|
||||
sliding_counter: SlidingWindowCounter,
|
||||
) -> None:
|
||||
"""Test getting counter status."""
|
||||
# Make some requests
|
||||
await sliding_counter.try_acquire()
|
||||
await sliding_counter.try_acquire()
|
||||
|
||||
current, remaining, reset_in = await sliding_counter.get_status()
|
||||
|
||||
assert current == 2
|
||||
assert remaining == 3 # 5 - 2
|
||||
assert reset_in >= 0
|
||||
|
||||
|
||||
class TestRateLimiter:
|
||||
"""Tests for RateLimiter class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_status(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
) -> None:
|
||||
"""Test checking rate limit status."""
|
||||
status = await rate_limiter.check("test_limit", "test-key")
|
||||
|
||||
assert status.name == "test_limit"
|
||||
assert status.current_count == 0
|
||||
assert status.limit == 5
|
||||
assert status.remaining == 5
|
||||
assert status.is_limited is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_success(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
) -> None:
|
||||
"""Test successful acquire."""
|
||||
allowed, status = await rate_limiter.acquire("test_limit", "test-key")
|
||||
|
||||
assert allowed is True
|
||||
assert status.current_count == 1
|
||||
assert status.remaining == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_burst_exceeded(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
) -> None:
|
||||
"""Test acquire fails when burst exceeded."""
|
||||
# Acquire up to burst limit
|
||||
for _ in range(3):
|
||||
allowed, _ = await rate_limiter.acquire("test_limit", "test-key")
|
||||
assert allowed is True
|
||||
|
||||
# Next should fail
|
||||
allowed, status = await rate_limiter.acquire("test_limit", "test-key")
|
||||
assert allowed is False
|
||||
assert status.is_limited is True
|
||||
assert status.retry_after_seconds > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_success(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
) -> None:
|
||||
"""Test require passes when not limited."""
|
||||
# Should not raise
|
||||
await rate_limiter.require("test_limit", "test-key")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_raises(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
) -> None:
|
||||
"""Test require raises when limited."""
|
||||
# Use up burst limit
|
||||
for _ in range(3):
|
||||
await rate_limiter.acquire("test_limit", "test-key")
|
||||
|
||||
with pytest.raises(RateLimitExceededError) as exc_info:
|
||||
await rate_limiter.require("test_limit", "test-key")
|
||||
|
||||
assert exc_info.value.limit_type == "test_limit"
|
||||
assert exc_info.value.retry_after_seconds > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_action_allowed(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test checking action is allowed."""
|
||||
action = create_action(sample_metadata)
|
||||
|
||||
allowed, statuses = await rate_limiter.check_action(action)
|
||||
|
||||
assert allowed is True
|
||||
assert len(statuses) >= 1 # At least "actions" limit
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_action_llm_limits(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test LLM actions check LLM-specific limits."""
|
||||
action = create_action(sample_metadata, action_type=ActionType.LLM_CALL)
|
||||
|
||||
allowed, statuses = await rate_limiter.check_action(action)
|
||||
|
||||
assert allowed is True
|
||||
# Should have checked both "actions" and "llm_calls"
|
||||
limit_names = [s.name for s in statuses]
|
||||
assert "actions" in limit_names
|
||||
assert "llm_calls" in limit_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_action_file_limits(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test file actions check file-specific limits."""
|
||||
action = create_action(sample_metadata, action_type=ActionType.FILE_READ)
|
||||
|
||||
allowed, statuses = await rate_limiter.check_action(action)
|
||||
|
||||
assert allowed is True
|
||||
# Should have checked both "actions" and "file_ops"
|
||||
limit_names = [s.name for s in statuses]
|
||||
assert "actions" in limit_names
|
||||
assert "file_ops" in limit_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_action_does_not_consume_slot(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test check_action only checks without consuming slots."""
|
||||
action = create_action(sample_metadata)
|
||||
|
||||
# Check multiple times - should never consume
|
||||
for _ in range(10):
|
||||
allowed, _ = await rate_limiter.check_action(action)
|
||||
assert allowed is True
|
||||
|
||||
# Verify no slots were consumed
|
||||
status = await rate_limiter.check("actions", sample_metadata.agent_id)
|
||||
assert status.current_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_action_consumes_slot(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test record_action consumes rate limit slots."""
|
||||
action = create_action(sample_metadata)
|
||||
|
||||
# Record the action
|
||||
await rate_limiter.record_action(action)
|
||||
|
||||
# Verify slot was consumed
|
||||
status = await rate_limiter.check("actions", sample_metadata.agent_id)
|
||||
assert status.current_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_action_consumes_type_specific_slots(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test record_action consumes type-specific slots."""
|
||||
# LLM action
|
||||
llm_action = create_action(sample_metadata, action_type=ActionType.LLM_CALL)
|
||||
await rate_limiter.record_action(llm_action)
|
||||
|
||||
statuses = await rate_limiter.get_all_statuses(sample_metadata.agent_id)
|
||||
assert statuses["actions"].current_count == 1
|
||||
assert statuses["llm_calls"].current_count == 1
|
||||
assert statuses["file_ops"].current_count == 0
|
||||
|
||||
# File action
|
||||
file_action = create_action(sample_metadata, action_type=ActionType.FILE_READ)
|
||||
await rate_limiter.record_action(file_action)
|
||||
|
||||
statuses = await rate_limiter.get_all_statuses(sample_metadata.agent_id)
|
||||
assert statuses["actions"].current_count == 2
|
||||
assert statuses["llm_calls"].current_count == 1
|
||||
assert statuses["file_ops"].current_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_statuses(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
) -> None:
|
||||
"""Test getting all rate limit statuses."""
|
||||
# Make some requests
|
||||
await rate_limiter.acquire("actions", "test-key")
|
||||
await rate_limiter.acquire("llm_calls", "test-key")
|
||||
|
||||
statuses = await rate_limiter.get_all_statuses("test-key")
|
||||
|
||||
assert "actions" in statuses
|
||||
assert "llm_calls" in statuses
|
||||
assert "file_ops" in statuses
|
||||
assert statuses["actions"].current_count >= 1
|
||||
assert statuses["llm_calls"].current_count >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_single(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
) -> None:
|
||||
"""Test resetting a single rate limit."""
|
||||
# Make some requests
|
||||
await rate_limiter.acquire("test_limit", "test-key")
|
||||
await rate_limiter.acquire("test_limit", "test-key")
|
||||
|
||||
# Reset
|
||||
result = await rate_limiter.reset("test_limit", "test-key")
|
||||
assert result is True
|
||||
|
||||
# Check it's reset
|
||||
status = await rate_limiter.check("test_limit", "test-key")
|
||||
assert status.current_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_nonexistent(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
) -> None:
|
||||
"""Test resetting non-existent limit returns False."""
|
||||
result = await rate_limiter.reset("nonexistent", "test-key")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_all(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
) -> None:
|
||||
"""Test resetting all rate limits for a key."""
|
||||
# Make requests across multiple limits
|
||||
await rate_limiter.acquire("actions", "test-key")
|
||||
await rate_limiter.acquire("llm_calls", "test-key")
|
||||
await rate_limiter.acquire("file_ops", "test-key")
|
||||
|
||||
# Reset all
|
||||
count = await rate_limiter.reset_all("test-key")
|
||||
assert count >= 3
|
||||
|
||||
# Check they're reset
|
||||
statuses = await rate_limiter.get_all_statuses("test-key")
|
||||
for status in statuses.values():
|
||||
assert status.current_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_per_key_isolation(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
) -> None:
|
||||
"""Test rate limits are isolated per key."""
|
||||
# Use up burst limit for key-1
|
||||
for _ in range(3):
|
||||
await rate_limiter.acquire("test_limit", "key-1")
|
||||
|
||||
# key-1 should be limited
|
||||
allowed1, _ = await rate_limiter.acquire("test_limit", "key-1")
|
||||
assert allowed1 is False
|
||||
|
||||
# key-2 should still be allowed
|
||||
allowed2, _ = await rate_limiter.acquire("test_limit", "key-2")
|
||||
assert allowed2 is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_configure_custom_limit(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
) -> None:
|
||||
"""Test configuring custom rate limits."""
|
||||
rate_limiter.configure(
|
||||
RateLimitConfig(
|
||||
name="custom",
|
||||
limit=100,
|
||||
window_seconds=120,
|
||||
burst_limit=50,
|
||||
)
|
||||
)
|
||||
|
||||
status = await rate_limiter.check("custom", "test-key")
|
||||
assert status.limit == 100
|
||||
assert status.window_seconds == 120
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_limit_fallback(
|
||||
self,
|
||||
rate_limiter: RateLimiter,
|
||||
) -> None:
|
||||
"""Test fallback to default limit for unknown limit names."""
|
||||
# Request limit that doesn't exist
|
||||
status = await rate_limiter.check("unknown_limit", "test-key")
|
||||
|
||||
# Should use default (60/60s)
|
||||
assert status.limit == 60
|
||||
assert status.window_seconds == 60
|
||||
Reference in New Issue
Block a user