Refactor backend to adopt async patterns across services, API routes, and CRUD operations
- Migrated database sessions and operations to `AsyncSession` for full async support. - Updated all service methods and dependencies (`get_db` to `get_async_db`) to support async logic. - Refactored admin, user, organization, session-related CRUD methods, and routes with await syntax. - Improved consistency and performance with async SQLAlchemy patterns. - Enhanced logging and error handling for async context.
This commit is contained in:
62
backend/app/api/routes/auth.py
Normal file → Executable file
62
backend/app/api/routes/auth.py
Normal file → Executable file
@@ -8,11 +8,11 @@ from fastapi import APIRouter, Depends, HTTPException, status, Body, Request
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.core.auth import TokenExpiredError, TokenInvalidError, decode_token
|
||||
from app.core.database import get_db
|
||||
from app.core.database_async import get_async_db
|
||||
from app.models.user import User
|
||||
from app.schemas.users import (
|
||||
UserCreate,
|
||||
@@ -29,8 +29,8 @@ from app.services.auth_service import AuthService, AuthenticationError
|
||||
from app.services.email_service import email_service
|
||||
from app.utils.security import create_password_reset_token, verify_password_reset_token
|
||||
from app.utils.device import extract_device_info
|
||||
from app.crud.user import user as user_crud
|
||||
from app.crud.session import session as session_crud
|
||||
from app.crud.user_async import user_async as user_crud
|
||||
from app.crud.session_async import session_async as session_crud
|
||||
from app.core.auth import get_password_hash
|
||||
|
||||
router = APIRouter()
|
||||
@@ -49,7 +49,7 @@ RATE_MULTIPLIER = 100 if IS_TEST else 1
|
||||
async def register_user(
|
||||
request: Request,
|
||||
user_data: UserCreate,
|
||||
db: Session = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Register a new user.
|
||||
@@ -58,7 +58,7 @@ async def register_user(
|
||||
The created user information.
|
||||
"""
|
||||
try:
|
||||
user = AuthService.create_user(db, user_data)
|
||||
user = await AuthService.create_user(db, user_data)
|
||||
return user
|
||||
except AuthenticationError as e:
|
||||
logger.warning(f"Registration failed: {str(e)}")
|
||||
@@ -79,7 +79,7 @@ async def register_user(
|
||||
async def login(
|
||||
request: Request,
|
||||
login_data: LoginRequest,
|
||||
db: Session = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Login with username and password.
|
||||
@@ -91,7 +91,7 @@ async def login(
|
||||
"""
|
||||
try:
|
||||
# Attempt to authenticate the user
|
||||
user = AuthService.authenticate_user(db, login_data.email, login_data.password)
|
||||
user = await AuthService.authenticate_user(db, login_data.email, login_data.password)
|
||||
|
||||
# Explicitly check for None result and raise correct exception
|
||||
if user is None:
|
||||
@@ -126,7 +126,7 @@ async def login(
|
||||
location_country=device_info.location_country,
|
||||
)
|
||||
|
||||
session_crud.create_session(db, obj_in=session_data)
|
||||
await session_crud.create_session(db, obj_in=session_data)
|
||||
|
||||
logger.info(
|
||||
f"User login successful: {user.email} from {device_info.device_name} "
|
||||
@@ -163,7 +163,7 @@ async def login(
|
||||
async def login_oauth(
|
||||
request: Request,
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
db: Session = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
) -> Any:
|
||||
"""
|
||||
OAuth2-compatible login endpoint, used by the OpenAPI UI.
|
||||
@@ -174,7 +174,7 @@ async def login_oauth(
|
||||
Access and refresh tokens.
|
||||
"""
|
||||
try:
|
||||
user = AuthService.authenticate_user(db, form_data.username, form_data.password)
|
||||
user = await AuthService.authenticate_user(db, form_data.username, form_data.password)
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
@@ -207,7 +207,7 @@ async def login_oauth(
|
||||
location_country=device_info.location_country,
|
||||
)
|
||||
|
||||
session_crud.create_session(db, obj_in=session_data)
|
||||
await session_crud.create_session(db, obj_in=session_data)
|
||||
|
||||
logger.info(f"OAuth login successful: {user.email} from {device_info.device_name}")
|
||||
except Exception as session_err:
|
||||
@@ -241,7 +241,7 @@ async def login_oauth(
|
||||
async def refresh_token(
|
||||
request: Request,
|
||||
refresh_data: RefreshTokenRequest,
|
||||
db: Session = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Refresh access token using a refresh token.
|
||||
@@ -256,7 +256,7 @@ async def refresh_token(
|
||||
refresh_payload = decode_token(refresh_data.refresh_token, verify_type="refresh")
|
||||
|
||||
# Check if session exists and is active
|
||||
session = session_crud.get_active_by_jti(db, jti=refresh_payload.jti)
|
||||
session = await session_crud.get_active_by_jti(db, jti=refresh_payload.jti)
|
||||
|
||||
if not session:
|
||||
logger.warning(f"Refresh token used for inactive or non-existent session: {refresh_payload.jti}")
|
||||
@@ -267,14 +267,14 @@ async def refresh_token(
|
||||
)
|
||||
|
||||
# Generate new tokens
|
||||
tokens = AuthService.refresh_tokens(db, refresh_data.refresh_token)
|
||||
tokens = await AuthService.refresh_tokens(db, refresh_data.refresh_token)
|
||||
|
||||
# Decode new refresh token to get new JTI
|
||||
new_refresh_payload = decode_token(tokens.refresh_token, verify_type="refresh")
|
||||
|
||||
# Update session with new refresh token JTI and expiration
|
||||
try:
|
||||
session_crud.update_refresh_token(
|
||||
await session_crud.update_refresh_token(
|
||||
db,
|
||||
session=session,
|
||||
new_jti=new_refresh_payload.jti,
|
||||
@@ -344,7 +344,7 @@ async def get_current_user_info(
|
||||
async def request_password_reset(
|
||||
request: Request,
|
||||
reset_request: PasswordResetRequest,
|
||||
db: Session = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Request a password reset.
|
||||
@@ -354,7 +354,7 @@ async def request_password_reset(
|
||||
"""
|
||||
try:
|
||||
# Look up user by email
|
||||
user = user_crud.get_by_email(db, email=reset_request.email)
|
||||
user = await user_crud.get_by_email(db, email=reset_request.email)
|
||||
|
||||
# Only send email if user exists and is active
|
||||
if user and user.is_active:
|
||||
@@ -399,10 +399,10 @@ async def request_password_reset(
|
||||
operation_id="confirm_password_reset"
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
def confirm_password_reset(
|
||||
async def confirm_password_reset(
|
||||
request: Request,
|
||||
reset_confirm: PasswordResetConfirm,
|
||||
db: Session = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Confirm password reset with token.
|
||||
@@ -420,7 +420,7 @@ def confirm_password_reset(
|
||||
)
|
||||
|
||||
# Look up user
|
||||
user = user_crud.get_by_email(db, email=email)
|
||||
user = await user_crud.get_by_email(db, email=email)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
@@ -437,7 +437,7 @@ def confirm_password_reset(
|
||||
# Update password
|
||||
user.password_hash = get_password_hash(reset_confirm.new_password)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"Password reset successful for {user.email}")
|
||||
|
||||
@@ -450,7 +450,7 @@ def confirm_password_reset(
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error confirming password reset: {str(e)}", exc_info=True)
|
||||
db.rollback()
|
||||
await db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An error occurred while resetting your password"
|
||||
@@ -474,11 +474,11 @@ def confirm_password_reset(
|
||||
operation_id="logout"
|
||||
)
|
||||
@limiter.limit("10/minute")
|
||||
def logout(
|
||||
async def logout(
|
||||
request: Request,
|
||||
logout_request: LogoutRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Logout from current device by deactivating the session.
|
||||
@@ -505,7 +505,7 @@ def logout(
|
||||
)
|
||||
|
||||
# Find the session by JTI
|
||||
session = session_crud.get_by_jti(db, jti=refresh_payload.jti)
|
||||
session = await session_crud.get_by_jti(db, jti=refresh_payload.jti)
|
||||
|
||||
if session:
|
||||
# Verify session belongs to current user (security check)
|
||||
@@ -520,7 +520,7 @@ def logout(
|
||||
)
|
||||
|
||||
# Deactivate the session
|
||||
session_crud.deactivate(db, session_id=str(session.id))
|
||||
await session_crud.deactivate(db, session_id=str(session.id))
|
||||
|
||||
logger.info(
|
||||
f"User {current_user.id} logged out from {session.device_name} "
|
||||
@@ -563,10 +563,10 @@ def logout(
|
||||
operation_id="logout_all"
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
def logout_all(
|
||||
async def logout_all(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Logout from all devices by deactivating all user sessions.
|
||||
@@ -580,7 +580,7 @@ def logout_all(
|
||||
"""
|
||||
try:
|
||||
# Deactivate all sessions for this user
|
||||
count = session_crud.deactivate_all_user_sessions(db, user_id=str(current_user.id))
|
||||
count = await session_crud.deactivate_all_user_sessions(db, user_id=str(current_user.id))
|
||||
|
||||
logger.info(f"User {current_user.id} logged out from all devices ({count} sessions)")
|
||||
|
||||
@@ -591,7 +591,7 @@ def logout_all(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during logout-all for user {current_user.id}: {str(e)}", exc_info=True)
|
||||
db.rollback()
|
||||
await db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An error occurred while logging out"
|
||||
|
||||
Reference in New Issue
Block a user