Remove token revocation logic and unused dependencies

Eliminated the `RevokedToken` model and associated logic for managing token revocation. Removed unused files, related tests, and outdated dependencies in authentication modules. Simplified token decoding, user validation, and dependency injection by streamlining the flow and enhancing maintainability.
This commit is contained in:
2025-03-02 11:04:12 +01:00
parent 453016629f
commit cd92cd9780
24 changed files with 954 additions and 781 deletions

View File

@@ -1,41 +0,0 @@
"""Add RevokedToken model
Revision ID: 37315a5b4021
Revises: 38bf9e7e74b3
Create Date: 2025-02-28 17:11:07.741372
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '37315a5b4021'
down_revision: Union[str, None] = '38bf9e7e74b3'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('revoked_tokens',
sa.Column('jti', sa.String(length=50), nullable=False),
sa.Column('token_type', sa.String(length=20), nullable=False),
sa.Column('user_id', sa.UUID(), nullable=True),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_revoked_tokens_jti'), 'revoked_tokens', ['jti'], unique=True)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_revoked_tokens_jti'), table_name='revoked_tokens')
op.drop_table('revoked_tokens')
# ### end Alembic commands ###

View File

@@ -0,0 +1,138 @@
# app/api/dependencies/auth.py
from typing import Optional
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.orm import Session
from app.core.auth import get_token_data, TokenExpiredError, TokenInvalidError
from app.core.database import get_db
from app.models.user import User
# OAuth2 configuration
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
def get_current_user(
db: Session = Depends(get_db),
token: str = Depends(oauth2_scheme)
) -> User:
"""
Get the current authenticated user.
Args:
db: Database session
token: JWT token from request
Returns:
User: The authenticated user
Raises:
HTTPException: If authentication fails
"""
try:
# Decode token and get user ID
token_data = get_token_data(token)
# Get user from database
user = db.query(User).filter(User.id == token_data.user_id).first()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Inactive user"
)
return user
except TokenExpiredError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token expired",
headers={"WWW-Authenticate": "Bearer"}
)
except TokenInvalidError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"}
)
def get_current_active_user(
current_user: User = Depends(get_current_user)
) -> User:
"""
Check if the current user is active.
Args:
current_user: The current authenticated user
Returns:
User: The authenticated and active user
Raises:
HTTPException: If user is inactive
"""
if not current_user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Inactive user"
)
return current_user
def get_current_superuser(
current_user: User = Depends(get_current_user)
) -> User:
"""
Check if the current user is a superuser.
Args:
current_user: The current authenticated user
Returns:
User: The authenticated superuser
Raises:
HTTPException: If user is not a superuser
"""
if not current_user.is_superuser:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions"
)
return current_user
def get_optional_current_user(
db: Session = Depends(get_db),
token: Optional[str] = Depends(oauth2_scheme)
) -> Optional[User]:
"""
Get the current user if authenticated, otherwise return None.
Useful for endpoints that work with both authenticated and unauthenticated users.
Args:
db: Database session
token: JWT token from request
Returns:
User or None: The authenticated user or None
"""
if not token:
return None
try:
token_data = get_token_data(token)
user = db.query(User).filter(User.id == token_data.user_id).first()
if not user or not user.is_active:
return None
return user
except (TokenExpiredError, TokenInvalidError):
return None

View File

@@ -1,134 +1,3 @@
from typing import Any
from app.auth.utils import revoke_token, is_token_revoked
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.security import authenticate_user, create_access_token, create_refresh_token, decode_token
from app.core.database import get_db
from app.models.user import User
from app.schemas.token import TokenResponse, TokenPayload, RefreshToken
from app.schemas.user import UserResponse
from fastapi import APIRouter
router = APIRouter()
oauth2_scheme = OAuth2PasswordRequestForm
# Existing: User Login Endpoint
@router.post(
"/auth/login",
response_model=TokenResponse,
summary="Authenticate user and provide tokens"
)
async def login(
form_data: OAuth2PasswordRequestForm = Depends(),
db: AsyncSession = Depends(get_db)
) -> Any:
"""
Authenticate a user with their credentials and return an access and refresh token.
"""
user = await authenticate_user(email=form_data.username, password=form_data.password, db=db)
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials.")
# Generate access and refresh tokens
access_token = create_access_token({"sub": str(user.id), "type": "access"})
refresh_token = create_refresh_token({"sub": str(user.id), "type": "refresh"})
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
token_type="bearer",
expires_in=1800, # Example: 30 minutes for access token
user_id=str(user.id),
)
# New: Logout Endpoint (Revoke Token)
@router.post(
"/auth/logout",
summary="Revoke the current token",
response_model=dict,
status_code=status.HTTP_200_OK
)
async def logout(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(
lambda token=Depends(oauth2_scheme), db=Depends(get_db): decode_token(token, db=db))
):
"""
Logout the user by revoking the current token.
"""
# Decode the token and revoke it
payload: TokenPayload = await decode_token(token, db=db)
await revoke_token(payload.jti, payload.type, payload.sub, db)
return {"message": "Successfully logged out."}
# New: Bulk Logout (Revoke All of a User's Tokens)
@router.post(
"/auth/logout-all",
summary="Revoke all active tokens for the user",
response_model=dict,
status_code=status.HTTP_200_OK
)
async def logout_all(
db: AsyncSession = Depends(get_db),
current_user: User = Depends(
lambda token=Depends(oauth2_scheme), db=Depends(get_db): decode_token(token, db=db))
):
"""
Revoke all tokens for the current user, effectively logging them out across all devices.
"""
await db.execute("DELETE FROM revoked_tokens WHERE user_id = :user_id", {"user_id": str(current_user.id)})
await db.commit()
return {"message": "Logged out from all devices."}
# Updated: Refresh Token Endpoint
@router.post(
"/auth/refresh-token",
response_model=TokenResponse,
summary="Generate a new access token using a refresh token"
)
async def refresh_token(
refresh_token: RefreshToken,
db: AsyncSession = Depends(get_db)
) -> TokenResponse:
"""
Refresh the user's access token using their refresh token while ensuring it has not been revoked.
"""
payload: TokenPayload = await decode_token(refresh_token.refresh_token, required_type="refresh", db=db)
if await is_token_revoked(payload.jti, db):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=f"Token has been revoked.")
# Generate a new access token with the user's info
new_access_token = create_access_token({"sub": payload.sub, "type": "access"})
return TokenResponse(
access_token=new_access_token,
refresh_token=refresh_token.refresh_token, # Reuse existing refresh token
expires_in=1800, # Example: 30 minutes expiry for access token
token_type="bearer",
user_id=payload.sub,
)
# Existing: Get Current User Endpoint
@router.get(
"/auth/me",
response_model=UserResponse,
summary="Get user details from the token"
)
async def read_users_me(
current_user: User = Depends(
lambda token=Depends(oauth2_scheme), db=Depends(get_db): decode_token(token, db=db))
) -> UserResponse:
"""
Retrieves the details of the currently authenticated user.
"""
return current_user

View File

@@ -1,40 +0,0 @@
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_db
from app.auth.security import decode_token
from app.models.user import User
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token")
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_db)
):
try:
payload = await decode_token(token) # Use updated decode_token.
user_id: str = payload.sub
token_type: str = payload.type
if user_id is None or token_type != "access":
raise HTTPException(status_code=401, detail="Invalid token type.")
user = await db.get(User, user_id)
if user is None:
raise HTTPException(status_code=401, detail="User not found.")
return user
except JWTError as e:
raise HTTPException(status_code=401, detail=str(e))
async def get_current_active_user(
current_user: User = Depends(get_current_user),
):
if not current_user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user

View File

@@ -1,176 +0,0 @@
from datetime import datetime, timedelta, timezone
from typing import Optional
from uuid import uuid4
from fastapi import Depends
from jose import jwt, JWTError, ExpiredSignatureError, JOSEError
from passlib.context import CryptContext
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.database import get_db
from app.schemas.token import TokenPayload, TokenResponse
from auth.utils import is_token_revoked
# Configuration
SECRET_KEY = settings.SECRET_KEY
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
REFRESH_TOKEN_EXPIRE_DAYS = 7
# Password hashing context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a plain password against its hash."""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
"""Generate password hash."""
return pwd_context.hash(password)
def create_tokens(user_id: str) -> TokenResponse:
"""
Create both access and refresh tokens for a user.
Args:
user_id: The user's ID
Returns:
TokenResponse containing both tokens and metadata
"""
# Add `jti` during token creation
access_token = create_access_token({"sub": user_id, "jti": str(uuid4())})
refresh_token = create_refresh_token({"sub": user_id, "jti": str(uuid4())})
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
token_type="bearer",
expires_in=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
user_id=user_id,
scope="read write"
)
def create_token(
data: dict,
expires_delta: Optional[timedelta] = None,
token_type: str = "access"
) -> str:
"""Create a JWT token with the specified type and expiration."""
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + (
timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) if token_type == "access"
else timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
)
to_encode.update({
"exp": expire,
"type": token_type,
"iat": datetime.now(timezone.utc),
})
if "jti" not in to_encode:
to_encode["jti"] = str(uuid4()) # Ensure unique `jti` is always added
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create a new access token."""
# Ensure `data` includes `jti` for consistency
if "jti" not in data:
data["jti"] = str(uuid4())
return create_token(data, expires_delta, "access")
def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create a new refresh token."""
# Ensure `data` includes `jti` for consistency
if "jti" not in data:
data["jti"] = str(uuid4())
return create_token(data, expires_delta, "refresh")
async def decode_token(
token: str,
required_type: str = "access",
db: AsyncSession = Depends(get_db)
) -> TokenPayload:
"""
Decode and validate a JWT token, including revocation checks.
Args:
token (str): The JWT token to decode.
required_type (str): The expected token type (default: "access").
db (AsyncSession): Database session for token revocation checks.
Returns:
TokenPayload: The decoded token data.
Raises:
JWTError: If the token is expired, revoked, malformed, or fails validation.
"""
try:
# Step 1: Decode the JWT token
payload = jwt.decode(
token,
SECRET_KEY,
algorithms=[ALGORITHM],
options={
"verify_exp": True,
"verify_iat": True,
"require": ["exp", "iat", "sub", "type", "jti"]
}
)
except ExpiredSignatureError:
raise JWTError("Token has expired. Please refresh your token or login again.")
except JWTError as e:
if "Signature verification failed" in str(e):
raise JWTError("Invalid token signature. The token may have been tampered with or corrupted.")
raise JWTError(f"Failed to decode the token: {e}")
except JOSEError as e:
if "segments" in str(e).lower():
raise JWTError("Malformed token. The token format is invalid (e.g., not enough segments).")
raise JWTError("Failed to decode the token. Ensure the token is valid and correctly formatted.") from e
except Exception as e:
# Catch-all for unexpected exceptions during decoding
raise JWTError(f"An unexpected error occurred while decoding the token: {e}")
# Step 2: Validate Required Claims
required_claims = ["exp", "sub", "type", "jti"]
missing_claims = [claim for claim in required_claims if claim not in payload]
if missing_claims:
raise JWTError(f"Malformed token. Missing required claims: {', '.join(missing_claims)}.")
# Step 3: Validate Expiry
expiration = datetime.fromtimestamp(payload["exp"])
if datetime.now(timezone.utc) > expiration:
raise JWTError("Token has expired. Please refresh your token or login again.")
# Step 4: Validate Token Type
token_type = payload.get("type")
if token_type != required_type:
raise JWTError(f"Invalid token type: expected '{required_type}', got '{token_type}'.")
# Step 5: Check Revocation
jti = payload.get("jti")
if await is_token_revoked(jti, db):
raise JWTError("Token has been revoked. Please login again to generate a new token.")
# Step 6: Return Validated Token Payload
return TokenPayload(
sub=payload["sub"],
type=payload["type"],
exp=expiration,
iat=datetime.fromtimestamp(payload.get("iat", 0)),
jti=jti
)

