Add comprehensive test suite and utilities for user functionality

This commit introduces a suite of tests for user models, schemas, CRUD operations, and authentication services. It also adds utilities for in-memory database setup to support these tests and updates environment settings for consistency.
This commit is contained in:
2025-03-04 19:10:54 +01:00
parent 481b6d618e
commit 162e586e13
40 changed files with 2948 additions and 11 deletions

View File

@@ -0,0 +1,46 @@
"""Add all initial models
Revision ID: 38bf9e7e74b3
Revises: 7396957cbe80
Create Date: 2025-02-28 09:19:33.212278
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = '38bf9e7e74b3'
down_revision: Union[str, None] = '7396957cbe80'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table('users',
sa.Column('email', sa.String(), nullable=False),
sa.Column('password_hash', sa.String(), nullable=False),
sa.Column('first_name', sa.String(), nullable=False),
sa.Column('last_name', sa.String(), nullable=True),
sa.Column('phone_number', sa.String(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('is_superuser', sa.Boolean(), nullable=False),
sa.Column('preferences', sa.JSON(), nullable=True),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_users_email'), table_name='users')
op.drop_table('users')
# ### end Alembic commands ###

View File

View File

View File

@@ -0,0 +1,137 @@
from typing import Optional
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.orm import Session
from app.core.auth import get_token_data, TokenExpiredError, TokenInvalidError
from app.core.database import get_db
from app.models.user import User
# OAuth2 configuration
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
def get_current_user(
db: Session = Depends(get_db),
token: str = Depends(oauth2_scheme)
) -> User:
"""
Get the current authenticated user.
Args:
db: Database session
token: JWT token from request
Returns:
User: The authenticated user
Raises:
HTTPException: If authentication fails
"""
try:
# Decode token and get user ID
token_data = get_token_data(token)
# Get user from database
user = db.query(User).filter(User.id == token_data.user_id).first()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Inactive user"
)
return user
except TokenExpiredError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token expired",
headers={"WWW-Authenticate": "Bearer"}
)
except TokenInvalidError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"}
)
def get_current_active_user(
current_user: User = Depends(get_current_user)
) -> User:
"""
Check if the current user is active.
Args:
current_user: The current authenticated user
Returns:
User: The authenticated and active user
Raises:
HTTPException: If user is inactive
"""
if not current_user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Inactive user"
)
return current_user
def get_current_superuser(
current_user: User = Depends(get_current_user)
) -> User:
"""
Check if the current user is a superuser.
Args:
current_user: The current authenticated user
Returns:
User: The authenticated superuser
Raises:
HTTPException: If user is not a superuser
"""
if not current_user.is_superuser:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions"
)
return current_user
def get_optional_current_user(
db: Session = Depends(get_db),
token: Optional[str] = Depends(oauth2_scheme)
) -> Optional[User]:
"""
Get the current user if authenticated, otherwise return None.
Useful for endpoints that work with both authenticated and unauthenticated users.
Args:
db: Database session
token: JWT token from request
Returns:
User or None: The authenticated user or None
"""
if not token:
return None
try:
token_data = get_token_data(token)
user = db.query(User).filter(User.id == token_data.user_id).first()
if not user or not user.is_active:
return None
return user
except (TokenExpiredError, TokenInvalidError):
return None

6
backend/app/api/main.py Normal file
View File

@@ -0,0 +1,6 @@
from fastapi import APIRouter
from app.api.routes import auth
api_router = APIRouter()
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])

View File

View File

@@ -0,0 +1,231 @@
# app/api/routes/auth.py
import logging
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, status, Body
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.orm import Session
from app.api.dependencies.auth import get_current_user
from app.core.auth import TokenExpiredError, TokenInvalidError
from app.core.database import get_db
from app.models.user import User
from app.schemas.users import (
UserCreate,
UserResponse,
Token,
LoginRequest,
RefreshTokenRequest
)
from app.services.auth_service import AuthService, AuthenticationError
router = APIRouter()
logger = logging.getLogger(__name__)
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def register_user(
user_data: UserCreate,
db: Session = Depends(get_db)
) -> Any:
"""
Register a new user.
Returns:
The created user information.
"""
try:
user = AuthService.create_user(db, user_data)
return user
except AuthenticationError as e:
logger.warning(f"Registration failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=str(e)
)
except Exception as e:
logger.error(f"Unexpected error during registration: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred. Please try again later."
)
@router.post("/login", response_model=Token)
async def login(
login_data: LoginRequest,
db: Session = Depends(get_db)
) -> Any:
"""
Login with username and password.
Returns:
Access and refresh tokens.
"""
try:
# Attempt to authenticate the user
user = AuthService.authenticate_user(db, login_data.email, login_data.password)
# Explicitly check for None result and raise correct exception
if user is None:
logger.warning(f"Invalid login attempt for: {login_data.email}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid email or password",
headers={"WWW-Authenticate": "Bearer"},
)
# User is authenticated, generate tokens
tokens = AuthService.create_tokens(user)
logger.info(f"User login successful: {user.email}")
return tokens
except HTTPException:
# Re-raise HTTP exceptions without modification
raise
except AuthenticationError as e:
# Handle specific authentication errors like inactive accounts
logger.warning(f"Authentication failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(e),
headers={"WWW-Authenticate": "Bearer"},
)
except Exception as e:
# Handle unexpected errors
logger.error(f"Unexpected error during login: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred. Please try again later."
)
@router.post("/login/oauth", response_model=Token)
async def login_oauth(
form_data: OAuth2PasswordRequestForm = Depends(),
db: Session = Depends(get_db)
) -> Any:
"""
OAuth2-compatible login endpoint, used by the OpenAPI UI.
Returns:
Access and refresh tokens.
"""
try:
user = AuthService.authenticate_user(db, form_data.username, form_data.password)
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid email or password",
headers={"WWW-Authenticate": "Bearer"},
)
# Generate tokens
tokens = AuthService.create_tokens(user)
# Format response for OAuth compatibility
return {
"access_token": tokens.access_token,
"refresh_token": tokens.refresh_token,
"token_type": tokens.token_type
}
except HTTPException:
raise
except AuthenticationError as e:
logger.warning(f"OAuth authentication failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(e),
headers={"WWW-Authenticate": "Bearer"},
)
except Exception as e:
logger.error(f"Unexpected error during OAuth login: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred. Please try again later."
)
@router.post("/refresh", response_model=Token)
async def refresh_token(
refresh_data: RefreshTokenRequest,
db: Session = Depends(get_db)
) -> Any:
"""
Refresh access token using a refresh token.
Returns:
New access and refresh tokens.
"""
try:
tokens = AuthService.refresh_tokens(db, refresh_data.refresh_token)
return tokens
except TokenExpiredError:
logger.warning("Token refresh failed: Token expired")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Refresh token has expired. Please log in again.",
headers={"WWW-Authenticate": "Bearer"},
)
except TokenInvalidError:
logger.warning("Token refresh failed: Invalid token")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token",
headers={"WWW-Authenticate": "Bearer"},
)
except Exception as e:
logger.error(f"Unexpected error during token refresh: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred. Please try again later."
)
@router.post("/change-password", status_code=status.HTTP_200_OK)
async def change_password(
current_password: str = Body(..., embed=True),
new_password: str = Body(..., embed=True),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
) -> Any:
"""
Change current user's password.
Requires authentication.
"""
try:
success = AuthService.change_password(
db=db,
user_id=current_user.id,
current_password=current_password,
new_password=new_password
)
if success:
return {"message": "Password changed successfully"}
except AuthenticationError as e:
logger.warning(f"Password change failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
logger.error(f"Unexpected error during password change: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred. Please try again later."
)
@router.get("/me", response_model=UserResponse)
async def get_current_user_info(
current_user: User = Depends(get_current_user)
) -> Any:
"""
Get current user information.
Requires authentication.
"""
return current_user

185
backend/app/core/auth.py Normal file
View File

@@ -0,0 +1,185 @@
import logging
logging.getLogger('passlib').setLevel(logging.ERROR)
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional, Union
import uuid
from jose import jwt, JWTError
from passlib.context import CryptContext
from pydantic import ValidationError
from app.core.config import settings
from app.schemas.users import TokenData, TokenPayload
# Password hashing context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# Custom exceptions for auth
class AuthError(Exception):
"""Base authentication error"""
pass
class TokenExpiredError(AuthError):
"""Token has expired"""
pass
class TokenInvalidError(AuthError):
"""Token is invalid"""
pass
class TokenMissingClaimError(AuthError):
"""Token is missing a required claim"""
pass
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against a hash."""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
"""Generate a password hash."""
return pwd_context.hash(password)
def create_access_token(
subject: Union[str, Any],
expires_delta: Optional[timedelta] = None,
claims: Optional[Dict[str, Any]] = None
) -> str:
"""
Create a JWT access token.
Args:
subject: The subject of the token (usually user_id)
expires_delta: Optional expiration time delta
claims: Optional additional claims to include in the token
Returns:
Encoded JWT token
"""
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
# Base token data
to_encode = {
"sub": str(subject),
"exp": expire,
"iat": datetime.now(tz=timezone.utc),
"jti": str(uuid.uuid4()),
"type": "access"
}
# Add custom claims
if claims:
to_encode.update(claims)
# Create the JWT
encoded_jwt = jwt.encode(
to_encode,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
)
return encoded_jwt
def create_refresh_token(
subject: Union[str, Any],
expires_delta: Optional[timedelta] = None
) -> str:
"""
Create a JWT refresh token.
Args:
subject: The subject of the token (usually user_id)
expires_delta: Optional expiration time delta
Returns:
Encoded JWT refresh token
"""
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
to_encode = {
"sub": str(subject),
"exp": expire,
"iat": datetime.now(timezone.utc),
"jti": str(uuid.uuid4()),
"type": "refresh"
}
encoded_jwt = jwt.encode(
to_encode,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
)
return encoded_jwt
def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload:
"""
Decode and verify a JWT token.
Args:
token: JWT token to decode
verify_type: Optional token type to verify (access or refresh)
Returns:
TokenPayload: The decoded token data
Raises:
TokenExpiredError: If the token has expired
TokenInvalidError: If the token is invalid
TokenMissingClaimError: If a required claim is missing
"""
try:
payload = jwt.decode(
token,
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM]
)
# Check required claims before Pydantic validation
if not payload.get("sub"):
raise TokenMissingClaimError("Token missing 'sub' claim")
# Verify token type if specified
if verify_type and payload.get("type") != verify_type:
raise TokenInvalidError(f"Invalid token type, expected {verify_type}")
# Now validate with Pydantic
token_data = TokenPayload(**payload)
return token_data
except JWTError as e:
# Check if the error is due to an expired token
if "expired" in str(e).lower():
raise TokenExpiredError("Token has expired")
raise TokenInvalidError("Invalid authentication token")
except ValidationError:
raise TokenInvalidError("Invalid token payload")
def get_token_data(token: str) -> TokenData:
"""
Extract the user ID and superuser status from a token.
Args:
token: JWT token
Returns:
TokenData with user_id and is_superuser
"""
payload = decode_token(token)
user_id = payload.sub
is_superuser = payload.is_superuser or False
return TokenData(user_id=uuid.UUID(user_id), is_superuser=is_superuser)

