This commit introduces a system to revoke tokens by storing their `jti` in a new `RevokedToken` model. It includes APIs for logging out (revoking a current token) and logging out from all devices (revoking all tokens). Additionally, token validation now checks revocation status during the decode process.
157 lines
5.2 KiB
Python
157 lines
5.2 KiB
Python
from datetime import datetime, timedelta
|
|
from typing import Optional, Tuple
|
|
from uuid import uuid4
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from jose import jwt, ExpiredSignatureError, JWTError
|
|
from passlib.context import CryptContext
|
|
from app.core.config import settings
|
|
from app.schemas.token import TokenPayload, TokenResponse
|
|
from jose.exceptions import ExpiredSignatureError, JWTError, JOSEError
|
|
from fastapi import Depends
|
|
from app.core.database import get_db
|
|
from auth.utlis import is_token_revoked
|
|
|
|
# Configuration
|
|
SECRET_KEY = settings.SECRET_KEY
|
|
ALGORITHM = "HS256"
|
|
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
|
REFRESH_TOKEN_EXPIRE_DAYS = 7
|
|
|
|
# Password hashing context
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
|
|
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|
"""Verify a plain password against its hash."""
|
|
return pwd_context.verify(plain_password, hashed_password)
|
|
|
|
|
|
def get_password_hash(password: str) -> str:
|
|
"""Generate password hash."""
|
|
return pwd_context.hash(password)
|
|
|
|
|
|
def create_tokens(user_id: str) -> TokenResponse:
|
|
"""
|
|
Create both access and refresh tokens for a user.
|
|
|
|
Args:
|
|
user_id: The user's ID
|
|
|
|
Returns:
|
|
TokenResponse containing both tokens and metadata
|
|
"""
|
|
# Add `jti` during token creation
|
|
access_token = create_access_token({"sub": user_id, "jti": str(uuid4())})
|
|
refresh_token = create_refresh_token({"sub": user_id, "jti": str(uuid4())})
|
|
|
|
return TokenResponse(
|
|
access_token=access_token,
|
|
refresh_token=refresh_token,
|
|
token_type="bearer",
|
|
expires_in=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
|
user_id=user_id
|
|
)
|
|
|
|
|
|
def create_token(
|
|
data: dict,
|
|
expires_delta: Optional[timedelta] = None,
|
|
token_type: str = "access"
|
|
) -> str:
|
|
"""Create a JWT token with the specified type and expiration."""
|
|
to_encode = data.copy()
|
|
|
|
if expires_delta:
|
|
expire = datetime.now() + expires_delta
|
|
else:
|
|
expire = datetime.now() + (
|
|
timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) if token_type == "access"
|
|
else timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
|
)
|
|
|
|
to_encode.update({
|
|
"exp": expire,
|
|
"type": token_type,
|
|
"iat": datetime.now(),
|
|
})
|
|
if "jti" not in to_encode:
|
|
to_encode["jti"] = str(uuid4()) # Ensure unique `jti` is always added
|
|
|
|
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
|
|
|
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
|
"""Create a new access token."""
|
|
# Ensure `data` includes `jti` for consistency
|
|
if "jti" not in data:
|
|
data["jti"] = str(uuid4())
|
|
return create_token(data, expires_delta, "access")
|
|
|
|
|
|
def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
|
"""Create a new refresh token."""
|
|
# Ensure `data` includes `jti` for consistency
|
|
if "jti" not in data:
|
|
data["jti"] = str(uuid4())
|
|
return create_token(data, expires_delta, "refresh")
|
|
|
|
async def decode_token(
|
|
token: str,
|
|
required_type: str = "access",
|
|
db: AsyncSession = Depends(get_db)
|
|
) -> TokenPayload:
|
|
"""
|
|
Decode and validate a JWT token, including revocation checks.
|
|
|
|
Args:
|
|
token (str): The JWT token to decode.
|
|
required_type (str): The expected token type (default: "access").
|
|
db (AsyncSession): Database session for token revocation checks.
|
|
|
|
Returns:
|
|
TokenPayload: The decoded token data.
|
|
|
|
Raises:
|
|
JWTError: If the token is expired, revoked, or malformed.
|
|
"""
|
|
try:
|
|
# Decode the JWT token using the secret and algorithm
|
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
|
|
# Explicitly validate required claims
|
|
if "exp" not in payload or "sub" not in payload or "type" not in payload or "jti" not in payload:
|
|
raise KeyError("Missing required claim(s) in token.")
|
|
|
|
# Validate token expiration (`exp`)
|
|
if datetime.now() > datetime.fromtimestamp(payload["exp"]):
|
|
raise ExpiredSignatureError("Token has expired.")
|
|
|
|
# Validate the token type (`type`)
|
|
if payload["type"] != required_type:
|
|
raise JWTError(f"Invalid token type: expected '{required_type}', got '{payload['type']}'.")
|
|
|
|
# Check the token's revocation status (via `jti`)
|
|
if await is_token_revoked(payload["jti"], db):
|
|
raise JWTError("Token has been revoked.")
|
|
|
|
# Construct and return the token payload
|
|
return TokenPayload(
|
|
sub=payload["sub"],
|
|
type=payload["type"],
|
|
exp=datetime.fromtimestamp(payload["exp"]),
|
|
iat=datetime.fromtimestamp(payload.get("iat", 0)),
|
|
jti=payload["jti"]
|
|
)
|
|
|
|
except ExpiredSignatureError as e:
|
|
# Handle expired token exception
|
|
raise JWTError("Token expired. Please refresh your token to continue.") from e
|
|
except KeyError as e:
|
|
# Handle missing claims in the token
|
|
raise JWTError("Malformed token. Missing required claim(s).") from e
|
|
except JWTError as e:
|
|
# Handle any other JWT-specific exceptions
|
|
raise JWTError(str(e)) from e
|
|
|