# app/core/redis.py """ Redis client configuration for caching and pub/sub. This module provides async Redis connectivity with connection pooling for FastAPI endpoints and background tasks. Features: - Connection pooling for efficient resource usage - Cache operations (get, set, delete, expire) - Pub/sub operations (publish, subscribe) - Health check for monitoring """ import asyncio import json import logging from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from typing import Any from redis.asyncio import ConnectionPool, Redis from redis.asyncio.client import PubSub from redis.exceptions import ConnectionError, RedisError, TimeoutError from app.core.config import settings # Configure logging logger = logging.getLogger(__name__) # Default TTL for cache entries (1 hour) DEFAULT_CACHE_TTL = 3600 # Connection pool settings POOL_MAX_CONNECTIONS = 50 POOL_TIMEOUT = 10 # seconds class RedisClient: """ Async Redis client with connection pooling. Provides high-level operations for caching and pub/sub with proper error handling and connection management. """ def __init__(self, url: str | None = None) -> None: """ Initialize Redis client. Args: url: Redis connection URL. Defaults to settings.REDIS_URL. """ self._url = url or settings.REDIS_URL self._pool: ConnectionPool | None = None self._client: Redis | None = None self._lock = asyncio.Lock() async def _ensure_pool(self) -> ConnectionPool: """Ensure connection pool is initialized (thread-safe).""" if self._pool is None: async with self._lock: # Double-check after acquiring lock if self._pool is None: self._pool = ConnectionPool.from_url( self._url, max_connections=POOL_MAX_CONNECTIONS, socket_timeout=POOL_TIMEOUT, socket_connect_timeout=POOL_TIMEOUT, decode_responses=True, health_check_interval=30, ) logger.info("Redis connection pool initialized") return self._pool async def _get_client(self) -> Redis: """Get Redis client instance from pool.""" pool = await self._ensure_pool() if self._client is None: self._client = Redis(connection_pool=pool) return self._client # ========================================================================= # Cache Operations # ========================================================================= async def cache_get(self, key: str) -> str | None: """ Get a value from cache. Args: key: Cache key. Returns: Cached value or None if not found. """ try: client = await self._get_client() value = await client.get(key) if value is not None: logger.debug(f"Cache hit for key: {key}") else: logger.debug(f"Cache miss for key: {key}") return value except (ConnectionError, TimeoutError) as e: logger.error(f"Redis cache_get failed for key '{key}': {e}") return None except RedisError as e: logger.error(f"Redis error in cache_get for key '{key}': {e}") return None async def cache_get_json(self, key: str) -> Any | None: """ Get a JSON-serialized value from cache. Args: key: Cache key. Returns: Deserialized value or None if not found. """ value = await self.cache_get(key) if value is not None: try: return json.loads(value) except json.JSONDecodeError as e: logger.error(f"Failed to decode JSON for key '{key}': {e}") return None return None async def cache_set( self, key: str, value: str, ttl: int | None = None, ) -> bool: """ Set a value in cache. Args: key: Cache key. value: Value to cache. ttl: Time-to-live in seconds. Defaults to DEFAULT_CACHE_TTL. Returns: True if successful, False otherwise. """ try: client = await self._get_client() ttl = ttl if ttl is not None else DEFAULT_CACHE_TTL await client.set(key, value, ex=ttl) logger.debug(f"Cache set for key: {key} (TTL: {ttl}s)") return True except (ConnectionError, TimeoutError) as e: logger.error(f"Redis cache_set failed for key '{key}': {e}") return False except RedisError as e: logger.error(f"Redis error in cache_set for key '{key}': {e}") return False async def cache_set_json( self, key: str, value: Any, ttl: int | None = None, ) -> bool: """ Set a JSON-serialized value in cache. Args: key: Cache key. value: Value to serialize and cache. ttl: Time-to-live in seconds. Returns: True if successful, False otherwise. """ try: serialized = json.dumps(value) return await self.cache_set(key, serialized, ttl) except (TypeError, ValueError) as e: logger.error(f"Failed to serialize value for key '{key}': {e}") return False async def cache_delete(self, key: str) -> bool: """ Delete a key from cache. Args: key: Cache key to delete. Returns: True if key was deleted, False otherwise. """ try: client = await self._get_client() result = await client.delete(key) logger.debug(f"Cache delete for key: {key} (deleted: {result > 0})") return result > 0 except (ConnectionError, TimeoutError) as e: logger.error(f"Redis cache_delete failed for key '{key}': {e}") return False except RedisError as e: logger.error(f"Redis error in cache_delete for key '{key}': {e}") return False async def cache_delete_pattern(self, pattern: str) -> int: """ Delete all keys matching a pattern. Args: pattern: Glob-style pattern (e.g., "user:*"). Returns: Number of keys deleted. """ try: client = await self._get_client() deleted = 0 async for key in client.scan_iter(pattern): await client.delete(key) deleted += 1 logger.debug(f"Cache delete pattern '{pattern}': {deleted} keys deleted") return deleted except (ConnectionError, TimeoutError) as e: logger.error(f"Redis cache_delete_pattern failed for '{pattern}': {e}") return 0 except RedisError as e: logger.error(f"Redis error in cache_delete_pattern for '{pattern}': {e}") return 0 async def cache_expire(self, key: str, ttl: int) -> bool: """ Set or update TTL for a key. Args: key: Cache key. ttl: New TTL in seconds. Returns: True if TTL was set, False if key doesn't exist. """ try: client = await self._get_client() result = await client.expire(key, ttl) logger.debug( f"Cache expire for key: {key} (TTL: {ttl}s, success: {result})" ) return result except (ConnectionError, TimeoutError) as e: logger.error(f"Redis cache_expire failed for key '{key}': {e}") return False except RedisError as e: logger.error(f"Redis error in cache_expire for key '{key}': {e}") return False async def cache_exists(self, key: str) -> bool: """ Check if a key exists in cache. Args: key: Cache key. Returns: True if key exists, False otherwise. """ try: client = await self._get_client() result = await client.exists(key) return result > 0 except (ConnectionError, TimeoutError) as e: logger.error(f"Redis cache_exists failed for key '{key}': {e}") return False except RedisError as e: logger.error(f"Redis error in cache_exists for key '{key}': {e}") return False async def cache_ttl(self, key: str) -> int: """ Get remaining TTL for a key. Args: key: Cache key. Returns: TTL in seconds, -1 if no TTL, -2 if key doesn't exist. """ try: client = await self._get_client() return await client.ttl(key) except (ConnectionError, TimeoutError) as e: logger.error(f"Redis cache_ttl failed for key '{key}': {e}") return -2 except RedisError as e: logger.error(f"Redis error in cache_ttl for key '{key}': {e}") return -2 # ========================================================================= # Pub/Sub Operations # ========================================================================= async def publish(self, channel: str, message: str | dict) -> int: """ Publish a message to a channel. Args: channel: Channel name. message: Message to publish (string or dict for JSON serialization). Returns: Number of subscribers that received the message. """ try: client = await self._get_client() if isinstance(message, dict): message = json.dumps(message) result = await client.publish(channel, message) logger.debug(f"Published to channel '{channel}': {result} subscribers") return result except (ConnectionError, TimeoutError) as e: logger.error(f"Redis publish failed for channel '{channel}': {e}") return 0 except RedisError as e: logger.error(f"Redis error in publish for channel '{channel}': {e}") return 0 @asynccontextmanager async def subscribe(self, *channels: str) -> AsyncGenerator[PubSub, None]: """ Subscribe to one or more channels. Usage: async with redis_client.subscribe("channel1", "channel2") as pubsub: async for message in pubsub.listen(): if message["type"] == "message": print(message["data"]) Args: channels: Channel names to subscribe to. Yields: PubSub instance for receiving messages. """ client = await self._get_client() pubsub = client.pubsub() try: await pubsub.subscribe(*channels) logger.debug(f"Subscribed to channels: {channels}") yield pubsub finally: await pubsub.unsubscribe(*channels) await pubsub.close() logger.debug(f"Unsubscribed from channels: {channels}") @asynccontextmanager async def psubscribe(self, *patterns: str) -> AsyncGenerator[PubSub, None]: """ Subscribe to channels matching patterns. Usage: async with redis_client.psubscribe("user:*") as pubsub: async for message in pubsub.listen(): if message["type"] == "pmessage": print(message["pattern"], message["channel"], message["data"]) Args: patterns: Glob-style patterns to subscribe to. Yields: PubSub instance for receiving messages. """ client = await self._get_client() pubsub = client.pubsub() try: await pubsub.psubscribe(*patterns) logger.debug(f"Pattern subscribed: {patterns}") yield pubsub finally: await pubsub.punsubscribe(*patterns) await pubsub.close() logger.debug(f"Pattern unsubscribed: {patterns}") # ========================================================================= # Health & Connection Management # ========================================================================= async def health_check(self) -> bool: """ Check if Redis connection is healthy. Returns: True if connection is successful, False otherwise. """ try: client = await self._get_client() result = await client.ping() return result is True except (ConnectionError, TimeoutError) as e: logger.error(f"Redis health check failed: {e}") return False except RedisError as e: logger.error(f"Redis health check error: {e}") return False async def close(self) -> None: """ Close Redis connections and cleanup resources. Should be called during application shutdown. """ if self._client: await self._client.close() self._client = None logger.debug("Redis client closed") if self._pool: await self._pool.disconnect() self._pool = None logger.info("Redis connection pool closed") async def get_pool_info(self) -> dict[str, Any]: """ Get connection pool statistics. Returns: Dictionary with pool information. """ if self._pool is None: return {"status": "not_initialized"} return { "status": "active", "max_connections": POOL_MAX_CONNECTIONS, "url": self._url.split("@")[-1] if "@" in self._url else self._url, } # Global Redis client instance redis_client = RedisClient() # FastAPI dependency for Redis client async def get_redis() -> AsyncGenerator[RedisClient, None]: """ FastAPI dependency that provides the Redis client. Usage: @router.get("/cached-data") async def get_data(redis: RedisClient = Depends(get_redis)): cached = await redis.cache_get("my-key") ... """ yield redis_client # Health check function for use in /health endpoint async def check_redis_health() -> bool: """ Check if Redis connection is healthy. Returns: True if connection is successful, False otherwise. """ return await redis_client.health_check() # Cleanup function for application shutdown async def close_redis() -> None: """ Close Redis connections. Should be called during application shutdown. """ await redis_client.close()