diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 43e1ec0..54c8f90 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -39,6 +39,32 @@ class Settings(BaseSettings): db_pool_timeout: int = 30 # Seconds to wait for a connection db_pool_recycle: int = 3600 # Recycle connections after 1 hour + # Redis configuration (Syndarix: cache, pub/sub, Celery broker) + REDIS_URL: str = Field( + default="redis://localhost:6379/0", + description="Redis URL for cache, pub/sub, and Celery broker", + ) + + # Celery configuration (Syndarix: background task processing) + CELERY_BROKER_URL: str | None = Field( + default=None, + description="Celery broker URL (defaults to REDIS_URL if not set)", + ) + CELERY_RESULT_BACKEND: str | None = Field( + default=None, + description="Celery result backend URL (defaults to REDIS_URL if not set)", + ) + + @property + def celery_broker_url(self) -> str: + """Get Celery broker URL, defaulting to Redis.""" + return self.CELERY_BROKER_URL or self.REDIS_URL + + @property + def celery_result_backend(self) -> str: + """Get Celery result backend URL, defaulting to Redis.""" + return self.CELERY_RESULT_BACKEND or self.REDIS_URL + # SQL debugging (disable in production) sql_echo: bool = False # Log SQL statements sql_echo_pool: bool = False # Log connection pool events diff --git a/backend/app/core/redis.py b/backend/app/core/redis.py new file mode 100644 index 0000000..43da57b --- /dev/null +++ b/backend/app/core/redis.py @@ -0,0 +1,476 @@ +# app/core/redis.py +""" +Redis client configuration for caching and pub/sub. + +This module provides async Redis connectivity with connection pooling +for FastAPI endpoints and background tasks. + +Features: +- Connection pooling for efficient resource usage +- Cache operations (get, set, delete, expire) +- Pub/sub operations (publish, subscribe) +- Health check for monitoring +""" + +import asyncio +import json +import logging +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from typing import Any + +from redis.asyncio import ConnectionPool, Redis +from redis.asyncio.client import PubSub +from redis.exceptions import ConnectionError, RedisError, TimeoutError + +from app.core.config import settings + +# Configure logging +logger = logging.getLogger(__name__) + +# Default TTL for cache entries (1 hour) +DEFAULT_CACHE_TTL = 3600 + +# Connection pool settings +POOL_MAX_CONNECTIONS = 50 +POOL_TIMEOUT = 10 # seconds + + +class RedisClient: + """ + Async Redis client with connection pooling. + + Provides high-level operations for caching and pub/sub + with proper error handling and connection management. + """ + + def __init__(self, url: str | None = None) -> None: + """ + Initialize Redis client. + + Args: + url: Redis connection URL. Defaults to settings.REDIS_URL. + """ + self._url = url or settings.REDIS_URL + self._pool: ConnectionPool | None = None + self._client: Redis | None = None + self._lock = asyncio.Lock() + + async def _ensure_pool(self) -> ConnectionPool: + """Ensure connection pool is initialized (thread-safe).""" + if self._pool is None: + async with self._lock: + # Double-check after acquiring lock + if self._pool is None: + self._pool = ConnectionPool.from_url( + self._url, + max_connections=POOL_MAX_CONNECTIONS, + socket_timeout=POOL_TIMEOUT, + socket_connect_timeout=POOL_TIMEOUT, + decode_responses=True, + health_check_interval=30, + ) + logger.info("Redis connection pool initialized") + return self._pool + + async def _get_client(self) -> Redis: + """Get Redis client instance from pool.""" + pool = await self._ensure_pool() + if self._client is None: + self._client = Redis(connection_pool=pool) + return self._client + + # ========================================================================= + # Cache Operations + # ========================================================================= + + async def cache_get(self, key: str) -> str | None: + """ + Get a value from cache. + + Args: + key: Cache key. + + Returns: + Cached value or None if not found. + """ + try: + client = await self._get_client() + value = await client.get(key) + if value is not None: + logger.debug(f"Cache hit for key: {key}") + else: + logger.debug(f"Cache miss for key: {key}") + return value + except (ConnectionError, TimeoutError) as e: + logger.error(f"Redis cache_get failed for key '{key}': {e}") + return None + except RedisError as e: + logger.error(f"Redis error in cache_get for key '{key}': {e}") + return None + + async def cache_get_json(self, key: str) -> Any | None: + """ + Get a JSON-serialized value from cache. + + Args: + key: Cache key. + + Returns: + Deserialized value or None if not found. + """ + value = await self.cache_get(key) + if value is not None: + try: + return json.loads(value) + except json.JSONDecodeError as e: + logger.error(f"Failed to decode JSON for key '{key}': {e}") + return None + return None + + async def cache_set( + self, + key: str, + value: str, + ttl: int | None = None, + ) -> bool: + """ + Set a value in cache. + + Args: + key: Cache key. + value: Value to cache. + ttl: Time-to-live in seconds. Defaults to DEFAULT_CACHE_TTL. + + Returns: + True if successful, False otherwise. + """ + try: + client = await self._get_client() + ttl = ttl if ttl is not None else DEFAULT_CACHE_TTL + await client.set(key, value, ex=ttl) + logger.debug(f"Cache set for key: {key} (TTL: {ttl}s)") + return True + except (ConnectionError, TimeoutError) as e: + logger.error(f"Redis cache_set failed for key '{key}': {e}") + return False + except RedisError as e: + logger.error(f"Redis error in cache_set for key '{key}': {e}") + return False + + async def cache_set_json( + self, + key: str, + value: Any, + ttl: int | None = None, + ) -> bool: + """ + Set a JSON-serialized value in cache. + + Args: + key: Cache key. + value: Value to serialize and cache. + ttl: Time-to-live in seconds. + + Returns: + True if successful, False otherwise. + """ + try: + serialized = json.dumps(value) + return await self.cache_set(key, serialized, ttl) + except (TypeError, ValueError) as e: + logger.error(f"Failed to serialize value for key '{key}': {e}") + return False + + async def cache_delete(self, key: str) -> bool: + """ + Delete a key from cache. + + Args: + key: Cache key to delete. + + Returns: + True if key was deleted, False otherwise. + """ + try: + client = await self._get_client() + result = await client.delete(key) + logger.debug(f"Cache delete for key: {key} (deleted: {result > 0})") + return result > 0 + except (ConnectionError, TimeoutError) as e: + logger.error(f"Redis cache_delete failed for key '{key}': {e}") + return False + except RedisError as e: + logger.error(f"Redis error in cache_delete for key '{key}': {e}") + return False + + async def cache_delete_pattern(self, pattern: str) -> int: + """ + Delete all keys matching a pattern. + + Args: + pattern: Glob-style pattern (e.g., "user:*"). + + Returns: + Number of keys deleted. + """ + try: + client = await self._get_client() + deleted = 0 + async for key in client.scan_iter(pattern): + await client.delete(key) + deleted += 1 + logger.debug(f"Cache delete pattern '{pattern}': {deleted} keys deleted") + return deleted + except (ConnectionError, TimeoutError) as e: + logger.error(f"Redis cache_delete_pattern failed for '{pattern}': {e}") + return 0 + except RedisError as e: + logger.error(f"Redis error in cache_delete_pattern for '{pattern}': {e}") + return 0 + + async def cache_expire(self, key: str, ttl: int) -> bool: + """ + Set or update TTL for a key. + + Args: + key: Cache key. + ttl: New TTL in seconds. + + Returns: + True if TTL was set, False if key doesn't exist. + """ + try: + client = await self._get_client() + result = await client.expire(key, ttl) + logger.debug(f"Cache expire for key: {key} (TTL: {ttl}s, success: {result})") + return result + except (ConnectionError, TimeoutError) as e: + logger.error(f"Redis cache_expire failed for key '{key}': {e}") + return False + except RedisError as e: + logger.error(f"Redis error in cache_expire for key '{key}': {e}") + return False + + async def cache_exists(self, key: str) -> bool: + """ + Check if a key exists in cache. + + Args: + key: Cache key. + + Returns: + True if key exists, False otherwise. + """ + try: + client = await self._get_client() + result = await client.exists(key) + return result > 0 + except (ConnectionError, TimeoutError) as e: + logger.error(f"Redis cache_exists failed for key '{key}': {e}") + return False + except RedisError as e: + logger.error(f"Redis error in cache_exists for key '{key}': {e}") + return False + + async def cache_ttl(self, key: str) -> int: + """ + Get remaining TTL for a key. + + Args: + key: Cache key. + + Returns: + TTL in seconds, -1 if no TTL, -2 if key doesn't exist. + """ + try: + client = await self._get_client() + return await client.ttl(key) + except (ConnectionError, TimeoutError) as e: + logger.error(f"Redis cache_ttl failed for key '{key}': {e}") + return -2 + except RedisError as e: + logger.error(f"Redis error in cache_ttl for key '{key}': {e}") + return -2 + + # ========================================================================= + # Pub/Sub Operations + # ========================================================================= + + async def publish(self, channel: str, message: str | dict) -> int: + """ + Publish a message to a channel. + + Args: + channel: Channel name. + message: Message to publish (string or dict for JSON serialization). + + Returns: + Number of subscribers that received the message. + """ + try: + client = await self._get_client() + if isinstance(message, dict): + message = json.dumps(message) + result = await client.publish(channel, message) + logger.debug(f"Published to channel '{channel}': {result} subscribers") + return result + except (ConnectionError, TimeoutError) as e: + logger.error(f"Redis publish failed for channel '{channel}': {e}") + return 0 + except RedisError as e: + logger.error(f"Redis error in publish for channel '{channel}': {e}") + return 0 + + @asynccontextmanager + async def subscribe( + self, *channels: str + ) -> AsyncGenerator[PubSub, None]: + """ + Subscribe to one or more channels. + + Usage: + async with redis_client.subscribe("channel1", "channel2") as pubsub: + async for message in pubsub.listen(): + if message["type"] == "message": + print(message["data"]) + + Args: + channels: Channel names to subscribe to. + + Yields: + PubSub instance for receiving messages. + """ + client = await self._get_client() + pubsub = client.pubsub() + try: + await pubsub.subscribe(*channels) + logger.debug(f"Subscribed to channels: {channels}") + yield pubsub + finally: + await pubsub.unsubscribe(*channels) + await pubsub.close() + logger.debug(f"Unsubscribed from channels: {channels}") + + @asynccontextmanager + async def psubscribe( + self, *patterns: str + ) -> AsyncGenerator[PubSub, None]: + """ + Subscribe to channels matching patterns. + + Usage: + async with redis_client.psubscribe("user:*") as pubsub: + async for message in pubsub.listen(): + if message["type"] == "pmessage": + print(message["pattern"], message["channel"], message["data"]) + + Args: + patterns: Glob-style patterns to subscribe to. + + Yields: + PubSub instance for receiving messages. + """ + client = await self._get_client() + pubsub = client.pubsub() + try: + await pubsub.psubscribe(*patterns) + logger.debug(f"Pattern subscribed: {patterns}") + yield pubsub + finally: + await pubsub.punsubscribe(*patterns) + await pubsub.close() + logger.debug(f"Pattern unsubscribed: {patterns}") + + # ========================================================================= + # Health & Connection Management + # ========================================================================= + + async def health_check(self) -> bool: + """ + Check if Redis connection is healthy. + + Returns: + True if connection is successful, False otherwise. + """ + try: + client = await self._get_client() + result = await client.ping() + return result is True + except (ConnectionError, TimeoutError) as e: + logger.error(f"Redis health check failed: {e}") + return False + except RedisError as e: + logger.error(f"Redis health check error: {e}") + return False + + async def close(self) -> None: + """ + Close Redis connections and cleanup resources. + + Should be called during application shutdown. + """ + if self._client: + await self._client.close() + self._client = None + logger.debug("Redis client closed") + + if self._pool: + await self._pool.disconnect() + self._pool = None + logger.info("Redis connection pool closed") + + async def get_pool_info(self) -> dict[str, Any]: + """ + Get connection pool statistics. + + Returns: + Dictionary with pool information. + """ + if self._pool is None: + return {"status": "not_initialized"} + + return { + "status": "active", + "max_connections": POOL_MAX_CONNECTIONS, + "url": self._url.split("@")[-1] if "@" in self._url else self._url, + } + + +# Global Redis client instance +redis_client = RedisClient() + + +# FastAPI dependency for Redis client +async def get_redis() -> AsyncGenerator[RedisClient, None]: + """ + FastAPI dependency that provides the Redis client. + + Usage: + @router.get("/cached-data") + async def get_data(redis: RedisClient = Depends(get_redis)): + cached = await redis.cache_get("my-key") + ... + """ + yield redis_client + + +# Health check function for use in /health endpoint +async def check_redis_health() -> bool: + """ + Check if Redis connection is healthy. + + Returns: + True if connection is successful, False otherwise. + """ + return await redis_client.health_check() + + +# Cleanup function for application shutdown +async def close_redis() -> None: + """ + Close Redis connections. + + Should be called during application shutdown. + """ + await redis_client.close() diff --git a/backend/tests/core/test_redis.py b/backend/tests/core/test_redis.py new file mode 100644 index 0000000..58214e8 --- /dev/null +++ b/backend/tests/core/test_redis.py @@ -0,0 +1,784 @@ +""" +Tests for Redis client utility functions (app/core/redis.py). + +Covers: +- Cache operations (get, set, delete, expire) +- JSON serialization helpers +- Pub/sub operations +- Health check +- Connection pooling +- Error handling +""" + +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from redis.exceptions import ConnectionError, RedisError, TimeoutError + +from app.core.redis import ( + DEFAULT_CACHE_TTL, + POOL_MAX_CONNECTIONS, + RedisClient, + check_redis_health, + close_redis, + get_redis, + redis_client, +) + + +class TestRedisClientInit: + """Test RedisClient initialization.""" + + def test_default_url_from_settings(self): + """Test that default URL comes from settings.""" + with patch("app.core.redis.settings") as mock_settings: + mock_settings.REDIS_URL = "redis://test:6379/0" + client = RedisClient() + assert client._url == "redis://test:6379/0" + + def test_custom_url_override(self): + """Test that custom URL overrides settings.""" + client = RedisClient(url="redis://custom:6379/1") + assert client._url == "redis://custom:6379/1" + + def test_initial_state(self): + """Test initial client state.""" + client = RedisClient(url="redis://localhost:6379/0") + assert client._pool is None + assert client._client is None + + +class TestCacheOperations: + """Test cache get/set/delete operations.""" + + @pytest.mark.asyncio + async def test_cache_set_success(self): + """Test setting a cache value.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + mock_redis.set = AsyncMock(return_value=True) + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.cache_set("test-key", "test-value", ttl=60) + + assert result is True + mock_redis.set.assert_called_once_with("test-key", "test-value", ex=60) + + @pytest.mark.asyncio + async def test_cache_set_default_ttl(self): + """Test setting a cache value with default TTL.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + mock_redis.set = AsyncMock(return_value=True) + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.cache_set("test-key", "test-value") + + assert result is True + mock_redis.set.assert_called_once_with( + "test-key", "test-value", ex=DEFAULT_CACHE_TTL + ) + + @pytest.mark.asyncio + async def test_cache_set_connection_error(self): + """Test cache_set handles connection errors.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + mock_redis.set = AsyncMock(side_effect=ConnectionError("Connection refused")) + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.cache_set("test-key", "test-value") + + assert result is False + + @pytest.mark.asyncio + async def test_cache_set_timeout_error(self): + """Test cache_set handles timeout errors.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + mock_redis.set = AsyncMock(side_effect=TimeoutError("Timeout")) + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.cache_set("test-key", "test-value") + + assert result is False + + @pytest.mark.asyncio + async def test_cache_set_redis_error(self): + """Test cache_set handles generic Redis errors.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + mock_redis.set = AsyncMock(side_effect=RedisError("Unknown error")) + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.cache_set("test-key", "test-value") + + assert result is False + + @pytest.mark.asyncio + async def test_cache_get_success(self): + """Test getting a cached value.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + mock_redis.get = AsyncMock(return_value="cached-value") + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.cache_get("test-key") + + assert result == "cached-value" + mock_redis.get.assert_called_once_with("test-key") + + @pytest.mark.asyncio + async def test_cache_get_miss(self): + """Test cache miss returns None.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + mock_redis.get = AsyncMock(return_value=None) + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.cache_get("nonexistent-key") + + assert result is None + + @pytest.mark.asyncio + async def test_cache_get_connection_error(self): + """Test cache_get handles connection errors.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + mock_redis.get = AsyncMock(side_effect=ConnectionError("Connection refused")) + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.cache_get("test-key") + + assert result is None + + @pytest.mark.asyncio + async def test_cache_delete_success(self): + """Test deleting a cache key.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + mock_redis.delete = AsyncMock(return_value=1) + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.cache_delete("test-key") + + assert result is True + mock_redis.delete.assert_called_once_with("test-key") + + @pytest.mark.asyncio + async def test_cache_delete_nonexistent_key(self): + """Test deleting a nonexistent key returns False.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + mock_redis.delete = AsyncMock(return_value=0) + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.cache_delete("nonexistent-key") + + assert result is False + + @pytest.mark.asyncio + async def test_cache_delete_connection_error(self): + """Test cache_delete handles connection errors.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + mock_redis.delete = AsyncMock(side_effect=ConnectionError("Connection refused")) + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.cache_delete("test-key") + + assert result is False + + +class TestCacheDeletePattern: + """Test cache_delete_pattern operation.""" + + @pytest.mark.asyncio + async def test_cache_delete_pattern_success(self): + """Test deleting keys by pattern.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + mock_redis.delete = AsyncMock(return_value=1) + + # Create async iterator for scan_iter + async def mock_scan_iter(pattern): + for key in ["user:1", "user:2", "user:3"]: + yield key + + mock_redis.scan_iter = mock_scan_iter + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.cache_delete_pattern("user:*") + + assert result == 3 + assert mock_redis.delete.call_count == 3 + + @pytest.mark.asyncio + async def test_cache_delete_pattern_no_matches(self): + """Test deleting pattern with no matches.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + + async def mock_scan_iter(pattern): + if False: # Empty iterator + yield + + mock_redis.scan_iter = mock_scan_iter + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.cache_delete_pattern("nonexistent:*") + + assert result == 0 + + @pytest.mark.asyncio + async def test_cache_delete_pattern_error(self): + """Test cache_delete_pattern handles errors.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + + async def mock_scan_iter(pattern): + raise ConnectionError("Connection lost") + if False: # Make it a generator + yield + + mock_redis.scan_iter = mock_scan_iter + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.cache_delete_pattern("user:*") + + assert result == 0 + + +class TestCacheExpire: + """Test cache_expire operation.""" + + @pytest.mark.asyncio + async def test_cache_expire_success(self): + """Test setting TTL on existing key.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + mock_redis.expire = AsyncMock(return_value=True) + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.cache_expire("test-key", 120) + + assert result is True + mock_redis.expire.assert_called_once_with("test-key", 120) + + @pytest.mark.asyncio + async def test_cache_expire_nonexistent_key(self): + """Test setting TTL on nonexistent key.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + mock_redis.expire = AsyncMock(return_value=False) + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.cache_expire("nonexistent-key", 120) + + assert result is False + + @pytest.mark.asyncio + async def test_cache_expire_error(self): + """Test cache_expire handles errors.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + mock_redis.expire = AsyncMock(side_effect=ConnectionError("Error")) + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.cache_expire("test-key", 120) + + assert result is False + + +class TestCacheHelpers: + """Test cache helper methods (exists, ttl).""" + + @pytest.mark.asyncio + async def test_cache_exists_true(self): + """Test cache_exists returns True for existing key.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + mock_redis.exists = AsyncMock(return_value=1) + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.cache_exists("test-key") + + assert result is True + + @pytest.mark.asyncio + async def test_cache_exists_false(self): + """Test cache_exists returns False for nonexistent key.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + mock_redis.exists = AsyncMock(return_value=0) + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.cache_exists("nonexistent-key") + + assert result is False + + @pytest.mark.asyncio + async def test_cache_exists_error(self): + """Test cache_exists handles errors.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + mock_redis.exists = AsyncMock(side_effect=ConnectionError("Error")) + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.cache_exists("test-key") + + assert result is False + + @pytest.mark.asyncio + async def test_cache_ttl_with_ttl(self): + """Test cache_ttl returns remaining TTL.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + mock_redis.ttl = AsyncMock(return_value=300) + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.cache_ttl("test-key") + + assert result == 300 + + @pytest.mark.asyncio + async def test_cache_ttl_no_ttl(self): + """Test cache_ttl returns -1 for key without TTL.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + mock_redis.ttl = AsyncMock(return_value=-1) + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.cache_ttl("test-key") + + assert result == -1 + + @pytest.mark.asyncio + async def test_cache_ttl_nonexistent_key(self): + """Test cache_ttl returns -2 for nonexistent key.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + mock_redis.ttl = AsyncMock(return_value=-2) + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.cache_ttl("nonexistent-key") + + assert result == -2 + + @pytest.mark.asyncio + async def test_cache_ttl_error(self): + """Test cache_ttl handles errors.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + mock_redis.ttl = AsyncMock(side_effect=ConnectionError("Error")) + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.cache_ttl("test-key") + + assert result == -2 + + +class TestJsonOperations: + """Test JSON serialization cache operations.""" + + @pytest.mark.asyncio + async def test_cache_set_json_success(self): + """Test setting a JSON value in cache.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + mock_redis.set = AsyncMock(return_value=True) + + with patch.object(client, "_get_client", return_value=mock_redis): + data = {"user": "test", "count": 42} + result = await client.cache_set_json("test-key", data, ttl=60) + + assert result is True + mock_redis.set.assert_called_once() + # Verify JSON was serialized + call_args = mock_redis.set.call_args + assert call_args[0][1] == json.dumps(data) + + @pytest.mark.asyncio + async def test_cache_set_json_serialization_error(self): + """Test cache_set_json handles serialization errors.""" + client = RedisClient(url="redis://localhost:6379/0") + + # Object that can't be serialized + class NonSerializable: + pass + + result = await client.cache_set_json("test-key", NonSerializable()) + assert result is False + + @pytest.mark.asyncio + async def test_cache_get_json_success(self): + """Test getting a JSON value from cache.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + data = {"user": "test", "count": 42} + mock_redis.get = AsyncMock(return_value=json.dumps(data)) + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.cache_get_json("test-key") + + assert result == data + + @pytest.mark.asyncio + async def test_cache_get_json_miss(self): + """Test cache_get_json returns None on cache miss.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + mock_redis.get = AsyncMock(return_value=None) + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.cache_get_json("nonexistent-key") + + assert result is None + + @pytest.mark.asyncio + async def test_cache_get_json_invalid_json(self): + """Test cache_get_json handles invalid JSON.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + mock_redis.get = AsyncMock(return_value="not valid json {{{") + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.cache_get_json("test-key") + + assert result is None + + +class TestPubSubOperations: + """Test pub/sub operations.""" + + @pytest.mark.asyncio + async def test_publish_string_message(self): + """Test publishing a string message.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + mock_redis.publish = AsyncMock(return_value=2) + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.publish("test-channel", "hello world") + + assert result == 2 + mock_redis.publish.assert_called_once_with("test-channel", "hello world") + + @pytest.mark.asyncio + async def test_publish_dict_message(self): + """Test publishing a dict message (JSON serialized).""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + mock_redis.publish = AsyncMock(return_value=1) + + with patch.object(client, "_get_client", return_value=mock_redis): + data = {"event": "user_created", "user_id": 123} + result = await client.publish("events", data) + + assert result == 1 + mock_redis.publish.assert_called_once_with("events", json.dumps(data)) + + @pytest.mark.asyncio + async def test_publish_connection_error(self): + """Test publish handles connection errors.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + mock_redis.publish = AsyncMock(side_effect=ConnectionError("Connection lost")) + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.publish("test-channel", "hello") + + assert result == 0 + + @pytest.mark.asyncio + async def test_subscribe_context_manager(self): + """Test subscribe context manager.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_pubsub = AsyncMock() + mock_pubsub.subscribe = AsyncMock() + mock_pubsub.unsubscribe = AsyncMock() + mock_pubsub.close = AsyncMock() + + mock_redis = AsyncMock() + mock_redis.pubsub = MagicMock(return_value=mock_pubsub) + + with patch.object(client, "_get_client", return_value=mock_redis): + async with client.subscribe("channel1", "channel2") as pubsub: + assert pubsub is mock_pubsub + mock_pubsub.subscribe.assert_called_once_with("channel1", "channel2") + + # After exiting context, should unsubscribe and close + mock_pubsub.unsubscribe.assert_called_once_with("channel1", "channel2") + mock_pubsub.close.assert_called_once() + + @pytest.mark.asyncio + async def test_psubscribe_context_manager(self): + """Test pattern subscribe context manager.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_pubsub = AsyncMock() + mock_pubsub.psubscribe = AsyncMock() + mock_pubsub.punsubscribe = AsyncMock() + mock_pubsub.close = AsyncMock() + + mock_redis = AsyncMock() + mock_redis.pubsub = MagicMock(return_value=mock_pubsub) + + with patch.object(client, "_get_client", return_value=mock_redis): + async with client.psubscribe("user:*", "event:*") as pubsub: + assert pubsub is mock_pubsub + mock_pubsub.psubscribe.assert_called_once_with("user:*", "event:*") + + mock_pubsub.punsubscribe.assert_called_once_with("user:*", "event:*") + mock_pubsub.close.assert_called_once() + + +class TestHealthCheck: + """Test health check functionality.""" + + @pytest.mark.asyncio + async def test_health_check_success(self): + """Test health check returns True when Redis is healthy.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + mock_redis.ping = AsyncMock(return_value=True) + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.health_check() + + assert result is True + mock_redis.ping.assert_called_once() + + @pytest.mark.asyncio + async def test_health_check_connection_error(self): + """Test health check returns False on connection error.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + mock_redis.ping = AsyncMock(side_effect=ConnectionError("Connection refused")) + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.health_check() + + assert result is False + + @pytest.mark.asyncio + async def test_health_check_timeout_error(self): + """Test health check returns False on timeout.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + mock_redis.ping = AsyncMock(side_effect=TimeoutError("Timeout")) + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.health_check() + + assert result is False + + @pytest.mark.asyncio + async def test_health_check_redis_error(self): + """Test health check returns False on Redis error.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_redis = AsyncMock() + mock_redis.ping = AsyncMock(side_effect=RedisError("Unknown error")) + + with patch.object(client, "_get_client", return_value=mock_redis): + result = await client.health_check() + + assert result is False + + +class TestConnectionPooling: + """Test connection pooling functionality.""" + + @pytest.mark.asyncio + async def test_pool_initialization(self): + """Test that pool is lazily initialized.""" + client = RedisClient(url="redis://localhost:6379/0") + + assert client._pool is None + + with patch("app.core.redis.ConnectionPool") as MockPool: + mock_pool = MagicMock() + MockPool.from_url = MagicMock(return_value=mock_pool) + + pool = await client._ensure_pool() + + assert pool is mock_pool + MockPool.from_url.assert_called_once_with( + "redis://localhost:6379/0", + max_connections=POOL_MAX_CONNECTIONS, + socket_timeout=10, + socket_connect_timeout=10, + decode_responses=True, + health_check_interval=30, + ) + + @pytest.mark.asyncio + async def test_pool_reuses_existing(self): + """Test that pool is reused after initialization.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_pool = MagicMock() + client._pool = mock_pool + + pool = await client._ensure_pool() + + assert pool is mock_pool + + @pytest.mark.asyncio + async def test_close_disposes_resources(self): + """Test that close() disposes pool and client.""" + client = RedisClient(url="redis://localhost:6379/0") + + mock_client = AsyncMock() + mock_pool = AsyncMock() + mock_pool.disconnect = AsyncMock() + + client._client = mock_client + client._pool = mock_pool + + await client.close() + + mock_client.close.assert_called_once() + mock_pool.disconnect.assert_called_once() + assert client._client is None + assert client._pool is None + + @pytest.mark.asyncio + async def test_close_handles_none(self): + """Test that close() handles None client and pool gracefully.""" + client = RedisClient(url="redis://localhost:6379/0") + + # Should not raise + await client.close() + + assert client._client is None + assert client._pool is None + + @pytest.mark.asyncio + async def test_get_pool_info_not_initialized(self): + """Test pool info when not initialized.""" + client = RedisClient(url="redis://localhost:6379/0") + + info = await client.get_pool_info() + + assert info == {"status": "not_initialized"} + + @pytest.mark.asyncio + async def test_get_pool_info_active(self): + """Test pool info when active.""" + client = RedisClient(url="redis://user:pass@localhost:6379/0") + + mock_pool = MagicMock() + client._pool = mock_pool + + info = await client.get_pool_info() + + assert info["status"] == "active" + assert info["max_connections"] == POOL_MAX_CONNECTIONS + # Password should be hidden + assert "pass" not in info["url"] + assert "localhost:6379/0" in info["url"] + + +class TestModuleLevelFunctions: + """Test module-level convenience functions.""" + + @pytest.mark.asyncio + async def test_get_redis_dependency(self): + """Test get_redis FastAPI dependency.""" + redis_gen = get_redis() + + client = await redis_gen.__anext__() + assert client is redis_client + + # Cleanup + with pytest.raises(StopAsyncIteration): + await redis_gen.__anext__() + + @pytest.mark.asyncio + async def test_check_redis_health(self): + """Test module-level check_redis_health function.""" + with patch.object(redis_client, "health_check", return_value=True) as mock: + result = await check_redis_health() + + assert result is True + mock.assert_called_once() + + @pytest.mark.asyncio + async def test_close_redis(self): + """Test module-level close_redis function.""" + with patch.object(redis_client, "close") as mock: + await close_redis() + + mock.assert_called_once() + + +class TestThreadSafety: + """Test thread-safety of pool initialization.""" + + @pytest.mark.asyncio + async def test_concurrent_pool_initialization(self): + """Test that concurrent _ensure_pool calls create only one pool.""" + client = RedisClient(url="redis://localhost:6379/0") + + call_count = 0 + mock_pool = MagicMock() + + def counting_from_url(*args, **kwargs): + nonlocal call_count + call_count += 1 + return mock_pool + + with patch("app.core.redis.ConnectionPool") as MockPool: + MockPool.from_url = MagicMock(side_effect=counting_from_url) + + # Start multiple concurrent _ensure_pool calls + results = await asyncio.gather( + client._ensure_pool(), + client._ensure_pool(), + client._ensure_pool(), + ) + + # All results should be the same pool instance + assert results[0] is results[1] is results[2] + assert results[0] is mock_pool + # Pool should only be created once despite concurrent calls + assert call_count == 1