Files
eventspace/backend/app/auth/security.py
Felipe Cardoso 453016629f Refactor and enhance token decoding error handling
Improved the `decode_token` function to clarify and extend error handling for token validation and decoding. Enhanced error messages for invalid tokens, added checks for missing claims, and ensured clear differentiation of failure scenarios. Updated imports and added a `scope` field to token response for completeness.
2025-02-28 19:05:08 +01:00

177 lines
5.8 KiB
Python

from datetime import datetime, timedelta, timezone
from typing import Optional
from uuid import uuid4
from fastapi import Depends
from jose import jwt, JWTError, ExpiredSignatureError, JOSEError
from passlib.context import CryptContext
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.database import get_db
from app.schemas.token import TokenPayload, TokenResponse
from auth.utils import is_token_revoked
# Configuration
SECRET_KEY = settings.SECRET_KEY
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
REFRESH_TOKEN_EXPIRE_DAYS = 7
# Password hashing context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a plain password against its hash."""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
"""Generate password hash."""
return pwd_context.hash(password)
def create_tokens(user_id: str) -> TokenResponse:
"""
Create both access and refresh tokens for a user.
Args:
user_id: The user's ID
Returns:
TokenResponse containing both tokens and metadata
"""
# Add `jti` during token creation
access_token = create_access_token({"sub": user_id, "jti": str(uuid4())})
refresh_token = create_refresh_token({"sub": user_id, "jti": str(uuid4())})
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
token_type="bearer",
expires_in=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
user_id=user_id,
scope="read write"
)
def create_token(
data: dict,
expires_delta: Optional[timedelta] = None,
token_type: str = "access"
) -> str:
"""Create a JWT token with the specified type and expiration."""
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + (
timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) if token_type == "access"
else timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
)
to_encode.update({
"exp": expire,
"type": token_type,
"iat": datetime.now(timezone.utc),
})
if "jti" not in to_encode:
to_encode["jti"] = str(uuid4()) # Ensure unique `jti` is always added
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create a new access token."""
# Ensure `data` includes `jti` for consistency
if "jti" not in data:
data["jti"] = str(uuid4())
return create_token(data, expires_delta, "access")
def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create a new refresh token."""
# Ensure `data` includes `jti` for consistency
if "jti" not in data:
data["jti"] = str(uuid4())
return create_token(data, expires_delta, "refresh")
async def decode_token(
token: str,
required_type: str = "access",
db: AsyncSession = Depends(get_db)
) -> TokenPayload:
"""
Decode and validate a JWT token, including revocation checks.
Args:
token (str): The JWT token to decode.
required_type (str): The expected token type (default: "access").
db (AsyncSession): Database session for token revocation checks.
Returns:
TokenPayload: The decoded token data.
Raises:
JWTError: If the token is expired, revoked, malformed, or fails validation.
"""
try:
# Step 1: Decode the JWT token
payload = jwt.decode(
token,
SECRET_KEY,
algorithms=[ALGORITHM],
options={
"verify_exp": True,
"verify_iat": True,
"require": ["exp", "iat", "sub", "type", "jti"]
}
)
except ExpiredSignatureError:
raise JWTError("Token has expired. Please refresh your token or login again.")
except JWTError as e:
if "Signature verification failed" in str(e):
raise JWTError("Invalid token signature. The token may have been tampered with or corrupted.")
raise JWTError(f"Failed to decode the token: {e}")
except JOSEError as e:
if "segments" in str(e).lower():
raise JWTError("Malformed token. The token format is invalid (e.g., not enough segments).")
raise JWTError("Failed to decode the token. Ensure the token is valid and correctly formatted.") from e
except Exception as e:
# Catch-all for unexpected exceptions during decoding
raise JWTError(f"An unexpected error occurred while decoding the token: {e}")
# Step 2: Validate Required Claims
required_claims = ["exp", "sub", "type", "jti"]
missing_claims = [claim for claim in required_claims if claim not in payload]
if missing_claims:
raise JWTError(f"Malformed token. Missing required claims: {', '.join(missing_claims)}.")
# Step 3: Validate Expiry
expiration = datetime.fromtimestamp(payload["exp"])
if datetime.now(timezone.utc) > expiration:
raise JWTError("Token has expired. Please refresh your token or login again.")
# Step 4: Validate Token Type
token_type = payload.get("type")
if token_type != required_type:
raise JWTError(f"Invalid token type: expected '{required_type}', got '{token_type}'.")
# Step 5: Check Revocation
jti = payload.get("jti")
if await is_token_revoked(jti, db):
raise JWTError("Token has been revoked. Please login again to generate a new token.")
# Step 6: Return Validated Token Payload
return TokenPayload(
sub=payload["sub"],
type=payload["type"],
exp=expiration,
iat=datetime.fromtimestamp(payload.get("iat", 0)),
jti=jti
)