Add async CRUD classes for organizations, sessions, and users

- Implemented `CRUDOrganizationAsync`, `CRUDSessionAsync`, and `CRUDUserAsync` with full async support for database operations.
- Added filtering, sorting, pagination, and advanced methods for organization management.
- Developed session-specific logic, including cleanup, per-device management, and security enhancements.
- Enhanced user CRUD with password hashing and comprehensive update handling.
This commit is contained in:
Felipe Cardoso
2025-10-31 21:59:40 +01:00
parent 26ff08d9f9
commit 1f15ee6db3
3 changed files with 917 additions and 0 deletions

View File

@@ -0,0 +1,384 @@
# app/crud/organization_async.py
"""Async CRUD operations for Organization model using SQLAlchemy 2.0 patterns."""
from typing import Optional, List, Dict, Any
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import IntegrityError
from sqlalchemy import func, or_, and_, select
from app.crud.base_async import CRUDBaseAsync
from app.models.organization import Organization
from app.models.user_organization import UserOrganization, OrganizationRole
from app.models.user import User
from app.schemas.organizations import (
OrganizationCreate,
OrganizationUpdate,
)
import logging
logger = logging.getLogger(__name__)
class CRUDOrganizationAsync(CRUDBaseAsync[Organization, OrganizationCreate, OrganizationUpdate]):
"""Async CRUD operations for Organization model."""
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Optional[Organization]:
"""Get organization by slug."""
try:
result = await db.execute(
select(Organization).where(Organization.slug == slug)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting organization by slug {slug}: {str(e)}")
raise
async def create(self, db: AsyncSession, *, obj_in: OrganizationCreate) -> Organization:
"""Create a new organization with error handling."""
try:
db_obj = Organization(
name=obj_in.name,
slug=obj_in.slug,
description=obj_in.description,
is_active=obj_in.is_active,
settings=obj_in.settings or {}
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
if "slug" in error_msg.lower():
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
raise ValueError(f"Organization with slug '{obj_in.slug}' already exists")
logger.error(f"Integrity error creating organization: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except Exception as e:
await db.rollback()
logger.error(f"Unexpected error creating organization: {str(e)}", exc_info=True)
raise
async def get_multi_with_filters(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None,
search: Optional[str] = None,
sort_by: str = "created_at",
sort_order: str = "desc"
) -> tuple[List[Organization], int]:
"""
Get multiple organizations with filtering, searching, and sorting.
Returns:
Tuple of (organizations list, total count)
"""
try:
query = select(Organization)
# Apply filters
if is_active is not None:
query = query.where(Organization.is_active == is_active)
if search:
search_filter = or_(
Organization.name.ilike(f"%{search}%"),
Organization.slug.ilike(f"%{search}%"),
Organization.description.ilike(f"%{search}%")
)
query = query.where(search_filter)
# Get total count before pagination
count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply sorting
sort_column = getattr(Organization, sort_by, Organization.created_at)
if sort_order == "desc":
query = query.order_by(sort_column.desc())
else:
query = query.order_by(sort_column.asc())
# Apply pagination
query = query.offset(skip).limit(limit)
result = await db.execute(query)
organizations = list(result.scalars().all())
return organizations, total
except Exception as e:
logger.error(f"Error getting organizations with filters: {str(e)}")
raise
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
"""Get the count of active members in an organization."""
try:
result = await db.execute(
select(func.count(UserOrganization.user_id)).where(
and_(
UserOrganization.organization_id == organization_id,
UserOrganization.is_active == True
)
)
)
return result.scalar_one() or 0
except Exception as e:
logger.error(f"Error getting member count for organization {organization_id}: {str(e)}")
raise
async def add_user(
self,
db: AsyncSession,
*,
organization_id: UUID,
user_id: UUID,
role: OrganizationRole = OrganizationRole.MEMBER,
custom_permissions: Optional[str] = None
) -> UserOrganization:
"""Add a user to an organization with a specific role."""
try:
# Check if relationship already exists
result = await db.execute(
select(UserOrganization).where(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id
)
)
)
existing = result.scalar_one_or_none()
if existing:
# Reactivate if inactive, or raise error if already active
if not existing.is_active:
existing.is_active = True
existing.role = role
existing.custom_permissions = custom_permissions
await db.commit()
await db.refresh(existing)
return existing
else:
raise ValueError("User is already a member of this organization")
# Create new relationship
user_org = UserOrganization(
user_id=user_id,
organization_id=organization_id,
role=role,
is_active=True,
custom_permissions=custom_permissions
)
db.add(user_org)
await db.commit()
await db.refresh(user_org)
return user_org
except IntegrityError as e:
await db.rollback()
logger.error(f"Integrity error adding user to organization: {str(e)}")
raise ValueError("Failed to add user to organization")
except Exception as e:
await db.rollback()
logger.error(f"Error adding user to organization: {str(e)}", exc_info=True)
raise
async def remove_user(
self,
db: AsyncSession,
*,
organization_id: UUID,
user_id: UUID
) -> bool:
"""Remove a user from an organization (soft delete)."""
try:
result = await db.execute(
select(UserOrganization).where(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id
)
)
)
user_org = result.scalar_one_or_none()
if not user_org:
return False
user_org.is_active = False
await db.commit()
return True
except Exception as e:
await db.rollback()
logger.error(f"Error removing user from organization: {str(e)}", exc_info=True)
raise
async def update_user_role(
self,
db: AsyncSession,
*,
organization_id: UUID,
user_id: UUID,
role: OrganizationRole,
custom_permissions: Optional[str] = None
) -> Optional[UserOrganization]:
"""Update a user's role in an organization."""
try:
result = await db.execute(
select(UserOrganization).where(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id
)
)
)
user_org = result.scalar_one_or_none()
if not user_org:
return None
user_org.role = role
if custom_permissions is not None:
user_org.custom_permissions = custom_permissions
await db.commit()
await db.refresh(user_org)
return user_org
except Exception as e:
await db.rollback()
logger.error(f"Error updating user role: {str(e)}", exc_info=True)
raise
async def get_organization_members(
self,
db: AsyncSession,
*,
organization_id: UUID,
skip: int = 0,
limit: int = 100,
is_active: bool = True
) -> tuple[List[Dict[str, Any]], int]:
"""
Get members of an organization with user details.
Returns:
Tuple of (members list with user details, total count)
"""
try:
# Build query with join
query = (
select(UserOrganization, User)
.join(User, UserOrganization.user_id == User.id)
.where(UserOrganization.organization_id == organization_id)
)
if is_active is not None:
query = query.where(UserOrganization.is_active == is_active)
# Get total count
count_query = select(func.count()).select_from(
select(UserOrganization)
.where(UserOrganization.organization_id == organization_id)
.where(UserOrganization.is_active == is_active if is_active is not None else True)
.alias()
)
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply ordering and pagination
query = query.order_by(UserOrganization.created_at.desc()).offset(skip).limit(limit)
result = await db.execute(query)
results = result.all()
members = []
for user_org, user in results:
members.append({
"user_id": user.id,
"email": user.email,
"first_name": user.first_name,
"last_name": user.last_name,
"role": user_org.role,
"is_active": user_org.is_active,
"joined_at": user_org.created_at
})
return members, total
except Exception as e:
logger.error(f"Error getting organization members: {str(e)}")
raise
async def get_user_organizations(
self,
db: AsyncSession,
*,
user_id: UUID,
is_active: bool = True
) -> List[Organization]:
"""Get all organizations a user belongs to."""
try:
query = (
select(Organization)
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
.where(UserOrganization.user_id == user_id)
)
if is_active is not None:
query = query.where(UserOrganization.is_active == is_active)
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting user organizations: {str(e)}")
raise
async def get_user_role_in_org(
self,
db: AsyncSession,
*,
user_id: UUID,
organization_id: UUID
) -> Optional[OrganizationRole]:
"""Get a user's role in a specific organization."""
try:
result = await db.execute(
select(UserOrganization).where(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id,
UserOrganization.is_active == True
)
)
)
user_org = result.scalar_one_or_none()
return user_org.role if user_org else None
except Exception as e:
logger.error(f"Error getting user role in org: {str(e)}")
raise
async def is_user_org_owner(
self,
db: AsyncSession,
*,
user_id: UUID,
organization_id: UUID
) -> bool:
"""Check if a user is an owner of an organization."""
role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
return role == OrganizationRole.OWNER
async def is_user_org_admin(
self,
db: AsyncSession,
*,
user_id: UUID,
organization_id: UUID
) -> bool:
"""Check if a user is an owner or admin of an organization."""
role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
# Create a singleton instance for use across the application
organization_async = CRUDOrganizationAsync(Organization)