View File

@@ -3,7 +3,7 @@ from typing import Optional, List
class Settings(BaseSettings):
PROJECT_NAME: str = "App"
PROJECT_NAME: str = "EventSpace"
VERSION: str = "1.0.0"
API_V1_STR: str = "/api/v1"
@@ -14,6 +14,17 @@ class Settings(BaseSettings):
POSTGRES_PORT: str = "5432"
POSTGRES_DB: str = "app"
DATABASE_URL: Optional[str] = None
REFRESH_TOKEN_EXPIRE_DAYS: int = 60
db_pool_size: int = 20 # Default connection pool size
db_max_overflow: int = 50 # Maximum overflow connections
db_pool_timeout: int = 30 # Seconds to wait for a connection
db_pool_recycle: int = 3600 # Recycle connections after 1 hour
# SQL debugging (disable in production)
sql_echo: bool = False # Log SQL statements
sql_echo_pool: bool = False # Log connection pool events
sql_echo_timing: bool = False # Log query execution times
slow_query_threshold: float = 0.5 # Log queries taking longer than this
@property
def database_url(self) -> str:
@@ -30,7 +41,7 @@ class Settings(BaseSettings):
# JWT configuration
SECRET_KEY: str = "your_secret_key_here"
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
ACCESS_TOKEN_EXPIRE_MINUTES: int = 1440 # 1 day
# CORS configuration
BACKEND_CORS_ORIGINS: List[str] = ["http://localhost:3000"]

