Files
syndarix/mcp-servers/llm-gateway/failover.py
Felipe Cardoso f482559e15 fix(llm-gateway): improve type safety and datetime consistency
- Add type annotations for mypy compliance
- Use UTC-aware datetimes consistently (datetime.now(UTC))
- Add type: ignore comments for LiteLLM incomplete stubs
- Fix import ordering and formatting
- Update pyproject.toml mypy configuration

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-03 20:56:05 +01:00

358 lines
11 KiB
Python

"""
Circuit Breaker implementation for LLM Gateway.
Provides fault tolerance by tracking provider failures and
temporarily disabling providers that are experiencing issues.
"""
import asyncio
import logging
import time
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, TypeVar
from config import Settings, get_settings
from exceptions import CircuitOpenError
logger = logging.getLogger(__name__)
T = TypeVar("T")
class CircuitState(str, Enum):
"""Circuit breaker states."""
CLOSED = "closed" # Normal operation, requests pass through
OPEN = "open" # Failures exceeded threshold, requests blocked
HALF_OPEN = "half_open" # Testing if service recovered
@dataclass
class CircuitStats:
"""Statistics for a circuit breaker."""
failures: int = 0
successes: int = 0
last_failure_time: float | None = None
last_success_time: float | None = None
state_changed_at: float = field(default_factory=time.time)
half_open_calls: int = 0
class CircuitBreaker:
"""
Circuit breaker for individual providers.
States:
- CLOSED: Normal operation. Failures increment counter.
- OPEN: Too many failures. Requests immediately fail.
- HALF_OPEN: Testing recovery. Limited requests allowed.
Transitions:
- CLOSED -> OPEN: When failures >= threshold
- OPEN -> HALF_OPEN: After recovery_timeout
- HALF_OPEN -> CLOSED: On success
- HALF_OPEN -> OPEN: On failure
"""
def __init__(
self,
name: str,
failure_threshold: int = 5,
recovery_timeout: int = 60,
half_open_max_calls: int = 3,
) -> None:
"""
Initialize circuit breaker.
Args:
name: Identifier for this circuit (usually provider name)
failure_threshold: Failures before opening circuit
recovery_timeout: Seconds before attempting recovery
half_open_max_calls: Max calls allowed in half-open state
"""
self.name = name
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout
self.half_open_max_calls = half_open_max_calls
self._state = CircuitState.CLOSED
self._stats = CircuitStats()
self._lock = asyncio.Lock()
@property
def state(self) -> CircuitState:
"""Get current circuit state (may trigger state transition)."""
self._check_state_transition()
return self._state
@property
def stats(self) -> CircuitStats:
"""Get circuit statistics."""
return self._stats
def _check_state_transition(self) -> None:
"""Check if state should transition based on time."""
if self._state == CircuitState.OPEN:
time_in_open = time.time() - self._stats.state_changed_at
if time_in_open >= self.recovery_timeout:
self._transition_to(CircuitState.HALF_OPEN)
logger.info(
f"Circuit {self.name} transitioned to HALF_OPEN "
f"after {time_in_open:.1f}s"
)
def _transition_to(self, new_state: CircuitState) -> None:
"""Transition to a new state."""
old_state = self._state
self._state = new_state
self._stats.state_changed_at = time.time()
if new_state == CircuitState.HALF_OPEN:
self._stats.half_open_calls = 0
elif new_state == CircuitState.CLOSED:
self._stats.failures = 0
logger.debug(f"Circuit {self.name}: {old_state.value} -> {new_state.value}")
def is_available(self) -> bool:
"""Check if circuit allows requests."""
state = self.state # Triggers state check
if state == CircuitState.CLOSED:
return True
if state == CircuitState.HALF_OPEN:
return self._stats.half_open_calls < self.half_open_max_calls
return False
def time_until_recovery(self) -> int | None:
"""Get seconds until circuit may recover (None if not open)."""
if self._state != CircuitState.OPEN:
return None
elapsed = time.time() - self._stats.state_changed_at
remaining = max(0, self.recovery_timeout - int(elapsed))
return remaining if remaining > 0 else 0
async def record_success(self) -> None:
"""Record a successful call."""
async with self._lock:
self._stats.successes += 1
self._stats.last_success_time = time.time()
if self._state == CircuitState.HALF_OPEN:
# Success in half-open state closes the circuit
self._transition_to(CircuitState.CLOSED)
logger.info(f"Circuit {self.name} closed after successful recovery")
async def record_failure(self, error: Exception | None = None) -> None: # noqa: ARG002
"""
Record a failed call.
Args:
error: The exception that occurred
"""
async with self._lock:
self._stats.failures += 1
self._stats.last_failure_time = time.time()
if self._state == CircuitState.HALF_OPEN:
# Failure in half-open state opens the circuit
self._transition_to(CircuitState.OPEN)
logger.warning(
f"Circuit {self.name} reopened after failure in half-open state"
)
elif self._state == CircuitState.CLOSED:
if self._stats.failures >= self.failure_threshold:
self._transition_to(CircuitState.OPEN)
logger.warning(
f"Circuit {self.name} opened after "
f"{self._stats.failures} failures"
)
async def execute(
self,
func: Callable[..., Awaitable[T]],
*args: Any,
**kwargs: Any,
) -> T:
"""
Execute a function with circuit breaker protection.
Args:
func: Async function to execute
*args: Positional arguments
**kwargs: Keyword arguments
Returns:
Function result
Raises:
CircuitOpenError: If circuit is open
"""
if not self.is_available():
raise CircuitOpenError(
provider=self.name,
recovery_time=self.time_until_recovery(),
)
async with self._lock:
if self._state == CircuitState.HALF_OPEN:
self._stats.half_open_calls += 1
try:
result = await func(*args, **kwargs)
await self.record_success()
return result
except Exception as e:
await self.record_failure(e)
raise
def reset(self) -> None:
"""Reset circuit to closed state."""
self._state = CircuitState.CLOSED
self._stats = CircuitStats()
logger.info(f"Circuit {self.name} reset to CLOSED")
def to_dict(self) -> dict[str, Any]:
"""Convert circuit state to dictionary."""
return {
"name": self.name,
"state": self._state.value,
"failures": self._stats.failures,
"successes": self._stats.successes,
"last_failure_time": self._stats.last_failure_time,
"last_success_time": self._stats.last_success_time,
"time_until_recovery": self.time_until_recovery(),
"is_available": self.is_available(),
}
class CircuitBreakerRegistry:
"""
Registry for managing multiple circuit breakers.
Provides centralized management of circuits for different providers/models.
"""
def __init__(self, settings: Settings | None = None) -> None:
"""
Initialize registry.
Args:
settings: Application settings (uses default if None)
"""
self._settings = settings or get_settings()
self._circuits: dict[str, CircuitBreaker] = {}
self._lock = asyncio.Lock()
async def get_circuit(self, name: str) -> CircuitBreaker:
"""
Get or create a circuit breaker.
Args:
name: Circuit identifier (e.g., provider name)
Returns:
CircuitBreaker instance
"""
async with self._lock:
if name not in self._circuits:
self._circuits[name] = CircuitBreaker(
name=name,
failure_threshold=self._settings.circuit_failure_threshold,
recovery_timeout=self._settings.circuit_recovery_timeout,
half_open_max_calls=self._settings.circuit_half_open_max_calls,
)
return self._circuits[name]
def get_circuit_sync(self, name: str) -> CircuitBreaker:
"""
Get or create a circuit breaker (sync version).
Args:
name: Circuit identifier
Returns:
CircuitBreaker instance
"""
if name not in self._circuits:
self._circuits[name] = CircuitBreaker(
name=name,
failure_threshold=self._settings.circuit_failure_threshold,
recovery_timeout=self._settings.circuit_recovery_timeout,
half_open_max_calls=self._settings.circuit_half_open_max_calls,
)
return self._circuits[name]
async def is_available(self, name: str) -> bool:
"""
Check if a circuit is available.
Args:
name: Circuit identifier
Returns:
True if circuit allows requests
"""
circuit = await self.get_circuit(name)
return circuit.is_available()
async def record_success(self, name: str) -> None:
"""Record success for a circuit."""
circuit = await self.get_circuit(name)
await circuit.record_success()
async def record_failure(self, name: str, error: Exception | None = None) -> None:
"""Record failure for a circuit."""
circuit = await self.get_circuit(name)
await circuit.record_failure(error)
async def reset(self, name: str) -> None:
"""Reset a specific circuit."""
async with self._lock:
if name in self._circuits:
self._circuits[name].reset()
async def reset_all(self) -> None:
"""Reset all circuits."""
async with self._lock:
for circuit in self._circuits.values():
circuit.reset()
def get_all_states(self) -> dict[str, dict[str, Any]]:
"""Get state of all circuits."""
return {name: circuit.to_dict() for name, circuit in self._circuits.items()}
def get_open_circuits(self) -> list[str]:
"""Get list of circuits that are currently open."""
return [
name
for name, circuit in self._circuits.items()
if circuit.state == CircuitState.OPEN
]
def get_available_circuits(self) -> list[str]:
"""Get list of circuits that are available for requests."""
return [
name for name, circuit in self._circuits.items() if circuit.is_available()
]
# Global registry instance (lazy initialization)
_registry: CircuitBreakerRegistry | None = None
def get_circuit_registry() -> CircuitBreakerRegistry:
"""Get the global circuit breaker registry."""
global _registry
if _registry is None:
_registry = CircuitBreakerRegistry()
return _registry
def reset_circuit_registry() -> None:
"""Reset the global registry (for testing)."""
global _registry
_registry = None