diff --git a/backend/app/api/routes/auth.py b/backend/app/api/routes/auth.py index b0fff81..1335e92 100755 --- a/backend/app/api/routes/auth.py +++ b/backend/app/api/routes/auth.py @@ -1,10 +1,10 @@ # app/api/routes/auth.py import logging import os -from typing import Any from datetime import datetime, timezone +from typing import Any -from fastapi import APIRouter, Depends, HTTPException, status, Body, Request +from fastapi import APIRouter, Depends, HTTPException, status, Request from fastapi.security import OAuth2PasswordRequestForm from slowapi import Limiter from slowapi.util import get_remote_address @@ -12,8 +12,18 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.api.dependencies.auth import get_current_user from app.core.auth import TokenExpiredError, TokenInvalidError, decode_token +from app.core.auth import get_password_hash from app.core.database_async import get_async_db +from app.core.exceptions import ( + AuthenticationError as AuthError, + DatabaseError, + ErrorCode +) +from app.crud.session_async import session_async as session_crud +from app.crud.user_async import user_async as user_crud from app.models.user import User +from app.schemas.common import MessageResponse +from app.schemas.sessions import SessionCreate, LogoutRequest from app.schemas.users import ( UserCreate, UserResponse, @@ -23,15 +33,10 @@ from app.schemas.users import ( PasswordResetRequest, PasswordResetConfirm ) -from app.schemas.common import MessageResponse -from app.schemas.sessions import SessionCreate, LogoutRequest from app.services.auth_service import AuthService, AuthenticationError from app.services.email_service import email_service -from app.utils.security import create_password_reset_token, verify_password_reset_token from app.utils.device import extract_device_info -from app.crud.user_async import user_async as user_crud -from app.crud.session_async import session_async as session_crud -from app.core.auth import get_password_hash +from app.utils.security import create_password_reset_token, verify_password_reset_token router = APIRouter() logger = logging.getLogger(__name__) @@ -68,10 +73,10 @@ async def register_user( detail="Registration failed. Please check your information and try again." ) except Exception as e: - logger.error(f"Unexpected error during registration: {str(e)}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="An unexpected error occurred. Please try again later." + logger.error(f"Unexpected error during registration: {str(e)}", exc_info=True) + raise DatabaseError( + message="An unexpected error occurred. Please try again later.", + error_code=ErrorCode.INTERNAL_ERROR ) @@ -97,10 +102,9 @@ async def login( # Explicitly check for None result and raise correct exception if user is None: logger.warning(f"Invalid login attempt for: {login_data.email}") - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid email or password", - headers={"WWW-Authenticate": "Bearer"}, + raise AuthError( + message="Invalid email or password", + error_code=ErrorCode.INVALID_CREDENTIALS ) # User is authenticated, generate tokens @@ -139,23 +143,22 @@ async def login( return tokens - except HTTPException: - # Re-raise HTTP exceptions without modification - raise except AuthenticationError as e: # Handle specific authentication errors like inactive accounts logger.warning(f"Authentication failed: {str(e)}") - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=str(e), - headers={"WWW-Authenticate": "Bearer"}, + raise AuthError( + message=str(e), + error_code=ErrorCode.INVALID_CREDENTIALS ) + except AuthError: + # Re-raise custom auth exceptions without modification + raise except Exception as e: # Handle unexpected errors - logger.error(f"Unexpected error during login: {str(e)}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="An unexpected error occurred. Please try again later." + logger.error(f"Unexpected error during login: {str(e)}", exc_info=True) + raise DatabaseError( + message="An unexpected error occurred. Please try again later.", + error_code=ErrorCode.INTERNAL_ERROR ) @@ -178,10 +181,9 @@ async def login_oauth( user = await AuthService.authenticate_user(db, form_data.username, form_data.password) if user is None: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid email or password", - headers={"WWW-Authenticate": "Bearer"}, + raise AuthError( + message="Invalid email or password", + error_code=ErrorCode.INVALID_CREDENTIALS ) # Generate tokens @@ -220,20 +222,20 @@ async def login_oauth( "refresh_token": tokens.refresh_token, "token_type": tokens.token_type } - except HTTPException: - raise except AuthenticationError as e: logger.warning(f"OAuth authentication failed: {str(e)}") - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=str(e), - headers={"WWW-Authenticate": "Bearer"}, + raise AuthError( + message=str(e), + error_code=ErrorCode.INVALID_CREDENTIALS ) + except AuthError: + # Re-raise custom auth exceptions without modification + raise except Exception as e: - logger.error(f"Unexpected error during OAuth login: {str(e)}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="An unexpected error occurred. Please try again later." + logger.error(f"Unexpected error during OAuth login: {str(e)}", exc_info=True) + raise DatabaseError( + message="An unexpected error occurred. Please try again later.", + error_code=ErrorCode.INTERNAL_ERROR ) @@ -312,20 +314,6 @@ async def refresh_token( ) -@router.get("/me", response_model=UserResponse, operation_id="get_current_user_info") -@limiter.limit("60/minute") -async def get_current_user_info( - request: Request, - current_user: User = Depends(get_current_user) -) -> Any: - """ - Get current user information. - - Requires authentication. - """ - return current_user - - @router.post( "/password-reset/request", response_model=MessageResponse, diff --git a/backend/app/api/routes/sessions.py b/backend/app/api/routes/sessions.py index b9dc3b5..64b0e04 100755 --- a/backend/app/api/routes/sessions.py +++ b/backend/app/api/routes/sessions.py @@ -4,7 +4,7 @@ Session management endpoints. Allows users to view and manage their active sessions across devices. """ import logging -from typing import Any, List +from typing import Any from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, status, Request @@ -13,13 +13,13 @@ from slowapi.util import get_remote_address from sqlalchemy.ext.asyncio import AsyncSession from app.api.dependencies.auth import get_current_user -from app.core.database_async import get_async_db from app.core.auth import decode_token -from app.models.user import User -from app.schemas.sessions import SessionResponse, SessionListResponse -from app.schemas.common import MessageResponse -from app.crud.session_async import session_async as session_crud +from app.core.database_async import get_async_db from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode +from app.crud.session_async import session_async as session_crud +from app.models.user import User +from app.schemas.common import MessageResponse +from app.schemas.sessions import SessionResponse, SessionListResponse router = APIRouter() logger = logging.getLogger(__name__) @@ -217,24 +217,12 @@ async def cleanup_expired_sessions( Success message with count of sessions cleaned """ try: - from datetime import datetime, timezone - - # Get all sessions for user - all_sessions = await session_crud.get_user_sessions( + # Use optimized bulk DELETE instead of N individual deletes + deleted_count = await session_crud.cleanup_expired_for_user( db, - user_id=str(current_user.id), - active_only=False + user_id=str(current_user.id) ) - # Delete expired and inactive sessions - deleted_count = 0 - for s in all_sessions: - if not s.is_active and s.expires_at < datetime.now(timezone.utc): - await db.delete(s) - deleted_count += 1 - - await db.commit() - logger.info(f"User {current_user.id} cleaned up {deleted_count} expired sessions") return MessageResponse( diff --git a/backend/app/crud/session_async.py b/backend/app/crud/session_async.py index e9990ba..53eb58c 100755 --- a/backend/app/crud/session_async.py +++ b/backend/app/crud/session_async.py @@ -1,13 +1,14 @@ """ Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns. """ +import logging from datetime import datetime, timezone, timedelta from typing import List, Optional from uuid import UUID -from sqlalchemy.ext.asyncio import AsyncSession + from sqlalchemy import and_, select, update, delete, func -from sqlalchemy.orm import selectinload, joinedload -import logging +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload from app.crud.base_async import CRUDBaseAsync from app.models.user_session import UserSession @@ -335,6 +336,61 @@ class CRUDSessionAsync(CRUDBaseAsync[UserSession, SessionCreate, SessionUpdate]) logger.error(f"Error cleaning up expired sessions: {str(e)}") raise + async def cleanup_expired_for_user( + self, + db: AsyncSession, + *, + user_id: str + ) -> int: + """ + Clean up expired and inactive sessions for a specific user. + + Uses single bulk DELETE query for efficiency instead of N individual deletes. + + Args: + db: Database session + user_id: User ID to cleanup sessions for + + Returns: + Number of sessions deleted + """ + try: + # Validate UUID + try: + uuid_obj = uuid.UUID(user_id) + except (ValueError, AttributeError): + logger.error(f"Invalid UUID format: {user_id}") + raise ValueError(f"Invalid user ID format: {user_id}") + + now = datetime.now(timezone.utc) + + # Use bulk DELETE with WHERE clause - single query + stmt = delete(UserSession).where( + and_( + UserSession.user_id == uuid_obj, + UserSession.is_active == False, + UserSession.expires_at < now + ) + ) + + result = await db.execute(stmt) + await db.commit() + + count = result.rowcount + + if count > 0: + logger.info( + f"Cleaned up {count} expired sessions for user {user_id} using bulk DELETE" + ) + + return count + except Exception as e: + await db.rollback() + logger.error( + f"Error cleaning up expired sessions for user {user_id}: {str(e)}" + ) + raise + async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int: """ Get count of active sessions for a user. diff --git a/backend/app/utils/security.py b/backend/app/utils/security.py index 1f9d975..303d339 100644 --- a/backend/app/utils/security.py +++ b/backend/app/utils/security.py @@ -12,7 +12,6 @@ import json import secrets import time from typing import Dict, Any, Optional -from datetime import datetime, timedelta from app.core.config import settings @@ -47,9 +46,12 @@ def create_upload_token(file_path: str, content_type: str, expires_in: int = 300 # Convert to JSON and encode payload_bytes = json.dumps(payload).encode('utf-8') - # Create a signature using the secret key - signature = hashlib.sha256( - payload_bytes + settings.SECRET_KEY.encode('utf-8') + # Create a signature using HMAC-SHA256 for security + # This prevents length extension attacks that plain SHA-256 is vulnerable to + signature = hmac.new( + settings.SECRET_KEY.encode('utf-8'), + payload_bytes, + hashlib.sha256 ).hexdigest() # Combine payload and signature @@ -93,10 +95,12 @@ def verify_upload_token(token: str) -> Optional[Dict[str, Any]]: payload = token_data["payload"] signature = token_data["signature"] - # Verify signature using constant-time comparison to prevent timing attacks + # Verify signature using HMAC and constant-time comparison payload_bytes = json.dumps(payload).encode('utf-8') - expected_signature = hashlib.sha256( - payload_bytes + settings.SECRET_KEY.encode('utf-8') + expected_signature = hmac.new( + settings.SECRET_KEY.encode('utf-8'), + payload_bytes, + hashlib.sha256 ).hexdigest() if not hmac.compare_digest(signature, expected_signature): @@ -138,9 +142,12 @@ def create_password_reset_token(email: str, expires_in: int = 3600) -> str: # Convert to JSON and encode payload_bytes = json.dumps(payload).encode('utf-8') - # Create a signature using the secret key - signature = hashlib.sha256( - payload_bytes + settings.SECRET_KEY.encode('utf-8') + # Create a signature using HMAC-SHA256 for security + # This prevents length extension attacks that plain SHA-256 is vulnerable to + signature = hmac.new( + settings.SECRET_KEY.encode('utf-8'), + payload_bytes, + hashlib.sha256 ).hexdigest() # Combine payload and signature @@ -186,10 +193,12 @@ def verify_password_reset_token(token: str) -> Optional[str]: if payload.get("purpose") != "password_reset": return None - # Verify signature using constant-time comparison to prevent timing attacks + # Verify signature using HMAC and constant-time comparison payload_bytes = json.dumps(payload).encode('utf-8') - expected_signature = hashlib.sha256( - payload_bytes + settings.SECRET_KEY.encode('utf-8') + expected_signature = hmac.new( + settings.SECRET_KEY.encode('utf-8'), + payload_bytes, + hashlib.sha256 ).hexdigest() if not hmac.compare_digest(signature, expected_signature): @@ -231,9 +240,12 @@ def create_email_verification_token(email: str, expires_in: int = 86400) -> str: # Convert to JSON and encode payload_bytes = json.dumps(payload).encode('utf-8') - # Create a signature using the secret key - signature = hashlib.sha256( - payload_bytes + settings.SECRET_KEY.encode('utf-8') + # Create a signature using HMAC-SHA256 for security + # This prevents length extension attacks that plain SHA-256 is vulnerable to + signature = hmac.new( + settings.SECRET_KEY.encode('utf-8'), + payload_bytes, + hashlib.sha256 ).hexdigest() # Combine payload and signature @@ -279,10 +291,12 @@ def verify_email_verification_token(token: str) -> Optional[str]: if payload.get("purpose") != "email_verification": return None - # Verify signature using constant-time comparison to prevent timing attacks + # Verify signature using HMAC and constant-time comparison payload_bytes = json.dumps(payload).encode('utf-8') - expected_signature = hashlib.sha256( - payload_bytes + settings.SECRET_KEY.encode('utf-8') + expected_signature = hmac.new( + settings.SECRET_KEY.encode('utf-8'), + payload_bytes, + hashlib.sha256 ).hexdigest() if not hmac.compare_digest(signature, expected_signature):