feat(llm-gateway): implement LLM Gateway MCP Server (#56)
Implements complete LLM Gateway MCP Server with: - FastMCP server with 4 tools: chat_completion, list_models, get_usage, count_tokens - LiteLLM Router with multi-provider failover chains - Circuit breaker pattern for fault tolerance - Redis-based cost tracking per project/agent - Comprehensive test suite (209 tests, 92% coverage) Model groups defined per ADR-004: - reasoning: claude-opus-4 → gpt-4.1 → gemini-2.5-pro - code: claude-sonnet-4 → gpt-4.1 → deepseek-coder - fast: claude-haiku → gpt-4.1-mini → gemini-2.0-flash 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
357
mcp-servers/llm-gateway/failover.py
Normal file
357
mcp-servers/llm-gateway/failover.py
Normal file
@@ -0,0 +1,357 @@
|
||||
"""
|
||||
Circuit Breaker implementation for LLM Gateway.
|
||||
|
||||
Provides fault tolerance by tracking provider failures and
|
||||
temporarily disabling providers that are experiencing issues.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from config import Settings, get_settings
|
||||
from exceptions import CircuitOpenError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class CircuitState(str, Enum):
|
||||
"""Circuit breaker states."""
|
||||
|
||||
CLOSED = "closed" # Normal operation, requests pass through
|
||||
OPEN = "open" # Failures exceeded threshold, requests blocked
|
||||
HALF_OPEN = "half_open" # Testing if service recovered
|
||||
|
||||
|
||||
@dataclass
|
||||
class CircuitStats:
|
||||
"""Statistics for a circuit breaker."""
|
||||
|
||||
failures: int = 0
|
||||
successes: int = 0
|
||||
last_failure_time: float | None = None
|
||||
last_success_time: float | None = None
|
||||
state_changed_at: float = field(default_factory=time.time)
|
||||
half_open_calls: int = 0
|
||||
|
||||
|
||||
class CircuitBreaker:
|
||||
"""
|
||||
Circuit breaker for individual providers.
|
||||
|
||||
States:
|
||||
- CLOSED: Normal operation. Failures increment counter.
|
||||
- OPEN: Too many failures. Requests immediately fail.
|
||||
- HALF_OPEN: Testing recovery. Limited requests allowed.
|
||||
|
||||
Transitions:
|
||||
- CLOSED -> OPEN: When failures >= threshold
|
||||
- OPEN -> HALF_OPEN: After recovery_timeout
|
||||
- HALF_OPEN -> CLOSED: On success
|
||||
- HALF_OPEN -> OPEN: On failure
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
failure_threshold: int = 5,
|
||||
recovery_timeout: int = 60,
|
||||
half_open_max_calls: int = 3,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize circuit breaker.
|
||||
|
||||
Args:
|
||||
name: Identifier for this circuit (usually provider name)
|
||||
failure_threshold: Failures before opening circuit
|
||||
recovery_timeout: Seconds before attempting recovery
|
||||
half_open_max_calls: Max calls allowed in half-open state
|
||||
"""
|
||||
self.name = name
|
||||
self.failure_threshold = failure_threshold
|
||||
self.recovery_timeout = recovery_timeout
|
||||
self.half_open_max_calls = half_open_max_calls
|
||||
|
||||
self._state = CircuitState.CLOSED
|
||||
self._stats = CircuitStats()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def state(self) -> CircuitState:
|
||||
"""Get current circuit state (may trigger state transition)."""
|
||||
self._check_state_transition()
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def stats(self) -> CircuitStats:
|
||||
"""Get circuit statistics."""
|
||||
return self._stats
|
||||
|
||||
def _check_state_transition(self) -> None:
|
||||
"""Check if state should transition based on time."""
|
||||
if self._state == CircuitState.OPEN:
|
||||
time_in_open = time.time() - self._stats.state_changed_at
|
||||
if time_in_open >= self.recovery_timeout:
|
||||
self._transition_to(CircuitState.HALF_OPEN)
|
||||
logger.info(
|
||||
f"Circuit {self.name} transitioned to HALF_OPEN "
|
||||
f"after {time_in_open:.1f}s"
|
||||
)
|
||||
|
||||
def _transition_to(self, new_state: CircuitState) -> None:
|
||||
"""Transition to a new state."""
|
||||
old_state = self._state
|
||||
self._state = new_state
|
||||
self._stats.state_changed_at = time.time()
|
||||
|
||||
if new_state == CircuitState.HALF_OPEN:
|
||||
self._stats.half_open_calls = 0
|
||||
elif new_state == CircuitState.CLOSED:
|
||||
self._stats.failures = 0
|
||||
|
||||
logger.debug(f"Circuit {self.name}: {old_state.value} -> {new_state.value}")
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if circuit allows requests."""
|
||||
state = self.state # Triggers state check
|
||||
if state == CircuitState.CLOSED:
|
||||
return True
|
||||
if state == CircuitState.HALF_OPEN:
|
||||
return self._stats.half_open_calls < self.half_open_max_calls
|
||||
return False
|
||||
|
||||
def time_until_recovery(self) -> int | None:
|
||||
"""Get seconds until circuit may recover (None if not open)."""
|
||||
if self._state != CircuitState.OPEN:
|
||||
return None
|
||||
elapsed = time.time() - self._stats.state_changed_at
|
||||
remaining = max(0, self.recovery_timeout - int(elapsed))
|
||||
return remaining if remaining > 0 else 0
|
||||
|
||||
async def record_success(self) -> None:
|
||||
"""Record a successful call."""
|
||||
async with self._lock:
|
||||
self._stats.successes += 1
|
||||
self._stats.last_success_time = time.time()
|
||||
|
||||
if self._state == CircuitState.HALF_OPEN:
|
||||
# Success in half-open state closes the circuit
|
||||
self._transition_to(CircuitState.CLOSED)
|
||||
logger.info(f"Circuit {self.name} closed after successful recovery")
|
||||
|
||||
async def record_failure(self, error: Exception | None = None) -> None: # noqa: ARG002
|
||||
"""
|
||||
Record a failed call.
|
||||
|
||||
Args:
|
||||
error: The exception that occurred
|
||||
"""
|
||||
async with self._lock:
|
||||
self._stats.failures += 1
|
||||
self._stats.last_failure_time = time.time()
|
||||
|
||||
if self._state == CircuitState.HALF_OPEN:
|
||||
# Failure in half-open state opens the circuit
|
||||
self._transition_to(CircuitState.OPEN)
|
||||
logger.warning(
|
||||
f"Circuit {self.name} reopened after failure in half-open state"
|
||||
)
|
||||
elif self._state == CircuitState.CLOSED:
|
||||
if self._stats.failures >= self.failure_threshold:
|
||||
self._transition_to(CircuitState.OPEN)
|
||||
logger.warning(
|
||||
f"Circuit {self.name} opened after "
|
||||
f"{self._stats.failures} failures"
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
func: Callable[..., T],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> T:
|
||||
"""
|
||||
Execute a function with circuit breaker protection.
|
||||
|
||||
Args:
|
||||
func: Async function to execute
|
||||
*args: Positional arguments
|
||||
**kwargs: Keyword arguments
|
||||
|
||||
Returns:
|
||||
Function result
|
||||
|
||||
Raises:
|
||||
CircuitOpenError: If circuit is open
|
||||
"""
|
||||
if not self.is_available():
|
||||
raise CircuitOpenError(
|
||||
provider=self.name,
|
||||
recovery_time=self.time_until_recovery(),
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
if self._state == CircuitState.HALF_OPEN:
|
||||
self._stats.half_open_calls += 1
|
||||
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
await self.record_success()
|
||||
return result
|
||||
except Exception as e:
|
||||
await self.record_failure(e)
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset circuit to closed state."""
|
||||
self._state = CircuitState.CLOSED
|
||||
self._stats = CircuitStats()
|
||||
logger.info(f"Circuit {self.name} reset to CLOSED")
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert circuit state to dictionary."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"state": self._state.value,
|
||||
"failures": self._stats.failures,
|
||||
"successes": self._stats.successes,
|
||||
"last_failure_time": self._stats.last_failure_time,
|
||||
"last_success_time": self._stats.last_success_time,
|
||||
"time_until_recovery": self.time_until_recovery(),
|
||||
"is_available": self.is_available(),
|
||||
}
|
||||
|
||||
|
||||
class CircuitBreakerRegistry:
|
||||
"""
|
||||
Registry for managing multiple circuit breakers.
|
||||
|
||||
Provides centralized management of circuits for different providers/models.
|
||||
"""
|
||||
|
||||
def __init__(self, settings: Settings | None = None) -> None:
|
||||
"""
|
||||
Initialize registry.
|
||||
|
||||
Args:
|
||||
settings: Application settings (uses default if None)
|
||||
"""
|
||||
self._settings = settings or get_settings()
|
||||
self._circuits: dict[str, CircuitBreaker] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def get_circuit(self, name: str) -> CircuitBreaker:
|
||||
"""
|
||||
Get or create a circuit breaker.
|
||||
|
||||
Args:
|
||||
name: Circuit identifier (e.g., provider name)
|
||||
|
||||
Returns:
|
||||
CircuitBreaker instance
|
||||
"""
|
||||
async with self._lock:
|
||||
if name not in self._circuits:
|
||||
self._circuits[name] = CircuitBreaker(
|
||||
name=name,
|
||||
failure_threshold=self._settings.circuit_failure_threshold,
|
||||
recovery_timeout=self._settings.circuit_recovery_timeout,
|
||||
half_open_max_calls=self._settings.circuit_half_open_max_calls,
|
||||
)
|
||||
return self._circuits[name]
|
||||
|
||||
def get_circuit_sync(self, name: str) -> CircuitBreaker:
|
||||
"""
|
||||
Get or create a circuit breaker (sync version).
|
||||
|
||||
Args:
|
||||
name: Circuit identifier
|
||||
|
||||
Returns:
|
||||
CircuitBreaker instance
|
||||
"""
|
||||
if name not in self._circuits:
|
||||
self._circuits[name] = CircuitBreaker(
|
||||
name=name,
|
||||
failure_threshold=self._settings.circuit_failure_threshold,
|
||||
recovery_timeout=self._settings.circuit_recovery_timeout,
|
||||
half_open_max_calls=self._settings.circuit_half_open_max_calls,
|
||||
)
|
||||
return self._circuits[name]
|
||||
|
||||
async def is_available(self, name: str) -> bool:
|
||||
"""
|
||||
Check if a circuit is available.
|
||||
|
||||
Args:
|
||||
name: Circuit identifier
|
||||
|
||||
Returns:
|
||||
True if circuit allows requests
|
||||
"""
|
||||
circuit = await self.get_circuit(name)
|
||||
return circuit.is_available()
|
||||
|
||||
async def record_success(self, name: str) -> None:
|
||||
"""Record success for a circuit."""
|
||||
circuit = await self.get_circuit(name)
|
||||
await circuit.record_success()
|
||||
|
||||
async def record_failure(self, name: str, error: Exception | None = None) -> None:
|
||||
"""Record failure for a circuit."""
|
||||
circuit = await self.get_circuit(name)
|
||||
await circuit.record_failure(error)
|
||||
|
||||
async def reset(self, name: str) -> None:
|
||||
"""Reset a specific circuit."""
|
||||
async with self._lock:
|
||||
if name in self._circuits:
|
||||
self._circuits[name].reset()
|
||||
|
||||
async def reset_all(self) -> None:
|
||||
"""Reset all circuits."""
|
||||
async with self._lock:
|
||||
for circuit in self._circuits.values():
|
||||
circuit.reset()
|
||||
|
||||
def get_all_states(self) -> dict[str, dict[str, Any]]:
|
||||
"""Get state of all circuits."""
|
||||
return {name: circuit.to_dict() for name, circuit in self._circuits.items()}
|
||||
|
||||
def get_open_circuits(self) -> list[str]:
|
||||
"""Get list of circuits that are currently open."""
|
||||
return [
|
||||
name
|
||||
for name, circuit in self._circuits.items()
|
||||
if circuit.state == CircuitState.OPEN
|
||||
]
|
||||
|
||||
def get_available_circuits(self) -> list[str]:
|
||||
"""Get list of circuits that are available for requests."""
|
||||
return [
|
||||
name for name, circuit in self._circuits.items() if circuit.is_available()
|
||||
]
|
||||
|
||||
|
||||
# Global registry instance (lazy initialization)
|
||||
_registry: CircuitBreakerRegistry | None = None
|
||||
|
||||
|
||||
def get_circuit_registry() -> CircuitBreakerRegistry:
|
||||
"""Get the global circuit breaker registry."""
|
||||
global _registry
|
||||
if _registry is None:
|
||||
_registry = CircuitBreakerRegistry()
|
||||
return _registry
|
||||
|
||||
|
||||
def reset_circuit_registry() -> None:
|
||||
"""Reset the global registry (for testing)."""
|
||||
global _registry
|
||||
_registry = None
|
||||
Reference in New Issue
Block a user