Add token revocation mechanism and support for logout APIs

This commit introduces a system to revoke tokens by storing their `jti` in a new `RevokedToken` model. It includes APIs for logging out (revoking a current token) and logging out from all devices (revoking all tokens). Additionally, token validation now checks revocation status during the decode process.
This commit is contained in:
2025-02-28 17:45:33 +01:00
parent aa77752981
commit 8814dc931f
8 changed files with 270 additions and 208 deletions

View File

@@ -0,0 +1,41 @@
"""Add RevokedToken model
Revision ID: 37315a5b4021
Revises: 38bf9e7e74b3
Create Date: 2025-02-28 17:11:07.741372
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '37315a5b4021'
down_revision: Union[str, None] = '38bf9e7e74b3'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('revoked_tokens',
sa.Column('jti', sa.String(length=50), nullable=False),
sa.Column('token_type', sa.String(length=20), nullable=False),
sa.Column('user_id', sa.UUID(), nullable=True),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_revoked_tokens_jti'), 'revoked_tokens', ['jti'], unique=True)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_revoked_tokens_jti'), table_name='revoked_tokens')
op.drop_table('revoked_tokens')
# ### end Alembic commands ###

View File

@@ -1,182 +1,134 @@
from typing import Annotated
from typing import Any
from app.auth.utils import revoke_token, is_token_revoked
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError
from sqlalchemy.orm import Session
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.security import (
verify_password,
get_password_hash,
create_tokens,
decode_token,
)
from app.auth.security import authenticate_user, create_access_token, create_refresh_token, decode_token
from app.core.database import get_db
from app.models.user import User
from app.schemas.token import TokenResponse, RefreshToken
from app.schemas.user import UserCreate, UserResponse
from app.schemas.token import TokenResponse, TokenPayload, RefreshToken
from app.schemas.user import UserResponse
router = APIRouter()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
oauth2_scheme = OAuth2PasswordRequestForm
async def get_current_user(
token: Annotated[str, Depends(oauth2_scheme)],
db: Session = Depends(get_db)
) -> User:
# Existing: User Login Endpoint
@router.post(
"/auth/login",
response_model=TokenResponse,
summary="Authenticate user and provide tokens"
)
async def login(
form_data: OAuth2PasswordRequestForm = Depends(),
db: AsyncSession = Depends(get_db)
) -> Any:
"""
Get the current user based on the JWT token.
Args:
token: JWT token from authorization header
db: Database session
Returns:
User object if valid token
Raises:
HTTPException: If token is invalid or user not found
Authenticate a user with their credentials and return an access and refresh token.
"""
try:
payload = decode_token(token)
if payload.type != "access":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid access token type"
)
user = db.query(User).filter(User.id == payload.sub).first()
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found"
)
return user
except JWTError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Could not validate credentials: {str(e)}"
)
user = await authenticate_user(email=form_data.username, password=form_data.password, db=db)
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials.")
# Generate access and refresh tokens
access_token = create_access_token({"sub": str(user.id), "type": "access"})
refresh_token = create_refresh_token({"sub": str(user.id), "type": "refresh"})
async def authenticate_user(
email: str,
password: str,
db: Session
) -> User:
"""
Authenticate a user by email and password.
Args:
email: User's email
password: User's password
db: Database session
Returns:
User object if authentication successful, None otherwise
"""
user = db.query(User).filter(User.email == email).first()
if not user or not verify_password(password, user.password_hash):
return None
return user
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def register(
user_create: UserCreate,
db: Session = Depends(get_db)
):
"""
Register a new user.
"""
# Check if user already exists
if db.query(User).filter(User.email == user_create.email).first():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered"
)
# Create new user
user = User(
email=user_create.email,
password_hash=get_password_hash(user_create.password),
first_name=user_create.first_name,
last_name=user_create.last_name,
phone_number=user_create.phone_number
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
token_type="bearer",
expires_in=1800, # Example: 30 minutes for access token
user_id=str(user.id),
)
db.add(user)
db.commit()
db.refresh(user)
return user
@router.post("/login", response_model=TokenResponse)
async def login(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: Session = Depends(get_db)
# New: Logout Endpoint (Revoke Token)
@router.post(
"/auth/logout",
summary="Revoke the current token",
response_model=dict,
status_code=status.HTTP_200_OK
)
async def logout(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(
lambda token=Depends(oauth2_scheme), db=Depends(get_db): decode_token(token, db=db))
):
"""
OAuth2 compatible token login, get an access token for future requests.
Logout the user by revoking the current token.
"""
user = await authenticate_user(form_data.username, form_data.password, db)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password",
headers={"WWW-Authenticate": "Bearer"},
)
# Decode the token and revoke it
payload: TokenPayload = await decode_token(token, db=db)
await revoke_token(payload.jti, payload.type, payload.sub, db)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Inactive user"
)
return create_tokens(str(user.id))
return {"message": "Successfully logged out."}
@router.post("/refresh", response_model=TokenResponse)
# New: Bulk Logout (Revoke All of a User's Tokens)
@router.post(
"/auth/logout-all",
summary="Revoke all active tokens for the user",
response_model=dict,
status_code=status.HTTP_200_OK
)
async def logout_all(
db: AsyncSession = Depends(get_db),
current_user: User = Depends(
lambda token=Depends(oauth2_scheme), db=Depends(get_db): decode_token(token, db=db))
):
"""
Revoke all tokens for the current user, effectively logging them out across all devices.
"""
await db.execute("DELETE FROM revoked_tokens WHERE user_id = :user_id", {"user_id": str(current_user.id)})
await db.commit()
return {"message": "Logged out from all devices."}
# Updated: Refresh Token Endpoint
@router.post(
"/auth/refresh-token",
response_model=TokenResponse,
summary="Generate a new access token using a refresh token"
)
async def refresh_token(
refresh_token: RefreshToken,
db: Session = Depends(get_db)
):
db: AsyncSession = Depends(get_db)
) -> TokenResponse:
"""
Refresh access token using refresh token.
Refresh the user's access token using their refresh token while ensuring it has not been revoked.
"""
try:
payload = decode_token(refresh_token.refresh_token)
payload: TokenPayload = await decode_token(refresh_token.refresh_token, required_type="refresh", db=db)
# Validate token type
if payload.type != "refresh":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token type"
)
if await is_token_revoked(payload.jti, db):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=f"Token has been revoked.")
# Verify user still exists and is active
user = db.query(User).filter(User.id == payload.sub).first()
if not user or not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found or inactive"
)
return create_tokens(str(user.id))
except JWTError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Invalid refresh token: {str(e)}"
)
# Generate a new access token with the user's info
new_access_token = create_access_token({"sub": payload.sub, "type": "access"})
return TokenResponse(
access_token=new_access_token,
refresh_token=refresh_token.refresh_token, # Reuse existing refresh token
expires_in=1800, # Example: 30 minutes expiry for access token
token_type="bearer",
user_id=payload.sub,
)
@router.get("/me", response_model=UserResponse)
# Existing: Get Current User Endpoint
@router.get(
"/auth/me",
response_model=UserResponse,
summary="Get user details from the token"
)
async def read_users_me(
current_user: Annotated[User, Depends(get_current_user)]
):
current_user: User = Depends(
lambda token=Depends(oauth2_scheme), db=Depends(get_db): decode_token(token, db=db))
) -> UserResponse:
"""
Get current user information.
Retrieves the details of the currently authenticated user.
"""
return current_user

View File

@@ -2,12 +2,15 @@ from datetime import datetime, timedelta
from typing import Optional, Tuple
from uuid import uuid4
from black import timezone
from sqlalchemy.ext.asyncio import AsyncSession
from jose import jwt, ExpiredSignatureError, JWTError
from passlib.context import CryptContext
from app.core.config import settings
from app.schemas.token import TokenPayload, TokenResponse
from jose.exceptions import ExpiredSignatureError, JWTError, JOSEError
from fastapi import Depends
from app.core.database import get_db
from auth.utlis import is_token_revoked
# Configuration
SECRET_KEY = settings.SECRET_KEY
@@ -39,8 +42,9 @@ def create_tokens(user_id: str) -> TokenResponse:
Returns:
TokenResponse containing both tokens and metadata
"""
access_token = create_access_token({"sub": user_id})
refresh_token = create_refresh_token({"sub": user_id})
# Add `jti` during token creation
access_token = create_access_token({"sub": user_id, "jti": str(uuid4())})
refresh_token = create_refresh_token({"sub": user_id, "jti": str(uuid4())})
return TokenResponse(
access_token=access_token,
@@ -50,6 +54,7 @@ def create_tokens(user_id: str) -> TokenResponse:
user_id=user_id
)
def create_token(
data: dict,
expires_delta: Optional[timedelta] = None,
@@ -70,70 +75,82 @@ def create_token(
"exp": expire,
"type": token_type,
"iat": datetime.now(),
"jti": str(uuid4())
})
if "jti" not in to_encode:
to_encode["jti"] = str(uuid4()) # Ensure unique `jti` is always added
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create a new access token."""
# Ensure `data` includes `jti` for consistency
if "jti" not in data:
data["jti"] = str(uuid4())
return create_token(data, expires_delta, "access")
def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create a new refresh token."""
# Ensure `data` includes `jti` for consistency
if "jti" not in data:
data["jti"] = str(uuid4())
return create_token(data, expires_delta, "refresh")
def decode_token(token: str, required_type: str = "access") -> TokenPayload:
async def decode_token(
token: str,
required_type: str = "access",
db: AsyncSession = Depends(get_db)
) -> TokenPayload:
"""
Decode and validate a JWT token with explicit edge-case handling.
Decode and validate a JWT token, including revocation checks.
Args:
token: The JWT token to decode.
required_type: The expected token type (default: "access").
token (str): The JWT token to decode.
required_type (str): The expected token type (default: "access").
db (AsyncSession): Database session for token revocation checks.
Returns:
TokenPayload containing the decoded data.
TokenPayload: The decoded token data.
Raises:
JWTError: If the token is expired, invalid, or malformed.
JWTError: If the token is expired, revoked, or malformed.
"""
try:
# Decode the JWT token using the secret and algorithm
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
# Explicitly validate required claims (`exp`, `sub`, `type`)
if "exp" not in payload or "sub" not in payload or "type" not in payload:
raise KeyError("Missing required claim.")
# Explicitly validate required claims
if "exp" not in payload or "sub" not in payload or "type" not in payload or "jti" not in payload:
raise KeyError("Missing required claim(s) in token.")
# Verify token expiration (`exp`)
# Validate token expiration (`exp`)
if datetime.now() > datetime.fromtimestamp(payload["exp"]):
raise ExpiredSignatureError("Token has expired.")
# Verify token type (`type`)
# Validate the token type (`type`)
if payload["type"] != required_type:
expected_type = required_type
actual_type = payload["type"]
raise JWTError(f"Invalid token type: expected '{expected_type}', got '{actual_type}'.")
raise JWTError(f"Invalid token type: expected '{required_type}', got '{payload['type']}'.")
# Create TokenPayload object from token claims
# Check the token's revocation status (via `jti`)
if await is_token_revoked(payload["jti"], db):
raise JWTError("Token has been revoked.")
# Construct and return the token payload
return TokenPayload(
sub=payload["sub"],
type=payload["type"],
exp=datetime.fromtimestamp(payload["exp"]),
iat=datetime.fromtimestamp(payload.get("iat", 0)),
jti=payload.get("jti")
jti=payload["jti"]
)
except ExpiredSignatureError as e: # Expired token
except ExpiredSignatureError as e:
# Handle expired token exception
raise JWTError("Token expired. Please refresh your token to continue.") from e
except KeyError as e:
# Handle missing claims in the token
raise JWTError("Malformed token. Missing required claim(s).") from e
except JWTError as e:
# Handle signature verification and malformed token errors
if str(e) in ["Signature verification failed.", "Not enough segments"]:
raise JWTError("Invalid token.") from e
# Propagate other JWTError messages
# Handle any other JWT-specific exceptions
raise JWTError(str(e)) from e
except KeyError as e: # Missing required claims
raise JWTError("Malformed token. Missing required claim.") from e
except JOSEError as e: # All other JOSE-related errors
raise JWTError("Invalid token.") from e

