Remove token revocation logic and unused dependencies

Eliminated the `RevokedToken` model and associated logic for managing token revocation. Removed unused files, related tests, and outdated dependencies in authentication modules. Simplified token decoding, user validation, and dependency injection by streamlining the flow and enhancing maintainability.
This commit is contained in:
2025-03-02 11:04:12 +01:00
parent 453016629f
commit cd92cd9780
24 changed files with 954 additions and 781 deletions

View File

@@ -1,41 +0,0 @@
"""Add RevokedToken model
Revision ID: 37315a5b4021
Revises: 38bf9e7e74b3
Create Date: 2025-02-28 17:11:07.741372
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '37315a5b4021'
down_revision: Union[str, None] = '38bf9e7e74b3'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('revoked_tokens',
sa.Column('jti', sa.String(length=50), nullable=False),
sa.Column('token_type', sa.String(length=20), nullable=False),
sa.Column('user_id', sa.UUID(), nullable=True),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_revoked_tokens_jti'), 'revoked_tokens', ['jti'], unique=True)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_revoked_tokens_jti'), table_name='revoked_tokens')
op.drop_table('revoked_tokens')
# ### end Alembic commands ###

View File

@@ -0,0 +1,138 @@
# app/api/dependencies/auth.py
from typing import Optional
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.orm import Session
from app.core.auth import get_token_data, TokenExpiredError, TokenInvalidError
from app.core.database import get_db
from app.models.user import User
# OAuth2 configuration
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
def get_current_user(
db: Session = Depends(get_db),
token: str = Depends(oauth2_scheme)
) -> User:
"""
Get the current authenticated user.
Args:
db: Database session
token: JWT token from request
Returns:
User: The authenticated user
Raises:
HTTPException: If authentication fails
"""
try:
# Decode token and get user ID
token_data = get_token_data(token)
# Get user from database
user = db.query(User).filter(User.id == token_data.user_id).first()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Inactive user"
)
return user
except TokenExpiredError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token expired",
headers={"WWW-Authenticate": "Bearer"}
)
except TokenInvalidError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"}
)
def get_current_active_user(
current_user: User = Depends(get_current_user)
) -> User:
"""
Check if the current user is active.
Args:
current_user: The current authenticated user
Returns:
User: The authenticated and active user
Raises:
HTTPException: If user is inactive
"""
if not current_user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Inactive user"
)
return current_user
def get_current_superuser(
current_user: User = Depends(get_current_user)
) -> User:
"""
Check if the current user is a superuser.
Args:
current_user: The current authenticated user
Returns:
User: The authenticated superuser
Raises:
HTTPException: If user is not a superuser
"""
if not current_user.is_superuser:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions"
)
return current_user
def get_optional_current_user(
db: Session = Depends(get_db),
token: Optional[str] = Depends(oauth2_scheme)
) -> Optional[User]:
"""
Get the current user if authenticated, otherwise return None.
Useful for endpoints that work with both authenticated and unauthenticated users.
Args:
db: Database session
token: JWT token from request
Returns:
User or None: The authenticated user or None
"""
if not token:
return None
try:
token_data = get_token_data(token)
user = db.query(User).filter(User.id == token_data.user_id).first()
if not user or not user.is_active:
return None
return user
except (TokenExpiredError, TokenInvalidError):
return None

View File

@@ -1,134 +1,3 @@
from typing import Any
from app.auth.utils import revoke_token, is_token_revoked
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.security import authenticate_user, create_access_token, create_refresh_token, decode_token
from app.core.database import get_db
from app.models.user import User
from app.schemas.token import TokenResponse, TokenPayload, RefreshToken
from app.schemas.user import UserResponse
from fastapi import APIRouter
router = APIRouter()
oauth2_scheme = OAuth2PasswordRequestForm
# Existing: User Login Endpoint
@router.post(
"/auth/login",
response_model=TokenResponse,
summary="Authenticate user and provide tokens"
)
async def login(
form_data: OAuth2PasswordRequestForm = Depends(),
db: AsyncSession = Depends(get_db)
) -> Any:
"""
Authenticate a user with their credentials and return an access and refresh token.
"""
user = await authenticate_user(email=form_data.username, password=form_data.password, db=db)
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials.")
# Generate access and refresh tokens
access_token = create_access_token({"sub": str(user.id), "type": "access"})
refresh_token = create_refresh_token({"sub": str(user.id), "type": "refresh"})
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
token_type="bearer",
expires_in=1800, # Example: 30 minutes for access token
user_id=str(user.id),
)
# New: Logout Endpoint (Revoke Token)
@router.post(
"/auth/logout",
summary="Revoke the current token",
response_model=dict,
status_code=status.HTTP_200_OK
)
async def logout(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(
lambda token=Depends(oauth2_scheme), db=Depends(get_db): decode_token(token, db=db))
):
"""
Logout the user by revoking the current token.
"""
# Decode the token and revoke it
payload: TokenPayload = await decode_token(token, db=db)
await revoke_token(payload.jti, payload.type, payload.sub, db)
return {"message": "Successfully logged out."}
# New: Bulk Logout (Revoke All of a User's Tokens)
@router.post(
"/auth/logout-all",
summary="Revoke all active tokens for the user",
response_model=dict,
status_code=status.HTTP_200_OK
)
async def logout_all(
db: AsyncSession = Depends(get_db),
current_user: User = Depends(
lambda token=Depends(oauth2_scheme), db=Depends(get_db): decode_token(token, db=db))
):
"""
Revoke all tokens for the current user, effectively logging them out across all devices.
"""
await db.execute("DELETE FROM revoked_tokens WHERE user_id = :user_id", {"user_id": str(current_user.id)})
await db.commit()
return {"message": "Logged out from all devices."}
# Updated: Refresh Token Endpoint
@router.post(
"/auth/refresh-token",
response_model=TokenResponse,
summary="Generate a new access token using a refresh token"
)
async def refresh_token(
refresh_token: RefreshToken,
db: AsyncSession = Depends(get_db)
) -> TokenResponse:
"""
Refresh the user's access token using their refresh token while ensuring it has not been revoked.
"""
payload: TokenPayload = await decode_token(refresh_token.refresh_token, required_type="refresh", db=db)
if await is_token_revoked(payload.jti, db):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=f"Token has been revoked.")
# Generate a new access token with the user's info
new_access_token = create_access_token({"sub": payload.sub, "type": "access"})
return TokenResponse(
access_token=new_access_token,
refresh_token=refresh_token.refresh_token, # Reuse existing refresh token
expires_in=1800, # Example: 30 minutes expiry for access token
token_type="bearer",
user_id=payload.sub,
)
# Existing: Get Current User Endpoint
@router.get(
"/auth/me",
response_model=UserResponse,
summary="Get user details from the token"
)
async def read_users_me(
current_user: User = Depends(
lambda token=Depends(oauth2_scheme), db=Depends(get_db): decode_token(token, db=db))
) -> UserResponse:
"""
Retrieves the details of the currently authenticated user.
"""
return current_user

View File

@@ -1,40 +0,0 @@
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_db
from app.auth.security import decode_token
from app.models.user import User
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token")
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_db)
):
try:
payload = await decode_token(token) # Use updated decode_token.
user_id: str = payload.sub
token_type: str = payload.type
if user_id is None or token_type != "access":
raise HTTPException(status_code=401, detail="Invalid token type.")
user = await db.get(User, user_id)
if user is None:
raise HTTPException(status_code=401, detail="User not found.")
return user
except JWTError as e:
raise HTTPException(status_code=401, detail=str(e))
async def get_current_active_user(
current_user: User = Depends(get_current_user),
):
if not current_user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user

View File

@@ -1,176 +0,0 @@
from datetime import datetime, timedelta, timezone
from typing import Optional
from uuid import uuid4
from fastapi import Depends
from jose import jwt, JWTError, ExpiredSignatureError, JOSEError
from passlib.context import CryptContext
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.database import get_db
from app.schemas.token import TokenPayload, TokenResponse
from auth.utils import is_token_revoked
# Configuration
SECRET_KEY = settings.SECRET_KEY
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
REFRESH_TOKEN_EXPIRE_DAYS = 7
# Password hashing context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a plain password against its hash."""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
"""Generate password hash."""
return pwd_context.hash(password)
def create_tokens(user_id: str) -> TokenResponse:
"""
Create both access and refresh tokens for a user.
Args:
user_id: The user's ID
Returns:
TokenResponse containing both tokens and metadata
"""
# Add `jti` during token creation
access_token = create_access_token({"sub": user_id, "jti": str(uuid4())})
refresh_token = create_refresh_token({"sub": user_id, "jti": str(uuid4())})
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
token_type="bearer",
expires_in=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
user_id=user_id,
scope="read write"
)
def create_token(
data: dict,
expires_delta: Optional[timedelta] = None,
token_type: str = "access"
) -> str:
"""Create a JWT token with the specified type and expiration."""
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + (
timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) if token_type == "access"
else timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
)
to_encode.update({
"exp": expire,
"type": token_type,
"iat": datetime.now(timezone.utc),
})
if "jti" not in to_encode:
to_encode["jti"] = str(uuid4()) # Ensure unique `jti` is always added
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create a new access token."""
# Ensure `data` includes `jti` for consistency
if "jti" not in data:
data["jti"] = str(uuid4())
return create_token(data, expires_delta, "access")
def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create a new refresh token."""
# Ensure `data` includes `jti` for consistency
if "jti" not in data:
data["jti"] = str(uuid4())
return create_token(data, expires_delta, "refresh")
async def decode_token(
token: str,
required_type: str = "access",
db: AsyncSession = Depends(get_db)
) -> TokenPayload:
"""
Decode and validate a JWT token, including revocation checks.
Args:
token (str): The JWT token to decode.
required_type (str): The expected token type (default: "access").
db (AsyncSession): Database session for token revocation checks.
Returns:
TokenPayload: The decoded token data.
Raises:
JWTError: If the token is expired, revoked, malformed, or fails validation.
"""
try:
# Step 1: Decode the JWT token
payload = jwt.decode(
token,
SECRET_KEY,
algorithms=[ALGORITHM],
options={
"verify_exp": True,
"verify_iat": True,
"require": ["exp", "iat", "sub", "type", "jti"]
}
)
except ExpiredSignatureError:
raise JWTError("Token has expired. Please refresh your token or login again.")
except JWTError as e:
if "Signature verification failed" in str(e):
raise JWTError("Invalid token signature. The token may have been tampered with or corrupted.")
raise JWTError(f"Failed to decode the token: {e}")
except JOSEError as e:
if "segments" in str(e).lower():
raise JWTError("Malformed token. The token format is invalid (e.g., not enough segments).")
raise JWTError("Failed to decode the token. Ensure the token is valid and correctly formatted.") from e
except Exception as e:
# Catch-all for unexpected exceptions during decoding
raise JWTError(f"An unexpected error occurred while decoding the token: {e}")
# Step 2: Validate Required Claims
required_claims = ["exp", "sub", "type", "jti"]
missing_claims = [claim for claim in required_claims if claim not in payload]
if missing_claims:
raise JWTError(f"Malformed token. Missing required claims: {', '.join(missing_claims)}.")
# Step 3: Validate Expiry
expiration = datetime.fromtimestamp(payload["exp"])
if datetime.now(timezone.utc) > expiration:
raise JWTError("Token has expired. Please refresh your token or login again.")
# Step 4: Validate Token Type
token_type = payload.get("type")
if token_type != required_type:
raise JWTError(f"Invalid token type: expected '{required_type}', got '{token_type}'.")
# Step 5: Check Revocation
jti = payload.get("jti")
if await is_token_revoked(jti, db):
raise JWTError("Token has been revoked. Please login again to generate a new token.")
# Step 6: Return Validated Token Payload
return TokenPayload(
sub=payload["sub"],
type=payload["type"],
exp=expiration,
iat=datetime.fromtimestamp(payload.get("iat", 0)),
jti=jti
)

