"""Tests for emergency controls module.""" import pytest from app.services.safety.emergency.controls import ( EmergencyControls, EmergencyReason, EmergencyState, EmergencyTrigger, ) from app.services.safety.exceptions import EmergencyStopError @pytest.fixture def controls() -> EmergencyControls: """Create fresh EmergencyControls.""" return EmergencyControls() class TestEmergencyControls: """Tests for EmergencyControls class.""" @pytest.mark.asyncio async def test_initial_state_is_normal( self, controls: EmergencyControls, ) -> None: """Test that initial state is normal.""" state = await controls.get_state("global") assert state == EmergencyState.NORMAL @pytest.mark.asyncio async def test_emergency_stop_changes_state( self, controls: EmergencyControls, ) -> None: """Test that emergency stop changes state to stopped.""" event = await controls.emergency_stop( reason=EmergencyReason.MANUAL, triggered_by="test", message="Test emergency stop", ) assert event.state == EmergencyState.STOPPED state = await controls.get_state("global") assert state == EmergencyState.STOPPED @pytest.mark.asyncio async def test_pause_changes_state( self, controls: EmergencyControls, ) -> None: """Test that pause changes state to paused.""" event = await controls.pause( reason=EmergencyReason.BUDGET_EXCEEDED, triggered_by="budget_controller", message="Budget exceeded", ) assert event.state == EmergencyState.PAUSED state = await controls.get_state("global") assert state == EmergencyState.PAUSED @pytest.mark.asyncio async def test_resume_from_paused( self, controls: EmergencyControls, ) -> None: """Test resuming from paused state.""" await controls.pause( reason=EmergencyReason.RATE_LIMIT, triggered_by="limiter", message="Rate limited", ) resumed = await controls.resume(resumed_by="admin") assert resumed is True state = await controls.get_state("global") assert state == EmergencyState.NORMAL @pytest.mark.asyncio async def test_cannot_resume_from_stopped( self, controls: EmergencyControls, ) -> None: """Test that you cannot resume from stopped state.""" await controls.emergency_stop( reason=EmergencyReason.SAFETY_VIOLATION, triggered_by="safety", message="Critical violation", ) resumed = await controls.resume(resumed_by="admin") assert resumed is False state = await controls.get_state("global") assert state == EmergencyState.STOPPED @pytest.mark.asyncio async def test_reset_from_stopped( self, controls: EmergencyControls, ) -> None: """Test resetting from stopped state.""" await controls.emergency_stop( reason=EmergencyReason.SAFETY_VIOLATION, triggered_by="safety", message="Critical violation", ) reset = await controls.reset(reset_by="admin") assert reset is True state = await controls.get_state("global") assert state == EmergencyState.NORMAL @pytest.mark.asyncio async def test_scoped_emergency_stop( self, controls: EmergencyControls, ) -> None: """Test emergency stop with specific scope.""" await controls.emergency_stop( reason=EmergencyReason.LOOP_DETECTED, triggered_by="detector", message="Loop in agent", scope="agent:agent-123", ) # Agent scope should be stopped agent_state = await controls.get_state("agent:agent-123") assert agent_state == EmergencyState.STOPPED # Global should still be normal global_state = await controls.get_state("global") assert global_state == EmergencyState.NORMAL @pytest.mark.asyncio async def test_check_allowed_when_normal( self, controls: EmergencyControls, ) -> None: """Test check_allowed returns True when state is normal.""" allowed = await controls.check_allowed() assert allowed is True @pytest.mark.asyncio async def test_check_allowed_when_stopped( self, controls: EmergencyControls, ) -> None: """Test check_allowed returns False when stopped.""" await controls.emergency_stop( reason=EmergencyReason.MANUAL, triggered_by="test", message="Stop", ) allowed = await controls.check_allowed(raise_if_blocked=False) assert allowed is False @pytest.mark.asyncio async def test_check_allowed_raises_when_blocked( self, controls: EmergencyControls, ) -> None: """Test check_allowed raises exception when blocked.""" await controls.emergency_stop( reason=EmergencyReason.MANUAL, triggered_by="test", message="Stop", ) with pytest.raises(EmergencyStopError): await controls.check_allowed(raise_if_blocked=True) @pytest.mark.asyncio async def test_check_allowed_with_scope( self, controls: EmergencyControls, ) -> None: """Test check_allowed with specific scope.""" await controls.pause( reason=EmergencyReason.BUDGET_EXCEEDED, triggered_by="budget", message="Paused", scope="project:proj-123", ) # Project scope should be blocked allowed_project = await controls.check_allowed( scope="project:proj-123", raise_if_blocked=False, ) assert allowed_project is False # Different scope should be allowed allowed_other = await controls.check_allowed( scope="project:proj-456", raise_if_blocked=False, ) assert allowed_other is True @pytest.mark.asyncio async def test_get_all_states( self, controls: EmergencyControls, ) -> None: """Test getting all states.""" await controls.pause( reason=EmergencyReason.MANUAL, triggered_by="test", message="Pause", scope="agent:a1", ) await controls.emergency_stop( reason=EmergencyReason.MANUAL, triggered_by="test", message="Stop", scope="agent:a2", ) states = await controls.get_all_states() assert states["global"] == EmergencyState.NORMAL assert states["agent:a1"] == EmergencyState.PAUSED assert states["agent:a2"] == EmergencyState.STOPPED @pytest.mark.asyncio async def test_get_active_events( self, controls: EmergencyControls, ) -> None: """Test getting active (unresolved) events.""" await controls.pause( reason=EmergencyReason.MANUAL, triggered_by="test", message="Pause 1", ) events = await controls.get_active_events() assert len(events) == 1 # Resume should resolve the event await controls.resume() events_after = await controls.get_active_events() assert len(events_after) == 0 @pytest.mark.asyncio async def test_event_history( self, controls: EmergencyControls, ) -> None: """Test getting event history.""" await controls.pause( reason=EmergencyReason.RATE_LIMIT, triggered_by="test", message="Rate limited", ) await controls.resume() history = await controls.get_event_history() assert len(history) == 1 assert history[0].reason == EmergencyReason.RATE_LIMIT @pytest.mark.asyncio async def test_event_metadata( self, controls: EmergencyControls, ) -> None: """Test event metadata storage.""" event = await controls.emergency_stop( reason=EmergencyReason.BUDGET_EXCEEDED, triggered_by="budget_controller", message="Over budget", metadata={"budget_type": "tokens", "usage": 150000}, ) assert event.metadata["budget_type"] == "tokens" assert event.metadata["usage"] == 150000 class TestCallbacks: """Tests for callback functionality.""" @pytest.mark.asyncio async def test_on_stop_callback( self, controls: EmergencyControls, ) -> None: """Test on_stop callback is called.""" callback_called = [] def callback(event: object) -> None: callback_called.append(event) controls.on_stop(callback) await controls.emergency_stop( reason=EmergencyReason.MANUAL, triggered_by="test", message="Stop", ) assert len(callback_called) == 1 @pytest.mark.asyncio async def test_on_pause_callback( self, controls: EmergencyControls, ) -> None: """Test on_pause callback is called.""" callback_called = [] def callback(event: object) -> None: callback_called.append(event) controls.on_pause(callback) await controls.pause( reason=EmergencyReason.MANUAL, triggered_by="test", message="Pause", ) assert len(callback_called) == 1 @pytest.mark.asyncio async def test_on_resume_callback( self, controls: EmergencyControls, ) -> None: """Test on_resume callback is called.""" callback_called = [] def callback(data: object) -> None: callback_called.append(data) controls.on_resume(callback) await controls.pause( reason=EmergencyReason.MANUAL, triggered_by="test", message="Pause", ) await controls.resume() assert len(callback_called) == 1 class TestEmergencyTrigger: """Tests for EmergencyTrigger class.""" @pytest.fixture def trigger(self, controls: EmergencyControls) -> EmergencyTrigger: """Create an EmergencyTrigger.""" return EmergencyTrigger(controls) @pytest.mark.asyncio async def test_trigger_on_safety_violation( self, trigger: EmergencyTrigger, controls: EmergencyControls, ) -> None: """Test triggering emergency on safety violation.""" event = await trigger.trigger_on_safety_violation( violation_type="unauthorized_access", details={"resource": "/secrets/key"}, ) assert event.reason == EmergencyReason.SAFETY_VIOLATION assert event.state == EmergencyState.STOPPED state = await controls.get_state("global") assert state == EmergencyState.STOPPED @pytest.mark.asyncio async def test_trigger_on_budget_exceeded( self, trigger: EmergencyTrigger, controls: EmergencyControls, ) -> None: """Test triggering pause on budget exceeded.""" event = await trigger.trigger_on_budget_exceeded( budget_type="tokens", current=150000, limit=100000, ) assert event.reason == EmergencyReason.BUDGET_EXCEEDED assert event.state == EmergencyState.PAUSED # Pause, not stop @pytest.mark.asyncio async def test_trigger_on_loop_detected( self, trigger: EmergencyTrigger, controls: EmergencyControls, ) -> None: """Test triggering pause on loop detection.""" event = await trigger.trigger_on_loop_detected( loop_type="exact", agent_id="agent-123", details={"pattern": "file_read"}, ) assert event.reason == EmergencyReason.LOOP_DETECTED assert event.scope == "agent:agent-123" assert event.state == EmergencyState.PAUSED @pytest.mark.asyncio async def test_trigger_on_content_violation( self, trigger: EmergencyTrigger, controls: EmergencyControls, ) -> None: """Test triggering stop on content violation.""" event = await trigger.trigger_on_content_violation( category="secrets", pattern="private_key", ) assert event.reason == EmergencyReason.CONTENT_VIOLATION assert event.state == EmergencyState.STOPPED