Add comprehensive test suite and utilities for user functionality

This commit introduces a suite of tests for user models, schemas, CRUD operations, and authentication services. It also adds utilities for in-memory database setup to support these tests and updates environment settings for consistency.
This commit is contained in:
2025-03-04 19:10:54 +01:00
parent 481b6d618e
commit 162e586e13
40 changed files with 2948 additions and 11 deletions

View File

@@ -17,7 +17,7 @@ ENVIRONMENT=development
DEBUG=true
BACKEND_CORS_ORIGINS=["http://localhost:3000"]
FIRST_SUPERUSER_EMAIL=admin@example.com
FIRST_SUPERUSER_PASSWORD=admin123
FIRST_SUPERUSER_PASSWORD=Admin123
# Frontend settings
FRONTEND_PORT=3000

View File

@@ -0,0 +1,46 @@
"""Add all initial models
Revision ID: 38bf9e7e74b3
Revises: 7396957cbe80
Create Date: 2025-02-28 09:19:33.212278
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = '38bf9e7e74b3'
down_revision: Union[str, None] = '7396957cbe80'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table('users',
sa.Column('email', sa.String(), nullable=False),
sa.Column('password_hash', sa.String(), nullable=False),
sa.Column('first_name', sa.String(), nullable=False),
sa.Column('last_name', sa.String(), nullable=True),
sa.Column('phone_number', sa.String(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('is_superuser', sa.Boolean(), nullable=False),
sa.Column('preferences', sa.JSON(), nullable=True),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_users_email'), table_name='users')
op.drop_table('users')
# ### end Alembic commands ###

View File

View File

View File

@@ -0,0 +1,137 @@
from typing import Optional
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.orm import Session
from app.core.auth import get_token_data, TokenExpiredError, TokenInvalidError
from app.core.database import get_db
from app.models.user import User
# OAuth2 configuration
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
def get_current_user(
db: Session = Depends(get_db),
token: str = Depends(oauth2_scheme)
) -> User:
"""
Get the current authenticated user.
Args:
db: Database session
token: JWT token from request
Returns:
User: The authenticated user
Raises:
HTTPException: If authentication fails
"""
try:
# Decode token and get user ID
token_data = get_token_data(token)
# Get user from database
user = db.query(User).filter(User.id == token_data.user_id).first()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Inactive user"
)
return user
except TokenExpiredError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token expired",
headers={"WWW-Authenticate": "Bearer"}
)
except TokenInvalidError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"}
)
def get_current_active_user(
current_user: User = Depends(get_current_user)
) -> User:
"""
Check if the current user is active.
Args:
current_user: The current authenticated user
Returns:
User: The authenticated and active user
Raises:
HTTPException: If user is inactive
"""
if not current_user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Inactive user"
)
return current_user
def get_current_superuser(
current_user: User = Depends(get_current_user)
) -> User:
"""
Check if the current user is a superuser.
Args:
current_user: The current authenticated user
Returns:
User: The authenticated superuser
Raises:
HTTPException: If user is not a superuser
"""
if not current_user.is_superuser:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions"
)
return current_user
def get_optional_current_user(
db: Session = Depends(get_db),
token: Optional[str] = Depends(oauth2_scheme)
) -> Optional[User]:
"""
Get the current user if authenticated, otherwise return None.
Useful for endpoints that work with both authenticated and unauthenticated users.
Args:
db: Database session
token: JWT token from request
Returns:
User or None: The authenticated user or None
"""
if not token:
return None
try:
token_data = get_token_data(token)
user = db.query(User).filter(User.id == token_data.user_id).first()
if not user or not user.is_active:
return None
return user
except (TokenExpiredError, TokenInvalidError):
return None

6
backend/app/api/main.py Normal file
View File

@@ -0,0 +1,6 @@
from fastapi import APIRouter
from app.api.routes import auth
api_router = APIRouter()
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])

View File

View File

@@ -0,0 +1,231 @@
# app/api/routes/auth.py
import logging
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, status, Body
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.orm import Session
from app.api.dependencies.auth import get_current_user
from app.core.auth import TokenExpiredError, TokenInvalidError
from app.core.database import get_db
from app.models.user import User
from app.schemas.users import (
UserCreate,
UserResponse,
Token,
LoginRequest,
RefreshTokenRequest
)
from app.services.auth_service import AuthService, AuthenticationError
router = APIRouter()
logger = logging.getLogger(__name__)
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def register_user(
user_data: UserCreate,
db: Session = Depends(get_db)
) -> Any:
"""
Register a new user.
Returns:
The created user information.
"""
try:
user = AuthService.create_user(db, user_data)
return user
except AuthenticationError as e:
logger.warning(f"Registration failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=str(e)
)
except Exception as e:
logger.error(f"Unexpected error during registration: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred. Please try again later."
)
@router.post("/login", response_model=Token)
async def login(
login_data: LoginRequest,
db: Session = Depends(get_db)
) -> Any:
"""
Login with username and password.
Returns:
Access and refresh tokens.
"""
try:
# Attempt to authenticate the user
user = AuthService.authenticate_user(db, login_data.email, login_data.password)
# Explicitly check for None result and raise correct exception
if user is None:
logger.warning(f"Invalid login attempt for: {login_data.email}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid email or password",
headers={"WWW-Authenticate": "Bearer"},
)
# User is authenticated, generate tokens
tokens = AuthService.create_tokens(user)
logger.info(f"User login successful: {user.email}")
return tokens
except HTTPException:
# Re-raise HTTP exceptions without modification
raise
except AuthenticationError as e:
# Handle specific authentication errors like inactive accounts
logger.warning(f"Authentication failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(e),
headers={"WWW-Authenticate": "Bearer"},
)
except Exception as e:
# Handle unexpected errors
logger.error(f"Unexpected error during login: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred. Please try again later."
)
@router.post("/login/oauth", response_model=Token)
async def login_oauth(
form_data: OAuth2PasswordRequestForm = Depends(),
db: Session = Depends(get_db)
) -> Any:
"""
OAuth2-compatible login endpoint, used by the OpenAPI UI.
Returns:
Access and refresh tokens.
"""
try:
user = AuthService.authenticate_user(db, form_data.username, form_data.password)
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid email or password",
headers={"WWW-Authenticate": "Bearer"},
)
# Generate tokens
tokens = AuthService.create_tokens(user)
# Format response for OAuth compatibility
return {
"access_token": tokens.access_token,
"refresh_token": tokens.refresh_token,
"token_type": tokens.token_type
}
except HTTPException:
raise
except AuthenticationError as e:
logger.warning(f"OAuth authentication failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(e),
headers={"WWW-Authenticate": "Bearer"},
)
except Exception as e:
logger.error(f"Unexpected error during OAuth login: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred. Please try again later."
)
@router.post("/refresh", response_model=Token)
async def refresh_token(
refresh_data: RefreshTokenRequest,
db: Session = Depends(get_db)
) -> Any:
"""
Refresh access token using a refresh token.
Returns:
New access and refresh tokens.
"""
try:
tokens = AuthService.refresh_tokens(db, refresh_data.refresh_token)
return tokens
except TokenExpiredError:
logger.warning("Token refresh failed: Token expired")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Refresh token has expired. Please log in again.",
headers={"WWW-Authenticate": "Bearer"},
)
except TokenInvalidError:
logger.warning("Token refresh failed: Invalid token")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token",
headers={"WWW-Authenticate": "Bearer"},
)
except Exception as e:
logger.error(f"Unexpected error during token refresh: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred. Please try again later."
)
@router.post("/change-password", status_code=status.HTTP_200_OK)
async def change_password(
current_password: str = Body(..., embed=True),
new_password: str = Body(..., embed=True),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
) -> Any:
"""
Change current user's password.
Requires authentication.
"""
try:
success = AuthService.change_password(
db=db,
user_id=current_user.id,
current_password=current_password,
new_password=new_password
)
if success:
return {"message": "Password changed successfully"}
except AuthenticationError as e:
logger.warning(f"Password change failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
logger.error(f"Unexpected error during password change: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred. Please try again later."
)
@router.get("/me", response_model=UserResponse)
async def get_current_user_info(
current_user: User = Depends(get_current_user)
) -> Any:
"""
Get current user information.
Requires authentication.
"""
return current_user

185
backend/app/core/auth.py Normal file
View File

@@ -0,0 +1,185 @@
import logging
logging.getLogger('passlib').setLevel(logging.ERROR)
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional, Union
import uuid
from jose import jwt, JWTError
from passlib.context import CryptContext
from pydantic import ValidationError
from app.core.config import settings
from app.schemas.users import TokenData, TokenPayload
# Password hashing context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# Custom exceptions for auth
class AuthError(Exception):
"""Base authentication error"""
pass
class TokenExpiredError(AuthError):
"""Token has expired"""
pass
class TokenInvalidError(AuthError):
"""Token is invalid"""
pass
class TokenMissingClaimError(AuthError):
"""Token is missing a required claim"""
pass
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against a hash."""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
"""Generate a password hash."""
return pwd_context.hash(password)
def create_access_token(
subject: Union[str, Any],
expires_delta: Optional[timedelta] = None,
claims: Optional[Dict[str, Any]] = None
) -> str:
"""
Create a JWT access token.
Args:
subject: The subject of the token (usually user_id)
expires_delta: Optional expiration time delta
claims: Optional additional claims to include in the token
Returns:
Encoded JWT token
"""
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
# Base token data
to_encode = {
"sub": str(subject),
"exp": expire,
"iat": datetime.now(tz=timezone.utc),
"jti": str(uuid.uuid4()),
"type": "access"
}
# Add custom claims
if claims:
to_encode.update(claims)
# Create the JWT
encoded_jwt = jwt.encode(
to_encode,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
)
return encoded_jwt
def create_refresh_token(
subject: Union[str, Any],
expires_delta: Optional[timedelta] = None
) -> str:
"""
Create a JWT refresh token.
Args:
subject: The subject of the token (usually user_id)
expires_delta: Optional expiration time delta
Returns:
Encoded JWT refresh token
"""
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
to_encode = {
"sub": str(subject),
"exp": expire,
"iat": datetime.now(timezone.utc),
"jti": str(uuid.uuid4()),
"type": "refresh"
}
encoded_jwt = jwt.encode(
to_encode,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
)
return encoded_jwt
def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload:
"""
Decode and verify a JWT token.
Args:
token: JWT token to decode
verify_type: Optional token type to verify (access or refresh)
Returns:
TokenPayload: The decoded token data
Raises:
TokenExpiredError: If the token has expired
TokenInvalidError: If the token is invalid
TokenMissingClaimError: If a required claim is missing
"""
try:
payload = jwt.decode(
token,
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM]
)
# Check required claims before Pydantic validation
if not payload.get("sub"):
raise TokenMissingClaimError("Token missing 'sub' claim")
# Verify token type if specified
if verify_type and payload.get("type") != verify_type:
raise TokenInvalidError(f"Invalid token type, expected {verify_type}")
# Now validate with Pydantic
token_data = TokenPayload(**payload)
return token_data
except JWTError as e:
# Check if the error is due to an expired token
if "expired" in str(e).lower():
raise TokenExpiredError("Token has expired")
raise TokenInvalidError("Invalid authentication token")
except ValidationError:
raise TokenInvalidError("Invalid token payload")
def get_token_data(token: str) -> TokenData:
"""
Extract the user ID and superuser status from a token.
Args:
token: JWT token
Returns:
TokenData with user_id and is_superuser
"""
payload = decode_token(token)
user_id = payload.sub
is_superuser = payload.is_superuser or False
return TokenData(user_id=uuid.UUID(user_id), is_superuser=is_superuser)

