Refactor and enhance token decoding error handling

Improved the `decode_token` function to clarify and extend error handling for token validation and decoding. Enhanced error messages for invalid tokens, added checks for missing claims, and ensured clear differentiation of failure scenarios. Updated imports and added a `scope` field to token response for completeness.
This commit is contained in:
2025-02-28 18:12:39 +01:00
parent 8814dc931f
commit 0bc9263d24

View File

@@ -1,15 +1,15 @@
from datetime import datetime, timedelta
from typing import Optional, Tuple
from typing import Optional
from uuid import uuid4
from sqlalchemy.ext.asyncio import AsyncSession
from jose import jwt, ExpiredSignatureError, JWTError
from passlib.context import CryptContext
from app.core.config import settings
from app.schemas.token import TokenPayload, TokenResponse
from jose.exceptions import ExpiredSignatureError, JWTError, JOSEError
from fastapi import Depends
from jose import jwt, JWTError, ExpiredSignatureError, JOSEError
from passlib.context import CryptContext
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.database import get_db
from app.schemas.token import TokenPayload, TokenResponse
from auth.utlis import is_token_revoked
# Configuration
@@ -51,7 +51,8 @@ def create_tokens(user_id: str) -> TokenResponse:
refresh_token=refresh_token,
token_type="bearer",
expires_in=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
user_id=user_id
user_id=user_id,
scope="read write"
)
@@ -81,6 +82,7 @@ def create_token(
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create a new access token."""
# Ensure `data` includes `jti` for consistency
@@ -96,6 +98,7 @@ def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None)
data["jti"] = str(uuid4())
return create_token(data, expires_delta, "refresh")
async def decode_token(
token: str,
required_type: str = "access",
@@ -113,44 +116,52 @@ async def decode_token(
TokenPayload: The decoded token data.
Raises:
JWTError: If the token is expired, revoked, or malformed.
JWTError: If the token is expired, revoked, malformed, or fails validation.
"""
try:
# Decode the JWT token using the secret and algorithm
# Step 1: Decode the JWT token
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
# Explicitly validate required claims
if "exp" not in payload or "sub" not in payload or "type" not in payload or "jti" not in payload:
raise KeyError("Missing required claim(s) in token.")
# Validate token expiration (`exp`)
if datetime.now() > datetime.fromtimestamp(payload["exp"]):
raise ExpiredSignatureError("Token has expired.")
# Validate the token type (`type`)
if payload["type"] != required_type:
raise JWTError(f"Invalid token type: expected '{required_type}', got '{payload['type']}'.")
# Check the token's revocation status (via `jti`)
if await is_token_revoked(payload["jti"], db):
raise JWTError("Token has been revoked.")
# Construct and return the token payload
return TokenPayload(
sub=payload["sub"],
type=payload["type"],
exp=datetime.fromtimestamp(payload["exp"]),
iat=datetime.fromtimestamp(payload.get("iat", 0)),
jti=payload["jti"]
)
except ExpiredSignatureError as e:
# Handle expired token exception
raise JWTError("Token expired. Please refresh your token to continue.") from e
except KeyError as e:
# Handle missing claims in the token
raise JWTError("Malformed token. Missing required claim(s).") from e
except ExpiredSignatureError:
raise JWTError("Token has expired. Please refresh your token or login again.")
except JWTError as e:
# Handle any other JWT-specific exceptions
raise JWTError(str(e)) from e
if "Signature verification failed" in str(e):
raise JWTError("Invalid token signature. The token may have been tampered with or corrupted.")
raise JWTError(f"Failed to decode the token: {e}")
except JOSEError as e:
if "segments" in str(e).lower():
raise JWTError("Malformed token. The token format is invalid (e.g., not enough segments).")
raise JWTError("Failed to decode the token. Ensure the token is valid and correctly formatted.") from e
except Exception as e:
# Catch-all for unexpected exceptions during decoding
raise JWTError(f"An unexpected error occurred while decoding the token: {e}")
# Step 2: Validate Required Claims
required_claims = ["exp", "sub", "type", "jti"]
missing_claims = [claim for claim in required_claims if claim not in payload]
if missing_claims:
raise JWTError(f"Malformed token. Missing required claims: {', '.join(missing_claims)}.")
# Step 3: Validate Expiry
expiration = datetime.fromtimestamp(payload["exp"])
if datetime.now() > expiration:
raise JWTError("Token has expired. Please refresh your token or login again.")
# Step 4: Validate Token Type
token_type = payload.get("type")
if token_type != required_type:
raise JWTError(f"Invalid token type: expected '{required_type}', got '{token_type}'.")
# Step 5: Check Revocation
jti = payload.get("jti")
if await is_token_revoked(jti, db):
raise JWTError("Token has been revoked. Please login again to generate a new token.")
# Step 6: Return Validated Token Payload
return TokenPayload(
sub=payload["sub"],
type=payload["type"],
exp=expiration,
iat=datetime.fromtimestamp(payload.get("iat", 0)),
jti=jti
)