Add comprehensive test suite and utilities for user functionality
This commit introduces a suite of tests for user models, schemas, CRUD operations, and authentication services. It also adds utilities for in-memory database setup to support these tests and updates environment settings for consistency.
This commit is contained in:
185
backend/app/core/auth.py
Normal file
185
backend/app/core/auth.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import logging
|
||||
logging.getLogger('passlib').setLevel(logging.ERROR)
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, Optional, Union
|
||||
import uuid
|
||||
|
||||
from jose import jwt, JWTError
|
||||
from passlib.context import CryptContext
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.core.config import settings
|
||||
from app.schemas.users import TokenData, TokenPayload
|
||||
|
||||
|
||||
# Password hashing context
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
# Custom exceptions for auth
|
||||
class AuthError(Exception):
|
||||
"""Base authentication error"""
|
||||
pass
|
||||
|
||||
class TokenExpiredError(AuthError):
|
||||
"""Token has expired"""
|
||||
pass
|
||||
|
||||
class TokenInvalidError(AuthError):
|
||||
"""Token is invalid"""
|
||||
pass
|
||||
|
||||
class TokenMissingClaimError(AuthError):
|
||||
"""Token is missing a required claim"""
|
||||
pass
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against a hash."""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""Generate a password hash."""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def create_access_token(
|
||||
subject: Union[str, Any],
|
||||
expires_delta: Optional[timedelta] = None,
|
||||
claims: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Create a JWT access token.
|
||||
|
||||
Args:
|
||||
subject: The subject of the token (usually user_id)
|
||||
expires_delta: Optional expiration time delta
|
||||
claims: Optional additional claims to include in the token
|
||||
|
||||
Returns:
|
||||
Encoded JWT token
|
||||
"""
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
# Base token data
|
||||
to_encode = {
|
||||
"sub": str(subject),
|
||||
"exp": expire,
|
||||
"iat": datetime.now(tz=timezone.utc),
|
||||
"jti": str(uuid.uuid4()),
|
||||
"type": "access"
|
||||
}
|
||||
|
||||
# Add custom claims
|
||||
if claims:
|
||||
to_encode.update(claims)
|
||||
|
||||
# Create the JWT
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def create_refresh_token(
|
||||
subject: Union[str, Any],
|
||||
expires_delta: Optional[timedelta] = None
|
||||
) -> str:
|
||||
"""
|
||||
Create a JWT refresh token.
|
||||
|
||||
Args:
|
||||
subject: The subject of the token (usually user_id)
|
||||
expires_delta: Optional expiration time delta
|
||||
|
||||
Returns:
|
||||
Encoded JWT refresh token
|
||||
"""
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
to_encode = {
|
||||
"sub": str(subject),
|
||||
"exp": expire,
|
||||
"iat": datetime.now(timezone.utc),
|
||||
"jti": str(uuid.uuid4()),
|
||||
"type": "refresh"
|
||||
}
|
||||
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload:
|
||||
"""
|
||||
Decode and verify a JWT token.
|
||||
|
||||
Args:
|
||||
token: JWT token to decode
|
||||
verify_type: Optional token type to verify (access or refresh)
|
||||
|
||||
Returns:
|
||||
TokenPayload: The decoded token data
|
||||
|
||||
Raises:
|
||||
TokenExpiredError: If the token has expired
|
||||
TokenInvalidError: If the token is invalid
|
||||
TokenMissingClaimError: If a required claim is missing
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
|
||||
# Check required claims before Pydantic validation
|
||||
if not payload.get("sub"):
|
||||
raise TokenMissingClaimError("Token missing 'sub' claim")
|
||||
|
||||
# Verify token type if specified
|
||||
if verify_type and payload.get("type") != verify_type:
|
||||
raise TokenInvalidError(f"Invalid token type, expected {verify_type}")
|
||||
|
||||
# Now validate with Pydantic
|
||||
token_data = TokenPayload(**payload)
|
||||
return token_data
|
||||
|
||||
except JWTError as e:
|
||||
# Check if the error is due to an expired token
|
||||
if "expired" in str(e).lower():
|
||||
raise TokenExpiredError("Token has expired")
|
||||
raise TokenInvalidError("Invalid authentication token")
|
||||
except ValidationError:
|
||||
raise TokenInvalidError("Invalid token payload")
|
||||
|
||||
|
||||
def get_token_data(token: str) -> TokenData:
|
||||
"""
|
||||
Extract the user ID and superuser status from a token.
|
||||
|
||||
Args:
|
||||
token: JWT token
|
||||
|
||||
Returns:
|
||||
TokenData with user_id and is_superuser
|
||||
"""
|
||||
payload = decode_token(token)
|
||||
user_id = payload.sub
|
||||
is_superuser = payload.is_superuser or False
|
||||
|
||||
return TokenData(user_id=uuid.UUID(user_id), is_superuser=is_superuser)
|
||||
@@ -3,7 +3,7 @@ from typing import Optional, List
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
PROJECT_NAME: str = "App"
|
||||
PROJECT_NAME: str = "EventSpace"
|
||||
VERSION: str = "1.0.0"
|
||||
API_V1_STR: str = "/api/v1"
|
||||
|
||||
@@ -14,6 +14,17 @@ class Settings(BaseSettings):
|
||||
POSTGRES_PORT: str = "5432"
|
||||
POSTGRES_DB: str = "app"
|
||||
DATABASE_URL: Optional[str] = None
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: int = 60
|
||||
db_pool_size: int = 20 # Default connection pool size
|
||||
db_max_overflow: int = 50 # Maximum overflow connections
|
||||
db_pool_timeout: int = 30 # Seconds to wait for a connection
|
||||
db_pool_recycle: int = 3600 # Recycle connections after 1 hour
|
||||
|
||||
# SQL debugging (disable in production)
|
||||
sql_echo: bool = False # Log SQL statements
|
||||
sql_echo_pool: bool = False # Log connection pool events
|
||||
sql_echo_timing: bool = False # Log query execution times
|
||||
slow_query_threshold: float = 0.5 # Log queries taking longer than this
|
||||
|
||||
@property
|
||||
def database_url(self) -> str:
|
||||
@@ -30,7 +41,7 @@ class Settings(BaseSettings):
|
||||
# JWT configuration
|
||||
SECRET_KEY: str = "your_secret_key_here"
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 1440 # 1 day
|
||||
|
||||
# CORS configuration
|
||||
BACKEND_CORS_ORIGINS: List[str] = ["http://localhost:3000"]
|
||||
|
||||
@@ -1,17 +1,57 @@
|
||||
# app/core/database.py
|
||||
import logging
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# Use the database URL from settings
|
||||
engine = create_engine(settings.database_url)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# SQLite compatibility for testing
|
||||
@compiles(JSONB, 'sqlite')
|
||||
def compile_jsonb_sqlite(type_, compiler, **kw):
|
||||
return "TEXT"
|
||||
|
||||
@compiles(UUID, 'sqlite')
|
||||
def compile_uuid_sqlite(type_, compiler, **kw):
|
||||
return "TEXT"
|
||||
|
||||
# Declarative base for models
|
||||
Base = declarative_base()
|
||||
|
||||
# Create engine with optimized settings for PostgreSQL
|
||||
def create_production_engine():
|
||||
return create_engine(
|
||||
settings.database_url,
|
||||
# Connection pool settings
|
||||
pool_size=settings.db_pool_size,
|
||||
max_overflow=settings.db_max_overflow,
|
||||
pool_timeout=settings.db_pool_timeout,
|
||||
pool_recycle=settings.db_pool_recycle,
|
||||
pool_pre_ping=True,
|
||||
# Query execution settings
|
||||
connect_args={
|
||||
"application_name": "eventspace",
|
||||
"keepalives": 1,
|
||||
"keepalives_idle": 60,
|
||||
"keepalives_interval": 10,
|
||||
"keepalives_count": 5,
|
||||
"options": "-c timezone=UTC",
|
||||
},
|
||||
isolation_level="READ COMMITTED",
|
||||
echo=settings.sql_echo,
|
||||
echo_pool=settings.sql_echo_pool,
|
||||
)
|
||||
|
||||
# Dependency to get DB session
|
||||
# Default production engine and session factory
|
||||
engine = create_production_engine()
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
# FastAPI dependency
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user