From 015f2de6c686db6db89841b7987ed78943b49bbb Mon Sep 17 00:00:00 2001 From: Felipe Cardoso Date: Sat, 3 Jan 2026 11:52:35 +0100 Subject: [PATCH] test(safety): add Phase E comprehensive safety tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add tests for models: ActionMetadata, ActionRequest, ActionResult, ValidationRule, BudgetStatus, RateLimitConfig, ApprovalRequest/Response, Checkpoint, RollbackResult, AuditEvent, SafetyPolicy, GuardianResult - Add tests for validation: ActionValidator rules, priorities, patterns, bypass mode, batch validation, rule creation helpers - Add tests for loops: LoopDetector exact/semantic/oscillation detection, LoopBreaker throttle/backoff, history management - Add tests for content filter: PII filtering (email, phone, SSN, credit card), secret blocking (API keys, GitHub tokens, private keys), custom patterns, scan without filtering, dict filtering - Add tests for emergency controls: state management, pause/resume/reset, scoped emergency stops, callbacks, EmergencyTrigger events - Fix exception kwargs in content filter and emergency controls to match exception class signatures All 108 tests passing with lint and type checks clean. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- backend/app/services/safety/content/filter.py | 4 +- .../app/services/safety/emergency/controls.py | 6 +- backend/tests/services/safety/__init__.py | 0 .../services/safety/test_content_filter.py | 345 ++++++++++++++ .../tests/services/safety/test_emergency.py | 425 +++++++++++++++++ backend/tests/services/safety/test_loops.py | 316 +++++++++++++ backend/tests/services/safety/test_models.py | 437 ++++++++++++++++++ .../tests/services/safety/test_validation.py | 404 ++++++++++++++++ 8 files changed, 1932 insertions(+), 5 deletions(-) create mode 100644 backend/tests/services/safety/__init__.py create mode 100644 backend/tests/services/safety/test_content_filter.py create mode 100644 backend/tests/services/safety/test_emergency.py create mode 100644 backend/tests/services/safety/test_loops.py create mode 100644 backend/tests/services/safety/test_models.py create mode 100644 backend/tests/services/safety/test_validation.py diff --git a/backend/app/services/safety/content/filter.py b/backend/app/services/safety/content/filter.py index 8f3b795..b6585df 100644 --- a/backend/app/services/safety/content/filter.py +++ b/backend/app/services/safety/content/filter.py @@ -370,8 +370,8 @@ class ContentFilter: if raise_on_block: raise ContentFilterError( block_reason or "Content blocked", - detected_category=all_matches[0].category.value if all_matches else "unknown", - pattern_name=all_matches[0].pattern_name if all_matches else None, + filter_type=all_matches[0].category.value if all_matches else "unknown", + detected_patterns=[m.pattern_name for m in all_matches] if all_matches else [], ) elif all_matches: logger.debug( diff --git a/backend/app/services/safety/emergency/controls.py b/backend/app/services/safety/emergency/controls.py index b565515..1b72bcc 100644 --- a/backend/app/services/safety/emergency/controls.py +++ b/backend/app/services/safety/emergency/controls.py @@ -328,7 +328,7 @@ class EmergencyControls: if raise_if_blocked: raise EmergencyStopError( f"Global emergency state: {self._global_state.value}", - stop_reason=self._get_last_reason("global"), + stop_type=self._get_last_reason("global") or "emergency", triggered_by=self._get_last_triggered_by("global"), ) return False @@ -340,9 +340,9 @@ class EmergencyControls: if raise_if_blocked: raise EmergencyStopError( f"Emergency state for {scope}: {state.value}", - stop_reason=self._get_last_reason(scope), + stop_type=self._get_last_reason(scope) or "emergency", triggered_by=self._get_last_triggered_by(scope), - scope=scope, + details={"scope": scope}, ) return False diff --git a/backend/tests/services/safety/__init__.py b/backend/tests/services/safety/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/services/safety/test_content_filter.py b/backend/tests/services/safety/test_content_filter.py new file mode 100644 index 0000000..c7e811b --- /dev/null +++ b/backend/tests/services/safety/test_content_filter.py @@ -0,0 +1,345 @@ +"""Tests for content filtering module.""" + +import pytest + +from app.services.safety.content.filter import ( + ContentCategory, + ContentFilter, + FilterAction, + FilterPattern, + filter_content, + scan_for_secrets, +) +from app.services.safety.exceptions import ContentFilterError + + +@pytest.fixture +def filter_all() -> ContentFilter: + """Create a ContentFilter with all filters enabled.""" + return ContentFilter( + enable_pii_filter=True, + enable_secret_filter=True, + enable_injection_filter=True, + ) + + +@pytest.fixture +def filter_pii_only() -> ContentFilter: + """Create a ContentFilter with only PII filter.""" + return ContentFilter( + enable_pii_filter=True, + enable_secret_filter=False, + enable_injection_filter=False, + ) + + +class TestContentFilter: + """Tests for ContentFilter class.""" + + @pytest.mark.asyncio + async def test_filter_email(self, filter_pii_only: ContentFilter) -> None: + """Test filtering email addresses.""" + content = "Contact me at john.doe@example.com for details." + result = await filter_pii_only.filter(content) + + assert result.has_sensitive_content + assert "[EMAIL]" in result.filtered_content + assert "john.doe@example.com" not in result.filtered_content + assert len(result.matches) == 1 + assert result.matches[0].category == ContentCategory.PII + + @pytest.mark.asyncio + async def test_filter_phone_number(self, filter_pii_only: ContentFilter) -> None: + """Test filtering phone numbers.""" + content = "Call me at 555-123-4567 or (555) 987-6543." + result = await filter_pii_only.filter(content) + + assert result.has_sensitive_content + assert "[PHONE]" in result.filtered_content + # Should redact both phone numbers + assert "555-123-4567" not in result.filtered_content + + @pytest.mark.asyncio + async def test_filter_ssn(self, filter_pii_only: ContentFilter) -> None: + """Test filtering Social Security Numbers.""" + content = "SSN: 123-45-6789" + result = await filter_pii_only.filter(content) + + assert result.has_sensitive_content + assert "[SSN]" in result.filtered_content + assert "123-45-6789" not in result.filtered_content + + @pytest.mark.asyncio + async def test_filter_credit_card(self, filter_all: ContentFilter) -> None: + """Test filtering credit card numbers.""" + content = "Card: 4111-2222-3333-4444" + result = await filter_all.filter(content) + + assert result.has_sensitive_content + assert "[CREDIT_CARD]" in result.filtered_content + + @pytest.mark.asyncio + async def test_block_api_key(self, filter_all: ContentFilter) -> None: + """Test blocking API keys.""" + content = "api_key: sk-abcdef1234567890abcdef1234567890" + result = await filter_all.filter(content) + + assert result.blocked + assert result.block_reason is not None + assert "api_key" in result.block_reason.lower() + + @pytest.mark.asyncio + async def test_block_github_token(self, filter_all: ContentFilter) -> None: + """Test blocking GitHub tokens.""" + content = "token: ghp_abcdefghijklmnopqrstuvwxyz1234567890" + result = await filter_all.filter(content) + + assert result.blocked + assert len(result.matches) > 0 + + @pytest.mark.asyncio + async def test_block_private_key(self, filter_all: ContentFilter) -> None: + """Test blocking private keys.""" + content = """ + -----BEGIN RSA PRIVATE KEY----- + MIIEpAIBAAKCAQEA... + -----END RSA PRIVATE KEY----- + """ + result = await filter_all.filter(content) + + assert result.blocked + + @pytest.mark.asyncio + async def test_block_password_in_url(self, filter_all: ContentFilter) -> None: + """Test blocking passwords in URLs.""" + content = "Connect to: postgres://user:secretpassword@localhost/db" + result = await filter_all.filter(content) + + assert result.blocked or result.has_sensitive_content + + @pytest.mark.asyncio + async def test_warn_ip_address(self, filter_pii_only: ContentFilter) -> None: + """Test warning on IP addresses.""" + content = "Server IP: 192.168.1.100" + result = await filter_pii_only.filter(content) + + # IP addresses generate warnings, not blocks + assert len(result.warnings) > 0 or result.has_sensitive_content + + @pytest.mark.asyncio + async def test_no_false_positives_clean_content( + self, + filter_all: ContentFilter, + ) -> None: + """Test that clean content passes through.""" + content = "This is a normal message with no sensitive data." + result = await filter_all.filter(content) + + assert not result.blocked + assert result.filtered_content == content + + @pytest.mark.asyncio + async def test_raise_on_block(self, filter_all: ContentFilter) -> None: + """Test raising exception on blocked content.""" + content = "-----BEGIN RSA PRIVATE KEY-----" + + with pytest.raises(ContentFilterError): + await filter_all.filter(content, raise_on_block=True) + + +class TestFilterDict: + """Tests for dictionary filtering.""" + + @pytest.mark.asyncio + async def test_filter_dict_values(self, filter_pii_only: ContentFilter) -> None: + """Test filtering string values in a dictionary.""" + data = { + "name": "John Doe", + "email": "john@example.com", + "age": 30, + } + result = await filter_pii_only.filter_dict(data) + + assert "[EMAIL]" in result["email"] + assert result["age"] == 30 # Non-string unchanged + + @pytest.mark.asyncio + async def test_filter_dict_recursive( + self, + filter_pii_only: ContentFilter, + ) -> None: + """Test recursive dictionary filtering.""" + data = { + "user": { + "contact": { + "email": "test@example.com", + } + } + } + result = await filter_pii_only.filter_dict(data, recursive=True) + + assert "[EMAIL]" in result["user"]["contact"]["email"] + + @pytest.mark.asyncio + async def test_filter_dict_specific_keys( + self, + filter_pii_only: ContentFilter, + ) -> None: + """Test filtering specific keys only.""" + data = { + "public_email": "public@example.com", + "private_email": "private@example.com", + } + result = await filter_pii_only.filter_dict( + data, + keys_to_filter=["private_email"], + ) + + # Only private_email should be filtered + assert "public@example.com" in result["public_email"] + assert "[EMAIL]" in result["private_email"] + + +class TestScan: + """Tests for content scanning.""" + + @pytest.mark.asyncio + async def test_scan_without_filtering( + self, + filter_all: ContentFilter, + ) -> None: + """Test scanning without modifying content.""" + content = "Email: test@example.com, SSN: 123-45-6789" + matches = await filter_all.scan(content) + + assert len(matches) >= 2 + + @pytest.mark.asyncio + async def test_scan_specific_categories( + self, + filter_all: ContentFilter, + ) -> None: + """Test scanning for specific categories only.""" + content = "Email: test@example.com, token: ghp_abc123456789012345678901234567890123" + + # Scan only for secrets + matches = await filter_all.scan( + content, + categories=[ContentCategory.SECRETS], + ) + + # Should only find the token, not the email + assert all(m.category == ContentCategory.SECRETS for m in matches) + + +class TestValidateSafe: + """Tests for safe validation.""" + + @pytest.mark.asyncio + async def test_validate_safe_clean(self, filter_all: ContentFilter) -> None: + """Test validation of clean content.""" + content = "This is safe content." + is_safe, issues = await filter_all.validate_safe(content) + + assert is_safe is True + assert len(issues) == 0 + + @pytest.mark.asyncio + async def test_validate_safe_with_secrets( + self, + filter_all: ContentFilter, + ) -> None: + """Test validation of content with secrets.""" + content = "-----BEGIN RSA PRIVATE KEY-----" + is_safe, issues = await filter_all.validate_safe(content) + + assert is_safe is False + assert len(issues) > 0 + + +class TestCustomPatterns: + """Tests for custom pattern support.""" + + @pytest.mark.asyncio + async def test_add_custom_pattern(self) -> None: + """Test adding a custom filter pattern.""" + content_filter = ContentFilter( + enable_pii_filter=False, + enable_secret_filter=False, + enable_injection_filter=False, + ) + + # Add custom pattern for internal IDs + content_filter.add_pattern( + FilterPattern( + name="internal_id", + category=ContentCategory.CUSTOM, + pattern=r"INTERNAL-[A-Z0-9]{8}", + action=FilterAction.REDACT, + replacement="[INTERNAL_ID]", + ) + ) + + content = "Reference: INTERNAL-ABC12345" + result = await content_filter.filter(content) + + assert "[INTERNAL_ID]" in result.filtered_content + + @pytest.mark.asyncio + async def test_disable_pattern(self, filter_pii_only: ContentFilter) -> None: + """Test disabling a pattern.""" + filter_pii_only.enable_pattern("email", enabled=False) + + content = "Email: test@example.com" + result = await filter_pii_only.filter(content) + + # Email should not be filtered + assert "test@example.com" in result.filtered_content + + @pytest.mark.asyncio + async def test_remove_pattern(self, filter_pii_only: ContentFilter) -> None: + """Test removing a pattern.""" + removed = filter_pii_only.remove_pattern("email") + assert removed is True + + content = "Email: test@example.com" + result = await filter_pii_only.filter(content) + + # Email should not be filtered + assert "test@example.com" in result.filtered_content + + +class TestPatternStats: + """Tests for pattern statistics.""" + + def test_get_pattern_stats(self, filter_all: ContentFilter) -> None: + """Test getting pattern statistics.""" + stats = filter_all.get_pattern_stats() + + assert stats["total_patterns"] > 0 + assert "by_category" in stats + assert "by_action" in stats + + +class TestConvenienceFunctions: + """Tests for convenience functions.""" + + @pytest.mark.asyncio + async def test_filter_content_function(self) -> None: + """Test the quick filter function.""" + content = "Email: test@example.com" + filtered = await filter_content(content) + + assert "test@example.com" not in filtered + + @pytest.mark.asyncio + async def test_scan_for_secrets_function(self) -> None: + """Test the quick secret scan function.""" + content = "Token: ghp_abcdefghijklmnopqrstuvwxyz1234567890" + matches = await scan_for_secrets(content) + + assert len(matches) > 0 + assert matches[0].category in ( + ContentCategory.SECRETS, + ContentCategory.CREDENTIALS, + ) diff --git a/backend/tests/services/safety/test_emergency.py b/backend/tests/services/safety/test_emergency.py new file mode 100644 index 0000000..8c1ace2 --- /dev/null +++ b/backend/tests/services/safety/test_emergency.py @@ -0,0 +1,425 @@ +"""Tests for emergency controls module.""" + +import pytest + +from app.services.safety.emergency.controls import ( + EmergencyControls, + EmergencyReason, + EmergencyState, + EmergencyTrigger, +) +from app.services.safety.exceptions import EmergencyStopError + + +@pytest.fixture +def controls() -> EmergencyControls: + """Create fresh EmergencyControls.""" + return EmergencyControls() + + +class TestEmergencyControls: + """Tests for EmergencyControls class.""" + + @pytest.mark.asyncio + async def test_initial_state_is_normal( + self, + controls: EmergencyControls, + ) -> None: + """Test that initial state is normal.""" + state = await controls.get_state("global") + assert state == EmergencyState.NORMAL + + @pytest.mark.asyncio + async def test_emergency_stop_changes_state( + self, + controls: EmergencyControls, + ) -> None: + """Test that emergency stop changes state to stopped.""" + event = await controls.emergency_stop( + reason=EmergencyReason.MANUAL, + triggered_by="test", + message="Test emergency stop", + ) + + assert event.state == EmergencyState.STOPPED + state = await controls.get_state("global") + assert state == EmergencyState.STOPPED + + @pytest.mark.asyncio + async def test_pause_changes_state( + self, + controls: EmergencyControls, + ) -> None: + """Test that pause changes state to paused.""" + event = await controls.pause( + reason=EmergencyReason.BUDGET_EXCEEDED, + triggered_by="budget_controller", + message="Budget exceeded", + ) + + assert event.state == EmergencyState.PAUSED + state = await controls.get_state("global") + assert state == EmergencyState.PAUSED + + @pytest.mark.asyncio + async def test_resume_from_paused( + self, + controls: EmergencyControls, + ) -> None: + """Test resuming from paused state.""" + await controls.pause( + reason=EmergencyReason.RATE_LIMIT, + triggered_by="limiter", + message="Rate limited", + ) + + resumed = await controls.resume(resumed_by="admin") + + assert resumed is True + state = await controls.get_state("global") + assert state == EmergencyState.NORMAL + + @pytest.mark.asyncio + async def test_cannot_resume_from_stopped( + self, + controls: EmergencyControls, + ) -> None: + """Test that you cannot resume from stopped state.""" + await controls.emergency_stop( + reason=EmergencyReason.SAFETY_VIOLATION, + triggered_by="safety", + message="Critical violation", + ) + + resumed = await controls.resume(resumed_by="admin") + + assert resumed is False + state = await controls.get_state("global") + assert state == EmergencyState.STOPPED + + @pytest.mark.asyncio + async def test_reset_from_stopped( + self, + controls: EmergencyControls, + ) -> None: + """Test resetting from stopped state.""" + await controls.emergency_stop( + reason=EmergencyReason.SAFETY_VIOLATION, + triggered_by="safety", + message="Critical violation", + ) + + reset = await controls.reset(reset_by="admin") + + assert reset is True + state = await controls.get_state("global") + assert state == EmergencyState.NORMAL + + @pytest.mark.asyncio + async def test_scoped_emergency_stop( + self, + controls: EmergencyControls, + ) -> None: + """Test emergency stop with specific scope.""" + await controls.emergency_stop( + reason=EmergencyReason.LOOP_DETECTED, + triggered_by="detector", + message="Loop in agent", + scope="agent:agent-123", + ) + + # Agent scope should be stopped + agent_state = await controls.get_state("agent:agent-123") + assert agent_state == EmergencyState.STOPPED + + # Global should still be normal + global_state = await controls.get_state("global") + assert global_state == EmergencyState.NORMAL + + @pytest.mark.asyncio + async def test_check_allowed_when_normal( + self, + controls: EmergencyControls, + ) -> None: + """Test check_allowed returns True when state is normal.""" + allowed = await controls.check_allowed() + assert allowed is True + + @pytest.mark.asyncio + async def test_check_allowed_when_stopped( + self, + controls: EmergencyControls, + ) -> None: + """Test check_allowed returns False when stopped.""" + await controls.emergency_stop( + reason=EmergencyReason.MANUAL, + triggered_by="test", + message="Stop", + ) + + allowed = await controls.check_allowed(raise_if_blocked=False) + assert allowed is False + + @pytest.mark.asyncio + async def test_check_allowed_raises_when_blocked( + self, + controls: EmergencyControls, + ) -> None: + """Test check_allowed raises exception when blocked.""" + await controls.emergency_stop( + reason=EmergencyReason.MANUAL, + triggered_by="test", + message="Stop", + ) + + with pytest.raises(EmergencyStopError): + await controls.check_allowed(raise_if_blocked=True) + + @pytest.mark.asyncio + async def test_check_allowed_with_scope( + self, + controls: EmergencyControls, + ) -> None: + """Test check_allowed with specific scope.""" + await controls.pause( + reason=EmergencyReason.BUDGET_EXCEEDED, + triggered_by="budget", + message="Paused", + scope="project:proj-123", + ) + + # Project scope should be blocked + allowed_project = await controls.check_allowed( + scope="project:proj-123", + raise_if_blocked=False, + ) + assert allowed_project is False + + # Different scope should be allowed + allowed_other = await controls.check_allowed( + scope="project:proj-456", + raise_if_blocked=False, + ) + assert allowed_other is True + + @pytest.mark.asyncio + async def test_get_all_states( + self, + controls: EmergencyControls, + ) -> None: + """Test getting all states.""" + await controls.pause( + reason=EmergencyReason.MANUAL, + triggered_by="test", + message="Pause", + scope="agent:a1", + ) + await controls.emergency_stop( + reason=EmergencyReason.MANUAL, + triggered_by="test", + message="Stop", + scope="agent:a2", + ) + + states = await controls.get_all_states() + + assert states["global"] == EmergencyState.NORMAL + assert states["agent:a1"] == EmergencyState.PAUSED + assert states["agent:a2"] == EmergencyState.STOPPED + + @pytest.mark.asyncio + async def test_get_active_events( + self, + controls: EmergencyControls, + ) -> None: + """Test getting active (unresolved) events.""" + await controls.pause( + reason=EmergencyReason.MANUAL, + triggered_by="test", + message="Pause 1", + ) + + events = await controls.get_active_events() + assert len(events) == 1 + + # Resume should resolve the event + await controls.resume() + events_after = await controls.get_active_events() + assert len(events_after) == 0 + + @pytest.mark.asyncio + async def test_event_history( + self, + controls: EmergencyControls, + ) -> None: + """Test getting event history.""" + await controls.pause( + reason=EmergencyReason.RATE_LIMIT, + triggered_by="test", + message="Rate limited", + ) + await controls.resume() + + history = await controls.get_event_history() + + assert len(history) == 1 + assert history[0].reason == EmergencyReason.RATE_LIMIT + + @pytest.mark.asyncio + async def test_event_metadata( + self, + controls: EmergencyControls, + ) -> None: + """Test event metadata storage.""" + event = await controls.emergency_stop( + reason=EmergencyReason.BUDGET_EXCEEDED, + triggered_by="budget_controller", + message="Over budget", + metadata={"budget_type": "tokens", "usage": 150000}, + ) + + assert event.metadata["budget_type"] == "tokens" + assert event.metadata["usage"] == 150000 + + +class TestCallbacks: + """Tests for callback functionality.""" + + @pytest.mark.asyncio + async def test_on_stop_callback( + self, + controls: EmergencyControls, + ) -> None: + """Test on_stop callback is called.""" + callback_called = [] + + def callback(event: object) -> None: + callback_called.append(event) + + controls.on_stop(callback) + + await controls.emergency_stop( + reason=EmergencyReason.MANUAL, + triggered_by="test", + message="Stop", + ) + + assert len(callback_called) == 1 + + @pytest.mark.asyncio + async def test_on_pause_callback( + self, + controls: EmergencyControls, + ) -> None: + """Test on_pause callback is called.""" + callback_called = [] + + def callback(event: object) -> None: + callback_called.append(event) + + controls.on_pause(callback) + + await controls.pause( + reason=EmergencyReason.MANUAL, + triggered_by="test", + message="Pause", + ) + + assert len(callback_called) == 1 + + @pytest.mark.asyncio + async def test_on_resume_callback( + self, + controls: EmergencyControls, + ) -> None: + """Test on_resume callback is called.""" + callback_called = [] + + def callback(data: object) -> None: + callback_called.append(data) + + controls.on_resume(callback) + + await controls.pause( + reason=EmergencyReason.MANUAL, + triggered_by="test", + message="Pause", + ) + await controls.resume() + + assert len(callback_called) == 1 + + +class TestEmergencyTrigger: + """Tests for EmergencyTrigger class.""" + + @pytest.fixture + def trigger(self, controls: EmergencyControls) -> EmergencyTrigger: + """Create an EmergencyTrigger.""" + return EmergencyTrigger(controls) + + @pytest.mark.asyncio + async def test_trigger_on_safety_violation( + self, + trigger: EmergencyTrigger, + controls: EmergencyControls, + ) -> None: + """Test triggering emergency on safety violation.""" + event = await trigger.trigger_on_safety_violation( + violation_type="unauthorized_access", + details={"resource": "/secrets/key"}, + ) + + assert event.reason == EmergencyReason.SAFETY_VIOLATION + assert event.state == EmergencyState.STOPPED + + state = await controls.get_state("global") + assert state == EmergencyState.STOPPED + + @pytest.mark.asyncio + async def test_trigger_on_budget_exceeded( + self, + trigger: EmergencyTrigger, + controls: EmergencyControls, + ) -> None: + """Test triggering pause on budget exceeded.""" + event = await trigger.trigger_on_budget_exceeded( + budget_type="tokens", + current=150000, + limit=100000, + ) + + assert event.reason == EmergencyReason.BUDGET_EXCEEDED + assert event.state == EmergencyState.PAUSED # Pause, not stop + + @pytest.mark.asyncio + async def test_trigger_on_loop_detected( + self, + trigger: EmergencyTrigger, + controls: EmergencyControls, + ) -> None: + """Test triggering pause on loop detection.""" + event = await trigger.trigger_on_loop_detected( + loop_type="exact", + agent_id="agent-123", + details={"pattern": "file_read"}, + ) + + assert event.reason == EmergencyReason.LOOP_DETECTED + assert event.scope == "agent:agent-123" + assert event.state == EmergencyState.PAUSED + + @pytest.mark.asyncio + async def test_trigger_on_content_violation( + self, + trigger: EmergencyTrigger, + controls: EmergencyControls, + ) -> None: + """Test triggering stop on content violation.""" + event = await trigger.trigger_on_content_violation( + category="secrets", + pattern="private_key", + ) + + assert event.reason == EmergencyReason.CONTENT_VIOLATION + assert event.state == EmergencyState.STOPPED diff --git a/backend/tests/services/safety/test_loops.py b/backend/tests/services/safety/test_loops.py new file mode 100644 index 0000000..eb8a319 --- /dev/null +++ b/backend/tests/services/safety/test_loops.py @@ -0,0 +1,316 @@ +"""Tests for loop detection module.""" + +import pytest + +from app.services.safety.exceptions import LoopDetectedError +from app.services.safety.loops.detector import ( + ActionSignature, + LoopBreaker, + LoopDetector, +) +from app.services.safety.models import ( + ActionMetadata, + ActionRequest, + ActionType, + AutonomyLevel, +) + + +@pytest.fixture +def detector() -> LoopDetector: + """Create a fresh LoopDetector with low thresholds for testing.""" + return LoopDetector( + history_size=20, + max_exact_repetitions=3, + max_semantic_repetitions=5, + ) + + +@pytest.fixture +def sample_metadata() -> ActionMetadata: + """Create sample action metadata.""" + return ActionMetadata( + agent_id="test-agent", + session_id="test-session", + autonomy_level=AutonomyLevel.MILESTONE, + ) + + +def create_action( + metadata: ActionMetadata, + tool_name: str, + resource: str = "/tmp/test.txt", # noqa: S108 + arguments: dict | None = None, +) -> ActionRequest: + """Helper to create test actions.""" + return ActionRequest( + action_type=ActionType.FILE_READ, + tool_name=tool_name, + resource=resource, + arguments=arguments or {}, + metadata=metadata, + ) + + +class TestActionSignature: + """Tests for ActionSignature class.""" + + def test_exact_key_includes_args(self, sample_metadata: ActionMetadata) -> None: + """Test that exact key includes argument hash.""" + action1 = create_action(sample_metadata, "file_read", arguments={"path": "a"}) + action2 = create_action(sample_metadata, "file_read", arguments={"path": "b"}) + + sig1 = ActionSignature(action1) + sig2 = ActionSignature(action2) + + assert sig1.exact_key() != sig2.exact_key() + + def test_semantic_key_ignores_args(self, sample_metadata: ActionMetadata) -> None: + """Test that semantic key ignores arguments.""" + action1 = create_action(sample_metadata, "file_read", arguments={"path": "a"}) + action2 = create_action(sample_metadata, "file_read", arguments={"path": "b"}) + + sig1 = ActionSignature(action1) + sig2 = ActionSignature(action2) + + assert sig1.semantic_key() == sig2.semantic_key() + + def test_type_key(self, sample_metadata: ActionMetadata) -> None: + """Test type key extraction.""" + action = create_action(sample_metadata, "file_read") + sig = ActionSignature(action) + + assert sig.type_key() == "file_read" + + +class TestLoopDetector: + """Tests for LoopDetector class.""" + + @pytest.mark.asyncio + async def test_no_loop_on_first_action( + self, + detector: LoopDetector, + sample_metadata: ActionMetadata, + ) -> None: + """Test that first action is never a loop.""" + action = create_action(sample_metadata, "file_read") + + is_loop, loop_type = await detector.check(action) + + assert is_loop is False + assert loop_type is None + + @pytest.mark.asyncio + async def test_exact_loop_detection( + self, + detector: LoopDetector, + sample_metadata: ActionMetadata, + ) -> None: + """Test detection of exact repetitions.""" + action = create_action( + sample_metadata, + "file_read", + resource="/tmp/same.txt", # noqa: S108 + arguments={"path": "/tmp/same.txt"}, # noqa: S108 + ) + + # Record the same action 3 times (threshold) + for _ in range(3): + await detector.record(action) + + # Next should be detected as a loop + is_loop, loop_type = await detector.check(action) + + assert is_loop is True + assert loop_type == "exact" + + @pytest.mark.asyncio + async def test_semantic_loop_detection( + self, + detector: LoopDetector, + sample_metadata: ActionMetadata, + ) -> None: + """Test detection of semantic (similar) repetitions.""" + # Record same tool/resource but different arguments + test_resource = "/tmp/test.txt" # noqa: S108 + for i in range(5): + action = create_action( + sample_metadata, + "file_read", + resource=test_resource, + arguments={"offset": i}, + ) + await detector.record(action) + + # Next similar action should be detected as semantic loop + action = create_action( + sample_metadata, + "file_read", + resource=test_resource, + arguments={"offset": 100}, + ) + is_loop, loop_type = await detector.check(action) + + assert is_loop is True + assert loop_type == "semantic" + + @pytest.mark.asyncio + async def test_oscillation_detection( + self, + detector: LoopDetector, + sample_metadata: ActionMetadata, + ) -> None: + """Test detection of A→B→A→B oscillation pattern.""" + action_a = create_action(sample_metadata, "tool_a", resource="/a") + action_b = create_action(sample_metadata, "tool_b", resource="/b") + + # Create A→B→A pattern + await detector.record(action_a) + await detector.record(action_b) + await detector.record(action_a) + + # Fourth action completing A→B→A→B should be detected as oscillation + is_loop, loop_type = await detector.check(action_b) + + assert is_loop is True + assert loop_type == "oscillation" + + @pytest.mark.asyncio + async def test_different_actions_no_loop( + self, + detector: LoopDetector, + sample_metadata: ActionMetadata, + ) -> None: + """Test that different actions don't trigger loops.""" + for i in range(10): + action = create_action( + sample_metadata, + f"tool_{i}", + resource=f"/resource_{i}", + ) + is_loop, _ = await detector.check(action) + assert is_loop is False + await detector.record(action) + + @pytest.mark.asyncio + async def test_check_and_raise( + self, + detector: LoopDetector, + sample_metadata: ActionMetadata, + ) -> None: + """Test check_and_raise raises on loop detection.""" + action = create_action(sample_metadata, "file_read") + + # Record threshold number of times + for _ in range(3): + await detector.record(action) + + # Should raise + with pytest.raises(LoopDetectedError) as exc_info: + await detector.check_and_raise(action) + + assert "exact" in exc_info.value.loop_type.lower() + + @pytest.mark.asyncio + async def test_clear_history( + self, + detector: LoopDetector, + sample_metadata: ActionMetadata, + ) -> None: + """Test clearing agent history.""" + action = create_action(sample_metadata, "file_read") + + # Record multiple times + for _ in range(3): + await detector.record(action) + + # Clear history + await detector.clear_history(sample_metadata.agent_id) + + # Should no longer detect loop + is_loop, _ = await detector.check(action) + assert is_loop is False + + @pytest.mark.asyncio + async def test_per_agent_history( + self, + detector: LoopDetector, + ) -> None: + """Test that history is tracked per agent.""" + metadata1 = ActionMetadata(agent_id="agent-1", session_id="s1") + metadata2 = ActionMetadata(agent_id="agent-2", session_id="s2") + + action1 = create_action(metadata1, "file_read") + action2 = create_action(metadata2, "file_read") + + # Record for agent 1 (threshold times) + for _ in range(3): + await detector.record(action1) + + # Agent 1 should detect loop + is_loop1, _ = await detector.check(action1) + assert is_loop1 is True + + # Agent 2 should not detect loop + is_loop2, _ = await detector.check(action2) + assert is_loop2 is False + + @pytest.mark.asyncio + async def test_get_stats( + self, + detector: LoopDetector, + sample_metadata: ActionMetadata, + ) -> None: + """Test getting loop detection stats.""" + for i in range(5): + action = create_action( + sample_metadata, + f"tool_{i % 2}", # Alternate between 2 tools + resource=f"/resource_{i}", + ) + await detector.record(action) + + stats = await detector.get_stats(sample_metadata.agent_id) + + assert stats["history_size"] == 5 + assert len(stats["action_type_counts"]) > 0 + + +class TestLoopBreaker: + """Tests for LoopBreaker class.""" + + @pytest.mark.asyncio + async def test_suggest_alternatives_exact( + self, + sample_metadata: ActionMetadata, + ) -> None: + """Test suggestions for exact loops.""" + action = create_action(sample_metadata, "file_read") + suggestions = await LoopBreaker.suggest_alternatives(action, "exact") + + assert len(suggestions) > 0 + assert "same action" in suggestions[0].lower() + + @pytest.mark.asyncio + async def test_suggest_alternatives_semantic( + self, + sample_metadata: ActionMetadata, + ) -> None: + """Test suggestions for semantic loops.""" + action = create_action(sample_metadata, "file_read") + suggestions = await LoopBreaker.suggest_alternatives(action, "semantic") + + assert len(suggestions) > 0 + assert "similar" in suggestions[0].lower() + + @pytest.mark.asyncio + async def test_suggest_alternatives_oscillation( + self, + sample_metadata: ActionMetadata, + ) -> None: + """Test suggestions for oscillation loops.""" + action = create_action(sample_metadata, "file_read") + suggestions = await LoopBreaker.suggest_alternatives(action, "oscillation") + + assert len(suggestions) > 0 + assert "oscillat" in suggestions[0].lower() diff --git a/backend/tests/services/safety/test_models.py b/backend/tests/services/safety/test_models.py new file mode 100644 index 0000000..ff273cc --- /dev/null +++ b/backend/tests/services/safety/test_models.py @@ -0,0 +1,437 @@ +"""Tests for safety framework models.""" + +from datetime import datetime, timedelta + +from app.services.safety.models import ( + ActionMetadata, + ActionRequest, + ActionResult, + ActionType, + ApprovalRequest, + ApprovalResponse, + ApprovalStatus, + AuditEvent, + AuditEventType, + AutonomyLevel, + BudgetScope, + BudgetStatus, + Checkpoint, + CheckpointType, + GuardianResult, + PermissionLevel, + RateLimitConfig, + RollbackResult, + SafetyDecision, + SafetyPolicy, + ValidationResult, + ValidationRule, +) + + +class TestActionMetadata: + """Tests for ActionMetadata model.""" + + def test_create_with_defaults(self) -> None: + """Test creating metadata with default values.""" + metadata = ActionMetadata( + agent_id="agent-1", + ) + + assert metadata.agent_id == "agent-1" + assert metadata.autonomy_level == AutonomyLevel.MILESTONE + assert metadata.project_id is None + assert metadata.session_id is None + + def test_create_with_all_fields(self) -> None: + """Test creating metadata with all fields.""" + metadata = ActionMetadata( + agent_id="agent-1", + session_id="session-1", + project_id="project-1", + user_id="user-1", + autonomy_level=AutonomyLevel.AUTONOMOUS, + ) + + assert metadata.project_id == "project-1" + assert metadata.user_id == "user-1" + assert metadata.autonomy_level == AutonomyLevel.AUTONOMOUS + + +class TestActionRequest: + """Tests for ActionRequest model.""" + + def test_create_basic_action(self) -> None: + """Test creating a basic action request.""" + metadata = ActionMetadata(agent_id="agent-1", session_id="session-1") + action = ActionRequest( + action_type=ActionType.FILE_READ, + tool_name="file_read", + metadata=metadata, + ) + + assert action.action_type == ActionType.FILE_READ + assert action.tool_name == "file_read" + assert action.id is not None + assert action.metadata.agent_id == "agent-1" + + def test_action_with_arguments(self) -> None: + """Test action with arguments.""" + test_path = "/tmp/test.txt" # noqa: S108 + metadata = ActionMetadata(agent_id="agent-1", session_id="session-1") + action = ActionRequest( + action_type=ActionType.FILE_WRITE, + tool_name="file_write", + arguments={"path": test_path, "content": "hello"}, + resource=test_path, + metadata=metadata, + ) + + assert action.arguments["path"] == test_path + assert action.resource == test_path + + +class TestActionResult: + """Tests for ActionResult model.""" + + def test_successful_result(self) -> None: + """Test creating a successful result.""" + result = ActionResult( + action_id="action-1", + success=True, + data={"output": "done"}, + ) + + assert result.success is True + assert result.data["output"] == "done" + assert result.error is None + + def test_failed_result(self) -> None: + """Test creating a failed result.""" + result = ActionResult( + action_id="action-1", + success=False, + error="Permission denied", + ) + + assert result.success is False + assert result.error == "Permission denied" + + +class TestValidationRule: + """Tests for ValidationRule model.""" + + def test_create_rule(self) -> None: + """Test creating a validation rule.""" + rule = ValidationRule( + name="deny_shell", + description="Deny shell commands", + priority=100, + tool_patterns=["shell_*"], + decision=SafetyDecision.DENY, + reason="Shell commands are not allowed", + ) + + assert rule.name == "deny_shell" + assert rule.priority == 100 + assert rule.decision == SafetyDecision.DENY + assert rule.enabled is True + + def test_rule_defaults(self) -> None: + """Test rule default values.""" + rule = ValidationRule(name="test_rule", decision=SafetyDecision.ALLOW) + + assert rule.id is not None + assert rule.priority == 0 + assert rule.enabled is True + assert rule.decision == SafetyDecision.ALLOW + + +class TestValidationResult: + """Tests for ValidationResult model.""" + + def test_allow_result(self) -> None: + """Test an allow result.""" + result = ValidationResult( + action_id="action-1", + decision=SafetyDecision.ALLOW, + applied_rules=["rule-1"], + reasons=["Action is allowed"], + ) + + assert result.decision == SafetyDecision.ALLOW + assert len(result.applied_rules) == 1 + + def test_deny_result(self) -> None: + """Test a deny result.""" + result = ValidationResult( + action_id="action-1", + decision=SafetyDecision.DENY, + applied_rules=["deny_rule"], + reasons=["Action is not permitted"], + ) + + assert result.decision == SafetyDecision.DENY + + +class TestBudgetStatus: + """Tests for BudgetStatus model.""" + + def test_under_budget(self) -> None: + """Test status when under budget.""" + status = BudgetStatus( + scope=BudgetScope.SESSION, + scope_id="session-1", + tokens_used=5000, + tokens_limit=10000, + tokens_remaining=5000, + ) + + assert status.tokens_remaining == 5000 + assert status.tokens_used == 5000 + + def test_over_budget(self) -> None: + """Test status when over budget.""" + status = BudgetStatus( + scope=BudgetScope.SESSION, + scope_id="session-1", + cost_used_usd=15.0, + cost_limit_usd=10.0, + cost_remaining_usd=0.0, + is_exceeded=True, + ) + + assert status.is_exceeded is True + assert status.cost_remaining_usd == 0.0 + + +class TestRateLimitConfig: + """Tests for RateLimitConfig model.""" + + def test_create_config(self) -> None: + """Test creating rate limit config.""" + config = RateLimitConfig( + name="actions", + limit=60, + window_seconds=60, + ) + + assert config.name == "actions" + assert config.limit == 60 + assert config.window_seconds == 60 + + +class TestApprovalRequest: + """Tests for ApprovalRequest model.""" + + def test_create_request(self) -> None: + """Test creating an approval request.""" + metadata = ActionMetadata(agent_id="agent-1", session_id="session-1") + action = ActionRequest( + action_type=ActionType.DATABASE_MUTATE, + tool_name="db_delete", + metadata=metadata, + ) + + request = ApprovalRequest( + id="approval-1", + action=action, + reason="Database mutation requires approval", + urgency="high", + timeout_seconds=300, + ) + + assert request.id == "approval-1" + assert request.urgency == "high" + assert request.timeout_seconds == 300 + + +class TestApprovalResponse: + """Tests for ApprovalResponse model.""" + + def test_approved_response(self) -> None: + """Test an approved response.""" + response = ApprovalResponse( + request_id="approval-1", + status=ApprovalStatus.APPROVED, + decided_by="admin", + reason="Looks safe", + ) + + assert response.status == ApprovalStatus.APPROVED + assert response.decided_by == "admin" + + def test_denied_response(self) -> None: + """Test a denied response.""" + response = ApprovalResponse( + request_id="approval-1", + status=ApprovalStatus.DENIED, + decided_by="admin", + reason="Too risky", + ) + + assert response.status == ApprovalStatus.DENIED + + +class TestCheckpoint: + """Tests for Checkpoint model.""" + + def test_create_checkpoint(self) -> None: + """Test creating a checkpoint.""" + test_path = "/tmp/test.txt" # noqa: S108 + checkpoint = Checkpoint( + id="checkpoint-1", + checkpoint_type=CheckpointType.FILE, + action_id="action-1", + created_at=datetime.utcnow(), + data={"path": test_path}, + description="File checkpoint", + ) + + assert checkpoint.id == "checkpoint-1" + assert checkpoint.is_valid is True + + def test_expired_checkpoint(self) -> None: + """Test an expired checkpoint.""" + checkpoint = Checkpoint( + id="checkpoint-1", + checkpoint_type=CheckpointType.FILE, + action_id="action-1", + created_at=datetime.utcnow() - timedelta(hours=2), + expires_at=datetime.utcnow() - timedelta(hours=1), + data={}, + ) + + # is_valid is a simple bool, not computed from expires_at + # The RollbackManager handles expiration logic + assert checkpoint.is_valid is True # Default value + + +class TestRollbackResult: + """Tests for RollbackResult model.""" + + def test_successful_rollback(self) -> None: + """Test a successful rollback.""" + result = RollbackResult( + checkpoint_id="checkpoint-1", + success=True, + actions_rolled_back=["file:/tmp/test.txt"], + failed_actions=[], + ) + + assert result.success is True + assert len(result.actions_rolled_back) == 1 + + def test_partial_rollback(self) -> None: + """Test a partial rollback.""" + result = RollbackResult( + checkpoint_id="checkpoint-1", + success=False, + actions_rolled_back=["file:/tmp/a.txt"], + failed_actions=["file:/tmp/b.txt"], + error="Failed to rollback 1 item", + ) + + assert result.success is False + assert len(result.failed_actions) == 1 + + +class TestAuditEvent: + """Tests for AuditEvent model.""" + + def test_create_event(self) -> None: + """Test creating an audit event.""" + event = AuditEvent( + id="event-1", + event_type=AuditEventType.ACTION_EXECUTED, + timestamp=datetime.utcnow(), + agent_id="agent-1", + action_id="action-1", + data={"tool": "file_read"}, + ) + + assert event.event_type == AuditEventType.ACTION_EXECUTED + assert event.agent_id == "agent-1" + + +class TestSafetyPolicy: + """Tests for SafetyPolicy model.""" + + def test_default_policy(self) -> None: + """Test creating a default policy.""" + policy = SafetyPolicy( + name="default", + description="Default safety policy", + ) + + assert policy.name == "default" + assert policy.enabled is True + + def test_restrictive_policy(self) -> None: + """Test creating a restrictive policy.""" + policy = SafetyPolicy( + name="restrictive", + description="Restrictive policy", + denied_tools=["shell_*", "exec_*"], + require_approval_for=["database_*", "git_push"], + ) + + assert len(policy.denied_tools) == 2 + assert len(policy.require_approval_for) == 2 + + +class TestGuardianResult: + """Tests for GuardianResult model.""" + + def test_allowed_result(self) -> None: + """Test an allowed result.""" + result = GuardianResult( + action_id="action-1", + allowed=True, + decision=SafetyDecision.ALLOW, + reasons=["All checks passed"], + ) + + assert result.decision == SafetyDecision.ALLOW + assert result.allowed is True + assert result.approval_id is None + + def test_approval_required_result(self) -> None: + """Test a result requiring approval.""" + result = GuardianResult( + action_id="action-1", + allowed=False, + decision=SafetyDecision.REQUIRE_APPROVAL, + reasons=["Action requires human approval"], + approval_id="approval-123", + ) + + assert result.decision == SafetyDecision.REQUIRE_APPROVAL + assert result.approval_id == "approval-123" + + +class TestEnums: + """Tests for enum values.""" + + def test_action_types(self) -> None: + """Test action type enum values.""" + assert ActionType.FILE_READ.value == "file_read" + assert ActionType.SHELL_COMMAND.value == "shell_command" + + def test_autonomy_levels(self) -> None: + """Test autonomy level enum values.""" + assert AutonomyLevel.FULL_CONTROL.value == "full_control" + assert AutonomyLevel.MILESTONE.value == "milestone" + assert AutonomyLevel.AUTONOMOUS.value == "autonomous" + + def test_permission_levels(self) -> None: + """Test permission level enum values.""" + assert PermissionLevel.NONE.value == "none" + assert PermissionLevel.READ.value == "read" + assert PermissionLevel.WRITE.value == "write" + assert PermissionLevel.ADMIN.value == "admin" + + def test_safety_decisions(self) -> None: + """Test safety decision enum values.""" + assert SafetyDecision.ALLOW.value == "allow" + assert SafetyDecision.DENY.value == "deny" + assert SafetyDecision.REQUIRE_APPROVAL.value == "require_approval" diff --git a/backend/tests/services/safety/test_validation.py b/backend/tests/services/safety/test_validation.py new file mode 100644 index 0000000..f326d50 --- /dev/null +++ b/backend/tests/services/safety/test_validation.py @@ -0,0 +1,404 @@ +"""Tests for safety validation module.""" + +import pytest + +from app.services.safety.models import ( + ActionMetadata, + ActionRequest, + ActionType, + AutonomyLevel, + SafetyDecision, + SafetyPolicy, + ValidationRule, +) +from app.services.safety.validation.validator import ( + ActionValidator, + create_allow_rule, + create_approval_rule, + create_deny_rule, +) + + +@pytest.fixture +def validator() -> ActionValidator: + """Create a fresh ActionValidator.""" + return ActionValidator(cache_enabled=False) + + +@pytest.fixture +def sample_action() -> ActionRequest: + """Create a sample action request.""" + metadata = ActionMetadata( + agent_id="test-agent", + session_id="test-session", + autonomy_level=AutonomyLevel.MILESTONE, + ) + return ActionRequest( + action_type=ActionType.FILE_READ, + tool_name="file_read", + resource="/tmp/test.txt", # noqa: S108 + metadata=metadata, + ) + + +class TestActionValidator: + """Tests for ActionValidator class.""" + + @pytest.mark.asyncio + async def test_no_rules_allows_by_default( + self, + validator: ActionValidator, + sample_action: ActionRequest, + ) -> None: + """Test that actions are allowed by default with no rules.""" + result = await validator.validate(sample_action) + + assert result.decision == SafetyDecision.ALLOW + assert "No matching rules" in result.reasons[0] + + @pytest.mark.asyncio + async def test_deny_rule_blocks_action( + self, + validator: ActionValidator, + ) -> None: + """Test that a deny rule blocks matching actions.""" + validator.add_rule( + create_deny_rule( + name="deny_shell", + tool_patterns=["shell_*"], + reason="Shell commands not allowed", + ) + ) + + metadata = ActionMetadata(agent_id="test-agent", session_id="session-1") + action = ActionRequest( + action_type=ActionType.SHELL_COMMAND, + tool_name="shell_exec", + metadata=metadata, + ) + + result = await validator.validate(action) + + assert result.decision == SafetyDecision.DENY + assert len(result.applied_rules) == 1 # One rule applied + + @pytest.mark.asyncio + async def test_approval_rule_requires_approval( + self, + validator: ActionValidator, + ) -> None: + """Test that an approval rule requires approval.""" + validator.add_rule( + create_approval_rule( + name="approve_db", + tool_patterns=["database_*"], + reason="Database operations require approval", + ) + ) + + metadata = ActionMetadata(agent_id="test-agent", session_id="session-1") + action = ActionRequest( + action_type=ActionType.DATABASE_MUTATE, + tool_name="database_delete", + metadata=metadata, + ) + + result = await validator.validate(action) + + assert result.decision == SafetyDecision.REQUIRE_APPROVAL + + @pytest.mark.asyncio + async def test_deny_takes_precedence( + self, + validator: ActionValidator, + ) -> None: + """Test that deny rules take precedence over allow rules.""" + validator.add_rule( + create_allow_rule( + name="allow_files", + tool_patterns=["file_*"], + priority=10, + ) + ) + validator.add_rule( + create_deny_rule( + name="deny_delete", + action_types=[ActionType.FILE_DELETE], + priority=100, + ) + ) + + metadata = ActionMetadata(agent_id="test-agent", session_id="session-1") + action = ActionRequest( + action_type=ActionType.FILE_DELETE, + tool_name="file_delete", + metadata=metadata, + ) + + result = await validator.validate(action) + + assert result.decision == SafetyDecision.DENY + + @pytest.mark.asyncio + async def test_rule_priority_ordering( + self, + validator: ActionValidator, + ) -> None: + """Test that rules are evaluated in priority order.""" + validator.add_rule( + ValidationRule( + name="low_priority", + priority=1, + decision=SafetyDecision.ALLOW, + ) + ) + validator.add_rule( + ValidationRule( + name="high_priority", + priority=100, + decision=SafetyDecision.DENY, + ) + ) + + # High priority should be first in the list + assert validator._rules[0].name == "high_priority" + + @pytest.mark.asyncio + async def test_disabled_rule_not_applied( + self, + validator: ActionValidator, + sample_action: ActionRequest, + ) -> None: + """Test that disabled rules are not applied.""" + rule = create_deny_rule( + name="deny_all", + tool_patterns=["*"], + ) + rule.enabled = False + validator.add_rule(rule) + + result = await validator.validate(sample_action) + + assert result.decision == SafetyDecision.ALLOW + + @pytest.mark.asyncio + async def test_resource_pattern_matching( + self, + validator: ActionValidator, + ) -> None: + """Test resource pattern matching.""" + validator.add_rule( + create_deny_rule( + name="deny_secrets", + resource_patterns=["*/secrets/*", "*.env"], + ) + ) + + metadata = ActionMetadata(agent_id="test-agent", session_id="session-1") + action = ActionRequest( + action_type=ActionType.FILE_READ, + tool_name="file_read", + resource="/app/secrets/api_key.txt", + metadata=metadata, + ) + + result = await validator.validate(action) + + assert result.decision == SafetyDecision.DENY + + @pytest.mark.asyncio + async def test_agent_id_filter( + self, + validator: ActionValidator, + ) -> None: + """Test filtering by agent ID.""" + rule = ValidationRule( + name="restrict_agent", + agent_ids=["restricted-agent"], + decision=SafetyDecision.DENY, + reason="Restricted agent", + ) + validator.add_rule(rule) + + # Restricted agent should be denied + metadata1 = ActionMetadata(agent_id="restricted-agent") + action1 = ActionRequest( + action_type=ActionType.FILE_READ, + tool_name="file_read", + metadata=metadata1, + ) + result1 = await validator.validate(action1) + assert result1.decision == SafetyDecision.DENY + + # Other agents should be allowed + metadata2 = ActionMetadata(agent_id="normal-agent") + action2 = ActionRequest( + action_type=ActionType.FILE_READ, + tool_name="file_read", + metadata=metadata2, + ) + result2 = await validator.validate(action2) + assert result2.decision == SafetyDecision.ALLOW + + @pytest.mark.asyncio + async def test_bypass_mode( + self, + validator: ActionValidator, + sample_action: ActionRequest, + ) -> None: + """Test validation bypass mode.""" + validator.add_rule(create_deny_rule(name="deny_all", tool_patterns=["*"])) + + # Should be denied normally + result1 = await validator.validate(sample_action) + assert result1.decision == SafetyDecision.DENY + + # Enable bypass + validator.enable_bypass("Emergency situation") + result2 = await validator.validate(sample_action) + assert result2.decision == SafetyDecision.ALLOW + assert "bypassed" in result2.reasons[0].lower() + + # Disable bypass + validator.disable_bypass() + result3 = await validator.validate(sample_action) + assert result3.decision == SafetyDecision.DENY + + def test_remove_rule(self, validator: ActionValidator) -> None: + """Test removing a rule.""" + rule = create_deny_rule(name="test_rule", tool_patterns=["test"]) + validator.add_rule(rule) + + assert len(validator._rules) == 1 + assert validator.remove_rule(rule.id) is True + assert len(validator._rules) == 0 + + def test_remove_nonexistent_rule(self, validator: ActionValidator) -> None: + """Test removing a nonexistent rule returns False.""" + assert validator.remove_rule("nonexistent") is False + + def test_clear_rules(self, validator: ActionValidator) -> None: + """Test clearing all rules.""" + validator.add_rule(create_deny_rule(name="rule1", tool_patterns=["a"])) + validator.add_rule(create_deny_rule(name="rule2", tool_patterns=["b"])) + + assert len(validator._rules) == 2 + validator.clear_rules() + assert len(validator._rules) == 0 + + +class TestLoadRulesFromPolicy: + """Tests for loading rules from policies.""" + + @pytest.mark.asyncio + async def test_load_denied_tools( + self, + validator: ActionValidator, + ) -> None: + """Test loading denied tools from policy.""" + policy = SafetyPolicy( + name="test", + denied_tools=["shell_*", "exec_*"], + ) + + validator.load_rules_from_policy(policy) + + # Should have 2 deny rules + deny_rules = [r for r in validator._rules if r.decision == SafetyDecision.DENY] + assert len(deny_rules) == 2 + + @pytest.mark.asyncio + async def test_load_approval_patterns( + self, + validator: ActionValidator, + ) -> None: + """Test loading approval patterns from policy.""" + policy = SafetyPolicy( + name="test", + require_approval_for=["database_*"], + ) + + validator.load_rules_from_policy(policy) + + approval_rules = [ + r for r in validator._rules + if r.decision == SafetyDecision.REQUIRE_APPROVAL + ] + assert len(approval_rules) == 1 + + +class TestValidationBatch: + """Tests for batch validation.""" + + @pytest.mark.asyncio + async def test_validate_batch( + self, + validator: ActionValidator, + ) -> None: + """Test validating multiple actions.""" + validator.add_rule( + create_deny_rule( + name="deny_shell", + tool_patterns=["shell_*"], + ) + ) + + metadata = ActionMetadata(agent_id="test-agent", session_id="session-1") + actions = [ + ActionRequest( + action_type=ActionType.FILE_READ, + tool_name="file_read", + metadata=metadata, + ), + ActionRequest( + action_type=ActionType.SHELL_COMMAND, + tool_name="shell_exec", + metadata=metadata, + ), + ] + + results = await validator.validate_batch(actions) + + assert len(results) == 2 + assert results[0].decision == SafetyDecision.ALLOW + assert results[1].decision == SafetyDecision.DENY + + +class TestHelperFunctions: + """Tests for rule creation helper functions.""" + + def test_create_allow_rule(self) -> None: + """Test creating an allow rule.""" + rule = create_allow_rule( + name="allow_test", + tool_patterns=["test_*"], + priority=50, + ) + + assert rule.name == "allow_test" + assert rule.decision == SafetyDecision.ALLOW + assert rule.priority == 50 + + def test_create_deny_rule(self) -> None: + """Test creating a deny rule.""" + rule = create_deny_rule( + name="deny_test", + tool_patterns=["dangerous_*"], + reason="Too dangerous", + ) + + assert rule.name == "deny_test" + assert rule.decision == SafetyDecision.DENY + assert rule.reason == "Too dangerous" + assert rule.priority == 100 # Default priority for deny + + def test_create_approval_rule(self) -> None: + """Test creating an approval rule.""" + rule = create_approval_rule( + name="approve_test", + action_types=[ActionType.DATABASE_MUTATE], + ) + + assert rule.name == "approve_test" + assert rule.decision == SafetyDecision.REQUIRE_APPROVAL + assert rule.priority == 50 # Default priority for approval