View File

@@ -1,45 +0,0 @@
from datetime import datetime, timezone, timedelta
from sqlalchemy import delete
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.token import RevokedToken
async def revoke_token(jti: str, token_type: str, user_id: str, db: AsyncSession):
"""Revoke a token by storing its `jti` in the revoked_tokens table."""
revoked_token = RevokedToken(jti=jti, token_type=token_type, user_id=user_id)
db.add(revoked_token)
await db.commit()
async def is_token_revoked(jti: str, db: AsyncSession) -> bool:
"""Check whether the token's JTI is in the revoked_tokens table."""
from sqlalchemy import select
result = await db.execute(select(RevokedToken).where(RevokedToken.jti == jti))
revoked = result.scalar_one_or_none()
return revoked is not None
async def cleanup_expired_tokens(db: AsyncSession):
"""Delete revoked tokens that are past their expiration time."""
now = datetime.now(timezone.utc)
# For access tokens (shorter expiry)
expire_before = now - timedelta(days=1) # Keep for 1 day past expiry
await db.execute(
delete(RevokedToken).where(
(RevokedToken.token_type == "access") &
(RevokedToken.created_at < expire_before)
)
)
# For refresh tokens (longer expiry)
expire_before = now - timedelta(days=14) # Keep for 14 days past expiry
await db.execute(
delete(RevokedToken).where(
(RevokedToken.token_type == "refresh") &
(RevokedToken.created_at < expire_before)
)
)
await db.commit()

