""" Tests for Redis client utility functions (app/core/redis.py). Covers: - Cache operations (get, set, delete, expire) - JSON serialization helpers - Pub/sub operations - Health check - Connection pooling - Error handling """ import asyncio import json from unittest.mock import AsyncMock, MagicMock, patch import pytest from redis.exceptions import ConnectionError, RedisError, TimeoutError from app.core.redis import ( DEFAULT_CACHE_TTL, POOL_MAX_CONNECTIONS, RedisClient, check_redis_health, close_redis, get_redis, redis_client, ) class TestRedisClientInit: """Test RedisClient initialization.""" def test_default_url_from_settings(self): """Test that default URL comes from settings.""" with patch("app.core.redis.settings") as mock_settings: mock_settings.REDIS_URL = "redis://test:6379/0" client = RedisClient() assert client._url == "redis://test:6379/0" def test_custom_url_override(self): """Test that custom URL overrides settings.""" client = RedisClient(url="redis://custom:6379/1") assert client._url == "redis://custom:6379/1" def test_initial_state(self): """Test initial client state.""" client = RedisClient(url="redis://localhost:6379/0") assert client._pool is None assert client._client is None class TestCacheOperations: """Test cache get/set/delete operations.""" @pytest.mark.asyncio async def test_cache_set_success(self): """Test setting a cache value.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() mock_redis.set = AsyncMock(return_value=True) with patch.object(client, "_get_client", return_value=mock_redis): result = await client.cache_set("test-key", "test-value", ttl=60) assert result is True mock_redis.set.assert_called_once_with("test-key", "test-value", ex=60) @pytest.mark.asyncio async def test_cache_set_default_ttl(self): """Test setting a cache value with default TTL.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() mock_redis.set = AsyncMock(return_value=True) with patch.object(client, "_get_client", return_value=mock_redis): result = await client.cache_set("test-key", "test-value") assert result is True mock_redis.set.assert_called_once_with( "test-key", "test-value", ex=DEFAULT_CACHE_TTL ) @pytest.mark.asyncio async def test_cache_set_connection_error(self): """Test cache_set handles connection errors.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() mock_redis.set = AsyncMock(side_effect=ConnectionError("Connection refused")) with patch.object(client, "_get_client", return_value=mock_redis): result = await client.cache_set("test-key", "test-value") assert result is False @pytest.mark.asyncio async def test_cache_set_timeout_error(self): """Test cache_set handles timeout errors.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() mock_redis.set = AsyncMock(side_effect=TimeoutError("Timeout")) with patch.object(client, "_get_client", return_value=mock_redis): result = await client.cache_set("test-key", "test-value") assert result is False @pytest.mark.asyncio async def test_cache_set_redis_error(self): """Test cache_set handles generic Redis errors.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() mock_redis.set = AsyncMock(side_effect=RedisError("Unknown error")) with patch.object(client, "_get_client", return_value=mock_redis): result = await client.cache_set("test-key", "test-value") assert result is False @pytest.mark.asyncio async def test_cache_get_success(self): """Test getting a cached value.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() mock_redis.get = AsyncMock(return_value="cached-value") with patch.object(client, "_get_client", return_value=mock_redis): result = await client.cache_get("test-key") assert result == "cached-value" mock_redis.get.assert_called_once_with("test-key") @pytest.mark.asyncio async def test_cache_get_miss(self): """Test cache miss returns None.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() mock_redis.get = AsyncMock(return_value=None) with patch.object(client, "_get_client", return_value=mock_redis): result = await client.cache_get("nonexistent-key") assert result is None @pytest.mark.asyncio async def test_cache_get_connection_error(self): """Test cache_get handles connection errors.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() mock_redis.get = AsyncMock(side_effect=ConnectionError("Connection refused")) with patch.object(client, "_get_client", return_value=mock_redis): result = await client.cache_get("test-key") assert result is None @pytest.mark.asyncio async def test_cache_delete_success(self): """Test deleting a cache key.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() mock_redis.delete = AsyncMock(return_value=1) with patch.object(client, "_get_client", return_value=mock_redis): result = await client.cache_delete("test-key") assert result is True mock_redis.delete.assert_called_once_with("test-key") @pytest.mark.asyncio async def test_cache_delete_nonexistent_key(self): """Test deleting a nonexistent key returns False.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() mock_redis.delete = AsyncMock(return_value=0) with patch.object(client, "_get_client", return_value=mock_redis): result = await client.cache_delete("nonexistent-key") assert result is False @pytest.mark.asyncio async def test_cache_delete_connection_error(self): """Test cache_delete handles connection errors.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() mock_redis.delete = AsyncMock(side_effect=ConnectionError("Connection refused")) with patch.object(client, "_get_client", return_value=mock_redis): result = await client.cache_delete("test-key") assert result is False class TestCacheDeletePattern: """Test cache_delete_pattern operation.""" @pytest.mark.asyncio async def test_cache_delete_pattern_success(self): """Test deleting keys by pattern.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() mock_redis.delete = AsyncMock(return_value=1) # Create async iterator for scan_iter async def mock_scan_iter(pattern): for key in ["user:1", "user:2", "user:3"]: yield key mock_redis.scan_iter = mock_scan_iter with patch.object(client, "_get_client", return_value=mock_redis): result = await client.cache_delete_pattern("user:*") assert result == 3 assert mock_redis.delete.call_count == 3 @pytest.mark.asyncio async def test_cache_delete_pattern_no_matches(self): """Test deleting pattern with no matches.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() async def mock_scan_iter(pattern): if False: # Empty iterator yield mock_redis.scan_iter = mock_scan_iter with patch.object(client, "_get_client", return_value=mock_redis): result = await client.cache_delete_pattern("nonexistent:*") assert result == 0 @pytest.mark.asyncio async def test_cache_delete_pattern_error(self): """Test cache_delete_pattern handles errors.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() async def mock_scan_iter(pattern): raise ConnectionError("Connection lost") if False: # Make it a generator yield mock_redis.scan_iter = mock_scan_iter with patch.object(client, "_get_client", return_value=mock_redis): result = await client.cache_delete_pattern("user:*") assert result == 0 class TestCacheExpire: """Test cache_expire operation.""" @pytest.mark.asyncio async def test_cache_expire_success(self): """Test setting TTL on existing key.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() mock_redis.expire = AsyncMock(return_value=True) with patch.object(client, "_get_client", return_value=mock_redis): result = await client.cache_expire("test-key", 120) assert result is True mock_redis.expire.assert_called_once_with("test-key", 120) @pytest.mark.asyncio async def test_cache_expire_nonexistent_key(self): """Test setting TTL on nonexistent key.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() mock_redis.expire = AsyncMock(return_value=False) with patch.object(client, "_get_client", return_value=mock_redis): result = await client.cache_expire("nonexistent-key", 120) assert result is False @pytest.mark.asyncio async def test_cache_expire_error(self): """Test cache_expire handles errors.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() mock_redis.expire = AsyncMock(side_effect=ConnectionError("Error")) with patch.object(client, "_get_client", return_value=mock_redis): result = await client.cache_expire("test-key", 120) assert result is False class TestCacheHelpers: """Test cache helper methods (exists, ttl).""" @pytest.mark.asyncio async def test_cache_exists_true(self): """Test cache_exists returns True for existing key.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() mock_redis.exists = AsyncMock(return_value=1) with patch.object(client, "_get_client", return_value=mock_redis): result = await client.cache_exists("test-key") assert result is True @pytest.mark.asyncio async def test_cache_exists_false(self): """Test cache_exists returns False for nonexistent key.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() mock_redis.exists = AsyncMock(return_value=0) with patch.object(client, "_get_client", return_value=mock_redis): result = await client.cache_exists("nonexistent-key") assert result is False @pytest.mark.asyncio async def test_cache_exists_error(self): """Test cache_exists handles errors.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() mock_redis.exists = AsyncMock(side_effect=ConnectionError("Error")) with patch.object(client, "_get_client", return_value=mock_redis): result = await client.cache_exists("test-key") assert result is False @pytest.mark.asyncio async def test_cache_ttl_with_ttl(self): """Test cache_ttl returns remaining TTL.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() mock_redis.ttl = AsyncMock(return_value=300) with patch.object(client, "_get_client", return_value=mock_redis): result = await client.cache_ttl("test-key") assert result == 300 @pytest.mark.asyncio async def test_cache_ttl_no_ttl(self): """Test cache_ttl returns -1 for key without TTL.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() mock_redis.ttl = AsyncMock(return_value=-1) with patch.object(client, "_get_client", return_value=mock_redis): result = await client.cache_ttl("test-key") assert result == -1 @pytest.mark.asyncio async def test_cache_ttl_nonexistent_key(self): """Test cache_ttl returns -2 for nonexistent key.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() mock_redis.ttl = AsyncMock(return_value=-2) with patch.object(client, "_get_client", return_value=mock_redis): result = await client.cache_ttl("nonexistent-key") assert result == -2 @pytest.mark.asyncio async def test_cache_ttl_error(self): """Test cache_ttl handles errors.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() mock_redis.ttl = AsyncMock(side_effect=ConnectionError("Error")) with patch.object(client, "_get_client", return_value=mock_redis): result = await client.cache_ttl("test-key") assert result == -2 class TestJsonOperations: """Test JSON serialization cache operations.""" @pytest.mark.asyncio async def test_cache_set_json_success(self): """Test setting a JSON value in cache.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() mock_redis.set = AsyncMock(return_value=True) with patch.object(client, "_get_client", return_value=mock_redis): data = {"user": "test", "count": 42} result = await client.cache_set_json("test-key", data, ttl=60) assert result is True mock_redis.set.assert_called_once() # Verify JSON was serialized call_args = mock_redis.set.call_args assert call_args[0][1] == json.dumps(data) @pytest.mark.asyncio async def test_cache_set_json_serialization_error(self): """Test cache_set_json handles serialization errors.""" client = RedisClient(url="redis://localhost:6379/0") # Object that can't be serialized class NonSerializable: pass result = await client.cache_set_json("test-key", NonSerializable()) assert result is False @pytest.mark.asyncio async def test_cache_get_json_success(self): """Test getting a JSON value from cache.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() data = {"user": "test", "count": 42} mock_redis.get = AsyncMock(return_value=json.dumps(data)) with patch.object(client, "_get_client", return_value=mock_redis): result = await client.cache_get_json("test-key") assert result == data @pytest.mark.asyncio async def test_cache_get_json_miss(self): """Test cache_get_json returns None on cache miss.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() mock_redis.get = AsyncMock(return_value=None) with patch.object(client, "_get_client", return_value=mock_redis): result = await client.cache_get_json("nonexistent-key") assert result is None @pytest.mark.asyncio async def test_cache_get_json_invalid_json(self): """Test cache_get_json handles invalid JSON.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() mock_redis.get = AsyncMock(return_value="not valid json {{{") with patch.object(client, "_get_client", return_value=mock_redis): result = await client.cache_get_json("test-key") assert result is None class TestPubSubOperations: """Test pub/sub operations.""" @pytest.mark.asyncio async def test_publish_string_message(self): """Test publishing a string message.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() mock_redis.publish = AsyncMock(return_value=2) with patch.object(client, "_get_client", return_value=mock_redis): result = await client.publish("test-channel", "hello world") assert result == 2 mock_redis.publish.assert_called_once_with("test-channel", "hello world") @pytest.mark.asyncio async def test_publish_dict_message(self): """Test publishing a dict message (JSON serialized).""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() mock_redis.publish = AsyncMock(return_value=1) with patch.object(client, "_get_client", return_value=mock_redis): data = {"event": "user_created", "user_id": 123} result = await client.publish("events", data) assert result == 1 mock_redis.publish.assert_called_once_with("events", json.dumps(data)) @pytest.mark.asyncio async def test_publish_connection_error(self): """Test publish handles connection errors.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() mock_redis.publish = AsyncMock(side_effect=ConnectionError("Connection lost")) with patch.object(client, "_get_client", return_value=mock_redis): result = await client.publish("test-channel", "hello") assert result == 0 @pytest.mark.asyncio async def test_subscribe_context_manager(self): """Test subscribe context manager.""" client = RedisClient(url="redis://localhost:6379/0") mock_pubsub = AsyncMock() mock_pubsub.subscribe = AsyncMock() mock_pubsub.unsubscribe = AsyncMock() mock_pubsub.close = AsyncMock() mock_redis = AsyncMock() mock_redis.pubsub = MagicMock(return_value=mock_pubsub) with patch.object(client, "_get_client", return_value=mock_redis): async with client.subscribe("channel1", "channel2") as pubsub: assert pubsub is mock_pubsub mock_pubsub.subscribe.assert_called_once_with("channel1", "channel2") # After exiting context, should unsubscribe and close mock_pubsub.unsubscribe.assert_called_once_with("channel1", "channel2") mock_pubsub.close.assert_called_once() @pytest.mark.asyncio async def test_psubscribe_context_manager(self): """Test pattern subscribe context manager.""" client = RedisClient(url="redis://localhost:6379/0") mock_pubsub = AsyncMock() mock_pubsub.psubscribe = AsyncMock() mock_pubsub.punsubscribe = AsyncMock() mock_pubsub.close = AsyncMock() mock_redis = AsyncMock() mock_redis.pubsub = MagicMock(return_value=mock_pubsub) with patch.object(client, "_get_client", return_value=mock_redis): async with client.psubscribe("user:*", "event:*") as pubsub: assert pubsub is mock_pubsub mock_pubsub.psubscribe.assert_called_once_with("user:*", "event:*") mock_pubsub.punsubscribe.assert_called_once_with("user:*", "event:*") mock_pubsub.close.assert_called_once() class TestHealthCheck: """Test health check functionality.""" @pytest.mark.asyncio async def test_health_check_success(self): """Test health check returns True when Redis is healthy.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() mock_redis.ping = AsyncMock(return_value=True) with patch.object(client, "_get_client", return_value=mock_redis): result = await client.health_check() assert result is True mock_redis.ping.assert_called_once() @pytest.mark.asyncio async def test_health_check_connection_error(self): """Test health check returns False on connection error.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() mock_redis.ping = AsyncMock(side_effect=ConnectionError("Connection refused")) with patch.object(client, "_get_client", return_value=mock_redis): result = await client.health_check() assert result is False @pytest.mark.asyncio async def test_health_check_timeout_error(self): """Test health check returns False on timeout.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() mock_redis.ping = AsyncMock(side_effect=TimeoutError("Timeout")) with patch.object(client, "_get_client", return_value=mock_redis): result = await client.health_check() assert result is False @pytest.mark.asyncio async def test_health_check_redis_error(self): """Test health check returns False on Redis error.""" client = RedisClient(url="redis://localhost:6379/0") mock_redis = AsyncMock() mock_redis.ping = AsyncMock(side_effect=RedisError("Unknown error")) with patch.object(client, "_get_client", return_value=mock_redis): result = await client.health_check() assert result is False class TestConnectionPooling: """Test connection pooling functionality.""" @pytest.mark.asyncio async def test_pool_initialization(self): """Test that pool is lazily initialized.""" client = RedisClient(url="redis://localhost:6379/0") assert client._pool is None with patch("app.core.redis.ConnectionPool") as MockPool: mock_pool = MagicMock() MockPool.from_url = MagicMock(return_value=mock_pool) pool = await client._ensure_pool() assert pool is mock_pool MockPool.from_url.assert_called_once_with( "redis://localhost:6379/0", max_connections=POOL_MAX_CONNECTIONS, socket_timeout=10, socket_connect_timeout=10, decode_responses=True, health_check_interval=30, ) @pytest.mark.asyncio async def test_pool_reuses_existing(self): """Test that pool is reused after initialization.""" client = RedisClient(url="redis://localhost:6379/0") mock_pool = MagicMock() client._pool = mock_pool pool = await client._ensure_pool() assert pool is mock_pool @pytest.mark.asyncio async def test_close_disposes_resources(self): """Test that close() disposes pool and client.""" client = RedisClient(url="redis://localhost:6379/0") mock_client = AsyncMock() mock_pool = AsyncMock() mock_pool.disconnect = AsyncMock() client._client = mock_client client._pool = mock_pool await client.close() mock_client.close.assert_called_once() mock_pool.disconnect.assert_called_once() assert client._client is None assert client._pool is None @pytest.mark.asyncio async def test_close_handles_none(self): """Test that close() handles None client and pool gracefully.""" client = RedisClient(url="redis://localhost:6379/0") # Should not raise await client.close() assert client._client is None assert client._pool is None @pytest.mark.asyncio async def test_get_pool_info_not_initialized(self): """Test pool info when not initialized.""" client = RedisClient(url="redis://localhost:6379/0") info = await client.get_pool_info() assert info == {"status": "not_initialized"} @pytest.mark.asyncio async def test_get_pool_info_active(self): """Test pool info when active.""" client = RedisClient(url="redis://user:pass@localhost:6379/0") mock_pool = MagicMock() client._pool = mock_pool info = await client.get_pool_info() assert info["status"] == "active" assert info["max_connections"] == POOL_MAX_CONNECTIONS # Password should be hidden assert "pass" not in info["url"] assert "localhost:6379/0" in info["url"] class TestModuleLevelFunctions: """Test module-level convenience functions.""" @pytest.mark.asyncio async def test_get_redis_dependency(self): """Test get_redis FastAPI dependency.""" redis_gen = get_redis() client = await redis_gen.__anext__() assert client is redis_client # Cleanup with pytest.raises(StopAsyncIteration): await redis_gen.__anext__() @pytest.mark.asyncio async def test_check_redis_health(self): """Test module-level check_redis_health function.""" with patch.object(redis_client, "health_check", return_value=True) as mock: result = await check_redis_health() assert result is True mock.assert_called_once() @pytest.mark.asyncio async def test_close_redis(self): """Test module-level close_redis function.""" with patch.object(redis_client, "close") as mock: await close_redis() mock.assert_called_once() class TestThreadSafety: """Test thread-safety of pool initialization.""" @pytest.mark.asyncio async def test_concurrent_pool_initialization(self): """Test that concurrent _ensure_pool calls create only one pool.""" client = RedisClient(url="redis://localhost:6379/0") call_count = 0 mock_pool = MagicMock() def counting_from_url(*args, **kwargs): nonlocal call_count call_count += 1 return mock_pool with patch("app.core.redis.ConnectionPool") as MockPool: MockPool.from_url = MagicMock(side_effect=counting_from_url) # Start multiple concurrent _ensure_pool calls results = await asyncio.gather( client._ensure_pool(), client._ensure_pool(), client._ensure_pool(), ) # All results should be the same pool instance assert results[0] is results[1] is results[2] assert results[0] is mock_pool # Pool should only be created once despite concurrent calls assert call_count == 1