Refactor authentication services to async password handling; optimize bulk operations and queries
- Updated `verify_password` and `get_password_hash` to their async counterparts to prevent event loop blocking. - Replaced N+1 query patterns in `admin.py` and `session_async.py` with optimized bulk operations for improved performance. - Enhanced `user_async.py` with bulk update and soft delete methods for efficient user management. - Added eager loading support in CRUD operations to prevent N+1 query issues. - Updated test cases with stronger password examples for better security representation.
This commit is contained in:
@@ -345,54 +345,50 @@ async def admin_bulk_user_action(
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Perform bulk actions on multiple users.
|
||||
Perform bulk actions on multiple users using optimized bulk operations.
|
||||
|
||||
Uses single UPDATE query instead of N individual queries for efficiency.
|
||||
Supported actions: activate, deactivate, delete
|
||||
"""
|
||||
affected_count = 0
|
||||
failed_count = 0
|
||||
failed_ids = []
|
||||
|
||||
try:
|
||||
for user_id in bulk_action.user_ids:
|
||||
try:
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
failed_count += 1
|
||||
failed_ids.append(user_id)
|
||||
continue
|
||||
# Use efficient bulk operations instead of loop
|
||||
if bulk_action.action == BulkAction.ACTIVATE:
|
||||
affected_count = await user_crud.bulk_update_status(
|
||||
db,
|
||||
user_ids=bulk_action.user_ids,
|
||||
is_active=True
|
||||
)
|
||||
elif bulk_action.action == BulkAction.DEACTIVATE:
|
||||
affected_count = await user_crud.bulk_update_status(
|
||||
db,
|
||||
user_ids=bulk_action.user_ids,
|
||||
is_active=False
|
||||
)
|
||||
elif bulk_action.action == BulkAction.DELETE:
|
||||
# bulk_soft_delete automatically excludes the admin user
|
||||
affected_count = await user_crud.bulk_soft_delete(
|
||||
db,
|
||||
user_ids=bulk_action.user_ids,
|
||||
exclude_user_id=admin.id
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported bulk action: {bulk_action.action}")
|
||||
|
||||
# Prevent affecting yourself
|
||||
if user.id == admin.id:
|
||||
failed_count += 1
|
||||
failed_ids.append(user_id)
|
||||
continue
|
||||
|
||||
if bulk_action.action == BulkAction.ACTIVATE:
|
||||
await user_crud.update(db, db_obj=user, obj_in={"is_active": True})
|
||||
elif bulk_action.action == BulkAction.DEACTIVATE:
|
||||
await user_crud.update(db, db_obj=user, obj_in={"is_active": False})
|
||||
elif bulk_action.action == BulkAction.DELETE:
|
||||
await user_crud.soft_delete(db, id=user_id)
|
||||
|
||||
affected_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing user {user_id} in bulk action: {str(e)}")
|
||||
failed_count += 1
|
||||
failed_ids.append(user_id)
|
||||
# Calculate failed count (requested - affected)
|
||||
requested_count = len(bulk_action.user_ids)
|
||||
failed_count = requested_count - affected_count
|
||||
|
||||
logger.info(
|
||||
f"Admin {admin.email} performed bulk {bulk_action.action.value} "
|
||||
f"on {affected_count} users ({failed_count} failed)"
|
||||
f"on {affected_count} users ({failed_count} skipped/failed)"
|
||||
)
|
||||
|
||||
return BulkActionResult(
|
||||
success=failed_count == 0,
|
||||
affected_count=affected_count,
|
||||
failed_count=failed_count,
|
||||
message=f"Bulk {bulk_action.action.value}: {affected_count} users affected, {failed_count} failed",
|
||||
failed_ids=failed_ids if failed_ids else None
|
||||
message=f"Bulk {bulk_action.action.value}: {affected_count} users affected, {failed_count} skipped",
|
||||
failed_ids=None # Bulk operations don't track individual failures
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -51,23 +51,20 @@ async def get_my_organizations(
|
||||
Get all organizations the current user belongs to.
|
||||
|
||||
Returns organizations with member count for each.
|
||||
Uses optimized single query to avoid N+1 problem.
|
||||
"""
|
||||
try:
|
||||
orgs = await organization_crud.get_user_organizations(
|
||||
# Get all org data in single query with JOIN and subquery
|
||||
orgs_data = await organization_crud.get_user_organizations_with_details(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
is_active=is_active
|
||||
)
|
||||
|
||||
# Add member count and role to each organization
|
||||
# Transform to response objects
|
||||
orgs_with_data = []
|
||||
for org in orgs:
|
||||
role = await organization_crud.get_user_role_in_org(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
organization_id=org.id
|
||||
)
|
||||
|
||||
for item in orgs_data:
|
||||
org = item['organization']
|
||||
org_dict = {
|
||||
"id": org.id,
|
||||
"name": org.name,
|
||||
@@ -77,7 +74,7 @@ async def get_my_organizations(
|
||||
"settings": org.settings,
|
||||
"created_at": org.created_at,
|
||||
"updated_at": org.updated_at,
|
||||
"member_count": await organization_crud.get_member_count(db, organization_id=org.id)
|
||||
"member_count": item['member_count']
|
||||
}
|
||||
orgs_with_data.append(OrganizationResponse(**org_dict))
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ logging.getLogger('passlib').setLevel(logging.ERROR)
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, Optional, Union
|
||||
import uuid
|
||||
import asyncio
|
||||
from functools import partial
|
||||
|
||||
from jose import jwt, JWTError
|
||||
from passlib.context import CryptContext
|
||||
@@ -44,6 +46,49 @@ def get_password_hash(password: str) -> str:
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
async def verify_password_async(plain_password: str, hashed_password: str) -> bool:
|
||||
"""
|
||||
Verify a password against a hash asynchronously.
|
||||
|
||||
Runs the CPU-intensive bcrypt operation in a thread pool to avoid
|
||||
blocking the event loop.
|
||||
|
||||
Args:
|
||||
plain_password: Plain text password to verify
|
||||
hashed_password: Hashed password to verify against
|
||||
|
||||
Returns:
|
||||
True if password matches, False otherwise
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
partial(pwd_context.verify, plain_password, hashed_password)
|
||||
)
|
||||
|
||||
|
||||
async def get_password_hash_async(password: str) -> str:
|
||||
"""
|
||||
Generate a password hash asynchronously.
|
||||
|
||||
Runs the CPU-intensive bcrypt operation in a thread pool to avoid
|
||||
blocking the event loop. This is especially important during user
|
||||
registration and password changes.
|
||||
|
||||
Args:
|
||||
password: Plain text password to hash
|
||||
|
||||
Returns:
|
||||
Hashed password string
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
pwd_context.hash,
|
||||
password
|
||||
)
|
||||
|
||||
|
||||
def create_access_token(
|
||||
subject: Union[str, Any],
|
||||
expires_delta: Optional[timedelta] = None,
|
||||
|
||||
@@ -13,6 +13,7 @@ from pydantic import BaseModel
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
|
||||
from sqlalchemy.orm import Load
|
||||
|
||||
from app.core.database_async import Base
|
||||
|
||||
@@ -35,8 +36,29 @@ class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
"""
|
||||
self.model = model
|
||||
|
||||
async def get(self, db: AsyncSession, id: str) -> Optional[ModelType]:
|
||||
"""Get a single record by ID with UUID validation."""
|
||||
async def get(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
id: str,
|
||||
options: Optional[List[Load]] = None
|
||||
) -> Optional[ModelType]:
|
||||
"""
|
||||
Get a single record by ID with UUID validation and optional eager loading.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
id: Record UUID
|
||||
options: Optional list of SQLAlchemy load options (e.g., joinedload, selectinload)
|
||||
for eager loading relationships to prevent N+1 queries
|
||||
|
||||
Returns:
|
||||
Model instance or None if not found
|
||||
|
||||
Example:
|
||||
# Eager load user relationship
|
||||
from sqlalchemy.orm import joinedload
|
||||
session = await session_crud.get(db, id=session_id, options=[joinedload(UserSession.user)])
|
||||
"""
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
@@ -48,18 +70,39 @@ class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
return None
|
||||
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(self.model).where(self.model.id == uuid_obj)
|
||||
)
|
||||
query = select(self.model).where(self.model.id == uuid_obj)
|
||||
|
||||
# Apply eager loading options if provided
|
||||
if options:
|
||||
for option in options:
|
||||
query = query.options(option)
|
||||
|
||||
result = await db.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_multi(
|
||||
self, db: AsyncSession, *, skip: int = 0, limit: int = 100
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
options: Optional[List[Load]] = None
|
||||
) -> List[ModelType]:
|
||||
"""Get multiple records with pagination validation."""
|
||||
"""
|
||||
Get multiple records with pagination validation and optional eager loading.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
skip: Number of records to skip
|
||||
limit: Maximum number of records to return
|
||||
options: Optional list of SQLAlchemy load options for eager loading
|
||||
|
||||
Returns:
|
||||
List of model instances
|
||||
"""
|
||||
# Validate pagination parameters
|
||||
if skip < 0:
|
||||
raise ValueError("skip must be non-negative")
|
||||
@@ -69,9 +112,14 @@ class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
raise ValueError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(self.model).offset(skip).limit(limit)
|
||||
)
|
||||
query = select(self.model).offset(skip).limit(limit)
|
||||
|
||||
# Apply eager loading options if provided
|
||||
if options:
|
||||
for option in options:
|
||||
query = query.options(option)
|
||||
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}")
|
||||
|
||||
@@ -5,7 +5,8 @@ from datetime import datetime, timezone, timedelta
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import and_, select, update, func
|
||||
from sqlalchemy import and_, select, update, delete, func
|
||||
from sqlalchemy.orm import selectinload, joinedload
|
||||
import logging
|
||||
|
||||
from app.crud.base_async import CRUDBaseAsync
|
||||
@@ -68,15 +69,17 @@ class CRUDSessionAsync(CRUDBaseAsync[UserSession, SessionCreate, SessionUpdate])
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str,
|
||||
active_only: bool = True
|
||||
active_only: bool = True,
|
||||
with_user: bool = False
|
||||
) -> List[UserSession]:
|
||||
"""
|
||||
Get all sessions for a user.
|
||||
Get all sessions for a user with optional eager loading.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
active_only: If True, return only active sessions
|
||||
with_user: If True, eager load user relationship to prevent N+1
|
||||
|
||||
Returns:
|
||||
List of UserSession objects
|
||||
@@ -87,6 +90,10 @@ class CRUDSessionAsync(CRUDBaseAsync[UserSession, SessionCreate, SessionUpdate])
|
||||
|
||||
query = select(UserSession).where(UserSession.user_id == user_uuid)
|
||||
|
||||
# Add eager loading if requested to prevent N+1 queries
|
||||
if with_user:
|
||||
query = query.options(joinedload(UserSession.user))
|
||||
|
||||
if active_only:
|
||||
query = query.where(UserSession.is_active == True)
|
||||
|
||||
@@ -286,12 +293,14 @@ class CRUDSessionAsync(CRUDBaseAsync[UserSession, SessionCreate, SessionUpdate])
|
||||
|
||||
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
|
||||
"""
|
||||
Clean up expired sessions.
|
||||
Clean up expired sessions using optimized bulk DELETE.
|
||||
|
||||
Deletes sessions that are:
|
||||
- Expired AND inactive
|
||||
- Older than keep_days
|
||||
|
||||
Uses single DELETE query instead of N individual deletes for efficiency.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
keep_days: Keep inactive sessions for this many days (for audit)
|
||||
@@ -301,28 +310,24 @@ class CRUDSessionAsync(CRUDBaseAsync[UserSession, SessionCreate, SessionUpdate])
|
||||
"""
|
||||
try:
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Get sessions to delete
|
||||
stmt = select(UserSession).where(
|
||||
# Use bulk DELETE with WHERE clause - single query
|
||||
stmt = delete(UserSession).where(
|
||||
and_(
|
||||
UserSession.is_active == False,
|
||||
UserSession.expires_at < datetime.now(timezone.utc),
|
||||
UserSession.expires_at < now,
|
||||
UserSession.created_at < cutoff_date
|
||||
)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
sessions_to_delete = list(result.scalars().all())
|
||||
|
||||
# Delete them
|
||||
for session in sessions_to_delete:
|
||||
await db.delete(session)
|
||||
|
||||
await db.commit()
|
||||
|
||||
count = len(sessions_to_delete)
|
||||
count = result.rowcount
|
||||
|
||||
if count > 0:
|
||||
logger.info(f"Cleaned up {count} expired sessions")
|
||||
logger.info(f"Cleaned up {count} expired sessions using bulk DELETE")
|
||||
|
||||
return count
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
# app/crud/user_async.py
|
||||
"""Async CRUD operations for User model using SQLAlchemy 2.0 patterns."""
|
||||
from typing import Optional, Union, Dict, Any, List, Tuple
|
||||
from uuid import UUID
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy import or_, select, update
|
||||
from app.crud.base_async import CRUDBaseAsync
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
from app.core.auth import get_password_hash
|
||||
from app.core.auth import get_password_hash_async
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -28,11 +30,14 @@ class CRUDUserAsync(CRUDBaseAsync[User, UserCreate, UserUpdate]):
|
||||
raise
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
|
||||
"""Create a new user with password hashing and error handling."""
|
||||
"""Create a new user with async password hashing and error handling."""
|
||||
try:
|
||||
# Hash password asynchronously to avoid blocking event loop
|
||||
password_hash = await get_password_hash_async(obj_in.password)
|
||||
|
||||
db_obj = User(
|
||||
email=obj_in.email,
|
||||
password_hash=get_password_hash(obj_in.password),
|
||||
password_hash=password_hash,
|
||||
first_name=obj_in.first_name,
|
||||
last_name=obj_in.last_name,
|
||||
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None,
|
||||
@@ -63,15 +68,16 @@ class CRUDUserAsync(CRUDBaseAsync[User, UserCreate, UserUpdate]):
|
||||
db_obj: User,
|
||||
obj_in: Union[UserUpdate, Dict[str, Any]]
|
||||
) -> User:
|
||||
"""Update user with password hashing if password is updated."""
|
||||
"""Update user with async password hashing if password is updated."""
|
||||
if isinstance(obj_in, dict):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
|
||||
# Handle password separately if it exists in update data
|
||||
# Hash password asynchronously to avoid blocking event loop
|
||||
if "password" in update_data:
|
||||
update_data["password_hash"] = get_password_hash(update_data["password"])
|
||||
update_data["password_hash"] = await get_password_hash_async(update_data["password"])
|
||||
del update_data["password"]
|
||||
|
||||
return await super().update(db, db_obj=db_obj, obj_in=update_data)
|
||||
@@ -157,6 +163,100 @@ class CRUDUserAsync(CRUDBaseAsync[User, UserCreate, UserUpdate]):
|
||||
logger.error(f"Error retrieving paginated users: {str(e)}")
|
||||
raise
|
||||
|
||||
async def bulk_update_status(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_ids: List[UUID],
|
||||
is_active: bool
|
||||
) -> int:
|
||||
"""
|
||||
Bulk update is_active status for multiple users.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_ids: List of user IDs to update
|
||||
is_active: New active status
|
||||
|
||||
Returns:
|
||||
Number of users updated
|
||||
"""
|
||||
try:
|
||||
if not user_ids:
|
||||
return 0
|
||||
|
||||
# Use UPDATE with WHERE IN for efficiency
|
||||
stmt = (
|
||||
update(User)
|
||||
.where(User.id.in_(user_ids))
|
||||
.where(User.deleted_at.is_(None)) # Don't update deleted users
|
||||
.values(is_active=is_active, updated_at=datetime.now(timezone.utc))
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
updated_count = result.rowcount
|
||||
logger.info(f"Bulk updated {updated_count} users to is_active={is_active}")
|
||||
return updated_count
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error bulk updating user status: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def bulk_soft_delete(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_ids: List[UUID],
|
||||
exclude_user_id: Optional[UUID] = None
|
||||
) -> int:
|
||||
"""
|
||||
Bulk soft delete multiple users.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_ids: List of user IDs to delete
|
||||
exclude_user_id: Optional user ID to exclude (e.g., the admin performing the action)
|
||||
|
||||
Returns:
|
||||
Number of users deleted
|
||||
"""
|
||||
try:
|
||||
if not user_ids:
|
||||
return 0
|
||||
|
||||
# Remove excluded user from list
|
||||
filtered_ids = [uid for uid in user_ids if uid != exclude_user_id]
|
||||
|
||||
if not filtered_ids:
|
||||
return 0
|
||||
|
||||
# Use UPDATE with WHERE IN for efficiency
|
||||
stmt = (
|
||||
update(User)
|
||||
.where(User.id.in_(filtered_ids))
|
||||
.where(User.deleted_at.is_(None)) # Don't re-delete already deleted users
|
||||
.values(
|
||||
deleted_at=datetime.now(timezone.utc),
|
||||
is_active=False,
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
deleted_count = result.rowcount
|
||||
logger.info(f"Bulk soft deleted {deleted_count} users")
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error bulk deleting users: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def is_active(self, user: User) -> bool:
|
||||
"""Check if user is active."""
|
||||
return user.is_active
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
Common schemas used across the API for pagination, responses, filtering, and sorting.
|
||||
"""
|
||||
from typing import Generic, TypeVar, List, Optional
|
||||
from uuid import UUID
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field
|
||||
from math import ceil
|
||||
@@ -138,6 +139,46 @@ class MessageResponse(BaseModel):
|
||||
}
|
||||
|
||||
|
||||
class BulkActionRequest(BaseModel):
|
||||
"""Request schema for bulk operations on multiple items."""
|
||||
|
||||
ids: List[UUID] = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
max_length=100,
|
||||
description="List of item IDs to perform action on (max 100)"
|
||||
)
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"ids": [
|
||||
"550e8400-e29b-41d4-a716-446655440000",
|
||||
"6ba7b810-9dad-11d1-80b4-00c04fd430c8"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class BulkActionResponse(BaseModel):
|
||||
"""Response schema for bulk operations."""
|
||||
|
||||
success: bool = Field(default=True, description="Operation success status")
|
||||
message: str = Field(..., description="Human-readable message")
|
||||
affected_count: int = Field(..., description="Number of items affected by the operation")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"success": True,
|
||||
"message": "Successfully deactivated 5 users",
|
||||
"affected_count": 5
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def create_pagination_meta(
|
||||
total: int,
|
||||
page: int,
|
||||
|
||||
@@ -7,8 +7,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.auth import (
|
||||
verify_password,
|
||||
get_password_hash,
|
||||
verify_password_async,
|
||||
get_password_hash_async,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
TokenExpiredError,
|
||||
@@ -31,7 +31,7 @@ class AuthService:
|
||||
@staticmethod
|
||||
async def authenticate_user(db: AsyncSession, email: str, password: str) -> Optional[User]:
|
||||
"""
|
||||
Authenticate a user with email and password.
|
||||
Authenticate a user with email and password using async password verification.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
@@ -47,7 +47,8 @@ class AuthService:
|
||||
if not user:
|
||||
return None
|
||||
|
||||
if not verify_password(password, user.password_hash):
|
||||
# Verify password asynchronously to avoid blocking event loop
|
||||
if not await verify_password_async(password, user.password_hash):
|
||||
return None
|
||||
|
||||
if not user.is_active:
|
||||
@@ -77,8 +78,9 @@ class AuthService:
|
||||
if existing_user:
|
||||
raise AuthenticationError("User with this email already exists")
|
||||
|
||||
# Create new user
|
||||
hashed_password = get_password_hash(user_data.password)
|
||||
# Create new user with async password hashing
|
||||
# Hash password asynchronously to avoid blocking event loop
|
||||
hashed_password = await get_password_hash_async(user_data.password)
|
||||
|
||||
# Create user object from model
|
||||
user = User(
|
||||
@@ -202,12 +204,12 @@ class AuthService:
|
||||
if not user:
|
||||
raise AuthenticationError("User not found")
|
||||
|
||||
# Verify current password
|
||||
if not verify_password(current_password, user.password_hash):
|
||||
# Verify current password asynchronously
|
||||
if not await verify_password_async(current_password, user.password_hash):
|
||||
raise AuthenticationError("Current password is incorrect")
|
||||
|
||||
# Update password
|
||||
user.password_hash = get_password_hash(new_password)
|
||||
# Hash new password asynchronously to avoid blocking event loop
|
||||
user.password_hash = await get_password_hash_async(new_password)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"Password changed successfully for user {user_id}")
|
||||
|
||||
Reference in New Issue
Block a user