From 162e586e1321ef7a5fc60e3ef581e653f025edb8 Mon Sep 17 00:00:00 2001 From: Felipe Cardoso Date: Tue, 4 Mar 2025 19:10:54 +0100 Subject: [PATCH] Add comprehensive test suite and utilities for user functionality This commit introduces a suite of tests for user models, schemas, CRUD operations, and authentication services. It also adds utilities for in-memory database setup to support these tests and updates environment settings for consistency. --- .env.template | 2 +- .../38bf9e7e74b3_add_all_initial_models.py | 46 +++ backend/app/api/__init__.py | 0 backend/app/api/dependencies/__init__.py | 0 backend/app/api/dependencies/auth.py | 137 +++++++ backend/app/api/main.py | 6 + backend/app/api/routes/__init__.py | 0 backend/app/api/routes/auth.py | 231 +++++++++++ backend/app/core/auth.py | 185 +++++++++ backend/app/core/config.py | 15 +- backend/app/core/database.py | 48 ++- backend/app/crud/__init__.py | 0 backend/app/crud/base.py | 62 +++ backend/app/crud/user.py | 56 +++ backend/app/main.py | 14 +- backend/app/models/__init__.py | 14 + backend/app/models/base.py | 20 + backend/app/models/user.py | 19 + backend/app/schemas/__init__.py | 0 backend/app/schemas/users.py | 149 +++++++ backend/app/services/__init__.py | 0 backend/app/services/auth_service.py | 193 +++++++++ backend/app/utils/__init__.py | 0 backend/app/utils/test_utils.py | 79 ++++ backend/pytest.ini | 10 + backend/requirements.txt | 14 +- backend/tests/api/routes/__init__.py | 0 backend/tests/api/routes/test_auth.py | 369 ++++++++++++++++++ backend/tests/api/test_auth_dependencies.py | 211 ++++++++++ backend/tests/conftest.py | 66 ++++ backend/tests/core/__init__.py | 0 backend/tests/core/test_auth.py | 260 ++++++++++++ backend/tests/crud/__init__.py | 0 backend/tests/crud/test_user.py | 125 ++++++ backend/tests/models/__init__.py | 0 backend/tests/models/test_user.py | 249 ++++++++++++ backend/tests/schemas/__init__.py | 0 backend/tests/schemas/test_user_schemas.py | 127 ++++++ backend/tests/services/__init__.py | 0 backend/tests/services/test_auth_service.py | 252 ++++++++++++ 40 files changed, 2948 insertions(+), 11 deletions(-) create mode 100644 backend/app/alembic/versions/38bf9e7e74b3_add_all_initial_models.py create mode 100644 backend/app/api/__init__.py create mode 100644 backend/app/api/dependencies/__init__.py create mode 100644 backend/app/api/dependencies/auth.py create mode 100644 backend/app/api/main.py create mode 100644 backend/app/api/routes/__init__.py create mode 100644 backend/app/api/routes/auth.py create mode 100644 backend/app/core/auth.py create mode 100644 backend/app/crud/__init__.py create mode 100644 backend/app/crud/base.py create mode 100644 backend/app/crud/user.py create mode 100644 backend/app/models/__init__.py create mode 100644 backend/app/models/base.py create mode 100644 backend/app/models/user.py create mode 100644 backend/app/schemas/__init__.py create mode 100644 backend/app/schemas/users.py create mode 100644 backend/app/services/__init__.py create mode 100644 backend/app/services/auth_service.py create mode 100644 backend/app/utils/__init__.py create mode 100644 backend/app/utils/test_utils.py create mode 100644 backend/pytest.ini create mode 100644 backend/tests/api/routes/__init__.py create mode 100644 backend/tests/api/routes/test_auth.py create mode 100644 backend/tests/api/test_auth_dependencies.py create mode 100644 backend/tests/conftest.py create mode 100644 backend/tests/core/__init__.py create mode 100644 backend/tests/core/test_auth.py create mode 100644 backend/tests/crud/__init__.py create mode 100644 backend/tests/crud/test_user.py create mode 100644 backend/tests/models/__init__.py create mode 100644 backend/tests/models/test_user.py create mode 100644 backend/tests/schemas/__init__.py create mode 100644 backend/tests/schemas/test_user_schemas.py create mode 100644 backend/tests/services/__init__.py create mode 100644 backend/tests/services/test_auth_service.py diff --git a/.env.template b/.env.template index b033c5b..a3bce2a 100644 --- a/.env.template +++ b/.env.template @@ -17,7 +17,7 @@ ENVIRONMENT=development DEBUG=true BACKEND_CORS_ORIGINS=["http://localhost:3000"] FIRST_SUPERUSER_EMAIL=admin@example.com -FIRST_SUPERUSER_PASSWORD=admin123 +FIRST_SUPERUSER_PASSWORD=Admin123 # Frontend settings FRONTEND_PORT=3000 diff --git a/backend/app/alembic/versions/38bf9e7e74b3_add_all_initial_models.py b/backend/app/alembic/versions/38bf9e7e74b3_add_all_initial_models.py new file mode 100644 index 0000000..f563857 --- /dev/null +++ b/backend/app/alembic/versions/38bf9e7e74b3_add_all_initial_models.py @@ -0,0 +1,46 @@ +"""Add all initial models + +Revision ID: 38bf9e7e74b3 +Revises: 7396957cbe80 +Create Date: 2025-02-28 09:19:33.212278 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '38bf9e7e74b3' +down_revision: Union[str, None] = '7396957cbe80' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + + op.create_table('users', + sa.Column('email', sa.String(), nullable=False), + sa.Column('password_hash', sa.String(), nullable=False), + sa.Column('first_name', sa.String(), nullable=False), + sa.Column('last_name', sa.String(), nullable=True), + sa.Column('phone_number', sa.String(), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=False), + sa.Column('is_superuser', sa.Boolean(), nullable=False), + sa.Column('preferences', sa.JSON(), nullable=True), + sa.Column('id', sa.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), + sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_users_email'), table_name='users') + op.drop_table('users') + # ### end Alembic commands ### diff --git a/backend/app/api/__init__.py b/backend/app/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/api/dependencies/__init__.py b/backend/app/api/dependencies/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/api/dependencies/auth.py b/backend/app/api/dependencies/auth.py new file mode 100644 index 0000000..db24f68 --- /dev/null +++ b/backend/app/api/dependencies/auth.py @@ -0,0 +1,137 @@ +from typing import Optional + +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer +from sqlalchemy.orm import Session + +from app.core.auth import get_token_data, TokenExpiredError, TokenInvalidError +from app.core.database import get_db +from app.models.user import User + +# OAuth2 configuration +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") + + +def get_current_user( + db: Session = Depends(get_db), + token: str = Depends(oauth2_scheme) +) -> User: + """ + Get the current authenticated user. + + Args: + db: Database session + token: JWT token from request + + Returns: + User: The authenticated user + + Raises: + HTTPException: If authentication fails + """ + try: + # Decode token and get user ID + token_data = get_token_data(token) + + # Get user from database + user = db.query(User).filter(User.id == token_data.user_id).first() + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found" + ) + + if not user.is_active: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Inactive user" + ) + + return user + + except TokenExpiredError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Token expired", + headers={"WWW-Authenticate": "Bearer"} + ) + except TokenInvalidError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"} + ) + + +def get_current_active_user( + current_user: User = Depends(get_current_user) +) -> User: + """ + Check if the current user is active. + + Args: + current_user: The current authenticated user + + Returns: + User: The authenticated and active user + + Raises: + HTTPException: If user is inactive + """ + if not current_user.is_active: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Inactive user" + ) + return current_user + + +def get_current_superuser( + current_user: User = Depends(get_current_user) +) -> User: + """ + Check if the current user is a superuser. + + Args: + current_user: The current authenticated user + + Returns: + User: The authenticated superuser + + Raises: + HTTPException: If user is not a superuser + """ + if not current_user.is_superuser: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not enough permissions" + ) + return current_user + + +def get_optional_current_user( + db: Session = Depends(get_db), + token: Optional[str] = Depends(oauth2_scheme) +) -> Optional[User]: + """ + Get the current user if authenticated, otherwise return None. + Useful for endpoints that work with both authenticated and unauthenticated users. + + Args: + db: Database session + token: JWT token from request + + Returns: + User or None: The authenticated user or None + """ + if not token: + return None + + try: + token_data = get_token_data(token) + user = db.query(User).filter(User.id == token_data.user_id).first() + if not user or not user.is_active: + return None + return user + except (TokenExpiredError, TokenInvalidError): + return None \ No newline at end of file diff --git a/backend/app/api/main.py b/backend/app/api/main.py new file mode 100644 index 0000000..6b6f08f --- /dev/null +++ b/backend/app/api/main.py @@ -0,0 +1,6 @@ +from fastapi import APIRouter + +from app.api.routes import auth + +api_router = APIRouter() +api_router.include_router(auth.router, prefix="/auth", tags=["auth"]) diff --git a/backend/app/api/routes/__init__.py b/backend/app/api/routes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/api/routes/auth.py b/backend/app/api/routes/auth.py new file mode 100644 index 0000000..d0bbb79 --- /dev/null +++ b/backend/app/api/routes/auth.py @@ -0,0 +1,231 @@ +# app/api/routes/auth.py +import logging +from typing import Any + +from fastapi import APIRouter, Depends, HTTPException, status, Body +from fastapi.security import OAuth2PasswordRequestForm +from sqlalchemy.orm import Session + +from app.api.dependencies.auth import get_current_user +from app.core.auth import TokenExpiredError, TokenInvalidError +from app.core.database import get_db +from app.models.user import User +from app.schemas.users import ( + UserCreate, + UserResponse, + Token, + LoginRequest, + RefreshTokenRequest +) +from app.services.auth_service import AuthService, AuthenticationError + +router = APIRouter() +logger = logging.getLogger(__name__) + + +@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED) +async def register_user( + user_data: UserCreate, + db: Session = Depends(get_db) +) -> Any: + """ + Register a new user. + + Returns: + The created user information. + """ + try: + user = AuthService.create_user(db, user_data) + return user + except AuthenticationError as e: + logger.warning(f"Registration failed: {str(e)}") + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=str(e) + ) + except Exception as e: + logger.error(f"Unexpected error during registration: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An unexpected error occurred. Please try again later." + ) + + +@router.post("/login", response_model=Token) +async def login( + login_data: LoginRequest, + db: Session = Depends(get_db) +) -> Any: + """ + Login with username and password. + + Returns: + Access and refresh tokens. + """ + try: + # Attempt to authenticate the user + user = AuthService.authenticate_user(db, login_data.email, login_data.password) + + # Explicitly check for None result and raise correct exception + if user is None: + logger.warning(f"Invalid login attempt for: {login_data.email}") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid email or password", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # User is authenticated, generate tokens + tokens = AuthService.create_tokens(user) + logger.info(f"User login successful: {user.email}") + return tokens + + except HTTPException: + # Re-raise HTTP exceptions without modification + raise + except AuthenticationError as e: + # Handle specific authentication errors like inactive accounts + logger.warning(f"Authentication failed: {str(e)}") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=str(e), + headers={"WWW-Authenticate": "Bearer"}, + ) + except Exception as e: + # Handle unexpected errors + logger.error(f"Unexpected error during login: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An unexpected error occurred. Please try again later." + ) + + +@router.post("/login/oauth", response_model=Token) +async def login_oauth( + form_data: OAuth2PasswordRequestForm = Depends(), + db: Session = Depends(get_db) +) -> Any: + """ + OAuth2-compatible login endpoint, used by the OpenAPI UI. + + Returns: + Access and refresh tokens. + """ + try: + user = AuthService.authenticate_user(db, form_data.username, form_data.password) + + if user is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid email or password", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # Generate tokens + tokens = AuthService.create_tokens(user) + + # Format response for OAuth compatibility + return { + "access_token": tokens.access_token, + "refresh_token": tokens.refresh_token, + "token_type": tokens.token_type + } + except HTTPException: + raise + except AuthenticationError as e: + logger.warning(f"OAuth authentication failed: {str(e)}") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=str(e), + headers={"WWW-Authenticate": "Bearer"}, + ) + except Exception as e: + logger.error(f"Unexpected error during OAuth login: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An unexpected error occurred. Please try again later." + ) + + +@router.post("/refresh", response_model=Token) +async def refresh_token( + refresh_data: RefreshTokenRequest, + db: Session = Depends(get_db) +) -> Any: + """ + Refresh access token using a refresh token. + + Returns: + New access and refresh tokens. + """ + try: + tokens = AuthService.refresh_tokens(db, refresh_data.refresh_token) + return tokens + except TokenExpiredError: + logger.warning("Token refresh failed: Token expired") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Refresh token has expired. Please log in again.", + headers={"WWW-Authenticate": "Bearer"}, + ) + except TokenInvalidError: + logger.warning("Token refresh failed: Invalid token") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid refresh token", + headers={"WWW-Authenticate": "Bearer"}, + ) + except Exception as e: + logger.error(f"Unexpected error during token refresh: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An unexpected error occurred. Please try again later." + ) + + +@router.post("/change-password", status_code=status.HTTP_200_OK) +async def change_password( + current_password: str = Body(..., embed=True), + new_password: str = Body(..., embed=True), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +) -> Any: + """ + Change current user's password. + + Requires authentication. + """ + try: + success = AuthService.change_password( + db=db, + user_id=current_user.id, + current_password=current_password, + new_password=new_password + ) + + if success: + return {"message": "Password changed successfully"} + except AuthenticationError as e: + logger.warning(f"Password change failed: {str(e)}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e) + ) + except Exception as e: + logger.error(f"Unexpected error during password change: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An unexpected error occurred. Please try again later." + ) + + +@router.get("/me", response_model=UserResponse) +async def get_current_user_info( + current_user: User = Depends(get_current_user) +) -> Any: + """ + Get current user information. + + Requires authentication. + """ + return current_user diff --git a/backend/app/core/auth.py b/backend/app/core/auth.py new file mode 100644 index 0000000..21ddaf1 --- /dev/null +++ b/backend/app/core/auth.py @@ -0,0 +1,185 @@ +import logging +logging.getLogger('passlib').setLevel(logging.ERROR) + +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, Optional, Union +import uuid + +from jose import jwt, JWTError +from passlib.context import CryptContext +from pydantic import ValidationError + +from app.core.config import settings +from app.schemas.users import TokenData, TokenPayload + + +# Password hashing context +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + +# Custom exceptions for auth +class AuthError(Exception): + """Base authentication error""" + pass + +class TokenExpiredError(AuthError): + """Token has expired""" + pass + +class TokenInvalidError(AuthError): + """Token is invalid""" + pass + +class TokenMissingClaimError(AuthError): + """Token is missing a required claim""" + pass + + +def verify_password(plain_password: str, hashed_password: str) -> bool: + """Verify a password against a hash.""" + return pwd_context.verify(plain_password, hashed_password) + + +def get_password_hash(password: str) -> str: + """Generate a password hash.""" + return pwd_context.hash(password) + + +def create_access_token( + subject: Union[str, Any], + expires_delta: Optional[timedelta] = None, + claims: Optional[Dict[str, Any]] = None +) -> str: + """ + Create a JWT access token. + + Args: + subject: The subject of the token (usually user_id) + expires_delta: Optional expiration time delta + claims: Optional additional claims to include in the token + + Returns: + Encoded JWT token + """ + if expires_delta: + expire = datetime.now(timezone.utc) + expires_delta + else: + expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + + # Base token data + to_encode = { + "sub": str(subject), + "exp": expire, + "iat": datetime.now(tz=timezone.utc), + "jti": str(uuid.uuid4()), + "type": "access" + } + + # Add custom claims + if claims: + to_encode.update(claims) + + # Create the JWT + encoded_jwt = jwt.encode( + to_encode, + settings.SECRET_KEY, + algorithm=settings.ALGORITHM + ) + + return encoded_jwt + + +def create_refresh_token( + subject: Union[str, Any], + expires_delta: Optional[timedelta] = None +) -> str: + """ + Create a JWT refresh token. + + Args: + subject: The subject of the token (usually user_id) + expires_delta: Optional expiration time delta + + Returns: + Encoded JWT refresh token + """ + if expires_delta: + expire = datetime.now(timezone.utc) + expires_delta + else: + expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) + + to_encode = { + "sub": str(subject), + "exp": expire, + "iat": datetime.now(timezone.utc), + "jti": str(uuid.uuid4()), + "type": "refresh" + } + + encoded_jwt = jwt.encode( + to_encode, + settings.SECRET_KEY, + algorithm=settings.ALGORITHM + ) + + return encoded_jwt + + +def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload: + """ + Decode and verify a JWT token. + + Args: + token: JWT token to decode + verify_type: Optional token type to verify (access or refresh) + + Returns: + TokenPayload: The decoded token data + + Raises: + TokenExpiredError: If the token has expired + TokenInvalidError: If the token is invalid + TokenMissingClaimError: If a required claim is missing + """ + try: + payload = jwt.decode( + token, + settings.SECRET_KEY, + algorithms=[settings.ALGORITHM] + ) + + # Check required claims before Pydantic validation + if not payload.get("sub"): + raise TokenMissingClaimError("Token missing 'sub' claim") + + # Verify token type if specified + if verify_type and payload.get("type") != verify_type: + raise TokenInvalidError(f"Invalid token type, expected {verify_type}") + + # Now validate with Pydantic + token_data = TokenPayload(**payload) + return token_data + + except JWTError as e: + # Check if the error is due to an expired token + if "expired" in str(e).lower(): + raise TokenExpiredError("Token has expired") + raise TokenInvalidError("Invalid authentication token") + except ValidationError: + raise TokenInvalidError("Invalid token payload") + + +def get_token_data(token: str) -> TokenData: + """ + Extract the user ID and superuser status from a token. + + Args: + token: JWT token + + Returns: + TokenData with user_id and is_superuser + """ + payload = decode_token(token) + user_id = payload.sub + is_superuser = payload.is_superuser or False + + return TokenData(user_id=uuid.UUID(user_id), is_superuser=is_superuser) \ No newline at end of file diff --git a/backend/app/core/config.py b/backend/app/core/config.py index d674792..d43b074 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -3,7 +3,7 @@ from typing import Optional, List class Settings(BaseSettings): - PROJECT_NAME: str = "App" + PROJECT_NAME: str = "EventSpace" VERSION: str = "1.0.0" API_V1_STR: str = "/api/v1" @@ -14,6 +14,17 @@ class Settings(BaseSettings): POSTGRES_PORT: str = "5432" POSTGRES_DB: str = "app" DATABASE_URL: Optional[str] = None + REFRESH_TOKEN_EXPIRE_DAYS: int = 60 + db_pool_size: int = 20 # Default connection pool size + db_max_overflow: int = 50 # Maximum overflow connections + db_pool_timeout: int = 30 # Seconds to wait for a connection + db_pool_recycle: int = 3600 # Recycle connections after 1 hour + + # SQL debugging (disable in production) + sql_echo: bool = False # Log SQL statements + sql_echo_pool: bool = False # Log connection pool events + sql_echo_timing: bool = False # Log query execution times + slow_query_threshold: float = 0.5 # Log queries taking longer than this @property def database_url(self) -> str: @@ -30,7 +41,7 @@ class Settings(BaseSettings): # JWT configuration SECRET_KEY: str = "your_secret_key_here" ALGORITHM: str = "HS256" - ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 + ACCESS_TOKEN_EXPIRE_MINUTES: int = 1440 # 1 day # CORS configuration BACKEND_CORS_ORIGINS: List[str] = ["http://localhost:3000"] diff --git a/backend/app/core/database.py b/backend/app/core/database.py index 22d743a..251d012 100644 --- a/backend/app/core/database.py +++ b/backend/app/core/database.py @@ -1,17 +1,57 @@ +# app/core/database.py +import logging from sqlalchemy import create_engine from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.dialects.postgresql import JSONB, UUID from app.core.config import settings -# Use the database URL from settings -engine = create_engine(settings.database_url) -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) +# Configure logging +logger = logging.getLogger(__name__) +# SQLite compatibility for testing +@compiles(JSONB, 'sqlite') +def compile_jsonb_sqlite(type_, compiler, **kw): + return "TEXT" + +@compiles(UUID, 'sqlite') +def compile_uuid_sqlite(type_, compiler, **kw): + return "TEXT" + +# Declarative base for models Base = declarative_base() +# Create engine with optimized settings for PostgreSQL +def create_production_engine(): + return create_engine( + settings.database_url, + # Connection pool settings + pool_size=settings.db_pool_size, + max_overflow=settings.db_max_overflow, + pool_timeout=settings.db_pool_timeout, + pool_recycle=settings.db_pool_recycle, + pool_pre_ping=True, + # Query execution settings + connect_args={ + "application_name": "eventspace", + "keepalives": 1, + "keepalives_idle": 60, + "keepalives_interval": 10, + "keepalives_count": 5, + "options": "-c timezone=UTC", + }, + isolation_level="READ COMMITTED", + echo=settings.sql_echo, + echo_pool=settings.sql_echo_pool, + ) -# Dependency to get DB session +# Default production engine and session factory +engine = create_production_engine() +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +# FastAPI dependency def get_db(): db = SessionLocal() try: diff --git a/backend/app/crud/__init__.py b/backend/app/crud/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/crud/base.py b/backend/app/crud/base.py new file mode 100644 index 0000000..18c27c2 --- /dev/null +++ b/backend/app/crud/base.py @@ -0,0 +1,62 @@ +from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union +from fastapi.encoders import jsonable_encoder +from pydantic import BaseModel +from sqlalchemy.orm import Session +from app.core.database import Base + +ModelType = TypeVar("ModelType", bound=Base) +CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) +UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) + + +class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): + def __init__(self, model: Type[ModelType]): + """ + CRUD object with default methods to Create, Read, Update, Delete (CRUD). + + Parameters: + model: A SQLAlchemy model class + """ + self.model = model + + def get(self, db: Session, id: str) -> Optional[ModelType]: + return db.query(self.model).filter(self.model.id == id).first() + + def get_multi( + self, db: Session, *, skip: int = 0, limit: int = 100 + ) -> List[ModelType]: + return db.query(self.model).offset(skip).limit(limit).all() + + def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType: + obj_in_data = jsonable_encoder(obj_in) + db_obj = self.model(**obj_in_data) + db.add(db_obj) + db.commit() + db.refresh(db_obj) + return db_obj + + def update( + self, + db: Session, + *, + db_obj: ModelType, + obj_in: Union[UpdateSchemaType, Dict[str, Any]] + ) -> ModelType: + obj_data = jsonable_encoder(db_obj) + if isinstance(obj_in, dict): + update_data = obj_in + else: + update_data = obj_in.model_dump(exclude_unset=True) + for field in obj_data: + if field in update_data: + setattr(db_obj, field, update_data[field]) + db.add(db_obj) + db.commit() + db.refresh(db_obj) + return db_obj + + def remove(self, db: Session, *, id: str) -> ModelType: + obj = db.query(self.model).get(id) + db.delete(obj) + db.commit() + return obj \ No newline at end of file diff --git a/backend/app/crud/user.py b/backend/app/crud/user.py new file mode 100644 index 0000000..5c8847e --- /dev/null +++ b/backend/app/crud/user.py @@ -0,0 +1,56 @@ +# app/crud/user.py +from typing import Optional, Union, Dict, Any +from sqlalchemy.orm import Session +from app.crud.base import CRUDBase +from app.models.user import User +from app.schemas.users import UserCreate, UserUpdate +from app.core.auth import get_password_hash + + +class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): + def get_by_email(self, db: Session, *, email: str) -> Optional[User]: + return db.query(User).filter(User.email == email).first() + + def create(self, db: Session, *, obj_in: UserCreate) -> User: + db_obj = User( + email=obj_in.email, + password_hash=get_password_hash(obj_in.password), + first_name=obj_in.first_name, + last_name=obj_in.last_name, + phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None, + is_superuser=obj_in.is_superuser if hasattr(obj_in, 'is_superuser') else False, + preferences={} + ) + db.add(db_obj) + db.commit() + db.refresh(db_obj) + return db_obj + + def update( + self, + db: Session, + *, + db_obj: User, + obj_in: Union[UserUpdate, Dict[str, Any]] + ) -> User: + if isinstance(obj_in, dict): + update_data = obj_in + else: + update_data = obj_in.model_dump(exclude_unset=True) + + # Handle password separately if it exists in update data + if "password" in update_data: + update_data["password_hash"] = get_password_hash(update_data["password"]) + del update_data["password"] + + return super().update(db, db_obj=db_obj, obj_in=update_data) + + def is_active(self, user: User) -> bool: + return user.is_active + + def is_superuser(self, user: User) -> bool: + return user.is_superuser + + +# Create a singleton instance for use across the application +user = CRUDUser(User) \ No newline at end of file diff --git a/backend/app/main.py b/backend/app/main.py index f305a55..c1ad77f 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -1,9 +1,18 @@ +import logging + +from apscheduler.schedulers.asyncio import AsyncIOScheduler from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse -from app.config import settings +from app.api.main import api_router +from app.core.config import settings +scheduler = AsyncIOScheduler() + +logger = logging.getLogger(__name__) + +logger.info(f"Starting app!!!") app = FastAPI( title=settings.PROJECT_NAME, version=settings.VERSION, @@ -34,3 +43,6 @@ async def root(): """ + + +app.include_router(api_router, prefix=settings.API_V1_STR) diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py new file mode 100644 index 0000000..b0998d0 --- /dev/null +++ b/backend/app/models/__init__.py @@ -0,0 +1,14 @@ +""" +Models package initialization. +Imports all models to ensure they're registered with SQLAlchemy. +""" +# First import Base to avoid circular imports +from app.core.database import Base +from .base import TimestampMixin, UUIDMixin + +# Import user model +from .user import User +__all__ = [ + 'Base', 'TimestampMixin', 'UUIDMixin', + 'User', +] \ No newline at end of file diff --git a/backend/app/models/base.py b/backend/app/models/base.py new file mode 100644 index 0000000..5a6f55e --- /dev/null +++ b/backend/app/models/base.py @@ -0,0 +1,20 @@ +import uuid +from datetime import datetime, timezone + +from sqlalchemy import Column, DateTime +from sqlalchemy.dialects.postgresql import UUID + +# noinspection PyUnresolvedReferences +from app.core.database import Base + + +class TimestampMixin: + """Mixin to add created_at and updated_at timestamps to models""" + created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False) + updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), + onupdate=lambda: datetime.now(timezone.utc), nullable=False) + + +class UUIDMixin: + """Mixin to add UUID primary keys to models""" + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) diff --git a/backend/app/models/user.py b/backend/app/models/user.py new file mode 100644 index 0000000..783de85 --- /dev/null +++ b/backend/app/models/user.py @@ -0,0 +1,19 @@ +from sqlalchemy import Column, String, JSON, Boolean + +from .base import Base, TimestampMixin, UUIDMixin + + +class User(Base, UUIDMixin, TimestampMixin): + __tablename__ = 'users' + + email = Column(String, unique=True, nullable=False, index=True) + password_hash = Column(String, nullable=False) + first_name = Column(String, nullable=False, default="user") + last_name = Column(String, nullable=True) + phone_number = Column(String) + is_active = Column(Boolean, default=True, nullable=False) + is_superuser = Column(Boolean, default=False, nullable=False) + preferences = Column(JSON) + + def __repr__(self): + return f"" \ No newline at end of file diff --git a/backend/app/schemas/__init__.py b/backend/app/schemas/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/schemas/users.py b/backend/app/schemas/users.py new file mode 100644 index 0000000..c97244a --- /dev/null +++ b/backend/app/schemas/users.py @@ -0,0 +1,149 @@ +# app/schemas/users.py +import re +from datetime import datetime +from typing import Optional, Dict, Any +from uuid import UUID + +from pydantic import BaseModel, EmailStr, field_validator, ConfigDict + + +class UserBase(BaseModel): + email: EmailStr + first_name: str + last_name: Optional[str] = None + phone_number: Optional[str] = None + + @field_validator('phone_number') + @classmethod + def validate_phone_number(cls, v: Optional[str]) -> Optional[str]: + if v is None: + return v + # Simple regex for phone validation + if not re.match(r'^\+?[0-9\s\-\(\)]{8,20}$', v): + raise ValueError('Invalid phone number format') + return v + + +class UserCreate(UserBase): + password: str + is_superuser: bool = False + + @field_validator('password') + @classmethod + def password_strength(cls, v: str) -> str: + """Basic password strength validation""" + if len(v) < 8: + raise ValueError('Password must be at least 8 characters') + if not any(char.isdigit() for char in v): + raise ValueError('Password must contain at least one digit') + if not any(char.isupper() for char in v): + raise ValueError('Password must contain at least one uppercase letter') + return v + + +class UserUpdate(BaseModel): + first_name: Optional[str] = None + last_name: Optional[str] = None + phone_number: Optional[str] = None + preferences: Optional[Dict[str, Any]] = None + is_active: Optional[bool] = True + @field_validator('phone_number') + def validate_phone_number(cls, v: Optional[str]) -> Optional[str]: + if v is None: + return v + + # Return early for empty strings or whitespace-only strings + if not v or v.strip() == "": + raise ValueError('Phone number cannot be empty') + + # Remove all spaces and formatting characters + cleaned = re.sub(r'[\s\-\(\)]', '', v) + + # Basic pattern: + # Must start with + or 0 + # After + must have at least 8 digits + # After 0 must have at least 8 digits + # Maximum total length of 15 digits (international standard) + # Only allowed characters are + at start and digits + pattern = r'^(?:\+[0-9]{8,14}|0[0-9]{8,14})$' + + if not re.match(pattern, cleaned): + raise ValueError('Phone number must start with + or 0 followed by 8-14 digits') + + # Additional validation to catch specific invalid cases + if cleaned.count('+') > 1: + raise ValueError('Phone number can only contain one + symbol at the start') + + # Check for any non-digit characters (except the leading +) + if not all(c.isdigit() for c in cleaned[1:]): + raise ValueError('Phone number can only contain digits after the prefix') + + return cleaned + + +class UserInDB(UserBase): + id: UUID + is_active: bool + is_superuser: bool + created_at: datetime + updated_at: Optional[datetime] = None + + model_config = ConfigDict(from_attributes=True) + + +class UserResponse(UserBase): + id: UUID + is_active: bool + is_superuser: bool + created_at: datetime + updated_at: Optional[datetime] = None + + model_config = ConfigDict(from_attributes=True) + + +class Token(BaseModel): + access_token: str + refresh_token: Optional[str] = None + token_type: str = "bearer" + + +class TokenPayload(BaseModel): + sub: str # User ID + exp: int # Expiration time + iat: Optional[int] = None # Issued at + jti: Optional[str] = None # JWT ID + is_superuser: Optional[bool] = False + first_name: Optional[str] = None + email: Optional[str] = None + type: Optional[str] = None # Token type (access/refresh) + + +class TokenData(BaseModel): + user_id: UUID + is_superuser: bool = False + + +class PasswordReset(BaseModel): + token: str + new_password: str + + @field_validator('new_password') + @classmethod + def password_strength(cls, v: str) -> str: + """Basic password strength validation""" + if len(v) < 8: + raise ValueError('Password must be at least 8 characters') + if not any(char.isdigit() for char in v): + raise ValueError('Password must contain at least one digit') + if not any(char.isupper() for char in v): + raise ValueError('Password must contain at least one uppercase letter') + return v + + +class LoginRequest(BaseModel): + email: EmailStr + password: str + + +class RefreshTokenRequest(BaseModel): + refresh_token: str diff --git a/backend/app/services/__init__.py b/backend/app/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/services/auth_service.py b/backend/app/services/auth_service.py new file mode 100644 index 0000000..4941671 --- /dev/null +++ b/backend/app/services/auth_service.py @@ -0,0 +1,193 @@ +# app/services/auth_service.py +import logging +from typing import Optional +from uuid import UUID + +from sqlalchemy.orm import Session + +from app.core.auth import ( + verify_password, + get_password_hash, + create_access_token, + create_refresh_token, + TokenExpiredError, + TokenInvalidError +) +from app.models.user import User +from app.schemas.users import Token, UserCreate + +logger = logging.getLogger(__name__) + + +class AuthenticationError(Exception): + """Exception raised for authentication errors""" + pass + + +class AuthService: + """Service for handling authentication operations""" + + @staticmethod + def authenticate_user(db: Session, email: str, password: str) -> Optional[User]: + """ + Authenticate a user with email and password. + + Args: + db: Database session + email: User email + password: User password + + Returns: + User if authenticated, None otherwise + """ + user = db.query(User).filter(User.email == email).first() + + if not user: + return None + + if not verify_password(password, user.password_hash): + return None + + if not user.is_active: + raise AuthenticationError("User account is inactive") + + return user + + @staticmethod + def create_user(db: Session, user_data: UserCreate) -> User: + """ + Create a new user. + + Args: + db: Database session + user_data: User data + + Returns: + Created user + """ + # Check if user already exists + existing_user = db.query(User).filter(User.email == user_data.email).first() + if existing_user: + raise AuthenticationError("User with this email already exists") + + # Create new user + hashed_password = get_password_hash(user_data.password) + + # Create user object from model + user = User( + email=user_data.email, + password_hash=hashed_password, + first_name=user_data.first_name, + last_name=user_data.last_name, + phone_number=user_data.phone_number, + is_active=True, + is_superuser=False + ) + + db.add(user) + db.commit() + db.refresh(user) + + return user + + @staticmethod + def create_tokens(user: User) -> Token: + """ + Create access and refresh tokens for a user. + + Args: + user: User to create tokens for + + Returns: + Token object with access and refresh tokens + """ + # Generate claims + claims = { + "is_superuser": user.is_superuser, + "email": user.email, + "first_name": user.first_name + } + + # Create tokens + access_token = create_access_token( + subject=str(user.id), + claims=claims + ) + + refresh_token = create_refresh_token( + subject=str(user.id) + ) + + return Token( + access_token=access_token, + refresh_token=refresh_token + ) + + @staticmethod + def refresh_tokens(db: Session, refresh_token: str) -> Token: + """ + Generate new tokens using a refresh token. + + Args: + db: Database session + refresh_token: Valid refresh token + + Returns: + New access and refresh tokens + + Raises: + TokenExpiredError: If refresh token has expired + TokenInvalidError: If refresh token is invalid + """ + from app.core.auth import decode_token, get_token_data + + try: + # Verify token is a refresh token + decode_token(refresh_token, verify_type="refresh") + + # Get user ID from token + token_data = get_token_data(refresh_token) + user_id = token_data.user_id + + # Get user from database + user = db.query(User).filter(User.id == user_id).first() + if not user or not user.is_active: + raise TokenInvalidError("Invalid user or inactive account") + + # Generate new tokens + return AuthService.create_tokens(user) + + except (TokenExpiredError, TokenInvalidError) as e: + logger.warning(f"Token refresh failed: {str(e)}") + raise + + @staticmethod + def change_password(db: Session, user_id: UUID, current_password: str, new_password: str) -> bool: + """ + Change a user's password. + + Args: + db: Database session + user_id: User ID + current_password: Current password + new_password: New password + + Returns: + True if password was changed successfully + + Raises: + AuthenticationError: If current password is incorrect + """ + user = db.query(User).filter(User.id == user_id).first() + if not user: + raise AuthenticationError("User not found") + + # Verify current password + if not verify_password(current_password, user.password_hash): + raise AuthenticationError("Current password is incorrect") + + # Update password + user.password_hash = get_password_hash(new_password) + db.commit() + + return True diff --git a/backend/app/utils/__init__.py b/backend/app/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/utils/test_utils.py b/backend/app/utils/test_utils.py new file mode 100644 index 0000000..26598b6 --- /dev/null +++ b/backend/app/utils/test_utils.py @@ -0,0 +1,79 @@ +import logging +from sqlalchemy import create_engine, event +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker, clear_mappers +from sqlalchemy.pool import StaticPool + +from app.core.database import Base + +logger = logging.getLogger(__name__) + +def get_test_engine(): + """Create an SQLite in-memory engine specifically for testing""" + test_engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, # Use static pool for in-memory testing + echo=False + ) + + return test_engine + +def setup_test_db(): + """Create a test database and session factory""" + # Create a new engine for this test run + test_engine = get_test_engine() + + # Create tables + Base.metadata.create_all(test_engine) + + # Create session factory + TestingSessionLocal = sessionmaker( + autocommit=False, + autoflush=False, + bind=test_engine, + expire_on_commit=False + ) + + return test_engine, TestingSessionLocal + +def teardown_test_db(engine): + """Clean up after tests""" + # Drop all tables + Base.metadata.drop_all(engine) + + # Dispose of engine + engine.dispose() + +async def get_async_test_engine(): + """Create an async SQLite in-memory engine specifically for testing""" + test_engine = create_async_engine( + "sqlite+aiosqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, # Use static pool for in-memory testing + echo=False + ) + return test_engine + + +async def setup_async_test_db(): + """Create an async test database and session factory""" + test_engine = await get_async_test_engine() + + async with test_engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + AsyncTestingSessionLocal = sessionmaker( + autocommit=False, + autoflush=False, + bind=test_engine, + expire_on_commit=False, + class_=AsyncSession + ) + + return test_engine, AsyncTestingSessionLocal + + +async def teardown_async_test_db(engine): + """Clean up after async tests""" + await engine.dispose() diff --git a/backend/pytest.ini b/backend/pytest.ini new file mode 100644 index 0000000..b46b024 --- /dev/null +++ b/backend/pytest.ini @@ -0,0 +1,10 @@ +[pytest] +env = + IS_TEST=True +testpaths = tests +python_files = test_*.py +addopts = --disable-warnings +markers = + sqlite: marks tests that should run on SQLite (mocked). + postgres: marks tests that require a real PostgreSQL database. +asyncio_default_fixture_loop_scope = function diff --git a/backend/requirements.txt b/backend/requirements.txt index 94bd973..ecc4682 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -4,13 +4,14 @@ uvicorn>=0.34.0 pydantic>=2.10.6 pydantic-settings>=2.2.1 python-multipart>=0.0.19 +fastapi-utils==0.8.0 # Database sqlalchemy>=2.0.29 alembic>=1.14.1 psycopg2-binary>=2.9.9 asyncpg>=0.29.0 - +aiosqlite==0.21.0 # Security and authentication python-jose>=3.4.0 passlib>=1.7.4 @@ -30,7 +31,7 @@ httpx>=0.27.0 tenacity>=8.2.3 pytz>=2024.1 pillow>=10.3.0 - +apscheduler==3.11.0 # Testing pytest>=8.0.0 pytest-asyncio>=0.23.5 @@ -41,4 +42,11 @@ requests>=2.32.0 black>=24.3.0 isort>=5.13.2 flake8>=7.0.0 -mypy>=1.8.0 \ No newline at end of file +mypy>=1.8.0 + +# Security +python-jose==3.4.0 +bcrypt==4.2.1 +cryptography==44.0.1 +passlib==1.7.4 +freezegun~=1.5.1 \ No newline at end of file diff --git a/backend/tests/api/routes/__init__.py b/backend/tests/api/routes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/api/routes/test_auth.py b/backend/tests/api/routes/test_auth.py new file mode 100644 index 0000000..dc3f99f --- /dev/null +++ b/backend/tests/api/routes/test_auth.py @@ -0,0 +1,369 @@ +# tests/api/routes/test_auth.py +import json +import uuid +from datetime import datetime, timezone +from unittest.mock import patch, MagicMock, Mock + +import pytest +from fastapi import FastAPI, Depends +from fastapi.testclient import TestClient +from sqlalchemy.orm import Session + +from app.api.routes.auth import router as auth_router +from app.core.auth import get_password_hash +from app.core.database import get_db +from app.models.user import User +from app.services.auth_service import AuthService, AuthenticationError +from app.core.auth import TokenExpiredError, TokenInvalidError + + +# Mock the get_db dependency +@pytest.fixture +def override_get_db(db_session): + """Override get_db dependency for testing.""" + return db_session + + +@pytest.fixture +def app(override_get_db): + """Create a FastAPI test application with overridden dependencies.""" + app = FastAPI() + app.include_router(auth_router, prefix="/auth", tags=["auth"]) + + # Override the get_db dependency + app.dependency_overrides[get_db] = lambda: override_get_db + + return app + + +@pytest.fixture +def client(app): + """Create a FastAPI test client.""" + return TestClient(app) + + +class TestRegisterUser: + """Tests for the register_user endpoint.""" + + def test_register_user_success(self, client, monkeypatch, db_session): + """Test successful user registration.""" + # Mock the service method with a valid complete User object + mock_user = User( + id=uuid.uuid4(), + email="newuser@example.com", + password_hash="hashed_password", + first_name="New", + last_name="User", + is_active=True, + is_superuser=False, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc) + ) + + # Use patch for mocking + with patch.object(AuthService, 'create_user', return_value=mock_user): + # Test request + response = client.post( + "/auth/register", + json={ + "email": "newuser@example.com", + "password": "Password123", + "first_name": "New", + "last_name": "User" + } + ) + + # Assertions + assert response.status_code == 201 + data = response.json() + assert data["email"] == "newuser@example.com" + assert data["first_name"] == "New" + assert data["last_name"] == "User" + assert "password" not in data + + def test_register_user_duplicate_email(self, client, db_session): + """Test registration with duplicate email.""" + # Use patch for mocking with a side effect + with patch.object(AuthService, 'create_user', + side_effect=AuthenticationError("User with this email already exists")): + # Test request + response = client.post( + "/auth/register", + json={ + "email": "existing@example.com", + "password": "Password123", + "first_name": "Existing", + "last_name": "User" + } + ) + + # Assertions + assert response.status_code == 409 + assert "already exists" in response.json()["detail"] + + +class TestLogin: + """Tests for the login endpoint.""" + + def test_login_success(self, client, mock_user, db_session): + """Test successful login.""" + # Ensure mock_user has required attributes + if not hasattr(mock_user, 'created_at') or mock_user.created_at is None: + mock_user.created_at = datetime.now(timezone.utc) + if not hasattr(mock_user, 'updated_at') or mock_user.updated_at is None: + mock_user.updated_at = datetime.now(timezone.utc) + + # Create mock tokens + mock_tokens = MagicMock( + access_token="mock_access_token", + refresh_token="mock_refresh_token", + token_type="bearer" + ) + + # Use context managers for patching + with patch.object(AuthService, 'authenticate_user', return_value=mock_user), \ + patch.object(AuthService, 'create_tokens', return_value=mock_tokens): + + # Test request + response = client.post( + "/auth/login", + json={ + "email": "user@example.com", + "password": "Password123" + } + ) + + # Assertions + assert response.status_code == 200 + data = response.json() + assert data["access_token"] == "mock_access_token" + assert data["refresh_token"] == "mock_refresh_token" + assert data["token_type"] == "bearer" + + + def test_login_invalid_credentials_debug(self, client, app): + """Improved test for login with invalid credentials.""" + # Print response for debugging + from app.services.auth_service import AuthService + + # Create a complete mock for AuthService + class MockAuthService: + @staticmethod + def authenticate_user(db, email, password): + print(f"Mock called with: {email}, {password}") + return None + + # Replace the entire class with our mock + original_service = AuthService + try: + # Replace with our mock + import sys + sys.modules['app.services.auth_service'].AuthService = MockAuthService + + # Make the request + response = client.post( + "/auth/login", + json={ + "email": "user@example.com", + "password": "WrongPassword" + } + ) + + # Print response details for debugging + print(f"Response status: {response.status_code}") + print(f"Response body: {response.text}") + + # Assertions + assert response.status_code == 401 + assert "Invalid email or password" in response.json()["detail"] + finally: + # Restore original service + sys.modules['app.services.auth_service'].AuthService = original_service + + + def test_login_inactive_user(self, client, db_session): + """Test login with inactive user.""" + # Mock authentication to raise an error + with patch.object(AuthService, 'authenticate_user', + side_effect=AuthenticationError("User account is inactive")): + # Test request + response = client.post( + "/auth/login", + json={ + "email": "inactive@example.com", + "password": "Password123" + } + ) + + # Assertions + assert response.status_code == 401 + assert "inactive" in response.json()["detail"] + + +class TestRefreshToken: + """Tests for the refresh_token endpoint.""" + + def test_refresh_token_success(self, client, db_session): + """Test successful token refresh.""" + # Mock refresh to return tokens + mock_tokens = MagicMock( + access_token="new_access_token", + refresh_token="new_refresh_token", + token_type="bearer" + ) + + with patch.object(AuthService, 'refresh_tokens', return_value=mock_tokens): + # Test request + response = client.post( + "/auth/refresh", + json={ + "refresh_token": "valid_refresh_token" + } + ) + + # Assertions + assert response.status_code == 200 + data = response.json() + assert data["access_token"] == "new_access_token" + assert data["refresh_token"] == "new_refresh_token" + assert data["token_type"] == "bearer" + + def test_refresh_token_expired(self, client, db_session): + """Test refresh with expired token.""" + # Mock refresh to raise expired token error + with patch.object(AuthService, 'refresh_tokens', + side_effect=TokenExpiredError("Token expired")): + # Test request + response = client.post( + "/auth/refresh", + json={ + "refresh_token": "expired_refresh_token" + } + ) + + # Assertions + assert response.status_code == 401 + assert "expired" in response.json()["detail"] + + def test_refresh_token_invalid(self, client, db_session): + """Test refresh with invalid token.""" + # Mock refresh to raise invalid token error + with patch.object(AuthService, 'refresh_tokens', + side_effect=TokenInvalidError("Invalid token")): + # Test request + response = client.post( + "/auth/refresh", + json={ + "refresh_token": "invalid_refresh_token" + } + ) + + # Assertions + assert response.status_code == 401 + assert "Invalid" in response.json()["detail"] + + +class TestChangePassword: + """Tests for the change_password endpoint.""" + + def test_change_password_success(self, client, mock_user, db_session, app): + """Test successful password change.""" + # Ensure mock_user has required attributes + if not hasattr(mock_user, 'created_at') or mock_user.created_at is None: + mock_user.created_at = datetime.now(timezone.utc) + if not hasattr(mock_user, 'updated_at') or mock_user.updated_at is None: + mock_user.updated_at = datetime.now(timezone.utc) + + # Override get_current_user dependency + from app.api.dependencies.auth import get_current_user + app.dependency_overrides[get_current_user] = lambda: mock_user + + # Mock password change to return success + with patch.object(AuthService, 'change_password', return_value=True): + # Test request + response = client.post( + "/auth/change-password", + json={ + "current_password": "OldPassword123", + "new_password": "NewPassword123" + } + ) + + # Assertions + assert response.status_code == 200 + assert "success" in response.json()["message"].lower() + + # Clean up override + if get_current_user in app.dependency_overrides: + del app.dependency_overrides[get_current_user] + + def test_change_password_incorrect_current_password(self, client, mock_user, db_session, app): + """Test change password with incorrect current password.""" + # Ensure mock_user has required attributes + if not hasattr(mock_user, 'created_at') or mock_user.created_at is None: + mock_user.created_at = datetime.now(timezone.utc) + if not hasattr(mock_user, 'updated_at') or mock_user.updated_at is None: + mock_user.updated_at = datetime.now(timezone.utc) + + # Override get_current_user dependency + from app.api.dependencies.auth import get_current_user + app.dependency_overrides[get_current_user] = lambda: mock_user + + # Mock password change to raise error + with patch.object(AuthService, 'change_password', + side_effect=AuthenticationError("Current password is incorrect")): + # Test request + response = client.post( + "/auth/change-password", + json={ + "current_password": "WrongPassword", + "new_password": "NewPassword123" + } + ) + + # Assertions + assert response.status_code == 400 + assert "incorrect" in response.json()["detail"].lower() + + # Clean up override + if get_current_user in app.dependency_overrides: + del app.dependency_overrides[get_current_user] + + +class TestGetCurrentUserInfo: + """Tests for the get_current_user_info endpoint.""" + + def test_get_current_user_info(self, client, mock_user, app): + """Test getting current user info.""" + # Ensure mock_user has required attributes + if not hasattr(mock_user, 'created_at') or mock_user.created_at is None: + mock_user.created_at = datetime.now(timezone.utc) + if not hasattr(mock_user, 'updated_at') or mock_user.updated_at is None: + mock_user.updated_at = datetime.now(timezone.utc) + + # Override get_current_user dependency + from app.api.dependencies.auth import get_current_user + app.dependency_overrides[get_current_user] = lambda: mock_user + + # Test request + response = client.get("/auth/me") + + # Assertions + assert response.status_code == 200 + data = response.json() + assert data["email"] == mock_user.email + assert data["first_name"] == mock_user.first_name + assert data["last_name"] == mock_user.last_name + assert "password" not in data + + # Clean up override + if get_current_user in app.dependency_overrides: + del app.dependency_overrides[get_current_user] + + def test_get_current_user_info_unauthorized(self, client): + """Test getting user info without authentication.""" + # Test request without authentication + response = client.get("/auth/me") + + # Assertions + assert response.status_code == 401 \ No newline at end of file diff --git a/backend/tests/api/test_auth_dependencies.py b/backend/tests/api/test_auth_dependencies.py new file mode 100644 index 0000000..1948d7f --- /dev/null +++ b/backend/tests/api/test_auth_dependencies.py @@ -0,0 +1,211 @@ +# tests/api/dependencies/test_auth_dependencies.py +import pytest +from unittest.mock import patch, MagicMock +from fastapi import HTTPException + +from app.api.dependencies.auth import ( + get_current_user, + get_current_active_user, + get_current_superuser, + get_optional_current_user +) +from app.core.auth import TokenExpiredError, TokenInvalidError + + +@pytest.fixture +def mock_token(): + return "mock.jwt.token" + + +class TestGetCurrentUser: + """Tests for get_current_user dependency""" + + def test_get_current_user_success(self, db_session, mock_user, mock_token): + """Test successfully getting the current user""" + # Mock get_token_data to return user_id that matches our mock_user + with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: + mock_get_data.return_value.user_id = mock_user.id + + # Call the dependency + user = get_current_user(db=db_session, token=mock_token) + + # Verify the correct user was returned + assert user.id == mock_user.id + assert user.email == mock_user.email + + def test_get_current_user_nonexistent(self, db_session, mock_token): + """Test when the token contains a user ID that doesn't exist""" + # Mock get_token_data to return a non-existent user ID + # Use a real UUID object instead of a string + import uuid + nonexistent_id = uuid.UUID("11111111-1111-1111-1111-111111111111") + + with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: + mock_get_data.return_value.user_id = nonexistent_id # Using UUID object, not string + + # Should raise HTTPException with 404 status + with pytest.raises(HTTPException) as exc_info: + get_current_user(db=db_session, token=mock_token) + + assert exc_info.value.status_code == 404 + + def test_get_current_user_inactive(self, db_session, mock_user, mock_token): + """Test when the user is inactive""" + # Make the user inactive + mock_user.is_active = False + db_session.commit() + + # Mock get_token_data + with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: + mock_get_data.return_value.user_id = mock_user.id + + # Should raise HTTPException with 403 status + with pytest.raises(HTTPException) as exc_info: + get_current_user(db=db_session, token=mock_token) + + assert exc_info.value.status_code == 403 + + def test_get_current_user_expired_token(self, db_session, mock_token): + """Test with an expired token""" + # Mock get_token_data to raise TokenExpiredError + with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: + mock_get_data.side_effect = TokenExpiredError("Token expired") + + # Should raise HTTPException with 401 status + with pytest.raises(HTTPException) as exc_info: + get_current_user(db=db_session, token=mock_token) + + assert exc_info.value.status_code == 401 + assert "Token expired" in exc_info.value.detail + + def test_get_current_user_invalid_token(self, db_session, mock_token): + """Test with an invalid token""" + # Mock get_token_data to raise TokenInvalidError + with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: + mock_get_data.side_effect = TokenInvalidError("Invalid token") + + # Should raise HTTPException with 401 status + with pytest.raises(HTTPException) as exc_info: + get_current_user(db=db_session, token=mock_token) + + assert exc_info.value.status_code == 401 + assert "Could not validate credentials" in exc_info.value.detail + + +class TestGetCurrentActiveUser: + """Tests for get_current_active_user dependency""" + + def test_get_current_active_user(self, mock_user): + """Test getting an active user""" + # Ensure user is active + mock_user.is_active = True + + # Call the dependency with mocked current_user + user = get_current_active_user(current_user=mock_user) + + # Should return the same user + assert user == mock_user + + def test_get_current_inactive_user(self, mock_user): + """Test getting an inactive user""" + # Make user inactive + mock_user.is_active = False + + # Should raise HTTPException with 403 status + with pytest.raises(HTTPException) as exc_info: + get_current_active_user(current_user=mock_user) + + assert exc_info.value.status_code == 403 + assert "Inactive user" in exc_info.value.detail + + +class TestGetCurrentSuperuser: + """Tests for get_current_superuser dependency""" + + def test_get_current_superuser(self, mock_user): + """Test getting a superuser""" + # Make user a superuser + mock_user.is_superuser = True + + # Call the dependency with mocked current_user + user = get_current_superuser(current_user=mock_user) + + # Should return the same user + assert user == mock_user + + def test_get_current_non_superuser(self, mock_user): + """Test getting a non-superuser""" + # Ensure user is not a superuser + mock_user.is_superuser = False + + # Should raise HTTPException with 403 status + with pytest.raises(HTTPException) as exc_info: + get_current_superuser(current_user=mock_user) + + assert exc_info.value.status_code == 403 + assert "Not enough permissions" in exc_info.value.detail + + +class TestGetOptionalCurrentUser: + """Tests for get_optional_current_user dependency""" + + def test_get_optional_current_user_with_token(self, db_session, mock_user, mock_token): + """Test getting optional user with a valid token""" + # Mock get_token_data + with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: + mock_get_data.return_value.user_id = mock_user.id + + # Call the dependency + user = get_optional_current_user(db=db_session, token=mock_token) + + # Should return the correct user + assert user is not None + assert user.id == mock_user.id + + def test_get_optional_current_user_no_token(self, db_session): + """Test getting optional user with no token""" + # Call the dependency with no token + user = get_optional_current_user(db=db_session, token=None) + + # Should return None + assert user is None + + def test_get_optional_current_user_invalid_token(self, db_session, mock_token): + """Test getting optional user with an invalid token""" + # Mock get_token_data to raise TokenInvalidError + with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: + mock_get_data.side_effect = TokenInvalidError("Invalid token") + + # Call the dependency + user = get_optional_current_user(db=db_session, token=mock_token) + + # Should return None, not raise an exception + assert user is None + + def test_get_optional_current_user_expired_token(self, db_session, mock_token): + """Test getting optional user with an expired token""" + # Mock get_token_data to raise TokenExpiredError + with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: + mock_get_data.side_effect = TokenExpiredError("Token expired") + + # Call the dependency + user = get_optional_current_user(db=db_session, token=mock_token) + + # Should return None, not raise an exception + assert user is None + + def test_get_optional_current_user_inactive(self, db_session, mock_user, mock_token): + """Test getting optional user when user is inactive""" + # Make the user inactive + mock_user.is_active = False + db_session.commit() + + # Mock get_token_data + with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: + mock_get_data.return_value.user_id = mock_user.id + + # Call the dependency + user = get_optional_current_user(db=db_session, token=mock_token) + + # Should return None for inactive users + assert user is None \ No newline at end of file diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py new file mode 100644 index 0000000..1b83a92 --- /dev/null +++ b/backend/tests/conftest.py @@ -0,0 +1,66 @@ +# tests/conftest.py +import uuid +from datetime import datetime, timezone + +import pytest + +from app.models.user import User +from app.utils.test_utils import setup_test_db, teardown_test_db, setup_async_test_db, teardown_async_test_db + + +@pytest.fixture(scope="function") +def db_session(): + """ + Creates a fresh SQLite in-memory database for each test function. + + Yields a SQLAlchemy session that can be used for testing. + """ + # Set up the database + test_engine, TestingSessionLocal = setup_test_db() + + # Create a session + with TestingSessionLocal() as session: + yield session + + # Clean up + teardown_test_db(test_engine) + + +@pytest.fixture(scope="function") # Define a fixture +async def async_test_db(): + """Fixture provides new testing engine and session for each test run to improve isolation.""" + + test_engine, AsyncTestingSessionLocal = await setup_async_test_db() + yield test_engine, AsyncTestingSessionLocal + await teardown_async_test_db(test_engine) + +@pytest.fixture +def user_create_data(): + return { + "email": "newtest@example.com", # Changed to avoid conflict with mock_user + "password": "TestPassword123!", + "first_name": "Test", + "last_name": "User", + "phone_number": "+1234567890", + "is_superuser": False, + "preferences": None + } + + +@pytest.fixture +def mock_user(db_session): + """Fixture to create and return a mock User instance.""" + mock_user = User( + id=uuid.uuid4(), + email="mockuser@example.com", + password_hash="mockhashedpassword", + first_name="Mock", + last_name="User", + phone_number="1234567890", + is_active=True, + is_superuser=False, + preferences=None, + ) + db_session.add(mock_user) + db_session.commit() + return mock_user \ No newline at end of file diff --git a/backend/tests/core/__init__.py b/backend/tests/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/core/test_auth.py b/backend/tests/core/test_auth.py new file mode 100644 index 0000000..7929e11 --- /dev/null +++ b/backend/tests/core/test_auth.py @@ -0,0 +1,260 @@ +# tests/core/test_auth.py +import uuid +import pytest +from datetime import datetime, timedelta, timezone +from jose import jwt +from pydantic import ValidationError + +from app.core.auth import ( + verify_password, + get_password_hash, + create_access_token, + create_refresh_token, + decode_token, + get_token_data, + TokenExpiredError, + TokenInvalidError, + TokenMissingClaimError +) +from app.core.config import settings + + +class TestPasswordHandling: + """Tests for password hashing and verification functions""" + + def test_password_hash_different_from_password(self): + """Test that a password hash is different from the original password""" + password = "TestPassword123" + hashed = get_password_hash(password) + assert hashed != password + + def test_verify_correct_password(self): + """Test that verify_password returns True for the correct password""" + password = "TestPassword123" + hashed = get_password_hash(password) + assert verify_password(password, hashed) is True + + def test_verify_incorrect_password(self): + """Test that verify_password returns False for an incorrect password""" + password = "TestPassword123" + wrong_password = "WrongPassword123" + hashed = get_password_hash(password) + assert verify_password(wrong_password, hashed) is False + + def test_same_password_different_hash(self): + """Test that the same password gets a different hash each time""" + password = "TestPassword123" + hash1 = get_password_hash(password) + hash2 = get_password_hash(password) + assert hash1 != hash2 + + +class TestTokenCreation: + """Tests for token creation functions""" + + def test_create_access_token(self): + """Test that an access token is created with the correct claims""" + user_id = str(uuid.uuid4()) + custom_claims = { + "email": "test@example.com", + "first_name": "Test", + "is_superuser": True + } + token = create_access_token(subject=user_id, claims=custom_claims) + + # Decode token to verify claims + payload = jwt.decode( + token, + settings.SECRET_KEY, + algorithms=[settings.ALGORITHM] + ) + + # Check standard claims + assert payload["sub"] == user_id + assert "jti" in payload + assert "exp" in payload + assert "iat" in payload + assert payload["type"] == "access" + + # Check custom claims + for key, value in custom_claims.items(): + assert payload[key] == value + + def test_create_refresh_token(self): + """Test that a refresh token is created with the correct claims""" + user_id = str(uuid.uuid4()) + token = create_refresh_token(subject=user_id) + + # Decode token to verify claims + payload = jwt.decode( + token, + settings.SECRET_KEY, + algorithms=[settings.ALGORITHM] + ) + + # Check standard claims + assert payload["sub"] == user_id + assert "jti" in payload + assert "exp" in payload + assert "iat" in payload + assert payload["type"] == "refresh" + + def test_token_expiration(self): + """Test that tokens have the correct expiration time""" + user_id = str(uuid.uuid4()) + expires = timedelta(minutes=5) + + # Create token with specific expiration + token = create_access_token( + subject=user_id, + expires_delta=expires + ) + + # Decode token + payload = jwt.decode( + token, + settings.SECRET_KEY, + algorithms=[settings.ALGORITHM] + ) + + # Get actual expiration time from token + expiration = datetime.fromtimestamp(payload["exp"], tz=timezone.utc) + + # Calculate expected expiration (approximately) + now = datetime.now(timezone.utc) + expected_expiration = now + expires + + # Difference should be small (less than 1 second) + difference = abs((expiration - expected_expiration).total_seconds()) + assert difference < 1 + + +class TestTokenDecoding: + """Tests for token decoding and validation functions""" + + def test_decode_valid_token(self): + """Test that a valid token can be decoded""" + user_id = str(uuid.uuid4()) + token = create_access_token(subject=user_id) + + # Decode token + payload = decode_token(token) + + # Check that the subject matches + assert payload.sub == user_id + + def test_decode_expired_token(self): + """Test that an expired token raises TokenExpiredError""" + user_id = str(uuid.uuid4()) + + # Create a token that's already expired by directly manipulating the payload + now = datetime.now(timezone.utc) + expired_time = now - timedelta(hours=1) # 1 hour in the past + + # Create the expired token manually + payload = { + "sub": user_id, + "exp": int(expired_time.timestamp()), # Set expiration in the past + "iat": int(now.timestamp()), + "jti": str(uuid.uuid4()), + "type": "access" + } + + expired_token = jwt.encode( + payload, + settings.SECRET_KEY, + algorithm=settings.ALGORITHM + ) + + # Attempting to decode should raise TokenExpiredError + with pytest.raises(TokenExpiredError): + decode_token(expired_token) + + def test_decode_invalid_token(self): + """Test that an invalid token raises TokenInvalidError""" + invalid_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJpbnZhbGlkIn0.invalid-signature" + + with pytest.raises(TokenInvalidError): + decode_token(invalid_token) + + def test_decode_token_with_missing_sub(self): + """Test that a token without 'sub' claim raises TokenMissingClaimError""" + # Create a token without a subject + now = datetime.now(timezone.utc) + payload = { + "exp": int((now + timedelta(minutes=30)).timestamp()), + "iat": int(now.timestamp()), + "jti": str(uuid.uuid4()), + "type": "access" + # No 'sub' claim + } + + token = jwt.encode( + payload, + settings.SECRET_KEY, + algorithm=settings.ALGORITHM + ) + + with pytest.raises(TokenMissingClaimError): + decode_token(token) + + def test_decode_token_with_wrong_type(self): + """Test that verifying a token with wrong type raises TokenInvalidError""" + user_id = str(uuid.uuid4()) + token = create_access_token(subject=user_id) + + # Try to verify it as a refresh token + with pytest.raises(TokenInvalidError): + decode_token(token, verify_type="refresh") + + def test_decode_with_invalid_payload(self): + """Test that a token with invalid payload structure raises TokenInvalidError""" + # Create a token with an invalid payload structure - missing 'sub' which is required + # but including 'exp' to avoid the expiration check + now = datetime.now(timezone.utc) + payload = { + # Missing "sub" field which is required + "exp": int((now + timedelta(minutes=30)).timestamp()), + "iat": int(now.timestamp()), + "jti": str(uuid.uuid4()), + "invalid_field": "test" + } + + token = jwt.encode( + payload, + settings.SECRET_KEY, + algorithm=settings.ALGORITHM + ) + + # Should raise TokenMissingClaimError due to missing 'sub' + with pytest.raises(TokenMissingClaimError): + decode_token(token) + + # Create another token with invalid type for required field + payload = { + "sub": 123, # sub should be a string, not an integer + "exp": int((now + timedelta(minutes=30)).timestamp()), + } + + token = jwt.encode( + payload, + settings.SECRET_KEY, + algorithm=settings.ALGORITHM + ) + + # Should raise TokenInvalidError due to ValidationError + with pytest.raises(TokenInvalidError): + decode_token(token) + + def test_get_token_data(self): + """Test extracting TokenData from a token""" + user_id = uuid.uuid4() + token = create_access_token( + subject=str(user_id), + claims={"is_superuser": True} + ) + + token_data = get_token_data(token) + + assert token_data.user_id == user_id + assert token_data.is_superuser is True \ No newline at end of file diff --git a/backend/tests/crud/__init__.py b/backend/tests/crud/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/crud/test_user.py b/backend/tests/crud/test_user.py new file mode 100644 index 0000000..26f4005 --- /dev/null +++ b/backend/tests/crud/test_user.py @@ -0,0 +1,125 @@ +import pytest + +from app.crud.user import user as user_crud +from app.models.user import User +from app.schemas.users import UserCreate, UserUpdate + + +def test_create_user(db_session, user_create_data): + user_in = UserCreate(**user_create_data) + user_obj = user_crud.create(db_session, obj_in=user_in) + + assert user_obj.email == user_create_data["email"] + assert user_obj.first_name == user_create_data["first_name"] + assert user_obj.last_name == user_create_data["last_name"] + assert user_obj.phone_number == user_create_data["phone_number"] + assert user_obj.is_superuser == user_create_data["is_superuser"] + assert user_obj.password_hash is not None + assert user_obj.id is not None + + +def test_get_user(db_session, mock_user): + # Using mock_user fixture instead of creating new user + stored_user = user_crud.get(db_session, id=mock_user.id) + assert stored_user + assert stored_user.id == mock_user.id + assert stored_user.email == mock_user.email + + +def test_get_user_by_email(db_session, mock_user): + stored_user = user_crud.get_by_email(db_session, email=mock_user.email) + assert stored_user + assert stored_user.id == mock_user.id + assert stored_user.email == mock_user.email + + +def test_update_user(db_session, mock_user): + update_data = UserUpdate( + first_name="Updated", + last_name="Name", + phone_number="+9876543210" + ) + + updated_user = user_crud.update(db_session, db_obj=mock_user, obj_in=update_data) + + assert updated_user.first_name == "Updated" + assert updated_user.last_name == "Name" + assert updated_user.phone_number == "+9876543210" + assert updated_user.email == mock_user.email + + +def test_delete_user(db_session, mock_user): + user_crud.remove(db_session, id=mock_user.id) + deleted_user = user_crud.get(db_session, id=mock_user.id) + assert deleted_user is None + + +def test_get_multi_users(db_session, mock_user, user_create_data): + # Create additional users (mock_user is already in db) + users_data = [ + {**user_create_data, "email": f"test{i}@example.com"} + for i in range(2) # Creating 2 more users + mock_user = 3 total + ] + + for user_data in users_data: + user_in = UserCreate(**user_data) + user_crud.create(db_session, obj_in=user_in) + + users = user_crud.get_multi(db_session, skip=0, limit=10) + assert len(users) == 3 + assert all(isinstance(user, User) for user in users) + + +def test_is_active(db_session, mock_user): + assert user_crud.is_active(mock_user) is True + + # Test deactivating user + update_data = UserUpdate(is_active=False) + deactivated_user = user_crud.update(db_session, db_obj=mock_user, obj_in=update_data) + assert user_crud.is_active(deactivated_user) is False + + +def test_is_superuser(db_session, mock_user, user_create_data): + # mock_user is regular user + assert user_crud.is_superuser(mock_user) is False + + # Create superuser + super_user_data = {**user_create_data, "email": "super@example.com", "is_superuser": True} + super_user_in = UserCreate(**super_user_data) + super_user = user_crud.create(db_session, obj_in=super_user_in) + assert user_crud.is_superuser(super_user) is True + + +# Additional test cases +def test_create_duplicate_email(db_session, mock_user): + user_data = UserCreate( + email=mock_user.email, # Try to create user with existing email + password="TestPassword123!", + first_name="Test", + last_name="User" + ) + with pytest.raises(Exception): # Should raise an integrity error + user_crud.create(db_session, obj_in=user_data) + + +def test_update_user_preferences(db_session, mock_user): + preferences = {"theme": "dark", "notifications": True} + update_data = UserUpdate(preferences=preferences) + + updated_user = user_crud.update(db_session, db_obj=mock_user, obj_in=update_data) + assert updated_user.preferences == preferences + + +def test_get_multi_users_pagination(db_session, user_create_data): + # Create 5 users + for i in range(5): + user_in = UserCreate(**{**user_create_data, "email": f"test{i}@example.com"}) + user_crud.create(db_session, obj_in=user_in) + + # Test pagination + first_page = user_crud.get_multi(db_session, skip=0, limit=2) + second_page = user_crud.get_multi(db_session, skip=2, limit=2) + + assert len(first_page) == 2 + assert len(second_page) == 2 + assert first_page[0].id != second_page[0].id diff --git a/backend/tests/models/__init__.py b/backend/tests/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/models/test_user.py b/backend/tests/models/test_user.py new file mode 100644 index 0000000..faaab0b --- /dev/null +++ b/backend/tests/models/test_user.py @@ -0,0 +1,249 @@ +# tests/models/test_user.py +import uuid +import pytest +from datetime import datetime +from sqlalchemy.exc import IntegrityError +from app.models.user import User + + +def test_create_user(db_session): + """Test creating a basic user.""" + # Arrange + user_id = uuid.uuid4() + new_user = User( + id=user_id, + email="test@example.com", + password_hash="hashedpassword", + first_name="Test", + last_name="User", + phone_number="1234567890", + is_active=True, + is_superuser=False, + preferences={"theme": "dark"}, + ) + db_session.add(new_user) + + # Act + db_session.commit() + created_user = db_session.query(User).filter_by(email="test@example.com").first() + + # Assert + assert created_user is not None + assert created_user.email == "test@example.com" + assert created_user.first_name == "Test" + assert created_user.last_name == "User" + assert created_user.phone_number == "1234567890" + assert created_user.is_active is True + assert created_user.is_superuser is False + assert created_user.preferences == {"theme": "dark"} + # UUID should be preserved + assert created_user.id == user_id + # Timestamps should be set + assert isinstance(created_user.created_at, datetime) + assert isinstance(created_user.updated_at, datetime) + + +def test_update_user(db_session): + """Test updating an existing user.""" + # Arrange - Create a user + user_id = uuid.uuid4() + user = User( + id=user_id, + email="update@example.com", + password_hash="hashedpassword", + first_name="Before", + last_name="Update", + ) + db_session.add(user) + db_session.commit() + + # Record the original creation timestamp + original_created_at = user.created_at + + # Act - Update the user + user.first_name = "After" + user.last_name = "Updated" + user.phone_number = "9876543210" + user.preferences = {"theme": "light", "notifications": True} + db_session.commit() + + # Fetch the updated user to verify changes were persisted + updated_user = db_session.query(User).filter_by(id=user_id).first() + + # Assert + assert updated_user.first_name == "After" + assert updated_user.last_name == "Updated" + assert updated_user.phone_number == "9876543210" + assert updated_user.preferences == {"theme": "light", "notifications": True} + # created_at shouldn't change on update + assert updated_user.created_at == original_created_at + # updated_at should be newer than created_at + assert updated_user.updated_at > original_created_at + + +def test_delete_user(db_session): + """Test deleting a user.""" + # Arrange - Create a user + user_id = uuid.uuid4() + user = User( + id=user_id, + email="delete@example.com", + password_hash="hashedpassword", + first_name="Delete", + last_name="Me", + ) + db_session.add(user) + db_session.commit() + + # Act - Delete the user + db_session.delete(user) + db_session.commit() + + # Assert + deleted_user = db_session.query(User).filter_by(id=user_id).first() + assert deleted_user is None + + +def test_user_unique_email_constraint(db_session): + """Test that users cannot have duplicate emails.""" + # Arrange - Create a user + user1 = User( + id=uuid.uuid4(), + email="duplicate@example.com", + password_hash="hashedpassword", + first_name="First", + last_name="User", + ) + db_session.add(user1) + db_session.commit() + + # Act & Assert - Try to create another user with the same email + user2 = User( + id=uuid.uuid4(), + email="duplicate@example.com", # Same email + password_hash="differenthash", + first_name="Second", + last_name="User", + ) + db_session.add(user2) + + # Should raise IntegrityError due to unique constraint + with pytest.raises(IntegrityError): + db_session.commit() + + # Rollback for cleanup + db_session.rollback() + + +def test_user_required_fields(db_session): + """Test that required fields are enforced.""" + # Test each required field by creating a user without it + + # Missing email + user_no_email = User( + id=uuid.uuid4(), + # email is missing + password_hash="hashedpassword", + first_name="Test", + last_name="User", + ) + db_session.add(user_no_email) + with pytest.raises(IntegrityError): + db_session.commit() + db_session.rollback() + + # Missing password_hash + user_no_password = User( + id=uuid.uuid4(), + email="nopassword@example.com", + # password_hash is missing + first_name="Test", + last_name="User", + ) + db_session.add(user_no_password) + with pytest.raises(IntegrityError): + db_session.commit() + db_session.rollback() + + + +def test_user_defaults(db_session): + """Test that default values are correctly set.""" + # Arrange - Create a minimal user with only required fields + minimal_user = User( + id=uuid.uuid4(), + email="minimal@example.com", + password_hash="hashedpassword", + first_name="Minimal", + last_name="User", + ) + db_session.add(minimal_user) + db_session.commit() + + # Act - Retrieve the user + created_user = db_session.query(User).filter_by(email="minimal@example.com").first() + + # Assert - Check default values + assert created_user.is_active is True # Default should be True + assert created_user.is_superuser is False # Default should be False + assert created_user.phone_number is None # Optional field + assert created_user.preferences is None # Optional field + + +def test_user_string_representation(db_session): + """Test the string representation of a user.""" + # Arrange + user = User( + id=uuid.uuid4(), + email="repr@example.com", + password_hash="hashedpassword", + first_name="String", + last_name="Repr", + ) + + # Act & Assert + assert str(user) == "" + assert repr(user) == "" + + +def test_user_with_complex_json_preferences(db_session): + """Test storing and retrieving complex JSON preferences.""" + # Arrange - Create a user with nested JSON preferences + complex_preferences = { + "theme": { + "mode": "dark", + "colors": { + "primary": "#333", + "secondary": "#666" + } + }, + "notifications": { + "email": True, + "sms": False, + "push": { + "enabled": True, + "quiet_hours": [22, 7] + } + }, + "tags": ["important", "family", "events"] + } + + user = User( + id=uuid.uuid4(), + email="complex@example.com", + password_hash="hashedpassword", + first_name="Complex", + last_name="JSON", + preferences=complex_preferences + ) + db_session.add(user) + db_session.commit() + + # Act - Retrieve the user + retrieved_user = db_session.query(User).filter_by(email="complex@example.com").first() + + # Assert - The complex JSON should be preserved + assert retrieved_user.preferences == complex_preferences + assert retrieved_user.preferences["theme"]["colors"]["primary"] == "#333" + assert retrieved_user.preferences["notifications"]["push"]["quiet_hours"] == [22, 7] + assert "important" in retrieved_user.preferences["tags"] \ No newline at end of file diff --git a/backend/tests/schemas/__init__.py b/backend/tests/schemas/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/schemas/test_user_schemas.py b/backend/tests/schemas/test_user_schemas.py new file mode 100644 index 0000000..67bf7b5 --- /dev/null +++ b/backend/tests/schemas/test_user_schemas.py @@ -0,0 +1,127 @@ +# tests/schemas/test_user_schemas.py +import pytest +import re +from pydantic import ValidationError + +from app.schemas.users import UserBase, UserCreate + +class TestPhoneNumberValidation: + """Tests for phone number validation in user schemas""" + + def test_valid_swiss_numbers(self): + """Test valid Swiss phone numbers are accepted""" + # International format + user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41791234567") + assert user.phone_number == "+41791234567" + + # Local format + user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0791234567") + assert user.phone_number == "0791234567" + + # With formatting characters + user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41 79 123 45 67") + assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+41791234567" + + user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="079 123 45 67") + assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "0791234567" + + user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41-79-123-45-67") + assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+41791234567" + + user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="079-123-45-67") + assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "0791234567" + + user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41 (79) 123 45 67") + assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+41791234567" + + user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="079 (123) 45 67") + assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "0791234567" + + def test_valid_italian_numbers(self): + """Test valid Italian phone numbers are accepted""" + # International format + user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+393451234567") + assert user.phone_number == "+393451234567" + + user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39345123456") + assert user.phone_number == "+39345123456" + + # Local format + user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="03451234567") + assert user.phone_number == "03451234567" + + user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345123456789") + assert user.phone_number == "0345123456789" + + # With formatting characters + user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39 345 123 4567") + assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+393451234567" + + user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345 123 4567") + assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "03451234567" + + user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39-345-123-4567") + assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+393451234567" + + user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345-123-4567") + assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "03451234567" + + user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39 (345) 123 4567") + assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+393451234567" + + user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345 (123) 4567") + assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "03451234567" + + def test_none_phone_number(self): + """Test that None is accepted as a valid value (optional phone number)""" + user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number=None) + assert user.phone_number is None + + def test_invalid_phone_numbers(self): + """Test that invalid phone numbers are rejected""" + invalid_numbers = [ + # Too short + "+12", + "012", + + # Invalid characters + "+41xyz123456", + "079abc4567", + "123-abc-7890", + "+1(800)CALL-NOW", + + # Completely invalid formats + "++4412345678", # Double plus + "()+41123456", # Misplaced parentheses + + # Empty string + "", + # Spaces only + " ", + ] + + for number in invalid_numbers: + with pytest.raises(ValidationError): + UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number=number) + + def test_phone_validation_in_user_create(self): + """Test that phone validation also works in UserCreate schema""" + # Valid phone number + user = UserCreate( + email="test@example.com", + first_name="Test", + last_name="User", + password="Password123", + phone_number="+41791234567" + ) + assert user.phone_number == "+41791234567" + + # Invalid phone number should raise ValidationError + with pytest.raises(ValidationError): + UserCreate( + email="test@example.com", + first_name="Test", + last_name="User", + password="Password123", + phone_number="invalid-number" + ) \ No newline at end of file diff --git a/backend/tests/services/__init__.py b/backend/tests/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/services/test_auth_service.py b/backend/tests/services/test_auth_service.py new file mode 100644 index 0000000..b043a3c --- /dev/null +++ b/backend/tests/services/test_auth_service.py @@ -0,0 +1,252 @@ +# tests/services/test_auth_service.py +import uuid +import pytest +from unittest.mock import patch + +from app.core.auth import get_password_hash, verify_password, TokenExpiredError, TokenInvalidError +from app.models.user import User +from app.schemas.users import UserCreate, Token +from app.services.auth_service import AuthService, AuthenticationError + + +class TestAuthServiceAuthentication: + """Tests for AuthService.authenticate_user method""" + + def test_authenticate_valid_user(self, db_session, mock_user): + """Test authenticating a user with valid credentials""" + # Set a known password for the mock user + password = "TestPassword123" + mock_user.password_hash = get_password_hash(password) + db_session.commit() + + # Authenticate with correct credentials + user = AuthService.authenticate_user( + db=db_session, + email=mock_user.email, + password=password + ) + + assert user is not None + assert user.id == mock_user.id + assert user.email == mock_user.email + + def test_authenticate_nonexistent_user(self, db_session): + """Test authenticating with an email that doesn't exist""" + user = AuthService.authenticate_user( + db=db_session, + email="nonexistent@example.com", + password="password" + ) + + assert user is None + + def test_authenticate_with_wrong_password(self, db_session, mock_user): + """Test authenticating with the wrong password""" + # Set a known password for the mock user + password = "TestPassword123" + mock_user.password_hash = get_password_hash(password) + db_session.commit() + + # Authenticate with wrong password + user = AuthService.authenticate_user( + db=db_session, + email=mock_user.email, + password="WrongPassword123" + ) + + assert user is None + + def test_authenticate_inactive_user(self, db_session, mock_user): + """Test authenticating an inactive user""" + # Set a known password and make user inactive + password = "TestPassword123" + mock_user.password_hash = get_password_hash(password) + mock_user.is_active = False + db_session.commit() + + # Should raise AuthenticationError + with pytest.raises(AuthenticationError): + AuthService.authenticate_user( + db=db_session, + email=mock_user.email, + password=password + ) + + +class TestAuthServiceUserCreation: + """Tests for AuthService.create_user method""" + + def test_create_new_user(self, db_session): + """Test creating a new user""" + user_data = UserCreate( + email="newuser@example.com", + password="TestPassword123", + first_name="New", + last_name="User", + phone_number="1234567890" + ) + + user = AuthService.create_user(db=db_session, user_data=user_data) + + # Verify user was created with correct data + assert user is not None + assert user.email == user_data.email + assert user.first_name == user_data.first_name + assert user.last_name == user_data.last_name + assert user.phone_number == user_data.phone_number + + # Verify password was hashed + assert user.password_hash != user_data.password + assert verify_password(user_data.password, user.password_hash) + + # Verify default values + assert user.is_active is True + assert user.is_superuser is False + + def test_create_user_with_existing_email(self, db_session, mock_user): + """Test creating a user with an email that already exists""" + user_data = UserCreate( + email=mock_user.email, # Use existing email + password="TestPassword123", + first_name="Duplicate", + last_name="User" + ) + + # Should raise AuthenticationError + with pytest.raises(AuthenticationError): + AuthService.create_user(db=db_session, user_data=user_data) + + +class TestAuthServiceTokens: + """Tests for AuthService token-related methods""" + + def test_create_tokens(self, mock_user): + """Test creating access and refresh tokens for a user""" + tokens = AuthService.create_tokens(mock_user) + + # Verify token structure + assert isinstance(tokens, Token) + assert tokens.access_token is not None + assert tokens.refresh_token is not None + assert tokens.token_type == "bearer" + + # This is a more in-depth test that would decode the tokens to verify claims + # but we'll rely on the auth module tests for token verification + + def test_refresh_tokens(self, db_session, mock_user): + """Test refreshing tokens with a valid refresh token""" + # Create initial tokens + initial_tokens = AuthService.create_tokens(mock_user) + + # Refresh tokens + new_tokens = AuthService.refresh_tokens( + db=db_session, + refresh_token=initial_tokens.refresh_token + ) + + # Verify new tokens are different from old ones + assert new_tokens.access_token != initial_tokens.access_token + assert new_tokens.refresh_token != initial_tokens.refresh_token + + def test_refresh_tokens_with_invalid_token(self, db_session): + """Test refreshing tokens with an invalid token""" + # Create an invalid token + invalid_token = "invalid.token.string" + + # Should raise TokenInvalidError + with pytest.raises(TokenInvalidError): + AuthService.refresh_tokens( + db=db_session, + refresh_token=invalid_token + ) + + def test_refresh_tokens_with_access_token(self, db_session, mock_user): + """Test refreshing tokens with an access token instead of refresh token""" + # Create tokens + tokens = AuthService.create_tokens(mock_user) + + # Try to refresh with access token + with pytest.raises(TokenInvalidError): + AuthService.refresh_tokens( + db=db_session, + refresh_token=tokens.access_token + ) + + def test_refresh_tokens_with_nonexistent_user(self, db_session): + """Test refreshing tokens for a user that doesn't exist in the database""" + # Create a token for a non-existent user + non_existent_id = str(uuid.uuid4()) + with patch('app.core.auth.decode_token'), patch('app.core.auth.get_token_data') as mock_get_data: + # Mock the token data to return a non-existent user ID + mock_get_data.return_value.user_id = uuid.UUID(non_existent_id) + + # Should raise TokenInvalidError + with pytest.raises(TokenInvalidError): + AuthService.refresh_tokens( + db=db_session, + refresh_token="some.refresh.token" + ) + + +class TestAuthServicePasswordChange: + """Tests for AuthService.change_password method""" + + def test_change_password(self, db_session, mock_user): + """Test changing a user's password""" + # Set a known password for the mock user + current_password = "CurrentPassword123" + mock_user.password_hash = get_password_hash(current_password) + db_session.commit() + + # Change password + new_password = "NewPassword456" + result = AuthService.change_password( + db=db_session, + user_id=mock_user.id, + current_password=current_password, + new_password=new_password + ) + + # Verify operation was successful + assert result is True + + # Refresh user from DB + db_session.refresh(mock_user) + + # Verify old password no longer works + assert not verify_password(current_password, mock_user.password_hash) + + # Verify new password works + assert verify_password(new_password, mock_user.password_hash) + + def test_change_password_wrong_current_password(self, db_session, mock_user): + """Test changing password with incorrect current password""" + # Set a known password for the mock user + current_password = "CurrentPassword123" + mock_user.password_hash = get_password_hash(current_password) + db_session.commit() + + # Try to change password with wrong current password + wrong_password = "WrongPassword123" + with pytest.raises(AuthenticationError): + AuthService.change_password( + db=db_session, + user_id=mock_user.id, + current_password=wrong_password, + new_password="NewPassword456" + ) + + # Verify password was not changed + assert verify_password(current_password, mock_user.password_hash) + + def test_change_password_nonexistent_user(self, db_session): + """Test changing password for a user that doesn't exist""" + non_existent_id = uuid.uuid4() + + with pytest.raises(AuthenticationError): + AuthService.change_password( + db=db_session, + user_id=non_existent_id, + current_password="CurrentPassword123", + new_password="NewPassword456" + ) \ No newline at end of file