feat(safety): enhance rate limiting and cost control with alert deduplication and usage tracking

- Added `record_action` in `RateLimiter` for precise tracking of slot consumption post-validation.
- Introduced deduplication mechanism for warning alerts in `CostController` to prevent spamming.
- Refactored `CostController`'s session and daily budget alert handling for improved clarity.
- Implemented test suites for `CostController` and `SafetyGuardian` to validate changes.
- Expanded integration testing to cover deduplication, validation, and loop detection edge cases.
This commit is contained in:
2026-01-03 17:55:34 +01:00
parent 520c06175e
commit caf283bed2
9 changed files with 1782 additions and 92 deletions

View File

@@ -411,7 +411,20 @@ async def shutdown_mcp_client() -> None:
_manager_instance = None
def reset_mcp_client() -> None:
"""Reset the global MCP client manager (for testing)."""
async def reset_mcp_client() -> None:
"""
Reset the global MCP client manager (for testing).
This is an async function to properly acquire the manager lock
and avoid race conditions with get_mcp_client().
"""
global _manager_instance
_manager_instance = None
async with _manager_lock:
if _manager_instance is not None:
# Shutdown gracefully before resetting
try:
await _manager_instance.shutdown()
except Exception: # noqa: S110
pass # Ignore errors during test cleanup
_manager_instance = None

View File

@@ -161,7 +161,7 @@ class MCPConnection:
server_name=self.server_name,
url=self.config.url,
cause=e,
)
) from e
else:
# For STDIO and SSE transports, we'll implement later
raise NotImplementedError(
@@ -297,13 +297,13 @@ class MCPConnection:
server_name=self.server_name,
url=f"{self.config.url}{path}",
cause=e,
)
) from e
except Exception as e:
raise MCPConnectionError(
f"Request failed: {e}",
server_name=self.server_name,
cause=e,
)
) from e
class ConnectionPool:
@@ -322,8 +322,19 @@ class ConnectionPool:
"""
self._connections: dict[str, MCPConnection] = {}
self._lock = asyncio.Lock()
self._per_server_locks: dict[str, asyncio.Lock] = {}
self._max_per_server = max_connections_per_server
def _get_server_lock(self, server_name: str) -> asyncio.Lock:
"""Get or create a lock for a specific server.
Uses setdefault for atomic dict access to prevent race conditions
where two coroutines could create different locks for the same server.
"""
# setdefault is atomic - if key exists, returns existing value
# if key doesn't exist, inserts new value and returns it
return self._per_server_locks.setdefault(server_name, asyncio.Lock())
async def get_connection(
self,
server_name: str,
@@ -332,6 +343,9 @@ class ConnectionPool:
"""
Get or create a connection to a server.
Uses per-server locking to avoid blocking all connections
when establishing a new connection.
Args:
server_name: Name of the server
config: Server configuration
@@ -339,17 +353,33 @@ class ConnectionPool:
Returns:
Active connection
"""
async with self._lock:
if server_name not in self._connections:
connection = MCPConnection(server_name, config)
await connection.connect()
self._connections[server_name] = connection
# Quick check without lock - if connection exists and is connected, return it
if server_name in self._connections:
connection = self._connections[server_name]
if connection.is_connected:
return connection
# Reconnect if not connected
if not connection.is_connected:
# Need to create or reconnect - use per-server lock to avoid blocking others
async with self._lock:
server_lock = self._get_server_lock(server_name)
async with server_lock:
# Double-check after acquiring per-server lock
if server_name in self._connections:
connection = self._connections[server_name]
if connection.is_connected:
return connection
# Connection exists but not connected - reconnect
await connection.connect()
return connection
# Create new connection (outside global lock, under per-server lock)
connection = MCPConnection(server_name, config)
await connection.connect()
# Store connection under global lock
async with self._lock:
self._connections[server_name] = connection
return connection
@@ -374,6 +404,9 @@ class ConnectionPool:
if server_name in self._connections:
await self._connections[server_name].disconnect()
del self._connections[server_name]
# Clean up per-server lock
if server_name in self._per_server_locks:
del self._per_server_locks[server_name]
async def close_all(self) -> None:
"""Close all connections in the pool."""
@@ -385,6 +418,7 @@ class ConnectionPool:
logger.warning("Error closing connection: %s", e)
self._connections.clear()
self._per_server_locks.clear()
logger.info("Closed all MCP connections")
async def health_check_all(self) -> dict[str, bool]:
@@ -394,8 +428,12 @@ class ConnectionPool:
Returns:
Dict mapping server names to health status
"""
# Copy connections under lock to prevent modification during iteration
async with self._lock:
connections_snapshot = dict(self._connections)
results = {}
for name, connection in self._connections.items():
for name, connection in connections_snapshot.items():
results[name] = await connection.health_check()
return results