183
backend/app/core/auth.py Normal file
View File

@@ -0,0 +1,183 @@
# app/core/auth.py
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional, Union
import uuid
from jose import jwt, JWTError
from passlib.context import CryptContext
from pydantic import ValidationError
from app.core.config import settings
from app.schemas.users import TokenData, TokenPayload
# Password hashing context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# Custom exceptions for auth
class AuthError(Exception):
"""Base authentication error"""
pass
class TokenExpiredError(AuthError):
"""Token has expired"""
pass
class TokenInvalidError(AuthError):
"""Token is invalid"""
pass
class TokenMissingClaimError(AuthError):
"""Token is missing a required claim"""
pass
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against a hash."""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
"""Generate a password hash."""
return pwd_context.hash(password)
def create_access_token(
subject: Union[str, Any],
expires_delta: Optional[timedelta] = None,
claims: Optional[Dict[str, Any]] = None
) -> str:
"""
Create a JWT access token.
Args:
subject: The subject of the token (usually user_id)
expires_delta: Optional expiration time delta
claims: Optional additional claims to include in the token
Returns:
Encoded JWT token
"""
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
# Base token data
to_encode = {
"sub": str(subject),
"exp": expire,
"iat": datetime.now(tz=timezone.utc),
"jti": str(uuid.uuid4()),
"type": "access"
}
# Add custom claims
if claims:
to_encode.update(claims)
# Create the JWT
encoded_jwt = jwt.encode(
to_encode,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
)
return encoded_jwt
def create_refresh_token(
subject: Union[str, Any],
expires_delta: Optional[timedelta] = None
) -> str:
"""
Create a JWT refresh token.
Args:
subject: The subject of the token (usually user_id)
expires_delta: Optional expiration time delta
Returns:
Encoded JWT refresh token
"""
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
to_encode = {
"sub": str(subject),
"exp": expire,
"iat": datetime.now(timezone.utc),
"jti": str(uuid.uuid4()),
"type": "refresh"
}
encoded_jwt = jwt.encode(
to_encode,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
)
return encoded_jwt
def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload:
"""
Decode and verify a JWT token.
Args:
token: JWT token to decode
verify_type: Optional token type to verify (access or refresh)
Returns:
TokenPayload: The decoded token data
Raises:
TokenExpiredError: If the token has expired
TokenInvalidError: If the token is invalid
TokenMissingClaimError: If a required claim is missing
"""
try:
payload = jwt.decode(
token,
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM]
)
# Check required claims before Pydantic validation
if not payload.get("sub"):
raise TokenMissingClaimError("Token missing 'sub' claim")
# Verify token type if specified
if verify_type and payload.get("type") != verify_type:
raise TokenInvalidError(f"Invalid token type, expected {verify_type}")
# Now validate with Pydantic
token_data = TokenPayload(**payload)
return token_data
except JWTError as e:
# Check if the error is due to an expired token
if "expired" in str(e).lower():
raise TokenExpiredError("Token has expired")
raise TokenInvalidError("Invalid authentication token")
except ValidationError:
raise TokenInvalidError("Invalid token payload")
def get_token_data(token: str) -> TokenData:
"""
Extract the user ID and superuser status from a token.
Args:
token: JWT token
Returns:
TokenData with user_id and is_superuser
"""
payload = decode_token(token)
user_id = payload.sub
is_superuser = payload.is_superuser or False
return TokenData(user_id=uuid.UUID(user_id), is_superuser=is_superuser)

View File

@@ -14,6 +14,7 @@ class Settings(BaseSettings):
POSTGRES_PORT: str = "5432"
POSTGRES_DB: str = "eventspace"
DATABASE_URL: Optional[str] = None
REFRESH_TOKEN_EXPIRE_DAYS: int = 60
db_pool_size: int = 20 # Default connection pool size
db_max_overflow: int = 50 # Maximum overflow connections
db_pool_timeout: int = 30 # Seconds to wait for a connection
@@ -24,6 +25,7 @@ class Settings(BaseSettings):
sql_echo_pool: bool = False # Log connection pool events
sql_echo_timing: bool = False # Log query execution times
slow_query_threshold: float = 0.5 # Log queries taking longer than this
@property
def database_url(self) -> str:
"""

