diff --git a/backend/app/crud/organization_async.py b/backend/app/crud/organization_async.py new file mode 100755 index 0000000..c92f3be --- /dev/null +++ b/backend/app/crud/organization_async.py @@ -0,0 +1,384 @@ +# app/crud/organization_async.py +"""Async CRUD operations for Organization model using SQLAlchemy 2.0 patterns.""" +from typing import Optional, List, Dict, Any +from uuid import UUID +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.exc import IntegrityError +from sqlalchemy import func, or_, and_, select + +from app.crud.base_async import CRUDBaseAsync +from app.models.organization import Organization +from app.models.user_organization import UserOrganization, OrganizationRole +from app.models.user import User +from app.schemas.organizations import ( + OrganizationCreate, + OrganizationUpdate, +) +import logging + +logger = logging.getLogger(__name__) + + +class CRUDOrganizationAsync(CRUDBaseAsync[Organization, OrganizationCreate, OrganizationUpdate]): + """Async CRUD operations for Organization model.""" + + async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Optional[Organization]: + """Get organization by slug.""" + try: + result = await db.execute( + select(Organization).where(Organization.slug == slug) + ) + return result.scalar_one_or_none() + except Exception as e: + logger.error(f"Error getting organization by slug {slug}: {str(e)}") + raise + + async def create(self, db: AsyncSession, *, obj_in: OrganizationCreate) -> Organization: + """Create a new organization with error handling.""" + try: + db_obj = Organization( + name=obj_in.name, + slug=obj_in.slug, + description=obj_in.description, + is_active=obj_in.is_active, + settings=obj_in.settings or {} + ) + db.add(db_obj) + await db.commit() + await db.refresh(db_obj) + return db_obj + except IntegrityError as e: + await db.rollback() + error_msg = str(e.orig) if hasattr(e, 'orig') else str(e) + if "slug" in error_msg.lower(): + logger.warning(f"Duplicate slug attempted: {obj_in.slug}") + raise ValueError(f"Organization with slug '{obj_in.slug}' already exists") + logger.error(f"Integrity error creating organization: {error_msg}") + raise ValueError(f"Database integrity error: {error_msg}") + except Exception as e: + await db.rollback() + logger.error(f"Unexpected error creating organization: {str(e)}", exc_info=True) + raise + + async def get_multi_with_filters( + self, + db: AsyncSession, + *, + skip: int = 0, + limit: int = 100, + is_active: Optional[bool] = None, + search: Optional[str] = None, + sort_by: str = "created_at", + sort_order: str = "desc" + ) -> tuple[List[Organization], int]: + """ + Get multiple organizations with filtering, searching, and sorting. + + Returns: + Tuple of (organizations list, total count) + """ + try: + query = select(Organization) + + # Apply filters + if is_active is not None: + query = query.where(Organization.is_active == is_active) + + if search: + search_filter = or_( + Organization.name.ilike(f"%{search}%"), + Organization.slug.ilike(f"%{search}%"), + Organization.description.ilike(f"%{search}%") + ) + query = query.where(search_filter) + + # Get total count before pagination + count_query = select(func.count()).select_from(query.alias()) + count_result = await db.execute(count_query) + total = count_result.scalar_one() + + # Apply sorting + sort_column = getattr(Organization, sort_by, Organization.created_at) + if sort_order == "desc": + query = query.order_by(sort_column.desc()) + else: + query = query.order_by(sort_column.asc()) + + # Apply pagination + query = query.offset(skip).limit(limit) + result = await db.execute(query) + organizations = list(result.scalars().all()) + + return organizations, total + except Exception as e: + logger.error(f"Error getting organizations with filters: {str(e)}") + raise + + async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int: + """Get the count of active members in an organization.""" + try: + result = await db.execute( + select(func.count(UserOrganization.user_id)).where( + and_( + UserOrganization.organization_id == organization_id, + UserOrganization.is_active == True + ) + ) + ) + return result.scalar_one() or 0 + except Exception as e: + logger.error(f"Error getting member count for organization {organization_id}: {str(e)}") + raise + + async def add_user( + self, + db: AsyncSession, + *, + organization_id: UUID, + user_id: UUID, + role: OrganizationRole = OrganizationRole.MEMBER, + custom_permissions: Optional[str] = None + ) -> UserOrganization: + """Add a user to an organization with a specific role.""" + try: + # Check if relationship already exists + result = await db.execute( + select(UserOrganization).where( + and_( + UserOrganization.user_id == user_id, + UserOrganization.organization_id == organization_id + ) + ) + ) + existing = result.scalar_one_or_none() + + if existing: + # Reactivate if inactive, or raise error if already active + if not existing.is_active: + existing.is_active = True + existing.role = role + existing.custom_permissions = custom_permissions + await db.commit() + await db.refresh(existing) + return existing + else: + raise ValueError("User is already a member of this organization") + + # Create new relationship + user_org = UserOrganization( + user_id=user_id, + organization_id=organization_id, + role=role, + is_active=True, + custom_permissions=custom_permissions + ) + db.add(user_org) + await db.commit() + await db.refresh(user_org) + return user_org + except IntegrityError as e: + await db.rollback() + logger.error(f"Integrity error adding user to organization: {str(e)}") + raise ValueError("Failed to add user to organization") + except Exception as e: + await db.rollback() + logger.error(f"Error adding user to organization: {str(e)}", exc_info=True) + raise + + async def remove_user( + self, + db: AsyncSession, + *, + organization_id: UUID, + user_id: UUID + ) -> bool: + """Remove a user from an organization (soft delete).""" + try: + result = await db.execute( + select(UserOrganization).where( + and_( + UserOrganization.user_id == user_id, + UserOrganization.organization_id == organization_id + ) + ) + ) + user_org = result.scalar_one_or_none() + + if not user_org: + return False + + user_org.is_active = False + await db.commit() + return True + except Exception as e: + await db.rollback() + logger.error(f"Error removing user from organization: {str(e)}", exc_info=True) + raise + + async def update_user_role( + self, + db: AsyncSession, + *, + organization_id: UUID, + user_id: UUID, + role: OrganizationRole, + custom_permissions: Optional[str] = None + ) -> Optional[UserOrganization]: + """Update a user's role in an organization.""" + try: + result = await db.execute( + select(UserOrganization).where( + and_( + UserOrganization.user_id == user_id, + UserOrganization.organization_id == organization_id + ) + ) + ) + user_org = result.scalar_one_or_none() + + if not user_org: + return None + + user_org.role = role + if custom_permissions is not None: + user_org.custom_permissions = custom_permissions + await db.commit() + await db.refresh(user_org) + return user_org + except Exception as e: + await db.rollback() + logger.error(f"Error updating user role: {str(e)}", exc_info=True) + raise + + async def get_organization_members( + self, + db: AsyncSession, + *, + organization_id: UUID, + skip: int = 0, + limit: int = 100, + is_active: bool = True + ) -> tuple[List[Dict[str, Any]], int]: + """ + Get members of an organization with user details. + + Returns: + Tuple of (members list with user details, total count) + """ + try: + # Build query with join + query = ( + select(UserOrganization, User) + .join(User, UserOrganization.user_id == User.id) + .where(UserOrganization.organization_id == organization_id) + ) + + if is_active is not None: + query = query.where(UserOrganization.is_active == is_active) + + # Get total count + count_query = select(func.count()).select_from( + select(UserOrganization) + .where(UserOrganization.organization_id == organization_id) + .where(UserOrganization.is_active == is_active if is_active is not None else True) + .alias() + ) + count_result = await db.execute(count_query) + total = count_result.scalar_one() + + # Apply ordering and pagination + query = query.order_by(UserOrganization.created_at.desc()).offset(skip).limit(limit) + result = await db.execute(query) + results = result.all() + + members = [] + for user_org, user in results: + members.append({ + "user_id": user.id, + "email": user.email, + "first_name": user.first_name, + "last_name": user.last_name, + "role": user_org.role, + "is_active": user_org.is_active, + "joined_at": user_org.created_at + }) + + return members, total + except Exception as e: + logger.error(f"Error getting organization members: {str(e)}") + raise + + async def get_user_organizations( + self, + db: AsyncSession, + *, + user_id: UUID, + is_active: bool = True + ) -> List[Organization]: + """Get all organizations a user belongs to.""" + try: + query = ( + select(Organization) + .join(UserOrganization, Organization.id == UserOrganization.organization_id) + .where(UserOrganization.user_id == user_id) + ) + + if is_active is not None: + query = query.where(UserOrganization.is_active == is_active) + + result = await db.execute(query) + return list(result.scalars().all()) + except Exception as e: + logger.error(f"Error getting user organizations: {str(e)}") + raise + + async def get_user_role_in_org( + self, + db: AsyncSession, + *, + user_id: UUID, + organization_id: UUID + ) -> Optional[OrganizationRole]: + """Get a user's role in a specific organization.""" + try: + result = await db.execute( + select(UserOrganization).where( + and_( + UserOrganization.user_id == user_id, + UserOrganization.organization_id == organization_id, + UserOrganization.is_active == True + ) + ) + ) + user_org = result.scalar_one_or_none() + + return user_org.role if user_org else None + except Exception as e: + logger.error(f"Error getting user role in org: {str(e)}") + raise + + async def is_user_org_owner( + self, + db: AsyncSession, + *, + user_id: UUID, + organization_id: UUID + ) -> bool: + """Check if a user is an owner of an organization.""" + role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id) + return role == OrganizationRole.OWNER + + async def is_user_org_admin( + self, + db: AsyncSession, + *, + user_id: UUID, + organization_id: UUID + ) -> bool: + """Check if a user is an owner or admin of an organization.""" + role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id) + return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN] + + +# Create a singleton instance for use across the application +organization_async = CRUDOrganizationAsync(Organization) diff --git a/backend/app/crud/session_async.py b/backend/app/crud/session_async.py new file mode 100755 index 0000000..2bd9ea1 --- /dev/null +++ b/backend/app/crud/session_async.py @@ -0,0 +1,363 @@ +""" +Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns. +""" +from datetime import datetime, timezone, timedelta +from typing import List, Optional +from uuid import UUID +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import and_, select, update, func +import logging + +from app.crud.base_async import CRUDBaseAsync +from app.models.user_session import UserSession +from app.schemas.sessions import SessionCreate, SessionUpdate + +logger = logging.getLogger(__name__) + + +class CRUDSessionAsync(CRUDBaseAsync[UserSession, SessionCreate, SessionUpdate]): + """Async CRUD operations for user sessions.""" + + async def get_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]: + """ + Get session by refresh token JTI. + + Args: + db: Database session + jti: Refresh token JWT ID + + Returns: + UserSession if found, None otherwise + """ + try: + result = await db.execute( + select(UserSession).where(UserSession.refresh_token_jti == jti) + ) + return result.scalar_one_or_none() + except Exception as e: + logger.error(f"Error getting session by JTI {jti}: {str(e)}") + raise + + async def get_active_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]: + """ + Get active session by refresh token JTI. + + Args: + db: Database session + jti: Refresh token JWT ID + + Returns: + Active UserSession if found, None otherwise + """ + try: + result = await db.execute( + select(UserSession).where( + and_( + UserSession.refresh_token_jti == jti, + UserSession.is_active == True + ) + ) + ) + return result.scalar_one_or_none() + except Exception as e: + logger.error(f"Error getting active session by JTI {jti}: {str(e)}") + raise + + async def get_user_sessions( + self, + db: AsyncSession, + *, + user_id: str, + active_only: bool = True + ) -> List[UserSession]: + """ + Get all sessions for a user. + + Args: + db: Database session + user_id: User ID + active_only: If True, return only active sessions + + Returns: + List of UserSession objects + """ + try: + # Convert user_id string to UUID if needed + user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id + + query = select(UserSession).where(UserSession.user_id == user_uuid) + + if active_only: + query = query.where(UserSession.is_active == True) + + query = query.order_by(UserSession.last_used_at.desc()) + result = await db.execute(query) + return list(result.scalars().all()) + except Exception as e: + logger.error(f"Error getting sessions for user {user_id}: {str(e)}") + raise + + async def create_session( + self, + db: AsyncSession, + *, + obj_in: SessionCreate + ) -> UserSession: + """ + Create a new user session. + + Args: + db: Database session + obj_in: SessionCreate schema with session data + + Returns: + Created UserSession + + Raises: + ValueError: If session creation fails + """ + try: + db_obj = UserSession( + user_id=obj_in.user_id, + refresh_token_jti=obj_in.refresh_token_jti, + device_name=obj_in.device_name, + device_id=obj_in.device_id, + ip_address=obj_in.ip_address, + user_agent=obj_in.user_agent, + last_used_at=obj_in.last_used_at, + expires_at=obj_in.expires_at, + is_active=True, + location_city=obj_in.location_city, + location_country=obj_in.location_country, + ) + db.add(db_obj) + await db.commit() + await db.refresh(db_obj) + + logger.info( + f"Session created for user {obj_in.user_id} from {obj_in.device_name} " + f"(IP: {obj_in.ip_address})" + ) + + return db_obj + except Exception as e: + await db.rollback() + logger.error(f"Error creating session: {str(e)}", exc_info=True) + raise ValueError(f"Failed to create session: {str(e)}") + + async def deactivate(self, db: AsyncSession, *, session_id: str) -> Optional[UserSession]: + """ + Deactivate a session (logout from device). + + Args: + db: Database session + session_id: Session UUID + + Returns: + Deactivated UserSession if found, None otherwise + """ + try: + session = await self.get(db, id=session_id) + if not session: + logger.warning(f"Session {session_id} not found for deactivation") + return None + + session.is_active = False + db.add(session) + await db.commit() + await db.refresh(session) + + logger.info( + f"Session {session_id} deactivated for user {session.user_id} " + f"({session.device_name})" + ) + + return session + except Exception as e: + await db.rollback() + logger.error(f"Error deactivating session {session_id}: {str(e)}") + raise + + async def deactivate_all_user_sessions( + self, + db: AsyncSession, + *, + user_id: str + ) -> int: + """ + Deactivate all active sessions for a user (logout from all devices). + + Args: + db: Database session + user_id: User ID + + Returns: + Number of sessions deactivated + """ + try: + # Convert user_id string to UUID if needed + user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id + + stmt = ( + update(UserSession) + .where( + and_( + UserSession.user_id == user_uuid, + UserSession.is_active == True + ) + ) + .values(is_active=False) + ) + + result = await db.execute(stmt) + await db.commit() + + count = result.rowcount + + logger.info(f"Deactivated {count} sessions for user {user_id}") + + return count + except Exception as e: + await db.rollback() + logger.error(f"Error deactivating all sessions for user {user_id}: {str(e)}") + raise + + async def update_last_used( + self, + db: AsyncSession, + *, + session: UserSession + ) -> UserSession: + """ + Update the last_used_at timestamp for a session. + + Args: + db: Database session + session: UserSession object + + Returns: + Updated UserSession + """ + try: + session.last_used_at = datetime.now(timezone.utc) + db.add(session) + await db.commit() + await db.refresh(session) + return session + except Exception as e: + await db.rollback() + logger.error(f"Error updating last_used for session {session.id}: {str(e)}") + raise + + async def update_refresh_token( + self, + db: AsyncSession, + *, + session: UserSession, + new_jti: str, + new_expires_at: datetime + ) -> UserSession: + """ + Update session with new refresh token JTI and expiration. + + Called during token refresh. + + Args: + db: Database session + session: UserSession object + new_jti: New refresh token JTI + new_expires_at: New expiration datetime + + Returns: + Updated UserSession + """ + try: + session.refresh_token_jti = new_jti + session.expires_at = new_expires_at + session.last_used_at = datetime.now(timezone.utc) + db.add(session) + await db.commit() + await db.refresh(session) + return session + except Exception as e: + await db.rollback() + logger.error(f"Error updating refresh token for session {session.id}: {str(e)}") + raise + + async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int: + """ + Clean up expired sessions. + + Deletes sessions that are: + - Expired AND inactive + - Older than keep_days + + Args: + db: Database session + keep_days: Keep inactive sessions for this many days (for audit) + + Returns: + Number of sessions deleted + """ + try: + cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days) + + # Get sessions to delete + stmt = select(UserSession).where( + and_( + UserSession.is_active == False, + UserSession.expires_at < datetime.now(timezone.utc), + UserSession.created_at < cutoff_date + ) + ) + result = await db.execute(stmt) + sessions_to_delete = list(result.scalars().all()) + + # Delete them + for session in sessions_to_delete: + await db.delete(session) + + await db.commit() + + count = len(sessions_to_delete) + + if count > 0: + logger.info(f"Cleaned up {count} expired sessions") + + return count + except Exception as e: + await db.rollback() + logger.error(f"Error cleaning up expired sessions: {str(e)}") + raise + + async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int: + """ + Get count of active sessions for a user. + + Args: + db: Database session + user_id: User ID + + Returns: + Number of active sessions + """ + try: + # Convert user_id string to UUID if needed + user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id + + result = await db.execute( + select(func.count(UserSession.id)).where( + and_( + UserSession.user_id == user_uuid, + UserSession.is_active == True + ) + ) + ) + return result.scalar_one() + except Exception as e: + logger.error(f"Error counting sessions for user {user_id}: {str(e)}") + raise + + +# Create singleton instance +session_async = CRUDSessionAsync(UserSession) diff --git a/backend/app/crud/user_async.py b/backend/app/crud/user_async.py new file mode 100755 index 0000000..c913806 --- /dev/null +++ b/backend/app/crud/user_async.py @@ -0,0 +1,170 @@ +# app/crud/user_async.py +"""Async CRUD operations for User model using SQLAlchemy 2.0 patterns.""" +from typing import Optional, Union, Dict, Any, List, Tuple +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.exc import IntegrityError +from sqlalchemy import or_, select +from app.crud.base_async import CRUDBaseAsync +from app.models.user import User +from app.schemas.users import UserCreate, UserUpdate +from app.core.auth import get_password_hash +import logging + +logger = logging.getLogger(__name__) + + +class CRUDUserAsync(CRUDBaseAsync[User, UserCreate, UserUpdate]): + """Async CRUD operations for User model.""" + + async def get_by_email(self, db: AsyncSession, *, email: str) -> Optional[User]: + """Get user by email address.""" + try: + result = await db.execute( + select(User).where(User.email == email) + ) + return result.scalar_one_or_none() + except Exception as e: + logger.error(f"Error getting user by email {email}: {str(e)}") + raise + + async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User: + """Create a new user with password hashing and error handling.""" + try: + db_obj = User( + email=obj_in.email, + password_hash=get_password_hash(obj_in.password), + first_name=obj_in.first_name, + last_name=obj_in.last_name, + phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None, + is_superuser=obj_in.is_superuser if hasattr(obj_in, 'is_superuser') else False, + preferences={} + ) + db.add(db_obj) + await db.commit() + await db.refresh(db_obj) + return db_obj + except IntegrityError as e: + await db.rollback() + error_msg = str(e.orig) if hasattr(e, 'orig') else str(e) + if "email" in error_msg.lower(): + logger.warning(f"Duplicate email attempted: {obj_in.email}") + raise ValueError(f"User with email {obj_in.email} already exists") + logger.error(f"Integrity error creating user: {error_msg}") + raise ValueError(f"Database integrity error: {error_msg}") + except Exception as e: + await db.rollback() + logger.error(f"Unexpected error creating user: {str(e)}", exc_info=True) + raise + + async def update( + self, + db: AsyncSession, + *, + db_obj: User, + obj_in: Union[UserUpdate, Dict[str, Any]] + ) -> User: + """Update user with password hashing if password is updated.""" + if isinstance(obj_in, dict): + update_data = obj_in + else: + update_data = obj_in.model_dump(exclude_unset=True) + + # Handle password separately if it exists in update data + if "password" in update_data: + update_data["password_hash"] = get_password_hash(update_data["password"]) + del update_data["password"] + + return await super().update(db, db_obj=db_obj, obj_in=update_data) + + async def get_multi_with_total( + self, + db: AsyncSession, + *, + skip: int = 0, + limit: int = 100, + sort_by: Optional[str] = None, + sort_order: str = "asc", + filters: Optional[Dict[str, Any]] = None, + search: Optional[str] = None + ) -> Tuple[List[User], int]: + """ + Get multiple users with total count, filtering, sorting, and search. + + Args: + db: Database session + skip: Number of records to skip + limit: Maximum number of records to return + sort_by: Field name to sort by + sort_order: Sort order ("asc" or "desc") + filters: Dictionary of filters (field_name: value) + search: Search term to match against email, first_name, last_name + + Returns: + Tuple of (users list, total count) + """ + # Validate pagination + if skip < 0: + raise ValueError("skip must be non-negative") + if limit < 0: + raise ValueError("limit must be non-negative") + if limit > 1000: + raise ValueError("Maximum limit is 1000") + + try: + # Build base query + query = select(User) + + # Exclude soft-deleted users + query = query.where(User.deleted_at.is_(None)) + + # Apply filters + if filters: + for field, value in filters.items(): + if hasattr(User, field) and value is not None: + query = query.where(getattr(User, field) == value) + + # Apply search + if search: + search_filter = or_( + User.email.ilike(f"%{search}%"), + User.first_name.ilike(f"%{search}%"), + User.last_name.ilike(f"%{search}%") + ) + query = query.where(search_filter) + + # Get total count + from sqlalchemy import func + count_query = select(func.count()).select_from(query.alias()) + count_result = await db.execute(count_query) + total = count_result.scalar_one() + + # Apply sorting + if sort_by and hasattr(User, sort_by): + sort_column = getattr(User, sort_by) + if sort_order.lower() == "desc": + query = query.order_by(sort_column.desc()) + else: + query = query.order_by(sort_column.asc()) + + # Apply pagination + query = query.offset(skip).limit(limit) + result = await db.execute(query) + users = list(result.scalars().all()) + + return users, total + + except Exception as e: + logger.error(f"Error retrieving paginated users: {str(e)}") + raise + + def is_active(self, user: User) -> bool: + """Check if user is active.""" + return user.is_active + + def is_superuser(self, user: User) -> bool: + """Check if user is a superuser.""" + return user.is_superuser + + +# Create a singleton instance for use across the application +user_async = CRUDUserAsync(User)