View File

@@ -185,6 +185,9 @@ class CostController:
# Alert handlers
self._alert_handlers: list[Any] = []
# Track which budgets have had warning alerts sent (to avoid spam)
self._warned_budgets: set[str] = set()
async def get_or_create_tracker(
self,
scope: BudgetScope,
@@ -343,32 +346,44 @@ class CostController:
"""
# Update session budget
if session_id:
session_key = f"session:{session_id}"
session_tracker = await self.get_or_create_tracker(
BudgetScope.SESSION, session_id
)
await session_tracker.add_usage(tokens, cost_usd)
# Check for warning
# Check for warning (only alert once per budget to avoid spam)
status = await session_tracker.get_status()
if status.is_warning and not status.is_exceeded:
await self._send_alert(
"warning",
f"Session {session_id} at {status.tokens_used}/{status.tokens_limit} tokens",
status,
)
if session_key not in self._warned_budgets:
self._warned_budgets.add(session_key)
await self._send_alert(
"warning",
f"Session {session_id} at {status.tokens_used}/{status.tokens_limit} tokens",
status,
)
elif not status.is_warning:
# Clear warning flag if usage dropped below threshold (e.g., after reset)
self._warned_budgets.discard(session_key)
# Update agent daily budget
daily_key = f"daily:{agent_id}"
agent_tracker = await self.get_or_create_tracker(BudgetScope.DAILY, agent_id)
await agent_tracker.add_usage(tokens, cost_usd)
# Check for warning
# Check for warning (only alert once per budget to avoid spam)
status = await agent_tracker.get_status()
if status.is_warning and not status.is_exceeded:
await self._send_alert(
"warning",
f"Agent {agent_id} at {status.tokens_used}/{status.tokens_limit} daily tokens",
status,
)
if daily_key not in self._warned_budgets:
self._warned_budgets.add(daily_key)
await self._send_alert(
"warning",
f"Agent {agent_id} at {status.tokens_used}/{status.tokens_limit} daily tokens",
status,
)
elif not status.is_warning:
# Clear warning flag if usage dropped below threshold (e.g., after reset)
self._warned_budgets.discard(daily_key)
async def get_status(
self,
@@ -388,20 +403,18 @@ class CostController:
key = f"{scope.value}:{scope_id}"
async with self._lock:
tracker = self._trackers.get(key)
if tracker:
return await tracker.get_status()
return None
# Get status while holding lock to prevent TOCTOU race
if tracker:
return await tracker.get_status()
return None
async def get_all_statuses(self) -> list[BudgetStatus]:
"""Get status of all tracked budgets."""
statuses = []
async with self._lock:
trackers = list(self._trackers.values())
for tracker in trackers:
statuses.append(await tracker.get_status())
# Get all statuses while holding lock to prevent TOCTOU race
for tracker in self._trackers.values():
statuses.append(await tracker.get_status())
return statuses
async def set_budget(
@@ -453,11 +466,11 @@ class CostController:
key = f"{scope.value}:{scope_id}"
async with self._lock:
tracker = self._trackers.get(key)
if tracker:
await tracker.reset()
return True
return False
# Reset while holding lock to prevent TOCTOU race
if tracker:
await tracker.reset()
return True
return False
def add_alert_handler(self, handler: Any) -> None:
"""Add an alert handler."""

View File

@@ -15,13 +15,20 @@ from .config import (
get_policy_for_autonomy_level,
get_safety_config,
)
from .costs.controller import CostController
from .exceptions import (
BudgetExceededError,
LoopDetectedError,
RateLimitExceededError,
SafetyError,
)
from .limits.limiter import RateLimiter
from .loops.detector import LoopDetector
from .models import (
ActionRequest,
ActionResult,
AuditEventType,
BudgetScope,
GuardianResult,
SafetyDecision,
SafetyPolicy,
@@ -62,6 +69,9 @@ class SafetyGuardian:
self,
config: SafetyConfig | None = None,
audit_logger: AuditLogger | None = None,
cost_controller: CostController | None = None,
rate_limiter: RateLimiter | None = None,
loop_detector: LoopDetector | None = None,
) -> None:
"""
Initialize the SafetyGuardian.
@@ -69,17 +79,22 @@ class SafetyGuardian:
Args:
config: Optional safety configuration. If None, loads from environment.
audit_logger: Optional audit logger. If None, uses global instance.
cost_controller: Optional cost controller. If None, creates default.
rate_limiter: Optional rate limiter. If None, creates default.
loop_detector: Optional loop detector. If None, creates default.
"""
self._config = config or get_safety_config()
self._audit_logger = audit_logger
self._initialized = False
self._lock = asyncio.Lock()
# Subsystem references (will be initialized lazily)
# Core safety subsystems (always initialized)
self._cost_controller: CostController | None = cost_controller
self._rate_limiter: RateLimiter | None = rate_limiter
self._loop_detector: LoopDetector | None = loop_detector
# Optional subsystems (will be initialized when available)
self._permission_manager: Any = None
self._cost_controller: Any = None
self._rate_limiter: Any = None
self._loop_detector: Any = None
self._hitl_manager: Any = None
self._rollback_manager: Any = None
self._content_filter: Any = None
@@ -95,6 +110,21 @@ class SafetyGuardian:
"""Check if the guardian is initialized."""
return self._initialized
@property
def cost_controller(self) -> CostController | None:
"""Get the cost controller instance."""
return self._cost_controller
@property
def rate_limiter(self) -> RateLimiter | None:
"""Get the rate limiter instance."""
return self._rate_limiter
@property
def loop_detector(self) -> LoopDetector | None:
"""Get the loop detector instance."""
return self._loop_detector
async def initialize(self) -> None:
"""Initialize the SafetyGuardian and all subsystems."""
async with self._lock:
@@ -108,11 +138,23 @@ class SafetyGuardian:
if self._audit_logger is None:
self._audit_logger = await get_audit_logger()
# Initialize subsystems lazily as they're implemented
# For now, we'll import and initialize them when available
# Initialize core safety subsystems
if self._cost_controller is None:
self._cost_controller = CostController()
logger.debug("Initialized CostController")
if self._rate_limiter is None:
self._rate_limiter = RateLimiter()
logger.debug("Initialized RateLimiter")
if self._loop_detector is None:
self._loop_detector = LoopDetector()
logger.debug("Initialized LoopDetector")
self._initialized = True
logger.info("SafetyGuardian initialized")
logger.info(
"SafetyGuardian initialized with CostController, RateLimiter, LoopDetector"
)
async def shutdown(self) -> None:
"""Shutdown the SafetyGuardian and all subsystems."""
@@ -309,13 +351,40 @@ class SafetyGuardian:
# Update cost tracking
if self._cost_controller:
# Track actual cost
pass
try:
# Use explicit None check - 0 is a valid cost value
tokens = (
result.actual_cost_tokens
if result.actual_cost_tokens is not None
else action.estimated_cost_tokens
)
cost_usd = (
result.actual_cost_usd
if result.actual_cost_usd is not None
else action.estimated_cost_usd
)
await self._cost_controller.record_usage(
agent_id=action.metadata.agent_id,
session_id=action.metadata.session_id,
tokens=tokens,
cost_usd=cost_usd,
)
except Exception as e:
logger.warning("Failed to record cost: %s", e)
# Update rate limiter - consume slots for executed actions
if self._rate_limiter:
try:
await self._rate_limiter.record_action(action)
except Exception as e:
logger.warning("Failed to record action in rate limiter: %s", e)
# Update loop detection history
if self._loop_detector:
# Add to action history
pass
try:
await self._loop_detector.record(action)
except Exception as e:
logger.warning("Failed to record action in loop detector: %s", e)
async def rollback(self, checkpoint_id: str) -> bool:
"""
@@ -442,14 +511,80 @@ class SafetyGuardian:
policy: SafetyPolicy,
) -> GuardianResult:
"""Check if action is within budget."""
# TODO: Implement with CostController
# For now, return allow
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Budget check passed (not fully implemented)"],
)
if self._cost_controller is None:
logger.warning("CostController not initialized - skipping budget check")
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Budget check skipped (controller not initialized)"],
)
agent_id = action.metadata.agent_id
session_id = action.metadata.session_id
try:
# Check if we have budget for this action
has_budget = await self._cost_controller.check_budget(
agent_id=agent_id,
session_id=session_id,
estimated_tokens=action.estimated_cost_tokens,
estimated_cost_usd=action.estimated_cost_usd,
)
if not has_budget:
# Get current status for better error message
if session_id:
session_status = await self._cost_controller.get_status(
BudgetScope.SESSION, session_id
)
if session_status and session_status.is_exceeded:
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=[
f"Session budget exceeded: {session_status.tokens_used}"
f"/{session_status.tokens_limit} tokens"
],
)
agent_status = await self._cost_controller.get_status(
BudgetScope.DAILY, agent_id
)
if agent_status and agent_status.is_exceeded:
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=[
f"Daily budget exceeded: {agent_status.tokens_used}"
f"/{agent_status.tokens_limit} tokens"
],
)
# Generic budget exceeded
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=["Budget exceeded"],
)
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Budget check passed"],
)
except BudgetExceededError as e:
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=[str(e)],
)
async def _check_rate_limit(
self,
@@ -457,14 +592,78 @@ class SafetyGuardian:
policy: SafetyPolicy,
) -> GuardianResult:
"""Check if action is within rate limits."""
# TODO: Implement with RateLimiter
# For now, return allow
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Rate limit check passed (not fully implemented)"],
)
if self._rate_limiter is None:
logger.warning("RateLimiter not initialized - skipping rate limit check")
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Rate limit check skipped (limiter not initialized)"],
)
try:
# Check all applicable rate limits for this action
allowed, statuses = await self._rate_limiter.check_action(action)
if not allowed:
# Find the first exceeded limit for the error message
exceeded_status = next(
(s for s in statuses if s.is_limited),
statuses[0] if statuses else None,
)
if exceeded_status:
retry_after = exceeded_status.retry_after_seconds
# Determine if this is a soft limit (delay) or hard limit (deny)
if retry_after > 0 and retry_after <= 5.0:
# Short wait - suggest delay
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DELAY,
reasons=[
f"Rate limit '{exceeded_status.name}' exceeded. "
f"Current: {exceeded_status.current_count}/{exceeded_status.limit}"
],
retry_after_seconds=retry_after,
)
else:
# Hard deny
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=[
f"Rate limit '{exceeded_status.name}' exceeded. "
f"Current: {exceeded_status.current_count}/{exceeded_status.limit}. "
f"Retry after {retry_after:.1f}s"
],
retry_after_seconds=retry_after,
)
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=["Rate limit exceeded"],
)
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Rate limit check passed"],
)
except RateLimitExceededError as e:
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=[str(e)],
retry_after_seconds=e.retry_after_seconds,
)
async def _check_loops(
self,
@@ -472,14 +671,51 @@ class SafetyGuardian:
policy: SafetyPolicy,
) -> GuardianResult:
"""Check for action loops."""
# TODO: Implement with LoopDetector
# For now, return allow
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Loop check passed (not fully implemented)"],
)
if self._loop_detector is None:
logger.warning("LoopDetector not initialized - skipping loop check")
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Loop check skipped (detector not initialized)"],
)
try:
# Check if this action would create a loop
is_loop, loop_type = await self._loop_detector.check(action)
if is_loop:
# Get suggestions for breaking the loop
from .loops.detector import LoopBreaker
suggestions = await LoopBreaker.suggest_alternatives(
action, loop_type or "unknown"
)
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=[
f"Loop detected: {loop_type}",
*suggestions,
],
)
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Loop check passed"],
)
except LoopDetectedError as e:
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=[str(e)],
)
async def _check_hitl(
self,
@@ -610,7 +846,19 @@ async def shutdown_safety_guardian() -> None:
_guardian_instance = None
def reset_safety_guardian() -> None:
"""Reset the SafetyGuardian (for testing)."""
async def reset_safety_guardian() -> None:
"""
Reset the SafetyGuardian (for testing).
This is an async function to properly acquire the guardian lock
and avoid race conditions with get_safety_guardian().
"""
global _guardian_instance
_guardian_instance = None
async with _guardian_lock:
if _guardian_instance is not None:
try:
await _guardian_instance.shutdown()
except Exception: # noqa: S110
pass # Ignore errors during test cleanup
_guardian_instance = None

View File

@@ -223,7 +223,10 @@ class RateLimiter:
action: ActionRequest,
) -> tuple[bool, list[RateLimitStatus]]:
"""
Check all applicable rate limits for an action.
Check all applicable rate limits for an action WITHOUT consuming slots.
Use this during validation to check if action would be allowed.
Call record_action() after successful execution to consume slots.
Args:
action: The action to check
@@ -235,28 +238,53 @@ class RateLimiter:
statuses: list[RateLimitStatus] = []
allowed = True
# Check general actions limit
actions_allowed, actions_status = await self.acquire("actions", agent_id)
# Check general actions limit (read-only)
actions_status = await self.check("actions", agent_id)
statuses.append(actions_status)
if not actions_allowed:
if actions_status.is_limited:
allowed = False
# Check LLM-specific limit for LLM calls
if action.action_type.value == "llm_call":
llm_allowed, llm_status = await self.acquire("llm_calls", agent_id)
llm_status = await self.check("llm_calls", agent_id)
statuses.append(llm_status)
if not llm_allowed:
if llm_status.is_limited:
allowed = False
# Check file ops limit for file operations
if action.action_type.value in {"file_read", "file_write", "file_delete"}:
file_allowed, file_status = await self.acquire("file_ops", agent_id)
file_status = await self.check("file_ops", agent_id)
statuses.append(file_status)
if not file_allowed:
if file_status.is_limited:
allowed = False
return allowed, statuses
async def record_action(
self,
action: ActionRequest,
) -> None:
"""
Record an action by consuming rate limit slots.
Call this AFTER successful execution to properly count the action.
Args:
action: The executed action
"""
agent_id = action.metadata.agent_id
# Consume general actions slot
await self.acquire("actions", agent_id)
# Consume LLM-specific slot for LLM calls
if action.action_type.value == "llm_call":
await self.acquire("llm_calls", agent_id)
# Consume file ops slot for file operations
if action.action_type.value in {"file_read", "file_write", "file_delete"}:
await self.acquire("file_ops", agent_id)
async def require(
self,
limit_name: str,