forked from cardosofelipe/fast-next-template
test(safety): add Phase E comprehensive safety tests
- Add tests for models: ActionMetadata, ActionRequest, ActionResult, ValidationRule, BudgetStatus, RateLimitConfig, ApprovalRequest/Response, Checkpoint, RollbackResult, AuditEvent, SafetyPolicy, GuardianResult - Add tests for validation: ActionValidator rules, priorities, patterns, bypass mode, batch validation, rule creation helpers - Add tests for loops: LoopDetector exact/semantic/oscillation detection, LoopBreaker throttle/backoff, history management - Add tests for content filter: PII filtering (email, phone, SSN, credit card), secret blocking (API keys, GitHub tokens, private keys), custom patterns, scan without filtering, dict filtering - Add tests for emergency controls: state management, pause/resume/reset, scoped emergency stops, callbacks, EmergencyTrigger events - Fix exception kwargs in content filter and emergency controls to match exception class signatures All 108 tests passing with lint and type checks clean. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -370,8 +370,8 @@ class ContentFilter:
|
||||
if raise_on_block:
|
||||
raise ContentFilterError(
|
||||
block_reason or "Content blocked",
|
||||
detected_category=all_matches[0].category.value if all_matches else "unknown",
|
||||
pattern_name=all_matches[0].pattern_name if all_matches else None,
|
||||
filter_type=all_matches[0].category.value if all_matches else "unknown",
|
||||
detected_patterns=[m.pattern_name for m in all_matches] if all_matches else [],
|
||||
)
|
||||
elif all_matches:
|
||||
logger.debug(
|
||||
|
||||
@@ -328,7 +328,7 @@ class EmergencyControls:
|
||||
if raise_if_blocked:
|
||||
raise EmergencyStopError(
|
||||
f"Global emergency state: {self._global_state.value}",
|
||||
stop_reason=self._get_last_reason("global"),
|
||||
stop_type=self._get_last_reason("global") or "emergency",
|
||||
triggered_by=self._get_last_triggered_by("global"),
|
||||
)
|
||||
return False
|
||||
@@ -340,9 +340,9 @@ class EmergencyControls:
|
||||
if raise_if_blocked:
|
||||
raise EmergencyStopError(
|
||||
f"Emergency state for {scope}: {state.value}",
|
||||
stop_reason=self._get_last_reason(scope),
|
||||
stop_type=self._get_last_reason(scope) or "emergency",
|
||||
triggered_by=self._get_last_triggered_by(scope),
|
||||
scope=scope,
|
||||
details={"scope": scope},
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
0
backend/tests/services/safety/__init__.py
Normal file
0
backend/tests/services/safety/__init__.py
Normal file
345
backend/tests/services/safety/test_content_filter.py
Normal file
345
backend/tests/services/safety/test_content_filter.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""Tests for content filtering module."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.safety.content.filter import (
|
||||
ContentCategory,
|
||||
ContentFilter,
|
||||
FilterAction,
|
||||
FilterPattern,
|
||||
filter_content,
|
||||
scan_for_secrets,
|
||||
)
|
||||
from app.services.safety.exceptions import ContentFilterError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def filter_all() -> ContentFilter:
|
||||
"""Create a ContentFilter with all filters enabled."""
|
||||
return ContentFilter(
|
||||
enable_pii_filter=True,
|
||||
enable_secret_filter=True,
|
||||
enable_injection_filter=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def filter_pii_only() -> ContentFilter:
|
||||
"""Create a ContentFilter with only PII filter."""
|
||||
return ContentFilter(
|
||||
enable_pii_filter=True,
|
||||
enable_secret_filter=False,
|
||||
enable_injection_filter=False,
|
||||
)
|
||||
|
||||
|
||||
class TestContentFilter:
|
||||
"""Tests for ContentFilter class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_email(self, filter_pii_only: ContentFilter) -> None:
|
||||
"""Test filtering email addresses."""
|
||||
content = "Contact me at john.doe@example.com for details."
|
||||
result = await filter_pii_only.filter(content)
|
||||
|
||||
assert result.has_sensitive_content
|
||||
assert "[EMAIL]" in result.filtered_content
|
||||
assert "john.doe@example.com" not in result.filtered_content
|
||||
assert len(result.matches) == 1
|
||||
assert result.matches[0].category == ContentCategory.PII
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_phone_number(self, filter_pii_only: ContentFilter) -> None:
|
||||
"""Test filtering phone numbers."""
|
||||
content = "Call me at 555-123-4567 or (555) 987-6543."
|
||||
result = await filter_pii_only.filter(content)
|
||||
|
||||
assert result.has_sensitive_content
|
||||
assert "[PHONE]" in result.filtered_content
|
||||
# Should redact both phone numbers
|
||||
assert "555-123-4567" not in result.filtered_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_ssn(self, filter_pii_only: ContentFilter) -> None:
|
||||
"""Test filtering Social Security Numbers."""
|
||||
content = "SSN: 123-45-6789"
|
||||
result = await filter_pii_only.filter(content)
|
||||
|
||||
assert result.has_sensitive_content
|
||||
assert "[SSN]" in result.filtered_content
|
||||
assert "123-45-6789" not in result.filtered_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_credit_card(self, filter_all: ContentFilter) -> None:
|
||||
"""Test filtering credit card numbers."""
|
||||
content = "Card: 4111-2222-3333-4444"
|
||||
result = await filter_all.filter(content)
|
||||
|
||||
assert result.has_sensitive_content
|
||||
assert "[CREDIT_CARD]" in result.filtered_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_block_api_key(self, filter_all: ContentFilter) -> None:
|
||||
"""Test blocking API keys."""
|
||||
content = "api_key: sk-abcdef1234567890abcdef1234567890"
|
||||
result = await filter_all.filter(content)
|
||||
|
||||
assert result.blocked
|
||||
assert result.block_reason is not None
|
||||
assert "api_key" in result.block_reason.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_block_github_token(self, filter_all: ContentFilter) -> None:
|
||||
"""Test blocking GitHub tokens."""
|
||||
content = "token: ghp_abcdefghijklmnopqrstuvwxyz1234567890"
|
||||
result = await filter_all.filter(content)
|
||||
|
||||
assert result.blocked
|
||||
assert len(result.matches) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_block_private_key(self, filter_all: ContentFilter) -> None:
|
||||
"""Test blocking private keys."""
|
||||
content = """
|
||||
-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIEpAIBAAKCAQEA...
|
||||
-----END RSA PRIVATE KEY-----
|
||||
"""
|
||||
result = await filter_all.filter(content)
|
||||
|
||||
assert result.blocked
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_block_password_in_url(self, filter_all: ContentFilter) -> None:
|
||||
"""Test blocking passwords in URLs."""
|
||||
content = "Connect to: postgres://user:secretpassword@localhost/db"
|
||||
result = await filter_all.filter(content)
|
||||
|
||||
assert result.blocked or result.has_sensitive_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_warn_ip_address(self, filter_pii_only: ContentFilter) -> None:
|
||||
"""Test warning on IP addresses."""
|
||||
content = "Server IP: 192.168.1.100"
|
||||
result = await filter_pii_only.filter(content)
|
||||
|
||||
# IP addresses generate warnings, not blocks
|
||||
assert len(result.warnings) > 0 or result.has_sensitive_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_false_positives_clean_content(
|
||||
self,
|
||||
filter_all: ContentFilter,
|
||||
) -> None:
|
||||
"""Test that clean content passes through."""
|
||||
content = "This is a normal message with no sensitive data."
|
||||
result = await filter_all.filter(content)
|
||||
|
||||
assert not result.blocked
|
||||
assert result.filtered_content == content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raise_on_block(self, filter_all: ContentFilter) -> None:
|
||||
"""Test raising exception on blocked content."""
|
||||
content = "-----BEGIN RSA PRIVATE KEY-----"
|
||||
|
||||
with pytest.raises(ContentFilterError):
|
||||
await filter_all.filter(content, raise_on_block=True)
|
||||
|
||||
|
||||
class TestFilterDict:
|
||||
"""Tests for dictionary filtering."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_dict_values(self, filter_pii_only: ContentFilter) -> None:
|
||||
"""Test filtering string values in a dictionary."""
|
||||
data = {
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com",
|
||||
"age": 30,
|
||||
}
|
||||
result = await filter_pii_only.filter_dict(data)
|
||||
|
||||
assert "[EMAIL]" in result["email"]
|
||||
assert result["age"] == 30 # Non-string unchanged
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_dict_recursive(
|
||||
self,
|
||||
filter_pii_only: ContentFilter,
|
||||
) -> None:
|
||||
"""Test recursive dictionary filtering."""
|
||||
data = {
|
||||
"user": {
|
||||
"contact": {
|
||||
"email": "test@example.com",
|
||||
}
|
||||
}
|
||||
}
|
||||
result = await filter_pii_only.filter_dict(data, recursive=True)
|
||||
|
||||
assert "[EMAIL]" in result["user"]["contact"]["email"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_dict_specific_keys(
|
||||
self,
|
||||
filter_pii_only: ContentFilter,
|
||||
) -> None:
|
||||
"""Test filtering specific keys only."""
|
||||
data = {
|
||||
"public_email": "public@example.com",
|
||||
"private_email": "private@example.com",
|
||||
}
|
||||
result = await filter_pii_only.filter_dict(
|
||||
data,
|
||||
keys_to_filter=["private_email"],
|
||||
)
|
||||
|
||||
# Only private_email should be filtered
|
||||
assert "public@example.com" in result["public_email"]
|
||||
assert "[EMAIL]" in result["private_email"]
|
||||
|
||||
|
||||
class TestScan:
|
||||
"""Tests for content scanning."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_without_filtering(
|
||||
self,
|
||||
filter_all: ContentFilter,
|
||||
) -> None:
|
||||
"""Test scanning without modifying content."""
|
||||
content = "Email: test@example.com, SSN: 123-45-6789"
|
||||
matches = await filter_all.scan(content)
|
||||
|
||||
assert len(matches) >= 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_specific_categories(
|
||||
self,
|
||||
filter_all: ContentFilter,
|
||||
) -> None:
|
||||
"""Test scanning for specific categories only."""
|
||||
content = "Email: test@example.com, token: ghp_abc123456789012345678901234567890123"
|
||||
|
||||
# Scan only for secrets
|
||||
matches = await filter_all.scan(
|
||||
content,
|
||||
categories=[ContentCategory.SECRETS],
|
||||
)
|
||||
|
||||
# Should only find the token, not the email
|
||||
assert all(m.category == ContentCategory.SECRETS for m in matches)
|
||||
|
||||
|
||||
class TestValidateSafe:
|
||||
"""Tests for safe validation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_safe_clean(self, filter_all: ContentFilter) -> None:
|
||||
"""Test validation of clean content."""
|
||||
content = "This is safe content."
|
||||
is_safe, issues = await filter_all.validate_safe(content)
|
||||
|
||||
assert is_safe is True
|
||||
assert len(issues) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_safe_with_secrets(
|
||||
self,
|
||||
filter_all: ContentFilter,
|
||||
) -> None:
|
||||
"""Test validation of content with secrets."""
|
||||
content = "-----BEGIN RSA PRIVATE KEY-----"
|
||||
is_safe, issues = await filter_all.validate_safe(content)
|
||||
|
||||
assert is_safe is False
|
||||
assert len(issues) > 0
|
||||
|
||||
|
||||
class TestCustomPatterns:
|
||||
"""Tests for custom pattern support."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_custom_pattern(self) -> None:
|
||||
"""Test adding a custom filter pattern."""
|
||||
content_filter = ContentFilter(
|
||||
enable_pii_filter=False,
|
||||
enable_secret_filter=False,
|
||||
enable_injection_filter=False,
|
||||
)
|
||||
|
||||
# Add custom pattern for internal IDs
|
||||
content_filter.add_pattern(
|
||||
FilterPattern(
|
||||
name="internal_id",
|
||||
category=ContentCategory.CUSTOM,
|
||||
pattern=r"INTERNAL-[A-Z0-9]{8}",
|
||||
action=FilterAction.REDACT,
|
||||
replacement="[INTERNAL_ID]",
|
||||
)
|
||||
)
|
||||
|
||||
content = "Reference: INTERNAL-ABC12345"
|
||||
result = await content_filter.filter(content)
|
||||
|
||||
assert "[INTERNAL_ID]" in result.filtered_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disable_pattern(self, filter_pii_only: ContentFilter) -> None:
|
||||
"""Test disabling a pattern."""
|
||||
filter_pii_only.enable_pattern("email", enabled=False)
|
||||
|
||||
content = "Email: test@example.com"
|
||||
result = await filter_pii_only.filter(content)
|
||||
|
||||
# Email should not be filtered
|
||||
assert "test@example.com" in result.filtered_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_pattern(self, filter_pii_only: ContentFilter) -> None:
|
||||
"""Test removing a pattern."""
|
||||
removed = filter_pii_only.remove_pattern("email")
|
||||
assert removed is True
|
||||
|
||||
content = "Email: test@example.com"
|
||||
result = await filter_pii_only.filter(content)
|
||||
|
||||
# Email should not be filtered
|
||||
assert "test@example.com" in result.filtered_content
|
||||
|
||||
|
||||
class TestPatternStats:
|
||||
"""Tests for pattern statistics."""
|
||||
|
||||
def test_get_pattern_stats(self, filter_all: ContentFilter) -> None:
|
||||
"""Test getting pattern statistics."""
|
||||
stats = filter_all.get_pattern_stats()
|
||||
|
||||
assert stats["total_patterns"] > 0
|
||||
assert "by_category" in stats
|
||||
assert "by_action" in stats
|
||||
|
||||
|
||||
class TestConvenienceFunctions:
|
||||
"""Tests for convenience functions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_content_function(self) -> None:
|
||||
"""Test the quick filter function."""
|
||||
content = "Email: test@example.com"
|
||||
filtered = await filter_content(content)
|
||||
|
||||
assert "test@example.com" not in filtered
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_for_secrets_function(self) -> None:
|
||||
"""Test the quick secret scan function."""
|
||||
content = "Token: ghp_abcdefghijklmnopqrstuvwxyz1234567890"
|
||||
matches = await scan_for_secrets(content)
|
||||
|
||||
assert len(matches) > 0
|
||||
assert matches[0].category in (
|
||||
ContentCategory.SECRETS,
|
||||
ContentCategory.CREDENTIALS,
|
||||
)
|
||||
425
backend/tests/services/safety/test_emergency.py
Normal file
425
backend/tests/services/safety/test_emergency.py
Normal file
@@ -0,0 +1,425 @@
|
||||
"""Tests for emergency controls module."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.safety.emergency.controls import (
|
||||
EmergencyControls,
|
||||
EmergencyReason,
|
||||
EmergencyState,
|
||||
EmergencyTrigger,
|
||||
)
|
||||
from app.services.safety.exceptions import EmergencyStopError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def controls() -> EmergencyControls:
|
||||
"""Create fresh EmergencyControls."""
|
||||
return EmergencyControls()
|
||||
|
||||
|
||||
class TestEmergencyControls:
|
||||
"""Tests for EmergencyControls class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initial_state_is_normal(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test that initial state is normal."""
|
||||
state = await controls.get_state("global")
|
||||
assert state == EmergencyState.NORMAL
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emergency_stop_changes_state(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test that emergency stop changes state to stopped."""
|
||||
event = await controls.emergency_stop(
|
||||
reason=EmergencyReason.MANUAL,
|
||||
triggered_by="test",
|
||||
message="Test emergency stop",
|
||||
)
|
||||
|
||||
assert event.state == EmergencyState.STOPPED
|
||||
state = await controls.get_state("global")
|
||||
assert state == EmergencyState.STOPPED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_changes_state(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test that pause changes state to paused."""
|
||||
event = await controls.pause(
|
||||
reason=EmergencyReason.BUDGET_EXCEEDED,
|
||||
triggered_by="budget_controller",
|
||||
message="Budget exceeded",
|
||||
)
|
||||
|
||||
assert event.state == EmergencyState.PAUSED
|
||||
state = await controls.get_state("global")
|
||||
assert state == EmergencyState.PAUSED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_from_paused(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test resuming from paused state."""
|
||||
await controls.pause(
|
||||
reason=EmergencyReason.RATE_LIMIT,
|
||||
triggered_by="limiter",
|
||||
message="Rate limited",
|
||||
)
|
||||
|
||||
resumed = await controls.resume(resumed_by="admin")
|
||||
|
||||
assert resumed is True
|
||||
state = await controls.get_state("global")
|
||||
assert state == EmergencyState.NORMAL
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_resume_from_stopped(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test that you cannot resume from stopped state."""
|
||||
await controls.emergency_stop(
|
||||
reason=EmergencyReason.SAFETY_VIOLATION,
|
||||
triggered_by="safety",
|
||||
message="Critical violation",
|
||||
)
|
||||
|
||||
resumed = await controls.resume(resumed_by="admin")
|
||||
|
||||
assert resumed is False
|
||||
state = await controls.get_state("global")
|
||||
assert state == EmergencyState.STOPPED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_from_stopped(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test resetting from stopped state."""
|
||||
await controls.emergency_stop(
|
||||
reason=EmergencyReason.SAFETY_VIOLATION,
|
||||
triggered_by="safety",
|
||||
message="Critical violation",
|
||||
)
|
||||
|
||||
reset = await controls.reset(reset_by="admin")
|
||||
|
||||
assert reset is True
|
||||
state = await controls.get_state("global")
|
||||
assert state == EmergencyState.NORMAL
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoped_emergency_stop(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test emergency stop with specific scope."""
|
||||
await controls.emergency_stop(
|
||||
reason=EmergencyReason.LOOP_DETECTED,
|
||||
triggered_by="detector",
|
||||
message="Loop in agent",
|
||||
scope="agent:agent-123",
|
||||
)
|
||||
|
||||
# Agent scope should be stopped
|
||||
agent_state = await controls.get_state("agent:agent-123")
|
||||
assert agent_state == EmergencyState.STOPPED
|
||||
|
||||
# Global should still be normal
|
||||
global_state = await controls.get_state("global")
|
||||
assert global_state == EmergencyState.NORMAL
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_allowed_when_normal(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test check_allowed returns True when state is normal."""
|
||||
allowed = await controls.check_allowed()
|
||||
assert allowed is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_allowed_when_stopped(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test check_allowed returns False when stopped."""
|
||||
await controls.emergency_stop(
|
||||
reason=EmergencyReason.MANUAL,
|
||||
triggered_by="test",
|
||||
message="Stop",
|
||||
)
|
||||
|
||||
allowed = await controls.check_allowed(raise_if_blocked=False)
|
||||
assert allowed is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_allowed_raises_when_blocked(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test check_allowed raises exception when blocked."""
|
||||
await controls.emergency_stop(
|
||||
reason=EmergencyReason.MANUAL,
|
||||
triggered_by="test",
|
||||
message="Stop",
|
||||
)
|
||||
|
||||
with pytest.raises(EmergencyStopError):
|
||||
await controls.check_allowed(raise_if_blocked=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_allowed_with_scope(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test check_allowed with specific scope."""
|
||||
await controls.pause(
|
||||
reason=EmergencyReason.BUDGET_EXCEEDED,
|
||||
triggered_by="budget",
|
||||
message="Paused",
|
||||
scope="project:proj-123",
|
||||
)
|
||||
|
||||
# Project scope should be blocked
|
||||
allowed_project = await controls.check_allowed(
|
||||
scope="project:proj-123",
|
||||
raise_if_blocked=False,
|
||||
)
|
||||
assert allowed_project is False
|
||||
|
||||
# Different scope should be allowed
|
||||
allowed_other = await controls.check_allowed(
|
||||
scope="project:proj-456",
|
||||
raise_if_blocked=False,
|
||||
)
|
||||
assert allowed_other is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_states(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test getting all states."""
|
||||
await controls.pause(
|
||||
reason=EmergencyReason.MANUAL,
|
||||
triggered_by="test",
|
||||
message="Pause",
|
||||
scope="agent:a1",
|
||||
)
|
||||
await controls.emergency_stop(
|
||||
reason=EmergencyReason.MANUAL,
|
||||
triggered_by="test",
|
||||
message="Stop",
|
||||
scope="agent:a2",
|
||||
)
|
||||
|
||||
states = await controls.get_all_states()
|
||||
|
||||
assert states["global"] == EmergencyState.NORMAL
|
||||
assert states["agent:a1"] == EmergencyState.PAUSED
|
||||
assert states["agent:a2"] == EmergencyState.STOPPED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_events(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test getting active (unresolved) events."""
|
||||
await controls.pause(
|
||||
reason=EmergencyReason.MANUAL,
|
||||
triggered_by="test",
|
||||
message="Pause 1",
|
||||
)
|
||||
|
||||
events = await controls.get_active_events()
|
||||
assert len(events) == 1
|
||||
|
||||
# Resume should resolve the event
|
||||
await controls.resume()
|
||||
events_after = await controls.get_active_events()
|
||||
assert len(events_after) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_history(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test getting event history."""
|
||||
await controls.pause(
|
||||
reason=EmergencyReason.RATE_LIMIT,
|
||||
triggered_by="test",
|
||||
message="Rate limited",
|
||||
)
|
||||
await controls.resume()
|
||||
|
||||
history = await controls.get_event_history()
|
||||
|
||||
assert len(history) == 1
|
||||
assert history[0].reason == EmergencyReason.RATE_LIMIT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_metadata(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test event metadata storage."""
|
||||
event = await controls.emergency_stop(
|
||||
reason=EmergencyReason.BUDGET_EXCEEDED,
|
||||
triggered_by="budget_controller",
|
||||
message="Over budget",
|
||||
metadata={"budget_type": "tokens", "usage": 150000},
|
||||
)
|
||||
|
||||
assert event.metadata["budget_type"] == "tokens"
|
||||
assert event.metadata["usage"] == 150000
|
||||
|
||||
|
||||
class TestCallbacks:
|
||||
"""Tests for callback functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_stop_callback(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test on_stop callback is called."""
|
||||
callback_called = []
|
||||
|
||||
def callback(event: object) -> None:
|
||||
callback_called.append(event)
|
||||
|
||||
controls.on_stop(callback)
|
||||
|
||||
await controls.emergency_stop(
|
||||
reason=EmergencyReason.MANUAL,
|
||||
triggered_by="test",
|
||||
message="Stop",
|
||||
)
|
||||
|
||||
assert len(callback_called) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_pause_callback(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test on_pause callback is called."""
|
||||
callback_called = []
|
||||
|
||||
def callback(event: object) -> None:
|
||||
callback_called.append(event)
|
||||
|
||||
controls.on_pause(callback)
|
||||
|
||||
await controls.pause(
|
||||
reason=EmergencyReason.MANUAL,
|
||||
triggered_by="test",
|
||||
message="Pause",
|
||||
)
|
||||
|
||||
assert len(callback_called) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_resume_callback(
|
||||
self,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test on_resume callback is called."""
|
||||
callback_called = []
|
||||
|
||||
def callback(data: object) -> None:
|
||||
callback_called.append(data)
|
||||
|
||||
controls.on_resume(callback)
|
||||
|
||||
await controls.pause(
|
||||
reason=EmergencyReason.MANUAL,
|
||||
triggered_by="test",
|
||||
message="Pause",
|
||||
)
|
||||
await controls.resume()
|
||||
|
||||
assert len(callback_called) == 1
|
||||
|
||||
|
||||
class TestEmergencyTrigger:
|
||||
"""Tests for EmergencyTrigger class."""
|
||||
|
||||
@pytest.fixture
|
||||
def trigger(self, controls: EmergencyControls) -> EmergencyTrigger:
|
||||
"""Create an EmergencyTrigger."""
|
||||
return EmergencyTrigger(controls)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_on_safety_violation(
|
||||
self,
|
||||
trigger: EmergencyTrigger,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test triggering emergency on safety violation."""
|
||||
event = await trigger.trigger_on_safety_violation(
|
||||
violation_type="unauthorized_access",
|
||||
details={"resource": "/secrets/key"},
|
||||
)
|
||||
|
||||
assert event.reason == EmergencyReason.SAFETY_VIOLATION
|
||||
assert event.state == EmergencyState.STOPPED
|
||||
|
||||
state = await controls.get_state("global")
|
||||
assert state == EmergencyState.STOPPED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_on_budget_exceeded(
|
||||
self,
|
||||
trigger: EmergencyTrigger,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test triggering pause on budget exceeded."""
|
||||
event = await trigger.trigger_on_budget_exceeded(
|
||||
budget_type="tokens",
|
||||
current=150000,
|
||||
limit=100000,
|
||||
)
|
||||
|
||||
assert event.reason == EmergencyReason.BUDGET_EXCEEDED
|
||||
assert event.state == EmergencyState.PAUSED # Pause, not stop
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_on_loop_detected(
|
||||
self,
|
||||
trigger: EmergencyTrigger,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test triggering pause on loop detection."""
|
||||
event = await trigger.trigger_on_loop_detected(
|
||||
loop_type="exact",
|
||||
agent_id="agent-123",
|
||||
details={"pattern": "file_read"},
|
||||
)
|
||||
|
||||
assert event.reason == EmergencyReason.LOOP_DETECTED
|
||||
assert event.scope == "agent:agent-123"
|
||||
assert event.state == EmergencyState.PAUSED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_on_content_violation(
|
||||
self,
|
||||
trigger: EmergencyTrigger,
|
||||
controls: EmergencyControls,
|
||||
) -> None:
|
||||
"""Test triggering stop on content violation."""
|
||||
event = await trigger.trigger_on_content_violation(
|
||||
category="secrets",
|
||||
pattern="private_key",
|
||||
)
|
||||
|
||||
assert event.reason == EmergencyReason.CONTENT_VIOLATION
|
||||
assert event.state == EmergencyState.STOPPED
|
||||
316
backend/tests/services/safety/test_loops.py
Normal file
316
backend/tests/services/safety/test_loops.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""Tests for loop detection module."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.safety.exceptions import LoopDetectedError
|
||||
from app.services.safety.loops.detector import (
|
||||
ActionSignature,
|
||||
LoopBreaker,
|
||||
LoopDetector,
|
||||
)
|
||||
from app.services.safety.models import (
|
||||
ActionMetadata,
|
||||
ActionRequest,
|
||||
ActionType,
|
||||
AutonomyLevel,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def detector() -> LoopDetector:
|
||||
"""Create a fresh LoopDetector with low thresholds for testing."""
|
||||
return LoopDetector(
|
||||
history_size=20,
|
||||
max_exact_repetitions=3,
|
||||
max_semantic_repetitions=5,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_metadata() -> ActionMetadata:
|
||||
"""Create sample action metadata."""
|
||||
return ActionMetadata(
|
||||
agent_id="test-agent",
|
||||
session_id="test-session",
|
||||
autonomy_level=AutonomyLevel.MILESTONE,
|
||||
)
|
||||
|
||||
|
||||
def create_action(
|
||||
metadata: ActionMetadata,
|
||||
tool_name: str,
|
||||
resource: str = "/tmp/test.txt", # noqa: S108
|
||||
arguments: dict | None = None,
|
||||
) -> ActionRequest:
|
||||
"""Helper to create test actions."""
|
||||
return ActionRequest(
|
||||
action_type=ActionType.FILE_READ,
|
||||
tool_name=tool_name,
|
||||
resource=resource,
|
||||
arguments=arguments or {},
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
class TestActionSignature:
|
||||
"""Tests for ActionSignature class."""
|
||||
|
||||
def test_exact_key_includes_args(self, sample_metadata: ActionMetadata) -> None:
|
||||
"""Test that exact key includes argument hash."""
|
||||
action1 = create_action(sample_metadata, "file_read", arguments={"path": "a"})
|
||||
action2 = create_action(sample_metadata, "file_read", arguments={"path": "b"})
|
||||
|
||||
sig1 = ActionSignature(action1)
|
||||
sig2 = ActionSignature(action2)
|
||||
|
||||
assert sig1.exact_key() != sig2.exact_key()
|
||||
|
||||
def test_semantic_key_ignores_args(self, sample_metadata: ActionMetadata) -> None:
|
||||
"""Test that semantic key ignores arguments."""
|
||||
action1 = create_action(sample_metadata, "file_read", arguments={"path": "a"})
|
||||
action2 = create_action(sample_metadata, "file_read", arguments={"path": "b"})
|
||||
|
||||
sig1 = ActionSignature(action1)
|
||||
sig2 = ActionSignature(action2)
|
||||
|
||||
assert sig1.semantic_key() == sig2.semantic_key()
|
||||
|
||||
def test_type_key(self, sample_metadata: ActionMetadata) -> None:
|
||||
"""Test type key extraction."""
|
||||
action = create_action(sample_metadata, "file_read")
|
||||
sig = ActionSignature(action)
|
||||
|
||||
assert sig.type_key() == "file_read"
|
||||
|
||||
|
||||
class TestLoopDetector:
|
||||
"""Tests for LoopDetector class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_loop_on_first_action(
|
||||
self,
|
||||
detector: LoopDetector,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test that first action is never a loop."""
|
||||
action = create_action(sample_metadata, "file_read")
|
||||
|
||||
is_loop, loop_type = await detector.check(action)
|
||||
|
||||
assert is_loop is False
|
||||
assert loop_type is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exact_loop_detection(
|
||||
self,
|
||||
detector: LoopDetector,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test detection of exact repetitions."""
|
||||
action = create_action(
|
||||
sample_metadata,
|
||||
"file_read",
|
||||
resource="/tmp/same.txt", # noqa: S108
|
||||
arguments={"path": "/tmp/same.txt"}, # noqa: S108
|
||||
)
|
||||
|
||||
# Record the same action 3 times (threshold)
|
||||
for _ in range(3):
|
||||
await detector.record(action)
|
||||
|
||||
# Next should be detected as a loop
|
||||
is_loop, loop_type = await detector.check(action)
|
||||
|
||||
assert is_loop is True
|
||||
assert loop_type == "exact"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_semantic_loop_detection(
|
||||
self,
|
||||
detector: LoopDetector,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test detection of semantic (similar) repetitions."""
|
||||
# Record same tool/resource but different arguments
|
||||
test_resource = "/tmp/test.txt" # noqa: S108
|
||||
for i in range(5):
|
||||
action = create_action(
|
||||
sample_metadata,
|
||||
"file_read",
|
||||
resource=test_resource,
|
||||
arguments={"offset": i},
|
||||
)
|
||||
await detector.record(action)
|
||||
|
||||
# Next similar action should be detected as semantic loop
|
||||
action = create_action(
|
||||
sample_metadata,
|
||||
"file_read",
|
||||
resource=test_resource,
|
||||
arguments={"offset": 100},
|
||||
)
|
||||
is_loop, loop_type = await detector.check(action)
|
||||
|
||||
assert is_loop is True
|
||||
assert loop_type == "semantic"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oscillation_detection(
|
||||
self,
|
||||
detector: LoopDetector,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test detection of A→B→A→B oscillation pattern."""
|
||||
action_a = create_action(sample_metadata, "tool_a", resource="/a")
|
||||
action_b = create_action(sample_metadata, "tool_b", resource="/b")
|
||||
|
||||
# Create A→B→A pattern
|
||||
await detector.record(action_a)
|
||||
await detector.record(action_b)
|
||||
await detector.record(action_a)
|
||||
|
||||
# Fourth action completing A→B→A→B should be detected as oscillation
|
||||
is_loop, loop_type = await detector.check(action_b)
|
||||
|
||||
assert is_loop is True
|
||||
assert loop_type == "oscillation"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_actions_no_loop(
|
||||
self,
|
||||
detector: LoopDetector,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test that different actions don't trigger loops."""
|
||||
for i in range(10):
|
||||
action = create_action(
|
||||
sample_metadata,
|
||||
f"tool_{i}",
|
||||
resource=f"/resource_{i}",
|
||||
)
|
||||
is_loop, _ = await detector.check(action)
|
||||
assert is_loop is False
|
||||
await detector.record(action)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_and_raise(
|
||||
self,
|
||||
detector: LoopDetector,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test check_and_raise raises on loop detection."""
|
||||
action = create_action(sample_metadata, "file_read")
|
||||
|
||||
# Record threshold number of times
|
||||
for _ in range(3):
|
||||
await detector.record(action)
|
||||
|
||||
# Should raise
|
||||
with pytest.raises(LoopDetectedError) as exc_info:
|
||||
await detector.check_and_raise(action)
|
||||
|
||||
assert "exact" in exc_info.value.loop_type.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_history(
|
||||
self,
|
||||
detector: LoopDetector,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test clearing agent history."""
|
||||
action = create_action(sample_metadata, "file_read")
|
||||
|
||||
# Record multiple times
|
||||
for _ in range(3):
|
||||
await detector.record(action)
|
||||
|
||||
# Clear history
|
||||
await detector.clear_history(sample_metadata.agent_id)
|
||||
|
||||
# Should no longer detect loop
|
||||
is_loop, _ = await detector.check(action)
|
||||
assert is_loop is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_per_agent_history(
|
||||
self,
|
||||
detector: LoopDetector,
|
||||
) -> None:
|
||||
"""Test that history is tracked per agent."""
|
||||
metadata1 = ActionMetadata(agent_id="agent-1", session_id="s1")
|
||||
metadata2 = ActionMetadata(agent_id="agent-2", session_id="s2")
|
||||
|
||||
action1 = create_action(metadata1, "file_read")
|
||||
action2 = create_action(metadata2, "file_read")
|
||||
|
||||
# Record for agent 1 (threshold times)
|
||||
for _ in range(3):
|
||||
await detector.record(action1)
|
||||
|
||||
# Agent 1 should detect loop
|
||||
is_loop1, _ = await detector.check(action1)
|
||||
assert is_loop1 is True
|
||||
|
||||
# Agent 2 should not detect loop
|
||||
is_loop2, _ = await detector.check(action2)
|
||||
assert is_loop2 is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stats(
|
||||
self,
|
||||
detector: LoopDetector,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test getting loop detection stats."""
|
||||
for i in range(5):
|
||||
action = create_action(
|
||||
sample_metadata,
|
||||
f"tool_{i % 2}", # Alternate between 2 tools
|
||||
resource=f"/resource_{i}",
|
||||
)
|
||||
await detector.record(action)
|
||||
|
||||
stats = await detector.get_stats(sample_metadata.agent_id)
|
||||
|
||||
assert stats["history_size"] == 5
|
||||
assert len(stats["action_type_counts"]) > 0
|
||||
|
||||
|
||||
class TestLoopBreaker:
|
||||
"""Tests for LoopBreaker class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_suggest_alternatives_exact(
|
||||
self,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test suggestions for exact loops."""
|
||||
action = create_action(sample_metadata, "file_read")
|
||||
suggestions = await LoopBreaker.suggest_alternatives(action, "exact")
|
||||
|
||||
assert len(suggestions) > 0
|
||||
assert "same action" in suggestions[0].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_suggest_alternatives_semantic(
|
||||
self,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test suggestions for semantic loops."""
|
||||
action = create_action(sample_metadata, "file_read")
|
||||
suggestions = await LoopBreaker.suggest_alternatives(action, "semantic")
|
||||
|
||||
assert len(suggestions) > 0
|
||||
assert "similar" in suggestions[0].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_suggest_alternatives_oscillation(
|
||||
self,
|
||||
sample_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test suggestions for oscillation loops."""
|
||||
action = create_action(sample_metadata, "file_read")
|
||||
suggestions = await LoopBreaker.suggest_alternatives(action, "oscillation")
|
||||
|
||||
assert len(suggestions) > 0
|
||||
assert "oscillat" in suggestions[0].lower()
|
||||
437
backend/tests/services/safety/test_models.py
Normal file
437
backend/tests/services/safety/test_models.py
Normal file
@@ -0,0 +1,437 @@
|
||||
"""Tests for safety framework models."""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from app.services.safety.models import (
|
||||
ActionMetadata,
|
||||
ActionRequest,
|
||||
ActionResult,
|
||||
ActionType,
|
||||
ApprovalRequest,
|
||||
ApprovalResponse,
|
||||
ApprovalStatus,
|
||||
AuditEvent,
|
||||
AuditEventType,
|
||||
AutonomyLevel,
|
||||
BudgetScope,
|
||||
BudgetStatus,
|
||||
Checkpoint,
|
||||
CheckpointType,
|
||||
GuardianResult,
|
||||
PermissionLevel,
|
||||
RateLimitConfig,
|
||||
RollbackResult,
|
||||
SafetyDecision,
|
||||
SafetyPolicy,
|
||||
ValidationResult,
|
||||
ValidationRule,
|
||||
)
|
||||
|
||||
|
||||
class TestActionMetadata:
|
||||
"""Tests for ActionMetadata model."""
|
||||
|
||||
def test_create_with_defaults(self) -> None:
|
||||
"""Test creating metadata with default values."""
|
||||
metadata = ActionMetadata(
|
||||
agent_id="agent-1",
|
||||
)
|
||||
|
||||
assert metadata.agent_id == "agent-1"
|
||||
assert metadata.autonomy_level == AutonomyLevel.MILESTONE
|
||||
assert metadata.project_id is None
|
||||
assert metadata.session_id is None
|
||||
|
||||
def test_create_with_all_fields(self) -> None:
|
||||
"""Test creating metadata with all fields."""
|
||||
metadata = ActionMetadata(
|
||||
agent_id="agent-1",
|
||||
session_id="session-1",
|
||||
project_id="project-1",
|
||||
user_id="user-1",
|
||||
autonomy_level=AutonomyLevel.AUTONOMOUS,
|
||||
)
|
||||
|
||||
assert metadata.project_id == "project-1"
|
||||
assert metadata.user_id == "user-1"
|
||||
assert metadata.autonomy_level == AutonomyLevel.AUTONOMOUS
|
||||
|
||||
|
||||
class TestActionRequest:
|
||||
"""Tests for ActionRequest model."""
|
||||
|
||||
def test_create_basic_action(self) -> None:
|
||||
"""Test creating a basic action request."""
|
||||
metadata = ActionMetadata(agent_id="agent-1", session_id="session-1")
|
||||
action = ActionRequest(
|
||||
action_type=ActionType.FILE_READ,
|
||||
tool_name="file_read",
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
assert action.action_type == ActionType.FILE_READ
|
||||
assert action.tool_name == "file_read"
|
||||
assert action.id is not None
|
||||
assert action.metadata.agent_id == "agent-1"
|
||||
|
||||
def test_action_with_arguments(self) -> None:
|
||||
"""Test action with arguments."""
|
||||
test_path = "/tmp/test.txt" # noqa: S108
|
||||
metadata = ActionMetadata(agent_id="agent-1", session_id="session-1")
|
||||
action = ActionRequest(
|
||||
action_type=ActionType.FILE_WRITE,
|
||||
tool_name="file_write",
|
||||
arguments={"path": test_path, "content": "hello"},
|
||||
resource=test_path,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
assert action.arguments["path"] == test_path
|
||||
assert action.resource == test_path
|
||||
|
||||
|
||||
class TestActionResult:
|
||||
"""Tests for ActionResult model."""
|
||||
|
||||
def test_successful_result(self) -> None:
|
||||
"""Test creating a successful result."""
|
||||
result = ActionResult(
|
||||
action_id="action-1",
|
||||
success=True,
|
||||
data={"output": "done"},
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["output"] == "done"
|
||||
assert result.error is None
|
||||
|
||||
def test_failed_result(self) -> None:
|
||||
"""Test creating a failed result."""
|
||||
result = ActionResult(
|
||||
action_id="action-1",
|
||||
success=False,
|
||||
error="Permission denied",
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert result.error == "Permission denied"
|
||||
|
||||
|
||||
class TestValidationRule:
|
||||
"""Tests for ValidationRule model."""
|
||||
|
||||
def test_create_rule(self) -> None:
|
||||
"""Test creating a validation rule."""
|
||||
rule = ValidationRule(
|
||||
name="deny_shell",
|
||||
description="Deny shell commands",
|
||||
priority=100,
|
||||
tool_patterns=["shell_*"],
|
||||
decision=SafetyDecision.DENY,
|
||||
reason="Shell commands are not allowed",
|
||||
)
|
||||
|
||||
assert rule.name == "deny_shell"
|
||||
assert rule.priority == 100
|
||||
assert rule.decision == SafetyDecision.DENY
|
||||
assert rule.enabled is True
|
||||
|
||||
def test_rule_defaults(self) -> None:
|
||||
"""Test rule default values."""
|
||||
rule = ValidationRule(name="test_rule", decision=SafetyDecision.ALLOW)
|
||||
|
||||
assert rule.id is not None
|
||||
assert rule.priority == 0
|
||||
assert rule.enabled is True
|
||||
assert rule.decision == SafetyDecision.ALLOW
|
||||
|
||||
|
||||
class TestValidationResult:
|
||||
"""Tests for ValidationResult model."""
|
||||
|
||||
def test_allow_result(self) -> None:
|
||||
"""Test an allow result."""
|
||||
result = ValidationResult(
|
||||
action_id="action-1",
|
||||
decision=SafetyDecision.ALLOW,
|
||||
applied_rules=["rule-1"],
|
||||
reasons=["Action is allowed"],
|
||||
)
|
||||
|
||||
assert result.decision == SafetyDecision.ALLOW
|
||||
assert len(result.applied_rules) == 1
|
||||
|
||||
def test_deny_result(self) -> None:
|
||||
"""Test a deny result."""
|
||||
result = ValidationResult(
|
||||
action_id="action-1",
|
||||
decision=SafetyDecision.DENY,
|
||||
applied_rules=["deny_rule"],
|
||||
reasons=["Action is not permitted"],
|
||||
)
|
||||
|
||||
assert result.decision == SafetyDecision.DENY
|
||||
|
||||
|
||||
class TestBudgetStatus:
|
||||
"""Tests for BudgetStatus model."""
|
||||
|
||||
def test_under_budget(self) -> None:
|
||||
"""Test status when under budget."""
|
||||
status = BudgetStatus(
|
||||
scope=BudgetScope.SESSION,
|
||||
scope_id="session-1",
|
||||
tokens_used=5000,
|
||||
tokens_limit=10000,
|
||||
tokens_remaining=5000,
|
||||
)
|
||||
|
||||
assert status.tokens_remaining == 5000
|
||||
assert status.tokens_used == 5000
|
||||
|
||||
def test_over_budget(self) -> None:
|
||||
"""Test status when over budget."""
|
||||
status = BudgetStatus(
|
||||
scope=BudgetScope.SESSION,
|
||||
scope_id="session-1",
|
||||
cost_used_usd=15.0,
|
||||
cost_limit_usd=10.0,
|
||||
cost_remaining_usd=0.0,
|
||||
is_exceeded=True,
|
||||
)
|
||||
|
||||
assert status.is_exceeded is True
|
||||
assert status.cost_remaining_usd == 0.0
|
||||
|
||||
|
||||
class TestRateLimitConfig:
|
||||
"""Tests for RateLimitConfig model."""
|
||||
|
||||
def test_create_config(self) -> None:
|
||||
"""Test creating rate limit config."""
|
||||
config = RateLimitConfig(
|
||||
name="actions",
|
||||
limit=60,
|
||||
window_seconds=60,
|
||||
)
|
||||
|
||||
assert config.name == "actions"
|
||||
assert config.limit == 60
|
||||
assert config.window_seconds == 60
|
||||
|
||||
|
||||
class TestApprovalRequest:
|
||||
"""Tests for ApprovalRequest model."""
|
||||
|
||||
def test_create_request(self) -> None:
|
||||
"""Test creating an approval request."""
|
||||
metadata = ActionMetadata(agent_id="agent-1", session_id="session-1")
|
||||
action = ActionRequest(
|
||||
action_type=ActionType.DATABASE_MUTATE,
|
||||
tool_name="db_delete",
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
request = ApprovalRequest(
|
||||
id="approval-1",
|
||||
action=action,
|
||||
reason="Database mutation requires approval",
|
||||
urgency="high",
|
||||
timeout_seconds=300,
|
||||
)
|
||||
|
||||
assert request.id == "approval-1"
|
||||
assert request.urgency == "high"
|
||||
assert request.timeout_seconds == 300
|
||||
|
||||
|
||||
class TestApprovalResponse:
|
||||
"""Tests for ApprovalResponse model."""
|
||||
|
||||
def test_approved_response(self) -> None:
|
||||
"""Test an approved response."""
|
||||
response = ApprovalResponse(
|
||||
request_id="approval-1",
|
||||
status=ApprovalStatus.APPROVED,
|
||||
decided_by="admin",
|
||||
reason="Looks safe",
|
||||
)
|
||||
|
||||
assert response.status == ApprovalStatus.APPROVED
|
||||
assert response.decided_by == "admin"
|
||||
|
||||
def test_denied_response(self) -> None:
|
||||
"""Test a denied response."""
|
||||
response = ApprovalResponse(
|
||||
request_id="approval-1",
|
||||
status=ApprovalStatus.DENIED,
|
||||
decided_by="admin",
|
||||
reason="Too risky",
|
||||
)
|
||||
|
||||
assert response.status == ApprovalStatus.DENIED
|
||||
|
||||
|
||||
class TestCheckpoint:
|
||||
"""Tests for Checkpoint model."""
|
||||
|
||||
def test_create_checkpoint(self) -> None:
|
||||
"""Test creating a checkpoint."""
|
||||
test_path = "/tmp/test.txt" # noqa: S108
|
||||
checkpoint = Checkpoint(
|
||||
id="checkpoint-1",
|
||||
checkpoint_type=CheckpointType.FILE,
|
||||
action_id="action-1",
|
||||
created_at=datetime.utcnow(),
|
||||
data={"path": test_path},
|
||||
description="File checkpoint",
|
||||
)
|
||||
|
||||
assert checkpoint.id == "checkpoint-1"
|
||||
assert checkpoint.is_valid is True
|
||||
|
||||
def test_expired_checkpoint(self) -> None:
|
||||
"""Test an expired checkpoint."""
|
||||
checkpoint = Checkpoint(
|
||||
id="checkpoint-1",
|
||||
checkpoint_type=CheckpointType.FILE,
|
||||
action_id="action-1",
|
||||
created_at=datetime.utcnow() - timedelta(hours=2),
|
||||
expires_at=datetime.utcnow() - timedelta(hours=1),
|
||||
data={},
|
||||
)
|
||||
|
||||
# is_valid is a simple bool, not computed from expires_at
|
||||
# The RollbackManager handles expiration logic
|
||||
assert checkpoint.is_valid is True # Default value
|
||||
|
||||
|
||||
class TestRollbackResult:
|
||||
"""Tests for RollbackResult model."""
|
||||
|
||||
def test_successful_rollback(self) -> None:
|
||||
"""Test a successful rollback."""
|
||||
result = RollbackResult(
|
||||
checkpoint_id="checkpoint-1",
|
||||
success=True,
|
||||
actions_rolled_back=["file:/tmp/test.txt"],
|
||||
failed_actions=[],
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert len(result.actions_rolled_back) == 1
|
||||
|
||||
def test_partial_rollback(self) -> None:
|
||||
"""Test a partial rollback."""
|
||||
result = RollbackResult(
|
||||
checkpoint_id="checkpoint-1",
|
||||
success=False,
|
||||
actions_rolled_back=["file:/tmp/a.txt"],
|
||||
failed_actions=["file:/tmp/b.txt"],
|
||||
error="Failed to rollback 1 item",
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert len(result.failed_actions) == 1
|
||||
|
||||
|
||||
class TestAuditEvent:
|
||||
"""Tests for AuditEvent model."""
|
||||
|
||||
def test_create_event(self) -> None:
|
||||
"""Test creating an audit event."""
|
||||
event = AuditEvent(
|
||||
id="event-1",
|
||||
event_type=AuditEventType.ACTION_EXECUTED,
|
||||
timestamp=datetime.utcnow(),
|
||||
agent_id="agent-1",
|
||||
action_id="action-1",
|
||||
data={"tool": "file_read"},
|
||||
)
|
||||
|
||||
assert event.event_type == AuditEventType.ACTION_EXECUTED
|
||||
assert event.agent_id == "agent-1"
|
||||
|
||||
|
||||
class TestSafetyPolicy:
|
||||
"""Tests for SafetyPolicy model."""
|
||||
|
||||
def test_default_policy(self) -> None:
|
||||
"""Test creating a default policy."""
|
||||
policy = SafetyPolicy(
|
||||
name="default",
|
||||
description="Default safety policy",
|
||||
)
|
||||
|
||||
assert policy.name == "default"
|
||||
assert policy.enabled is True
|
||||
|
||||
def test_restrictive_policy(self) -> None:
|
||||
"""Test creating a restrictive policy."""
|
||||
policy = SafetyPolicy(
|
||||
name="restrictive",
|
||||
description="Restrictive policy",
|
||||
denied_tools=["shell_*", "exec_*"],
|
||||
require_approval_for=["database_*", "git_push"],
|
||||
)
|
||||
|
||||
assert len(policy.denied_tools) == 2
|
||||
assert len(policy.require_approval_for) == 2
|
||||
|
||||
|
||||
class TestGuardianResult:
|
||||
"""Tests for GuardianResult model."""
|
||||
|
||||
def test_allowed_result(self) -> None:
|
||||
"""Test an allowed result."""
|
||||
result = GuardianResult(
|
||||
action_id="action-1",
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["All checks passed"],
|
||||
)
|
||||
|
||||
assert result.decision == SafetyDecision.ALLOW
|
||||
assert result.allowed is True
|
||||
assert result.approval_id is None
|
||||
|
||||
def test_approval_required_result(self) -> None:
|
||||
"""Test a result requiring approval."""
|
||||
result = GuardianResult(
|
||||
action_id="action-1",
|
||||
allowed=False,
|
||||
decision=SafetyDecision.REQUIRE_APPROVAL,
|
||||
reasons=["Action requires human approval"],
|
||||
approval_id="approval-123",
|
||||
)
|
||||
|
||||
assert result.decision == SafetyDecision.REQUIRE_APPROVAL
|
||||
assert result.approval_id == "approval-123"
|
||||
|
||||
|
||||
class TestEnums:
|
||||
"""Tests for enum values."""
|
||||
|
||||
def test_action_types(self) -> None:
|
||||
"""Test action type enum values."""
|
||||
assert ActionType.FILE_READ.value == "file_read"
|
||||
assert ActionType.SHELL_COMMAND.value == "shell_command"
|
||||
|
||||
def test_autonomy_levels(self) -> None:
|
||||
"""Test autonomy level enum values."""
|
||||
assert AutonomyLevel.FULL_CONTROL.value == "full_control"
|
||||
assert AutonomyLevel.MILESTONE.value == "milestone"
|
||||
assert AutonomyLevel.AUTONOMOUS.value == "autonomous"
|
||||
|
||||
def test_permission_levels(self) -> None:
|
||||
"""Test permission level enum values."""
|
||||
assert PermissionLevel.NONE.value == "none"
|
||||
assert PermissionLevel.READ.value == "read"
|
||||
assert PermissionLevel.WRITE.value == "write"
|
||||
assert PermissionLevel.ADMIN.value == "admin"
|
||||
|
||||
def test_safety_decisions(self) -> None:
|
||||
"""Test safety decision enum values."""
|
||||
assert SafetyDecision.ALLOW.value == "allow"
|
||||
assert SafetyDecision.DENY.value == "deny"
|
||||
assert SafetyDecision.REQUIRE_APPROVAL.value == "require_approval"
|
||||
404
backend/tests/services/safety/test_validation.py
Normal file
404
backend/tests/services/safety/test_validation.py
Normal file
@@ -0,0 +1,404 @@
|
||||
"""Tests for safety validation module."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.safety.models import (
|
||||
ActionMetadata,
|
||||
ActionRequest,
|
||||
ActionType,
|
||||
AutonomyLevel,
|
||||
SafetyDecision,
|
||||
SafetyPolicy,
|
||||
ValidationRule,
|
||||
)
|
||||
from app.services.safety.validation.validator import (
|
||||
ActionValidator,
|
||||
create_allow_rule,
|
||||
create_approval_rule,
|
||||
create_deny_rule,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def validator() -> ActionValidator:
|
||||
"""Create a fresh ActionValidator."""
|
||||
return ActionValidator(cache_enabled=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_action() -> ActionRequest:
|
||||
"""Create a sample action request."""
|
||||
metadata = ActionMetadata(
|
||||
agent_id="test-agent",
|
||||
session_id="test-session",
|
||||
autonomy_level=AutonomyLevel.MILESTONE,
|
||||
)
|
||||
return ActionRequest(
|
||||
action_type=ActionType.FILE_READ,
|
||||
tool_name="file_read",
|
||||
resource="/tmp/test.txt", # noqa: S108
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
class TestActionValidator:
|
||||
"""Tests for ActionValidator class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_rules_allows_by_default(
|
||||
self,
|
||||
validator: ActionValidator,
|
||||
sample_action: ActionRequest,
|
||||
) -> None:
|
||||
"""Test that actions are allowed by default with no rules."""
|
||||
result = await validator.validate(sample_action)
|
||||
|
||||
assert result.decision == SafetyDecision.ALLOW
|
||||
assert "No matching rules" in result.reasons[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deny_rule_blocks_action(
|
||||
self,
|
||||
validator: ActionValidator,
|
||||
) -> None:
|
||||
"""Test that a deny rule blocks matching actions."""
|
||||
validator.add_rule(
|
||||
create_deny_rule(
|
||||
name="deny_shell",
|
||||
tool_patterns=["shell_*"],
|
||||
reason="Shell commands not allowed",
|
||||
)
|
||||
)
|
||||
|
||||
metadata = ActionMetadata(agent_id="test-agent", session_id="session-1")
|
||||
action = ActionRequest(
|
||||
action_type=ActionType.SHELL_COMMAND,
|
||||
tool_name="shell_exec",
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
result = await validator.validate(action)
|
||||
|
||||
assert result.decision == SafetyDecision.DENY
|
||||
assert len(result.applied_rules) == 1 # One rule applied
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approval_rule_requires_approval(
|
||||
self,
|
||||
validator: ActionValidator,
|
||||
) -> None:
|
||||
"""Test that an approval rule requires approval."""
|
||||
validator.add_rule(
|
||||
create_approval_rule(
|
||||
name="approve_db",
|
||||
tool_patterns=["database_*"],
|
||||
reason="Database operations require approval",
|
||||
)
|
||||
)
|
||||
|
||||
metadata = ActionMetadata(agent_id="test-agent", session_id="session-1")
|
||||
action = ActionRequest(
|
||||
action_type=ActionType.DATABASE_MUTATE,
|
||||
tool_name="database_delete",
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
result = await validator.validate(action)
|
||||
|
||||
assert result.decision == SafetyDecision.REQUIRE_APPROVAL
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deny_takes_precedence(
|
||||
self,
|
||||
validator: ActionValidator,
|
||||
) -> None:
|
||||
"""Test that deny rules take precedence over allow rules."""
|
||||
validator.add_rule(
|
||||
create_allow_rule(
|
||||
name="allow_files",
|
||||
tool_patterns=["file_*"],
|
||||
priority=10,
|
||||
)
|
||||
)
|
||||
validator.add_rule(
|
||||
create_deny_rule(
|
||||
name="deny_delete",
|
||||
action_types=[ActionType.FILE_DELETE],
|
||||
priority=100,
|
||||
)
|
||||
)
|
||||
|
||||
metadata = ActionMetadata(agent_id="test-agent", session_id="session-1")
|
||||
action = ActionRequest(
|
||||
action_type=ActionType.FILE_DELETE,
|
||||
tool_name="file_delete",
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
result = await validator.validate(action)
|
||||
|
||||
assert result.decision == SafetyDecision.DENY
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rule_priority_ordering(
|
||||
self,
|
||||
validator: ActionValidator,
|
||||
) -> None:
|
||||
"""Test that rules are evaluated in priority order."""
|
||||
validator.add_rule(
|
||||
ValidationRule(
|
||||
name="low_priority",
|
||||
priority=1,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
)
|
||||
)
|
||||
validator.add_rule(
|
||||
ValidationRule(
|
||||
name="high_priority",
|
||||
priority=100,
|
||||
decision=SafetyDecision.DENY,
|
||||
)
|
||||
)
|
||||
|
||||
# High priority should be first in the list
|
||||
assert validator._rules[0].name == "high_priority"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disabled_rule_not_applied(
|
||||
self,
|
||||
validator: ActionValidator,
|
||||
sample_action: ActionRequest,
|
||||
) -> None:
|
||||
"""Test that disabled rules are not applied."""
|
||||
rule = create_deny_rule(
|
||||
name="deny_all",
|
||||
tool_patterns=["*"],
|
||||
)
|
||||
rule.enabled = False
|
||||
validator.add_rule(rule)
|
||||
|
||||
result = await validator.validate(sample_action)
|
||||
|
||||
assert result.decision == SafetyDecision.ALLOW
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resource_pattern_matching(
|
||||
self,
|
||||
validator: ActionValidator,
|
||||
) -> None:
|
||||
"""Test resource pattern matching."""
|
||||
validator.add_rule(
|
||||
create_deny_rule(
|
||||
name="deny_secrets",
|
||||
resource_patterns=["*/secrets/*", "*.env"],
|
||||
)
|
||||
)
|
||||
|
||||
metadata = ActionMetadata(agent_id="test-agent", session_id="session-1")
|
||||
action = ActionRequest(
|
||||
action_type=ActionType.FILE_READ,
|
||||
tool_name="file_read",
|
||||
resource="/app/secrets/api_key.txt",
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
result = await validator.validate(action)
|
||||
|
||||
assert result.decision == SafetyDecision.DENY
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_id_filter(
|
||||
self,
|
||||
validator: ActionValidator,
|
||||
) -> None:
|
||||
"""Test filtering by agent ID."""
|
||||
rule = ValidationRule(
|
||||
name="restrict_agent",
|
||||
agent_ids=["restricted-agent"],
|
||||
decision=SafetyDecision.DENY,
|
||||
reason="Restricted agent",
|
||||
)
|
||||
validator.add_rule(rule)
|
||||
|
||||
# Restricted agent should be denied
|
||||
metadata1 = ActionMetadata(agent_id="restricted-agent")
|
||||
action1 = ActionRequest(
|
||||
action_type=ActionType.FILE_READ,
|
||||
tool_name="file_read",
|
||||
metadata=metadata1,
|
||||
)
|
||||
result1 = await validator.validate(action1)
|
||||
assert result1.decision == SafetyDecision.DENY
|
||||
|
||||
# Other agents should be allowed
|
||||
metadata2 = ActionMetadata(agent_id="normal-agent")
|
||||
action2 = ActionRequest(
|
||||
action_type=ActionType.FILE_READ,
|
||||
tool_name="file_read",
|
||||
metadata=metadata2,
|
||||
)
|
||||
result2 = await validator.validate(action2)
|
||||
assert result2.decision == SafetyDecision.ALLOW
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bypass_mode(
|
||||
self,
|
||||
validator: ActionValidator,
|
||||
sample_action: ActionRequest,
|
||||
) -> None:
|
||||
"""Test validation bypass mode."""
|
||||
validator.add_rule(create_deny_rule(name="deny_all", tool_patterns=["*"]))
|
||||
|
||||
# Should be denied normally
|
||||
result1 = await validator.validate(sample_action)
|
||||
assert result1.decision == SafetyDecision.DENY
|
||||
|
||||
# Enable bypass
|
||||
validator.enable_bypass("Emergency situation")
|
||||
result2 = await validator.validate(sample_action)
|
||||
assert result2.decision == SafetyDecision.ALLOW
|
||||
assert "bypassed" in result2.reasons[0].lower()
|
||||
|
||||
# Disable bypass
|
||||
validator.disable_bypass()
|
||||
result3 = await validator.validate(sample_action)
|
||||
assert result3.decision == SafetyDecision.DENY
|
||||
|
||||
def test_remove_rule(self, validator: ActionValidator) -> None:
|
||||
"""Test removing a rule."""
|
||||
rule = create_deny_rule(name="test_rule", tool_patterns=["test"])
|
||||
validator.add_rule(rule)
|
||||
|
||||
assert len(validator._rules) == 1
|
||||
assert validator.remove_rule(rule.id) is True
|
||||
assert len(validator._rules) == 0
|
||||
|
||||
def test_remove_nonexistent_rule(self, validator: ActionValidator) -> None:
|
||||
"""Test removing a nonexistent rule returns False."""
|
||||
assert validator.remove_rule("nonexistent") is False
|
||||
|
||||
def test_clear_rules(self, validator: ActionValidator) -> None:
|
||||
"""Test clearing all rules."""
|
||||
validator.add_rule(create_deny_rule(name="rule1", tool_patterns=["a"]))
|
||||
validator.add_rule(create_deny_rule(name="rule2", tool_patterns=["b"]))
|
||||
|
||||
assert len(validator._rules) == 2
|
||||
validator.clear_rules()
|
||||
assert len(validator._rules) == 0
|
||||
|
||||
|
||||
class TestLoadRulesFromPolicy:
|
||||
"""Tests for loading rules from policies."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_denied_tools(
|
||||
self,
|
||||
validator: ActionValidator,
|
||||
) -> None:
|
||||
"""Test loading denied tools from policy."""
|
||||
policy = SafetyPolicy(
|
||||
name="test",
|
||||
denied_tools=["shell_*", "exec_*"],
|
||||
)
|
||||
|
||||
validator.load_rules_from_policy(policy)
|
||||
|
||||
# Should have 2 deny rules
|
||||
deny_rules = [r for r in validator._rules if r.decision == SafetyDecision.DENY]
|
||||
assert len(deny_rules) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_approval_patterns(
|
||||
self,
|
||||
validator: ActionValidator,
|
||||
) -> None:
|
||||
"""Test loading approval patterns from policy."""
|
||||
policy = SafetyPolicy(
|
||||
name="test",
|
||||
require_approval_for=["database_*"],
|
||||
)
|
||||
|
||||
validator.load_rules_from_policy(policy)
|
||||
|
||||
approval_rules = [
|
||||
r for r in validator._rules
|
||||
if r.decision == SafetyDecision.REQUIRE_APPROVAL
|
||||
]
|
||||
assert len(approval_rules) == 1
|
||||
|
||||
|
||||
class TestValidationBatch:
|
||||
"""Tests for batch validation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_batch(
|
||||
self,
|
||||
validator: ActionValidator,
|
||||
) -> None:
|
||||
"""Test validating multiple actions."""
|
||||
validator.add_rule(
|
||||
create_deny_rule(
|
||||
name="deny_shell",
|
||||
tool_patterns=["shell_*"],
|
||||
)
|
||||
)
|
||||
|
||||
metadata = ActionMetadata(agent_id="test-agent", session_id="session-1")
|
||||
actions = [
|
||||
ActionRequest(
|
||||
action_type=ActionType.FILE_READ,
|
||||
tool_name="file_read",
|
||||
metadata=metadata,
|
||||
),
|
||||
ActionRequest(
|
||||
action_type=ActionType.SHELL_COMMAND,
|
||||
tool_name="shell_exec",
|
||||
metadata=metadata,
|
||||
),
|
||||
]
|
||||
|
||||
results = await validator.validate_batch(actions)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0].decision == SafetyDecision.ALLOW
|
||||
assert results[1].decision == SafetyDecision.DENY
|
||||
|
||||
|
||||
class TestHelperFunctions:
|
||||
"""Tests for rule creation helper functions."""
|
||||
|
||||
def test_create_allow_rule(self) -> None:
|
||||
"""Test creating an allow rule."""
|
||||
rule = create_allow_rule(
|
||||
name="allow_test",
|
||||
tool_patterns=["test_*"],
|
||||
priority=50,
|
||||
)
|
||||
|
||||
assert rule.name == "allow_test"
|
||||
assert rule.decision == SafetyDecision.ALLOW
|
||||
assert rule.priority == 50
|
||||
|
||||
def test_create_deny_rule(self) -> None:
|
||||
"""Test creating a deny rule."""
|
||||
rule = create_deny_rule(
|
||||
name="deny_test",
|
||||
tool_patterns=["dangerous_*"],
|
||||
reason="Too dangerous",
|
||||
)
|
||||
|
||||
assert rule.name == "deny_test"
|
||||
assert rule.decision == SafetyDecision.DENY
|
||||
assert rule.reason == "Too dangerous"
|
||||
assert rule.priority == 100 # Default priority for deny
|
||||
|
||||
def test_create_approval_rule(self) -> None:
|
||||
"""Test creating an approval rule."""
|
||||
rule = create_approval_rule(
|
||||
name="approve_test",
|
||||
action_types=[ActionType.DATABASE_MUTATE],
|
||||
)
|
||||
|
||||
assert rule.name == "approve_test"
|
||||
assert rule.decision == SafetyDecision.REQUIRE_APPROVAL
|
||||
assert rule.priority == 50 # Default priority for approval
|
||||
Reference in New Issue
Block a user