forked from cardosofelipe/fast-next-template
Frontend: - Fix debounce race condition in UserListTable search handler - Use useRef to properly track and cleanup timeout between keystrokes Backend (LLM Gateway): - Add thread-safe double-checked locking for global singletons (providers, circuit registry, cost tracker) - Fix Redis URL parsing with proper urlparse validation - Add explicit error handling for malformed Redis URLs - Document circuit breaker state transition safety 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
382 lines
12 KiB
Python
382 lines
12 KiB
Python
"""
|
|
Circuit Breaker implementation for LLM Gateway.
|
|
|
|
Provides fault tolerance by tracking provider failures and
|
|
temporarily disabling providers that are experiencing issues.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import threading
|
|
import time
|
|
from collections.abc import Awaitable, Callable
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
from typing import Any, TypeVar
|
|
|
|
from config import Settings, get_settings
|
|
from exceptions import CircuitOpenError
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
class CircuitState(str, Enum):
|
|
"""Circuit breaker states."""
|
|
|
|
CLOSED = "closed" # Normal operation, requests pass through
|
|
OPEN = "open" # Failures exceeded threshold, requests blocked
|
|
HALF_OPEN = "half_open" # Testing if service recovered
|
|
|
|
|
|
@dataclass
|
|
class CircuitStats:
|
|
"""Statistics for a circuit breaker."""
|
|
|
|
failures: int = 0
|
|
successes: int = 0
|
|
last_failure_time: float | None = None
|
|
last_success_time: float | None = None
|
|
state_changed_at: float = field(default_factory=time.time)
|
|
half_open_calls: int = 0
|
|
|
|
|
|
class CircuitBreaker:
|
|
"""
|
|
Circuit breaker for individual providers.
|
|
|
|
States:
|
|
- CLOSED: Normal operation. Failures increment counter.
|
|
- OPEN: Too many failures. Requests immediately fail.
|
|
- HALF_OPEN: Testing recovery. Limited requests allowed.
|
|
|
|
Transitions:
|
|
- CLOSED -> OPEN: When failures >= threshold
|
|
- OPEN -> HALF_OPEN: After recovery_timeout
|
|
- HALF_OPEN -> CLOSED: On success
|
|
- HALF_OPEN -> OPEN: On failure
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
failure_threshold: int = 5,
|
|
recovery_timeout: int = 60,
|
|
half_open_max_calls: int = 3,
|
|
) -> None:
|
|
"""
|
|
Initialize circuit breaker.
|
|
|
|
Args:
|
|
name: Identifier for this circuit (usually provider name)
|
|
failure_threshold: Failures before opening circuit
|
|
recovery_timeout: Seconds before attempting recovery
|
|
half_open_max_calls: Max calls allowed in half-open state
|
|
"""
|
|
self.name = name
|
|
self.failure_threshold = failure_threshold
|
|
self.recovery_timeout = recovery_timeout
|
|
self.half_open_max_calls = half_open_max_calls
|
|
|
|
self._state = CircuitState.CLOSED
|
|
self._stats = CircuitStats()
|
|
self._lock = asyncio.Lock()
|
|
|
|
@property
|
|
def state(self) -> CircuitState:
|
|
"""Get current circuit state (may trigger state transition)."""
|
|
# Check transition outside lock since it only reads/writes _state
|
|
# which is atomic in Python, and we use the lock inside _check_state_transition
|
|
# if a state change is needed
|
|
self._check_state_transition()
|
|
return self._state
|
|
|
|
@property
|
|
def stats(self) -> CircuitStats:
|
|
"""Get circuit statistics."""
|
|
return self._stats
|
|
|
|
def _check_state_transition(self) -> None:
|
|
"""
|
|
Check if state should transition based on time.
|
|
|
|
Note: This method is intentionally not async and doesn't acquire the lock
|
|
because it's called frequently from the state property. The transition
|
|
is safe because:
|
|
1. We check the current state first (atomic read)
|
|
2. _transition_to only modifies state if we're still in OPEN state
|
|
3. Multiple concurrent transitions to HALF_OPEN are idempotent
|
|
"""
|
|
if self._state == CircuitState.OPEN:
|
|
time_in_open = time.time() - self._stats.state_changed_at
|
|
if time_in_open >= self.recovery_timeout:
|
|
# Only transition if still in OPEN state (double-check)
|
|
if self._state == CircuitState.OPEN:
|
|
self._transition_to(CircuitState.HALF_OPEN)
|
|
logger.info(
|
|
f"Circuit {self.name} transitioned to HALF_OPEN "
|
|
f"after {time_in_open:.1f}s"
|
|
)
|
|
|
|
def _transition_to(self, new_state: CircuitState) -> None:
|
|
"""Transition to a new state."""
|
|
old_state = self._state
|
|
self._state = new_state
|
|
self._stats.state_changed_at = time.time()
|
|
|
|
if new_state == CircuitState.HALF_OPEN:
|
|
self._stats.half_open_calls = 0
|
|
elif new_state == CircuitState.CLOSED:
|
|
self._stats.failures = 0
|
|
|
|
logger.debug(f"Circuit {self.name}: {old_state.value} -> {new_state.value}")
|
|
|
|
def is_available(self) -> bool:
|
|
"""Check if circuit allows requests."""
|
|
state = self.state # Triggers state check
|
|
if state == CircuitState.CLOSED:
|
|
return True
|
|
if state == CircuitState.HALF_OPEN:
|
|
return self._stats.half_open_calls < self.half_open_max_calls
|
|
return False
|
|
|
|
def time_until_recovery(self) -> int | None:
|
|
"""Get seconds until circuit may recover (None if not open)."""
|
|
if self._state != CircuitState.OPEN:
|
|
return None
|
|
elapsed = time.time() - self._stats.state_changed_at
|
|
remaining = max(0, self.recovery_timeout - int(elapsed))
|
|
return remaining if remaining > 0 else 0
|
|
|
|
async def record_success(self) -> None:
|
|
"""Record a successful call."""
|
|
async with self._lock:
|
|
self._stats.successes += 1
|
|
self._stats.last_success_time = time.time()
|
|
|
|
if self._state == CircuitState.HALF_OPEN:
|
|
# Success in half-open state closes the circuit
|
|
self._transition_to(CircuitState.CLOSED)
|
|
logger.info(f"Circuit {self.name} closed after successful recovery")
|
|
|
|
async def record_failure(self, error: Exception | None = None) -> None: # noqa: ARG002
|
|
"""
|
|
Record a failed call.
|
|
|
|
Args:
|
|
error: The exception that occurred
|
|
"""
|
|
async with self._lock:
|
|
self._stats.failures += 1
|
|
self._stats.last_failure_time = time.time()
|
|
|
|
if self._state == CircuitState.HALF_OPEN:
|
|
# Failure in half-open state opens the circuit
|
|
self._transition_to(CircuitState.OPEN)
|
|
logger.warning(
|
|
f"Circuit {self.name} reopened after failure in half-open state"
|
|
)
|
|
elif self._state == CircuitState.CLOSED:
|
|
if self._stats.failures >= self.failure_threshold:
|
|
self._transition_to(CircuitState.OPEN)
|
|
logger.warning(
|
|
f"Circuit {self.name} opened after "
|
|
f"{self._stats.failures} failures"
|
|
)
|
|
|
|
async def execute(
|
|
self,
|
|
func: Callable[..., Awaitable[T]],
|
|
*args: Any,
|
|
**kwargs: Any,
|
|
) -> T:
|
|
"""
|
|
Execute a function with circuit breaker protection.
|
|
|
|
Args:
|
|
func: Async function to execute
|
|
*args: Positional arguments
|
|
**kwargs: Keyword arguments
|
|
|
|
Returns:
|
|
Function result
|
|
|
|
Raises:
|
|
CircuitOpenError: If circuit is open
|
|
"""
|
|
if not self.is_available():
|
|
raise CircuitOpenError(
|
|
provider=self.name,
|
|
recovery_time=self.time_until_recovery(),
|
|
)
|
|
|
|
async with self._lock:
|
|
if self._state == CircuitState.HALF_OPEN:
|
|
self._stats.half_open_calls += 1
|
|
|
|
try:
|
|
result = await func(*args, **kwargs)
|
|
await self.record_success()
|
|
return result
|
|
except Exception as e:
|
|
await self.record_failure(e)
|
|
raise
|
|
|
|
def reset(self) -> None:
|
|
"""Reset circuit to closed state."""
|
|
self._state = CircuitState.CLOSED
|
|
self._stats = CircuitStats()
|
|
logger.info(f"Circuit {self.name} reset to CLOSED")
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
"""Convert circuit state to dictionary."""
|
|
return {
|
|
"name": self.name,
|
|
"state": self._state.value,
|
|
"failures": self._stats.failures,
|
|
"successes": self._stats.successes,
|
|
"last_failure_time": self._stats.last_failure_time,
|
|
"last_success_time": self._stats.last_success_time,
|
|
"time_until_recovery": self.time_until_recovery(),
|
|
"is_available": self.is_available(),
|
|
}
|
|
|
|
|
|
class CircuitBreakerRegistry:
|
|
"""
|
|
Registry for managing multiple circuit breakers.
|
|
|
|
Provides centralized management of circuits for different providers/models.
|
|
"""
|
|
|
|
def __init__(self, settings: Settings | None = None) -> None:
|
|
"""
|
|
Initialize registry.
|
|
|
|
Args:
|
|
settings: Application settings (uses default if None)
|
|
"""
|
|
self._settings = settings or get_settings()
|
|
self._circuits: dict[str, CircuitBreaker] = {}
|
|
self._lock = asyncio.Lock()
|
|
|
|
async def get_circuit(self, name: str) -> CircuitBreaker:
|
|
"""
|
|
Get or create a circuit breaker.
|
|
|
|
Args:
|
|
name: Circuit identifier (e.g., provider name)
|
|
|
|
Returns:
|
|
CircuitBreaker instance
|
|
"""
|
|
async with self._lock:
|
|
if name not in self._circuits:
|
|
self._circuits[name] = CircuitBreaker(
|
|
name=name,
|
|
failure_threshold=self._settings.circuit_failure_threshold,
|
|
recovery_timeout=self._settings.circuit_recovery_timeout,
|
|
half_open_max_calls=self._settings.circuit_half_open_max_calls,
|
|
)
|
|
return self._circuits[name]
|
|
|
|
def get_circuit_sync(self, name: str) -> CircuitBreaker:
|
|
"""
|
|
Get or create a circuit breaker (sync version).
|
|
|
|
Args:
|
|
name: Circuit identifier
|
|
|
|
Returns:
|
|
CircuitBreaker instance
|
|
"""
|
|
if name not in self._circuits:
|
|
self._circuits[name] = CircuitBreaker(
|
|
name=name,
|
|
failure_threshold=self._settings.circuit_failure_threshold,
|
|
recovery_timeout=self._settings.circuit_recovery_timeout,
|
|
half_open_max_calls=self._settings.circuit_half_open_max_calls,
|
|
)
|
|
return self._circuits[name]
|
|
|
|
async def is_available(self, name: str) -> bool:
|
|
"""
|
|
Check if a circuit is available.
|
|
|
|
Args:
|
|
name: Circuit identifier
|
|
|
|
Returns:
|
|
True if circuit allows requests
|
|
"""
|
|
circuit = await self.get_circuit(name)
|
|
return circuit.is_available()
|
|
|
|
async def record_success(self, name: str) -> None:
|
|
"""Record success for a circuit."""
|
|
circuit = await self.get_circuit(name)
|
|
await circuit.record_success()
|
|
|
|
async def record_failure(self, name: str, error: Exception | None = None) -> None:
|
|
"""Record failure for a circuit."""
|
|
circuit = await self.get_circuit(name)
|
|
await circuit.record_failure(error)
|
|
|
|
async def reset(self, name: str) -> None:
|
|
"""Reset a specific circuit."""
|
|
async with self._lock:
|
|
if name in self._circuits:
|
|
self._circuits[name].reset()
|
|
|
|
async def reset_all(self) -> None:
|
|
"""Reset all circuits."""
|
|
async with self._lock:
|
|
for circuit in self._circuits.values():
|
|
circuit.reset()
|
|
|
|
def get_all_states(self) -> dict[str, dict[str, Any]]:
|
|
"""Get state of all circuits."""
|
|
return {name: circuit.to_dict() for name, circuit in self._circuits.items()}
|
|
|
|
def get_open_circuits(self) -> list[str]:
|
|
"""Get list of circuits that are currently open."""
|
|
return [
|
|
name
|
|
for name, circuit in self._circuits.items()
|
|
if circuit.state == CircuitState.OPEN
|
|
]
|
|
|
|
def get_available_circuits(self) -> list[str]:
|
|
"""Get list of circuits that are available for requests."""
|
|
return [
|
|
name for name, circuit in self._circuits.items() if circuit.is_available()
|
|
]
|
|
|
|
|
|
# Global registry instance with thread-safe lazy initialization
|
|
_registry: CircuitBreakerRegistry | None = None
|
|
_registry_lock = threading.Lock()
|
|
|
|
|
|
def get_circuit_registry() -> CircuitBreakerRegistry:
|
|
"""
|
|
Get the global circuit breaker registry.
|
|
|
|
Thread-safe with double-checked locking pattern.
|
|
"""
|
|
global _registry
|
|
if _registry is None:
|
|
with _registry_lock:
|
|
# Double-check after acquiring lock
|
|
if _registry is None:
|
|
_registry = CircuitBreakerRegistry()
|
|
return _registry
|
|
|
|
|
|
def reset_circuit_registry() -> None:
|
|
"""Reset the global registry (for testing)."""
|
|
global _registry
|
|
with _registry_lock:
|
|
_registry = None
|