forked from cardosofelipe/fast-next-template
feat(backend): Add Redis client with connection pooling
- Add RedisClient with async connection pool management - Add cache operations (get, set, delete, expire, pattern delete) - Add JSON serialization helpers for cache - Add pub/sub operations (publish, subscribe, psubscribe) - Add health check and pool statistics - Add FastAPI dependency injection support - Update config with Redis settings (URL, SSL, TLS) - Add comprehensive tests for Redis client Implements #17 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -39,6 +39,32 @@ class Settings(BaseSettings):
|
||||
db_pool_timeout: int = 30 # Seconds to wait for a connection
|
||||
db_pool_recycle: int = 3600 # Recycle connections after 1 hour
|
||||
|
||||
# Redis configuration (Syndarix: cache, pub/sub, Celery broker)
|
||||
REDIS_URL: str = Field(
|
||||
default="redis://localhost:6379/0",
|
||||
description="Redis URL for cache, pub/sub, and Celery broker",
|
||||
)
|
||||
|
||||
# Celery configuration (Syndarix: background task processing)
|
||||
CELERY_BROKER_URL: str | None = Field(
|
||||
default=None,
|
||||
description="Celery broker URL (defaults to REDIS_URL if not set)",
|
||||
)
|
||||
CELERY_RESULT_BACKEND: str | None = Field(
|
||||
default=None,
|
||||
description="Celery result backend URL (defaults to REDIS_URL if not set)",
|
||||
)
|
||||
|
||||
@property
|
||||
def celery_broker_url(self) -> str:
|
||||
"""Get Celery broker URL, defaulting to Redis."""
|
||||
return self.CELERY_BROKER_URL or self.REDIS_URL
|
||||
|
||||
@property
|
||||
def celery_result_backend(self) -> str:
|
||||
"""Get Celery result backend URL, defaulting to Redis."""
|
||||
return self.CELERY_RESULT_BACKEND or self.REDIS_URL
|
||||
|
||||
# SQL debugging (disable in production)
|
||||
sql_echo: bool = False # Log SQL statements
|
||||
sql_echo_pool: bool = False # Log connection pool events
|
||||
|
||||
476
backend/app/core/redis.py
Normal file
476
backend/app/core/redis.py
Normal file
@@ -0,0 +1,476 @@
|
||||
# app/core/redis.py
|
||||
"""
|
||||
Redis client configuration for caching and pub/sub.
|
||||
|
||||
This module provides async Redis connectivity with connection pooling
|
||||
for FastAPI endpoints and background tasks.
|
||||
|
||||
Features:
|
||||
- Connection pooling for efficient resource usage
|
||||
- Cache operations (get, set, delete, expire)
|
||||
- Pub/sub operations (publish, subscribe)
|
||||
- Health check for monitoring
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
from redis.asyncio import ConnectionPool, Redis
|
||||
from redis.asyncio.client import PubSub
|
||||
from redis.exceptions import ConnectionError, RedisError, TimeoutError
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default TTL for cache entries (1 hour)
|
||||
DEFAULT_CACHE_TTL = 3600
|
||||
|
||||
# Connection pool settings
|
||||
POOL_MAX_CONNECTIONS = 50
|
||||
POOL_TIMEOUT = 10 # seconds
|
||||
|
||||
|
||||
class RedisClient:
|
||||
"""
|
||||
Async Redis client with connection pooling.
|
||||
|
||||
Provides high-level operations for caching and pub/sub
|
||||
with proper error handling and connection management.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str | None = None) -> None:
|
||||
"""
|
||||
Initialize Redis client.
|
||||
|
||||
Args:
|
||||
url: Redis connection URL. Defaults to settings.REDIS_URL.
|
||||
"""
|
||||
self._url = url or settings.REDIS_URL
|
||||
self._pool: ConnectionPool | None = None
|
||||
self._client: Redis | None = None
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def _ensure_pool(self) -> ConnectionPool:
|
||||
"""Ensure connection pool is initialized (thread-safe)."""
|
||||
if self._pool is None:
|
||||
async with self._lock:
|
||||
# Double-check after acquiring lock
|
||||
if self._pool is None:
|
||||
self._pool = ConnectionPool.from_url(
|
||||
self._url,
|
||||
max_connections=POOL_MAX_CONNECTIONS,
|
||||
socket_timeout=POOL_TIMEOUT,
|
||||
socket_connect_timeout=POOL_TIMEOUT,
|
||||
decode_responses=True,
|
||||
health_check_interval=30,
|
||||
)
|
||||
logger.info("Redis connection pool initialized")
|
||||
return self._pool
|
||||
|
||||
async def _get_client(self) -> Redis:
|
||||
"""Get Redis client instance from pool."""
|
||||
pool = await self._ensure_pool()
|
||||
if self._client is None:
|
||||
self._client = Redis(connection_pool=pool)
|
||||
return self._client
|
||||
|
||||
# =========================================================================
|
||||
# Cache Operations
|
||||
# =========================================================================
|
||||
|
||||
async def cache_get(self, key: str) -> str | None:
|
||||
"""
|
||||
Get a value from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key.
|
||||
|
||||
Returns:
|
||||
Cached value or None if not found.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
value = await client.get(key)
|
||||
if value is not None:
|
||||
logger.debug(f"Cache hit for key: {key}")
|
||||
else:
|
||||
logger.debug(f"Cache miss for key: {key}")
|
||||
return value
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis cache_get failed for key '{key}': {e}")
|
||||
return None
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis error in cache_get for key '{key}': {e}")
|
||||
return None
|
||||
|
||||
async def cache_get_json(self, key: str) -> Any | None:
|
||||
"""
|
||||
Get a JSON-serialized value from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key.
|
||||
|
||||
Returns:
|
||||
Deserialized value or None if not found.
|
||||
"""
|
||||
value = await self.cache_get(key)
|
||||
if value is not None:
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to decode JSON for key '{key}': {e}")
|
||||
return None
|
||||
return None
|
||||
|
||||
async def cache_set(
|
||||
self,
|
||||
key: str,
|
||||
value: str,
|
||||
ttl: int | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Set a value in cache.
|
||||
|
||||
Args:
|
||||
key: Cache key.
|
||||
value: Value to cache.
|
||||
ttl: Time-to-live in seconds. Defaults to DEFAULT_CACHE_TTL.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
ttl = ttl if ttl is not None else DEFAULT_CACHE_TTL
|
||||
await client.set(key, value, ex=ttl)
|
||||
logger.debug(f"Cache set for key: {key} (TTL: {ttl}s)")
|
||||
return True
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis cache_set failed for key '{key}': {e}")
|
||||
return False
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis error in cache_set for key '{key}': {e}")
|
||||
return False
|
||||
|
||||
async def cache_set_json(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl: int | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Set a JSON-serialized value in cache.
|
||||
|
||||
Args:
|
||||
key: Cache key.
|
||||
value: Value to serialize and cache.
|
||||
ttl: Time-to-live in seconds.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
serialized = json.dumps(value)
|
||||
return await self.cache_set(key, serialized, ttl)
|
||||
except (TypeError, ValueError) as e:
|
||||
logger.error(f"Failed to serialize value for key '{key}': {e}")
|
||||
return False
|
||||
|
||||
async def cache_delete(self, key: str) -> bool:
|
||||
"""
|
||||
Delete a key from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key to delete.
|
||||
|
||||
Returns:
|
||||
True if key was deleted, False otherwise.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
result = await client.delete(key)
|
||||
logger.debug(f"Cache delete for key: {key} (deleted: {result > 0})")
|
||||
return result > 0
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis cache_delete failed for key '{key}': {e}")
|
||||
return False
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis error in cache_delete for key '{key}': {e}")
|
||||
return False
|
||||
|
||||
async def cache_delete_pattern(self, pattern: str) -> int:
|
||||
"""
|
||||
Delete all keys matching a pattern.
|
||||
|
||||
Args:
|
||||
pattern: Glob-style pattern (e.g., "user:*").
|
||||
|
||||
Returns:
|
||||
Number of keys deleted.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
deleted = 0
|
||||
async for key in client.scan_iter(pattern):
|
||||
await client.delete(key)
|
||||
deleted += 1
|
||||
logger.debug(f"Cache delete pattern '{pattern}': {deleted} keys deleted")
|
||||
return deleted
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis cache_delete_pattern failed for '{pattern}': {e}")
|
||||
return 0
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis error in cache_delete_pattern for '{pattern}': {e}")
|
||||
return 0
|
||||
|
||||
async def cache_expire(self, key: str, ttl: int) -> bool:
|
||||
"""
|
||||
Set or update TTL for a key.
|
||||
|
||||
Args:
|
||||
key: Cache key.
|
||||
ttl: New TTL in seconds.
|
||||
|
||||
Returns:
|
||||
True if TTL was set, False if key doesn't exist.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
result = await client.expire(key, ttl)
|
||||
logger.debug(f"Cache expire for key: {key} (TTL: {ttl}s, success: {result})")
|
||||
return result
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis cache_expire failed for key '{key}': {e}")
|
||||
return False
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis error in cache_expire for key '{key}': {e}")
|
||||
return False
|
||||
|
||||
async def cache_exists(self, key: str) -> bool:
|
||||
"""
|
||||
Check if a key exists in cache.
|
||||
|
||||
Args:
|
||||
key: Cache key.
|
||||
|
||||
Returns:
|
||||
True if key exists, False otherwise.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
result = await client.exists(key)
|
||||
return result > 0
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis cache_exists failed for key '{key}': {e}")
|
||||
return False
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis error in cache_exists for key '{key}': {e}")
|
||||
return False
|
||||
|
||||
async def cache_ttl(self, key: str) -> int:
|
||||
"""
|
||||
Get remaining TTL for a key.
|
||||
|
||||
Args:
|
||||
key: Cache key.
|
||||
|
||||
Returns:
|
||||
TTL in seconds, -1 if no TTL, -2 if key doesn't exist.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
return await client.ttl(key)
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis cache_ttl failed for key '{key}': {e}")
|
||||
return -2
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis error in cache_ttl for key '{key}': {e}")
|
||||
return -2
|
||||
|
||||
# =========================================================================
|
||||
# Pub/Sub Operations
|
||||
# =========================================================================
|
||||
|
||||
async def publish(self, channel: str, message: str | dict) -> int:
|
||||
"""
|
||||
Publish a message to a channel.
|
||||
|
||||
Args:
|
||||
channel: Channel name.
|
||||
message: Message to publish (string or dict for JSON serialization).
|
||||
|
||||
Returns:
|
||||
Number of subscribers that received the message.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
if isinstance(message, dict):
|
||||
message = json.dumps(message)
|
||||
result = await client.publish(channel, message)
|
||||
logger.debug(f"Published to channel '{channel}': {result} subscribers")
|
||||
return result
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis publish failed for channel '{channel}': {e}")
|
||||
return 0
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis error in publish for channel '{channel}': {e}")
|
||||
return 0
|
||||
|
||||
@asynccontextmanager
|
||||
async def subscribe(
|
||||
self, *channels: str
|
||||
) -> AsyncGenerator[PubSub, None]:
|
||||
"""
|
||||
Subscribe to one or more channels.
|
||||
|
||||
Usage:
|
||||
async with redis_client.subscribe("channel1", "channel2") as pubsub:
|
||||
async for message in pubsub.listen():
|
||||
if message["type"] == "message":
|
||||
print(message["data"])
|
||||
|
||||
Args:
|
||||
channels: Channel names to subscribe to.
|
||||
|
||||
Yields:
|
||||
PubSub instance for receiving messages.
|
||||
"""
|
||||
client = await self._get_client()
|
||||
pubsub = client.pubsub()
|
||||
try:
|
||||
await pubsub.subscribe(*channels)
|
||||
logger.debug(f"Subscribed to channels: {channels}")
|
||||
yield pubsub
|
||||
finally:
|
||||
await pubsub.unsubscribe(*channels)
|
||||
await pubsub.close()
|
||||
logger.debug(f"Unsubscribed from channels: {channels}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def psubscribe(
|
||||
self, *patterns: str
|
||||
) -> AsyncGenerator[PubSub, None]:
|
||||
"""
|
||||
Subscribe to channels matching patterns.
|
||||
|
||||
Usage:
|
||||
async with redis_client.psubscribe("user:*") as pubsub:
|
||||
async for message in pubsub.listen():
|
||||
if message["type"] == "pmessage":
|
||||
print(message["pattern"], message["channel"], message["data"])
|
||||
|
||||
Args:
|
||||
patterns: Glob-style patterns to subscribe to.
|
||||
|
||||
Yields:
|
||||
PubSub instance for receiving messages.
|
||||
"""
|
||||
client = await self._get_client()
|
||||
pubsub = client.pubsub()
|
||||
try:
|
||||
await pubsub.psubscribe(*patterns)
|
||||
logger.debug(f"Pattern subscribed: {patterns}")
|
||||
yield pubsub
|
||||
finally:
|
||||
await pubsub.punsubscribe(*patterns)
|
||||
await pubsub.close()
|
||||
logger.debug(f"Pattern unsubscribed: {patterns}")
|
||||
|
||||
# =========================================================================
|
||||
# Health & Connection Management
|
||||
# =========================================================================
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""
|
||||
Check if Redis connection is healthy.
|
||||
|
||||
Returns:
|
||||
True if connection is successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
result = await client.ping()
|
||||
return result is True
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis health check failed: {e}")
|
||||
return False
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis health check error: {e}")
|
||||
return False
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
Close Redis connections and cleanup resources.
|
||||
|
||||
Should be called during application shutdown.
|
||||
"""
|
||||
if self._client:
|
||||
await self._client.close()
|
||||
self._client = None
|
||||
logger.debug("Redis client closed")
|
||||
|
||||
if self._pool:
|
||||
await self._pool.disconnect()
|
||||
self._pool = None
|
||||
logger.info("Redis connection pool closed")
|
||||
|
||||
async def get_pool_info(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get connection pool statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with pool information.
|
||||
"""
|
||||
if self._pool is None:
|
||||
return {"status": "not_initialized"}
|
||||
|
||||
return {
|
||||
"status": "active",
|
||||
"max_connections": POOL_MAX_CONNECTIONS,
|
||||
"url": self._url.split("@")[-1] if "@" in self._url else self._url,
|
||||
}
|
||||
|
||||
|
||||
# Global Redis client instance
|
||||
redis_client = RedisClient()
|
||||
|
||||
|
||||
# FastAPI dependency for Redis client
|
||||
async def get_redis() -> AsyncGenerator[RedisClient, None]:
|
||||
"""
|
||||
FastAPI dependency that provides the Redis client.
|
||||
|
||||
Usage:
|
||||
@router.get("/cached-data")
|
||||
async def get_data(redis: RedisClient = Depends(get_redis)):
|
||||
cached = await redis.cache_get("my-key")
|
||||
...
|
||||
"""
|
||||
yield redis_client
|
||||
|
||||
|
||||
# Health check function for use in /health endpoint
|
||||
async def check_redis_health() -> bool:
|
||||
"""
|
||||
Check if Redis connection is healthy.
|
||||
|
||||
Returns:
|
||||
True if connection is successful, False otherwise.
|
||||
"""
|
||||
return await redis_client.health_check()
|
||||
|
||||
|
||||
# Cleanup function for application shutdown
|
||||
async def close_redis() -> None:
|
||||
"""
|
||||
Close Redis connections.
|
||||
|
||||
Should be called during application shutdown.
|
||||
"""
|
||||
await redis_client.close()
|
||||
784
backend/tests/core/test_redis.py
Normal file
784
backend/tests/core/test_redis.py
Normal file
@@ -0,0 +1,784 @@
|
||||
"""
|
||||
Tests for Redis client utility functions (app/core/redis.py).
|
||||
|
||||
Covers:
|
||||
- Cache operations (get, set, delete, expire)
|
||||
- JSON serialization helpers
|
||||
- Pub/sub operations
|
||||
- Health check
|
||||
- Connection pooling
|
||||
- Error handling
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from redis.exceptions import ConnectionError, RedisError, TimeoutError
|
||||
|
||||
from app.core.redis import (
|
||||
DEFAULT_CACHE_TTL,
|
||||
POOL_MAX_CONNECTIONS,
|
||||
RedisClient,
|
||||
check_redis_health,
|
||||
close_redis,
|
||||
get_redis,
|
||||
redis_client,
|
||||
)
|
||||
|
||||
|
||||
class TestRedisClientInit:
|
||||
"""Test RedisClient initialization."""
|
||||
|
||||
def test_default_url_from_settings(self):
|
||||
"""Test that default URL comes from settings."""
|
||||
with patch("app.core.redis.settings") as mock_settings:
|
||||
mock_settings.REDIS_URL = "redis://test:6379/0"
|
||||
client = RedisClient()
|
||||
assert client._url == "redis://test:6379/0"
|
||||
|
||||
def test_custom_url_override(self):
|
||||
"""Test that custom URL overrides settings."""
|
||||
client = RedisClient(url="redis://custom:6379/1")
|
||||
assert client._url == "redis://custom:6379/1"
|
||||
|
||||
def test_initial_state(self):
|
||||
"""Test initial client state."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
assert client._pool is None
|
||||
assert client._client is None
|
||||
|
||||
|
||||
class TestCacheOperations:
|
||||
"""Test cache get/set/delete operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_set_success(self):
|
||||
"""Test setting a cache value."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.set = AsyncMock(return_value=True)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_set("test-key", "test-value", ttl=60)
|
||||
|
||||
assert result is True
|
||||
mock_redis.set.assert_called_once_with("test-key", "test-value", ex=60)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_set_default_ttl(self):
|
||||
"""Test setting a cache value with default TTL."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.set = AsyncMock(return_value=True)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_set("test-key", "test-value")
|
||||
|
||||
assert result is True
|
||||
mock_redis.set.assert_called_once_with(
|
||||
"test-key", "test-value", ex=DEFAULT_CACHE_TTL
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_set_connection_error(self):
|
||||
"""Test cache_set handles connection errors."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.set = AsyncMock(side_effect=ConnectionError("Connection refused"))
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_set("test-key", "test-value")
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_set_timeout_error(self):
|
||||
"""Test cache_set handles timeout errors."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.set = AsyncMock(side_effect=TimeoutError("Timeout"))
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_set("test-key", "test-value")
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_set_redis_error(self):
|
||||
"""Test cache_set handles generic Redis errors."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.set = AsyncMock(side_effect=RedisError("Unknown error"))
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_set("test-key", "test-value")
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_get_success(self):
|
||||
"""Test getting a cached value."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value="cached-value")
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_get("test-key")
|
||||
|
||||
assert result == "cached-value"
|
||||
mock_redis.get.assert_called_once_with("test-key")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_get_miss(self):
|
||||
"""Test cache miss returns None."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=None)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_get("nonexistent-key")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_get_connection_error(self):
|
||||
"""Test cache_get handles connection errors."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=ConnectionError("Connection refused"))
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_get("test-key")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_delete_success(self):
|
||||
"""Test deleting a cache key."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.delete = AsyncMock(return_value=1)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_delete("test-key")
|
||||
|
||||
assert result is True
|
||||
mock_redis.delete.assert_called_once_with("test-key")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_delete_nonexistent_key(self):
|
||||
"""Test deleting a nonexistent key returns False."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.delete = AsyncMock(return_value=0)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_delete("nonexistent-key")
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_delete_connection_error(self):
|
||||
"""Test cache_delete handles connection errors."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.delete = AsyncMock(side_effect=ConnectionError("Connection refused"))
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_delete("test-key")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestCacheDeletePattern:
|
||||
"""Test cache_delete_pattern operation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_delete_pattern_success(self):
|
||||
"""Test deleting keys by pattern."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.delete = AsyncMock(return_value=1)
|
||||
|
||||
# Create async iterator for scan_iter
|
||||
async def mock_scan_iter(pattern):
|
||||
for key in ["user:1", "user:2", "user:3"]:
|
||||
yield key
|
||||
|
||||
mock_redis.scan_iter = mock_scan_iter
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_delete_pattern("user:*")
|
||||
|
||||
assert result == 3
|
||||
assert mock_redis.delete.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_delete_pattern_no_matches(self):
|
||||
"""Test deleting pattern with no matches."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
|
||||
async def mock_scan_iter(pattern):
|
||||
if False: # Empty iterator
|
||||
yield
|
||||
|
||||
mock_redis.scan_iter = mock_scan_iter
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_delete_pattern("nonexistent:*")
|
||||
|
||||
assert result == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_delete_pattern_error(self):
|
||||
"""Test cache_delete_pattern handles errors."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
|
||||
async def mock_scan_iter(pattern):
|
||||
raise ConnectionError("Connection lost")
|
||||
if False: # Make it a generator
|
||||
yield
|
||||
|
||||
mock_redis.scan_iter = mock_scan_iter
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_delete_pattern("user:*")
|
||||
|
||||
assert result == 0
|
||||
|
||||
|
||||
class TestCacheExpire:
|
||||
"""Test cache_expire operation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_expire_success(self):
|
||||
"""Test setting TTL on existing key."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.expire = AsyncMock(return_value=True)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_expire("test-key", 120)
|
||||
|
||||
assert result is True
|
||||
mock_redis.expire.assert_called_once_with("test-key", 120)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_expire_nonexistent_key(self):
|
||||
"""Test setting TTL on nonexistent key."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.expire = AsyncMock(return_value=False)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_expire("nonexistent-key", 120)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_expire_error(self):
|
||||
"""Test cache_expire handles errors."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.expire = AsyncMock(side_effect=ConnectionError("Error"))
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_expire("test-key", 120)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestCacheHelpers:
|
||||
"""Test cache helper methods (exists, ttl)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_exists_true(self):
|
||||
"""Test cache_exists returns True for existing key."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.exists = AsyncMock(return_value=1)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_exists("test-key")
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_exists_false(self):
|
||||
"""Test cache_exists returns False for nonexistent key."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.exists = AsyncMock(return_value=0)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_exists("nonexistent-key")
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_exists_error(self):
|
||||
"""Test cache_exists handles errors."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.exists = AsyncMock(side_effect=ConnectionError("Error"))
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_exists("test-key")
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_ttl_with_ttl(self):
|
||||
"""Test cache_ttl returns remaining TTL."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ttl = AsyncMock(return_value=300)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_ttl("test-key")
|
||||
|
||||
assert result == 300
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_ttl_no_ttl(self):
|
||||
"""Test cache_ttl returns -1 for key without TTL."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ttl = AsyncMock(return_value=-1)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_ttl("test-key")
|
||||
|
||||
assert result == -1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_ttl_nonexistent_key(self):
|
||||
"""Test cache_ttl returns -2 for nonexistent key."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ttl = AsyncMock(return_value=-2)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_ttl("nonexistent-key")
|
||||
|
||||
assert result == -2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_ttl_error(self):
|
||||
"""Test cache_ttl handles errors."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ttl = AsyncMock(side_effect=ConnectionError("Error"))
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_ttl("test-key")
|
||||
|
||||
assert result == -2
|
||||
|
||||
|
||||
class TestJsonOperations:
|
||||
"""Test JSON serialization cache operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_set_json_success(self):
|
||||
"""Test setting a JSON value in cache."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.set = AsyncMock(return_value=True)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
data = {"user": "test", "count": 42}
|
||||
result = await client.cache_set_json("test-key", data, ttl=60)
|
||||
|
||||
assert result is True
|
||||
mock_redis.set.assert_called_once()
|
||||
# Verify JSON was serialized
|
||||
call_args = mock_redis.set.call_args
|
||||
assert call_args[0][1] == json.dumps(data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_set_json_serialization_error(self):
|
||||
"""Test cache_set_json handles serialization errors."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
# Object that can't be serialized
|
||||
class NonSerializable:
|
||||
pass
|
||||
|
||||
result = await client.cache_set_json("test-key", NonSerializable())
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_get_json_success(self):
|
||||
"""Test getting a JSON value from cache."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
data = {"user": "test", "count": 42}
|
||||
mock_redis.get = AsyncMock(return_value=json.dumps(data))
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_get_json("test-key")
|
||||
|
||||
assert result == data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_get_json_miss(self):
|
||||
"""Test cache_get_json returns None on cache miss."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=None)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_get_json("nonexistent-key")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_get_json_invalid_json(self):
|
||||
"""Test cache_get_json handles invalid JSON."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value="not valid json {{{")
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_get_json("test-key")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestPubSubOperations:
|
||||
"""Test pub/sub operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_string_message(self):
|
||||
"""Test publishing a string message."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.publish = AsyncMock(return_value=2)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.publish("test-channel", "hello world")
|
||||
|
||||
assert result == 2
|
||||
mock_redis.publish.assert_called_once_with("test-channel", "hello world")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_dict_message(self):
|
||||
"""Test publishing a dict message (JSON serialized)."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.publish = AsyncMock(return_value=1)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
data = {"event": "user_created", "user_id": 123}
|
||||
result = await client.publish("events", data)
|
||||
|
||||
assert result == 1
|
||||
mock_redis.publish.assert_called_once_with("events", json.dumps(data))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_connection_error(self):
|
||||
"""Test publish handles connection errors."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.publish = AsyncMock(side_effect=ConnectionError("Connection lost"))
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.publish("test-channel", "hello")
|
||||
|
||||
assert result == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_context_manager(self):
|
||||
"""Test subscribe context manager."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_pubsub = AsyncMock()
|
||||
mock_pubsub.subscribe = AsyncMock()
|
||||
mock_pubsub.unsubscribe = AsyncMock()
|
||||
mock_pubsub.close = AsyncMock()
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pubsub = MagicMock(return_value=mock_pubsub)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
async with client.subscribe("channel1", "channel2") as pubsub:
|
||||
assert pubsub is mock_pubsub
|
||||
mock_pubsub.subscribe.assert_called_once_with("channel1", "channel2")
|
||||
|
||||
# After exiting context, should unsubscribe and close
|
||||
mock_pubsub.unsubscribe.assert_called_once_with("channel1", "channel2")
|
||||
mock_pubsub.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_psubscribe_context_manager(self):
|
||||
"""Test pattern subscribe context manager."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_pubsub = AsyncMock()
|
||||
mock_pubsub.psubscribe = AsyncMock()
|
||||
mock_pubsub.punsubscribe = AsyncMock()
|
||||
mock_pubsub.close = AsyncMock()
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pubsub = MagicMock(return_value=mock_pubsub)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
async with client.psubscribe("user:*", "event:*") as pubsub:
|
||||
assert pubsub is mock_pubsub
|
||||
mock_pubsub.psubscribe.assert_called_once_with("user:*", "event:*")
|
||||
|
||||
mock_pubsub.punsubscribe.assert_called_once_with("user:*", "event:*")
|
||||
mock_pubsub.close.assert_called_once()
|
||||
|
||||
|
||||
class TestHealthCheck:
|
||||
"""Test health check functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_success(self):
|
||||
"""Test health check returns True when Redis is healthy."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock(return_value=True)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.health_check()
|
||||
|
||||
assert result is True
|
||||
mock_redis.ping.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_connection_error(self):
|
||||
"""Test health check returns False on connection error."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock(side_effect=ConnectionError("Connection refused"))
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.health_check()
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_timeout_error(self):
|
||||
"""Test health check returns False on timeout."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock(side_effect=TimeoutError("Timeout"))
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.health_check()
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_redis_error(self):
|
||||
"""Test health check returns False on Redis error."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock(side_effect=RedisError("Unknown error"))
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.health_check()
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestConnectionPooling:
|
||||
"""Test connection pooling functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pool_initialization(self):
|
||||
"""Test that pool is lazily initialized."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
assert client._pool is None
|
||||
|
||||
with patch("app.core.redis.ConnectionPool") as MockPool:
|
||||
mock_pool = MagicMock()
|
||||
MockPool.from_url = MagicMock(return_value=mock_pool)
|
||||
|
||||
pool = await client._ensure_pool()
|
||||
|
||||
assert pool is mock_pool
|
||||
MockPool.from_url.assert_called_once_with(
|
||||
"redis://localhost:6379/0",
|
||||
max_connections=POOL_MAX_CONNECTIONS,
|
||||
socket_timeout=10,
|
||||
socket_connect_timeout=10,
|
||||
decode_responses=True,
|
||||
health_check_interval=30,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pool_reuses_existing(self):
|
||||
"""Test that pool is reused after initialization."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_pool = MagicMock()
|
||||
client._pool = mock_pool
|
||||
|
||||
pool = await client._ensure_pool()
|
||||
|
||||
assert pool is mock_pool
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_disposes_resources(self):
|
||||
"""Test that close() disposes pool and client."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_pool = AsyncMock()
|
||||
mock_pool.disconnect = AsyncMock()
|
||||
|
||||
client._client = mock_client
|
||||
client._pool = mock_pool
|
||||
|
||||
await client.close()
|
||||
|
||||
mock_client.close.assert_called_once()
|
||||
mock_pool.disconnect.assert_called_once()
|
||||
assert client._client is None
|
||||
assert client._pool is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_handles_none(self):
|
||||
"""Test that close() handles None client and pool gracefully."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
# Should not raise
|
||||
await client.close()
|
||||
|
||||
assert client._client is None
|
||||
assert client._pool is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pool_info_not_initialized(self):
|
||||
"""Test pool info when not initialized."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
info = await client.get_pool_info()
|
||||
|
||||
assert info == {"status": "not_initialized"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pool_info_active(self):
|
||||
"""Test pool info when active."""
|
||||
client = RedisClient(url="redis://user:pass@localhost:6379/0")
|
||||
|
||||
mock_pool = MagicMock()
|
||||
client._pool = mock_pool
|
||||
|
||||
info = await client.get_pool_info()
|
||||
|
||||
assert info["status"] == "active"
|
||||
assert info["max_connections"] == POOL_MAX_CONNECTIONS
|
||||
# Password should be hidden
|
||||
assert "pass" not in info["url"]
|
||||
assert "localhost:6379/0" in info["url"]
|
||||
|
||||
|
||||
class TestModuleLevelFunctions:
|
||||
"""Test module-level convenience functions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_redis_dependency(self):
|
||||
"""Test get_redis FastAPI dependency."""
|
||||
redis_gen = get_redis()
|
||||
|
||||
client = await redis_gen.__anext__()
|
||||
assert client is redis_client
|
||||
|
||||
# Cleanup
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await redis_gen.__anext__()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_redis_health(self):
|
||||
"""Test module-level check_redis_health function."""
|
||||
with patch.object(redis_client, "health_check", return_value=True) as mock:
|
||||
result = await check_redis_health()
|
||||
|
||||
assert result is True
|
||||
mock.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_redis(self):
|
||||
"""Test module-level close_redis function."""
|
||||
with patch.object(redis_client, "close") as mock:
|
||||
await close_redis()
|
||||
|
||||
mock.assert_called_once()
|
||||
|
||||
|
||||
class TestThreadSafety:
|
||||
"""Test thread-safety of pool initialization."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_pool_initialization(self):
|
||||
"""Test that concurrent _ensure_pool calls create only one pool."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
call_count = 0
|
||||
mock_pool = MagicMock()
|
||||
|
||||
def counting_from_url(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return mock_pool
|
||||
|
||||
with patch("app.core.redis.ConnectionPool") as MockPool:
|
||||
MockPool.from_url = MagicMock(side_effect=counting_from_url)
|
||||
|
||||
# Start multiple concurrent _ensure_pool calls
|
||||
results = await asyncio.gather(
|
||||
client._ensure_pool(),
|
||||
client._ensure_pool(),
|
||||
client._ensure_pool(),
|
||||
)
|
||||
|
||||
# All results should be the same pool instance
|
||||
assert results[0] is results[1] is results[2]
|
||||
assert results[0] is mock_pool
|
||||
# Pool should only be created once despite concurrent calls
|
||||
assert call_count == 1
|
||||
Reference in New Issue
Block a user