View File

@@ -3,7 +3,7 @@ from typing import Optional, List
class Settings(BaseSettings):
PROJECT_NAME: str = "App"
PROJECT_NAME: str = "EventSpace"
VERSION: str = "1.0.0"
API_V1_STR: str = "/api/v1"
@@ -14,6 +14,17 @@ class Settings(BaseSettings):
POSTGRES_PORT: str = "5432"
POSTGRES_DB: str = "app"
DATABASE_URL: Optional[str] = None
REFRESH_TOKEN_EXPIRE_DAYS: int = 60
db_pool_size: int = 20 # Default connection pool size
db_max_overflow: int = 50 # Maximum overflow connections
db_pool_timeout: int = 30 # Seconds to wait for a connection
db_pool_recycle: int = 3600 # Recycle connections after 1 hour
# SQL debugging (disable in production)
sql_echo: bool = False # Log SQL statements
sql_echo_pool: bool = False # Log connection pool events
sql_echo_timing: bool = False # Log query execution times
slow_query_threshold: float = 0.5 # Log queries taking longer than this
@property
def database_url(self) -> str:
@@ -30,7 +41,7 @@ class Settings(BaseSettings):
# JWT configuration
SECRET_KEY: str = "your_secret_key_here"
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
ACCESS_TOKEN_EXPIRE_MINUTES: int = 1440 # 1 day
# CORS configuration
BACKEND_CORS_ORIGINS: List[str] = ["http://localhost:3000"]

View File

@@ -1,17 +1,57 @@
# app/core/database.py
import logging
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.dialects.postgresql import JSONB, UUID
from app.core.config import settings
# Use the database URL from settings
engine = create_engine(settings.database_url)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Configure logging
logger = logging.getLogger(__name__)
# SQLite compatibility for testing
@compiles(JSONB, 'sqlite')
def compile_jsonb_sqlite(type_, compiler, **kw):
return "TEXT"
@compiles(UUID, 'sqlite')
def compile_uuid_sqlite(type_, compiler, **kw):
return "TEXT"
# Declarative base for models
Base = declarative_base()
# Create engine with optimized settings for PostgreSQL
def create_production_engine():
return create_engine(
settings.database_url,
# Connection pool settings
pool_size=settings.db_pool_size,
max_overflow=settings.db_max_overflow,
pool_timeout=settings.db_pool_timeout,
pool_recycle=settings.db_pool_recycle,
pool_pre_ping=True,
# Query execution settings
connect_args={
"application_name": "eventspace",
"keepalives": 1,
"keepalives_idle": 60,
"keepalives_interval": 10,
"keepalives_count": 5,
"options": "-c timezone=UTC",
},
isolation_level="READ COMMITTED",
echo=settings.sql_echo,
echo_pool=settings.sql_echo_pool,
)
# Dependency to get DB session
# Default production engine and session factory
engine = create_production_engine()
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# FastAPI dependency
def get_db():
db = SessionLocal()
try:

View File

62
backend/app/crud/base.py Normal file
View File

@@ -0,0 +1,62 @@
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.core.database import Base
ModelType = TypeVar("ModelType", bound=Base)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
def __init__(self, model: Type[ModelType]):
"""
CRUD object with default methods to Create, Read, Update, Delete (CRUD).
Parameters:
model: A SQLAlchemy model class
"""
self.model = model
def get(self, db: Session, id: str) -> Optional[ModelType]:
return db.query(self.model).filter(self.model.id == id).first()
def get_multi(
self, db: Session, *, skip: int = 0, limit: int = 100
) -> List[ModelType]:
return db.query(self.model).offset(skip).limit(limit).all()
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
obj_in_data = jsonable_encoder(obj_in)
db_obj = self.model(**obj_in_data)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
def update(
self,
db: Session,
*,
db_obj: ModelType,
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
) -> ModelType:
obj_data = jsonable_encoder(db_obj)
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.model_dump(exclude_unset=True)
for field in obj_data:
if field in update_data:
setattr(db_obj, field, update_data[field])
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
def remove(self, db: Session, *, id: str) -> ModelType:
obj = db.query(self.model).get(id)
db.delete(obj)
db.commit()
return obj

56
backend/app/crud/user.py Normal file
View File

@@ -0,0 +1,56 @@
# app/crud/user.py
from typing import Optional, Union, Dict, Any
from sqlalchemy.orm import Session
from app.crud.base import CRUDBase
from app.models.user import User
from app.schemas.users import UserCreate, UserUpdate
from app.core.auth import get_password_hash
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
def get_by_email(self, db: Session, *, email: str) -> Optional[User]:
return db.query(User).filter(User.email == email).first()
def create(self, db: Session, *, obj_in: UserCreate) -> User:
db_obj = User(
email=obj_in.email,
password_hash=get_password_hash(obj_in.password),
first_name=obj_in.first_name,
last_name=obj_in.last_name,
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None,
is_superuser=obj_in.is_superuser if hasattr(obj_in, 'is_superuser') else False,
preferences={}
)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
def update(
self,
db: Session,
*,
db_obj: User,
obj_in: Union[UserUpdate, Dict[str, Any]]
) -> User:
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.model_dump(exclude_unset=True)
# Handle password separately if it exists in update data
if "password" in update_data:
update_data["password_hash"] = get_password_hash(update_data["password"])
del update_data["password"]
return super().update(db, db_obj=db_obj, obj_in=update_data)
def is_active(self, user: User) -> bool:
return user.is_active
def is_superuser(self, user: User) -> bool:
return user.is_superuser
# Create a singleton instance for use across the application
user = CRUDUser(User)

View File

