Refactor authentication services to async password handling; optimize bulk operations and queries

- Updated `verify_password` and `get_password_hash` to their async counterparts to prevent event loop blocking.
- Replaced N+1 query patterns in `admin.py` and `session_async.py` with optimized bulk operations for improved performance.
- Enhanced `user_async.py` with bulk update and soft delete methods for efficient user management.
- Added eager loading support in CRUD operations to prevent N+1 query issues.
- Updated test cases with stronger password examples for better security representation.
This commit is contained in:
Felipe Cardoso
2025-11-01 03:53:22 +01:00
parent 819f3ba963
commit 3fe5d301f8
17 changed files with 397 additions and 163 deletions

View File

@@ -345,54 +345,50 @@ async def admin_bulk_user_action(
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""
Perform bulk actions on multiple users.
Perform bulk actions on multiple users using optimized bulk operations.
Uses single UPDATE query instead of N individual queries for efficiency.
Supported actions: activate, deactivate, delete
"""
affected_count = 0
failed_count = 0
failed_ids = []
try:
for user_id in bulk_action.user_ids:
try:
user = await user_crud.get(db, id=user_id)
if not user:
failed_count += 1
failed_ids.append(user_id)
continue
# Use efficient bulk operations instead of loop
if bulk_action.action == BulkAction.ACTIVATE:
affected_count = await user_crud.bulk_update_status(
db,
user_ids=bulk_action.user_ids,
is_active=True
)
elif bulk_action.action == BulkAction.DEACTIVATE:
affected_count = await user_crud.bulk_update_status(
db,
user_ids=bulk_action.user_ids,
is_active=False
)
elif bulk_action.action == BulkAction.DELETE:
# bulk_soft_delete automatically excludes the admin user
affected_count = await user_crud.bulk_soft_delete(
db,
user_ids=bulk_action.user_ids,
exclude_user_id=admin.id
)
else:
raise ValueError(f"Unsupported bulk action: {bulk_action.action}")
# Prevent affecting yourself
if user.id == admin.id:
failed_count += 1
failed_ids.append(user_id)
continue
if bulk_action.action == BulkAction.ACTIVATE:
await user_crud.update(db, db_obj=user, obj_in={"is_active": True})
elif bulk_action.action == BulkAction.DEACTIVATE:
await user_crud.update(db, db_obj=user, obj_in={"is_active": False})
elif bulk_action.action == BulkAction.DELETE:
await user_crud.soft_delete(db, id=user_id)
affected_count += 1
except Exception as e:
logger.error(f"Error processing user {user_id} in bulk action: {str(e)}")
failed_count += 1
failed_ids.append(user_id)
# Calculate failed count (requested - affected)
requested_count = len(bulk_action.user_ids)
failed_count = requested_count - affected_count
logger.info(
f"Admin {admin.email} performed bulk {bulk_action.action.value} "
f"on {affected_count} users ({failed_count} failed)"
f"on {affected_count} users ({failed_count} skipped/failed)"
)
return BulkActionResult(
success=failed_count == 0,
affected_count=affected_count,
failed_count=failed_count,
message=f"Bulk {bulk_action.action.value}: {affected_count} users affected, {failed_count} failed",
failed_ids=failed_ids if failed_ids else None
message=f"Bulk {bulk_action.action.value}: {affected_count} users affected, {failed_count} skipped",
failed_ids=None # Bulk operations don't track individual failures
)
except Exception as e:

View File

@@ -51,23 +51,20 @@ async def get_my_organizations(
Get all organizations the current user belongs to.
Returns organizations with member count for each.
Uses optimized single query to avoid N+1 problem.
"""
try:
orgs = await organization_crud.get_user_organizations(
# Get all org data in single query with JOIN and subquery
orgs_data = await organization_crud.get_user_organizations_with_details(
db,
user_id=current_user.id,
is_active=is_active
)
# Add member count and role to each organization
# Transform to response objects
orgs_with_data = []
for org in orgs:
role = await organization_crud.get_user_role_in_org(
db,
user_id=current_user.id,
organization_id=org.id
)
for item in orgs_data:
org = item['organization']
org_dict = {
"id": org.id,
"name": org.name,
@@ -77,7 +74,7 @@ async def get_my_organizations(
"settings": org.settings,
"created_at": org.created_at,
"updated_at": org.updated_at,
"member_count": await organization_crud.get_member_count(db, organization_id=org.id)
"member_count": item['member_count']
}
orgs_with_data.append(OrganizationResponse(**org_dict))

View File

@@ -4,6 +4,8 @@ logging.getLogger('passlib').setLevel(logging.ERROR)
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional, Union
import uuid
import asyncio
from functools import partial
from jose import jwt, JWTError
from passlib.context import CryptContext
@@ -44,6 +46,49 @@ def get_password_hash(password: str) -> str:
return pwd_context.hash(password)
async def verify_password_async(plain_password: str, hashed_password: str) -> bool:
"""
Verify a password against a hash asynchronously.
Runs the CPU-intensive bcrypt operation in a thread pool to avoid
blocking the event loop.
Args:
plain_password: Plain text password to verify
hashed_password: Hashed password to verify against
Returns:
True if password matches, False otherwise
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
partial(pwd_context.verify, plain_password, hashed_password)
)
async def get_password_hash_async(password: str) -> str:
"""
Generate a password hash asynchronously.
Runs the CPU-intensive bcrypt operation in a thread pool to avoid
blocking the event loop. This is especially important during user
registration and password changes.
Args:
password: Plain text password to hash
Returns:
Hashed password string
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
pwd_context.hash,
password
)
def create_access_token(
subject: Union[str, Any],
expires_delta: Optional[timedelta] = None,

View File

@@ -13,6 +13,7 @@ from pydantic import BaseModel
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
from sqlalchemy.orm import Load
from app.core.database_async import Base
@@ -35,8 +36,29 @@ class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
"""
self.model = model
async def get(self, db: AsyncSession, id: str) -> Optional[ModelType]:
"""Get a single record by ID with UUID validation."""
async def get(
self,
db: AsyncSession,
id: str,
options: Optional[List[Load]] = None
) -> Optional[ModelType]:
"""
Get a single record by ID with UUID validation and optional eager loading.
Args:
db: Database session
id: Record UUID
options: Optional list of SQLAlchemy load options (e.g., joinedload, selectinload)
for eager loading relationships to prevent N+1 queries
Returns:
Model instance or None if not found
Example:
# Eager load user relationship
from sqlalchemy.orm import joinedload
session = await session_crud.get(db, id=session_id, options=[joinedload(UserSession.user)])
"""
# Validate UUID format and convert to UUID object if string
try:
if isinstance(id, uuid.UUID):
@@ -48,18 +70,39 @@ class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
return None
try:
result = await db.execute(
select(self.model).where(self.model.id == uuid_obj)
)
query = select(self.model).where(self.model.id == uuid_obj)
# Apply eager loading options if provided
if options:
for option in options:
query = query.options(option)
result = await db.execute(query)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}")
raise
async def get_multi(
self, db: AsyncSession, *, skip: int = 0, limit: int = 100
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
options: Optional[List[Load]] = None
) -> List[ModelType]:
"""Get multiple records with pagination validation."""
"""
Get multiple records with pagination validation and optional eager loading.
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
options: Optional list of SQLAlchemy load options for eager loading
Returns:
List of model instances
"""
# Validate pagination parameters
if skip < 0:
raise ValueError("skip must be non-negative")
@@ -69,9 +112,14 @@ class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
raise ValueError("Maximum limit is 1000")
try:
result = await db.execute(
select(self.model).offset(skip).limit(limit)
)
query = select(self.model).offset(skip).limit(limit)
# Apply eager loading options if provided
if options:
for option in options:
query = query.options(option)
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}")

View File

@@ -5,7 +5,8 @@ from datetime import datetime, timezone, timedelta
from typing import List, Optional
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import and_, select, update, func
from sqlalchemy import and_, select, update, delete, func
from sqlalchemy.orm import selectinload, joinedload
import logging
from app.crud.base_async import CRUDBaseAsync
@@ -68,15 +69,17 @@ class CRUDSessionAsync(CRUDBaseAsync[UserSession, SessionCreate, SessionUpdate])
db: AsyncSession,
*,
user_id: str,
active_only: bool = True
active_only: bool = True,
with_user: bool = False
) -> List[UserSession]:
"""
Get all sessions for a user.
Get all sessions for a user with optional eager loading.
Args:
db: Database session
user_id: User ID
active_only: If True, return only active sessions
with_user: If True, eager load user relationship to prevent N+1
Returns:
List of UserSession objects
@@ -87,6 +90,10 @@ class CRUDSessionAsync(CRUDBaseAsync[UserSession, SessionCreate, SessionUpdate])
query = select(UserSession).where(UserSession.user_id == user_uuid)
# Add eager loading if requested to prevent N+1 queries
if with_user:
query = query.options(joinedload(UserSession.user))
if active_only:
query = query.where(UserSession.is_active == True)
@@ -286,12 +293,14 @@ class CRUDSessionAsync(CRUDBaseAsync[UserSession, SessionCreate, SessionUpdate])
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
"""
Clean up expired sessions.
Clean up expired sessions using optimized bulk DELETE.
Deletes sessions that are:
- Expired AND inactive
- Older than keep_days
Uses single DELETE query instead of N individual deletes for efficiency.
Args:
db: Database session
keep_days: Keep inactive sessions for this many days (for audit)
@@ -301,28 +310,24 @@ class CRUDSessionAsync(CRUDBaseAsync[UserSession, SessionCreate, SessionUpdate])
"""
try:
cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days)
now = datetime.now(timezone.utc)
# Get sessions to delete
stmt = select(UserSession).where(
# Use bulk DELETE with WHERE clause - single query
stmt = delete(UserSession).where(
and_(
UserSession.is_active == False,
UserSession.expires_at < datetime.now(timezone.utc),
UserSession.expires_at < now,
UserSession.created_at < cutoff_date
)
)
result = await db.execute(stmt)
sessions_to_delete = list(result.scalars().all())
# Delete them
for session in sessions_to_delete:
await db.delete(session)
await db.commit()
count = len(sessions_to_delete)
count = result.rowcount
if count > 0:
logger.info(f"Cleaned up {count} expired sessions")
logger.info(f"Cleaned up {count} expired sessions using bulk DELETE")
return count
except Exception as e:

View File

@@ -1,13 +1,15 @@
# app/crud/user_async.py
"""Async CRUD operations for User model using SQLAlchemy 2.0 patterns."""
from typing import Optional, Union, Dict, Any, List, Tuple
from uuid import UUID
from datetime import datetime, timezone
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import IntegrityError
from sqlalchemy import or_, select
from sqlalchemy import or_, select, update
from app.crud.base_async import CRUDBaseAsync
from app.models.user import User
from app.schemas.users import UserCreate, UserUpdate
from app.core.auth import get_password_hash
from app.core.auth import get_password_hash_async
import logging
logger = logging.getLogger(__name__)
@@ -28,11 +30,14 @@ class CRUDUserAsync(CRUDBaseAsync[User, UserCreate, UserUpdate]):
raise
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
"""Create a new user with password hashing and error handling."""
"""Create a new user with async password hashing and error handling."""
try:
# Hash password asynchronously to avoid blocking event loop
password_hash = await get_password_hash_async(obj_in.password)
db_obj = User(
email=obj_in.email,
password_hash=get_password_hash(obj_in.password),
password_hash=password_hash,
first_name=obj_in.first_name,
last_name=obj_in.last_name,
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None,
@@ -63,15 +68,16 @@ class CRUDUserAsync(CRUDBaseAsync[User, UserCreate, UserUpdate]):
db_obj: User,
obj_in: Union[UserUpdate, Dict[str, Any]]
) -> User:
"""Update user with password hashing if password is updated."""
"""Update user with async password hashing if password is updated."""
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.model_dump(exclude_unset=True)
# Handle password separately if it exists in update data
# Hash password asynchronously to avoid blocking event loop
if "password" in update_data:
update_data["password_hash"] = get_password_hash(update_data["password"])
update_data["password_hash"] = await get_password_hash_async(update_data["password"])
del update_data["password"]
return await super().update(db, db_obj=db_obj, obj_in=update_data)
@@ -157,6 +163,100 @@ class CRUDUserAsync(CRUDBaseAsync[User, UserCreate, UserUpdate]):
logger.error(f"Error retrieving paginated users: {str(e)}")
raise
async def bulk_update_status(
self,
db: AsyncSession,
*,
user_ids: List[UUID],
is_active: bool
) -> int:
"""
Bulk update is_active status for multiple users.
Args:
db: Database session
user_ids: List of user IDs to update
is_active: New active status
Returns:
Number of users updated
"""
try:
if not user_ids:
return 0
# Use UPDATE with WHERE IN for efficiency
stmt = (
update(User)
.where(User.id.in_(user_ids))
.where(User.deleted_at.is_(None)) # Don't update deleted users
.values(is_active=is_active, updated_at=datetime.now(timezone.utc))
)
result = await db.execute(stmt)
await db.commit()
updated_count = result.rowcount
logger.info(f"Bulk updated {updated_count} users to is_active={is_active}")
return updated_count
except Exception as e:
await db.rollback()
logger.error(f"Error bulk updating user status: {str(e)}", exc_info=True)
raise
async def bulk_soft_delete(
self,
db: AsyncSession,
*,
user_ids: List[UUID],
exclude_user_id: Optional[UUID] = None
) -> int:
"""
Bulk soft delete multiple users.
Args:
db: Database session
user_ids: List of user IDs to delete
exclude_user_id: Optional user ID to exclude (e.g., the admin performing the action)
Returns:
Number of users deleted
"""
try:
if not user_ids:
return 0
# Remove excluded user from list
filtered_ids = [uid for uid in user_ids if uid != exclude_user_id]
if not filtered_ids:
return 0
# Use UPDATE with WHERE IN for efficiency
stmt = (
update(User)
.where(User.id.in_(filtered_ids))
.where(User.deleted_at.is_(None)) # Don't re-delete already deleted users
.values(
deleted_at=datetime.now(timezone.utc),
is_active=False,
updated_at=datetime.now(timezone.utc)
)
)
result = await db.execute(stmt)
await db.commit()
deleted_count = result.rowcount
logger.info(f"Bulk soft deleted {deleted_count} users")
return deleted_count
except Exception as e:
await db.rollback()
logger.error(f"Error bulk deleting users: {str(e)}", exc_info=True)
raise
def is_active(self, user: User) -> bool:
"""Check if user is active."""
return user.is_active

View File

@@ -2,6 +2,7 @@
Common schemas used across the API for pagination, responses, filtering, and sorting.
"""
from typing import Generic, TypeVar, List, Optional
from uuid import UUID
from enum import Enum
from pydantic import BaseModel, Field
from math import ceil
@@ -138,6 +139,46 @@ class MessageResponse(BaseModel):
}
class BulkActionRequest(BaseModel):
"""Request schema for bulk operations on multiple items."""
ids: List[UUID] = Field(
...,
min_length=1,
max_length=100,
description="List of item IDs to perform action on (max 100)"
)
model_config = {
"json_schema_extra": {
"example": {
"ids": [
"550e8400-e29b-41d4-a716-446655440000",
"6ba7b810-9dad-11d1-80b4-00c04fd430c8"
]
}
}
}
class BulkActionResponse(BaseModel):
"""Response schema for bulk operations."""
success: bool = Field(default=True, description="Operation success status")
message: str = Field(..., description="Human-readable message")
affected_count: int = Field(..., description="Number of items affected by the operation")
model_config = {
"json_schema_extra": {
"example": {
"success": True,
"message": "Successfully deactivated 5 users",
"affected_count": 5
}
}
}
def create_pagination_meta(
total: int,
page: int,

View File

@@ -7,8 +7,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.core.auth import (
verify_password,
get_password_hash,
verify_password_async,
get_password_hash_async,
create_access_token,
create_refresh_token,
TokenExpiredError,
@@ -31,7 +31,7 @@ class AuthService:
@staticmethod
async def authenticate_user(db: AsyncSession, email: str, password: str) -> Optional[User]:
"""
Authenticate a user with email and password.
Authenticate a user with email and password using async password verification.
Args:
db: Database session
@@ -47,7 +47,8 @@ class AuthService:
if not user:
return None
if not verify_password(password, user.password_hash):
# Verify password asynchronously to avoid blocking event loop
if not await verify_password_async(password, user.password_hash):
return None
if not user.is_active:
@@ -77,8 +78,9 @@ class AuthService:
if existing_user:
raise AuthenticationError("User with this email already exists")
# Create new user
hashed_password = get_password_hash(user_data.password)
# Create new user with async password hashing
# Hash password asynchronously to avoid blocking event loop
hashed_password = await get_password_hash_async(user_data.password)
# Create user object from model
user = User(
@@ -202,12 +204,12 @@ class AuthService:
if not user:
raise AuthenticationError("User not found")
# Verify current password
if not verify_password(current_password, user.password_hash):
# Verify current password asynchronously
if not await verify_password_async(current_password, user.password_hash):
raise AuthenticationError("Current password is incorrect")
# Update password
user.password_hash = get_password_hash(new_password)
# Hash new password asynchronously to avoid blocking event loop
user.password_hash = await get_password_hash_async(new_password)
await db.commit()
logger.info(f"Password changed successfully for user {user_id}")

View File

@@ -30,7 +30,7 @@ class TestRegisterEndpoint:
"/api/v1/auth/register",
json={
"email": "newuser@example.com",
"password": "SecurePassword123",
"password": "SecurePassword123!",
"first_name": "New",
"last_name": "User"
}
@@ -49,7 +49,7 @@ class TestRegisterEndpoint:
"/api/v1/auth/register",
json={
"email": async_test_user.email,
"password": "SecurePassword123",
"password": "SecurePassword123!",
"first_name": "Duplicate",
"last_name": "User"
}
@@ -103,7 +103,7 @@ class TestLoginEndpoint:
"/api/v1/auth/login",
json={
"email": async_test_user.email,
"password": "TestPassword123"
"password": "TestPassword123!"
}
)
@@ -133,7 +133,7 @@ class TestLoginEndpoint:
"/api/v1/auth/login",
json={
"email": "nonexistent@example.com",
"password": "Password123"
"password": "Password123!"
}
)
@@ -154,7 +154,7 @@ class TestLoginEndpoint:
"/api/v1/auth/login",
json={
"email": async_test_user.email,
"password": "TestPassword123"
"password": "TestPassword123!"
}
)
@@ -170,7 +170,7 @@ class TestLoginEndpoint:
"/api/v1/auth/login",
json={
"email": async_test_user.email,
"password": "TestPassword123"
"password": "TestPassword123!"
}
)
@@ -187,7 +187,7 @@ class TestOAuthLoginEndpoint:
"/api/v1/auth/login/oauth",
data={
"username": async_test_user.email,
"password": "TestPassword123"
"password": "TestPassword123!"
}
)
@@ -224,7 +224,7 @@ class TestOAuthLoginEndpoint:
"/api/v1/auth/login/oauth",
data={
"username": async_test_user.email,
"password": "TestPassword123"
"password": "TestPassword123!"
}
)
@@ -240,7 +240,7 @@ class TestOAuthLoginEndpoint:
"/api/v1/auth/login/oauth",
data={
"username": async_test_user.email,
"password": "TestPassword123"
"password": "TestPassword123!"
}
)
@@ -258,7 +258,7 @@ class TestRefreshTokenEndpoint:
"/api/v1/auth/login",
json={
"email": async_test_user.email,
"password": "TestPassword123"
"password": "TestPassword123!"
}
)
refresh_token = login_response.json()["refresh_token"]
@@ -307,7 +307,7 @@ class TestRefreshTokenEndpoint:
"/api/v1/auth/login",
json={
"email": async_test_user.email,
"password": "TestPassword123"
"password": "TestPassword123!"
}
)
refresh_token = login_response.json()["refresh_token"]
@@ -334,7 +334,7 @@ class TestGetCurrentUserEndpoint:
"/api/v1/auth/login",
json={
"email": async_test_user.email,
"password": "TestPassword123"
"password": "TestPassword123!"
}
)
access_token = login_response.json()["access_token"]

