Add token revocation mechanism and support for logout APIs

This commit introduces a system to revoke tokens by storing their `jti` in a new `RevokedToken` model. It includes APIs for logging out (revoking a current token) and logging out from all devices (revoking all tokens). Additionally, token validation now checks revocation status during the decode process.
This commit is contained in:
2025-02-28 17:45:33 +01:00
parent aa77752981
commit 8814dc931f
8 changed files with 270 additions and 208 deletions

View File

@@ -1,182 +1,134 @@
from typing import Annotated
from typing import Any
from app.auth.utils import revoke_token, is_token_revoked
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError
from sqlalchemy.orm import Session
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.security import (
verify_password,
get_password_hash,
create_tokens,
decode_token,
)
from app.auth.security import authenticate_user, create_access_token, create_refresh_token, decode_token
from app.core.database import get_db
from app.models.user import User
from app.schemas.token import TokenResponse, RefreshToken
from app.schemas.user import UserCreate, UserResponse
from app.schemas.token import TokenResponse, TokenPayload, RefreshToken
from app.schemas.user import UserResponse
router = APIRouter()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
oauth2_scheme = OAuth2PasswordRequestForm
async def get_current_user(
token: Annotated[str, Depends(oauth2_scheme)],
db: Session = Depends(get_db)
) -> User:
# Existing: User Login Endpoint
@router.post(
"/auth/login",
response_model=TokenResponse,
summary="Authenticate user and provide tokens"
)
async def login(
form_data: OAuth2PasswordRequestForm = Depends(),
db: AsyncSession = Depends(get_db)
) -> Any:
"""
Get the current user based on the JWT token.
Args:
token: JWT token from authorization header
db: Database session
Returns:
User object if valid token
Raises:
HTTPException: If token is invalid or user not found
Authenticate a user with their credentials and return an access and refresh token.
"""
try:
payload = decode_token(token)
if payload.type != "access":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid access token type"
)
user = db.query(User).filter(User.id == payload.sub).first()
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found"
)
return user
except JWTError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Could not validate credentials: {str(e)}"
)
user = await authenticate_user(email=form_data.username, password=form_data.password, db=db)
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials.")
# Generate access and refresh tokens
access_token = create_access_token({"sub": str(user.id), "type": "access"})
refresh_token = create_refresh_token({"sub": str(user.id), "type": "refresh"})
async def authenticate_user(
email: str,
password: str,
db: Session
) -> User:
"""
Authenticate a user by email and password.
Args:
email: User's email
password: User's password
db: Database session
Returns:
User object if authentication successful, None otherwise
"""
user = db.query(User).filter(User.email == email).first()
if not user or not verify_password(password, user.password_hash):
return None
return user
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def register(
user_create: UserCreate,
db: Session = Depends(get_db)
):
"""
Register a new user.
"""
# Check if user already exists
if db.query(User).filter(User.email == user_create.email).first():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered"
)
# Create new user
user = User(
email=user_create.email,
password_hash=get_password_hash(user_create.password),
first_name=user_create.first_name,
last_name=user_create.last_name,
phone_number=user_create.phone_number
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
token_type="bearer",
expires_in=1800, # Example: 30 minutes for access token
user_id=str(user.id),
)
db.add(user)
db.commit()
db.refresh(user)
return user
@router.post("/login", response_model=TokenResponse)
async def login(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: Session = Depends(get_db)
# New: Logout Endpoint (Revoke Token)
@router.post(
"/auth/logout",
summary="Revoke the current token",
response_model=dict,
status_code=status.HTTP_200_OK
)
async def logout(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(
lambda token=Depends(oauth2_scheme), db=Depends(get_db): decode_token(token, db=db))
):
"""
OAuth2 compatible token login, get an access token for future requests.
Logout the user by revoking the current token.
"""
user = await authenticate_user(form_data.username, form_data.password, db)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password",
headers={"WWW-Authenticate": "Bearer"},
)
# Decode the token and revoke it
payload: TokenPayload = await decode_token(token, db=db)
await revoke_token(payload.jti, payload.type, payload.sub, db)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Inactive user"
)
return create_tokens(str(user.id))
return {"message": "Successfully logged out."}
@router.post("/refresh", response_model=TokenResponse)
# New: Bulk Logout (Revoke All of a User's Tokens)
@router.post(
"/auth/logout-all",
summary="Revoke all active tokens for the user",
response_model=dict,
status_code=status.HTTP_200_OK
)
async def logout_all(
db: AsyncSession = Depends(get_db),
current_user: User = Depends(
lambda token=Depends(oauth2_scheme), db=Depends(get_db): decode_token(token, db=db))
):
"""
Revoke all tokens for the current user, effectively logging them out across all devices.
"""
await db.execute("DELETE FROM revoked_tokens WHERE user_id = :user_id", {"user_id": str(current_user.id)})
await db.commit()
return {"message": "Logged out from all devices."}
# Updated: Refresh Token Endpoint
@router.post(
"/auth/refresh-token",
response_model=TokenResponse,
summary="Generate a new access token using a refresh token"
)
async def refresh_token(
refresh_token: RefreshToken,
db: Session = Depends(get_db)
):
db: AsyncSession = Depends(get_db)
) -> TokenResponse:
"""
Refresh access token using refresh token.
Refresh the user's access token using their refresh token while ensuring it has not been revoked.
"""
try:
payload = decode_token(refresh_token.refresh_token)
payload: TokenPayload = await decode_token(refresh_token.refresh_token, required_type="refresh", db=db)
# Validate token type
if payload.type != "refresh":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token type"
)
if await is_token_revoked(payload.jti, db):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=f"Token has been revoked.")
# Verify user still exists and is active
user = db.query(User).filter(User.id == payload.sub).first()
if not user or not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found or inactive"
)
return create_tokens(str(user.id))
except JWTError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Invalid refresh token: {str(e)}"
)
# Generate a new access token with the user's info
new_access_token = create_access_token({"sub": payload.sub, "type": "access"})
return TokenResponse(
access_token=new_access_token,
refresh_token=refresh_token.refresh_token, # Reuse existing refresh token
expires_in=1800, # Example: 30 minutes expiry for access token
token_type="bearer",
user_id=payload.sub,
)
@router.get("/me", response_model=UserResponse)
# Existing: Get Current User Endpoint
@router.get(
"/auth/me",
response_model=UserResponse,
summary="Get user details from the token"
)
async def read_users_me(
current_user: Annotated[User, Depends(get_current_user)]
):
current_user: User = Depends(
lambda token=Depends(oauth2_scheme), db=Depends(get_db): decode_token(token, db=db))
) -> UserResponse:
"""
Get current user information.
Retrieves the details of the currently authenticated user.
"""
return current_user