Reformatted multiline function calls, object definitions, and queries for improved code readability and consistency. Adjusted imports and constraints where necessary.
475 lines
15 KiB
Python
475 lines
15 KiB
Python
# app/core/redis.py
|
|
"""
|
|
Redis client configuration for caching and pub/sub.
|
|
|
|
This module provides async Redis connectivity with connection pooling
|
|
for FastAPI endpoints and background tasks.
|
|
|
|
Features:
|
|
- Connection pooling for efficient resource usage
|
|
- Cache operations (get, set, delete, expire)
|
|
- Pub/sub operations (publish, subscribe)
|
|
- Health check for monitoring
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from collections.abc import AsyncGenerator
|
|
from contextlib import asynccontextmanager
|
|
from typing import Any
|
|
|
|
from redis.asyncio import ConnectionPool, Redis
|
|
from redis.asyncio.client import PubSub
|
|
from redis.exceptions import ConnectionError, RedisError, TimeoutError
|
|
|
|
from app.core.config import settings
|
|
|
|
# Configure logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Default TTL for cache entries (1 hour)
|
|
DEFAULT_CACHE_TTL = 3600
|
|
|
|
# Connection pool settings
|
|
POOL_MAX_CONNECTIONS = 50
|
|
POOL_TIMEOUT = 10 # seconds
|
|
|
|
|
|
class RedisClient:
|
|
"""
|
|
Async Redis client with connection pooling.
|
|
|
|
Provides high-level operations for caching and pub/sub
|
|
with proper error handling and connection management.
|
|
"""
|
|
|
|
def __init__(self, url: str | None = None) -> None:
|
|
"""
|
|
Initialize Redis client.
|
|
|
|
Args:
|
|
url: Redis connection URL. Defaults to settings.REDIS_URL.
|
|
"""
|
|
self._url = url or settings.REDIS_URL
|
|
self._pool: ConnectionPool | None = None
|
|
self._client: Redis | None = None
|
|
self._lock = asyncio.Lock()
|
|
|
|
async def _ensure_pool(self) -> ConnectionPool:
|
|
"""Ensure connection pool is initialized (thread-safe)."""
|
|
if self._pool is None:
|
|
async with self._lock:
|
|
# Double-check after acquiring lock
|
|
if self._pool is None:
|
|
self._pool = ConnectionPool.from_url(
|
|
self._url,
|
|
max_connections=POOL_MAX_CONNECTIONS,
|
|
socket_timeout=POOL_TIMEOUT,
|
|
socket_connect_timeout=POOL_TIMEOUT,
|
|
decode_responses=True,
|
|
health_check_interval=30,
|
|
)
|
|
logger.info("Redis connection pool initialized")
|
|
return self._pool
|
|
|
|
async def _get_client(self) -> Redis:
|
|
"""Get Redis client instance from pool."""
|
|
pool = await self._ensure_pool()
|
|
if self._client is None:
|
|
self._client = Redis(connection_pool=pool)
|
|
return self._client
|
|
|
|
# =========================================================================
|
|
# Cache Operations
|
|
# =========================================================================
|
|
|
|
async def cache_get(self, key: str) -> str | None:
|
|
"""
|
|
Get a value from cache.
|
|
|
|
Args:
|
|
key: Cache key.
|
|
|
|
Returns:
|
|
Cached value or None if not found.
|
|
"""
|
|
try:
|
|
client = await self._get_client()
|
|
value = await client.get(key)
|
|
if value is not None:
|
|
logger.debug(f"Cache hit for key: {key}")
|
|
else:
|
|
logger.debug(f"Cache miss for key: {key}")
|
|
return value
|
|
except (ConnectionError, TimeoutError) as e:
|
|
logger.error(f"Redis cache_get failed for key '{key}': {e}")
|
|
return None
|
|
except RedisError as e:
|
|
logger.error(f"Redis error in cache_get for key '{key}': {e}")
|
|
return None
|
|
|
|
async def cache_get_json(self, key: str) -> Any | None:
|
|
"""
|
|
Get a JSON-serialized value from cache.
|
|
|
|
Args:
|
|
key: Cache key.
|
|
|
|
Returns:
|
|
Deserialized value or None if not found.
|
|
"""
|
|
value = await self.cache_get(key)
|
|
if value is not None:
|
|
try:
|
|
return json.loads(value)
|
|
except json.JSONDecodeError as e:
|
|
logger.error(f"Failed to decode JSON for key '{key}': {e}")
|
|
return None
|
|
return None
|
|
|
|
async def cache_set(
|
|
self,
|
|
key: str,
|
|
value: str,
|
|
ttl: int | None = None,
|
|
) -> bool:
|
|
"""
|
|
Set a value in cache.
|
|
|
|
Args:
|
|
key: Cache key.
|
|
value: Value to cache.
|
|
ttl: Time-to-live in seconds. Defaults to DEFAULT_CACHE_TTL.
|
|
|
|
Returns:
|
|
True if successful, False otherwise.
|
|
"""
|
|
try:
|
|
client = await self._get_client()
|
|
ttl = ttl if ttl is not None else DEFAULT_CACHE_TTL
|
|
await client.set(key, value, ex=ttl)
|
|
logger.debug(f"Cache set for key: {key} (TTL: {ttl}s)")
|
|
return True
|
|
except (ConnectionError, TimeoutError) as e:
|
|
logger.error(f"Redis cache_set failed for key '{key}': {e}")
|
|
return False
|
|
except RedisError as e:
|
|
logger.error(f"Redis error in cache_set for key '{key}': {e}")
|
|
return False
|
|
|
|
async def cache_set_json(
|
|
self,
|
|
key: str,
|
|
value: Any,
|
|
ttl: int | None = None,
|
|
) -> bool:
|
|
"""
|
|
Set a JSON-serialized value in cache.
|
|
|
|
Args:
|
|
key: Cache key.
|
|
value: Value to serialize and cache.
|
|
ttl: Time-to-live in seconds.
|
|
|
|
Returns:
|
|
True if successful, False otherwise.
|
|
"""
|
|
try:
|
|
serialized = json.dumps(value)
|
|
return await self.cache_set(key, serialized, ttl)
|
|
except (TypeError, ValueError) as e:
|
|
logger.error(f"Failed to serialize value for key '{key}': {e}")
|
|
return False
|
|
|
|
async def cache_delete(self, key: str) -> bool:
|
|
"""
|
|
Delete a key from cache.
|
|
|
|
Args:
|
|
key: Cache key to delete.
|
|
|
|
Returns:
|
|
True if key was deleted, False otherwise.
|
|
"""
|
|
try:
|
|
client = await self._get_client()
|
|
result = await client.delete(key)
|
|
logger.debug(f"Cache delete for key: {key} (deleted: {result > 0})")
|
|
return result > 0
|
|
except (ConnectionError, TimeoutError) as e:
|
|
logger.error(f"Redis cache_delete failed for key '{key}': {e}")
|
|
return False
|
|
except RedisError as e:
|
|
logger.error(f"Redis error in cache_delete for key '{key}': {e}")
|
|
return False
|
|
|
|
async def cache_delete_pattern(self, pattern: str) -> int:
|
|
"""
|
|
Delete all keys matching a pattern.
|
|
|
|
Args:
|
|
pattern: Glob-style pattern (e.g., "user:*").
|
|
|
|
Returns:
|
|
Number of keys deleted.
|
|
"""
|
|
try:
|
|
client = await self._get_client()
|
|
deleted = 0
|
|
async for key in client.scan_iter(pattern):
|
|
await client.delete(key)
|
|
deleted += 1
|
|
logger.debug(f"Cache delete pattern '{pattern}': {deleted} keys deleted")
|
|
return deleted
|
|
except (ConnectionError, TimeoutError) as e:
|
|
logger.error(f"Redis cache_delete_pattern failed for '{pattern}': {e}")
|
|
return 0
|
|
except RedisError as e:
|
|
logger.error(f"Redis error in cache_delete_pattern for '{pattern}': {e}")
|
|
return 0
|
|
|
|
async def cache_expire(self, key: str, ttl: int) -> bool:
|
|
"""
|
|
Set or update TTL for a key.
|
|
|
|
Args:
|
|
key: Cache key.
|
|
ttl: New TTL in seconds.
|
|
|
|
Returns:
|
|
True if TTL was set, False if key doesn't exist.
|
|
"""
|
|
try:
|
|
client = await self._get_client()
|
|
result = await client.expire(key, ttl)
|
|
logger.debug(
|
|
f"Cache expire for key: {key} (TTL: {ttl}s, success: {result})"
|
|
)
|
|
return result
|
|
except (ConnectionError, TimeoutError) as e:
|
|
logger.error(f"Redis cache_expire failed for key '{key}': {e}")
|
|
return False
|
|
except RedisError as e:
|
|
logger.error(f"Redis error in cache_expire for key '{key}': {e}")
|
|
return False
|
|
|
|
async def cache_exists(self, key: str) -> bool:
|
|
"""
|
|
Check if a key exists in cache.
|
|
|
|
Args:
|
|
key: Cache key.
|
|
|
|
Returns:
|
|
True if key exists, False otherwise.
|
|
"""
|
|
try:
|
|
client = await self._get_client()
|
|
result = await client.exists(key)
|
|
return result > 0
|
|
except (ConnectionError, TimeoutError) as e:
|
|
logger.error(f"Redis cache_exists failed for key '{key}': {e}")
|
|
return False
|
|
except RedisError as e:
|
|
logger.error(f"Redis error in cache_exists for key '{key}': {e}")
|
|
return False
|
|
|
|
async def cache_ttl(self, key: str) -> int:
|
|
"""
|
|
Get remaining TTL for a key.
|
|
|
|
Args:
|
|
key: Cache key.
|
|
|
|
Returns:
|
|
TTL in seconds, -1 if no TTL, -2 if key doesn't exist.
|
|
"""
|
|
try:
|
|
client = await self._get_client()
|
|
return await client.ttl(key)
|
|
except (ConnectionError, TimeoutError) as e:
|
|
logger.error(f"Redis cache_ttl failed for key '{key}': {e}")
|
|
return -2
|
|
except RedisError as e:
|
|
logger.error(f"Redis error in cache_ttl for key '{key}': {e}")
|
|
return -2
|
|
|
|
# =========================================================================
|
|
# Pub/Sub Operations
|
|
# =========================================================================
|
|
|
|
async def publish(self, channel: str, message: str | dict) -> int:
|
|
"""
|
|
Publish a message to a channel.
|
|
|
|
Args:
|
|
channel: Channel name.
|
|
message: Message to publish (string or dict for JSON serialization).
|
|
|
|
Returns:
|
|
Number of subscribers that received the message.
|
|
"""
|
|
try:
|
|
client = await self._get_client()
|
|
if isinstance(message, dict):
|
|
message = json.dumps(message)
|
|
result = await client.publish(channel, message)
|
|
logger.debug(f"Published to channel '{channel}': {result} subscribers")
|
|
return result
|
|
except (ConnectionError, TimeoutError) as e:
|
|
logger.error(f"Redis publish failed for channel '{channel}': {e}")
|
|
return 0
|
|
except RedisError as e:
|
|
logger.error(f"Redis error in publish for channel '{channel}': {e}")
|
|
return 0
|
|
|
|
@asynccontextmanager
|
|
async def subscribe(self, *channels: str) -> AsyncGenerator[PubSub, None]:
|
|
"""
|
|
Subscribe to one or more channels.
|
|
|
|
Usage:
|
|
async with redis_client.subscribe("channel1", "channel2") as pubsub:
|
|
async for message in pubsub.listen():
|
|
if message["type"] == "message":
|
|
print(message["data"])
|
|
|
|
Args:
|
|
channels: Channel names to subscribe to.
|
|
|
|
Yields:
|
|
PubSub instance for receiving messages.
|
|
"""
|
|
client = await self._get_client()
|
|
pubsub = client.pubsub()
|
|
try:
|
|
await pubsub.subscribe(*channels)
|
|
logger.debug(f"Subscribed to channels: {channels}")
|
|
yield pubsub
|
|
finally:
|
|
await pubsub.unsubscribe(*channels)
|
|
await pubsub.close()
|
|
logger.debug(f"Unsubscribed from channels: {channels}")
|
|
|
|
@asynccontextmanager
|
|
async def psubscribe(self, *patterns: str) -> AsyncGenerator[PubSub, None]:
|
|
"""
|
|
Subscribe to channels matching patterns.
|
|
|
|
Usage:
|
|
async with redis_client.psubscribe("user:*") as pubsub:
|
|
async for message in pubsub.listen():
|
|
if message["type"] == "pmessage":
|
|
print(message["pattern"], message["channel"], message["data"])
|
|
|
|
Args:
|
|
patterns: Glob-style patterns to subscribe to.
|
|
|
|
Yields:
|
|
PubSub instance for receiving messages.
|
|
"""
|
|
client = await self._get_client()
|
|
pubsub = client.pubsub()
|
|
try:
|
|
await pubsub.psubscribe(*patterns)
|
|
logger.debug(f"Pattern subscribed: {patterns}")
|
|
yield pubsub
|
|
finally:
|
|
await pubsub.punsubscribe(*patterns)
|
|
await pubsub.close()
|
|
logger.debug(f"Pattern unsubscribed: {patterns}")
|
|
|
|
# =========================================================================
|
|
# Health & Connection Management
|
|
# =========================================================================
|
|
|
|
async def health_check(self) -> bool:
|
|
"""
|
|
Check if Redis connection is healthy.
|
|
|
|
Returns:
|
|
True if connection is successful, False otherwise.
|
|
"""
|
|
try:
|
|
client = await self._get_client()
|
|
result = await client.ping()
|
|
return result is True
|
|
except (ConnectionError, TimeoutError) as e:
|
|
logger.error(f"Redis health check failed: {e}")
|
|
return False
|
|
except RedisError as e:
|
|
logger.error(f"Redis health check error: {e}")
|
|
return False
|
|
|
|
async def close(self) -> None:
|
|
"""
|
|
Close Redis connections and cleanup resources.
|
|
|
|
Should be called during application shutdown.
|
|
"""
|
|
if self._client:
|
|
await self._client.close()
|
|
self._client = None
|
|
logger.debug("Redis client closed")
|
|
|
|
if self._pool:
|
|
await self._pool.disconnect()
|
|
self._pool = None
|
|
logger.info("Redis connection pool closed")
|
|
|
|
async def get_pool_info(self) -> dict[str, Any]:
|
|
"""
|
|
Get connection pool statistics.
|
|
|
|
Returns:
|
|
Dictionary with pool information.
|
|
"""
|
|
if self._pool is None:
|
|
return {"status": "not_initialized"}
|
|
|
|
return {
|
|
"status": "active",
|
|
"max_connections": POOL_MAX_CONNECTIONS,
|
|
"url": self._url.split("@")[-1] if "@" in self._url else self._url,
|
|
}
|
|
|
|
|
|
# Global Redis client instance
|
|
redis_client = RedisClient()
|
|
|
|
|
|
# FastAPI dependency for Redis client
|
|
async def get_redis() -> AsyncGenerator[RedisClient, None]:
|
|
"""
|
|
FastAPI dependency that provides the Redis client.
|
|
|
|
Usage:
|
|
@router.get("/cached-data")
|
|
async def get_data(redis: RedisClient = Depends(get_redis)):
|
|
cached = await redis.cache_get("my-key")
|
|
...
|
|
"""
|
|
yield redis_client
|
|
|
|
|
|
# Health check function for use in /health endpoint
|
|
async def check_redis_health() -> bool:
|
|
"""
|
|
Check if Redis connection is healthy.
|
|
|
|
Returns:
|
|
True if connection is successful, False otherwise.
|
|
"""
|
|
return await redis_client.health_check()
|
|
|
|
|
|
# Cleanup function for application shutdown
|
|
async def close_redis() -> None:
|
|
"""
|
|
Close Redis connections.
|
|
|
|
Should be called during application shutdown.
|
|
"""
|
|
await redis_client.close()
|