Refactor authentication services to async password handling; optimize bulk operations and queries
- Updated `verify_password` and `get_password_hash` to their async counterparts to prevent event loop blocking. - Replaced N+1 query patterns in `admin.py` and `session_async.py` with optimized bulk operations for improved performance. - Enhanced `user_async.py` with bulk update and soft delete methods for efficient user management. - Added eager loading support in CRUD operations to prevent N+1 query issues. - Updated test cases with stronger password examples for better security representation.
This commit is contained in:
@@ -345,54 +345,50 @@ async def admin_bulk_user_action(
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Perform bulk actions on multiple users.
|
||||
Perform bulk actions on multiple users using optimized bulk operations.
|
||||
|
||||
Uses single UPDATE query instead of N individual queries for efficiency.
|
||||
Supported actions: activate, deactivate, delete
|
||||
"""
|
||||
affected_count = 0
|
||||
failed_count = 0
|
||||
failed_ids = []
|
||||
|
||||
try:
|
||||
for user_id in bulk_action.user_ids:
|
||||
try:
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
failed_count += 1
|
||||
failed_ids.append(user_id)
|
||||
continue
|
||||
# Use efficient bulk operations instead of loop
|
||||
if bulk_action.action == BulkAction.ACTIVATE:
|
||||
affected_count = await user_crud.bulk_update_status(
|
||||
db,
|
||||
user_ids=bulk_action.user_ids,
|
||||
is_active=True
|
||||
)
|
||||
elif bulk_action.action == BulkAction.DEACTIVATE:
|
||||
affected_count = await user_crud.bulk_update_status(
|
||||
db,
|
||||
user_ids=bulk_action.user_ids,
|
||||
is_active=False
|
||||
)
|
||||
elif bulk_action.action == BulkAction.DELETE:
|
||||
# bulk_soft_delete automatically excludes the admin user
|
||||
affected_count = await user_crud.bulk_soft_delete(
|
||||
db,
|
||||
user_ids=bulk_action.user_ids,
|
||||
exclude_user_id=admin.id
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported bulk action: {bulk_action.action}")
|
||||
|
||||
# Prevent affecting yourself
|
||||
if user.id == admin.id:
|
||||
failed_count += 1
|
||||
failed_ids.append(user_id)
|
||||
continue
|
||||
|
||||
if bulk_action.action == BulkAction.ACTIVATE:
|
||||
await user_crud.update(db, db_obj=user, obj_in={"is_active": True})
|
||||
elif bulk_action.action == BulkAction.DEACTIVATE:
|
||||
await user_crud.update(db, db_obj=user, obj_in={"is_active": False})
|
||||
elif bulk_action.action == BulkAction.DELETE:
|
||||
await user_crud.soft_delete(db, id=user_id)
|
||||
|
||||
affected_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing user {user_id} in bulk action: {str(e)}")
|
||||
failed_count += 1
|
||||
failed_ids.append(user_id)
|
||||
# Calculate failed count (requested - affected)
|
||||
requested_count = len(bulk_action.user_ids)
|
||||
failed_count = requested_count - affected_count
|
||||
|
||||
logger.info(
|
||||
f"Admin {admin.email} performed bulk {bulk_action.action.value} "
|
||||
f"on {affected_count} users ({failed_count} failed)"
|
||||
f"on {affected_count} users ({failed_count} skipped/failed)"
|
||||
)
|
||||
|
||||
return BulkActionResult(
|
||||
success=failed_count == 0,
|
||||
affected_count=affected_count,
|
||||
failed_count=failed_count,
|
||||
message=f"Bulk {bulk_action.action.value}: {affected_count} users affected, {failed_count} failed",
|
||||
failed_ids=failed_ids if failed_ids else None
|
||||
message=f"Bulk {bulk_action.action.value}: {affected_count} users affected, {failed_count} skipped",
|
||||
failed_ids=None # Bulk operations don't track individual failures
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -51,23 +51,20 @@ async def get_my_organizations(
|
||||
Get all organizations the current user belongs to.
|
||||
|
||||
Returns organizations with member count for each.
|
||||
Uses optimized single query to avoid N+1 problem.
|
||||
"""
|
||||
try:
|
||||
orgs = await organization_crud.get_user_organizations(
|
||||
# Get all org data in single query with JOIN and subquery
|
||||
orgs_data = await organization_crud.get_user_organizations_with_details(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
is_active=is_active
|
||||
)
|
||||
|
||||
# Add member count and role to each organization
|
||||
# Transform to response objects
|
||||
orgs_with_data = []
|
||||
for org in orgs:
|
||||
role = await organization_crud.get_user_role_in_org(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
organization_id=org.id
|
||||
)
|
||||
|
||||
for item in orgs_data:
|
||||
org = item['organization']
|
||||
org_dict = {
|
||||
"id": org.id,
|
||||
"name": org.name,
|
||||
@@ -77,7 +74,7 @@ async def get_my_organizations(
|
||||
"settings": org.settings,
|
||||
"created_at": org.created_at,
|
||||
"updated_at": org.updated_at,
|
||||
"member_count": await organization_crud.get_member_count(db, organization_id=org.id)
|
||||
"member_count": item['member_count']
|
||||
}
|
||||
orgs_with_data.append(OrganizationResponse(**org_dict))
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ logging.getLogger('passlib').setLevel(logging.ERROR)
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, Optional, Union
|
||||
import uuid
|
||||
import asyncio
|
||||
from functools import partial
|
||||
|
||||
from jose import jwt, JWTError
|
||||
from passlib.context import CryptContext
|
||||
@@ -44,6 +46,49 @@ def get_password_hash(password: str) -> str:
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
async def verify_password_async(plain_password: str, hashed_password: str) -> bool:
|
||||
"""
|
||||
Verify a password against a hash asynchronously.
|
||||
|
||||
Runs the CPU-intensive bcrypt operation in a thread pool to avoid
|
||||
blocking the event loop.
|
||||
|
||||
Args:
|
||||
plain_password: Plain text password to verify
|
||||
hashed_password: Hashed password to verify against
|
||||
|
||||
Returns:
|
||||
True if password matches, False otherwise
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
partial(pwd_context.verify, plain_password, hashed_password)
|
||||
)
|
||||
|
||||
|
||||
async def get_password_hash_async(password: str) -> str:
|
||||
"""
|
||||
Generate a password hash asynchronously.
|
||||
|
||||
Runs the CPU-intensive bcrypt operation in a thread pool to avoid
|
||||
blocking the event loop. This is especially important during user
|
||||
registration and password changes.
|
||||
|
||||
Args:
|
||||
password: Plain text password to hash
|
||||
|
||||
Returns:
|
||||
Hashed password string
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
pwd_context.hash,
|
||||
password
|
||||
)
|
||||
|
||||
|
||||
def create_access_token(
|
||||
subject: Union[str, Any],
|
||||
expires_delta: Optional[timedelta] = None,
|
||||
|
||||
@@ -13,6 +13,7 @@ from pydantic import BaseModel
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
|
||||
from sqlalchemy.orm import Load
|
||||
|
||||
from app.core.database_async import Base
|
||||
|
||||
@@ -35,8 +36,29 @@ class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
"""
|
||||
self.model = model
|
||||
|
||||
async def get(self, db: AsyncSession, id: str) -> Optional[ModelType]:
|
||||
"""Get a single record by ID with UUID validation."""
|
||||
async def get(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
id: str,
|
||||
options: Optional[List[Load]] = None
|
||||
) -> Optional[ModelType]:
|
||||
"""
|
||||
Get a single record by ID with UUID validation and optional eager loading.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
id: Record UUID
|
||||
options: Optional list of SQLAlchemy load options (e.g., joinedload, selectinload)
|
||||
for eager loading relationships to prevent N+1 queries
|
||||
|
||||
Returns:
|
||||
Model instance or None if not found
|
||||
|
||||
Example:
|
||||
# Eager load user relationship
|
||||
from sqlalchemy.orm import joinedload
|
||||
session = await session_crud.get(db, id=session_id, options=[joinedload(UserSession.user)])
|
||||
"""
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
@@ -48,18 +70,39 @@ class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
return None
|
||||
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(self.model).where(self.model.id == uuid_obj)
|
||||
)
|
||||
query = select(self.model).where(self.model.id == uuid_obj)
|
||||
|
||||
# Apply eager loading options if provided
|
||||
if options:
|
||||
for option in options:
|
||||
query = query.options(option)
|
||||
|
||||
result = await db.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_multi(
|
||||
self, db: AsyncSession, *, skip: int = 0, limit: int = 100
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
options: Optional[List[Load]] = None
|
||||
) -> List[ModelType]:
|
||||
"""Get multiple records with pagination validation."""
|
||||
"""
|
||||
Get multiple records with pagination validation and optional eager loading.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
skip: Number of records to skip
|
||||
limit: Maximum number of records to return
|
||||
options: Optional list of SQLAlchemy load options for eager loading
|
||||
|
||||
Returns:
|
||||
List of model instances
|
||||
"""
|
||||
# Validate pagination parameters
|
||||
if skip < 0:
|
||||
raise ValueError("skip must be non-negative")
|
||||
@@ -69,9 +112,14 @@ class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
raise ValueError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(self.model).offset(skip).limit(limit)
|
||||
)
|
||||
query = select(self.model).offset(skip).limit(limit)
|
||||
|
||||
# Apply eager loading options if provided
|
||||
if options:
|
||||
for option in options:
|
||||
query = query.options(option)
|
||||
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}")
|
||||
|
||||
@@ -5,7 +5,8 @@ from datetime import datetime, timezone, timedelta
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import and_, select, update, func
|
||||
from sqlalchemy import and_, select, update, delete, func
|
||||
from sqlalchemy.orm import selectinload, joinedload
|
||||
import logging
|
||||
|
||||
from app.crud.base_async import CRUDBaseAsync
|
||||
@@ -68,15 +69,17 @@ class CRUDSessionAsync(CRUDBaseAsync[UserSession, SessionCreate, SessionUpdate])
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str,
|
||||
active_only: bool = True
|
||||
active_only: bool = True,
|
||||
with_user: bool = False
|
||||
) -> List[UserSession]:
|
||||
"""
|
||||
Get all sessions for a user.
|
||||
Get all sessions for a user with optional eager loading.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
active_only: If True, return only active sessions
|
||||
with_user: If True, eager load user relationship to prevent N+1
|
||||
|
||||
Returns:
|
||||
List of UserSession objects
|
||||
@@ -87,6 +90,10 @@ class CRUDSessionAsync(CRUDBaseAsync[UserSession, SessionCreate, SessionUpdate])
|
||||
|
||||
query = select(UserSession).where(UserSession.user_id == user_uuid)
|
||||
|
||||
# Add eager loading if requested to prevent N+1 queries
|
||||
if with_user:
|
||||
query = query.options(joinedload(UserSession.user))
|
||||
|
||||
if active_only:
|
||||
query = query.where(UserSession.is_active == True)
|
||||
|
||||
@@ -286,12 +293,14 @@ class CRUDSessionAsync(CRUDBaseAsync[UserSession, SessionCreate, SessionUpdate])
|
||||
|
||||
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
|
||||
"""
|
||||
Clean up expired sessions.
|
||||
Clean up expired sessions using optimized bulk DELETE.
|
||||
|
||||
Deletes sessions that are:
|
||||
- Expired AND inactive
|
||||
- Older than keep_days
|
||||
|
||||
Uses single DELETE query instead of N individual deletes for efficiency.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
keep_days: Keep inactive sessions for this many days (for audit)
|
||||
@@ -301,28 +310,24 @@ class CRUDSessionAsync(CRUDBaseAsync[UserSession, SessionCreate, SessionUpdate])
|
||||
"""
|
||||
try:
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Get sessions to delete
|
||||
stmt = select(UserSession).where(
|
||||
# Use bulk DELETE with WHERE clause - single query
|
||||
stmt = delete(UserSession).where(
|
||||
and_(
|
||||
UserSession.is_active == False,
|
||||
UserSession.expires_at < datetime.now(timezone.utc),
|
||||
UserSession.expires_at < now,
|
||||
UserSession.created_at < cutoff_date
|
||||
)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
sessions_to_delete = list(result.scalars().all())
|
||||
|
||||
# Delete them
|
||||
for session in sessions_to_delete:
|
||||
await db.delete(session)
|
||||
|
||||
await db.commit()
|
||||
|
||||
count = len(sessions_to_delete)
|
||||
count = result.rowcount
|
||||
|
||||
if count > 0:
|
||||
logger.info(f"Cleaned up {count} expired sessions")
|
||||
logger.info(f"Cleaned up {count} expired sessions using bulk DELETE")
|
||||
|
||||
return count
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
# app/crud/user_async.py
|
||||
"""Async CRUD operations for User model using SQLAlchemy 2.0 patterns."""
|
||||
from typing import Optional, Union, Dict, Any, List, Tuple
|
||||
from uuid import UUID
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy import or_, select, update
|
||||
from app.crud.base_async import CRUDBaseAsync
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
from app.core.auth import get_password_hash
|
||||
from app.core.auth import get_password_hash_async
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -28,11 +30,14 @@ class CRUDUserAsync(CRUDBaseAsync[User, UserCreate, UserUpdate]):
|
||||
raise
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
|
||||
"""Create a new user with password hashing and error handling."""
|
||||
"""Create a new user with async password hashing and error handling."""
|
||||
try:
|
||||
# Hash password asynchronously to avoid blocking event loop
|
||||
password_hash = await get_password_hash_async(obj_in.password)
|
||||
|
||||
db_obj = User(
|
||||
email=obj_in.email,
|
||||
password_hash=get_password_hash(obj_in.password),
|
||||
password_hash=password_hash,
|
||||
first_name=obj_in.first_name,
|
||||
last_name=obj_in.last_name,
|
||||
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None,
|
||||
@@ -63,15 +68,16 @@ class CRUDUserAsync(CRUDBaseAsync[User, UserCreate, UserUpdate]):
|
||||
db_obj: User,
|
||||
obj_in: Union[UserUpdate, Dict[str, Any]]
|
||||
) -> User:
|
||||
"""Update user with password hashing if password is updated."""
|
||||
"""Update user with async password hashing if password is updated."""
|
||||
if isinstance(obj_in, dict):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
|
||||
# Handle password separately if it exists in update data
|
||||
# Hash password asynchronously to avoid blocking event loop
|
||||
if "password" in update_data:
|
||||
update_data["password_hash"] = get_password_hash(update_data["password"])
|
||||
update_data["password_hash"] = await get_password_hash_async(update_data["password"])
|
||||
del update_data["password"]
|
||||
|
||||
return await super().update(db, db_obj=db_obj, obj_in=update_data)
|
||||
@@ -157,6 +163,100 @@ class CRUDUserAsync(CRUDBaseAsync[User, UserCreate, UserUpdate]):
|
||||
logger.error(f"Error retrieving paginated users: {str(e)}")
|
||||
raise
|
||||
|
||||
async def bulk_update_status(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_ids: List[UUID],
|
||||
is_active: bool
|
||||
) -> int:
|
||||
"""
|
||||
Bulk update is_active status for multiple users.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_ids: List of user IDs to update
|
||||
is_active: New active status
|
||||
|
||||
Returns:
|
||||
Number of users updated
|
||||
"""
|
||||
try:
|
||||
if not user_ids:
|
||||
return 0
|
||||
|
||||
# Use UPDATE with WHERE IN for efficiency
|
||||
stmt = (
|
||||
update(User)
|
||||
.where(User.id.in_(user_ids))
|
||||
.where(User.deleted_at.is_(None)) # Don't update deleted users
|
||||
.values(is_active=is_active, updated_at=datetime.now(timezone.utc))
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
updated_count = result.rowcount
|
||||
logger.info(f"Bulk updated {updated_count} users to is_active={is_active}")
|
||||
return updated_count
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error bulk updating user status: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def bulk_soft_delete(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_ids: List[UUID],
|
||||
exclude_user_id: Optional[UUID] = None
|
||||
) -> int:
|
||||
"""
|
||||
Bulk soft delete multiple users.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_ids: List of user IDs to delete
|
||||
exclude_user_id: Optional user ID to exclude (e.g., the admin performing the action)
|
||||
|
||||
Returns:
|
||||
Number of users deleted
|
||||
"""
|
||||
try:
|
||||
if not user_ids:
|
||||
return 0
|
||||
|
||||
# Remove excluded user from list
|
||||
filtered_ids = [uid for uid in user_ids if uid != exclude_user_id]
|
||||
|
||||
if not filtered_ids:
|
||||
return 0
|
||||
|
||||
# Use UPDATE with WHERE IN for efficiency
|
||||
stmt = (
|
||||
update(User)
|
||||
.where(User.id.in_(filtered_ids))
|
||||
.where(User.deleted_at.is_(None)) # Don't re-delete already deleted users
|
||||
.values(
|
||||
deleted_at=datetime.now(timezone.utc),
|
||||
is_active=False,
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
deleted_count = result.rowcount
|
||||
logger.info(f"Bulk soft deleted {deleted_count} users")
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error bulk deleting users: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def is_active(self, user: User) -> bool:
|
||||
"""Check if user is active."""
|
||||
return user.is_active
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
Common schemas used across the API for pagination, responses, filtering, and sorting.
|
||||
"""
|
||||
from typing import Generic, TypeVar, List, Optional
|
||||
from uuid import UUID
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field
|
||||
from math import ceil
|
||||
@@ -138,6 +139,46 @@ class MessageResponse(BaseModel):
|
||||
}
|
||||
|
||||
|
||||
class BulkActionRequest(BaseModel):
|
||||
"""Request schema for bulk operations on multiple items."""
|
||||
|
||||
ids: List[UUID] = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
max_length=100,
|
||||
description="List of item IDs to perform action on (max 100)"
|
||||
)
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"ids": [
|
||||
"550e8400-e29b-41d4-a716-446655440000",
|
||||
"6ba7b810-9dad-11d1-80b4-00c04fd430c8"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class BulkActionResponse(BaseModel):
|
||||
"""Response schema for bulk operations."""
|
||||
|
||||
success: bool = Field(default=True, description="Operation success status")
|
||||
message: str = Field(..., description="Human-readable message")
|
||||
affected_count: int = Field(..., description="Number of items affected by the operation")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"success": True,
|
||||
"message": "Successfully deactivated 5 users",
|
||||
"affected_count": 5
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def create_pagination_meta(
|
||||
total: int,
|
||||
page: int,
|
||||
|
||||
@@ -7,8 +7,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.auth import (
|
||||
verify_password,
|
||||
get_password_hash,
|
||||
verify_password_async,
|
||||
get_password_hash_async,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
TokenExpiredError,
|
||||
@@ -31,7 +31,7 @@ class AuthService:
|
||||
@staticmethod
|
||||
async def authenticate_user(db: AsyncSession, email: str, password: str) -> Optional[User]:
|
||||
"""
|
||||
Authenticate a user with email and password.
|
||||
Authenticate a user with email and password using async password verification.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
@@ -47,7 +47,8 @@ class AuthService:
|
||||
if not user:
|
||||
return None
|
||||
|
||||
if not verify_password(password, user.password_hash):
|
||||
# Verify password asynchronously to avoid blocking event loop
|
||||
if not await verify_password_async(password, user.password_hash):
|
||||
return None
|
||||
|
||||
if not user.is_active:
|
||||
@@ -77,8 +78,9 @@ class AuthService:
|
||||
if existing_user:
|
||||
raise AuthenticationError("User with this email already exists")
|
||||
|
||||
# Create new user
|
||||
hashed_password = get_password_hash(user_data.password)
|
||||
# Create new user with async password hashing
|
||||
# Hash password asynchronously to avoid blocking event loop
|
||||
hashed_password = await get_password_hash_async(user_data.password)
|
||||
|
||||
# Create user object from model
|
||||
user = User(
|
||||
@@ -202,12 +204,12 @@ class AuthService:
|
||||
if not user:
|
||||
raise AuthenticationError("User not found")
|
||||
|
||||
# Verify current password
|
||||
if not verify_password(current_password, user.password_hash):
|
||||
# Verify current password asynchronously
|
||||
if not await verify_password_async(current_password, user.password_hash):
|
||||
raise AuthenticationError("Current password is incorrect")
|
||||
|
||||
# Update password
|
||||
user.password_hash = get_password_hash(new_password)
|
||||
# Hash new password asynchronously to avoid blocking event loop
|
||||
user.password_hash = await get_password_hash_async(new_password)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"Password changed successfully for user {user_id}")
|
||||
|
||||
@@ -30,7 +30,7 @@ class TestRegisterEndpoint:
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": "newuser@example.com",
|
||||
"password": "SecurePassword123",
|
||||
"password": "SecurePassword123!",
|
||||
"first_name": "New",
|
||||
"last_name": "User"
|
||||
}
|
||||
@@ -49,7 +49,7 @@ class TestRegisterEndpoint:
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": async_test_user.email,
|
||||
"password": "SecurePassword123",
|
||||
"password": "SecurePassword123!",
|
||||
"first_name": "Duplicate",
|
||||
"last_name": "User"
|
||||
}
|
||||
@@ -103,7 +103,7 @@ class TestLoginEndpoint:
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": async_test_user.email,
|
||||
"password": "TestPassword123"
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
)
|
||||
|
||||
@@ -133,7 +133,7 @@ class TestLoginEndpoint:
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "nonexistent@example.com",
|
||||
"password": "Password123"
|
||||
"password": "Password123!"
|
||||
}
|
||||
)
|
||||
|
||||
@@ -154,7 +154,7 @@ class TestLoginEndpoint:
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": async_test_user.email,
|
||||
"password": "TestPassword123"
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
)
|
||||
|
||||
@@ -170,7 +170,7 @@ class TestLoginEndpoint:
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": async_test_user.email,
|
||||
"password": "TestPassword123"
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
)
|
||||
|
||||
@@ -187,7 +187,7 @@ class TestOAuthLoginEndpoint:
|
||||
"/api/v1/auth/login/oauth",
|
||||
data={
|
||||
"username": async_test_user.email,
|
||||
"password": "TestPassword123"
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
)
|
||||
|
||||
@@ -224,7 +224,7 @@ class TestOAuthLoginEndpoint:
|
||||
"/api/v1/auth/login/oauth",
|
||||
data={
|
||||
"username": async_test_user.email,
|
||||
"password": "TestPassword123"
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
)
|
||||
|
||||
@@ -240,7 +240,7 @@ class TestOAuthLoginEndpoint:
|
||||
"/api/v1/auth/login/oauth",
|
||||
data={
|
||||
"username": async_test_user.email,
|
||||
"password": "TestPassword123"
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
)
|
||||
|
||||
@@ -258,7 +258,7 @@ class TestRefreshTokenEndpoint:
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": async_test_user.email,
|
||||
"password": "TestPassword123"
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
)
|
||||
refresh_token = login_response.json()["refresh_token"]
|
||||
@@ -307,7 +307,7 @@ class TestRefreshTokenEndpoint:
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": async_test_user.email,
|
||||
"password": "TestPassword123"
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
)
|
||||
refresh_token = login_response.json()["refresh_token"]
|
||||
@@ -334,7 +334,7 @@ class TestGetCurrentUserEndpoint:
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": async_test_user.email,
|
||||
"password": "TestPassword123"
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
)
|
||||
access_token = login_response.json()["access_token"]
|
||||
|
||||
@@ -148,7 +148,7 @@ class TestPasswordResetConfirm:
|
||||
"""Test password reset confirmation with valid token."""
|
||||
# Generate valid token
|
||||
token = create_password_reset_token(async_test_user.email)
|
||||
new_password = "NewSecure123"
|
||||
new_password = "NewSecure123!"
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
@@ -186,7 +186,7 @@ class TestPasswordResetConfirm:
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={
|
||||
"token": token,
|
||||
"new_password": "NewSecure123"
|
||||
"new_password": "NewSecure123!"
|
||||
}
|
||||
)
|
||||
|
||||
@@ -204,7 +204,7 @@ class TestPasswordResetConfirm:
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={
|
||||
"token": "invalid_token_xyz",
|
||||
"new_password": "NewSecure123"
|
||||
"new_password": "NewSecure123!"
|
||||
}
|
||||
)
|
||||
|
||||
@@ -233,7 +233,7 @@ class TestPasswordResetConfirm:
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={
|
||||
"token": tampered,
|
||||
"new_password": "NewSecure123"
|
||||
"new_password": "NewSecure123!"
|
||||
}
|
||||
)
|
||||
|
||||
@@ -249,7 +249,7 @@ class TestPasswordResetConfirm:
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={
|
||||
"token": token,
|
||||
"new_password": "NewSecure123"
|
||||
"new_password": "NewSecure123!"
|
||||
}
|
||||
)
|
||||
|
||||
@@ -276,7 +276,7 @@ class TestPasswordResetConfirm:
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={
|
||||
"token": token,
|
||||
"new_password": "NewSecure123"
|
||||
"new_password": "NewSecure123!"
|
||||
}
|
||||
)
|
||||
|
||||
@@ -315,7 +315,7 @@ class TestPasswordResetConfirm:
|
||||
# Missing token
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={"new_password": "NewSecure123"}
|
||||
json={"new_password": "NewSecure123!"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
@@ -340,7 +340,7 @@ class TestPasswordResetConfirm:
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={
|
||||
"token": token,
|
||||
"new_password": "NewSecure123"
|
||||
"new_password": "NewSecure123!"
|
||||
}
|
||||
)
|
||||
|
||||
@@ -354,7 +354,7 @@ class TestPasswordResetConfirm:
|
||||
async def test_password_reset_full_flow(self, client, async_test_user, async_test_db):
|
||||
"""Test complete password reset flow."""
|
||||
original_password = async_test_user.password_hash
|
||||
new_password = "BrandNew123"
|
||||
new_password = "BrandNew123!"
|
||||
|
||||
# Step 1: Request password reset
|
||||
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send:
|
||||
|
||||
@@ -40,7 +40,7 @@ class TestListUsers:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users_as_superuser(self, client, async_test_superuser):
|
||||
"""Test listing users as superuser."""
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123")
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
|
||||
response = await client.get("/api/v1/users", headers=headers)
|
||||
|
||||
@@ -53,7 +53,7 @@ class TestListUsers:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users_as_regular_user(self, client, async_test_user):
|
||||
"""Test that regular users cannot list users."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123")
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
|
||||
response = await client.get("/api/v1/users", headers=headers)
|
||||
|
||||
@@ -77,7 +77,7 @@ class TestListUsers:
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123")
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
|
||||
# Get first page
|
||||
response = await client.get("/api/v1/users?page=1&limit=5", headers=headers)
|
||||
@@ -111,7 +111,7 @@ class TestListUsers:
|
||||
session.add_all([active_user, inactive_user])
|
||||
await session.commit()
|
||||
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123")
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
|
||||
# Filter for active users
|
||||
response = await client.get("/api/v1/users?is_active=true", headers=headers)
|
||||
@@ -130,7 +130,7 @@ class TestListUsers:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users_sort_by_email(self, client, async_test_superuser):
|
||||
"""Test sorting users by email."""
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123")
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
|
||||
response = await client.get("/api/v1/users?sort_by=email&sort_order=asc", headers=headers)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -154,7 +154,7 @@ class TestGetCurrentUserProfile:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_own_profile(self, client, async_test_user):
|
||||
"""Test getting own profile."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123")
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
|
||||
response = await client.get("/api/v1/users/me", headers=headers)
|
||||
|
||||
@@ -176,7 +176,7 @@ class TestUpdateCurrentUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_own_profile(self, client, async_test_user):
|
||||
"""Test updating own profile."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123")
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
|
||||
response = await client.patch(
|
||||
"/api/v1/users/me",
|
||||
@@ -192,7 +192,7 @@ class TestUpdateCurrentUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_profile_phone_number(self, client, async_test_user, test_db):
|
||||
"""Test updating phone number with validation."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123")
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
|
||||
response = await client.patch(
|
||||
"/api/v1/users/me",
|
||||
@@ -207,7 +207,7 @@ class TestUpdateCurrentUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_profile_invalid_phone(self, client, async_test_user):
|
||||
"""Test that invalid phone numbers are rejected."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123")
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
|
||||
response = await client.patch(
|
||||
"/api/v1/users/me",
|
||||
@@ -220,7 +220,7 @@ class TestUpdateCurrentUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_elevate_to_superuser(self, client, async_test_user):
|
||||
"""Test that users cannot make themselves superuser."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123")
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
|
||||
# Note: is_superuser is not in UserUpdate schema, but the endpoint checks for it
|
||||
# This tests that even if someone tries to send it, it's rejected
|
||||
@@ -255,7 +255,7 @@ class TestGetUserById:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_own_profile_by_id(self, client, async_test_user):
|
||||
"""Test getting own profile by ID."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123")
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
|
||||
response = await client.get(f"/api/v1/users/{async_test_user.id}", headers=headers)
|
||||
|
||||
@@ -278,7 +278,7 @@ class TestGetUserById:
|
||||
test_db.commit()
|
||||
test_db.refresh(other_user)
|
||||
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123")
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
|
||||
response = await client.get(f"/api/v1/users/{other_user.id}", headers=headers)
|
||||
|
||||
@@ -287,7 +287,7 @@ class TestGetUserById:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_other_user_as_superuser(self, client, async_test_superuser, async_test_user):
|
||||
"""Test that superusers can view other profiles."""
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123")
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
|
||||
response = await client.get(f"/api/v1/users/{async_test_user.id}", headers=headers)
|
||||
|
||||
@@ -298,7 +298,7 @@ class TestGetUserById:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_user(self, client, async_test_superuser):
|
||||
"""Test getting non-existent user."""
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123")
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
fake_id = uuid.uuid4()
|
||||
|
||||
response = await client.get(f"/api/v1/users/{fake_id}", headers=headers)
|
||||
@@ -308,7 +308,7 @@ class TestGetUserById:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_invalid_uuid(self, client, async_test_superuser):
|
||||
"""Test getting user with invalid UUID format."""
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123")
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
|
||||
response = await client.get("/api/v1/users/not-a-uuid", headers=headers)
|
||||
|
||||
@@ -321,7 +321,7 @@ class TestUpdateUserById:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_own_profile_by_id(self, client, async_test_user, test_db):
|
||||
"""Test updating own profile by ID."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123")
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
|
||||
response = await client.patch(
|
||||
f"/api/v1/users/{async_test_user.id}",
|
||||
@@ -348,7 +348,7 @@ class TestUpdateUserById:
|
||||
test_db.commit()
|
||||
test_db.refresh(other_user)
|
||||
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123")
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
|
||||
response = await client.patch(
|
||||
f"/api/v1/users/{other_user.id}",
|
||||
@@ -365,7 +365,7 @@ class TestUpdateUserById:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_other_user_as_superuser(self, client, async_test_superuser, async_test_user, test_db):
|
||||
"""Test that superusers can update other profiles."""
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123")
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
|
||||
response = await client.patch(
|
||||
f"/api/v1/users/{async_test_user.id}",
|
||||
@@ -380,7 +380,7 @@ class TestUpdateUserById:
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_cannot_modify_superuser_status(self, client, async_test_user):
|
||||
"""Test that regular users cannot change superuser status even if they try."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123")
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
|
||||
# is_superuser not in UserUpdate schema, so it gets ignored by Pydantic
|
||||
# Just verify the user stays the same
|
||||
@@ -397,7 +397,7 @@ class TestUpdateUserById:
|
||||
@pytest.mark.asyncio
|
||||
async def test_superuser_can_update_users(self, client, async_test_superuser, async_test_user, test_db):
|
||||
"""Test that superusers can update other users."""
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123")
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
|
||||
response = await client.patch(
|
||||
f"/api/v1/users/{async_test_user.id}",
|
||||
@@ -413,7 +413,7 @@ class TestUpdateUserById:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_nonexistent_user(self, client, async_test_superuser):
|
||||
"""Test updating non-existent user."""
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123")
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
fake_id = uuid.uuid4()
|
||||
|
||||
response = await client.patch(
|
||||
@@ -433,14 +433,14 @@ class TestChangePassword:
|
||||
@pytest.mark.asyncio
|
||||
async def test_change_password_success(self, client, async_test_user, test_db):
|
||||
"""Test successful password change."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123")
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
|
||||
response = await client.patch(
|
||||
"/api/v1/users/me/password",
|
||||
headers=headers,
|
||||
json={
|
||||
"current_password": "TestPassword123",
|
||||
"new_password": "NewPassword123"
|
||||
"current_password": "TestPassword123!",
|
||||
"new_password": "NewPassword123!"
|
||||
}
|
||||
)
|
||||
|
||||
@@ -453,7 +453,7 @@ class TestChangePassword:
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": async_test_user.email,
|
||||
"password": "NewPassword123"
|
||||
"password": "NewPassword123!"
|
||||
}
|
||||
)
|
||||
assert login_response.status_code == status.HTTP_200_OK
|
||||
@@ -461,14 +461,14 @@ class TestChangePassword:
|
||||
@pytest.mark.asyncio
|
||||
async def test_change_password_wrong_current(self, client, async_test_user):
|
||||
"""Test that wrong current password is rejected."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123")
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
|
||||
response = await client.patch(
|
||||
"/api/v1/users/me/password",
|
||||
headers=headers,
|
||||
json={
|
||||
"current_password": "WrongPassword123",
|
||||
"new_password": "NewPassword123"
|
||||
"new_password": "NewPassword123!"
|
||||
}
|
||||
)
|
||||
|
||||
@@ -477,13 +477,13 @@ class TestChangePassword:
|
||||
@pytest.mark.asyncio
|
||||
async def test_change_password_weak_new_password(self, client, async_test_user):
|
||||
"""Test that weak new passwords are rejected."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123")
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
|
||||
response = await client.patch(
|
||||
"/api/v1/users/me/password",
|
||||
headers=headers,
|
||||
json={
|
||||
"current_password": "TestPassword123",
|
||||
"current_password": "TestPassword123!",
|
||||
"new_password": "weak"
|
||||
}
|
||||
)
|
||||
@@ -496,8 +496,8 @@ class TestChangePassword:
|
||||
response = await client.patch(
|
||||
"/api/v1/users/me/password",
|
||||
json={
|
||||
"current_password": "TestPassword123",
|
||||
"new_password": "NewPassword123"
|
||||
"current_password": "TestPassword123!",
|
||||
"new_password": "NewPassword123!"
|
||||
}
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
@@ -527,7 +527,7 @@ class TestDeleteUser:
|
||||
await session.refresh(user_to_delete)
|
||||
user_id = user_to_delete.id
|
||||
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123")
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
|
||||
response = await client.delete(f"/api/v1/users/{user_id}", headers=headers)
|
||||
|
||||
@@ -545,7 +545,7 @@ class TestDeleteUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_delete_self(self, client, async_test_superuser):
|
||||
"""Test that users cannot delete their own account."""
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123")
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
|
||||
response = await client.delete(f"/api/v1/users/{async_test_superuser.id}", headers=headers)
|
||||
|
||||
@@ -566,7 +566,7 @@ class TestDeleteUser:
|
||||
test_db.commit()
|
||||
test_db.refresh(other_user)
|
||||
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123")
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
|
||||
response = await client.delete(f"/api/v1/users/{other_user.id}", headers=headers)
|
||||
|
||||
@@ -575,7 +575,7 @@ class TestDeleteUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_nonexistent_user(self, client, async_test_superuser):
|
||||
"""Test deleting non-existent user."""
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123")
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
fake_id = uuid.uuid4()
|
||||
|
||||
response = await client.delete(f"/api/v1/users/{fake_id}", headers=headers)
|
||||
|
||||
@@ -131,7 +131,7 @@ def test_user(test_db):
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="testuser@example.com",
|
||||
password_hash=get_password_hash("TestPassword123"),
|
||||
password_hash=get_password_hash("TestPassword123!"),
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="+1234567890",
|
||||
@@ -155,7 +155,7 @@ def test_superuser(test_db):
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="superuser@example.com",
|
||||
password_hash=get_password_hash("SuperPassword123"),
|
||||
password_hash=get_password_hash("SuperPassword123!"),
|
||||
first_name="Super",
|
||||
last_name="User",
|
||||
phone_number="+9876543210",
|
||||
@@ -181,7 +181,7 @@ async def async_test_user(async_test_db):
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="testuser@example.com",
|
||||
password_hash=get_password_hash("TestPassword123"),
|
||||
password_hash=get_password_hash("TestPassword123!"),
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="+1234567890",
|
||||
@@ -207,7 +207,7 @@ async def async_test_superuser(async_test_db):
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="superuser@example.com",
|
||||
password_hash=get_password_hash("SuperPassword123"),
|
||||
password_hash=get_password_hash("SuperPassword123!"),
|
||||
first_name="Super",
|
||||
last_name="User",
|
||||
phone_number="+9876543210",
|
||||
|
||||
@@ -24,26 +24,26 @@ class TestPasswordHandling:
|
||||
|
||||
def test_password_hash_different_from_password(self):
|
||||
"""Test that a password hash is different from the original password"""
|
||||
password = "TestPassword123"
|
||||
password = "TestPassword123!"
|
||||
hashed = get_password_hash(password)
|
||||
assert hashed != password
|
||||
|
||||
def test_verify_correct_password(self):
|
||||
"""Test that verify_password returns True for the correct password"""
|
||||
password = "TestPassword123"
|
||||
password = "TestPassword123!"
|
||||
hashed = get_password_hash(password)
|
||||
assert verify_password(password, hashed) is True
|
||||
|
||||
def test_verify_incorrect_password(self):
|
||||
"""Test that verify_password returns False for an incorrect password"""
|
||||
password = "TestPassword123"
|
||||
wrong_password = "WrongPassword123"
|
||||
password = "TestPassword123!"
|
||||
wrong_password = "WrongPassword123!"
|
||||
hashed = get_password_hash(password)
|
||||
assert verify_password(wrong_password, hashed) is False
|
||||
|
||||
def test_same_password_different_hash(self):
|
||||
"""Test that the same password gets a different hash each time"""
|
||||
password = "TestPassword123"
|
||||
password = "TestPassword123!"
|
||||
hash1 = get_password_hash(password)
|
||||
hash2 = get_password_hash(password)
|
||||
assert hash1 != hash2
|
||||
|
||||
@@ -318,7 +318,7 @@ class TestCRUDCreate:
|
||||
"""Test basic record creation."""
|
||||
user_data = UserCreate(
|
||||
email="create@example.com",
|
||||
password="Password123",
|
||||
password="Password123!",
|
||||
first_name="Create",
|
||||
last_name="Test"
|
||||
)
|
||||
@@ -333,7 +333,7 @@ class TestCRUDCreate:
|
||||
"""Test that creating duplicate email raises error."""
|
||||
user_data = UserCreate(
|
||||
email="duplicate@example.com",
|
||||
password="Password123",
|
||||
password="Password123!",
|
||||
first_name="First"
|
||||
)
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ class TestCRUDErrorPaths:
|
||||
# Create first user
|
||||
user_data = UserCreate(
|
||||
email="unique@example.com",
|
||||
password="Password123",
|
||||
password="Password123!",
|
||||
first_name="First"
|
||||
)
|
||||
user_crud.create(db_session, obj_in=user_data)
|
||||
@@ -52,7 +52,7 @@ class TestCRUDErrorPaths:
|
||||
"""Test create handles other integrity errors."""
|
||||
user_data = UserCreate(
|
||||
email="integrityerror@example.com",
|
||||
password="Password123",
|
||||
password="Password123!",
|
||||
first_name="Integrity"
|
||||
)
|
||||
|
||||
@@ -71,7 +71,7 @@ class TestCRUDErrorPaths:
|
||||
"""Test create handles unexpected errors."""
|
||||
user_data = UserCreate(
|
||||
email="unexpectederror@example.com",
|
||||
password="Password123",
|
||||
password="Password123!",
|
||||
first_name="Unexpected"
|
||||
)
|
||||
|
||||
|
||||
@@ -111,7 +111,7 @@ class TestPhoneNumberValidation:
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
password="Password123",
|
||||
password="Password123!",
|
||||
phone_number="+41791234567"
|
||||
)
|
||||
assert user.phone_number == "+41791234567"
|
||||
@@ -122,6 +122,6 @@ class TestPhoneNumberValidation:
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
password="Password123",
|
||||
password="Password123!",
|
||||
phone_number="invalid-number"
|
||||
)
|
||||
@@ -20,7 +20,7 @@ class TestAuthServiceAuthentication:
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Set a known password for the mock user
|
||||
password = "TestPassword123"
|
||||
password = "TestPassword123!"
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
user = result.scalar_one_or_none()
|
||||
@@ -59,7 +59,7 @@ class TestAuthServiceAuthentication:
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Set a known password for the mock user
|
||||
password = "TestPassword123"
|
||||
password = "TestPassword123!"
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
user = result.scalar_one_or_none()
|
||||
@@ -82,7 +82,7 @@ class TestAuthServiceAuthentication:
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Set a known password and make user inactive
|
||||
password = "TestPassword123"
|
||||
password = "TestPassword123!"
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
user = result.scalar_one_or_none()
|
||||
@@ -110,10 +110,10 @@ class TestAuthServiceUserCreation:
|
||||
|
||||
user_data = UserCreate(
|
||||
email="newuser@example.com",
|
||||
password="TestPassword123",
|
||||
password="TestPassword123!",
|
||||
first_name="New",
|
||||
last_name="User",
|
||||
phone_number="1234567890"
|
||||
phone_number="+1234567890"
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -141,7 +141,7 @@ class TestAuthServiceUserCreation:
|
||||
|
||||
user_data = UserCreate(
|
||||
email=async_test_user.email, # Use existing email
|
||||
password="TestPassword123",
|
||||
password="TestPassword123!",
|
||||
first_name="Duplicate",
|
||||
last_name="User"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user