"""Tests for rate limiter module.""" import pytest from app.services.safety.exceptions import RateLimitExceededError from app.services.safety.limits.limiter import ( RateLimiter, SlidingWindowCounter, ) from app.services.safety.models import ( ActionMetadata, ActionRequest, ActionType, RateLimitConfig, ) @pytest.fixture def sliding_counter() -> SlidingWindowCounter: """Create a sliding window counter for testing.""" return SlidingWindowCounter( limit=5, window_seconds=60, burst_limit=3, ) @pytest.fixture def rate_limiter() -> RateLimiter: """Create a rate limiter for testing.""" limiter = RateLimiter() # Configure a test limit limiter.configure( RateLimitConfig( name="test_limit", limit=5, window_seconds=60, burst_limit=3, ) ) return limiter @pytest.fixture def sample_metadata() -> ActionMetadata: """Create sample action metadata.""" return ActionMetadata( agent_id="test-agent", session_id="test-session", ) def create_action( metadata: ActionMetadata, action_type: ActionType = ActionType.LLM_CALL, ) -> ActionRequest: """Helper to create test actions.""" return ActionRequest( action_type=action_type, tool_name="test_tool", resource="test-resource", arguments={}, metadata=metadata, ) class TestSlidingWindowCounter: """Tests for SlidingWindowCounter class.""" @pytest.mark.asyncio async def test_first_acquire_allowed( self, sliding_counter: SlidingWindowCounter, ) -> None: """Test first acquire is always allowed.""" allowed, retry_after = await sliding_counter.try_acquire() assert allowed is True assert retry_after == 0.0 @pytest.mark.asyncio async def test_burst_limit( self, sliding_counter: SlidingWindowCounter, ) -> None: """Test burst limit is enforced.""" # Acquire up to burst limit (3) for _ in range(3): allowed, _ = await sliding_counter.try_acquire() assert allowed is True # Next should be denied (burst exceeded) allowed, retry_after = await sliding_counter.try_acquire() assert allowed is False assert retry_after > 0 @pytest.mark.asyncio async def test_get_status( self, sliding_counter: SlidingWindowCounter, ) -> None: """Test getting counter status.""" # Make some requests await sliding_counter.try_acquire() await sliding_counter.try_acquire() current, remaining, reset_in = await sliding_counter.get_status() assert current == 2 assert remaining == 3 # 5 - 2 assert reset_in >= 0 class TestRateLimiter: """Tests for RateLimiter class.""" @pytest.mark.asyncio async def test_check_status( self, rate_limiter: RateLimiter, ) -> None: """Test checking rate limit status.""" status = await rate_limiter.check("test_limit", "test-key") assert status.name == "test_limit" assert status.current_count == 0 assert status.limit == 5 assert status.remaining == 5 assert status.is_limited is False @pytest.mark.asyncio async def test_acquire_success( self, rate_limiter: RateLimiter, ) -> None: """Test successful acquire.""" allowed, status = await rate_limiter.acquire("test_limit", "test-key") assert allowed is True assert status.current_count == 1 assert status.remaining == 4 @pytest.mark.asyncio async def test_acquire_burst_exceeded( self, rate_limiter: RateLimiter, ) -> None: """Test acquire fails when burst exceeded.""" # Acquire up to burst limit for _ in range(3): allowed, _ = await rate_limiter.acquire("test_limit", "test-key") assert allowed is True # Next should fail allowed, status = await rate_limiter.acquire("test_limit", "test-key") assert allowed is False assert status.is_limited is True assert status.retry_after_seconds > 0 @pytest.mark.asyncio async def test_require_success( self, rate_limiter: RateLimiter, ) -> None: """Test require passes when not limited.""" # Should not raise await rate_limiter.require("test_limit", "test-key") @pytest.mark.asyncio async def test_require_raises( self, rate_limiter: RateLimiter, ) -> None: """Test require raises when limited.""" # Use up burst limit for _ in range(3): await rate_limiter.acquire("test_limit", "test-key") with pytest.raises(RateLimitExceededError) as exc_info: await rate_limiter.require("test_limit", "test-key") assert exc_info.value.limit_type == "test_limit" assert exc_info.value.retry_after_seconds > 0 @pytest.mark.asyncio async def test_check_action_allowed( self, rate_limiter: RateLimiter, sample_metadata: ActionMetadata, ) -> None: """Test checking action is allowed.""" action = create_action(sample_metadata) allowed, statuses = await rate_limiter.check_action(action) assert allowed is True assert len(statuses) >= 1 # At least "actions" limit @pytest.mark.asyncio async def test_check_action_llm_limits( self, rate_limiter: RateLimiter, sample_metadata: ActionMetadata, ) -> None: """Test LLM actions check LLM-specific limits.""" action = create_action(sample_metadata, action_type=ActionType.LLM_CALL) allowed, statuses = await rate_limiter.check_action(action) assert allowed is True # Should have checked both "actions" and "llm_calls" limit_names = [s.name for s in statuses] assert "actions" in limit_names assert "llm_calls" in limit_names @pytest.mark.asyncio async def test_check_action_file_limits( self, rate_limiter: RateLimiter, sample_metadata: ActionMetadata, ) -> None: """Test file actions check file-specific limits.""" action = create_action(sample_metadata, action_type=ActionType.FILE_READ) allowed, statuses = await rate_limiter.check_action(action) assert allowed is True # Should have checked both "actions" and "file_ops" limit_names = [s.name for s in statuses] assert "actions" in limit_names assert "file_ops" in limit_names @pytest.mark.asyncio async def test_check_action_does_not_consume_slot( self, rate_limiter: RateLimiter, sample_metadata: ActionMetadata, ) -> None: """Test check_action only checks without consuming slots.""" action = create_action(sample_metadata) # Check multiple times - should never consume for _ in range(10): allowed, _ = await rate_limiter.check_action(action) assert allowed is True # Verify no slots were consumed status = await rate_limiter.check("actions", sample_metadata.agent_id) assert status.current_count == 0 @pytest.mark.asyncio async def test_record_action_consumes_slot( self, rate_limiter: RateLimiter, sample_metadata: ActionMetadata, ) -> None: """Test record_action consumes rate limit slots.""" action = create_action(sample_metadata) # Record the action await rate_limiter.record_action(action) # Verify slot was consumed status = await rate_limiter.check("actions", sample_metadata.agent_id) assert status.current_count == 1 @pytest.mark.asyncio async def test_record_action_consumes_type_specific_slots( self, rate_limiter: RateLimiter, sample_metadata: ActionMetadata, ) -> None: """Test record_action consumes type-specific slots.""" # LLM action llm_action = create_action(sample_metadata, action_type=ActionType.LLM_CALL) await rate_limiter.record_action(llm_action) statuses = await rate_limiter.get_all_statuses(sample_metadata.agent_id) assert statuses["actions"].current_count == 1 assert statuses["llm_calls"].current_count == 1 assert statuses["file_ops"].current_count == 0 # File action file_action = create_action(sample_metadata, action_type=ActionType.FILE_READ) await rate_limiter.record_action(file_action) statuses = await rate_limiter.get_all_statuses(sample_metadata.agent_id) assert statuses["actions"].current_count == 2 assert statuses["llm_calls"].current_count == 1 assert statuses["file_ops"].current_count == 1 @pytest.mark.asyncio async def test_get_all_statuses( self, rate_limiter: RateLimiter, ) -> None: """Test getting all rate limit statuses.""" # Make some requests await rate_limiter.acquire("actions", "test-key") await rate_limiter.acquire("llm_calls", "test-key") statuses = await rate_limiter.get_all_statuses("test-key") assert "actions" in statuses assert "llm_calls" in statuses assert "file_ops" in statuses assert statuses["actions"].current_count >= 1 assert statuses["llm_calls"].current_count >= 1 @pytest.mark.asyncio async def test_reset_single( self, rate_limiter: RateLimiter, ) -> None: """Test resetting a single rate limit.""" # Make some requests await rate_limiter.acquire("test_limit", "test-key") await rate_limiter.acquire("test_limit", "test-key") # Reset result = await rate_limiter.reset("test_limit", "test-key") assert result is True # Check it's reset status = await rate_limiter.check("test_limit", "test-key") assert status.current_count == 0 @pytest.mark.asyncio async def test_reset_nonexistent( self, rate_limiter: RateLimiter, ) -> None: """Test resetting non-existent limit returns False.""" result = await rate_limiter.reset("nonexistent", "test-key") assert result is False @pytest.mark.asyncio async def test_reset_all( self, rate_limiter: RateLimiter, ) -> None: """Test resetting all rate limits for a key.""" # Make requests across multiple limits await rate_limiter.acquire("actions", "test-key") await rate_limiter.acquire("llm_calls", "test-key") await rate_limiter.acquire("file_ops", "test-key") # Reset all count = await rate_limiter.reset_all("test-key") assert count >= 3 # Check they're reset statuses = await rate_limiter.get_all_statuses("test-key") for status in statuses.values(): assert status.current_count == 0 @pytest.mark.asyncio async def test_per_key_isolation( self, rate_limiter: RateLimiter, ) -> None: """Test rate limits are isolated per key.""" # Use up burst limit for key-1 for _ in range(3): await rate_limiter.acquire("test_limit", "key-1") # key-1 should be limited allowed1, _ = await rate_limiter.acquire("test_limit", "key-1") assert allowed1 is False # key-2 should still be allowed allowed2, _ = await rate_limiter.acquire("test_limit", "key-2") assert allowed2 is True @pytest.mark.asyncio async def test_configure_custom_limit( self, rate_limiter: RateLimiter, ) -> None: """Test configuring custom rate limits.""" rate_limiter.configure( RateLimitConfig( name="custom", limit=100, window_seconds=120, burst_limit=50, ) ) status = await rate_limiter.check("custom", "test-key") assert status.limit == 100 assert status.window_seconds == 120 @pytest.mark.asyncio async def test_default_limit_fallback( self, rate_limiter: RateLimiter, ) -> None: """Test fallback to default limit for unknown limit names.""" # Request limit that doesn't exist status = await rate_limiter.check("unknown_limit", "test-key") # Should use default (60/60s) assert status.limit == 60 assert status.window_seconds == 60