diff --git a/backend/app/services/mcp/client_manager.py b/backend/app/services/mcp/client_manager.py index 7f80a7c..927652a 100644 --- a/backend/app/services/mcp/client_manager.py +++ b/backend/app/services/mcp/client_manager.py @@ -411,7 +411,20 @@ async def shutdown_mcp_client() -> None: _manager_instance = None -def reset_mcp_client() -> None: - """Reset the global MCP client manager (for testing).""" +async def reset_mcp_client() -> None: + """ + Reset the global MCP client manager (for testing). + + This is an async function to properly acquire the manager lock + and avoid race conditions with get_mcp_client(). + """ global _manager_instance - _manager_instance = None + + async with _manager_lock: + if _manager_instance is not None: + # Shutdown gracefully before resetting + try: + await _manager_instance.shutdown() + except Exception: # noqa: S110 + pass # Ignore errors during test cleanup + _manager_instance = None diff --git a/backend/app/services/mcp/connection.py b/backend/app/services/mcp/connection.py index fc71d7d..1775e45 100644 --- a/backend/app/services/mcp/connection.py +++ b/backend/app/services/mcp/connection.py @@ -161,7 +161,7 @@ class MCPConnection: server_name=self.server_name, url=self.config.url, cause=e, - ) + ) from e else: # For STDIO and SSE transports, we'll implement later raise NotImplementedError( @@ -297,13 +297,13 @@ class MCPConnection: server_name=self.server_name, url=f"{self.config.url}{path}", cause=e, - ) + ) from e except Exception as e: raise MCPConnectionError( f"Request failed: {e}", server_name=self.server_name, cause=e, - ) + ) from e class ConnectionPool: @@ -322,8 +322,19 @@ class ConnectionPool: """ self._connections: dict[str, MCPConnection] = {} self._lock = asyncio.Lock() + self._per_server_locks: dict[str, asyncio.Lock] = {} self._max_per_server = max_connections_per_server + def _get_server_lock(self, server_name: str) -> asyncio.Lock: + """Get or create a lock for a specific server. + + Uses setdefault for atomic dict access to prevent race conditions + where two coroutines could create different locks for the same server. + """ + # setdefault is atomic - if key exists, returns existing value + # if key doesn't exist, inserts new value and returns it + return self._per_server_locks.setdefault(server_name, asyncio.Lock()) + async def get_connection( self, server_name: str, @@ -332,6 +343,9 @@ class ConnectionPool: """ Get or create a connection to a server. + Uses per-server locking to avoid blocking all connections + when establishing a new connection. + Args: server_name: Name of the server config: Server configuration @@ -339,17 +353,33 @@ class ConnectionPool: Returns: Active connection """ - async with self._lock: - if server_name not in self._connections: - connection = MCPConnection(server_name, config) - await connection.connect() - self._connections[server_name] = connection - + # Quick check without lock - if connection exists and is connected, return it + if server_name in self._connections: connection = self._connections[server_name] + if connection.is_connected: + return connection - # Reconnect if not connected - if not connection.is_connected: + # Need to create or reconnect - use per-server lock to avoid blocking others + async with self._lock: + server_lock = self._get_server_lock(server_name) + + async with server_lock: + # Double-check after acquiring per-server lock + if server_name in self._connections: + connection = self._connections[server_name] + if connection.is_connected: + return connection + # Connection exists but not connected - reconnect await connection.connect() + return connection + + # Create new connection (outside global lock, under per-server lock) + connection = MCPConnection(server_name, config) + await connection.connect() + + # Store connection under global lock + async with self._lock: + self._connections[server_name] = connection return connection @@ -374,6 +404,9 @@ class ConnectionPool: if server_name in self._connections: await self._connections[server_name].disconnect() del self._connections[server_name] + # Clean up per-server lock + if server_name in self._per_server_locks: + del self._per_server_locks[server_name] async def close_all(self) -> None: """Close all connections in the pool.""" @@ -385,6 +418,7 @@ class ConnectionPool: logger.warning("Error closing connection: %s", e) self._connections.clear() + self._per_server_locks.clear() logger.info("Closed all MCP connections") async def health_check_all(self) -> dict[str, bool]: @@ -394,8 +428,12 @@ class ConnectionPool: Returns: Dict mapping server names to health status """ + # Copy connections under lock to prevent modification during iteration + async with self._lock: + connections_snapshot = dict(self._connections) + results = {} - for name, connection in self._connections.items(): + for name, connection in connections_snapshot.items(): results[name] = await connection.health_check() return results diff --git a/backend/app/services/safety/costs/controller.py b/backend/app/services/safety/costs/controller.py index 1c30ce6..cacf011 100644 --- a/backend/app/services/safety/costs/controller.py +++ b/backend/app/services/safety/costs/controller.py @@ -185,6 +185,9 @@ class CostController: # Alert handlers self._alert_handlers: list[Any] = [] + # Track which budgets have had warning alerts sent (to avoid spam) + self._warned_budgets: set[str] = set() + async def get_or_create_tracker( self, scope: BudgetScope, @@ -343,32 +346,44 @@ class CostController: """ # Update session budget if session_id: + session_key = f"session:{session_id}" session_tracker = await self.get_or_create_tracker( BudgetScope.SESSION, session_id ) await session_tracker.add_usage(tokens, cost_usd) - # Check for warning + # Check for warning (only alert once per budget to avoid spam) status = await session_tracker.get_status() if status.is_warning and not status.is_exceeded: - await self._send_alert( - "warning", - f"Session {session_id} at {status.tokens_used}/{status.tokens_limit} tokens", - status, - ) + if session_key not in self._warned_budgets: + self._warned_budgets.add(session_key) + await self._send_alert( + "warning", + f"Session {session_id} at {status.tokens_used}/{status.tokens_limit} tokens", + status, + ) + elif not status.is_warning: + # Clear warning flag if usage dropped below threshold (e.g., after reset) + self._warned_budgets.discard(session_key) # Update agent daily budget + daily_key = f"daily:{agent_id}" agent_tracker = await self.get_or_create_tracker(BudgetScope.DAILY, agent_id) await agent_tracker.add_usage(tokens, cost_usd) - # Check for warning + # Check for warning (only alert once per budget to avoid spam) status = await agent_tracker.get_status() if status.is_warning and not status.is_exceeded: - await self._send_alert( - "warning", - f"Agent {agent_id} at {status.tokens_used}/{status.tokens_limit} daily tokens", - status, - ) + if daily_key not in self._warned_budgets: + self._warned_budgets.add(daily_key) + await self._send_alert( + "warning", + f"Agent {agent_id} at {status.tokens_used}/{status.tokens_limit} daily tokens", + status, + ) + elif not status.is_warning: + # Clear warning flag if usage dropped below threshold (e.g., after reset) + self._warned_budgets.discard(daily_key) async def get_status( self, @@ -388,20 +403,18 @@ class CostController: key = f"{scope.value}:{scope_id}" async with self._lock: tracker = self._trackers.get(key) - - if tracker: - return await tracker.get_status() - return None + # Get status while holding lock to prevent TOCTOU race + if tracker: + return await tracker.get_status() + return None async def get_all_statuses(self) -> list[BudgetStatus]: """Get status of all tracked budgets.""" statuses = [] async with self._lock: - trackers = list(self._trackers.values()) - - for tracker in trackers: - statuses.append(await tracker.get_status()) - + # Get all statuses while holding lock to prevent TOCTOU race + for tracker in self._trackers.values(): + statuses.append(await tracker.get_status()) return statuses async def set_budget( @@ -453,11 +466,11 @@ class CostController: key = f"{scope.value}:{scope_id}" async with self._lock: tracker = self._trackers.get(key) - - if tracker: - await tracker.reset() - return True - return False + # Reset while holding lock to prevent TOCTOU race + if tracker: + await tracker.reset() + return True + return False def add_alert_handler(self, handler: Any) -> None: """Add an alert handler.""" diff --git a/backend/app/services/safety/guardian.py b/backend/app/services/safety/guardian.py index b79c173..035afc5 100644 --- a/backend/app/services/safety/guardian.py +++ b/backend/app/services/safety/guardian.py @@ -15,13 +15,20 @@ from .config import ( get_policy_for_autonomy_level, get_safety_config, ) +from .costs.controller import CostController from .exceptions import ( + BudgetExceededError, + LoopDetectedError, + RateLimitExceededError, SafetyError, ) +from .limits.limiter import RateLimiter +from .loops.detector import LoopDetector from .models import ( ActionRequest, ActionResult, AuditEventType, + BudgetScope, GuardianResult, SafetyDecision, SafetyPolicy, @@ -62,6 +69,9 @@ class SafetyGuardian: self, config: SafetyConfig | None = None, audit_logger: AuditLogger | None = None, + cost_controller: CostController | None = None, + rate_limiter: RateLimiter | None = None, + loop_detector: LoopDetector | None = None, ) -> None: """ Initialize the SafetyGuardian. @@ -69,17 +79,22 @@ class SafetyGuardian: Args: config: Optional safety configuration. If None, loads from environment. audit_logger: Optional audit logger. If None, uses global instance. + cost_controller: Optional cost controller. If None, creates default. + rate_limiter: Optional rate limiter. If None, creates default. + loop_detector: Optional loop detector. If None, creates default. """ self._config = config or get_safety_config() self._audit_logger = audit_logger self._initialized = False self._lock = asyncio.Lock() - # Subsystem references (will be initialized lazily) + # Core safety subsystems (always initialized) + self._cost_controller: CostController | None = cost_controller + self._rate_limiter: RateLimiter | None = rate_limiter + self._loop_detector: LoopDetector | None = loop_detector + + # Optional subsystems (will be initialized when available) self._permission_manager: Any = None - self._cost_controller: Any = None - self._rate_limiter: Any = None - self._loop_detector: Any = None self._hitl_manager: Any = None self._rollback_manager: Any = None self._content_filter: Any = None @@ -95,6 +110,21 @@ class SafetyGuardian: """Check if the guardian is initialized.""" return self._initialized + @property + def cost_controller(self) -> CostController | None: + """Get the cost controller instance.""" + return self._cost_controller + + @property + def rate_limiter(self) -> RateLimiter | None: + """Get the rate limiter instance.""" + return self._rate_limiter + + @property + def loop_detector(self) -> LoopDetector | None: + """Get the loop detector instance.""" + return self._loop_detector + async def initialize(self) -> None: """Initialize the SafetyGuardian and all subsystems.""" async with self._lock: @@ -108,11 +138,23 @@ class SafetyGuardian: if self._audit_logger is None: self._audit_logger = await get_audit_logger() - # Initialize subsystems lazily as they're implemented - # For now, we'll import and initialize them when available + # Initialize core safety subsystems + if self._cost_controller is None: + self._cost_controller = CostController() + logger.debug("Initialized CostController") + + if self._rate_limiter is None: + self._rate_limiter = RateLimiter() + logger.debug("Initialized RateLimiter") + + if self._loop_detector is None: + self._loop_detector = LoopDetector() + logger.debug("Initialized LoopDetector") self._initialized = True - logger.info("SafetyGuardian initialized") + logger.info( + "SafetyGuardian initialized with CostController, RateLimiter, LoopDetector" + ) async def shutdown(self) -> None: """Shutdown the SafetyGuardian and all subsystems.""" @@ -309,13 +351,40 @@ class SafetyGuardian: # Update cost tracking if self._cost_controller: - # Track actual cost - pass + try: + # Use explicit None check - 0 is a valid cost value + tokens = ( + result.actual_cost_tokens + if result.actual_cost_tokens is not None + else action.estimated_cost_tokens + ) + cost_usd = ( + result.actual_cost_usd + if result.actual_cost_usd is not None + else action.estimated_cost_usd + ) + await self._cost_controller.record_usage( + agent_id=action.metadata.agent_id, + session_id=action.metadata.session_id, + tokens=tokens, + cost_usd=cost_usd, + ) + except Exception as e: + logger.warning("Failed to record cost: %s", e) + + # Update rate limiter - consume slots for executed actions + if self._rate_limiter: + try: + await self._rate_limiter.record_action(action) + except Exception as e: + logger.warning("Failed to record action in rate limiter: %s", e) # Update loop detection history if self._loop_detector: - # Add to action history - pass + try: + await self._loop_detector.record(action) + except Exception as e: + logger.warning("Failed to record action in loop detector: %s", e) async def rollback(self, checkpoint_id: str) -> bool: """ @@ -442,14 +511,80 @@ class SafetyGuardian: policy: SafetyPolicy, ) -> GuardianResult: """Check if action is within budget.""" - # TODO: Implement with CostController - # For now, return allow - return GuardianResult( - action_id=action.id, - allowed=True, - decision=SafetyDecision.ALLOW, - reasons=["Budget check passed (not fully implemented)"], - ) + if self._cost_controller is None: + logger.warning("CostController not initialized - skipping budget check") + return GuardianResult( + action_id=action.id, + allowed=True, + decision=SafetyDecision.ALLOW, + reasons=["Budget check skipped (controller not initialized)"], + ) + + agent_id = action.metadata.agent_id + session_id = action.metadata.session_id + + try: + # Check if we have budget for this action + has_budget = await self._cost_controller.check_budget( + agent_id=agent_id, + session_id=session_id, + estimated_tokens=action.estimated_cost_tokens, + estimated_cost_usd=action.estimated_cost_usd, + ) + + if not has_budget: + # Get current status for better error message + if session_id: + session_status = await self._cost_controller.get_status( + BudgetScope.SESSION, session_id + ) + if session_status and session_status.is_exceeded: + return GuardianResult( + action_id=action.id, + allowed=False, + decision=SafetyDecision.DENY, + reasons=[ + f"Session budget exceeded: {session_status.tokens_used}" + f"/{session_status.tokens_limit} tokens" + ], + ) + + agent_status = await self._cost_controller.get_status( + BudgetScope.DAILY, agent_id + ) + if agent_status and agent_status.is_exceeded: + return GuardianResult( + action_id=action.id, + allowed=False, + decision=SafetyDecision.DENY, + reasons=[ + f"Daily budget exceeded: {agent_status.tokens_used}" + f"/{agent_status.tokens_limit} tokens" + ], + ) + + # Generic budget exceeded + return GuardianResult( + action_id=action.id, + allowed=False, + decision=SafetyDecision.DENY, + reasons=["Budget exceeded"], + ) + + return GuardianResult( + action_id=action.id, + allowed=True, + decision=SafetyDecision.ALLOW, + reasons=["Budget check passed"], + ) + + except BudgetExceededError as e: + return GuardianResult( + action_id=action.id, + allowed=False, + decision=SafetyDecision.DENY, + reasons=[str(e)], + ) async def _check_rate_limit( self, @@ -457,14 +592,78 @@ class SafetyGuardian: policy: SafetyPolicy, ) -> GuardianResult: """Check if action is within rate limits.""" - # TODO: Implement with RateLimiter - # For now, return allow - return GuardianResult( - action_id=action.id, - allowed=True, - decision=SafetyDecision.ALLOW, - reasons=["Rate limit check passed (not fully implemented)"], - ) + if self._rate_limiter is None: + logger.warning("RateLimiter not initialized - skipping rate limit check") + return GuardianResult( + action_id=action.id, + allowed=True, + decision=SafetyDecision.ALLOW, + reasons=["Rate limit check skipped (limiter not initialized)"], + ) + + try: + # Check all applicable rate limits for this action + allowed, statuses = await self._rate_limiter.check_action(action) + + if not allowed: + # Find the first exceeded limit for the error message + exceeded_status = next( + (s for s in statuses if s.is_limited), + statuses[0] if statuses else None, + ) + + if exceeded_status: + retry_after = exceeded_status.retry_after_seconds + + # Determine if this is a soft limit (delay) or hard limit (deny) + if retry_after > 0 and retry_after <= 5.0: + # Short wait - suggest delay + return GuardianResult( + action_id=action.id, + allowed=False, + decision=SafetyDecision.DELAY, + reasons=[ + f"Rate limit '{exceeded_status.name}' exceeded. " + f"Current: {exceeded_status.current_count}/{exceeded_status.limit}" + ], + retry_after_seconds=retry_after, + ) + else: + # Hard deny + return GuardianResult( + action_id=action.id, + allowed=False, + decision=SafetyDecision.DENY, + reasons=[ + f"Rate limit '{exceeded_status.name}' exceeded. " + f"Current: {exceeded_status.current_count}/{exceeded_status.limit}. " + f"Retry after {retry_after:.1f}s" + ], + retry_after_seconds=retry_after, + ) + + return GuardianResult( + action_id=action.id, + allowed=False, + decision=SafetyDecision.DENY, + reasons=["Rate limit exceeded"], + ) + + return GuardianResult( + action_id=action.id, + allowed=True, + decision=SafetyDecision.ALLOW, + reasons=["Rate limit check passed"], + ) + + except RateLimitExceededError as e: + return GuardianResult( + action_id=action.id, + allowed=False, + decision=SafetyDecision.DENY, + reasons=[str(e)], + retry_after_seconds=e.retry_after_seconds, + ) async def _check_loops( self, @@ -472,14 +671,51 @@ class SafetyGuardian: policy: SafetyPolicy, ) -> GuardianResult: """Check for action loops.""" - # TODO: Implement with LoopDetector - # For now, return allow - return GuardianResult( - action_id=action.id, - allowed=True, - decision=SafetyDecision.ALLOW, - reasons=["Loop check passed (not fully implemented)"], - ) + if self._loop_detector is None: + logger.warning("LoopDetector not initialized - skipping loop check") + return GuardianResult( + action_id=action.id, + allowed=True, + decision=SafetyDecision.ALLOW, + reasons=["Loop check skipped (detector not initialized)"], + ) + + try: + # Check if this action would create a loop + is_loop, loop_type = await self._loop_detector.check(action) + + if is_loop: + # Get suggestions for breaking the loop + from .loops.detector import LoopBreaker + + suggestions = await LoopBreaker.suggest_alternatives( + action, loop_type or "unknown" + ) + + return GuardianResult( + action_id=action.id, + allowed=False, + decision=SafetyDecision.DENY, + reasons=[ + f"Loop detected: {loop_type}", + *suggestions, + ], + ) + + return GuardianResult( + action_id=action.id, + allowed=True, + decision=SafetyDecision.ALLOW, + reasons=["Loop check passed"], + ) + + except LoopDetectedError as e: + return GuardianResult( + action_id=action.id, + allowed=False, + decision=SafetyDecision.DENY, + reasons=[str(e)], + ) async def _check_hitl( self, @@ -610,7 +846,19 @@ async def shutdown_safety_guardian() -> None: _guardian_instance = None -def reset_safety_guardian() -> None: - """Reset the SafetyGuardian (for testing).""" +async def reset_safety_guardian() -> None: + """ + Reset the SafetyGuardian (for testing). + + This is an async function to properly acquire the guardian lock + and avoid race conditions with get_safety_guardian(). + """ global _guardian_instance - _guardian_instance = None + + async with _guardian_lock: + if _guardian_instance is not None: + try: + await _guardian_instance.shutdown() + except Exception: # noqa: S110 + pass # Ignore errors during test cleanup + _guardian_instance = None diff --git a/backend/app/services/safety/limits/limiter.py b/backend/app/services/safety/limits/limiter.py index bc94ab0..bd94a6d 100644 --- a/backend/app/services/safety/limits/limiter.py +++ b/backend/app/services/safety/limits/limiter.py @@ -223,7 +223,10 @@ class RateLimiter: action: ActionRequest, ) -> tuple[bool, list[RateLimitStatus]]: """ - Check all applicable rate limits for an action. + Check all applicable rate limits for an action WITHOUT consuming slots. + + Use this during validation to check if action would be allowed. + Call record_action() after successful execution to consume slots. Args: action: The action to check @@ -235,28 +238,53 @@ class RateLimiter: statuses: list[RateLimitStatus] = [] allowed = True - # Check general actions limit - actions_allowed, actions_status = await self.acquire("actions", agent_id) + # Check general actions limit (read-only) + actions_status = await self.check("actions", agent_id) statuses.append(actions_status) - if not actions_allowed: + if actions_status.is_limited: allowed = False # Check LLM-specific limit for LLM calls if action.action_type.value == "llm_call": - llm_allowed, llm_status = await self.acquire("llm_calls", agent_id) + llm_status = await self.check("llm_calls", agent_id) statuses.append(llm_status) - if not llm_allowed: + if llm_status.is_limited: allowed = False # Check file ops limit for file operations if action.action_type.value in {"file_read", "file_write", "file_delete"}: - file_allowed, file_status = await self.acquire("file_ops", agent_id) + file_status = await self.check("file_ops", agent_id) statuses.append(file_status) - if not file_allowed: + if file_status.is_limited: allowed = False return allowed, statuses + async def record_action( + self, + action: ActionRequest, + ) -> None: + """ + Record an action by consuming rate limit slots. + + Call this AFTER successful execution to properly count the action. + + Args: + action: The executed action + """ + agent_id = action.metadata.agent_id + + # Consume general actions slot + await self.acquire("actions", agent_id) + + # Consume LLM-specific slot for LLM calls + if action.action_type.value == "llm_call": + await self.acquire("llm_calls", agent_id) + + # Consume file ops slot for file operations + if action.action_type.value in {"file_read", "file_write", "file_delete"}: + await self.acquire("file_ops", agent_id) + async def require( self, limit_name: str, diff --git a/backend/tests/services/mcp/test_client_manager.py b/backend/tests/services/mcp/test_client_manager.py index c7ddada..13d2e88 100644 --- a/backend/tests/services/mcp/test_client_manager.py +++ b/backend/tests/services/mcp/test_client_manager.py @@ -20,13 +20,13 @@ from app.services.mcp.routing import ToolInfo, ToolResult @pytest.fixture -def reset_registry(): +async def reset_registry(): """Reset the singleton registry before and after each test.""" MCPServerRegistry.reset_instance() - reset_mcp_client() + await reset_mcp_client() yield MCPServerRegistry.reset_instance() - reset_mcp_client() + await reset_mcp_client() @pytest.fixture @@ -388,7 +388,8 @@ class TestModuleLevelFunctions: mock_shutdown.return_value = None await shutdown_mcp_client() - def test_reset_mcp_client(self, reset_registry): + @pytest.mark.asyncio + async def test_reset_mcp_client(self, reset_registry): """Test resetting the global client.""" - reset_mcp_client() + await reset_mcp_client() # Should not raise diff --git a/backend/tests/services/safety/test_costs.py b/backend/tests/services/safety/test_costs.py new file mode 100644 index 0000000..0897cd6 --- /dev/null +++ b/backend/tests/services/safety/test_costs.py @@ -0,0 +1,436 @@ +"""Tests for cost controller module.""" + +import pytest + +from app.services.safety.costs.controller import ( + BudgetTracker, + CostController, +) +from app.services.safety.exceptions import BudgetExceededError +from app.services.safety.models import ( + ActionMetadata, + ActionRequest, + ActionType, + BudgetScope, +) + + +@pytest.fixture +def budget_tracker() -> BudgetTracker: + """Create a budget tracker for testing.""" + return BudgetTracker( + scope=BudgetScope.SESSION, + scope_id="test-session", + tokens_limit=1000, + cost_limit_usd=10.0, + warning_threshold=0.8, + ) + + +@pytest.fixture +def cost_controller() -> CostController: + """Create a cost controller for testing.""" + return CostController( + default_session_tokens=1000, + default_session_cost_usd=10.0, + default_daily_tokens=5000, + default_daily_cost_usd=50.0, + ) + + +@pytest.fixture +def sample_metadata() -> ActionMetadata: + """Create sample action metadata.""" + return ActionMetadata( + agent_id="test-agent", + session_id="test-session", + ) + + +def create_action( + metadata: ActionMetadata, + estimated_tokens: int = 100, + estimated_cost: float = 0.01, +) -> ActionRequest: + """Helper to create test actions.""" + return ActionRequest( + action_type=ActionType.LLM_CALL, + tool_name="test_tool", + resource="test-resource", + arguments={}, + metadata=metadata, + estimated_cost_tokens=estimated_tokens, + estimated_cost_usd=estimated_cost, + ) + + +class TestBudgetTracker: + """Tests for BudgetTracker class.""" + + @pytest.mark.asyncio + async def test_initial_status(self, budget_tracker: BudgetTracker) -> None: + """Test initial budget status is clean.""" + status = await budget_tracker.get_status() + + assert status.tokens_used == 0 + assert status.cost_used_usd == 0.0 + assert status.tokens_remaining == 1000 + assert status.cost_remaining_usd == 10.0 + assert status.is_warning is False + assert status.is_exceeded is False + + @pytest.mark.asyncio + async def test_add_usage(self, budget_tracker: BudgetTracker) -> None: + """Test adding usage updates counters.""" + await budget_tracker.add_usage(tokens=100, cost_usd=1.0) + + status = await budget_tracker.get_status() + assert status.tokens_used == 100 + assert status.cost_used_usd == 1.0 + assert status.tokens_remaining == 900 + assert status.cost_remaining_usd == 9.0 + + @pytest.mark.asyncio + async def test_warning_threshold(self, budget_tracker: BudgetTracker) -> None: + """Test warning is triggered at threshold.""" + # Add usage to reach 80% of tokens + await budget_tracker.add_usage(tokens=800, cost_usd=1.0) + + status = await budget_tracker.get_status() + assert status.is_warning is True + assert status.is_exceeded is False + + @pytest.mark.asyncio + async def test_budget_exceeded(self, budget_tracker: BudgetTracker) -> None: + """Test budget exceeded detection.""" + # Exceed token limit + await budget_tracker.add_usage(tokens=1100, cost_usd=1.0) + + status = await budget_tracker.get_status() + assert status.is_exceeded is True + + @pytest.mark.asyncio + async def test_check_budget_allows(self, budget_tracker: BudgetTracker) -> None: + """Test check_budget allows within budget.""" + result = await budget_tracker.check_budget( + estimated_tokens=500, + estimated_cost_usd=5.0, + ) + assert result is True + + @pytest.mark.asyncio + async def test_check_budget_denies(self, budget_tracker: BudgetTracker) -> None: + """Test check_budget denies when would exceed.""" + # Use most of the budget + await budget_tracker.add_usage(tokens=800, cost_usd=8.0) + + # Check would exceed + result = await budget_tracker.check_budget( + estimated_tokens=300, + estimated_cost_usd=3.0, + ) + assert result is False + + @pytest.mark.asyncio + async def test_reset(self, budget_tracker: BudgetTracker) -> None: + """Test manual reset clears counters.""" + await budget_tracker.add_usage(tokens=500, cost_usd=5.0) + await budget_tracker.reset() + + status = await budget_tracker.get_status() + assert status.tokens_used == 0 + assert status.cost_used_usd == 0.0 + + +class TestCostController: + """Tests for CostController class.""" + + @pytest.mark.asyncio + async def test_check_budget_success( + self, + cost_controller: CostController, + ) -> None: + """Test budget check passes with available budget.""" + result = await cost_controller.check_budget( + agent_id="test-agent", + session_id="test-session", + estimated_tokens=100, + estimated_cost_usd=1.0, + ) + assert result is True + + @pytest.mark.asyncio + async def test_check_budget_session_exceeded( + self, + cost_controller: CostController, + ) -> None: + """Test budget check fails when session budget exceeded.""" + # Use most of session budget + await cost_controller.record_usage( + agent_id="test-agent", + session_id="test-session", + tokens=900, + cost_usd=9.0, + ) + + # Check would exceed + result = await cost_controller.check_budget( + agent_id="test-agent", + session_id="test-session", + estimated_tokens=200, + estimated_cost_usd=2.0, + ) + assert result is False + + @pytest.mark.asyncio + async def test_check_budget_daily_exceeded( + self, + cost_controller: CostController, + ) -> None: + """Test budget check fails when daily budget exceeded.""" + # Use most of daily budget + await cost_controller.record_usage( + agent_id="test-agent", + session_id=None, + tokens=4900, + cost_usd=49.0, + ) + + # Check would exceed daily + result = await cost_controller.check_budget( + agent_id="test-agent", + session_id="new-session", + estimated_tokens=200, + estimated_cost_usd=2.0, + ) + assert result is False + + @pytest.mark.asyncio + async def test_check_action( + self, + cost_controller: CostController, + sample_metadata: ActionMetadata, + ) -> None: + """Test checking action budget.""" + action = create_action( + sample_metadata, + estimated_tokens=100, + estimated_cost=0.01, + ) + + result = await cost_controller.check_action(action) + assert result is True + + @pytest.mark.asyncio + async def test_require_budget_success( + self, + cost_controller: CostController, + ) -> None: + """Test require_budget passes when budget available.""" + # Should not raise + await cost_controller.require_budget( + agent_id="test-agent", + session_id="test-session", + estimated_tokens=100, + estimated_cost_usd=1.0, + ) + + @pytest.mark.asyncio + async def test_require_budget_raises( + self, + cost_controller: CostController, + ) -> None: + """Test require_budget raises when budget exceeded.""" + # Use all session budget + await cost_controller.record_usage( + agent_id="test-agent", + session_id="test-session", + tokens=1000, + cost_usd=10.0, + ) + + with pytest.raises(BudgetExceededError) as exc_info: + await cost_controller.require_budget( + agent_id="test-agent", + session_id="test-session", + estimated_tokens=100, + estimated_cost_usd=1.0, + ) + + assert "session" in exc_info.value.budget_type.lower() + + @pytest.mark.asyncio + async def test_record_usage( + self, + cost_controller: CostController, + ) -> None: + """Test recording usage updates trackers.""" + await cost_controller.record_usage( + agent_id="test-agent", + session_id="test-session", + tokens=100, + cost_usd=1.0, + ) + + # Check session budget was updated + session_status = await cost_controller.get_status( + BudgetScope.SESSION, "test-session" + ) + assert session_status is not None + assert session_status.tokens_used == 100 + + # Check daily budget was updated + daily_status = await cost_controller.get_status(BudgetScope.DAILY, "test-agent") + assert daily_status is not None + assert daily_status.tokens_used == 100 + + @pytest.mark.asyncio + async def test_get_all_statuses( + self, + cost_controller: CostController, + ) -> None: + """Test getting all budget statuses.""" + # Record some usage + await cost_controller.record_usage( + agent_id="agent-1", + session_id="session-1", + tokens=100, + cost_usd=1.0, + ) + await cost_controller.record_usage( + agent_id="agent-2", + session_id="session-2", + tokens=200, + cost_usd=2.0, + ) + + statuses = await cost_controller.get_all_statuses() + assert len(statuses) >= 2 + + @pytest.mark.asyncio + async def test_set_budget( + self, + cost_controller: CostController, + ) -> None: + """Test setting custom budget.""" + await cost_controller.set_budget( + scope=BudgetScope.SESSION, + scope_id="custom-session", + tokens_limit=5000, + cost_limit_usd=50.0, + ) + + status = await cost_controller.get_status(BudgetScope.SESSION, "custom-session") + assert status is not None + assert status.tokens_limit == 5000 + assert status.cost_limit_usd == 50.0 + + @pytest.mark.asyncio + async def test_reset_budget( + self, + cost_controller: CostController, + ) -> None: + """Test resetting budget.""" + # Record usage + await cost_controller.record_usage( + agent_id="test-agent", + session_id="test-session", + tokens=500, + cost_usd=5.0, + ) + + # Reset session budget + result = await cost_controller.reset_budget(BudgetScope.SESSION, "test-session") + assert result is True + + # Verify reset + status = await cost_controller.get_status(BudgetScope.SESSION, "test-session") + assert status is not None + assert status.tokens_used == 0 + + @pytest.mark.asyncio + async def test_reset_nonexistent_budget( + self, + cost_controller: CostController, + ) -> None: + """Test resetting non-existent budget returns False.""" + result = await cost_controller.reset_budget(BudgetScope.SESSION, "nonexistent") + assert result is False + + @pytest.mark.asyncio + async def test_alert_handler( + self, + cost_controller: CostController, + ) -> None: + """Test alert handler is called at warning threshold.""" + alerts_received = [] + + def alert_handler(alert_type: str, message: str, status): + alerts_received.append((alert_type, message)) + + cost_controller.add_alert_handler(alert_handler) + + # Record usage to reach warning threshold (80%) + await cost_controller.record_usage( + agent_id="test-agent", + session_id="test-session", + tokens=850, # 85% of 1000 + cost_usd=0.0, + ) + + assert len(alerts_received) > 0 + assert alerts_received[0][0] == "warning" + + @pytest.mark.asyncio + async def test_remove_alert_handler( + self, + cost_controller: CostController, + ) -> None: + """Test removing alert handler.""" + alerts_received = [] + + def alert_handler(alert_type: str, message: str, status): + alerts_received.append((alert_type, message)) + + cost_controller.add_alert_handler(alert_handler) + cost_controller.remove_alert_handler(alert_handler) + + # Record usage to reach warning threshold + await cost_controller.record_usage( + agent_id="test-agent", + session_id="test-session", + tokens=850, + cost_usd=0.0, + ) + + assert len(alerts_received) == 0 + + @pytest.mark.asyncio + async def test_alert_deduplication( + self, + cost_controller: CostController, + ) -> None: + """Test alerts are only sent once per budget (no spam).""" + alerts_received = [] + + def alert_handler(alert_type: str, message: str, status): + alerts_received.append((alert_type, message)) + + cost_controller.add_alert_handler(alert_handler) + + # Record usage multiple times at warning level + # Session budget is 1000 with 80% threshold = 800 tokens + # 10 * 85 = 850 tokens triggers session warning once + for _ in range(10): + await cost_controller.record_usage( + agent_id="test-agent", + session_id="test-session", + tokens=85, # Each call adds 85 tokens + cost_usd=0.0, + ) + + # Should only receive ONE session warning (daily budget of 5000 + # isn't reached yet). The key point is we don't get 10 alerts! + assert len(alerts_received) == 1 + assert alerts_received[0][0] == "warning" + assert "Session" in alerts_received[0][1] diff --git a/backend/tests/services/safety/test_guardian.py b/backend/tests/services/safety/test_guardian.py new file mode 100644 index 0000000..ebdd2dd --- /dev/null +++ b/backend/tests/services/safety/test_guardian.py @@ -0,0 +1,508 @@ +"""Tests for SafetyGuardian integration.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import pytest_asyncio + +from app.services.safety.config import SafetyConfig +from app.services.safety.costs.controller import CostController +from app.services.safety.guardian import ( + SafetyGuardian, + get_safety_guardian, + reset_safety_guardian, + shutdown_safety_guardian, +) +from app.services.safety.limits.limiter import RateLimiter +from app.services.safety.loops.detector import LoopDetector +from app.services.safety.models import ( + ActionMetadata, + ActionRequest, + ActionResult, + ActionType, + AuditEvent, + AuditEventType, + AutonomyLevel, + BudgetScope, + SafetyDecision, + SafetyPolicy, +) + + +@pytest_asyncio.fixture +async def reset_guardian(): + """Reset the singleton guardian before and after each test.""" + await reset_safety_guardian() + yield + await reset_safety_guardian() + + +@pytest.fixture +def safety_config() -> SafetyConfig: + """Create a test safety configuration.""" + return SafetyConfig( + enabled=True, + strict_mode=True, + hitl_enabled=False, + auto_checkpoint_destructive=False, + ) + + +@pytest.fixture +def cost_controller() -> CostController: + """Create a cost controller for testing.""" + return CostController( + default_session_tokens=1000, + default_session_cost_usd=10.0, + default_daily_tokens=5000, + default_daily_cost_usd=50.0, + ) + + +@pytest.fixture +def rate_limiter() -> RateLimiter: + """Create a rate limiter for testing.""" + return RateLimiter() + + +@pytest.fixture +def loop_detector() -> LoopDetector: + """Create a loop detector for testing.""" + return LoopDetector( + history_size=10, + max_exact_repetitions=3, + max_semantic_repetitions=5, + ) + + +def _make_audit_event() -> AuditEvent: + """Create a mock audit event.""" + return AuditEvent( + event_type=AuditEventType.ACTION_REQUESTED, + agent_id="test-agent", + action_id="test-action", + ) + + +@pytest_asyncio.fixture +async def guardian( + safety_config: SafetyConfig, + cost_controller: CostController, + rate_limiter: RateLimiter, + loop_detector: LoopDetector, +) -> SafetyGuardian: + """Create a SafetyGuardian for testing.""" + guardian = SafetyGuardian( + config=safety_config, + cost_controller=cost_controller, + rate_limiter=rate_limiter, + loop_detector=loop_detector, + ) + # Patch the audit logger to avoid actual logging + # Return proper AuditEvent objects instead of AsyncMock + guardian._audit_logger = MagicMock() + guardian._audit_logger.log = AsyncMock(return_value=_make_audit_event()) + guardian._audit_logger.log_action_request = AsyncMock( + return_value=_make_audit_event() + ) + guardian._audit_logger.log_action_executed = AsyncMock(return_value=None) + await guardian.initialize() + return guardian + + +@pytest.fixture +def sample_metadata() -> ActionMetadata: + """Create sample action metadata.""" + return ActionMetadata( + agent_id="test-agent", + session_id="test-session", + autonomy_level=AutonomyLevel.MILESTONE, + ) + + +def create_action( + metadata: ActionMetadata, + tool_name: str = "test_tool", + action_type: ActionType = ActionType.LLM_CALL, + resource: str = "/tmp/test.txt", # noqa: S108 + estimated_tokens: int = 100, + estimated_cost: float = 0.01, +) -> ActionRequest: + """Helper to create test actions.""" + return ActionRequest( + action_type=action_type, + tool_name=tool_name, + resource=resource, + arguments={}, + metadata=metadata, + estimated_cost_tokens=estimated_tokens, + estimated_cost_usd=estimated_cost, + ) + + +class TestSafetyGuardianInit: + """Tests for SafetyGuardian initialization.""" + + @pytest.mark.asyncio + async def test_init_creates_subsystems( + self, + safety_config: SafetyConfig, + ) -> None: + """Test initialization creates subsystems if not provided.""" + with patch( + "app.services.safety.guardian.get_audit_logger", + new_callable=AsyncMock, + ): + guardian = SafetyGuardian(config=safety_config) + await guardian.initialize() + + assert guardian.cost_controller is not None + assert guardian.rate_limiter is not None + assert guardian.loop_detector is not None + assert guardian.is_initialized is True + + @pytest.mark.asyncio + async def test_init_with_provided_subsystems( + self, + safety_config: SafetyConfig, + cost_controller: CostController, + rate_limiter: RateLimiter, + loop_detector: LoopDetector, + ) -> None: + """Test initialization uses provided subsystems.""" + guardian = SafetyGuardian( + config=safety_config, + cost_controller=cost_controller, + rate_limiter=rate_limiter, + loop_detector=loop_detector, + ) + guardian._audit_logger = MagicMock() + await guardian.initialize() + + # Should use the provided instances + assert guardian.cost_controller is cost_controller + assert guardian.rate_limiter is rate_limiter + assert guardian.loop_detector is loop_detector + + +class TestSafetyGuardianValidation: + """Tests for SafetyGuardian.validate().""" + + @pytest.mark.asyncio + async def test_validate_success( + self, + guardian: SafetyGuardian, + sample_metadata: ActionMetadata, + ) -> None: + """Test successful validation passes all checks.""" + action = create_action(sample_metadata) + + result = await guardian.validate(action) + + assert result.allowed is True + assert result.decision == SafetyDecision.ALLOW + + @pytest.mark.asyncio + async def test_validate_disabled_allows_all( + self, + guardian: SafetyGuardian, + sample_metadata: ActionMetadata, + ) -> None: + """Test validation with disabled safety allows all.""" + guardian._config.enabled = False + action = create_action(sample_metadata) + + result = await guardian.validate(action) + + assert result.allowed is True + assert "disabled" in result.reasons[0].lower() + + @pytest.mark.asyncio + async def test_validate_budget_exceeded( + self, + guardian: SafetyGuardian, + sample_metadata: ActionMetadata, + ) -> None: + """Test validation fails when budget exceeded.""" + # Use up the session budget + await guardian.cost_controller.record_usage( + agent_id=sample_metadata.agent_id, + session_id=sample_metadata.session_id, + tokens=1000, + cost_usd=10.0, + ) + + action = create_action(sample_metadata, estimated_tokens=100) + result = await guardian.validate(action) + + assert result.allowed is False + assert result.decision == SafetyDecision.DENY + assert any("budget" in r.lower() for r in result.reasons) + + @pytest.mark.asyncio + async def test_validate_rate_limit_exceeded( + self, + guardian: SafetyGuardian, + sample_metadata: ActionMetadata, + ) -> None: + """Test validation fails when rate limit exceeded.""" + # Exhaust rate limits by calling validate many times + for _ in range(100): # More than default limit + action = create_action(sample_metadata) + await guardian.rate_limiter.acquire("actions", sample_metadata.agent_id) + + action = create_action(sample_metadata) + result = await guardian.validate(action) + + # Should be denied or delayed + assert result.allowed is False + assert result.decision in (SafetyDecision.DENY, SafetyDecision.DELAY) + + @pytest.mark.asyncio + async def test_validate_loop_detected( + self, + guardian: SafetyGuardian, + sample_metadata: ActionMetadata, + ) -> None: + """Test validation fails when loop detected.""" + action = create_action(sample_metadata) + + # Record the same action multiple times (to trigger loop) + for _ in range(3): + await guardian.loop_detector.record(action) + + result = await guardian.validate(action) + + assert result.allowed is False + assert result.decision == SafetyDecision.DENY + assert any("loop" in r.lower() for r in result.reasons) + + @pytest.mark.asyncio + async def test_validate_denied_tool( + self, + guardian: SafetyGuardian, + sample_metadata: ActionMetadata, + ) -> None: + """Test validation fails for denied tools.""" + # Create action with tool that matches denied pattern + action = create_action(sample_metadata, tool_name="shell_exec") + + # Create policy with denied pattern + policy = SafetyPolicy( + name="test-policy", + allowed_tools=["*"], + denied_tools=["shell_*"], + ) + + result = await guardian.validate(action, policy=policy) + + assert result.allowed is False + assert result.decision == SafetyDecision.DENY + assert any("denied" in r.lower() for r in result.reasons) + + @pytest.mark.asyncio + async def test_validate_with_custom_policy( + self, + guardian: SafetyGuardian, + sample_metadata: ActionMetadata, + ) -> None: + """Test validation with custom policy.""" + action = create_action(sample_metadata, tool_name="allowed_tool") + + policy = SafetyPolicy( + name="test-custom-policy", + allowed_tools=["allowed_*"], + denied_tools=[], + ) + + result = await guardian.validate(action, policy=policy) + + assert result.allowed is True + assert result.decision == SafetyDecision.ALLOW + + +class TestSafetyGuardianRecording: + """Tests for SafetyGuardian.record_execution().""" + + @pytest.mark.asyncio + async def test_record_execution_updates_cost( + self, + guardian: SafetyGuardian, + sample_metadata: ActionMetadata, + ) -> None: + """Test recording execution updates cost tracker.""" + action = create_action(sample_metadata) + action_result = ActionResult( + action_id=action.id, + success=True, + actual_cost_tokens=50, + actual_cost_usd=0.005, + ) + + await guardian.record_execution(action, action_result) + + # Check cost was recorded + status = await guardian.cost_controller.get_status( + BudgetScope.SESSION, sample_metadata.session_id + ) + assert status is not None + assert status.tokens_used == 50 + + @pytest.mark.asyncio + async def test_record_execution_updates_loop_history( + self, + guardian: SafetyGuardian, + sample_metadata: ActionMetadata, + ) -> None: + """Test recording execution updates loop detector history.""" + action = create_action(sample_metadata) + action_result = ActionResult( + action_id=action.id, + success=True, + ) + + await guardian.record_execution(action, action_result) + + # Check action was recorded in loop detector + stats = await guardian.loop_detector.get_stats(sample_metadata.agent_id) + assert stats["history_size"] == 1 + + +class TestSafetyGuardianSingleton: + """Tests for SafetyGuardian singleton functions.""" + + @pytest.mark.asyncio + async def test_get_safety_guardian_creates_singleton( + self, + reset_guardian, + ) -> None: + """Test get_safety_guardian creates singleton.""" + with patch( + "app.services.safety.guardian.get_audit_logger", + new_callable=AsyncMock, + ): + guardian1 = await get_safety_guardian() + guardian2 = await get_safety_guardian() + + assert guardian1 is guardian2 + assert guardian1.is_initialized is True + + @pytest.mark.asyncio + async def test_shutdown_safety_guardian( + self, + reset_guardian, + ) -> None: + """Test shutdown cleans up singleton.""" + with patch( + "app.services.safety.guardian.get_audit_logger", + new_callable=AsyncMock, + ): + guardian = await get_safety_guardian() + assert guardian.is_initialized is True + + await shutdown_safety_guardian() + # Singleton should be cleared - next get creates new instance + + @pytest.mark.asyncio + async def test_reset_safety_guardian( + self, + reset_guardian, + ) -> None: + """Test reset clears singleton.""" + with patch( + "app.services.safety.guardian.get_audit_logger", + new_callable=AsyncMock, + ): + guardian1 = await get_safety_guardian() + + await reset_safety_guardian() + + guardian2 = await get_safety_guardian() + assert guardian1 is not guardian2 + + +class TestPatternMatching: + """Tests for pattern matching logic.""" + + def test_exact_match(self) -> None: + """Test exact pattern matching.""" + guardian = SafetyGuardian() + assert guardian._matches_pattern("file_read", "file_read") is True + assert guardian._matches_pattern("file_read", "file_write") is False + + def test_wildcard_all(self) -> None: + """Test wildcard * matches all.""" + guardian = SafetyGuardian() + assert guardian._matches_pattern("anything", "*") is True + assert guardian._matches_pattern("", "*") is True + + def test_prefix_wildcard(self) -> None: + """Test prefix wildcard matching.""" + guardian = SafetyGuardian() + assert guardian._matches_pattern("test_read", "*_read") is True + assert guardian._matches_pattern("test_write", "*_read") is False + + def test_suffix_wildcard(self) -> None: + """Test suffix wildcard matching.""" + guardian = SafetyGuardian() + assert guardian._matches_pattern("file_read", "file_*") is True + assert guardian._matches_pattern("shell_read", "file_*") is False + + def test_contains_wildcard(self) -> None: + """Test contains wildcard matching.""" + guardian = SafetyGuardian() + assert guardian._matches_pattern("test_dangerous_action", "*dangerous*") is True + assert guardian._matches_pattern("test_safe_action", "*dangerous*") is False + + +class TestErrorHandling: + """Tests for error handling in SafetyGuardian.""" + + @pytest.mark.asyncio + async def test_strict_mode_fails_on_error( + self, + guardian: SafetyGuardian, + sample_metadata: ActionMetadata, + ) -> None: + """Test strict mode denies on unexpected errors.""" + action = create_action(sample_metadata) + + # Force an error by breaking the cost controller + original_check = guardian.cost_controller.check_budget + guardian.cost_controller.check_budget = AsyncMock( + side_effect=Exception("Unexpected error") + ) + + result = await guardian.validate(action) + + assert result.allowed is False + assert result.decision == SafetyDecision.DENY + assert any("error" in r.lower() for r in result.reasons) + + # Restore + guardian.cost_controller.check_budget = original_check + + @pytest.mark.asyncio + async def test_non_strict_mode_allows_on_error( + self, + guardian: SafetyGuardian, + sample_metadata: ActionMetadata, + ) -> None: + """Test non-strict mode allows on unexpected errors.""" + guardian._config.strict_mode = False + action = create_action(sample_metadata) + + # Force an error by breaking the cost controller + original_check = guardian.cost_controller.check_budget + guardian.cost_controller.check_budget = AsyncMock( + side_effect=Exception("Unexpected error") + ) + + result = await guardian.validate(action) + + assert result.allowed is True + assert result.decision == SafetyDecision.ALLOW + + # Restore + guardian.cost_controller.check_budget = original_check + guardian._config.strict_mode = True diff --git a/backend/tests/services/safety/test_limits.py b/backend/tests/services/safety/test_limits.py new file mode 100644 index 0000000..960a3d3 --- /dev/null +++ b/backend/tests/services/safety/test_limits.py @@ -0,0 +1,405 @@ +"""Tests for rate limiter module.""" + +import pytest + +from app.services.safety.exceptions import RateLimitExceededError +from app.services.safety.limits.limiter import ( + RateLimiter, + SlidingWindowCounter, +) +from app.services.safety.models import ( + ActionMetadata, + ActionRequest, + ActionType, + RateLimitConfig, +) + + +@pytest.fixture +def sliding_counter() -> SlidingWindowCounter: + """Create a sliding window counter for testing.""" + return SlidingWindowCounter( + limit=5, + window_seconds=60, + burst_limit=3, + ) + + +@pytest.fixture +def rate_limiter() -> RateLimiter: + """Create a rate limiter for testing.""" + limiter = RateLimiter() + # Configure a test limit + limiter.configure( + RateLimitConfig( + name="test_limit", + limit=5, + window_seconds=60, + burst_limit=3, + ) + ) + return limiter + + +@pytest.fixture +def sample_metadata() -> ActionMetadata: + """Create sample action metadata.""" + return ActionMetadata( + agent_id="test-agent", + session_id="test-session", + ) + + +def create_action( + metadata: ActionMetadata, + action_type: ActionType = ActionType.LLM_CALL, +) -> ActionRequest: + """Helper to create test actions.""" + return ActionRequest( + action_type=action_type, + tool_name="test_tool", + resource="test-resource", + arguments={}, + metadata=metadata, + ) + + +class TestSlidingWindowCounter: + """Tests for SlidingWindowCounter class.""" + + @pytest.mark.asyncio + async def test_first_acquire_allowed( + self, + sliding_counter: SlidingWindowCounter, + ) -> None: + """Test first acquire is always allowed.""" + allowed, retry_after = await sliding_counter.try_acquire() + + assert allowed is True + assert retry_after == 0.0 + + @pytest.mark.asyncio + async def test_burst_limit( + self, + sliding_counter: SlidingWindowCounter, + ) -> None: + """Test burst limit is enforced.""" + # Acquire up to burst limit (3) + for _ in range(3): + allowed, _ = await sliding_counter.try_acquire() + assert allowed is True + + # Next should be denied (burst exceeded) + allowed, retry_after = await sliding_counter.try_acquire() + assert allowed is False + assert retry_after > 0 + + @pytest.mark.asyncio + async def test_get_status( + self, + sliding_counter: SlidingWindowCounter, + ) -> None: + """Test getting counter status.""" + # Make some requests + await sliding_counter.try_acquire() + await sliding_counter.try_acquire() + + current, remaining, reset_in = await sliding_counter.get_status() + + assert current == 2 + assert remaining == 3 # 5 - 2 + assert reset_in >= 0 + + +class TestRateLimiter: + """Tests for RateLimiter class.""" + + @pytest.mark.asyncio + async def test_check_status( + self, + rate_limiter: RateLimiter, + ) -> None: + """Test checking rate limit status.""" + status = await rate_limiter.check("test_limit", "test-key") + + assert status.name == "test_limit" + assert status.current_count == 0 + assert status.limit == 5 + assert status.remaining == 5 + assert status.is_limited is False + + @pytest.mark.asyncio + async def test_acquire_success( + self, + rate_limiter: RateLimiter, + ) -> None: + """Test successful acquire.""" + allowed, status = await rate_limiter.acquire("test_limit", "test-key") + + assert allowed is True + assert status.current_count == 1 + assert status.remaining == 4 + + @pytest.mark.asyncio + async def test_acquire_burst_exceeded( + self, + rate_limiter: RateLimiter, + ) -> None: + """Test acquire fails when burst exceeded.""" + # Acquire up to burst limit + for _ in range(3): + allowed, _ = await rate_limiter.acquire("test_limit", "test-key") + assert allowed is True + + # Next should fail + allowed, status = await rate_limiter.acquire("test_limit", "test-key") + assert allowed is False + assert status.is_limited is True + assert status.retry_after_seconds > 0 + + @pytest.mark.asyncio + async def test_require_success( + self, + rate_limiter: RateLimiter, + ) -> None: + """Test require passes when not limited.""" + # Should not raise + await rate_limiter.require("test_limit", "test-key") + + @pytest.mark.asyncio + async def test_require_raises( + self, + rate_limiter: RateLimiter, + ) -> None: + """Test require raises when limited.""" + # Use up burst limit + for _ in range(3): + await rate_limiter.acquire("test_limit", "test-key") + + with pytest.raises(RateLimitExceededError) as exc_info: + await rate_limiter.require("test_limit", "test-key") + + assert exc_info.value.limit_type == "test_limit" + assert exc_info.value.retry_after_seconds > 0 + + @pytest.mark.asyncio + async def test_check_action_allowed( + self, + rate_limiter: RateLimiter, + sample_metadata: ActionMetadata, + ) -> None: + """Test checking action is allowed.""" + action = create_action(sample_metadata) + + allowed, statuses = await rate_limiter.check_action(action) + + assert allowed is True + assert len(statuses) >= 1 # At least "actions" limit + + @pytest.mark.asyncio + async def test_check_action_llm_limits( + self, + rate_limiter: RateLimiter, + sample_metadata: ActionMetadata, + ) -> None: + """Test LLM actions check LLM-specific limits.""" + action = create_action(sample_metadata, action_type=ActionType.LLM_CALL) + + allowed, statuses = await rate_limiter.check_action(action) + + assert allowed is True + # Should have checked both "actions" and "llm_calls" + limit_names = [s.name for s in statuses] + assert "actions" in limit_names + assert "llm_calls" in limit_names + + @pytest.mark.asyncio + async def test_check_action_file_limits( + self, + rate_limiter: RateLimiter, + sample_metadata: ActionMetadata, + ) -> None: + """Test file actions check file-specific limits.""" + action = create_action(sample_metadata, action_type=ActionType.FILE_READ) + + allowed, statuses = await rate_limiter.check_action(action) + + assert allowed is True + # Should have checked both "actions" and "file_ops" + limit_names = [s.name for s in statuses] + assert "actions" in limit_names + assert "file_ops" in limit_names + + @pytest.mark.asyncio + async def test_check_action_does_not_consume_slot( + self, + rate_limiter: RateLimiter, + sample_metadata: ActionMetadata, + ) -> None: + """Test check_action only checks without consuming slots.""" + action = create_action(sample_metadata) + + # Check multiple times - should never consume + for _ in range(10): + allowed, _ = await rate_limiter.check_action(action) + assert allowed is True + + # Verify no slots were consumed + status = await rate_limiter.check("actions", sample_metadata.agent_id) + assert status.current_count == 0 + + @pytest.mark.asyncio + async def test_record_action_consumes_slot( + self, + rate_limiter: RateLimiter, + sample_metadata: ActionMetadata, + ) -> None: + """Test record_action consumes rate limit slots.""" + action = create_action(sample_metadata) + + # Record the action + await rate_limiter.record_action(action) + + # Verify slot was consumed + status = await rate_limiter.check("actions", sample_metadata.agent_id) + assert status.current_count == 1 + + @pytest.mark.asyncio + async def test_record_action_consumes_type_specific_slots( + self, + rate_limiter: RateLimiter, + sample_metadata: ActionMetadata, + ) -> None: + """Test record_action consumes type-specific slots.""" + # LLM action + llm_action = create_action(sample_metadata, action_type=ActionType.LLM_CALL) + await rate_limiter.record_action(llm_action) + + statuses = await rate_limiter.get_all_statuses(sample_metadata.agent_id) + assert statuses["actions"].current_count == 1 + assert statuses["llm_calls"].current_count == 1 + assert statuses["file_ops"].current_count == 0 + + # File action + file_action = create_action(sample_metadata, action_type=ActionType.FILE_READ) + await rate_limiter.record_action(file_action) + + statuses = await rate_limiter.get_all_statuses(sample_metadata.agent_id) + assert statuses["actions"].current_count == 2 + assert statuses["llm_calls"].current_count == 1 + assert statuses["file_ops"].current_count == 1 + + @pytest.mark.asyncio + async def test_get_all_statuses( + self, + rate_limiter: RateLimiter, + ) -> None: + """Test getting all rate limit statuses.""" + # Make some requests + await rate_limiter.acquire("actions", "test-key") + await rate_limiter.acquire("llm_calls", "test-key") + + statuses = await rate_limiter.get_all_statuses("test-key") + + assert "actions" in statuses + assert "llm_calls" in statuses + assert "file_ops" in statuses + assert statuses["actions"].current_count >= 1 + assert statuses["llm_calls"].current_count >= 1 + + @pytest.mark.asyncio + async def test_reset_single( + self, + rate_limiter: RateLimiter, + ) -> None: + """Test resetting a single rate limit.""" + # Make some requests + await rate_limiter.acquire("test_limit", "test-key") + await rate_limiter.acquire("test_limit", "test-key") + + # Reset + result = await rate_limiter.reset("test_limit", "test-key") + assert result is True + + # Check it's reset + status = await rate_limiter.check("test_limit", "test-key") + assert status.current_count == 0 + + @pytest.mark.asyncio + async def test_reset_nonexistent( + self, + rate_limiter: RateLimiter, + ) -> None: + """Test resetting non-existent limit returns False.""" + result = await rate_limiter.reset("nonexistent", "test-key") + assert result is False + + @pytest.mark.asyncio + async def test_reset_all( + self, + rate_limiter: RateLimiter, + ) -> None: + """Test resetting all rate limits for a key.""" + # Make requests across multiple limits + await rate_limiter.acquire("actions", "test-key") + await rate_limiter.acquire("llm_calls", "test-key") + await rate_limiter.acquire("file_ops", "test-key") + + # Reset all + count = await rate_limiter.reset_all("test-key") + assert count >= 3 + + # Check they're reset + statuses = await rate_limiter.get_all_statuses("test-key") + for status in statuses.values(): + assert status.current_count == 0 + + @pytest.mark.asyncio + async def test_per_key_isolation( + self, + rate_limiter: RateLimiter, + ) -> None: + """Test rate limits are isolated per key.""" + # Use up burst limit for key-1 + for _ in range(3): + await rate_limiter.acquire("test_limit", "key-1") + + # key-1 should be limited + allowed1, _ = await rate_limiter.acquire("test_limit", "key-1") + assert allowed1 is False + + # key-2 should still be allowed + allowed2, _ = await rate_limiter.acquire("test_limit", "key-2") + assert allowed2 is True + + @pytest.mark.asyncio + async def test_configure_custom_limit( + self, + rate_limiter: RateLimiter, + ) -> None: + """Test configuring custom rate limits.""" + rate_limiter.configure( + RateLimitConfig( + name="custom", + limit=100, + window_seconds=120, + burst_limit=50, + ) + ) + + status = await rate_limiter.check("custom", "test-key") + assert status.limit == 100 + assert status.window_seconds == 120 + + @pytest.mark.asyncio + async def test_default_limit_fallback( + self, + rate_limiter: RateLimiter, + ) -> None: + """Test fallback to default limit for unknown limit names.""" + # Request limit that doesn't exist + status = await rate_limiter.check("unknown_limit", "test-key") + + # Should use default (60/60s) + assert status.limit == 60 + assert status.window_seconds == 60