Improve error handling, logging, and security in authentication services and utilities

- Refactored `create_user` and `change_password` methods to add transaction rollback on failures and enhanced logging for error contexts.
- Updated security utilities to use constant-time comparison (`hmac.compare_digest`) to mitigate timing attacks.
- Adjusted API responses in registration and password reset flows for better security and user experience.
- Added session invalidation after password resets to enhance account security.
This commit is contained in:
Felipe Cardoso
2025-11-01 01:13:02 +01:00
parent cc98a76e24
commit 4de440ed2d
4 changed files with 82 additions and 44 deletions

View File

@@ -61,10 +61,11 @@ async def register_user(
user = await AuthService.create_user(db, user_data) user = await AuthService.create_user(db, user_data)
return user return user
except AuthenticationError as e: except AuthenticationError as e:
# SECURITY: Don't reveal if email exists - generic error message
logger.warning(f"Registration failed: {str(e)}") logger.warning(f"Registration failed: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_409_CONFLICT, status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e) detail="Registration failed. Please check your information and try again."
) )
except Exception as e: except Exception as e:
logger.error(f"Unexpected error during registration: {str(e)}") logger.error(f"Unexpected error during registration: {str(e)}")
@@ -439,11 +440,22 @@ async def confirm_password_reset(
db.add(user) db.add(user)
await db.commit() await db.commit()
logger.info(f"Password reset successful for {user.email}") # SECURITY: Invalidate all existing sessions after password reset
# This prevents stolen sessions from being used after password change
from app.crud.session_async import session_async as session_crud
try:
deactivated_count = await session_crud.deactivate_all_user_sessions(
db,
user_id=str(user.id)
)
logger.info(f"Password reset successful for {user.email}, invalidated {deactivated_count} sessions")
except Exception as session_error:
# Log but don't fail password reset if session invalidation fails
logger.error(f"Failed to invalidate sessions after password reset: {str(session_error)}")
return MessageResponse( return MessageResponse(
success=True, success=True,
message="Password has been reset successfully. You can now log in with your new password." message="Password has been reset successfully. All devices have been logged out for security. You can now log in with your new password."
) )
except HTTPException: except HTTPException:

View File

@@ -62,7 +62,7 @@ async def get_my_organizations(
# Add member count and role to each organization # Add member count and role to each organization
orgs_with_data = [] orgs_with_data = []
for org in orgs: for org in orgs:
role = organization_crud.get_user_role_in_org( role = await organization_crud.get_user_role_in_org(
db, db,
user_id=current_user.id, user_id=current_user.id,
organization_id=org.id organization_id=org.id

View File

@@ -66,7 +66,11 @@ class AuthService:
Returns: Returns:
Created user Created user
Raises:
AuthenticationError: If user already exists or creation fails
""" """
try:
# Check if user already exists # Check if user already exists
result = await db.execute(select(User).where(User.email == user_data.email)) result = await db.execute(select(User).where(User.email == user_data.email))
existing_user = result.scalar_one_or_none() existing_user = result.scalar_one_or_none()
@@ -91,8 +95,18 @@ class AuthService:
await db.commit() await db.commit()
await db.refresh(user) await db.refresh(user)
logger.info(f"User created successfully: {user.email}")
return user return user
except AuthenticationError:
# Re-raise authentication errors without rollback
raise
except Exception as e:
# Rollback on any database errors
await db.rollback()
logger.error(f"Error creating user: {str(e)}", exc_info=True)
raise AuthenticationError(f"Failed to create user: {str(e)}")
@staticmethod @staticmethod
def create_tokens(user: User) -> Token: def create_tokens(user: User) -> Token:
""" """
@@ -180,8 +194,9 @@ class AuthService:
True if password was changed successfully True if password was changed successfully
Raises: Raises:
AuthenticationError: If current password is incorrect AuthenticationError: If current password is incorrect or update fails
""" """
try:
result = await db.execute(select(User).where(User.id == user_id)) result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
if not user: if not user:
@@ -195,4 +210,14 @@ class AuthService:
user.password_hash = get_password_hash(new_password) user.password_hash = get_password_hash(new_password)
await db.commit() await db.commit()
logger.info(f"Password changed successfully for user {user_id}")
return True return True
except AuthenticationError:
# Re-raise authentication errors without rollback
raise
except Exception as e:
# Rollback on any database errors
await db.rollback()
logger.error(f"Error changing password for user {user_id}: {str(e)}", exc_info=True)
raise AuthenticationError(f"Failed to change password: {str(e)}")

View File

@@ -7,6 +7,7 @@ time-limited, single-use operations.
""" """
import base64 import base64
import hashlib import hashlib
import hmac
import json import json
import secrets import secrets
import time import time
@@ -92,13 +93,13 @@ def verify_upload_token(token: str) -> Optional[Dict[str, Any]]:
payload = token_data["payload"] payload = token_data["payload"]
signature = token_data["signature"] signature = token_data["signature"]
# Verify signature # Verify signature using constant-time comparison to prevent timing attacks
payload_bytes = json.dumps(payload).encode('utf-8') payload_bytes = json.dumps(payload).encode('utf-8')
expected_signature = hashlib.sha256( expected_signature = hashlib.sha256(
payload_bytes + settings.SECRET_KEY.encode('utf-8') payload_bytes + settings.SECRET_KEY.encode('utf-8')
).hexdigest() ).hexdigest()
if signature != expected_signature: if not hmac.compare_digest(signature, expected_signature):
return None return None
# Check expiration # Check expiration
@@ -185,13 +186,13 @@ def verify_password_reset_token(token: str) -> Optional[str]:
if payload.get("purpose") != "password_reset": if payload.get("purpose") != "password_reset":
return None return None
# Verify signature # Verify signature using constant-time comparison to prevent timing attacks
payload_bytes = json.dumps(payload).encode('utf-8') payload_bytes = json.dumps(payload).encode('utf-8')
expected_signature = hashlib.sha256( expected_signature = hashlib.sha256(
payload_bytes + settings.SECRET_KEY.encode('utf-8') payload_bytes + settings.SECRET_KEY.encode('utf-8')
).hexdigest() ).hexdigest()
if signature != expected_signature: if not hmac.compare_digest(signature, expected_signature):
return None return None
# Check expiration # Check expiration
@@ -278,13 +279,13 @@ def verify_email_verification_token(token: str) -> Optional[str]:
if payload.get("purpose") != "email_verification": if payload.get("purpose") != "email_verification":
return None return None
# Verify signature # Verify signature using constant-time comparison to prevent timing attacks
payload_bytes = json.dumps(payload).encode('utf-8') payload_bytes = json.dumps(payload).encode('utf-8')
expected_signature = hashlib.sha256( expected_signature = hashlib.sha256(
payload_bytes + settings.SECRET_KEY.encode('utf-8') payload_bytes + settings.SECRET_KEY.encode('utf-8')
).hexdigest() ).hexdigest()
if signature != expected_signature: if not hmac.compare_digest(signature, expected_signature):
return None return None
# Check expiration # Check expiration