Add pyproject.toml for consolidated project configuration and replace Black, isort, and Flake8 with Ruff

- Introduced `pyproject.toml` to centralize backend tool configurations (e.g., Ruff, mypy, coverage, pytest).
- Replaced Black, isort, and Flake8 with Ruff for linting, formatting, and import sorting.
- Updated `requirements.txt` to include Ruff and remove replaced tools.
- Added `Makefile` to streamline development workflows with commands for linting, formatting, type-checking, testing, and cleanup.
This commit is contained in:
2025-11-10 11:55:15 +01:00
parent a5c671c133
commit c589b565f0
86 changed files with 4572 additions and 3956 deletions

View File

@@ -3,4 +3,4 @@ from .organization import organization
from .session import session as session_crud
from .user import user
__all__ = ["user", "session_crud", "organization"]
__all__ = ["organization", "session_crud", "user"]

View File

@@ -4,14 +4,16 @@ Async CRUD operations base class using SQLAlchemy 2.0 async patterns.
Provides reusable create, read, update, and delete operations for all models.
"""
import logging
import uuid
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple
from datetime import UTC
from typing import Any, TypeVar
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from sqlalchemy import func, select
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
from sqlalchemy.exc import DataError, IntegrityError, OperationalError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Load
@@ -24,10 +26,14 @@ CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
class CRUDBase[
ModelType: Base,
CreateSchemaType: BaseModel,
UpdateSchemaType: BaseModel,
]:
"""Async CRUD operations for a model."""
def __init__(self, model: Type[ModelType]):
def __init__(self, model: type[ModelType]):
"""
CRUD object with default async methods to Create, Read, Update, Delete.
@@ -37,11 +43,8 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
self.model = model
async def get(
self,
db: AsyncSession,
id: str,
options: Optional[List[Load]] = None
) -> Optional[ModelType]:
self, db: AsyncSession, id: str, options: list[Load] | None = None
) -> ModelType | None:
"""
Get a single record by ID with UUID validation and optional eager loading.
@@ -66,7 +69,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format: {id} - {str(e)}")
logger.warning(f"Invalid UUID format: {id} - {e!s}")
return None
try:
@@ -80,7 +83,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
result = await db.execute(query)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}")
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {e!s}")
raise
async def get_multi(
@@ -89,8 +92,8 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
*,
skip: int = 0,
limit: int = 100,
options: Optional[List[Load]] = None
) -> List[ModelType]:
options: list[Load] | None = None,
) -> list[ModelType]:
"""
Get multiple records with pagination validation and optional eager loading.
@@ -122,10 +125,14 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}")
logger.error(
f"Error retrieving multiple {self.model.__name__} records: {e!s}"
)
raise
async def create(self, db: AsyncSession, *, obj_in: CreateSchemaType) -> ModelType: # pragma: no cover
async def create(
self, db: AsyncSession, *, obj_in: CreateSchemaType
) -> ModelType: # pragma: no cover
"""Create a new record with error handling.
NOTE: This method is defensive code that's never called in practice.
@@ -142,19 +149,25 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
return db_obj
except IntegrityError as e: # pragma: no cover
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
raise ValueError(f"A {self.model.__name__} with this data already exists")
logger.warning(
f"Duplicate entry attempted for {self.model.__name__}: {error_msg}"
)
raise ValueError(
f"A {self.model.__name__} with this data already exists"
)
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e: # pragma: no cover
await db.rollback()
logger.error(f"Database error creating {self.model.__name__}: {str(e)}")
raise ValueError(f"Database operation failed: {str(e)}")
logger.error(f"Database error creating {self.model.__name__}: {e!s}")
raise ValueError(f"Database operation failed: {e!s}")
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Unexpected error creating {self.model.__name__}: {str(e)}", exc_info=True)
logger.error(
f"Unexpected error creating {self.model.__name__}: {e!s}", exc_info=True
)
raise
async def update(
@@ -162,7 +175,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
db: AsyncSession,
*,
db_obj: ModelType,
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
obj_in: UpdateSchemaType | dict[str, Any],
) -> ModelType:
"""Update a record with error handling."""
try:
@@ -182,22 +195,28 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
raise ValueError(f"A {self.model.__name__} with this data already exists")
logger.warning(
f"Duplicate entry attempted for {self.model.__name__}: {error_msg}"
)
raise ValueError(
f"A {self.model.__name__} with this data already exists"
)
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e:
await db.rollback()
logger.error(f"Database error updating {self.model.__name__}: {str(e)}")
raise ValueError(f"Database operation failed: {str(e)}")
logger.error(f"Database error updating {self.model.__name__}: {e!s}")
raise ValueError(f"Database operation failed: {e!s}")
except Exception as e:
await db.rollback()
logger.error(f"Unexpected error updating {self.model.__name__}: {str(e)}", exc_info=True)
logger.error(
f"Unexpected error updating {self.model.__name__}: {e!s}", exc_info=True
)
raise
async def remove(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
async def remove(self, db: AsyncSession, *, id: str) -> ModelType | None:
"""Delete a record with error handling and null check."""
# Validate UUID format and convert to UUID object if string
try:
@@ -206,7 +225,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for deletion: {id} - {str(e)}")
logger.warning(f"Invalid UUID format for deletion: {id} - {e!s}")
return None
try:
@@ -216,7 +235,9 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
obj = result.scalar_one_or_none()
if obj is None:
logger.warning(f"{self.model.__name__} with id {id} not found for deletion")
logger.warning(
f"{self.model.__name__} with id {id} not found for deletion"
)
return None
await db.delete(obj)
@@ -224,12 +245,17 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
return obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
raise ValueError(f"Cannot delete {self.model.__name__}: referenced by other records")
raise ValueError(
f"Cannot delete {self.model.__name__}: referenced by other records"
)
except Exception as e:
await db.rollback()
logger.error(f"Error deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
logger.error(
f"Error deleting {self.model.__name__} with id {id}: {e!s}",
exc_info=True,
)
raise
async def get_multi_with_total(
@@ -238,10 +264,10 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
*,
skip: int = 0,
limit: int = 100,
sort_by: Optional[str] = None,
sort_by: str | None = None,
sort_order: str = "asc",
filters: Optional[Dict[str, Any]] = None
) -> Tuple[List[ModelType], int]:
filters: dict[str, Any] | None = None,
) -> tuple[list[ModelType], int]:
"""
Get multiple records with total count, filtering, and sorting.
@@ -269,7 +295,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
query = select(self.model)
# Exclude soft-deleted records by default
if hasattr(self.model, 'deleted_at'):
if hasattr(self.model, "deleted_at"):
query = query.where(self.model.deleted_at.is_(None))
# Apply filters
@@ -298,7 +324,9 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
return items, total
except Exception as e:
logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}")
logger.error(
f"Error retrieving paginated {self.model.__name__} records: {e!s}"
)
raise
async def count(self, db: AsyncSession) -> int:
@@ -307,7 +335,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
result = await db.execute(select(func.count(self.model.id)))
return result.scalar_one()
except Exception as e:
logger.error(f"Error counting {self.model.__name__} records: {str(e)}")
logger.error(f"Error counting {self.model.__name__} records: {e!s}")
raise
async def exists(self, db: AsyncSession, id: str) -> bool:
@@ -315,13 +343,13 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
obj = await self.get(db, id=id)
return obj is not None
async def soft_delete(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
async def soft_delete(self, db: AsyncSession, *, id: str) -> ModelType | None:
"""
Soft delete a record by setting deleted_at timestamp.
Only works if the model has a 'deleted_at' column.
"""
from datetime import datetime, timezone
from datetime import datetime
# Validate UUID format and convert to UUID object if string
try:
@@ -330,7 +358,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for soft deletion: {id} - {str(e)}")
logger.warning(f"Invalid UUID format for soft deletion: {id} - {e!s}")
return None
try:
@@ -340,26 +368,33 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
obj = result.scalar_one_or_none()
if obj is None:
logger.warning(f"{self.model.__name__} with id {id} not found for soft deletion")
logger.warning(
f"{self.model.__name__} with id {id} not found for soft deletion"
)
return None
# Check if model supports soft deletes
if not hasattr(self.model, 'deleted_at'):
if not hasattr(self.model, "deleted_at"):
logger.error(f"{self.model.__name__} does not support soft deletes")
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
raise ValueError(
f"{self.model.__name__} does not have a deleted_at column"
)
# Set deleted_at timestamp
obj.deleted_at = datetime.now(timezone.utc)
obj.deleted_at = datetime.now(UTC)
db.add(obj)
await db.commit()
await db.refresh(obj)
return obj
except Exception as e:
await db.rollback()
logger.error(f"Error soft deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
logger.error(
f"Error soft deleting {self.model.__name__} with id {id}: {e!s}",
exc_info=True,
)
raise
async def restore(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
async def restore(self, db: AsyncSession, *, id: str) -> ModelType | None:
"""
Restore a soft-deleted record by clearing the deleted_at timestamp.
@@ -372,25 +407,28 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for restoration: {id} - {str(e)}")
logger.warning(f"Invalid UUID format for restoration: {id} - {e!s}")
return None
try:
# Find the soft-deleted record
if hasattr(self.model, 'deleted_at'):
if hasattr(self.model, "deleted_at"):
result = await db.execute(
select(self.model).where(
self.model.id == uuid_obj,
self.model.deleted_at.isnot(None)
self.model.id == uuid_obj, self.model.deleted_at.isnot(None)
)
)
obj = result.scalar_one_or_none()
else:
logger.error(f"{self.model.__name__} does not support soft deletes")
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
raise ValueError(
f"{self.model.__name__} does not have a deleted_at column"
)
if obj is None:
logger.warning(f"Soft-deleted {self.model.__name__} with id {id} not found for restoration")
logger.warning(
f"Soft-deleted {self.model.__name__} with id {id} not found for restoration"
)
return None
# Clear deleted_at timestamp
@@ -401,5 +439,8 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
return obj
except Exception as e:
await db.rollback()
logger.error(f"Error restoring {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
logger.error(
f"Error restoring {self.model.__name__} with id {id}: {e!s}",
exc_info=True,
)
raise

View File

@@ -1,17 +1,18 @@
# app/crud/organization_async.py
"""Async CRUD operations for Organization model using SQLAlchemy 2.0 patterns."""
import logging
from typing import Optional, List, Dict, Any
from typing import Any
from uuid import UUID
from sqlalchemy import func, or_, and_, select, case
from sqlalchemy import and_, case, func, or_, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.crud.base import CRUDBase
from app.models.organization import Organization
from app.models.user import User
from app.models.user_organization import UserOrganization, OrganizationRole
from app.models.user_organization import OrganizationRole, UserOrganization
from app.schemas.organizations import (
OrganizationCreate,
OrganizationUpdate,
@@ -23,7 +24,7 @@ logger = logging.getLogger(__name__)
class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUpdate]):
"""Async CRUD operations for Organization model."""
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Optional[Organization]:
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Organization | None:
"""Get organization by slug."""
try:
result = await db.execute(
@@ -31,10 +32,12 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting organization by slug {slug}: {str(e)}")
logger.error(f"Error getting organization by slug {slug}: {e!s}")
raise
async def create(self, db: AsyncSession, *, obj_in: OrganizationCreate) -> Organization:
async def create(
self, db: AsyncSession, *, obj_in: OrganizationCreate
) -> Organization:
"""Create a new organization with error handling."""
try:
db_obj = Organization(
@@ -42,7 +45,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
slug=obj_in.slug,
description=obj_in.description,
is_active=obj_in.is_active,
settings=obj_in.settings or {}
settings=obj_in.settings or {},
)
db.add(db_obj)
await db.commit()
@@ -50,15 +53,19 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "slug" in error_msg.lower():
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
raise ValueError(f"Organization with slug '{obj_in.slug}' already exists")
raise ValueError(
f"Organization with slug '{obj_in.slug}' already exists"
)
logger.error(f"Integrity error creating organization: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except Exception as e:
await db.rollback()
logger.error(f"Unexpected error creating organization: {str(e)}", exc_info=True)
logger.error(
f"Unexpected error creating organization: {e!s}", exc_info=True
)
raise
async def get_multi_with_filters(
@@ -67,11 +74,11 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
*,
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None,
search: Optional[str] = None,
is_active: bool | None = None,
search: str | None = None,
sort_by: str = "created_at",
sort_order: str = "desc"
) -> tuple[List[Organization], int]:
sort_order: str = "desc",
) -> tuple[list[Organization], int]:
"""
Get multiple organizations with filtering, searching, and sorting.
@@ -89,7 +96,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
search_filter = or_(
Organization.name.ilike(f"%{search}%"),
Organization.slug.ilike(f"%{search}%"),
Organization.description.ilike(f"%{search}%")
Organization.description.ilike(f"%{search}%"),
)
query = query.where(search_filter)
@@ -112,7 +119,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return organizations, total
except Exception as e:
logger.error(f"Error getting organizations with filters: {str(e)}")
logger.error(f"Error getting organizations with filters: {e!s}")
raise
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
@@ -122,13 +129,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
select(func.count(UserOrganization.user_id)).where(
and_(
UserOrganization.organization_id == organization_id,
UserOrganization.is_active == True
UserOrganization.is_active,
)
)
)
return result.scalar_one() or 0
except Exception as e:
logger.error(f"Error getting member count for organization {organization_id}: {str(e)}")
logger.error(
f"Error getting member count for organization {organization_id}: {e!s}"
)
raise
async def get_multi_with_member_counts(
@@ -137,9 +146,9 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
*,
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None,
search: Optional[str] = None
) -> tuple[List[Dict[str, Any]], int]:
is_active: bool | None = None,
search: str | None = None,
) -> tuple[list[dict[str, Any]], int]:
"""
Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY.
This eliminates the N+1 query problem.
@@ -156,13 +165,19 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
func.count(
func.distinct(
case(
(UserOrganization.is_active == True, UserOrganization.user_id),
else_=None
(
UserOrganization.is_active,
UserOrganization.user_id,
),
else_=None,
)
)
).label('member_count')
).label("member_count"),
)
.outerjoin(
UserOrganization,
Organization.id == UserOrganization.organization_id,
)
.outerjoin(UserOrganization, Organization.id == UserOrganization.organization_id)
.group_by(Organization.id)
)
@@ -174,7 +189,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
search_filter = or_(
Organization.name.ilike(f"%{search}%"),
Organization.slug.ilike(f"%{search}%"),
Organization.description.ilike(f"%{search}%")
Organization.description.ilike(f"%{search}%"),
)
query = query.where(search_filter)
@@ -189,24 +204,25 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
total = count_result.scalar_one()
# Apply pagination and ordering
query = query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
query = (
query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
)
result = await db.execute(query)
rows = result.all()
# Convert to list of dicts
orgs_with_counts = [
{
'organization': org,
'member_count': member_count
}
{"organization": org, "member_count": member_count}
for org, member_count in rows
]
return orgs_with_counts, total
except Exception as e:
logger.error(f"Error getting organizations with member counts: {str(e)}", exc_info=True)
logger.error(
f"Error getting organizations with member counts: {e!s}", exc_info=True
)
raise
async def add_user(
@@ -216,7 +232,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
organization_id: UUID,
user_id: UUID,
role: OrganizationRole = OrganizationRole.MEMBER,
custom_permissions: Optional[str] = None
custom_permissions: str | None = None,
) -> UserOrganization:
"""Add a user to an organization with a specific role."""
try:
@@ -225,7 +241,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
select(UserOrganization).where(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id
UserOrganization.organization_id == organization_id,
)
)
)
@@ -249,7 +265,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
organization_id=organization_id,
role=role,
is_active=True,
custom_permissions=custom_permissions
custom_permissions=custom_permissions,
)
db.add(user_org)
await db.commit()
@@ -257,19 +273,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return user_org
except IntegrityError as e:
await db.rollback()
logger.error(f"Integrity error adding user to organization: {str(e)}")
logger.error(f"Integrity error adding user to organization: {e!s}")
raise ValueError("Failed to add user to organization")
except Exception as e:
await db.rollback()
logger.error(f"Error adding user to organization: {str(e)}", exc_info=True)
logger.error(f"Error adding user to organization: {e!s}", exc_info=True)
raise
async def remove_user(
self,
db: AsyncSession,
*,
organization_id: UUID,
user_id: UUID
self, db: AsyncSession, *, organization_id: UUID, user_id: UUID
) -> bool:
"""Remove a user from an organization (soft delete)."""
try:
@@ -277,7 +289,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
select(UserOrganization).where(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id
UserOrganization.organization_id == organization_id,
)
)
)
@@ -291,7 +303,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return True
except Exception as e:
await db.rollback()
logger.error(f"Error removing user from organization: {str(e)}", exc_info=True)
logger.error(f"Error removing user from organization: {e!s}", exc_info=True)
raise
async def update_user_role(
@@ -301,15 +313,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
organization_id: UUID,
user_id: UUID,
role: OrganizationRole,
custom_permissions: Optional[str] = None
) -> Optional[UserOrganization]:
custom_permissions: str | None = None,
) -> UserOrganization | None:
"""Update a user's role in an organization."""
try:
result = await db.execute(
select(UserOrganization).where(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id
UserOrganization.organization_id == organization_id,
)
)
)
@@ -326,7 +338,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return user_org
except Exception as e:
await db.rollback()
logger.error(f"Error updating user role: {str(e)}", exc_info=True)
logger.error(f"Error updating user role: {e!s}", exc_info=True)
raise
async def get_organization_members(
@@ -336,8 +348,8 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
organization_id: UUID,
skip: int = 0,
limit: int = 100,
is_active: bool = True
) -> tuple[List[Dict[str, Any]], int]:
is_active: bool = True,
) -> tuple[list[dict[str, Any]], int]:
"""
Get members of an organization with user details.
@@ -359,46 +371,55 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
count_query = select(func.count()).select_from(
select(UserOrganization)
.where(UserOrganization.organization_id == organization_id)
.where(UserOrganization.is_active == is_active if is_active is not None else True)
.where(
UserOrganization.is_active == is_active
if is_active is not None
else True
)
.alias()
)
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply ordering and pagination
query = query.order_by(UserOrganization.created_at.desc()).offset(skip).limit(limit)
query = (
query.order_by(UserOrganization.created_at.desc())
.offset(skip)
.limit(limit)
)
result = await db.execute(query)
results = result.all()
members = []
for user_org, user in results:
members.append({
"user_id": user.id,
"email": user.email,
"first_name": user.first_name,
"last_name": user.last_name,
"role": user_org.role,
"is_active": user_org.is_active,
"joined_at": user_org.created_at
})
members.append(
{
"user_id": user.id,
"email": user.email,
"first_name": user.first_name,
"last_name": user.last_name,
"role": user_org.role,
"is_active": user_org.is_active,
"joined_at": user_org.created_at,
}
)
return members, total
except Exception as e:
logger.error(f"Error getting organization members: {str(e)}")
logger.error(f"Error getting organization members: {e!s}")
raise
async def get_user_organizations(
self,
db: AsyncSession,
*,
user_id: UUID,
is_active: bool = True
) -> List[Organization]:
self, db: AsyncSession, *, user_id: UUID, is_active: bool = True
) -> list[Organization]:
"""Get all organizations a user belongs to."""
try:
query = (
select(Organization)
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
.join(
UserOrganization,
Organization.id == UserOrganization.organization_id,
)
.where(UserOrganization.user_id == user_id)
)
@@ -408,16 +429,12 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting user organizations: {str(e)}")
logger.error(f"Error getting user organizations: {e!s}")
raise
async def get_user_organizations_with_details(
self,
db: AsyncSession,
*,
user_id: UUID,
is_active: bool = True
) -> List[Dict[str, Any]]:
self, db: AsyncSession, *, user_id: UUID, is_active: bool = True
) -> list[dict[str, Any]]:
"""
Get user's organizations with role and member count in SINGLE QUERY.
Eliminates N+1 problem by using subquery for member counts.
@@ -430,9 +447,9 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
member_count_subq = (
select(
UserOrganization.organization_id,
func.count(UserOrganization.user_id).label('member_count')
func.count(UserOrganization.user_id).label("member_count"),
)
.where(UserOrganization.is_active == True)
.where(UserOrganization.is_active)
.group_by(UserOrganization.organization_id)
.subquery()
)
@@ -442,10 +459,18 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
select(
Organization,
UserOrganization.role,
func.coalesce(member_count_subq.c.member_count, 0).label('member_count')
func.coalesce(member_count_subq.c.member_count, 0).label(
"member_count"
),
)
.join(
UserOrganization,
Organization.id == UserOrganization.organization_id,
)
.outerjoin(
member_count_subq,
Organization.id == member_count_subq.c.organization_id,
)
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
.outerjoin(member_count_subq, Organization.id == member_count_subq.c.organization_id)
.where(UserOrganization.user_id == user_id)
)
@@ -456,25 +481,19 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
rows = result.all()
return [
{
'organization': org,
'role': role,
'member_count': member_count
}
{"organization": org, "role": role, "member_count": member_count}
for org, role, member_count in rows
]
except Exception as e:
logger.error(f"Error getting user organizations with details: {str(e)}", exc_info=True)
logger.error(
f"Error getting user organizations with details: {e!s}", exc_info=True
)
raise
async def get_user_role_in_org(
self,
db: AsyncSession,
*,
user_id: UUID,
organization_id: UUID
) -> Optional[OrganizationRole]:
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
) -> OrganizationRole | None:
"""Get a user's role in a specific organization."""
try:
result = await db.execute(
@@ -482,7 +501,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id,
UserOrganization.is_active == True
UserOrganization.is_active,
)
)
)
@@ -490,29 +509,25 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return user_org.role if user_org else None
except Exception as e:
logger.error(f"Error getting user role in org: {str(e)}")
logger.error(f"Error getting user role in org: {e!s}")
raise
async def is_user_org_owner(
self,
db: AsyncSession,
*,
user_id: UUID,
organization_id: UUID
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
) -> bool:
"""Check if a user is an owner of an organization."""
role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
role = await self.get_user_role_in_org(
db, user_id=user_id, organization_id=organization_id
)
return role == OrganizationRole.OWNER
async def is_user_org_admin(
self,
db: AsyncSession,
*,
user_id: UUID,
organization_id: UUID
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
) -> bool:
"""Check if a user is an owner or admin of an organization."""
role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
role = await self.get_user_role_in_org(
db, user_id=user_id, organization_id=organization_id
)
return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]