View File

@@ -1,15 +1,12 @@
import logging
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from app.core.config import settings
from app.api.main import api_router
import logging
from auth.utils import cleanup_expired_tokens
from app.core.database import SessionLocal
from app.core.config import settings
scheduler = AsyncIOScheduler()
@@ -32,26 +29,6 @@ app.add_middleware(
)
# Create a function that gets its own database session
async def scheduled_cleanup():
async with SessionLocal() as db:
await cleanup_expired_tokens(db)
@app.on_event("startup")
async def start_scheduler():
# Run every day at 3 AM
scheduler.add_job(
scheduled_cleanup,
CronTrigger(hour=10, minute=0),
id="token_cleanup",
name="Clean up expired revoked tokens"
)
scheduler.start()
@app.on_event("shutdown")
async def stop_scheduler():
scheduler.shutdown()
@app.get("/", response_class=HTMLResponse)
async def root():
return """
@@ -67,4 +44,5 @@ async def root():
</html>
"""
app.include_router(api_router, prefix=settings.API_V1_STR)
app.include_router(api_router, prefix=settings.API_V1_STR)

View File

@@ -29,7 +29,6 @@ from .gift import (
from .email_template import EmailTemplate, TemplateType
from .notification_log import NotificationLog, NotificationType, NotificationStatus
from .activity_log import ActivityLog, ActivityType
from .token import RevokedToken
# Make sure all models are imported above this line
__all__ = [
'Base', 'TimestampMixin', 'UUIDMixin',
@@ -40,5 +39,4 @@ __all__ = [
'EmailTemplate', 'TemplateType',
'NotificationLog', 'NotificationType', 'NotificationStatus',
'ActivityLog', 'ActivityType',
'RevokedToken',
]

View File

@@ -1,15 +0,0 @@
from sqlalchemy import Column, String, ForeignKey
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from app.models.base import Base, TimestampMixin, UUIDMixin
class RevokedToken(UUIDMixin, TimestampMixin, Base):
"""Model to store revoked JWT tokens via their jti (JWT ID)."""
__tablename__ = "revoked_tokens"
jti = Column(String(length=50), nullable=False, unique=True, index=True)
token_type = Column(String(length=20), nullable=False)
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"))
user = relationship("User", back_populates="revoked_tokens")

View File

@@ -25,7 +25,6 @@ class User(Base, UUIDMixin, TimestampMixin):
foreign_keys="EventManager.user_id"
)
guest_profiles = relationship("Guest", back_populates="user", foreign_keys="Guest.user_id")
revoked_tokens = relationship("RevokedToken", back_populates="user", cascade="all, delete")
def __repr__(self):
return f"<User {self.email}>"

View File

@@ -1,66 +0,0 @@
from typing import Optional
from uuid import UUID
from pydantic import BaseModel, EmailStr, Field, field_validator
from datetime import datetime
from passlib.hash import bcrypt
# Base schema with shared user attributes
class UserBase(BaseModel):
"""Base schema with common user attributes."""
email: EmailStr
first_name: str
last_name: str
# Schema for creating a new user
class UserCreate(UserBase):
"""Schema for user registration."""
password: str = Field(
...,
min_length=8,
description="Password must be at least 8 characters"
)
@field_validator('password')
def password_strength(cls, v):
# Add more complex password validation if needed
if len(v) < 8:
raise ValueError('Password must be at least 8 characters')
return v
def hash_password(self) -> str:
"""Hash the password before saving it to the database."""
return bcrypt.hash(self.password)
# Schema for updating user details
class UserUpdate(BaseModel):
"""Schema for updating user information."""
email: Optional[EmailStr] = None
first_name: Optional[str] = None
last_name: Optional[str] = None
phone_number: Optional[str] = None
is_active: Optional[bool] = None
preferences: Optional[dict] = None # Provide preferences support
# Schema for user responses (read-only fields)
class UserResponse(UserBase):
"""Schema for user responses in API."""
id: UUID
is_active: bool
is_superuser: bool # Include roles or superuser flags if needed
preferences: Optional[dict] = None # Include preferences in response
created_at: datetime
updated_at: Optional[datetime] = None
class Config:
orm_mode = True # Enable mapping SQLAlchemy models to this schema
# Schema for user authentication (e.g., login requests)
class UserAuth(BaseModel):
"""Schema for user authentication."""
email: EmailStr
password: str

View File

@@ -0,0 +1,126 @@
# app/schemas/users.py
import re
from datetime import datetime
from typing import Optional, Dict, Any
from uuid import UUID
import pydantic
from pydantic import BaseModel, EmailStr, field_validator, ConfigDict
class UserBase(BaseModel):
email: EmailStr
first_name: str
last_name: str
phone_number: Optional[str] = None
@field_validator('phone_number')
@classmethod
def validate_phone_number(cls, v: Optional[str]) -> Optional[str]:
if v is None:
return v
# Simple regex for phone validation
if not re.match(r'^\+?[0-9\s\-\(\)]{8,20}$', v):
raise ValueError('Invalid phone number format')
return v
class UserCreate(UserBase):
password: str
@field_validator('password')
@classmethod
def password_strength(cls, v: str) -> str:
"""Basic password strength validation"""
if len(v) < 8:
raise ValueError('Password must be at least 8 characters')
if not any(char.isdigit() for char in v):
raise ValueError('Password must contain at least one digit')
if not any(char.isupper() for char in v):
raise ValueError('Password must contain at least one uppercase letter')
return v
class UserUpdate(BaseModel):
first_name: Optional[str] = None
last_name: Optional[str] = None
phone_number: Optional[str] = None
preferences: Optional[Dict[str, Any]] = None
@field_validator('phone_number')
@classmethod
def validate_phone_number(cls, v: Optional[str]) -> Optional[str]:
if v is None:
return v
# Simple regex for phone validation
if not re.match(r'^\+?[0-9\s\-\(\)]{8,20}$', v):
raise ValueError('Invalid phone number format')
return v
class UserInDB(UserBase):
id: UUID
is_active: bool
is_superuser: bool
created_at: datetime
updated_at: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
class UserResponse(UserBase):
id: UUID
is_active: bool
is_superuser: bool
created_at: datetime
updated_at: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
class Token(BaseModel):
access_token: str
refresh_token: Optional[str] = None
token_type: str = "bearer"
class TokenPayload(BaseModel):
sub: str # User ID
exp: int # Expiration time
iat: Optional[int] = None # Issued at
jti: Optional[str] = None # JWT ID
is_superuser: Optional[bool] = False
first_name: Optional[str] = None
email: Optional[str] = None
type: Optional[str] = None # Token type (access/refresh)
class TokenData(BaseModel):
user_id: UUID
is_superuser: bool = False
class PasswordReset(BaseModel):
token: str
new_password: str
@field_validator('new_password')
@classmethod
def password_strength(cls, v: str) -> str:
"""Basic password strength validation"""
if len(v) < 8:
raise ValueError('Password must be at least 8 characters')
if not any(char.isdigit() for char in v):
raise ValueError('Password must contain at least one digit')
if not any(char.isupper() for char in v):
raise ValueError('Password must contain at least one uppercase letter')
return v
class LoginRequest(BaseModel):
email: EmailStr
password: str
class RefreshTokenRequest(BaseModel):
refresh_token: str

View File

View File

@@ -0,0 +1,193 @@
# app/services/auth_service.py
import logging
from typing import Optional
from uuid import UUID
from sqlalchemy.orm import Session
from app.core.auth import (
verify_password,
get_password_hash,
create_access_token,
create_refresh_token,
TokenExpiredError,
TokenInvalidError
)
from app.models.user import User
from app.schemas.users import Token, UserCreate
logger = logging.getLogger(__name__)
class AuthenticationError(Exception):
"""Exception raised for authentication errors"""
pass
class AuthService:
"""Service for handling authentication operations"""
@staticmethod
def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
"""
Authenticate a user with email and password.
Args:
db: Database session
email: User email
password: User password
Returns:
User if authenticated, None otherwise
"""
user = db.query(User).filter(User.email == email).first()
if not user:
return None
if not verify_password(password, user.password_hash):
return None
if not user.is_active:
raise AuthenticationError("User account is inactive")
return user
@staticmethod
def create_user(db: Session, user_data: UserCreate) -> User:
"""
Create a new user.
Args:
db: Database session
user_data: User data
Returns:
Created user
"""
# Check if user already exists
existing_user = db.query(User).filter(User.email == user_data.email).first()
if existing_user:
raise AuthenticationError("User with this email already exists")
# Create new user
hashed_password = get_password_hash(user_data.password)
# Create user object from model
user = User(
email=user_data.email,
password_hash=hashed_password,
first_name=user_data.first_name,
last_name=user_data.last_name,
phone_number=user_data.phone_number,
is_active=True,
is_superuser=False
)
db.add(user)
db.commit()
db.refresh(user)
return user
@staticmethod
def create_tokens(user: User) -> Token:
"""
Create access and refresh tokens for a user.
Args:
user: User to create tokens for
Returns:
Token object with access and refresh tokens
"""
# Generate claims
claims = {
"is_superuser": user.is_superuser,
"email": user.email,
"first_name": user.first_name
}
# Create tokens
access_token = create_access_token(
subject=str(user.id),
claims=claims
)
refresh_token = create_refresh_token(
subject=str(user.id)
)
return Token(
access_token=access_token,
refresh_token=refresh_token
)
@staticmethod
def refresh_tokens(db: Session, refresh_token: str) -> Token:
"""
Generate new tokens using a refresh token.
Args:
db: Database session
refresh_token: Valid refresh token
Returns:
New access and refresh tokens
Raises:
TokenExpiredError: If refresh token has expired
TokenInvalidError: If refresh token is invalid
"""
from app.core.auth import decode_token, get_token_data
try:
# Verify token is a refresh token
decode_token(refresh_token, verify_type="refresh")
# Get user ID from token
token_data = get_token_data(refresh_token)
user_id = token_data.user_id
# Get user from database
user = db.query(User).filter(User.id == user_id).first()
if not user or not user.is_active:
raise TokenInvalidError("Invalid user or inactive account")
# Generate new tokens
return AuthService.create_tokens(user)
except (TokenExpiredError, TokenInvalidError) as e:
logger.warning(f"Token refresh failed: {str(e)}")
raise
@staticmethod
def change_password(db: Session, user_id: UUID, current_password: str, new_password: str) -> bool:
"""
Change a user's password.
Args:
db: Database session
user_id: User ID
current_password: Current password
new_password: New password
Returns:
True if password was changed successfully
Raises:
AuthenticationError: If current password is incorrect
"""
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise AuthenticationError("User not found")
# Verify current password
if not verify_password(current_password, user.password_hash):
raise AuthenticationError("Current password is incorrect")
# Update password
user.password_hash = get_password_hash(new_password)
db.commit()
return True

View File

@@ -1,5 +1,6 @@
import logging
from sqlalchemy import create_engine, event
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker, clear_mappers
from sqlalchemy.pool import StaticPool
@@ -42,4 +43,37 @@ def teardown_test_db(engine):
Base.metadata.drop_all(engine)
# Dispose of engine
engine.dispose()
engine.dispose()
async def get_async_test_engine():
"""Create an async SQLite in-memory engine specifically for testing"""
test_engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool, # Use static pool for in-memory testing
echo=False
)
return test_engine
async def setup_async_test_db():
"""Create an async test database and session factory"""
test_engine = await get_async_test_engine()
async with test_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
AsyncTestingSessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=test_engine,
expire_on_commit=False,
class_=AsyncSession
)
return test_engine, AsyncTestingSessionLocal
async def teardown_async_test_db(engine):
"""Clean up after async tests"""
await engine.dispose()

