diff --git a/backend/app/api/routes/auth.py b/backend/app/api/routes/auth.py index fcfe1d0..b0fff81 100755 --- a/backend/app/api/routes/auth.py +++ b/backend/app/api/routes/auth.py @@ -61,10 +61,11 @@ async def register_user( user = await AuthService.create_user(db, user_data) return user except AuthenticationError as e: + # SECURITY: Don't reveal if email exists - generic error message logger.warning(f"Registration failed: {str(e)}") raise HTTPException( - status_code=status.HTTP_409_CONFLICT, - detail=str(e) + status_code=status.HTTP_400_BAD_REQUEST, + detail="Registration failed. Please check your information and try again." ) except Exception as e: logger.error(f"Unexpected error during registration: {str(e)}") @@ -439,11 +440,22 @@ async def confirm_password_reset( db.add(user) await db.commit() - logger.info(f"Password reset successful for {user.email}") + # SECURITY: Invalidate all existing sessions after password reset + # This prevents stolen sessions from being used after password change + from app.crud.session_async import session_async as session_crud + try: + deactivated_count = await session_crud.deactivate_all_user_sessions( + db, + user_id=str(user.id) + ) + logger.info(f"Password reset successful for {user.email}, invalidated {deactivated_count} sessions") + except Exception as session_error: + # Log but don't fail password reset if session invalidation fails + logger.error(f"Failed to invalidate sessions after password reset: {str(session_error)}") return MessageResponse( success=True, - message="Password has been reset successfully. You can now log in with your new password." + message="Password has been reset successfully. All devices have been logged out for security. You can now log in with your new password." ) except HTTPException: diff --git a/backend/app/api/routes/organizations.py b/backend/app/api/routes/organizations.py index 6a756e3..2e9ed30 100755 --- a/backend/app/api/routes/organizations.py +++ b/backend/app/api/routes/organizations.py @@ -62,7 +62,7 @@ async def get_my_organizations( # Add member count and role to each organization orgs_with_data = [] for org in orgs: - role = organization_crud.get_user_role_in_org( + role = await organization_crud.get_user_role_in_org( db, user_id=current_user.id, organization_id=org.id diff --git a/backend/app/services/auth_service.py b/backend/app/services/auth_service.py index 5d3537c..fe21ae8 100755 --- a/backend/app/services/auth_service.py +++ b/backend/app/services/auth_service.py @@ -66,32 +66,46 @@ class AuthService: Returns: Created user + + Raises: + AuthenticationError: If user already exists or creation fails """ - # Check if user already exists - result = await db.execute(select(User).where(User.email == user_data.email)) - existing_user = result.scalar_one_or_none() - if existing_user: - raise AuthenticationError("User with this email already exists") + try: + # Check if user already exists + result = await db.execute(select(User).where(User.email == user_data.email)) + existing_user = result.scalar_one_or_none() + if existing_user: + raise AuthenticationError("User with this email already exists") - # Create new user - hashed_password = get_password_hash(user_data.password) + # Create new user + hashed_password = get_password_hash(user_data.password) - # Create user object from model - user = User( - email=user_data.email, - password_hash=hashed_password, - first_name=user_data.first_name, - last_name=user_data.last_name, - phone_number=user_data.phone_number, - is_active=True, - is_superuser=False - ) + # Create user object from model + user = User( + email=user_data.email, + password_hash=hashed_password, + first_name=user_data.first_name, + last_name=user_data.last_name, + phone_number=user_data.phone_number, + is_active=True, + is_superuser=False + ) - db.add(user) - await db.commit() - await db.refresh(user) + db.add(user) + await db.commit() + await db.refresh(user) - return user + logger.info(f"User created successfully: {user.email}") + return user + + except AuthenticationError: + # Re-raise authentication errors without rollback + raise + except Exception as e: + # Rollback on any database errors + await db.rollback() + logger.error(f"Error creating user: {str(e)}", exc_info=True) + raise AuthenticationError(f"Failed to create user: {str(e)}") @staticmethod def create_tokens(user: User) -> Token: @@ -180,19 +194,30 @@ class AuthService: True if password was changed successfully Raises: - AuthenticationError: If current password is incorrect + AuthenticationError: If current password is incorrect or update fails """ - result = await db.execute(select(User).where(User.id == user_id)) - user = result.scalar_one_or_none() - if not user: - raise AuthenticationError("User not found") + try: + result = await db.execute(select(User).where(User.id == user_id)) + user = result.scalar_one_or_none() + if not user: + raise AuthenticationError("User not found") - # Verify current password - if not verify_password(current_password, user.password_hash): - raise AuthenticationError("Current password is incorrect") + # Verify current password + if not verify_password(current_password, user.password_hash): + raise AuthenticationError("Current password is incorrect") - # Update password - user.password_hash = get_password_hash(new_password) - await db.commit() + # Update password + user.password_hash = get_password_hash(new_password) + await db.commit() - return True + logger.info(f"Password changed successfully for user {user_id}") + return True + + except AuthenticationError: + # Re-raise authentication errors without rollback + raise + except Exception as e: + # Rollback on any database errors + await db.rollback() + logger.error(f"Error changing password for user {user_id}: {str(e)}", exc_info=True) + raise AuthenticationError(f"Failed to change password: {str(e)}") diff --git a/backend/app/utils/security.py b/backend/app/utils/security.py index 2abe3f1..1f9d975 100644 --- a/backend/app/utils/security.py +++ b/backend/app/utils/security.py @@ -7,6 +7,7 @@ time-limited, single-use operations. """ import base64 import hashlib +import hmac import json import secrets import time @@ -92,13 +93,13 @@ def verify_upload_token(token: str) -> Optional[Dict[str, Any]]: payload = token_data["payload"] signature = token_data["signature"] - # Verify signature + # Verify signature using constant-time comparison to prevent timing attacks payload_bytes = json.dumps(payload).encode('utf-8') expected_signature = hashlib.sha256( payload_bytes + settings.SECRET_KEY.encode('utf-8') ).hexdigest() - if signature != expected_signature: + if not hmac.compare_digest(signature, expected_signature): return None # Check expiration @@ -185,13 +186,13 @@ def verify_password_reset_token(token: str) -> Optional[str]: if payload.get("purpose") != "password_reset": return None - # Verify signature + # Verify signature using constant-time comparison to prevent timing attacks payload_bytes = json.dumps(payload).encode('utf-8') expected_signature = hashlib.sha256( payload_bytes + settings.SECRET_KEY.encode('utf-8') ).hexdigest() - if signature != expected_signature: + if not hmac.compare_digest(signature, expected_signature): return None # Check expiration @@ -278,13 +279,13 @@ def verify_email_verification_token(token: str) -> Optional[str]: if payload.get("purpose") != "email_verification": return None - # Verify signature + # Verify signature using constant-time comparison to prevent timing attacks payload_bytes = json.dumps(payload).encode('utf-8') expected_signature = hashlib.sha256( payload_bytes + settings.SECRET_KEY.encode('utf-8') ).hexdigest() - if signature != expected_signature: + if not hmac.compare_digest(signature, expected_signature): return None # Check expiration