View File

@@ -1,13 +1,13 @@
"""
Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns.
"""
import logging
import uuid
from datetime import datetime, timezone, timedelta
from typing import List, Optional
from datetime import UTC, datetime, timedelta
from uuid import UUID
from sqlalchemy import and_, select, update, delete, func
from sqlalchemy import and_, delete, func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
"""Async CRUD operations for user sessions."""
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]:
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
"""
Get session by refresh token JTI.
@@ -38,10 +38,12 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting session by JTI {jti}: {str(e)}")
logger.error(f"Error getting session by JTI {jti}: {e!s}")
raise
async def get_active_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]:
async def get_active_by_jti(
self, db: AsyncSession, *, jti: str
) -> UserSession | None:
"""
Get active session by refresh token JTI.
@@ -57,13 +59,13 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
select(UserSession).where(
and_(
UserSession.refresh_token_jti == jti,
UserSession.is_active == True
UserSession.is_active,
)
)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting active session by JTI {jti}: {str(e)}")
logger.error(f"Error getting active session by JTI {jti}: {e!s}")
raise
async def get_user_sessions(
@@ -72,8 +74,8 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
*,
user_id: str,
active_only: bool = True,
with_user: bool = False
) -> List[UserSession]:
with_user: bool = False,
) -> list[UserSession]:
"""
Get all sessions for a user with optional eager loading.
@@ -97,20 +99,17 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
query = query.options(joinedload(UserSession.user))
if active_only:
query = query.where(UserSession.is_active == True)
query = query.where(UserSession.is_active)
query = query.order_by(UserSession.last_used_at.desc())
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting sessions for user {user_id}: {str(e)}")
logger.error(f"Error getting sessions for user {user_id}: {e!s}")
raise
async def create_session(
self,
db: AsyncSession,
*,
obj_in: SessionCreate
self, db: AsyncSession, *, obj_in: SessionCreate
) -> UserSession:
"""
Create a new user session.
@@ -151,10 +150,12 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
return db_obj
except Exception as e:
await db.rollback()
logger.error(f"Error creating session: {str(e)}", exc_info=True)
raise ValueError(f"Failed to create session: {str(e)}")
logger.error(f"Error creating session: {e!s}", exc_info=True)
raise ValueError(f"Failed to create session: {e!s}")
async def deactivate(self, db: AsyncSession, *, session_id: str) -> Optional[UserSession]:
async def deactivate(
self, db: AsyncSession, *, session_id: str
) -> UserSession | None:
"""
Deactivate a session (logout from device).
@@ -184,14 +185,11 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
return session
except Exception as e:
await db.rollback()
logger.error(f"Error deactivating session {session_id}: {str(e)}")
logger.error(f"Error deactivating session {session_id}: {e!s}")
raise
async def deactivate_all_user_sessions(
self,
db: AsyncSession,
*,
user_id: str
self, db: AsyncSession, *, user_id: str
) -> int:
"""
Deactivate all active sessions for a user (logout from all devices).
@@ -209,12 +207,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
stmt = (
update(UserSession)
.where(
and_(
UserSession.user_id == user_uuid,
UserSession.is_active == True
)
)
.where(and_(UserSession.user_id == user_uuid, UserSession.is_active))
.values(is_active=False)
)
@@ -228,14 +221,11 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
return count
except Exception as e:
await db.rollback()
logger.error(f"Error deactivating all sessions for user {user_id}: {str(e)}")
logger.error(f"Error deactivating all sessions for user {user_id}: {e!s}")
raise
async def update_last_used(
self,
db: AsyncSession,
*,
session: UserSession
self, db: AsyncSession, *, session: UserSession
) -> UserSession:
"""
Update the last_used_at timestamp for a session.
@@ -248,14 +238,14 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
Updated UserSession
"""
try:
session.last_used_at = datetime.now(timezone.utc)
session.last_used_at = datetime.now(UTC)
db.add(session)
await db.commit()
await db.refresh(session)
return session
except Exception as e:
await db.rollback()
logger.error(f"Error updating last_used for session {session.id}: {str(e)}")
logger.error(f"Error updating last_used for session {session.id}: {e!s}")
raise
async def update_refresh_token(
@@ -264,7 +254,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
*,
session: UserSession,
new_jti: str,
new_expires_at: datetime
new_expires_at: datetime,
) -> UserSession:
"""
Update session with new refresh token JTI and expiration.
@@ -283,14 +273,16 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
try:
session.refresh_token_jti = new_jti
session.expires_at = new_expires_at
session.last_used_at = datetime.now(timezone.utc)
session.last_used_at = datetime.now(UTC)
db.add(session)
await db.commit()
await db.refresh(session)
return session
except Exception as e:
await db.rollback()
logger.error(f"Error updating refresh token for session {session.id}: {str(e)}")
logger.error(
f"Error updating refresh token for session {session.id}: {e!s}"
)
raise
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
@@ -311,15 +303,15 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
Number of sessions deleted
"""
try:
cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days)
now = datetime.now(timezone.utc)
cutoff_date = datetime.now(UTC) - timedelta(days=keep_days)
now = datetime.now(UTC)
# Use bulk DELETE with WHERE clause - single query
stmt = delete(UserSession).where(
and_(
UserSession.is_active == False,
not UserSession.is_active,
UserSession.expires_at < now,
UserSession.created_at < cutoff_date
UserSession.created_at < cutoff_date,
)
)
@@ -334,15 +326,10 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
return count
except Exception as e:
await db.rollback()
logger.error(f"Error cleaning up expired sessions: {str(e)}")
logger.error(f"Error cleaning up expired sessions: {e!s}")
raise
async def cleanup_expired_for_user(
self,
db: AsyncSession,
*,
user_id: str
) -> int:
async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int:
"""
Clean up expired and inactive sessions for a specific user.
@@ -363,14 +350,14 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
logger.error(f"Invalid UUID format: {user_id}")
raise ValueError(f"Invalid user ID format: {user_id}")
now = datetime.now(timezone.utc)
now = datetime.now(UTC)
# Use bulk DELETE with WHERE clause - single query
stmt = delete(UserSession).where(
and_(
UserSession.user_id == uuid_obj,
UserSession.is_active == False,
UserSession.expires_at < now
not UserSession.is_active,
UserSession.expires_at < now,
)
)
@@ -388,7 +375,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
except Exception as e:
await db.rollback()
logger.error(
f"Error cleaning up expired sessions for user {user_id}: {str(e)}"
f"Error cleaning up expired sessions for user {user_id}: {e!s}"
)
raise
@@ -409,15 +396,12 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
result = await db.execute(
select(func.count(UserSession.id)).where(
and_(
UserSession.user_id == user_uuid,
UserSession.is_active == True
)
and_(UserSession.user_id == user_uuid, UserSession.is_active)
)
)
return result.scalar_one()
except Exception as e:
logger.error(f"Error counting sessions for user {user_id}: {str(e)}")
logger.error(f"Error counting sessions for user {user_id}: {e!s}")
raise
async def get_all_sessions(
@@ -427,8 +411,8 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
skip: int = 0,
limit: int = 100,
active_only: bool = True,
with_user: bool = True
) -> tuple[List[UserSession], int]:
with_user: bool = True,
) -> tuple[list[UserSession], int]:
"""
Get all sessions across all users with pagination (admin only).
@@ -451,18 +435,22 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
query = query.options(joinedload(UserSession.user))
if active_only:
query = query.where(UserSession.is_active == True)
query = query.where(UserSession.is_active)
# Get total count
count_query = select(func.count(UserSession.id))
if active_only:
count_query = count_query.where(UserSession.is_active == True)
count_query = count_query.where(UserSession.is_active)
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply pagination and ordering
query = query.order_by(UserSession.last_used_at.desc()).offset(skip).limit(limit)
query = (
query.order_by(UserSession.last_used_at.desc())
.offset(skip)
.limit(limit)
)
result = await db.execute(query)
sessions = list(result.scalars().all())
@@ -470,7 +458,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
return sessions, total
except Exception as e:
logger.error(f"Error getting all sessions: {str(e)}", exc_info=True)
logger.error(f"Error getting all sessions: {e!s}", exc_info=True)
raise

