Files
eventspace/backend/app/auth/security.py
Felipe Cardoso 8814dc931f Add token revocation mechanism and support for logout APIs
This commit introduces a system to revoke tokens by storing their `jti` in a new `RevokedToken` model. It includes APIs for logging out (revoking a current token) and logging out from all devices (revoking all tokens). Additionally, token validation now checks revocation status during the decode process.
2025-02-28 17:45:33 +01:00

157 lines
5.2 KiB
Python

from datetime import datetime, timedelta
from typing import Optional, Tuple
from uuid import uuid4
from sqlalchemy.ext.asyncio import AsyncSession
from jose import jwt, ExpiredSignatureError, JWTError
from passlib.context import CryptContext
from app.core.config import settings
from app.schemas.token import TokenPayload, TokenResponse
from jose.exceptions import ExpiredSignatureError, JWTError, JOSEError
from fastapi import Depends
from app.core.database import get_db
from auth.utlis import is_token_revoked
# Configuration
SECRET_KEY = settings.SECRET_KEY
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
REFRESH_TOKEN_EXPIRE_DAYS = 7
# Password hashing context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a plain password against its hash."""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
"""Generate password hash."""
return pwd_context.hash(password)
def create_tokens(user_id: str) -> TokenResponse:
"""
Create both access and refresh tokens for a user.
Args:
user_id: The user's ID
Returns:
TokenResponse containing both tokens and metadata
"""
# Add `jti` during token creation
access_token = create_access_token({"sub": user_id, "jti": str(uuid4())})
refresh_token = create_refresh_token({"sub": user_id, "jti": str(uuid4())})
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
token_type="bearer",
expires_in=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
user_id=user_id
)
def create_token(
data: dict,
expires_delta: Optional[timedelta] = None,
token_type: str = "access"
) -> str:
"""Create a JWT token with the specified type and expiration."""
to_encode = data.copy()
if expires_delta:
expire = datetime.now() + expires_delta
else:
expire = datetime.now() + (
timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) if token_type == "access"
else timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
)
to_encode.update({
"exp": expire,
"type": token_type,
"iat": datetime.now(),
})
if "jti" not in to_encode:
to_encode["jti"] = str(uuid4()) # Ensure unique `jti` is always added
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create a new access token."""
# Ensure `data` includes `jti` for consistency
if "jti" not in data:
data["jti"] = str(uuid4())
return create_token(data, expires_delta, "access")
def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create a new refresh token."""
# Ensure `data` includes `jti` for consistency
if "jti" not in data:
data["jti"] = str(uuid4())
return create_token(data, expires_delta, "refresh")
async def decode_token(
token: str,
required_type: str = "access",
db: AsyncSession = Depends(get_db)
) -> TokenPayload:
"""
Decode and validate a JWT token, including revocation checks.
Args:
token (str): The JWT token to decode.
required_type (str): The expected token type (default: "access").
db (AsyncSession): Database session for token revocation checks.
Returns:
TokenPayload: The decoded token data.
Raises:
JWTError: If the token is expired, revoked, or malformed.
"""
try:
# Decode the JWT token using the secret and algorithm
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
# Explicitly validate required claims
if "exp" not in payload or "sub" not in payload or "type" not in payload or "jti" not in payload:
raise KeyError("Missing required claim(s) in token.")
# Validate token expiration (`exp`)
if datetime.now() > datetime.fromtimestamp(payload["exp"]):
raise ExpiredSignatureError("Token has expired.")
# Validate the token type (`type`)
if payload["type"] != required_type:
raise JWTError(f"Invalid token type: expected '{required_type}', got '{payload['type']}'.")
# Check the token's revocation status (via `jti`)
if await is_token_revoked(payload["jti"], db):
raise JWTError("Token has been revoked.")
# Construct and return the token payload
return TokenPayload(
sub=payload["sub"],
type=payload["type"],
exp=datetime.fromtimestamp(payload["exp"]),
iat=datetime.fromtimestamp(payload.get("iat", 0)),
jti=payload["jti"]
)
except ExpiredSignatureError as e:
# Handle expired token exception
raise JWTError("Token expired. Please refresh your token to continue.") from e
except KeyError as e:
# Handle missing claims in the token
raise JWTError("Malformed token. Missing required claim(s).") from e
except JWTError as e:
# Handle any other JWT-specific exceptions
raise JWTError(str(e)) from e