""" Tests for failover module (circuit breaker). """ import asyncio import time import pytest from config import Settings from exceptions import CircuitOpenError from failover import ( CircuitBreaker, CircuitBreakerRegistry, CircuitState, CircuitStats, get_circuit_registry, reset_circuit_registry, ) class TestCircuitState: """Tests for CircuitState enum.""" def test_circuit_states(self) -> None: """Test circuit state values.""" assert CircuitState.CLOSED.value == "closed" assert CircuitState.OPEN.value == "open" assert CircuitState.HALF_OPEN.value == "half_open" class TestCircuitStats: """Tests for CircuitStats dataclass.""" def test_default_stats(self) -> None: """Test default stats values.""" stats = CircuitStats() assert stats.failures == 0 assert stats.successes == 0 assert stats.last_failure_time is None assert stats.last_success_time is None assert stats.half_open_calls == 0 class TestCircuitBreaker: """Tests for CircuitBreaker class.""" def test_initial_state(self) -> None: """Test circuit breaker initial state.""" cb = CircuitBreaker(name="test", failure_threshold=5) assert cb.name == "test" assert cb.state == CircuitState.CLOSED assert cb.failure_threshold == 5 assert cb.is_available() is True def test_state_remains_closed_below_threshold(self) -> None: """Test circuit stays closed below failure threshold.""" cb = CircuitBreaker(name="test", failure_threshold=3) # Record 2 failures (below threshold) asyncio.run(cb.record_failure()) asyncio.run(cb.record_failure()) assert cb.state == CircuitState.CLOSED assert cb.stats.failures == 2 assert cb.is_available() is True def test_state_opens_at_threshold(self) -> None: """Test circuit opens at failure threshold.""" cb = CircuitBreaker(name="test", failure_threshold=3) # Record 3 failures (at threshold) asyncio.run(cb.record_failure()) asyncio.run(cb.record_failure()) asyncio.run(cb.record_failure()) assert cb.state == CircuitState.OPEN assert cb.is_available() is False def test_success_resets_in_closed(self) -> None: """Test success in closed state records properly.""" cb = CircuitBreaker(name="test", failure_threshold=3) asyncio.run(cb.record_failure()) asyncio.run(cb.record_success()) assert cb.state == CircuitState.CLOSED assert cb.stats.successes == 1 assert cb.stats.last_success_time is not None def test_half_open_transition(self) -> None: """Test transition to half-open after recovery timeout.""" cb = CircuitBreaker( name="test", failure_threshold=1, recovery_timeout=1, # 1 second ) # Open the circuit asyncio.run(cb.record_failure()) assert cb.state == CircuitState.OPEN # Wait for recovery timeout time.sleep(1.1) # State should transition to half-open assert cb.state == CircuitState.HALF_OPEN assert cb.is_available() is True def test_half_open_success_closes(self) -> None: """Test success in half-open closes circuit.""" cb = CircuitBreaker( name="test", failure_threshold=1, recovery_timeout=0, # Immediate recovery for testing ) # Open and transition to half-open asyncio.run(cb.record_failure()) time.sleep(0.1) _ = cb.state # Trigger state check assert cb.state == CircuitState.HALF_OPEN # Success should close asyncio.run(cb.record_success()) assert cb.state == CircuitState.CLOSED def test_half_open_failure_reopens(self) -> None: """Test failure in half-open reopens circuit.""" cb = CircuitBreaker( name="test", failure_threshold=1, recovery_timeout=0.05, # Small but non-zero for reliable timing ) # Open and transition to half-open asyncio.run(cb.record_failure()) assert cb.state == CircuitState.OPEN # Wait for recovery timeout time.sleep(0.1) assert cb.state == CircuitState.HALF_OPEN # Failure should reopen asyncio.run(cb.record_failure()) assert cb.state == CircuitState.OPEN def test_half_open_call_limit(self) -> None: """Test half-open call limit.""" cb = CircuitBreaker( name="test", failure_threshold=1, recovery_timeout=0, half_open_max_calls=2, ) # Open and transition to half-open asyncio.run(cb.record_failure()) time.sleep(0.1) _ = cb.state assert cb.is_available() is True # Simulate calls in half-open cb._stats.half_open_calls = 1 assert cb.is_available() is True cb._stats.half_open_calls = 2 assert cb.is_available() is False def test_time_until_recovery(self) -> None: """Test time until recovery calculation.""" cb = CircuitBreaker( name="test", failure_threshold=1, recovery_timeout=60, ) # Closed circuit has no recovery time assert cb.time_until_recovery() is None # Open circuit asyncio.run(cb.record_failure()) assert cb.state == CircuitState.OPEN # Should have recovery time remaining = cb.time_until_recovery() assert remaining is not None assert 0 <= remaining <= 60 def test_execute_success(self) -> None: """Test execute with successful function.""" cb = CircuitBreaker(name="test", failure_threshold=3) async def success_func() -> str: return "success" result = asyncio.run(cb.execute(success_func)) assert result == "success" assert cb.stats.successes == 1 def test_execute_failure(self) -> None: """Test execute with failing function.""" cb = CircuitBreaker(name="test", failure_threshold=3) async def fail_func() -> None: raise ValueError("Error") with pytest.raises(ValueError): asyncio.run(cb.execute(fail_func)) assert cb.stats.failures == 1 def test_execute_when_open(self) -> None: """Test execute raises when circuit is open.""" cb = CircuitBreaker(name="test", failure_threshold=1) # Open the circuit asyncio.run(cb.record_failure()) assert cb.state == CircuitState.OPEN async def success_func() -> str: return "success" with pytest.raises(CircuitOpenError) as exc_info: asyncio.run(cb.execute(success_func)) assert exc_info.value.provider == "test" def test_reset(self) -> None: """Test circuit reset.""" cb = CircuitBreaker(name="test", failure_threshold=1) # Open the circuit asyncio.run(cb.record_failure()) assert cb.state == CircuitState.OPEN # Reset cb.reset() assert cb.state == CircuitState.CLOSED assert cb.stats.failures == 0 assert cb.stats.successes == 0 def test_to_dict(self) -> None: """Test converting circuit to dict.""" cb = CircuitBreaker(name="test", failure_threshold=3) asyncio.run(cb.record_failure()) asyncio.run(cb.record_success()) result = cb.to_dict() assert result["name"] == "test" assert result["state"] == "closed" assert result["failures"] == 1 assert result["successes"] == 1 assert result["is_available"] is True class TestCircuitBreakerRegistry: """Tests for CircuitBreakerRegistry class.""" def test_get_circuit_creates_new(self) -> None: """Test getting a new circuit.""" settings = Settings(circuit_failure_threshold=5) registry = CircuitBreakerRegistry(settings=settings) circuit = asyncio.run(registry.get_circuit("anthropic")) assert circuit.name == "anthropic" assert circuit.failure_threshold == 5 def test_get_circuit_returns_same(self) -> None: """Test getting same circuit twice.""" registry = CircuitBreakerRegistry() circuit1 = asyncio.run(registry.get_circuit("openai")) circuit2 = asyncio.run(registry.get_circuit("openai")) assert circuit1 is circuit2 def test_get_circuit_sync(self) -> None: """Test sync circuit getter.""" registry = CircuitBreakerRegistry() circuit = registry.get_circuit_sync("google") assert circuit.name == "google" def test_is_available(self) -> None: """Test checking if circuit is available.""" registry = CircuitBreakerRegistry() assert asyncio.run(registry.is_available("test")) is True # Open the circuit circuit = asyncio.run(registry.get_circuit("test")) for _ in range(5): asyncio.run(circuit.record_failure()) assert asyncio.run(registry.is_available("test")) is False def test_record_success(self) -> None: """Test recording success through registry.""" registry = CircuitBreakerRegistry() asyncio.run(registry.record_success("test")) circuit = asyncio.run(registry.get_circuit("test")) assert circuit.stats.successes == 1 def test_record_failure(self) -> None: """Test recording failure through registry.""" registry = CircuitBreakerRegistry() asyncio.run(registry.record_failure("test")) circuit = asyncio.run(registry.get_circuit("test")) assert circuit.stats.failures == 1 def test_reset(self) -> None: """Test resetting a specific circuit.""" registry = CircuitBreakerRegistry() # Create and fail a circuit asyncio.run(registry.record_failure("test")) asyncio.run(registry.reset("test")) circuit = asyncio.run(registry.get_circuit("test")) assert circuit.stats.failures == 0 def test_reset_all(self) -> None: """Test resetting all circuits.""" registry = CircuitBreakerRegistry() # Create multiple circuits with failures asyncio.run(registry.record_failure("circuit1")) asyncio.run(registry.record_failure("circuit2")) asyncio.run(registry.reset_all()) circuit1 = asyncio.run(registry.get_circuit("circuit1")) circuit2 = asyncio.run(registry.get_circuit("circuit2")) assert circuit1.stats.failures == 0 assert circuit2.stats.failures == 0 def test_get_all_states(self) -> None: """Test getting all circuit states.""" registry = CircuitBreakerRegistry() asyncio.run(registry.get_circuit("circuit1")) asyncio.run(registry.get_circuit("circuit2")) states = registry.get_all_states() assert "circuit1" in states assert "circuit2" in states assert states["circuit1"]["state"] == "closed" def test_get_open_circuits(self) -> None: """Test getting open circuits.""" settings = Settings(circuit_failure_threshold=1) registry = CircuitBreakerRegistry(settings=settings) asyncio.run(registry.get_circuit("healthy")) asyncio.run(registry.record_failure("failing")) open_circuits = registry.get_open_circuits() assert "failing" in open_circuits assert "healthy" not in open_circuits def test_get_available_circuits(self) -> None: """Test getting available circuits.""" settings = Settings(circuit_failure_threshold=1) registry = CircuitBreakerRegistry(settings=settings) asyncio.run(registry.get_circuit("healthy")) asyncio.run(registry.record_failure("failing")) available = registry.get_available_circuits() assert "healthy" in available assert "failing" not in available class TestGlobalRegistry: """Tests for global registry functions.""" def test_get_circuit_registry(self) -> None: """Test getting global registry.""" reset_circuit_registry() registry = get_circuit_registry() assert isinstance(registry, CircuitBreakerRegistry) def test_get_circuit_registry_singleton(self) -> None: """Test registry is singleton.""" reset_circuit_registry() registry1 = get_circuit_registry() registry2 = get_circuit_registry() assert registry1 is registry2 def test_reset_circuit_registry(self) -> None: """Test resetting global registry.""" reset_circuit_registry() registry1 = get_circuit_registry() reset_circuit_registry() registry2 = get_circuit_registry() assert registry1 is not registry2