363
backend/app/crud/session_async.py Executable file
View File

@@ -0,0 +1,363 @@
"""
Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns.
"""
from datetime import datetime, timezone, timedelta
from typing import List, Optional
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import and_, select, update, func
import logging
from app.crud.base_async import CRUDBaseAsync
from app.models.user_session import UserSession
from app.schemas.sessions import SessionCreate, SessionUpdate
logger = logging.getLogger(__name__)
class CRUDSessionAsync(CRUDBaseAsync[UserSession, SessionCreate, SessionUpdate]):
"""Async CRUD operations for user sessions."""
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]:
"""
Get session by refresh token JTI.
Args:
db: Database session
jti: Refresh token JWT ID
Returns:
UserSession if found, None otherwise
"""
try:
result = await db.execute(
select(UserSession).where(UserSession.refresh_token_jti == jti)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting session by JTI {jti}: {str(e)}")
raise
async def get_active_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]:
"""
Get active session by refresh token JTI.
Args:
db: Database session
jti: Refresh token JWT ID
Returns:
Active UserSession if found, None otherwise
"""
try:
result = await db.execute(
select(UserSession).where(
and_(
UserSession.refresh_token_jti == jti,
UserSession.is_active == True
)
)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting active session by JTI {jti}: {str(e)}")
raise
async def get_user_sessions(
self,
db: AsyncSession,
*,
user_id: str,
active_only: bool = True
) -> List[UserSession]:
"""
Get all sessions for a user.
Args:
db: Database session
user_id: User ID
active_only: If True, return only active sessions
Returns:
List of UserSession objects
"""
try:
# Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
query = select(UserSession).where(UserSession.user_id == user_uuid)
if active_only:
query = query.where(UserSession.is_active == True)
query = query.order_by(UserSession.last_used_at.desc())
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting sessions for user {user_id}: {str(e)}")
raise
async def create_session(
self,
db: AsyncSession,
*,
obj_in: SessionCreate
) -> UserSession:
"""
Create a new user session.
Args:
db: Database session
obj_in: SessionCreate schema with session data
Returns:
Created UserSession
Raises:
ValueError: If session creation fails
"""
try:
db_obj = UserSession(
user_id=obj_in.user_id,
refresh_token_jti=obj_in.refresh_token_jti,
device_name=obj_in.device_name,
device_id=obj_in.device_id,
ip_address=obj_in.ip_address,
user_agent=obj_in.user_agent,
last_used_at=obj_in.last_used_at,
expires_at=obj_in.expires_at,
is_active=True,
location_city=obj_in.location_city,
location_country=obj_in.location_country,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.info(
f"Session created for user {obj_in.user_id} from {obj_in.device_name} "
f"(IP: {obj_in.ip_address})"
)
return db_obj
except Exception as e:
await db.rollback()
logger.error(f"Error creating session: {str(e)}", exc_info=True)
raise ValueError(f"Failed to create session: {str(e)}")
async def deactivate(self, db: AsyncSession, *, session_id: str) -> Optional[UserSession]:
"""
Deactivate a session (logout from device).
Args:
db: Database session
session_id: Session UUID
Returns:
Deactivated UserSession if found, None otherwise
"""
try:
session = await self.get(db, id=session_id)
if not session:
logger.warning(f"Session {session_id} not found for deactivation")
return None
session.is_active = False
db.add(session)
await db.commit()
await db.refresh(session)
logger.info(
f"Session {session_id} deactivated for user {session.user_id} "
f"({session.device_name})"
)
return session
except Exception as e:
await db.rollback()
logger.error(f"Error deactivating session {session_id}: {str(e)}")
raise
async def deactivate_all_user_sessions(
self,
db: AsyncSession,
*,
user_id: str
) -> int:
"""
Deactivate all active sessions for a user (logout from all devices).
Args:
db: Database session
user_id: User ID
Returns:
Number of sessions deactivated
"""
try:
# Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
stmt = (
update(UserSession)
.where(
and_(
UserSession.user_id == user_uuid,
UserSession.is_active == True
)
)
.values(is_active=False)
)
result = await db.execute(stmt)
await db.commit()
count = result.rowcount
logger.info(f"Deactivated {count} sessions for user {user_id}")
return count
except Exception as e:
await db.rollback()
logger.error(f"Error deactivating all sessions for user {user_id}: {str(e)}")
raise
async def update_last_used(
self,
db: AsyncSession,
*,
session: UserSession
) -> UserSession:
"""
Update the last_used_at timestamp for a session.
Args:
db: Database session
session: UserSession object
Returns:
Updated UserSession
"""
try:
session.last_used_at = datetime.now(timezone.utc)
db.add(session)
await db.commit()
await db.refresh(session)
return session
except Exception as e:
await db.rollback()
logger.error(f"Error updating last_used for session {session.id}: {str(e)}")
raise
async def update_refresh_token(
self,
db: AsyncSession,
*,
session: UserSession,
new_jti: str,
new_expires_at: datetime
) -> UserSession:
"""
Update session with new refresh token JTI and expiration.
Called during token refresh.
Args:
db: Database session
session: UserSession object
new_jti: New refresh token JTI
new_expires_at: New expiration datetime
Returns:
Updated UserSession
"""
try:
session.refresh_token_jti = new_jti
session.expires_at = new_expires_at
session.last_used_at = datetime.now(timezone.utc)
db.add(session)
await db.commit()
await db.refresh(session)
return session
except Exception as e:
await db.rollback()
logger.error(f"Error updating refresh token for session {session.id}: {str(e)}")
raise
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
"""
Clean up expired sessions.
Deletes sessions that are:
- Expired AND inactive
- Older than keep_days
Args:
db: Database session
keep_days: Keep inactive sessions for this many days (for audit)
Returns:
Number of sessions deleted
"""
try:
cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days)
# Get sessions to delete
stmt = select(UserSession).where(
and_(
UserSession.is_active == False,
UserSession.expires_at < datetime.now(timezone.utc),
UserSession.created_at < cutoff_date
)
)
result = await db.execute(stmt)
sessions_to_delete = list(result.scalars().all())
# Delete them
for session in sessions_to_delete:
await db.delete(session)
await db.commit()
count = len(sessions_to_delete)
if count > 0:
logger.info(f"Cleaned up {count} expired sessions")
return count
except Exception as e:
await db.rollback()
logger.error(f"Error cleaning up expired sessions: {str(e)}")
raise
async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int:
"""
Get count of active sessions for a user.
Args:
db: Database session
user_id: User ID
Returns:
Number of active sessions
"""
try:
# Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
result = await db.execute(
select(func.count(UserSession.id)).where(
and_(
UserSession.user_id == user_uuid,
UserSession.is_active == True
)
)
)
return result.scalar_one()
except Exception as e:
logger.error(f"Error counting sessions for user {user_id}: {str(e)}")
raise
# Create singleton instance
session_async = CRUDSessionAsync(UserSession)

170
backend/app/crud/user_async.py Executable file
View File

@@ -0,0 +1,170 @@
# app/crud/user_async.py
"""Async CRUD operations for User model using SQLAlchemy 2.0 patterns."""
from typing import Optional, Union, Dict, Any, List, Tuple
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import IntegrityError
from sqlalchemy import or_, select
from app.crud.base_async import CRUDBaseAsync
from app.models.user import User
from app.schemas.users import UserCreate, UserUpdate
from app.core.auth import get_password_hash
import logging
logger = logging.getLogger(__name__)
class CRUDUserAsync(CRUDBaseAsync[User, UserCreate, UserUpdate]):
"""Async CRUD operations for User model."""
async def get_by_email(self, db: AsyncSession, *, email: str) -> Optional[User]:
"""Get user by email address."""
try:
result = await db.execute(
select(User).where(User.email == email)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting user by email {email}: {str(e)}")
raise
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
"""Create a new user with password hashing and error handling."""
try:
db_obj = User(
email=obj_in.email,
password_hash=get_password_hash(obj_in.password),
first_name=obj_in.first_name,
last_name=obj_in.last_name,
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None,
is_superuser=obj_in.is_superuser if hasattr(obj_in, 'is_superuser') else False,
preferences={}
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
if "email" in error_msg.lower():
logger.warning(f"Duplicate email attempted: {obj_in.email}")
raise ValueError(f"User with email {obj_in.email} already exists")
logger.error(f"Integrity error creating user: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except Exception as e:
await db.rollback()
logger.error(f"Unexpected error creating user: {str(e)}", exc_info=True)
raise
async def update(
self,
db: AsyncSession,
*,
db_obj: User,
obj_in: Union[UserUpdate, Dict[str, Any]]
) -> User:
"""Update user with password hashing if password is updated."""
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.model_dump(exclude_unset=True)
# Handle password separately if it exists in update data
if "password" in update_data:
update_data["password_hash"] = get_password_hash(update_data["password"])
del update_data["password"]
return await super().update(db, db_obj=db_obj, obj_in=update_data)
async def get_multi_with_total(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
sort_by: Optional[str] = None,
sort_order: str = "asc",
filters: Optional[Dict[str, Any]] = None,
search: Optional[str] = None
) -> Tuple[List[User], int]:
"""
Get multiple users with total count, filtering, sorting, and search.
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
sort_by: Field name to sort by
sort_order: Sort order ("asc" or "desc")
filters: Dictionary of filters (field_name: value)
search: Search term to match against email, first_name, last_name
Returns:
Tuple of (users list, total count)
"""
# Validate pagination
if skip < 0:
raise ValueError("skip must be non-negative")
if limit < 0:
raise ValueError("limit must be non-negative")
if limit > 1000:
raise ValueError("Maximum limit is 1000")
try:
# Build base query
query = select(User)
# Exclude soft-deleted users
query = query.where(User.deleted_at.is_(None))
# Apply filters
if filters:
for field, value in filters.items():
if hasattr(User, field) and value is not None:
query = query.where(getattr(User, field) == value)
# Apply search
if search:
search_filter = or_(
User.email.ilike(f"%{search}%"),
User.first_name.ilike(f"%{search}%"),
User.last_name.ilike(f"%{search}%")
)
query = query.where(search_filter)
# Get total count
from sqlalchemy import func
count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply sorting
if sort_by and hasattr(User, sort_by):
sort_column = getattr(User, sort_by)
if sort_order.lower() == "desc":
query = query.order_by(sort_column.desc())
else:
query = query.order_by(sort_column.asc())
# Apply pagination
query = query.offset(skip).limit(limit)
result = await db.execute(query)
users = list(result.scalars().all())
return users, total
except Exception as e:
logger.error(f"Error retrieving paginated users: {str(e)}")
raise
def is_active(self, user: User) -> bool:
"""Check if user is active."""
return user.is_active
def is_superuser(self, user: User) -> bool:
"""Check if user is a superuser."""
return user.is_superuser
# Create a singleton instance for use across the application
user_async = CRUDUserAsync(User)