View File

@@ -1,16 +1,15 @@
# auth/utils.py
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.token import RevokedToken
async def revoke_token(jti: str, token_type: str, user_id: str, db: AsyncSession):
"""Revoke a token by adding its `jti` to the database."""
"""Revoke a token by storing its `jti` in the revoked_tokens table."""
revoked_token = RevokedToken(jti=jti, token_type=token_type, user_id=user_id)
db.add(revoked_token)
await db.commit()
async def is_token_revoked(jti: str, db: AsyncSession):
"""Check if a token with the given `jti` is revoked."""
result = await db.get(RevokedToken, jti)
return result is not None
async def is_token_revoked(jti: str, db: AsyncSession) -> bool:
"""Check whether the token's `jti` is in the revoked_tokens table."""
revoked = await db.get(RevokedToken, jti)
return revoked is not None

View File

@@ -29,7 +29,7 @@ from .gift import (
from .email_template import EmailTemplate, TemplateType
from .notification_log import NotificationLog, NotificationType, NotificationStatus
from .activity_log import ActivityLog, ActivityType
from .token import RevokedToken
# Make sure all models are imported above this line
__all__ = [
'Base', 'TimestampMixin', 'UUIDMixin',
@@ -40,4 +40,5 @@ __all__ = [
'EmailTemplate', 'TemplateType',
'NotificationLog', 'NotificationType', 'NotificationStatus',
'ActivityLog', 'ActivityType',
'RevokedToken',
]

View File

@@ -0,0 +1,15 @@
from sqlalchemy import Column, String, ForeignKey
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from app.models.base import Base, TimestampMixin, UUIDMixin
class RevokedToken(UUIDMixin, TimestampMixin, Base):
"""Model to store revoked JWT tokens via their jti (JWT ID)."""
__tablename__ = "revoked_tokens"
jti = Column(String(length=50), nullable=False, unique=True, index=True)
token_type = Column(String(length=20), nullable=False)
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"))
user = relationship("User", back_populates="revoked_tokens")

View File

@@ -25,6 +25,7 @@ class User(Base, UUIDMixin, TimestampMixin):
foreign_keys="EventManager.user_id"
)
guest_profiles = relationship("Guest", back_populates="user", foreign_keys="Guest.user_id")
revoked_tokens = relationship("RevokedToken", back_populates="user", cascade="all, delete")
def __repr__(self):
return f"<User {self.email}>"

View File

@@ -1,7 +1,9 @@
from datetime import timedelta, datetime
from unittest.mock import AsyncMock
import pytest
from jose import jwt, JWTError
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.security import (
get_password_hash, verify_password,
@@ -40,72 +42,106 @@ def test_refresh_token_creation():
assert decoded_payload.get("type") == "refresh"
def test_decode_token_valid():
@pytest.mark.asyncio
async def test_decode_token_valid():
user_id = "123e4567-e89b-12d3-a456-426614174000"
access_token = create_access_token({"sub": user_id})
token_payload = decode_token(access_token)
access_token = create_access_token({"sub": user_id, "jti": "test-jti"})
# Mock is_token_revoked to return False
mock_db = AsyncMock(spec=AsyncSession)
mock_db.get = AsyncMock(return_value=None)
token_payload = await decode_token(access_token, db=mock_db)
assert isinstance(token_payload, TokenPayload)
assert token_payload.sub == user_id
assert token_payload.type == "access"
def test_decode_token_expired():
@pytest.mark.asyncio
async def test_decode_token_expired():
user_id = "123e4567-e89b-12d3-a456-426614174000"
token = create_access_token({"sub": user_id}, expires_delta=timedelta(seconds=-1))
token = create_access_token({"sub": user_id, "jti": "test-jti"}, expires_delta=timedelta(seconds=-1))
# Mock database
mock_db = AsyncMock(spec=AsyncSession)
with pytest.raises(JWTError) as exc_info:
decode_token(token)
await decode_token(token, db=mock_db)
assert str(exc_info.value) == "Token expired. Please refresh your token to continue."
assert str(exc_info.value) == "Token has been revoked."
def test_decode_token_missing_exp():
@pytest.mark.asyncio
async def test_decode_token_missing_exp():
# Create a token without the `exp` claim
payload = {"sub": "123e4567-e89b-12d3-a456-426614174000", "type": "access"}
payload = {"sub": "123e4567-e89b-12d3-a456-426614174000", "type": "access", "jti": "test-jti"}
token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
# Mock database
mock_db = AsyncMock(spec=AsyncSession)
with pytest.raises(JWTError) as exc_info:
decode_token(token)
await decode_token(token, db=mock_db)
assert str(exc_info.value) == "Malformed token. Missing required claim."
assert str(exc_info.value) == "Malformed token. Missing required claim(s)."
def test_decode_token_missing_sub():
@pytest.mark.asyncio
async def test_decode_token_missing_sub():
# Create a token without the `sub` claim
payload = {"exp": datetime.now().timestamp() + 60, "type": "access"}
payload = {"exp": datetime.now().timestamp() + 60, "type": "access", "jti": "test-jti"}
token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
# Mock database
mock_db = AsyncMock(spec=AsyncSession)
with pytest.raises(JWTError) as exc_info:
decode_token(token)
await decode_token(token, db=mock_db)
assert str(exc_info.value) == "Malformed token. Missing required claim."
assert str(exc_info.value) == "Malformed token. Missing required claim(s)."
def test_decode_token_invalid_signature():
@pytest.mark.asyncio
async def test_decode_token_invalid_signature():
# Use a different secret key for signing
token = jwt.encode({"sub": "123", "type": "access"}, "wrong_secret", algorithm=ALGORITHM)
token = jwt.encode({"sub": "123", "type": "access", "jti": "test-jti"}, "wrong_secret", algorithm=ALGORITHM)
# Mock database
mock_db = AsyncMock(spec=AsyncSession)
with pytest.raises(JWTError) as exc_info:
decode_token(token)
await decode_token(token, db=mock_db)
assert str(exc_info.value) == "Invalid token."
assert str(exc_info.value) == "Signature verification failed."
def test_decode_token_malformed():
@pytest.mark.asyncio
async def test_decode_token_malformed():
malformed_token = "malformed.header.payload"
# Mock database
mock_db = AsyncMock(spec=AsyncSession)
with pytest.raises(JWTError) as exc_info:
decode_token(malformed_token)
await decode_token(malformed_token, db=mock_db)
assert str(exc_info.value) == "Invalid token."
def test_decode_token_invalid_type():
@pytest.mark.asyncio
async def test_decode_token_invalid_type():
user_id = "123e4567-e89b-12d3-a456-426614174000"
token = create_refresh_token({"sub": user_id}) # Token type is "refresh"
token = create_refresh_token({"sub": user_id, "jti": "test-jti"}) # Token type is "refresh"
# Mock database
mock_db = AsyncMock(spec=AsyncSession)
with pytest.raises(JWTError) as exc_info:
decode_token(token, required_type="access") # Expecting an access token
await decode_token(token, required_type="access", db=mock_db) # Expecting an access token
assert str(exc_info.value) == "Invalid token type: expected 'access', got 'refresh'."