forked from cardosofelipe/pragma-stack
refactor(backend): replace python-jose and passlib with PyJWT and bcrypt for security and simplicity
- Migrated JWT token handling from `python-jose` to `PyJWT`, reducing dependencies and improving error clarity. - Replaced `passlib` bcrypt integration with direct `bcrypt` usage for password hashing. - Updated `Makefile`, removing unused CVE ignore based on the replaced dependencies. - Reflected changes in `ARCHITECTURE.md` and adjusted function headers in `auth.py`. - Cleaned up `uv.lock` and `pyproject.toml` to remove unused dependencies (`ecdsa`, `rsa`, etc.) and add `PyJWT`. - Refactored tests and services to align with the updated libraries (`PyJWT` error handling, decoding, and validation).
This commit is contained in:
@@ -1,23 +1,21 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
import bcrypt
|
||||
import jwt
|
||||
from jwt.exceptions import (
|
||||
ExpiredSignatureError,
|
||||
InvalidTokenError,
|
||||
MissingRequiredClaimError,
|
||||
)
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.core.config import settings
|
||||
from app.schemas.users import TokenData, TokenPayload
|
||||
|
||||
# Suppress passlib bcrypt warnings about ident
|
||||
logging.getLogger("passlib").setLevel(logging.ERROR)
|
||||
|
||||
# Password hashing context
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
# Custom exceptions for auth
|
||||
class AuthError(Exception):
|
||||
@@ -37,13 +35,16 @@ class TokenMissingClaimError(AuthError):
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against a hash."""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
"""Verify a password against a bcrypt hash."""
|
||||
return bcrypt.checkpw(
|
||||
plain_password.encode("utf-8"), hashed_password.encode("utf-8")
|
||||
)
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""Generate a password hash."""
|
||||
return pwd_context.hash(password)
|
||||
"""Generate a bcrypt password hash."""
|
||||
salt = bcrypt.gensalt()
|
||||
return bcrypt.hashpw(password.encode("utf-8"), salt).decode("utf-8")
|
||||
|
||||
|
||||
async def verify_password_async(plain_password: str, hashed_password: str) -> bool:
|
||||
@@ -60,9 +61,9 @@ async def verify_password_async(plain_password: str, hashed_password: str) -> bo
|
||||
Returns:
|
||||
True if password matches, False otherwise
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(
|
||||
None, partial(pwd_context.verify, plain_password, hashed_password)
|
||||
None, partial(verify_password, plain_password, hashed_password)
|
||||
)
|
||||
|
||||
|
||||
@@ -80,8 +81,8 @@ async def get_password_hash_async(password: str) -> str:
|
||||
Returns:
|
||||
Hashed password string
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, pwd_context.hash, password)
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, get_password_hash, password)
|
||||
|
||||
|
||||
def create_access_token(
|
||||
@@ -121,11 +122,7 @@ def create_access_token(
|
||||
to_encode.update(claims)
|
||||
|
||||
# Create the JWT
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
return encoded_jwt
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
|
||||
|
||||
def create_refresh_token(
|
||||
@@ -154,11 +151,7 @@ def create_refresh_token(
|
||||
"type": "refresh",
|
||||
}
|
||||
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
return encoded_jwt
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
|
||||
|
||||
def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
|
||||
@@ -198,7 +191,7 @@ def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
|
||||
|
||||
# Reject weak or unexpected algorithms
|
||||
# NOTE: These are defensive checks that provide defense-in-depth.
|
||||
# The python-jose library rejects these tokens BEFORE we reach here,
|
||||
# PyJWT rejects these tokens BEFORE we reach here,
|
||||
# but we keep these checks in case the library changes or is misconfigured.
|
||||
# Coverage: Marked as pragma since library catches first (see tests/core/test_auth_security.py)
|
||||
if token_algorithm == "NONE": # pragma: no cover
|
||||
@@ -219,10 +212,11 @@ def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
|
||||
token_data = TokenPayload(**payload)
|
||||
return token_data
|
||||
|
||||
except JWTError as e:
|
||||
# Check if the error is due to an expired token
|
||||
if "expired" in str(e).lower():
|
||||
raise TokenExpiredError("Token has expired")
|
||||
except ExpiredSignatureError:
|
||||
raise TokenExpiredError("Token has expired")
|
||||
except MissingRequiredClaimError as e:
|
||||
raise TokenMissingClaimError(f"Token missing required claim: {e}")
|
||||
except InvalidTokenError:
|
||||
raise TokenInvalidError("Invalid authentication token")
|
||||
except ValidationError:
|
||||
raise TokenInvalidError("Invalid token payload")
|
||||
|
||||
@@ -25,8 +25,8 @@ from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from jose import JWTError, jwt
|
||||
from jose.exceptions import ExpiredSignatureError
|
||||
import jwt
|
||||
from jwt.exceptions import ExpiredSignatureError, InvalidTokenError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
@@ -704,7 +704,7 @@ async def revoke_token(
|
||||
"Revoked refresh token via access token JTI %s...", jti[:8]
|
||||
)
|
||||
return True
|
||||
except JWTError:
|
||||
except InvalidTokenError:
|
||||
pass
|
||||
except Exception: # noqa: S110 - Intentional: invalid JWT not an error
|
||||
pass
|
||||
@@ -827,7 +827,7 @@ async def introspect_token(
|
||||
}
|
||||
except ExpiredSignatureError:
|
||||
return {"active": False}
|
||||
except JWTError:
|
||||
except InvalidTokenError:
|
||||
pass
|
||||
except Exception: # noqa: S110 - Intentional: invalid JWT falls through to refresh token check
|
||||
pass
|
||||
|
||||
@@ -537,8 +537,9 @@ class OAuthService:
|
||||
AuthenticationError: If verification fails
|
||||
"""
|
||||
import httpx
|
||||
from jose import jwt as jose_jwt
|
||||
from jose.exceptions import JWTError
|
||||
import jwt as pyjwt
|
||||
from jwt.algorithms import RSAAlgorithm
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
|
||||
try:
|
||||
# Fetch Google's public keys (JWKS)
|
||||
@@ -552,24 +553,27 @@ class OAuthService:
|
||||
jwks = jwks_response.json()
|
||||
|
||||
# Get the key ID from the token header
|
||||
unverified_header = jose_jwt.get_unverified_header(id_token)
|
||||
unverified_header = pyjwt.get_unverified_header(id_token)
|
||||
kid = unverified_header.get("kid")
|
||||
if not kid:
|
||||
raise AuthenticationError("ID token missing key ID (kid)")
|
||||
|
||||
# Find the matching public key
|
||||
public_key = None
|
||||
jwk_data = None
|
||||
for key in jwks.get("keys", []):
|
||||
if key.get("kid") == kid:
|
||||
public_key = key
|
||||
jwk_data = key
|
||||
break
|
||||
|
||||
if not public_key:
|
||||
if not jwk_data:
|
||||
raise AuthenticationError("ID token signed with unknown key")
|
||||
|
||||
# Convert JWK to a public key object for PyJWT
|
||||
public_key = RSAAlgorithm.from_jwk(jwk_data)
|
||||
|
||||
# Verify the token signature and decode claims
|
||||
# jose library will verify signature against the JWK
|
||||
payload = jose_jwt.decode(
|
||||
# PyJWT will verify signature against the RSA public key
|
||||
payload = pyjwt.decode(
|
||||
id_token,
|
||||
public_key,
|
||||
algorithms=["RS256"], # Google uses RS256
|
||||
@@ -597,7 +601,7 @@ class OAuthService:
|
||||
logger.debug("Google ID token verified successfully")
|
||||
return payload
|
||||
|
||||
except JWTError as e:
|
||||
except InvalidTokenError as e:
|
||||
logger.warning("Google ID token verification failed: %s", e)
|
||||
raise AuthenticationError("Invalid ID token signature")
|
||||
except httpx.HTTPError as e:
|
||||
|
||||
Reference in New Issue
Block a user