Files
Felipe Cardoso 1a36907f10 refactor(backend): replace python-jose and passlib with PyJWT and bcrypt for security and simplicity
- Migrated JWT token handling from `python-jose` to `PyJWT`, reducing dependencies and improving error clarity.
- Replaced `passlib` bcrypt integration with direct `bcrypt` usage for password hashing.
- Updated `Makefile`, removing unused CVE ignore based on the replaced dependencies.
- Reflected changes in `ARCHITECTURE.md` and adjusted function headers in `auth.py`.
- Cleaned up `uv.lock` and `pyproject.toml` to remove unused dependencies (`ecdsa`, `rsa`, etc.) and add `PyJWT`.
- Refactored tests and services to align with the updated libraries (`PyJWT` error handling, decoding, and validation).
2026-03-01 14:02:04 +01:00

240 lines
6.9 KiB
Python

import asyncio
import uuid
from datetime import UTC, datetime, timedelta
from functools import partial
from typing import Any
import bcrypt
import jwt
from jwt.exceptions import (
ExpiredSignatureError,
InvalidTokenError,
MissingRequiredClaimError,
)
from pydantic import ValidationError
from app.core.config import settings
from app.schemas.users import TokenData, TokenPayload
# Custom exceptions for auth
class AuthError(Exception):
"""Base authentication error"""
class TokenExpiredError(AuthError):
"""Token has expired"""
class TokenInvalidError(AuthError):
"""Token is invalid"""
class TokenMissingClaimError(AuthError):
"""Token is missing a required claim"""
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against a bcrypt hash."""
return bcrypt.checkpw(
plain_password.encode("utf-8"), hashed_password.encode("utf-8")
)
def get_password_hash(password: str) -> str:
"""Generate a bcrypt password hash."""
salt = bcrypt.gensalt()
return bcrypt.hashpw(password.encode("utf-8"), salt).decode("utf-8")
async def verify_password_async(plain_password: str, hashed_password: str) -> bool:
"""
Verify a password against a hash asynchronously.
Runs the CPU-intensive bcrypt operation in a thread pool to avoid
blocking the event loop.
Args:
plain_password: Plain text password to verify
hashed_password: Hashed password to verify against
Returns:
True if password matches, False otherwise
"""
loop = asyncio.get_running_loop()
return await loop.run_in_executor(
None, partial(verify_password, plain_password, hashed_password)
)
async def get_password_hash_async(password: str) -> str:
"""
Generate a password hash asynchronously.
Runs the CPU-intensive bcrypt operation in a thread pool to avoid
blocking the event loop. This is especially important during user
registration and password changes.
Args:
password: Plain text password to hash
Returns:
Hashed password string
"""
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, get_password_hash, password)
def create_access_token(
subject: str | Any,
expires_delta: timedelta | None = None,
claims: dict[str, Any] | None = None,
) -> str:
"""
Create a JWT access token.
Args:
subject: The subject of the token (usually user_id)
expires_delta: Optional expiration time delta
claims: Optional additional claims to include in the token
Returns:
Encoded JWT token
"""
if expires_delta:
expire = datetime.now(UTC) + expires_delta
else:
expire = datetime.now(UTC) + timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
# Base token data
to_encode = {
"sub": str(subject),
"exp": expire,
"iat": datetime.now(tz=UTC),
"jti": str(uuid.uuid4()),
"type": "access",
}
# Add custom claims
if claims:
to_encode.update(claims)
# Create the JWT
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
def create_refresh_token(
subject: str | Any, expires_delta: timedelta | None = None
) -> str:
"""
Create a JWT refresh token.
Args:
subject: The subject of the token (usually user_id)
expires_delta: Optional expiration time delta
Returns:
Encoded JWT refresh token
"""
if expires_delta:
expire = datetime.now(UTC) + expires_delta
else:
expire = datetime.now(UTC) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
to_encode = {
"sub": str(subject),
"exp": expire,
"iat": datetime.now(UTC),
"jti": str(uuid.uuid4()),
"type": "refresh",
}
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
"""
Decode and verify a JWT token.
Args:
token: JWT token to decode
verify_type: Optional token type to verify (access or refresh)
Returns:
TokenPayload: The decoded token data
Raises:
TokenExpiredError: If the token has expired
TokenInvalidError: If the token is invalid
TokenMissingClaimError: If a required claim is missing
"""
try:
# Decode token with strict algorithm validation
payload = jwt.decode(
token,
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM],
options={
"verify_signature": True,
"verify_exp": True,
"verify_iat": True,
"require": ["exp", "sub", "iat"],
},
)
# SECURITY: Explicitly verify the algorithm to prevent algorithm confusion attacks
# Decode header to check algorithm (without verification, just to inspect)
header = jwt.get_unverified_header(token)
token_algorithm = header.get("alg", "").upper()
# Reject weak or unexpected algorithms
# NOTE: These are defensive checks that provide defense-in-depth.
# PyJWT rejects these tokens BEFORE we reach here,
# but we keep these checks in case the library changes or is misconfigured.
# Coverage: Marked as pragma since library catches first (see tests/core/test_auth_security.py)
if token_algorithm == "NONE": # pragma: no cover
raise TokenInvalidError("Algorithm 'none' is not allowed")
if token_algorithm != settings.ALGORITHM.upper(): # pragma: no cover
raise TokenInvalidError(f"Invalid algorithm: {token_algorithm}")
# Check required claims before Pydantic validation
if not payload.get("sub"):
raise TokenMissingClaimError("Token missing 'sub' claim")
# Verify token type if specified
if verify_type and payload.get("type") != verify_type:
raise TokenInvalidError(f"Invalid token type, expected {verify_type}")
# Now validate with Pydantic
token_data = TokenPayload(**payload)
return token_data
except ExpiredSignatureError:
raise TokenExpiredError("Token has expired")
except MissingRequiredClaimError as e:
raise TokenMissingClaimError(f"Token missing required claim: {e}")
except InvalidTokenError:
raise TokenInvalidError("Invalid authentication token")
except ValidationError:
raise TokenInvalidError("Invalid token payload")
def get_token_data(token: str) -> TokenData:
"""
Extract the user ID and superuser status from a token.
Args:
token: JWT token
Returns:
TokenData with user_id and is_superuser
"""
payload = decode_token(token)
user_id = payload.sub
is_superuser = payload.is_superuser or False
return TokenData(user_id=uuid.UUID(user_id), is_superuser=is_superuser)