View File

@@ -11,7 +11,7 @@ sqlalchemy>=2.0.29
alembic>=1.14.1
psycopg2-binary>=2.9.9
asyncpg>=0.29.0
aiosqlite==0.21.0
# Security and authentication
python-jose>=3.4.0
passlib>=1.7.4

View File

@@ -1,85 +0,0 @@
from datetime import datetime, timezone
from unittest.mock import AsyncMock
import pytest
from fastapi import HTTPException
from jose import jwt
from app.auth.dependencies import get_current_user, get_current_active_user
from app.auth.security import SECRET_KEY, ALGORITHM
from app.models.user import User
@pytest.fixture
def mock_user():
return User(
id="123e4567-e89b-12d3-a456-426614174000",
email="test@example.com",
password_hash="hashedpassword",
is_active=True
)
@pytest.mark.asyncio
async def test_get_current_user_success(mock_user):
# Create a valid access token with required claims
valid_token = jwt.encode(
{"sub": str(mock_user.id), "type": "access", "exp": datetime.now(tz=timezone.utc).timestamp() + 3600},
SECRET_KEY,
algorithm=ALGORITHM
)
# Mock database session
mock_db = AsyncMock()
mock_db.get.return_value = mock_user # Ensure `db.get()` returns the mock_user
# Call the dependency
user = await get_current_user(token=valid_token, db=mock_db)
# Assertions
assert user == mock_user, "Returned user does not match the mocked user."
mock_db.get.assert_called_once_with(User, mock_user.id)
@pytest.mark.asyncio
async def test_get_current_user_invalid_token():
invalid_token = "invalid.token.payload"
with pytest.raises(HTTPException) as exc_info:
await get_current_user(token=invalid_token, db=AsyncMock())
assert exc_info.value.status_code == 401
assert exc_info.value.detail == "Could not validate credentials"
@pytest.mark.asyncio
async def test_get_current_user_wrong_token_type():
token = jwt.encode({"sub": "123", "type": "refresh"}, SECRET_KEY, algorithm=ALGORITHM)
with pytest.raises(HTTPException) as exc_info:
await get_current_user(token=token, db=AsyncMock())
assert exc_info.value.status_code == 401
assert exc_info.value.detail == "Could not validate credentials"
@pytest.mark.asyncio
async def test_get_current_active_user_success(mock_user):
result = await get_current_active_user(mock_user)
assert result == mock_user
@pytest.mark.asyncio
async def test_get_current_active_user_inactive():
inactive_user = User(
id="123e4567-e89b-12d3-a456-426614174000",
email="inactive@example.com",
password_hash="hashedpassword",
is_active=False
)
with pytest.raises(HTTPException) as exc_info:
await get_current_active_user(inactive_user)
assert exc_info.value.status_code == 400
assert exc_info.value.detail == "Inactive user"