@@ -1,9 +1,18 @@
import logging
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from app.config import settings
from app.api.main import api_router
from app.core.config import settings
scheduler = AsyncIOScheduler()
logger = logging.getLogger(__name__)
logger.info(f"Starting app!!!")
app = FastAPI(
title=settings.PROJECT_NAME,
version=settings.VERSION,
@@ -34,3 +43,6 @@ async def root():
</body>
</html>
"""
app.include_router(api_router, prefix=settings.API_V1_STR)

View File

@@ -0,0 +1,14 @@
"""
Models package initialization.
Imports all models to ensure they're registered with SQLAlchemy.
"""
# First import Base to avoid circular imports
from app.core.database import Base
from .base import TimestampMixin, UUIDMixin
# Import user model
from .user import User
__all__ = [
'Base', 'TimestampMixin', 'UUIDMixin',
'User',
]

View File

@@ -0,0 +1,20 @@
import uuid
from datetime import datetime, timezone
from sqlalchemy import Column, DateTime
from sqlalchemy.dialects.postgresql import UUID
# noinspection PyUnresolvedReferences
from app.core.database import Base
class TimestampMixin:
"""Mixin to add created_at and updated_at timestamps to models"""
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False)
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc), nullable=False)
class UUIDMixin:
"""Mixin to add UUID primary keys to models"""
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)

View File

@@ -0,0 +1,19 @@
from sqlalchemy import Column, String, JSON, Boolean
from .base import Base, TimestampMixin, UUIDMixin
class User(Base, UUIDMixin, TimestampMixin):
__tablename__ = 'users'
email = Column(String, unique=True, nullable=False, index=True)
password_hash = Column(String, nullable=False)
first_name = Column(String, nullable=False, default="user")
last_name = Column(String, nullable=True)
phone_number = Column(String)
is_active = Column(Boolean, default=True, nullable=False)
is_superuser = Column(Boolean, default=False, nullable=False)
preferences = Column(JSON)
def __repr__(self):
return f"<User {self.email}>"

View File

View File

@@ -0,0 +1,149 @@
# app/schemas/users.py
import re
from datetime import datetime
from typing import Optional, Dict, Any
from uuid import UUID
from pydantic import BaseModel, EmailStr, field_validator, ConfigDict
class UserBase(BaseModel):
email: EmailStr
first_name: str
last_name: Optional[str] = None
phone_number: Optional[str] = None
@field_validator('phone_number')
@classmethod
def validate_phone_number(cls, v: Optional[str]) -> Optional[str]:
if v is None:
return v
# Simple regex for phone validation
if not re.match(r'^\+?[0-9\s\-\(\)]{8,20}$', v):
raise ValueError('Invalid phone number format')
return v
class UserCreate(UserBase):
password: str
is_superuser: bool = False
@field_validator('password')
@classmethod
def password_strength(cls, v: str) -> str:
"""Basic password strength validation"""
if len(v) < 8:
raise ValueError('Password must be at least 8 characters')
if not any(char.isdigit() for char in v):
raise ValueError('Password must contain at least one digit')
if not any(char.isupper() for char in v):
raise ValueError('Password must contain at least one uppercase letter')
return v
class UserUpdate(BaseModel):
first_name: Optional[str] = None
last_name: Optional[str] = None
phone_number: Optional[str] = None
preferences: Optional[Dict[str, Any]] = None
is_active: Optional[bool] = True
@field_validator('phone_number')
def validate_phone_number(cls, v: Optional[str]) -> Optional[str]:
if v is None:
return v
# Return early for empty strings or whitespace-only strings
if not v or v.strip() == "":
raise ValueError('Phone number cannot be empty')
# Remove all spaces and formatting characters
cleaned = re.sub(r'[\s\-\(\)]', '', v)
# Basic pattern:
# Must start with + or 0
# After + must have at least 8 digits
# After 0 must have at least 8 digits
# Maximum total length of 15 digits (international standard)
# Only allowed characters are + at start and digits
pattern = r'^(?:\+[0-9]{8,14}|0[0-9]{8,14})$'
if not re.match(pattern, cleaned):
raise ValueError('Phone number must start with + or 0 followed by 8-14 digits')
# Additional validation to catch specific invalid cases
if cleaned.count('+') > 1:
raise ValueError('Phone number can only contain one + symbol at the start')
# Check for any non-digit characters (except the leading +)
if not all(c.isdigit() for c in cleaned[1:]):
raise ValueError('Phone number can only contain digits after the prefix')
return cleaned
class UserInDB(UserBase):
id: UUID
is_active: bool
is_superuser: bool
created_at: datetime
updated_at: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
class UserResponse(UserBase):
id: UUID
is_active: bool
is_superuser: bool
created_at: datetime
updated_at: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
class Token(BaseModel):
access_token: str
refresh_token: Optional[str] = None
token_type: str = "bearer"
class TokenPayload(BaseModel):
sub: str # User ID
exp: int # Expiration time
iat: Optional[int] = None # Issued at
jti: Optional[str] = None # JWT ID
is_superuser: Optional[bool] = False
first_name: Optional[str] = None
email: Optional[str] = None
type: Optional[str] = None # Token type (access/refresh)
class TokenData(BaseModel):
user_id: UUID
is_superuser: bool = False
class PasswordReset(BaseModel):
token: str
new_password: str
@field_validator('new_password')
@classmethod
def password_strength(cls, v: str) -> str:
"""Basic password strength validation"""
if len(v) < 8:
raise ValueError('Password must be at least 8 characters')
if not any(char.isdigit() for char in v):
raise ValueError('Password must contain at least one digit')
if not any(char.isupper() for char in v):
raise ValueError('Password must contain at least one uppercase letter')
return v
class LoginRequest(BaseModel):
email: EmailStr
password: str
class RefreshTokenRequest(BaseModel):
refresh_token: str

View File

View File

@@ -0,0 +1,193 @@
# app/services/auth_service.py
import logging
from typing import Optional
from uuid import UUID
from sqlalchemy.orm import Session
from app.core.auth import (
verify_password,
get_password_hash,
create_access_token,
create_refresh_token,
TokenExpiredError,
TokenInvalidError
)
from app.models.user import User
from app.schemas.users import Token, UserCreate
logger = logging.getLogger(__name__)
class AuthenticationError(Exception):
"""Exception raised for authentication errors"""
pass
class AuthService:
"""Service for handling authentication operations"""
@staticmethod
def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
"""
Authenticate a user with email and password.
Args:
db: Database session
email: User email
password: User password
Returns:
User if authenticated, None otherwise
"""
user = db.query(User).filter(User.email == email).first()
if not user:
return None
if not verify_password(password, user.password_hash):
return None
if not user.is_active:
raise AuthenticationError("User account is inactive")
return user
@staticmethod
def create_user(db: Session, user_data: UserCreate) -> User:
"""
Create a new user.
Args:
db: Database session
user_data: User data
Returns:
Created user
"""
# Check if user already exists
existing_user = db.query(User).filter(User.email == user_data.email).first()
if existing_user:
raise AuthenticationError("User with this email already exists")
# Create new user
hashed_password = get_password_hash(user_data.password)
# Create user object from model
user = User(
email=user_data.email,
password_hash=hashed_password,
first_name=user_data.first_name,
last_name=user_data.last_name,
phone_number=user_data.phone_number,
is_active=True,
is_superuser=False
)
db.add(user)
db.commit()
db.refresh(user)
return user
@staticmethod
def create_tokens(user: User) -> Token:
"""
Create access and refresh tokens for a user.
Args:
user: User to create tokens for
Returns:
Token object with access and refresh tokens
"""
# Generate claims
claims = {
"is_superuser": user.is_superuser,
"email": user.email,
"first_name": user.first_name
}
# Create tokens
access_token = create_access_token(
subject=str(user.id),
claims=claims
)
refresh_token = create_refresh_token(
subject=str(user.id)
)
return Token(
access_token=access_token,
refresh_token=refresh_token
)
@staticmethod
def refresh_tokens(db: Session, refresh_token: str) -> Token:
"""
Generate new tokens using a refresh token.
Args:
db: Database session
refresh_token: Valid refresh token
Returns:
New access and refresh tokens
Raises:
TokenExpiredError: If refresh token has expired
TokenInvalidError: If refresh token is invalid
"""
from app.core.auth import decode_token, get_token_data
try:
# Verify token is a refresh token
decode_token(refresh_token, verify_type="refresh")
# Get user ID from token
token_data = get_token_data(refresh_token)
user_id = token_data.user_id
# Get user from database
user = db.query(User).filter(User.id == user_id).first()
if not user or not user.is_active:
raise TokenInvalidError("Invalid user or inactive account")
# Generate new tokens
return AuthService.create_tokens(user)
except (TokenExpiredError, TokenInvalidError) as e:
logger.warning(f"Token refresh failed: {str(e)}")
raise
@staticmethod
def change_password(db: Session, user_id: UUID, current_password: str, new_password: str) -> bool:
"""
Change a user's password.
Args:
db: Database session
user_id: User ID
current_password: Current password
new_password: New password
Returns:
True if password was changed successfully
Raises:
AuthenticationError: If current password is incorrect
"""
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise AuthenticationError("User not found")
# Verify current password
if not verify_password(current_password, user.password_hash):
raise AuthenticationError("Current password is incorrect")
# Update password
user.password_hash = get_password_hash(new_password)
db.commit()
return True

View File

View File

@@ -0,0 +1,79 @@
import logging
from sqlalchemy import create_engine, event
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker, clear_mappers
from sqlalchemy.pool import StaticPool
from app.core.database import Base
logger = logging.getLogger(__name__)
def get_test_engine():
"""Create an SQLite in-memory engine specifically for testing"""
test_engine = create_engine(
"sqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool, # Use static pool for in-memory testing
echo=False
)
return test_engine
def setup_test_db():
"""Create a test database and session factory"""
# Create a new engine for this test run
test_engine = get_test_engine()
# Create tables
Base.metadata.create_all(test_engine)
# Create session factory
TestingSessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=test_engine,
expire_on_commit=False
)
return test_engine, TestingSessionLocal
def teardown_test_db(engine):
"""Clean up after tests"""
# Drop all tables
Base.metadata.drop_all(engine)
# Dispose of engine
engine.dispose()
async def get_async_test_engine():
"""Create an async SQLite in-memory engine specifically for testing"""
test_engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool, # Use static pool for in-memory testing
echo=False
)
return test_engine
async def setup_async_test_db():
"""Create an async test database and session factory"""
test_engine = await get_async_test_engine()
async with test_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
AsyncTestingSessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=test_engine,
expire_on_commit=False,
class_=AsyncSession
)
return test_engine, AsyncTestingSessionLocal
async def teardown_async_test_db(engine):
"""Clean up after async tests"""
await engine.dispose()

10
backend/pytest.ini Normal file
View File

@@ -0,0 +1,10 @@
[pytest]
env =
IS_TEST=True
testpaths = tests
python_files = test_*.py
addopts = --disable-warnings
markers =
sqlite: marks tests that should run on SQLite (mocked).
postgres: marks tests that require a real PostgreSQL database.
asyncio_default_fixture_loop_scope = function

View File

@@ -4,13 +4,14 @@ uvicorn>=0.34.0
pydantic>=2.10.6
pydantic-settings>=2.2.1
python-multipart>=0.0.19
fastapi-utils==0.8.0
# Database
sqlalchemy>=2.0.29
alembic>=1.14.1
psycopg2-binary>=2.9.9
asyncpg>=0.29.0
aiosqlite==0.21.0
# Security and authentication
python-jose>=3.4.0
passlib>=1.7.4
@@ -30,7 +31,7 @@ httpx>=0.27.0
tenacity>=8.2.3
pytz>=2024.1
pillow>=10.3.0
apscheduler==3.11.0
# Testing
pytest>=8.0.0
pytest-asyncio>=0.23.5
@@ -41,4 +42,11 @@ requests>=2.32.0
black>=24.3.0
isort>=5.13.2
flake8>=7.0.0
mypy>=1.8.0
mypy>=1.8.0
# Security
python-jose==3.4.0
bcrypt==4.2.1
cryptography==44.0.1
passlib==1.7.4
freezegun~=1.5.1

View File

View File

@@ -0,0 +1,369 @@
# tests/api/routes/test_auth.py
import json
import uuid
from datetime import datetime, timezone
from unittest.mock import patch, MagicMock, Mock
import pytest
from fastapi import FastAPI, Depends
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session
from app.api.routes.auth import router as auth_router
from app.core.auth import get_password_hash
from app.core.database import get_db
from app.models.user import User
from app.services.auth_service import AuthService, AuthenticationError
from app.core.auth import TokenExpiredError, TokenInvalidError
# Mock the get_db dependency
@pytest.fixture
def override_get_db(db_session):
"""Override get_db dependency for testing."""
return db_session
@pytest.fixture
def app(override_get_db):
"""Create a FastAPI test application with overridden dependencies."""
app = FastAPI()
app.include_router(auth_router, prefix="/auth", tags=["auth"])
# Override the get_db dependency
app.dependency_overrides[get_db] = lambda: override_get_db
return app
@pytest.fixture
def client(app):
"""Create a FastAPI test client."""
return TestClient(app)
class TestRegisterUser:
"""Tests for the register_user endpoint."""
def test_register_user_success(self, client, monkeypatch, db_session):
"""Test successful user registration."""
# Mock the service method with a valid complete User object
mock_user = User(
id=uuid.uuid4(),
email="newuser@example.com",
password_hash="hashed_password",
first_name="New",
last_name="User",
is_active=True,
is_superuser=False,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc)
)
# Use patch for mocking
with patch.object(AuthService, 'create_user', return_value=mock_user):
# Test request
response = client.post(
"/auth/register",
json={
"email": "newuser@example.com",
"password": "Password123",
"first_name": "New",
"last_name": "User"
}
)
# Assertions
assert response.status_code == 201
data = response.json()
assert data["email"] == "newuser@example.com"
assert data["first_name"] == "New"
assert data["last_name"] == "User"
assert "password" not in data
def test_register_user_duplicate_email(self, client, db_session):
"""Test registration with duplicate email."""
# Use patch for mocking with a side effect
with patch.object(AuthService, 'create_user',
side_effect=AuthenticationError("User with this email already exists")):
# Test request
response = client.post(
"/auth/register",
json={
"email": "existing@example.com",
"password": "Password123",
"first_name": "Existing",
"last_name": "User"
}
)
# Assertions
assert response.status_code == 409
assert "already exists" in response.json()["detail"]
class TestLogin:
"""Tests for the login endpoint."""
def test_login_success(self, client, mock_user, db_session):
"""Test successful login."""
# Ensure mock_user has required attributes
if not hasattr(mock_user, 'created_at') or mock_user.created_at is None:
mock_user.created_at = datetime.now(timezone.utc)
if not hasattr(mock_user, 'updated_at') or mock_user.updated_at is None:
mock_user.updated_at = datetime.now(timezone.utc)
# Create mock tokens
mock_tokens = MagicMock(
access_token="mock_access_token",
refresh_token="mock_refresh_token",
token_type="bearer"
)
# Use context managers for patching
with patch.object(AuthService, 'authenticate_user', return_value=mock_user), \
patch.object(AuthService, 'create_tokens', return_value=mock_tokens):
# Test request
response = client.post(
"/auth/login",
json={
"email": "user@example.com",
"password": "Password123"
}
)
# Assertions
assert response.status_code == 200
data = response.json()
assert data["access_token"] == "mock_access_token"
assert data["refresh_token"] == "mock_refresh_token"
assert data["token_type"] == "bearer"
def test_login_invalid_credentials_debug(self, client, app):
"""Improved test for login with invalid credentials."""
# Print response for debugging
from app.services.auth_service import AuthService
# Create a complete mock for AuthService
class MockAuthService:
@staticmethod
def authenticate_user(db, email, password):
print(f"Mock called with: {email}, {password}")
return None
# Replace the entire class with our mock
original_service = AuthService
try:
# Replace with our mock
import sys
sys.modules['app.services.auth_service'].AuthService = MockAuthService
# Make the request
response = client.post(
"/auth/login",
json={
"email": "user@example.com",
"password": "WrongPassword"
}
)
# Print response details for debugging
print(f"Response status: {response.status_code}")
print(f"Response body: {response.text}")
# Assertions
assert response.status_code == 401
assert "Invalid email or password" in response.json()["detail"]
finally:
# Restore original service
sys.modules['app.services.auth_service'].AuthService = original_service
def test_login_inactive_user(self, client, db_session):
"""Test login with inactive user."""
# Mock authentication to raise an error
with patch.object(AuthService, 'authenticate_user',
side_effect=AuthenticationError("User account is inactive")):
# Test request
response = client.post(
"/auth/login",
json={
"email": "inactive@example.com",
"password": "Password123"
}
)
# Assertions
assert response.status_code == 401
assert "inactive" in response.json()["detail"]
class TestRefreshToken:
"""Tests for the refresh_token endpoint."""
def test_refresh_token_success(self, client, db_session):
"""Test successful token refresh."""
# Mock refresh to return tokens
mock_tokens = MagicMock(
access_token="new_access_token",
refresh_token="new_refresh_token",
token_type="bearer"
)
with patch.object(AuthService, 'refresh_tokens', return_value=mock_tokens):
# Test request
response = client.post(
"/auth/refresh",
json={
"refresh_token": "valid_refresh_token"
}
)
# Assertions
assert response.status_code == 200
data = response.json()
assert data["access_token"] == "new_access_token"
assert data["refresh_token"] == "new_refresh_token"
assert data["token_type"] == "bearer"
def test_refresh_token_expired(self, client, db_session):
"""Test refresh with expired token."""
# Mock refresh to raise expired token error
with patch.object(AuthService, 'refresh_tokens',
side_effect=TokenExpiredError("Token expired")):
# Test request
response = client.post(
"/auth/refresh",
json={
"refresh_token": "expired_refresh_token"
}
)
# Assertions
assert response.status_code == 401
assert "expired" in response.json()["detail"]
def test_refresh_token_invalid(self, client, db_session):
"""Test refresh with invalid token."""
# Mock refresh to raise invalid token error
with patch.object(AuthService, 'refresh_tokens',
side_effect=TokenInvalidError("Invalid token")):
# Test request
response = client.post(
"/auth/refresh",
json={
"refresh_token": "invalid_refresh_token"
}
)
# Assertions
assert response.status_code == 401
assert "Invalid" in response.json()["detail"]
class TestChangePassword:
"""Tests for the change_password endpoint."""
def test_change_password_success(self, client, mock_user, db_session, app):
"""Test successful password change."""
# Ensure mock_user has required attributes
if not hasattr(mock_user, 'created_at') or mock_user.created_at is None:
mock_user.created_at = datetime.now(timezone.utc)
if not hasattr(mock_user, 'updated_at') or mock_user.updated_at is None:
mock_user.updated_at = datetime.now(timezone.utc)
# Override get_current_user dependency
from app.api.dependencies.auth import get_current_user
app.dependency_overrides[get_current_user] = lambda: mock_user
# Mock password change to return success
with patch.object(AuthService, 'change_password', return_value=True):
# Test request
response = client.post(
"/auth/change-password",
json={
"current_password": "OldPassword123",
"new_password": "NewPassword123"
}
)
# Assertions
assert response.status_code == 200
assert "success" in response.json()["message"].lower()
# Clean up override
if get_current_user in app.dependency_overrides:
del app.dependency_overrides[get_current_user]
def test_change_password_incorrect_current_password(self, client, mock_user, db_session, app):
"""Test change password with incorrect current password."""
# Ensure mock_user has required attributes
if not hasattr(mock_user, 'created_at') or mock_user.created_at is None:
mock_user.created_at = datetime.now(timezone.utc)
if not hasattr(mock_user, 'updated_at') or mock_user.updated_at is None:
mock_user.updated_at = datetime.now(timezone.utc)
# Override get_current_user dependency
from app.api.dependencies.auth import get_current_user
app.dependency_overrides[get_current_user] = lambda: mock_user
# Mock password change to raise error
with patch.object(AuthService, 'change_password',
side_effect=AuthenticationError("Current password is incorrect")):
# Test request
response = client.post(
"/auth/change-password",
json={
"current_password": "WrongPassword",
"new_password": "NewPassword123"
}
)
# Assertions
assert response.status_code == 400
assert "incorrect" in response.json()["detail"].lower()
# Clean up override
if get_current_user in app.dependency_overrides:
del app.dependency_overrides[get_current_user]
class TestGetCurrentUserInfo:
"""Tests for the get_current_user_info endpoint."""
def test_get_current_user_info(self, client, mock_user, app):
"""Test getting current user info."""
# Ensure mock_user has required attributes
if not hasattr(mock_user, 'created_at') or mock_user.created_at is None:
mock_user.created_at = datetime.now(timezone.utc)
if not hasattr(mock_user, 'updated_at') or mock_user.updated_at is None:
mock_user.updated_at = datetime.now(timezone.utc)
# Override get_current_user dependency
from app.api.dependencies.auth import get_current_user
app.dependency_overrides[get_current_user] = lambda: mock_user
# Test request
response = client.get("/auth/me")
# Assertions
assert response.status_code == 200
data = response.json()
assert data["email"] == mock_user.email
assert data["first_name"] == mock_user.first_name
assert data["last_name"] == mock_user.last_name
assert "password" not in data
# Clean up override
if get_current_user in app.dependency_overrides:
del app.dependency_overrides[get_current_user]
def test_get_current_user_info_unauthorized(self, client):
"""Test getting user info without authentication."""
# Test request without authentication
response = client.get("/auth/me")
# Assertions
assert response.status_code == 401

View File

@@ -0,0 +1,211 @@
# tests/api/dependencies/test_auth_dependencies.py
import pytest
from unittest.mock import patch, MagicMock
from fastapi import HTTPException
from app.api.dependencies.auth import (
get_current_user,
get_current_active_user,
get_current_superuser,
get_optional_current_user
)
from app.core.auth import TokenExpiredError, TokenInvalidError
@pytest.fixture
def mock_token():
return "mock.jwt.token"
class TestGetCurrentUser:
"""Tests for get_current_user dependency"""
def test_get_current_user_success(self, db_session, mock_user, mock_token):
"""Test successfully getting the current user"""
# Mock get_token_data to return user_id that matches our mock_user
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
mock_get_data.return_value.user_id = mock_user.id
# Call the dependency
user = get_current_user(db=db_session, token=mock_token)
# Verify the correct user was returned
assert user.id == mock_user.id
assert user.email == mock_user.email
def test_get_current_user_nonexistent(self, db_session, mock_token):
"""Test when the token contains a user ID that doesn't exist"""
# Mock get_token_data to return a non-existent user ID
# Use a real UUID object instead of a string
import uuid
nonexistent_id = uuid.UUID("11111111-1111-1111-1111-111111111111")
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
mock_get_data.return_value.user_id = nonexistent_id # Using UUID object, not string
# Should raise HTTPException with 404 status
with pytest.raises(HTTPException) as exc_info:
get_current_user(db=db_session, token=mock_token)
assert exc_info.value.status_code == 404
def test_get_current_user_inactive(self, db_session, mock_user, mock_token):
"""Test when the user is inactive"""
# Make the user inactive
mock_user.is_active = False
db_session.commit()
# Mock get_token_data
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
mock_get_data.return_value.user_id = mock_user.id
# Should raise HTTPException with 403 status
with pytest.raises(HTTPException) as exc_info:
get_current_user(db=db_session, token=mock_token)
assert exc_info.value.status_code == 403
def test_get_current_user_expired_token(self, db_session, mock_token):
"""Test with an expired token"""
# Mock get_token_data to raise TokenExpiredError
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
mock_get_data.side_effect = TokenExpiredError("Token expired")
# Should raise HTTPException with 401 status
with pytest.raises(HTTPException) as exc_info:
get_current_user(db=db_session, token=mock_token)
assert exc_info.value.status_code == 401
assert "Token expired" in exc_info.value.detail
def test_get_current_user_invalid_token(self, db_session, mock_token):
"""Test with an invalid token"""
# Mock get_token_data to raise TokenInvalidError
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
mock_get_data.side_effect = TokenInvalidError("Invalid token")
# Should raise HTTPException with 401 status
with pytest.raises(HTTPException) as exc_info:
get_current_user(db=db_session, token=mock_token)
assert exc_info.value.status_code == 401
assert "Could not validate credentials" in exc_info.value.detail
class TestGetCurrentActiveUser:
"""Tests for get_current_active_user dependency"""
def test_get_current_active_user(self, mock_user):
"""Test getting an active user"""
# Ensure user is active
mock_user.is_active = True
# Call the dependency with mocked current_user
user = get_current_active_user(current_user=mock_user)
# Should return the same user
assert user == mock_user
def test_get_current_inactive_user(self, mock_user):
"""Test getting an inactive user"""
# Make user inactive
mock_user.is_active = False
# Should raise HTTPException with 403 status
with pytest.raises(HTTPException) as exc_info:
get_current_active_user(current_user=mock_user)
assert exc_info.value.status_code == 403
assert "Inactive user" in exc_info.value.detail
class TestGetCurrentSuperuser:
"""Tests for get_current_superuser dependency"""
def test_get_current_superuser(self, mock_user):
"""Test getting a superuser"""
# Make user a superuser
mock_user.is_superuser = True
# Call the dependency with mocked current_user
user = get_current_superuser(current_user=mock_user)
# Should return the same user
assert user == mock_user
def test_get_current_non_superuser(self, mock_user):
"""Test getting a non-superuser"""
# Ensure user is not a superuser
mock_user.is_superuser = False
# Should raise HTTPException with 403 status
with pytest.raises(HTTPException) as exc_info:
get_current_superuser(current_user=mock_user)
assert exc_info.value.status_code == 403
assert "Not enough permissions" in exc_info.value.detail
class TestGetOptionalCurrentUser:
"""Tests for get_optional_current_user dependency"""
def test_get_optional_current_user_with_token(self, db_session, mock_user, mock_token):
"""Test getting optional user with a valid token"""
# Mock get_token_data
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
mock_get_data.return_value.user_id = mock_user.id
# Call the dependency
user = get_optional_current_user(db=db_session, token=mock_token)
# Should return the correct user
assert user is not None
assert user.id == mock_user.id
def test_get_optional_current_user_no_token(self, db_session):
"""Test getting optional user with no token"""
# Call the dependency with no token
user = get_optional_current_user(db=db_session, token=None)
# Should return None
assert user is None
def test_get_optional_current_user_invalid_token(self, db_session, mock_token):
"""Test getting optional user with an invalid token"""
# Mock get_token_data to raise TokenInvalidError
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
mock_get_data.side_effect = TokenInvalidError("Invalid token")
# Call the dependency
user = get_optional_current_user(db=db_session, token=mock_token)
# Should return None, not raise an exception
assert user is None
def test_get_optional_current_user_expired_token(self, db_session, mock_token):
"""Test getting optional user with an expired token"""
# Mock get_token_data to raise TokenExpiredError
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
mock_get_data.side_effect = TokenExpiredError("Token expired")
# Call the dependency
user = get_optional_current_user(db=db_session, token=mock_token)
# Should return None, not raise an exception
assert user is None
def test_get_optional_current_user_inactive(self, db_session, mock_user, mock_token):
"""Test getting optional user when user is inactive"""
# Make the user inactive
mock_user.is_active = False
db_session.commit()
# Mock get_token_data
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
mock_get_data.return_value.user_id = mock_user.id
# Call the dependency
user = get_optional_current_user(db=db_session, token=mock_token)
# Should return None for inactive users
assert user is None

66
backend/tests/conftest.py Normal file
View File

@@ -0,0 +1,66 @@
# tests/conftest.py
import uuid
from datetime import datetime, timezone
import pytest
from app.models.user import User
from app.utils.test_utils import setup_test_db, teardown_test_db, setup_async_test_db, teardown_async_test_db
@pytest.fixture(scope="function")
def db_session():
"""
Creates a fresh SQLite in-memory database for each test function.
Yields a SQLAlchemy session that can be used for testing.
"""
# Set up the database
test_engine, TestingSessionLocal = setup_test_db()
# Create a session
with TestingSessionLocal() as session:
yield session
# Clean up
teardown_test_db(test_engine)
@pytest.fixture(scope="function") # Define a fixture
async def async_test_db():
"""Fixture provides new testing engine and session for each test run to improve isolation."""
test_engine, AsyncTestingSessionLocal = await setup_async_test_db()
yield test_engine, AsyncTestingSessionLocal
await teardown_async_test_db(test_engine)
@pytest.fixture
def user_create_data():
return {
"email": "newtest@example.com", # Changed to avoid conflict with mock_user
"password": "TestPassword123!",
"first_name": "Test",
"last_name": "User",
"phone_number": "+1234567890",
"is_superuser": False,
"preferences": None
}
@pytest.fixture
def mock_user(db_session):
"""Fixture to create and return a mock User instance."""
mock_user = User(
id=uuid.uuid4(),
email="mockuser@example.com",
password_hash="mockhashedpassword",
first_name="Mock",
last_name="User",
phone_number="1234567890",
is_active=True,
is_superuser=False,
preferences=None,
)
db_session.add(mock_user)
db_session.commit()
return mock_user

View File

View File

@@ -0,0 +1,260 @@
# tests/core/test_auth.py
import uuid
import pytest
from datetime import datetime, timedelta, timezone
from jose import jwt
from pydantic import ValidationError
from app.core.auth import (
verify_password,
get_password_hash,
create_access_token,
create_refresh_token,
decode_token,
get_token_data,
TokenExpiredError,
TokenInvalidError,
TokenMissingClaimError
)
from app.core.config import settings
class TestPasswordHandling:
"""Tests for password hashing and verification functions"""
def test_password_hash_different_from_password(self):
"""Test that a password hash is different from the original password"""
password = "TestPassword123"
hashed = get_password_hash(password)
assert hashed != password
def test_verify_correct_password(self):
"""Test that verify_password returns True for the correct password"""
password = "TestPassword123"
hashed = get_password_hash(password)
assert verify_password(password, hashed) is True
def test_verify_incorrect_password(self):
"""Test that verify_password returns False for an incorrect password"""
password = "TestPassword123"
wrong_password = "WrongPassword123"
hashed = get_password_hash(password)
assert verify_password(wrong_password, hashed) is False
def test_same_password_different_hash(self):
"""Test that the same password gets a different hash each time"""
password = "TestPassword123"
hash1 = get_password_hash(password)
hash2 = get_password_hash(password)
assert hash1 != hash2
class TestTokenCreation:
"""Tests for token creation functions"""
def test_create_access_token(self):
"""Test that an access token is created with the correct claims"""
user_id = str(uuid.uuid4())
custom_claims = {
"email": "test@example.com",
"first_name": "Test",
"is_superuser": True
}
token = create_access_token(subject=user_id, claims=custom_claims)
# Decode token to verify claims
payload = jwt.decode(
token,
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM]
)
# Check standard claims
assert payload["sub"] == user_id
assert "jti" in payload
assert "exp" in payload
assert "iat" in payload
assert payload["type"] == "access"
# Check custom claims
for key, value in custom_claims.items():
assert payload[key] == value
def test_create_refresh_token(self):
"""Test that a refresh token is created with the correct claims"""
user_id = str(uuid.uuid4())
token = create_refresh_token(subject=user_id)
# Decode token to verify claims
payload = jwt.decode(
token,
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM]
)
# Check standard claims
assert payload["sub"] == user_id
assert "jti" in payload
assert "exp" in payload
assert "iat" in payload
assert payload["type"] == "refresh"
def test_token_expiration(self):
"""Test that tokens have the correct expiration time"""
user_id = str(uuid.uuid4())
expires = timedelta(minutes=5)
# Create token with specific expiration
token = create_access_token(
subject=user_id,
expires_delta=expires
)
# Decode token
payload = jwt.decode(
token,
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM]
)
# Get actual expiration time from token
expiration = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
# Calculate expected expiration (approximately)
now = datetime.now(timezone.utc)
expected_expiration = now + expires
# Difference should be small (less than 1 second)
difference = abs((expiration - expected_expiration).total_seconds())
assert difference < 1
class TestTokenDecoding:
"""Tests for token decoding and validation functions"""
def test_decode_valid_token(self):
"""Test that a valid token can be decoded"""
user_id = str(uuid.uuid4())
token = create_access_token(subject=user_id)
# Decode token
payload = decode_token(token)
# Check that the subject matches
assert payload.sub == user_id
def test_decode_expired_token(self):
"""Test that an expired token raises TokenExpiredError"""
user_id = str(uuid.uuid4())
# Create a token that's already expired by directly manipulating the payload
now = datetime.now(timezone.utc)
expired_time = now - timedelta(hours=1) # 1 hour in the past
# Create the expired token manually
payload = {
"sub": user_id,
"exp": int(expired_time.timestamp()), # Set expiration in the past
"iat": int(now.timestamp()),
"jti": str(uuid.uuid4()),
"type": "access"
}
expired_token = jwt.encode(
payload,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
)
# Attempting to decode should raise TokenExpiredError
with pytest.raises(TokenExpiredError):
decode_token(expired_token)
def test_decode_invalid_token(self):
"""Test that an invalid token raises TokenInvalidError"""
invalid_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJpbnZhbGlkIn0.invalid-signature"
with pytest.raises(TokenInvalidError):
decode_token(invalid_token)
def test_decode_token_with_missing_sub(self):
"""Test that a token without 'sub' claim raises TokenMissingClaimError"""
# Create a token without a subject
now = datetime.now(timezone.utc)
payload = {
"exp": int((now + timedelta(minutes=30)).timestamp()),
"iat": int(now.timestamp()),
"jti": str(uuid.uuid4()),
"type": "access"
# No 'sub' claim
}
token = jwt.encode(
payload,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
)
with pytest.raises(TokenMissingClaimError):
decode_token(token)
def test_decode_token_with_wrong_type(self):
"""Test that verifying a token with wrong type raises TokenInvalidError"""
user_id = str(uuid.uuid4())
token = create_access_token(subject=user_id)
# Try to verify it as a refresh token
with pytest.raises(TokenInvalidError):
decode_token(token, verify_type="refresh")
def test_decode_with_invalid_payload(self):
"""Test that a token with invalid payload structure raises TokenInvalidError"""
# Create a token with an invalid payload structure - missing 'sub' which is required
# but including 'exp' to avoid the expiration check
now = datetime.now(timezone.utc)
payload = {
# Missing "sub" field which is required
"exp": int((now + timedelta(minutes=30)).timestamp()),
"iat": int(now.timestamp()),
"jti": str(uuid.uuid4()),
"invalid_field": "test"
}
token = jwt.encode(
payload,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
)
# Should raise TokenMissingClaimError due to missing 'sub'
with pytest.raises(TokenMissingClaimError):
decode_token(token)
# Create another token with invalid type for required field
payload = {
"sub": 123, # sub should be a string, not an integer
"exp": int((now + timedelta(minutes=30)).timestamp()),
}
token = jwt.encode(
payload,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
)
# Should raise TokenInvalidError due to ValidationError
with pytest.raises(TokenInvalidError):
decode_token(token)
def test_get_token_data(self):
"""Test extracting TokenData from a token"""
user_id = uuid.uuid4()
token = create_access_token(
subject=str(user_id),
claims={"is_superuser": True}
)
token_data = get_token_data(token)
assert token_data.user_id == user_id
assert token_data.is_superuser is True

