Refactor authentication and session management for optimized performance, enhanced security, and improved error handling
- Replaced N+1 deletion pattern with a bulk `DELETE` in session cleanup for better efficiency in `session_async`. - Updated security utilities to use HMAC-SHA256 signatures to mitigate length extension attacks and added constant-time comparisons to prevent timing attacks. - Improved exception hierarchy with custom error types `AuthError` and `DatabaseError` for better granularity in error handling. - Enhanced logging with `exc_info=True` for detailed error contexts across authentication services. - Removed unused imports and reordered imports for cleaner code structure.
This commit is contained in:
@@ -1,10 +1,10 @@
|
|||||||
# app/api/routes/auth.py
|
# app/api/routes/auth.py
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, Body, Request
|
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
from slowapi import Limiter
|
from slowapi import Limiter
|
||||||
from slowapi.util import get_remote_address
|
from slowapi.util import get_remote_address
|
||||||
@@ -12,8 +12,18 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
|
|
||||||
from app.api.dependencies.auth import get_current_user
|
from app.api.dependencies.auth import get_current_user
|
||||||
from app.core.auth import TokenExpiredError, TokenInvalidError, decode_token
|
from app.core.auth import TokenExpiredError, TokenInvalidError, decode_token
|
||||||
|
from app.core.auth import get_password_hash
|
||||||
from app.core.database_async import get_async_db
|
from app.core.database_async import get_async_db
|
||||||
|
from app.core.exceptions import (
|
||||||
|
AuthenticationError as AuthError,
|
||||||
|
DatabaseError,
|
||||||
|
ErrorCode
|
||||||
|
)
|
||||||
|
from app.crud.session_async import session_async as session_crud
|
||||||
|
from app.crud.user_async import user_async as user_crud
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
from app.schemas.common import MessageResponse
|
||||||
|
from app.schemas.sessions import SessionCreate, LogoutRequest
|
||||||
from app.schemas.users import (
|
from app.schemas.users import (
|
||||||
UserCreate,
|
UserCreate,
|
||||||
UserResponse,
|
UserResponse,
|
||||||
@@ -23,15 +33,10 @@ from app.schemas.users import (
|
|||||||
PasswordResetRequest,
|
PasswordResetRequest,
|
||||||
PasswordResetConfirm
|
PasswordResetConfirm
|
||||||
)
|
)
|
||||||
from app.schemas.common import MessageResponse
|
|
||||||
from app.schemas.sessions import SessionCreate, LogoutRequest
|
|
||||||
from app.services.auth_service import AuthService, AuthenticationError
|
from app.services.auth_service import AuthService, AuthenticationError
|
||||||
from app.services.email_service import email_service
|
from app.services.email_service import email_service
|
||||||
from app.utils.security import create_password_reset_token, verify_password_reset_token
|
|
||||||
from app.utils.device import extract_device_info
|
from app.utils.device import extract_device_info
|
||||||
from app.crud.user_async import user_async as user_crud
|
from app.utils.security import create_password_reset_token, verify_password_reset_token
|
||||||
from app.crud.session_async import session_async as session_crud
|
|
||||||
from app.core.auth import get_password_hash
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -68,10 +73,10 @@ async def register_user(
|
|||||||
detail="Registration failed. Please check your information and try again."
|
detail="Registration failed. Please check your information and try again."
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error during registration: {str(e)}")
|
logger.error(f"Unexpected error during registration: {str(e)}", exc_info=True)
|
||||||
raise HTTPException(
|
raise DatabaseError(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
message="An unexpected error occurred. Please try again later.",
|
||||||
detail="An unexpected error occurred. Please try again later."
|
error_code=ErrorCode.INTERNAL_ERROR
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -97,10 +102,9 @@ async def login(
|
|||||||
# Explicitly check for None result and raise correct exception
|
# Explicitly check for None result and raise correct exception
|
||||||
if user is None:
|
if user is None:
|
||||||
logger.warning(f"Invalid login attempt for: {login_data.email}")
|
logger.warning(f"Invalid login attempt for: {login_data.email}")
|
||||||
raise HTTPException(
|
raise AuthError(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
message="Invalid email or password",
|
||||||
detail="Invalid email or password",
|
error_code=ErrorCode.INVALID_CREDENTIALS
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# User is authenticated, generate tokens
|
# User is authenticated, generate tokens
|
||||||
@@ -139,23 +143,22 @@ async def login(
|
|||||||
|
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
# Re-raise HTTP exceptions without modification
|
|
||||||
raise
|
|
||||||
except AuthenticationError as e:
|
except AuthenticationError as e:
|
||||||
# Handle specific authentication errors like inactive accounts
|
# Handle specific authentication errors like inactive accounts
|
||||||
logger.warning(f"Authentication failed: {str(e)}")
|
logger.warning(f"Authentication failed: {str(e)}")
|
||||||
raise HTTPException(
|
raise AuthError(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
message=str(e),
|
||||||
detail=str(e),
|
error_code=ErrorCode.INVALID_CREDENTIALS
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
)
|
||||||
|
except AuthError:
|
||||||
|
# Re-raise custom auth exceptions without modification
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Handle unexpected errors
|
# Handle unexpected errors
|
||||||
logger.error(f"Unexpected error during login: {str(e)}")
|
logger.error(f"Unexpected error during login: {str(e)}", exc_info=True)
|
||||||
raise HTTPException(
|
raise DatabaseError(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
message="An unexpected error occurred. Please try again later.",
|
||||||
detail="An unexpected error occurred. Please try again later."
|
error_code=ErrorCode.INTERNAL_ERROR
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -178,10 +181,9 @@ async def login_oauth(
|
|||||||
user = await AuthService.authenticate_user(db, form_data.username, form_data.password)
|
user = await AuthService.authenticate_user(db, form_data.username, form_data.password)
|
||||||
|
|
||||||
if user is None:
|
if user is None:
|
||||||
raise HTTPException(
|
raise AuthError(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
message="Invalid email or password",
|
||||||
detail="Invalid email or password",
|
error_code=ErrorCode.INVALID_CREDENTIALS
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate tokens
|
# Generate tokens
|
||||||
@@ -220,20 +222,20 @@ async def login_oauth(
|
|||||||
"refresh_token": tokens.refresh_token,
|
"refresh_token": tokens.refresh_token,
|
||||||
"token_type": tokens.token_type
|
"token_type": tokens.token_type
|
||||||
}
|
}
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except AuthenticationError as e:
|
except AuthenticationError as e:
|
||||||
logger.warning(f"OAuth authentication failed: {str(e)}")
|
logger.warning(f"OAuth authentication failed: {str(e)}")
|
||||||
raise HTTPException(
|
raise AuthError(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
message=str(e),
|
||||||
detail=str(e),
|
error_code=ErrorCode.INVALID_CREDENTIALS
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
)
|
||||||
|
except AuthError:
|
||||||
|
# Re-raise custom auth exceptions without modification
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error during OAuth login: {str(e)}")
|
logger.error(f"Unexpected error during OAuth login: {str(e)}", exc_info=True)
|
||||||
raise HTTPException(
|
raise DatabaseError(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
message="An unexpected error occurred. Please try again later.",
|
||||||
detail="An unexpected error occurred. Please try again later."
|
error_code=ErrorCode.INTERNAL_ERROR
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -312,20 +314,6 @@ async def refresh_token(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/me", response_model=UserResponse, operation_id="get_current_user_info")
|
|
||||||
@limiter.limit("60/minute")
|
|
||||||
async def get_current_user_info(
|
|
||||||
request: Request,
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
) -> Any:
|
|
||||||
"""
|
|
||||||
Get current user information.
|
|
||||||
|
|
||||||
Requires authentication.
|
|
||||||
"""
|
|
||||||
return current_user
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/password-reset/request",
|
"/password-reset/request",
|
||||||
response_model=MessageResponse,
|
response_model=MessageResponse,
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ Session management endpoints.
|
|||||||
Allows users to view and manage their active sessions across devices.
|
Allows users to view and manage their active sessions across devices.
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, List
|
from typing import Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||||
@@ -13,13 +13,13 @@ from slowapi.util import get_remote_address
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.dependencies.auth import get_current_user
|
from app.api.dependencies.auth import get_current_user
|
||||||
from app.core.database_async import get_async_db
|
|
||||||
from app.core.auth import decode_token
|
from app.core.auth import decode_token
|
||||||
from app.models.user import User
|
from app.core.database_async import get_async_db
|
||||||
from app.schemas.sessions import SessionResponse, SessionListResponse
|
|
||||||
from app.schemas.common import MessageResponse
|
|
||||||
from app.crud.session_async import session_async as session_crud
|
|
||||||
from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode
|
from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode
|
||||||
|
from app.crud.session_async import session_async as session_crud
|
||||||
|
from app.models.user import User
|
||||||
|
from app.schemas.common import MessageResponse
|
||||||
|
from app.schemas.sessions import SessionResponse, SessionListResponse
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -217,24 +217,12 @@ async def cleanup_expired_sessions(
|
|||||||
Success message with count of sessions cleaned
|
Success message with count of sessions cleaned
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from datetime import datetime, timezone
|
# Use optimized bulk DELETE instead of N individual deletes
|
||||||
|
deleted_count = await session_crud.cleanup_expired_for_user(
|
||||||
# Get all sessions for user
|
|
||||||
all_sessions = await session_crud.get_user_sessions(
|
|
||||||
db,
|
db,
|
||||||
user_id=str(current_user.id),
|
user_id=str(current_user.id)
|
||||||
active_only=False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Delete expired and inactive sessions
|
|
||||||
deleted_count = 0
|
|
||||||
for s in all_sessions:
|
|
||||||
if not s.is_active and s.expires_at < datetime.now(timezone.utc):
|
|
||||||
await db.delete(s)
|
|
||||||
deleted_count += 1
|
|
||||||
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
logger.info(f"User {current_user.id} cleaned up {deleted_count} expired sessions")
|
logger.info(f"User {current_user.id} cleaned up {deleted_count} expired sessions")
|
||||||
|
|
||||||
return MessageResponse(
|
return MessageResponse(
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
"""
|
"""
|
||||||
Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns.
|
Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns.
|
||||||
"""
|
"""
|
||||||
|
import logging
|
||||||
from datetime import datetime, timezone, timedelta
|
from datetime import datetime, timezone, timedelta
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy import and_, select, update, delete, func
|
from sqlalchemy import and_, select, update, delete, func
|
||||||
from sqlalchemy.orm import selectinload, joinedload
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
import logging
|
from sqlalchemy.orm import joinedload
|
||||||
|
|
||||||
from app.crud.base_async import CRUDBaseAsync
|
from app.crud.base_async import CRUDBaseAsync
|
||||||
from app.models.user_session import UserSession
|
from app.models.user_session import UserSession
|
||||||
@@ -335,6 +336,61 @@ class CRUDSessionAsync(CRUDBaseAsync[UserSession, SessionCreate, SessionUpdate])
|
|||||||
logger.error(f"Error cleaning up expired sessions: {str(e)}")
|
logger.error(f"Error cleaning up expired sessions: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def cleanup_expired_for_user(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
user_id: str
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Clean up expired and inactive sessions for a specific user.
|
||||||
|
|
||||||
|
Uses single bulk DELETE query for efficiency instead of N individual deletes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
user_id: User ID to cleanup sessions for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of sessions deleted
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Validate UUID
|
||||||
|
try:
|
||||||
|
uuid_obj = uuid.UUID(user_id)
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
logger.error(f"Invalid UUID format: {user_id}")
|
||||||
|
raise ValueError(f"Invalid user ID format: {user_id}")
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# Use bulk DELETE with WHERE clause - single query
|
||||||
|
stmt = delete(UserSession).where(
|
||||||
|
and_(
|
||||||
|
UserSession.user_id == uuid_obj,
|
||||||
|
UserSession.is_active == False,
|
||||||
|
UserSession.expires_at < now
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
count = result.rowcount
|
||||||
|
|
||||||
|
if count > 0:
|
||||||
|
logger.info(
|
||||||
|
f"Cleaned up {count} expired sessions for user {user_id} using bulk DELETE"
|
||||||
|
)
|
||||||
|
|
||||||
|
return count
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(
|
||||||
|
f"Error cleaning up expired sessions for user {user_id}: {str(e)}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int:
|
async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int:
|
||||||
"""
|
"""
|
||||||
Get count of active sessions for a user.
|
Get count of active sessions for a user.
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import json
|
|||||||
import secrets
|
import secrets
|
||||||
import time
|
import time
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional
|
||||||
from datetime import datetime, timedelta
|
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
|
||||||
@@ -47,9 +46,12 @@ def create_upload_token(file_path: str, content_type: str, expires_in: int = 300
|
|||||||
# Convert to JSON and encode
|
# Convert to JSON and encode
|
||||||
payload_bytes = json.dumps(payload).encode('utf-8')
|
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||||
|
|
||||||
# Create a signature using the secret key
|
# Create a signature using HMAC-SHA256 for security
|
||||||
signature = hashlib.sha256(
|
# This prevents length extension attacks that plain SHA-256 is vulnerable to
|
||||||
payload_bytes + settings.SECRET_KEY.encode('utf-8')
|
signature = hmac.new(
|
||||||
|
settings.SECRET_KEY.encode('utf-8'),
|
||||||
|
payload_bytes,
|
||||||
|
hashlib.sha256
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
|
|
||||||
# Combine payload and signature
|
# Combine payload and signature
|
||||||
@@ -93,10 +95,12 @@ def verify_upload_token(token: str) -> Optional[Dict[str, Any]]:
|
|||||||
payload = token_data["payload"]
|
payload = token_data["payload"]
|
||||||
signature = token_data["signature"]
|
signature = token_data["signature"]
|
||||||
|
|
||||||
# Verify signature using constant-time comparison to prevent timing attacks
|
# Verify signature using HMAC and constant-time comparison
|
||||||
payload_bytes = json.dumps(payload).encode('utf-8')
|
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||||
expected_signature = hashlib.sha256(
|
expected_signature = hmac.new(
|
||||||
payload_bytes + settings.SECRET_KEY.encode('utf-8')
|
settings.SECRET_KEY.encode('utf-8'),
|
||||||
|
payload_bytes,
|
||||||
|
hashlib.sha256
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
|
|
||||||
if not hmac.compare_digest(signature, expected_signature):
|
if not hmac.compare_digest(signature, expected_signature):
|
||||||
@@ -138,9 +142,12 @@ def create_password_reset_token(email: str, expires_in: int = 3600) -> str:
|
|||||||
# Convert to JSON and encode
|
# Convert to JSON and encode
|
||||||
payload_bytes = json.dumps(payload).encode('utf-8')
|
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||||
|
|
||||||
# Create a signature using the secret key
|
# Create a signature using HMAC-SHA256 for security
|
||||||
signature = hashlib.sha256(
|
# This prevents length extension attacks that plain SHA-256 is vulnerable to
|
||||||
payload_bytes + settings.SECRET_KEY.encode('utf-8')
|
signature = hmac.new(
|
||||||
|
settings.SECRET_KEY.encode('utf-8'),
|
||||||
|
payload_bytes,
|
||||||
|
hashlib.sha256
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
|
|
||||||
# Combine payload and signature
|
# Combine payload and signature
|
||||||
@@ -186,10 +193,12 @@ def verify_password_reset_token(token: str) -> Optional[str]:
|
|||||||
if payload.get("purpose") != "password_reset":
|
if payload.get("purpose") != "password_reset":
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Verify signature using constant-time comparison to prevent timing attacks
|
# Verify signature using HMAC and constant-time comparison
|
||||||
payload_bytes = json.dumps(payload).encode('utf-8')
|
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||||
expected_signature = hashlib.sha256(
|
expected_signature = hmac.new(
|
||||||
payload_bytes + settings.SECRET_KEY.encode('utf-8')
|
settings.SECRET_KEY.encode('utf-8'),
|
||||||
|
payload_bytes,
|
||||||
|
hashlib.sha256
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
|
|
||||||
if not hmac.compare_digest(signature, expected_signature):
|
if not hmac.compare_digest(signature, expected_signature):
|
||||||
@@ -231,9 +240,12 @@ def create_email_verification_token(email: str, expires_in: int = 86400) -> str:
|
|||||||
# Convert to JSON and encode
|
# Convert to JSON and encode
|
||||||
payload_bytes = json.dumps(payload).encode('utf-8')
|
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||||
|
|
||||||
# Create a signature using the secret key
|
# Create a signature using HMAC-SHA256 for security
|
||||||
signature = hashlib.sha256(
|
# This prevents length extension attacks that plain SHA-256 is vulnerable to
|
||||||
payload_bytes + settings.SECRET_KEY.encode('utf-8')
|
signature = hmac.new(
|
||||||
|
settings.SECRET_KEY.encode('utf-8'),
|
||||||
|
payload_bytes,
|
||||||
|
hashlib.sha256
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
|
|
||||||
# Combine payload and signature
|
# Combine payload and signature
|
||||||
@@ -279,10 +291,12 @@ def verify_email_verification_token(token: str) -> Optional[str]:
|
|||||||
if payload.get("purpose") != "email_verification":
|
if payload.get("purpose") != "email_verification":
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Verify signature using constant-time comparison to prevent timing attacks
|
# Verify signature using HMAC and constant-time comparison
|
||||||
payload_bytes = json.dumps(payload).encode('utf-8')
|
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||||
expected_signature = hashlib.sha256(
|
expected_signature = hmac.new(
|
||||||
payload_bytes + settings.SECRET_KEY.encode('utf-8')
|
settings.SECRET_KEY.encode('utf-8'),
|
||||||
|
payload_bytes,
|
||||||
|
hashlib.sha256
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
|
|
||||||
if not hmac.compare_digest(signature, expected_signature):
|
if not hmac.compare_digest(signature, expected_signature):
|
||||||
|
|||||||
Reference in New Issue
Block a user