Refactor token handling and introduce token revocation logic

Updated `decode_token` for stricter validation of token claims and explicit error handling. Added utilities for token revocation and verification, improving
This commit is contained in:
2025-02-28 16:57:57 +01:00
parent c3a55b26c7
commit 548880b468
7 changed files with 124 additions and 36 deletions

View File

@@ -2,7 +2,8 @@ from datetime import datetime, timedelta
from typing import Optional, Tuple
from uuid import uuid4
from jose import jwt, JWTError
from black import timezone
from jose import jwt, ExpiredSignatureError, JWTError
from passlib.context import CryptContext
from app.core.config import settings
from app.schemas.token import TokenPayload, TokenResponse
@@ -75,30 +76,54 @@ def create_token(
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def decode_token(token: str) -> TokenPayload:
def decode_token(token: str, required_type: str = "access") -> TokenPayload:
"""
Decode and validate a JWT token.
Decode and validate a JWT token with explicit edge-case handling.
Args:
token: The JWT token to decode
token: The JWT token to decode.
required_type: The expected token type (default: "access").
Returns:
TokenPayload containing the decoded data
TokenPayload containing the decoded data.
Raises:
JWTError: If token is invalid or expired
JWTError: If the token is expired, invalid, or malformed.
"""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
# Explicitly validate required claims (`exp`, `sub`, `type`)
if "exp" not in payload or "sub" not in payload or "type" not in payload:
raise KeyError("Missing required claim.")
# Verify token expiration (`exp`)
if datetime.now() > datetime.fromtimestamp(payload["exp"]):
raise ExpiredSignatureError("Token has expired.")
# Verify token type (`type`)
if payload["type"] != required_type:
expected_type = required_type
actual_type = payload["type"]
raise JWTError(f"Invalid token type: expected '{expected_type}', got '{actual_type}'.")
# Create TokenPayload object from token claims
return TokenPayload(
sub=payload["sub"],
type=payload["type"],
exp=datetime.fromtimestamp(payload["exp"]),
iat=datetime.fromtimestamp(payload["iat"]),
iat=datetime.fromtimestamp(payload.get("iat", 0)),
jti=payload.get("jti")
)
except KeyError as e:
raise JWTError("Malformed token. Missing required claim.") from e
except ExpiredSignatureError as e:
raise JWTError("Token expired. Please refresh your token to continue.") from e
except JWTError as e:
raise JWTError(f"Invalid token: {str(e)}")
raise JWTError(str(e)) from e
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
@@ -108,4 +133,4 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create a new refresh token."""
return create_token(data, expires_delta, "refresh")
return create_token(data, expires_delta, "refresh")