Add token revocation mechanism and support for logout APIs

This commit introduces a system to revoke tokens by storing their `jti` in a new `RevokedToken` model. It includes APIs for logging out (revoking a current token) and logging out from all devices (revoking all tokens). Additionally, token validation now checks revocation status during the decode process.
This commit is contained in:
2025-02-28 17:45:33 +01:00
parent aa77752981
commit 8814dc931f
8 changed files with 270 additions and 208 deletions

View File

@@ -2,12 +2,15 @@ from datetime import datetime, timedelta
from typing import Optional, Tuple
from uuid import uuid4
from black import timezone
from sqlalchemy.ext.asyncio import AsyncSession
from jose import jwt, ExpiredSignatureError, JWTError
from passlib.context import CryptContext
from app.core.config import settings
from app.schemas.token import TokenPayload, TokenResponse
from jose.exceptions import ExpiredSignatureError, JWTError, JOSEError
from fastapi import Depends
from app.core.database import get_db
from auth.utlis import is_token_revoked
# Configuration
SECRET_KEY = settings.SECRET_KEY
@@ -39,8 +42,9 @@ def create_tokens(user_id: str) -> TokenResponse:
Returns:
TokenResponse containing both tokens and metadata
"""
access_token = create_access_token({"sub": user_id})
refresh_token = create_refresh_token({"sub": user_id})
# Add `jti` during token creation
access_token = create_access_token({"sub": user_id, "jti": str(uuid4())})
refresh_token = create_refresh_token({"sub": user_id, "jti": str(uuid4())})
return TokenResponse(
access_token=access_token,
@@ -50,6 +54,7 @@ def create_tokens(user_id: str) -> TokenResponse:
user_id=user_id
)
def create_token(
data: dict,
expires_delta: Optional[timedelta] = None,
@@ -70,70 +75,82 @@ def create_token(
"exp": expire,
"type": token_type,
"iat": datetime.now(),
"jti": str(uuid4())
})
if "jti" not in to_encode:
to_encode["jti"] = str(uuid4()) # Ensure unique `jti` is always added
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create a new access token."""
# Ensure `data` includes `jti` for consistency
if "jti" not in data:
data["jti"] = str(uuid4())
return create_token(data, expires_delta, "access")
def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create a new refresh token."""
# Ensure `data` includes `jti` for consistency
if "jti" not in data:
data["jti"] = str(uuid4())
return create_token(data, expires_delta, "refresh")
def decode_token(token: str, required_type: str = "access") -> TokenPayload:
async def decode_token(
token: str,
required_type: str = "access",
db: AsyncSession = Depends(get_db)
) -> TokenPayload:
"""
Decode and validate a JWT token with explicit edge-case handling.
Decode and validate a JWT token, including revocation checks.
Args:
token: The JWT token to decode.
required_type: The expected token type (default: "access").
token (str): The JWT token to decode.
required_type (str): The expected token type (default: "access").
db (AsyncSession): Database session for token revocation checks.
Returns:
TokenPayload containing the decoded data.
TokenPayload: The decoded token data.
Raises:
JWTError: If the token is expired, invalid, or malformed.
JWTError: If the token is expired, revoked, or malformed.
"""
try:
# Decode the JWT token using the secret and algorithm
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
# Explicitly validate required claims (`exp`, `sub`, `type`)
if "exp" not in payload or "sub" not in payload or "type" not in payload:
raise KeyError("Missing required claim.")
# Explicitly validate required claims
if "exp" not in payload or "sub" not in payload or "type" not in payload or "jti" not in payload:
raise KeyError("Missing required claim(s) in token.")
# Verify token expiration (`exp`)
# Validate token expiration (`exp`)
if datetime.now() > datetime.fromtimestamp(payload["exp"]):
raise ExpiredSignatureError("Token has expired.")
# Verify token type (`type`)
# Validate the token type (`type`)
if payload["type"] != required_type:
expected_type = required_type
actual_type = payload["type"]
raise JWTError(f"Invalid token type: expected '{expected_type}', got '{actual_type}'.")
raise JWTError(f"Invalid token type: expected '{required_type}', got '{payload['type']}'.")
# Create TokenPayload object from token claims
# Check the token's revocation status (via `jti`)
if await is_token_revoked(payload["jti"], db):
raise JWTError("Token has been revoked.")
# Construct and return the token payload
return TokenPayload(
sub=payload["sub"],
type=payload["type"],
exp=datetime.fromtimestamp(payload["exp"]),
iat=datetime.fromtimestamp(payload.get("iat", 0)),
jti=payload.get("jti")
jti=payload["jti"]
)
except ExpiredSignatureError as e: # Expired token
except ExpiredSignatureError as e:
# Handle expired token exception
raise JWTError("Token expired. Please refresh your token to continue.") from e
except KeyError as e:
# Handle missing claims in the token
raise JWTError("Malformed token. Missing required claim(s).") from e
except JWTError as e:
# Handle signature verification and malformed token errors
if str(e) in ["Signature verification failed.", "Not enough segments"]:
raise JWTError("Invalid token.") from e
# Propagate other JWTError messages
# Handle any other JWT-specific exceptions
raise JWTError(str(e)) from e
except KeyError as e: # Missing required claims
raise JWTError("Malformed token. Missing required claim.") from e
except JOSEError as e: # All other JOSE-related errors
raise JWTError("Invalid token.") from e

View File

@@ -1,16 +1,15 @@
# auth/utils.py
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.token import RevokedToken
async def revoke_token(jti: str, token_type: str, user_id: str, db: AsyncSession):
"""Revoke a token by adding its `jti` to the database."""
"""Revoke a token by storing its `jti` in the revoked_tokens table."""
revoked_token = RevokedToken(jti=jti, token_type=token_type, user_id=user_id)
db.add(revoked_token)
await db.commit()
async def is_token_revoked(jti: str, db: AsyncSession):
"""Check if a token with the given `jti` is revoked."""
result = await db.get(RevokedToken, jti)
return result is not None
async def is_token_revoked(jti: str, db: AsyncSession) -> bool:
"""Check whether the token's `jti` is in the revoked_tokens table."""
revoked = await db.get(RevokedToken, jti)
return revoked is not None