Files
syndarix/mcp-servers/llm-gateway/failover.py
Felipe Cardoso 95342cc94d fix(mcp-gateway): address critical issues from deep review
Frontend:
- Fix debounce race condition in UserListTable search handler
- Use useRef to properly track and cleanup timeout between keystrokes

Backend (LLM Gateway):
- Add thread-safe double-checked locking for global singletons
  (providers, circuit registry, cost tracker)
- Fix Redis URL parsing with proper urlparse validation
- Add explicit error handling for malformed Redis URLs
- Document circuit breaker state transition safety

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-04 01:36:55 +01:00

382 lines
12 KiB
Python

"""
Circuit Breaker implementation for LLM Gateway.
Provides fault tolerance by tracking provider failures and
temporarily disabling providers that are experiencing issues.
"""
import asyncio
import logging
import threading
import time
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, TypeVar
from config import Settings, get_settings
from exceptions import CircuitOpenError
logger = logging.getLogger(__name__)
T = TypeVar("T")
class CircuitState(str, Enum):
"""Circuit breaker states."""
CLOSED = "closed" # Normal operation, requests pass through
OPEN = "open" # Failures exceeded threshold, requests blocked
HALF_OPEN = "half_open" # Testing if service recovered
@dataclass
class CircuitStats:
"""Statistics for a circuit breaker."""
failures: int = 0
successes: int = 0
last_failure_time: float | None = None
last_success_time: float | None = None
state_changed_at: float = field(default_factory=time.time)
half_open_calls: int = 0
class CircuitBreaker:
"""
Circuit breaker for individual providers.
States:
- CLOSED: Normal operation. Failures increment counter.
- OPEN: Too many failures. Requests immediately fail.
- HALF_OPEN: Testing recovery. Limited requests allowed.
Transitions:
- CLOSED -> OPEN: When failures >= threshold
- OPEN -> HALF_OPEN: After recovery_timeout
- HALF_OPEN -> CLOSED: On success
- HALF_OPEN -> OPEN: On failure
"""
def __init__(
self,
name: str,
failure_threshold: int = 5,
recovery_timeout: int = 60,
half_open_max_calls: int = 3,
) -> None:
"""
Initialize circuit breaker.
Args:
name: Identifier for this circuit (usually provider name)
failure_threshold: Failures before opening circuit
recovery_timeout: Seconds before attempting recovery
half_open_max_calls: Max calls allowed in half-open state
"""
self.name = name
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout
self.half_open_max_calls = half_open_max_calls
self._state = CircuitState.CLOSED
self._stats = CircuitStats()
self._lock = asyncio.Lock()
@property
def state(self) -> CircuitState:
"""Get current circuit state (may trigger state transition)."""
# Check transition outside lock since it only reads/writes _state
# which is atomic in Python, and we use the lock inside _check_state_transition
# if a state change is needed
self._check_state_transition()
return self._state
@property
def stats(self) -> CircuitStats:
"""Get circuit statistics."""
return self._stats
def _check_state_transition(self) -> None:
"""
Check if state should transition based on time.
Note: This method is intentionally not async and doesn't acquire the lock
because it's called frequently from the state property. The transition
is safe because:
1. We check the current state first (atomic read)
2. _transition_to only modifies state if we're still in OPEN state
3. Multiple concurrent transitions to HALF_OPEN are idempotent
"""
if self._state == CircuitState.OPEN:
time_in_open = time.time() - self._stats.state_changed_at
if time_in_open >= self.recovery_timeout:
# Only transition if still in OPEN state (double-check)
if self._state == CircuitState.OPEN:
self._transition_to(CircuitState.HALF_OPEN)
logger.info(
f"Circuit {self.name} transitioned to HALF_OPEN "
f"after {time_in_open:.1f}s"
)
def _transition_to(self, new_state: CircuitState) -> None:
"""Transition to a new state."""
old_state = self._state
self._state = new_state
self._stats.state_changed_at = time.time()
if new_state == CircuitState.HALF_OPEN:
self._stats.half_open_calls = 0
elif new_state == CircuitState.CLOSED:
self._stats.failures = 0
logger.debug(f"Circuit {self.name}: {old_state.value} -> {new_state.value}")
def is_available(self) -> bool:
"""Check if circuit allows requests."""
state = self.state # Triggers state check
if state == CircuitState.CLOSED:
return True
if state == CircuitState.HALF_OPEN:
return self._stats.half_open_calls < self.half_open_max_calls
return False
def time_until_recovery(self) -> int | None:
"""Get seconds until circuit may recover (None if not open)."""
if self._state != CircuitState.OPEN:
return None
elapsed = time.time() - self._stats.state_changed_at
remaining = max(0, self.recovery_timeout - int(elapsed))
return remaining if remaining > 0 else 0
async def record_success(self) -> None:
"""Record a successful call."""
async with self._lock:
self._stats.successes += 1
self._stats.last_success_time = time.time()
if self._state == CircuitState.HALF_OPEN:
# Success in half-open state closes the circuit
self._transition_to(CircuitState.CLOSED)
logger.info(f"Circuit {self.name} closed after successful recovery")
async def record_failure(self, error: Exception | None = None) -> None: # noqa: ARG002
"""
Record a failed call.
Args:
error: The exception that occurred
"""
async with self._lock:
self._stats.failures += 1
self._stats.last_failure_time = time.time()
if self._state == CircuitState.HALF_OPEN:
# Failure in half-open state opens the circuit
self._transition_to(CircuitState.OPEN)
logger.warning(
f"Circuit {self.name} reopened after failure in half-open state"
)
elif self._state == CircuitState.CLOSED:
if self._stats.failures >= self.failure_threshold:
self._transition_to(CircuitState.OPEN)
logger.warning(
f"Circuit {self.name} opened after "
f"{self._stats.failures} failures"
)
async def execute(
self,
func: Callable[..., Awaitable[T]],
*args: Any,
**kwargs: Any,
) -> T:
"""
Execute a function with circuit breaker protection.
Args:
func: Async function to execute
*args: Positional arguments
**kwargs: Keyword arguments
Returns:
Function result
Raises:
CircuitOpenError: If circuit is open
"""
if not self.is_available():
raise CircuitOpenError(
provider=self.name,
recovery_time=self.time_until_recovery(),
)
async with self._lock:
if self._state == CircuitState.HALF_OPEN:
self._stats.half_open_calls += 1
try:
result = await func(*args, **kwargs)
await self.record_success()
return result
except Exception as e:
await self.record_failure(e)
raise
def reset(self) -> None:
"""Reset circuit to closed state."""
self._state = CircuitState.CLOSED
self._stats = CircuitStats()
logger.info(f"Circuit {self.name} reset to CLOSED")
def to_dict(self) -> dict[str, Any]:
"""Convert circuit state to dictionary."""
return {
"name": self.name,
"state": self._state.value,
"failures": self._stats.failures,
"successes": self._stats.successes,
"last_failure_time": self._stats.last_failure_time,
"last_success_time": self._stats.last_success_time,
"time_until_recovery": self.time_until_recovery(),
"is_available": self.is_available(),
}
class CircuitBreakerRegistry:
"""
Registry for managing multiple circuit breakers.
Provides centralized management of circuits for different providers/models.
"""
def __init__(self, settings: Settings | None = None) -> None:
"""
Initialize registry.
Args:
settings: Application settings (uses default if None)
"""
self._settings = settings or get_settings()
self._circuits: dict[str, CircuitBreaker] = {}
self._lock = asyncio.Lock()
async def get_circuit(self, name: str) -> CircuitBreaker:
"""
Get or create a circuit breaker.
Args:
name: Circuit identifier (e.g., provider name)
Returns:
CircuitBreaker instance
"""
async with self._lock:
if name not in self._circuits:
self._circuits[name] = CircuitBreaker(
name=name,
failure_threshold=self._settings.circuit_failure_threshold,
recovery_timeout=self._settings.circuit_recovery_timeout,
half_open_max_calls=self._settings.circuit_half_open_max_calls,
)
return self._circuits[name]
def get_circuit_sync(self, name: str) -> CircuitBreaker:
"""
Get or create a circuit breaker (sync version).
Args:
name: Circuit identifier
Returns:
CircuitBreaker instance
"""
if name not in self._circuits:
self._circuits[name] = CircuitBreaker(
name=name,
failure_threshold=self._settings.circuit_failure_threshold,
recovery_timeout=self._settings.circuit_recovery_timeout,
half_open_max_calls=self._settings.circuit_half_open_max_calls,
)
return self._circuits[name]
async def is_available(self, name: str) -> bool:
"""
Check if a circuit is available.
Args:
name: Circuit identifier
Returns:
True if circuit allows requests
"""
circuit = await self.get_circuit(name)
return circuit.is_available()
async def record_success(self, name: str) -> None:
"""Record success for a circuit."""
circuit = await self.get_circuit(name)
await circuit.record_success()
async def record_failure(self, name: str, error: Exception | None = None) -> None:
"""Record failure for a circuit."""
circuit = await self.get_circuit(name)
await circuit.record_failure(error)
async def reset(self, name: str) -> None:
"""Reset a specific circuit."""
async with self._lock:
if name in self._circuits:
self._circuits[name].reset()
async def reset_all(self) -> None:
"""Reset all circuits."""
async with self._lock:
for circuit in self._circuits.values():
circuit.reset()
def get_all_states(self) -> dict[str, dict[str, Any]]:
"""Get state of all circuits."""
return {name: circuit.to_dict() for name, circuit in self._circuits.items()}
def get_open_circuits(self) -> list[str]:
"""Get list of circuits that are currently open."""
return [
name
for name, circuit in self._circuits.items()
if circuit.state == CircuitState.OPEN
]
def get_available_circuits(self) -> list[str]:
"""Get list of circuits that are available for requests."""
return [
name for name, circuit in self._circuits.items() if circuit.is_available()
]
# Global registry instance with thread-safe lazy initialization
_registry: CircuitBreakerRegistry | None = None
_registry_lock = threading.Lock()
def get_circuit_registry() -> CircuitBreakerRegistry:
"""
Get the global circuit breaker registry.
Thread-safe with double-checked locking pattern.
"""
global _registry
if _registry is None:
with _registry_lock:
# Double-check after acquiring lock
if _registry is None:
_registry = CircuitBreakerRegistry()
return _registry
def reset_circuit_registry() -> None:
"""Reset the global registry (for testing)."""
global _registry
with _registry_lock:
_registry = None