Add comprehensive test suite and utilities for user functionality
This commit introduces a suite of tests for user models, schemas, CRUD operations, and authentication services. It also adds utilities for in-memory database setup to support these tests and updates environment settings for consistency.
This commit is contained in:
@@ -17,7 +17,7 @@ ENVIRONMENT=development
|
|||||||
DEBUG=true
|
DEBUG=true
|
||||||
BACKEND_CORS_ORIGINS=["http://localhost:3000"]
|
BACKEND_CORS_ORIGINS=["http://localhost:3000"]
|
||||||
FIRST_SUPERUSER_EMAIL=admin@example.com
|
FIRST_SUPERUSER_EMAIL=admin@example.com
|
||||||
FIRST_SUPERUSER_PASSWORD=admin123
|
FIRST_SUPERUSER_PASSWORD=Admin123
|
||||||
|
|
||||||
# Frontend settings
|
# Frontend settings
|
||||||
FRONTEND_PORT=3000
|
FRONTEND_PORT=3000
|
||||||
|
|||||||
@@ -0,0 +1,46 @@
|
|||||||
|
"""Add all initial models
|
||||||
|
|
||||||
|
Revision ID: 38bf9e7e74b3
|
||||||
|
Revises: 7396957cbe80
|
||||||
|
Create Date: 2025-02-28 09:19:33.212278
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '38bf9e7e74b3'
|
||||||
|
down_revision: Union[str, None] = '7396957cbe80'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
|
||||||
|
op.create_table('users',
|
||||||
|
sa.Column('email', sa.String(), nullable=False),
|
||||||
|
sa.Column('password_hash', sa.String(), nullable=False),
|
||||||
|
sa.Column('first_name', sa.String(), nullable=False),
|
||||||
|
sa.Column('last_name', sa.String(), nullable=True),
|
||||||
|
sa.Column('phone_number', sa.String(), nullable=True),
|
||||||
|
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||||
|
sa.Column('is_superuser', sa.Boolean(), nullable=False),
|
||||||
|
sa.Column('preferences', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('id', sa.UUID(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_index(op.f('ix_users_email'), table_name='users')
|
||||||
|
op.drop_table('users')
|
||||||
|
# ### end Alembic commands ###
|
||||||
0
backend/app/api/__init__.py
Normal file
0
backend/app/api/__init__.py
Normal file
0
backend/app/api/dependencies/__init__.py
Normal file
0
backend/app/api/dependencies/__init__.py
Normal file
137
backend/app/api/dependencies/auth.py
Normal file
137
backend/app/api/dependencies/auth.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.auth import get_token_data, TokenExpiredError, TokenInvalidError
|
||||||
|
from app.core.database import get_db
|
||||||
|
from app.models.user import User
|
||||||
|
|
||||||
|
# OAuth2 configuration
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_user(
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
token: str = Depends(oauth2_scheme)
|
||||||
|
) -> User:
|
||||||
|
"""
|
||||||
|
Get the current authenticated user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
token: JWT token from request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User: The authenticated user
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If authentication fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Decode token and get user ID
|
||||||
|
token_data = get_token_data(token)
|
||||||
|
|
||||||
|
# Get user from database
|
||||||
|
user = db.query(User).filter(User.id == token_data.user_id).first()
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="User not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not user.is_active:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Inactive user"
|
||||||
|
)
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
|
except TokenExpiredError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Token expired",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"}
|
||||||
|
)
|
||||||
|
except TokenInvalidError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Could not validate credentials",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_active_user(
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
) -> User:
|
||||||
|
"""
|
||||||
|
Check if the current user is active.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_user: The current authenticated user
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User: The authenticated and active user
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If user is inactive
|
||||||
|
"""
|
||||||
|
if not current_user.is_active:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Inactive user"
|
||||||
|
)
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_superuser(
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
) -> User:
|
||||||
|
"""
|
||||||
|
Check if the current user is a superuser.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_user: The current authenticated user
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User: The authenticated superuser
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If user is not a superuser
|
||||||
|
"""
|
||||||
|
if not current_user.is_superuser:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Not enough permissions"
|
||||||
|
)
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
def get_optional_current_user(
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
token: Optional[str] = Depends(oauth2_scheme)
|
||||||
|
) -> Optional[User]:
|
||||||
|
"""
|
||||||
|
Get the current user if authenticated, otherwise return None.
|
||||||
|
Useful for endpoints that work with both authenticated and unauthenticated users.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
token: JWT token from request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User or None: The authenticated user or None
|
||||||
|
"""
|
||||||
|
if not token:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
token_data = get_token_data(token)
|
||||||
|
user = db.query(User).filter(User.id == token_data.user_id).first()
|
||||||
|
if not user or not user.is_active:
|
||||||
|
return None
|
||||||
|
return user
|
||||||
|
except (TokenExpiredError, TokenInvalidError):
|
||||||
|
return None
|
||||||
6
backend/app/api/main.py
Normal file
6
backend/app/api/main.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from app.api.routes import auth
|
||||||
|
|
||||||
|
api_router = APIRouter()
|
||||||
|
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
|
||||||
0
backend/app/api/routes/__init__.py
Normal file
0
backend/app/api/routes/__init__.py
Normal file
231
backend/app/api/routes/auth.py
Normal file
231
backend/app/api/routes/auth.py
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
# app/api/routes/auth.py
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status, Body
|
||||||
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.api.dependencies.auth import get_current_user
|
||||||
|
from app.core.auth import TokenExpiredError, TokenInvalidError
|
||||||
|
from app.core.database import get_db
|
||||||
|
from app.models.user import User
|
||||||
|
from app.schemas.users import (
|
||||||
|
UserCreate,
|
||||||
|
UserResponse,
|
||||||
|
Token,
|
||||||
|
LoginRequest,
|
||||||
|
RefreshTokenRequest
|
||||||
|
)
|
||||||
|
from app.services.auth_service import AuthService, AuthenticationError
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def register_user(
|
||||||
|
user_data: UserCreate,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Register a new user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created user information.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
user = AuthService.create_user(db, user_data)
|
||||||
|
return user
|
||||||
|
except AuthenticationError as e:
|
||||||
|
logger.warning(f"Registration failed: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail=str(e)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error during registration: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="An unexpected error occurred. Please try again later."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login", response_model=Token)
|
||||||
|
async def login(
|
||||||
|
login_data: LoginRequest,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Login with username and password.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Access and refresh tokens.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Attempt to authenticate the user
|
||||||
|
user = AuthService.authenticate_user(db, login_data.email, login_data.password)
|
||||||
|
|
||||||
|
# Explicitly check for None result and raise correct exception
|
||||||
|
if user is None:
|
||||||
|
logger.warning(f"Invalid login attempt for: {login_data.email}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid email or password",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# User is authenticated, generate tokens
|
||||||
|
tokens = AuthService.create_tokens(user)
|
||||||
|
logger.info(f"User login successful: {user.email}")
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
# Re-raise HTTP exceptions without modification
|
||||||
|
raise
|
||||||
|
except AuthenticationError as e:
|
||||||
|
# Handle specific authentication errors like inactive accounts
|
||||||
|
logger.warning(f"Authentication failed: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=str(e),
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Handle unexpected errors
|
||||||
|
logger.error(f"Unexpected error during login: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="An unexpected error occurred. Please try again later."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login/oauth", response_model=Token)
|
||||||
|
async def login_oauth(
|
||||||
|
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
OAuth2-compatible login endpoint, used by the OpenAPI UI.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Access and refresh tokens.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
user = AuthService.authenticate_user(db, form_data.username, form_data.password)
|
||||||
|
|
||||||
|
if user is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid email or password",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate tokens
|
||||||
|
tokens = AuthService.create_tokens(user)
|
||||||
|
|
||||||
|
# Format response for OAuth compatibility
|
||||||
|
return {
|
||||||
|
"access_token": tokens.access_token,
|
||||||
|
"refresh_token": tokens.refresh_token,
|
||||||
|
"token_type": tokens.token_type
|
||||||
|
}
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except AuthenticationError as e:
|
||||||
|
logger.warning(f"OAuth authentication failed: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=str(e),
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error during OAuth login: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="An unexpected error occurred. Please try again later."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/refresh", response_model=Token)
|
||||||
|
async def refresh_token(
|
||||||
|
refresh_data: RefreshTokenRequest,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Refresh access token using a refresh token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New access and refresh tokens.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
tokens = AuthService.refresh_tokens(db, refresh_data.refresh_token)
|
||||||
|
return tokens
|
||||||
|
except TokenExpiredError:
|
||||||
|
logger.warning("Token refresh failed: Token expired")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Refresh token has expired. Please log in again.",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
except TokenInvalidError:
|
||||||
|
logger.warning("Token refresh failed: Invalid token")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid refresh token",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error during token refresh: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="An unexpected error occurred. Please try again later."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/change-password", status_code=status.HTTP_200_OK)
|
||||||
|
async def change_password(
|
||||||
|
current_password: str = Body(..., embed=True),
|
||||||
|
new_password: str = Body(..., embed=True),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Change current user's password.
|
||||||
|
|
||||||
|
Requires authentication.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
success = AuthService.change_password(
|
||||||
|
db=db,
|
||||||
|
user_id=current_user.id,
|
||||||
|
current_password=current_password,
|
||||||
|
new_password=new_password
|
||||||
|
)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
return {"message": "Password changed successfully"}
|
||||||
|
except AuthenticationError as e:
|
||||||
|
logger.warning(f"Password change failed: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=str(e)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error during password change: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="An unexpected error occurred. Please try again later."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me", response_model=UserResponse)
|
||||||
|
async def get_current_user_info(
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Get current user information.
|
||||||
|
|
||||||
|
Requires authentication.
|
||||||
|
"""
|
||||||
|
return current_user
|
||||||
185
backend/app/core/auth.py
Normal file
185
backend/app/core/auth.py
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
import logging
|
||||||
|
logging.getLogger('passlib').setLevel(logging.ERROR)
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Any, Dict, Optional, Union
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from jose import jwt, JWTError
|
||||||
|
from passlib.context import CryptContext
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.schemas.users import TokenData, TokenPayload
|
||||||
|
|
||||||
|
|
||||||
|
# Password hashing context
|
||||||
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
|
|
||||||
|
# Custom exceptions for auth
|
||||||
|
class AuthError(Exception):
|
||||||
|
"""Base authentication error"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class TokenExpiredError(AuthError):
|
||||||
|
"""Token has expired"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class TokenInvalidError(AuthError):
|
||||||
|
"""Token is invalid"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class TokenMissingClaimError(AuthError):
|
||||||
|
"""Token is missing a required claim"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||||
|
"""Verify a password against a hash."""
|
||||||
|
return pwd_context.verify(plain_password, hashed_password)
|
||||||
|
|
||||||
|
|
||||||
|
def get_password_hash(password: str) -> str:
|
||||||
|
"""Generate a password hash."""
|
||||||
|
return pwd_context.hash(password)
|
||||||
|
|
||||||
|
|
||||||
|
def create_access_token(
|
||||||
|
subject: Union[str, Any],
|
||||||
|
expires_delta: Optional[timedelta] = None,
|
||||||
|
claims: Optional[Dict[str, Any]] = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Create a JWT access token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
subject: The subject of the token (usually user_id)
|
||||||
|
expires_delta: Optional expiration time delta
|
||||||
|
claims: Optional additional claims to include in the token
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Encoded JWT token
|
||||||
|
"""
|
||||||
|
if expires_delta:
|
||||||
|
expire = datetime.now(timezone.utc) + expires_delta
|
||||||
|
else:
|
||||||
|
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||||
|
|
||||||
|
# Base token data
|
||||||
|
to_encode = {
|
||||||
|
"sub": str(subject),
|
||||||
|
"exp": expire,
|
||||||
|
"iat": datetime.now(tz=timezone.utc),
|
||||||
|
"jti": str(uuid.uuid4()),
|
||||||
|
"type": "access"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add custom claims
|
||||||
|
if claims:
|
||||||
|
to_encode.update(claims)
|
||||||
|
|
||||||
|
# Create the JWT
|
||||||
|
encoded_jwt = jwt.encode(
|
||||||
|
to_encode,
|
||||||
|
settings.SECRET_KEY,
|
||||||
|
algorithm=settings.ALGORITHM
|
||||||
|
)
|
||||||
|
|
||||||
|
return encoded_jwt
|
||||||
|
|
||||||
|
|
||||||
|
def create_refresh_token(
|
||||||
|
subject: Union[str, Any],
|
||||||
|
expires_delta: Optional[timedelta] = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Create a JWT refresh token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
subject: The subject of the token (usually user_id)
|
||||||
|
expires_delta: Optional expiration time delta
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Encoded JWT refresh token
|
||||||
|
"""
|
||||||
|
if expires_delta:
|
||||||
|
expire = datetime.now(timezone.utc) + expires_delta
|
||||||
|
else:
|
||||||
|
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||||
|
|
||||||
|
to_encode = {
|
||||||
|
"sub": str(subject),
|
||||||
|
"exp": expire,
|
||||||
|
"iat": datetime.now(timezone.utc),
|
||||||
|
"jti": str(uuid.uuid4()),
|
||||||
|
"type": "refresh"
|
||||||
|
}
|
||||||
|
|
||||||
|
encoded_jwt = jwt.encode(
|
||||||
|
to_encode,
|
||||||
|
settings.SECRET_KEY,
|
||||||
|
algorithm=settings.ALGORITHM
|
||||||
|
)
|
||||||
|
|
||||||
|
return encoded_jwt
|
||||||
|
|
||||||
|
|
||||||
|
def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload:
|
||||||
|
"""
|
||||||
|
Decode and verify a JWT token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: JWT token to decode
|
||||||
|
verify_type: Optional token type to verify (access or refresh)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TokenPayload: The decoded token data
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TokenExpiredError: If the token has expired
|
||||||
|
TokenInvalidError: If the token is invalid
|
||||||
|
TokenMissingClaimError: If a required claim is missing
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token,
|
||||||
|
settings.SECRET_KEY,
|
||||||
|
algorithms=[settings.ALGORITHM]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check required claims before Pydantic validation
|
||||||
|
if not payload.get("sub"):
|
||||||
|
raise TokenMissingClaimError("Token missing 'sub' claim")
|
||||||
|
|
||||||
|
# Verify token type if specified
|
||||||
|
if verify_type and payload.get("type") != verify_type:
|
||||||
|
raise TokenInvalidError(f"Invalid token type, expected {verify_type}")
|
||||||
|
|
||||||
|
# Now validate with Pydantic
|
||||||
|
token_data = TokenPayload(**payload)
|
||||||
|
return token_data
|
||||||
|
|
||||||
|
except JWTError as e:
|
||||||
|
# Check if the error is due to an expired token
|
||||||
|
if "expired" in str(e).lower():
|
||||||
|
raise TokenExpiredError("Token has expired")
|
||||||
|
raise TokenInvalidError("Invalid authentication token")
|
||||||
|
except ValidationError:
|
||||||
|
raise TokenInvalidError("Invalid token payload")
|
||||||
|
|
||||||
|
|
||||||
|
def get_token_data(token: str) -> TokenData:
|
||||||
|
"""
|
||||||
|
Extract the user ID and superuser status from a token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: JWT token
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TokenData with user_id and is_superuser
|
||||||
|
"""
|
||||||
|
payload = decode_token(token)
|
||||||
|
user_id = payload.sub
|
||||||
|
is_superuser = payload.is_superuser or False
|
||||||
|
|
||||||
|
return TokenData(user_id=uuid.UUID(user_id), is_superuser=is_superuser)
|
||||||
@@ -3,7 +3,7 @@ from typing import Optional, List
|
|||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
PROJECT_NAME: str = "App"
|
PROJECT_NAME: str = "EventSpace"
|
||||||
VERSION: str = "1.0.0"
|
VERSION: str = "1.0.0"
|
||||||
API_V1_STR: str = "/api/v1"
|
API_V1_STR: str = "/api/v1"
|
||||||
|
|
||||||
@@ -14,6 +14,17 @@ class Settings(BaseSettings):
|
|||||||
POSTGRES_PORT: str = "5432"
|
POSTGRES_PORT: str = "5432"
|
||||||
POSTGRES_DB: str = "app"
|
POSTGRES_DB: str = "app"
|
||||||
DATABASE_URL: Optional[str] = None
|
DATABASE_URL: Optional[str] = None
|
||||||
|
REFRESH_TOKEN_EXPIRE_DAYS: int = 60
|
||||||
|
db_pool_size: int = 20 # Default connection pool size
|
||||||
|
db_max_overflow: int = 50 # Maximum overflow connections
|
||||||
|
db_pool_timeout: int = 30 # Seconds to wait for a connection
|
||||||
|
db_pool_recycle: int = 3600 # Recycle connections after 1 hour
|
||||||
|
|
||||||
|
# SQL debugging (disable in production)
|
||||||
|
sql_echo: bool = False # Log SQL statements
|
||||||
|
sql_echo_pool: bool = False # Log connection pool events
|
||||||
|
sql_echo_timing: bool = False # Log query execution times
|
||||||
|
slow_query_threshold: float = 0.5 # Log queries taking longer than this
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def database_url(self) -> str:
|
def database_url(self) -> str:
|
||||||
@@ -30,7 +41,7 @@ class Settings(BaseSettings):
|
|||||||
# JWT configuration
|
# JWT configuration
|
||||||
SECRET_KEY: str = "your_secret_key_here"
|
SECRET_KEY: str = "your_secret_key_here"
|
||||||
ALGORITHM: str = "HS256"
|
ALGORITHM: str = "HS256"
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
ACCESS_TOKEN_EXPIRE_MINUTES: int = 1440 # 1 day
|
||||||
|
|
||||||
# CORS configuration
|
# CORS configuration
|
||||||
BACKEND_CORS_ORIGINS: List[str] = ["http://localhost:3000"]
|
BACKEND_CORS_ORIGINS: List[str] = ["http://localhost:3000"]
|
||||||
|
|||||||
@@ -1,17 +1,57 @@
|
|||||||
|
# app/core/database.py
|
||||||
|
import logging
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from sqlalchemy.ext.compiler import compiles
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
|
||||||
# Use the database URL from settings
|
# Configure logging
|
||||||
engine = create_engine(settings.database_url)
|
logger = logging.getLogger(__name__)
|
||||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
||||||
|
|
||||||
|
# SQLite compatibility for testing
|
||||||
|
@compiles(JSONB, 'sqlite')
|
||||||
|
def compile_jsonb_sqlite(type_, compiler, **kw):
|
||||||
|
return "TEXT"
|
||||||
|
|
||||||
|
@compiles(UUID, 'sqlite')
|
||||||
|
def compile_uuid_sqlite(type_, compiler, **kw):
|
||||||
|
return "TEXT"
|
||||||
|
|
||||||
|
# Declarative base for models
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
|
# Create engine with optimized settings for PostgreSQL
|
||||||
|
def create_production_engine():
|
||||||
|
return create_engine(
|
||||||
|
settings.database_url,
|
||||||
|
# Connection pool settings
|
||||||
|
pool_size=settings.db_pool_size,
|
||||||
|
max_overflow=settings.db_max_overflow,
|
||||||
|
pool_timeout=settings.db_pool_timeout,
|
||||||
|
pool_recycle=settings.db_pool_recycle,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
# Query execution settings
|
||||||
|
connect_args={
|
||||||
|
"application_name": "eventspace",
|
||||||
|
"keepalives": 1,
|
||||||
|
"keepalives_idle": 60,
|
||||||
|
"keepalives_interval": 10,
|
||||||
|
"keepalives_count": 5,
|
||||||
|
"options": "-c timezone=UTC",
|
||||||
|
},
|
||||||
|
isolation_level="READ COMMITTED",
|
||||||
|
echo=settings.sql_echo,
|
||||||
|
echo_pool=settings.sql_echo_pool,
|
||||||
|
)
|
||||||
|
|
||||||
# Dependency to get DB session
|
# Default production engine and session factory
|
||||||
|
engine = create_production_engine()
|
||||||
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
|
||||||
|
# FastAPI dependency
|
||||||
def get_db():
|
def get_db():
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
|
|||||||
0
backend/app/crud/__init__.py
Normal file
0
backend/app/crud/__init__.py
Normal file
62
backend/app/crud/base.py
Normal file
62
backend/app/crud/base.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union
|
||||||
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from app.core.database import Base
|
||||||
|
|
||||||
|
ModelType = TypeVar("ModelType", bound=Base)
|
||||||
|
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
|
||||||
|
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
|
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||||
|
def __init__(self, model: Type[ModelType]):
|
||||||
|
"""
|
||||||
|
CRUD object with default methods to Create, Read, Update, Delete (CRUD).
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
model: A SQLAlchemy model class
|
||||||
|
"""
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def get(self, db: Session, id: str) -> Optional[ModelType]:
|
||||||
|
return db.query(self.model).filter(self.model.id == id).first()
|
||||||
|
|
||||||
|
def get_multi(
|
||||||
|
self, db: Session, *, skip: int = 0, limit: int = 100
|
||||||
|
) -> List[ModelType]:
|
||||||
|
return db.query(self.model).offset(skip).limit(limit).all()
|
||||||
|
|
||||||
|
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
|
||||||
|
obj_in_data = jsonable_encoder(obj_in)
|
||||||
|
db_obj = self.model(**obj_in_data)
|
||||||
|
db.add(db_obj)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_obj)
|
||||||
|
return db_obj
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
db_obj: ModelType,
|
||||||
|
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
|
||||||
|
) -> ModelType:
|
||||||
|
obj_data = jsonable_encoder(db_obj)
|
||||||
|
if isinstance(obj_in, dict):
|
||||||
|
update_data = obj_in
|
||||||
|
else:
|
||||||
|
update_data = obj_in.model_dump(exclude_unset=True)
|
||||||
|
for field in obj_data:
|
||||||
|
if field in update_data:
|
||||||
|
setattr(db_obj, field, update_data[field])
|
||||||
|
db.add(db_obj)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_obj)
|
||||||
|
return db_obj
|
||||||
|
|
||||||
|
def remove(self, db: Session, *, id: str) -> ModelType:
|
||||||
|
obj = db.query(self.model).get(id)
|
||||||
|
db.delete(obj)
|
||||||
|
db.commit()
|
||||||
|
return obj
|
||||||
56
backend/app/crud/user.py
Normal file
56
backend/app/crud/user.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
# app/crud/user.py
|
||||||
|
from typing import Optional, Union, Dict, Any
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from app.crud.base import CRUDBase
|
||||||
|
from app.models.user import User
|
||||||
|
from app.schemas.users import UserCreate, UserUpdate
|
||||||
|
from app.core.auth import get_password_hash
|
||||||
|
|
||||||
|
|
||||||
|
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||||
|
def get_by_email(self, db: Session, *, email: str) -> Optional[User]:
|
||||||
|
return db.query(User).filter(User.email == email).first()
|
||||||
|
|
||||||
|
def create(self, db: Session, *, obj_in: UserCreate) -> User:
|
||||||
|
db_obj = User(
|
||||||
|
email=obj_in.email,
|
||||||
|
password_hash=get_password_hash(obj_in.password),
|
||||||
|
first_name=obj_in.first_name,
|
||||||
|
last_name=obj_in.last_name,
|
||||||
|
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None,
|
||||||
|
is_superuser=obj_in.is_superuser if hasattr(obj_in, 'is_superuser') else False,
|
||||||
|
preferences={}
|
||||||
|
)
|
||||||
|
db.add(db_obj)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_obj)
|
||||||
|
return db_obj
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
db_obj: User,
|
||||||
|
obj_in: Union[UserUpdate, Dict[str, Any]]
|
||||||
|
) -> User:
|
||||||
|
if isinstance(obj_in, dict):
|
||||||
|
update_data = obj_in
|
||||||
|
else:
|
||||||
|
update_data = obj_in.model_dump(exclude_unset=True)
|
||||||
|
|
||||||
|
# Handle password separately if it exists in update data
|
||||||
|
if "password" in update_data:
|
||||||
|
update_data["password_hash"] = get_password_hash(update_data["password"])
|
||||||
|
del update_data["password"]
|
||||||
|
|
||||||
|
return super().update(db, db_obj=db_obj, obj_in=update_data)
|
||||||
|
|
||||||
|
def is_active(self, user: User) -> bool:
|
||||||
|
return user.is_active
|
||||||
|
|
||||||
|
def is_superuser(self, user: User) -> bool:
|
||||||
|
return user.is_superuser
|
||||||
|
|
||||||
|
|
||||||
|
# Create a singleton instance for use across the application
|
||||||
|
user = CRUDUser(User)
|
||||||
@@ -1,9 +1,18 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
|
|
||||||
from app.config import settings
|
from app.api.main import api_router
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
scheduler = AsyncIOScheduler()
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
logger.info(f"Starting app!!!")
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title=settings.PROJECT_NAME,
|
title=settings.PROJECT_NAME,
|
||||||
version=settings.VERSION,
|
version=settings.VERSION,
|
||||||
@@ -34,3 +43,6 @@ async def root():
|
|||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||||||
|
|||||||
14
backend/app/models/__init__.py
Normal file
14
backend/app/models/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""
|
||||||
|
Models package initialization.
|
||||||
|
Imports all models to ensure they're registered with SQLAlchemy.
|
||||||
|
"""
|
||||||
|
# First import Base to avoid circular imports
|
||||||
|
from app.core.database import Base
|
||||||
|
from .base import TimestampMixin, UUIDMixin
|
||||||
|
|
||||||
|
# Import user model
|
||||||
|
from .user import User
|
||||||
|
__all__ = [
|
||||||
|
'Base', 'TimestampMixin', 'UUIDMixin',
|
||||||
|
'User',
|
||||||
|
]
|
||||||
20
backend/app/models/base.py
Normal file
20
backend/app/models/base.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from sqlalchemy import Column, DateTime
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
from app.core.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
class TimestampMixin:
|
||||||
|
"""Mixin to add created_at and updated_at timestamps to models"""
|
||||||
|
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False)
|
||||||
|
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc),
|
||||||
|
onupdate=lambda: datetime.now(timezone.utc), nullable=False)
|
||||||
|
|
||||||
|
|
||||||
|
class UUIDMixin:
|
||||||
|
"""Mixin to add UUID primary keys to models"""
|
||||||
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
19
backend/app/models/user.py
Normal file
19
backend/app/models/user.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
from sqlalchemy import Column, String, JSON, Boolean
|
||||||
|
|
||||||
|
from .base import Base, TimestampMixin, UUIDMixin
|
||||||
|
|
||||||
|
|
||||||
|
class User(Base, UUIDMixin, TimestampMixin):
|
||||||
|
__tablename__ = 'users'
|
||||||
|
|
||||||
|
email = Column(String, unique=True, nullable=False, index=True)
|
||||||
|
password_hash = Column(String, nullable=False)
|
||||||
|
first_name = Column(String, nullable=False, default="user")
|
||||||
|
last_name = Column(String, nullable=True)
|
||||||
|
phone_number = Column(String)
|
||||||
|
is_active = Column(Boolean, default=True, nullable=False)
|
||||||
|
is_superuser = Column(Boolean, default=False, nullable=False)
|
||||||
|
preferences = Column(JSON)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<User {self.email}>"
|
||||||
0
backend/app/schemas/__init__.py
Normal file
0
backend/app/schemas/__init__.py
Normal file
149
backend/app/schemas/users.py
Normal file
149
backend/app/schemas/users.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
# app/schemas/users.py
|
||||||
|
import re
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import BaseModel, EmailStr, field_validator, ConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class UserBase(BaseModel):
|
||||||
|
email: EmailStr
|
||||||
|
first_name: str
|
||||||
|
last_name: Optional[str] = None
|
||||||
|
phone_number: Optional[str] = None
|
||||||
|
|
||||||
|
@field_validator('phone_number')
|
||||||
|
@classmethod
|
||||||
|
def validate_phone_number(cls, v: Optional[str]) -> Optional[str]:
|
||||||
|
if v is None:
|
||||||
|
return v
|
||||||
|
# Simple regex for phone validation
|
||||||
|
if not re.match(r'^\+?[0-9\s\-\(\)]{8,20}$', v):
|
||||||
|
raise ValueError('Invalid phone number format')
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class UserCreate(UserBase):
|
||||||
|
password: str
|
||||||
|
is_superuser: bool = False
|
||||||
|
|
||||||
|
@field_validator('password')
|
||||||
|
@classmethod
|
||||||
|
def password_strength(cls, v: str) -> str:
|
||||||
|
"""Basic password strength validation"""
|
||||||
|
if len(v) < 8:
|
||||||
|
raise ValueError('Password must be at least 8 characters')
|
||||||
|
if not any(char.isdigit() for char in v):
|
||||||
|
raise ValueError('Password must contain at least one digit')
|
||||||
|
if not any(char.isupper() for char in v):
|
||||||
|
raise ValueError('Password must contain at least one uppercase letter')
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class UserUpdate(BaseModel):
|
||||||
|
first_name: Optional[str] = None
|
||||||
|
last_name: Optional[str] = None
|
||||||
|
phone_number: Optional[str] = None
|
||||||
|
preferences: Optional[Dict[str, Any]] = None
|
||||||
|
is_active: Optional[bool] = True
|
||||||
|
@field_validator('phone_number')
|
||||||
|
def validate_phone_number(cls, v: Optional[str]) -> Optional[str]:
|
||||||
|
if v is None:
|
||||||
|
return v
|
||||||
|
|
||||||
|
# Return early for empty strings or whitespace-only strings
|
||||||
|
if not v or v.strip() == "":
|
||||||
|
raise ValueError('Phone number cannot be empty')
|
||||||
|
|
||||||
|
# Remove all spaces and formatting characters
|
||||||
|
cleaned = re.sub(r'[\s\-\(\)]', '', v)
|
||||||
|
|
||||||
|
# Basic pattern:
|
||||||
|
# Must start with + or 0
|
||||||
|
# After + must have at least 8 digits
|
||||||
|
# After 0 must have at least 8 digits
|
||||||
|
# Maximum total length of 15 digits (international standard)
|
||||||
|
# Only allowed characters are + at start and digits
|
||||||
|
pattern = r'^(?:\+[0-9]{8,14}|0[0-9]{8,14})$'
|
||||||
|
|
||||||
|
if not re.match(pattern, cleaned):
|
||||||
|
raise ValueError('Phone number must start with + or 0 followed by 8-14 digits')
|
||||||
|
|
||||||
|
# Additional validation to catch specific invalid cases
|
||||||
|
if cleaned.count('+') > 1:
|
||||||
|
raise ValueError('Phone number can only contain one + symbol at the start')
|
||||||
|
|
||||||
|
# Check for any non-digit characters (except the leading +)
|
||||||
|
if not all(c.isdigit() for c in cleaned[1:]):
|
||||||
|
raise ValueError('Phone number can only contain digits after the prefix')
|
||||||
|
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
|
class UserInDB(UserBase):
|
||||||
|
id: UUID
|
||||||
|
is_active: bool
|
||||||
|
is_superuser: bool
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
class UserResponse(UserBase):
|
||||||
|
id: UUID
|
||||||
|
is_active: bool
|
||||||
|
is_superuser: bool
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
class Token(BaseModel):
|
||||||
|
access_token: str
|
||||||
|
refresh_token: Optional[str] = None
|
||||||
|
token_type: str = "bearer"
|
||||||
|
|
||||||
|
|
||||||
|
class TokenPayload(BaseModel):
|
||||||
|
sub: str # User ID
|
||||||
|
exp: int # Expiration time
|
||||||
|
iat: Optional[int] = None # Issued at
|
||||||
|
jti: Optional[str] = None # JWT ID
|
||||||
|
is_superuser: Optional[bool] = False
|
||||||
|
first_name: Optional[str] = None
|
||||||
|
email: Optional[str] = None
|
||||||
|
type: Optional[str] = None # Token type (access/refresh)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenData(BaseModel):
|
||||||
|
user_id: UUID
|
||||||
|
is_superuser: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class PasswordReset(BaseModel):
|
||||||
|
token: str
|
||||||
|
new_password: str
|
||||||
|
|
||||||
|
@field_validator('new_password')
|
||||||
|
@classmethod
|
||||||
|
def password_strength(cls, v: str) -> str:
|
||||||
|
"""Basic password strength validation"""
|
||||||
|
if len(v) < 8:
|
||||||
|
raise ValueError('Password must be at least 8 characters')
|
||||||
|
if not any(char.isdigit() for char in v):
|
||||||
|
raise ValueError('Password must contain at least one digit')
|
||||||
|
if not any(char.isupper() for char in v):
|
||||||
|
raise ValueError('Password must contain at least one uppercase letter')
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class LoginRequest(BaseModel):
|
||||||
|
email: EmailStr
|
||||||
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshTokenRequest(BaseModel):
|
||||||
|
refresh_token: str
|
||||||
0
backend/app/services/__init__.py
Normal file
0
backend/app/services/__init__.py
Normal file
193
backend/app/services/auth_service.py
Normal file
193
backend/app/services/auth_service.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
# app/services/auth_service.py
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.auth import (
|
||||||
|
verify_password,
|
||||||
|
get_password_hash,
|
||||||
|
create_access_token,
|
||||||
|
create_refresh_token,
|
||||||
|
TokenExpiredError,
|
||||||
|
TokenInvalidError
|
||||||
|
)
|
||||||
|
from app.models.user import User
|
||||||
|
from app.schemas.users import Token, UserCreate
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationError(Exception):
|
||||||
|
"""Exception raised for authentication errors"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AuthService:
|
||||||
|
"""Service for handling authentication operations"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
|
||||||
|
"""
|
||||||
|
Authenticate a user with email and password.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
email: User email
|
||||||
|
password: User password
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User if authenticated, None otherwise
|
||||||
|
"""
|
||||||
|
user = db.query(User).filter(User.email == email).first()
|
||||||
|
|
||||||
|
if not user:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not verify_password(password, user.password_hash):
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not user.is_active:
|
||||||
|
raise AuthenticationError("User account is inactive")
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_user(db: Session, user_data: UserCreate) -> User:
|
||||||
|
"""
|
||||||
|
Create a new user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
user_data: User data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created user
|
||||||
|
"""
|
||||||
|
# Check if user already exists
|
||||||
|
existing_user = db.query(User).filter(User.email == user_data.email).first()
|
||||||
|
if existing_user:
|
||||||
|
raise AuthenticationError("User with this email already exists")
|
||||||
|
|
||||||
|
# Create new user
|
||||||
|
hashed_password = get_password_hash(user_data.password)
|
||||||
|
|
||||||
|
# Create user object from model
|
||||||
|
user = User(
|
||||||
|
email=user_data.email,
|
||||||
|
password_hash=hashed_password,
|
||||||
|
first_name=user_data.first_name,
|
||||||
|
last_name=user_data.last_name,
|
||||||
|
phone_number=user_data.phone_number,
|
||||||
|
is_active=True,
|
||||||
|
is_superuser=False
|
||||||
|
)
|
||||||
|
|
||||||
|
db.add(user)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(user)
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_tokens(user: User) -> Token:
|
||||||
|
"""
|
||||||
|
Create access and refresh tokens for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User to create tokens for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Token object with access and refresh tokens
|
||||||
|
"""
|
||||||
|
# Generate claims
|
||||||
|
claims = {
|
||||||
|
"is_superuser": user.is_superuser,
|
||||||
|
"email": user.email,
|
||||||
|
"first_name": user.first_name
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create tokens
|
||||||
|
access_token = create_access_token(
|
||||||
|
subject=str(user.id),
|
||||||
|
claims=claims
|
||||||
|
)
|
||||||
|
|
||||||
|
refresh_token = create_refresh_token(
|
||||||
|
subject=str(user.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
return Token(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=refresh_token
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def refresh_tokens(db: Session, refresh_token: str) -> Token:
|
||||||
|
"""
|
||||||
|
Generate new tokens using a refresh token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
refresh_token: Valid refresh token
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New access and refresh tokens
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TokenExpiredError: If refresh token has expired
|
||||||
|
TokenInvalidError: If refresh token is invalid
|
||||||
|
"""
|
||||||
|
from app.core.auth import decode_token, get_token_data
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Verify token is a refresh token
|
||||||
|
decode_token(refresh_token, verify_type="refresh")
|
||||||
|
|
||||||
|
# Get user ID from token
|
||||||
|
token_data = get_token_data(refresh_token)
|
||||||
|
user_id = token_data.user_id
|
||||||
|
|
||||||
|
# Get user from database
|
||||||
|
user = db.query(User).filter(User.id == user_id).first()
|
||||||
|
if not user or not user.is_active:
|
||||||
|
raise TokenInvalidError("Invalid user or inactive account")
|
||||||
|
|
||||||
|
# Generate new tokens
|
||||||
|
return AuthService.create_tokens(user)
|
||||||
|
|
||||||
|
except (TokenExpiredError, TokenInvalidError) as e:
|
||||||
|
logger.warning(f"Token refresh failed: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def change_password(db: Session, user_id: UUID, current_password: str, new_password: str) -> bool:
|
||||||
|
"""
|
||||||
|
Change a user's password.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
user_id: User ID
|
||||||
|
current_password: Current password
|
||||||
|
new_password: New password
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if password was changed successfully
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AuthenticationError: If current password is incorrect
|
||||||
|
"""
|
||||||
|
user = db.query(User).filter(User.id == user_id).first()
|
||||||
|
if not user:
|
||||||
|
raise AuthenticationError("User not found")
|
||||||
|
|
||||||
|
# Verify current password
|
||||||
|
if not verify_password(current_password, user.password_hash):
|
||||||
|
raise AuthenticationError("Current password is incorrect")
|
||||||
|
|
||||||
|
# Update password
|
||||||
|
user.password_hash = get_password_hash(new_password)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
return True
|
||||||
0
backend/app/utils/__init__.py
Normal file
0
backend/app/utils/__init__.py
Normal file
79
backend/app/utils/test_utils.py
Normal file
79
backend/app/utils/test_utils.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
import logging
|
||||||
|
from sqlalchemy import create_engine, event
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker, clear_mappers
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
|
from app.core.database import Base
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def get_test_engine():
|
||||||
|
"""Create an SQLite in-memory engine specifically for testing"""
|
||||||
|
test_engine = create_engine(
|
||||||
|
"sqlite:///:memory:",
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
poolclass=StaticPool, # Use static pool for in-memory testing
|
||||||
|
echo=False
|
||||||
|
)
|
||||||
|
|
||||||
|
return test_engine
|
||||||
|
|
||||||
|
def setup_test_db():
|
||||||
|
"""Create a test database and session factory"""
|
||||||
|
# Create a new engine for this test run
|
||||||
|
test_engine = get_test_engine()
|
||||||
|
|
||||||
|
# Create tables
|
||||||
|
Base.metadata.create_all(test_engine)
|
||||||
|
|
||||||
|
# Create session factory
|
||||||
|
TestingSessionLocal = sessionmaker(
|
||||||
|
autocommit=False,
|
||||||
|
autoflush=False,
|
||||||
|
bind=test_engine,
|
||||||
|
expire_on_commit=False
|
||||||
|
)
|
||||||
|
|
||||||
|
return test_engine, TestingSessionLocal
|
||||||
|
|
||||||
|
def teardown_test_db(engine):
|
||||||
|
"""Clean up after tests"""
|
||||||
|
# Drop all tables
|
||||||
|
Base.metadata.drop_all(engine)
|
||||||
|
|
||||||
|
# Dispose of engine
|
||||||
|
engine.dispose()
|
||||||
|
|
||||||
|
async def get_async_test_engine():
|
||||||
|
"""Create an async SQLite in-memory engine specifically for testing"""
|
||||||
|
test_engine = create_async_engine(
|
||||||
|
"sqlite+aiosqlite:///:memory:",
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
poolclass=StaticPool, # Use static pool for in-memory testing
|
||||||
|
echo=False
|
||||||
|
)
|
||||||
|
return test_engine
|
||||||
|
|
||||||
|
|
||||||
|
async def setup_async_test_db():
|
||||||
|
"""Create an async test database and session factory"""
|
||||||
|
test_engine = await get_async_test_engine()
|
||||||
|
|
||||||
|
async with test_engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
|
AsyncTestingSessionLocal = sessionmaker(
|
||||||
|
autocommit=False,
|
||||||
|
autoflush=False,
|
||||||
|
bind=test_engine,
|
||||||
|
expire_on_commit=False,
|
||||||
|
class_=AsyncSession
|
||||||
|
)
|
||||||
|
|
||||||
|
return test_engine, AsyncTestingSessionLocal
|
||||||
|
|
||||||
|
|
||||||
|
async def teardown_async_test_db(engine):
|
||||||
|
"""Clean up after async tests"""
|
||||||
|
await engine.dispose()
|
||||||
10
backend/pytest.ini
Normal file
10
backend/pytest.ini
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
[pytest]
|
||||||
|
env =
|
||||||
|
IS_TEST=True
|
||||||
|
testpaths = tests
|
||||||
|
python_files = test_*.py
|
||||||
|
addopts = --disable-warnings
|
||||||
|
markers =
|
||||||
|
sqlite: marks tests that should run on SQLite (mocked).
|
||||||
|
postgres: marks tests that require a real PostgreSQL database.
|
||||||
|
asyncio_default_fixture_loop_scope = function
|
||||||
@@ -4,13 +4,14 @@ uvicorn>=0.34.0
|
|||||||
pydantic>=2.10.6
|
pydantic>=2.10.6
|
||||||
pydantic-settings>=2.2.1
|
pydantic-settings>=2.2.1
|
||||||
python-multipart>=0.0.19
|
python-multipart>=0.0.19
|
||||||
|
fastapi-utils==0.8.0
|
||||||
|
|
||||||
# Database
|
# Database
|
||||||
sqlalchemy>=2.0.29
|
sqlalchemy>=2.0.29
|
||||||
alembic>=1.14.1
|
alembic>=1.14.1
|
||||||
psycopg2-binary>=2.9.9
|
psycopg2-binary>=2.9.9
|
||||||
asyncpg>=0.29.0
|
asyncpg>=0.29.0
|
||||||
|
aiosqlite==0.21.0
|
||||||
# Security and authentication
|
# Security and authentication
|
||||||
python-jose>=3.4.0
|
python-jose>=3.4.0
|
||||||
passlib>=1.7.4
|
passlib>=1.7.4
|
||||||
@@ -30,7 +31,7 @@ httpx>=0.27.0
|
|||||||
tenacity>=8.2.3
|
tenacity>=8.2.3
|
||||||
pytz>=2024.1
|
pytz>=2024.1
|
||||||
pillow>=10.3.0
|
pillow>=10.3.0
|
||||||
|
apscheduler==3.11.0
|
||||||
# Testing
|
# Testing
|
||||||
pytest>=8.0.0
|
pytest>=8.0.0
|
||||||
pytest-asyncio>=0.23.5
|
pytest-asyncio>=0.23.5
|
||||||
@@ -42,3 +43,10 @@ black>=24.3.0
|
|||||||
isort>=5.13.2
|
isort>=5.13.2
|
||||||
flake8>=7.0.0
|
flake8>=7.0.0
|
||||||
mypy>=1.8.0
|
mypy>=1.8.0
|
||||||
|
|
||||||
|
# Security
|
||||||
|
python-jose==3.4.0
|
||||||
|
bcrypt==4.2.1
|
||||||
|
cryptography==44.0.1
|
||||||
|
passlib==1.7.4
|
||||||
|
freezegun~=1.5.1
|
||||||
0
backend/tests/api/routes/__init__.py
Normal file
0
backend/tests/api/routes/__init__.py
Normal file
369
backend/tests/api/routes/test_auth.py
Normal file
369
backend/tests/api/routes/test_auth.py
Normal file
@@ -0,0 +1,369 @@
|
|||||||
|
# tests/api/routes/test_auth.py
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import patch, MagicMock, Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI, Depends
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.api.routes.auth import router as auth_router
|
||||||
|
from app.core.auth import get_password_hash
|
||||||
|
from app.core.database import get_db
|
||||||
|
from app.models.user import User
|
||||||
|
from app.services.auth_service import AuthService, AuthenticationError
|
||||||
|
from app.core.auth import TokenExpiredError, TokenInvalidError
|
||||||
|
|
||||||
|
|
||||||
|
# Mock the get_db dependency
|
||||||
|
@pytest.fixture
|
||||||
|
def override_get_db(db_session):
|
||||||
|
"""Override get_db dependency for testing."""
|
||||||
|
return db_session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app(override_get_db):
|
||||||
|
"""Create a FastAPI test application with overridden dependencies."""
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(auth_router, prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
|
# Override the get_db dependency
|
||||||
|
app.dependency_overrides[get_db] = lambda: override_get_db
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(app):
|
||||||
|
"""Create a FastAPI test client."""
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRegisterUser:
|
||||||
|
"""Tests for the register_user endpoint."""
|
||||||
|
|
||||||
|
def test_register_user_success(self, client, monkeypatch, db_session):
|
||||||
|
"""Test successful user registration."""
|
||||||
|
# Mock the service method with a valid complete User object
|
||||||
|
mock_user = User(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
email="newuser@example.com",
|
||||||
|
password_hash="hashed_password",
|
||||||
|
first_name="New",
|
||||||
|
last_name="User",
|
||||||
|
is_active=True,
|
||||||
|
is_superuser=False,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use patch for mocking
|
||||||
|
with patch.object(AuthService, 'create_user', return_value=mock_user):
|
||||||
|
# Test request
|
||||||
|
response = client.post(
|
||||||
|
"/auth/register",
|
||||||
|
json={
|
||||||
|
"email": "newuser@example.com",
|
||||||
|
"password": "Password123",
|
||||||
|
"first_name": "New",
|
||||||
|
"last_name": "User"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert response.status_code == 201
|
||||||
|
data = response.json()
|
||||||
|
assert data["email"] == "newuser@example.com"
|
||||||
|
assert data["first_name"] == "New"
|
||||||
|
assert data["last_name"] == "User"
|
||||||
|
assert "password" not in data
|
||||||
|
|
||||||
|
def test_register_user_duplicate_email(self, client, db_session):
|
||||||
|
"""Test registration with duplicate email."""
|
||||||
|
# Use patch for mocking with a side effect
|
||||||
|
with patch.object(AuthService, 'create_user',
|
||||||
|
side_effect=AuthenticationError("User with this email already exists")):
|
||||||
|
# Test request
|
||||||
|
response = client.post(
|
||||||
|
"/auth/register",
|
||||||
|
json={
|
||||||
|
"email": "existing@example.com",
|
||||||
|
"password": "Password123",
|
||||||
|
"first_name": "Existing",
|
||||||
|
"last_name": "User"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert response.status_code == 409
|
||||||
|
assert "already exists" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestLogin:
|
||||||
|
"""Tests for the login endpoint."""
|
||||||
|
|
||||||
|
def test_login_success(self, client, mock_user, db_session):
|
||||||
|
"""Test successful login."""
|
||||||
|
# Ensure mock_user has required attributes
|
||||||
|
if not hasattr(mock_user, 'created_at') or mock_user.created_at is None:
|
||||||
|
mock_user.created_at = datetime.now(timezone.utc)
|
||||||
|
if not hasattr(mock_user, 'updated_at') or mock_user.updated_at is None:
|
||||||
|
mock_user.updated_at = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# Create mock tokens
|
||||||
|
mock_tokens = MagicMock(
|
||||||
|
access_token="mock_access_token",
|
||||||
|
refresh_token="mock_refresh_token",
|
||||||
|
token_type="bearer"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use context managers for patching
|
||||||
|
with patch.object(AuthService, 'authenticate_user', return_value=mock_user), \
|
||||||
|
patch.object(AuthService, 'create_tokens', return_value=mock_tokens):
|
||||||
|
|
||||||
|
# Test request
|
||||||
|
response = client.post(
|
||||||
|
"/auth/login",
|
||||||
|
json={
|
||||||
|
"email": "user@example.com",
|
||||||
|
"password": "Password123"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["access_token"] == "mock_access_token"
|
||||||
|
assert data["refresh_token"] == "mock_refresh_token"
|
||||||
|
assert data["token_type"] == "bearer"
|
||||||
|
|
||||||
|
|
||||||
|
def test_login_invalid_credentials_debug(self, client, app):
|
||||||
|
"""Improved test for login with invalid credentials."""
|
||||||
|
# Print response for debugging
|
||||||
|
from app.services.auth_service import AuthService
|
||||||
|
|
||||||
|
# Create a complete mock for AuthService
|
||||||
|
class MockAuthService:
|
||||||
|
@staticmethod
|
||||||
|
def authenticate_user(db, email, password):
|
||||||
|
print(f"Mock called with: {email}, {password}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Replace the entire class with our mock
|
||||||
|
original_service = AuthService
|
||||||
|
try:
|
||||||
|
# Replace with our mock
|
||||||
|
import sys
|
||||||
|
sys.modules['app.services.auth_service'].AuthService = MockAuthService
|
||||||
|
|
||||||
|
# Make the request
|
||||||
|
response = client.post(
|
||||||
|
"/auth/login",
|
||||||
|
json={
|
||||||
|
"email": "user@example.com",
|
||||||
|
"password": "WrongPassword"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print response details for debugging
|
||||||
|
print(f"Response status: {response.status_code}")
|
||||||
|
print(f"Response body: {response.text}")
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Invalid email or password" in response.json()["detail"]
|
||||||
|
finally:
|
||||||
|
# Restore original service
|
||||||
|
sys.modules['app.services.auth_service'].AuthService = original_service
|
||||||
|
|
||||||
|
|
||||||
|
def test_login_inactive_user(self, client, db_session):
|
||||||
|
"""Test login with inactive user."""
|
||||||
|
# Mock authentication to raise an error
|
||||||
|
with patch.object(AuthService, 'authenticate_user',
|
||||||
|
side_effect=AuthenticationError("User account is inactive")):
|
||||||
|
# Test request
|
||||||
|
response = client.post(
|
||||||
|
"/auth/login",
|
||||||
|
json={
|
||||||
|
"email": "inactive@example.com",
|
||||||
|
"password": "Password123"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "inactive" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestRefreshToken:
|
||||||
|
"""Tests for the refresh_token endpoint."""
|
||||||
|
|
||||||
|
def test_refresh_token_success(self, client, db_session):
|
||||||
|
"""Test successful token refresh."""
|
||||||
|
# Mock refresh to return tokens
|
||||||
|
mock_tokens = MagicMock(
|
||||||
|
access_token="new_access_token",
|
||||||
|
refresh_token="new_refresh_token",
|
||||||
|
token_type="bearer"
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(AuthService, 'refresh_tokens', return_value=mock_tokens):
|
||||||
|
# Test request
|
||||||
|
response = client.post(
|
||||||
|
"/auth/refresh",
|
||||||
|
json={
|
||||||
|
"refresh_token": "valid_refresh_token"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["access_token"] == "new_access_token"
|
||||||
|
assert data["refresh_token"] == "new_refresh_token"
|
||||||
|
assert data["token_type"] == "bearer"
|
||||||
|
|
||||||
|
def test_refresh_token_expired(self, client, db_session):
|
||||||
|
"""Test refresh with expired token."""
|
||||||
|
# Mock refresh to raise expired token error
|
||||||
|
with patch.object(AuthService, 'refresh_tokens',
|
||||||
|
side_effect=TokenExpiredError("Token expired")):
|
||||||
|
# Test request
|
||||||
|
response = client.post(
|
||||||
|
"/auth/refresh",
|
||||||
|
json={
|
||||||
|
"refresh_token": "expired_refresh_token"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "expired" in response.json()["detail"]
|
||||||
|
|
||||||
|
def test_refresh_token_invalid(self, client, db_session):
|
||||||
|
"""Test refresh with invalid token."""
|
||||||
|
# Mock refresh to raise invalid token error
|
||||||
|
with patch.object(AuthService, 'refresh_tokens',
|
||||||
|
side_effect=TokenInvalidError("Invalid token")):
|
||||||
|
# Test request
|
||||||
|
response = client.post(
|
||||||
|
"/auth/refresh",
|
||||||
|
json={
|
||||||
|
"refresh_token": "invalid_refresh_token"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Invalid" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestChangePassword:
|
||||||
|
"""Tests for the change_password endpoint."""
|
||||||
|
|
||||||
|
def test_change_password_success(self, client, mock_user, db_session, app):
|
||||||
|
"""Test successful password change."""
|
||||||
|
# Ensure mock_user has required attributes
|
||||||
|
if not hasattr(mock_user, 'created_at') or mock_user.created_at is None:
|
||||||
|
mock_user.created_at = datetime.now(timezone.utc)
|
||||||
|
if not hasattr(mock_user, 'updated_at') or mock_user.updated_at is None:
|
||||||
|
mock_user.updated_at = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# Override get_current_user dependency
|
||||||
|
from app.api.dependencies.auth import get_current_user
|
||||||
|
app.dependency_overrides[get_current_user] = lambda: mock_user
|
||||||
|
|
||||||
|
# Mock password change to return success
|
||||||
|
with patch.object(AuthService, 'change_password', return_value=True):
|
||||||
|
# Test request
|
||||||
|
response = client.post(
|
||||||
|
"/auth/change-password",
|
||||||
|
json={
|
||||||
|
"current_password": "OldPassword123",
|
||||||
|
"new_password": "NewPassword123"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "success" in response.json()["message"].lower()
|
||||||
|
|
||||||
|
# Clean up override
|
||||||
|
if get_current_user in app.dependency_overrides:
|
||||||
|
del app.dependency_overrides[get_current_user]
|
||||||
|
|
||||||
|
def test_change_password_incorrect_current_password(self, client, mock_user, db_session, app):
|
||||||
|
"""Test change password with incorrect current password."""
|
||||||
|
# Ensure mock_user has required attributes
|
||||||
|
if not hasattr(mock_user, 'created_at') or mock_user.created_at is None:
|
||||||
|
mock_user.created_at = datetime.now(timezone.utc)
|
||||||
|
if not hasattr(mock_user, 'updated_at') or mock_user.updated_at is None:
|
||||||
|
mock_user.updated_at = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# Override get_current_user dependency
|
||||||
|
from app.api.dependencies.auth import get_current_user
|
||||||
|
app.dependency_overrides[get_current_user] = lambda: mock_user
|
||||||
|
|
||||||
|
# Mock password change to raise error
|
||||||
|
with patch.object(AuthService, 'change_password',
|
||||||
|
side_effect=AuthenticationError("Current password is incorrect")):
|
||||||
|
# Test request
|
||||||
|
response = client.post(
|
||||||
|
"/auth/change-password",
|
||||||
|
json={
|
||||||
|
"current_password": "WrongPassword",
|
||||||
|
"new_password": "NewPassword123"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "incorrect" in response.json()["detail"].lower()
|
||||||
|
|
||||||
|
# Clean up override
|
||||||
|
if get_current_user in app.dependency_overrides:
|
||||||
|
del app.dependency_overrides[get_current_user]
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetCurrentUserInfo:
|
||||||
|
"""Tests for the get_current_user_info endpoint."""
|
||||||
|
|
||||||
|
def test_get_current_user_info(self, client, mock_user, app):
|
||||||
|
"""Test getting current user info."""
|
||||||
|
# Ensure mock_user has required attributes
|
||||||
|
if not hasattr(mock_user, 'created_at') or mock_user.created_at is None:
|
||||||
|
mock_user.created_at = datetime.now(timezone.utc)
|
||||||
|
if not hasattr(mock_user, 'updated_at') or mock_user.updated_at is None:
|
||||||
|
mock_user.updated_at = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# Override get_current_user dependency
|
||||||
|
from app.api.dependencies.auth import get_current_user
|
||||||
|
app.dependency_overrides[get_current_user] = lambda: mock_user
|
||||||
|
|
||||||
|
# Test request
|
||||||
|
response = client.get("/auth/me")
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["email"] == mock_user.email
|
||||||
|
assert data["first_name"] == mock_user.first_name
|
||||||
|
assert data["last_name"] == mock_user.last_name
|
||||||
|
assert "password" not in data
|
||||||
|
|
||||||
|
# Clean up override
|
||||||
|
if get_current_user in app.dependency_overrides:
|
||||||
|
del app.dependency_overrides[get_current_user]
|
||||||
|
|
||||||
|
def test_get_current_user_info_unauthorized(self, client):
|
||||||
|
"""Test getting user info without authentication."""
|
||||||
|
# Test request without authentication
|
||||||
|
response = client.get("/auth/me")
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert response.status_code == 401
|
||||||
211
backend/tests/api/test_auth_dependencies.py
Normal file
211
backend/tests/api/test_auth_dependencies.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
# tests/api/dependencies/test_auth_dependencies.py
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from app.api.dependencies.auth import (
|
||||||
|
get_current_user,
|
||||||
|
get_current_active_user,
|
||||||
|
get_current_superuser,
|
||||||
|
get_optional_current_user
|
||||||
|
)
|
||||||
|
from app.core.auth import TokenExpiredError, TokenInvalidError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_token():
|
||||||
|
return "mock.jwt.token"
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetCurrentUser:
|
||||||
|
"""Tests for get_current_user dependency"""
|
||||||
|
|
||||||
|
def test_get_current_user_success(self, db_session, mock_user, mock_token):
|
||||||
|
"""Test successfully getting the current user"""
|
||||||
|
# Mock get_token_data to return user_id that matches our mock_user
|
||||||
|
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||||
|
mock_get_data.return_value.user_id = mock_user.id
|
||||||
|
|
||||||
|
# Call the dependency
|
||||||
|
user = get_current_user(db=db_session, token=mock_token)
|
||||||
|
|
||||||
|
# Verify the correct user was returned
|
||||||
|
assert user.id == mock_user.id
|
||||||
|
assert user.email == mock_user.email
|
||||||
|
|
||||||
|
def test_get_current_user_nonexistent(self, db_session, mock_token):
|
||||||
|
"""Test when the token contains a user ID that doesn't exist"""
|
||||||
|
# Mock get_token_data to return a non-existent user ID
|
||||||
|
# Use a real UUID object instead of a string
|
||||||
|
import uuid
|
||||||
|
nonexistent_id = uuid.UUID("11111111-1111-1111-1111-111111111111")
|
||||||
|
|
||||||
|
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||||
|
mock_get_data.return_value.user_id = nonexistent_id # Using UUID object, not string
|
||||||
|
|
||||||
|
# Should raise HTTPException with 404 status
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
get_current_user(db=db_session, token=mock_token)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 404
|
||||||
|
|
||||||
|
def test_get_current_user_inactive(self, db_session, mock_user, mock_token):
|
||||||
|
"""Test when the user is inactive"""
|
||||||
|
# Make the user inactive
|
||||||
|
mock_user.is_active = False
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Mock get_token_data
|
||||||
|
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||||
|
mock_get_data.return_value.user_id = mock_user.id
|
||||||
|
|
||||||
|
# Should raise HTTPException with 403 status
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
get_current_user(db=db_session, token=mock_token)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 403
|
||||||
|
|
||||||
|
def test_get_current_user_expired_token(self, db_session, mock_token):
|
||||||
|
"""Test with an expired token"""
|
||||||
|
# Mock get_token_data to raise TokenExpiredError
|
||||||
|
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||||
|
mock_get_data.side_effect = TokenExpiredError("Token expired")
|
||||||
|
|
||||||
|
# Should raise HTTPException with 401 status
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
get_current_user(db=db_session, token=mock_token)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 401
|
||||||
|
assert "Token expired" in exc_info.value.detail
|
||||||
|
|
||||||
|
def test_get_current_user_invalid_token(self, db_session, mock_token):
|
||||||
|
"""Test with an invalid token"""
|
||||||
|
# Mock get_token_data to raise TokenInvalidError
|
||||||
|
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||||
|
mock_get_data.side_effect = TokenInvalidError("Invalid token")
|
||||||
|
|
||||||
|
# Should raise HTTPException with 401 status
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
get_current_user(db=db_session, token=mock_token)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 401
|
||||||
|
assert "Could not validate credentials" in exc_info.value.detail
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetCurrentActiveUser:
|
||||||
|
"""Tests for get_current_active_user dependency"""
|
||||||
|
|
||||||
|
def test_get_current_active_user(self, mock_user):
|
||||||
|
"""Test getting an active user"""
|
||||||
|
# Ensure user is active
|
||||||
|
mock_user.is_active = True
|
||||||
|
|
||||||
|
# Call the dependency with mocked current_user
|
||||||
|
user = get_current_active_user(current_user=mock_user)
|
||||||
|
|
||||||
|
# Should return the same user
|
||||||
|
assert user == mock_user
|
||||||
|
|
||||||
|
def test_get_current_inactive_user(self, mock_user):
|
||||||
|
"""Test getting an inactive user"""
|
||||||
|
# Make user inactive
|
||||||
|
mock_user.is_active = False
|
||||||
|
|
||||||
|
# Should raise HTTPException with 403 status
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
get_current_active_user(current_user=mock_user)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 403
|
||||||
|
assert "Inactive user" in exc_info.value.detail
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetCurrentSuperuser:
|
||||||
|
"""Tests for get_current_superuser dependency"""
|
||||||
|
|
||||||
|
def test_get_current_superuser(self, mock_user):
|
||||||
|
"""Test getting a superuser"""
|
||||||
|
# Make user a superuser
|
||||||
|
mock_user.is_superuser = True
|
||||||
|
|
||||||
|
# Call the dependency with mocked current_user
|
||||||
|
user = get_current_superuser(current_user=mock_user)
|
||||||
|
|
||||||
|
# Should return the same user
|
||||||
|
assert user == mock_user
|
||||||
|
|
||||||
|
def test_get_current_non_superuser(self, mock_user):
|
||||||
|
"""Test getting a non-superuser"""
|
||||||
|
# Ensure user is not a superuser
|
||||||
|
mock_user.is_superuser = False
|
||||||
|
|
||||||
|
# Should raise HTTPException with 403 status
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
get_current_superuser(current_user=mock_user)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 403
|
||||||
|
assert "Not enough permissions" in exc_info.value.detail
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetOptionalCurrentUser:
|
||||||
|
"""Tests for get_optional_current_user dependency"""
|
||||||
|
|
||||||
|
def test_get_optional_current_user_with_token(self, db_session, mock_user, mock_token):
|
||||||
|
"""Test getting optional user with a valid token"""
|
||||||
|
# Mock get_token_data
|
||||||
|
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||||
|
mock_get_data.return_value.user_id = mock_user.id
|
||||||
|
|
||||||
|
# Call the dependency
|
||||||
|
user = get_optional_current_user(db=db_session, token=mock_token)
|
||||||
|
|
||||||
|
# Should return the correct user
|
||||||
|
assert user is not None
|
||||||
|
assert user.id == mock_user.id
|
||||||
|
|
||||||
|
def test_get_optional_current_user_no_token(self, db_session):
|
||||||
|
"""Test getting optional user with no token"""
|
||||||
|
# Call the dependency with no token
|
||||||
|
user = get_optional_current_user(db=db_session, token=None)
|
||||||
|
|
||||||
|
# Should return None
|
||||||
|
assert user is None
|
||||||
|
|
||||||
|
def test_get_optional_current_user_invalid_token(self, db_session, mock_token):
|
||||||
|
"""Test getting optional user with an invalid token"""
|
||||||
|
# Mock get_token_data to raise TokenInvalidError
|
||||||
|
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||||
|
mock_get_data.side_effect = TokenInvalidError("Invalid token")
|
||||||
|
|
||||||
|
# Call the dependency
|
||||||
|
user = get_optional_current_user(db=db_session, token=mock_token)
|
||||||
|
|
||||||
|
# Should return None, not raise an exception
|
||||||
|
assert user is None
|
||||||
|
|
||||||
|
def test_get_optional_current_user_expired_token(self, db_session, mock_token):
|
||||||
|
"""Test getting optional user with an expired token"""
|
||||||
|
# Mock get_token_data to raise TokenExpiredError
|
||||||
|
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||||
|
mock_get_data.side_effect = TokenExpiredError("Token expired")
|
||||||
|
|
||||||
|
# Call the dependency
|
||||||
|
user = get_optional_current_user(db=db_session, token=mock_token)
|
||||||
|
|
||||||
|
# Should return None, not raise an exception
|
||||||
|
assert user is None
|
||||||
|
|
||||||
|
def test_get_optional_current_user_inactive(self, db_session, mock_user, mock_token):
|
||||||
|
"""Test getting optional user when user is inactive"""
|
||||||
|
# Make the user inactive
|
||||||
|
mock_user.is_active = False
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Mock get_token_data
|
||||||
|
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||||
|
mock_get_data.return_value.user_id = mock_user.id
|
||||||
|
|
||||||
|
# Call the dependency
|
||||||
|
user = get_optional_current_user(db=db_session, token=mock_token)
|
||||||
|
|
||||||
|
# Should return None for inactive users
|
||||||
|
assert user is None
|
||||||
66
backend/tests/conftest.py
Normal file
66
backend/tests/conftest.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
# tests/conftest.py
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.models.user import User
|
||||||
|
from app.utils.test_utils import setup_test_db, teardown_test_db, setup_async_test_db, teardown_async_test_db
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def db_session():
|
||||||
|
"""
|
||||||
|
Creates a fresh SQLite in-memory database for each test function.
|
||||||
|
|
||||||
|
Yields a SQLAlchemy session that can be used for testing.
|
||||||
|
"""
|
||||||
|
# Set up the database
|
||||||
|
test_engine, TestingSessionLocal = setup_test_db()
|
||||||
|
|
||||||
|
# Create a session
|
||||||
|
with TestingSessionLocal() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
teardown_test_db(test_engine)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function") # Define a fixture
|
||||||
|
async def async_test_db():
|
||||||
|
"""Fixture provides new testing engine and session for each test run to improve isolation."""
|
||||||
|
|
||||||
|
test_engine, AsyncTestingSessionLocal = await setup_async_test_db()
|
||||||
|
yield test_engine, AsyncTestingSessionLocal
|
||||||
|
await teardown_async_test_db(test_engine)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def user_create_data():
|
||||||
|
return {
|
||||||
|
"email": "newtest@example.com", # Changed to avoid conflict with mock_user
|
||||||
|
"password": "TestPassword123!",
|
||||||
|
"first_name": "Test",
|
||||||
|
"last_name": "User",
|
||||||
|
"phone_number": "+1234567890",
|
||||||
|
"is_superuser": False,
|
||||||
|
"preferences": None
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_user(db_session):
|
||||||
|
"""Fixture to create and return a mock User instance."""
|
||||||
|
mock_user = User(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
email="mockuser@example.com",
|
||||||
|
password_hash="mockhashedpassword",
|
||||||
|
first_name="Mock",
|
||||||
|
last_name="User",
|
||||||
|
phone_number="1234567890",
|
||||||
|
is_active=True,
|
||||||
|
is_superuser=False,
|
||||||
|
preferences=None,
|
||||||
|
)
|
||||||
|
db_session.add(mock_user)
|
||||||
|
db_session.commit()
|
||||||
|
return mock_user
|
||||||
0
backend/tests/core/__init__.py
Normal file
0
backend/tests/core/__init__.py
Normal file
260
backend/tests/core/test_auth.py
Normal file
260
backend/tests/core/test_auth.py
Normal file
@@ -0,0 +1,260 @@
|
|||||||
|
# tests/core/test_auth.py
|
||||||
|
import uuid
|
||||||
|
import pytest
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from jose import jwt
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from app.core.auth import (
|
||||||
|
verify_password,
|
||||||
|
get_password_hash,
|
||||||
|
create_access_token,
|
||||||
|
create_refresh_token,
|
||||||
|
decode_token,
|
||||||
|
get_token_data,
|
||||||
|
TokenExpiredError,
|
||||||
|
TokenInvalidError,
|
||||||
|
TokenMissingClaimError
|
||||||
|
)
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
class TestPasswordHandling:
|
||||||
|
"""Tests for password hashing and verification functions"""
|
||||||
|
|
||||||
|
def test_password_hash_different_from_password(self):
|
||||||
|
"""Test that a password hash is different from the original password"""
|
||||||
|
password = "TestPassword123"
|
||||||
|
hashed = get_password_hash(password)
|
||||||
|
assert hashed != password
|
||||||
|
|
||||||
|
def test_verify_correct_password(self):
|
||||||
|
"""Test that verify_password returns True for the correct password"""
|
||||||
|
password = "TestPassword123"
|
||||||
|
hashed = get_password_hash(password)
|
||||||
|
assert verify_password(password, hashed) is True
|
||||||
|
|
||||||
|
def test_verify_incorrect_password(self):
|
||||||
|
"""Test that verify_password returns False for an incorrect password"""
|
||||||
|
password = "TestPassword123"
|
||||||
|
wrong_password = "WrongPassword123"
|
||||||
|
hashed = get_password_hash(password)
|
||||||
|
assert verify_password(wrong_password, hashed) is False
|
||||||
|
|
||||||
|
def test_same_password_different_hash(self):
|
||||||
|
"""Test that the same password gets a different hash each time"""
|
||||||
|
password = "TestPassword123"
|
||||||
|
hash1 = get_password_hash(password)
|
||||||
|
hash2 = get_password_hash(password)
|
||||||
|
assert hash1 != hash2
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenCreation:
|
||||||
|
"""Tests for token creation functions"""
|
||||||
|
|
||||||
|
def test_create_access_token(self):
|
||||||
|
"""Test that an access token is created with the correct claims"""
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
custom_claims = {
|
||||||
|
"email": "test@example.com",
|
||||||
|
"first_name": "Test",
|
||||||
|
"is_superuser": True
|
||||||
|
}
|
||||||
|
token = create_access_token(subject=user_id, claims=custom_claims)
|
||||||
|
|
||||||
|
# Decode token to verify claims
|
||||||
|
payload = jwt.decode(
|
||||||
|
token,
|
||||||
|
settings.SECRET_KEY,
|
||||||
|
algorithms=[settings.ALGORITHM]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check standard claims
|
||||||
|
assert payload["sub"] == user_id
|
||||||
|
assert "jti" in payload
|
||||||
|
assert "exp" in payload
|
||||||
|
assert "iat" in payload
|
||||||
|
assert payload["type"] == "access"
|
||||||
|
|
||||||
|
# Check custom claims
|
||||||
|
for key, value in custom_claims.items():
|
||||||
|
assert payload[key] == value
|
||||||
|
|
||||||
|
def test_create_refresh_token(self):
|
||||||
|
"""Test that a refresh token is created with the correct claims"""
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
token = create_refresh_token(subject=user_id)
|
||||||
|
|
||||||
|
# Decode token to verify claims
|
||||||
|
payload = jwt.decode(
|
||||||
|
token,
|
||||||
|
settings.SECRET_KEY,
|
||||||
|
algorithms=[settings.ALGORITHM]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check standard claims
|
||||||
|
assert payload["sub"] == user_id
|
||||||
|
assert "jti" in payload
|
||||||
|
assert "exp" in payload
|
||||||
|
assert "iat" in payload
|
||||||
|
assert payload["type"] == "refresh"
|
||||||
|
|
||||||
|
def test_token_expiration(self):
|
||||||
|
"""Test that tokens have the correct expiration time"""
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
expires = timedelta(minutes=5)
|
||||||
|
|
||||||
|
# Create token with specific expiration
|
||||||
|
token = create_access_token(
|
||||||
|
subject=user_id,
|
||||||
|
expires_delta=expires
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decode token
|
||||||
|
payload = jwt.decode(
|
||||||
|
token,
|
||||||
|
settings.SECRET_KEY,
|
||||||
|
algorithms=[settings.ALGORITHM]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get actual expiration time from token
|
||||||
|
expiration = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
|
||||||
|
|
||||||
|
# Calculate expected expiration (approximately)
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
expected_expiration = now + expires
|
||||||
|
|
||||||
|
# Difference should be small (less than 1 second)
|
||||||
|
difference = abs((expiration - expected_expiration).total_seconds())
|
||||||
|
assert difference < 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenDecoding:
|
||||||
|
"""Tests for token decoding and validation functions"""
|
||||||
|
|
||||||
|
def test_decode_valid_token(self):
|
||||||
|
"""Test that a valid token can be decoded"""
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
token = create_access_token(subject=user_id)
|
||||||
|
|
||||||
|
# Decode token
|
||||||
|
payload = decode_token(token)
|
||||||
|
|
||||||
|
# Check that the subject matches
|
||||||
|
assert payload.sub == user_id
|
||||||
|
|
||||||
|
def test_decode_expired_token(self):
|
||||||
|
"""Test that an expired token raises TokenExpiredError"""
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Create a token that's already expired by directly manipulating the payload
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
expired_time = now - timedelta(hours=1) # 1 hour in the past
|
||||||
|
|
||||||
|
# Create the expired token manually
|
||||||
|
payload = {
|
||||||
|
"sub": user_id,
|
||||||
|
"exp": int(expired_time.timestamp()), # Set expiration in the past
|
||||||
|
"iat": int(now.timestamp()),
|
||||||
|
"jti": str(uuid.uuid4()),
|
||||||
|
"type": "access"
|
||||||
|
}
|
||||||
|
|
||||||
|
expired_token = jwt.encode(
|
||||||
|
payload,
|
||||||
|
settings.SECRET_KEY,
|
||||||
|
algorithm=settings.ALGORITHM
|
||||||
|
)
|
||||||
|
|
||||||
|
# Attempting to decode should raise TokenExpiredError
|
||||||
|
with pytest.raises(TokenExpiredError):
|
||||||
|
decode_token(expired_token)
|
||||||
|
|
||||||
|
def test_decode_invalid_token(self):
|
||||||
|
"""Test that an invalid token raises TokenInvalidError"""
|
||||||
|
invalid_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJpbnZhbGlkIn0.invalid-signature"
|
||||||
|
|
||||||
|
with pytest.raises(TokenInvalidError):
|
||||||
|
decode_token(invalid_token)
|
||||||
|
|
||||||
|
def test_decode_token_with_missing_sub(self):
|
||||||
|
"""Test that a token without 'sub' claim raises TokenMissingClaimError"""
|
||||||
|
# Create a token without a subject
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
payload = {
|
||||||
|
"exp": int((now + timedelta(minutes=30)).timestamp()),
|
||||||
|
"iat": int(now.timestamp()),
|
||||||
|
"jti": str(uuid.uuid4()),
|
||||||
|
"type": "access"
|
||||||
|
# No 'sub' claim
|
||||||
|
}
|
||||||
|
|
||||||
|
token = jwt.encode(
|
||||||
|
payload,
|
||||||
|
settings.SECRET_KEY,
|
||||||
|
algorithm=settings.ALGORITHM
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(TokenMissingClaimError):
|
||||||
|
decode_token(token)
|
||||||
|
|
||||||
|
def test_decode_token_with_wrong_type(self):
|
||||||
|
"""Test that verifying a token with wrong type raises TokenInvalidError"""
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
token = create_access_token(subject=user_id)
|
||||||
|
|
||||||
|
# Try to verify it as a refresh token
|
||||||
|
with pytest.raises(TokenInvalidError):
|
||||||
|
decode_token(token, verify_type="refresh")
|
||||||
|
|
||||||
|
def test_decode_with_invalid_payload(self):
|
||||||
|
"""Test that a token with invalid payload structure raises TokenInvalidError"""
|
||||||
|
# Create a token with an invalid payload structure - missing 'sub' which is required
|
||||||
|
# but including 'exp' to avoid the expiration check
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
payload = {
|
||||||
|
# Missing "sub" field which is required
|
||||||
|
"exp": int((now + timedelta(minutes=30)).timestamp()),
|
||||||
|
"iat": int(now.timestamp()),
|
||||||
|
"jti": str(uuid.uuid4()),
|
||||||
|
"invalid_field": "test"
|
||||||
|
}
|
||||||
|
|
||||||
|
token = jwt.encode(
|
||||||
|
payload,
|
||||||
|
settings.SECRET_KEY,
|
||||||
|
algorithm=settings.ALGORITHM
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should raise TokenMissingClaimError due to missing 'sub'
|
||||||
|
with pytest.raises(TokenMissingClaimError):
|
||||||
|
decode_token(token)
|
||||||
|
|
||||||
|
# Create another token with invalid type for required field
|
||||||
|
payload = {
|
||||||
|
"sub": 123, # sub should be a string, not an integer
|
||||||
|
"exp": int((now + timedelta(minutes=30)).timestamp()),
|
||||||
|
}
|
||||||
|
|
||||||
|
token = jwt.encode(
|
||||||
|
payload,
|
||||||
|
settings.SECRET_KEY,
|
||||||
|
algorithm=settings.ALGORITHM
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should raise TokenInvalidError due to ValidationError
|
||||||
|
with pytest.raises(TokenInvalidError):
|
||||||
|
decode_token(token)
|
||||||
|
|
||||||
|
def test_get_token_data(self):
|
||||||
|
"""Test extracting TokenData from a token"""
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
token = create_access_token(
|
||||||
|
subject=str(user_id),
|
||||||
|
claims={"is_superuser": True}
|
||||||
|
)
|
||||||
|
|
||||||
|
token_data = get_token_data(token)
|
||||||
|
|
||||||
|
assert token_data.user_id == user_id
|
||||||
|
assert token_data.is_superuser is True
|
||||||
0
backend/tests/crud/__init__.py
Normal file
0
backend/tests/crud/__init__.py
Normal file
125
backend/tests/crud/test_user.py
Normal file
125
backend/tests/crud/test_user.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.crud.user import user as user_crud
|
||||||
|
from app.models.user import User
|
||||||
|
from app.schemas.users import UserCreate, UserUpdate
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_user(db_session, user_create_data):
|
||||||
|
user_in = UserCreate(**user_create_data)
|
||||||
|
user_obj = user_crud.create(db_session, obj_in=user_in)
|
||||||
|
|
||||||
|
assert user_obj.email == user_create_data["email"]
|
||||||
|
assert user_obj.first_name == user_create_data["first_name"]
|
||||||
|
assert user_obj.last_name == user_create_data["last_name"]
|
||||||
|
assert user_obj.phone_number == user_create_data["phone_number"]
|
||||||
|
assert user_obj.is_superuser == user_create_data["is_superuser"]
|
||||||
|
assert user_obj.password_hash is not None
|
||||||
|
assert user_obj.id is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_user(db_session, mock_user):
|
||||||
|
# Using mock_user fixture instead of creating new user
|
||||||
|
stored_user = user_crud.get(db_session, id=mock_user.id)
|
||||||
|
assert stored_user
|
||||||
|
assert stored_user.id == mock_user.id
|
||||||
|
assert stored_user.email == mock_user.email
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_user_by_email(db_session, mock_user):
|
||||||
|
stored_user = user_crud.get_by_email(db_session, email=mock_user.email)
|
||||||
|
assert stored_user
|
||||||
|
assert stored_user.id == mock_user.id
|
||||||
|
assert stored_user.email == mock_user.email
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_user(db_session, mock_user):
|
||||||
|
update_data = UserUpdate(
|
||||||
|
first_name="Updated",
|
||||||
|
last_name="Name",
|
||||||
|
phone_number="+9876543210"
|
||||||
|
)
|
||||||
|
|
||||||
|
updated_user = user_crud.update(db_session, db_obj=mock_user, obj_in=update_data)
|
||||||
|
|
||||||
|
assert updated_user.first_name == "Updated"
|
||||||
|
assert updated_user.last_name == "Name"
|
||||||
|
assert updated_user.phone_number == "+9876543210"
|
||||||
|
assert updated_user.email == mock_user.email
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_user(db_session, mock_user):
|
||||||
|
user_crud.remove(db_session, id=mock_user.id)
|
||||||
|
deleted_user = user_crud.get(db_session, id=mock_user.id)
|
||||||
|
assert deleted_user is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_multi_users(db_session, mock_user, user_create_data):
|
||||||
|
# Create additional users (mock_user is already in db)
|
||||||
|
users_data = [
|
||||||
|
{**user_create_data, "email": f"test{i}@example.com"}
|
||||||
|
for i in range(2) # Creating 2 more users + mock_user = 3 total
|
||||||
|
]
|
||||||
|
|
||||||
|
for user_data in users_data:
|
||||||
|
user_in = UserCreate(**user_data)
|
||||||
|
user_crud.create(db_session, obj_in=user_in)
|
||||||
|
|
||||||
|
users = user_crud.get_multi(db_session, skip=0, limit=10)
|
||||||
|
assert len(users) == 3
|
||||||
|
assert all(isinstance(user, User) for user in users)
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_active(db_session, mock_user):
|
||||||
|
assert user_crud.is_active(mock_user) is True
|
||||||
|
|
||||||
|
# Test deactivating user
|
||||||
|
update_data = UserUpdate(is_active=False)
|
||||||
|
deactivated_user = user_crud.update(db_session, db_obj=mock_user, obj_in=update_data)
|
||||||
|
assert user_crud.is_active(deactivated_user) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_superuser(db_session, mock_user, user_create_data):
|
||||||
|
# mock_user is regular user
|
||||||
|
assert user_crud.is_superuser(mock_user) is False
|
||||||
|
|
||||||
|
# Create superuser
|
||||||
|
super_user_data = {**user_create_data, "email": "super@example.com", "is_superuser": True}
|
||||||
|
super_user_in = UserCreate(**super_user_data)
|
||||||
|
super_user = user_crud.create(db_session, obj_in=super_user_in)
|
||||||
|
assert user_crud.is_superuser(super_user) is True
|
||||||
|
|
||||||
|
|
||||||
|
# Additional test cases
|
||||||
|
def test_create_duplicate_email(db_session, mock_user):
|
||||||
|
user_data = UserCreate(
|
||||||
|
email=mock_user.email, # Try to create user with existing email
|
||||||
|
password="TestPassword123!",
|
||||||
|
first_name="Test",
|
||||||
|
last_name="User"
|
||||||
|
)
|
||||||
|
with pytest.raises(Exception): # Should raise an integrity error
|
||||||
|
user_crud.create(db_session, obj_in=user_data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_user_preferences(db_session, mock_user):
|
||||||
|
preferences = {"theme": "dark", "notifications": True}
|
||||||
|
update_data = UserUpdate(preferences=preferences)
|
||||||
|
|
||||||
|
updated_user = user_crud.update(db_session, db_obj=mock_user, obj_in=update_data)
|
||||||
|
assert updated_user.preferences == preferences
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_multi_users_pagination(db_session, user_create_data):
|
||||||
|
# Create 5 users
|
||||||
|
for i in range(5):
|
||||||
|
user_in = UserCreate(**{**user_create_data, "email": f"test{i}@example.com"})
|
||||||
|
user_crud.create(db_session, obj_in=user_in)
|
||||||
|
|
||||||
|
# Test pagination
|
||||||
|
first_page = user_crud.get_multi(db_session, skip=0, limit=2)
|
||||||
|
second_page = user_crud.get_multi(db_session, skip=2, limit=2)
|
||||||
|
|
||||||
|
assert len(first_page) == 2
|
||||||
|
assert len(second_page) == 2
|
||||||
|
assert first_page[0].id != second_page[0].id
|
||||||
0
backend/tests/models/__init__.py
Normal file
0
backend/tests/models/__init__.py
Normal file
249
backend/tests/models/test_user.py
Normal file
249
backend/tests/models/test_user.py
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
# tests/models/test_user.py
|
||||||
|
import uuid
|
||||||
|
import pytest
|
||||||
|
from datetime import datetime
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
from app.models.user import User
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_user(db_session):
|
||||||
|
"""Test creating a basic user."""
|
||||||
|
# Arrange
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
new_user = User(
|
||||||
|
id=user_id,
|
||||||
|
email="test@example.com",
|
||||||
|
password_hash="hashedpassword",
|
||||||
|
first_name="Test",
|
||||||
|
last_name="User",
|
||||||
|
phone_number="1234567890",
|
||||||
|
is_active=True,
|
||||||
|
is_superuser=False,
|
||||||
|
preferences={"theme": "dark"},
|
||||||
|
)
|
||||||
|
db_session.add(new_user)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
db_session.commit()
|
||||||
|
created_user = db_session.query(User).filter_by(email="test@example.com").first()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert created_user is not None
|
||||||
|
assert created_user.email == "test@example.com"
|
||||||
|
assert created_user.first_name == "Test"
|
||||||
|
assert created_user.last_name == "User"
|
||||||
|
assert created_user.phone_number == "1234567890"
|
||||||
|
assert created_user.is_active is True
|
||||||
|
assert created_user.is_superuser is False
|
||||||
|
assert created_user.preferences == {"theme": "dark"}
|
||||||
|
# UUID should be preserved
|
||||||
|
assert created_user.id == user_id
|
||||||
|
# Timestamps should be set
|
||||||
|
assert isinstance(created_user.created_at, datetime)
|
||||||
|
assert isinstance(created_user.updated_at, datetime)
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_user(db_session):
|
||||||
|
"""Test updating an existing user."""
|
||||||
|
# Arrange - Create a user
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
user = User(
|
||||||
|
id=user_id,
|
||||||
|
email="update@example.com",
|
||||||
|
password_hash="hashedpassword",
|
||||||
|
first_name="Before",
|
||||||
|
last_name="Update",
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Record the original creation timestamp
|
||||||
|
original_created_at = user.created_at
|
||||||
|
|
||||||
|
# Act - Update the user
|
||||||
|
user.first_name = "After"
|
||||||
|
user.last_name = "Updated"
|
||||||
|
user.phone_number = "9876543210"
|
||||||
|
user.preferences = {"theme": "light", "notifications": True}
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Fetch the updated user to verify changes were persisted
|
||||||
|
updated_user = db_session.query(User).filter_by(id=user_id).first()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert updated_user.first_name == "After"
|
||||||
|
assert updated_user.last_name == "Updated"
|
||||||
|
assert updated_user.phone_number == "9876543210"
|
||||||
|
assert updated_user.preferences == {"theme": "light", "notifications": True}
|
||||||
|
# created_at shouldn't change on update
|
||||||
|
assert updated_user.created_at == original_created_at
|
||||||
|
# updated_at should be newer than created_at
|
||||||
|
assert updated_user.updated_at > original_created_at
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_user(db_session):
|
||||||
|
"""Test deleting a user."""
|
||||||
|
# Arrange - Create a user
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
user = User(
|
||||||
|
id=user_id,
|
||||||
|
email="delete@example.com",
|
||||||
|
password_hash="hashedpassword",
|
||||||
|
first_name="Delete",
|
||||||
|
last_name="Me",
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Act - Delete the user
|
||||||
|
db_session.delete(user)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
deleted_user = db_session.query(User).filter_by(id=user_id).first()
|
||||||
|
assert deleted_user is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_unique_email_constraint(db_session):
|
||||||
|
"""Test that users cannot have duplicate emails."""
|
||||||
|
# Arrange - Create a user
|
||||||
|
user1 = User(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
email="duplicate@example.com",
|
||||||
|
password_hash="hashedpassword",
|
||||||
|
first_name="First",
|
||||||
|
last_name="User",
|
||||||
|
)
|
||||||
|
db_session.add(user1)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Act & Assert - Try to create another user with the same email
|
||||||
|
user2 = User(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
email="duplicate@example.com", # Same email
|
||||||
|
password_hash="differenthash",
|
||||||
|
first_name="Second",
|
||||||
|
last_name="User",
|
||||||
|
)
|
||||||
|
db_session.add(user2)
|
||||||
|
|
||||||
|
# Should raise IntegrityError due to unique constraint
|
||||||
|
with pytest.raises(IntegrityError):
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Rollback for cleanup
|
||||||
|
db_session.rollback()
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_required_fields(db_session):
|
||||||
|
"""Test that required fields are enforced."""
|
||||||
|
# Test each required field by creating a user without it
|
||||||
|
|
||||||
|
# Missing email
|
||||||
|
user_no_email = User(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
# email is missing
|
||||||
|
password_hash="hashedpassword",
|
||||||
|
first_name="Test",
|
||||||
|
last_name="User",
|
||||||
|
)
|
||||||
|
db_session.add(user_no_email)
|
||||||
|
with pytest.raises(IntegrityError):
|
||||||
|
db_session.commit()
|
||||||
|
db_session.rollback()
|
||||||
|
|
||||||
|
# Missing password_hash
|
||||||
|
user_no_password = User(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
email="nopassword@example.com",
|
||||||
|
# password_hash is missing
|
||||||
|
first_name="Test",
|
||||||
|
last_name="User",
|
||||||
|
)
|
||||||
|
db_session.add(user_no_password)
|
||||||
|
with pytest.raises(IntegrityError):
|
||||||
|
db_session.commit()
|
||||||
|
db_session.rollback()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_defaults(db_session):
|
||||||
|
"""Test that default values are correctly set."""
|
||||||
|
# Arrange - Create a minimal user with only required fields
|
||||||
|
minimal_user = User(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
email="minimal@example.com",
|
||||||
|
password_hash="hashedpassword",
|
||||||
|
first_name="Minimal",
|
||||||
|
last_name="User",
|
||||||
|
)
|
||||||
|
db_session.add(minimal_user)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Act - Retrieve the user
|
||||||
|
created_user = db_session.query(User).filter_by(email="minimal@example.com").first()
|
||||||
|
|
||||||
|
# Assert - Check default values
|
||||||
|
assert created_user.is_active is True # Default should be True
|
||||||
|
assert created_user.is_superuser is False # Default should be False
|
||||||
|
assert created_user.phone_number is None # Optional field
|
||||||
|
assert created_user.preferences is None # Optional field
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_string_representation(db_session):
|
||||||
|
"""Test the string representation of a user."""
|
||||||
|
# Arrange
|
||||||
|
user = User(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
email="repr@example.com",
|
||||||
|
password_hash="hashedpassword",
|
||||||
|
first_name="String",
|
||||||
|
last_name="Repr",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
assert str(user) == "<User repr@example.com>"
|
||||||
|
assert repr(user) == "<User repr@example.com>"
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_with_complex_json_preferences(db_session):
|
||||||
|
"""Test storing and retrieving complex JSON preferences."""
|
||||||
|
# Arrange - Create a user with nested JSON preferences
|
||||||
|
complex_preferences = {
|
||||||
|
"theme": {
|
||||||
|
"mode": "dark",
|
||||||
|
"colors": {
|
||||||
|
"primary": "#333",
|
||||||
|
"secondary": "#666"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"notifications": {
|
||||||
|
"email": True,
|
||||||
|
"sms": False,
|
||||||
|
"push": {
|
||||||
|
"enabled": True,
|
||||||
|
"quiet_hours": [22, 7]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"tags": ["important", "family", "events"]
|
||||||
|
}
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
email="complex@example.com",
|
||||||
|
password_hash="hashedpassword",
|
||||||
|
first_name="Complex",
|
||||||
|
last_name="JSON",
|
||||||
|
preferences=complex_preferences
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Act - Retrieve the user
|
||||||
|
retrieved_user = db_session.query(User).filter_by(email="complex@example.com").first()
|
||||||
|
|
||||||
|
# Assert - The complex JSON should be preserved
|
||||||
|
assert retrieved_user.preferences == complex_preferences
|
||||||
|
assert retrieved_user.preferences["theme"]["colors"]["primary"] == "#333"
|
||||||
|
assert retrieved_user.preferences["notifications"]["push"]["quiet_hours"] == [22, 7]
|
||||||
|
assert "important" in retrieved_user.preferences["tags"]
|
||||||
0
backend/tests/schemas/__init__.py
Normal file
0
backend/tests/schemas/__init__.py
Normal file
127
backend/tests/schemas/test_user_schemas.py
Normal file
127
backend/tests/schemas/test_user_schemas.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
# tests/schemas/test_user_schemas.py
|
||||||
|
import pytest
|
||||||
|
import re
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from app.schemas.users import UserBase, UserCreate
|
||||||
|
|
||||||
|
class TestPhoneNumberValidation:
|
||||||
|
"""Tests for phone number validation in user schemas"""
|
||||||
|
|
||||||
|
def test_valid_swiss_numbers(self):
|
||||||
|
"""Test valid Swiss phone numbers are accepted"""
|
||||||
|
# International format
|
||||||
|
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41791234567")
|
||||||
|
assert user.phone_number == "+41791234567"
|
||||||
|
|
||||||
|
# Local format
|
||||||
|
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0791234567")
|
||||||
|
assert user.phone_number == "0791234567"
|
||||||
|
|
||||||
|
# With formatting characters
|
||||||
|
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41 79 123 45 67")
|
||||||
|
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+41791234567"
|
||||||
|
|
||||||
|
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="079 123 45 67")
|
||||||
|
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "0791234567"
|
||||||
|
|
||||||
|
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41-79-123-45-67")
|
||||||
|
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+41791234567"
|
||||||
|
|
||||||
|
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="079-123-45-67")
|
||||||
|
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "0791234567"
|
||||||
|
|
||||||
|
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41 (79) 123 45 67")
|
||||||
|
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+41791234567"
|
||||||
|
|
||||||
|
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="079 (123) 45 67")
|
||||||
|
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "0791234567"
|
||||||
|
|
||||||
|
def test_valid_italian_numbers(self):
|
||||||
|
"""Test valid Italian phone numbers are accepted"""
|
||||||
|
# International format
|
||||||
|
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+393451234567")
|
||||||
|
assert user.phone_number == "+393451234567"
|
||||||
|
|
||||||
|
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39345123456")
|
||||||
|
assert user.phone_number == "+39345123456"
|
||||||
|
|
||||||
|
# Local format
|
||||||
|
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="03451234567")
|
||||||
|
assert user.phone_number == "03451234567"
|
||||||
|
|
||||||
|
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345123456789")
|
||||||
|
assert user.phone_number == "0345123456789"
|
||||||
|
|
||||||
|
# With formatting characters
|
||||||
|
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39 345 123 4567")
|
||||||
|
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+393451234567"
|
||||||
|
|
||||||
|
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345 123 4567")
|
||||||
|
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "03451234567"
|
||||||
|
|
||||||
|
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39-345-123-4567")
|
||||||
|
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+393451234567"
|
||||||
|
|
||||||
|
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345-123-4567")
|
||||||
|
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "03451234567"
|
||||||
|
|
||||||
|
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39 (345) 123 4567")
|
||||||
|
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+393451234567"
|
||||||
|
|
||||||
|
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345 (123) 4567")
|
||||||
|
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "03451234567"
|
||||||
|
|
||||||
|
def test_none_phone_number(self):
|
||||||
|
"""Test that None is accepted as a valid value (optional phone number)"""
|
||||||
|
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number=None)
|
||||||
|
assert user.phone_number is None
|
||||||
|
|
||||||
|
def test_invalid_phone_numbers(self):
|
||||||
|
"""Test that invalid phone numbers are rejected"""
|
||||||
|
invalid_numbers = [
|
||||||
|
# Too short
|
||||||
|
"+12",
|
||||||
|
"012",
|
||||||
|
|
||||||
|
# Invalid characters
|
||||||
|
"+41xyz123456",
|
||||||
|
"079abc4567",
|
||||||
|
"123-abc-7890",
|
||||||
|
"+1(800)CALL-NOW",
|
||||||
|
|
||||||
|
# Completely invalid formats
|
||||||
|
"++4412345678", # Double plus
|
||||||
|
"()+41123456", # Misplaced parentheses
|
||||||
|
|
||||||
|
# Empty string
|
||||||
|
"",
|
||||||
|
# Spaces only
|
||||||
|
" ",
|
||||||
|
]
|
||||||
|
|
||||||
|
for number in invalid_numbers:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number=number)
|
||||||
|
|
||||||
|
def test_phone_validation_in_user_create(self):
|
||||||
|
"""Test that phone validation also works in UserCreate schema"""
|
||||||
|
# Valid phone number
|
||||||
|
user = UserCreate(
|
||||||
|
email="test@example.com",
|
||||||
|
first_name="Test",
|
||||||
|
last_name="User",
|
||||||
|
password="Password123",
|
||||||
|
phone_number="+41791234567"
|
||||||
|
)
|
||||||
|
assert user.phone_number == "+41791234567"
|
||||||
|
|
||||||
|
# Invalid phone number should raise ValidationError
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
UserCreate(
|
||||||
|
email="test@example.com",
|
||||||
|
first_name="Test",
|
||||||
|
last_name="User",
|
||||||
|
password="Password123",
|
||||||
|
phone_number="invalid-number"
|
||||||
|
)
|
||||||
0
backend/tests/services/__init__.py
Normal file
0
backend/tests/services/__init__.py
Normal file
252
backend/tests/services/test_auth_service.py
Normal file
252
backend/tests/services/test_auth_service.py
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
# tests/services/test_auth_service.py
|
||||||
|
import uuid
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from app.core.auth import get_password_hash, verify_password, TokenExpiredError, TokenInvalidError
|
||||||
|
from app.models.user import User
|
||||||
|
from app.schemas.users import UserCreate, Token
|
||||||
|
from app.services.auth_service import AuthService, AuthenticationError
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuthServiceAuthentication:
|
||||||
|
"""Tests for AuthService.authenticate_user method"""
|
||||||
|
|
||||||
|
def test_authenticate_valid_user(self, db_session, mock_user):
|
||||||
|
"""Test authenticating a user with valid credentials"""
|
||||||
|
# Set a known password for the mock user
|
||||||
|
password = "TestPassword123"
|
||||||
|
mock_user.password_hash = get_password_hash(password)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Authenticate with correct credentials
|
||||||
|
user = AuthService.authenticate_user(
|
||||||
|
db=db_session,
|
||||||
|
email=mock_user.email,
|
||||||
|
password=password
|
||||||
|
)
|
||||||
|
|
||||||
|
assert user is not None
|
||||||
|
assert user.id == mock_user.id
|
||||||
|
assert user.email == mock_user.email
|
||||||
|
|
||||||
|
def test_authenticate_nonexistent_user(self, db_session):
|
||||||
|
"""Test authenticating with an email that doesn't exist"""
|
||||||
|
user = AuthService.authenticate_user(
|
||||||
|
db=db_session,
|
||||||
|
email="nonexistent@example.com",
|
||||||
|
password="password"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert user is None
|
||||||
|
|
||||||
|
def test_authenticate_with_wrong_password(self, db_session, mock_user):
|
||||||
|
"""Test authenticating with the wrong password"""
|
||||||
|
# Set a known password for the mock user
|
||||||
|
password = "TestPassword123"
|
||||||
|
mock_user.password_hash = get_password_hash(password)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Authenticate with wrong password
|
||||||
|
user = AuthService.authenticate_user(
|
||||||
|
db=db_session,
|
||||||
|
email=mock_user.email,
|
||||||
|
password="WrongPassword123"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert user is None
|
||||||
|
|
||||||
|
def test_authenticate_inactive_user(self, db_session, mock_user):
|
||||||
|
"""Test authenticating an inactive user"""
|
||||||
|
# Set a known password and make user inactive
|
||||||
|
password = "TestPassword123"
|
||||||
|
mock_user.password_hash = get_password_hash(password)
|
||||||
|
mock_user.is_active = False
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Should raise AuthenticationError
|
||||||
|
with pytest.raises(AuthenticationError):
|
||||||
|
AuthService.authenticate_user(
|
||||||
|
db=db_session,
|
||||||
|
email=mock_user.email,
|
||||||
|
password=password
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuthServiceUserCreation:
|
||||||
|
"""Tests for AuthService.create_user method"""
|
||||||
|
|
||||||
|
def test_create_new_user(self, db_session):
|
||||||
|
"""Test creating a new user"""
|
||||||
|
user_data = UserCreate(
|
||||||
|
email="newuser@example.com",
|
||||||
|
password="TestPassword123",
|
||||||
|
first_name="New",
|
||||||
|
last_name="User",
|
||||||
|
phone_number="1234567890"
|
||||||
|
)
|
||||||
|
|
||||||
|
user = AuthService.create_user(db=db_session, user_data=user_data)
|
||||||
|
|
||||||
|
# Verify user was created with correct data
|
||||||
|
assert user is not None
|
||||||
|
assert user.email == user_data.email
|
||||||
|
assert user.first_name == user_data.first_name
|
||||||
|
assert user.last_name == user_data.last_name
|
||||||
|
assert user.phone_number == user_data.phone_number
|
||||||
|
|
||||||
|
# Verify password was hashed
|
||||||
|
assert user.password_hash != user_data.password
|
||||||
|
assert verify_password(user_data.password, user.password_hash)
|
||||||
|
|
||||||
|
# Verify default values
|
||||||
|
assert user.is_active is True
|
||||||
|
assert user.is_superuser is False
|
||||||
|
|
||||||
|
def test_create_user_with_existing_email(self, db_session, mock_user):
|
||||||
|
"""Test creating a user with an email that already exists"""
|
||||||
|
user_data = UserCreate(
|
||||||
|
email=mock_user.email, # Use existing email
|
||||||
|
password="TestPassword123",
|
||||||
|
first_name="Duplicate",
|
||||||
|
last_name="User"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should raise AuthenticationError
|
||||||
|
with pytest.raises(AuthenticationError):
|
||||||
|
AuthService.create_user(db=db_session, user_data=user_data)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuthServiceTokens:
|
||||||
|
"""Tests for AuthService token-related methods"""
|
||||||
|
|
||||||
|
def test_create_tokens(self, mock_user):
|
||||||
|
"""Test creating access and refresh tokens for a user"""
|
||||||
|
tokens = AuthService.create_tokens(mock_user)
|
||||||
|
|
||||||
|
# Verify token structure
|
||||||
|
assert isinstance(tokens, Token)
|
||||||
|
assert tokens.access_token is not None
|
||||||
|
assert tokens.refresh_token is not None
|
||||||
|
assert tokens.token_type == "bearer"
|
||||||
|
|
||||||
|
# This is a more in-depth test that would decode the tokens to verify claims
|
||||||
|
# but we'll rely on the auth module tests for token verification
|
||||||
|
|
||||||
|
def test_refresh_tokens(self, db_session, mock_user):
|
||||||
|
"""Test refreshing tokens with a valid refresh token"""
|
||||||
|
# Create initial tokens
|
||||||
|
initial_tokens = AuthService.create_tokens(mock_user)
|
||||||
|
|
||||||
|
# Refresh tokens
|
||||||
|
new_tokens = AuthService.refresh_tokens(
|
||||||
|
db=db_session,
|
||||||
|
refresh_token=initial_tokens.refresh_token
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify new tokens are different from old ones
|
||||||
|
assert new_tokens.access_token != initial_tokens.access_token
|
||||||
|
assert new_tokens.refresh_token != initial_tokens.refresh_token
|
||||||
|
|
||||||
|
def test_refresh_tokens_with_invalid_token(self, db_session):
|
||||||
|
"""Test refreshing tokens with an invalid token"""
|
||||||
|
# Create an invalid token
|
||||||
|
invalid_token = "invalid.token.string"
|
||||||
|
|
||||||
|
# Should raise TokenInvalidError
|
||||||
|
with pytest.raises(TokenInvalidError):
|
||||||
|
AuthService.refresh_tokens(
|
||||||
|
db=db_session,
|
||||||
|
refresh_token=invalid_token
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_refresh_tokens_with_access_token(self, db_session, mock_user):
|
||||||
|
"""Test refreshing tokens with an access token instead of refresh token"""
|
||||||
|
# Create tokens
|
||||||
|
tokens = AuthService.create_tokens(mock_user)
|
||||||
|
|
||||||
|
# Try to refresh with access token
|
||||||
|
with pytest.raises(TokenInvalidError):
|
||||||
|
AuthService.refresh_tokens(
|
||||||
|
db=db_session,
|
||||||
|
refresh_token=tokens.access_token
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_refresh_tokens_with_nonexistent_user(self, db_session):
|
||||||
|
"""Test refreshing tokens for a user that doesn't exist in the database"""
|
||||||
|
# Create a token for a non-existent user
|
||||||
|
non_existent_id = str(uuid.uuid4())
|
||||||
|
with patch('app.core.auth.decode_token'), patch('app.core.auth.get_token_data') as mock_get_data:
|
||||||
|
# Mock the token data to return a non-existent user ID
|
||||||
|
mock_get_data.return_value.user_id = uuid.UUID(non_existent_id)
|
||||||
|
|
||||||
|
# Should raise TokenInvalidError
|
||||||
|
with pytest.raises(TokenInvalidError):
|
||||||
|
AuthService.refresh_tokens(
|
||||||
|
db=db_session,
|
||||||
|
refresh_token="some.refresh.token"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuthServicePasswordChange:
|
||||||
|
"""Tests for AuthService.change_password method"""
|
||||||
|
|
||||||
|
def test_change_password(self, db_session, mock_user):
|
||||||
|
"""Test changing a user's password"""
|
||||||
|
# Set a known password for the mock user
|
||||||
|
current_password = "CurrentPassword123"
|
||||||
|
mock_user.password_hash = get_password_hash(current_password)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Change password
|
||||||
|
new_password = "NewPassword456"
|
||||||
|
result = AuthService.change_password(
|
||||||
|
db=db_session,
|
||||||
|
user_id=mock_user.id,
|
||||||
|
current_password=current_password,
|
||||||
|
new_password=new_password
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify operation was successful
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
# Refresh user from DB
|
||||||
|
db_session.refresh(mock_user)
|
||||||
|
|
||||||
|
# Verify old password no longer works
|
||||||
|
assert not verify_password(current_password, mock_user.password_hash)
|
||||||
|
|
||||||
|
# Verify new password works
|
||||||
|
assert verify_password(new_password, mock_user.password_hash)
|
||||||
|
|
||||||
|
def test_change_password_wrong_current_password(self, db_session, mock_user):
|
||||||
|
"""Test changing password with incorrect current password"""
|
||||||
|
# Set a known password for the mock user
|
||||||
|
current_password = "CurrentPassword123"
|
||||||
|
mock_user.password_hash = get_password_hash(current_password)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Try to change password with wrong current password
|
||||||
|
wrong_password = "WrongPassword123"
|
||||||
|
with pytest.raises(AuthenticationError):
|
||||||
|
AuthService.change_password(
|
||||||
|
db=db_session,
|
||||||
|
user_id=mock_user.id,
|
||||||
|
current_password=wrong_password,
|
||||||
|
new_password="NewPassword456"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify password was not changed
|
||||||
|
assert verify_password(current_password, mock_user.password_hash)
|
||||||
|
|
||||||
|
def test_change_password_nonexistent_user(self, db_session):
|
||||||
|
"""Test changing password for a user that doesn't exist"""
|
||||||
|
non_existent_id = uuid.uuid4()
|
||||||
|
|
||||||
|
with pytest.raises(AuthenticationError):
|
||||||
|
AuthService.change_password(
|
||||||
|
db=db_session,
|
||||||
|
user_id=non_existent_id,
|
||||||
|
current_password="CurrentPassword123",
|
||||||
|
new_password="NewPassword456"
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user