Add async CRUD base, async database configuration, soft delete for users, and composite indexes

- Introduced `CRUDBaseAsync` for reusable async operations.
- Configured async database connection using SQLAlchemy 2.0 patterns with `asyncpg`.
- Added `deleted_at` column and soft delete functionality to the `User` model, including related Alembic migration.
- Optimized queries by adding composite indexes for common user filtering scenarios.
- Extended tests: added cases for token-based security utilities and user management endpoints.
This commit is contained in:
Felipe Cardoso
2025-10-30 16:45:01 +01:00
parent c684f2ba95
commit 313e6691b5
9 changed files with 1251 additions and 10 deletions

View File

@@ -0,0 +1,34 @@
"""add_soft_delete_to_users
Revision ID: 2d0fcec3b06d
Revises: 9e4f2a1b8c7d
Create Date: 2025-10-30 16:40:21.000021
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '2d0fcec3b06d'
down_revision: Union[str, None] = '9e4f2a1b8c7d'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add deleted_at column for soft deletes
op.add_column('users', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True))
# Add index on deleted_at for efficient queries
op.create_index('ix_users_deleted_at', 'users', ['deleted_at'])
def downgrade() -> None:
# Remove index
op.drop_index('ix_users_deleted_at', table_name='users')
# Remove column
op.drop_column('users', 'deleted_at')

View File

@@ -0,0 +1,52 @@
"""add_composite_indexes
Revision ID: b76c725fc3cf
Revises: 2d0fcec3b06d
Create Date: 2025-10-30 16:41:33.273135
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'b76c725fc3cf'
down_revision: Union[str, None] = '2d0fcec3b06d'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add composite indexes for common query patterns
# Composite index for filtering active users by role
op.create_index(
'ix_users_active_superuser',
'users',
['is_active', 'is_superuser'],
postgresql_where=sa.text('deleted_at IS NULL')
)
# Composite index for sorting active users by creation date
op.create_index(
'ix_users_active_created',
'users',
['is_active', 'created_at'],
postgresql_where=sa.text('deleted_at IS NULL')
)
# Composite index for email lookup of non-deleted users
op.create_index(
'ix_users_email_not_deleted',
'users',
['email', 'deleted_at']
)
def downgrade() -> None:
# Remove composite indexes
op.drop_index('ix_users_email_not_deleted', table_name='users')
op.drop_index('ix_users_active_created', table_name='users')
op.drop_index('ix_users_active_superuser', table_name='users')

View File