View File

@@ -1,147 +0,0 @@
from datetime import timedelta, datetime
from unittest.mock import AsyncMock
import pytest
from jose import jwt, JWTError
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.security import (
get_password_hash, verify_password,
create_access_token, create_refresh_token,
decode_token, SECRET_KEY, ALGORITHM
)
from app.schemas.token import TokenPayload
def test_password_hashing():
plain_password = "securepassword123"
hashed_password = get_password_hash(plain_password)
# Ensure hashed passwords are not the same
assert hashed_password != plain_password
# Test password verification
assert verify_password(plain_password, hashed_password)
assert not verify_password("wrongpassword", hashed_password)
def test_access_token_creation():
user_id = "123e4567-e89b-12d3-a456-426614174000"
token = create_access_token({"sub": user_id})
decoded_payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
assert decoded_payload.get("sub") == user_id
assert decoded_payload.get("type") == "access"
def test_refresh_token_creation():
user_id = "123e4567-e89b-12d3-a456-426614174000"
token = create_refresh_token({"sub": user_id})
decoded_payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
assert decoded_payload.get("sub") == user_id
assert decoded_payload.get("type") == "refresh"
@pytest.mark.asyncio
async def test_decode_token_valid():
user_id = "123e4567-e89b-12d3-a456-426614174000"
access_token = create_access_token({"sub": user_id, "jti": "test-jti"})
# Mock is_token_revoked to return False
mock_db = AsyncMock(spec=AsyncSession)
mock_db.get = AsyncMock(return_value=None)
token_payload = await decode_token(access_token, db=mock_db)
assert isinstance(token_payload, TokenPayload)
assert token_payload.sub == user_id
assert token_payload.type == "access"
@pytest.mark.asyncio
async def test_decode_token_expired():
user_id = "123e4567-e89b-12d3-a456-426614174000"
token = create_access_token({"sub": user_id, "jti": "test-jti"}, expires_delta=timedelta(seconds=-1))
# Mock database
mock_db = AsyncMock(spec=AsyncSession)
with pytest.raises(JWTError) as exc_info:
await decode_token(token, db=mock_db)
assert str(exc_info.value) == "Token has been revoked."
@pytest.mark.asyncio
async def test_decode_token_missing_exp():
# Create a token without the `exp` claim
payload = {"sub": "123e4567-e89b-12d3-a456-426614174000", "type": "access", "jti": "test-jti"}
token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
# Mock database
mock_db = AsyncMock(spec=AsyncSession)
with pytest.raises(JWTError) as exc_info:
await decode_token(token, db=mock_db)
assert str(exc_info.value) == "Malformed token. Missing required claim(s)."
@pytest.mark.asyncio
async def test_decode_token_missing_sub():
# Create a token without the `sub` claim
payload = {"exp": datetime.now().timestamp() + 60, "type": "access", "jti": "test-jti"}
token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
# Mock database
mock_db = AsyncMock(spec=AsyncSession)
with pytest.raises(JWTError) as exc_info:
await decode_token(token, db=mock_db)
assert str(exc_info.value) == "Malformed token. Missing required claim(s)."
@pytest.mark.asyncio
async def test_decode_token_invalid_signature():
# Use a different secret key for signing
token = jwt.encode({"sub": "123", "type": "access", "jti": "test-jti"}, "wrong_secret", algorithm=ALGORITHM)
# Mock database
mock_db = AsyncMock(spec=AsyncSession)
with pytest.raises(JWTError) as exc_info:
await decode_token(token, db=mock_db)
assert str(exc_info.value) == "Signature verification failed."
@pytest.mark.asyncio
async def test_decode_token_malformed():
malformed_token = "malformed.header.payload"
# Mock database
mock_db = AsyncMock(spec=AsyncSession)
with pytest.raises(JWTError) as exc_info:
await decode_token(malformed_token, db=mock_db)
assert str(exc_info.value) == "Invalid token."
@pytest.mark.asyncio
async def test_decode_token_invalid_type():
user_id = "123e4567-e89b-12d3-a456-426614174000"
token = create_refresh_token({"sub": user_id, "jti": "test-jti"}) # Token type is "refresh"
# Mock database
mock_db = AsyncMock(spec=AsyncSession)
with pytest.raises(JWTError) as exc_info:
await decode_token(token, required_type="access", db=mock_db) # Expecting an access token
assert str(exc_info.value) == "Invalid token type: expected 'access', got 'refresh'."