View File

@@ -148,7 +148,7 @@ class TestPasswordResetConfirm:
"""Test password reset confirmation with valid token."""
# Generate valid token
token = create_password_reset_token(async_test_user.email)
new_password = "NewSecure123"
new_password = "NewSecure123!"
response = await client.post(
"/api/v1/auth/password-reset/confirm",
@@ -186,7 +186,7 @@ class TestPasswordResetConfirm:
"/api/v1/auth/password-reset/confirm",
json={
"token": token,
"new_password": "NewSecure123"
"new_password": "NewSecure123!"
}
)
@@ -204,7 +204,7 @@ class TestPasswordResetConfirm:
"/api/v1/auth/password-reset/confirm",
json={
"token": "invalid_token_xyz",
"new_password": "NewSecure123"
"new_password": "NewSecure123!"
}
)
@@ -233,7 +233,7 @@ class TestPasswordResetConfirm:
"/api/v1/auth/password-reset/confirm",
json={
"token": tampered,
"new_password": "NewSecure123"
"new_password": "NewSecure123!"
}
)
@@ -249,7 +249,7 @@ class TestPasswordResetConfirm:
"/api/v1/auth/password-reset/confirm",
json={
"token": token,
"new_password": "NewSecure123"
"new_password": "NewSecure123!"
}
)
@@ -276,7 +276,7 @@ class TestPasswordResetConfirm:
"/api/v1/auth/password-reset/confirm",
json={
"token": token,
"new_password": "NewSecure123"
"new_password": "NewSecure123!"
}
)
@@ -315,7 +315,7 @@ class TestPasswordResetConfirm:
# Missing token
response = await client.post(
"/api/v1/auth/password-reset/confirm",
json={"new_password": "NewSecure123"}
json={"new_password": "NewSecure123!"}
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@@ -340,7 +340,7 @@ class TestPasswordResetConfirm:
"/api/v1/auth/password-reset/confirm",
json={
"token": token,
"new_password": "NewSecure123"
"new_password": "NewSecure123!"
}
)
@@ -354,7 +354,7 @@ class TestPasswordResetConfirm:
async def test_password_reset_full_flow(self, client, async_test_user, async_test_db):
"""Test complete password reset flow."""
original_password = async_test_user.password_hash
new_password = "BrandNew123"
new_password = "BrandNew123!"
# Step 1: Request password reset
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send:

