Add comprehensive test suite and utilities for user functionality
This commit introduces a suite of tests for user models, schemas, CRUD operations, and authentication services. It also adds utilities for in-memory database setup to support these tests and updates environment settings for consistency.
This commit is contained in:
@@ -17,7 +17,7 @@ ENVIRONMENT=development
|
||||
DEBUG=true
|
||||
BACKEND_CORS_ORIGINS=["http://localhost:3000"]
|
||||
FIRST_SUPERUSER_EMAIL=admin@example.com
|
||||
FIRST_SUPERUSER_PASSWORD=admin123
|
||||
FIRST_SUPERUSER_PASSWORD=Admin123
|
||||
|
||||
# Frontend settings
|
||||
FRONTEND_PORT=3000
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
"""Add all initial models
|
||||
|
||||
Revision ID: 38bf9e7e74b3
|
||||
Revises: 7396957cbe80
|
||||
Create Date: 2025-02-28 09:19:33.212278
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '38bf9e7e74b3'
|
||||
down_revision: Union[str, None] = '7396957cbe80'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
||||
op.create_table('users',
|
||||
sa.Column('email', sa.String(), nullable=False),
|
||||
sa.Column('password_hash', sa.String(), nullable=False),
|
||||
sa.Column('first_name', sa.String(), nullable=False),
|
||||
sa.Column('last_name', sa.String(), nullable=True),
|
||||
sa.Column('phone_number', sa.String(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_superuser', sa.Boolean(), nullable=False),
|
||||
sa.Column('preferences', sa.JSON(), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('ix_users_email'), table_name='users')
|
||||
op.drop_table('users')
|
||||
# ### end Alembic commands ###
|
||||
0
backend/app/api/__init__.py
Normal file
0
backend/app/api/__init__.py
Normal file
0
backend/app/api/dependencies/__init__.py
Normal file
0
backend/app/api/dependencies/__init__.py
Normal file
137
backend/app/api/dependencies/auth.py
Normal file
137
backend/app/api/dependencies/auth.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.auth import get_token_data, TokenExpiredError, TokenInvalidError
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
|
||||
# OAuth2 configuration
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
|
||||
|
||||
def get_current_user(
|
||||
db: Session = Depends(get_db),
|
||||
token: str = Depends(oauth2_scheme)
|
||||
) -> User:
|
||||
"""
|
||||
Get the current authenticated user.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
token: JWT token from request
|
||||
|
||||
Returns:
|
||||
User: The authenticated user
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails
|
||||
"""
|
||||
try:
|
||||
# Decode token and get user ID
|
||||
token_data = get_token_data(token)
|
||||
|
||||
# Get user from database
|
||||
user = db.query(User).filter(User.id == token_data.user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user"
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
except TokenExpiredError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token expired",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
except TokenInvalidError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
|
||||
|
||||
def get_current_active_user(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> User:
|
||||
"""
|
||||
Check if the current user is active.
|
||||
|
||||
Args:
|
||||
current_user: The current authenticated user
|
||||
|
||||
Returns:
|
||||
User: The authenticated and active user
|
||||
|
||||
Raises:
|
||||
HTTPException: If user is inactive
|
||||
"""
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user"
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
def get_current_superuser(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> User:
|
||||
"""
|
||||
Check if the current user is a superuser.
|
||||
|
||||
Args:
|
||||
current_user: The current authenticated user
|
||||
|
||||
Returns:
|
||||
User: The authenticated superuser
|
||||
|
||||
Raises:
|
||||
HTTPException: If user is not a superuser
|
||||
"""
|
||||
if not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not enough permissions"
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
def get_optional_current_user(
|
||||
db: Session = Depends(get_db),
|
||||
token: Optional[str] = Depends(oauth2_scheme)
|
||||
) -> Optional[User]:
|
||||
"""
|
||||
Get the current user if authenticated, otherwise return None.
|
||||
Useful for endpoints that work with both authenticated and unauthenticated users.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
token: JWT token from request
|
||||
|
||||
Returns:
|
||||
User or None: The authenticated user or None
|
||||
"""
|
||||
if not token:
|
||||
return None
|
||||
|
||||
try:
|
||||
token_data = get_token_data(token)
|
||||
user = db.query(User).filter(User.id == token_data.user_id).first()
|
||||
if not user or not user.is_active:
|
||||
return None
|
||||
return user
|
||||
except (TokenExpiredError, TokenInvalidError):
|
||||
return None
|
||||
6
backend/app/api/main.py
Normal file
6
backend/app/api/main.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.routes import auth
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
|
||||
0
backend/app/api/routes/__init__.py
Normal file
0
backend/app/api/routes/__init__.py
Normal file
231
backend/app/api/routes/auth.py
Normal file
231
backend/app/api/routes/auth.py
Normal file
@@ -0,0 +1,231 @@
|
||||
# app/api/routes/auth.py
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Body
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.core.auth import TokenExpiredError, TokenInvalidError
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.users import (
|
||||
UserCreate,
|
||||
UserResponse,
|
||||
Token,
|
||||
LoginRequest,
|
||||
RefreshTokenRequest
|
||||
)
|
||||
from app.services.auth_service import AuthService, AuthenticationError
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def register_user(
|
||||
user_data: UserCreate,
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Register a new user.
|
||||
|
||||
Returns:
|
||||
The created user information.
|
||||
"""
|
||||
try:
|
||||
user = AuthService.create_user(db, user_data)
|
||||
return user
|
||||
except AuthenticationError as e:
|
||||
logger.warning(f"Registration failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during registration: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred. Please try again later."
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login", response_model=Token)
|
||||
async def login(
|
||||
login_data: LoginRequest,
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Login with username and password.
|
||||
|
||||
Returns:
|
||||
Access and refresh tokens.
|
||||
"""
|
||||
try:
|
||||
# Attempt to authenticate the user
|
||||
user = AuthService.authenticate_user(db, login_data.email, login_data.password)
|
||||
|
||||
# Explicitly check for None result and raise correct exception
|
||||
if user is None:
|
||||
logger.warning(f"Invalid login attempt for: {login_data.email}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid email or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# User is authenticated, generate tokens
|
||||
tokens = AuthService.create_tokens(user)
|
||||
logger.info(f"User login successful: {user.email}")
|
||||
return tokens
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions without modification
|
||||
raise
|
||||
except AuthenticationError as e:
|
||||
# Handle specific authentication errors like inactive accounts
|
||||
logger.warning(f"Authentication failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
except Exception as e:
|
||||
# Handle unexpected errors
|
||||
logger.error(f"Unexpected error during login: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred. Please try again later."
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login/oauth", response_model=Token)
|
||||
async def login_oauth(
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
OAuth2-compatible login endpoint, used by the OpenAPI UI.
|
||||
|
||||
Returns:
|
||||
Access and refresh tokens.
|
||||
"""
|
||||
try:
|
||||
user = AuthService.authenticate_user(db, form_data.username, form_data.password)
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid email or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Generate tokens
|
||||
tokens = AuthService.create_tokens(user)
|
||||
|
||||
# Format response for OAuth compatibility
|
||||
return {
|
||||
"access_token": tokens.access_token,
|
||||
"refresh_token": tokens.refresh_token,
|
||||
"token_type": tokens.token_type
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except AuthenticationError as e:
|
||||
logger.warning(f"OAuth authentication failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during OAuth login: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred. Please try again later."
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=Token)
|
||||
async def refresh_token(
|
||||
refresh_data: RefreshTokenRequest,
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Refresh access token using a refresh token.
|
||||
|
||||
Returns:
|
||||
New access and refresh tokens.
|
||||
"""
|
||||
try:
|
||||
tokens = AuthService.refresh_tokens(db, refresh_data.refresh_token)
|
||||
return tokens
|
||||
except TokenExpiredError:
|
||||
logger.warning("Token refresh failed: Token expired")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Refresh token has expired. Please log in again.",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
except TokenInvalidError:
|
||||
logger.warning("Token refresh failed: Invalid token")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during token refresh: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred. Please try again later."
|
||||
)
|
||||
|
||||
|
||||
@router.post("/change-password", status_code=status.HTTP_200_OK)
|
||||
async def change_password(
|
||||
current_password: str = Body(..., embed=True),
|
||||
new_password: str = Body(..., embed=True),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Change current user's password.
|
||||
|
||||
Requires authentication.
|
||||
"""
|
||||
try:
|
||||
success = AuthService.change_password(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
current_password=current_password,
|
||||
new_password=new_password
|
||||
)
|
||||
|
||||
if success:
|
||||
return {"message": "Password changed successfully"}
|
||||
except AuthenticationError as e:
|
||||
logger.warning(f"Password change failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during password change: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred. Please try again later."
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_current_user_info(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> Any:
|
||||
"""
|
||||
Get current user information.
|
||||
|
||||
Requires authentication.
|
||||
"""
|
||||
return current_user
|
||||
185
backend/app/core/auth.py
Normal file
185
backend/app/core/auth.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import logging
|
||||
logging.getLogger('passlib').setLevel(logging.ERROR)
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, Optional, Union
|
||||
import uuid
|
||||
|
||||
from jose import jwt, JWTError
|
||||
from passlib.context import CryptContext
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.core.config import settings
|
||||
from app.schemas.users import TokenData, TokenPayload
|
||||
|
||||
|
||||
# Password hashing context
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
# Custom exceptions for auth
|
||||
class AuthError(Exception):
|
||||
"""Base authentication error"""
|
||||
pass
|
||||
|
||||
class TokenExpiredError(AuthError):
|
||||
"""Token has expired"""
|
||||
pass
|
||||
|
||||
class TokenInvalidError(AuthError):
|
||||
"""Token is invalid"""
|
||||
pass
|
||||
|
||||
class TokenMissingClaimError(AuthError):
|
||||
"""Token is missing a required claim"""
|
||||
pass
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against a hash."""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""Generate a password hash."""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def create_access_token(
|
||||
subject: Union[str, Any],
|
||||
expires_delta: Optional[timedelta] = None,
|
||||
claims: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Create a JWT access token.
|
||||
|
||||
Args:
|
||||
subject: The subject of the token (usually user_id)
|
||||
expires_delta: Optional expiration time delta
|
||||
claims: Optional additional claims to include in the token
|
||||
|
||||
Returns:
|
||||
Encoded JWT token
|
||||
"""
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
# Base token data
|
||||
to_encode = {
|
||||
"sub": str(subject),
|
||||
"exp": expire,
|
||||
"iat": datetime.now(tz=timezone.utc),
|
||||
"jti": str(uuid.uuid4()),
|
||||
"type": "access"
|
||||
}
|
||||
|
||||
# Add custom claims
|
||||
if claims:
|
||||
to_encode.update(claims)
|
||||
|
||||
# Create the JWT
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def create_refresh_token(
|
||||
subject: Union[str, Any],
|
||||
expires_delta: Optional[timedelta] = None
|
||||
) -> str:
|
||||
"""
|
||||
Create a JWT refresh token.
|
||||
|
||||
Args:
|
||||
subject: The subject of the token (usually user_id)
|
||||
expires_delta: Optional expiration time delta
|
||||
|
||||
Returns:
|
||||
Encoded JWT refresh token
|
||||
"""
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
to_encode = {
|
||||
"sub": str(subject),
|
||||
"exp": expire,
|
||||
"iat": datetime.now(timezone.utc),
|
||||
"jti": str(uuid.uuid4()),
|
||||
"type": "refresh"
|
||||
}
|
||||
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload:
|
||||
"""
|
||||
Decode and verify a JWT token.
|
||||
|
||||
Args:
|
||||
token: JWT token to decode
|
||||
verify_type: Optional token type to verify (access or refresh)
|
||||
|
||||
Returns:
|
||||
TokenPayload: The decoded token data
|
||||
|
||||
Raises:
|
||||
TokenExpiredError: If the token has expired
|
||||
TokenInvalidError: If the token is invalid
|
||||
TokenMissingClaimError: If a required claim is missing
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
|
||||
# Check required claims before Pydantic validation
|
||||
if not payload.get("sub"):
|
||||
raise TokenMissingClaimError("Token missing 'sub' claim")
|
||||
|
||||
# Verify token type if specified
|
||||
if verify_type and payload.get("type") != verify_type:
|
||||
raise TokenInvalidError(f"Invalid token type, expected {verify_type}")
|
||||
|
||||
# Now validate with Pydantic
|
||||
token_data = TokenPayload(**payload)
|
||||
return token_data
|
||||
|
||||
except JWTError as e:
|
||||
# Check if the error is due to an expired token
|
||||
if "expired" in str(e).lower():
|
||||
raise TokenExpiredError("Token has expired")
|
||||
raise TokenInvalidError("Invalid authentication token")
|
||||
except ValidationError:
|
||||
raise TokenInvalidError("Invalid token payload")
|
||||
|
||||
|
||||
def get_token_data(token: str) -> TokenData:
|
||||
"""
|
||||
Extract the user ID and superuser status from a token.
|
||||
|
||||
Args:
|
||||
token: JWT token
|
||||
|
||||
Returns:
|
||||
TokenData with user_id and is_superuser
|
||||
"""
|
||||
payload = decode_token(token)
|
||||
user_id = payload.sub
|
||||
is_superuser = payload.is_superuser or False
|
||||
|
||||
return TokenData(user_id=uuid.UUID(user_id), is_superuser=is_superuser)
|
||||
@@ -3,7 +3,7 @@ from typing import Optional, List
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
PROJECT_NAME: str = "App"
|
||||
PROJECT_NAME: str = "EventSpace"
|
||||
VERSION: str = "1.0.0"
|
||||
API_V1_STR: str = "/api/v1"
|
||||
|
||||
@@ -14,6 +14,17 @@ class Settings(BaseSettings):
|
||||
POSTGRES_PORT: str = "5432"
|
||||
POSTGRES_DB: str = "app"
|
||||
DATABASE_URL: Optional[str] = None
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: int = 60
|
||||
db_pool_size: int = 20 # Default connection pool size
|
||||
db_max_overflow: int = 50 # Maximum overflow connections
|
||||
db_pool_timeout: int = 30 # Seconds to wait for a connection
|
||||
db_pool_recycle: int = 3600 # Recycle connections after 1 hour
|
||||
|
||||
# SQL debugging (disable in production)
|
||||
sql_echo: bool = False # Log SQL statements
|
||||
sql_echo_pool: bool = False # Log connection pool events
|
||||
sql_echo_timing: bool = False # Log query execution times
|
||||
slow_query_threshold: float = 0.5 # Log queries taking longer than this
|
||||
|
||||
@property
|
||||
def database_url(self) -> str:
|
||||
@@ -30,7 +41,7 @@ class Settings(BaseSettings):
|
||||
# JWT configuration
|
||||
SECRET_KEY: str = "your_secret_key_here"
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 1440 # 1 day
|
||||
|
||||
# CORS configuration
|
||||
BACKEND_CORS_ORIGINS: List[str] = ["http://localhost:3000"]
|
||||
|
||||
@@ -1,17 +1,57 @@
|
||||
# app/core/database.py
|
||||
import logging
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# Use the database URL from settings
|
||||
engine = create_engine(settings.database_url)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# SQLite compatibility for testing
|
||||
@compiles(JSONB, 'sqlite')
|
||||
def compile_jsonb_sqlite(type_, compiler, **kw):
|
||||
return "TEXT"
|
||||
|
||||
@compiles(UUID, 'sqlite')
|
||||
def compile_uuid_sqlite(type_, compiler, **kw):
|
||||
return "TEXT"
|
||||
|
||||
# Declarative base for models
|
||||
Base = declarative_base()
|
||||
|
||||
# Create engine with optimized settings for PostgreSQL
|
||||
def create_production_engine():
|
||||
return create_engine(
|
||||
settings.database_url,
|
||||
# Connection pool settings
|
||||
pool_size=settings.db_pool_size,
|
||||
max_overflow=settings.db_max_overflow,
|
||||
pool_timeout=settings.db_pool_timeout,
|
||||
pool_recycle=settings.db_pool_recycle,
|
||||
pool_pre_ping=True,
|
||||
# Query execution settings
|
||||
connect_args={
|
||||
"application_name": "eventspace",
|
||||
"keepalives": 1,
|
||||
"keepalives_idle": 60,
|
||||
"keepalives_interval": 10,
|
||||
"keepalives_count": 5,
|
||||
"options": "-c timezone=UTC",
|
||||
},
|
||||
isolation_level="READ COMMITTED",
|
||||
echo=settings.sql_echo,
|
||||
echo_pool=settings.sql_echo_pool,
|
||||
)
|
||||
|
||||
# Dependency to get DB session
|
||||
# Default production engine and session factory
|
||||
engine = create_production_engine()
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
# FastAPI dependency
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
|
||||
0
backend/app/crud/__init__.py
Normal file
0
backend/app/crud/__init__.py
Normal file
62
backend/app/crud/base.py
Normal file
62
backend/app/crud/base.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from app.core.database import Base
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=Base)
|
||||
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
|
||||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
|
||||
|
||||
|
||||
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
def __init__(self, model: Type[ModelType]):
|
||||
"""
|
||||
CRUD object with default methods to Create, Read, Update, Delete (CRUD).
|
||||
|
||||
Parameters:
|
||||
model: A SQLAlchemy model class
|
||||
"""
|
||||
self.model = model
|
||||
|
||||
def get(self, db: Session, id: str) -> Optional[ModelType]:
|
||||
return db.query(self.model).filter(self.model.id == id).first()
|
||||
|
||||
def get_multi(
|
||||
self, db: Session, *, skip: int = 0, limit: int = 100
|
||||
) -> List[ModelType]:
|
||||
return db.query(self.model).offset(skip).limit(limit).all()
|
||||
|
||||
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
|
||||
obj_in_data = jsonable_encoder(obj_in)
|
||||
db_obj = self.model(**obj_in_data)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
def update(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
db_obj: ModelType,
|
||||
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
|
||||
) -> ModelType:
|
||||
obj_data = jsonable_encoder(db_obj)
|
||||
if isinstance(obj_in, dict):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
for field in obj_data:
|
||||
if field in update_data:
|
||||
setattr(db_obj, field, update_data[field])
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
def remove(self, db: Session, *, id: str) -> ModelType:
|
||||
obj = db.query(self.model).get(id)
|
||||
db.delete(obj)
|
||||
db.commit()
|
||||
return obj
|
||||
56
backend/app/crud/user.py
Normal file
56
backend/app/crud/user.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# app/crud/user.py
|
||||
from typing import Optional, Union, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
from app.core.auth import get_password_hash
|
||||
|
||||
|
||||
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
def get_by_email(self, db: Session, *, email: str) -> Optional[User]:
|
||||
return db.query(User).filter(User.email == email).first()
|
||||
|
||||
def create(self, db: Session, *, obj_in: UserCreate) -> User:
|
||||
db_obj = User(
|
||||
email=obj_in.email,
|
||||
password_hash=get_password_hash(obj_in.password),
|
||||
first_name=obj_in.first_name,
|
||||
last_name=obj_in.last_name,
|
||||
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None,
|
||||
is_superuser=obj_in.is_superuser if hasattr(obj_in, 'is_superuser') else False,
|
||||
preferences={}
|
||||
)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
def update(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
db_obj: User,
|
||||
obj_in: Union[UserUpdate, Dict[str, Any]]
|
||||
) -> User:
|
||||
if isinstance(obj_in, dict):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
|
||||
# Handle password separately if it exists in update data
|
||||
if "password" in update_data:
|
||||
update_data["password_hash"] = get_password_hash(update_data["password"])
|
||||
del update_data["password"]
|
||||
|
||||
return super().update(db, db_obj=db_obj, obj_in=update_data)
|
||||
|
||||
def is_active(self, user: User) -> bool:
|
||||
return user.is_active
|
||||
|
||||
def is_superuser(self, user: User) -> bool:
|
||||
return user.is_superuser
|
||||
|
||||
|
||||
# Create a singleton instance for use across the application
|
||||
user = CRUDUser(User)
|
||||
@@ -1,9 +1,18 @@
|
||||
import logging
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from app.config import settings
|
||||
from app.api.main import api_router
|
||||
from app.core.config import settings
|
||||
|
||||
scheduler = AsyncIOScheduler()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger.info(f"Starting app!!!")
|
||||
app = FastAPI(
|
||||
title=settings.PROJECT_NAME,
|
||||
version=settings.VERSION,
|
||||
@@ -34,3 +43,6 @@ async def root():
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||||
|
||||
14
backend/app/models/__init__.py
Normal file
14
backend/app/models/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
Models package initialization.
|
||||
Imports all models to ensure they're registered with SQLAlchemy.
|
||||
"""
|
||||
# First import Base to avoid circular imports
|
||||
from app.core.database import Base
|
||||
from .base import TimestampMixin, UUIDMixin
|
||||
|
||||
# Import user model
|
||||
from .user import User
|
||||
__all__ = [
|
||||
'Base', 'TimestampMixin', 'UUIDMixin',
|
||||
'User',
|
||||
]
|
||||
20
backend/app/models/base.py
Normal file
20
backend/app/models/base.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import Column, DateTime
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class TimestampMixin:
|
||||
"""Mixin to add created_at and updated_at timestamps to models"""
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc), nullable=False)
|
||||
|
||||
|
||||
class UUIDMixin:
|
||||
"""Mixin to add UUID primary keys to models"""
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
19
backend/app/models/user.py
Normal file
19
backend/app/models/user.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from sqlalchemy import Column, String, JSON, Boolean
|
||||
|
||||
from .base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class User(Base, UUIDMixin, TimestampMixin):
|
||||
__tablename__ = 'users'
|
||||
|
||||
email = Column(String, unique=True, nullable=False, index=True)
|
||||
password_hash = Column(String, nullable=False)
|
||||
first_name = Column(String, nullable=False, default="user")
|
||||
last_name = Column(String, nullable=True)
|
||||
phone_number = Column(String)
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
is_superuser = Column(Boolean, default=False, nullable=False)
|
||||
preferences = Column(JSON)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<User {self.email}>"
|
||||
0
backend/app/schemas/__init__.py
Normal file
0
backend/app/schemas/__init__.py
Normal file
149
backend/app/schemas/users.py
Normal file
149
backend/app/schemas/users.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# app/schemas/users.py
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, EmailStr, field_validator, ConfigDict
|
||||
|
||||
|
||||
class UserBase(BaseModel):
|
||||
email: EmailStr
|
||||
first_name: str
|
||||
last_name: Optional[str] = None
|
||||
phone_number: Optional[str] = None
|
||||
|
||||
@field_validator('phone_number')
|
||||
@classmethod
|
||||
def validate_phone_number(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is None:
|
||||
return v
|
||||
# Simple regex for phone validation
|
||||
if not re.match(r'^\+?[0-9\s\-\(\)]{8,20}$', v):
|
||||
raise ValueError('Invalid phone number format')
|
||||
return v
|
||||
|
||||
|
||||
class UserCreate(UserBase):
|
||||
password: str
|
||||
is_superuser: bool = False
|
||||
|
||||
@field_validator('password')
|
||||
@classmethod
|
||||
def password_strength(cls, v: str) -> str:
|
||||
"""Basic password strength validation"""
|
||||
if len(v) < 8:
|
||||
raise ValueError('Password must be at least 8 characters')
|
||||
if not any(char.isdigit() for char in v):
|
||||
raise ValueError('Password must contain at least one digit')
|
||||
if not any(char.isupper() for char in v):
|
||||
raise ValueError('Password must contain at least one uppercase letter')
|
||||
return v
|
||||
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
phone_number: Optional[str] = None
|
||||
preferences: Optional[Dict[str, Any]] = None
|
||||
is_active: Optional[bool] = True
|
||||
@field_validator('phone_number')
|
||||
def validate_phone_number(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
# Return early for empty strings or whitespace-only strings
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError('Phone number cannot be empty')
|
||||
|
||||
# Remove all spaces and formatting characters
|
||||
cleaned = re.sub(r'[\s\-\(\)]', '', v)
|
||||
|
||||
# Basic pattern:
|
||||
# Must start with + or 0
|
||||
# After + must have at least 8 digits
|
||||
# After 0 must have at least 8 digits
|
||||
# Maximum total length of 15 digits (international standard)
|
||||
# Only allowed characters are + at start and digits
|
||||
pattern = r'^(?:\+[0-9]{8,14}|0[0-9]{8,14})$'
|
||||
|
||||
if not re.match(pattern, cleaned):
|
||||
raise ValueError('Phone number must start with + or 0 followed by 8-14 digits')
|
||||
|
||||
# Additional validation to catch specific invalid cases
|
||||
if cleaned.count('+') > 1:
|
||||
raise ValueError('Phone number can only contain one + symbol at the start')
|
||||
|
||||
# Check for any non-digit characters (except the leading +)
|
||||
if not all(c.isdigit() for c in cleaned[1:]):
|
||||
raise ValueError('Phone number can only contain digits after the prefix')
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
class UserInDB(UserBase):
|
||||
id: UUID
|
||||
is_active: bool
|
||||
is_superuser: bool
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class UserResponse(UserBase):
|
||||
id: UUID
|
||||
is_active: bool
|
||||
is_superuser: bool
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: Optional[str] = None
|
||||
token_type: str = "bearer"
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
sub: str # User ID
|
||||
exp: int # Expiration time
|
||||
iat: Optional[int] = None # Issued at
|
||||
jti: Optional[str] = None # JWT ID
|
||||
is_superuser: Optional[bool] = False
|
||||
first_name: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
type: Optional[str] = None # Token type (access/refresh)
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
user_id: UUID
|
||||
is_superuser: bool = False
|
||||
|
||||
|
||||
class PasswordReset(BaseModel):
|
||||
token: str
|
||||
new_password: str
|
||||
|
||||
@field_validator('new_password')
|
||||
@classmethod
|
||||
def password_strength(cls, v: str) -> str:
|
||||
"""Basic password strength validation"""
|
||||
if len(v) < 8:
|
||||
raise ValueError('Password must be at least 8 characters')
|
||||
if not any(char.isdigit() for char in v):
|
||||
raise ValueError('Password must contain at least one digit')
|
||||
if not any(char.isupper() for char in v):
|
||||
raise ValueError('Password must contain at least one uppercase letter')
|
||||
return v
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
refresh_token: str
|
||||
0
backend/app/services/__init__.py
Normal file
0
backend/app/services/__init__.py
Normal file
193
backend/app/services/auth_service.py
Normal file
193
backend/app/services/auth_service.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# app/services/auth_service.py
|
||||
import logging
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.auth import (
|
||||
verify_password,
|
||||
get_password_hash,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
TokenExpiredError,
|
||||
TokenInvalidError
|
||||
)
|
||||
from app.models.user import User
|
||||
from app.schemas.users import Token, UserCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthenticationError(Exception):
|
||||
"""Exception raised for authentication errors"""
|
||||
pass
|
||||
|
||||
|
||||
class AuthService:
|
||||
"""Service for handling authentication operations"""
|
||||
|
||||
@staticmethod
|
||||
def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
|
||||
"""
|
||||
Authenticate a user with email and password.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
email: User email
|
||||
password: User password
|
||||
|
||||
Returns:
|
||||
User if authenticated, None otherwise
|
||||
"""
|
||||
user = db.query(User).filter(User.email == email).first()
|
||||
|
||||
if not user:
|
||||
return None
|
||||
|
||||
if not verify_password(password, user.password_hash):
|
||||
return None
|
||||
|
||||
if not user.is_active:
|
||||
raise AuthenticationError("User account is inactive")
|
||||
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def create_user(db: Session, user_data: UserCreate) -> User:
|
||||
"""
|
||||
Create a new user.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_data: User data
|
||||
|
||||
Returns:
|
||||
Created user
|
||||
"""
|
||||
# Check if user already exists
|
||||
existing_user = db.query(User).filter(User.email == user_data.email).first()
|
||||
if existing_user:
|
||||
raise AuthenticationError("User with this email already exists")
|
||||
|
||||
# Create new user
|
||||
hashed_password = get_password_hash(user_data.password)
|
||||
|
||||
# Create user object from model
|
||||
user = User(
|
||||
email=user_data.email,
|
||||
password_hash=hashed_password,
|
||||
first_name=user_data.first_name,
|
||||
last_name=user_data.last_name,
|
||||
phone_number=user_data.phone_number,
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def create_tokens(user: User) -> Token:
|
||||
"""
|
||||
Create access and refresh tokens for a user.
|
||||
|
||||
Args:
|
||||
user: User to create tokens for
|
||||
|
||||
Returns:
|
||||
Token object with access and refresh tokens
|
||||
"""
|
||||
# Generate claims
|
||||
claims = {
|
||||
"is_superuser": user.is_superuser,
|
||||
"email": user.email,
|
||||
"first_name": user.first_name
|
||||
}
|
||||
|
||||
# Create tokens
|
||||
access_token = create_access_token(
|
||||
subject=str(user.id),
|
||||
claims=claims
|
||||
)
|
||||
|
||||
refresh_token = create_refresh_token(
|
||||
subject=str(user.id)
|
||||
)
|
||||
|
||||
return Token(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def refresh_tokens(db: Session, refresh_token: str) -> Token:
|
||||
"""
|
||||
Generate new tokens using a refresh token.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
refresh_token: Valid refresh token
|
||||
|
||||
Returns:
|
||||
New access and refresh tokens
|
||||
|
||||
Raises:
|
||||
TokenExpiredError: If refresh token has expired
|
||||
TokenInvalidError: If refresh token is invalid
|
||||
"""
|
||||
from app.core.auth import decode_token, get_token_data
|
||||
|
||||
try:
|
||||
# Verify token is a refresh token
|
||||
decode_token(refresh_token, verify_type="refresh")
|
||||
|
||||
# Get user ID from token
|
||||
token_data = get_token_data(refresh_token)
|
||||
user_id = token_data.user_id
|
||||
|
||||
# Get user from database
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user or not user.is_active:
|
||||
raise TokenInvalidError("Invalid user or inactive account")
|
||||
|
||||
# Generate new tokens
|
||||
return AuthService.create_tokens(user)
|
||||
|
||||
except (TokenExpiredError, TokenInvalidError) as e:
|
||||
logger.warning(f"Token refresh failed: {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def change_password(db: Session, user_id: UUID, current_password: str, new_password: str) -> bool:
|
||||
"""
|
||||
Change a user's password.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
current_password: Current password
|
||||
new_password: New password
|
||||
|
||||
Returns:
|
||||
True if password was changed successfully
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If current password is incorrect
|
||||
"""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise AuthenticationError("User not found")
|
||||
|
||||
# Verify current password
|
||||
if not verify_password(current_password, user.password_hash):
|
||||
raise AuthenticationError("Current password is incorrect")
|
||||
|
||||
# Update password
|
||||
user.password_hash = get_password_hash(new_password)
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
0
backend/app/utils/__init__.py
Normal file
0
backend/app/utils/__init__.py
Normal file
79
backend/app/utils/test_utils.py
Normal file
79
backend/app/utils/test_utils.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import logging
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker, clear_mappers
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_test_engine():
|
||||
"""Create an SQLite in-memory engine specifically for testing"""
|
||||
test_engine = create_engine(
|
||||
"sqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool, # Use static pool for in-memory testing
|
||||
echo=False
|
||||
)
|
||||
|
||||
return test_engine
|
||||
|
||||
def setup_test_db():
|
||||
"""Create a test database and session factory"""
|
||||
# Create a new engine for this test run
|
||||
test_engine = get_test_engine()
|
||||
|
||||
# Create tables
|
||||
Base.metadata.create_all(test_engine)
|
||||
|
||||
# Create session factory
|
||||
TestingSessionLocal = sessionmaker(
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=test_engine,
|
||||
expire_on_commit=False
|
||||
)
|
||||
|
||||
return test_engine, TestingSessionLocal
|
||||
|
||||
def teardown_test_db(engine):
|
||||
"""Clean up after tests"""
|
||||
# Drop all tables
|
||||
Base.metadata.drop_all(engine)
|
||||
|
||||
# Dispose of engine
|
||||
engine.dispose()
|
||||
|
||||
async def get_async_test_engine():
|
||||
"""Create an async SQLite in-memory engine specifically for testing"""
|
||||
test_engine = create_async_engine(
|
||||
"sqlite+aiosqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool, # Use static pool for in-memory testing
|
||||
echo=False
|
||||
)
|
||||
return test_engine
|
||||
|
||||
|
||||
async def setup_async_test_db():
|
||||
"""Create an async test database and session factory"""
|
||||
test_engine = await get_async_test_engine()
|
||||
|
||||
async with test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
AsyncTestingSessionLocal = sessionmaker(
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=test_engine,
|
||||
expire_on_commit=False,
|
||||
class_=AsyncSession
|
||||
)
|
||||
|
||||
return test_engine, AsyncTestingSessionLocal
|
||||
|
||||
|
||||
async def teardown_async_test_db(engine):
|
||||
"""Clean up after async tests"""
|
||||
await engine.dispose()
|
||||
10
backend/pytest.ini
Normal file
10
backend/pytest.ini
Normal file
@@ -0,0 +1,10 @@
|
||||
[pytest]
|
||||
env =
|
||||
IS_TEST=True
|
||||
testpaths = tests
|
||||
python_files = test_*.py
|
||||
addopts = --disable-warnings
|
||||
markers =
|
||||
sqlite: marks tests that should run on SQLite (mocked).
|
||||
postgres: marks tests that require a real PostgreSQL database.
|
||||
asyncio_default_fixture_loop_scope = function
|
||||
@@ -4,13 +4,14 @@ uvicorn>=0.34.0
|
||||
pydantic>=2.10.6
|
||||
pydantic-settings>=2.2.1
|
||||
python-multipart>=0.0.19
|
||||
fastapi-utils==0.8.0
|
||||
|
||||
# Database
|
||||
sqlalchemy>=2.0.29
|
||||
alembic>=1.14.1
|
||||
psycopg2-binary>=2.9.9
|
||||
asyncpg>=0.29.0
|
||||
|
||||
aiosqlite==0.21.0
|
||||
# Security and authentication
|
||||
python-jose>=3.4.0
|
||||
passlib>=1.7.4
|
||||
@@ -30,7 +31,7 @@ httpx>=0.27.0
|
||||
tenacity>=8.2.3
|
||||
pytz>=2024.1
|
||||
pillow>=10.3.0
|
||||
|
||||
apscheduler==3.11.0
|
||||
# Testing
|
||||
pytest>=8.0.0
|
||||
pytest-asyncio>=0.23.5
|
||||
@@ -41,4 +42,11 @@ requests>=2.32.0
|
||||
black>=24.3.0
|
||||
isort>=5.13.2
|
||||
flake8>=7.0.0
|
||||
mypy>=1.8.0
|
||||
mypy>=1.8.0
|
||||
|
||||
# Security
|
||||
python-jose==3.4.0
|
||||
bcrypt==4.2.1
|
||||
cryptography==44.0.1
|
||||
passlib==1.7.4
|
||||
freezegun~=1.5.1
|
||||
0
backend/tests/api/routes/__init__.py
Normal file
0
backend/tests/api/routes/__init__.py
Normal file
369
backend/tests/api/routes/test_auth.py
Normal file
369
backend/tests/api/routes/test_auth.py
Normal file
@@ -0,0 +1,369 @@
|
||||
# tests/api/routes/test_auth.py
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import patch, MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, Depends
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.routes.auth import router as auth_router
|
||||
from app.core.auth import get_password_hash
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.services.auth_service import AuthService, AuthenticationError
|
||||
from app.core.auth import TokenExpiredError, TokenInvalidError
|
||||
|
||||
|
||||
# Mock the get_db dependency
|
||||
@pytest.fixture
|
||||
def override_get_db(db_session):
|
||||
"""Override get_db dependency for testing."""
|
||||
return db_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(override_get_db):
|
||||
"""Create a FastAPI test application with overridden dependencies."""
|
||||
app = FastAPI()
|
||||
app.include_router(auth_router, prefix="/auth", tags=["auth"])
|
||||
|
||||
# Override the get_db dependency
|
||||
app.dependency_overrides[get_db] = lambda: override_get_db
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Create a FastAPI test client."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
class TestRegisterUser:
|
||||
"""Tests for the register_user endpoint."""
|
||||
|
||||
def test_register_user_success(self, client, monkeypatch, db_session):
|
||||
"""Test successful user registration."""
|
||||
# Mock the service method with a valid complete User object
|
||||
mock_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="newuser@example.com",
|
||||
password_hash="hashed_password",
|
||||
first_name="New",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
# Use patch for mocking
|
||||
with patch.object(AuthService, 'create_user', return_value=mock_user):
|
||||
# Test request
|
||||
response = client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "newuser@example.com",
|
||||
"password": "Password123",
|
||||
"first_name": "New",
|
||||
"last_name": "User"
|
||||
}
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["email"] == "newuser@example.com"
|
||||
assert data["first_name"] == "New"
|
||||
assert data["last_name"] == "User"
|
||||
assert "password" not in data
|
||||
|
||||
def test_register_user_duplicate_email(self, client, db_session):
|
||||
"""Test registration with duplicate email."""
|
||||
# Use patch for mocking with a side effect
|
||||
with patch.object(AuthService, 'create_user',
|
||||
side_effect=AuthenticationError("User with this email already exists")):
|
||||
# Test request
|
||||
response = client.post(
|
||||
"/auth/register",
|
||||
json={
|
||||
"email": "existing@example.com",
|
||||
"password": "Password123",
|
||||
"first_name": "Existing",
|
||||
"last_name": "User"
|
||||
}
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert response.status_code == 409
|
||||
assert "already exists" in response.json()["detail"]
|
||||
|
||||
|
||||
class TestLogin:
|
||||
"""Tests for the login endpoint."""
|
||||
|
||||
def test_login_success(self, client, mock_user, db_session):
|
||||
"""Test successful login."""
|
||||
# Ensure mock_user has required attributes
|
||||
if not hasattr(mock_user, 'created_at') or mock_user.created_at is None:
|
||||
mock_user.created_at = datetime.now(timezone.utc)
|
||||
if not hasattr(mock_user, 'updated_at') or mock_user.updated_at is None:
|
||||
mock_user.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
# Create mock tokens
|
||||
mock_tokens = MagicMock(
|
||||
access_token="mock_access_token",
|
||||
refresh_token="mock_refresh_token",
|
||||
token_type="bearer"
|
||||
)
|
||||
|
||||
# Use context managers for patching
|
||||
with patch.object(AuthService, 'authenticate_user', return_value=mock_user), \
|
||||
patch.object(AuthService, 'create_tokens', return_value=mock_tokens):
|
||||
|
||||
# Test request
|
||||
response = client.post(
|
||||
"/auth/login",
|
||||
json={
|
||||
"email": "user@example.com",
|
||||
"password": "Password123"
|
||||
}
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["access_token"] == "mock_access_token"
|
||||
assert data["refresh_token"] == "mock_refresh_token"
|
||||
assert data["token_type"] == "bearer"
|
||||
|
||||
|
||||
def test_login_invalid_credentials_debug(self, client, app):
|
||||
"""Improved test for login with invalid credentials."""
|
||||
# Print response for debugging
|
||||
from app.services.auth_service import AuthService
|
||||
|
||||
# Create a complete mock for AuthService
|
||||
class MockAuthService:
|
||||
@staticmethod
|
||||
def authenticate_user(db, email, password):
|
||||
print(f"Mock called with: {email}, {password}")
|
||||
return None
|
||||
|
||||
# Replace the entire class with our mock
|
||||
original_service = AuthService
|
||||
try:
|
||||
# Replace with our mock
|
||||
import sys
|
||||
sys.modules['app.services.auth_service'].AuthService = MockAuthService
|
||||
|
||||
# Make the request
|
||||
response = client.post(
|
||||
"/auth/login",
|
||||
json={
|
||||
"email": "user@example.com",
|
||||
"password": "WrongPassword"
|
||||
}
|
||||
)
|
||||
|
||||
# Print response details for debugging
|
||||
print(f"Response status: {response.status_code}")
|
||||
print(f"Response body: {response.text}")
|
||||
|
||||
# Assertions
|
||||
assert response.status_code == 401
|
||||
assert "Invalid email or password" in response.json()["detail"]
|
||||
finally:
|
||||
# Restore original service
|
||||
sys.modules['app.services.auth_service'].AuthService = original_service
|
||||
|
||||
|
||||
def test_login_inactive_user(self, client, db_session):
|
||||
"""Test login with inactive user."""
|
||||
# Mock authentication to raise an error
|
||||
with patch.object(AuthService, 'authenticate_user',
|
||||
side_effect=AuthenticationError("User account is inactive")):
|
||||
# Test request
|
||||
response = client.post(
|
||||
"/auth/login",
|
||||
json={
|
||||
"email": "inactive@example.com",
|
||||
"password": "Password123"
|
||||
}
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert response.status_code == 401
|
||||
assert "inactive" in response.json()["detail"]
|
||||
|
||||
|
||||
class TestRefreshToken:
|
||||
"""Tests for the refresh_token endpoint."""
|
||||
|
||||
def test_refresh_token_success(self, client, db_session):
|
||||
"""Test successful token refresh."""
|
||||
# Mock refresh to return tokens
|
||||
mock_tokens = MagicMock(
|
||||
access_token="new_access_token",
|
||||
refresh_token="new_refresh_token",
|
||||
token_type="bearer"
|
||||
)
|
||||
|
||||
with patch.object(AuthService, 'refresh_tokens', return_value=mock_tokens):
|
||||
# Test request
|
||||
response = client.post(
|
||||
"/auth/refresh",
|
||||
json={
|
||||
"refresh_token": "valid_refresh_token"
|
||||
}
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["access_token"] == "new_access_token"
|
||||
assert data["refresh_token"] == "new_refresh_token"
|
||||
assert data["token_type"] == "bearer"
|
||||
|
||||
def test_refresh_token_expired(self, client, db_session):
|
||||
"""Test refresh with expired token."""
|
||||
# Mock refresh to raise expired token error
|
||||
with patch.object(AuthService, 'refresh_tokens',
|
||||
side_effect=TokenExpiredError("Token expired")):
|
||||
# Test request
|
||||
response = client.post(
|
||||
"/auth/refresh",
|
||||
json={
|
||||
"refresh_token": "expired_refresh_token"
|
||||
}
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert response.status_code == 401
|
||||
assert "expired" in response.json()["detail"]
|
||||
|
||||
def test_refresh_token_invalid(self, client, db_session):
|
||||
"""Test refresh with invalid token."""
|
||||
# Mock refresh to raise invalid token error
|
||||
with patch.object(AuthService, 'refresh_tokens',
|
||||
side_effect=TokenInvalidError("Invalid token")):
|
||||
# Test request
|
||||
response = client.post(
|
||||
"/auth/refresh",
|
||||
json={
|
||||
"refresh_token": "invalid_refresh_token"
|
||||
}
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert response.status_code == 401
|
||||
assert "Invalid" in response.json()["detail"]
|
||||
|
||||
|
||||
class TestChangePassword:
|
||||
"""Tests for the change_password endpoint."""
|
||||
|
||||
def test_change_password_success(self, client, mock_user, db_session, app):
|
||||
"""Test successful password change."""
|
||||
# Ensure mock_user has required attributes
|
||||
if not hasattr(mock_user, 'created_at') or mock_user.created_at is None:
|
||||
mock_user.created_at = datetime.now(timezone.utc)
|
||||
if not hasattr(mock_user, 'updated_at') or mock_user.updated_at is None:
|
||||
mock_user.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
# Override get_current_user dependency
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
app.dependency_overrides[get_current_user] = lambda: mock_user
|
||||
|
||||
# Mock password change to return success
|
||||
with patch.object(AuthService, 'change_password', return_value=True):
|
||||
# Test request
|
||||
response = client.post(
|
||||
"/auth/change-password",
|
||||
json={
|
||||
"current_password": "OldPassword123",
|
||||
"new_password": "NewPassword123"
|
||||
}
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert response.status_code == 200
|
||||
assert "success" in response.json()["message"].lower()
|
||||
|
||||
# Clean up override
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
def test_change_password_incorrect_current_password(self, client, mock_user, db_session, app):
|
||||
"""Test change password with incorrect current password."""
|
||||
# Ensure mock_user has required attributes
|
||||
if not hasattr(mock_user, 'created_at') or mock_user.created_at is None:
|
||||
mock_user.created_at = datetime.now(timezone.utc)
|
||||
if not hasattr(mock_user, 'updated_at') or mock_user.updated_at is None:
|
||||
mock_user.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
# Override get_current_user dependency
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
app.dependency_overrides[get_current_user] = lambda: mock_user
|
||||
|
||||
# Mock password change to raise error
|
||||
with patch.object(AuthService, 'change_password',
|
||||
side_effect=AuthenticationError("Current password is incorrect")):
|
||||
# Test request
|
||||
response = client.post(
|
||||
"/auth/change-password",
|
||||
json={
|
||||
"current_password": "WrongPassword",
|
||||
"new_password": "NewPassword123"
|
||||
}
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert response.status_code == 400
|
||||
assert "incorrect" in response.json()["detail"].lower()
|
||||
|
||||
# Clean up override
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
|
||||
class TestGetCurrentUserInfo:
|
||||
"""Tests for the get_current_user_info endpoint."""
|
||||
|
||||
def test_get_current_user_info(self, client, mock_user, app):
|
||||
"""Test getting current user info."""
|
||||
# Ensure mock_user has required attributes
|
||||
if not hasattr(mock_user, 'created_at') or mock_user.created_at is None:
|
||||
mock_user.created_at = datetime.now(timezone.utc)
|
||||
if not hasattr(mock_user, 'updated_at') or mock_user.updated_at is None:
|
||||
mock_user.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
# Override get_current_user dependency
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
app.dependency_overrides[get_current_user] = lambda: mock_user
|
||||
|
||||
# Test request
|
||||
response = client.get("/auth/me")
|
||||
|
||||
# Assertions
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["email"] == mock_user.email
|
||||
assert data["first_name"] == mock_user.first_name
|
||||
assert data["last_name"] == mock_user.last_name
|
||||
assert "password" not in data
|
||||
|
||||
# Clean up override
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
def test_get_current_user_info_unauthorized(self, client):
|
||||
"""Test getting user info without authentication."""
|
||||
# Test request without authentication
|
||||
response = client.get("/auth/me")
|
||||
|
||||
# Assertions
|
||||
assert response.status_code == 401
|
||||
211
backend/tests/api/test_auth_dependencies.py
Normal file
211
backend/tests/api/test_auth_dependencies.py
Normal file
@@ -0,0 +1,211 @@
|
||||
# tests/api/dependencies/test_auth_dependencies.py
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.api.dependencies.auth import (
|
||||
get_current_user,
|
||||
get_current_active_user,
|
||||
get_current_superuser,
|
||||
get_optional_current_user
|
||||
)
|
||||
from app.core.auth import TokenExpiredError, TokenInvalidError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_token():
|
||||
return "mock.jwt.token"
|
||||
|
||||
|
||||
class TestGetCurrentUser:
|
||||
"""Tests for get_current_user dependency"""
|
||||
|
||||
def test_get_current_user_success(self, db_session, mock_user, mock_token):
|
||||
"""Test successfully getting the current user"""
|
||||
# Mock get_token_data to return user_id that matches our mock_user
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
mock_get_data.return_value.user_id = mock_user.id
|
||||
|
||||
# Call the dependency
|
||||
user = get_current_user(db=db_session, token=mock_token)
|
||||
|
||||
# Verify the correct user was returned
|
||||
assert user.id == mock_user.id
|
||||
assert user.email == mock_user.email
|
||||
|
||||
def test_get_current_user_nonexistent(self, db_session, mock_token):
|
||||
"""Test when the token contains a user ID that doesn't exist"""
|
||||
# Mock get_token_data to return a non-existent user ID
|
||||
# Use a real UUID object instead of a string
|
||||
import uuid
|
||||
nonexistent_id = uuid.UUID("11111111-1111-1111-1111-111111111111")
|
||||
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
mock_get_data.return_value.user_id = nonexistent_id # Using UUID object, not string
|
||||
|
||||
# Should raise HTTPException with 404 status
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_user(db=db_session, token=mock_token)
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
def test_get_current_user_inactive(self, db_session, mock_user, mock_token):
|
||||
"""Test when the user is inactive"""
|
||||
# Make the user inactive
|
||||
mock_user.is_active = False
|
||||
db_session.commit()
|
||||
|
||||
# Mock get_token_data
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
mock_get_data.return_value.user_id = mock_user.id
|
||||
|
||||
# Should raise HTTPException with 403 status
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_user(db=db_session, token=mock_token)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
def test_get_current_user_expired_token(self, db_session, mock_token):
|
||||
"""Test with an expired token"""
|
||||
# Mock get_token_data to raise TokenExpiredError
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
mock_get_data.side_effect = TokenExpiredError("Token expired")
|
||||
|
||||
# Should raise HTTPException with 401 status
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_user(db=db_session, token=mock_token)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Token expired" in exc_info.value.detail
|
||||
|
||||
def test_get_current_user_invalid_token(self, db_session, mock_token):
|
||||
"""Test with an invalid token"""
|
||||
# Mock get_token_data to raise TokenInvalidError
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
mock_get_data.side_effect = TokenInvalidError("Invalid token")
|
||||
|
||||
# Should raise HTTPException with 401 status
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_user(db=db_session, token=mock_token)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Could not validate credentials" in exc_info.value.detail
|
||||
|
||||
|
||||
class TestGetCurrentActiveUser:
|
||||
"""Tests for get_current_active_user dependency"""
|
||||
|
||||
def test_get_current_active_user(self, mock_user):
|
||||
"""Test getting an active user"""
|
||||
# Ensure user is active
|
||||
mock_user.is_active = True
|
||||
|
||||
# Call the dependency with mocked current_user
|
||||
user = get_current_active_user(current_user=mock_user)
|
||||
|
||||
# Should return the same user
|
||||
assert user == mock_user
|
||||
|
||||
def test_get_current_inactive_user(self, mock_user):
|
||||
"""Test getting an inactive user"""
|
||||
# Make user inactive
|
||||
mock_user.is_active = False
|
||||
|
||||
# Should raise HTTPException with 403 status
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_active_user(current_user=mock_user)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "Inactive user" in exc_info.value.detail
|
||||
|
||||
|
||||
class TestGetCurrentSuperuser:
|
||||
"""Tests for get_current_superuser dependency"""
|
||||
|
||||
def test_get_current_superuser(self, mock_user):
|
||||
"""Test getting a superuser"""
|
||||
# Make user a superuser
|
||||
mock_user.is_superuser = True
|
||||
|
||||
# Call the dependency with mocked current_user
|
||||
user = get_current_superuser(current_user=mock_user)
|
||||
|
||||
# Should return the same user
|
||||
assert user == mock_user
|
||||
|
||||
def test_get_current_non_superuser(self, mock_user):
|
||||
"""Test getting a non-superuser"""
|
||||
# Ensure user is not a superuser
|
||||
mock_user.is_superuser = False
|
||||
|
||||
# Should raise HTTPException with 403 status
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_current_superuser(current_user=mock_user)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "Not enough permissions" in exc_info.value.detail
|
||||
|
||||
|
||||
class TestGetOptionalCurrentUser:
|
||||
"""Tests for get_optional_current_user dependency"""
|
||||
|
||||
def test_get_optional_current_user_with_token(self, db_session, mock_user, mock_token):
|
||||
"""Test getting optional user with a valid token"""
|
||||
# Mock get_token_data
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
mock_get_data.return_value.user_id = mock_user.id
|
||||
|
||||
# Call the dependency
|
||||
user = get_optional_current_user(db=db_session, token=mock_token)
|
||||
|
||||
# Should return the correct user
|
||||
assert user is not None
|
||||
assert user.id == mock_user.id
|
||||
|
||||
def test_get_optional_current_user_no_token(self, db_session):
|
||||
"""Test getting optional user with no token"""
|
||||
# Call the dependency with no token
|
||||
user = get_optional_current_user(db=db_session, token=None)
|
||||
|
||||
# Should return None
|
||||
assert user is None
|
||||
|
||||
def test_get_optional_current_user_invalid_token(self, db_session, mock_token):
|
||||
"""Test getting optional user with an invalid token"""
|
||||
# Mock get_token_data to raise TokenInvalidError
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
mock_get_data.side_effect = TokenInvalidError("Invalid token")
|
||||
|
||||
# Call the dependency
|
||||
user = get_optional_current_user(db=db_session, token=mock_token)
|
||||
|
||||
# Should return None, not raise an exception
|
||||
assert user is None
|
||||
|
||||
def test_get_optional_current_user_expired_token(self, db_session, mock_token):
|
||||
"""Test getting optional user with an expired token"""
|
||||
# Mock get_token_data to raise TokenExpiredError
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
mock_get_data.side_effect = TokenExpiredError("Token expired")
|
||||
|
||||
# Call the dependency
|
||||
user = get_optional_current_user(db=db_session, token=mock_token)
|
||||
|
||||
# Should return None, not raise an exception
|
||||
assert user is None
|
||||
|
||||
def test_get_optional_current_user_inactive(self, db_session, mock_user, mock_token):
|
||||
"""Test getting optional user when user is inactive"""
|
||||
# Make the user inactive
|
||||
mock_user.is_active = False
|
||||
db_session.commit()
|
||||
|
||||
# Mock get_token_data
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
mock_get_data.return_value.user_id = mock_user.id
|
||||
|
||||
# Call the dependency
|
||||
user = get_optional_current_user(db=db_session, token=mock_token)
|
||||
|
||||
# Should return None for inactive users
|
||||
assert user is None
|
||||
66
backend/tests/conftest.py
Normal file
66
backend/tests/conftest.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# tests/conftest.py
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.user import User
|
||||
from app.utils.test_utils import setup_test_db, teardown_test_db, setup_async_test_db, teardown_async_test_db
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def db_session():
|
||||
"""
|
||||
Creates a fresh SQLite in-memory database for each test function.
|
||||
|
||||
Yields a SQLAlchemy session that can be used for testing.
|
||||
"""
|
||||
# Set up the database
|
||||
test_engine, TestingSessionLocal = setup_test_db()
|
||||
|
||||
# Create a session
|
||||
with TestingSessionLocal() as session:
|
||||
yield session
|
||||
|
||||
# Clean up
|
||||
teardown_test_db(test_engine)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function") # Define a fixture
|
||||
async def async_test_db():
|
||||
"""Fixture provides new testing engine and session for each test run to improve isolation."""
|
||||
|
||||
test_engine, AsyncTestingSessionLocal = await setup_async_test_db()
|
||||
yield test_engine, AsyncTestingSessionLocal
|
||||
await teardown_async_test_db(test_engine)
|
||||
|
||||
@pytest.fixture
|
||||
def user_create_data():
|
||||
return {
|
||||
"email": "newtest@example.com", # Changed to avoid conflict with mock_user
|
||||
"password": "TestPassword123!",
|
||||
"first_name": "Test",
|
||||
"last_name": "User",
|
||||
"phone_number": "+1234567890",
|
||||
"is_superuser": False,
|
||||
"preferences": None
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user(db_session):
|
||||
"""Fixture to create and return a mock User instance."""
|
||||
mock_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="mockuser@example.com",
|
||||
password_hash="mockhashedpassword",
|
||||
first_name="Mock",
|
||||
last_name="User",
|
||||
phone_number="1234567890",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
preferences=None,
|
||||
)
|
||||
db_session.add(mock_user)
|
||||
db_session.commit()
|
||||
return mock_user
|
||||
0
backend/tests/core/__init__.py
Normal file
0
backend/tests/core/__init__.py
Normal file
260
backend/tests/core/test_auth.py
Normal file
260
backend/tests/core/test_auth.py
Normal file
@@ -0,0 +1,260 @@
|
||||
# tests/core/test_auth.py
|
||||
import uuid
|
||||
import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from jose import jwt
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.core.auth import (
|
||||
verify_password,
|
||||
get_password_hash,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_token,
|
||||
get_token_data,
|
||||
TokenExpiredError,
|
||||
TokenInvalidError,
|
||||
TokenMissingClaimError
|
||||
)
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class TestPasswordHandling:
|
||||
"""Tests for password hashing and verification functions"""
|
||||
|
||||
def test_password_hash_different_from_password(self):
|
||||
"""Test that a password hash is different from the original password"""
|
||||
password = "TestPassword123"
|
||||
hashed = get_password_hash(password)
|
||||
assert hashed != password
|
||||
|
||||
def test_verify_correct_password(self):
|
||||
"""Test that verify_password returns True for the correct password"""
|
||||
password = "TestPassword123"
|
||||
hashed = get_password_hash(password)
|
||||
assert verify_password(password, hashed) is True
|
||||
|
||||
def test_verify_incorrect_password(self):
|
||||
"""Test that verify_password returns False for an incorrect password"""
|
||||
password = "TestPassword123"
|
||||
wrong_password = "WrongPassword123"
|
||||
hashed = get_password_hash(password)
|
||||
assert verify_password(wrong_password, hashed) is False
|
||||
|
||||
def test_same_password_different_hash(self):
|
||||
"""Test that the same password gets a different hash each time"""
|
||||
password = "TestPassword123"
|
||||
hash1 = get_password_hash(password)
|
||||
hash2 = get_password_hash(password)
|
||||
assert hash1 != hash2
|
||||
|
||||
|
||||
class TestTokenCreation:
|
||||
"""Tests for token creation functions"""
|
||||
|
||||
def test_create_access_token(self):
|
||||
"""Test that an access token is created with the correct claims"""
|
||||
user_id = str(uuid.uuid4())
|
||||
custom_claims = {
|
||||
"email": "test@example.com",
|
||||
"first_name": "Test",
|
||||
"is_superuser": True
|
||||
}
|
||||
token = create_access_token(subject=user_id, claims=custom_claims)
|
||||
|
||||
# Decode token to verify claims
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
|
||||
# Check standard claims
|
||||
assert payload["sub"] == user_id
|
||||
assert "jti" in payload
|
||||
assert "exp" in payload
|
||||
assert "iat" in payload
|
||||
assert payload["type"] == "access"
|
||||
|
||||
# Check custom claims
|
||||
for key, value in custom_claims.items():
|
||||
assert payload[key] == value
|
||||
|
||||
def test_create_refresh_token(self):
|
||||
"""Test that a refresh token is created with the correct claims"""
|
||||
user_id = str(uuid.uuid4())
|
||||
token = create_refresh_token(subject=user_id)
|
||||
|
||||
# Decode token to verify claims
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
|
||||
# Check standard claims
|
||||
assert payload["sub"] == user_id
|
||||
assert "jti" in payload
|
||||
assert "exp" in payload
|
||||
assert "iat" in payload
|
||||
assert payload["type"] == "refresh"
|
||||
|
||||
def test_token_expiration(self):
|
||||
"""Test that tokens have the correct expiration time"""
|
||||
user_id = str(uuid.uuid4())
|
||||
expires = timedelta(minutes=5)
|
||||
|
||||
# Create token with specific expiration
|
||||
token = create_access_token(
|
||||
subject=user_id,
|
||||
expires_delta=expires
|
||||
)
|
||||
|
||||
# Decode token
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
|
||||
# Get actual expiration time from token
|
||||
expiration = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
|
||||
|
||||
# Calculate expected expiration (approximately)
|
||||
now = datetime.now(timezone.utc)
|
||||
expected_expiration = now + expires
|
||||
|
||||
# Difference should be small (less than 1 second)
|
||||
difference = abs((expiration - expected_expiration).total_seconds())
|
||||
assert difference < 1
|
||||
|
||||
|
||||
class TestTokenDecoding:
|
||||
"""Tests for token decoding and validation functions"""
|
||||
|
||||
def test_decode_valid_token(self):
|
||||
"""Test that a valid token can be decoded"""
|
||||
user_id = str(uuid.uuid4())
|
||||
token = create_access_token(subject=user_id)
|
||||
|
||||
# Decode token
|
||||
payload = decode_token(token)
|
||||
|
||||
# Check that the subject matches
|
||||
assert payload.sub == user_id
|
||||
|
||||
def test_decode_expired_token(self):
|
||||
"""Test that an expired token raises TokenExpiredError"""
|
||||
user_id = str(uuid.uuid4())
|
||||
|
||||
# Create a token that's already expired by directly manipulating the payload
|
||||
now = datetime.now(timezone.utc)
|
||||
expired_time = now - timedelta(hours=1) # 1 hour in the past
|
||||
|
||||
# Create the expired token manually
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"exp": int(expired_time.timestamp()), # Set expiration in the past
|
||||
"iat": int(now.timestamp()),
|
||||
"jti": str(uuid.uuid4()),
|
||||
"type": "access"
|
||||
}
|
||||
|
||||
expired_token = jwt.encode(
|
||||
payload,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
# Attempting to decode should raise TokenExpiredError
|
||||
with pytest.raises(TokenExpiredError):
|
||||
decode_token(expired_token)
|
||||
|
||||
def test_decode_invalid_token(self):
|
||||
"""Test that an invalid token raises TokenInvalidError"""
|
||||
invalid_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJpbnZhbGlkIn0.invalid-signature"
|
||||
|
||||
with pytest.raises(TokenInvalidError):
|
||||
decode_token(invalid_token)
|
||||
|
||||
def test_decode_token_with_missing_sub(self):
|
||||
"""Test that a token without 'sub' claim raises TokenMissingClaimError"""
|
||||
# Create a token without a subject
|
||||
now = datetime.now(timezone.utc)
|
||||
payload = {
|
||||
"exp": int((now + timedelta(minutes=30)).timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
"jti": str(uuid.uuid4()),
|
||||
"type": "access"
|
||||
# No 'sub' claim
|
||||
}
|
||||
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
with pytest.raises(TokenMissingClaimError):
|
||||
decode_token(token)
|
||||
|
||||
def test_decode_token_with_wrong_type(self):
|
||||
"""Test that verifying a token with wrong type raises TokenInvalidError"""
|
||||
user_id = str(uuid.uuid4())
|
||||
token = create_access_token(subject=user_id)
|
||||
|
||||
# Try to verify it as a refresh token
|
||||
with pytest.raises(TokenInvalidError):
|
||||
decode_token(token, verify_type="refresh")
|
||||
|
||||
def test_decode_with_invalid_payload(self):
|
||||
"""Test that a token with invalid payload structure raises TokenInvalidError"""
|
||||
# Create a token with an invalid payload structure - missing 'sub' which is required
|
||||
# but including 'exp' to avoid the expiration check
|
||||
now = datetime.now(timezone.utc)
|
||||
payload = {
|
||||
# Missing "sub" field which is required
|
||||
"exp": int((now + timedelta(minutes=30)).timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
"jti": str(uuid.uuid4()),
|
||||
"invalid_field": "test"
|
||||
}
|
||||
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
# Should raise TokenMissingClaimError due to missing 'sub'
|
||||
with pytest.raises(TokenMissingClaimError):
|
||||
decode_token(token)
|
||||
|
||||
# Create another token with invalid type for required field
|
||||
payload = {
|
||||
"sub": 123, # sub should be a string, not an integer
|
||||
"exp": int((now + timedelta(minutes=30)).timestamp()),
|
||||
}
|
||||
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
# Should raise TokenInvalidError due to ValidationError
|
||||
with pytest.raises(TokenInvalidError):
|
||||
decode_token(token)
|
||||
|
||||
def test_get_token_data(self):
|
||||
"""Test extracting TokenData from a token"""
|
||||
user_id = uuid.uuid4()
|
||||
token = create_access_token(
|
||||
subject=str(user_id),
|
||||
claims={"is_superuser": True}
|
||||
)
|
||||
|
||||
token_data = get_token_data(token)
|
||||
|
||||
assert token_data.user_id == user_id
|
||||
assert token_data.is_superuser is True
|
||||
0
backend/tests/crud/__init__.py
Normal file
0
backend/tests/crud/__init__.py
Normal file
125
backend/tests/crud/test_user.py
Normal file
125
backend/tests/crud/test_user.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import pytest
|
||||
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
|
||||
|
||||
def test_create_user(db_session, user_create_data):
|
||||
user_in = UserCreate(**user_create_data)
|
||||
user_obj = user_crud.create(db_session, obj_in=user_in)
|
||||
|
||||
assert user_obj.email == user_create_data["email"]
|
||||
assert user_obj.first_name == user_create_data["first_name"]
|
||||
assert user_obj.last_name == user_create_data["last_name"]
|
||||
assert user_obj.phone_number == user_create_data["phone_number"]
|
||||
assert user_obj.is_superuser == user_create_data["is_superuser"]
|
||||
assert user_obj.password_hash is not None
|
||||
assert user_obj.id is not None
|
||||
|
||||
|
||||
def test_get_user(db_session, mock_user):
|
||||
# Using mock_user fixture instead of creating new user
|
||||
stored_user = user_crud.get(db_session, id=mock_user.id)
|
||||
assert stored_user
|
||||
assert stored_user.id == mock_user.id
|
||||
assert stored_user.email == mock_user.email
|
||||
|
||||
|
||||
def test_get_user_by_email(db_session, mock_user):
|
||||
stored_user = user_crud.get_by_email(db_session, email=mock_user.email)
|
||||
assert stored_user
|
||||
assert stored_user.id == mock_user.id
|
||||
assert stored_user.email == mock_user.email
|
||||
|
||||
|
||||
def test_update_user(db_session, mock_user):
|
||||
update_data = UserUpdate(
|
||||
first_name="Updated",
|
||||
last_name="Name",
|
||||
phone_number="+9876543210"
|
||||
)
|
||||
|
||||
updated_user = user_crud.update(db_session, db_obj=mock_user, obj_in=update_data)
|
||||
|
||||
assert updated_user.first_name == "Updated"
|
||||
assert updated_user.last_name == "Name"
|
||||
assert updated_user.phone_number == "+9876543210"
|
||||
assert updated_user.email == mock_user.email
|
||||
|
||||
|
||||
def test_delete_user(db_session, mock_user):
|
||||
user_crud.remove(db_session, id=mock_user.id)
|
||||
deleted_user = user_crud.get(db_session, id=mock_user.id)
|
||||
assert deleted_user is None
|
||||
|
||||
|
||||
def test_get_multi_users(db_session, mock_user, user_create_data):
|
||||
# Create additional users (mock_user is already in db)
|
||||
users_data = [
|
||||
{**user_create_data, "email": f"test{i}@example.com"}
|
||||
for i in range(2) # Creating 2 more users + mock_user = 3 total
|
||||
]
|
||||
|
||||
for user_data in users_data:
|
||||
user_in = UserCreate(**user_data)
|
||||
user_crud.create(db_session, obj_in=user_in)
|
||||
|
||||
users = user_crud.get_multi(db_session, skip=0, limit=10)
|
||||
assert len(users) == 3
|
||||
assert all(isinstance(user, User) for user in users)
|
||||
|
||||
|
||||
def test_is_active(db_session, mock_user):
|
||||
assert user_crud.is_active(mock_user) is True
|
||||
|
||||
# Test deactivating user
|
||||
update_data = UserUpdate(is_active=False)
|
||||
deactivated_user = user_crud.update(db_session, db_obj=mock_user, obj_in=update_data)
|
||||
assert user_crud.is_active(deactivated_user) is False
|
||||
|
||||
|
||||
def test_is_superuser(db_session, mock_user, user_create_data):
|
||||
# mock_user is regular user
|
||||
assert user_crud.is_superuser(mock_user) is False
|
||||
|
||||
# Create superuser
|
||||
super_user_data = {**user_create_data, "email": "super@example.com", "is_superuser": True}
|
||||
super_user_in = UserCreate(**super_user_data)
|
||||
super_user = user_crud.create(db_session, obj_in=super_user_in)
|
||||
assert user_crud.is_superuser(super_user) is True
|
||||
|
||||
|
||||
# Additional test cases
|
||||
def test_create_duplicate_email(db_session, mock_user):
|
||||
user_data = UserCreate(
|
||||
email=mock_user.email, # Try to create user with existing email
|
||||
password="TestPassword123!",
|
||||
first_name="Test",
|
||||
last_name="User"
|
||||
)
|
||||
with pytest.raises(Exception): # Should raise an integrity error
|
||||
user_crud.create(db_session, obj_in=user_data)
|
||||
|
||||
|
||||
def test_update_user_preferences(db_session, mock_user):
|
||||
preferences = {"theme": "dark", "notifications": True}
|
||||
update_data = UserUpdate(preferences=preferences)
|
||||
|
||||
updated_user = user_crud.update(db_session, db_obj=mock_user, obj_in=update_data)
|
||||
assert updated_user.preferences == preferences
|
||||
|
||||
|
||||
def test_get_multi_users_pagination(db_session, user_create_data):
|
||||
# Create 5 users
|
||||
for i in range(5):
|
||||
user_in = UserCreate(**{**user_create_data, "email": f"test{i}@example.com"})
|
||||
user_crud.create(db_session, obj_in=user_in)
|
||||
|
||||
# Test pagination
|
||||
first_page = user_crud.get_multi(db_session, skip=0, limit=2)
|
||||
second_page = user_crud.get_multi(db_session, skip=2, limit=2)
|
||||
|
||||
assert len(first_page) == 2
|
||||
assert len(second_page) == 2
|
||||
assert first_page[0].id != second_page[0].id
|
||||
0
backend/tests/models/__init__.py
Normal file
0
backend/tests/models/__init__.py
Normal file
249
backend/tests/models/test_user.py
Normal file
249
backend/tests/models/test_user.py
Normal file
@@ -0,0 +1,249 @@
|
||||
# tests/models/test_user.py
|
||||
import uuid
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
def test_create_user(db_session):
|
||||
"""Test creating a basic user."""
|
||||
# Arrange
|
||||
user_id = uuid.uuid4()
|
||||
new_user = User(
|
||||
id=user_id,
|
||||
email="test@example.com",
|
||||
password_hash="hashedpassword",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="1234567890",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
preferences={"theme": "dark"},
|
||||
)
|
||||
db_session.add(new_user)
|
||||
|
||||
# Act
|
||||
db_session.commit()
|
||||
created_user = db_session.query(User).filter_by(email="test@example.com").first()
|
||||
|
||||
# Assert
|
||||
assert created_user is not None
|
||||
assert created_user.email == "test@example.com"
|
||||
assert created_user.first_name == "Test"
|
||||
assert created_user.last_name == "User"
|
||||
assert created_user.phone_number == "1234567890"
|
||||
assert created_user.is_active is True
|
||||
assert created_user.is_superuser is False
|
||||
assert created_user.preferences == {"theme": "dark"}
|
||||
# UUID should be preserved
|
||||
assert created_user.id == user_id
|
||||
# Timestamps should be set
|
||||
assert isinstance(created_user.created_at, datetime)
|
||||
assert isinstance(created_user.updated_at, datetime)
|
||||
|
||||
|
||||
def test_update_user(db_session):
|
||||
"""Test updating an existing user."""
|
||||
# Arrange - Create a user
|
||||
user_id = uuid.uuid4()
|
||||
user = User(
|
||||
id=user_id,
|
||||
email="update@example.com",
|
||||
password_hash="hashedpassword",
|
||||
first_name="Before",
|
||||
last_name="Update",
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
# Record the original creation timestamp
|
||||
original_created_at = user.created_at
|
||||
|
||||
# Act - Update the user
|
||||
user.first_name = "After"
|
||||
user.last_name = "Updated"
|
||||
user.phone_number = "9876543210"
|
||||
user.preferences = {"theme": "light", "notifications": True}
|
||||
db_session.commit()
|
||||
|
||||
# Fetch the updated user to verify changes were persisted
|
||||
updated_user = db_session.query(User).filter_by(id=user_id).first()
|
||||
|
||||
# Assert
|
||||
assert updated_user.first_name == "After"
|
||||
assert updated_user.last_name == "Updated"
|
||||
assert updated_user.phone_number == "9876543210"
|
||||
assert updated_user.preferences == {"theme": "light", "notifications": True}
|
||||
# created_at shouldn't change on update
|
||||
assert updated_user.created_at == original_created_at
|
||||
# updated_at should be newer than created_at
|
||||
assert updated_user.updated_at > original_created_at
|
||||
|
||||
|
||||
def test_delete_user(db_session):
|
||||
"""Test deleting a user."""
|
||||
# Arrange - Create a user
|
||||
user_id = uuid.uuid4()
|
||||
user = User(
|
||||
id=user_id,
|
||||
email="delete@example.com",
|
||||
password_hash="hashedpassword",
|
||||
first_name="Delete",
|
||||
last_name="Me",
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
# Act - Delete the user
|
||||
db_session.delete(user)
|
||||
db_session.commit()
|
||||
|
||||
# Assert
|
||||
deleted_user = db_session.query(User).filter_by(id=user_id).first()
|
||||
assert deleted_user is None
|
||||
|
||||
|
||||
def test_user_unique_email_constraint(db_session):
|
||||
"""Test that users cannot have duplicate emails."""
|
||||
# Arrange - Create a user
|
||||
user1 = User(
|
||||
id=uuid.uuid4(),
|
||||
email="duplicate@example.com",
|
||||
password_hash="hashedpassword",
|
||||
first_name="First",
|
||||
last_name="User",
|
||||
)
|
||||
db_session.add(user1)
|
||||
db_session.commit()
|
||||
|
||||
# Act & Assert - Try to create another user with the same email
|
||||
user2 = User(
|
||||
id=uuid.uuid4(),
|
||||
email="duplicate@example.com", # Same email
|
||||
password_hash="differenthash",
|
||||
first_name="Second",
|
||||
last_name="User",
|
||||
)
|
||||
db_session.add(user2)
|
||||
|
||||
# Should raise IntegrityError due to unique constraint
|
||||
with pytest.raises(IntegrityError):
|
||||
db_session.commit()
|
||||
|
||||
# Rollback for cleanup
|
||||
db_session.rollback()
|
||||
|
||||
|
||||
def test_user_required_fields(db_session):
|
||||
"""Test that required fields are enforced."""
|
||||
# Test each required field by creating a user without it
|
||||
|
||||
# Missing email
|
||||
user_no_email = User(
|
||||
id=uuid.uuid4(),
|
||||
# email is missing
|
||||
password_hash="hashedpassword",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
)
|
||||
db_session.add(user_no_email)
|
||||
with pytest.raises(IntegrityError):
|
||||
db_session.commit()
|
||||
db_session.rollback()
|
||||
|
||||
# Missing password_hash
|
||||
user_no_password = User(
|
||||
id=uuid.uuid4(),
|
||||
email="nopassword@example.com",
|
||||
# password_hash is missing
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
)
|
||||
db_session.add(user_no_password)
|
||||
with pytest.raises(IntegrityError):
|
||||
db_session.commit()
|
||||
db_session.rollback()
|
||||
|
||||
|
||||
|
||||
def test_user_defaults(db_session):
|
||||
"""Test that default values are correctly set."""
|
||||
# Arrange - Create a minimal user with only required fields
|
||||
minimal_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="minimal@example.com",
|
||||
password_hash="hashedpassword",
|
||||
first_name="Minimal",
|
||||
last_name="User",
|
||||
)
|
||||
db_session.add(minimal_user)
|
||||
db_session.commit()
|
||||
|
||||
# Act - Retrieve the user
|
||||
created_user = db_session.query(User).filter_by(email="minimal@example.com").first()
|
||||
|
||||
# Assert - Check default values
|
||||
assert created_user.is_active is True # Default should be True
|
||||
assert created_user.is_superuser is False # Default should be False
|
||||
assert created_user.phone_number is None # Optional field
|
||||
assert created_user.preferences is None # Optional field
|
||||
|
||||
|
||||
def test_user_string_representation(db_session):
|
||||
"""Test the string representation of a user."""
|
||||
# Arrange
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="repr@example.com",
|
||||
password_hash="hashedpassword",
|
||||
first_name="String",
|
||||
last_name="Repr",
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
assert str(user) == "<User repr@example.com>"
|
||||
assert repr(user) == "<User repr@example.com>"
|
||||
|
||||
|
||||
def test_user_with_complex_json_preferences(db_session):
|
||||
"""Test storing and retrieving complex JSON preferences."""
|
||||
# Arrange - Create a user with nested JSON preferences
|
||||
complex_preferences = {
|
||||
"theme": {
|
||||
"mode": "dark",
|
||||
"colors": {
|
||||
"primary": "#333",
|
||||
"secondary": "#666"
|
||||
}
|
||||
},
|
||||
"notifications": {
|
||||
"email": True,
|
||||
"sms": False,
|
||||
"push": {
|
||||
"enabled": True,
|
||||
"quiet_hours": [22, 7]
|
||||
}
|
||||
},
|
||||
"tags": ["important", "family", "events"]
|
||||
}
|
||||
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="complex@example.com",
|
||||
password_hash="hashedpassword",
|
||||
first_name="Complex",
|
||||
last_name="JSON",
|
||||
preferences=complex_preferences
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
# Act - Retrieve the user
|
||||
retrieved_user = db_session.query(User).filter_by(email="complex@example.com").first()
|
||||
|
||||
# Assert - The complex JSON should be preserved
|
||||
assert retrieved_user.preferences == complex_preferences
|
||||
assert retrieved_user.preferences["theme"]["colors"]["primary"] == "#333"
|
||||
assert retrieved_user.preferences["notifications"]["push"]["quiet_hours"] == [22, 7]
|
||||
assert "important" in retrieved_user.preferences["tags"]
|
||||
0
backend/tests/schemas/__init__.py
Normal file
0
backend/tests/schemas/__init__.py
Normal file
127
backend/tests/schemas/test_user_schemas.py
Normal file
127
backend/tests/schemas/test_user_schemas.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# tests/schemas/test_user_schemas.py
|
||||
import pytest
|
||||
import re
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.schemas.users import UserBase, UserCreate
|
||||
|
||||
class TestPhoneNumberValidation:
|
||||
"""Tests for phone number validation in user schemas"""
|
||||
|
||||
def test_valid_swiss_numbers(self):
|
||||
"""Test valid Swiss phone numbers are accepted"""
|
||||
# International format
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41791234567")
|
||||
assert user.phone_number == "+41791234567"
|
||||
|
||||
# Local format
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0791234567")
|
||||
assert user.phone_number == "0791234567"
|
||||
|
||||
# With formatting characters
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41 79 123 45 67")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+41791234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="079 123 45 67")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "0791234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41-79-123-45-67")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+41791234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="079-123-45-67")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "0791234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41 (79) 123 45 67")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+41791234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="079 (123) 45 67")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "0791234567"
|
||||
|
||||
def test_valid_italian_numbers(self):
|
||||
"""Test valid Italian phone numbers are accepted"""
|
||||
# International format
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+393451234567")
|
||||
assert user.phone_number == "+393451234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39345123456")
|
||||
assert user.phone_number == "+39345123456"
|
||||
|
||||
# Local format
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="03451234567")
|
||||
assert user.phone_number == "03451234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345123456789")
|
||||
assert user.phone_number == "0345123456789"
|
||||
|
||||
# With formatting characters
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39 345 123 4567")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+393451234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345 123 4567")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "03451234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39-345-123-4567")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+393451234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345-123-4567")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "03451234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39 (345) 123 4567")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+393451234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345 (123) 4567")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "03451234567"
|
||||
|
||||
def test_none_phone_number(self):
|
||||
"""Test that None is accepted as a valid value (optional phone number)"""
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number=None)
|
||||
assert user.phone_number is None
|
||||
|
||||
def test_invalid_phone_numbers(self):
|
||||
"""Test that invalid phone numbers are rejected"""
|
||||
invalid_numbers = [
|
||||
# Too short
|
||||
"+12",
|
||||
"012",
|
||||
|
||||
# Invalid characters
|
||||
"+41xyz123456",
|
||||
"079abc4567",
|
||||
"123-abc-7890",
|
||||
"+1(800)CALL-NOW",
|
||||
|
||||
# Completely invalid formats
|
||||
"++4412345678", # Double plus
|
||||
"()+41123456", # Misplaced parentheses
|
||||
|
||||
# Empty string
|
||||
"",
|
||||
# Spaces only
|
||||
" ",
|
||||
]
|
||||
|
||||
for number in invalid_numbers:
|
||||
with pytest.raises(ValidationError):
|
||||
UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number=number)
|
||||
|
||||
def test_phone_validation_in_user_create(self):
|
||||
"""Test that phone validation also works in UserCreate schema"""
|
||||
# Valid phone number
|
||||
user = UserCreate(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
password="Password123",
|
||||
phone_number="+41791234567"
|
||||
)
|
||||
assert user.phone_number == "+41791234567"
|
||||
|
||||
# Invalid phone number should raise ValidationError
|
||||
with pytest.raises(ValidationError):
|
||||
UserCreate(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
password="Password123",
|
||||
phone_number="invalid-number"
|
||||
)
|
||||
0
backend/tests/services/__init__.py
Normal file
0
backend/tests/services/__init__.py
Normal file
252
backend/tests/services/test_auth_service.py
Normal file
252
backend/tests/services/test_auth_service.py
Normal file
@@ -0,0 +1,252 @@
|
||||
# tests/services/test_auth_service.py
|
||||
import uuid
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.core.auth import get_password_hash, verify_password, TokenExpiredError, TokenInvalidError
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserCreate, Token
|
||||
from app.services.auth_service import AuthService, AuthenticationError
|
||||
|
||||
|
||||
class TestAuthServiceAuthentication:
|
||||
"""Tests for AuthService.authenticate_user method"""
|
||||
|
||||
def test_authenticate_valid_user(self, db_session, mock_user):
|
||||
"""Test authenticating a user with valid credentials"""
|
||||
# Set a known password for the mock user
|
||||
password = "TestPassword123"
|
||||
mock_user.password_hash = get_password_hash(password)
|
||||
db_session.commit()
|
||||
|
||||
# Authenticate with correct credentials
|
||||
user = AuthService.authenticate_user(
|
||||
db=db_session,
|
||||
email=mock_user.email,
|
||||
password=password
|
||||
)
|
||||
|
||||
assert user is not None
|
||||
assert user.id == mock_user.id
|
||||
assert user.email == mock_user.email
|
||||
|
||||
def test_authenticate_nonexistent_user(self, db_session):
|
||||
"""Test authenticating with an email that doesn't exist"""
|
||||
user = AuthService.authenticate_user(
|
||||
db=db_session,
|
||||
email="nonexistent@example.com",
|
||||
password="password"
|
||||
)
|
||||
|
||||
assert user is None
|
||||
|
||||
def test_authenticate_with_wrong_password(self, db_session, mock_user):
|
||||
"""Test authenticating with the wrong password"""
|
||||
# Set a known password for the mock user
|
||||
password = "TestPassword123"
|
||||
mock_user.password_hash = get_password_hash(password)
|
||||
db_session.commit()
|
||||
|
||||
# Authenticate with wrong password
|
||||
user = AuthService.authenticate_user(
|
||||
db=db_session,
|
||||
email=mock_user.email,
|
||||
password="WrongPassword123"
|
||||
)
|
||||
|
||||
assert user is None
|
||||
|
||||
def test_authenticate_inactive_user(self, db_session, mock_user):
|
||||
"""Test authenticating an inactive user"""
|
||||
# Set a known password and make user inactive
|
||||
password = "TestPassword123"
|
||||
mock_user.password_hash = get_password_hash(password)
|
||||
mock_user.is_active = False
|
||||
db_session.commit()
|
||||
|
||||
# Should raise AuthenticationError
|
||||
with pytest.raises(AuthenticationError):
|
||||
AuthService.authenticate_user(
|
||||
db=db_session,
|
||||
email=mock_user.email,
|
||||
password=password
|
||||
)
|
||||
|
||||
|
||||
class TestAuthServiceUserCreation:
|
||||
"""Tests for AuthService.create_user method"""
|
||||
|
||||
def test_create_new_user(self, db_session):
|
||||
"""Test creating a new user"""
|
||||
user_data = UserCreate(
|
||||
email="newuser@example.com",
|
||||
password="TestPassword123",
|
||||
first_name="New",
|
||||
last_name="User",
|
||||
phone_number="1234567890"
|
||||
)
|
||||
|
||||
user = AuthService.create_user(db=db_session, user_data=user_data)
|
||||
|
||||
# Verify user was created with correct data
|
||||
assert user is not None
|
||||
assert user.email == user_data.email
|
||||
assert user.first_name == user_data.first_name
|
||||
assert user.last_name == user_data.last_name
|
||||
assert user.phone_number == user_data.phone_number
|
||||
|
||||
# Verify password was hashed
|
||||
assert user.password_hash != user_data.password
|
||||
assert verify_password(user_data.password, user.password_hash)
|
||||
|
||||
# Verify default values
|
||||
assert user.is_active is True
|
||||
assert user.is_superuser is False
|
||||
|
||||
def test_create_user_with_existing_email(self, db_session, mock_user):
|
||||
"""Test creating a user with an email that already exists"""
|
||||
user_data = UserCreate(
|
||||
email=mock_user.email, # Use existing email
|
||||
password="TestPassword123",
|
||||
first_name="Duplicate",
|
||||
last_name="User"
|
||||
)
|
||||
|
||||
# Should raise AuthenticationError
|
||||
with pytest.raises(AuthenticationError):
|
||||
AuthService.create_user(db=db_session, user_data=user_data)
|
||||
|
||||
|
||||
class TestAuthServiceTokens:
|
||||
"""Tests for AuthService token-related methods"""
|
||||
|
||||
def test_create_tokens(self, mock_user):
|
||||
"""Test creating access and refresh tokens for a user"""
|
||||
tokens = AuthService.create_tokens(mock_user)
|
||||
|
||||
# Verify token structure
|
||||
assert isinstance(tokens, Token)
|
||||
assert tokens.access_token is not None
|
||||
assert tokens.refresh_token is not None
|
||||
assert tokens.token_type == "bearer"
|
||||
|
||||
# This is a more in-depth test that would decode the tokens to verify claims
|
||||
# but we'll rely on the auth module tests for token verification
|
||||
|
||||
def test_refresh_tokens(self, db_session, mock_user):
|
||||
"""Test refreshing tokens with a valid refresh token"""
|
||||
# Create initial tokens
|
||||
initial_tokens = AuthService.create_tokens(mock_user)
|
||||
|
||||
# Refresh tokens
|
||||
new_tokens = AuthService.refresh_tokens(
|
||||
db=db_session,
|
||||
refresh_token=initial_tokens.refresh_token
|
||||
)
|
||||
|
||||
# Verify new tokens are different from old ones
|
||||
assert new_tokens.access_token != initial_tokens.access_token
|
||||
assert new_tokens.refresh_token != initial_tokens.refresh_token
|
||||
|
||||
def test_refresh_tokens_with_invalid_token(self, db_session):
|
||||
"""Test refreshing tokens with an invalid token"""
|
||||
# Create an invalid token
|
||||
invalid_token = "invalid.token.string"
|
||||
|
||||
# Should raise TokenInvalidError
|
||||
with pytest.raises(TokenInvalidError):
|
||||
AuthService.refresh_tokens(
|
||||
db=db_session,
|
||||
refresh_token=invalid_token
|
||||
)
|
||||
|
||||
def test_refresh_tokens_with_access_token(self, db_session, mock_user):
|
||||
"""Test refreshing tokens with an access token instead of refresh token"""
|
||||
# Create tokens
|
||||
tokens = AuthService.create_tokens(mock_user)
|
||||
|
||||
# Try to refresh with access token
|
||||
with pytest.raises(TokenInvalidError):
|
||||
AuthService.refresh_tokens(
|
||||
db=db_session,
|
||||
refresh_token=tokens.access_token
|
||||
)
|
||||
|
||||
def test_refresh_tokens_with_nonexistent_user(self, db_session):
|
||||
"""Test refreshing tokens for a user that doesn't exist in the database"""
|
||||
# Create a token for a non-existent user
|
||||
non_existent_id = str(uuid.uuid4())
|
||||
with patch('app.core.auth.decode_token'), patch('app.core.auth.get_token_data') as mock_get_data:
|
||||
# Mock the token data to return a non-existent user ID
|
||||
mock_get_data.return_value.user_id = uuid.UUID(non_existent_id)
|
||||
|
||||
# Should raise TokenInvalidError
|
||||
with pytest.raises(TokenInvalidError):
|
||||
AuthService.refresh_tokens(
|
||||
db=db_session,
|
||||
refresh_token="some.refresh.token"
|
||||
)
|
||||
|
||||
|
||||
class TestAuthServicePasswordChange:
|
||||
"""Tests for AuthService.change_password method"""
|
||||
|
||||
def test_change_password(self, db_session, mock_user):
|
||||
"""Test changing a user's password"""
|
||||
# Set a known password for the mock user
|
||||
current_password = "CurrentPassword123"
|
||||
mock_user.password_hash = get_password_hash(current_password)
|
||||
db_session.commit()
|
||||
|
||||
# Change password
|
||||
new_password = "NewPassword456"
|
||||
result = AuthService.change_password(
|
||||
db=db_session,
|
||||
user_id=mock_user.id,
|
||||
current_password=current_password,
|
||||
new_password=new_password
|
||||
)
|
||||
|
||||
# Verify operation was successful
|
||||
assert result is True
|
||||
|
||||
# Refresh user from DB
|
||||
db_session.refresh(mock_user)
|
||||
|
||||
# Verify old password no longer works
|
||||
assert not verify_password(current_password, mock_user.password_hash)
|
||||
|
||||
# Verify new password works
|
||||
assert verify_password(new_password, mock_user.password_hash)
|
||||
|
||||
def test_change_password_wrong_current_password(self, db_session, mock_user):
|
||||
"""Test changing password with incorrect current password"""
|
||||
# Set a known password for the mock user
|
||||
current_password = "CurrentPassword123"
|
||||
mock_user.password_hash = get_password_hash(current_password)
|
||||
db_session.commit()
|
||||
|
||||
# Try to change password with wrong current password
|
||||
wrong_password = "WrongPassword123"
|
||||
with pytest.raises(AuthenticationError):
|
||||
AuthService.change_password(
|
||||
db=db_session,
|
||||
user_id=mock_user.id,
|
||||
current_password=wrong_password,
|
||||
new_password="NewPassword456"
|
||||
)
|
||||
|
||||
# Verify password was not changed
|
||||
assert verify_password(current_password, mock_user.password_hash)
|
||||
|
||||
def test_change_password_nonexistent_user(self, db_session):
|
||||
"""Test changing password for a user that doesn't exist"""
|
||||
non_existent_id = uuid.uuid4()
|
||||
|
||||
with pytest.raises(AuthenticationError):
|
||||
AuthService.change_password(
|
||||
db=db_session,
|
||||
user_id=non_existent_id,
|
||||
current_password="CurrentPassword123",
|
||||
new_password="NewPassword456"
|
||||
)
|
||||
Reference in New Issue
Block a user