View File

@@ -9,7 +9,7 @@ from app.models import Event, GiftItem, GiftStatus, GiftPriority, GiftCategory,
EventTheme, Guest, GuestStatus, ActivityType, ActivityLog, EmailTemplate, TemplateType, NotificationLog, \
NotificationType, NotificationStatus
from app.models.user import User
from app.utils.test_utils import setup_test_db, teardown_test_db
from app.utils.test_utils import setup_test_db, teardown_test_db, setup_async_test_db, teardown_async_test_db
@pytest.fixture(scope="function")
@@ -30,6 +30,15 @@ def db_session():
teardown_test_db(test_engine)
@pytest.fixture(scope="function") # Define a fixture
async def async_test_db():
"""Fixture provides new testing engine and session for each test run to improve isolation."""
test_engine, AsyncTestingSessionLocal = await setup_async_test_db()
yield test_engine, AsyncTestingSessionLocal
await teardown_async_test_db(test_engine)
@pytest.fixture
def mock_user(db_session):
"""Fixture to create and return a mock User instance."""
@@ -72,7 +81,6 @@ def event_fixture(db_session, mock_user):
return event
@pytest.fixture
def gift_item_fixture(db_session, mock_user):
"""

View File

View File

@@ -0,0 +1,260 @@
# tests/core/test_auth.py
import uuid
import pytest
from datetime import datetime, timedelta, timezone
from jose import jwt
from pydantic import ValidationError
from app.core.auth import (
verify_password,
get_password_hash,
create_access_token,
create_refresh_token,
decode_token,
get_token_data,
TokenExpiredError,
TokenInvalidError,
TokenMissingClaimError
)
from app.core.config import settings
class TestPasswordHandling:
"""Tests for password hashing and verification functions"""
def test_password_hash_different_from_password(self):
"""Test that a password hash is different from the original password"""
password = "TestPassword123"
hashed = get_password_hash(password)
assert hashed != password
def test_verify_correct_password(self):
"""Test that verify_password returns True for the correct password"""
password = "TestPassword123"
hashed = get_password_hash(password)
assert verify_password(password, hashed) is True
def test_verify_incorrect_password(self):
"""Test that verify_password returns False for an incorrect password"""
password = "TestPassword123"
wrong_password = "WrongPassword123"
hashed = get_password_hash(password)
assert verify_password(wrong_password, hashed) is False
def test_same_password_different_hash(self):
"""Test that the same password gets a different hash each time"""
password = "TestPassword123"
hash1 = get_password_hash(password)
hash2 = get_password_hash(password)
assert hash1 != hash2
class TestTokenCreation:
"""Tests for token creation functions"""
def test_create_access_token(self):
"""Test that an access token is created with the correct claims"""
user_id = str(uuid.uuid4())
custom_claims = {
"email": "test@example.com",
"first_name": "Test",
"is_superuser": True
}
token = create_access_token(subject=user_id, claims=custom_claims)
# Decode token to verify claims
payload = jwt.decode(
token,
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM]
)
# Check standard claims
assert payload["sub"] == user_id
assert "jti" in payload
assert "exp" in payload
assert "iat" in payload
assert payload["type"] == "access"
# Check custom claims
for key, value in custom_claims.items():
assert payload[key] == value
def test_create_refresh_token(self):
"""Test that a refresh token is created with the correct claims"""
user_id = str(uuid.uuid4())
token = create_refresh_token(subject=user_id)
# Decode token to verify claims
payload = jwt.decode(
token,
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM]
)
# Check standard claims
assert payload["sub"] == user_id
assert "jti" in payload
assert "exp" in payload
assert "iat" in payload
assert payload["type"] == "refresh"
def test_token_expiration(self):
"""Test that tokens have the correct expiration time"""
user_id = str(uuid.uuid4())
expires = timedelta(minutes=5)
# Create token with specific expiration
token = create_access_token(
subject=user_id,
expires_delta=expires
)
# Decode token
payload = jwt.decode(
token,
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM]
)
# Get actual expiration time from token
expiration = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
# Calculate expected expiration (approximately)
now = datetime.now(timezone.utc)
expected_expiration = now + expires
# Difference should be small (less than 1 second)
difference = abs((expiration - expected_expiration).total_seconds())
assert difference < 1
class TestTokenDecoding:
"""Tests for token decoding and validation functions"""
def test_decode_valid_token(self):
"""Test that a valid token can be decoded"""
user_id = str(uuid.uuid4())
token = create_access_token(subject=user_id)
# Decode token
payload = decode_token(token)
# Check that the subject matches
assert payload.sub == user_id
def test_decode_expired_token(self):
"""Test that an expired token raises TokenExpiredError"""
user_id = str(uuid.uuid4())
# Create a token that's already expired by directly manipulating the payload
now = datetime.now(timezone.utc)
expired_time = now - timedelta(hours=1) # 1 hour in the past
# Create the expired token manually
payload = {
"sub": user_id,
"exp": int(expired_time.timestamp()), # Set expiration in the past
"iat": int(now.timestamp()),
"jti": str(uuid.uuid4()),
"type": "access"
}
expired_token = jwt.encode(
payload,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
)
# Attempting to decode should raise TokenExpiredError
with pytest.raises(TokenExpiredError):
decode_token(expired_token)
def test_decode_invalid_token(self):
"""Test that an invalid token raises TokenInvalidError"""
invalid_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJpbnZhbGlkIn0.invalid-signature"
with pytest.raises(TokenInvalidError):
decode_token(invalid_token)
def test_decode_token_with_missing_sub(self):
"""Test that a token without 'sub' claim raises TokenMissingClaimError"""
# Create a token without a subject
now = datetime.now(timezone.utc)
payload = {
"exp": int((now + timedelta(minutes=30)).timestamp()),
"iat": int(now.timestamp()),
"jti": str(uuid.uuid4()),
"type": "access"
# No 'sub' claim
}
token = jwt.encode(
payload,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
)
with pytest.raises(TokenMissingClaimError):
decode_token(token)
def test_decode_token_with_wrong_type(self):
"""Test that verifying a token with wrong type raises TokenInvalidError"""
user_id = str(uuid.uuid4())
token = create_access_token(subject=user_id)
# Try to verify it as a refresh token
with pytest.raises(TokenInvalidError):
decode_token(token, verify_type="refresh")
def test_decode_with_invalid_payload(self):
"""Test that a token with invalid payload structure raises TokenInvalidError"""
# Create a token with an invalid payload structure - missing 'sub' which is required
# but including 'exp' to avoid the expiration check
now = datetime.now(timezone.utc)
payload = {
# Missing "sub" field which is required
"exp": int((now + timedelta(minutes=30)).timestamp()),
"iat": int(now.timestamp()),
"jti": str(uuid.uuid4()),
"invalid_field": "test"
}
token = jwt.encode(
payload,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
)
# Should raise TokenMissingClaimError due to missing 'sub'
with pytest.raises(TokenMissingClaimError):
decode_token(token)
# Create another token with invalid type for required field
payload = {
"sub": 123, # sub should be a string, not an integer
"exp": int((now + timedelta(minutes=30)).timestamp()),
}
token = jwt.encode(
payload,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
)
# Should raise TokenInvalidError due to ValidationError
with pytest.raises(TokenInvalidError):
decode_token(token)
def test_get_token_data(self):
"""Test extracting TokenData from a token"""
user_id = uuid.uuid4()
token = create_access_token(
subject=str(user_id),
claims={"is_superuser": True}
)
token_data = get_token_data(token)
assert token_data.user_id == user_id
assert token_data.is_superuser is True

View File