""" Pytest fixtures for LLM Gateway tests. """ import os import sys from collections.abc import AsyncIterator, Iterator from typing import Any from unittest.mock import MagicMock, patch import fakeredis.aioredis import pytest # Add parent directory to path for imports sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from config import Settings from cost_tracking import CostTracker, reset_cost_tracker from failover import CircuitBreakerRegistry, reset_circuit_registry from providers import LLMProvider, reset_provider from routing import ModelRouter, reset_model_router @pytest.fixture def test_settings() -> Settings: """Create test settings with mock API keys.""" return Settings( host="127.0.0.1", port=8001, debug=True, redis_url="redis://localhost:6379/0", anthropic_api_key="test-anthropic-key", openai_api_key="test-openai-key", google_api_key="test-google-key", litellm_timeout=30, litellm_max_retries=2, litellm_cache_enabled=False, cost_tracking_enabled=True, circuit_failure_threshold=3, circuit_recovery_timeout=10, ) @pytest.fixture def settings_no_providers() -> Settings: """Create settings with no providers configured.""" return Settings( host="127.0.0.1", port=8001, debug=False, redis_url="redis://localhost:6379/0", anthropic_api_key=None, openai_api_key=None, google_api_key=None, alibaba_api_key=None, deepseek_api_key=None, ) @pytest.fixture def fake_redis() -> fakeredis.aioredis.FakeRedis: """Create a fake Redis instance for testing.""" return fakeredis.aioredis.FakeRedis(decode_responses=True) @pytest.fixture async def cost_tracker( fake_redis: fakeredis.aioredis.FakeRedis, test_settings: Settings, ) -> AsyncIterator[CostTracker]: """Create a cost tracker with fake Redis.""" reset_cost_tracker() tracker = CostTracker(redis_client=fake_redis, settings=test_settings) yield tracker await tracker.close() reset_cost_tracker() @pytest.fixture def circuit_registry(test_settings: Settings) -> Iterator[CircuitBreakerRegistry]: """Create a circuit breaker registry for testing.""" reset_circuit_registry() registry = CircuitBreakerRegistry(settings=test_settings) yield registry reset_circuit_registry() @pytest.fixture def model_router( test_settings: Settings, circuit_registry: CircuitBreakerRegistry, ) -> Iterator[ModelRouter]: """Create a model router for testing.""" reset_model_router() router = ModelRouter(settings=test_settings, circuit_registry=circuit_registry) yield router reset_model_router() @pytest.fixture def llm_provider(test_settings: Settings) -> Iterator[LLMProvider]: """Create an LLM provider for testing.""" reset_provider() provider = LLMProvider(settings=test_settings) yield provider reset_provider() @pytest.fixture def mock_litellm_response() -> dict[str, Any]: """Create a mock LiteLLM response.""" return { "id": "test-response-id", "model": "claude-opus-4", "choices": [ { "index": 0, "message": { "role": "assistant", "content": "This is a test response.", }, "finish_reason": "stop", } ], "usage": { "prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30, }, } @pytest.fixture def mock_completion_response() -> MagicMock: """Create a mock completion response object.""" response = MagicMock() response.id = "test-response-id" response.model = "claude-opus-4" choice = MagicMock() choice.index = 0 choice.message = MagicMock() choice.message.content = "This is a test response." choice.finish_reason = "stop" response.choices = [choice] response.usage = MagicMock() response.usage.prompt_tokens = 10 response.usage.completion_tokens = 20 response.usage.total_tokens = 30 return response @pytest.fixture def sample_messages() -> list[dict[str, str]]: """Sample chat messages for testing.""" return [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello, how are you?"}, ] @pytest.fixture def sample_project_id() -> str: """Sample project ID.""" return "proj-12345678-1234-1234-1234-123456789abc" @pytest.fixture def sample_agent_id() -> str: """Sample agent ID.""" return "agent-87654321-4321-4321-4321-cba987654321" @pytest.fixture def sample_session_id() -> str: """Sample session ID.""" return "session-11111111-2222-3333-4444-555555555555" # Reset all global state after each test @pytest.fixture(autouse=True) def reset_globals() -> Iterator[None]: """Reset all global state after each test.""" yield reset_cost_tracker() reset_circuit_registry() reset_model_router() reset_provider() # Mock environment variables for tests # Note: Not autouse=True to avoid affecting default value tests @pytest.fixture def mock_env_vars() -> Iterator[None]: """Set test environment variables.""" env_vars = { "LLM_GATEWAY_HOST": "127.0.0.1", "LLM_GATEWAY_PORT": "8001", "LLM_GATEWAY_DEBUG": "true", } with patch.dict(os.environ, env_vars, clear=False): yield