refactor(backend): replace python-jose and passlib with PyJWT and bcrypt for security and simplicity
- Migrated JWT token handling from `python-jose` to `PyJWT`, reducing dependencies and improving error clarity. - Replaced `passlib` bcrypt integration with direct `bcrypt` usage for password hashing. - Updated `Makefile`, removing unused CVE ignore based on the replaced dependencies. - Reflected changes in `ARCHITECTURE.md` and adjusted function headers in `auth.py`. - Cleaned up `uv.lock` and `pyproject.toml` to remove unused dependencies (`ecdsa`, `rsa`, etc.) and add `PyJWT`. - Refactored tests and services to align with the updated libraries (`PyJWT` error handling, decoding, and validation).
This commit is contained in:
@@ -1,23 +1,21 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
import bcrypt
|
||||
import jwt
|
||||
from jwt.exceptions import (
|
||||
ExpiredSignatureError,
|
||||
InvalidTokenError,
|
||||
MissingRequiredClaimError,
|
||||
)
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.core.config import settings
|
||||
from app.schemas.users import TokenData, TokenPayload
|
||||
|
||||
# Suppress passlib bcrypt warnings about ident
|
||||
logging.getLogger("passlib").setLevel(logging.ERROR)
|
||||
|
||||
# Password hashing context
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
# Custom exceptions for auth
|
||||
class AuthError(Exception):
|
||||
@@ -37,13 +35,16 @@ class TokenMissingClaimError(AuthError):
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against a hash."""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
"""Verify a password against a bcrypt hash."""
|
||||
return bcrypt.checkpw(
|
||||
plain_password.encode("utf-8"), hashed_password.encode("utf-8")
|
||||
)
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""Generate a password hash."""
|
||||
return pwd_context.hash(password)
|
||||
"""Generate a bcrypt password hash."""
|
||||
salt = bcrypt.gensalt()
|
||||
return bcrypt.hashpw(password.encode("utf-8"), salt).decode("utf-8")
|
||||
|
||||
|
||||
async def verify_password_async(plain_password: str, hashed_password: str) -> bool:
|
||||
@@ -60,9 +61,9 @@ async def verify_password_async(plain_password: str, hashed_password: str) -> bo
|
||||
Returns:
|
||||
True if password matches, False otherwise
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(
|
||||
None, partial(pwd_context.verify, plain_password, hashed_password)
|
||||
None, partial(verify_password, plain_password, hashed_password)
|
||||
)
|
||||
|
||||
|
||||
@@ -80,8 +81,8 @@ async def get_password_hash_async(password: str) -> str:
|
||||
Returns:
|
||||
Hashed password string
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, pwd_context.hash, password)
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, get_password_hash, password)
|
||||
|
||||
|
||||
def create_access_token(
|
||||
@@ -121,11 +122,7 @@ def create_access_token(
|
||||
to_encode.update(claims)
|
||||
|
||||
# Create the JWT
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
return encoded_jwt
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
|
||||
|
||||
def create_refresh_token(
|
||||
@@ -154,11 +151,7 @@ def create_refresh_token(
|
||||
"type": "refresh",
|
||||
}
|
||||
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
return encoded_jwt
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
|
||||
|
||||
def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
|
||||
@@ -198,7 +191,7 @@ def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
|
||||
|
||||
# Reject weak or unexpected algorithms
|
||||
# NOTE: These are defensive checks that provide defense-in-depth.
|
||||
# The python-jose library rejects these tokens BEFORE we reach here,
|
||||
# PyJWT rejects these tokens BEFORE we reach here,
|
||||
# but we keep these checks in case the library changes or is misconfigured.
|
||||
# Coverage: Marked as pragma since library catches first (see tests/core/test_auth_security.py)
|
||||
if token_algorithm == "NONE": # pragma: no cover
|
||||
@@ -219,10 +212,11 @@ def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
|
||||
token_data = TokenPayload(**payload)
|
||||
return token_data
|
||||
|
||||
except JWTError as e:
|
||||
# Check if the error is due to an expired token
|
||||
if "expired" in str(e).lower():
|
||||
raise TokenExpiredError("Token has expired")
|
||||
except ExpiredSignatureError:
|
||||
raise TokenExpiredError("Token has expired")
|
||||
except MissingRequiredClaimError as e:
|
||||
raise TokenMissingClaimError(f"Token missing required claim: {e}")
|
||||
except InvalidTokenError:
|
||||
raise TokenInvalidError("Invalid authentication token")
|
||||
except ValidationError:
|
||||
raise TokenInvalidError("Invalid token payload")
|
||||
|
||||
Reference in New Issue
Block a user