View File

@@ -1,45 +0,0 @@
from datetime import datetime, timezone, timedelta
from sqlalchemy import delete
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.token import RevokedToken
async def revoke_token(jti: str, token_type: str, user_id: str, db: AsyncSession):
"""Revoke a token by storing its `jti` in the revoked_tokens table."""
revoked_token = RevokedToken(jti=jti, token_type=token_type, user_id=user_id)
db.add(revoked_token)
await db.commit()
async def is_token_revoked(jti: str, db: AsyncSession) -> bool:
"""Check whether the token's JTI is in the revoked_tokens table."""
from sqlalchemy import select
result = await db.execute(select(RevokedToken).where(RevokedToken.jti == jti))
revoked = result.scalar_one_or_none()
return revoked is not None
async def cleanup_expired_tokens(db: AsyncSession):
"""Delete revoked tokens that are past their expiration time."""
now = datetime.now(timezone.utc)
# For access tokens (shorter expiry)
expire_before = now - timedelta(days=1) # Keep for 1 day past expiry
await db.execute(
delete(RevokedToken).where(
(RevokedToken.token_type == "access") &
(RevokedToken.created_at < expire_before)
)
)
# For refresh tokens (longer expiry)
expire_before = now - timedelta(days=14) # Keep for 14 days past expiry
await db.execute(
delete(RevokedToken).where(
(RevokedToken.token_type == "refresh") &
(RevokedToken.created_at < expire_before)
)
)
await db.commit()