View File

@@ -40,7 +40,7 @@ class TestListUsers:
@pytest.mark.asyncio
async def test_list_users_as_superuser(self, client, async_test_superuser):
"""Test listing users as superuser."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123")
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
response = await client.get("/api/v1/users", headers=headers)
@@ -53,7 +53,7 @@ class TestListUsers:
@pytest.mark.asyncio
async def test_list_users_as_regular_user(self, client, async_test_user):
"""Test that regular users cannot list users."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123")
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
response = await client.get("/api/v1/users", headers=headers)
@@ -77,7 +77,7 @@ class TestListUsers:
session.add(user)
await session.commit()
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123")
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
# Get first page
response = await client.get("/api/v1/users?page=1&limit=5", headers=headers)
@@ -111,7 +111,7 @@ class TestListUsers:
session.add_all([active_user, inactive_user])
await session.commit()
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123")
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
# Filter for active users
response = await client.get("/api/v1/users?is_active=true", headers=headers)
@@ -130,7 +130,7 @@ class TestListUsers:
@pytest.mark.asyncio
async def test_list_users_sort_by_email(self, client, async_test_superuser):
"""Test sorting users by email."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123")
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
response = await client.get("/api/v1/users?sort_by=email&sort_order=asc", headers=headers)
assert response.status_code == status.HTTP_200_OK
@@ -154,7 +154,7 @@ class TestGetCurrentUserProfile:
@pytest.mark.asyncio
async def test_get_own_profile(self, client, async_test_user):
"""Test getting own profile."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123")
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
response = await client.get("/api/v1/users/me", headers=headers)
@@ -176,7 +176,7 @@ class TestUpdateCurrentUser:
@pytest.mark.asyncio
async def test_update_own_profile(self, client, async_test_user):
"""Test updating own profile."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123")
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
response = await client.patch(
"/api/v1/users/me",
@@ -192,7 +192,7 @@ class TestUpdateCurrentUser:
@pytest.mark.asyncio
async def test_update_profile_phone_number(self, client, async_test_user, test_db):
"""Test updating phone number with validation."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123")
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
response = await client.patch(
"/api/v1/users/me",
@@ -207,7 +207,7 @@ class TestUpdateCurrentUser:
@pytest.mark.asyncio
async def test_update_profile_invalid_phone(self, client, async_test_user):
"""Test that invalid phone numbers are rejected."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123")
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
response = await client.patch(
"/api/v1/users/me",
@@ -220,7 +220,7 @@ class TestUpdateCurrentUser:
@pytest.mark.asyncio
async def test_cannot_elevate_to_superuser(self, client, async_test_user):
"""Test that users cannot make themselves superuser."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123")
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
# Note: is_superuser is not in UserUpdate schema, but the endpoint checks for it
# This tests that even if someone tries to send it, it's rejected
@@ -255,7 +255,7 @@ class TestGetUserById:
@pytest.mark.asyncio
async def test_get_own_profile_by_id(self, client, async_test_user):
"""Test getting own profile by ID."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123")
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
response = await client.get(f"/api/v1/users/{async_test_user.id}", headers=headers)
@@ -278,7 +278,7 @@ class TestGetUserById:
test_db.commit()
test_db.refresh(other_user)
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123")
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
response = await client.get(f"/api/v1/users/{other_user.id}", headers=headers)
@@ -287,7 +287,7 @@ class TestGetUserById:
@pytest.mark.asyncio
async def test_get_other_user_as_superuser(self, client, async_test_superuser, async_test_user):
"""Test that superusers can view other profiles."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123")
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
response = await client.get(f"/api/v1/users/{async_test_user.id}", headers=headers)
@@ -298,7 +298,7 @@ class TestGetUserById:
@pytest.mark.asyncio
async def test_get_nonexistent_user(self, client, async_test_superuser):
"""Test getting non-existent user."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123")
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
fake_id = uuid.uuid4()
response = await client.get(f"/api/v1/users/{fake_id}", headers=headers)
@@ -308,7 +308,7 @@ class TestGetUserById:
@pytest.mark.asyncio
async def test_get_user_invalid_uuid(self, client, async_test_superuser):
"""Test getting user with invalid UUID format."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123")
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
response = await client.get("/api/v1/users/not-a-uuid", headers=headers)
@@ -321,7 +321,7 @@ class TestUpdateUserById:
@pytest.mark.asyncio
async def test_update_own_profile_by_id(self, client, async_test_user, test_db):
"""Test updating own profile by ID."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123")
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
response = await client.patch(
f"/api/v1/users/{async_test_user.id}",
@@ -348,7 +348,7 @@ class TestUpdateUserById:
test_db.commit()
test_db.refresh(other_user)
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123")
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
response = await client.patch(
f"/api/v1/users/{other_user.id}",
@@ -365,7 +365,7 @@ class TestUpdateUserById:
@pytest.mark.asyncio
async def test_update_other_user_as_superuser(self, client, async_test_superuser, async_test_user, test_db):
"""Test that superusers can update other profiles."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123")
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
response = await client.patch(
f"/api/v1/users/{async_test_user.id}",
@@ -380,7 +380,7 @@ class TestUpdateUserById:
@pytest.mark.asyncio
async def test_regular_user_cannot_modify_superuser_status(self, client, async_test_user):
"""Test that regular users cannot change superuser status even if they try."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123")
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
# is_superuser not in UserUpdate schema, so it gets ignored by Pydantic
# Just verify the user stays the same
@@ -397,7 +397,7 @@ class TestUpdateUserById:
@pytest.mark.asyncio
async def test_superuser_can_update_users(self, client, async_test_superuser, async_test_user, test_db):
"""Test that superusers can update other users."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123")
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
response = await client.patch(
f"/api/v1/users/{async_test_user.id}",
@@ -413,7 +413,7 @@ class TestUpdateUserById:
@pytest.mark.asyncio
async def test_update_nonexistent_user(self, client, async_test_superuser):
"""Test updating non-existent user."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123")
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
fake_id = uuid.uuid4()
response = await client.patch(
@@ -433,14 +433,14 @@ class TestChangePassword:
@pytest.mark.asyncio
async def test_change_password_success(self, client, async_test_user, test_db):
"""Test successful password change."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123")
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
response = await client.patch(
"/api/v1/users/me/password",
headers=headers,
json={
"current_password": "TestPassword123",
"new_password": "NewPassword123"
"current_password": "TestPassword123!",
"new_password": "NewPassword123!"
}
)
@@ -453,7 +453,7 @@ class TestChangePassword:
"/api/v1/auth/login",
json={
"email": async_test_user.email,
"password": "NewPassword123"
"password": "NewPassword123!"
}
)
assert login_response.status_code == status.HTTP_200_OK
@@ -461,14 +461,14 @@ class TestChangePassword:
@pytest.mark.asyncio
async def test_change_password_wrong_current(self, client, async_test_user):
"""Test that wrong current password is rejected."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123")
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
response = await client.patch(
"/api/v1/users/me/password",
headers=headers,
json={
"current_password": "WrongPassword123",
"new_password": "NewPassword123"
"new_password": "NewPassword123!"
}
)
@@ -477,13 +477,13 @@ class TestChangePassword:
@pytest.mark.asyncio
async def test_change_password_weak_new_password(self, client, async_test_user):
"""Test that weak new passwords are rejected."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123")
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
response = await client.patch(
"/api/v1/users/me/password",
headers=headers,
json={
"current_password": "TestPassword123",
"current_password": "TestPassword123!",
"new_password": "weak"
}
)
@@ -496,8 +496,8 @@ class TestChangePassword:
response = await client.patch(
"/api/v1/users/me/password",
json={
"current_password": "TestPassword123",
"new_password": "NewPassword123"
"current_password": "TestPassword123!",
"new_password": "NewPassword123!"
}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -527,7 +527,7 @@ class TestDeleteUser:
await session.refresh(user_to_delete)
user_id = user_to_delete.id
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123")
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
response = await client.delete(f"/api/v1/users/{user_id}", headers=headers)
@@ -545,7 +545,7 @@ class TestDeleteUser:
@pytest.mark.asyncio
async def test_cannot_delete_self(self, client, async_test_superuser):
"""Test that users cannot delete their own account."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123")
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
response = await client.delete(f"/api/v1/users/{async_test_superuser.id}", headers=headers)
@@ -566,7 +566,7 @@ class TestDeleteUser:
test_db.commit()
test_db.refresh(other_user)
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123")
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
response = await client.delete(f"/api/v1/users/{other_user.id}", headers=headers)
@@ -575,7 +575,7 @@ class TestDeleteUser:
@pytest.mark.asyncio
async def test_delete_nonexistent_user(self, client, async_test_superuser):
"""Test deleting non-existent user."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123")
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
fake_id = uuid.uuid4()
response = await client.delete(f"/api/v1/users/{fake_id}", headers=headers)