View File

View File

@@ -0,0 +1,125 @@
import pytest
from app.crud.user import user as user_crud
from app.models.user import User
from app.schemas.users import UserCreate, UserUpdate
def test_create_user(db_session, user_create_data):
user_in = UserCreate(**user_create_data)
user_obj = user_crud.create(db_session, obj_in=user_in)
assert user_obj.email == user_create_data["email"]
assert user_obj.first_name == user_create_data["first_name"]
assert user_obj.last_name == user_create_data["last_name"]
assert user_obj.phone_number == user_create_data["phone_number"]
assert user_obj.is_superuser == user_create_data["is_superuser"]
assert user_obj.password_hash is not None
assert user_obj.id is not None
def test_get_user(db_session, mock_user):
# Using mock_user fixture instead of creating new user
stored_user = user_crud.get(db_session, id=mock_user.id)
assert stored_user
assert stored_user.id == mock_user.id
assert stored_user.email == mock_user.email
def test_get_user_by_email(db_session, mock_user):
stored_user = user_crud.get_by_email(db_session, email=mock_user.email)
assert stored_user
assert stored_user.id == mock_user.id
assert stored_user.email == mock_user.email
def test_update_user(db_session, mock_user):
update_data = UserUpdate(
first_name="Updated",
last_name="Name",
phone_number="+9876543210"
)
updated_user = user_crud.update(db_session, db_obj=mock_user, obj_in=update_data)
assert updated_user.first_name == "Updated"
assert updated_user.last_name == "Name"
assert updated_user.phone_number == "+9876543210"
assert updated_user.email == mock_user.email
def test_delete_user(db_session, mock_user):
user_crud.remove(db_session, id=mock_user.id)
deleted_user = user_crud.get(db_session, id=mock_user.id)
assert deleted_user is None
def test_get_multi_users(db_session, mock_user, user_create_data):
# Create additional users (mock_user is already in db)
users_data = [
{**user_create_data, "email": f"test{i}@example.com"}
for i in range(2) # Creating 2 more users + mock_user = 3 total
]
for user_data in users_data:
user_in = UserCreate(**user_data)
user_crud.create(db_session, obj_in=user_in)
users = user_crud.get_multi(db_session, skip=0, limit=10)
assert len(users) == 3
assert all(isinstance(user, User) for user in users)
def test_is_active(db_session, mock_user):
assert user_crud.is_active(mock_user) is True
# Test deactivating user
update_data = UserUpdate(is_active=False)
deactivated_user = user_crud.update(db_session, db_obj=mock_user, obj_in=update_data)
assert user_crud.is_active(deactivated_user) is False
def test_is_superuser(db_session, mock_user, user_create_data):
# mock_user is regular user
assert user_crud.is_superuser(mock_user) is False
# Create superuser
super_user_data = {**user_create_data, "email": "super@example.com", "is_superuser": True}
super_user_in = UserCreate(**super_user_data)
super_user = user_crud.create(db_session, obj_in=super_user_in)
assert user_crud.is_superuser(super_user) is True
# Additional test cases
def test_create_duplicate_email(db_session, mock_user):
user_data = UserCreate(
email=mock_user.email, # Try to create user with existing email
password="TestPassword123!",
first_name="Test",
last_name="User"
)
with pytest.raises(Exception): # Should raise an integrity error
user_crud.create(db_session, obj_in=user_data)
def test_update_user_preferences(db_session, mock_user):
preferences = {"theme": "dark", "notifications": True}
update_data = UserUpdate(preferences=preferences)
updated_user = user_crud.update(db_session, db_obj=mock_user, obj_in=update_data)
assert updated_user.preferences == preferences
def test_get_multi_users_pagination(db_session, user_create_data):
# Create 5 users
for i in range(5):
user_in = UserCreate(**{**user_create_data, "email": f"test{i}@example.com"})
user_crud.create(db_session, obj_in=user_in)
# Test pagination
first_page = user_crud.get_multi(db_session, skip=0, limit=2)
second_page = user_crud.get_multi(db_session, skip=2, limit=2)
assert len(first_page) == 2
assert len(second_page) == 2
assert first_page[0].id != second_page[0].id

