Refactor token handling and introduce token revocation logic
Updated `decode_token` for stricter validation of token claims and explicit error handling. Added utilities for token revocation and verification, improving
This commit is contained in:
@@ -2,7 +2,8 @@ from datetime import datetime, timedelta
|
||||
from typing import Optional, Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
from jose import jwt, JWTError
|
||||
from black import timezone
|
||||
from jose import jwt, ExpiredSignatureError, JWTError
|
||||
from passlib.context import CryptContext
|
||||
from app.core.config import settings
|
||||
from app.schemas.token import TokenPayload, TokenResponse
|
||||
@@ -75,30 +76,54 @@ def create_token(
|
||||
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def decode_token(token: str) -> TokenPayload:
|
||||
def decode_token(token: str, required_type: str = "access") -> TokenPayload:
|
||||
"""
|
||||
Decode and validate a JWT token.
|
||||
Decode and validate a JWT token with explicit edge-case handling.
|
||||
|
||||
Args:
|
||||
token: The JWT token to decode
|
||||
token: The JWT token to decode.
|
||||
required_type: The expected token type (default: "access").
|
||||
|
||||
Returns:
|
||||
TokenPayload containing the decoded data
|
||||
TokenPayload containing the decoded data.
|
||||
|
||||
Raises:
|
||||
JWTError: If token is invalid or expired
|
||||
JWTError: If the token is expired, invalid, or malformed.
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
|
||||
# Explicitly validate required claims (`exp`, `sub`, `type`)
|
||||
if "exp" not in payload or "sub" not in payload or "type" not in payload:
|
||||
raise KeyError("Missing required claim.")
|
||||
|
||||
# Verify token expiration (`exp`)
|
||||
if datetime.now() > datetime.fromtimestamp(payload["exp"]):
|
||||
raise ExpiredSignatureError("Token has expired.")
|
||||
|
||||
# Verify token type (`type`)
|
||||
if payload["type"] != required_type:
|
||||
expected_type = required_type
|
||||
actual_type = payload["type"]
|
||||
raise JWTError(f"Invalid token type: expected '{expected_type}', got '{actual_type}'.")
|
||||
|
||||
# Create TokenPayload object from token claims
|
||||
return TokenPayload(
|
||||
sub=payload["sub"],
|
||||
type=payload["type"],
|
||||
exp=datetime.fromtimestamp(payload["exp"]),
|
||||
iat=datetime.fromtimestamp(payload["iat"]),
|
||||
iat=datetime.fromtimestamp(payload.get("iat", 0)),
|
||||
jti=payload.get("jti")
|
||||
)
|
||||
|
||||
except KeyError as e:
|
||||
raise JWTError("Malformed token. Missing required claim.") from e
|
||||
except ExpiredSignatureError as e:
|
||||
raise JWTError("Token expired. Please refresh your token to continue.") from e
|
||||
except JWTError as e:
|
||||
raise JWTError(f"Invalid token: {str(e)}")
|
||||
raise JWTError(str(e)) from e
|
||||
|
||||
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
@@ -108,4 +133,4 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
|
||||
|
||||
def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""Create a new refresh token."""
|
||||
return create_token(data, expires_delta, "refresh")
|
||||
return create_token(data, expires_delta, "refresh")
|
||||
|
||||
Reference in New Issue
Block a user