183
backend/app/core/auth.py Normal file
View File

@@ -0,0 +1,183 @@
# app/core/auth.py
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional, Union
import uuid
from jose import jwt, JWTError
from passlib.context import CryptContext
from pydantic import ValidationError
from app.core.config import settings
from app.schemas.users import TokenData, TokenPayload
# Password hashing context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# Custom exceptions for auth
class AuthError(Exception):
"""Base authentication error"""
pass
class TokenExpiredError(AuthError):
"""Token has expired"""
pass
class TokenInvalidError(AuthError):
"""Token is invalid"""
pass
class TokenMissingClaimError(AuthError):
"""Token is missing a required claim"""
pass
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against a hash."""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
"""Generate a password hash."""
return pwd_context.hash(password)
def create_access_token(
subject: Union[str, Any],
expires_delta: Optional[timedelta] = None,
claims: Optional[Dict[str, Any]] = None
) -> str:
"""
Create a JWT access token.
Args:
subject: The subject of the token (usually user_id)
expires_delta: Optional expiration time delta
claims: Optional additional claims to include in the token
Returns:
Encoded JWT token
"""
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
# Base token data
to_encode = {
"sub": str(subject),
"exp": expire,
"iat": datetime.now(tz=timezone.utc),
"jti": str(uuid.uuid4()),
"type": "access"
}
# Add custom claims
if claims:
to_encode.update(claims)
# Create the JWT
encoded_jwt = jwt.encode(
to_encode,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
)
return encoded_jwt
def create_refresh_token(
subject: Union[str, Any],
expires_delta: Optional[timedelta] = None
) -> str:
"""
Create a JWT refresh token.
Args:
subject: The subject of the token (usually user_id)
expires_delta: Optional expiration time delta
Returns:
Encoded JWT refresh token
"""
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
to_encode = {
"sub": str(subject),
"exp": expire,
"iat": datetime.now(timezone.utc),
"jti": str(uuid.uuid4()),
"type": "refresh"
}
encoded_jwt = jwt.encode(
to_encode,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
)
return encoded_jwt
def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload:
"""
Decode and verify a JWT token.
Args:
token: JWT token to decode
verify_type: Optional token type to verify (access or refresh)
Returns:
TokenPayload: The decoded token data
Raises:
TokenExpiredError: If the token has expired
TokenInvalidError: If the token is invalid
TokenMissingClaimError: If a required claim is missing
"""
try:
payload = jwt.decode(
token,
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM]
)
# Check required claims before Pydantic validation
if not payload.get("sub"):
raise TokenMissingClaimError("Token missing 'sub' claim")
# Verify token type if specified
if verify_type and payload.get("type") != verify_type:
raise TokenInvalidError(f"Invalid token type, expected {verify_type}")
# Now validate with Pydantic
token_data = TokenPayload(**payload)
return token_data
except JWTError as e:
# Check if the error is due to an expired token
if "expired" in str(e).lower():
raise TokenExpiredError("Token has expired")
raise TokenInvalidError("Invalid authentication token")
except ValidationError:
raise TokenInvalidError("Invalid token payload")
def get_token_data(token: str) -> TokenData:
"""
Extract the user ID and superuser status from a token.
Args:
token: JWT token
Returns:
TokenData with user_id and is_superuser
"""
payload = decode_token(token)
user_id = payload.sub
is_superuser = payload.is_superuser or False
return TokenData(user_id=uuid.UUID(user_id), is_superuser=is_superuser)

View File

@@ -14,6 +14,7 @@ class Settings(BaseSettings):
POSTGRES_PORT: str = "5432"
POSTGRES_DB: str = "eventspace"
DATABASE_URL: Optional[str] = None
REFRESH_TOKEN_EXPIRE_DAYS: int = 60
db_pool_size: int = 20 # Default connection pool size
db_max_overflow: int = 50 # Maximum overflow connections
db_pool_timeout: int = 30 # Seconds to wait for a connection
@@ -24,6 +25,7 @@ class Settings(BaseSettings):
sql_echo_pool: bool = False # Log connection pool events
sql_echo_timing: bool = False # Log query execution times
slow_query_threshold: float = 0.5 # Log queries taking longer than this
@property
def database_url(self) -> str:
"""

