From 0bc9263d241b519f9275e2e49f92371768a8849c Mon Sep 17 00:00:00 2001 From: Felipe Cardoso Date: Fri, 28 Feb 2025 18:12:39 +0100 Subject: [PATCH] Refactor and enhance token decoding error handling Improved the `decode_token` function to clarify and extend error handling for token validation and decoding. Enhanced error messages for invalid tokens, added checks for missing claims, and ensured clear differentiation of failure scenarios. Updated imports and added a `scope` field to token response for completeness. --- backend/app/auth/security.py | 97 ++++++++++++++++++++---------------- 1 file changed, 54 insertions(+), 43 deletions(-) diff --git a/backend/app/auth/security.py b/backend/app/auth/security.py index f7f2785..f9cc59a 100644 --- a/backend/app/auth/security.py +++ b/backend/app/auth/security.py @@ -1,15 +1,15 @@ from datetime import datetime, timedelta -from typing import Optional, Tuple +from typing import Optional from uuid import uuid4 -from sqlalchemy.ext.asyncio import AsyncSession -from jose import jwt, ExpiredSignatureError, JWTError -from passlib.context import CryptContext -from app.core.config import settings -from app.schemas.token import TokenPayload, TokenResponse -from jose.exceptions import ExpiredSignatureError, JWTError, JOSEError from fastapi import Depends +from jose import jwt, JWTError, ExpiredSignatureError, JOSEError +from passlib.context import CryptContext +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.config import settings from app.core.database import get_db +from app.schemas.token import TokenPayload, TokenResponse from auth.utlis import is_token_revoked # Configuration @@ -51,7 +51,8 @@ def create_tokens(user_id: str) -> TokenResponse: refresh_token=refresh_token, token_type="bearer", expires_in=ACCESS_TOKEN_EXPIRE_MINUTES * 60, - user_id=user_id + user_id=user_id, + scope="read write" ) @@ -81,6 +82,7 @@ def create_token( return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: """Create a new access token.""" # Ensure `data` includes `jti` for consistency @@ -96,6 +98,7 @@ def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None) data["jti"] = str(uuid4()) return create_token(data, expires_delta, "refresh") + async def decode_token( token: str, required_type: str = "access", @@ -113,44 +116,52 @@ async def decode_token( TokenPayload: The decoded token data. Raises: - JWTError: If the token is expired, revoked, or malformed. + JWTError: If the token is expired, revoked, malformed, or fails validation. """ try: - # Decode the JWT token using the secret and algorithm + # Step 1: Decode the JWT token payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) - # Explicitly validate required claims - if "exp" not in payload or "sub" not in payload or "type" not in payload or "jti" not in payload: - raise KeyError("Missing required claim(s) in token.") - - # Validate token expiration (`exp`) - if datetime.now() > datetime.fromtimestamp(payload["exp"]): - raise ExpiredSignatureError("Token has expired.") - - # Validate the token type (`type`) - if payload["type"] != required_type: - raise JWTError(f"Invalid token type: expected '{required_type}', got '{payload['type']}'.") - - # Check the token's revocation status (via `jti`) - if await is_token_revoked(payload["jti"], db): - raise JWTError("Token has been revoked.") - - # Construct and return the token payload - return TokenPayload( - sub=payload["sub"], - type=payload["type"], - exp=datetime.fromtimestamp(payload["exp"]), - iat=datetime.fromtimestamp(payload.get("iat", 0)), - jti=payload["jti"] - ) - - except ExpiredSignatureError as e: - # Handle expired token exception - raise JWTError("Token expired. Please refresh your token to continue.") from e - except KeyError as e: - # Handle missing claims in the token - raise JWTError("Malformed token. Missing required claim(s).") from e + except ExpiredSignatureError: + raise JWTError("Token has expired. Please refresh your token or login again.") except JWTError as e: - # Handle any other JWT-specific exceptions - raise JWTError(str(e)) from e + if "Signature verification failed" in str(e): + raise JWTError("Invalid token signature. The token may have been tampered with or corrupted.") + raise JWTError(f"Failed to decode the token: {e}") + except JOSEError as e: + if "segments" in str(e).lower(): + raise JWTError("Malformed token. The token format is invalid (e.g., not enough segments).") + raise JWTError("Failed to decode the token. Ensure the token is valid and correctly formatted.") from e + except Exception as e: + # Catch-all for unexpected exceptions during decoding + raise JWTError(f"An unexpected error occurred while decoding the token: {e}") + # Step 2: Validate Required Claims + required_claims = ["exp", "sub", "type", "jti"] + missing_claims = [claim for claim in required_claims if claim not in payload] + if missing_claims: + raise JWTError(f"Malformed token. Missing required claims: {', '.join(missing_claims)}.") + + # Step 3: Validate Expiry + expiration = datetime.fromtimestamp(payload["exp"]) + if datetime.now() > expiration: + raise JWTError("Token has expired. Please refresh your token or login again.") + + # Step 4: Validate Token Type + token_type = payload.get("type") + if token_type != required_type: + raise JWTError(f"Invalid token type: expected '{required_type}', got '{token_type}'.") + + # Step 5: Check Revocation + jti = payload.get("jti") + if await is_token_revoked(jti, db): + raise JWTError("Token has been revoked. Please login again to generate a new token.") + + # Step 6: Return Validated Token Payload + return TokenPayload( + sub=payload["sub"], + type=payload["type"], + exp=expiration, + iat=datetime.fromtimestamp(payload.get("iat", 0)), + jti=jti + )