@@ -2,7 +2,7 @@
User management endpoints for CRUD operations.
"""
import logging
from typing import Any
from typing import Any, Optional
from uuid import UUID
from fastapi import APIRouter, Depends, Query, status, Request
@@ -19,9 +19,10 @@ from app.schemas.common import (
PaginationParams,
PaginatedResponse,
MessageResponse,
SortParams,
create_pagination_meta
)
from app.services.auth_service import AuthService
from app.services.auth_service import AuthService, AuthenticationError
from app.core.exceptions import (
NotFoundError,
AuthorizationError,
@@ -39,31 +40,47 @@ limiter = Limiter(key_func=get_remote_address)
response_model=PaginatedResponse[UserResponse],
summary="List Users",
description="""
List all users with pagination (admin only).
List all users with pagination, filtering, and sorting (admin only).
**Authentication**: Required (Bearer token)
**Authorization**: Superuser only
**Filtering**: is_active, is_superuser
**Sorting**: Any user field (email, first_name, last_name, created_at, etc.)
**Rate Limit**: 60 requests/minute
""",
operation_id="list_users"
)
def list_users(
pagination: PaginationParams = Depends(),
sort: SortParams = Depends(),
is_active: Optional[bool] = Query(None, description="Filter by active status"),
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"),
current_user: User = Depends(get_current_superuser),
db: Session = Depends(get_db)
) -> Any:
"""
List all users with pagination.
List all users with pagination, filtering, and sorting.
Only accessible by superusers.
"""
try:
# Build filters
filters = {}
if is_active is not None:
filters["is_active"] = is_active
if is_superuser is not None:
filters["is_superuser"] = is_superuser
# Get paginated users with total count
users, total = user_crud.get_multi_with_total(
db,
skip=pagination.offset,
limit=pagination.limit
limit=pagination.limit,
sort_by=sort.sort_by,
sort_order=sort.sort_order.value if sort.sort_order else "asc",
filters=filters if filters else None
)
# Create pagination metadata
@@ -129,7 +146,7 @@ def update_current_user(
Users cannot elevate their own permissions (is_superuser).
"""
# Prevent users from making themselves superuser
if user_update.is_superuser is not None:
if getattr(user_update, 'is_superuser', None) is not None:
logger.warning(f"User {current_user.id} attempted to modify is_superuser field")
raise AuthorizationError(
message="Cannot modify superuser status",
@@ -248,7 +265,7 @@ def update_user(
)
# Prevent non-superusers from modifying superuser status
if user_update.is_superuser is not None and not current_user.is_superuser:
if getattr(user_update, 'is_superuser', None) is not None and not current_user.is_superuser:
logger.warning(f"User {current_user.id} attempted to modify is_superuser field")
raise AuthorizationError(
message="Cannot modify superuser status",
@@ -308,6 +325,12 @@ def change_current_user_password(
success=True,
message="Password changed successfully"
)
except AuthenticationError as e:
logger.warning(f"Failed password change attempt for user {current_user.id}: {str(e)}")
raise AuthorizationError(
message=str(e),
error_code=ErrorCode.INVALID_CREDENTIALS
)
except Exception as e:
logger.error(f"Error changing password for user {current_user.id}: {str(e)}")
raise
@@ -356,8 +379,9 @@ def delete_user(
)
try:
user_crud.remove(db, id=str(user_id))
logger.info(f"User {user_id} deleted by {current_user.id}")
# Use soft delete instead of hard delete
user_crud.soft_delete(db, id=str(user_id))
logger.info(f"User {user_id} soft-deleted by {current_user.id}")
return MessageResponse(
success=True,
message=f"User {user_id} deleted successfully"

View File

@@ -0,0 +1,182 @@
# app/core/database_async.py
"""
Async database configuration using SQLAlchemy 2.0 and asyncpg.
This module provides async database connectivity with proper connection pooling
and session management for FastAPI endpoints.
"""
import logging
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from sqlalchemy import text
from sqlalchemy.ext.asyncio import (
AsyncSession,
AsyncEngine,
create_async_engine,
async_sessionmaker,
)
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import DeclarativeBase
from app.core.config import settings
# Configure logging
logger = logging.getLogger(__name__)
# SQLite compatibility for testing
@compiles(JSONB, 'sqlite')
def compile_jsonb_sqlite(type_, compiler, **kw):
return "TEXT"
@compiles(UUID, 'sqlite')
def compile_uuid_sqlite(type_, compiler, **kw):
return "TEXT"
# Declarative base for models (SQLAlchemy 2.0 style)
class Base(DeclarativeBase):
"""Base class for all database models."""
pass
def get_async_database_url(url: str) -> str:
"""
Convert sync database URL to async URL.
postgresql:// -> postgresql+asyncpg://
sqlite:// -> sqlite+aiosqlite://
"""
if url.startswith("postgresql://"):
return url.replace("postgresql://", "postgresql+asyncpg://")
elif url.startswith("sqlite://"):
return url.replace("sqlite://", "sqlite+aiosqlite://")
return url
# Create async engine with optimized settings
def create_async_production_engine() -> AsyncEngine:
"""Create an async database engine with production settings."""
async_url = get_async_database_url(settings.database_url)
# Base engine config
engine_config = {
"pool_size": settings.db_pool_size,
"max_overflow": settings.db_max_overflow,
"pool_timeout": settings.db_pool_timeout,
"pool_recycle": settings.db_pool_recycle,
"pool_pre_ping": True,
"echo": settings.sql_echo,
"echo_pool": settings.sql_echo_pool,
}
# Add PostgreSQL-specific connect_args
if "postgresql" in async_url:
engine_config["connect_args"] = {
"server_settings": {
"application_name": "eventspace",
"timezone": "UTC",
},
# asyncpg-specific settings
"command_timeout": 60,
"timeout": 10,
}
return create_async_engine(async_url, **engine_config)
# Create async engine and session factory
async_engine = create_async_production_engine()
AsyncSessionLocal = async_sessionmaker(
async_engine,
class_=AsyncSession,
autocommit=False,
autoflush=False,
expire_on_commit=False, # Prevent unnecessary queries after commit
)
# FastAPI dependency for async database sessions
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
"""
FastAPI dependency that provides an async database session.
Automatically closes the session after the request completes.
Usage:
@router.get("/users")
async def get_users(db: AsyncSession = Depends(get_async_db)):
result = await db.execute(select(User))
return result.scalars().all()
"""
async with AsyncSessionLocal() as session:
try:
yield session
finally:
await session.close()
@asynccontextmanager
async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
"""
Provide an async transactional scope for database operations.
Automatically commits on success or rolls back on exception.
Useful for grouping multiple operations in a single transaction.
Usage:
async with async_transaction_scope() as db:
user = await user_crud.create(db, obj_in=user_create)
profile = await profile_crud.create(db, obj_in=profile_create)
# Both operations committed together
"""
async with AsyncSessionLocal() as session:
try:
yield session
await session.commit()
logger.debug("Async transaction committed successfully")
except Exception as e:
await session.rollback()
logger.error(f"Async transaction failed, rolling back: {str(e)}")
raise
finally:
await session.close()
async def check_async_database_health() -> bool:
"""
Check if async database connection is healthy.
Returns True if connection is successful, False otherwise.
"""
try:
async with async_transaction_scope() as db:
await db.execute(text("SELECT 1"))
return True
except Exception as e:
logger.error(f"Async database health check failed: {str(e)}")
return False
async def init_async_db() -> None:
"""
Initialize async database tables.
This creates all tables defined in the models.
Should only be used in development or testing.
In production, use Alembic migrations.
"""
async with async_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.info("Async database tables created")
async def close_async_db() -> None:
"""
Close all async database connections.
Should be called during application shutdown.
"""
await async_engine.dispose()
logger.info("Async database connections closed")

View File

@@ -0,0 +1,228 @@
# app/crud/base_async.py
"""
Async CRUD operations base class using SQLAlchemy 2.0 async patterns.
Provides reusable create, read, update, and delete operations for all models.
"""
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple
import logging
import uuid
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
from app.core.database_async import Base
logger = logging.getLogger(__name__)
ModelType = TypeVar("ModelType", bound=Base)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
"""Async CRUD operations for a model."""
def __init__(self, model: Type[ModelType]):
"""
CRUD object with default async methods to Create, Read, Update, Delete.
Parameters:
model: A SQLAlchemy model class
"""
self.model = model
async def get(self, db: AsyncSession, id: str) -> Optional[ModelType]:
"""Get a single record by ID with UUID validation."""
# Validate UUID format and convert to UUID object if string
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format: {id} - {str(e)}")
return None
try:
result = await db.execute(
select(self.model).where(self.model.id == uuid_obj)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}")
raise
async def get_multi(
self, db: AsyncSession, *, skip: int = 0, limit: int = 100
) -> List[ModelType]:
"""Get multiple records with pagination validation."""
# Validate pagination parameters
if skip < 0:
raise ValueError("skip must be non-negative")
if limit < 0:
raise ValueError("limit must be non-negative")
if limit > 1000:
raise ValueError("Maximum limit is 1000")
try:
result = await db.execute(
select(self.model).offset(skip).limit(limit)
)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}")
raise
async def create(self, db: AsyncSession, *, obj_in: CreateSchemaType) -> ModelType:
"""Create a new record with error handling."""
try:
obj_in_data = jsonable_encoder(obj_in)
db_obj = self.model(**obj_in_data)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
raise ValueError(f"A {self.model.__name__} with this data already exists")
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e:
await db.rollback()
logger.error(f"Database error creating {self.model.__name__}: {str(e)}")
raise ValueError(f"Database operation failed: {str(e)}")
except Exception as e:
await db.rollback()
logger.error(f"Unexpected error creating {self.model.__name__}: {str(e)}", exc_info=True)
raise
async def update(
self,
db: AsyncSession,
*,
db_obj: ModelType,
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
) -> ModelType:
"""Update a record with error handling."""
try:
obj_data = jsonable_encoder(db_obj)
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.model_dump(exclude_unset=True)
for field in obj_data:
if field in update_data:
setattr(db_obj, field, update_data[field])
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
raise ValueError(f"A {self.model.__name__} with this data already exists")
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e:
await db.rollback()
logger.error(f"Database error updating {self.model.__name__}: {str(e)}")
raise ValueError(f"Database operation failed: {str(e)}")
except Exception as e:
await db.rollback()
logger.error(f"Unexpected error updating {self.model.__name__}: {str(e)}", exc_info=True)
raise
async def remove(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
"""Delete a record with error handling and null check."""
# Validate UUID format and convert to UUID object if string
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for deletion: {id} - {str(e)}")
return None
try:
result = await db.execute(
select(self.model).where(self.model.id == uuid_obj)
)
obj = result.scalar_one_or_none()
if obj is None:
logger.warning(f"{self.model.__name__} with id {id} not found for deletion")
return None
await db.delete(obj)
await db.commit()
return obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
raise ValueError(f"Cannot delete {self.model.__name__}: referenced by other records")
except Exception as e:
await db.rollback()
logger.error(f"Error deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
raise
async def get_multi_with_total(
self, db: AsyncSession, *, skip: int = 0, limit: int = 100
) -> Tuple[List[ModelType], int]:
"""
Get multiple records with total count for pagination.
Returns:
Tuple of (items, total_count)
"""
# Validate pagination parameters
if skip < 0:
raise ValueError("skip must be non-negative")
if limit < 0:
raise ValueError("limit must be non-negative")
if limit > 1000:
raise ValueError("Maximum limit is 1000")
try:
# Get total count
count_result = await db.execute(
select(func.count(self.model.id))
)
total = count_result.scalar_one()
# Get paginated items
items_result = await db.execute(
select(self.model).offset(skip).limit(limit)
)
items = list(items_result.scalars().all())
return items, total
except Exception as e:
logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}")
raise
async def count(self, db: AsyncSession) -> int:
"""Get total count of records."""
try:
result = await db.execute(select(func.count(self.model.id)))
return result.scalar_one()
except Exception as e:
logger.error(f"Error counting {self.model.__name__} records: {str(e)}")
raise
async def exists(self, db: AsyncSession, id: str) -> bool:
"""Check if a record exists by ID."""
obj = await self.get(db, id=id)
return obj is not None

View File

@@ -1,4 +1,4 @@
from sqlalchemy import Column, String, Boolean
from sqlalchemy import Column, String, Boolean, DateTime
from sqlalchemy.dialects.postgresql import JSONB
from .base import Base, TimestampMixin, UUIDMixin
@@ -15,6 +15,7 @@ class User(Base, UUIDMixin, TimestampMixin):
is_active = Column(Boolean, default=True, nullable=False, index=True)
is_superuser = Column(Boolean, default=False, nullable=False, index=True)
preferences = Column(JSONB)
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
def __repr__(self):
return f"<User {self.email}>"