Add comprehensive test suite and utilities for user functionality

This commit introduces a suite of tests for user models, schemas, CRUD operations, and authentication services. It also adds utilities for in-memory database setup to support these tests and updates environment settings for consistency.
This commit is contained in:
2025-03-04 19:10:54 +01:00
parent 481b6d618e
commit 162e586e13
40 changed files with 2948 additions and 11 deletions

185
backend/app/core/auth.py Normal file
View File

@@ -0,0 +1,185 @@
import logging
logging.getLogger('passlib').setLevel(logging.ERROR)
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional, Union
import uuid
from jose import jwt, JWTError
from passlib.context import CryptContext
from pydantic import ValidationError
from app.core.config import settings
from app.schemas.users import TokenData, TokenPayload
# Password hashing context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# Custom exceptions for auth
class AuthError(Exception):
"""Base authentication error"""
pass
class TokenExpiredError(AuthError):
"""Token has expired"""
pass
class TokenInvalidError(AuthError):
"""Token is invalid"""
pass
class TokenMissingClaimError(AuthError):
"""Token is missing a required claim"""
pass
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against a hash."""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
"""Generate a password hash."""
return pwd_context.hash(password)
def create_access_token(
subject: Union[str, Any],
expires_delta: Optional[timedelta] = None,
claims: Optional[Dict[str, Any]] = None
) -> str:
"""
Create a JWT access token.
Args:
subject: The subject of the token (usually user_id)
expires_delta: Optional expiration time delta
claims: Optional additional claims to include in the token
Returns:
Encoded JWT token
"""
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
# Base token data
to_encode = {
"sub": str(subject),
"exp": expire,
"iat": datetime.now(tz=timezone.utc),
"jti": str(uuid.uuid4()),
"type": "access"
}
# Add custom claims
if claims:
to_encode.update(claims)
# Create the JWT
encoded_jwt = jwt.encode(
to_encode,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
)
return encoded_jwt
def create_refresh_token(
subject: Union[str, Any],
expires_delta: Optional[timedelta] = None
) -> str:
"""
Create a JWT refresh token.
Args:
subject: The subject of the token (usually user_id)
expires_delta: Optional expiration time delta
Returns:
Encoded JWT refresh token
"""
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
to_encode = {
"sub": str(subject),
"exp": expire,
"iat": datetime.now(timezone.utc),
"jti": str(uuid.uuid4()),
"type": "refresh"
}
encoded_jwt = jwt.encode(
to_encode,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
)
return encoded_jwt
def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload:
"""
Decode and verify a JWT token.
Args:
token: JWT token to decode
verify_type: Optional token type to verify (access or refresh)
Returns:
TokenPayload: The decoded token data
Raises:
TokenExpiredError: If the token has expired
TokenInvalidError: If the token is invalid
TokenMissingClaimError: If a required claim is missing
"""
try:
payload = jwt.decode(
token,
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM]
)
# Check required claims before Pydantic validation
if not payload.get("sub"):
raise TokenMissingClaimError("Token missing 'sub' claim")
# Verify token type if specified
if verify_type and payload.get("type") != verify_type:
raise TokenInvalidError(f"Invalid token type, expected {verify_type}")
# Now validate with Pydantic
token_data = TokenPayload(**payload)
return token_data
except JWTError as e:
# Check if the error is due to an expired token
if "expired" in str(e).lower():
raise TokenExpiredError("Token has expired")
raise TokenInvalidError("Invalid authentication token")
except ValidationError:
raise TokenInvalidError("Invalid token payload")
def get_token_data(token: str) -> TokenData:
"""
Extract the user ID and superuser status from a token.
Args:
token: JWT token
Returns:
TokenData with user_id and is_superuser
"""
payload = decode_token(token)
user_id = payload.sub
is_superuser = payload.is_superuser or False
return TokenData(user_id=uuid.UUID(user_id), is_superuser=is_superuser)

View File

@@ -3,7 +3,7 @@ from typing import Optional, List
class Settings(BaseSettings):
PROJECT_NAME: str = "App"
PROJECT_NAME: str = "EventSpace"
VERSION: str = "1.0.0"
API_V1_STR: str = "/api/v1"
@@ -14,6 +14,17 @@ class Settings(BaseSettings):
POSTGRES_PORT: str = "5432"
POSTGRES_DB: str = "app"
DATABASE_URL: Optional[str] = None
REFRESH_TOKEN_EXPIRE_DAYS: int = 60
db_pool_size: int = 20 # Default connection pool size
db_max_overflow: int = 50 # Maximum overflow connections
db_pool_timeout: int = 30 # Seconds to wait for a connection
db_pool_recycle: int = 3600 # Recycle connections after 1 hour
# SQL debugging (disable in production)
sql_echo: bool = False # Log SQL statements
sql_echo_pool: bool = False # Log connection pool events
sql_echo_timing: bool = False # Log query execution times
slow_query_threshold: float = 0.5 # Log queries taking longer than this
@property
def database_url(self) -> str:
@@ -30,7 +41,7 @@ class Settings(BaseSettings):
# JWT configuration
SECRET_KEY: str = "your_secret_key_here"
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
ACCESS_TOKEN_EXPIRE_MINUTES: int = 1440 # 1 day
# CORS configuration
BACKEND_CORS_ORIGINS: List[str] = ["http://localhost:3000"]

View File

@@ -1,17 +1,57 @@
# app/core/database.py
import logging
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.dialects.postgresql import JSONB, UUID
from app.core.config import settings
# Use the database URL from settings
engine = create_engine(settings.database_url)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Configure logging
logger = logging.getLogger(__name__)
# SQLite compatibility for testing
@compiles(JSONB, 'sqlite')
def compile_jsonb_sqlite(type_, compiler, **kw):
return "TEXT"
@compiles(UUID, 'sqlite')
def compile_uuid_sqlite(type_, compiler, **kw):
return "TEXT"
# Declarative base for models
Base = declarative_base()
# Create engine with optimized settings for PostgreSQL
def create_production_engine():
return create_engine(
settings.database_url,
# Connection pool settings
pool_size=settings.db_pool_size,
max_overflow=settings.db_max_overflow,
pool_timeout=settings.db_pool_timeout,
pool_recycle=settings.db_pool_recycle,
pool_pre_ping=True,
# Query execution settings
connect_args={
"application_name": "eventspace",
"keepalives": 1,
"keepalives_idle": 60,
"keepalives_interval": 10,
"keepalives_count": 5,
"options": "-c timezone=UTC",
},
isolation_level="READ COMMITTED",
echo=settings.sql_echo,
echo_pool=settings.sql_echo_pool,
)
# Dependency to get DB session
# Default production engine and session factory
engine = create_production_engine()
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# FastAPI dependency
def get_db():
db = SessionLocal()
try: