Expanded exception handling to cover more specific JWT and JOSE-related errors, including signature verification failures and malformed tokens. This ensures better error messaging and robustness in token validation.
145 lines
4.6 KiB
Python
145 lines
4.6 KiB
Python
from datetime import datetime, timedelta
|
|
from typing import Optional, Tuple
|
|
from uuid import uuid4
|
|
|
|
from black import timezone
|
|
from jose import jwt, ExpiredSignatureError, JWTError
|
|
from passlib.context import CryptContext
|
|
from app.core.config import settings
|
|
from app.schemas.token import TokenPayload, TokenResponse
|
|
from jose.exceptions import ExpiredSignatureError, JWTError, JOSEError
|
|
|
|
# Configuration
|
|
SECRET_KEY = settings.SECRET_KEY
|
|
ALGORITHM = "HS256"
|
|
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
|
REFRESH_TOKEN_EXPIRE_DAYS = 7
|
|
|
|
# Password hashing context
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
|
|
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|
"""Verify a plain password against its hash."""
|
|
return pwd_context.verify(plain_password, hashed_password)
|
|
|
|
|
|
def get_password_hash(password: str) -> str:
|
|
"""Generate password hash."""
|
|
return pwd_context.hash(password)
|
|
|
|
|
|
def create_tokens(user_id: str) -> TokenResponse:
|
|
"""
|
|
Create both access and refresh tokens for a user.
|
|
|
|
Args:
|
|
user_id: The user's ID
|
|
|
|
Returns:
|
|
TokenResponse containing both tokens and metadata
|
|
"""
|
|
access_token = create_access_token({"sub": user_id})
|
|
refresh_token = create_refresh_token({"sub": user_id})
|
|
|
|
return TokenResponse(
|
|
access_token=access_token,
|
|
refresh_token=refresh_token,
|
|
token_type="bearer",
|
|
expires_in=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
|
user_id=user_id
|
|
)
|
|
|
|
|
|
def create_token(
|
|
data: dict,
|
|
expires_delta: Optional[timedelta] = None,
|
|
token_type: str = "access"
|
|
) -> str:
|
|
"""Create a JWT token with the specified type and expiration."""
|
|
to_encode = data.copy()
|
|
|
|
if expires_delta:
|
|
expire = datetime.utcnow() + expires_delta
|
|
else:
|
|
expire = datetime.utcnow() + (
|
|
timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) if token_type == "access"
|
|
else timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
|
)
|
|
|
|
to_encode.update({
|
|
"exp": expire,
|
|
"type": token_type,
|
|
"iat": datetime.utcnow(),
|
|
"jti": str(uuid4())
|
|
})
|
|
|
|
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
|
|
|
|
|
def decode_token(token: str, required_type: str = "access") -> TokenPayload:
|
|
"""
|
|
Decode and validate a JWT token with explicit edge-case handling.
|
|
|
|
Args:
|
|
token: The JWT token to decode.
|
|
required_type: The expected token type (default: "access").
|
|
|
|
Returns:
|
|
TokenPayload containing the decoded data.
|
|
|
|
Raises:
|
|
JWTError: If the token is expired, invalid, or malformed.
|
|
"""
|
|
try:
|
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
|
|
# Explicitly validate required claims (`exp`, `sub`, `type`)
|
|
if "exp" not in payload or "sub" not in payload or "type" not in payload:
|
|
raise KeyError("Missing required claim.")
|
|
|
|
# Verify token expiration (`exp`)
|
|
if datetime.now() > datetime.fromtimestamp(payload["exp"]):
|
|
raise ExpiredSignatureError("Token has expired.")
|
|
|
|
# Verify token type (`type`)
|
|
if payload["type"] != required_type:
|
|
expected_type = required_type
|
|
actual_type = payload["type"]
|
|
raise JWTError(f"Invalid token type: expected '{expected_type}', got '{actual_type}'.")
|
|
|
|
# Create TokenPayload object from token claims
|
|
return TokenPayload(
|
|
sub=payload["sub"],
|
|
type=payload["type"],
|
|
exp=datetime.fromtimestamp(payload["exp"]),
|
|
iat=datetime.fromtimestamp(payload.get("iat", 0)),
|
|
jti=payload.get("jti")
|
|
)
|
|
|
|
except ExpiredSignatureError as e: # Expired token
|
|
raise JWTError("Token expired. Please refresh your token to continue.") from e
|
|
except JWTError as e:
|
|
# Handle signature verification and malformed token errors
|
|
if str(e) in ["Signature verification failed.", "Not enough segments"]:
|
|
raise JWTError("Invalid token.") from e
|
|
# Propagate other JWTError messages
|
|
raise JWTError(str(e)) from e
|
|
except KeyError as e: # Missing required claims
|
|
raise JWTError("Malformed token. Missing required claim.") from e
|
|
except JOSEError as e: # All other JOSE-related errors
|
|
raise JWTError("Invalid token.") from e
|
|
|
|
|
|
|
|
|
|
|
|
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
|
"""Create a new access token."""
|
|
return create_token(data, expires_delta, "access")
|
|
|
|
|
|
def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
|
"""Create a new refresh token."""
|
|
return create_token(data, expires_delta, "refresh")
|