View File

@@ -1,15 +1,12 @@
import logging
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from app.core.config import settings
from app.api.main import api_router
import logging
from auth.utils import cleanup_expired_tokens
from app.core.database import SessionLocal
from app.core.config import settings
scheduler = AsyncIOScheduler()
@@ -32,26 +29,6 @@ app.add_middleware(
)
# Create a function that gets its own database session
async def scheduled_cleanup():
async with SessionLocal() as db:
await cleanup_expired_tokens(db)
@app.on_event("startup")
async def start_scheduler():
# Run every day at 3 AM
scheduler.add_job(
scheduled_cleanup,
CronTrigger(hour=10, minute=0),
id="token_cleanup",
name="Clean up expired revoked tokens"
)
scheduler.start()
@app.on_event("shutdown")
async def stop_scheduler():
scheduler.shutdown()
@app.get("/", response_class=HTMLResponse)
async def root():
return """
@@ -67,4 +44,5 @@ async def root():
</html>
"""
app.include_router(api_router, prefix=settings.API_V1_STR)
app.include_router(api_router, prefix=settings.API_V1_STR)

View File

@@ -29,7 +29,6 @@ from .gift import (
from .email_template import EmailTemplate, TemplateType
from .notification_log import NotificationLog, NotificationType, NotificationStatus
from .activity_log import ActivityLog, ActivityType
from .token import RevokedToken
# Make sure all models are imported above this line
__all__ = [
'Base', 'TimestampMixin', 'UUIDMixin',
@@ -40,5 +39,4 @@ __all__ = [
'EmailTemplate', 'TemplateType',
'NotificationLog', 'NotificationType', 'NotificationStatus',
'ActivityLog', 'ActivityType',
'RevokedToken',
]

View File

@@ -1,15 +0,0 @@
from sqlalchemy import Column, String, ForeignKey
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from app.models.base import Base, TimestampMixin, UUIDMixin
class RevokedToken(UUIDMixin, TimestampMixin, Base):
"""Model to store revoked JWT tokens via their jti (JWT ID)."""
__tablename__ = "revoked_tokens"
jti = Column(String(length=50), nullable=False, unique=True, index=True)
token_type = Column(String(length=20), nullable=False)
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"))
user = relationship("User", back_populates="revoked_tokens")

View File

@@ -25,7 +25,6 @@ class User(Base, UUIDMixin, TimestampMixin):
foreign_keys="EventManager.user_id"
)
guest_profiles = relationship("Guest", back_populates="user", foreign_keys="Guest.user_id")
revoked_tokens = relationship("RevokedToken", back_populates="user", cascade="all, delete")
def __repr__(self):
return f"<User {self.email}>"

View File

@@ -1,66 +0,0 @@
from typing import Optional
from uuid import UUID
from pydantic import BaseModel, EmailStr, Field, field_validator
from datetime import datetime
from passlib.hash import bcrypt
# Base schema with shared user attributes
class UserBase(BaseModel):
"""Base schema with common user attributes."""
email: EmailStr
first_name: str
last_name: str
# Schema for creating a new user
class UserCreate(UserBase):
"""Schema for user registration."""
password: str = Field(
...,
min_length=8,
description="Password must be at least 8 characters"
)
@field_validator('password')
def password_strength(cls, v):
# Add more complex password validation if needed
if len(v) < 8:
raise ValueError('Password must be at least 8 characters')
return v
def hash_password(self) -> str:
"""Hash the password before saving it to the database."""
return bcrypt.hash(self.password)
# Schema for updating user details
class UserUpdate(BaseModel):
"""Schema for updating user information."""
email: Optional[EmailStr] = None
first_name: Optional[str] = None
last_name: Optional[str] = None
phone_number: Optional[str] = None
is_active: Optional[bool] = None
preferences: Optional[dict] = None # Provide preferences support
# Schema for user responses (read-only fields)
class UserResponse(UserBase):
"""Schema for user responses in API."""
id: UUID
is_active: bool
is_superuser: bool # Include roles or superuser flags if needed
preferences: Optional[dict] = None # Include preferences in response
created_at: datetime
updated_at: Optional[datetime] = None
class Config:
orm_mode = True # Enable mapping SQLAlchemy models to this schema
# Schema for user authentication (e.g., login requests)
class UserAuth(BaseModel):
"""Schema for user authentication."""
email: EmailStr
password: str

View File

@@ -0,0 +1,126 @@
# app/schemas/users.py
import re
from datetime import datetime
from typing import Optional, Dict, Any
from uuid import UUID
import pydantic
from pydantic import BaseModel, EmailStr, field_validator, ConfigDict
class UserBase(BaseModel):
email: EmailStr
first_name: str
last_name: str
phone_number: Optional[str] = None
@field_validator('phone_number')
@classmethod
def validate_phone_number(cls, v: Optional[str]) -> Optional[str]:
if v is None:
return v
# Simple regex for phone validation
if not re.match(r'^\+?[0-9\s\-\(\)]{8,20}$', v):
raise ValueError('Invalid phone number format')
return v
class UserCreate(UserBase):
password: str
@field_validator('password')
@classmethod
def password_strength(cls, v: str) -> str:
"""Basic password strength validation"""
if len(v) < 8:
raise ValueError('Password must be at least 8 characters')
if not any(char.isdigit() for char in v):
raise ValueError('Password must contain at least one digit')
if not any(char.isupper() for char in v):
raise ValueError('Password must contain at least one uppercase letter')
return v
class UserUpdate(BaseModel):
first_name: Optional[str] = None
last_name: Optional[str] = None
phone_number: Optional[str] = None
preferences: Optional[Dict[str, Any]] = None
@field_validator('phone_number')
@classmethod
def validate_phone_number(cls, v: Optional[str]) -> Optional[str]:
if v is None:
return v
# Simple regex for phone validation
if not re.match(r'^\+?[0-9\s\-\(\)]{8,20}$', v):
raise ValueError('Invalid phone number format')
return v
class UserInDB(UserBase):
id: UUID
is_active: bool
is_superuser: bool
created_at: datetime
updated_at: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
class UserResponse(UserBase):
id: UUID
is_active: bool
is_superuser: bool
created_at: datetime
updated_at: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
class Token(BaseModel):
access_token: str
refresh_token: Optional[str] = None
token_type: str = "bearer"
class TokenPayload(BaseModel):
sub: str # User ID
exp: int # Expiration time
iat: Optional[int] = None # Issued at
jti: Optional[str] = None # JWT ID
is_superuser: Optional[bool] = False
first_name: Optional[str] = None
email: Optional[str] = None
type: Optional[str] = None # Token type (access/refresh)
class TokenData(BaseModel):
user_id: UUID
is_superuser: bool = False
class PasswordReset(BaseModel):
token: str
new_password: str
@field_validator('new_password')
@classmethod
def password_strength(cls, v: str) -> str:
"""Basic password strength validation"""
if len(v) < 8:
raise ValueError('Password must be at least 8 characters')
if not any(char.isdigit() for char in v):
raise ValueError('Password must contain at least one digit')
if not any(char.isupper() for char in v):
raise ValueError('Password must contain at least one uppercase letter')
return v
class LoginRequest(BaseModel):
email: EmailStr
password: str
class RefreshTokenRequest(BaseModel):
refresh_token: str

View File

View File

@@ -0,0 +1,193 @@
# app/services/auth_service.py
import logging
from typing import Optional
from uuid import UUID
from sqlalchemy.orm import Session
from app.core.auth import (
verify_password,
get_password_hash,
create_access_token,
create_refresh_token,
TokenExpiredError,
TokenInvalidError
)
from app.models.user import User
from app.schemas.users import Token, UserCreate
logger = logging.getLogger(__name__)
class AuthenticationError(Exception):
"""Exception raised for authentication errors"""
pass
class AuthService:
"""Service for handling authentication operations"""
@staticmethod
def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
"""
Authenticate a user with email and password.
Args:
db: Database session
email: User email
password: User password
Returns:
User if authenticated, None otherwise
"""
user = db.query(User).filter(User.email == email).first()
if not user:
return None
if not verify_password(password, user.password_hash):
return None
if not user.is_active:
raise AuthenticationError("User account is inactive")
return user
@staticmethod
def create_user(db: Session, user_data: UserCreate) -> User:
"""
Create a new user.
Args:
db: Database session
user_data: User data
Returns:
Created user
"""
# Check if user already exists
existing_user = db.query(User).filter(User.email == user_data.email).first()
if existing_user:
raise AuthenticationError("User with this email already exists")
# Create new user
hashed_password = get_password_hash(user_data.password)
# Create user object from model
user = User(
email=user_data.email,
password_hash=hashed_password,
first_name=user_data.first_name,
last_name=user_data.last_name,
phone_number=user_data.phone_number,
is_active=True,
is_superuser=False
)
db.add(user)
db.commit()
db.refresh(user)
return user
@staticmethod
def create_tokens(user: User) -> Token:
"""
Create access and refresh tokens for a user.
Args:
user: User to create tokens for
Returns:
Token object with access and refresh tokens
"""
# Generate claims
claims = {
"is_superuser": user.is_superuser,
"email": user.email,
"first_name": user.first_name
}
# Create tokens
access_token = create_access_token(
subject=str(user.id),
claims=claims
)
refresh_token = create_refresh_token(
subject=str(user.id)
)
return Token(
access_token=access_token,
refresh_token=refresh_token
)
@staticmethod
def refresh_tokens(db: Session, refresh_token: str) -> Token:
"""
Generate new tokens using a refresh token.
Args:
db: Database session
refresh_token: Valid refresh token
Returns:
New access and refresh tokens
Raises:
TokenExpiredError: If refresh token has expired
TokenInvalidError: If refresh token is invalid
"""
from app.core.auth import decode_token, get_token_data
try:
# Verify token is a refresh token
decode_token(refresh_token, verify_type="refresh")
# Get user ID from token
token_data = get_token_data(refresh_token)
user_id = token_data.user_id
# Get user from database
user = db.query(User).filter(User.id == user_id).first()
if not user or not user.is_active:
raise TokenInvalidError("Invalid user or inactive account")
# Generate new tokens
return AuthService.create_tokens(user)
except (TokenExpiredError, TokenInvalidError) as e:
logger.warning(f"Token refresh failed: {str(e)}")
raise
@staticmethod
def change_password(db: Session, user_id: UUID, current_password: str, new_password: str) -> bool:
"""
Change a user's password.
Args:
db: Database session
user_id: User ID
current_password: Current password
new_password: New password
Returns:
True if password was changed successfully
Raises:
AuthenticationError: If current password is incorrect
"""
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise AuthenticationError("User not found")
# Verify current password
if not verify_password(current_password, user.password_hash):
raise AuthenticationError("Current password is incorrect")
# Update password
user.password_hash = get_password_hash(new_password)
db.commit()
return True

View File

@@ -1,5 +1,6 @@
import logging
from sqlalchemy import create_engine, event
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker, clear_mappers
from sqlalchemy.pool import StaticPool
@@ -42,4 +43,37 @@ def teardown_test_db(engine):
Base.metadata.drop_all(engine)
# Dispose of engine
engine.dispose()
engine.dispose()
async def get_async_test_engine():
"""Create an async SQLite in-memory engine specifically for testing"""
test_engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool, # Use static pool for in-memory testing
echo=False
)
return test_engine
async def setup_async_test_db():
"""Create an async test database and session factory"""
test_engine = await get_async_test_engine()
async with test_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
AsyncTestingSessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=test_engine,
expire_on_commit=False,
class_=AsyncSession
)
return test_engine, AsyncTestingSessionLocal
async def teardown_async_test_db(engine):
"""Clean up after async tests"""
await engine.dispose()