feat(safety): enhance rate limiting and cost control with alert deduplication and usage tracking

- Added `record_action` in `RateLimiter` for precise tracking of slot consumption post-validation.
- Introduced deduplication mechanism for warning alerts in `CostController` to prevent spamming.
- Refactored `CostController`'s session and daily budget alert handling for improved clarity.
- Implemented test suites for `CostController` and `SafetyGuardian` to validate changes.
- Expanded integration testing to cover deduplication, validation, and loop detection edge cases.
This commit is contained in:
2026-01-03 17:55:34 +01:00
parent 520c06175e
commit caf283bed2
9 changed files with 1782 additions and 92 deletions

View File

@@ -15,13 +15,20 @@ from .config import (
get_policy_for_autonomy_level,
get_safety_config,
)
from .costs.controller import CostController
from .exceptions import (
BudgetExceededError,
LoopDetectedError,
RateLimitExceededError,
SafetyError,
)
from .limits.limiter import RateLimiter
from .loops.detector import LoopDetector
from .models import (
ActionRequest,
ActionResult,
AuditEventType,
BudgetScope,
GuardianResult,
SafetyDecision,
SafetyPolicy,
@@ -62,6 +69,9 @@ class SafetyGuardian:
self,
config: SafetyConfig | None = None,
audit_logger: AuditLogger | None = None,
cost_controller: CostController | None = None,
rate_limiter: RateLimiter | None = None,
loop_detector: LoopDetector | None = None,
) -> None:
"""
Initialize the SafetyGuardian.
@@ -69,17 +79,22 @@ class SafetyGuardian:
Args:
config: Optional safety configuration. If None, loads from environment.
audit_logger: Optional audit logger. If None, uses global instance.
cost_controller: Optional cost controller. If None, creates default.
rate_limiter: Optional rate limiter. If None, creates default.
loop_detector: Optional loop detector. If None, creates default.
"""
self._config = config or get_safety_config()
self._audit_logger = audit_logger
self._initialized = False
self._lock = asyncio.Lock()
# Subsystem references (will be initialized lazily)
# Core safety subsystems (always initialized)
self._cost_controller: CostController | None = cost_controller
self._rate_limiter: RateLimiter | None = rate_limiter
self._loop_detector: LoopDetector | None = loop_detector
# Optional subsystems (will be initialized when available)
self._permission_manager: Any = None
self._cost_controller: Any = None
self._rate_limiter: Any = None
self._loop_detector: Any = None
self._hitl_manager: Any = None
self._rollback_manager: Any = None
self._content_filter: Any = None
@@ -95,6 +110,21 @@ class SafetyGuardian:
"""Check if the guardian is initialized."""
return self._initialized
@property
def cost_controller(self) -> CostController | None:
"""Get the cost controller instance."""
return self._cost_controller
@property
def rate_limiter(self) -> RateLimiter | None:
"""Get the rate limiter instance."""
return self._rate_limiter
@property
def loop_detector(self) -> LoopDetector | None:
"""Get the loop detector instance."""
return self._loop_detector
async def initialize(self) -> None:
"""Initialize the SafetyGuardian and all subsystems."""
async with self._lock:
@@ -108,11 +138,23 @@ class SafetyGuardian:
if self._audit_logger is None:
self._audit_logger = await get_audit_logger()
# Initialize subsystems lazily as they're implemented
# For now, we'll import and initialize them when available
# Initialize core safety subsystems
if self._cost_controller is None:
self._cost_controller = CostController()
logger.debug("Initialized CostController")
if self._rate_limiter is None:
self._rate_limiter = RateLimiter()
logger.debug("Initialized RateLimiter")
if self._loop_detector is None:
self._loop_detector = LoopDetector()
logger.debug("Initialized LoopDetector")
self._initialized = True
logger.info("SafetyGuardian initialized")
logger.info(
"SafetyGuardian initialized with CostController, RateLimiter, LoopDetector"
)
async def shutdown(self) -> None:
"""Shutdown the SafetyGuardian and all subsystems."""
@@ -309,13 +351,40 @@ class SafetyGuardian:
# Update cost tracking
if self._cost_controller:
# Track actual cost
pass
try:
# Use explicit None check - 0 is a valid cost value
tokens = (
result.actual_cost_tokens
if result.actual_cost_tokens is not None
else action.estimated_cost_tokens
)
cost_usd = (
result.actual_cost_usd
if result.actual_cost_usd is not None
else action.estimated_cost_usd
)
await self._cost_controller.record_usage(
agent_id=action.metadata.agent_id,
session_id=action.metadata.session_id,
tokens=tokens,
cost_usd=cost_usd,
)
except Exception as e:
logger.warning("Failed to record cost: %s", e)
# Update rate limiter - consume slots for executed actions
if self._rate_limiter:
try:
await self._rate_limiter.record_action(action)
except Exception as e:
logger.warning("Failed to record action in rate limiter: %s", e)
# Update loop detection history
if self._loop_detector:
# Add to action history
pass
try:
await self._loop_detector.record(action)
except Exception as e:
logger.warning("Failed to record action in loop detector: %s", e)
async def rollback(self, checkpoint_id: str) -> bool:
"""
@@ -442,14 +511,80 @@ class SafetyGuardian:
policy: SafetyPolicy,
) -> GuardianResult:
"""Check if action is within budget."""
# TODO: Implement with CostController
# For now, return allow
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Budget check passed (not fully implemented)"],
)
if self._cost_controller is None:
logger.warning("CostController not initialized - skipping budget check")
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Budget check skipped (controller not initialized)"],
)
agent_id = action.metadata.agent_id
session_id = action.metadata.session_id
try:
# Check if we have budget for this action
has_budget = await self._cost_controller.check_budget(
agent_id=agent_id,
session_id=session_id,
estimated_tokens=action.estimated_cost_tokens,
estimated_cost_usd=action.estimated_cost_usd,
)
if not has_budget:
# Get current status for better error message
if session_id:
session_status = await self._cost_controller.get_status(
BudgetScope.SESSION, session_id
)
if session_status and session_status.is_exceeded:
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=[
f"Session budget exceeded: {session_status.tokens_used}"
f"/{session_status.tokens_limit} tokens"
],
)
agent_status = await self._cost_controller.get_status(
BudgetScope.DAILY, agent_id
)
if agent_status and agent_status.is_exceeded:
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=[
f"Daily budget exceeded: {agent_status.tokens_used}"
f"/{agent_status.tokens_limit} tokens"
],
)
# Generic budget exceeded
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=["Budget exceeded"],
)
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Budget check passed"],
)
except BudgetExceededError as e:
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=[str(e)],
)
async def _check_rate_limit(
self,
@@ -457,14 +592,78 @@ class SafetyGuardian:
policy: SafetyPolicy,
) -> GuardianResult:
"""Check if action is within rate limits."""
# TODO: Implement with RateLimiter
# For now, return allow
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Rate limit check passed (not fully implemented)"],
)
if self._rate_limiter is None:
logger.warning("RateLimiter not initialized - skipping rate limit check")
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Rate limit check skipped (limiter not initialized)"],
)
try:
# Check all applicable rate limits for this action
allowed, statuses = await self._rate_limiter.check_action(action)
if not allowed:
# Find the first exceeded limit for the error message
exceeded_status = next(
(s for s in statuses if s.is_limited),
statuses[0] if statuses else None,
)
if exceeded_status:
retry_after = exceeded_status.retry_after_seconds
# Determine if this is a soft limit (delay) or hard limit (deny)
if retry_after > 0 and retry_after <= 5.0:
# Short wait - suggest delay
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DELAY,
reasons=[
f"Rate limit '{exceeded_status.name}' exceeded. "
f"Current: {exceeded_status.current_count}/{exceeded_status.limit}"
],
retry_after_seconds=retry_after,
)
else:
# Hard deny
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=[
f"Rate limit '{exceeded_status.name}' exceeded. "
f"Current: {exceeded_status.current_count}/{exceeded_status.limit}. "
f"Retry after {retry_after:.1f}s"
],
retry_after_seconds=retry_after,
)
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=["Rate limit exceeded"],
)
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Rate limit check passed"],
)
except RateLimitExceededError as e:
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=[str(e)],
retry_after_seconds=e.retry_after_seconds,
)
async def _check_loops(
self,
@@ -472,14 +671,51 @@ class SafetyGuardian:
policy: SafetyPolicy,
) -> GuardianResult:
"""Check for action loops."""
# TODO: Implement with LoopDetector
# For now, return allow
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Loop check passed (not fully implemented)"],
)
if self._loop_detector is None:
logger.warning("LoopDetector not initialized - skipping loop check")
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Loop check skipped (detector not initialized)"],
)
try:
# Check if this action would create a loop
is_loop, loop_type = await self._loop_detector.check(action)
if is_loop:
# Get suggestions for breaking the loop
from .loops.detector import LoopBreaker
suggestions = await LoopBreaker.suggest_alternatives(
action, loop_type or "unknown"
)
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=[
f"Loop detected: {loop_type}",
*suggestions,
],
)
return GuardianResult(
action_id=action.id,
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["Loop check passed"],
)
except LoopDetectedError as e:
return GuardianResult(
action_id=action.id,
allowed=False,
decision=SafetyDecision.DENY,
reasons=[str(e)],
)
async def _check_hitl(
self,
@@ -610,7 +846,19 @@ async def shutdown_safety_guardian() -> None:
_guardian_instance = None
def reset_safety_guardian() -> None:
"""Reset the SafetyGuardian (for testing)."""
async def reset_safety_guardian() -> None:
"""
Reset the SafetyGuardian (for testing).
This is an async function to properly acquire the guardian lock
and avoid race conditions with get_safety_guardian().
"""
global _guardian_instance
_guardian_instance = None
async with _guardian_lock:
if _guardian_instance is not None:
try:
await _guardian_instance.shutdown()
except Exception: # noqa: S110
pass # Ignore errors during test cleanup
_guardian_instance = None