test(safety): add Phase E comprehensive safety tests

- Add tests for models: ActionMetadata, ActionRequest, ActionResult,
  ValidationRule, BudgetStatus, RateLimitConfig, ApprovalRequest/Response,
  Checkpoint, RollbackResult, AuditEvent, SafetyPolicy, GuardianResult
- Add tests for validation: ActionValidator rules, priorities, patterns,
  bypass mode, batch validation, rule creation helpers
- Add tests for loops: LoopDetector exact/semantic/oscillation detection,
  LoopBreaker throttle/backoff, history management
- Add tests for content filter: PII filtering (email, phone, SSN, credit card),
  secret blocking (API keys, GitHub tokens, private keys), custom patterns,
  scan without filtering, dict filtering
- Add tests for emergency controls: state management, pause/resume/reset,
  scoped emergency stops, callbacks, EmergencyTrigger events
- Fix exception kwargs in content filter and emergency controls to match
  exception class signatures

All 108 tests passing with lint and type checks clean.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-03 11:52:35 +01:00
parent f36bfb3781
commit 015f2de6c6
8 changed files with 1932 additions and 5 deletions

View File

@@ -370,8 +370,8 @@ class ContentFilter:
if raise_on_block:
raise ContentFilterError(
block_reason or "Content blocked",
detected_category=all_matches[0].category.value if all_matches else "unknown",
pattern_name=all_matches[0].pattern_name if all_matches else None,
filter_type=all_matches[0].category.value if all_matches else "unknown",
detected_patterns=[m.pattern_name for m in all_matches] if all_matches else [],
)
elif all_matches:
logger.debug(

View File

@@ -328,7 +328,7 @@ class EmergencyControls:
if raise_if_blocked:
raise EmergencyStopError(
f"Global emergency state: {self._global_state.value}",
stop_reason=self._get_last_reason("global"),
stop_type=self._get_last_reason("global") or "emergency",
triggered_by=self._get_last_triggered_by("global"),
)
return False
@@ -340,9 +340,9 @@ class EmergencyControls:
if raise_if_blocked:
raise EmergencyStopError(
f"Emergency state for {scope}: {state.value}",
stop_reason=self._get_last_reason(scope),
stop_type=self._get_last_reason(scope) or "emergency",
triggered_by=self._get_last_triggered_by(scope),
scope=scope,
details={"scope": scope},
)
return False

View File

@@ -0,0 +1,345 @@
"""Tests for content filtering module."""
import pytest
from app.services.safety.content.filter import (
ContentCategory,
ContentFilter,
FilterAction,
FilterPattern,
filter_content,
scan_for_secrets,
)
from app.services.safety.exceptions import ContentFilterError
@pytest.fixture
def filter_all() -> ContentFilter:
"""Create a ContentFilter with all filters enabled."""
return ContentFilter(
enable_pii_filter=True,
enable_secret_filter=True,
enable_injection_filter=True,
)
@pytest.fixture
def filter_pii_only() -> ContentFilter:
"""Create a ContentFilter with only PII filter."""
return ContentFilter(
enable_pii_filter=True,
enable_secret_filter=False,
enable_injection_filter=False,
)
class TestContentFilter:
"""Tests for ContentFilter class."""
@pytest.mark.asyncio
async def test_filter_email(self, filter_pii_only: ContentFilter) -> None:
"""Test filtering email addresses."""
content = "Contact me at john.doe@example.com for details."
result = await filter_pii_only.filter(content)
assert result.has_sensitive_content
assert "[EMAIL]" in result.filtered_content
assert "john.doe@example.com" not in result.filtered_content
assert len(result.matches) == 1
assert result.matches[0].category == ContentCategory.PII
@pytest.mark.asyncio
async def test_filter_phone_number(self, filter_pii_only: ContentFilter) -> None:
"""Test filtering phone numbers."""
content = "Call me at 555-123-4567 or (555) 987-6543."
result = await filter_pii_only.filter(content)
assert result.has_sensitive_content
assert "[PHONE]" in result.filtered_content
# Should redact both phone numbers
assert "555-123-4567" not in result.filtered_content
@pytest.mark.asyncio
async def test_filter_ssn(self, filter_pii_only: ContentFilter) -> None:
"""Test filtering Social Security Numbers."""
content = "SSN: 123-45-6789"
result = await filter_pii_only.filter(content)
assert result.has_sensitive_content
assert "[SSN]" in result.filtered_content
assert "123-45-6789" not in result.filtered_content
@pytest.mark.asyncio
async def test_filter_credit_card(self, filter_all: ContentFilter) -> None:
"""Test filtering credit card numbers."""
content = "Card: 4111-2222-3333-4444"
result = await filter_all.filter(content)
assert result.has_sensitive_content
assert "[CREDIT_CARD]" in result.filtered_content
@pytest.mark.asyncio
async def test_block_api_key(self, filter_all: ContentFilter) -> None:
"""Test blocking API keys."""
content = "api_key: sk-abcdef1234567890abcdef1234567890"
result = await filter_all.filter(content)
assert result.blocked
assert result.block_reason is not None
assert "api_key" in result.block_reason.lower()
@pytest.mark.asyncio
async def test_block_github_token(self, filter_all: ContentFilter) -> None:
"""Test blocking GitHub tokens."""
content = "token: ghp_abcdefghijklmnopqrstuvwxyz1234567890"
result = await filter_all.filter(content)
assert result.blocked
assert len(result.matches) > 0
@pytest.mark.asyncio
async def test_block_private_key(self, filter_all: ContentFilter) -> None:
"""Test blocking private keys."""
content = """
-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEA...
-----END RSA PRIVATE KEY-----
"""
result = await filter_all.filter(content)
assert result.blocked
@pytest.mark.asyncio
async def test_block_password_in_url(self, filter_all: ContentFilter) -> None:
"""Test blocking passwords in URLs."""
content = "Connect to: postgres://user:secretpassword@localhost/db"
result = await filter_all.filter(content)
assert result.blocked or result.has_sensitive_content
@pytest.mark.asyncio
async def test_warn_ip_address(self, filter_pii_only: ContentFilter) -> None:
"""Test warning on IP addresses."""
content = "Server IP: 192.168.1.100"
result = await filter_pii_only.filter(content)
# IP addresses generate warnings, not blocks
assert len(result.warnings) > 0 or result.has_sensitive_content
@pytest.mark.asyncio
async def test_no_false_positives_clean_content(
self,
filter_all: ContentFilter,
) -> None:
"""Test that clean content passes through."""
content = "This is a normal message with no sensitive data."
result = await filter_all.filter(content)
assert not result.blocked
assert result.filtered_content == content
@pytest.mark.asyncio
async def test_raise_on_block(self, filter_all: ContentFilter) -> None:
"""Test raising exception on blocked content."""
content = "-----BEGIN RSA PRIVATE KEY-----"
with pytest.raises(ContentFilterError):
await filter_all.filter(content, raise_on_block=True)
class TestFilterDict:
"""Tests for dictionary filtering."""
@pytest.mark.asyncio
async def test_filter_dict_values(self, filter_pii_only: ContentFilter) -> None:
"""Test filtering string values in a dictionary."""
data = {
"name": "John Doe",
"email": "john@example.com",
"age": 30,
}
result = await filter_pii_only.filter_dict(data)
assert "[EMAIL]" in result["email"]
assert result["age"] == 30 # Non-string unchanged
@pytest.mark.asyncio
async def test_filter_dict_recursive(
self,
filter_pii_only: ContentFilter,
) -> None:
"""Test recursive dictionary filtering."""
data = {
"user": {
"contact": {
"email": "test@example.com",
}
}
}
result = await filter_pii_only.filter_dict(data, recursive=True)
assert "[EMAIL]" in result["user"]["contact"]["email"]
@pytest.mark.asyncio
async def test_filter_dict_specific_keys(
self,
filter_pii_only: ContentFilter,
) -> None:
"""Test filtering specific keys only."""
data = {
"public_email": "public@example.com",
"private_email": "private@example.com",
}
result = await filter_pii_only.filter_dict(
data,
keys_to_filter=["private_email"],
)
# Only private_email should be filtered
assert "public@example.com" in result["public_email"]
assert "[EMAIL]" in result["private_email"]
class TestScan:
"""Tests for content scanning."""
@pytest.mark.asyncio
async def test_scan_without_filtering(
self,
filter_all: ContentFilter,
) -> None:
"""Test scanning without modifying content."""
content = "Email: test@example.com, SSN: 123-45-6789"
matches = await filter_all.scan(content)
assert len(matches) >= 2
@pytest.mark.asyncio
async def test_scan_specific_categories(
self,
filter_all: ContentFilter,
) -> None:
"""Test scanning for specific categories only."""
content = "Email: test@example.com, token: ghp_abc123456789012345678901234567890123"
# Scan only for secrets
matches = await filter_all.scan(
content,
categories=[ContentCategory.SECRETS],
)
# Should only find the token, not the email
assert all(m.category == ContentCategory.SECRETS for m in matches)
class TestValidateSafe:
"""Tests for safe validation."""
@pytest.mark.asyncio
async def test_validate_safe_clean(self, filter_all: ContentFilter) -> None:
"""Test validation of clean content."""
content = "This is safe content."
is_safe, issues = await filter_all.validate_safe(content)
assert is_safe is True
assert len(issues) == 0
@pytest.mark.asyncio
async def test_validate_safe_with_secrets(
self,
filter_all: ContentFilter,
) -> None:
"""Test validation of content with secrets."""
content = "-----BEGIN RSA PRIVATE KEY-----"
is_safe, issues = await filter_all.validate_safe(content)
assert is_safe is False
assert len(issues) > 0
class TestCustomPatterns:
"""Tests for custom pattern support."""
@pytest.mark.asyncio
async def test_add_custom_pattern(self) -> None:
"""Test adding a custom filter pattern."""
content_filter = ContentFilter(
enable_pii_filter=False,
enable_secret_filter=False,
enable_injection_filter=False,
)
# Add custom pattern for internal IDs
content_filter.add_pattern(
FilterPattern(
name="internal_id",
category=ContentCategory.CUSTOM,
pattern=r"INTERNAL-[A-Z0-9]{8}",
action=FilterAction.REDACT,
replacement="[INTERNAL_ID]",
)
)
content = "Reference: INTERNAL-ABC12345"
result = await content_filter.filter(content)
assert "[INTERNAL_ID]" in result.filtered_content
@pytest.mark.asyncio
async def test_disable_pattern(self, filter_pii_only: ContentFilter) -> None:
"""Test disabling a pattern."""
filter_pii_only.enable_pattern("email", enabled=False)
content = "Email: test@example.com"
result = await filter_pii_only.filter(content)
# Email should not be filtered
assert "test@example.com" in result.filtered_content
@pytest.mark.asyncio
async def test_remove_pattern(self, filter_pii_only: ContentFilter) -> None:
"""Test removing a pattern."""
removed = filter_pii_only.remove_pattern("email")
assert removed is True
content = "Email: test@example.com"
result = await filter_pii_only.filter(content)
# Email should not be filtered
assert "test@example.com" in result.filtered_content
class TestPatternStats:
"""Tests for pattern statistics."""
def test_get_pattern_stats(self, filter_all: ContentFilter) -> None:
"""Test getting pattern statistics."""
stats = filter_all.get_pattern_stats()
assert stats["total_patterns"] > 0
assert "by_category" in stats
assert "by_action" in stats
class TestConvenienceFunctions:
"""Tests for convenience functions."""
@pytest.mark.asyncio
async def test_filter_content_function(self) -> None:
"""Test the quick filter function."""
content = "Email: test@example.com"
filtered = await filter_content(content)
assert "test@example.com" not in filtered
@pytest.mark.asyncio
async def test_scan_for_secrets_function(self) -> None:
"""Test the quick secret scan function."""
content = "Token: ghp_abcdefghijklmnopqrstuvwxyz1234567890"
matches = await scan_for_secrets(content)
assert len(matches) > 0
assert matches[0].category in (
ContentCategory.SECRETS,
ContentCategory.CREDENTIALS,
)

View File

@@ -0,0 +1,425 @@
"""Tests for emergency controls module."""
import pytest
from app.services.safety.emergency.controls import (
EmergencyControls,
EmergencyReason,
EmergencyState,
EmergencyTrigger,
)
from app.services.safety.exceptions import EmergencyStopError
@pytest.fixture
def controls() -> EmergencyControls:
"""Create fresh EmergencyControls."""
return EmergencyControls()
class TestEmergencyControls:
"""Tests for EmergencyControls class."""
@pytest.mark.asyncio
async def test_initial_state_is_normal(
self,
controls: EmergencyControls,
) -> None:
"""Test that initial state is normal."""
state = await controls.get_state("global")
assert state == EmergencyState.NORMAL
@pytest.mark.asyncio
async def test_emergency_stop_changes_state(
self,
controls: EmergencyControls,
) -> None:
"""Test that emergency stop changes state to stopped."""
event = await controls.emergency_stop(
reason=EmergencyReason.MANUAL,
triggered_by="test",
message="Test emergency stop",
)
assert event.state == EmergencyState.STOPPED
state = await controls.get_state("global")
assert state == EmergencyState.STOPPED
@pytest.mark.asyncio
async def test_pause_changes_state(
self,
controls: EmergencyControls,
) -> None:
"""Test that pause changes state to paused."""
event = await controls.pause(
reason=EmergencyReason.BUDGET_EXCEEDED,
triggered_by="budget_controller",
message="Budget exceeded",
)
assert event.state == EmergencyState.PAUSED
state = await controls.get_state("global")
assert state == EmergencyState.PAUSED
@pytest.mark.asyncio
async def test_resume_from_paused(
self,
controls: EmergencyControls,
) -> None:
"""Test resuming from paused state."""
await controls.pause(
reason=EmergencyReason.RATE_LIMIT,
triggered_by="limiter",
message="Rate limited",
)
resumed = await controls.resume(resumed_by="admin")
assert resumed is True
state = await controls.get_state("global")
assert state == EmergencyState.NORMAL
@pytest.mark.asyncio
async def test_cannot_resume_from_stopped(
self,
controls: EmergencyControls,
) -> None:
"""Test that you cannot resume from stopped state."""
await controls.emergency_stop(
reason=EmergencyReason.SAFETY_VIOLATION,
triggered_by="safety",
message="Critical violation",
)
resumed = await controls.resume(resumed_by="admin")
assert resumed is False
state = await controls.get_state("global")
assert state == EmergencyState.STOPPED
@pytest.mark.asyncio
async def test_reset_from_stopped(
self,
controls: EmergencyControls,
) -> None:
"""Test resetting from stopped state."""
await controls.emergency_stop(
reason=EmergencyReason.SAFETY_VIOLATION,
triggered_by="safety",
message="Critical violation",
)
reset = await controls.reset(reset_by="admin")
assert reset is True
state = await controls.get_state("global")
assert state == EmergencyState.NORMAL
@pytest.mark.asyncio
async def test_scoped_emergency_stop(
self,
controls: EmergencyControls,
) -> None:
"""Test emergency stop with specific scope."""
await controls.emergency_stop(
reason=EmergencyReason.LOOP_DETECTED,
triggered_by="detector",
message="Loop in agent",
scope="agent:agent-123",
)
# Agent scope should be stopped
agent_state = await controls.get_state("agent:agent-123")
assert agent_state == EmergencyState.STOPPED
# Global should still be normal
global_state = await controls.get_state("global")
assert global_state == EmergencyState.NORMAL
@pytest.mark.asyncio
async def test_check_allowed_when_normal(
self,
controls: EmergencyControls,
) -> None:
"""Test check_allowed returns True when state is normal."""
allowed = await controls.check_allowed()
assert allowed is True
@pytest.mark.asyncio
async def test_check_allowed_when_stopped(
self,
controls: EmergencyControls,
) -> None:
"""Test check_allowed returns False when stopped."""
await controls.emergency_stop(
reason=EmergencyReason.MANUAL,
triggered_by="test",
message="Stop",
)
allowed = await controls.check_allowed(raise_if_blocked=False)
assert allowed is False
@pytest.mark.asyncio
async def test_check_allowed_raises_when_blocked(
self,
controls: EmergencyControls,
) -> None:
"""Test check_allowed raises exception when blocked."""
await controls.emergency_stop(
reason=EmergencyReason.MANUAL,
triggered_by="test",
message="Stop",
)
with pytest.raises(EmergencyStopError):
await controls.check_allowed(raise_if_blocked=True)
@pytest.mark.asyncio
async def test_check_allowed_with_scope(
self,
controls: EmergencyControls,
) -> None:
"""Test check_allowed with specific scope."""
await controls.pause(
reason=EmergencyReason.BUDGET_EXCEEDED,
triggered_by="budget",
message="Paused",
scope="project:proj-123",
)
# Project scope should be blocked
allowed_project = await controls.check_allowed(
scope="project:proj-123",
raise_if_blocked=False,
)
assert allowed_project is False
# Different scope should be allowed
allowed_other = await controls.check_allowed(
scope="project:proj-456",
raise_if_blocked=False,
)
assert allowed_other is True
@pytest.mark.asyncio
async def test_get_all_states(
self,
controls: EmergencyControls,
) -> None:
"""Test getting all states."""
await controls.pause(
reason=EmergencyReason.MANUAL,
triggered_by="test",
message="Pause",
scope="agent:a1",
)
await controls.emergency_stop(
reason=EmergencyReason.MANUAL,
triggered_by="test",
message="Stop",
scope="agent:a2",
)
states = await controls.get_all_states()
assert states["global"] == EmergencyState.NORMAL
assert states["agent:a1"] == EmergencyState.PAUSED
assert states["agent:a2"] == EmergencyState.STOPPED
@pytest.mark.asyncio
async def test_get_active_events(
self,
controls: EmergencyControls,
) -> None:
"""Test getting active (unresolved) events."""
await controls.pause(
reason=EmergencyReason.MANUAL,
triggered_by="test",
message="Pause 1",
)
events = await controls.get_active_events()
assert len(events) == 1
# Resume should resolve the event
await controls.resume()
events_after = await controls.get_active_events()
assert len(events_after) == 0
@pytest.mark.asyncio
async def test_event_history(
self,
controls: EmergencyControls,
) -> None:
"""Test getting event history."""
await controls.pause(
reason=EmergencyReason.RATE_LIMIT,
triggered_by="test",
message="Rate limited",
)
await controls.resume()
history = await controls.get_event_history()
assert len(history) == 1
assert history[0].reason == EmergencyReason.RATE_LIMIT
@pytest.mark.asyncio
async def test_event_metadata(
self,
controls: EmergencyControls,
) -> None:
"""Test event metadata storage."""
event = await controls.emergency_stop(
reason=EmergencyReason.BUDGET_EXCEEDED,
triggered_by="budget_controller",
message="Over budget",
metadata={"budget_type": "tokens", "usage": 150000},
)
assert event.metadata["budget_type"] == "tokens"
assert event.metadata["usage"] == 150000
class TestCallbacks:
"""Tests for callback functionality."""
@pytest.mark.asyncio
async def test_on_stop_callback(
self,
controls: EmergencyControls,
) -> None:
"""Test on_stop callback is called."""
callback_called = []
def callback(event: object) -> None:
callback_called.append(event)
controls.on_stop(callback)
await controls.emergency_stop(
reason=EmergencyReason.MANUAL,
triggered_by="test",
message="Stop",
)
assert len(callback_called) == 1
@pytest.mark.asyncio
async def test_on_pause_callback(
self,
controls: EmergencyControls,
) -> None:
"""Test on_pause callback is called."""
callback_called = []
def callback(event: object) -> None:
callback_called.append(event)
controls.on_pause(callback)
await controls.pause(
reason=EmergencyReason.MANUAL,
triggered_by="test",
message="Pause",
)
assert len(callback_called) == 1
@pytest.mark.asyncio
async def test_on_resume_callback(
self,
controls: EmergencyControls,
) -> None:
"""Test on_resume callback is called."""
callback_called = []
def callback(data: object) -> None:
callback_called.append(data)
controls.on_resume(callback)
await controls.pause(
reason=EmergencyReason.MANUAL,
triggered_by="test",
message="Pause",
)
await controls.resume()
assert len(callback_called) == 1
class TestEmergencyTrigger:
"""Tests for EmergencyTrigger class."""
@pytest.fixture
def trigger(self, controls: EmergencyControls) -> EmergencyTrigger:
"""Create an EmergencyTrigger."""
return EmergencyTrigger(controls)
@pytest.mark.asyncio
async def test_trigger_on_safety_violation(
self,
trigger: EmergencyTrigger,
controls: EmergencyControls,
) -> None:
"""Test triggering emergency on safety violation."""
event = await trigger.trigger_on_safety_violation(
violation_type="unauthorized_access",
details={"resource": "/secrets/key"},
)
assert event.reason == EmergencyReason.SAFETY_VIOLATION
assert event.state == EmergencyState.STOPPED
state = await controls.get_state("global")
assert state == EmergencyState.STOPPED
@pytest.mark.asyncio
async def test_trigger_on_budget_exceeded(
self,
trigger: EmergencyTrigger,
controls: EmergencyControls,
) -> None:
"""Test triggering pause on budget exceeded."""
event = await trigger.trigger_on_budget_exceeded(
budget_type="tokens",
current=150000,
limit=100000,
)
assert event.reason == EmergencyReason.BUDGET_EXCEEDED
assert event.state == EmergencyState.PAUSED # Pause, not stop
@pytest.mark.asyncio
async def test_trigger_on_loop_detected(
self,
trigger: EmergencyTrigger,
controls: EmergencyControls,
) -> None:
"""Test triggering pause on loop detection."""
event = await trigger.trigger_on_loop_detected(
loop_type="exact",
agent_id="agent-123",
details={"pattern": "file_read"},
)
assert event.reason == EmergencyReason.LOOP_DETECTED
assert event.scope == "agent:agent-123"
assert event.state == EmergencyState.PAUSED
@pytest.mark.asyncio
async def test_trigger_on_content_violation(
self,
trigger: EmergencyTrigger,
controls: EmergencyControls,
) -> None:
"""Test triggering stop on content violation."""
event = await trigger.trigger_on_content_violation(
category="secrets",
pattern="private_key",
)
assert event.reason == EmergencyReason.CONTENT_VIOLATION
assert event.state == EmergencyState.STOPPED

View File

@@ -0,0 +1,316 @@
"""Tests for loop detection module."""
import pytest
from app.services.safety.exceptions import LoopDetectedError
from app.services.safety.loops.detector import (
ActionSignature,
LoopBreaker,
LoopDetector,
)
from app.services.safety.models import (
ActionMetadata,
ActionRequest,
ActionType,
AutonomyLevel,
)
@pytest.fixture
def detector() -> LoopDetector:
"""Create a fresh LoopDetector with low thresholds for testing."""
return LoopDetector(
history_size=20,
max_exact_repetitions=3,
max_semantic_repetitions=5,
)
@pytest.fixture
def sample_metadata() -> ActionMetadata:
"""Create sample action metadata."""
return ActionMetadata(
agent_id="test-agent",
session_id="test-session",
autonomy_level=AutonomyLevel.MILESTONE,
)
def create_action(
metadata: ActionMetadata,
tool_name: str,
resource: str = "/tmp/test.txt", # noqa: S108
arguments: dict | None = None,
) -> ActionRequest:
"""Helper to create test actions."""
return ActionRequest(
action_type=ActionType.FILE_READ,
tool_name=tool_name,
resource=resource,
arguments=arguments or {},
metadata=metadata,
)
class TestActionSignature:
"""Tests for ActionSignature class."""
def test_exact_key_includes_args(self, sample_metadata: ActionMetadata) -> None:
"""Test that exact key includes argument hash."""
action1 = create_action(sample_metadata, "file_read", arguments={"path": "a"})
action2 = create_action(sample_metadata, "file_read", arguments={"path": "b"})
sig1 = ActionSignature(action1)
sig2 = ActionSignature(action2)
assert sig1.exact_key() != sig2.exact_key()
def test_semantic_key_ignores_args(self, sample_metadata: ActionMetadata) -> None:
"""Test that semantic key ignores arguments."""
action1 = create_action(sample_metadata, "file_read", arguments={"path": "a"})
action2 = create_action(sample_metadata, "file_read", arguments={"path": "b"})
sig1 = ActionSignature(action1)
sig2 = ActionSignature(action2)
assert sig1.semantic_key() == sig2.semantic_key()
def test_type_key(self, sample_metadata: ActionMetadata) -> None:
"""Test type key extraction."""
action = create_action(sample_metadata, "file_read")
sig = ActionSignature(action)
assert sig.type_key() == "file_read"
class TestLoopDetector:
"""Tests for LoopDetector class."""
@pytest.mark.asyncio
async def test_no_loop_on_first_action(
self,
detector: LoopDetector,
sample_metadata: ActionMetadata,
) -> None:
"""Test that first action is never a loop."""
action = create_action(sample_metadata, "file_read")
is_loop, loop_type = await detector.check(action)
assert is_loop is False
assert loop_type is None
@pytest.mark.asyncio
async def test_exact_loop_detection(
self,
detector: LoopDetector,
sample_metadata: ActionMetadata,
) -> None:
"""Test detection of exact repetitions."""
action = create_action(
sample_metadata,
"file_read",
resource="/tmp/same.txt", # noqa: S108
arguments={"path": "/tmp/same.txt"}, # noqa: S108
)
# Record the same action 3 times (threshold)
for _ in range(3):
await detector.record(action)
# Next should be detected as a loop
is_loop, loop_type = await detector.check(action)
assert is_loop is True
assert loop_type == "exact"
@pytest.mark.asyncio
async def test_semantic_loop_detection(
self,
detector: LoopDetector,
sample_metadata: ActionMetadata,
) -> None:
"""Test detection of semantic (similar) repetitions."""
# Record same tool/resource but different arguments
test_resource = "/tmp/test.txt" # noqa: S108
for i in range(5):
action = create_action(
sample_metadata,
"file_read",
resource=test_resource,
arguments={"offset": i},
)
await detector.record(action)
# Next similar action should be detected as semantic loop
action = create_action(
sample_metadata,
"file_read",
resource=test_resource,
arguments={"offset": 100},
)
is_loop, loop_type = await detector.check(action)
assert is_loop is True
assert loop_type == "semantic"
@pytest.mark.asyncio
async def test_oscillation_detection(
self,
detector: LoopDetector,
sample_metadata: ActionMetadata,
) -> None:
"""Test detection of A→B→A→B oscillation pattern."""
action_a = create_action(sample_metadata, "tool_a", resource="/a")
action_b = create_action(sample_metadata, "tool_b", resource="/b")
# Create A→B→A pattern
await detector.record(action_a)
await detector.record(action_b)
await detector.record(action_a)
# Fourth action completing A→B→A→B should be detected as oscillation
is_loop, loop_type = await detector.check(action_b)
assert is_loop is True
assert loop_type == "oscillation"
@pytest.mark.asyncio
async def test_different_actions_no_loop(
self,
detector: LoopDetector,
sample_metadata: ActionMetadata,
) -> None:
"""Test that different actions don't trigger loops."""
for i in range(10):
action = create_action(
sample_metadata,
f"tool_{i}",
resource=f"/resource_{i}",
)
is_loop, _ = await detector.check(action)
assert is_loop is False
await detector.record(action)
@pytest.mark.asyncio
async def test_check_and_raise(
self,
detector: LoopDetector,
sample_metadata: ActionMetadata,
) -> None:
"""Test check_and_raise raises on loop detection."""
action = create_action(sample_metadata, "file_read")
# Record threshold number of times
for _ in range(3):
await detector.record(action)
# Should raise
with pytest.raises(LoopDetectedError) as exc_info:
await detector.check_and_raise(action)
assert "exact" in exc_info.value.loop_type.lower()
@pytest.mark.asyncio
async def test_clear_history(
self,
detector: LoopDetector,
sample_metadata: ActionMetadata,
) -> None:
"""Test clearing agent history."""
action = create_action(sample_metadata, "file_read")
# Record multiple times
for _ in range(3):
await detector.record(action)
# Clear history
await detector.clear_history(sample_metadata.agent_id)
# Should no longer detect loop
is_loop, _ = await detector.check(action)
assert is_loop is False
@pytest.mark.asyncio
async def test_per_agent_history(
self,
detector: LoopDetector,
) -> None:
"""Test that history is tracked per agent."""
metadata1 = ActionMetadata(agent_id="agent-1", session_id="s1")
metadata2 = ActionMetadata(agent_id="agent-2", session_id="s2")
action1 = create_action(metadata1, "file_read")
action2 = create_action(metadata2, "file_read")
# Record for agent 1 (threshold times)
for _ in range(3):
await detector.record(action1)
# Agent 1 should detect loop
is_loop1, _ = await detector.check(action1)
assert is_loop1 is True
# Agent 2 should not detect loop
is_loop2, _ = await detector.check(action2)
assert is_loop2 is False
@pytest.mark.asyncio
async def test_get_stats(
self,
detector: LoopDetector,
sample_metadata: ActionMetadata,
) -> None:
"""Test getting loop detection stats."""
for i in range(5):
action = create_action(
sample_metadata,
f"tool_{i % 2}", # Alternate between 2 tools
resource=f"/resource_{i}",
)
await detector.record(action)
stats = await detector.get_stats(sample_metadata.agent_id)
assert stats["history_size"] == 5
assert len(stats["action_type_counts"]) > 0
class TestLoopBreaker:
"""Tests for LoopBreaker class."""
@pytest.mark.asyncio
async def test_suggest_alternatives_exact(
self,
sample_metadata: ActionMetadata,
) -> None:
"""Test suggestions for exact loops."""
action = create_action(sample_metadata, "file_read")
suggestions = await LoopBreaker.suggest_alternatives(action, "exact")
assert len(suggestions) > 0
assert "same action" in suggestions[0].lower()
@pytest.mark.asyncio
async def test_suggest_alternatives_semantic(
self,
sample_metadata: ActionMetadata,
) -> None:
"""Test suggestions for semantic loops."""
action = create_action(sample_metadata, "file_read")
suggestions = await LoopBreaker.suggest_alternatives(action, "semantic")
assert len(suggestions) > 0
assert "similar" in suggestions[0].lower()
@pytest.mark.asyncio
async def test_suggest_alternatives_oscillation(
self,
sample_metadata: ActionMetadata,
) -> None:
"""Test suggestions for oscillation loops."""
action = create_action(sample_metadata, "file_read")
suggestions = await LoopBreaker.suggest_alternatives(action, "oscillation")
assert len(suggestions) > 0
assert "oscillat" in suggestions[0].lower()

View File

@@ -0,0 +1,437 @@
"""Tests for safety framework models."""
from datetime import datetime, timedelta
from app.services.safety.models import (
ActionMetadata,
ActionRequest,
ActionResult,
ActionType,
ApprovalRequest,
ApprovalResponse,
ApprovalStatus,
AuditEvent,
AuditEventType,
AutonomyLevel,
BudgetScope,
BudgetStatus,
Checkpoint,
CheckpointType,
GuardianResult,
PermissionLevel,
RateLimitConfig,
RollbackResult,
SafetyDecision,
SafetyPolicy,
ValidationResult,
ValidationRule,
)
class TestActionMetadata:
"""Tests for ActionMetadata model."""
def test_create_with_defaults(self) -> None:
"""Test creating metadata with default values."""
metadata = ActionMetadata(
agent_id="agent-1",
)
assert metadata.agent_id == "agent-1"
assert metadata.autonomy_level == AutonomyLevel.MILESTONE
assert metadata.project_id is None
assert metadata.session_id is None
def test_create_with_all_fields(self) -> None:
"""Test creating metadata with all fields."""
metadata = ActionMetadata(
agent_id="agent-1",
session_id="session-1",
project_id="project-1",
user_id="user-1",
autonomy_level=AutonomyLevel.AUTONOMOUS,
)
assert metadata.project_id == "project-1"
assert metadata.user_id == "user-1"
assert metadata.autonomy_level == AutonomyLevel.AUTONOMOUS
class TestActionRequest:
"""Tests for ActionRequest model."""
def test_create_basic_action(self) -> None:
"""Test creating a basic action request."""
metadata = ActionMetadata(agent_id="agent-1", session_id="session-1")
action = ActionRequest(
action_type=ActionType.FILE_READ,
tool_name="file_read",
metadata=metadata,
)
assert action.action_type == ActionType.FILE_READ
assert action.tool_name == "file_read"
assert action.id is not None
assert action.metadata.agent_id == "agent-1"
def test_action_with_arguments(self) -> None:
"""Test action with arguments."""
test_path = "/tmp/test.txt" # noqa: S108
metadata = ActionMetadata(agent_id="agent-1", session_id="session-1")
action = ActionRequest(
action_type=ActionType.FILE_WRITE,
tool_name="file_write",
arguments={"path": test_path, "content": "hello"},
resource=test_path,
metadata=metadata,
)
assert action.arguments["path"] == test_path
assert action.resource == test_path
class TestActionResult:
"""Tests for ActionResult model."""
def test_successful_result(self) -> None:
"""Test creating a successful result."""
result = ActionResult(
action_id="action-1",
success=True,
data={"output": "done"},
)
assert result.success is True
assert result.data["output"] == "done"
assert result.error is None
def test_failed_result(self) -> None:
"""Test creating a failed result."""
result = ActionResult(
action_id="action-1",
success=False,
error="Permission denied",
)
assert result.success is False
assert result.error == "Permission denied"
class TestValidationRule:
"""Tests for ValidationRule model."""
def test_create_rule(self) -> None:
"""Test creating a validation rule."""
rule = ValidationRule(
name="deny_shell",
description="Deny shell commands",
priority=100,
tool_patterns=["shell_*"],
decision=SafetyDecision.DENY,
reason="Shell commands are not allowed",
)
assert rule.name == "deny_shell"
assert rule.priority == 100
assert rule.decision == SafetyDecision.DENY
assert rule.enabled is True
def test_rule_defaults(self) -> None:
"""Test rule default values."""
rule = ValidationRule(name="test_rule", decision=SafetyDecision.ALLOW)
assert rule.id is not None
assert rule.priority == 0
assert rule.enabled is True
assert rule.decision == SafetyDecision.ALLOW
class TestValidationResult:
"""Tests for ValidationResult model."""
def test_allow_result(self) -> None:
"""Test an allow result."""
result = ValidationResult(
action_id="action-1",
decision=SafetyDecision.ALLOW,
applied_rules=["rule-1"],
reasons=["Action is allowed"],
)
assert result.decision == SafetyDecision.ALLOW
assert len(result.applied_rules) == 1
def test_deny_result(self) -> None:
"""Test a deny result."""
result = ValidationResult(
action_id="action-1",
decision=SafetyDecision.DENY,
applied_rules=["deny_rule"],
reasons=["Action is not permitted"],
)
assert result.decision == SafetyDecision.DENY
class TestBudgetStatus:
"""Tests for BudgetStatus model."""
def test_under_budget(self) -> None:
"""Test status when under budget."""
status = BudgetStatus(
scope=BudgetScope.SESSION,
scope_id="session-1",
tokens_used=5000,
tokens_limit=10000,
tokens_remaining=5000,
)
assert status.tokens_remaining == 5000
assert status.tokens_used == 5000
def test_over_budget(self) -> None:
"""Test status when over budget."""
status = BudgetStatus(
scope=BudgetScope.SESSION,
scope_id="session-1",
cost_used_usd=15.0,
cost_limit_usd=10.0,
cost_remaining_usd=0.0,
is_exceeded=True,
)
assert status.is_exceeded is True
assert status.cost_remaining_usd == 0.0
class TestRateLimitConfig:
"""Tests for RateLimitConfig model."""
def test_create_config(self) -> None:
"""Test creating rate limit config."""
config = RateLimitConfig(
name="actions",
limit=60,
window_seconds=60,
)
assert config.name == "actions"
assert config.limit == 60
assert config.window_seconds == 60
class TestApprovalRequest:
"""Tests for ApprovalRequest model."""
def test_create_request(self) -> None:
"""Test creating an approval request."""
metadata = ActionMetadata(agent_id="agent-1", session_id="session-1")
action = ActionRequest(
action_type=ActionType.DATABASE_MUTATE,
tool_name="db_delete",
metadata=metadata,
)
request = ApprovalRequest(
id="approval-1",
action=action,
reason="Database mutation requires approval",
urgency="high",
timeout_seconds=300,
)
assert request.id == "approval-1"
assert request.urgency == "high"
assert request.timeout_seconds == 300
class TestApprovalResponse:
"""Tests for ApprovalResponse model."""
def test_approved_response(self) -> None:
"""Test an approved response."""
response = ApprovalResponse(
request_id="approval-1",
status=ApprovalStatus.APPROVED,
decided_by="admin",
reason="Looks safe",
)
assert response.status == ApprovalStatus.APPROVED
assert response.decided_by == "admin"
def test_denied_response(self) -> None:
"""Test a denied response."""
response = ApprovalResponse(
request_id="approval-1",
status=ApprovalStatus.DENIED,
decided_by="admin",
reason="Too risky",
)
assert response.status == ApprovalStatus.DENIED
class TestCheckpoint:
"""Tests for Checkpoint model."""
def test_create_checkpoint(self) -> None:
"""Test creating a checkpoint."""
test_path = "/tmp/test.txt" # noqa: S108
checkpoint = Checkpoint(
id="checkpoint-1",
checkpoint_type=CheckpointType.FILE,
action_id="action-1",
created_at=datetime.utcnow(),
data={"path": test_path},
description="File checkpoint",
)
assert checkpoint.id == "checkpoint-1"
assert checkpoint.is_valid is True
def test_expired_checkpoint(self) -> None:
"""Test an expired checkpoint."""
checkpoint = Checkpoint(
id="checkpoint-1",
checkpoint_type=CheckpointType.FILE,
action_id="action-1",
created_at=datetime.utcnow() - timedelta(hours=2),
expires_at=datetime.utcnow() - timedelta(hours=1),
data={},
)
# is_valid is a simple bool, not computed from expires_at
# The RollbackManager handles expiration logic
assert checkpoint.is_valid is True # Default value
class TestRollbackResult:
"""Tests for RollbackResult model."""
def test_successful_rollback(self) -> None:
"""Test a successful rollback."""
result = RollbackResult(
checkpoint_id="checkpoint-1",
success=True,
actions_rolled_back=["file:/tmp/test.txt"],
failed_actions=[],
)
assert result.success is True
assert len(result.actions_rolled_back) == 1
def test_partial_rollback(self) -> None:
"""Test a partial rollback."""
result = RollbackResult(
checkpoint_id="checkpoint-1",
success=False,
actions_rolled_back=["file:/tmp/a.txt"],
failed_actions=["file:/tmp/b.txt"],
error="Failed to rollback 1 item",
)
assert result.success is False
assert len(result.failed_actions) == 1
class TestAuditEvent:
"""Tests for AuditEvent model."""
def test_create_event(self) -> None:
"""Test creating an audit event."""
event = AuditEvent(
id="event-1",
event_type=AuditEventType.ACTION_EXECUTED,
timestamp=datetime.utcnow(),
agent_id="agent-1",
action_id="action-1",
data={"tool": "file_read"},
)
assert event.event_type == AuditEventType.ACTION_EXECUTED
assert event.agent_id == "agent-1"
class TestSafetyPolicy:
"""Tests for SafetyPolicy model."""
def test_default_policy(self) -> None:
"""Test creating a default policy."""
policy = SafetyPolicy(
name="default",
description="Default safety policy",
)
assert policy.name == "default"
assert policy.enabled is True
def test_restrictive_policy(self) -> None:
"""Test creating a restrictive policy."""
policy = SafetyPolicy(
name="restrictive",
description="Restrictive policy",
denied_tools=["shell_*", "exec_*"],
require_approval_for=["database_*", "git_push"],
)
assert len(policy.denied_tools) == 2
assert len(policy.require_approval_for) == 2
class TestGuardianResult:
"""Tests for GuardianResult model."""
def test_allowed_result(self) -> None:
"""Test an allowed result."""
result = GuardianResult(
action_id="action-1",
allowed=True,
decision=SafetyDecision.ALLOW,
reasons=["All checks passed"],
)
assert result.decision == SafetyDecision.ALLOW
assert result.allowed is True
assert result.approval_id is None
def test_approval_required_result(self) -> None:
"""Test a result requiring approval."""
result = GuardianResult(
action_id="action-1",
allowed=False,
decision=SafetyDecision.REQUIRE_APPROVAL,
reasons=["Action requires human approval"],
approval_id="approval-123",
)
assert result.decision == SafetyDecision.REQUIRE_APPROVAL
assert result.approval_id == "approval-123"
class TestEnums:
"""Tests for enum values."""
def test_action_types(self) -> None:
"""Test action type enum values."""
assert ActionType.FILE_READ.value == "file_read"
assert ActionType.SHELL_COMMAND.value == "shell_command"
def test_autonomy_levels(self) -> None:
"""Test autonomy level enum values."""
assert AutonomyLevel.FULL_CONTROL.value == "full_control"
assert AutonomyLevel.MILESTONE.value == "milestone"
assert AutonomyLevel.AUTONOMOUS.value == "autonomous"
def test_permission_levels(self) -> None:
"""Test permission level enum values."""
assert PermissionLevel.NONE.value == "none"
assert PermissionLevel.READ.value == "read"
assert PermissionLevel.WRITE.value == "write"
assert PermissionLevel.ADMIN.value == "admin"
def test_safety_decisions(self) -> None:
"""Test safety decision enum values."""
assert SafetyDecision.ALLOW.value == "allow"
assert SafetyDecision.DENY.value == "deny"
assert SafetyDecision.REQUIRE_APPROVAL.value == "require_approval"

View File

@@ -0,0 +1,404 @@
"""Tests for safety validation module."""
import pytest
from app.services.safety.models import (
ActionMetadata,
ActionRequest,
ActionType,
AutonomyLevel,
SafetyDecision,
SafetyPolicy,
ValidationRule,
)
from app.services.safety.validation.validator import (
ActionValidator,
create_allow_rule,
create_approval_rule,
create_deny_rule,
)
@pytest.fixture
def validator() -> ActionValidator:
"""Create a fresh ActionValidator."""
return ActionValidator(cache_enabled=False)
@pytest.fixture
def sample_action() -> ActionRequest:
"""Create a sample action request."""
metadata = ActionMetadata(
agent_id="test-agent",
session_id="test-session",
autonomy_level=AutonomyLevel.MILESTONE,
)
return ActionRequest(
action_type=ActionType.FILE_READ,
tool_name="file_read",
resource="/tmp/test.txt", # noqa: S108
metadata=metadata,
)
class TestActionValidator:
"""Tests for ActionValidator class."""
@pytest.mark.asyncio
async def test_no_rules_allows_by_default(
self,
validator: ActionValidator,
sample_action: ActionRequest,
) -> None:
"""Test that actions are allowed by default with no rules."""
result = await validator.validate(sample_action)
assert result.decision == SafetyDecision.ALLOW
assert "No matching rules" in result.reasons[0]
@pytest.mark.asyncio
async def test_deny_rule_blocks_action(
self,
validator: ActionValidator,
) -> None:
"""Test that a deny rule blocks matching actions."""
validator.add_rule(
create_deny_rule(
name="deny_shell",
tool_patterns=["shell_*"],
reason="Shell commands not allowed",
)
)
metadata = ActionMetadata(agent_id="test-agent", session_id="session-1")
action = ActionRequest(
action_type=ActionType.SHELL_COMMAND,
tool_name="shell_exec",
metadata=metadata,
)
result = await validator.validate(action)
assert result.decision == SafetyDecision.DENY
assert len(result.applied_rules) == 1 # One rule applied
@pytest.mark.asyncio
async def test_approval_rule_requires_approval(
self,
validator: ActionValidator,
) -> None:
"""Test that an approval rule requires approval."""
validator.add_rule(
create_approval_rule(
name="approve_db",
tool_patterns=["database_*"],
reason="Database operations require approval",
)
)
metadata = ActionMetadata(agent_id="test-agent", session_id="session-1")
action = ActionRequest(
action_type=ActionType.DATABASE_MUTATE,
tool_name="database_delete",
metadata=metadata,
)
result = await validator.validate(action)
assert result.decision == SafetyDecision.REQUIRE_APPROVAL
@pytest.mark.asyncio
async def test_deny_takes_precedence(
self,
validator: ActionValidator,
) -> None:
"""Test that deny rules take precedence over allow rules."""
validator.add_rule(
create_allow_rule(
name="allow_files",
tool_patterns=["file_*"],
priority=10,
)
)
validator.add_rule(
create_deny_rule(
name="deny_delete",
action_types=[ActionType.FILE_DELETE],
priority=100,
)
)
metadata = ActionMetadata(agent_id="test-agent", session_id="session-1")
action = ActionRequest(
action_type=ActionType.FILE_DELETE,
tool_name="file_delete",
metadata=metadata,
)
result = await validator.validate(action)
assert result.decision == SafetyDecision.DENY
@pytest.mark.asyncio
async def test_rule_priority_ordering(
self,
validator: ActionValidator,
) -> None:
"""Test that rules are evaluated in priority order."""
validator.add_rule(
ValidationRule(
name="low_priority",
priority=1,
decision=SafetyDecision.ALLOW,
)
)
validator.add_rule(
ValidationRule(
name="high_priority",
priority=100,
decision=SafetyDecision.DENY,
)
)
# High priority should be first in the list
assert validator._rules[0].name == "high_priority"
@pytest.mark.asyncio
async def test_disabled_rule_not_applied(
self,
validator: ActionValidator,
sample_action: ActionRequest,
) -> None:
"""Test that disabled rules are not applied."""
rule = create_deny_rule(
name="deny_all",
tool_patterns=["*"],
)
rule.enabled = False
validator.add_rule(rule)
result = await validator.validate(sample_action)
assert result.decision == SafetyDecision.ALLOW
@pytest.mark.asyncio
async def test_resource_pattern_matching(
self,
validator: ActionValidator,
) -> None:
"""Test resource pattern matching."""
validator.add_rule(
create_deny_rule(
name="deny_secrets",
resource_patterns=["*/secrets/*", "*.env"],
)
)
metadata = ActionMetadata(agent_id="test-agent", session_id="session-1")
action = ActionRequest(
action_type=ActionType.FILE_READ,
tool_name="file_read",
resource="/app/secrets/api_key.txt",
metadata=metadata,
)
result = await validator.validate(action)
assert result.decision == SafetyDecision.DENY
@pytest.mark.asyncio
async def test_agent_id_filter(
self,
validator: ActionValidator,
) -> None:
"""Test filtering by agent ID."""
rule = ValidationRule(
name="restrict_agent",
agent_ids=["restricted-agent"],
decision=SafetyDecision.DENY,
reason="Restricted agent",
)
validator.add_rule(rule)
# Restricted agent should be denied
metadata1 = ActionMetadata(agent_id="restricted-agent")
action1 = ActionRequest(
action_type=ActionType.FILE_READ,
tool_name="file_read",
metadata=metadata1,
)
result1 = await validator.validate(action1)
assert result1.decision == SafetyDecision.DENY
# Other agents should be allowed
metadata2 = ActionMetadata(agent_id="normal-agent")
action2 = ActionRequest(
action_type=ActionType.FILE_READ,
tool_name="file_read",
metadata=metadata2,
)
result2 = await validator.validate(action2)
assert result2.decision == SafetyDecision.ALLOW
@pytest.mark.asyncio
async def test_bypass_mode(
self,
validator: ActionValidator,
sample_action: ActionRequest,
) -> None:
"""Test validation bypass mode."""
validator.add_rule(create_deny_rule(name="deny_all", tool_patterns=["*"]))
# Should be denied normally
result1 = await validator.validate(sample_action)
assert result1.decision == SafetyDecision.DENY
# Enable bypass
validator.enable_bypass("Emergency situation")
result2 = await validator.validate(sample_action)
assert result2.decision == SafetyDecision.ALLOW
assert "bypassed" in result2.reasons[0].lower()
# Disable bypass
validator.disable_bypass()
result3 = await validator.validate(sample_action)
assert result3.decision == SafetyDecision.DENY
def test_remove_rule(self, validator: ActionValidator) -> None:
"""Test removing a rule."""
rule = create_deny_rule(name="test_rule", tool_patterns=["test"])
validator.add_rule(rule)
assert len(validator._rules) == 1
assert validator.remove_rule(rule.id) is True
assert len(validator._rules) == 0
def test_remove_nonexistent_rule(self, validator: ActionValidator) -> None:
"""Test removing a nonexistent rule returns False."""
assert validator.remove_rule("nonexistent") is False
def test_clear_rules(self, validator: ActionValidator) -> None:
"""Test clearing all rules."""
validator.add_rule(create_deny_rule(name="rule1", tool_patterns=["a"]))
validator.add_rule(create_deny_rule(name="rule2", tool_patterns=["b"]))
assert len(validator._rules) == 2
validator.clear_rules()
assert len(validator._rules) == 0
class TestLoadRulesFromPolicy:
"""Tests for loading rules from policies."""
@pytest.mark.asyncio
async def test_load_denied_tools(
self,
validator: ActionValidator,
) -> None:
"""Test loading denied tools from policy."""
policy = SafetyPolicy(
name="test",
denied_tools=["shell_*", "exec_*"],
)
validator.load_rules_from_policy(policy)
# Should have 2 deny rules
deny_rules = [r for r in validator._rules if r.decision == SafetyDecision.DENY]
assert len(deny_rules) == 2
@pytest.mark.asyncio
async def test_load_approval_patterns(
self,
validator: ActionValidator,
) -> None:
"""Test loading approval patterns from policy."""
policy = SafetyPolicy(
name="test",
require_approval_for=["database_*"],
)
validator.load_rules_from_policy(policy)
approval_rules = [
r for r in validator._rules
if r.decision == SafetyDecision.REQUIRE_APPROVAL
]
assert len(approval_rules) == 1
class TestValidationBatch:
"""Tests for batch validation."""
@pytest.mark.asyncio
async def test_validate_batch(
self,
validator: ActionValidator,
) -> None:
"""Test validating multiple actions."""
validator.add_rule(
create_deny_rule(
name="deny_shell",
tool_patterns=["shell_*"],
)
)
metadata = ActionMetadata(agent_id="test-agent", session_id="session-1")
actions = [
ActionRequest(
action_type=ActionType.FILE_READ,
tool_name="file_read",
metadata=metadata,
),
ActionRequest(
action_type=ActionType.SHELL_COMMAND,
tool_name="shell_exec",
metadata=metadata,
),
]
results = await validator.validate_batch(actions)
assert len(results) == 2
assert results[0].decision == SafetyDecision.ALLOW
assert results[1].decision == SafetyDecision.DENY
class TestHelperFunctions:
"""Tests for rule creation helper functions."""
def test_create_allow_rule(self) -> None:
"""Test creating an allow rule."""
rule = create_allow_rule(
name="allow_test",
tool_patterns=["test_*"],
priority=50,
)
assert rule.name == "allow_test"
assert rule.decision == SafetyDecision.ALLOW
assert rule.priority == 50
def test_create_deny_rule(self) -> None:
"""Test creating a deny rule."""
rule = create_deny_rule(
name="deny_test",
tool_patterns=["dangerous_*"],
reason="Too dangerous",
)
assert rule.name == "deny_test"
assert rule.decision == SafetyDecision.DENY
assert rule.reason == "Too dangerous"
assert rule.priority == 100 # Default priority for deny
def test_create_approval_rule(self) -> None:
"""Test creating an approval rule."""
rule = create_approval_rule(
name="approve_test",
action_types=[ActionType.DATABASE_MUTATE],
)
assert rule.name == "approve_test"
assert rule.decision == SafetyDecision.REQUIRE_APPROVAL
assert rule.priority == 50 # Default priority for approval