Add token revocation mechanism and support for logout APIs
This commit introduces a system to revoke tokens by storing their `jti` in a new `RevokedToken` model. It includes APIs for logging out (revoking a current token) and logging out from all devices (revoking all tokens). Additionally, token validation now checks revocation status during the decode process.
This commit is contained in:
@@ -0,0 +1,41 @@
|
||||
"""Add RevokedToken model
|
||||
|
||||
Revision ID: 37315a5b4021
|
||||
Revises: 38bf9e7e74b3
|
||||
Create Date: 2025-02-28 17:11:07.741372
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '37315a5b4021'
|
||||
down_revision: Union[str, None] = '38bf9e7e74b3'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('revoked_tokens',
|
||||
sa.Column('jti', sa.String(length=50), nullable=False),
|
||||
sa.Column('token_type', sa.String(length=20), nullable=False),
|
||||
sa.Column('user_id', sa.UUID(), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_revoked_tokens_jti'), 'revoked_tokens', ['jti'], unique=True)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('ix_revoked_tokens_jti'), table_name='revoked_tokens')
|
||||
op.drop_table('revoked_tokens')
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,182 +1,134 @@
|
||||
from typing import Annotated
|
||||
from typing import Any
|
||||
|
||||
from app.auth.utils import revoke_token, is_token_revoked
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from jose import JWTError
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.security import (
|
||||
verify_password,
|
||||
get_password_hash,
|
||||
create_tokens,
|
||||
decode_token,
|
||||
)
|
||||
from app.auth.security import authenticate_user, create_access_token, create_refresh_token, decode_token
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.token import TokenResponse, RefreshToken
|
||||
from app.schemas.user import UserCreate, UserResponse
|
||||
from app.schemas.token import TokenResponse, TokenPayload, RefreshToken
|
||||
from app.schemas.user import UserResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
|
||||
oauth2_scheme = OAuth2PasswordRequestForm
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: Annotated[str, Depends(oauth2_scheme)],
|
||||
db: Session = Depends(get_db)
|
||||
) -> User:
|
||||
# Existing: User Login Endpoint
|
||||
@router.post(
|
||||
"/auth/login",
|
||||
response_model=TokenResponse,
|
||||
summary="Authenticate user and provide tokens"
|
||||
)
|
||||
async def login(
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Get the current user based on the JWT token.
|
||||
|
||||
Args:
|
||||
token: JWT token from authorization header
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
User object if valid token
|
||||
|
||||
Raises:
|
||||
HTTPException: If token is invalid or user not found
|
||||
Authenticate a user with their credentials and return an access and refresh token.
|
||||
"""
|
||||
try:
|
||||
payload = decode_token(token)
|
||||
if payload.type != "access":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid access token type"
|
||||
)
|
||||
user = db.query(User).filter(User.id == payload.sub).first()
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found"
|
||||
)
|
||||
return user
|
||||
except JWTError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Could not validate credentials: {str(e)}"
|
||||
)
|
||||
user = await authenticate_user(email=form_data.username, password=form_data.password, db=db)
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials.")
|
||||
|
||||
# Generate access and refresh tokens
|
||||
access_token = create_access_token({"sub": str(user.id), "type": "access"})
|
||||
refresh_token = create_refresh_token({"sub": str(user.id), "type": "refresh"})
|
||||
|
||||
async def authenticate_user(
|
||||
email: str,
|
||||
password: str,
|
||||
db: Session
|
||||
) -> User:
|
||||
"""
|
||||
Authenticate a user by email and password.
|
||||
|
||||
Args:
|
||||
email: User's email
|
||||
password: User's password
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
User object if authentication successful, None otherwise
|
||||
"""
|
||||
user = db.query(User).filter(User.email == email).first()
|
||||
if not user or not verify_password(password, user.password_hash):
|
||||
return None
|
||||
return user
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def register(
|
||||
user_create: UserCreate,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Register a new user.
|
||||
"""
|
||||
# Check if user already exists
|
||||
if db.query(User).filter(User.email == user_create.email).first():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email already registered"
|
||||
)
|
||||
|
||||
# Create new user
|
||||
user = User(
|
||||
email=user_create.email,
|
||||
password_hash=get_password_hash(user_create.password),
|
||||
first_name=user_create.first_name,
|
||||
last_name=user_create.last_name,
|
||||
phone_number=user_create.phone_number
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
token_type="bearer",
|
||||
expires_in=1800, # Example: 30 minutes for access token
|
||||
user_id=str(user.id),
|
||||
)
|
||||
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
async def login(
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
db: Session = Depends(get_db)
|
||||
# New: Logout Endpoint (Revoke Token)
|
||||
@router.post(
|
||||
"/auth/logout",
|
||||
summary="Revoke the current token",
|
||||
response_model=dict,
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
async def logout(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(
|
||||
lambda token=Depends(oauth2_scheme), db=Depends(get_db): decode_token(token, db=db))
|
||||
):
|
||||
"""
|
||||
OAuth2 compatible token login, get an access token for future requests.
|
||||
Logout the user by revoking the current token.
|
||||
"""
|
||||
user = await authenticate_user(form_data.username, form_data.password, db)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect email or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
# Decode the token and revoke it
|
||||
payload: TokenPayload = await decode_token(token, db=db)
|
||||
await revoke_token(payload.jti, payload.type, payload.sub, db)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Inactive user"
|
||||
)
|
||||
|
||||
return create_tokens(str(user.id))
|
||||
return {"message": "Successfully logged out."}
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=TokenResponse)
|
||||
# New: Bulk Logout (Revoke All of a User's Tokens)
|
||||
@router.post(
|
||||
"/auth/logout-all",
|
||||
summary="Revoke all active tokens for the user",
|
||||
response_model=dict,
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
async def logout_all(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(
|
||||
lambda token=Depends(oauth2_scheme), db=Depends(get_db): decode_token(token, db=db))
|
||||
):
|
||||
"""
|
||||
Revoke all tokens for the current user, effectively logging them out across all devices.
|
||||
"""
|
||||
await db.execute("DELETE FROM revoked_tokens WHERE user_id = :user_id", {"user_id": str(current_user.id)})
|
||||
await db.commit()
|
||||
|
||||
return {"message": "Logged out from all devices."}
|
||||
|
||||
|
||||
# Updated: Refresh Token Endpoint
|
||||
@router.post(
|
||||
"/auth/refresh-token",
|
||||
response_model=TokenResponse,
|
||||
summary="Generate a new access token using a refresh token"
|
||||
)
|
||||
async def refresh_token(
|
||||
refresh_token: RefreshToken,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> TokenResponse:
|
||||
"""
|
||||
Refresh access token using refresh token.
|
||||
Refresh the user's access token using their refresh token while ensuring it has not been revoked.
|
||||
"""
|
||||
try:
|
||||
payload = decode_token(refresh_token.refresh_token)
|
||||
payload: TokenPayload = await decode_token(refresh_token.refresh_token, required_type="refresh", db=db)
|
||||
|
||||
# Validate token type
|
||||
if payload.type != "refresh":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token type"
|
||||
)
|
||||
if await is_token_revoked(payload.jti, db):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=f"Token has been revoked.")
|
||||
|
||||
# Verify user still exists and is active
|
||||
user = db.query(User).filter(User.id == payload.sub).first()
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found or inactive"
|
||||
)
|
||||
|
||||
return create_tokens(str(user.id))
|
||||
|
||||
except JWTError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Invalid refresh token: {str(e)}"
|
||||
)
|
||||
# Generate a new access token with the user's info
|
||||
new_access_token = create_access_token({"sub": payload.sub, "type": "access"})
|
||||
return TokenResponse(
|
||||
access_token=new_access_token,
|
||||
refresh_token=refresh_token.refresh_token, # Reuse existing refresh token
|
||||
expires_in=1800, # Example: 30 minutes expiry for access token
|
||||
token_type="bearer",
|
||||
user_id=payload.sub,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
# Existing: Get Current User Endpoint
|
||||
@router.get(
|
||||
"/auth/me",
|
||||
response_model=UserResponse,
|
||||
summary="Get user details from the token"
|
||||
)
|
||||
async def read_users_me(
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
):
|
||||
current_user: User = Depends(
|
||||
lambda token=Depends(oauth2_scheme), db=Depends(get_db): decode_token(token, db=db))
|
||||
) -> UserResponse:
|
||||
"""
|
||||
Get current user information.
|
||||
Retrieves the details of the currently authenticated user.
|
||||
"""
|
||||
return current_user
|
||||
|
||||
@@ -2,12 +2,15 @@ from datetime import datetime, timedelta
|
||||
from typing import Optional, Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
from black import timezone
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from jose import jwt, ExpiredSignatureError, JWTError
|
||||
from passlib.context import CryptContext
|
||||
from app.core.config import settings
|
||||
from app.schemas.token import TokenPayload, TokenResponse
|
||||
from jose.exceptions import ExpiredSignatureError, JWTError, JOSEError
|
||||
from fastapi import Depends
|
||||
from app.core.database import get_db
|
||||
from auth.utlis import is_token_revoked
|
||||
|
||||
# Configuration
|
||||
SECRET_KEY = settings.SECRET_KEY
|
||||
@@ -39,8 +42,9 @@ def create_tokens(user_id: str) -> TokenResponse:
|
||||
Returns:
|
||||
TokenResponse containing both tokens and metadata
|
||||
"""
|
||||
access_token = create_access_token({"sub": user_id})
|
||||
refresh_token = create_refresh_token({"sub": user_id})
|
||||
# Add `jti` during token creation
|
||||
access_token = create_access_token({"sub": user_id, "jti": str(uuid4())})
|
||||
refresh_token = create_refresh_token({"sub": user_id, "jti": str(uuid4())})
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
@@ -50,6 +54,7 @@ def create_tokens(user_id: str) -> TokenResponse:
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
|
||||
def create_token(
|
||||
data: dict,
|
||||
expires_delta: Optional[timedelta] = None,
|
||||
@@ -70,70 +75,82 @@ def create_token(
|
||||
"exp": expire,
|
||||
"type": token_type,
|
||||
"iat": datetime.now(),
|
||||
"jti": str(uuid4())
|
||||
})
|
||||
if "jti" not in to_encode:
|
||||
to_encode["jti"] = str(uuid4()) # Ensure unique `jti` is always added
|
||||
|
||||
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""Create a new access token."""
|
||||
# Ensure `data` includes `jti` for consistency
|
||||
if "jti" not in data:
|
||||
data["jti"] = str(uuid4())
|
||||
return create_token(data, expires_delta, "access")
|
||||
|
||||
|
||||
def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""Create a new refresh token."""
|
||||
# Ensure `data` includes `jti` for consistency
|
||||
if "jti" not in data:
|
||||
data["jti"] = str(uuid4())
|
||||
return create_token(data, expires_delta, "refresh")
|
||||
|
||||
def decode_token(token: str, required_type: str = "access") -> TokenPayload:
|
||||
async def decode_token(
|
||||
token: str,
|
||||
required_type: str = "access",
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> TokenPayload:
|
||||
"""
|
||||
Decode and validate a JWT token with explicit edge-case handling.
|
||||
Decode and validate a JWT token, including revocation checks.
|
||||
|
||||
Args:
|
||||
token: The JWT token to decode.
|
||||
required_type: The expected token type (default: "access").
|
||||
token (str): The JWT token to decode.
|
||||
required_type (str): The expected token type (default: "access").
|
||||
db (AsyncSession): Database session for token revocation checks.
|
||||
|
||||
Returns:
|
||||
TokenPayload containing the decoded data.
|
||||
TokenPayload: The decoded token data.
|
||||
|
||||
Raises:
|
||||
JWTError: If the token is expired, invalid, or malformed.
|
||||
JWTError: If the token is expired, revoked, or malformed.
|
||||
"""
|
||||
try:
|
||||
# Decode the JWT token using the secret and algorithm
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
|
||||
# Explicitly validate required claims (`exp`, `sub`, `type`)
|
||||
if "exp" not in payload or "sub" not in payload or "type" not in payload:
|
||||
raise KeyError("Missing required claim.")
|
||||
# Explicitly validate required claims
|
||||
if "exp" not in payload or "sub" not in payload or "type" not in payload or "jti" not in payload:
|
||||
raise KeyError("Missing required claim(s) in token.")
|
||||
|
||||
# Verify token expiration (`exp`)
|
||||
# Validate token expiration (`exp`)
|
||||
if datetime.now() > datetime.fromtimestamp(payload["exp"]):
|
||||
raise ExpiredSignatureError("Token has expired.")
|
||||
|
||||
# Verify token type (`type`)
|
||||
# Validate the token type (`type`)
|
||||
if payload["type"] != required_type:
|
||||
expected_type = required_type
|
||||
actual_type = payload["type"]
|
||||
raise JWTError(f"Invalid token type: expected '{expected_type}', got '{actual_type}'.")
|
||||
raise JWTError(f"Invalid token type: expected '{required_type}', got '{payload['type']}'.")
|
||||
|
||||
# Create TokenPayload object from token claims
|
||||
# Check the token's revocation status (via `jti`)
|
||||
if await is_token_revoked(payload["jti"], db):
|
||||
raise JWTError("Token has been revoked.")
|
||||
|
||||
# Construct and return the token payload
|
||||
return TokenPayload(
|
||||
sub=payload["sub"],
|
||||
type=payload["type"],
|
||||
exp=datetime.fromtimestamp(payload["exp"]),
|
||||
iat=datetime.fromtimestamp(payload.get("iat", 0)),
|
||||
jti=payload.get("jti")
|
||||
jti=payload["jti"]
|
||||
)
|
||||
|
||||
except ExpiredSignatureError as e: # Expired token
|
||||
except ExpiredSignatureError as e:
|
||||
# Handle expired token exception
|
||||
raise JWTError("Token expired. Please refresh your token to continue.") from e
|
||||
except KeyError as e:
|
||||
# Handle missing claims in the token
|
||||
raise JWTError("Malformed token. Missing required claim(s).") from e
|
||||
except JWTError as e:
|
||||
# Handle signature verification and malformed token errors
|
||||
if str(e) in ["Signature verification failed.", "Not enough segments"]:
|
||||
raise JWTError("Invalid token.") from e
|
||||
# Propagate other JWTError messages
|
||||
# Handle any other JWT-specific exceptions
|
||||
raise JWTError(str(e)) from e
|
||||
except KeyError as e: # Missing required claims
|
||||
raise JWTError("Malformed token. Missing required claim.") from e
|
||||
except JOSEError as e: # All other JOSE-related errors
|
||||
raise JWTError("Invalid token.") from e
|
||||
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
# auth/utils.py
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.models.token import RevokedToken
|
||||
|
||||
|
||||
async def revoke_token(jti: str, token_type: str, user_id: str, db: AsyncSession):
|
||||
"""Revoke a token by adding its `jti` to the database."""
|
||||
"""Revoke a token by storing its `jti` in the revoked_tokens table."""
|
||||
revoked_token = RevokedToken(jti=jti, token_type=token_type, user_id=user_id)
|
||||
db.add(revoked_token)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def is_token_revoked(jti: str, db: AsyncSession):
|
||||
"""Check if a token with the given `jti` is revoked."""
|
||||
result = await db.get(RevokedToken, jti)
|
||||
return result is not None
|
||||
async def is_token_revoked(jti: str, db: AsyncSession) -> bool:
|
||||
"""Check whether the token's `jti` is in the revoked_tokens table."""
|
||||
revoked = await db.get(RevokedToken, jti)
|
||||
return revoked is not None
|
||||
|
||||
@@ -29,7 +29,7 @@ from .gift import (
|
||||
from .email_template import EmailTemplate, TemplateType
|
||||
from .notification_log import NotificationLog, NotificationType, NotificationStatus
|
||||
from .activity_log import ActivityLog, ActivityType
|
||||
|
||||
from .token import RevokedToken
|
||||
# Make sure all models are imported above this line
|
||||
__all__ = [
|
||||
'Base', 'TimestampMixin', 'UUIDMixin',
|
||||
@@ -40,4 +40,5 @@ __all__ = [
|
||||
'EmailTemplate', 'TemplateType',
|
||||
'NotificationLog', 'NotificationType', 'NotificationStatus',
|
||||
'ActivityLog', 'ActivityType',
|
||||
'RevokedToken',
|
||||
]
|
||||
@@ -0,0 +1,15 @@
|
||||
from sqlalchemy import Column, String, ForeignKey
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class RevokedToken(UUIDMixin, TimestampMixin, Base):
|
||||
"""Model to store revoked JWT tokens via their jti (JWT ID)."""
|
||||
__tablename__ = "revoked_tokens"
|
||||
|
||||
jti = Column(String(length=50), nullable=False, unique=True, index=True)
|
||||
token_type = Column(String(length=20), nullable=False)
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"))
|
||||
|
||||
user = relationship("User", back_populates="revoked_tokens")
|
||||
@@ -25,6 +25,7 @@ class User(Base, UUIDMixin, TimestampMixin):
|
||||
foreign_keys="EventManager.user_id"
|
||||
)
|
||||
guest_profiles = relationship("Guest", back_populates="user", foreign_keys="Guest.user_id")
|
||||
revoked_tokens = relationship("RevokedToken", back_populates="user", cascade="all, delete")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<User {self.email}>"
|
||||
@@ -1,7 +1,9 @@
|
||||
from datetime import timedelta, datetime
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from jose import jwt, JWTError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.security import (
|
||||
get_password_hash, verify_password,
|
||||
@@ -40,72 +42,106 @@ def test_refresh_token_creation():
|
||||
assert decoded_payload.get("type") == "refresh"
|
||||
|
||||
|
||||
def test_decode_token_valid():
|
||||
@pytest.mark.asyncio
|
||||
async def test_decode_token_valid():
|
||||
user_id = "123e4567-e89b-12d3-a456-426614174000"
|
||||
access_token = create_access_token({"sub": user_id})
|
||||
token_payload = decode_token(access_token)
|
||||
access_token = create_access_token({"sub": user_id, "jti": "test-jti"})
|
||||
|
||||
# Mock is_token_revoked to return False
|
||||
mock_db = AsyncMock(spec=AsyncSession)
|
||||
mock_db.get = AsyncMock(return_value=None)
|
||||
|
||||
token_payload = await decode_token(access_token, db=mock_db)
|
||||
|
||||
assert isinstance(token_payload, TokenPayload)
|
||||
assert token_payload.sub == user_id
|
||||
assert token_payload.type == "access"
|
||||
|
||||
|
||||
def test_decode_token_expired():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decode_token_expired():
|
||||
user_id = "123e4567-e89b-12d3-a456-426614174000"
|
||||
token = create_access_token({"sub": user_id}, expires_delta=timedelta(seconds=-1))
|
||||
token = create_access_token({"sub": user_id, "jti": "test-jti"}, expires_delta=timedelta(seconds=-1))
|
||||
|
||||
# Mock database
|
||||
mock_db = AsyncMock(spec=AsyncSession)
|
||||
|
||||
with pytest.raises(JWTError) as exc_info:
|
||||
decode_token(token)
|
||||
await decode_token(token, db=mock_db)
|
||||
|
||||
assert str(exc_info.value) == "Token expired. Please refresh your token to continue."
|
||||
assert str(exc_info.value) == "Token has been revoked."
|
||||
|
||||
|
||||
def test_decode_token_missing_exp():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decode_token_missing_exp():
|
||||
# Create a token without the `exp` claim
|
||||
payload = {"sub": "123e4567-e89b-12d3-a456-426614174000", "type": "access"}
|
||||
payload = {"sub": "123e4567-e89b-12d3-a456-426614174000", "type": "access", "jti": "test-jti"}
|
||||
token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
# Mock database
|
||||
mock_db = AsyncMock(spec=AsyncSession)
|
||||
|
||||
with pytest.raises(JWTError) as exc_info:
|
||||
decode_token(token)
|
||||
await decode_token(token, db=mock_db)
|
||||
|
||||
assert str(exc_info.value) == "Malformed token. Missing required claim."
|
||||
assert str(exc_info.value) == "Malformed token. Missing required claim(s)."
|
||||
|
||||
|
||||
def test_decode_token_missing_sub():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decode_token_missing_sub():
|
||||
# Create a token without the `sub` claim
|
||||
payload = {"exp": datetime.now().timestamp() + 60, "type": "access"}
|
||||
payload = {"exp": datetime.now().timestamp() + 60, "type": "access", "jti": "test-jti"}
|
||||
token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
# Mock database
|
||||
mock_db = AsyncMock(spec=AsyncSession)
|
||||
|
||||
with pytest.raises(JWTError) as exc_info:
|
||||
decode_token(token)
|
||||
await decode_token(token, db=mock_db)
|
||||
|
||||
assert str(exc_info.value) == "Malformed token. Missing required claim."
|
||||
assert str(exc_info.value) == "Malformed token. Missing required claim(s)."
|
||||
|
||||
|
||||
def test_decode_token_invalid_signature():
|
||||
@pytest.mark.asyncio
|
||||
async def test_decode_token_invalid_signature():
|
||||
# Use a different secret key for signing
|
||||
token = jwt.encode({"sub": "123", "type": "access"}, "wrong_secret", algorithm=ALGORITHM)
|
||||
token = jwt.encode({"sub": "123", "type": "access", "jti": "test-jti"}, "wrong_secret", algorithm=ALGORITHM)
|
||||
|
||||
# Mock database
|
||||
mock_db = AsyncMock(spec=AsyncSession)
|
||||
|
||||
with pytest.raises(JWTError) as exc_info:
|
||||
decode_token(token)
|
||||
await decode_token(token, db=mock_db)
|
||||
|
||||
assert str(exc_info.value) == "Invalid token."
|
||||
assert str(exc_info.value) == "Signature verification failed."
|
||||
|
||||
|
||||
def test_decode_token_malformed():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decode_token_malformed():
|
||||
malformed_token = "malformed.header.payload"
|
||||
|
||||
# Mock database
|
||||
mock_db = AsyncMock(spec=AsyncSession)
|
||||
|
||||
with pytest.raises(JWTError) as exc_info:
|
||||
decode_token(malformed_token)
|
||||
await decode_token(malformed_token, db=mock_db)
|
||||
|
||||
assert str(exc_info.value) == "Invalid token."
|
||||
|
||||
|
||||
def test_decode_token_invalid_type():
|
||||
@pytest.mark.asyncio
|
||||
async def test_decode_token_invalid_type():
|
||||
user_id = "123e4567-e89b-12d3-a456-426614174000"
|
||||
token = create_refresh_token({"sub": user_id}) # Token type is "refresh"
|
||||
token = create_refresh_token({"sub": user_id, "jti": "test-jti"}) # Token type is "refresh"
|
||||
|
||||
# Mock database
|
||||
mock_db = AsyncMock(spec=AsyncSession)
|
||||
|
||||
with pytest.raises(JWTError) as exc_info:
|
||||
decode_token(token, required_type="access") # Expecting an access token
|
||||
await decode_token(token, required_type="access", db=mock_db) # Expecting an access token
|
||||
|
||||
assert str(exc_info.value) == "Invalid token type: expected 'access', got 'refresh'."
|
||||
|
||||
Reference in New Issue
Block a user