View File

View File

@@ -0,0 +1,249 @@
# tests/models/test_user.py
import uuid
import pytest
from datetime import datetime
from sqlalchemy.exc import IntegrityError
from app.models.user import User
def test_create_user(db_session):
"""Test creating a basic user."""
# Arrange
user_id = uuid.uuid4()
new_user = User(
id=user_id,
email="test@example.com",
password_hash="hashedpassword",
first_name="Test",
last_name="User",
phone_number="1234567890",
is_active=True,
is_superuser=False,
preferences={"theme": "dark"},
)
db_session.add(new_user)
# Act
db_session.commit()
created_user = db_session.query(User).filter_by(email="test@example.com").first()
# Assert
assert created_user is not None
assert created_user.email == "test@example.com"
assert created_user.first_name == "Test"
assert created_user.last_name == "User"
assert created_user.phone_number == "1234567890"
assert created_user.is_active is True
assert created_user.is_superuser is False
assert created_user.preferences == {"theme": "dark"}
# UUID should be preserved
assert created_user.id == user_id
# Timestamps should be set
assert isinstance(created_user.created_at, datetime)
assert isinstance(created_user.updated_at, datetime)
def test_update_user(db_session):
"""Test updating an existing user."""
# Arrange - Create a user
user_id = uuid.uuid4()
user = User(
id=user_id,
email="update@example.com",
password_hash="hashedpassword",
first_name="Before",
last_name="Update",
)
db_session.add(user)
db_session.commit()
# Record the original creation timestamp
original_created_at = user.created_at
# Act - Update the user
user.first_name = "After"
user.last_name = "Updated"
user.phone_number = "9876543210"
user.preferences = {"theme": "light", "notifications": True}
db_session.commit()
# Fetch the updated user to verify changes were persisted
updated_user = db_session.query(User).filter_by(id=user_id).first()
# Assert
assert updated_user.first_name == "After"
assert updated_user.last_name == "Updated"
assert updated_user.phone_number == "9876543210"
assert updated_user.preferences == {"theme": "light", "notifications": True}
# created_at shouldn't change on update
assert updated_user.created_at == original_created_at
# updated_at should be newer than created_at
assert updated_user.updated_at > original_created_at
def test_delete_user(db_session):
"""Test deleting a user."""
# Arrange - Create a user
user_id = uuid.uuid4()
user = User(
id=user_id,
email="delete@example.com",
password_hash="hashedpassword",
first_name="Delete",
last_name="Me",
)
db_session.add(user)
db_session.commit()
# Act - Delete the user
db_session.delete(user)
db_session.commit()
# Assert
deleted_user = db_session.query(User).filter_by(id=user_id).first()
assert deleted_user is None
def test_user_unique_email_constraint(db_session):
"""Test that users cannot have duplicate emails."""
# Arrange - Create a user
user1 = User(
id=uuid.uuid4(),
email="duplicate@example.com",
password_hash="hashedpassword",
first_name="First",
last_name="User",
)
db_session.add(user1)
db_session.commit()
# Act & Assert - Try to create another user with the same email
user2 = User(
id=uuid.uuid4(),
email="duplicate@example.com", # Same email
password_hash="differenthash",
first_name="Second",
last_name="User",
)
db_session.add(user2)
# Should raise IntegrityError due to unique constraint
with pytest.raises(IntegrityError):
db_session.commit()
# Rollback for cleanup
db_session.rollback()
def test_user_required_fields(db_session):
"""Test that required fields are enforced."""
# Test each required field by creating a user without it
# Missing email
user_no_email = User(
id=uuid.uuid4(),
# email is missing
password_hash="hashedpassword",
first_name="Test",
last_name="User",
)
db_session.add(user_no_email)
with pytest.raises(IntegrityError):
db_session.commit()
db_session.rollback()
# Missing password_hash
user_no_password = User(
id=uuid.uuid4(),
email="nopassword@example.com",
# password_hash is missing
first_name="Test",
last_name="User",
)
db_session.add(user_no_password)
with pytest.raises(IntegrityError):
db_session.commit()
db_session.rollback()
def test_user_defaults(db_session):
"""Test that default values are correctly set."""
# Arrange - Create a minimal user with only required fields
minimal_user = User(
id=uuid.uuid4(),
email="minimal@example.com",
password_hash="hashedpassword",
first_name="Minimal",
last_name="User",
)
db_session.add(minimal_user)
db_session.commit()
# Act - Retrieve the user
created_user = db_session.query(User).filter_by(email="minimal@example.com").first()
# Assert - Check default values
assert created_user.is_active is True # Default should be True
assert created_user.is_superuser is False # Default should be False
assert created_user.phone_number is None # Optional field
assert created_user.preferences is None # Optional field
def test_user_string_representation(db_session):
"""Test the string representation of a user."""
# Arrange
user = User(
id=uuid.uuid4(),
email="repr@example.com",
password_hash="hashedpassword",
first_name="String",
last_name="Repr",
)
# Act & Assert
assert str(user) == "<User repr@example.com>"
assert repr(user) == "<User repr@example.com>"
def test_user_with_complex_json_preferences(db_session):
"""Test storing and retrieving complex JSON preferences."""
# Arrange - Create a user with nested JSON preferences
complex_preferences = {
"theme": {
"mode": "dark",
"colors": {
"primary": "#333",
"secondary": "#666"
}
},
"notifications": {
"email": True,
"sms": False,
"push": {
"enabled": True,
"quiet_hours": [22, 7]
}
},
"tags": ["important", "family", "events"]
}
user = User(
id=uuid.uuid4(),
email="complex@example.com",
password_hash="hashedpassword",
first_name="Complex",
last_name="JSON",
preferences=complex_preferences
)
db_session.add(user)
db_session.commit()
# Act - Retrieve the user
retrieved_user = db_session.query(User).filter_by(email="complex@example.com").first()
# Assert - The complex JSON should be preserved
assert retrieved_user.preferences == complex_preferences
assert retrieved_user.preferences["theme"]["colors"]["primary"] == "#333"
assert retrieved_user.preferences["notifications"]["push"]["quiet_hours"] == [22, 7]
assert "important" in retrieved_user.preferences["tags"]

