forked from cardosofelipe/fast-next-template
Improve error handling, logging, and security in authentication services and utilities
- Refactored `create_user` and `change_password` methods to add transaction rollback on failures and enhanced logging for error contexts. - Updated security utilities to use constant-time comparison (`hmac.compare_digest`) to mitigate timing attacks. - Adjusted API responses in registration and password reset flows for better security and user experience. - Added session invalidation after password resets to enhance account security.
This commit is contained in:
@@ -61,10 +61,11 @@ async def register_user(
|
|||||||
user = await AuthService.create_user(db, user_data)
|
user = await AuthService.create_user(db, user_data)
|
||||||
return user
|
return user
|
||||||
except AuthenticationError as e:
|
except AuthenticationError as e:
|
||||||
|
# SECURITY: Don't reveal if email exists - generic error message
|
||||||
logger.warning(f"Registration failed: {str(e)}")
|
logger.warning(f"Registration failed: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_409_CONFLICT,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=str(e)
|
detail="Registration failed. Please check your information and try again."
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error during registration: {str(e)}")
|
logger.error(f"Unexpected error during registration: {str(e)}")
|
||||||
@@ -439,11 +440,22 @@ async def confirm_password_reset(
|
|||||||
db.add(user)
|
db.add(user)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
logger.info(f"Password reset successful for {user.email}")
|
# SECURITY: Invalidate all existing sessions after password reset
|
||||||
|
# This prevents stolen sessions from being used after password change
|
||||||
|
from app.crud.session_async import session_async as session_crud
|
||||||
|
try:
|
||||||
|
deactivated_count = await session_crud.deactivate_all_user_sessions(
|
||||||
|
db,
|
||||||
|
user_id=str(user.id)
|
||||||
|
)
|
||||||
|
logger.info(f"Password reset successful for {user.email}, invalidated {deactivated_count} sessions")
|
||||||
|
except Exception as session_error:
|
||||||
|
# Log but don't fail password reset if session invalidation fails
|
||||||
|
logger.error(f"Failed to invalidate sessions after password reset: {str(session_error)}")
|
||||||
|
|
||||||
return MessageResponse(
|
return MessageResponse(
|
||||||
success=True,
|
success=True,
|
||||||
message="Password has been reset successfully. You can now log in with your new password."
|
message="Password has been reset successfully. All devices have been logged out for security. You can now log in with your new password."
|
||||||
)
|
)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ async def get_my_organizations(
|
|||||||
# Add member count and role to each organization
|
# Add member count and role to each organization
|
||||||
orgs_with_data = []
|
orgs_with_data = []
|
||||||
for org in orgs:
|
for org in orgs:
|
||||||
role = organization_crud.get_user_role_in_org(
|
role = await organization_crud.get_user_role_in_org(
|
||||||
db,
|
db,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
organization_id=org.id
|
organization_id=org.id
|
||||||
|
|||||||
@@ -66,7 +66,11 @@ class AuthService:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Created user
|
Created user
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AuthenticationError: If user already exists or creation fails
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
# Check if user already exists
|
# Check if user already exists
|
||||||
result = await db.execute(select(User).where(User.email == user_data.email))
|
result = await db.execute(select(User).where(User.email == user_data.email))
|
||||||
existing_user = result.scalar_one_or_none()
|
existing_user = result.scalar_one_or_none()
|
||||||
@@ -91,8 +95,18 @@ class AuthService:
|
|||||||
await db.commit()
|
await db.commit()
|
||||||
await db.refresh(user)
|
await db.refresh(user)
|
||||||
|
|
||||||
|
logger.info(f"User created successfully: {user.email}")
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
except AuthenticationError:
|
||||||
|
# Re-raise authentication errors without rollback
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
# Rollback on any database errors
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Error creating user: {str(e)}", exc_info=True)
|
||||||
|
raise AuthenticationError(f"Failed to create user: {str(e)}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_tokens(user: User) -> Token:
|
def create_tokens(user: User) -> Token:
|
||||||
"""
|
"""
|
||||||
@@ -180,8 +194,9 @@ class AuthService:
|
|||||||
True if password was changed successfully
|
True if password was changed successfully
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AuthenticationError: If current password is incorrect
|
AuthenticationError: If current password is incorrect or update fails
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
result = await db.execute(select(User).where(User.id == user_id))
|
result = await db.execute(select(User).where(User.id == user_id))
|
||||||
user = result.scalar_one_or_none()
|
user = result.scalar_one_or_none()
|
||||||
if not user:
|
if not user:
|
||||||
@@ -195,4 +210,14 @@ class AuthService:
|
|||||||
user.password_hash = get_password_hash(new_password)
|
user.password_hash = get_password_hash(new_password)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
|
logger.info(f"Password changed successfully for user {user_id}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
except AuthenticationError:
|
||||||
|
# Re-raise authentication errors without rollback
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
# Rollback on any database errors
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Error changing password for user {user_id}: {str(e)}", exc_info=True)
|
||||||
|
raise AuthenticationError(f"Failed to change password: {str(e)}")
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ time-limited, single-use operations.
|
|||||||
"""
|
"""
|
||||||
import base64
|
import base64
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import hmac
|
||||||
import json
|
import json
|
||||||
import secrets
|
import secrets
|
||||||
import time
|
import time
|
||||||
@@ -92,13 +93,13 @@ def verify_upload_token(token: str) -> Optional[Dict[str, Any]]:
|
|||||||
payload = token_data["payload"]
|
payload = token_data["payload"]
|
||||||
signature = token_data["signature"]
|
signature = token_data["signature"]
|
||||||
|
|
||||||
# Verify signature
|
# Verify signature using constant-time comparison to prevent timing attacks
|
||||||
payload_bytes = json.dumps(payload).encode('utf-8')
|
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||||
expected_signature = hashlib.sha256(
|
expected_signature = hashlib.sha256(
|
||||||
payload_bytes + settings.SECRET_KEY.encode('utf-8')
|
payload_bytes + settings.SECRET_KEY.encode('utf-8')
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
|
|
||||||
if signature != expected_signature:
|
if not hmac.compare_digest(signature, expected_signature):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Check expiration
|
# Check expiration
|
||||||
@@ -185,13 +186,13 @@ def verify_password_reset_token(token: str) -> Optional[str]:
|
|||||||
if payload.get("purpose") != "password_reset":
|
if payload.get("purpose") != "password_reset":
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Verify signature
|
# Verify signature using constant-time comparison to prevent timing attacks
|
||||||
payload_bytes = json.dumps(payload).encode('utf-8')
|
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||||
expected_signature = hashlib.sha256(
|
expected_signature = hashlib.sha256(
|
||||||
payload_bytes + settings.SECRET_KEY.encode('utf-8')
|
payload_bytes + settings.SECRET_KEY.encode('utf-8')
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
|
|
||||||
if signature != expected_signature:
|
if not hmac.compare_digest(signature, expected_signature):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Check expiration
|
# Check expiration
|
||||||
@@ -278,13 +279,13 @@ def verify_email_verification_token(token: str) -> Optional[str]:
|
|||||||
if payload.get("purpose") != "email_verification":
|
if payload.get("purpose") != "email_verification":
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Verify signature
|
# Verify signature using constant-time comparison to prevent timing attacks
|
||||||
payload_bytes = json.dumps(payload).encode('utf-8')
|
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||||
expected_signature = hashlib.sha256(
|
expected_signature = hashlib.sha256(
|
||||||
payload_bytes + settings.SECRET_KEY.encode('utf-8')
|
payload_bytes + settings.SECRET_KEY.encode('utf-8')
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
|
|
||||||
if signature != expected_signature:
|
if not hmac.compare_digest(signature, expected_signature):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Check expiration
|
# Check expiration
|
||||||
|
|||||||
Reference in New Issue
Block a user