Refactor authentication services to async password handling; optimize bulk operations and queries

- Updated `verify_password` and `get_password_hash` to their async counterparts to prevent event loop blocking.
- Replaced N+1 query patterns in `admin.py` and `session_async.py` with optimized bulk operations for improved performance.
- Enhanced `user_async.py` with bulk update and soft delete methods for efficient user management.
- Added eager loading support in CRUD operations to prevent N+1 query issues.
- Updated test cases with stronger password examples for better security representation.
This commit is contained in:
Felipe Cardoso
2025-11-01 03:53:22 +01:00
parent 819f3ba963
commit 3fe5d301f8
17 changed files with 397 additions and 163 deletions

View File

@@ -13,6 +13,7 @@ from pydantic import BaseModel
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
from sqlalchemy.orm import Load
from app.core.database_async import Base
@@ -35,8 +36,29 @@ class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
"""
self.model = model
async def get(self, db: AsyncSession, id: str) -> Optional[ModelType]:
"""Get a single record by ID with UUID validation."""
async def get(
self,
db: AsyncSession,
id: str,
options: Optional[List[Load]] = None
) -> Optional[ModelType]:
"""
Get a single record by ID with UUID validation and optional eager loading.
Args:
db: Database session
id: Record UUID
options: Optional list of SQLAlchemy load options (e.g., joinedload, selectinload)
for eager loading relationships to prevent N+1 queries
Returns:
Model instance or None if not found
Example:
# Eager load user relationship
from sqlalchemy.orm import joinedload
session = await session_crud.get(db, id=session_id, options=[joinedload(UserSession.user)])
"""
# Validate UUID format and convert to UUID object if string
try:
if isinstance(id, uuid.UUID):
@@ -48,18 +70,39 @@ class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
return None
try:
result = await db.execute(
select(self.model).where(self.model.id == uuid_obj)
)
query = select(self.model).where(self.model.id == uuid_obj)
# Apply eager loading options if provided
if options:
for option in options:
query = query.options(option)
result = await db.execute(query)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}")
raise
async def get_multi(
self, db: AsyncSession, *, skip: int = 0, limit: int = 100
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
options: Optional[List[Load]] = None
) -> List[ModelType]:
"""Get multiple records with pagination validation."""
"""
Get multiple records with pagination validation and optional eager loading.
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
options: Optional list of SQLAlchemy load options for eager loading
Returns:
List of model instances
"""
# Validate pagination parameters
if skip < 0:
raise ValueError("skip must be non-negative")
@@ -69,9 +112,14 @@ class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
raise ValueError("Maximum limit is 1000")
try:
result = await db.execute(
select(self.model).offset(skip).limit(limit)
)
query = select(self.model).offset(skip).limit(limit)
# Apply eager loading options if provided
if options:
for option in options:
query = query.options(option)
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}")

View File

@@ -5,7 +5,8 @@ from datetime import datetime, timezone, timedelta
from typing import List, Optional
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import and_, select, update, func
from sqlalchemy import and_, select, update, delete, func
from sqlalchemy.orm import selectinload, joinedload
import logging
from app.crud.base_async import CRUDBaseAsync
@@ -68,15 +69,17 @@ class CRUDSessionAsync(CRUDBaseAsync[UserSession, SessionCreate, SessionUpdate])
db: AsyncSession,
*,
user_id: str,
active_only: bool = True
active_only: bool = True,
with_user: bool = False
) -> List[UserSession]:
"""
Get all sessions for a user.
Get all sessions for a user with optional eager loading.
Args:
db: Database session
user_id: User ID
active_only: If True, return only active sessions
with_user: If True, eager load user relationship to prevent N+1
Returns:
List of UserSession objects
@@ -87,6 +90,10 @@ class CRUDSessionAsync(CRUDBaseAsync[UserSession, SessionCreate, SessionUpdate])
query = select(UserSession).where(UserSession.user_id == user_uuid)
# Add eager loading if requested to prevent N+1 queries
if with_user:
query = query.options(joinedload(UserSession.user))
if active_only:
query = query.where(UserSession.is_active == True)
@@ -286,12 +293,14 @@ class CRUDSessionAsync(CRUDBaseAsync[UserSession, SessionCreate, SessionUpdate])
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
"""
Clean up expired sessions.
Clean up expired sessions using optimized bulk DELETE.
Deletes sessions that are:
- Expired AND inactive
- Older than keep_days
Uses single DELETE query instead of N individual deletes for efficiency.
Args:
db: Database session
keep_days: Keep inactive sessions for this many days (for audit)
@@ -301,28 +310,24 @@ class CRUDSessionAsync(CRUDBaseAsync[UserSession, SessionCreate, SessionUpdate])
"""
try:
cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days)
now = datetime.now(timezone.utc)
# Get sessions to delete
stmt = select(UserSession).where(
# Use bulk DELETE with WHERE clause - single query
stmt = delete(UserSession).where(
and_(
UserSession.is_active == False,
UserSession.expires_at < datetime.now(timezone.utc),
UserSession.expires_at < now,
UserSession.created_at < cutoff_date
)
)
result = await db.execute(stmt)
sessions_to_delete = list(result.scalars().all())
# Delete them
for session in sessions_to_delete:
await db.delete(session)
await db.commit()
count = len(sessions_to_delete)
count = result.rowcount
if count > 0:
logger.info(f"Cleaned up {count} expired sessions")
logger.info(f"Cleaned up {count} expired sessions using bulk DELETE")
return count
except Exception as e:

View File

@@ -1,13 +1,15 @@
# app/crud/user_async.py
"""Async CRUD operations for User model using SQLAlchemy 2.0 patterns."""
from typing import Optional, Union, Dict, Any, List, Tuple
from uuid import UUID
from datetime import datetime, timezone
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import IntegrityError
from sqlalchemy import or_, select
from sqlalchemy import or_, select, update
from app.crud.base_async import CRUDBaseAsync
from app.models.user import User
from app.schemas.users import UserCreate, UserUpdate
from app.core.auth import get_password_hash
from app.core.auth import get_password_hash_async
import logging
logger = logging.getLogger(__name__)
@@ -28,11 +30,14 @@ class CRUDUserAsync(CRUDBaseAsync[User, UserCreate, UserUpdate]):
raise
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
"""Create a new user with password hashing and error handling."""
"""Create a new user with async password hashing and error handling."""
try:
# Hash password asynchronously to avoid blocking event loop
password_hash = await get_password_hash_async(obj_in.password)
db_obj = User(
email=obj_in.email,
password_hash=get_password_hash(obj_in.password),
password_hash=password_hash,
first_name=obj_in.first_name,
last_name=obj_in.last_name,
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None,
@@ -63,15 +68,16 @@ class CRUDUserAsync(CRUDBaseAsync[User, UserCreate, UserUpdate]):
db_obj: User,
obj_in: Union[UserUpdate, Dict[str, Any]]
) -> User:
"""Update user with password hashing if password is updated."""
"""Update user with async password hashing if password is updated."""
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.model_dump(exclude_unset=True)
# Handle password separately if it exists in update data
# Hash password asynchronously to avoid blocking event loop
if "password" in update_data:
update_data["password_hash"] = get_password_hash(update_data["password"])
update_data["password_hash"] = await get_password_hash_async(update_data["password"])
del update_data["password"]
return await super().update(db, db_obj=db_obj, obj_in=update_data)
@@ -157,6 +163,100 @@ class CRUDUserAsync(CRUDBaseAsync[User, UserCreate, UserUpdate]):
logger.error(f"Error retrieving paginated users: {str(e)}")
raise
async def bulk_update_status(
self,
db: AsyncSession,
*,
user_ids: List[UUID],
is_active: bool
) -> int:
"""
Bulk update is_active status for multiple users.
Args:
db: Database session
user_ids: List of user IDs to update
is_active: New active status
Returns:
Number of users updated
"""
try:
if not user_ids:
return 0
# Use UPDATE with WHERE IN for efficiency
stmt = (
update(User)
.where(User.id.in_(user_ids))
.where(User.deleted_at.is_(None)) # Don't update deleted users
.values(is_active=is_active, updated_at=datetime.now(timezone.utc))
)
result = await db.execute(stmt)
await db.commit()
updated_count = result.rowcount
logger.info(f"Bulk updated {updated_count} users to is_active={is_active}")
return updated_count
except Exception as e:
await db.rollback()
logger.error(f"Error bulk updating user status: {str(e)}", exc_info=True)
raise
async def bulk_soft_delete(
self,
db: AsyncSession,
*,
user_ids: List[UUID],
exclude_user_id: Optional[UUID] = None
) -> int:
"""
Bulk soft delete multiple users.
Args:
db: Database session
user_ids: List of user IDs to delete
exclude_user_id: Optional user ID to exclude (e.g., the admin performing the action)
Returns:
Number of users deleted
"""
try:
if not user_ids:
return 0
# Remove excluded user from list
filtered_ids = [uid for uid in user_ids if uid != exclude_user_id]
if not filtered_ids:
return 0
# Use UPDATE with WHERE IN for efficiency
stmt = (
update(User)
.where(User.id.in_(filtered_ids))
.where(User.deleted_at.is_(None)) # Don't re-delete already deleted users
.values(
deleted_at=datetime.now(timezone.utc),
is_active=False,
updated_at=datetime.now(timezone.utc)
)
)
result = await db.execute(stmt)
await db.commit()
deleted_count = result.rowcount
logger.info(f"Bulk soft deleted {deleted_count} users")
return deleted_count
except Exception as e:
await db.rollback()
logger.error(f"Error bulk deleting users: {str(e)}", exc_info=True)
raise
def is_active(self, user: User) -> bool:
"""Check if user is active."""
return user.is_active