Files
eventspace/backend/app/auth/security.py
Felipe Cardoso af53b52c0c Refactor token creation logic and fix datetime usage
Adjusted `datetime.utcnow` to `datetime.now` for consistency and refactored token creation functions for cleaner structure. Removed duplicated `create_access_token` and `create_refresh_token` definitions by consolidating them into a single location.
2025-02-28 17:32:20 +01:00

140 lines
4.6 KiB
Python

from datetime import datetime, timedelta
from typing import Optional, Tuple
from uuid import uuid4
from black import timezone
from jose import jwt, ExpiredSignatureError, JWTError
from passlib.context import CryptContext
from app.core.config import settings
from app.schemas.token import TokenPayload, TokenResponse
from jose.exceptions import ExpiredSignatureError, JWTError, JOSEError
# Configuration
SECRET_KEY = settings.SECRET_KEY
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
REFRESH_TOKEN_EXPIRE_DAYS = 7
# Password hashing context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a plain password against its hash."""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
"""Generate password hash."""
return pwd_context.hash(password)
def create_tokens(user_id: str) -> TokenResponse:
"""
Create both access and refresh tokens for a user.
Args:
user_id: The user's ID
Returns:
TokenResponse containing both tokens and metadata
"""
access_token = create_access_token({"sub": user_id})
refresh_token = create_refresh_token({"sub": user_id})
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
token_type="bearer",
expires_in=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
user_id=user_id
)
def create_token(
data: dict,
expires_delta: Optional[timedelta] = None,
token_type: str = "access"
) -> str:
"""Create a JWT token with the specified type and expiration."""
to_encode = data.copy()
if expires_delta:
expire = datetime.now() + expires_delta
else:
expire = datetime.now() + (
timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) if token_type == "access"
else timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
)
to_encode.update({
"exp": expire,
"type": token_type,
"iat": datetime.now(),
"jti": str(uuid4())
})
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create a new access token."""
return create_token(data, expires_delta, "access")
def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create a new refresh token."""
return create_token(data, expires_delta, "refresh")
def decode_token(token: str, required_type: str = "access") -> TokenPayload:
"""
Decode and validate a JWT token with explicit edge-case handling.
Args:
token: The JWT token to decode.
required_type: The expected token type (default: "access").
Returns:
TokenPayload containing the decoded data.
Raises:
JWTError: If the token is expired, invalid, or malformed.
"""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
# Explicitly validate required claims (`exp`, `sub`, `type`)
if "exp" not in payload or "sub" not in payload or "type" not in payload:
raise KeyError("Missing required claim.")
# Verify token expiration (`exp`)
if datetime.now() > datetime.fromtimestamp(payload["exp"]):
raise ExpiredSignatureError("Token has expired.")
# Verify token type (`type`)
if payload["type"] != required_type:
expected_type = required_type
actual_type = payload["type"]
raise JWTError(f"Invalid token type: expected '{expected_type}', got '{actual_type}'.")
# Create TokenPayload object from token claims
return TokenPayload(
sub=payload["sub"],
type=payload["type"],
exp=datetime.fromtimestamp(payload["exp"]),
iat=datetime.fromtimestamp(payload.get("iat", 0)),
jti=payload.get("jti")
)
except ExpiredSignatureError as e: # Expired token
raise JWTError("Token expired. Please refresh your token to continue.") from e
except JWTError as e:
# Handle signature verification and malformed token errors
if str(e) in ["Signature verification failed.", "Not enough segments"]:
raise JWTError("Invalid token.") from e
# Propagate other JWTError messages
raise JWTError(str(e)) from e
except KeyError as e: # Missing required claims
raise JWTError("Malformed token. Missing required claim.") from e
except JOSEError as e: # All other JOSE-related errors
raise JWTError("Invalid token.") from e