Add comprehensive test suite and utilities for user functionality
This commit introduces a suite of tests for user models, schemas, CRUD operations, and authentication services. It also adds utilities for in-memory database setup to support these tests and updates environment settings for consistency.
This commit is contained in:
@@ -0,0 +1,46 @@
|
||||
"""Add all initial models
|
||||
|
||||
Revision ID: 38bf9e7e74b3
|
||||
Revises: 7396957cbe80
|
||||
Create Date: 2025-02-28 09:19:33.212278
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '38bf9e7e74b3'
|
||||
down_revision: Union[str, None] = '7396957cbe80'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
||||
op.create_table('users',
|
||||
sa.Column('email', sa.String(), nullable=False),
|
||||
sa.Column('password_hash', sa.String(), nullable=False),
|
||||
sa.Column('first_name', sa.String(), nullable=False),
|
||||
sa.Column('last_name', sa.String(), nullable=True),
|
||||
sa.Column('phone_number', sa.String(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_superuser', sa.Boolean(), nullable=False),
|
||||
sa.Column('preferences', sa.JSON(), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('ix_users_email'), table_name='users')
|
||||
op.drop_table('users')
|
||||
# ### end Alembic commands ###
|
||||
0
backend/app/api/__init__.py
Normal file
0
backend/app/api/__init__.py
Normal file
0
backend/app/api/dependencies/__init__.py
Normal file
0
backend/app/api/dependencies/__init__.py
Normal file
137
backend/app/api/dependencies/auth.py
Normal file
137
backend/app/api/dependencies/auth.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.auth import get_token_data, TokenExpiredError, TokenInvalidError
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
|
||||
# OAuth2 configuration
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
|
||||
|
||||
def get_current_user(
|
||||
db: Session = Depends(get_db),
|
||||
token: str = Depends(oauth2_scheme)
|
||||
) -> User:
|
||||
"""
|
||||
Get the current authenticated user.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
token: JWT token from request
|
||||
|
||||
Returns:
|
||||
User: The authenticated user
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails
|
||||
"""
|
||||
try:
|
||||
# Decode token and get user ID
|
||||
token_data = get_token_data(token)
|
||||
|
||||
# Get user from database
|
||||
user = db.query(User).filter(User.id == token_data.user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user"
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
except TokenExpiredError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token expired",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
except TokenInvalidError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
|
||||
|
||||
def get_current_active_user(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> User:
|
||||
"""
|
||||
Check if the current user is active.
|
||||
|
||||
Args:
|
||||
current_user: The current authenticated user
|
||||
|
||||
Returns:
|
||||
User: The authenticated and active user
|
||||
|
||||
Raises:
|
||||
HTTPException: If user is inactive
|
||||
"""
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user"
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
def get_current_superuser(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> User:
|
||||
"""
|
||||
Check if the current user is a superuser.
|
||||
|
||||
Args:
|
||||
current_user: The current authenticated user
|
||||
|
||||
Returns:
|
||||
User: The authenticated superuser
|
||||
|
||||
Raises:
|
||||
HTTPException: If user is not a superuser
|
||||
"""
|
||||
if not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not enough permissions"
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
def get_optional_current_user(
|
||||
db: Session = Depends(get_db),
|
||||
token: Optional[str] = Depends(oauth2_scheme)
|
||||
) -> Optional[User]:
|
||||
"""
|
||||
Get the current user if authenticated, otherwise return None.
|
||||
Useful for endpoints that work with both authenticated and unauthenticated users.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
token: JWT token from request
|
||||
|
||||
Returns:
|
||||
User or None: The authenticated user or None
|
||||
"""
|
||||
if not token:
|
||||
return None
|
||||
|
||||
try:
|
||||
token_data = get_token_data(token)
|
||||
user = db.query(User).filter(User.id == token_data.user_id).first()
|
||||
if not user or not user.is_active:
|
||||
return None
|
||||
return user
|
||||
except (TokenExpiredError, TokenInvalidError):
|
||||
return None
|
||||
6
backend/app/api/main.py
Normal file
6
backend/app/api/main.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.routes import auth
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
|
||||
0
backend/app/api/routes/__init__.py
Normal file
0
backend/app/api/routes/__init__.py
Normal file
231
backend/app/api/routes/auth.py
Normal file
231
backend/app/api/routes/auth.py
Normal file
@@ -0,0 +1,231 @@
|
||||
# app/api/routes/auth.py
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Body
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.core.auth import TokenExpiredError, TokenInvalidError
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.users import (
|
||||
UserCreate,
|
||||
UserResponse,
|
||||
Token,
|
||||
LoginRequest,
|
||||
RefreshTokenRequest
|
||||
)
|
||||
from app.services.auth_service import AuthService, AuthenticationError
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def register_user(
|
||||
user_data: UserCreate,
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Register a new user.
|
||||
|
||||
Returns:
|
||||
The created user information.
|
||||
"""
|
||||
try:
|
||||
user = AuthService.create_user(db, user_data)
|
||||
return user
|
||||
except AuthenticationError as e:
|
||||
logger.warning(f"Registration failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during registration: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred. Please try again later."
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login", response_model=Token)
|
||||
async def login(
|
||||
login_data: LoginRequest,
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Login with username and password.
|
||||
|
||||
Returns:
|
||||
Access and refresh tokens.
|
||||
"""
|
||||
try:
|
||||
# Attempt to authenticate the user
|
||||
user = AuthService.authenticate_user(db, login_data.email, login_data.password)
|
||||
|
||||
# Explicitly check for None result and raise correct exception
|
||||
if user is None:
|
||||
logger.warning(f"Invalid login attempt for: {login_data.email}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid email or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# User is authenticated, generate tokens
|
||||
tokens = AuthService.create_tokens(user)
|
||||
logger.info(f"User login successful: {user.email}")
|
||||
return tokens
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions without modification
|
||||
raise
|
||||
except AuthenticationError as e:
|
||||
# Handle specific authentication errors like inactive accounts
|
||||
logger.warning(f"Authentication failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
except Exception as e:
|
||||
# Handle unexpected errors
|
||||
logger.error(f"Unexpected error during login: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred. Please try again later."
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login/oauth", response_model=Token)
|
||||
async def login_oauth(
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
OAuth2-compatible login endpoint, used by the OpenAPI UI.
|
||||
|
||||
Returns:
|
||||
Access and refresh tokens.
|
||||
"""
|
||||
try:
|
||||
user = AuthService.authenticate_user(db, form_data.username, form_data.password)
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid email or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Generate tokens
|
||||
tokens = AuthService.create_tokens(user)
|
||||
|
||||
# Format response for OAuth compatibility
|
||||
return {
|
||||
"access_token": tokens.access_token,
|
||||
"refresh_token": tokens.refresh_token,
|
||||
"token_type": tokens.token_type
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except AuthenticationError as e:
|
||||
logger.warning(f"OAuth authentication failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during OAuth login: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred. Please try again later."
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=Token)
|
||||
async def refresh_token(
|
||||
refresh_data: RefreshTokenRequest,
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Refresh access token using a refresh token.
|
||||
|
||||
Returns:
|
||||
New access and refresh tokens.
|
||||
"""
|
||||
try:
|
||||
tokens = AuthService.refresh_tokens(db, refresh_data.refresh_token)
|
||||
return tokens
|
||||
except TokenExpiredError:
|
||||
logger.warning("Token refresh failed: Token expired")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Refresh token has expired. Please log in again.",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
except TokenInvalidError:
|
||||
logger.warning("Token refresh failed: Invalid token")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during token refresh: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred. Please try again later."
|
||||
)
|
||||
|
||||
|
||||
@router.post("/change-password", status_code=status.HTTP_200_OK)
|
||||
async def change_password(
|
||||
current_password: str = Body(..., embed=True),
|
||||
new_password: str = Body(..., embed=True),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Change current user's password.
|
||||
|
||||
Requires authentication.
|
||||
"""
|
||||
try:
|
||||
success = AuthService.change_password(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
current_password=current_password,
|
||||
new_password=new_password
|
||||
)
|
||||
|
||||
if success:
|
||||
return {"message": "Password changed successfully"}
|
||||
except AuthenticationError as e:
|
||||
logger.warning(f"Password change failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during password change: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred. Please try again later."
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_current_user_info(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> Any:
|
||||
"""
|
||||
Get current user information.
|
||||
|
||||
Requires authentication.
|
||||
"""
|
||||
return current_user
|
||||
185
backend/app/core/auth.py
Normal file
185
backend/app/core/auth.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import logging
|
||||
logging.getLogger('passlib').setLevel(logging.ERROR)
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, Optional, Union
|
||||
import uuid
|
||||
|
||||
from jose import jwt, JWTError
|
||||
from passlib.context import CryptContext
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.core.config import settings
|
||||
from app.schemas.users import TokenData, TokenPayload
|
||||
|
||||
|
||||
# Password hashing context
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
# Custom exceptions for auth
|
||||
class AuthError(Exception):
|
||||
"""Base authentication error"""
|
||||
pass
|
||||
|
||||
class TokenExpiredError(AuthError):
|
||||
"""Token has expired"""
|
||||
pass
|
||||
|
||||
class TokenInvalidError(AuthError):
|
||||
"""Token is invalid"""
|
||||
pass
|
||||
|
||||
class TokenMissingClaimError(AuthError):
|
||||
"""Token is missing a required claim"""
|
||||
pass
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against a hash."""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""Generate a password hash."""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def create_access_token(
|
||||
subject: Union[str, Any],
|
||||
expires_delta: Optional[timedelta] = None,
|
||||
claims: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Create a JWT access token.
|
||||
|
||||
Args:
|
||||
subject: The subject of the token (usually user_id)
|
||||
expires_delta: Optional expiration time delta
|
||||
claims: Optional additional claims to include in the token
|
||||
|
||||
Returns:
|
||||
Encoded JWT token
|
||||
"""
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
# Base token data
|
||||
to_encode = {
|
||||
"sub": str(subject),
|
||||
"exp": expire,
|
||||
"iat": datetime.now(tz=timezone.utc),
|
||||
"jti": str(uuid.uuid4()),
|
||||
"type": "access"
|
||||
}
|
||||
|
||||
# Add custom claims
|
||||
if claims:
|
||||
to_encode.update(claims)
|
||||
|
||||
# Create the JWT
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def create_refresh_token(
|
||||
subject: Union[str, Any],
|
||||
expires_delta: Optional[timedelta] = None
|
||||
) -> str:
|
||||
"""
|
||||
Create a JWT refresh token.
|
||||
|
||||
Args:
|
||||
subject: The subject of the token (usually user_id)
|
||||
expires_delta: Optional expiration time delta
|
||||
|
||||
Returns:
|
||||
Encoded JWT refresh token
|
||||
"""
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
to_encode = {
|
||||
"sub": str(subject),
|
||||
"exp": expire,
|
||||
"iat": datetime.now(timezone.utc),
|
||||
"jti": str(uuid.uuid4()),
|
||||
"type": "refresh"
|
||||
}
|
||||
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload:
|
||||
"""
|
||||
Decode and verify a JWT token.
|
||||
|
||||
Args:
|
||||
token: JWT token to decode
|
||||
verify_type: Optional token type to verify (access or refresh)
|
||||
|
||||
Returns:
|
||||
TokenPayload: The decoded token data
|
||||
|
||||
Raises:
|
||||
TokenExpiredError: If the token has expired
|
||||
TokenInvalidError: If the token is invalid
|
||||
TokenMissingClaimError: If a required claim is missing
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
|
||||
# Check required claims before Pydantic validation
|
||||
if not payload.get("sub"):
|
||||
raise TokenMissingClaimError("Token missing 'sub' claim")
|
||||
|
||||
# Verify token type if specified
|
||||
if verify_type and payload.get("type") != verify_type:
|
||||
raise TokenInvalidError(f"Invalid token type, expected {verify_type}")
|
||||
|
||||
# Now validate with Pydantic
|
||||
token_data = TokenPayload(**payload)
|
||||
return token_data
|
||||
|
||||
except JWTError as e:
|
||||
# Check if the error is due to an expired token
|
||||
if "expired" in str(e).lower():
|
||||
raise TokenExpiredError("Token has expired")
|
||||
raise TokenInvalidError("Invalid authentication token")
|
||||
except ValidationError:
|
||||
raise TokenInvalidError("Invalid token payload")
|
||||
|
||||
|
||||
def get_token_data(token: str) -> TokenData:
|
||||
"""
|
||||
Extract the user ID and superuser status from a token.
|
||||
|
||||
Args:
|
||||
token: JWT token
|
||||
|
||||
Returns:
|
||||
TokenData with user_id and is_superuser
|
||||
"""
|
||||
payload = decode_token(token)
|
||||
user_id = payload.sub
|
||||
is_superuser = payload.is_superuser or False
|
||||
|
||||
return TokenData(user_id=uuid.UUID(user_id), is_superuser=is_superuser)
|
||||
@@ -3,7 +3,7 @@ from typing import Optional, List
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
PROJECT_NAME: str = "App"
|
||||
PROJECT_NAME: str = "EventSpace"
|
||||
VERSION: str = "1.0.0"
|
||||
API_V1_STR: str = "/api/v1"
|
||||
|
||||
@@ -14,6 +14,17 @@ class Settings(BaseSettings):
|
||||
POSTGRES_PORT: str = "5432"
|
||||
POSTGRES_DB: str = "app"
|
||||
DATABASE_URL: Optional[str] = None
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: int = 60
|
||||
db_pool_size: int = 20 # Default connection pool size
|
||||
db_max_overflow: int = 50 # Maximum overflow connections
|
||||
db_pool_timeout: int = 30 # Seconds to wait for a connection
|
||||
db_pool_recycle: int = 3600 # Recycle connections after 1 hour
|
||||
|
||||
# SQL debugging (disable in production)
|
||||
sql_echo: bool = False # Log SQL statements
|
||||
sql_echo_pool: bool = False # Log connection pool events
|
||||
sql_echo_timing: bool = False # Log query execution times
|
||||
slow_query_threshold: float = 0.5 # Log queries taking longer than this
|
||||
|
||||
@property
|
||||
def database_url(self) -> str:
|
||||
@@ -30,7 +41,7 @@ class Settings(BaseSettings):
|
||||
# JWT configuration
|
||||
SECRET_KEY: str = "your_secret_key_here"
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 1440 # 1 day
|
||||
|
||||
# CORS configuration
|
||||
BACKEND_CORS_ORIGINS: List[str] = ["http://localhost:3000"]
|
||||
|
||||
@@ -1,17 +1,57 @@
|
||||
# app/core/database.py
|
||||
import logging
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# Use the database URL from settings
|
||||
engine = create_engine(settings.database_url)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# SQLite compatibility for testing
|
||||
@compiles(JSONB, 'sqlite')
|
||||
def compile_jsonb_sqlite(type_, compiler, **kw):
|
||||
return "TEXT"
|
||||
|
||||
@compiles(UUID, 'sqlite')
|
||||
def compile_uuid_sqlite(type_, compiler, **kw):
|
||||
return "TEXT"
|
||||
|
||||
# Declarative base for models
|
||||
Base = declarative_base()
|
||||
|
||||
# Create engine with optimized settings for PostgreSQL
|
||||
def create_production_engine():
|
||||
return create_engine(
|
||||
settings.database_url,
|
||||
# Connection pool settings
|
||||
pool_size=settings.db_pool_size,
|
||||
max_overflow=settings.db_max_overflow,
|
||||
pool_timeout=settings.db_pool_timeout,
|
||||
pool_recycle=settings.db_pool_recycle,
|
||||
pool_pre_ping=True,
|
||||
# Query execution settings
|
||||
connect_args={
|
||||
"application_name": "eventspace",
|
||||
"keepalives": 1,
|
||||
"keepalives_idle": 60,
|
||||
"keepalives_interval": 10,
|
||||
"keepalives_count": 5,
|
||||
"options": "-c timezone=UTC",
|
||||
},
|
||||
isolation_level="READ COMMITTED",
|
||||
echo=settings.sql_echo,
|
||||
echo_pool=settings.sql_echo_pool,
|
||||
)
|
||||
|
||||
# Dependency to get DB session
|
||||
# Default production engine and session factory
|
||||
engine = create_production_engine()
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
# FastAPI dependency
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
|
||||
0
backend/app/crud/__init__.py
Normal file
0
backend/app/crud/__init__.py
Normal file
62
backend/app/crud/base.py
Normal file
62
backend/app/crud/base.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from app.core.database import Base
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=Base)
|
||||
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
|
||||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
|
||||
|
||||
|
||||
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
def __init__(self, model: Type[ModelType]):
|
||||
"""
|
||||
CRUD object with default methods to Create, Read, Update, Delete (CRUD).
|
||||
|
||||
Parameters:
|
||||
model: A SQLAlchemy model class
|
||||
"""
|
||||
self.model = model
|
||||
|
||||
def get(self, db: Session, id: str) -> Optional[ModelType]:
|
||||
return db.query(self.model).filter(self.model.id == id).first()
|
||||
|
||||
def get_multi(
|
||||
self, db: Session, *, skip: int = 0, limit: int = 100
|
||||
) -> List[ModelType]:
|
||||
return db.query(self.model).offset(skip).limit(limit).all()
|
||||
|
||||
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
|
||||
obj_in_data = jsonable_encoder(obj_in)
|
||||
db_obj = self.model(**obj_in_data)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
def update(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
db_obj: ModelType,
|
||||
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
|
||||
) -> ModelType:
|
||||
obj_data = jsonable_encoder(db_obj)
|
||||
if isinstance(obj_in, dict):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
for field in obj_data:
|
||||
if field in update_data:
|
||||
setattr(db_obj, field, update_data[field])
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
def remove(self, db: Session, *, id: str) -> ModelType:
|
||||
obj = db.query(self.model).get(id)
|
||||
db.delete(obj)
|
||||
db.commit()
|
||||
return obj
|
||||
56
backend/app/crud/user.py
Normal file
56
backend/app/crud/user.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# app/crud/user.py
|
||||
from typing import Optional, Union, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
from app.core.auth import get_password_hash
|
||||
|
||||
|
||||
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
def get_by_email(self, db: Session, *, email: str) -> Optional[User]:
|
||||
return db.query(User).filter(User.email == email).first()
|
||||
|
||||
def create(self, db: Session, *, obj_in: UserCreate) -> User:
|
||||
db_obj = User(
|
||||
email=obj_in.email,
|
||||
password_hash=get_password_hash(obj_in.password),
|
||||
first_name=obj_in.first_name,
|
||||
last_name=obj_in.last_name,
|
||||
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None,
|
||||
is_superuser=obj_in.is_superuser if hasattr(obj_in, 'is_superuser') else False,
|
||||
preferences={}
|
||||
)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
def update(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
db_obj: User,
|
||||
obj_in: Union[UserUpdate, Dict[str, Any]]
|
||||
) -> User:
|
||||
if isinstance(obj_in, dict):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
|
||||
# Handle password separately if it exists in update data
|
||||
if "password" in update_data:
|
||||
update_data["password_hash"] = get_password_hash(update_data["password"])
|
||||
del update_data["password"]
|
||||
|
||||
return super().update(db, db_obj=db_obj, obj_in=update_data)
|
||||
|
||||
def is_active(self, user: User) -> bool:
|
||||
return user.is_active
|
||||
|
||||
def is_superuser(self, user: User) -> bool:
|
||||
return user.is_superuser
|
||||
|
||||
|
||||
# Create a singleton instance for use across the application
|
||||
user = CRUDUser(User)
|
||||
@@ -1,9 +1,18 @@
|
||||
import logging
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from app.config import settings
|
||||
from app.api.main import api_router
|
||||
from app.core.config import settings
|
||||
|
||||
scheduler = AsyncIOScheduler()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger.info(f"Starting app!!!")
|
||||
app = FastAPI(
|
||||
title=settings.PROJECT_NAME,
|
||||
version=settings.VERSION,
|
||||
@@ -34,3 +43,6 @@ async def root():
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||||
|
||||
14
backend/app/models/__init__.py
Normal file
14
backend/app/models/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
Models package initialization.
|
||||
Imports all models to ensure they're registered with SQLAlchemy.
|
||||
"""
|
||||
# First import Base to avoid circular imports
|
||||
from app.core.database import Base
|
||||
from .base import TimestampMixin, UUIDMixin
|
||||
|
||||
# Import user model
|
||||
from .user import User
|
||||
__all__ = [
|
||||
'Base', 'TimestampMixin', 'UUIDMixin',
|
||||
'User',
|
||||
]
|
||||
20
backend/app/models/base.py
Normal file
20
backend/app/models/base.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import Column, DateTime
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class TimestampMixin:
|
||||
"""Mixin to add created_at and updated_at timestamps to models"""
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc), nullable=False)
|
||||
|
||||
|
||||
class UUIDMixin:
|
||||
"""Mixin to add UUID primary keys to models"""
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
19
backend/app/models/user.py
Normal file
19
backend/app/models/user.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from sqlalchemy import Column, String, JSON, Boolean
|
||||
|
||||
from .base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class User(Base, UUIDMixin, TimestampMixin):
|
||||
__tablename__ = 'users'
|
||||
|
||||
email = Column(String, unique=True, nullable=False, index=True)
|
||||
password_hash = Column(String, nullable=False)
|
||||
first_name = Column(String, nullable=False, default="user")
|
||||
last_name = Column(String, nullable=True)
|
||||
phone_number = Column(String)
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
is_superuser = Column(Boolean, default=False, nullable=False)
|
||||
preferences = Column(JSON)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<User {self.email}>"
|
||||
0
backend/app/schemas/__init__.py
Normal file
0
backend/app/schemas/__init__.py
Normal file
149
backend/app/schemas/users.py
Normal file
149
backend/app/schemas/users.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# app/schemas/users.py
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, EmailStr, field_validator, ConfigDict
|
||||
|
||||
|
||||
class UserBase(BaseModel):
|
||||
email: EmailStr
|
||||
first_name: str
|
||||
last_name: Optional[str] = None
|
||||
phone_number: Optional[str] = None
|
||||
|
||||
@field_validator('phone_number')
|
||||
@classmethod
|
||||
def validate_phone_number(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is None:
|
||||
return v
|
||||
# Simple regex for phone validation
|
||||
if not re.match(r'^\+?[0-9\s\-\(\)]{8,20}$', v):
|
||||
raise ValueError('Invalid phone number format')
|
||||
return v
|
||||
|
||||
|
||||
class UserCreate(UserBase):
|
||||
password: str
|
||||
is_superuser: bool = False
|
||||
|
||||
@field_validator('password')
|
||||
@classmethod
|
||||
def password_strength(cls, v: str) -> str:
|
||||
"""Basic password strength validation"""
|
||||
if len(v) < 8:
|
||||
raise ValueError('Password must be at least 8 characters')
|
||||
if not any(char.isdigit() for char in v):
|
||||
raise ValueError('Password must contain at least one digit')
|
||||
if not any(char.isupper() for char in v):
|
||||
raise ValueError('Password must contain at least one uppercase letter')
|
||||
return v
|
||||
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
phone_number: Optional[str] = None
|
||||
preferences: Optional[Dict[str, Any]] = None
|
||||
is_active: Optional[bool] = True
|
||||
@field_validator('phone_number')
|
||||
def validate_phone_number(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
# Return early for empty strings or whitespace-only strings
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError('Phone number cannot be empty')
|
||||
|
||||
# Remove all spaces and formatting characters
|
||||
cleaned = re.sub(r'[\s\-\(\)]', '', v)
|
||||
|
||||
# Basic pattern:
|
||||
# Must start with + or 0
|
||||
# After + must have at least 8 digits
|
||||
# After 0 must have at least 8 digits
|
||||
# Maximum total length of 15 digits (international standard)
|
||||
# Only allowed characters are + at start and digits
|
||||
pattern = r'^(?:\+[0-9]{8,14}|0[0-9]{8,14})$'
|
||||
|
||||
if not re.match(pattern, cleaned):
|
||||
raise ValueError('Phone number must start with + or 0 followed by 8-14 digits')
|
||||
|
||||
# Additional validation to catch specific invalid cases
|
||||
if cleaned.count('+') > 1:
|
||||
raise ValueError('Phone number can only contain one + symbol at the start')
|
||||
|
||||
# Check for any non-digit characters (except the leading +)
|
||||
if not all(c.isdigit() for c in cleaned[1:]):
|
||||
raise ValueError('Phone number can only contain digits after the prefix')
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
class UserInDB(UserBase):
|
||||
id: UUID
|
||||
is_active: bool
|
||||
is_superuser: bool
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class UserResponse(UserBase):
|
||||
id: UUID
|
||||
is_active: bool
|
||||
is_superuser: bool
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: Optional[str] = None
|
||||
token_type: str = "bearer"
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
sub: str # User ID
|
||||
exp: int # Expiration time
|
||||
iat: Optional[int] = None # Issued at
|
||||
jti: Optional[str] = None # JWT ID
|
||||
is_superuser: Optional[bool] = False
|
||||
first_name: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
type: Optional[str] = None # Token type (access/refresh)
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
user_id: UUID
|
||||
is_superuser: bool = False
|
||||
|
||||
|
||||
class PasswordReset(BaseModel):
|
||||
token: str
|
||||
new_password: str
|
||||
|
||||
@field_validator('new_password')
|
||||
@classmethod
|
||||
def password_strength(cls, v: str) -> str:
|
||||
"""Basic password strength validation"""
|
||||
if len(v) < 8:
|
||||
raise ValueError('Password must be at least 8 characters')
|
||||
if not any(char.isdigit() for char in v):
|
||||
raise ValueError('Password must contain at least one digit')
|
||||
if not any(char.isupper() for char in v):
|
||||
raise ValueError('Password must contain at least one uppercase letter')
|
||||
return v
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
refresh_token: str
|
||||
0
backend/app/services/__init__.py
Normal file
0
backend/app/services/__init__.py
Normal file
193
backend/app/services/auth_service.py
Normal file
193
backend/app/services/auth_service.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# app/services/auth_service.py
|
||||
import logging
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.auth import (
|
||||
verify_password,
|
||||
get_password_hash,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
TokenExpiredError,
|
||||
TokenInvalidError
|
||||
)
|
||||
from app.models.user import User
|
||||
from app.schemas.users import Token, UserCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthenticationError(Exception):
|
||||
"""Exception raised for authentication errors"""
|
||||
pass
|
||||
|
||||
|
||||
class AuthService:
|
||||
"""Service for handling authentication operations"""
|
||||
|
||||
@staticmethod
|
||||
def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
|
||||
"""
|
||||
Authenticate a user with email and password.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
email: User email
|
||||
password: User password
|
||||
|
||||
Returns:
|
||||
User if authenticated, None otherwise
|
||||
"""
|
||||
user = db.query(User).filter(User.email == email).first()
|
||||
|
||||
if not user:
|
||||
return None
|
||||
|
||||
if not verify_password(password, user.password_hash):
|
||||
return None
|
||||
|
||||
if not user.is_active:
|
||||
raise AuthenticationError("User account is inactive")
|
||||
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def create_user(db: Session, user_data: UserCreate) -> User:
|
||||
"""
|
||||
Create a new user.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_data: User data
|
||||
|
||||
Returns:
|
||||
Created user
|
||||
"""
|
||||
# Check if user already exists
|
||||
existing_user = db.query(User).filter(User.email == user_data.email).first()
|
||||
if existing_user:
|
||||
raise AuthenticationError("User with this email already exists")
|
||||
|
||||
# Create new user
|
||||
hashed_password = get_password_hash(user_data.password)
|
||||
|
||||
# Create user object from model
|
||||
user = User(
|
||||
email=user_data.email,
|
||||
password_hash=hashed_password,
|
||||
first_name=user_data.first_name,
|
||||
last_name=user_data.last_name,
|
||||
phone_number=user_data.phone_number,
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def create_tokens(user: User) -> Token:
|
||||
"""
|
||||
Create access and refresh tokens for a user.
|
||||
|
||||
Args:
|
||||
user: User to create tokens for
|
||||
|
||||
Returns:
|
||||
Token object with access and refresh tokens
|
||||
"""
|
||||
# Generate claims
|
||||
claims = {
|
||||
"is_superuser": user.is_superuser,
|
||||
"email": user.email,
|
||||
"first_name": user.first_name
|
||||
}
|
||||
|
||||
# Create tokens
|
||||
access_token = create_access_token(
|
||||
subject=str(user.id),
|
||||
claims=claims
|
||||
)
|
||||
|
||||
refresh_token = create_refresh_token(
|
||||
subject=str(user.id)
|
||||
)
|
||||
|
||||
return Token(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def refresh_tokens(db: Session, refresh_token: str) -> Token:
|
||||
"""
|
||||
Generate new tokens using a refresh token.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
refresh_token: Valid refresh token
|
||||
|
||||
Returns:
|
||||
New access and refresh tokens
|
||||
|
||||
Raises:
|
||||
TokenExpiredError: If refresh token has expired
|
||||
TokenInvalidError: If refresh token is invalid
|
||||
"""
|
||||
from app.core.auth import decode_token, get_token_data
|
||||
|
||||
try:
|
||||
# Verify token is a refresh token
|
||||
decode_token(refresh_token, verify_type="refresh")
|
||||
|
||||
# Get user ID from token
|
||||
token_data = get_token_data(refresh_token)
|
||||
user_id = token_data.user_id
|
||||
|
||||
# Get user from database
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user or not user.is_active:
|
||||
raise TokenInvalidError("Invalid user or inactive account")
|
||||
|
||||
# Generate new tokens
|
||||
return AuthService.create_tokens(user)
|
||||
|
||||
except (TokenExpiredError, TokenInvalidError) as e:
|
||||
logger.warning(f"Token refresh failed: {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def change_password(db: Session, user_id: UUID, current_password: str, new_password: str) -> bool:
|
||||
"""
|
||||
Change a user's password.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
current_password: Current password
|
||||
new_password: New password
|
||||
|
||||
Returns:
|
||||
True if password was changed successfully
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If current password is incorrect
|
||||
"""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise AuthenticationError("User not found")
|
||||
|
||||
# Verify current password
|
||||
if not verify_password(current_password, user.password_hash):
|
||||
raise AuthenticationError("Current password is incorrect")
|
||||
|
||||
# Update password
|
||||
user.password_hash = get_password_hash(new_password)
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
0
backend/app/utils/__init__.py
Normal file
0
backend/app/utils/__init__.py
Normal file
79
backend/app/utils/test_utils.py
Normal file
79
backend/app/utils/test_utils.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import logging
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker, clear_mappers
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_test_engine():
|
||||
"""Create an SQLite in-memory engine specifically for testing"""
|
||||
test_engine = create_engine(
|
||||
"sqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool, # Use static pool for in-memory testing
|
||||
echo=False
|
||||
)
|
||||
|
||||
return test_engine
|
||||
|
||||
def setup_test_db():
|
||||
"""Create a test database and session factory"""
|
||||
# Create a new engine for this test run
|
||||
test_engine = get_test_engine()
|
||||
|
||||
# Create tables
|
||||
Base.metadata.create_all(test_engine)
|
||||
|
||||
# Create session factory
|
||||
TestingSessionLocal = sessionmaker(
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=test_engine,
|
||||
expire_on_commit=False
|
||||
)
|
||||
|
||||
return test_engine, TestingSessionLocal
|
||||
|
||||
def teardown_test_db(engine):
|
||||
"""Clean up after tests"""
|
||||
# Drop all tables
|
||||
Base.metadata.drop_all(engine)
|
||||
|
||||
# Dispose of engine
|
||||
engine.dispose()
|
||||
|
||||
async def get_async_test_engine():
|
||||
"""Create an async SQLite in-memory engine specifically for testing"""
|
||||
test_engine = create_async_engine(
|
||||
"sqlite+aiosqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool, # Use static pool for in-memory testing
|
||||
echo=False
|
||||
)
|
||||
return test_engine
|
||||
|
||||
|
||||
async def setup_async_test_db():
|
||||
"""Create an async test database and session factory"""
|
||||
test_engine = await get_async_test_engine()
|
||||
|
||||
async with test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
AsyncTestingSessionLocal = sessionmaker(
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=test_engine,
|
||||
expire_on_commit=False,
|
||||
class_=AsyncSession
|
||||
)
|
||||
|
||||
return test_engine, AsyncTestingSessionLocal
|
||||
|
||||
|
||||
async def teardown_async_test_db(engine):
|
||||
"""Clean up after async tests"""
|
||||
await engine.dispose()
|
||||
Reference in New Issue
Block a user