- **Configurations:** Test minimum `SECRET_KEY` length validation to prevent weak JWT signing keys. Validate proper handling of secure defaults. - **Permissions:** Add tests for inactive user blocking, API access control, and superuser privilege escalation across organizational roles. - **Authentication:** Test logout safety, session revocation, token replay prevention, and defense against JWT algorithm confusion attacks. - Include `# pragma: no cover` for unreachable defensive code in security-sensitive areas.
253 lines
7.1 KiB
Python
253 lines
7.1 KiB
Python
import logging
|
|
logging.getLogger('passlib').setLevel(logging.ERROR)
|
|
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Any, Dict, Optional, Union
|
|
import uuid
|
|
import asyncio
|
|
from functools import partial
|
|
|
|
from jose import jwt, JWTError
|
|
from passlib.context import CryptContext
|
|
from pydantic import ValidationError
|
|
|
|
from app.core.config import settings
|
|
from app.schemas.users import TokenData, TokenPayload
|
|
|
|
|
|
# Password hashing context
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
|
# Custom exceptions for auth
|
|
class AuthError(Exception):
|
|
"""Base authentication error"""
|
|
pass
|
|
|
|
class TokenExpiredError(AuthError):
|
|
"""Token has expired"""
|
|
pass
|
|
|
|
class TokenInvalidError(AuthError):
|
|
"""Token is invalid"""
|
|
pass
|
|
|
|
class TokenMissingClaimError(AuthError):
|
|
"""Token is missing a required claim"""
|
|
pass
|
|
|
|
|
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|
"""Verify a password against a hash."""
|
|
return pwd_context.verify(plain_password, hashed_password)
|
|
|
|
|
|
def get_password_hash(password: str) -> str:
|
|
"""Generate a password hash."""
|
|
return pwd_context.hash(password)
|
|
|
|
|
|
async def verify_password_async(plain_password: str, hashed_password: str) -> bool:
|
|
"""
|
|
Verify a password against a hash asynchronously.
|
|
|
|
Runs the CPU-intensive bcrypt operation in a thread pool to avoid
|
|
blocking the event loop.
|
|
|
|
Args:
|
|
plain_password: Plain text password to verify
|
|
hashed_password: Hashed password to verify against
|
|
|
|
Returns:
|
|
True if password matches, False otherwise
|
|
"""
|
|
loop = asyncio.get_event_loop()
|
|
return await loop.run_in_executor(
|
|
None,
|
|
partial(pwd_context.verify, plain_password, hashed_password)
|
|
)
|
|
|
|
|
|
async def get_password_hash_async(password: str) -> str:
|
|
"""
|
|
Generate a password hash asynchronously.
|
|
|
|
Runs the CPU-intensive bcrypt operation in a thread pool to avoid
|
|
blocking the event loop. This is especially important during user
|
|
registration and password changes.
|
|
|
|
Args:
|
|
password: Plain text password to hash
|
|
|
|
Returns:
|
|
Hashed password string
|
|
"""
|
|
loop = asyncio.get_event_loop()
|
|
return await loop.run_in_executor(
|
|
None,
|
|
pwd_context.hash,
|
|
password
|
|
)
|
|
|
|
|
|
def create_access_token(
|
|
subject: Union[str, Any],
|
|
expires_delta: Optional[timedelta] = None,
|
|
claims: Optional[Dict[str, Any]] = None
|
|
) -> str:
|
|
"""
|
|
Create a JWT access token.
|
|
|
|
Args:
|
|
subject: The subject of the token (usually user_id)
|
|
expires_delta: Optional expiration time delta
|
|
claims: Optional additional claims to include in the token
|
|
|
|
Returns:
|
|
Encoded JWT token
|
|
"""
|
|
if expires_delta:
|
|
expire = datetime.now(timezone.utc) + expires_delta
|
|
else:
|
|
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
|
|
# Base token data
|
|
to_encode = {
|
|
"sub": str(subject),
|
|
"exp": expire,
|
|
"iat": datetime.now(tz=timezone.utc),
|
|
"jti": str(uuid.uuid4()),
|
|
"type": "access"
|
|
}
|
|
|
|
# Add custom claims
|
|
if claims:
|
|
to_encode.update(claims)
|
|
|
|
# Create the JWT
|
|
encoded_jwt = jwt.encode(
|
|
to_encode,
|
|
settings.SECRET_KEY,
|
|
algorithm=settings.ALGORITHM
|
|
)
|
|
|
|
return encoded_jwt
|
|
|
|
|
|
def create_refresh_token(
|
|
subject: Union[str, Any],
|
|
expires_delta: Optional[timedelta] = None
|
|
) -> str:
|
|
"""
|
|
Create a JWT refresh token.
|
|
|
|
Args:
|
|
subject: The subject of the token (usually user_id)
|
|
expires_delta: Optional expiration time delta
|
|
|
|
Returns:
|
|
Encoded JWT refresh token
|
|
"""
|
|
if expires_delta:
|
|
expire = datetime.now(timezone.utc) + expires_delta
|
|
else:
|
|
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
|
|
|
to_encode = {
|
|
"sub": str(subject),
|
|
"exp": expire,
|
|
"iat": datetime.now(timezone.utc),
|
|
"jti": str(uuid.uuid4()),
|
|
"type": "refresh"
|
|
}
|
|
|
|
encoded_jwt = jwt.encode(
|
|
to_encode,
|
|
settings.SECRET_KEY,
|
|
algorithm=settings.ALGORITHM
|
|
)
|
|
|
|
return encoded_jwt
|
|
|
|
|
|
def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload:
|
|
"""
|
|
Decode and verify a JWT token.
|
|
|
|
Args:
|
|
token: JWT token to decode
|
|
verify_type: Optional token type to verify (access or refresh)
|
|
|
|
Returns:
|
|
TokenPayload: The decoded token data
|
|
|
|
Raises:
|
|
TokenExpiredError: If the token has expired
|
|
TokenInvalidError: If the token is invalid
|
|
TokenMissingClaimError: If a required claim is missing
|
|
"""
|
|
try:
|
|
# Decode token with strict algorithm validation
|
|
payload = jwt.decode(
|
|
token,
|
|
settings.SECRET_KEY,
|
|
algorithms=[settings.ALGORITHM],
|
|
options={
|
|
"verify_signature": True,
|
|
"verify_exp": True,
|
|
"verify_iat": True,
|
|
"require": ["exp", "sub", "iat"]
|
|
}
|
|
)
|
|
|
|
# SECURITY: Explicitly verify the algorithm to prevent algorithm confusion attacks
|
|
# Decode header to check algorithm (without verification, just to inspect)
|
|
header = jwt.get_unverified_header(token)
|
|
token_algorithm = header.get("alg", "").upper()
|
|
|
|
# Reject weak or unexpected algorithms
|
|
# NOTE: These are defensive checks that provide defense-in-depth.
|
|
# The python-jose library rejects these tokens BEFORE we reach here,
|
|
# but we keep these checks in case the library changes or is misconfigured.
|
|
# Coverage: Marked as pragma since library catches first (see tests/core/test_auth_security.py)
|
|
if token_algorithm == "NONE": # pragma: no cover
|
|
raise TokenInvalidError("Algorithm 'none' is not allowed")
|
|
|
|
if token_algorithm != settings.ALGORITHM.upper(): # pragma: no cover
|
|
raise TokenInvalidError(f"Invalid algorithm: {token_algorithm}")
|
|
|
|
# Check required claims before Pydantic validation
|
|
if not payload.get("sub"):
|
|
raise TokenMissingClaimError("Token missing 'sub' claim")
|
|
|
|
# Verify token type if specified
|
|
if verify_type and payload.get("type") != verify_type:
|
|
raise TokenInvalidError(f"Invalid token type, expected {verify_type}")
|
|
|
|
# Now validate with Pydantic
|
|
token_data = TokenPayload(**payload)
|
|
return token_data
|
|
|
|
except JWTError as e:
|
|
# Check if the error is due to an expired token
|
|
if "expired" in str(e).lower():
|
|
raise TokenExpiredError("Token has expired")
|
|
raise TokenInvalidError("Invalid authentication token")
|
|
except ValidationError:
|
|
raise TokenInvalidError("Invalid token payload")
|
|
|
|
|
|
def get_token_data(token: str) -> TokenData:
|
|
"""
|
|
Extract the user ID and superuser status from a token.
|
|
|
|
Args:
|
|
token: JWT token
|
|
|
|
Returns:
|
|
TokenData with user_id and is_superuser
|
|
"""
|
|
payload = decode_token(token)
|
|
user_id = payload.sub
|
|
is_superuser = payload.is_superuser or False
|
|
|
|
return TokenData(user_id=uuid.UUID(user_id), is_superuser=is_superuser) |