Add token revocation mechanism and support for logout APIs
This commit introduces a system to revoke tokens by storing their `jti` in a new `RevokedToken` model. It includes APIs for logging out (revoking a current token) and logging out from all devices (revoking all tokens). Additionally, token validation now checks revocation status during the decode process.
This commit is contained in:
@@ -2,12 +2,15 @@ from datetime import datetime, timedelta
|
||||
from typing import Optional, Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
from black import timezone
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from jose import jwt, ExpiredSignatureError, JWTError
|
||||
from passlib.context import CryptContext
|
||||
from app.core.config import settings
|
||||
from app.schemas.token import TokenPayload, TokenResponse
|
||||
from jose.exceptions import ExpiredSignatureError, JWTError, JOSEError
|
||||
from fastapi import Depends
|
||||
from app.core.database import get_db
|
||||
from auth.utlis import is_token_revoked
|
||||
|
||||
# Configuration
|
||||
SECRET_KEY = settings.SECRET_KEY
|
||||
@@ -39,8 +42,9 @@ def create_tokens(user_id: str) -> TokenResponse:
|
||||
Returns:
|
||||
TokenResponse containing both tokens and metadata
|
||||
"""
|
||||
access_token = create_access_token({"sub": user_id})
|
||||
refresh_token = create_refresh_token({"sub": user_id})
|
||||
# Add `jti` during token creation
|
||||
access_token = create_access_token({"sub": user_id, "jti": str(uuid4())})
|
||||
refresh_token = create_refresh_token({"sub": user_id, "jti": str(uuid4())})
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
@@ -50,6 +54,7 @@ def create_tokens(user_id: str) -> TokenResponse:
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
|
||||
def create_token(
|
||||
data: dict,
|
||||
expires_delta: Optional[timedelta] = None,
|
||||
@@ -70,70 +75,82 @@ def create_token(
|
||||
"exp": expire,
|
||||
"type": token_type,
|
||||
"iat": datetime.now(),
|
||||
"jti": str(uuid4())
|
||||
})
|
||||
if "jti" not in to_encode:
|
||||
to_encode["jti"] = str(uuid4()) # Ensure unique `jti` is always added
|
||||
|
||||
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""Create a new access token."""
|
||||
# Ensure `data` includes `jti` for consistency
|
||||
if "jti" not in data:
|
||||
data["jti"] = str(uuid4())
|
||||
return create_token(data, expires_delta, "access")
|
||||
|
||||
|
||||
def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""Create a new refresh token."""
|
||||
# Ensure `data` includes `jti` for consistency
|
||||
if "jti" not in data:
|
||||
data["jti"] = str(uuid4())
|
||||
return create_token(data, expires_delta, "refresh")
|
||||
|
||||
def decode_token(token: str, required_type: str = "access") -> TokenPayload:
|
||||
async def decode_token(
|
||||
token: str,
|
||||
required_type: str = "access",
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> TokenPayload:
|
||||
"""
|
||||
Decode and validate a JWT token with explicit edge-case handling.
|
||||
Decode and validate a JWT token, including revocation checks.
|
||||
|
||||
Args:
|
||||
token: The JWT token to decode.
|
||||
required_type: The expected token type (default: "access").
|
||||
token (str): The JWT token to decode.
|
||||
required_type (str): The expected token type (default: "access").
|
||||
db (AsyncSession): Database session for token revocation checks.
|
||||
|
||||
Returns:
|
||||
TokenPayload containing the decoded data.
|
||||
TokenPayload: The decoded token data.
|
||||
|
||||
Raises:
|
||||
JWTError: If the token is expired, invalid, or malformed.
|
||||
JWTError: If the token is expired, revoked, or malformed.
|
||||
"""
|
||||
try:
|
||||
# Decode the JWT token using the secret and algorithm
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
|
||||
# Explicitly validate required claims (`exp`, `sub`, `type`)
|
||||
if "exp" not in payload or "sub" not in payload or "type" not in payload:
|
||||
raise KeyError("Missing required claim.")
|
||||
# Explicitly validate required claims
|
||||
if "exp" not in payload or "sub" not in payload or "type" not in payload or "jti" not in payload:
|
||||
raise KeyError("Missing required claim(s) in token.")
|
||||
|
||||
# Verify token expiration (`exp`)
|
||||
# Validate token expiration (`exp`)
|
||||
if datetime.now() > datetime.fromtimestamp(payload["exp"]):
|
||||
raise ExpiredSignatureError("Token has expired.")
|
||||
|
||||
# Verify token type (`type`)
|
||||
# Validate the token type (`type`)
|
||||
if payload["type"] != required_type:
|
||||
expected_type = required_type
|
||||
actual_type = payload["type"]
|
||||
raise JWTError(f"Invalid token type: expected '{expected_type}', got '{actual_type}'.")
|
||||
raise JWTError(f"Invalid token type: expected '{required_type}', got '{payload['type']}'.")
|
||||
|
||||
# Create TokenPayload object from token claims
|
||||
# Check the token's revocation status (via `jti`)
|
||||
if await is_token_revoked(payload["jti"], db):
|
||||
raise JWTError("Token has been revoked.")
|
||||
|
||||
# Construct and return the token payload
|
||||
return TokenPayload(
|
||||
sub=payload["sub"],
|
||||
type=payload["type"],
|
||||
exp=datetime.fromtimestamp(payload["exp"]),
|
||||
iat=datetime.fromtimestamp(payload.get("iat", 0)),
|
||||
jti=payload.get("jti")
|
||||
jti=payload["jti"]
|
||||
)
|
||||
|
||||
except ExpiredSignatureError as e: # Expired token
|
||||
except ExpiredSignatureError as e:
|
||||
# Handle expired token exception
|
||||
raise JWTError("Token expired. Please refresh your token to continue.") from e
|
||||
except KeyError as e:
|
||||
# Handle missing claims in the token
|
||||
raise JWTError("Malformed token. Missing required claim(s).") from e
|
||||
except JWTError as e:
|
||||
# Handle signature verification and malformed token errors
|
||||
if str(e) in ["Signature verification failed.", "Not enough segments"]:
|
||||
raise JWTError("Invalid token.") from e
|
||||
# Propagate other JWTError messages
|
||||
# Handle any other JWT-specific exceptions
|
||||
raise JWTError(str(e)) from e
|
||||
except KeyError as e: # Missing required claims
|
||||
raise JWTError("Malformed token. Missing required claim.") from e
|
||||
except JOSEError as e: # All other JOSE-related errors
|
||||
raise JWTError("Invalid token.") from e
|
||||
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
# auth/utils.py
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.models.token import RevokedToken
|
||||
|
||||
|
||||
async def revoke_token(jti: str, token_type: str, user_id: str, db: AsyncSession):
|
||||
"""Revoke a token by adding its `jti` to the database."""
|
||||
"""Revoke a token by storing its `jti` in the revoked_tokens table."""
|
||||
revoked_token = RevokedToken(jti=jti, token_type=token_type, user_id=user_id)
|
||||
db.add(revoked_token)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def is_token_revoked(jti: str, db: AsyncSession):
|
||||
"""Check if a token with the given `jti` is revoked."""
|
||||
result = await db.get(RevokedToken, jti)
|
||||
return result is not None
|
||||
async def is_token_revoked(jti: str, db: AsyncSession) -> bool:
|
||||
"""Check whether the token's `jti` is in the revoked_tokens table."""
|
||||
revoked = await db.get(RevokedToken, jti)
|
||||
return revoked is not None
|
||||
|
||||
Reference in New Issue
Block a user