""" Circuit Breaker implementation for LLM Gateway. Provides fault tolerance by tracking provider failures and temporarily disabling providers that are experiencing issues. """ import asyncio import logging import threading import time from collections.abc import Awaitable, Callable from dataclasses import dataclass, field from enum import Enum from typing import Any, TypeVar from config import Settings, get_settings from exceptions import CircuitOpenError logger = logging.getLogger(__name__) T = TypeVar("T") class CircuitState(str, Enum): """Circuit breaker states.""" CLOSED = "closed" # Normal operation, requests pass through OPEN = "open" # Failures exceeded threshold, requests blocked HALF_OPEN = "half_open" # Testing if service recovered @dataclass class CircuitStats: """Statistics for a circuit breaker.""" failures: int = 0 successes: int = 0 last_failure_time: float | None = None last_success_time: float | None = None state_changed_at: float = field(default_factory=time.time) half_open_calls: int = 0 class CircuitBreaker: """ Circuit breaker for individual providers. States: - CLOSED: Normal operation. Failures increment counter. - OPEN: Too many failures. Requests immediately fail. - HALF_OPEN: Testing recovery. Limited requests allowed. Transitions: - CLOSED -> OPEN: When failures >= threshold - OPEN -> HALF_OPEN: After recovery_timeout - HALF_OPEN -> CLOSED: On success - HALF_OPEN -> OPEN: On failure """ def __init__( self, name: str, failure_threshold: int = 5, recovery_timeout: int = 60, half_open_max_calls: int = 3, ) -> None: """ Initialize circuit breaker. Args: name: Identifier for this circuit (usually provider name) failure_threshold: Failures before opening circuit recovery_timeout: Seconds before attempting recovery half_open_max_calls: Max calls allowed in half-open state """ self.name = name self.failure_threshold = failure_threshold self.recovery_timeout = recovery_timeout self.half_open_max_calls = half_open_max_calls self._state = CircuitState.CLOSED self._stats = CircuitStats() self._lock = asyncio.Lock() @property def state(self) -> CircuitState: """Get current circuit state (may trigger state transition).""" # Check transition outside lock since it only reads/writes _state # which is atomic in Python, and we use the lock inside _check_state_transition # if a state change is needed self._check_state_transition() return self._state @property def stats(self) -> CircuitStats: """Get circuit statistics.""" return self._stats def _check_state_transition(self) -> None: """ Check if state should transition based on time. Note: This method is intentionally not async and doesn't acquire the lock because it's called frequently from the state property. The transition is safe because: 1. We check the current state first (atomic read) 2. _transition_to only modifies state if we're still in OPEN state 3. Multiple concurrent transitions to HALF_OPEN are idempotent """ if self._state == CircuitState.OPEN: time_in_open = time.time() - self._stats.state_changed_at # Double-check state after time calculation (for thread safety) if ( time_in_open >= self.recovery_timeout and self._state == CircuitState.OPEN ): self._transition_to(CircuitState.HALF_OPEN) logger.info( f"Circuit {self.name} transitioned to HALF_OPEN " f"after {time_in_open:.1f}s" ) def _transition_to(self, new_state: CircuitState) -> None: """Transition to a new state.""" old_state = self._state self._state = new_state self._stats.state_changed_at = time.time() if new_state == CircuitState.HALF_OPEN: self._stats.half_open_calls = 0 elif new_state == CircuitState.CLOSED: self._stats.failures = 0 logger.debug(f"Circuit {self.name}: {old_state.value} -> {new_state.value}") def is_available(self) -> bool: """Check if circuit allows requests.""" state = self.state # Triggers state check if state == CircuitState.CLOSED: return True if state == CircuitState.HALF_OPEN: return self._stats.half_open_calls < self.half_open_max_calls return False def time_until_recovery(self) -> int | None: """Get seconds until circuit may recover (None if not open).""" if self._state != CircuitState.OPEN: return None elapsed = time.time() - self._stats.state_changed_at remaining = max(0, self.recovery_timeout - int(elapsed)) return remaining if remaining > 0 else 0 async def record_success(self) -> None: """Record a successful call.""" async with self._lock: self._stats.successes += 1 self._stats.last_success_time = time.time() if self._state == CircuitState.HALF_OPEN: # Success in half-open state closes the circuit self._transition_to(CircuitState.CLOSED) logger.info(f"Circuit {self.name} closed after successful recovery") async def record_failure(self, error: Exception | None = None) -> None: # noqa: ARG002 """ Record a failed call. Args: error: The exception that occurred """ async with self._lock: self._stats.failures += 1 self._stats.last_failure_time = time.time() if self._state == CircuitState.HALF_OPEN: # Failure in half-open state opens the circuit self._transition_to(CircuitState.OPEN) logger.warning( f"Circuit {self.name} reopened after failure in half-open state" ) elif self._state == CircuitState.CLOSED: if self._stats.failures >= self.failure_threshold: self._transition_to(CircuitState.OPEN) logger.warning( f"Circuit {self.name} opened after " f"{self._stats.failures} failures" ) async def execute( self, func: Callable[..., Awaitable[T]], *args: Any, **kwargs: Any, ) -> T: """ Execute a function with circuit breaker protection. Args: func: Async function to execute *args: Positional arguments **kwargs: Keyword arguments Returns: Function result Raises: CircuitOpenError: If circuit is open """ if not self.is_available(): raise CircuitOpenError( provider=self.name, recovery_time=self.time_until_recovery(), ) async with self._lock: if self._state == CircuitState.HALF_OPEN: self._stats.half_open_calls += 1 try: result = await func(*args, **kwargs) await self.record_success() return result except Exception as e: await self.record_failure(e) raise def reset(self) -> None: """Reset circuit to closed state.""" self._state = CircuitState.CLOSED self._stats = CircuitStats() logger.info(f"Circuit {self.name} reset to CLOSED") def to_dict(self) -> dict[str, Any]: """Convert circuit state to dictionary.""" return { "name": self.name, "state": self._state.value, "failures": self._stats.failures, "successes": self._stats.successes, "last_failure_time": self._stats.last_failure_time, "last_success_time": self._stats.last_success_time, "time_until_recovery": self.time_until_recovery(), "is_available": self.is_available(), } class CircuitBreakerRegistry: """ Registry for managing multiple circuit breakers. Provides centralized management of circuits for different providers/models. """ def __init__(self, settings: Settings | None = None) -> None: """ Initialize registry. Args: settings: Application settings (uses default if None) """ self._settings = settings or get_settings() self._circuits: dict[str, CircuitBreaker] = {} self._lock = asyncio.Lock() async def get_circuit(self, name: str) -> CircuitBreaker: """ Get or create a circuit breaker. Args: name: Circuit identifier (e.g., provider name) Returns: CircuitBreaker instance """ async with self._lock: if name not in self._circuits: self._circuits[name] = CircuitBreaker( name=name, failure_threshold=self._settings.circuit_failure_threshold, recovery_timeout=self._settings.circuit_recovery_timeout, half_open_max_calls=self._settings.circuit_half_open_max_calls, ) return self._circuits[name] def get_circuit_sync(self, name: str) -> CircuitBreaker: """ Get or create a circuit breaker (sync version). Args: name: Circuit identifier Returns: CircuitBreaker instance """ if name not in self._circuits: self._circuits[name] = CircuitBreaker( name=name, failure_threshold=self._settings.circuit_failure_threshold, recovery_timeout=self._settings.circuit_recovery_timeout, half_open_max_calls=self._settings.circuit_half_open_max_calls, ) return self._circuits[name] async def is_available(self, name: str) -> bool: """ Check if a circuit is available. Args: name: Circuit identifier Returns: True if circuit allows requests """ circuit = await self.get_circuit(name) return circuit.is_available() async def record_success(self, name: str) -> None: """Record success for a circuit.""" circuit = await self.get_circuit(name) await circuit.record_success() async def record_failure(self, name: str, error: Exception | None = None) -> None: """Record failure for a circuit.""" circuit = await self.get_circuit(name) await circuit.record_failure(error) async def reset(self, name: str) -> None: """Reset a specific circuit.""" async with self._lock: if name in self._circuits: self._circuits[name].reset() async def reset_all(self) -> None: """Reset all circuits.""" async with self._lock: for circuit in self._circuits.values(): circuit.reset() def get_all_states(self) -> dict[str, dict[str, Any]]: """Get state of all circuits.""" return {name: circuit.to_dict() for name, circuit in self._circuits.items()} def get_open_circuits(self) -> list[str]: """Get list of circuits that are currently open.""" return [ name for name, circuit in self._circuits.items() if circuit.state == CircuitState.OPEN ] def get_available_circuits(self) -> list[str]: """Get list of circuits that are available for requests.""" return [ name for name, circuit in self._circuits.items() if circuit.is_available() ] # Global registry instance with thread-safe lazy initialization _registry: CircuitBreakerRegistry | None = None _registry_lock = threading.Lock() def get_circuit_registry() -> CircuitBreakerRegistry: """ Get the global circuit breaker registry. Thread-safe with double-checked locking pattern. """ global _registry if _registry is None: with _registry_lock: # Double-check after acquiring lock if _registry is None: _registry = CircuitBreakerRegistry() return _registry def reset_circuit_registry() -> None: """Reset the global registry (for testing).""" global _registry with _registry_lock: _registry = None