"""Tests for SafetyGuardian integration.""" from unittest.mock import AsyncMock, MagicMock, patch import pytest import pytest_asyncio from app.services.safety.config import SafetyConfig from app.services.safety.costs.controller import CostController from app.services.safety.guardian import ( SafetyGuardian, get_safety_guardian, reset_safety_guardian, shutdown_safety_guardian, ) from app.services.safety.limits.limiter import RateLimiter from app.services.safety.loops.detector import LoopDetector from app.services.safety.models import ( ActionMetadata, ActionRequest, ActionResult, ActionType, AuditEvent, AuditEventType, AutonomyLevel, BudgetScope, SafetyDecision, SafetyPolicy, ) @pytest_asyncio.fixture async def reset_guardian(): """Reset the singleton guardian before and after each test.""" await reset_safety_guardian() yield await reset_safety_guardian() @pytest.fixture def safety_config() -> SafetyConfig: """Create a test safety configuration.""" return SafetyConfig( enabled=True, strict_mode=True, hitl_enabled=False, auto_checkpoint_destructive=False, ) @pytest.fixture def cost_controller() -> CostController: """Create a cost controller for testing.""" return CostController( default_session_tokens=1000, default_session_cost_usd=10.0, default_daily_tokens=5000, default_daily_cost_usd=50.0, ) @pytest.fixture def rate_limiter() -> RateLimiter: """Create a rate limiter for testing.""" return RateLimiter() @pytest.fixture def loop_detector() -> LoopDetector: """Create a loop detector for testing.""" return LoopDetector( history_size=10, max_exact_repetitions=3, max_semantic_repetitions=5, ) def _make_audit_event() -> AuditEvent: """Create a mock audit event.""" return AuditEvent( event_type=AuditEventType.ACTION_REQUESTED, agent_id="test-agent", action_id="test-action", ) @pytest_asyncio.fixture async def guardian( safety_config: SafetyConfig, cost_controller: CostController, rate_limiter: RateLimiter, loop_detector: LoopDetector, ) -> SafetyGuardian: """Create a SafetyGuardian for testing.""" guardian = SafetyGuardian( config=safety_config, cost_controller=cost_controller, rate_limiter=rate_limiter, loop_detector=loop_detector, ) # Patch the audit logger to avoid actual logging # Return proper AuditEvent objects instead of AsyncMock guardian._audit_logger = MagicMock() guardian._audit_logger.log = AsyncMock(return_value=_make_audit_event()) guardian._audit_logger.log_action_request = AsyncMock( return_value=_make_audit_event() ) guardian._audit_logger.log_action_executed = AsyncMock(return_value=None) await guardian.initialize() return guardian @pytest.fixture def sample_metadata() -> ActionMetadata: """Create sample action metadata.""" return ActionMetadata( agent_id="test-agent", session_id="test-session", autonomy_level=AutonomyLevel.MILESTONE, ) def create_action( metadata: ActionMetadata, tool_name: str = "test_tool", action_type: ActionType = ActionType.LLM_CALL, resource: str = "/tmp/test.txt", # noqa: S108 estimated_tokens: int = 100, estimated_cost: float = 0.01, ) -> ActionRequest: """Helper to create test actions.""" return ActionRequest( action_type=action_type, tool_name=tool_name, resource=resource, arguments={}, metadata=metadata, estimated_cost_tokens=estimated_tokens, estimated_cost_usd=estimated_cost, ) class TestSafetyGuardianInit: """Tests for SafetyGuardian initialization.""" @pytest.mark.asyncio async def test_init_creates_subsystems( self, safety_config: SafetyConfig, ) -> None: """Test initialization creates subsystems if not provided.""" with patch( "app.services.safety.guardian.get_audit_logger", new_callable=AsyncMock, ): guardian = SafetyGuardian(config=safety_config) await guardian.initialize() assert guardian.cost_controller is not None assert guardian.rate_limiter is not None assert guardian.loop_detector is not None assert guardian.is_initialized is True @pytest.mark.asyncio async def test_init_with_provided_subsystems( self, safety_config: SafetyConfig, cost_controller: CostController, rate_limiter: RateLimiter, loop_detector: LoopDetector, ) -> None: """Test initialization uses provided subsystems.""" guardian = SafetyGuardian( config=safety_config, cost_controller=cost_controller, rate_limiter=rate_limiter, loop_detector=loop_detector, ) guardian._audit_logger = MagicMock() await guardian.initialize() # Should use the provided instances assert guardian.cost_controller is cost_controller assert guardian.rate_limiter is rate_limiter assert guardian.loop_detector is loop_detector class TestSafetyGuardianValidation: """Tests for SafetyGuardian.validate().""" @pytest.mark.asyncio async def test_validate_success( self, guardian: SafetyGuardian, sample_metadata: ActionMetadata, ) -> None: """Test successful validation passes all checks.""" action = create_action(sample_metadata) result = await guardian.validate(action) assert result.allowed is True assert result.decision == SafetyDecision.ALLOW @pytest.mark.asyncio async def test_validate_disabled_allows_all( self, guardian: SafetyGuardian, sample_metadata: ActionMetadata, ) -> None: """Test validation with disabled safety allows all.""" guardian._config.enabled = False action = create_action(sample_metadata) result = await guardian.validate(action) assert result.allowed is True assert "disabled" in result.reasons[0].lower() @pytest.mark.asyncio async def test_validate_budget_exceeded( self, guardian: SafetyGuardian, sample_metadata: ActionMetadata, ) -> None: """Test validation fails when budget exceeded.""" # Use up the session budget await guardian.cost_controller.record_usage( agent_id=sample_metadata.agent_id, session_id=sample_metadata.session_id, tokens=1000, cost_usd=10.0, ) action = create_action(sample_metadata, estimated_tokens=100) result = await guardian.validate(action) assert result.allowed is False assert result.decision == SafetyDecision.DENY assert any("budget" in r.lower() for r in result.reasons) @pytest.mark.asyncio async def test_validate_rate_limit_exceeded( self, guardian: SafetyGuardian, sample_metadata: ActionMetadata, ) -> None: """Test validation fails when rate limit exceeded.""" # Exhaust rate limits by calling validate many times for _ in range(100): # More than default limit action = create_action(sample_metadata) await guardian.rate_limiter.acquire("actions", sample_metadata.agent_id) action = create_action(sample_metadata) result = await guardian.validate(action) # Should be denied or delayed assert result.allowed is False assert result.decision in (SafetyDecision.DENY, SafetyDecision.DELAY) @pytest.mark.asyncio async def test_validate_loop_detected( self, guardian: SafetyGuardian, sample_metadata: ActionMetadata, ) -> None: """Test validation fails when loop detected.""" action = create_action(sample_metadata) # Record the same action multiple times (to trigger loop) for _ in range(3): await guardian.loop_detector.record(action) result = await guardian.validate(action) assert result.allowed is False assert result.decision == SafetyDecision.DENY assert any("loop" in r.lower() for r in result.reasons) @pytest.mark.asyncio async def test_validate_denied_tool( self, guardian: SafetyGuardian, sample_metadata: ActionMetadata, ) -> None: """Test validation fails for denied tools.""" # Create action with tool that matches denied pattern action = create_action(sample_metadata, tool_name="shell_exec") # Create policy with denied pattern policy = SafetyPolicy( name="test-policy", allowed_tools=["*"], denied_tools=["shell_*"], ) result = await guardian.validate(action, policy=policy) assert result.allowed is False assert result.decision == SafetyDecision.DENY assert any("denied" in r.lower() for r in result.reasons) @pytest.mark.asyncio async def test_validate_with_custom_policy( self, guardian: SafetyGuardian, sample_metadata: ActionMetadata, ) -> None: """Test validation with custom policy.""" action = create_action(sample_metadata, tool_name="allowed_tool") policy = SafetyPolicy( name="test-custom-policy", allowed_tools=["allowed_*"], denied_tools=[], ) result = await guardian.validate(action, policy=policy) assert result.allowed is True assert result.decision == SafetyDecision.ALLOW class TestSafetyGuardianRecording: """Tests for SafetyGuardian.record_execution().""" @pytest.mark.asyncio async def test_record_execution_updates_cost( self, guardian: SafetyGuardian, sample_metadata: ActionMetadata, ) -> None: """Test recording execution updates cost tracker.""" action = create_action(sample_metadata) action_result = ActionResult( action_id=action.id, success=True, actual_cost_tokens=50, actual_cost_usd=0.005, ) await guardian.record_execution(action, action_result) # Check cost was recorded status = await guardian.cost_controller.get_status( BudgetScope.SESSION, sample_metadata.session_id ) assert status is not None assert status.tokens_used == 50 @pytest.mark.asyncio async def test_record_execution_updates_loop_history( self, guardian: SafetyGuardian, sample_metadata: ActionMetadata, ) -> None: """Test recording execution updates loop detector history.""" action = create_action(sample_metadata) action_result = ActionResult( action_id=action.id, success=True, ) await guardian.record_execution(action, action_result) # Check action was recorded in loop detector stats = await guardian.loop_detector.get_stats(sample_metadata.agent_id) assert stats["history_size"] == 1 class TestSafetyGuardianSingleton: """Tests for SafetyGuardian singleton functions.""" @pytest.mark.asyncio async def test_get_safety_guardian_creates_singleton( self, reset_guardian, ) -> None: """Test get_safety_guardian creates singleton.""" with patch( "app.services.safety.guardian.get_audit_logger", new_callable=AsyncMock, ): guardian1 = await get_safety_guardian() guardian2 = await get_safety_guardian() assert guardian1 is guardian2 assert guardian1.is_initialized is True @pytest.mark.asyncio async def test_shutdown_safety_guardian( self, reset_guardian, ) -> None: """Test shutdown cleans up singleton.""" with patch( "app.services.safety.guardian.get_audit_logger", new_callable=AsyncMock, ): guardian = await get_safety_guardian() assert guardian.is_initialized is True await shutdown_safety_guardian() # Singleton should be cleared - next get creates new instance @pytest.mark.asyncio async def test_reset_safety_guardian( self, reset_guardian, ) -> None: """Test reset clears singleton.""" with patch( "app.services.safety.guardian.get_audit_logger", new_callable=AsyncMock, ): guardian1 = await get_safety_guardian() await reset_safety_guardian() guardian2 = await get_safety_guardian() assert guardian1 is not guardian2 class TestPatternMatching: """Tests for pattern matching logic.""" def test_exact_match(self) -> None: """Test exact pattern matching.""" guardian = SafetyGuardian() assert guardian._matches_pattern("file_read", "file_read") is True assert guardian._matches_pattern("file_read", "file_write") is False def test_wildcard_all(self) -> None: """Test wildcard * matches all.""" guardian = SafetyGuardian() assert guardian._matches_pattern("anything", "*") is True assert guardian._matches_pattern("", "*") is True def test_prefix_wildcard(self) -> None: """Test prefix wildcard matching.""" guardian = SafetyGuardian() assert guardian._matches_pattern("test_read", "*_read") is True assert guardian._matches_pattern("test_write", "*_read") is False def test_suffix_wildcard(self) -> None: """Test suffix wildcard matching.""" guardian = SafetyGuardian() assert guardian._matches_pattern("file_read", "file_*") is True assert guardian._matches_pattern("shell_read", "file_*") is False def test_contains_wildcard(self) -> None: """Test contains wildcard matching.""" guardian = SafetyGuardian() assert guardian._matches_pattern("test_dangerous_action", "*dangerous*") is True assert guardian._matches_pattern("test_safe_action", "*dangerous*") is False class TestErrorHandling: """Tests for error handling in SafetyGuardian.""" @pytest.mark.asyncio async def test_strict_mode_fails_on_error( self, guardian: SafetyGuardian, sample_metadata: ActionMetadata, ) -> None: """Test strict mode denies on unexpected errors.""" action = create_action(sample_metadata) # Force an error by breaking the cost controller original_check = guardian.cost_controller.check_budget guardian.cost_controller.check_budget = AsyncMock( side_effect=Exception("Unexpected error") ) result = await guardian.validate(action) assert result.allowed is False assert result.decision == SafetyDecision.DENY assert any("error" in r.lower() for r in result.reasons) # Restore guardian.cost_controller.check_budget = original_check @pytest.mark.asyncio async def test_non_strict_mode_allows_on_error( self, guardian: SafetyGuardian, sample_metadata: ActionMetadata, ) -> None: """Test non-strict mode allows on unexpected errors.""" guardian._config.strict_mode = False action = create_action(sample_metadata) # Force an error by breaking the cost controller original_check = guardian.cost_controller.check_budget guardian.cost_controller.check_budget = AsyncMock( side_effect=Exception("Unexpected error") ) result = await guardian.validate(action) assert result.allowed is True assert result.decision == SafetyDecision.ALLOW # Restore guardian.cost_controller.check_budget = original_check guardian._config.strict_mode = True