View File

@@ -1,17 +1,57 @@
# app/core/database.py
import logging
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.dialects.postgresql import JSONB, UUID
from app.core.config import settings
# Use the database URL from settings
engine = create_engine(settings.database_url)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Configure logging
logger = logging.getLogger(__name__)
# SQLite compatibility for testing
@compiles(JSONB, 'sqlite')
def compile_jsonb_sqlite(type_, compiler, **kw):
return "TEXT"
@compiles(UUID, 'sqlite')
def compile_uuid_sqlite(type_, compiler, **kw):
return "TEXT"
# Declarative base for models
Base = declarative_base()
# Create engine with optimized settings for PostgreSQL
def create_production_engine():
return create_engine(
settings.database_url,
# Connection pool settings
pool_size=settings.db_pool_size,
max_overflow=settings.db_max_overflow,
pool_timeout=settings.db_pool_timeout,
pool_recycle=settings.db_pool_recycle,
pool_pre_ping=True,
# Query execution settings
connect_args={
"application_name": "eventspace",
"keepalives": 1,
"keepalives_idle": 60,
"keepalives_interval": 10,
"keepalives_count": 5,
"options": "-c timezone=UTC",
},
isolation_level="READ COMMITTED",
echo=settings.sql_echo,
echo_pool=settings.sql_echo_pool,
)
# Dependency to get DB session
# Default production engine and session factory
engine = create_production_engine()
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# FastAPI dependency
def get_db():
db = SessionLocal()
try:

