Files
fast-next-template/backend/app/core/redis.py
Felipe Cardoso acd18ff694 chore(backend): standardize multiline formatting across modules
Reformatted multiline function calls, object definitions, and queries for improved code readability and consistency. Adjusted imports and constraints where necessary.
2026-01-03 01:35:18 +01:00

475 lines
15 KiB
Python

# app/core/redis.py
"""
Redis client configuration for caching and pub/sub.
This module provides async Redis connectivity with connection pooling
for FastAPI endpoints and background tasks.
Features:
- Connection pooling for efficient resource usage
- Cache operations (get, set, delete, expire)
- Pub/sub operations (publish, subscribe)
- Health check for monitoring
"""
import asyncio
import json
import logging
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import Any
from redis.asyncio import ConnectionPool, Redis
from redis.asyncio.client import PubSub
from redis.exceptions import ConnectionError, RedisError, TimeoutError
from app.core.config import settings
# Configure logging
logger = logging.getLogger(__name__)
# Default TTL for cache entries (1 hour)
DEFAULT_CACHE_TTL = 3600
# Connection pool settings
POOL_MAX_CONNECTIONS = 50
POOL_TIMEOUT = 10 # seconds
class RedisClient:
"""
Async Redis client with connection pooling.
Provides high-level operations for caching and pub/sub
with proper error handling and connection management.
"""
def __init__(self, url: str | None = None) -> None:
"""
Initialize Redis client.
Args:
url: Redis connection URL. Defaults to settings.REDIS_URL.
"""
self._url = url or settings.REDIS_URL
self._pool: ConnectionPool | None = None
self._client: Redis | None = None
self._lock = asyncio.Lock()
async def _ensure_pool(self) -> ConnectionPool:
"""Ensure connection pool is initialized (thread-safe)."""
if self._pool is None:
async with self._lock:
# Double-check after acquiring lock
if self._pool is None:
self._pool = ConnectionPool.from_url(
self._url,
max_connections=POOL_MAX_CONNECTIONS,
socket_timeout=POOL_TIMEOUT,
socket_connect_timeout=POOL_TIMEOUT,
decode_responses=True,
health_check_interval=30,
)
logger.info("Redis connection pool initialized")
return self._pool
async def _get_client(self) -> Redis:
"""Get Redis client instance from pool."""
pool = await self._ensure_pool()
if self._client is None:
self._client = Redis(connection_pool=pool)
return self._client
# =========================================================================
# Cache Operations
# =========================================================================
async def cache_get(self, key: str) -> str | None:
"""
Get a value from cache.
Args:
key: Cache key.
Returns:
Cached value or None if not found.
"""
try:
client = await self._get_client()
value = await client.get(key)
if value is not None:
logger.debug(f"Cache hit for key: {key}")
else:
logger.debug(f"Cache miss for key: {key}")
return value
except (ConnectionError, TimeoutError) as e:
logger.error(f"Redis cache_get failed for key '{key}': {e}")
return None
except RedisError as e:
logger.error(f"Redis error in cache_get for key '{key}': {e}")
return None
async def cache_get_json(self, key: str) -> Any | None:
"""
Get a JSON-serialized value from cache.
Args:
key: Cache key.
Returns:
Deserialized value or None if not found.
"""
value = await self.cache_get(key)
if value is not None:
try:
return json.loads(value)
except json.JSONDecodeError as e:
logger.error(f"Failed to decode JSON for key '{key}': {e}")
return None
return None
async def cache_set(
self,
key: str,
value: str,
ttl: int | None = None,
) -> bool:
"""
Set a value in cache.
Args:
key: Cache key.
value: Value to cache.
ttl: Time-to-live in seconds. Defaults to DEFAULT_CACHE_TTL.
Returns:
True if successful, False otherwise.
"""
try:
client = await self._get_client()
ttl = ttl if ttl is not None else DEFAULT_CACHE_TTL
await client.set(key, value, ex=ttl)
logger.debug(f"Cache set for key: {key} (TTL: {ttl}s)")
return True
except (ConnectionError, TimeoutError) as e:
logger.error(f"Redis cache_set failed for key '{key}': {e}")
return False
except RedisError as e:
logger.error(f"Redis error in cache_set for key '{key}': {e}")
return False
async def cache_set_json(
self,
key: str,
value: Any,
ttl: int | None = None,
) -> bool:
"""
Set a JSON-serialized value in cache.
Args:
key: Cache key.
value: Value to serialize and cache.
ttl: Time-to-live in seconds.
Returns:
True if successful, False otherwise.
"""
try:
serialized = json.dumps(value)
return await self.cache_set(key, serialized, ttl)
except (TypeError, ValueError) as e:
logger.error(f"Failed to serialize value for key '{key}': {e}")
return False
async def cache_delete(self, key: str) -> bool:
"""
Delete a key from cache.
Args:
key: Cache key to delete.
Returns:
True if key was deleted, False otherwise.
"""
try:
client = await self._get_client()
result = await client.delete(key)
logger.debug(f"Cache delete for key: {key} (deleted: {result > 0})")
return result > 0
except (ConnectionError, TimeoutError) as e:
logger.error(f"Redis cache_delete failed for key '{key}': {e}")
return False
except RedisError as e:
logger.error(f"Redis error in cache_delete for key '{key}': {e}")
return False
async def cache_delete_pattern(self, pattern: str) -> int:
"""
Delete all keys matching a pattern.
Args:
pattern: Glob-style pattern (e.g., "user:*").
Returns:
Number of keys deleted.
"""
try:
client = await self._get_client()
deleted = 0
async for key in client.scan_iter(pattern):
await client.delete(key)
deleted += 1
logger.debug(f"Cache delete pattern '{pattern}': {deleted} keys deleted")
return deleted
except (ConnectionError, TimeoutError) as e:
logger.error(f"Redis cache_delete_pattern failed for '{pattern}': {e}")
return 0
except RedisError as e:
logger.error(f"Redis error in cache_delete_pattern for '{pattern}': {e}")
return 0
async def cache_expire(self, key: str, ttl: int) -> bool:
"""
Set or update TTL for a key.
Args:
key: Cache key.
ttl: New TTL in seconds.
Returns:
True if TTL was set, False if key doesn't exist.
"""
try:
client = await self._get_client()
result = await client.expire(key, ttl)
logger.debug(
f"Cache expire for key: {key} (TTL: {ttl}s, success: {result})"
)
return result
except (ConnectionError, TimeoutError) as e:
logger.error(f"Redis cache_expire failed for key '{key}': {e}")
return False
except RedisError as e:
logger.error(f"Redis error in cache_expire for key '{key}': {e}")
return False
async def cache_exists(self, key: str) -> bool:
"""
Check if a key exists in cache.
Args:
key: Cache key.
Returns:
True if key exists, False otherwise.
"""
try:
client = await self._get_client()
result = await client.exists(key)
return result > 0
except (ConnectionError, TimeoutError) as e:
logger.error(f"Redis cache_exists failed for key '{key}': {e}")
return False
except RedisError as e:
logger.error(f"Redis error in cache_exists for key '{key}': {e}")
return False
async def cache_ttl(self, key: str) -> int:
"""
Get remaining TTL for a key.
Args:
key: Cache key.
Returns:
TTL in seconds, -1 if no TTL, -2 if key doesn't exist.
"""
try:
client = await self._get_client()
return await client.ttl(key)
except (ConnectionError, TimeoutError) as e:
logger.error(f"Redis cache_ttl failed for key '{key}': {e}")
return -2
except RedisError as e:
logger.error(f"Redis error in cache_ttl for key '{key}': {e}")
return -2
# =========================================================================
# Pub/Sub Operations
# =========================================================================
async def publish(self, channel: str, message: str | dict) -> int:
"""
Publish a message to a channel.
Args:
channel: Channel name.
message: Message to publish (string or dict for JSON serialization).
Returns:
Number of subscribers that received the message.
"""
try:
client = await self._get_client()
if isinstance(message, dict):
message = json.dumps(message)
result = await client.publish(channel, message)
logger.debug(f"Published to channel '{channel}': {result} subscribers")
return result
except (ConnectionError, TimeoutError) as e:
logger.error(f"Redis publish failed for channel '{channel}': {e}")
return 0
except RedisError as e:
logger.error(f"Redis error in publish for channel '{channel}': {e}")
return 0
@asynccontextmanager
async def subscribe(self, *channels: str) -> AsyncGenerator[PubSub, None]:
"""
Subscribe to one or more channels.
Usage:
async with redis_client.subscribe("channel1", "channel2") as pubsub:
async for message in pubsub.listen():
if message["type"] == "message":
print(message["data"])
Args:
channels: Channel names to subscribe to.
Yields:
PubSub instance for receiving messages.
"""
client = await self._get_client()
pubsub = client.pubsub()
try:
await pubsub.subscribe(*channels)
logger.debug(f"Subscribed to channels: {channels}")
yield pubsub
finally:
await pubsub.unsubscribe(*channels)
await pubsub.close()
logger.debug(f"Unsubscribed from channels: {channels}")
@asynccontextmanager
async def psubscribe(self, *patterns: str) -> AsyncGenerator[PubSub, None]:
"""
Subscribe to channels matching patterns.
Usage:
async with redis_client.psubscribe("user:*") as pubsub:
async for message in pubsub.listen():
if message["type"] == "pmessage":
print(message["pattern"], message["channel"], message["data"])
Args:
patterns: Glob-style patterns to subscribe to.
Yields:
PubSub instance for receiving messages.
"""
client = await self._get_client()
pubsub = client.pubsub()
try:
await pubsub.psubscribe(*patterns)
logger.debug(f"Pattern subscribed: {patterns}")
yield pubsub
finally:
await pubsub.punsubscribe(*patterns)
await pubsub.close()
logger.debug(f"Pattern unsubscribed: {patterns}")
# =========================================================================
# Health & Connection Management
# =========================================================================
async def health_check(self) -> bool:
"""
Check if Redis connection is healthy.
Returns:
True if connection is successful, False otherwise.
"""
try:
client = await self._get_client()
result = await client.ping()
return result is True
except (ConnectionError, TimeoutError) as e:
logger.error(f"Redis health check failed: {e}")
return False
except RedisError as e:
logger.error(f"Redis health check error: {e}")
return False
async def close(self) -> None:
"""
Close Redis connections and cleanup resources.
Should be called during application shutdown.
"""
if self._client:
await self._client.close()
self._client = None
logger.debug("Redis client closed")
if self._pool:
await self._pool.disconnect()
self._pool = None
logger.info("Redis connection pool closed")
async def get_pool_info(self) -> dict[str, Any]:
"""
Get connection pool statistics.
Returns:
Dictionary with pool information.
"""
if self._pool is None:
return {"status": "not_initialized"}
return {
"status": "active",
"max_connections": POOL_MAX_CONNECTIONS,
"url": self._url.split("@")[-1] if "@" in self._url else self._url,
}
# Global Redis client instance
redis_client = RedisClient()
# FastAPI dependency for Redis client
async def get_redis() -> AsyncGenerator[RedisClient, None]:
"""
FastAPI dependency that provides the Redis client.
Usage:
@router.get("/cached-data")
async def get_data(redis: RedisClient = Depends(get_redis)):
cached = await redis.cache_get("my-key")
...
"""
yield redis_client
# Health check function for use in /health endpoint
async def check_redis_health() -> bool:
"""
Check if Redis connection is healthy.
Returns:
True if connection is successful, False otherwise.
"""
return await redis_client.health_check()
# Cleanup function for application shutdown
async def close_redis() -> None:
"""
Close Redis connections.
Should be called during application shutdown.
"""
await redis_client.close()