View File

View File

@@ -0,0 +1,127 @@
# tests/schemas/test_user_schemas.py
import pytest
import re
from pydantic import ValidationError
from app.schemas.users import UserBase, UserCreate
class TestPhoneNumberValidation:
"""Tests for phone number validation in user schemas"""
def test_valid_swiss_numbers(self):
"""Test valid Swiss phone numbers are accepted"""
# International format
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41791234567")
assert user.phone_number == "+41791234567"
# Local format
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0791234567")
assert user.phone_number == "0791234567"
# With formatting characters
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41 79 123 45 67")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+41791234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="079 123 45 67")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "0791234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41-79-123-45-67")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+41791234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="079-123-45-67")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "0791234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41 (79) 123 45 67")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+41791234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="079 (123) 45 67")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "0791234567"
def test_valid_italian_numbers(self):
"""Test valid Italian phone numbers are accepted"""
# International format
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+393451234567")
assert user.phone_number == "+393451234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39345123456")
assert user.phone_number == "+39345123456"
# Local format
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="03451234567")
assert user.phone_number == "03451234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345123456789")
assert user.phone_number == "0345123456789"
# With formatting characters
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39 345 123 4567")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+393451234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345 123 4567")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "03451234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39-345-123-4567")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+393451234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345-123-4567")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "03451234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39 (345) 123 4567")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+393451234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345 (123) 4567")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "03451234567"
def test_none_phone_number(self):
"""Test that None is accepted as a valid value (optional phone number)"""
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number=None)
assert user.phone_number is None
def test_invalid_phone_numbers(self):
"""Test that invalid phone numbers are rejected"""
invalid_numbers = [
# Too short
"+12",
"012",
# Invalid characters
"+41xyz123456",
"079abc4567",
"123-abc-7890",
"+1(800)CALL-NOW",
# Completely invalid formats
"++4412345678", # Double plus
"()+41123456", # Misplaced parentheses
# Empty string
"",
# Spaces only
" ",
]
for number in invalid_numbers:
with pytest.raises(ValidationError):
UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number=number)
def test_phone_validation_in_user_create(self):
"""Test that phone validation also works in UserCreate schema"""
# Valid phone number
user = UserCreate(
email="test@example.com",
first_name="Test",
last_name="User",
password="Password123",
phone_number="+41791234567"
)
assert user.phone_number == "+41791234567"
# Invalid phone number should raise ValidationError
with pytest.raises(ValidationError):
UserCreate(
email="test@example.com",
first_name="Test",
last_name="User",
password="Password123",
phone_number="invalid-number"
)

