forked from cardosofelipe/fast-next-template
- Add RedisClient with async connection pool management - Add cache operations (get, set, delete, expire, pattern delete) - Add JSON serialization helpers for cache - Add pub/sub operations (publish, subscribe, psubscribe) - Add health check and pool statistics - Add FastAPI dependency injection support - Update config with Redis settings (URL, SSL, TLS) - Add comprehensive tests for Redis client Implements #17 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
785 lines
26 KiB
Python
785 lines
26 KiB
Python
"""
|
|
Tests for Redis client utility functions (app/core/redis.py).
|
|
|
|
Covers:
|
|
- Cache operations (get, set, delete, expire)
|
|
- JSON serialization helpers
|
|
- Pub/sub operations
|
|
- Health check
|
|
- Connection pooling
|
|
- Error handling
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from redis.exceptions import ConnectionError, RedisError, TimeoutError
|
|
|
|
from app.core.redis import (
|
|
DEFAULT_CACHE_TTL,
|
|
POOL_MAX_CONNECTIONS,
|
|
RedisClient,
|
|
check_redis_health,
|
|
close_redis,
|
|
get_redis,
|
|
redis_client,
|
|
)
|
|
|
|
|
|
class TestRedisClientInit:
|
|
"""Test RedisClient initialization."""
|
|
|
|
def test_default_url_from_settings(self):
|
|
"""Test that default URL comes from settings."""
|
|
with patch("app.core.redis.settings") as mock_settings:
|
|
mock_settings.REDIS_URL = "redis://test:6379/0"
|
|
client = RedisClient()
|
|
assert client._url == "redis://test:6379/0"
|
|
|
|
def test_custom_url_override(self):
|
|
"""Test that custom URL overrides settings."""
|
|
client = RedisClient(url="redis://custom:6379/1")
|
|
assert client._url == "redis://custom:6379/1"
|
|
|
|
def test_initial_state(self):
|
|
"""Test initial client state."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
assert client._pool is None
|
|
assert client._client is None
|
|
|
|
|
|
class TestCacheOperations:
|
|
"""Test cache get/set/delete operations."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_set_success(self):
|
|
"""Test setting a cache value."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.set = AsyncMock(return_value=True)
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.cache_set("test-key", "test-value", ttl=60)
|
|
|
|
assert result is True
|
|
mock_redis.set.assert_called_once_with("test-key", "test-value", ex=60)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_set_default_ttl(self):
|
|
"""Test setting a cache value with default TTL."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.set = AsyncMock(return_value=True)
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.cache_set("test-key", "test-value")
|
|
|
|
assert result is True
|
|
mock_redis.set.assert_called_once_with(
|
|
"test-key", "test-value", ex=DEFAULT_CACHE_TTL
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_set_connection_error(self):
|
|
"""Test cache_set handles connection errors."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.set = AsyncMock(side_effect=ConnectionError("Connection refused"))
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.cache_set("test-key", "test-value")
|
|
|
|
assert result is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_set_timeout_error(self):
|
|
"""Test cache_set handles timeout errors."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.set = AsyncMock(side_effect=TimeoutError("Timeout"))
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.cache_set("test-key", "test-value")
|
|
|
|
assert result is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_set_redis_error(self):
|
|
"""Test cache_set handles generic Redis errors."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.set = AsyncMock(side_effect=RedisError("Unknown error"))
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.cache_set("test-key", "test-value")
|
|
|
|
assert result is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_get_success(self):
|
|
"""Test getting a cached value."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.get = AsyncMock(return_value="cached-value")
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.cache_get("test-key")
|
|
|
|
assert result == "cached-value"
|
|
mock_redis.get.assert_called_once_with("test-key")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_get_miss(self):
|
|
"""Test cache miss returns None."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.get = AsyncMock(return_value=None)
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.cache_get("nonexistent-key")
|
|
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_get_connection_error(self):
|
|
"""Test cache_get handles connection errors."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.get = AsyncMock(side_effect=ConnectionError("Connection refused"))
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.cache_get("test-key")
|
|
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_delete_success(self):
|
|
"""Test deleting a cache key."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.delete = AsyncMock(return_value=1)
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.cache_delete("test-key")
|
|
|
|
assert result is True
|
|
mock_redis.delete.assert_called_once_with("test-key")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_delete_nonexistent_key(self):
|
|
"""Test deleting a nonexistent key returns False."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.delete = AsyncMock(return_value=0)
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.cache_delete("nonexistent-key")
|
|
|
|
assert result is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_delete_connection_error(self):
|
|
"""Test cache_delete handles connection errors."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.delete = AsyncMock(side_effect=ConnectionError("Connection refused"))
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.cache_delete("test-key")
|
|
|
|
assert result is False
|
|
|
|
|
|
class TestCacheDeletePattern:
|
|
"""Test cache_delete_pattern operation."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_delete_pattern_success(self):
|
|
"""Test deleting keys by pattern."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.delete = AsyncMock(return_value=1)
|
|
|
|
# Create async iterator for scan_iter
|
|
async def mock_scan_iter(pattern):
|
|
for key in ["user:1", "user:2", "user:3"]:
|
|
yield key
|
|
|
|
mock_redis.scan_iter = mock_scan_iter
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.cache_delete_pattern("user:*")
|
|
|
|
assert result == 3
|
|
assert mock_redis.delete.call_count == 3
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_delete_pattern_no_matches(self):
|
|
"""Test deleting pattern with no matches."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
|
|
async def mock_scan_iter(pattern):
|
|
if False: # Empty iterator
|
|
yield
|
|
|
|
mock_redis.scan_iter = mock_scan_iter
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.cache_delete_pattern("nonexistent:*")
|
|
|
|
assert result == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_delete_pattern_error(self):
|
|
"""Test cache_delete_pattern handles errors."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
|
|
async def mock_scan_iter(pattern):
|
|
raise ConnectionError("Connection lost")
|
|
if False: # Make it a generator
|
|
yield
|
|
|
|
mock_redis.scan_iter = mock_scan_iter
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.cache_delete_pattern("user:*")
|
|
|
|
assert result == 0
|
|
|
|
|
|
class TestCacheExpire:
|
|
"""Test cache_expire operation."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_expire_success(self):
|
|
"""Test setting TTL on existing key."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.expire = AsyncMock(return_value=True)
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.cache_expire("test-key", 120)
|
|
|
|
assert result is True
|
|
mock_redis.expire.assert_called_once_with("test-key", 120)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_expire_nonexistent_key(self):
|
|
"""Test setting TTL on nonexistent key."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.expire = AsyncMock(return_value=False)
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.cache_expire("nonexistent-key", 120)
|
|
|
|
assert result is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_expire_error(self):
|
|
"""Test cache_expire handles errors."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.expire = AsyncMock(side_effect=ConnectionError("Error"))
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.cache_expire("test-key", 120)
|
|
|
|
assert result is False
|
|
|
|
|
|
class TestCacheHelpers:
|
|
"""Test cache helper methods (exists, ttl)."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_exists_true(self):
|
|
"""Test cache_exists returns True for existing key."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.exists = AsyncMock(return_value=1)
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.cache_exists("test-key")
|
|
|
|
assert result is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_exists_false(self):
|
|
"""Test cache_exists returns False for nonexistent key."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.exists = AsyncMock(return_value=0)
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.cache_exists("nonexistent-key")
|
|
|
|
assert result is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_exists_error(self):
|
|
"""Test cache_exists handles errors."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.exists = AsyncMock(side_effect=ConnectionError("Error"))
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.cache_exists("test-key")
|
|
|
|
assert result is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_ttl_with_ttl(self):
|
|
"""Test cache_ttl returns remaining TTL."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.ttl = AsyncMock(return_value=300)
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.cache_ttl("test-key")
|
|
|
|
assert result == 300
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_ttl_no_ttl(self):
|
|
"""Test cache_ttl returns -1 for key without TTL."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.ttl = AsyncMock(return_value=-1)
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.cache_ttl("test-key")
|
|
|
|
assert result == -1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_ttl_nonexistent_key(self):
|
|
"""Test cache_ttl returns -2 for nonexistent key."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.ttl = AsyncMock(return_value=-2)
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.cache_ttl("nonexistent-key")
|
|
|
|
assert result == -2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_ttl_error(self):
|
|
"""Test cache_ttl handles errors."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.ttl = AsyncMock(side_effect=ConnectionError("Error"))
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.cache_ttl("test-key")
|
|
|
|
assert result == -2
|
|
|
|
|
|
class TestJsonOperations:
|
|
"""Test JSON serialization cache operations."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_set_json_success(self):
|
|
"""Test setting a JSON value in cache."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.set = AsyncMock(return_value=True)
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
data = {"user": "test", "count": 42}
|
|
result = await client.cache_set_json("test-key", data, ttl=60)
|
|
|
|
assert result is True
|
|
mock_redis.set.assert_called_once()
|
|
# Verify JSON was serialized
|
|
call_args = mock_redis.set.call_args
|
|
assert call_args[0][1] == json.dumps(data)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_set_json_serialization_error(self):
|
|
"""Test cache_set_json handles serialization errors."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
# Object that can't be serialized
|
|
class NonSerializable:
|
|
pass
|
|
|
|
result = await client.cache_set_json("test-key", NonSerializable())
|
|
assert result is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_get_json_success(self):
|
|
"""Test getting a JSON value from cache."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
data = {"user": "test", "count": 42}
|
|
mock_redis.get = AsyncMock(return_value=json.dumps(data))
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.cache_get_json("test-key")
|
|
|
|
assert result == data
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_get_json_miss(self):
|
|
"""Test cache_get_json returns None on cache miss."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.get = AsyncMock(return_value=None)
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.cache_get_json("nonexistent-key")
|
|
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_get_json_invalid_json(self):
|
|
"""Test cache_get_json handles invalid JSON."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.get = AsyncMock(return_value="not valid json {{{")
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.cache_get_json("test-key")
|
|
|
|
assert result is None
|
|
|
|
|
|
class TestPubSubOperations:
|
|
"""Test pub/sub operations."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_publish_string_message(self):
|
|
"""Test publishing a string message."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.publish = AsyncMock(return_value=2)
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.publish("test-channel", "hello world")
|
|
|
|
assert result == 2
|
|
mock_redis.publish.assert_called_once_with("test-channel", "hello world")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_publish_dict_message(self):
|
|
"""Test publishing a dict message (JSON serialized)."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.publish = AsyncMock(return_value=1)
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
data = {"event": "user_created", "user_id": 123}
|
|
result = await client.publish("events", data)
|
|
|
|
assert result == 1
|
|
mock_redis.publish.assert_called_once_with("events", json.dumps(data))
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_publish_connection_error(self):
|
|
"""Test publish handles connection errors."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.publish = AsyncMock(side_effect=ConnectionError("Connection lost"))
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.publish("test-channel", "hello")
|
|
|
|
assert result == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_subscribe_context_manager(self):
|
|
"""Test subscribe context manager."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_pubsub = AsyncMock()
|
|
mock_pubsub.subscribe = AsyncMock()
|
|
mock_pubsub.unsubscribe = AsyncMock()
|
|
mock_pubsub.close = AsyncMock()
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.pubsub = MagicMock(return_value=mock_pubsub)
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
async with client.subscribe("channel1", "channel2") as pubsub:
|
|
assert pubsub is mock_pubsub
|
|
mock_pubsub.subscribe.assert_called_once_with("channel1", "channel2")
|
|
|
|
# After exiting context, should unsubscribe and close
|
|
mock_pubsub.unsubscribe.assert_called_once_with("channel1", "channel2")
|
|
mock_pubsub.close.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_psubscribe_context_manager(self):
|
|
"""Test pattern subscribe context manager."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_pubsub = AsyncMock()
|
|
mock_pubsub.psubscribe = AsyncMock()
|
|
mock_pubsub.punsubscribe = AsyncMock()
|
|
mock_pubsub.close = AsyncMock()
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.pubsub = MagicMock(return_value=mock_pubsub)
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
async with client.psubscribe("user:*", "event:*") as pubsub:
|
|
assert pubsub is mock_pubsub
|
|
mock_pubsub.psubscribe.assert_called_once_with("user:*", "event:*")
|
|
|
|
mock_pubsub.punsubscribe.assert_called_once_with("user:*", "event:*")
|
|
mock_pubsub.close.assert_called_once()
|
|
|
|
|
|
class TestHealthCheck:
|
|
"""Test health check functionality."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check_success(self):
|
|
"""Test health check returns True when Redis is healthy."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.ping = AsyncMock(return_value=True)
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.health_check()
|
|
|
|
assert result is True
|
|
mock_redis.ping.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check_connection_error(self):
|
|
"""Test health check returns False on connection error."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.ping = AsyncMock(side_effect=ConnectionError("Connection refused"))
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.health_check()
|
|
|
|
assert result is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check_timeout_error(self):
|
|
"""Test health check returns False on timeout."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.ping = AsyncMock(side_effect=TimeoutError("Timeout"))
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.health_check()
|
|
|
|
assert result is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check_redis_error(self):
|
|
"""Test health check returns False on Redis error."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.ping = AsyncMock(side_effect=RedisError("Unknown error"))
|
|
|
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
|
result = await client.health_check()
|
|
|
|
assert result is False
|
|
|
|
|
|
class TestConnectionPooling:
|
|
"""Test connection pooling functionality."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pool_initialization(self):
|
|
"""Test that pool is lazily initialized."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
assert client._pool is None
|
|
|
|
with patch("app.core.redis.ConnectionPool") as MockPool:
|
|
mock_pool = MagicMock()
|
|
MockPool.from_url = MagicMock(return_value=mock_pool)
|
|
|
|
pool = await client._ensure_pool()
|
|
|
|
assert pool is mock_pool
|
|
MockPool.from_url.assert_called_once_with(
|
|
"redis://localhost:6379/0",
|
|
max_connections=POOL_MAX_CONNECTIONS,
|
|
socket_timeout=10,
|
|
socket_connect_timeout=10,
|
|
decode_responses=True,
|
|
health_check_interval=30,
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pool_reuses_existing(self):
|
|
"""Test that pool is reused after initialization."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_pool = MagicMock()
|
|
client._pool = mock_pool
|
|
|
|
pool = await client._ensure_pool()
|
|
|
|
assert pool is mock_pool
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_close_disposes_resources(self):
|
|
"""Test that close() disposes pool and client."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
mock_client = AsyncMock()
|
|
mock_pool = AsyncMock()
|
|
mock_pool.disconnect = AsyncMock()
|
|
|
|
client._client = mock_client
|
|
client._pool = mock_pool
|
|
|
|
await client.close()
|
|
|
|
mock_client.close.assert_called_once()
|
|
mock_pool.disconnect.assert_called_once()
|
|
assert client._client is None
|
|
assert client._pool is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_close_handles_none(self):
|
|
"""Test that close() handles None client and pool gracefully."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
# Should not raise
|
|
await client.close()
|
|
|
|
assert client._client is None
|
|
assert client._pool is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_pool_info_not_initialized(self):
|
|
"""Test pool info when not initialized."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
info = await client.get_pool_info()
|
|
|
|
assert info == {"status": "not_initialized"}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_pool_info_active(self):
|
|
"""Test pool info when active."""
|
|
client = RedisClient(url="redis://user:pass@localhost:6379/0")
|
|
|
|
mock_pool = MagicMock()
|
|
client._pool = mock_pool
|
|
|
|
info = await client.get_pool_info()
|
|
|
|
assert info["status"] == "active"
|
|
assert info["max_connections"] == POOL_MAX_CONNECTIONS
|
|
# Password should be hidden
|
|
assert "pass" not in info["url"]
|
|
assert "localhost:6379/0" in info["url"]
|
|
|
|
|
|
class TestModuleLevelFunctions:
|
|
"""Test module-level convenience functions."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_redis_dependency(self):
|
|
"""Test get_redis FastAPI dependency."""
|
|
redis_gen = get_redis()
|
|
|
|
client = await redis_gen.__anext__()
|
|
assert client is redis_client
|
|
|
|
# Cleanup
|
|
with pytest.raises(StopAsyncIteration):
|
|
await redis_gen.__anext__()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_redis_health(self):
|
|
"""Test module-level check_redis_health function."""
|
|
with patch.object(redis_client, "health_check", return_value=True) as mock:
|
|
result = await check_redis_health()
|
|
|
|
assert result is True
|
|
mock.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_close_redis(self):
|
|
"""Test module-level close_redis function."""
|
|
with patch.object(redis_client, "close") as mock:
|
|
await close_redis()
|
|
|
|
mock.assert_called_once()
|
|
|
|
|
|
class TestThreadSafety:
|
|
"""Test thread-safety of pool initialization."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_concurrent_pool_initialization(self):
|
|
"""Test that concurrent _ensure_pool calls create only one pool."""
|
|
client = RedisClient(url="redis://localhost:6379/0")
|
|
|
|
call_count = 0
|
|
mock_pool = MagicMock()
|
|
|
|
def counting_from_url(*args, **kwargs):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
return mock_pool
|
|
|
|
with patch("app.core.redis.ConnectionPool") as MockPool:
|
|
MockPool.from_url = MagicMock(side_effect=counting_from_url)
|
|
|
|
# Start multiple concurrent _ensure_pool calls
|
|
results = await asyncio.gather(
|
|
client._ensure_pool(),
|
|
client._ensure_pool(),
|
|
client._ensure_pool(),
|
|
)
|
|
|
|
# All results should be the same pool instance
|
|
assert results[0] is results[1] is results[2]
|
|
assert results[0] is mock_pool
|
|
# Pool should only be created once despite concurrent calls
|
|
assert call_count == 1
|