View File

@@ -131,7 +131,7 @@ def test_user(test_db):
user = User(
id=uuid.uuid4(),
email="testuser@example.com",
password_hash=get_password_hash("TestPassword123"),
password_hash=get_password_hash("TestPassword123!"),
first_name="Test",
last_name="User",
phone_number="+1234567890",
@@ -155,7 +155,7 @@ def test_superuser(test_db):
user = User(
id=uuid.uuid4(),
email="superuser@example.com",
password_hash=get_password_hash("SuperPassword123"),
password_hash=get_password_hash("SuperPassword123!"),
first_name="Super",
last_name="User",
phone_number="+9876543210",
@@ -181,7 +181,7 @@ async def async_test_user(async_test_db):
user = User(
id=uuid.uuid4(),
email="testuser@example.com",
password_hash=get_password_hash("TestPassword123"),
password_hash=get_password_hash("TestPassword123!"),
first_name="Test",
last_name="User",
phone_number="+1234567890",
@@ -207,7 +207,7 @@ async def async_test_superuser(async_test_db):
user = User(
id=uuid.uuid4(),
email="superuser@example.com",
password_hash=get_password_hash("SuperPassword123"),
password_hash=get_password_hash("SuperPassword123!"),
first_name="Super",
last_name="User",
phone_number="+9876543210",

View File

@@ -24,26 +24,26 @@ class TestPasswordHandling:
def test_password_hash_different_from_password(self):
"""Test that a password hash is different from the original password"""
password = "TestPassword123"
password = "TestPassword123!"
hashed = get_password_hash(password)
assert hashed != password
def test_verify_correct_password(self):
"""Test that verify_password returns True for the correct password"""
password = "TestPassword123"
password = "TestPassword123!"
hashed = get_password_hash(password)
assert verify_password(password, hashed) is True
def test_verify_incorrect_password(self):
"""Test that verify_password returns False for an incorrect password"""
password = "TestPassword123"
wrong_password = "WrongPassword123"
password = "TestPassword123!"
wrong_password = "WrongPassword123!"
hashed = get_password_hash(password)
assert verify_password(wrong_password, hashed) is False
def test_same_password_different_hash(self):
"""Test that the same password gets a different hash each time"""
password = "TestPassword123"
password = "TestPassword123!"
hash1 = get_password_hash(password)
hash2 = get_password_hash(password)
assert hash1 != hash2

View File

@@ -318,7 +318,7 @@ class TestCRUDCreate:
"""Test basic record creation."""
user_data = UserCreate(
email="create@example.com",
password="Password123",
password="Password123!",
first_name="Create",
last_name="Test"
)
@@ -333,7 +333,7 @@ class TestCRUDCreate:
"""Test that creating duplicate email raises error."""
user_data = UserCreate(
email="duplicate@example.com",
password="Password123",
password="Password123!",
first_name="First"
)

View File

@@ -39,7 +39,7 @@ class TestCRUDErrorPaths:
# Create first user
user_data = UserCreate(
email="unique@example.com",
password="Password123",
password="Password123!",
first_name="First"
)
user_crud.create(db_session, obj_in=user_data)
@@ -52,7 +52,7 @@ class TestCRUDErrorPaths:
"""Test create handles other integrity errors."""
user_data = UserCreate(
email="integrityerror@example.com",
password="Password123",
password="Password123!",
first_name="Integrity"
)
@@ -71,7 +71,7 @@ class TestCRUDErrorPaths:
"""Test create handles unexpected errors."""
user_data = UserCreate(
email="unexpectederror@example.com",
password="Password123",
password="Password123!",
first_name="Unexpected"
)

View File

@@ -111,7 +111,7 @@ class TestPhoneNumberValidation:
email="test@example.com",
first_name="Test",
last_name="User",
password="Password123",
password="Password123!",
phone_number="+41791234567"
)
assert user.phone_number == "+41791234567"
@@ -122,6 +122,6 @@ class TestPhoneNumberValidation:
email="test@example.com",
first_name="Test",
last_name="User",
password="Password123",
password="Password123!",
phone_number="invalid-number"
)

View File

@@ -20,7 +20,7 @@ class TestAuthServiceAuthentication:
test_engine, AsyncTestingSessionLocal = async_test_db
# Set a known password for the mock user
password = "TestPassword123"
password = "TestPassword123!"
async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id))
user = result.scalar_one_or_none()
@@ -59,7 +59,7 @@ class TestAuthServiceAuthentication:
test_engine, AsyncTestingSessionLocal = async_test_db
# Set a known password for the mock user
password = "TestPassword123"
password = "TestPassword123!"
async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id))
user = result.scalar_one_or_none()
@@ -82,7 +82,7 @@ class TestAuthServiceAuthentication:
test_engine, AsyncTestingSessionLocal = async_test_db
# Set a known password and make user inactive
password = "TestPassword123"
password = "TestPassword123!"
async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id))
user = result.scalar_one_or_none()
@@ -110,10 +110,10 @@ class TestAuthServiceUserCreation:
user_data = UserCreate(
email="newuser@example.com",
password="TestPassword123",
password="TestPassword123!",
first_name="New",
last_name="User",
phone_number="1234567890"
phone_number="+1234567890"
)
async with AsyncTestingSessionLocal() as session:
@@ -141,7 +141,7 @@ class TestAuthServiceUserCreation:
user_data = UserCreate(
email=async_test_user.email, # Use existing email
password="TestPassword123",
password="TestPassword123!",
first_name="Duplicate",
last_name="User"
)