View File

View File

@@ -0,0 +1,252 @@
# tests/services/test_auth_service.py
import uuid
import pytest
from unittest.mock import patch
from app.core.auth import get_password_hash, verify_password, TokenExpiredError, TokenInvalidError
from app.models.user import User
from app.schemas.users import UserCreate, Token
from app.services.auth_service import AuthService, AuthenticationError
class TestAuthServiceAuthentication:
"""Tests for AuthService.authenticate_user method"""
def test_authenticate_valid_user(self, db_session, mock_user):
"""Test authenticating a user with valid credentials"""
# Set a known password for the mock user
password = "TestPassword123"
mock_user.password_hash = get_password_hash(password)
db_session.commit()
# Authenticate with correct credentials
user = AuthService.authenticate_user(
db=db_session,
email=mock_user.email,
password=password
)
assert user is not None
assert user.id == mock_user.id
assert user.email == mock_user.email
def test_authenticate_nonexistent_user(self, db_session):
"""Test authenticating with an email that doesn't exist"""
user = AuthService.authenticate_user(
db=db_session,
email="nonexistent@example.com",
password="password"
)
assert user is None
def test_authenticate_with_wrong_password(self, db_session, mock_user):
"""Test authenticating with the wrong password"""
# Set a known password for the mock user
password = "TestPassword123"
mock_user.password_hash = get_password_hash(password)
db_session.commit()
# Authenticate with wrong password
user = AuthService.authenticate_user(
db=db_session,
email=mock_user.email,
password="WrongPassword123"
)
assert user is None
def test_authenticate_inactive_user(self, db_session, mock_user):
"""Test authenticating an inactive user"""
# Set a known password and make user inactive
password = "TestPassword123"
mock_user.password_hash = get_password_hash(password)
mock_user.is_active = False
db_session.commit()
# Should raise AuthenticationError
with pytest.raises(AuthenticationError):
AuthService.authenticate_user(
db=db_session,
email=mock_user.email,
password=password
)
class TestAuthServiceUserCreation:
"""Tests for AuthService.create_user method"""
def test_create_new_user(self, db_session):
"""Test creating a new user"""
user_data = UserCreate(
email="newuser@example.com",
password="TestPassword123",
first_name="New",
last_name="User",
phone_number="1234567890"
)
user = AuthService.create_user(db=db_session, user_data=user_data)
# Verify user was created with correct data
assert user is not None
assert user.email == user_data.email
assert user.first_name == user_data.first_name
assert user.last_name == user_data.last_name
assert user.phone_number == user_data.phone_number
# Verify password was hashed
assert user.password_hash != user_data.password
assert verify_password(user_data.password, user.password_hash)
# Verify default values
assert user.is_active is True
assert user.is_superuser is False
def test_create_user_with_existing_email(self, db_session, mock_user):
"""Test creating a user with an email that already exists"""
user_data = UserCreate(
email=mock_user.email, # Use existing email
password="TestPassword123",
first_name="Duplicate",
last_name="User"
)
# Should raise AuthenticationError
with pytest.raises(AuthenticationError):
AuthService.create_user(db=db_session, user_data=user_data)
class TestAuthServiceTokens:
"""Tests for AuthService token-related methods"""
def test_create_tokens(self, mock_user):
"""Test creating access and refresh tokens for a user"""
tokens = AuthService.create_tokens(mock_user)
# Verify token structure
assert isinstance(tokens, Token)
assert tokens.access_token is not None
assert tokens.refresh_token is not None
assert tokens.token_type == "bearer"
# This is a more in-depth test that would decode the tokens to verify claims
# but we'll rely on the auth module tests for token verification
def test_refresh_tokens(self, db_session, mock_user):
"""Test refreshing tokens with a valid refresh token"""
# Create initial tokens
initial_tokens = AuthService.create_tokens(mock_user)
# Refresh tokens
new_tokens = AuthService.refresh_tokens(
db=db_session,
refresh_token=initial_tokens.refresh_token
)
# Verify new tokens are different from old ones
assert new_tokens.access_token != initial_tokens.access_token
assert new_tokens.refresh_token != initial_tokens.refresh_token
def test_refresh_tokens_with_invalid_token(self, db_session):
"""Test refreshing tokens with an invalid token"""
# Create an invalid token
invalid_token = "invalid.token.string"
# Should raise TokenInvalidError
with pytest.raises(TokenInvalidError):
AuthService.refresh_tokens(
db=db_session,
refresh_token=invalid_token
)
def test_refresh_tokens_with_access_token(self, db_session, mock_user):
"""Test refreshing tokens with an access token instead of refresh token"""
# Create tokens
tokens = AuthService.create_tokens(mock_user)
# Try to refresh with access token
with pytest.raises(TokenInvalidError):
AuthService.refresh_tokens(
db=db_session,
refresh_token=tokens.access_token
)
def test_refresh_tokens_with_nonexistent_user(self, db_session):
"""Test refreshing tokens for a user that doesn't exist in the database"""
# Create a token for a non-existent user
non_existent_id = str(uuid.uuid4())
with patch('app.core.auth.decode_token'), patch('app.core.auth.get_token_data') as mock_get_data:
# Mock the token data to return a non-existent user ID
mock_get_data.return_value.user_id = uuid.UUID(non_existent_id)
# Should raise TokenInvalidError
with pytest.raises(TokenInvalidError):
AuthService.refresh_tokens(
db=db_session,
refresh_token="some.refresh.token"
)
class TestAuthServicePasswordChange:
"""Tests for AuthService.change_password method"""
def test_change_password(self, db_session, mock_user):
"""Test changing a user's password"""
# Set a known password for the mock user
current_password = "CurrentPassword123"
mock_user.password_hash = get_password_hash(current_password)
db_session.commit()
# Change password
new_password = "NewPassword456"
result = AuthService.change_password(
db=db_session,
user_id=mock_user.id,
current_password=current_password,
new_password=new_password
)
# Verify operation was successful
assert result is True
# Refresh user from DB
db_session.refresh(mock_user)
# Verify old password no longer works
assert not verify_password(current_password, mock_user.password_hash)
# Verify new password works
assert verify_password(new_password, mock_user.password_hash)
def test_change_password_wrong_current_password(self, db_session, mock_user):
"""Test changing password with incorrect current password"""
# Set a known password for the mock user
current_password = "CurrentPassword123"
mock_user.password_hash = get_password_hash(current_password)
db_session.commit()
# Try to change password with wrong current password
wrong_password = "WrongPassword123"
with pytest.raises(AuthenticationError):
AuthService.change_password(
db=db_session,
user_id=mock_user.id,
current_password=wrong_password,
new_password="NewPassword456"
)
# Verify password was not changed
assert verify_password(current_password, mock_user.password_hash)
def test_change_password_nonexistent_user(self, db_session):
"""Test changing password for a user that doesn't exist"""
non_existent_id = uuid.uuid4()
with pytest.raises(AuthenticationError):
AuthService.change_password(
db=db_session,
user_id=non_existent_id,
current_password="CurrentPassword123",
new_password="NewPassword456"
)