Remove token revocation logic and unused dependencies
Eliminated the `RevokedToken` model and associated logic for managing token revocation. Removed unused files, related tests, and outdated dependencies in authentication modules. Simplified token decoding, user validation, and dependency injection by streamlining the flow and enhancing maintainability.
This commit is contained in:
@@ -1,41 +0,0 @@
|
||||
"""Add RevokedToken model
|
||||
|
||||
Revision ID: 37315a5b4021
|
||||
Revises: 38bf9e7e74b3
|
||||
Create Date: 2025-02-28 17:11:07.741372
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '37315a5b4021'
|
||||
down_revision: Union[str, None] = '38bf9e7e74b3'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('revoked_tokens',
|
||||
sa.Column('jti', sa.String(length=50), nullable=False),
|
||||
sa.Column('token_type', sa.String(length=20), nullable=False),
|
||||
sa.Column('user_id', sa.UUID(), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_revoked_tokens_jti'), 'revoked_tokens', ['jti'], unique=True)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('ix_revoked_tokens_jti'), table_name='revoked_tokens')
|
||||
op.drop_table('revoked_tokens')
|
||||
# ### end Alembic commands ###
|
||||
138
backend/app/api/dependencies.py
Normal file
138
backend/app/api/dependencies.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# app/api/dependencies/auth.py
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.auth import get_token_data, TokenExpiredError, TokenInvalidError
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
|
||||
# OAuth2 configuration
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
|
||||
|
||||
def get_current_user(
|
||||
db: Session = Depends(get_db),
|
||||
token: str = Depends(oauth2_scheme)
|
||||
) -> User:
|
||||
"""
|
||||
Get the current authenticated user.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
token: JWT token from request
|
||||
|
||||
Returns:
|
||||
User: The authenticated user
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails
|
||||
"""
|
||||
try:
|
||||
# Decode token and get user ID
|
||||
token_data = get_token_data(token)
|
||||
|
||||
# Get user from database
|
||||
user = db.query(User).filter(User.id == token_data.user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user"
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
except TokenExpiredError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token expired",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
except TokenInvalidError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
|
||||
|
||||
def get_current_active_user(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> User:
|
||||
"""
|
||||
Check if the current user is active.
|
||||
|
||||
Args:
|
||||
current_user: The current authenticated user
|
||||
|
||||
Returns:
|
||||
User: The authenticated and active user
|
||||
|
||||
Raises:
|
||||
HTTPException: If user is inactive
|
||||
"""
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user"
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
def get_current_superuser(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> User:
|
||||
"""
|
||||
Check if the current user is a superuser.
|
||||
|
||||
Args:
|
||||
current_user: The current authenticated user
|
||||
|
||||
Returns:
|
||||
User: The authenticated superuser
|
||||
|
||||
Raises:
|
||||
HTTPException: If user is not a superuser
|
||||
"""
|
||||
if not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not enough permissions"
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
def get_optional_current_user(
|
||||
db: Session = Depends(get_db),
|
||||
token: Optional[str] = Depends(oauth2_scheme)
|
||||
) -> Optional[User]:
|
||||
"""
|
||||
Get the current user if authenticated, otherwise return None.
|
||||
Useful for endpoints that work with both authenticated and unauthenticated users.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
token: JWT token from request
|
||||
|
||||
Returns:
|
||||
User or None: The authenticated user or None
|
||||
"""
|
||||
if not token:
|
||||
return None
|
||||
|
||||
try:
|
||||
token_data = get_token_data(token)
|
||||
user = db.query(User).filter(User.id == token_data.user_id).first()
|
||||
if not user or not user.is_active:
|
||||
return None
|
||||
return user
|
||||
except (TokenExpiredError, TokenInvalidError):
|
||||
return None
|
||||
@@ -1,134 +1,3 @@
|
||||
from typing import Any
|
||||
|
||||
from app.auth.utils import revoke_token, is_token_revoked
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.security import authenticate_user, create_access_token, create_refresh_token, decode_token
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.token import TokenResponse, TokenPayload, RefreshToken
|
||||
from app.schemas.user import UserResponse
|
||||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter()
|
||||
oauth2_scheme = OAuth2PasswordRequestForm
|
||||
|
||||
|
||||
# Existing: User Login Endpoint
|
||||
@router.post(
|
||||
"/auth/login",
|
||||
response_model=TokenResponse,
|
||||
summary="Authenticate user and provide tokens"
|
||||
)
|
||||
async def login(
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Authenticate a user with their credentials and return an access and refresh token.
|
||||
"""
|
||||
user = await authenticate_user(email=form_data.username, password=form_data.password, db=db)
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials.")
|
||||
|
||||
# Generate access and refresh tokens
|
||||
access_token = create_access_token({"sub": str(user.id), "type": "access"})
|
||||
refresh_token = create_refresh_token({"sub": str(user.id), "type": "refresh"})
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
token_type="bearer",
|
||||
expires_in=1800, # Example: 30 minutes for access token
|
||||
user_id=str(user.id),
|
||||
)
|
||||
|
||||
|
||||
# New: Logout Endpoint (Revoke Token)
|
||||
@router.post(
|
||||
"/auth/logout",
|
||||
summary="Revoke the current token",
|
||||
response_model=dict,
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
async def logout(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(
|
||||
lambda token=Depends(oauth2_scheme), db=Depends(get_db): decode_token(token, db=db))
|
||||
):
|
||||
"""
|
||||
Logout the user by revoking the current token.
|
||||
"""
|
||||
# Decode the token and revoke it
|
||||
payload: TokenPayload = await decode_token(token, db=db)
|
||||
await revoke_token(payload.jti, payload.type, payload.sub, db)
|
||||
|
||||
return {"message": "Successfully logged out."}
|
||||
|
||||
|
||||
# New: Bulk Logout (Revoke All of a User's Tokens)
|
||||
@router.post(
|
||||
"/auth/logout-all",
|
||||
summary="Revoke all active tokens for the user",
|
||||
response_model=dict,
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
async def logout_all(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(
|
||||
lambda token=Depends(oauth2_scheme), db=Depends(get_db): decode_token(token, db=db))
|
||||
):
|
||||
"""
|
||||
Revoke all tokens for the current user, effectively logging them out across all devices.
|
||||
"""
|
||||
await db.execute("DELETE FROM revoked_tokens WHERE user_id = :user_id", {"user_id": str(current_user.id)})
|
||||
await db.commit()
|
||||
|
||||
return {"message": "Logged out from all devices."}
|
||||
|
||||
|
||||
# Updated: Refresh Token Endpoint
|
||||
@router.post(
|
||||
"/auth/refresh-token",
|
||||
response_model=TokenResponse,
|
||||
summary="Generate a new access token using a refresh token"
|
||||
)
|
||||
async def refresh_token(
|
||||
refresh_token: RefreshToken,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> TokenResponse:
|
||||
"""
|
||||
Refresh the user's access token using their refresh token while ensuring it has not been revoked.
|
||||
"""
|
||||
payload: TokenPayload = await decode_token(refresh_token.refresh_token, required_type="refresh", db=db)
|
||||
|
||||
if await is_token_revoked(payload.jti, db):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=f"Token has been revoked.")
|
||||
|
||||
# Generate a new access token with the user's info
|
||||
new_access_token = create_access_token({"sub": payload.sub, "type": "access"})
|
||||
return TokenResponse(
|
||||
access_token=new_access_token,
|
||||
refresh_token=refresh_token.refresh_token, # Reuse existing refresh token
|
||||
expires_in=1800, # Example: 30 minutes expiry for access token
|
||||
token_type="bearer",
|
||||
user_id=payload.sub,
|
||||
)
|
||||
|
||||
|
||||
# Existing: Get Current User Endpoint
|
||||
@router.get(
|
||||
"/auth/me",
|
||||
response_model=UserResponse,
|
||||
summary="Get user details from the token"
|
||||
)
|
||||
async def read_users_me(
|
||||
current_user: User = Depends(
|
||||
lambda token=Depends(oauth2_scheme), db=Depends(get_db): decode_token(token, db=db))
|
||||
) -> UserResponse:
|
||||
"""
|
||||
Retrieves the details of the currently authenticated user.
|
||||
"""
|
||||
return current_user
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.auth.security import decode_token
|
||||
from app.models.user import User
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token")
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
try:
|
||||
payload = await decode_token(token) # Use updated decode_token.
|
||||
user_id: str = payload.sub
|
||||
token_type: str = payload.type
|
||||
|
||||
if user_id is None or token_type != "access":
|
||||
raise HTTPException(status_code=401, detail="Invalid token type.")
|
||||
|
||||
user = await db.get(User, user_id)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="User not found.")
|
||||
|
||||
return user
|
||||
except JWTError as e:
|
||||
raise HTTPException(status_code=401, detail=str(e))
|
||||
|
||||
|
||||
|
||||
async def get_current_active_user(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
return current_user
|
||||
@@ -1,176 +0,0 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import Depends
|
||||
from jose import jwt, JWTError, ExpiredSignatureError, JOSEError
|
||||
from passlib.context import CryptContext
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import get_db
|
||||
from app.schemas.token import TokenPayload, TokenResponse
|
||||
from auth.utils import is_token_revoked
|
||||
|
||||
# Configuration
|
||||
SECRET_KEY = settings.SECRET_KEY
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
||||
REFRESH_TOKEN_EXPIRE_DAYS = 7
|
||||
|
||||
# Password hashing context
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a plain password against its hash."""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""Generate password hash."""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def create_tokens(user_id: str) -> TokenResponse:
|
||||
"""
|
||||
Create both access and refresh tokens for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
TokenResponse containing both tokens and metadata
|
||||
"""
|
||||
# Add `jti` during token creation
|
||||
access_token = create_access_token({"sub": user_id, "jti": str(uuid4())})
|
||||
refresh_token = create_refresh_token({"sub": user_id, "jti": str(uuid4())})
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
token_type="bearer",
|
||||
expires_in=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
user_id=user_id,
|
||||
scope="read write"
|
||||
)
|
||||
|
||||
|
||||
def create_token(
|
||||
data: dict,
|
||||
expires_delta: Optional[timedelta] = None,
|
||||
token_type: str = "access"
|
||||
) -> str:
|
||||
"""Create a JWT token with the specified type and expiration."""
|
||||
to_encode = data.copy()
|
||||
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + (
|
||||
timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) if token_type == "access"
|
||||
else timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
)
|
||||
|
||||
to_encode.update({
|
||||
"exp": expire,
|
||||
"type": token_type,
|
||||
"iat": datetime.now(timezone.utc),
|
||||
})
|
||||
if "jti" not in to_encode:
|
||||
to_encode["jti"] = str(uuid4()) # Ensure unique `jti` is always added
|
||||
|
||||
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""Create a new access token."""
|
||||
# Ensure `data` includes `jti` for consistency
|
||||
if "jti" not in data:
|
||||
data["jti"] = str(uuid4())
|
||||
return create_token(data, expires_delta, "access")
|
||||
|
||||
|
||||
def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""Create a new refresh token."""
|
||||
# Ensure `data` includes `jti` for consistency
|
||||
if "jti" not in data:
|
||||
data["jti"] = str(uuid4())
|
||||
return create_token(data, expires_delta, "refresh")
|
||||
|
||||
|
||||
async def decode_token(
|
||||
token: str,
|
||||
required_type: str = "access",
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> TokenPayload:
|
||||
"""
|
||||
Decode and validate a JWT token, including revocation checks.
|
||||
|
||||
Args:
|
||||
token (str): The JWT token to decode.
|
||||
required_type (str): The expected token type (default: "access").
|
||||
db (AsyncSession): Database session for token revocation checks.
|
||||
|
||||
Returns:
|
||||
TokenPayload: The decoded token data.
|
||||
|
||||
Raises:
|
||||
JWTError: If the token is expired, revoked, malformed, or fails validation.
|
||||
"""
|
||||
try:
|
||||
# Step 1: Decode the JWT token
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
SECRET_KEY,
|
||||
algorithms=[ALGORITHM],
|
||||
options={
|
||||
"verify_exp": True,
|
||||
"verify_iat": True,
|
||||
"require": ["exp", "iat", "sub", "type", "jti"]
|
||||
}
|
||||
)
|
||||
|
||||
except ExpiredSignatureError:
|
||||
raise JWTError("Token has expired. Please refresh your token or login again.")
|
||||
except JWTError as e:
|
||||
if "Signature verification failed" in str(e):
|
||||
raise JWTError("Invalid token signature. The token may have been tampered with or corrupted.")
|
||||
raise JWTError(f"Failed to decode the token: {e}")
|
||||
except JOSEError as e:
|
||||
if "segments" in str(e).lower():
|
||||
raise JWTError("Malformed token. The token format is invalid (e.g., not enough segments).")
|
||||
raise JWTError("Failed to decode the token. Ensure the token is valid and correctly formatted.") from e
|
||||
except Exception as e:
|
||||
# Catch-all for unexpected exceptions during decoding
|
||||
raise JWTError(f"An unexpected error occurred while decoding the token: {e}")
|
||||
|
||||
# Step 2: Validate Required Claims
|
||||
required_claims = ["exp", "sub", "type", "jti"]
|
||||
missing_claims = [claim for claim in required_claims if claim not in payload]
|
||||
if missing_claims:
|
||||
raise JWTError(f"Malformed token. Missing required claims: {', '.join(missing_claims)}.")
|
||||
|
||||
# Step 3: Validate Expiry
|
||||
expiration = datetime.fromtimestamp(payload["exp"])
|
||||
if datetime.now(timezone.utc) > expiration:
|
||||
raise JWTError("Token has expired. Please refresh your token or login again.")
|
||||
|
||||
# Step 4: Validate Token Type
|
||||
token_type = payload.get("type")
|
||||
if token_type != required_type:
|
||||
raise JWTError(f"Invalid token type: expected '{required_type}', got '{token_type}'.")
|
||||
|
||||
# Step 5: Check Revocation
|
||||
jti = payload.get("jti")
|
||||
if await is_token_revoked(jti, db):
|
||||
raise JWTError("Token has been revoked. Please login again to generate a new token.")
|
||||
|
||||
# Step 6: Return Validated Token Payload
|
||||
return TokenPayload(
|
||||
sub=payload["sub"],
|
||||
type=payload["type"],
|
||||
exp=expiration,
|
||||
iat=datetime.fromtimestamp(payload.get("iat", 0)),
|
||||
jti=jti
|
||||
)
|
||||
@@ -1,45 +0,0 @@
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.models.token import RevokedToken
|
||||
|
||||
|
||||
async def revoke_token(jti: str, token_type: str, user_id: str, db: AsyncSession):
|
||||
"""Revoke a token by storing its `jti` in the revoked_tokens table."""
|
||||
revoked_token = RevokedToken(jti=jti, token_type=token_type, user_id=user_id)
|
||||
db.add(revoked_token)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def is_token_revoked(jti: str, db: AsyncSession) -> bool:
|
||||
"""Check whether the token's JTI is in the revoked_tokens table."""
|
||||
from sqlalchemy import select
|
||||
result = await db.execute(select(RevokedToken).where(RevokedToken.jti == jti))
|
||||
revoked = result.scalar_one_or_none()
|
||||
return revoked is not None
|
||||
|
||||
|
||||
async def cleanup_expired_tokens(db: AsyncSession):
|
||||
"""Delete revoked tokens that are past their expiration time."""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# For access tokens (shorter expiry)
|
||||
expire_before = now - timedelta(days=1) # Keep for 1 day past expiry
|
||||
await db.execute(
|
||||
delete(RevokedToken).where(
|
||||
(RevokedToken.token_type == "access") &
|
||||
(RevokedToken.created_at < expire_before)
|
||||
)
|
||||
)
|
||||
|
||||
# For refresh tokens (longer expiry)
|
||||
expire_before = now - timedelta(days=14) # Keep for 14 days past expiry
|
||||
await db.execute(
|
||||
delete(RevokedToken).where(
|
||||
(RevokedToken.token_type == "refresh") &
|
||||
(RevokedToken.created_at < expire_before)
|
||||
)
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
183
backend/app/core/auth.py
Normal file
183
backend/app/core/auth.py
Normal file
@@ -0,0 +1,183 @@
|
||||
# app/core/auth.py
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, Optional, Union
|
||||
import uuid
|
||||
|
||||
from jose import jwt, JWTError
|
||||
from passlib.context import CryptContext
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.core.config import settings
|
||||
from app.schemas.users import TokenData, TokenPayload
|
||||
|
||||
|
||||
# Password hashing context
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
# Custom exceptions for auth
|
||||
class AuthError(Exception):
|
||||
"""Base authentication error"""
|
||||
pass
|
||||
|
||||
class TokenExpiredError(AuthError):
|
||||
"""Token has expired"""
|
||||
pass
|
||||
|
||||
class TokenInvalidError(AuthError):
|
||||
"""Token is invalid"""
|
||||
pass
|
||||
|
||||
class TokenMissingClaimError(AuthError):
|
||||
"""Token is missing a required claim"""
|
||||
pass
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against a hash."""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""Generate a password hash."""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def create_access_token(
|
||||
subject: Union[str, Any],
|
||||
expires_delta: Optional[timedelta] = None,
|
||||
claims: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Create a JWT access token.
|
||||
|
||||
Args:
|
||||
subject: The subject of the token (usually user_id)
|
||||
expires_delta: Optional expiration time delta
|
||||
claims: Optional additional claims to include in the token
|
||||
|
||||
Returns:
|
||||
Encoded JWT token
|
||||
"""
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
# Base token data
|
||||
to_encode = {
|
||||
"sub": str(subject),
|
||||
"exp": expire,
|
||||
"iat": datetime.now(tz=timezone.utc),
|
||||
"jti": str(uuid.uuid4()),
|
||||
"type": "access"
|
||||
}
|
||||
|
||||
# Add custom claims
|
||||
if claims:
|
||||
to_encode.update(claims)
|
||||
|
||||
# Create the JWT
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def create_refresh_token(
|
||||
subject: Union[str, Any],
|
||||
expires_delta: Optional[timedelta] = None
|
||||
) -> str:
|
||||
"""
|
||||
Create a JWT refresh token.
|
||||
|
||||
Args:
|
||||
subject: The subject of the token (usually user_id)
|
||||
expires_delta: Optional expiration time delta
|
||||
|
||||
Returns:
|
||||
Encoded JWT refresh token
|
||||
"""
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
to_encode = {
|
||||
"sub": str(subject),
|
||||
"exp": expire,
|
||||
"iat": datetime.now(timezone.utc),
|
||||
"jti": str(uuid.uuid4()),
|
||||
"type": "refresh"
|
||||
}
|
||||
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload:
|
||||
"""
|
||||
Decode and verify a JWT token.
|
||||
|
||||
Args:
|
||||
token: JWT token to decode
|
||||
verify_type: Optional token type to verify (access or refresh)
|
||||
|
||||
Returns:
|
||||
TokenPayload: The decoded token data
|
||||
|
||||
Raises:
|
||||
TokenExpiredError: If the token has expired
|
||||
TokenInvalidError: If the token is invalid
|
||||
TokenMissingClaimError: If a required claim is missing
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
|
||||
# Check required claims before Pydantic validation
|
||||
if not payload.get("sub"):
|
||||
raise TokenMissingClaimError("Token missing 'sub' claim")
|
||||
|
||||
# Verify token type if specified
|
||||
if verify_type and payload.get("type") != verify_type:
|
||||
raise TokenInvalidError(f"Invalid token type, expected {verify_type}")
|
||||
|
||||
# Now validate with Pydantic
|
||||
token_data = TokenPayload(**payload)
|
||||
return token_data
|
||||
|
||||
except JWTError as e:
|
||||
# Check if the error is due to an expired token
|
||||
if "expired" in str(e).lower():
|
||||
raise TokenExpiredError("Token has expired")
|
||||
raise TokenInvalidError("Invalid authentication token")
|
||||
except ValidationError:
|
||||
raise TokenInvalidError("Invalid token payload")
|
||||
|
||||
|
||||
def get_token_data(token: str) -> TokenData:
|
||||
"""
|
||||
Extract the user ID and superuser status from a token.
|
||||
|
||||
Args:
|
||||
token: JWT token
|
||||
|
||||
Returns:
|
||||
TokenData with user_id and is_superuser
|
||||
"""
|
||||
payload = decode_token(token)
|
||||
user_id = payload.sub
|
||||
is_superuser = payload.is_superuser or False
|
||||
|
||||
return TokenData(user_id=uuid.UUID(user_id), is_superuser=is_superuser)
|
||||
@@ -14,6 +14,7 @@ class Settings(BaseSettings):
|
||||
POSTGRES_PORT: str = "5432"
|
||||
POSTGRES_DB: str = "eventspace"
|
||||
DATABASE_URL: Optional[str] = None
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: int = 60
|
||||
db_pool_size: int = 20 # Default connection pool size
|
||||
db_max_overflow: int = 50 # Maximum overflow connections
|
||||
db_pool_timeout: int = 30 # Seconds to wait for a connection
|
||||
@@ -24,6 +25,7 @@ class Settings(BaseSettings):
|
||||
sql_echo_pool: bool = False # Log connection pool events
|
||||
sql_echo_timing: bool = False # Log query execution times
|
||||
slow_query_threshold: float = 0.5 # Log queries taking longer than this
|
||||
|
||||
@property
|
||||
def database_url(self) -> str:
|
||||
"""
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
import logging
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from app.core.config import settings
|
||||
from app.api.main import api_router
|
||||
import logging
|
||||
|
||||
from auth.utils import cleanup_expired_tokens
|
||||
from app.core.database import SessionLocal
|
||||
from app.core.config import settings
|
||||
|
||||
scheduler = AsyncIOScheduler()
|
||||
|
||||
@@ -32,26 +29,6 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
|
||||
# Create a function that gets its own database session
|
||||
async def scheduled_cleanup():
|
||||
async with SessionLocal() as db:
|
||||
await cleanup_expired_tokens(db)
|
||||
|
||||
@app.on_event("startup")
|
||||
async def start_scheduler():
|
||||
# Run every day at 3 AM
|
||||
scheduler.add_job(
|
||||
scheduled_cleanup,
|
||||
CronTrigger(hour=10, minute=0),
|
||||
id="token_cleanup",
|
||||
name="Clean up expired revoked tokens"
|
||||
)
|
||||
scheduler.start()
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def stop_scheduler():
|
||||
scheduler.shutdown()
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def root():
|
||||
return """
|
||||
@@ -67,4 +44,5 @@ async def root():
|
||||
</html>
|
||||
"""
|
||||
|
||||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||||
|
||||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||||
|
||||
@@ -29,7 +29,6 @@ from .gift import (
|
||||
from .email_template import EmailTemplate, TemplateType
|
||||
from .notification_log import NotificationLog, NotificationType, NotificationStatus
|
||||
from .activity_log import ActivityLog, ActivityType
|
||||
from .token import RevokedToken
|
||||
# Make sure all models are imported above this line
|
||||
__all__ = [
|
||||
'Base', 'TimestampMixin', 'UUIDMixin',
|
||||
@@ -40,5 +39,4 @@ __all__ = [
|
||||
'EmailTemplate', 'TemplateType',
|
||||
'NotificationLog', 'NotificationType', 'NotificationStatus',
|
||||
'ActivityLog', 'ActivityType',
|
||||
'RevokedToken',
|
||||
]
|
||||
@@ -1,15 +0,0 @@
|
||||
from sqlalchemy import Column, String, ForeignKey
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class RevokedToken(UUIDMixin, TimestampMixin, Base):
|
||||
"""Model to store revoked JWT tokens via their jti (JWT ID)."""
|
||||
__tablename__ = "revoked_tokens"
|
||||
|
||||
jti = Column(String(length=50), nullable=False, unique=True, index=True)
|
||||
token_type = Column(String(length=20), nullable=False)
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"))
|
||||
|
||||
user = relationship("User", back_populates="revoked_tokens")
|
||||
@@ -25,7 +25,6 @@ class User(Base, UUIDMixin, TimestampMixin):
|
||||
foreign_keys="EventManager.user_id"
|
||||
)
|
||||
guest_profiles = relationship("Guest", back_populates="user", foreign_keys="Guest.user_id")
|
||||
revoked_tokens = relationship("RevokedToken", back_populates="user", cascade="all, delete")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<User {self.email}>"
|
||||
@@ -1,66 +0,0 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel, EmailStr, Field, field_validator
|
||||
from datetime import datetime
|
||||
from passlib.hash import bcrypt
|
||||
|
||||
|
||||
# Base schema with shared user attributes
|
||||
class UserBase(BaseModel):
|
||||
"""Base schema with common user attributes."""
|
||||
email: EmailStr
|
||||
first_name: str
|
||||
last_name: str
|
||||
|
||||
|
||||
# Schema for creating a new user
|
||||
class UserCreate(UserBase):
|
||||
"""Schema for user registration."""
|
||||
password: str = Field(
|
||||
...,
|
||||
min_length=8,
|
||||
description="Password must be at least 8 characters"
|
||||
)
|
||||
|
||||
@field_validator('password')
|
||||
def password_strength(cls, v):
|
||||
# Add more complex password validation if needed
|
||||
if len(v) < 8:
|
||||
raise ValueError('Password must be at least 8 characters')
|
||||
return v
|
||||
|
||||
def hash_password(self) -> str:
|
||||
"""Hash the password before saving it to the database."""
|
||||
return bcrypt.hash(self.password)
|
||||
|
||||
|
||||
# Schema for updating user details
|
||||
class UserUpdate(BaseModel):
|
||||
"""Schema for updating user information."""
|
||||
email: Optional[EmailStr] = None
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
phone_number: Optional[str] = None
|
||||
is_active: Optional[bool] = None
|
||||
preferences: Optional[dict] = None # Provide preferences support
|
||||
|
||||
|
||||
# Schema for user responses (read-only fields)
|
||||
class UserResponse(UserBase):
|
||||
"""Schema for user responses in API."""
|
||||
id: UUID
|
||||
is_active: bool
|
||||
is_superuser: bool # Include roles or superuser flags if needed
|
||||
preferences: Optional[dict] = None # Include preferences in response
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
orm_mode = True # Enable mapping SQLAlchemy models to this schema
|
||||
|
||||
|
||||
# Schema for user authentication (e.g., login requests)
|
||||
class UserAuth(BaseModel):
|
||||
"""Schema for user authentication."""
|
||||
email: EmailStr
|
||||
password: str
|
||||
126
backend/app/schemas/users.py
Normal file
126
backend/app/schemas/users.py
Normal file
@@ -0,0 +1,126 @@
|
||||
# app/schemas/users.py
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
from uuid import UUID
|
||||
|
||||
import pydantic
|
||||
from pydantic import BaseModel, EmailStr, field_validator, ConfigDict
|
||||
|
||||
|
||||
class UserBase(BaseModel):
|
||||
email: EmailStr
|
||||
first_name: str
|
||||
last_name: str
|
||||
phone_number: Optional[str] = None
|
||||
|
||||
@field_validator('phone_number')
|
||||
@classmethod
|
||||
def validate_phone_number(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is None:
|
||||
return v
|
||||
# Simple regex for phone validation
|
||||
if not re.match(r'^\+?[0-9\s\-\(\)]{8,20}$', v):
|
||||
raise ValueError('Invalid phone number format')
|
||||
return v
|
||||
|
||||
|
||||
class UserCreate(UserBase):
|
||||
password: str
|
||||
|
||||
@field_validator('password')
|
||||
@classmethod
|
||||
def password_strength(cls, v: str) -> str:
|
||||
"""Basic password strength validation"""
|
||||
if len(v) < 8:
|
||||
raise ValueError('Password must be at least 8 characters')
|
||||
if not any(char.isdigit() for char in v):
|
||||
raise ValueError('Password must contain at least one digit')
|
||||
if not any(char.isupper() for char in v):
|
||||
raise ValueError('Password must contain at least one uppercase letter')
|
||||
return v
|
||||
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
phone_number: Optional[str] = None
|
||||
preferences: Optional[Dict[str, Any]] = None
|
||||
|
||||
@field_validator('phone_number')
|
||||
@classmethod
|
||||
def validate_phone_number(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is None:
|
||||
return v
|
||||
# Simple regex for phone validation
|
||||
if not re.match(r'^\+?[0-9\s\-\(\)]{8,20}$', v):
|
||||
raise ValueError('Invalid phone number format')
|
||||
return v
|
||||
|
||||
|
||||
class UserInDB(UserBase):
|
||||
id: UUID
|
||||
is_active: bool
|
||||
is_superuser: bool
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class UserResponse(UserBase):
|
||||
id: UUID
|
||||
is_active: bool
|
||||
is_superuser: bool
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: Optional[str] = None
|
||||
token_type: str = "bearer"
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
sub: str # User ID
|
||||
exp: int # Expiration time
|
||||
iat: Optional[int] = None # Issued at
|
||||
jti: Optional[str] = None # JWT ID
|
||||
is_superuser: Optional[bool] = False
|
||||
first_name: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
type: Optional[str] = None # Token type (access/refresh)
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
user_id: UUID
|
||||
is_superuser: bool = False
|
||||
|
||||
|
||||
class PasswordReset(BaseModel):
|
||||
token: str
|
||||
new_password: str
|
||||
|
||||
@field_validator('new_password')
|
||||
@classmethod
|
||||
def password_strength(cls, v: str) -> str:
|
||||
"""Basic password strength validation"""
|
||||
if len(v) < 8:
|
||||
raise ValueError('Password must be at least 8 characters')
|
||||
if not any(char.isdigit() for char in v):
|
||||
raise ValueError('Password must contain at least one digit')
|
||||
if not any(char.isupper() for char in v):
|
||||
raise ValueError('Password must contain at least one uppercase letter')
|
||||
return v
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
refresh_token: str
|
||||
0
backend/app/services/__init__.py
Normal file
0
backend/app/services/__init__.py
Normal file
193
backend/app/services/auth_service.py
Normal file
193
backend/app/services/auth_service.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# app/services/auth_service.py
|
||||
import logging
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.auth import (
|
||||
verify_password,
|
||||
get_password_hash,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
TokenExpiredError,
|
||||
TokenInvalidError
|
||||
)
|
||||
from app.models.user import User
|
||||
from app.schemas.users import Token, UserCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthenticationError(Exception):
|
||||
"""Exception raised for authentication errors"""
|
||||
pass
|
||||
|
||||
|
||||
class AuthService:
|
||||
"""Service for handling authentication operations"""
|
||||
|
||||
@staticmethod
|
||||
def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
|
||||
"""
|
||||
Authenticate a user with email and password.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
email: User email
|
||||
password: User password
|
||||
|
||||
Returns:
|
||||
User if authenticated, None otherwise
|
||||
"""
|
||||
user = db.query(User).filter(User.email == email).first()
|
||||
|
||||
if not user:
|
||||
return None
|
||||
|
||||
if not verify_password(password, user.password_hash):
|
||||
return None
|
||||
|
||||
if not user.is_active:
|
||||
raise AuthenticationError("User account is inactive")
|
||||
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def create_user(db: Session, user_data: UserCreate) -> User:
|
||||
"""
|
||||
Create a new user.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_data: User data
|
||||
|
||||
Returns:
|
||||
Created user
|
||||
"""
|
||||
# Check if user already exists
|
||||
existing_user = db.query(User).filter(User.email == user_data.email).first()
|
||||
if existing_user:
|
||||
raise AuthenticationError("User with this email already exists")
|
||||
|
||||
# Create new user
|
||||
hashed_password = get_password_hash(user_data.password)
|
||||
|
||||
# Create user object from model
|
||||
user = User(
|
||||
email=user_data.email,
|
||||
password_hash=hashed_password,
|
||||
first_name=user_data.first_name,
|
||||
last_name=user_data.last_name,
|
||||
phone_number=user_data.phone_number,
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def create_tokens(user: User) -> Token:
|
||||
"""
|
||||
Create access and refresh tokens for a user.
|
||||
|
||||
Args:
|
||||
user: User to create tokens for
|
||||
|
||||
Returns:
|
||||
Token object with access and refresh tokens
|
||||
"""
|
||||
# Generate claims
|
||||
claims = {
|
||||
"is_superuser": user.is_superuser,
|
||||
"email": user.email,
|
||||
"first_name": user.first_name
|
||||
}
|
||||
|
||||
# Create tokens
|
||||
access_token = create_access_token(
|
||||
subject=str(user.id),
|
||||
claims=claims
|
||||
)
|
||||
|
||||
refresh_token = create_refresh_token(
|
||||
subject=str(user.id)
|
||||
)
|
||||
|
||||
return Token(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def refresh_tokens(db: Session, refresh_token: str) -> Token:
|
||||
"""
|
||||
Generate new tokens using a refresh token.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
refresh_token: Valid refresh token
|
||||
|
||||
Returns:
|
||||
New access and refresh tokens
|
||||
|
||||
Raises:
|
||||
TokenExpiredError: If refresh token has expired
|
||||
TokenInvalidError: If refresh token is invalid
|
||||
"""
|
||||
from app.core.auth import decode_token, get_token_data
|
||||
|
||||
try:
|
||||
# Verify token is a refresh token
|
||||
decode_token(refresh_token, verify_type="refresh")
|
||||
|
||||
# Get user ID from token
|
||||
token_data = get_token_data(refresh_token)
|
||||
user_id = token_data.user_id
|
||||
|
||||
# Get user from database
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user or not user.is_active:
|
||||
raise TokenInvalidError("Invalid user or inactive account")
|
||||
|
||||
# Generate new tokens
|
||||
return AuthService.create_tokens(user)
|
||||
|
||||
except (TokenExpiredError, TokenInvalidError) as e:
|
||||
logger.warning(f"Token refresh failed: {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def change_password(db: Session, user_id: UUID, current_password: str, new_password: str) -> bool:
|
||||
"""
|
||||
Change a user's password.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
current_password: Current password
|
||||
new_password: New password
|
||||
|
||||
Returns:
|
||||
True if password was changed successfully
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If current password is incorrect
|
||||
"""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise AuthenticationError("User not found")
|
||||
|
||||
# Verify current password
|
||||
if not verify_password(current_password, user.password_hash):
|
||||
raise AuthenticationError("Current password is incorrect")
|
||||
|
||||
# Update password
|
||||
user.password_hash = get_password_hash(new_password)
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker, clear_mappers
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
@@ -42,4 +43,37 @@ def teardown_test_db(engine):
|
||||
Base.metadata.drop_all(engine)
|
||||
|
||||
# Dispose of engine
|
||||
engine.dispose()
|
||||
engine.dispose()
|
||||
|
||||
async def get_async_test_engine():
|
||||
"""Create an async SQLite in-memory engine specifically for testing"""
|
||||
test_engine = create_async_engine(
|
||||
"sqlite+aiosqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool, # Use static pool for in-memory testing
|
||||
echo=False
|
||||
)
|
||||
return test_engine
|
||||
|
||||
|
||||
async def setup_async_test_db():
|
||||
"""Create an async test database and session factory"""
|
||||
test_engine = await get_async_test_engine()
|
||||
|
||||
async with test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
AsyncTestingSessionLocal = sessionmaker(
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=test_engine,
|
||||
expire_on_commit=False,
|
||||
class_=AsyncSession
|
||||
)
|
||||
|
||||
return test_engine, AsyncTestingSessionLocal
|
||||
|
||||
|
||||
async def teardown_async_test_db(engine):
|
||||
"""Clean up after async tests"""
|
||||
await engine.dispose()
|
||||
|
||||
@@ -11,7 +11,7 @@ sqlalchemy>=2.0.29
|
||||
alembic>=1.14.1
|
||||
psycopg2-binary>=2.9.9
|
||||
asyncpg>=0.29.0
|
||||
|
||||
aiosqlite==0.21.0
|
||||
# Security and authentication
|
||||
python-jose>=3.4.0
|
||||
passlib>=1.7.4
|
||||
|
||||
@@ -1,85 +0,0 @@
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from jose import jwt
|
||||
|
||||
from app.auth.dependencies import get_current_user, get_current_active_user
|
||||
from app.auth.security import SECRET_KEY, ALGORITHM
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
return User(
|
||||
id="123e4567-e89b-12d3-a456-426614174000",
|
||||
email="test@example.com",
|
||||
password_hash="hashedpassword",
|
||||
is_active=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_success(mock_user):
|
||||
# Create a valid access token with required claims
|
||||
valid_token = jwt.encode(
|
||||
{"sub": str(mock_user.id), "type": "access", "exp": datetime.now(tz=timezone.utc).timestamp() + 3600},
|
||||
SECRET_KEY,
|
||||
algorithm=ALGORITHM
|
||||
)
|
||||
|
||||
# Mock database session
|
||||
mock_db = AsyncMock()
|
||||
mock_db.get.return_value = mock_user # Ensure `db.get()` returns the mock_user
|
||||
|
||||
# Call the dependency
|
||||
user = await get_current_user(token=valid_token, db=mock_db)
|
||||
|
||||
# Assertions
|
||||
assert user == mock_user, "Returned user does not match the mocked user."
|
||||
mock_db.get.assert_called_once_with(User, mock_user.id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_invalid_token():
|
||||
invalid_token = "invalid.token.payload"
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user(token=invalid_token, db=AsyncMock())
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.detail == "Could not validate credentials"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_wrong_token_type():
|
||||
token = jwt.encode({"sub": "123", "type": "refresh"}, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user(token=token, db=AsyncMock())
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.detail == "Could not validate credentials"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_active_user_success(mock_user):
|
||||
result = await get_current_active_user(mock_user)
|
||||
assert result == mock_user
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_active_user_inactive():
|
||||
inactive_user = User(
|
||||
id="123e4567-e89b-12d3-a456-426614174000",
|
||||
email="inactive@example.com",
|
||||
password_hash="hashedpassword",
|
||||
is_active=False
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_active_user(inactive_user)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert exc_info.value.detail == "Inactive user"
|
||||
@@ -1,147 +0,0 @@
|
||||
from datetime import timedelta, datetime
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from jose import jwt, JWTError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.security import (
|
||||
get_password_hash, verify_password,
|
||||
create_access_token, create_refresh_token,
|
||||
decode_token, SECRET_KEY, ALGORITHM
|
||||
)
|
||||
from app.schemas.token import TokenPayload
|
||||
|
||||
|
||||
def test_password_hashing():
|
||||
plain_password = "securepassword123"
|
||||
hashed_password = get_password_hash(plain_password)
|
||||
|
||||
# Ensure hashed passwords are not the same
|
||||
assert hashed_password != plain_password
|
||||
# Test password verification
|
||||
assert verify_password(plain_password, hashed_password)
|
||||
assert not verify_password("wrongpassword", hashed_password)
|
||||
|
||||
|
||||
def test_access_token_creation():
|
||||
user_id = "123e4567-e89b-12d3-a456-426614174000"
|
||||
token = create_access_token({"sub": user_id})
|
||||
decoded_payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
|
||||
assert decoded_payload.get("sub") == user_id
|
||||
assert decoded_payload.get("type") == "access"
|
||||
|
||||
|
||||
def test_refresh_token_creation():
|
||||
user_id = "123e4567-e89b-12d3-a456-426614174000"
|
||||
token = create_refresh_token({"sub": user_id})
|
||||
decoded_payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
|
||||
assert decoded_payload.get("sub") == user_id
|
||||
assert decoded_payload.get("type") == "refresh"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decode_token_valid():
|
||||
user_id = "123e4567-e89b-12d3-a456-426614174000"
|
||||
access_token = create_access_token({"sub": user_id, "jti": "test-jti"})
|
||||
|
||||
# Mock is_token_revoked to return False
|
||||
mock_db = AsyncMock(spec=AsyncSession)
|
||||
mock_db.get = AsyncMock(return_value=None)
|
||||
|
||||
token_payload = await decode_token(access_token, db=mock_db)
|
||||
|
||||
assert isinstance(token_payload, TokenPayload)
|
||||
assert token_payload.sub == user_id
|
||||
assert token_payload.type == "access"
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decode_token_expired():
|
||||
user_id = "123e4567-e89b-12d3-a456-426614174000"
|
||||
token = create_access_token({"sub": user_id, "jti": "test-jti"}, expires_delta=timedelta(seconds=-1))
|
||||
|
||||
# Mock database
|
||||
mock_db = AsyncMock(spec=AsyncSession)
|
||||
|
||||
with pytest.raises(JWTError) as exc_info:
|
||||
await decode_token(token, db=mock_db)
|
||||
|
||||
assert str(exc_info.value) == "Token has been revoked."
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decode_token_missing_exp():
|
||||
# Create a token without the `exp` claim
|
||||
payload = {"sub": "123e4567-e89b-12d3-a456-426614174000", "type": "access", "jti": "test-jti"}
|
||||
token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
# Mock database
|
||||
mock_db = AsyncMock(spec=AsyncSession)
|
||||
|
||||
with pytest.raises(JWTError) as exc_info:
|
||||
await decode_token(token, db=mock_db)
|
||||
|
||||
assert str(exc_info.value) == "Malformed token. Missing required claim(s)."
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decode_token_missing_sub():
|
||||
# Create a token without the `sub` claim
|
||||
payload = {"exp": datetime.now().timestamp() + 60, "type": "access", "jti": "test-jti"}
|
||||
token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
# Mock database
|
||||
mock_db = AsyncMock(spec=AsyncSession)
|
||||
|
||||
with pytest.raises(JWTError) as exc_info:
|
||||
await decode_token(token, db=mock_db)
|
||||
|
||||
assert str(exc_info.value) == "Malformed token. Missing required claim(s)."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decode_token_invalid_signature():
|
||||
# Use a different secret key for signing
|
||||
token = jwt.encode({"sub": "123", "type": "access", "jti": "test-jti"}, "wrong_secret", algorithm=ALGORITHM)
|
||||
|
||||
# Mock database
|
||||
mock_db = AsyncMock(spec=AsyncSession)
|
||||
|
||||
with pytest.raises(JWTError) as exc_info:
|
||||
await decode_token(token, db=mock_db)
|
||||
|
||||
assert str(exc_info.value) == "Signature verification failed."
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decode_token_malformed():
|
||||
malformed_token = "malformed.header.payload"
|
||||
|
||||
# Mock database
|
||||
mock_db = AsyncMock(spec=AsyncSession)
|
||||
|
||||
with pytest.raises(JWTError) as exc_info:
|
||||
await decode_token(malformed_token, db=mock_db)
|
||||
|
||||
assert str(exc_info.value) == "Invalid token."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decode_token_invalid_type():
|
||||
user_id = "123e4567-e89b-12d3-a456-426614174000"
|
||||
token = create_refresh_token({"sub": user_id, "jti": "test-jti"}) # Token type is "refresh"
|
||||
|
||||
# Mock database
|
||||
mock_db = AsyncMock(spec=AsyncSession)
|
||||
|
||||
with pytest.raises(JWTError) as exc_info:
|
||||
await decode_token(token, required_type="access", db=mock_db) # Expecting an access token
|
||||
|
||||
assert str(exc_info.value) == "Invalid token type: expected 'access', got 'refresh'."
|
||||
@@ -9,7 +9,7 @@ from app.models import Event, GiftItem, GiftStatus, GiftPriority, GiftCategory,
|
||||
EventTheme, Guest, GuestStatus, ActivityType, ActivityLog, EmailTemplate, TemplateType, NotificationLog, \
|
||||
NotificationType, NotificationStatus
|
||||
from app.models.user import User
|
||||
from app.utils.test_utils import setup_test_db, teardown_test_db
|
||||
from app.utils.test_utils import setup_test_db, teardown_test_db, setup_async_test_db, teardown_async_test_db
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@@ -30,6 +30,15 @@ def db_session():
|
||||
teardown_test_db(test_engine)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function") # Define a fixture
|
||||
async def async_test_db():
|
||||
"""Fixture provides new testing engine and session for each test run to improve isolation."""
|
||||
|
||||
test_engine, AsyncTestingSessionLocal = await setup_async_test_db()
|
||||
yield test_engine, AsyncTestingSessionLocal
|
||||
await teardown_async_test_db(test_engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user(db_session):
|
||||
"""Fixture to create and return a mock User instance."""
|
||||
@@ -72,7 +81,6 @@ def event_fixture(db_session, mock_user):
|
||||
return event
|
||||
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gift_item_fixture(db_session, mock_user):
|
||||
"""
|
||||
|
||||
0
backend/tests/core/__init__.py
Normal file
0
backend/tests/core/__init__.py
Normal file
260
backend/tests/core/test_auth.py
Normal file
260
backend/tests/core/test_auth.py
Normal file
@@ -0,0 +1,260 @@
|
||||
# tests/core/test_auth.py
|
||||
import uuid
|
||||
import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from jose import jwt
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.core.auth import (
|
||||
verify_password,
|
||||
get_password_hash,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_token,
|
||||
get_token_data,
|
||||
TokenExpiredError,
|
||||
TokenInvalidError,
|
||||
TokenMissingClaimError
|
||||
)
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class TestPasswordHandling:
|
||||
"""Tests for password hashing and verification functions"""
|
||||
|
||||
def test_password_hash_different_from_password(self):
|
||||
"""Test that a password hash is different from the original password"""
|
||||
password = "TestPassword123"
|
||||
hashed = get_password_hash(password)
|
||||
assert hashed != password
|
||||
|
||||
def test_verify_correct_password(self):
|
||||
"""Test that verify_password returns True for the correct password"""
|
||||
password = "TestPassword123"
|
||||
hashed = get_password_hash(password)
|
||||
assert verify_password(password, hashed) is True
|
||||
|
||||
def test_verify_incorrect_password(self):
|
||||
"""Test that verify_password returns False for an incorrect password"""
|
||||
password = "TestPassword123"
|
||||
wrong_password = "WrongPassword123"
|
||||
hashed = get_password_hash(password)
|
||||
assert verify_password(wrong_password, hashed) is False
|
||||
|
||||
def test_same_password_different_hash(self):
|
||||
"""Test that the same password gets a different hash each time"""
|
||||
password = "TestPassword123"
|
||||
hash1 = get_password_hash(password)
|
||||
hash2 = get_password_hash(password)
|
||||
assert hash1 != hash2
|
||||
|
||||
|
||||
class TestTokenCreation:
|
||||
"""Tests for token creation functions"""
|
||||
|
||||
def test_create_access_token(self):
|
||||
"""Test that an access token is created with the correct claims"""
|
||||
user_id = str(uuid.uuid4())
|
||||
custom_claims = {
|
||||
"email": "test@example.com",
|
||||
"first_name": "Test",
|
||||
"is_superuser": True
|
||||
}
|
||||
token = create_access_token(subject=user_id, claims=custom_claims)
|
||||
|
||||
# Decode token to verify claims
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
|
||||
# Check standard claims
|
||||
assert payload["sub"] == user_id
|
||||
assert "jti" in payload
|
||||
assert "exp" in payload
|
||||
assert "iat" in payload
|
||||
assert payload["type"] == "access"
|
||||
|
||||
# Check custom claims
|
||||
for key, value in custom_claims.items():
|
||||
assert payload[key] == value
|
||||
|
||||
def test_create_refresh_token(self):
|
||||
"""Test that a refresh token is created with the correct claims"""
|
||||
user_id = str(uuid.uuid4())
|
||||
token = create_refresh_token(subject=user_id)
|
||||
|
||||
# Decode token to verify claims
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
|
||||
# Check standard claims
|
||||
assert payload["sub"] == user_id
|
||||
assert "jti" in payload
|
||||
assert "exp" in payload
|
||||
assert "iat" in payload
|
||||
assert payload["type"] == "refresh"
|
||||
|
||||
def test_token_expiration(self):
|
||||
"""Test that tokens have the correct expiration time"""
|
||||
user_id = str(uuid.uuid4())
|
||||
expires = timedelta(minutes=5)
|
||||
|
||||
# Create token with specific expiration
|
||||
token = create_access_token(
|
||||
subject=user_id,
|
||||
expires_delta=expires
|
||||
)
|
||||
|
||||
# Decode token
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
|
||||
# Get actual expiration time from token
|
||||
expiration = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
|
||||
|
||||
# Calculate expected expiration (approximately)
|
||||
now = datetime.now(timezone.utc)
|
||||
expected_expiration = now + expires
|
||||
|
||||
# Difference should be small (less than 1 second)
|
||||
difference = abs((expiration - expected_expiration).total_seconds())
|
||||
assert difference < 1
|
||||
|
||||
|
||||
class TestTokenDecoding:
|
||||
"""Tests for token decoding and validation functions"""
|
||||
|
||||
def test_decode_valid_token(self):
|
||||
"""Test that a valid token can be decoded"""
|
||||
user_id = str(uuid.uuid4())
|
||||
token = create_access_token(subject=user_id)
|
||||
|
||||
# Decode token
|
||||
payload = decode_token(token)
|
||||
|
||||
# Check that the subject matches
|
||||
assert payload.sub == user_id
|
||||
|
||||
def test_decode_expired_token(self):
|
||||
"""Test that an expired token raises TokenExpiredError"""
|
||||
user_id = str(uuid.uuid4())
|
||||
|
||||
# Create a token that's already expired by directly manipulating the payload
|
||||
now = datetime.now(timezone.utc)
|
||||
expired_time = now - timedelta(hours=1) # 1 hour in the past
|
||||
|
||||
# Create the expired token manually
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"exp": int(expired_time.timestamp()), # Set expiration in the past
|
||||
"iat": int(now.timestamp()),
|
||||
"jti": str(uuid.uuid4()),
|
||||
"type": "access"
|
||||
}
|
||||
|
||||
expired_token = jwt.encode(
|
||||
payload,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
# Attempting to decode should raise TokenExpiredError
|
||||
with pytest.raises(TokenExpiredError):
|
||||
decode_token(expired_token)
|
||||
|
||||
def test_decode_invalid_token(self):
|
||||
"""Test that an invalid token raises TokenInvalidError"""
|
||||
invalid_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJpbnZhbGlkIn0.invalid-signature"
|
||||
|
||||
with pytest.raises(TokenInvalidError):
|
||||
decode_token(invalid_token)
|
||||
|
||||
def test_decode_token_with_missing_sub(self):
|
||||
"""Test that a token without 'sub' claim raises TokenMissingClaimError"""
|
||||
# Create a token without a subject
|
||||
now = datetime.now(timezone.utc)
|
||||
payload = {
|
||||
"exp": int((now + timedelta(minutes=30)).timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
"jti": str(uuid.uuid4()),
|
||||
"type": "access"
|
||||
# No 'sub' claim
|
||||
}
|
||||
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
with pytest.raises(TokenMissingClaimError):
|
||||
decode_token(token)
|
||||
|
||||
def test_decode_token_with_wrong_type(self):
|
||||
"""Test that verifying a token with wrong type raises TokenInvalidError"""
|
||||
user_id = str(uuid.uuid4())
|
||||
token = create_access_token(subject=user_id)
|
||||
|
||||
# Try to verify it as a refresh token
|
||||
with pytest.raises(TokenInvalidError):
|
||||
decode_token(token, verify_type="refresh")
|
||||
|
||||
def test_decode_with_invalid_payload(self):
|
||||
"""Test that a token with invalid payload structure raises TokenInvalidError"""
|
||||
# Create a token with an invalid payload structure - missing 'sub' which is required
|
||||
# but including 'exp' to avoid the expiration check
|
||||
now = datetime.now(timezone.utc)
|
||||
payload = {
|
||||
# Missing "sub" field which is required
|
||||
"exp": int((now + timedelta(minutes=30)).timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
"jti": str(uuid.uuid4()),
|
||||
"invalid_field": "test"
|
||||
}
|
||||
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
# Should raise TokenMissingClaimError due to missing 'sub'
|
||||
with pytest.raises(TokenMissingClaimError):
|
||||
decode_token(token)
|
||||
|
||||
# Create another token with invalid type for required field
|
||||
payload = {
|
||||
"sub": 123, # sub should be a string, not an integer
|
||||
"exp": int((now + timedelta(minutes=30)).timestamp()),
|
||||
}
|
||||
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
# Should raise TokenInvalidError due to ValidationError
|
||||
with pytest.raises(TokenInvalidError):
|
||||
decode_token(token)
|
||||
|
||||
def test_get_token_data(self):
|
||||
"""Test extracting TokenData from a token"""
|
||||
user_id = uuid.uuid4()
|
||||
token = create_access_token(
|
||||
subject=str(user_id),
|
||||
claims={"is_superuser": True}
|
||||
)
|
||||
|
||||
token_data = get_token_data(token)
|
||||
|
||||
assert token_data.user_id == user_id
|
||||
assert token_data.is_superuser is True
|
||||
0
backend/tests/services/__init__.py
Normal file
0
backend/tests/services/__init__.py
Normal file
Reference in New Issue
Block a user