View File

62
backend/app/crud/base.py Normal file
View File

@@ -0,0 +1,62 @@
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.core.database import Base
ModelType = TypeVar("ModelType", bound=Base)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
def __init__(self, model: Type[ModelType]):
"""
CRUD object with default methods to Create, Read, Update, Delete (CRUD).
Parameters:
model: A SQLAlchemy model class
"""
self.model = model
def get(self, db: Session, id: str) -> Optional[ModelType]:
return db.query(self.model).filter(self.model.id == id).first()
def get_multi(
self, db: Session, *, skip: int = 0, limit: int = 100
) -> List[ModelType]:
return db.query(self.model).offset(skip).limit(limit).all()
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
obj_in_data = jsonable_encoder(obj_in)
db_obj = self.model(**obj_in_data)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
def update(
self,
db: Session,
*,
db_obj: ModelType,
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
) -> ModelType:
obj_data = jsonable_encoder(db_obj)
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.model_dump(exclude_unset=True)
for field in obj_data:
if field in update_data:
setattr(db_obj, field, update_data[field])
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
def remove(self, db: Session, *, id: str) -> ModelType:
obj = db.query(self.model).get(id)
db.delete(obj)
db.commit()
return obj

56
backend/app/crud/user.py Normal file
View File

@@ -0,0 +1,56 @@
# app/crud/user.py
from typing import Optional, Union, Dict, Any
from sqlalchemy.orm import Session
from app.crud.base import CRUDBase
from app.models.user import User
from app.schemas.users import UserCreate, UserUpdate
from app.core.auth import get_password_hash
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
def get_by_email(self, db: Session, *, email: str) -> Optional[User]:
return db.query(User).filter(User.email == email).first()
def create(self, db: Session, *, obj_in: UserCreate) -> User:
db_obj = User(
email=obj_in.email,
password_hash=get_password_hash(obj_in.password),
first_name=obj_in.first_name,
last_name=obj_in.last_name,
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None,
is_superuser=obj_in.is_superuser if hasattr(obj_in, 'is_superuser') else False,
preferences={}
)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
def update(
self,
db: Session,
*,
db_obj: User,
obj_in: Union[UserUpdate, Dict[str, Any]]
) -> User:
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.model_dump(exclude_unset=True)
# Handle password separately if it exists in update data
if "password" in update_data:
update_data["password_hash"] = get_password_hash(update_data["password"])
del update_data["password"]
return super().update(db, db_obj=db_obj, obj_in=update_data)
def is_active(self, user: User) -> bool:
return user.is_active
def is_superuser(self, user: User) -> bool:
return user.is_superuser
# Create a singleton instance for use across the application
user = CRUDUser(User)

