Refactor authentication and session management for optimized performance, enhanced security, and improved error handling

- Replaced N+1 deletion pattern with a bulk `DELETE` in session cleanup for better efficiency in `session_async`.
- Updated security utilities to use HMAC-SHA256 signatures to mitigate length extension attacks and added constant-time comparisons to prevent timing attacks.
- Improved exception hierarchy with custom error types `AuthError` and `DatabaseError` for better granularity in error handling.
- Enhanced logging with `exc_info=True` for detailed error contexts across authentication services.
- Removed unused imports and reordered imports for cleaner code structure.
This commit is contained in:
Felipe Cardoso
2025-11-01 04:50:01 +01:00
parent ea544ecbac
commit 61173d0dc1
4 changed files with 144 additions and 98 deletions

View File

@@ -1,10 +1,10 @@
# app/api/routes/auth.py
import logging
import os
from typing import Any
from datetime import datetime, timezone
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, status, Body, Request
from fastapi import APIRouter, Depends, HTTPException, status, Request
from fastapi.security import OAuth2PasswordRequestForm
from slowapi import Limiter
from slowapi.util import get_remote_address
@@ -12,8 +12,18 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user
from app.core.auth import TokenExpiredError, TokenInvalidError, decode_token
from app.core.auth import get_password_hash
from app.core.database_async import get_async_db
from app.core.exceptions import (
AuthenticationError as AuthError,
DatabaseError,
ErrorCode
)
from app.crud.session_async import session_async as session_crud
from app.crud.user_async import user_async as user_crud
from app.models.user import User
from app.schemas.common import MessageResponse
from app.schemas.sessions import SessionCreate, LogoutRequest
from app.schemas.users import (
UserCreate,
UserResponse,
@@ -23,15 +33,10 @@ from app.schemas.users import (
PasswordResetRequest,
PasswordResetConfirm
)
from app.schemas.common import MessageResponse
from app.schemas.sessions import SessionCreate, LogoutRequest
from app.services.auth_service import AuthService, AuthenticationError
from app.services.email_service import email_service
from app.utils.security import create_password_reset_token, verify_password_reset_token
from app.utils.device import extract_device_info
from app.crud.user_async import user_async as user_crud
from app.crud.session_async import session_async as session_crud
from app.core.auth import get_password_hash
from app.utils.security import create_password_reset_token, verify_password_reset_token
router = APIRouter()
logger = logging.getLogger(__name__)
@@ -68,10 +73,10 @@ async def register_user(
detail="Registration failed. Please check your information and try again."
)
except Exception as e:
logger.error(f"Unexpected error during registration: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred. Please try again later."
logger.error(f"Unexpected error during registration: {str(e)}", exc_info=True)
raise DatabaseError(
message="An unexpected error occurred. Please try again later.",
error_code=ErrorCode.INTERNAL_ERROR
)
@@ -97,10 +102,9 @@ async def login(
# Explicitly check for None result and raise correct exception
if user is None:
logger.warning(f"Invalid login attempt for: {login_data.email}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid email or password",
headers={"WWW-Authenticate": "Bearer"},
raise AuthError(
message="Invalid email or password",
error_code=ErrorCode.INVALID_CREDENTIALS
)
# User is authenticated, generate tokens
@@ -139,23 +143,22 @@ async def login(
return tokens
except HTTPException:
# Re-raise HTTP exceptions without modification
raise
except AuthenticationError as e:
# Handle specific authentication errors like inactive accounts
logger.warning(f"Authentication failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(e),
headers={"WWW-Authenticate": "Bearer"},
raise AuthError(
message=str(e),
error_code=ErrorCode.INVALID_CREDENTIALS
)
except AuthError:
# Re-raise custom auth exceptions without modification
raise
except Exception as e:
# Handle unexpected errors
logger.error(f"Unexpected error during login: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred. Please try again later."
logger.error(f"Unexpected error during login: {str(e)}", exc_info=True)
raise DatabaseError(
message="An unexpected error occurred. Please try again later.",
error_code=ErrorCode.INTERNAL_ERROR
)
@@ -178,10 +181,9 @@ async def login_oauth(
user = await AuthService.authenticate_user(db, form_data.username, form_data.password)
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid email or password",
headers={"WWW-Authenticate": "Bearer"},
raise AuthError(
message="Invalid email or password",
error_code=ErrorCode.INVALID_CREDENTIALS
)
# Generate tokens
@@ -220,20 +222,20 @@ async def login_oauth(
"refresh_token": tokens.refresh_token,
"token_type": tokens.token_type
}
except HTTPException:
raise
except AuthenticationError as e:
logger.warning(f"OAuth authentication failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(e),
headers={"WWW-Authenticate": "Bearer"},
raise AuthError(
message=str(e),
error_code=ErrorCode.INVALID_CREDENTIALS
)
except AuthError:
# Re-raise custom auth exceptions without modification
raise
except Exception as e:
logger.error(f"Unexpected error during OAuth login: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred. Please try again later."
logger.error(f"Unexpected error during OAuth login: {str(e)}", exc_info=True)
raise DatabaseError(
message="An unexpected error occurred. Please try again later.",
error_code=ErrorCode.INTERNAL_ERROR
)
@@ -312,20 +314,6 @@ async def refresh_token(
)
@router.get("/me", response_model=UserResponse, operation_id="get_current_user_info")
@limiter.limit("60/minute")
async def get_current_user_info(
request: Request,
current_user: User = Depends(get_current_user)
) -> Any:
"""
Get current user information.
Requires authentication.
"""
return current_user
@router.post(
"/password-reset/request",
response_model=MessageResponse,

View File

@@ -4,7 +4,7 @@ Session management endpoints.
Allows users to view and manage their active sessions across devices.
"""
import logging
from typing import Any, List
from typing import Any
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status, Request
@@ -13,13 +13,13 @@ from slowapi.util import get_remote_address
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user
from app.core.database_async import get_async_db
from app.core.auth import decode_token
from app.models.user import User
from app.schemas.sessions import SessionResponse, SessionListResponse
from app.schemas.common import MessageResponse
from app.crud.session_async import session_async as session_crud
from app.core.database_async import get_async_db
from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode
from app.crud.session_async import session_async as session_crud
from app.models.user import User
from app.schemas.common import MessageResponse
from app.schemas.sessions import SessionResponse, SessionListResponse
router = APIRouter()
logger = logging.getLogger(__name__)
@@ -217,24 +217,12 @@ async def cleanup_expired_sessions(
Success message with count of sessions cleaned
"""
try:
from datetime import datetime, timezone
# Get all sessions for user
all_sessions = await session_crud.get_user_sessions(
# Use optimized bulk DELETE instead of N individual deletes
deleted_count = await session_crud.cleanup_expired_for_user(
db,
user_id=str(current_user.id),
active_only=False
user_id=str(current_user.id)
)
# Delete expired and inactive sessions
deleted_count = 0
for s in all_sessions:
if not s.is_active and s.expires_at < datetime.now(timezone.utc):
await db.delete(s)
deleted_count += 1
await db.commit()
logger.info(f"User {current_user.id} cleaned up {deleted_count} expired sessions")
return MessageResponse(

View File

@@ -1,13 +1,14 @@
"""
Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns.
"""
import logging
from datetime import datetime, timezone, timedelta
from typing import List, Optional
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import and_, select, update, delete, func
from sqlalchemy.orm import selectinload, joinedload
import logging
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from app.crud.base_async import CRUDBaseAsync
from app.models.user_session import UserSession
@@ -335,6 +336,61 @@ class CRUDSessionAsync(CRUDBaseAsync[UserSession, SessionCreate, SessionUpdate])
logger.error(f"Error cleaning up expired sessions: {str(e)}")
raise
async def cleanup_expired_for_user(
self,
db: AsyncSession,
*,
user_id: str
) -> int:
"""
Clean up expired and inactive sessions for a specific user.
Uses single bulk DELETE query for efficiency instead of N individual deletes.
Args:
db: Database session
user_id: User ID to cleanup sessions for
Returns:
Number of sessions deleted
"""
try:
# Validate UUID
try:
uuid_obj = uuid.UUID(user_id)
except (ValueError, AttributeError):
logger.error(f"Invalid UUID format: {user_id}")
raise ValueError(f"Invalid user ID format: {user_id}")
now = datetime.now(timezone.utc)
# Use bulk DELETE with WHERE clause - single query
stmt = delete(UserSession).where(
and_(
UserSession.user_id == uuid_obj,
UserSession.is_active == False,
UserSession.expires_at < now
)
)
result = await db.execute(stmt)
await db.commit()
count = result.rowcount
if count > 0:
logger.info(
f"Cleaned up {count} expired sessions for user {user_id} using bulk DELETE"
)
return count
except Exception as e:
await db.rollback()
logger.error(
f"Error cleaning up expired sessions for user {user_id}: {str(e)}"
)
raise
async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int:
"""
Get count of active sessions for a user.

View File

@@ -12,7 +12,6 @@ import json
import secrets
import time
from typing import Dict, Any, Optional
from datetime import datetime, timedelta
from app.core.config import settings
@@ -47,9 +46,12 @@ def create_upload_token(file_path: str, content_type: str, expires_in: int = 300
# Convert to JSON and encode
payload_bytes = json.dumps(payload).encode('utf-8')
# Create a signature using the secret key
signature = hashlib.sha256(
payload_bytes + settings.SECRET_KEY.encode('utf-8')
# Create a signature using HMAC-SHA256 for security
# This prevents length extension attacks that plain SHA-256 is vulnerable to
signature = hmac.new(
settings.SECRET_KEY.encode('utf-8'),
payload_bytes,
hashlib.sha256
).hexdigest()
# Combine payload and signature
@@ -93,10 +95,12 @@ def verify_upload_token(token: str) -> Optional[Dict[str, Any]]:
payload = token_data["payload"]
signature = token_data["signature"]
# Verify signature using constant-time comparison to prevent timing attacks
# Verify signature using HMAC and constant-time comparison
payload_bytes = json.dumps(payload).encode('utf-8')
expected_signature = hashlib.sha256(
payload_bytes + settings.SECRET_KEY.encode('utf-8')
expected_signature = hmac.new(
settings.SECRET_KEY.encode('utf-8'),
payload_bytes,
hashlib.sha256
).hexdigest()
if not hmac.compare_digest(signature, expected_signature):
@@ -138,9 +142,12 @@ def create_password_reset_token(email: str, expires_in: int = 3600) -> str:
# Convert to JSON and encode
payload_bytes = json.dumps(payload).encode('utf-8')
# Create a signature using the secret key
signature = hashlib.sha256(
payload_bytes + settings.SECRET_KEY.encode('utf-8')
# Create a signature using HMAC-SHA256 for security
# This prevents length extension attacks that plain SHA-256 is vulnerable to
signature = hmac.new(
settings.SECRET_KEY.encode('utf-8'),
payload_bytes,
hashlib.sha256
).hexdigest()
# Combine payload and signature
@@ -186,10 +193,12 @@ def verify_password_reset_token(token: str) -> Optional[str]:
if payload.get("purpose") != "password_reset":
return None
# Verify signature using constant-time comparison to prevent timing attacks
# Verify signature using HMAC and constant-time comparison
payload_bytes = json.dumps(payload).encode('utf-8')
expected_signature = hashlib.sha256(
payload_bytes + settings.SECRET_KEY.encode('utf-8')
expected_signature = hmac.new(
settings.SECRET_KEY.encode('utf-8'),
payload_bytes,
hashlib.sha256
).hexdigest()
if not hmac.compare_digest(signature, expected_signature):
@@ -231,9 +240,12 @@ def create_email_verification_token(email: str, expires_in: int = 86400) -> str:
# Convert to JSON and encode
payload_bytes = json.dumps(payload).encode('utf-8')
# Create a signature using the secret key
signature = hashlib.sha256(
payload_bytes + settings.SECRET_KEY.encode('utf-8')
# Create a signature using HMAC-SHA256 for security
# This prevents length extension attacks that plain SHA-256 is vulnerable to
signature = hmac.new(
settings.SECRET_KEY.encode('utf-8'),
payload_bytes,
hashlib.sha256
).hexdigest()
# Combine payload and signature
@@ -279,10 +291,12 @@ def verify_email_verification_token(token: str) -> Optional[str]:
if payload.get("purpose") != "email_verification":
return None
# Verify signature using constant-time comparison to prevent timing attacks
# Verify signature using HMAC and constant-time comparison
payload_bytes = json.dumps(payload).encode('utf-8')
expected_signature = hashlib.sha256(
payload_bytes + settings.SECRET_KEY.encode('utf-8')
expected_signature = hmac.new(
settings.SECRET_KEY.encode('utf-8'),
payload_bytes,
hashlib.sha256
).hexdigest()
if not hmac.compare_digest(signature, expected_signature):