Improved the `decode_token` function to clarify and extend error handling for token validation and decoding. Enhanced error messages for invalid tokens, added checks for missing claims, and ensured clear differentiation of failure scenarios. Updated imports and added a `scope` field to token response for completeness.
177 lines
5.8 KiB
Python
177 lines
5.8 KiB
Python
from datetime import datetime, timedelta, timezone
|
|
from typing import Optional
|
|
from uuid import uuid4
|
|
|
|
from fastapi import Depends
|
|
from jose import jwt, JWTError, ExpiredSignatureError, JOSEError
|
|
from passlib.context import CryptContext
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.config import settings
|
|
from app.core.database import get_db
|
|
from app.schemas.token import TokenPayload, TokenResponse
|
|
from auth.utils import is_token_revoked
|
|
|
|
# Configuration
|
|
SECRET_KEY = settings.SECRET_KEY
|
|
ALGORITHM = "HS256"
|
|
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
|
REFRESH_TOKEN_EXPIRE_DAYS = 7
|
|
|
|
# Password hashing context
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
|
|
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|
"""Verify a plain password against its hash."""
|
|
return pwd_context.verify(plain_password, hashed_password)
|
|
|
|
|
|
def get_password_hash(password: str) -> str:
|
|
"""Generate password hash."""
|
|
return pwd_context.hash(password)
|
|
|
|
|
|
def create_tokens(user_id: str) -> TokenResponse:
|
|
"""
|
|
Create both access and refresh tokens for a user.
|
|
|
|
Args:
|
|
user_id: The user's ID
|
|
|
|
Returns:
|
|
TokenResponse containing both tokens and metadata
|
|
"""
|
|
# Add `jti` during token creation
|
|
access_token = create_access_token({"sub": user_id, "jti": str(uuid4())})
|
|
refresh_token = create_refresh_token({"sub": user_id, "jti": str(uuid4())})
|
|
|
|
return TokenResponse(
|
|
access_token=access_token,
|
|
refresh_token=refresh_token,
|
|
token_type="bearer",
|
|
expires_in=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
|
user_id=user_id,
|
|
scope="read write"
|
|
)
|
|
|
|
|
|
def create_token(
|
|
data: dict,
|
|
expires_delta: Optional[timedelta] = None,
|
|
token_type: str = "access"
|
|
) -> str:
|
|
"""Create a JWT token with the specified type and expiration."""
|
|
to_encode = data.copy()
|
|
|
|
if expires_delta:
|
|
expire = datetime.now(timezone.utc) + expires_delta
|
|
else:
|
|
expire = datetime.now(timezone.utc) + (
|
|
timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) if token_type == "access"
|
|
else timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
|
)
|
|
|
|
to_encode.update({
|
|
"exp": expire,
|
|
"type": token_type,
|
|
"iat": datetime.now(timezone.utc),
|
|
})
|
|
if "jti" not in to_encode:
|
|
to_encode["jti"] = str(uuid4()) # Ensure unique `jti` is always added
|
|
|
|
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
|
|
|
|
|
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
|
"""Create a new access token."""
|
|
# Ensure `data` includes `jti` for consistency
|
|
if "jti" not in data:
|
|
data["jti"] = str(uuid4())
|
|
return create_token(data, expires_delta, "access")
|
|
|
|
|
|
def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
|
"""Create a new refresh token."""
|
|
# Ensure `data` includes `jti` for consistency
|
|
if "jti" not in data:
|
|
data["jti"] = str(uuid4())
|
|
return create_token(data, expires_delta, "refresh")
|
|
|
|
|
|
async def decode_token(
|
|
token: str,
|
|
required_type: str = "access",
|
|
db: AsyncSession = Depends(get_db)
|
|
) -> TokenPayload:
|
|
"""
|
|
Decode and validate a JWT token, including revocation checks.
|
|
|
|
Args:
|
|
token (str): The JWT token to decode.
|
|
required_type (str): The expected token type (default: "access").
|
|
db (AsyncSession): Database session for token revocation checks.
|
|
|
|
Returns:
|
|
TokenPayload: The decoded token data.
|
|
|
|
Raises:
|
|
JWTError: If the token is expired, revoked, malformed, or fails validation.
|
|
"""
|
|
try:
|
|
# Step 1: Decode the JWT token
|
|
payload = jwt.decode(
|
|
token,
|
|
SECRET_KEY,
|
|
algorithms=[ALGORITHM],
|
|
options={
|
|
"verify_exp": True,
|
|
"verify_iat": True,
|
|
"require": ["exp", "iat", "sub", "type", "jti"]
|
|
}
|
|
)
|
|
|
|
except ExpiredSignatureError:
|
|
raise JWTError("Token has expired. Please refresh your token or login again.")
|
|
except JWTError as e:
|
|
if "Signature verification failed" in str(e):
|
|
raise JWTError("Invalid token signature. The token may have been tampered with or corrupted.")
|
|
raise JWTError(f"Failed to decode the token: {e}")
|
|
except JOSEError as e:
|
|
if "segments" in str(e).lower():
|
|
raise JWTError("Malformed token. The token format is invalid (e.g., not enough segments).")
|
|
raise JWTError("Failed to decode the token. Ensure the token is valid and correctly formatted.") from e
|
|
except Exception as e:
|
|
# Catch-all for unexpected exceptions during decoding
|
|
raise JWTError(f"An unexpected error occurred while decoding the token: {e}")
|
|
|
|
# Step 2: Validate Required Claims
|
|
required_claims = ["exp", "sub", "type", "jti"]
|
|
missing_claims = [claim for claim in required_claims if claim not in payload]
|
|
if missing_claims:
|
|
raise JWTError(f"Malformed token. Missing required claims: {', '.join(missing_claims)}.")
|
|
|
|
# Step 3: Validate Expiry
|
|
expiration = datetime.fromtimestamp(payload["exp"])
|
|
if datetime.now(timezone.utc) > expiration:
|
|
raise JWTError("Token has expired. Please refresh your token or login again.")
|
|
|
|
# Step 4: Validate Token Type
|
|
token_type = payload.get("type")
|
|
if token_type != required_type:
|
|
raise JWTError(f"Invalid token type: expected '{required_type}', got '{token_type}'.")
|
|
|
|
# Step 5: Check Revocation
|
|
jti = payload.get("jti")
|
|
if await is_token_revoked(jti, db):
|
|
raise JWTError("Token has been revoked. Please login again to generate a new token.")
|
|
|
|
# Step 6: Return Validated Token Payload
|
|
return TokenPayload(
|
|
sub=payload["sub"],
|
|
type=payload["type"],
|
|
exp=expiration,
|
|
iat=datetime.fromtimestamp(payload.get("iat", 0)),
|
|
jti=jti
|
|
)
|