View File

@@ -1,9 +1,18 @@
import logging
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from app.config import settings
from app.api.main import api_router
from app.core.config import settings
scheduler = AsyncIOScheduler()
logger = logging.getLogger(__name__)
logger.info(f"Starting app!!!")
app = FastAPI(
title=settings.PROJECT_NAME,
version=settings.VERSION,
@@ -34,3 +43,6 @@ async def root():
</body>
</html>
"""
app.include_router(api_router, prefix=settings.API_V1_STR)

View File

@@ -0,0 +1,14 @@
"""
Models package initialization.
Imports all models to ensure they're registered with SQLAlchemy.
"""
# First import Base to avoid circular imports
from app.core.database import Base
from .base import TimestampMixin, UUIDMixin
# Import user model
from .user import User
__all__ = [
'Base', 'TimestampMixin', 'UUIDMixin',
'User',
]

View File

@@ -0,0 +1,20 @@
import uuid
from datetime import datetime, timezone
from sqlalchemy import Column, DateTime
from sqlalchemy.dialects.postgresql import UUID
# noinspection PyUnresolvedReferences
from app.core.database import Base
class TimestampMixin:
"""Mixin to add created_at and updated_at timestamps to models"""
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False)
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc), nullable=False)
class UUIDMixin:
"""Mixin to add UUID primary keys to models"""
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)

View File

@@ -0,0 +1,19 @@
from sqlalchemy import Column, String, JSON, Boolean
from .base import Base, TimestampMixin, UUIDMixin
class User(Base, UUIDMixin, TimestampMixin):
__tablename__ = 'users'
email = Column(String, unique=True, nullable=False, index=True)
password_hash = Column(String, nullable=False)
first_name = Column(String, nullable=False, default="user")
last_name = Column(String, nullable=True)
phone_number = Column(String)
is_active = Column(Boolean, default=True, nullable=False)
is_superuser = Column(Boolean, default=False, nullable=False)
preferences = Column(JSON)
def __repr__(self):
return f"<User {self.email}>"

View File

View File

@@ -0,0 +1,149 @@
# app/schemas/users.py
import re
from datetime import datetime
from typing import Optional, Dict, Any
from uuid import UUID
from pydantic import BaseModel, EmailStr, field_validator, ConfigDict
class UserBase(BaseModel):
email: EmailStr
first_name: str
last_name: Optional[str] = None
phone_number: Optional[str] = None
@field_validator('phone_number')
@classmethod
def validate_phone_number(cls, v: Optional[str]) -> Optional[str]:
if v is None:
return v
# Simple regex for phone validation
if not re.match(r'^\+?[0-9\s\-\(\)]{8,20}$', v):
raise ValueError('Invalid phone number format')
return v
class UserCreate(UserBase):
password: str
is_superuser: bool = False
@field_validator('password')
@classmethod
def password_strength(cls, v: str) -> str:
"""Basic password strength validation"""
if len(v) < 8:
raise ValueError('Password must be at least 8 characters')
if not any(char.isdigit() for char in v):
raise ValueError('Password must contain at least one digit')
if not any(char.isupper() for char in v):
raise ValueError('Password must contain at least one uppercase letter')
return v
class UserUpdate(BaseModel):
first_name: Optional[str] = None
last_name: Optional[str] = None
phone_number: Optional[str] = None
preferences: Optional[Dict[str, Any]] = None
is_active: Optional[bool] = True
@field_validator('phone_number')
def validate_phone_number(cls, v: Optional[str]) -> Optional[str]:
if v is None:
return v
# Return early for empty strings or whitespace-only strings
if not v or v.strip() == "":
raise ValueError('Phone number cannot be empty')
# Remove all spaces and formatting characters
cleaned = re.sub(r'[\s\-\(\)]', '', v)
# Basic pattern:
# Must start with + or 0
# After + must have at least 8 digits
# After 0 must have at least 8 digits
# Maximum total length of 15 digits (international standard)
# Only allowed characters are + at start and digits
pattern = r'^(?:\+[0-9]{8,14}|0[0-9]{8,14})$'
if not re.match(pattern, cleaned):
raise ValueError('Phone number must start with + or 0 followed by 8-14 digits')
# Additional validation to catch specific invalid cases
if cleaned.count('+') > 1:
raise ValueError('Phone number can only contain one + symbol at the start')
# Check for any non-digit characters (except the leading +)
if not all(c.isdigit() for c in cleaned[1:]):
raise ValueError('Phone number can only contain digits after the prefix')
return cleaned
class UserInDB(UserBase):
id: UUID
is_active: bool
is_superuser: bool
created_at: datetime
updated_at: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
class UserResponse(UserBase):
id: UUID
is_active: bool
is_superuser: bool
created_at: datetime
updated_at: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
class Token(BaseModel):
access_token: str
refresh_token: Optional[str] = None
token_type: str = "bearer"
class TokenPayload(BaseModel):
sub: str # User ID
exp: int # Expiration time
iat: Optional[int] = None # Issued at
jti: Optional[str] = None # JWT ID
is_superuser: Optional[bool] = False
first_name: Optional[str] = None
email: Optional[str] = None
type: Optional[str] = None # Token type (access/refresh)
class TokenData(BaseModel):
user_id: UUID
is_superuser: bool = False
class PasswordReset(BaseModel):
token: str
new_password: str
@field_validator('new_password')
@classmethod
def password_strength(cls, v: str) -> str:
"""Basic password strength validation"""
if len(v) < 8:
raise ValueError('Password must be at least 8 characters')
if not any(char.isdigit() for char in v):
raise ValueError('Password must contain at least one digit')
if not any(char.isupper() for char in v):
raise ValueError('Password must contain at least one uppercase letter')
return v
class LoginRequest(BaseModel):
email: EmailStr
password: str
class RefreshTokenRequest(BaseModel):
refresh_token: str

View File

View File

@@ -0,0 +1,193 @@
# app/services/auth_service.py
import logging
from typing import Optional
from uuid import UUID
from sqlalchemy.orm import Session
from app.core.auth import (
verify_password,
get_password_hash,
create_access_token,
create_refresh_token,
TokenExpiredError,
TokenInvalidError
)
from app.models.user import User
from app.schemas.users import Token, UserCreate
logger = logging.getLogger(__name__)
class AuthenticationError(Exception):
"""Exception raised for authentication errors"""
pass
class AuthService:
"""Service for handling authentication operations"""
@staticmethod
def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
"""
Authenticate a user with email and password.
Args:
db: Database session
email: User email
password: User password
Returns:
User if authenticated, None otherwise
"""
user = db.query(User).filter(User.email == email).first()
if not user:
return None
if not verify_password(password, user.password_hash):
return None
if not user.is_active:
raise AuthenticationError("User account is inactive")
return user
@staticmethod
def create_user(db: Session, user_data: UserCreate) -> User:
"""
Create a new user.
Args:
db: Database session
user_data: User data
Returns:
Created user
"""
# Check if user already exists
existing_user = db.query(User).filter(User.email == user_data.email).first()
if existing_user:
raise AuthenticationError("User with this email already exists")
# Create new user
hashed_password = get_password_hash(user_data.password)
# Create user object from model
user = User(
email=user_data.email,
password_hash=hashed_password,
first_name=user_data.first_name,
last_name=user_data.last_name,
phone_number=user_data.phone_number,
is_active=True,
is_superuser=False
)
db.add(user)
db.commit()
db.refresh(user)
return user
@staticmethod
def create_tokens(user: User) -> Token:
"""
Create access and refresh tokens for a user.
Args:
user: User to create tokens for
Returns:
Token object with access and refresh tokens
"""
# Generate claims
claims = {
"is_superuser": user.is_superuser,
"email": user.email,
"first_name": user.first_name
}
# Create tokens
access_token = create_access_token(
subject=str(user.id),
claims=claims
)
refresh_token = create_refresh_token(
subject=str(user.id)
)
return Token(
access_token=access_token,
refresh_token=refresh_token
)
@staticmethod
def refresh_tokens(db: Session, refresh_token: str) -> Token:
"""
Generate new tokens using a refresh token.
Args:
db: Database session
refresh_token: Valid refresh token
Returns:
New access and refresh tokens
Raises:
TokenExpiredError: If refresh token has expired
TokenInvalidError: If refresh token is invalid
"""
from app.core.auth import decode_token, get_token_data
try:
# Verify token is a refresh token
decode_token(refresh_token, verify_type="refresh")
# Get user ID from token
token_data = get_token_data(refresh_token)
user_id = token_data.user_id
# Get user from database
user = db.query(User).filter(User.id == user_id).first()
if not user or not user.is_active:
raise TokenInvalidError("Invalid user or inactive account")
# Generate new tokens
return AuthService.create_tokens(user)
except (TokenExpiredError, TokenInvalidError) as e:
logger.warning(f"Token refresh failed: {str(e)}")
raise
@staticmethod
def change_password(db: Session, user_id: UUID, current_password: str, new_password: str) -> bool:
"""
Change a user's password.
Args:
db: Database session
user_id: User ID
current_password: Current password
new_password: New password
Returns:
True if password was changed successfully
Raises:
AuthenticationError: If current password is incorrect
"""
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise AuthenticationError("User not found")
# Verify current password
if not verify_password(current_password, user.password_hash):
raise AuthenticationError("Current password is incorrect")
# Update password
user.password_hash = get_password_hash(new_password)
db.commit()
return True

View File

View File

@@ -0,0 +1,79 @@
import logging
from sqlalchemy import create_engine, event
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker, clear_mappers
from sqlalchemy.pool import StaticPool
from app.core.database import Base
logger = logging.getLogger(__name__)
def get_test_engine():
"""Create an SQLite in-memory engine specifically for testing"""
test_engine = create_engine(
"sqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool, # Use static pool for in-memory testing
echo=False
)
return test_engine
def setup_test_db():
"""Create a test database and session factory"""
# Create a new engine for this test run
test_engine = get_test_engine()
# Create tables
Base.metadata.create_all(test_engine)
# Create session factory
TestingSessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=test_engine,
expire_on_commit=False
)
return test_engine, TestingSessionLocal
def teardown_test_db(engine):
"""Clean up after tests"""
# Drop all tables
Base.metadata.drop_all(engine)
# Dispose of engine
engine.dispose()
async def get_async_test_engine():
"""Create an async SQLite in-memory engine specifically for testing"""
test_engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool, # Use static pool for in-memory testing
echo=False
)
return test_engine
async def setup_async_test_db():
"""Create an async test database and session factory"""
test_engine = await get_async_test_engine()
async with test_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
AsyncTestingSessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=test_engine,
expire_on_commit=False,
class_=AsyncSession
)
return test_engine, AsyncTestingSessionLocal
async def teardown_async_test_db(engine):
"""Clean up after async tests"""
await engine.dispose()