View File

@@ -1,8 +1,9 @@
# app/crud/user_async.py
"""Async CRUD operations for User model using SQLAlchemy 2.0 patterns."""
import logging
from datetime import datetime, timezone
from typing import Optional, Union, Dict, Any, List, Tuple
from datetime import UTC, datetime
from typing import Any
from uuid import UUID
from sqlalchemy import or_, select, update
@@ -20,15 +21,13 @@ logger = logging.getLogger(__name__)
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
"""Async CRUD operations for User model."""
async def get_by_email(self, db: AsyncSession, *, email: str) -> Optional[User]:
async def get_by_email(self, db: AsyncSession, *, email: str) -> User | None:
"""Get user by email address."""
try:
result = await db.execute(
select(User).where(User.email == email)
)
result = await db.execute(select(User).where(User.email == email))
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting user by email {email}: {str(e)}")
logger.error(f"Error getting user by email {email}: {e!s}")
raise
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
@@ -42,9 +41,13 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
password_hash=password_hash,
first_name=obj_in.first_name,
last_name=obj_in.last_name,
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None,
is_superuser=obj_in.is_superuser if hasattr(obj_in, 'is_superuser') else False,
preferences={}
phone_number=obj_in.phone_number
if hasattr(obj_in, "phone_number")
else None,
is_superuser=obj_in.is_superuser
if hasattr(obj_in, "is_superuser")
else False,
preferences={},
)
db.add(db_obj)
await db.commit()
@@ -52,7 +55,7 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "email" in error_msg.lower():
logger.warning(f"Duplicate email attempted: {obj_in.email}")
raise ValueError(f"User with email {obj_in.email} already exists")
@@ -60,15 +63,11 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
raise ValueError(f"Database integrity error: {error_msg}")
except Exception as e:
await db.rollback()
logger.error(f"Unexpected error creating user: {str(e)}", exc_info=True)
logger.error(f"Unexpected error creating user: {e!s}", exc_info=True)
raise
async def update(
self,
db: AsyncSession,
*,
db_obj: User,
obj_in: Union[UserUpdate, Dict[str, Any]]
self, db: AsyncSession, *, db_obj: User, obj_in: UserUpdate | dict[str, Any]
) -> User:
"""Update user with async password hashing if password is updated."""
if isinstance(obj_in, dict):
@@ -79,7 +78,9 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
# Handle password separately if it exists in update data
# Hash password asynchronously to avoid blocking event loop
if "password" in update_data:
update_data["password_hash"] = await get_password_hash_async(update_data["password"])
update_data["password_hash"] = await get_password_hash_async(
update_data["password"]
)
del update_data["password"]
return await super().update(db, db_obj=db_obj, obj_in=update_data)
@@ -90,11 +91,11 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
*,
skip: int = 0,
limit: int = 100,
sort_by: Optional[str] = None,
sort_by: str | None = None,
sort_order: str = "asc",
filters: Optional[Dict[str, Any]] = None,
search: Optional[str] = None
) -> Tuple[List[User], int]:
filters: dict[str, Any] | None = None,
search: str | None = None,
) -> tuple[list[User], int]:
"""
Get multiple users with total count, filtering, sorting, and search.
@@ -136,12 +137,13 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
search_filter = or_(
User.email.ilike(f"%{search}%"),
User.first_name.ilike(f"%{search}%"),
User.last_name.ilike(f"%{search}%")
User.last_name.ilike(f"%{search}%"),
)
query = query.where(search_filter)
# Get total count
from sqlalchemy import func
count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query)
total = count_result.scalar_one()
@@ -162,15 +164,11 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
return users, total
except Exception as e:
logger.error(f"Error retrieving paginated users: {str(e)}")
logger.error(f"Error retrieving paginated users: {e!s}")
raise
async def bulk_update_status(
self,
db: AsyncSession,
*,
user_ids: List[UUID],
is_active: bool
self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool
) -> int:
"""
Bulk update is_active status for multiple users.
@@ -192,7 +190,7 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
update(User)
.where(User.id.in_(user_ids))
.where(User.deleted_at.is_(None)) # Don't update deleted users
.values(is_active=is_active, updated_at=datetime.now(timezone.utc))
.values(is_active=is_active, updated_at=datetime.now(UTC))
)
result = await db.execute(stmt)
@@ -204,15 +202,15 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
except Exception as e:
await db.rollback()
logger.error(f"Error bulk updating user status: {str(e)}", exc_info=True)
logger.error(f"Error bulk updating user status: {e!s}", exc_info=True)
raise
async def bulk_soft_delete(
self,
db: AsyncSession,
*,
user_ids: List[UUID],
exclude_user_id: Optional[UUID] = None
user_ids: list[UUID],
exclude_user_id: UUID | None = None,
) -> int:
"""
Bulk soft delete multiple users.
@@ -239,11 +237,13 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
stmt = (
update(User)
.where(User.id.in_(filtered_ids))
.where(User.deleted_at.is_(None)) # Don't re-delete already deleted users
.where(
User.deleted_at.is_(None)
) # Don't re-delete already deleted users
.values(
deleted_at=datetime.now(timezone.utc),
deleted_at=datetime.now(UTC),
is_active=False,
updated_at=datetime.now(timezone.utc)
updated_at=datetime.now(UTC),
)
)
@@ -256,7 +256,7 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
except Exception as e:
await db.rollback()
logger.error(f"Error bulk deleting users: {str(e)}", exc_info=True)
logger.error(f"Error bulk deleting users: {e!s}", exc_info=True)
raise
def is_active(self, user: User) -> bool: