feat(safety): enhance rate limiting and cost control with alert deduplication and usage tracking
- Added `record_action` in `RateLimiter` for precise tracking of slot consumption post-validation. - Introduced deduplication mechanism for warning alerts in `CostController` to prevent spamming. - Refactored `CostController`'s session and daily budget alert handling for improved clarity. - Implemented test suites for `CostController` and `SafetyGuardian` to validate changes. - Expanded integration testing to cover deduplication, validation, and loop detection edge cases.
This commit is contained in:
508
backend/tests/services/safety/test_guardian.py
Normal file
508
backend/tests/services/safety/test_guardian.py
Normal file
@@ -0,0 +1,508 @@
|
||||
"""Tests for SafetyGuardian integration."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from app.services.safety.config import SafetyConfig
|
||||
from app.services.safety.costs.controller import CostController
|
||||
from app.services.safety.guardian import (
|
||||
SafetyGuardian,
|
||||
get_safety_guardian,
|
||||
reset_safety_guardian,
|
||||
shutdown_safety_guardian,
|
||||
)
|
||||
from app.services.safety.limits.limiter import RateLimiter
|
||||
from app.services.safety.loops.detector import LoopDetector
|
||||
from app.services.safety.models import (
|
||||
ActionMetadata,
|
||||
ActionRequest,
|
||||
ActionResult,
|
||||
ActionType,
|
||||
AuditEvent,
|
||||
AuditEventType,
|
||||
AutonomyLevel,
|
||||
BudgetScope,
|
||||
SafetyDecision,
|
||||
SafetyPolicy,
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def reset_guardian():
|
||||
"""Reset the singleton guardian before and after each test."""
|
||||
await reset_safety_guardian()
|
||||
yield
|
||||
await reset_safety_guardian()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def safety_config() -> SafetyConfig:
|
||||
"""Create a test safety configuration."""
|
||||
return SafetyConfig(
|
||||
enabled=True,
|
||||
strict_mode=True,
|
||||
hitl_enabled=False,
|
||||
auto_checkpoint_destructive=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cost_controller() -> CostController:
|
||||
"""Create a cost controller for testing."""
|
||||
return CostController(
|
||||
default_session_tokens=1000,
|
||||
default_session_cost_usd=10.0,
|
||||
default_daily_tokens=5000,
|
||||
default_daily_cost_usd=50.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rate_limiter() -> RateLimiter:
|
||||
"""Create a rate limiter for testing."""
|
||||
return RateLimiter()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def loop_detector() -> LoopDetector:
|
||||
"""Create a loop detector for testing."""
|
||||
return LoopDetector(
|
||||
history_size=10,
|
||||
max_exact_repetitions=3,
|
||||
max_semantic_repetitions=5,
|
||||
)
|
||||
|
||||
|
||||
def _make_audit_event() -> AuditEvent:
|
||||
"""Create a mock audit event."""
|
||||
return AuditEvent(
|
||||
event_type=AuditEventType.ACTION_REQUESTED,
|
||||
agent_id="test-agent",
|
||||
action_id="test-action",
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def guardian(
|
||||
safety_config: SafetyConfig,
|
||||
cost_controller: CostController,
|
||||
rate_limiter: RateLimiter,
|
||||
loop_detector: LoopDetector,
|
||||
) -> SafetyGuardian:
|
||||
"""Create a SafetyGuardian for testing."""
|
||||
guardian = SafetyGuardian(
|
||||
config=safety_config,
|
||||
cost_controller=cost_controller,
|
||||
rate_limiter=rate_limiter,
|
||||
loop_detector=loop_detector,
|
||||
)
|
||||
# Patch the audit logger to avoid actual logging
|
||||
# Return proper AuditEvent objects instead of AsyncMock
|
||||
guardian._audit_logger = MagicMock()
|
||||
guardian._audit_logger.log = AsyncMock(return_value=_make_audit_event())
|
||||
guardian._audit_logger.log_action_request = AsyncMock(
|
||||
return_value=_make_audit_event()
|
||||
)
|
||||
guardian._audit_logger.log_action_executed = AsyncMock(return_value=None)
|
||||
await guardian.initialize()
|
||||
return guardian
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_metadata() -> ActionMetadata:
|
||||
"""Create sample action metadata."""
|
||||
return ActionMetadata(
|
||||
agent_id="test-agent",
|
||||
session_id="test-session",
|
||||
autonomy_level=AutonomyLevel.MILESTONE,
|
||||
)
|
||||
|
||||
|
||||
def create_action(
|
||||
metadata: ActionMetadata,
|
||||
tool_name: str = "test_tool",
|
||||
action_type: ActionType = ActionType.LLM_CALL,
|
||||
resource: str = "/tmp/test.txt", # noqa: S108
|
||||
estimated_tokens: int = 100,
|
||||
estimated_cost: float = 0.01,
|
||||
) -> ActionRequest:
|
||||
"""Helper to create test actions."""
|
||||
return ActionRequest(
|
||||
action_type=action_type,
|
||||
tool_name=tool_name,
|
||||
resource=resource,
|
||||
arguments={},
|
||||
metadata=metadata,
|
||||
estimated_cost_tokens=estimated_tokens,
|
||||
estimated_cost_usd=estimated_cost,
|
||||
)
|
||||
|
||||
|
||||
class TestSafetyGuardianInit:
|
||||
"""Tests for SafetyGuardian initialization."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_creates_subsystems(
|
||||
self,
|
||||
safety_config: SafetyConfig,
|
||||
) -> None:
|
||||
"""Test initialization creates subsystems if not provided."""
|
||||
with patch(
|
||||
"app.services.safety.guardian.get_audit_logger",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
guardian = SafetyGuardian(config=safety_config)
|
||||
await guardian.initialize()
|
||||
|
||||
assert guardian.cost_controller is not None
|
||||
assert guardian.rate_limiter is not None
|
||||
assert guardian.loop_detector is not None
|
||||
assert guardian.is_initialized is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_with_provided_subsystems(
|
||||
self,
|
||||
safety_config: SafetyConfig,
|
||||
cost_controller: CostController,
|
||||
rate_limiter: RateLimiter,
|
||||
loop_detector: LoopDetector,
|
||||
) -> None:
|
||||
"""Test initialization uses provided subsystems."""
|
||||
guardian = SafetyGuardian(
|
||||
config=safety_config,
|
||||
cost_controller=cost_controller,
|
||||
rate_limiter=rate_limiter,
|
||||
loop_detector=loop_detector,
|
||||
)
|
||||
guardian._audit_logger = MagicMock()
|
||||
await guardian.initialize()
|
||||
|
||||
# Should use the provided instances
|
||||
assert guardian.cost_controller is cost_controller
|
||||
assert guardian.rate_limiter is rate_limiter
|
||||
assert guardian.loop_detector is loop_detector
|
||||
|
||||
|
||||
class TestSafetyGuardianValidation:
|
||||
"""Tests for SafetyGuardian.validate()."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_success(
|
||||
self,
|
||||
guardian: SafetyGuardian,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test successful validation passes all checks."""
|
||||
action = create_action(sample_metadata)
|
||||
|
||||
result = await guardian.validate(action)
|
||||
|
||||
assert result.allowed is True
|
||||
assert result.decision == SafetyDecision.ALLOW
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_disabled_allows_all(
|
||||
self,
|
||||
guardian: SafetyGuardian,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test validation with disabled safety allows all."""
|
||||
guardian._config.enabled = False
|
||||
action = create_action(sample_metadata)
|
||||
|
||||
result = await guardian.validate(action)
|
||||
|
||||
assert result.allowed is True
|
||||
assert "disabled" in result.reasons[0].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_budget_exceeded(
|
||||
self,
|
||||
guardian: SafetyGuardian,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test validation fails when budget exceeded."""
|
||||
# Use up the session budget
|
||||
await guardian.cost_controller.record_usage(
|
||||
agent_id=sample_metadata.agent_id,
|
||||
session_id=sample_metadata.session_id,
|
||||
tokens=1000,
|
||||
cost_usd=10.0,
|
||||
)
|
||||
|
||||
action = create_action(sample_metadata, estimated_tokens=100)
|
||||
result = await guardian.validate(action)
|
||||
|
||||
assert result.allowed is False
|
||||
assert result.decision == SafetyDecision.DENY
|
||||
assert any("budget" in r.lower() for r in result.reasons)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_rate_limit_exceeded(
|
||||
self,
|
||||
guardian: SafetyGuardian,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test validation fails when rate limit exceeded."""
|
||||
# Exhaust rate limits by calling validate many times
|
||||
for _ in range(100): # More than default limit
|
||||
action = create_action(sample_metadata)
|
||||
await guardian.rate_limiter.acquire("actions", sample_metadata.agent_id)
|
||||
|
||||
action = create_action(sample_metadata)
|
||||
result = await guardian.validate(action)
|
||||
|
||||
# Should be denied or delayed
|
||||
assert result.allowed is False
|
||||
assert result.decision in (SafetyDecision.DENY, SafetyDecision.DELAY)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_loop_detected(
|
||||
self,
|
||||
guardian: SafetyGuardian,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test validation fails when loop detected."""
|
||||
action = create_action(sample_metadata)
|
||||
|
||||
# Record the same action multiple times (to trigger loop)
|
||||
for _ in range(3):
|
||||
await guardian.loop_detector.record(action)
|
||||
|
||||
result = await guardian.validate(action)
|
||||
|
||||
assert result.allowed is False
|
||||
assert result.decision == SafetyDecision.DENY
|
||||
assert any("loop" in r.lower() for r in result.reasons)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_denied_tool(
|
||||
self,
|
||||
guardian: SafetyGuardian,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test validation fails for denied tools."""
|
||||
# Create action with tool that matches denied pattern
|
||||
action = create_action(sample_metadata, tool_name="shell_exec")
|
||||
|
||||
# Create policy with denied pattern
|
||||
policy = SafetyPolicy(
|
||||
name="test-policy",
|
||||
allowed_tools=["*"],
|
||||
denied_tools=["shell_*"],
|
||||
)
|
||||
|
||||
result = await guardian.validate(action, policy=policy)
|
||||
|
||||
assert result.allowed is False
|
||||
assert result.decision == SafetyDecision.DENY
|
||||
assert any("denied" in r.lower() for r in result.reasons)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_with_custom_policy(
|
||||
self,
|
||||
guardian: SafetyGuardian,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test validation with custom policy."""
|
||||
action = create_action(sample_metadata, tool_name="allowed_tool")
|
||||
|
||||
policy = SafetyPolicy(
|
||||
name="test-custom-policy",
|
||||
allowed_tools=["allowed_*"],
|
||||
denied_tools=[],
|
||||
)
|
||||
|
||||
result = await guardian.validate(action, policy=policy)
|
||||
|
||||
assert result.allowed is True
|
||||
assert result.decision == SafetyDecision.ALLOW
|
||||
|
||||
|
||||
class TestSafetyGuardianRecording:
|
||||
"""Tests for SafetyGuardian.record_execution()."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_execution_updates_cost(
|
||||
self,
|
||||
guardian: SafetyGuardian,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test recording execution updates cost tracker."""
|
||||
action = create_action(sample_metadata)
|
||||
action_result = ActionResult(
|
||||
action_id=action.id,
|
||||
success=True,
|
||||
actual_cost_tokens=50,
|
||||
actual_cost_usd=0.005,
|
||||
)
|
||||
|
||||
await guardian.record_execution(action, action_result)
|
||||
|
||||
# Check cost was recorded
|
||||
status = await guardian.cost_controller.get_status(
|
||||
BudgetScope.SESSION, sample_metadata.session_id
|
||||
)
|
||||
assert status is not None
|
||||
assert status.tokens_used == 50
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_execution_updates_loop_history(
|
||||
self,
|
||||
guardian: SafetyGuardian,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test recording execution updates loop detector history."""
|
||||
action = create_action(sample_metadata)
|
||||
action_result = ActionResult(
|
||||
action_id=action.id,
|
||||
success=True,
|
||||
)
|
||||
|
||||
await guardian.record_execution(action, action_result)
|
||||
|
||||
# Check action was recorded in loop detector
|
||||
stats = await guardian.loop_detector.get_stats(sample_metadata.agent_id)
|
||||
assert stats["history_size"] == 1
|
||||
|
||||
|
||||
class TestSafetyGuardianSingleton:
|
||||
"""Tests for SafetyGuardian singleton functions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_safety_guardian_creates_singleton(
|
||||
self,
|
||||
reset_guardian,
|
||||
) -> None:
|
||||
"""Test get_safety_guardian creates singleton."""
|
||||
with patch(
|
||||
"app.services.safety.guardian.get_audit_logger",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
guardian1 = await get_safety_guardian()
|
||||
guardian2 = await get_safety_guardian()
|
||||
|
||||
assert guardian1 is guardian2
|
||||
assert guardian1.is_initialized is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown_safety_guardian(
|
||||
self,
|
||||
reset_guardian,
|
||||
) -> None:
|
||||
"""Test shutdown cleans up singleton."""
|
||||
with patch(
|
||||
"app.services.safety.guardian.get_audit_logger",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
guardian = await get_safety_guardian()
|
||||
assert guardian.is_initialized is True
|
||||
|
||||
await shutdown_safety_guardian()
|
||||
# Singleton should be cleared - next get creates new instance
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_safety_guardian(
|
||||
self,
|
||||
reset_guardian,
|
||||
) -> None:
|
||||
"""Test reset clears singleton."""
|
||||
with patch(
|
||||
"app.services.safety.guardian.get_audit_logger",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
guardian1 = await get_safety_guardian()
|
||||
|
||||
await reset_safety_guardian()
|
||||
|
||||
guardian2 = await get_safety_guardian()
|
||||
assert guardian1 is not guardian2
|
||||
|
||||
|
||||
class TestPatternMatching:
|
||||
"""Tests for pattern matching logic."""
|
||||
|
||||
def test_exact_match(self) -> None:
|
||||
"""Test exact pattern matching."""
|
||||
guardian = SafetyGuardian()
|
||||
assert guardian._matches_pattern("file_read", "file_read") is True
|
||||
assert guardian._matches_pattern("file_read", "file_write") is False
|
||||
|
||||
def test_wildcard_all(self) -> None:
|
||||
"""Test wildcard * matches all."""
|
||||
guardian = SafetyGuardian()
|
||||
assert guardian._matches_pattern("anything", "*") is True
|
||||
assert guardian._matches_pattern("", "*") is True
|
||||
|
||||
def test_prefix_wildcard(self) -> None:
|
||||
"""Test prefix wildcard matching."""
|
||||
guardian = SafetyGuardian()
|
||||
assert guardian._matches_pattern("test_read", "*_read") is True
|
||||
assert guardian._matches_pattern("test_write", "*_read") is False
|
||||
|
||||
def test_suffix_wildcard(self) -> None:
|
||||
"""Test suffix wildcard matching."""
|
||||
guardian = SafetyGuardian()
|
||||
assert guardian._matches_pattern("file_read", "file_*") is True
|
||||
assert guardian._matches_pattern("shell_read", "file_*") is False
|
||||
|
||||
def test_contains_wildcard(self) -> None:
|
||||
"""Test contains wildcard matching."""
|
||||
guardian = SafetyGuardian()
|
||||
assert guardian._matches_pattern("test_dangerous_action", "*dangerous*") is True
|
||||
assert guardian._matches_pattern("test_safe_action", "*dangerous*") is False
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Tests for error handling in SafetyGuardian."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_strict_mode_fails_on_error(
|
||||
self,
|
||||
guardian: SafetyGuardian,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test strict mode denies on unexpected errors."""
|
||||
action = create_action(sample_metadata)
|
||||
|
||||
# Force an error by breaking the cost controller
|
||||
original_check = guardian.cost_controller.check_budget
|
||||
guardian.cost_controller.check_budget = AsyncMock(
|
||||
side_effect=Exception("Unexpected error")
|
||||
)
|
||||
|
||||
result = await guardian.validate(action)
|
||||
|
||||
assert result.allowed is False
|
||||
assert result.decision == SafetyDecision.DENY
|
||||
assert any("error" in r.lower() for r in result.reasons)
|
||||
|
||||
# Restore
|
||||
guardian.cost_controller.check_budget = original_check
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_strict_mode_allows_on_error(
|
||||
self,
|
||||
guardian: SafetyGuardian,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test non-strict mode allows on unexpected errors."""
|
||||
guardian._config.strict_mode = False
|
||||
action = create_action(sample_metadata)
|
||||
|
||||
# Force an error by breaking the cost controller
|
||||
original_check = guardian.cost_controller.check_budget
|
||||
guardian.cost_controller.check_budget = AsyncMock(
|
||||
side_effect=Exception("Unexpected error")
|
||||
)
|
||||
|
||||
result = await guardian.validate(action)
|
||||
|
||||
assert result.allowed is True
|
||||
assert result.decision == SafetyDecision.ALLOW
|
||||
|
||||
# Restore
|
||||
guardian.cost_controller.check_budget = original_check
|
||||
guardian._config.strict_mode = True
|
||||
Reference in New Issue
Block a user