From efcf10f9aad9ded3893aa63a45e2cce43e5c9115 Mon Sep 17 00:00:00 2001 From: Felipe Cardoso Date: Sat, 1 Nov 2025 05:47:43 +0100 Subject: [PATCH] Remove unused async database and CRUD modules - Deleted `database_async.py`, `base_async.py`, and `organization_async.py` modules due to deprecation and unused references across the project. - Improved overall codebase clarity and minimized redundant functionality by removing unused async database logic, CRUD utilities, and organization-related operations. --- backend/app/api/dependencies/auth.py | 6 +- backend/app/api/dependencies/permissions.py | 10 +- backend/app/api/routes/admin.py | 38 +- backend/app/api/routes/auth.py | 24 +- backend/app/api/routes/organizations.py | 12 +- backend/app/api/routes/sessions.py | 10 +- backend/app/api/routes/users.py | 16 +- backend/app/core/database.py | 207 +++++--- backend/app/core/database_async.py | 186 ------- backend/app/crud/base.py | 207 +++++--- backend/app/crud/base_async.py | 399 --------------- backend/app/crud/organization.py | 434 +++++++++++----- backend/app/crud/organization_async.py | 519 -------------------- backend/app/crud/session.py | 220 ++++++--- backend/app/crud/session_async.py | 424 ---------------- backend/app/crud/user.py | 183 +++++-- backend/app/crud/user_async.py | 272 ---------- backend/app/init_db.py | 78 --- backend/app/main.py | 2 +- backend/app/services/session_cleanup.py | 8 +- 20 files changed, 972 insertions(+), 2283 deletions(-) mode change 100644 => 100755 backend/app/core/database.py delete mode 100755 backend/app/core/database_async.py mode change 100644 => 100755 backend/app/crud/base.py delete mode 100755 backend/app/crud/base_async.py mode change 100644 => 100755 backend/app/crud/organization.py delete mode 100755 backend/app/crud/organization_async.py mode change 100644 => 100755 backend/app/crud/session.py delete mode 100755 backend/app/crud/session_async.py mode change 100644 => 100755 backend/app/crud/user.py delete mode 100755 backend/app/crud/user_async.py delete mode 100755 backend/app/init_db.py diff --git a/backend/app/api/dependencies/auth.py b/backend/app/api/dependencies/auth.py index 6502f9d..93b7411 100755 --- a/backend/app/api/dependencies/auth.py +++ b/backend/app/api/dependencies/auth.py @@ -7,7 +7,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.core.auth import get_token_data, TokenExpiredError, TokenInvalidError -from app.core.database_async import get_async_db +from app.core.database import get_db from app.models.user import User # OAuth2 configuration @@ -15,7 +15,7 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") async def get_current_user( - db: AsyncSession = Depends(get_async_db), + db: AsyncSession = Depends(get_db), token: str = Depends(oauth2_scheme) ) -> User: """ @@ -139,7 +139,7 @@ async def get_optional_token(authorization: str = Header(None)) -> Optional[str] async def get_optional_current_user( - db: AsyncSession = Depends(get_async_db), + db: AsyncSession = Depends(get_db), token: Optional[str] = Depends(get_optional_token) ) -> Optional[User]: """ diff --git a/backend/app/api/dependencies/permissions.py b/backend/app/api/dependencies/permissions.py index 76e4545..d3c575a 100755 --- a/backend/app/api/dependencies/permissions.py +++ b/backend/app/api/dependencies/permissions.py @@ -14,8 +14,8 @@ from fastapi import Depends, HTTPException, status from sqlalchemy.ext.asyncio import AsyncSession from app.api.dependencies.auth import get_current_user -from app.core.database_async import get_async_db -from app.crud.organization_async import organization_async as organization_crud +from app.core.database import get_db +from app.crud.organization import organization as organization_crud from app.models.user import User from app.models.user_organization import OrganizationRole @@ -78,7 +78,7 @@ class OrganizationPermission: self, organization_id: UUID, current_user: User = Depends(get_current_user), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> User: """ Check if user has required role in the organization. @@ -133,7 +133,7 @@ require_org_member = OrganizationPermission([ async def get_current_org_role( organization_id: UUID, current_user: User = Depends(get_current_user), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Optional[OrganizationRole]: """ Get the current user's role in an organization. @@ -164,7 +164,7 @@ async def get_current_org_role( async def require_org_membership( organization_id: UUID, current_user: User = Depends(get_current_user), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> User: """ Ensure user is a member of the organization (any role). diff --git a/backend/app/api/routes/admin.py b/backend/app/api/routes/admin.py index 6dafe2f..75e1491 100755 --- a/backend/app/api/routes/admin.py +++ b/backend/app/api/routes/admin.py @@ -15,10 +15,10 @@ from pydantic import BaseModel, Field from sqlalchemy.ext.asyncio import AsyncSession from app.api.dependencies.permissions import require_superuser -from app.core.database_async import get_async_db +from app.core.database import get_db from app.core.exceptions import NotFoundError, DuplicateError, AuthorizationError, ErrorCode -from app.crud.organization_async import organization_async as organization_crud -from app.crud.user_async import user_async as user_crud +from app.crud.organization import organization as organization_crud +from app.crud.user import user as user_crud from app.models.user import User from app.models.user_organization import OrganizationRole from app.schemas.common import ( @@ -80,7 +80,7 @@ async def admin_list_users( is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"), search: Optional[str] = Query(None, description="Search by email, name"), admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """ List all users with comprehensive filtering and search. @@ -131,7 +131,7 @@ async def admin_list_users( async def admin_create_user( user_in: UserCreate, admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """ Create a new user with admin privileges. @@ -163,7 +163,7 @@ async def admin_create_user( async def admin_get_user( user_id: UUID, admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """Get detailed information about a specific user.""" user = await user_crud.get(db, id=user_id) @@ -186,7 +186,7 @@ async def admin_update_user( user_id: UUID, user_in: UserUpdate, admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """Update user information with admin privileges.""" try: @@ -218,7 +218,7 @@ async def admin_update_user( async def admin_delete_user( user_id: UUID, admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """Soft delete a user (sets deleted_at timestamp).""" try: @@ -262,7 +262,7 @@ async def admin_delete_user( async def admin_activate_user( user_id: UUID, admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """Activate a user account.""" try: @@ -298,7 +298,7 @@ async def admin_activate_user( async def admin_deactivate_user( user_id: UUID, admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """Deactivate a user account.""" try: @@ -342,7 +342,7 @@ async def admin_deactivate_user( async def admin_bulk_user_action( bulk_action: BulkUserAction, admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """ Perform bulk actions on multiple users using optimized bulk operations. @@ -410,7 +410,7 @@ async def admin_list_organizations( is_active: Optional[bool] = Query(None, description="Filter by active status"), search: Optional[str] = Query(None, description="Search by name, slug, description"), admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """List all organizations with filtering and search.""" try: @@ -467,7 +467,7 @@ async def admin_list_organizations( async def admin_create_organization( org_in: OrganizationCreate, admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """Create a new organization.""" try: @@ -509,7 +509,7 @@ async def admin_create_organization( async def admin_get_organization( org_id: UUID, admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """Get detailed information about a specific organization.""" org = await organization_crud.get(db, id=org_id) @@ -544,7 +544,7 @@ async def admin_update_organization( org_id: UUID, org_in: OrganizationUpdate, admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """Update organization information.""" try: @@ -588,7 +588,7 @@ async def admin_update_organization( async def admin_delete_organization( org_id: UUID, admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """Delete an organization and all its relationships.""" try: @@ -626,7 +626,7 @@ async def admin_list_organization_members( pagination: PaginationParams = Depends(), is_active: Optional[bool] = Query(True, description="Filter by active status"), admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """List all members of an organization.""" try: @@ -681,7 +681,7 @@ async def admin_add_organization_member( org_id: UUID, request: AddMemberRequest, admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """Add a user to an organization.""" try: @@ -742,7 +742,7 @@ async def admin_remove_organization_member( org_id: UUID, user_id: UUID, admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """Remove a user from an organization.""" try: diff --git a/backend/app/api/routes/auth.py b/backend/app/api/routes/auth.py index 1335e92..725a6cd 100755 --- a/backend/app/api/routes/auth.py +++ b/backend/app/api/routes/auth.py @@ -13,14 +13,14 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.api.dependencies.auth import get_current_user from app.core.auth import TokenExpiredError, TokenInvalidError, decode_token from app.core.auth import get_password_hash -from app.core.database_async import get_async_db +from app.core.database import get_db from app.core.exceptions import ( AuthenticationError as AuthError, DatabaseError, ErrorCode ) -from app.crud.session_async import session_async as session_crud -from app.crud.user_async import user_async as user_crud +from app.crud.session import session as session_crud +from app.crud.user import user as user_crud from app.models.user import User from app.schemas.common import MessageResponse from app.schemas.sessions import SessionCreate, LogoutRequest @@ -54,7 +54,7 @@ RATE_MULTIPLIER = 100 if IS_TEST else 1 async def register_user( request: Request, user_data: UserCreate, - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """ Register a new user. @@ -85,7 +85,7 @@ async def register_user( async def login( request: Request, login_data: LoginRequest, - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """ Login with username and password. @@ -167,7 +167,7 @@ async def login( async def login_oauth( request: Request, form_data: OAuth2PasswordRequestForm = Depends(), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """ OAuth2-compatible login endpoint, used by the OpenAPI UI. @@ -244,7 +244,7 @@ async def login_oauth( async def refresh_token( request: Request, refresh_data: RefreshTokenRequest, - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """ Refresh access token using a refresh token. @@ -333,7 +333,7 @@ async def refresh_token( async def request_password_reset( request: Request, reset_request: PasswordResetRequest, - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """ Request a password reset. @@ -391,7 +391,7 @@ async def request_password_reset( async def confirm_password_reset( request: Request, reset_confirm: PasswordResetConfirm, - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """ Confirm password reset with token. @@ -430,7 +430,7 @@ async def confirm_password_reset( # SECURITY: Invalidate all existing sessions after password reset # This prevents stolen sessions from being used after password change - from app.crud.session_async import session_async as session_crud + from app.crud.session import session as session_crud try: deactivated_count = await session_crud.deactivate_all_user_sessions( db, @@ -478,7 +478,7 @@ async def logout( request: Request, logout_request: LogoutRequest, current_user: User = Depends(get_current_user), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """ Logout from current device by deactivating the session. @@ -566,7 +566,7 @@ async def logout( async def logout_all( request: Request, current_user: User = Depends(get_current_user), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """ Logout from all devices by deactivating all user sessions. diff --git a/backend/app/api/routes/organizations.py b/backend/app/api/routes/organizations.py index 9b1c953..a6b4394 100755 --- a/backend/app/api/routes/organizations.py +++ b/backend/app/api/routes/organizations.py @@ -13,9 +13,9 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.api.dependencies.auth import get_current_user from app.api.dependencies.permissions import require_org_admin, require_org_membership -from app.core.database_async import get_async_db +from app.core.database import get_db from app.core.exceptions import NotFoundError, ErrorCode -from app.crud.organization_async import organization_async as organization_crud +from app.crud.organization import organization as organization_crud from app.models.user import User from app.schemas.common import ( PaginationParams, @@ -43,7 +43,7 @@ router = APIRouter() async def get_my_organizations( is_active: bool = Query(True, description="Filter by active membership"), current_user: User = Depends(get_current_user), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """ Get all organizations the current user belongs to. @@ -93,7 +93,7 @@ async def get_my_organizations( async def get_organization( organization_id: UUID, current_user: User = Depends(require_org_membership), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """ Get details of a specific organization. @@ -140,7 +140,7 @@ async def get_organization_members( pagination: PaginationParams = Depends(), is_active: bool = Query(True, description="Filter by active status"), current_user: User = Depends(require_org_membership), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """ Get all members of an organization. @@ -183,7 +183,7 @@ async def update_organization( organization_id: UUID, org_in: OrganizationUpdate, current_user: User = Depends(require_org_admin), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """ Update organization details. diff --git a/backend/app/api/routes/sessions.py b/backend/app/api/routes/sessions.py index 64b0e04..f39056b 100755 --- a/backend/app/api/routes/sessions.py +++ b/backend/app/api/routes/sessions.py @@ -14,9 +14,9 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.api.dependencies.auth import get_current_user from app.core.auth import decode_token -from app.core.database_async import get_async_db +from app.core.database import get_db from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode -from app.crud.session_async import session_async as session_crud +from app.crud.session import session as session_crud from app.models.user import User from app.schemas.common import MessageResponse from app.schemas.sessions import SessionResponse, SessionListResponse @@ -45,7 +45,7 @@ limiter = Limiter(key_func=get_remote_address) async def list_my_sessions( request: Request, current_user: User = Depends(get_current_user), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """ List all active sessions for the current user. @@ -129,7 +129,7 @@ async def revoke_session( request: Request, session_id: UUID, current_user: User = Depends(get_current_user), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """ Revoke a specific session by ID. @@ -204,7 +204,7 @@ async def revoke_session( async def cleanup_expired_sessions( request: Request, current_user: User = Depends(get_current_user), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """ Cleanup expired sessions for the current user. diff --git a/backend/app/api/routes/users.py b/backend/app/api/routes/users.py index 050f833..d94a754 100755 --- a/backend/app/api/routes/users.py +++ b/backend/app/api/routes/users.py @@ -11,13 +11,13 @@ from slowapi.util import get_remote_address from sqlalchemy.ext.asyncio import AsyncSession from app.api.dependencies.auth import get_current_user, get_current_superuser -from app.core.database_async import get_async_db +from app.core.database import get_db from app.core.exceptions import ( NotFoundError, AuthorizationError, ErrorCode ) -from app.crud.user_async import user_async as user_crud +from app.crud.user import user as user_crud from app.models.user import User from app.schemas.common import ( PaginationParams, @@ -58,7 +58,7 @@ async def list_users( is_active: Optional[bool] = Query(None, description="Filter by active status"), is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"), current_user: User = Depends(get_current_superuser), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """ List all users with pagination, filtering, and sorting. @@ -138,7 +138,7 @@ def get_current_user_profile( async def update_current_user( user_update: UserUpdate, current_user: User = Depends(get_current_user), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """ Update current user's profile. @@ -188,7 +188,7 @@ async def update_current_user( async def get_user_by_id( user_id: UUID, current_user: User = Depends(get_current_user), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """ Get user by ID. @@ -236,7 +236,7 @@ async def update_user( user_id: UUID, user_update: UserUpdate, current_user: User = Depends(get_current_user), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """ Update user by ID. @@ -304,7 +304,7 @@ async def change_current_user_password( request: Request, password_change: PasswordChange, current_user: User = Depends(get_current_user), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """ Change current user's password. @@ -356,7 +356,7 @@ async def change_current_user_password( async def delete_user( user_id: UUID, current_user: User = Depends(get_current_superuser), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_db) ) -> Any: """ Delete user by ID (superuser only). diff --git a/backend/app/core/database.py b/backend/app/core/database.py old mode 100644 new mode 100755 index 1e51a43..42f857d --- a/backend/app/core/database.py +++ b/backend/app/core/database.py @@ -1,113 +1,186 @@ # app/core/database.py -import logging -from contextlib import contextmanager -from typing import Generator +""" +Database configuration using SQLAlchemy 2.0 and asyncpg. -from sqlalchemy import create_engine, text +This module provides async database connectivity with proper connection pooling +and session management for FastAPI endpoints. +""" +import logging +from contextlib import asynccontextmanager +from typing import AsyncGenerator + +from sqlalchemy import text from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.ext.asyncio import ( + AsyncSession, + AsyncEngine, + create_async_engine, + async_sessionmaker, +) from sqlalchemy.ext.compiler import compiles -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker, Session +from sqlalchemy.orm import DeclarativeBase from app.core.config import settings # Configure logging logger = logging.getLogger(__name__) + # SQLite compatibility for testing @compiles(JSONB, 'sqlite') def compile_jsonb_sqlite(type_, compiler, **kw): return "TEXT" + @compiles(UUID, 'sqlite') def compile_uuid_sqlite(type_, compiler, **kw): return "TEXT" -# Declarative base for models -Base = declarative_base() -# Create engine with optimized settings for PostgreSQL -def create_production_engine(): - return create_engine( - settings.database_url, - # Connection pool settings - pool_size=settings.db_pool_size, - max_overflow=settings.db_max_overflow, - pool_timeout=settings.db_pool_timeout, - pool_recycle=settings.db_pool_recycle, - pool_pre_ping=True, - # Query execution settings - connect_args={ - "application_name": "eventspace", - "keepalives": 1, - "keepalives_idle": 60, - "keepalives_interval": 10, - "keepalives_count": 5, - "options": "-c timezone=UTC", - }, - isolation_level="READ COMMITTED", - echo=settings.sql_echo, - echo_pool=settings.sql_echo_pool, - ) +# Declarative base for models (SQLAlchemy 2.0 style) +class Base(DeclarativeBase): + """Base class for all database models.""" + pass -# Default production engine and session factory -engine = create_production_engine() -SessionLocal = sessionmaker( + +def get_async_database_url(url: str) -> str: + """ + Convert sync database URL to async URL. + + postgresql:// -> postgresql+asyncpg:// + sqlite:// -> sqlite+aiosqlite:// + """ + if url.startswith("postgresql://"): + return url.replace("postgresql://", "postgresql+asyncpg://") + elif url.startswith("sqlite://"): + return url.replace("sqlite://", "sqlite+aiosqlite://") + return url + + +# Create async engine with optimized settings +def create_async_production_engine() -> AsyncEngine: + """Create an async database engine with production settings.""" + async_url = get_async_database_url(settings.database_url) + + # Base engine config + engine_config = { + "pool_size": settings.db_pool_size, + "max_overflow": settings.db_max_overflow, + "pool_timeout": settings.db_pool_timeout, + "pool_recycle": settings.db_pool_recycle, + "pool_pre_ping": True, + "echo": settings.sql_echo, + "echo_pool": settings.sql_echo_pool, + } + + # Add PostgreSQL-specific connect_args + if "postgresql" in async_url: + engine_config["connect_args"] = { + "server_settings": { + "application_name": "eventspace", + "timezone": "UTC", + }, + # asyncpg-specific settings + "command_timeout": 60, + "timeout": 10, + } + + return create_async_engine(async_url, **engine_config) + + +# Create async engine and session factory +engine = create_async_production_engine() +SessionLocal = async_sessionmaker( + engine, + class_=AsyncSession, autocommit=False, autoflush=False, - bind=engine, - expire_on_commit=False # Prevent unnecessary queries after commit + expire_on_commit=False, # Prevent unnecessary queries after commit ) -# FastAPI dependency -def get_db() -> Generator[Session, None, None]: + +# FastAPI dependency for async database sessions +async def get_db() -> AsyncGenerator[AsyncSession, None]: """ - FastAPI dependency that provides a database session. + FastAPI dependency that provides an async database session. Automatically closes the session after the request completes. + + Usage: + @router.get("/users") + async def get_users(db: AsyncSession = Depends(get_db)): + result = await db.execute(select(User)) + return result.scalars().all() """ - db = SessionLocal() - try: - yield db - finally: - db.close() + async with SessionLocal() as session: + try: + yield session + finally: + await session.close() -@contextmanager -def transaction_scope() -> Generator[Session, None, None]: +@asynccontextmanager +async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]: """ - Provide a transactional scope for database operations. + Provide an async transactional scope for database operations. Automatically commits on success or rolls back on exception. Useful for grouping multiple operations in a single transaction. Usage: - with transaction_scope() as db: - user = user_crud.create(db, obj_in=user_create) - profile = profile_crud.create(db, obj_in=profile_create) + async with async_transaction_scope() as db: + user = await user_crud.create(db, obj_in=user_create) + profile = await profile_crud.create(db, obj_in=profile_create) # Both operations committed together """ - db = SessionLocal() - try: - yield db - db.commit() - logger.debug("Transaction committed successfully") - except Exception as e: - db.rollback() - logger.error(f"Transaction failed, rolling back: {str(e)}") - raise - finally: - db.close() + async with SessionLocal() as session: + try: + yield session + await session.commit() + logger.debug("Async transaction committed successfully") + except Exception as e: + await session.rollback() + logger.error(f"Async transaction failed, rolling back: {str(e)}") + raise + finally: + await session.close() -def check_database_health() -> bool: +async def check_async_database_health() -> bool: """ - Check if database connection is healthy. + Check if async database connection is healthy. Returns True if connection is successful, False otherwise. """ try: - with transaction_scope() as db: - db.execute(text("SELECT 1")) + async with async_transaction_scope() as db: + await db.execute(text("SELECT 1")) return True except Exception as e: - logger.error(f"Database health check failed: {str(e)}") - return False \ No newline at end of file + logger.error(f"Async database health check failed: {str(e)}") + return False + + +# Alias for consistency with main.py +check_database_health = check_async_database_health + + +async def init_async_db() -> None: + """ + Initialize async database tables. + + This creates all tables defined in the models. + Should only be used in development or testing. + In production, use Alembic migrations. + """ + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + logger.info("Async database tables created") + + +async def close_async_db() -> None: + """ + Close all async database connections. + + Should be called during application shutdown. + """ + await engine.dispose() + logger.info("Async database connections closed") diff --git a/backend/app/core/database_async.py b/backend/app/core/database_async.py deleted file mode 100755 index bf0dfbd..0000000 --- a/backend/app/core/database_async.py +++ /dev/null @@ -1,186 +0,0 @@ -# app/core/database_async.py -""" -Async database configuration using SQLAlchemy 2.0 and asyncpg. - -This module provides async database connectivity with proper connection pooling -and session management for FastAPI endpoints. -""" -import logging -from contextlib import asynccontextmanager -from typing import AsyncGenerator - -from sqlalchemy import text -from sqlalchemy.dialects.postgresql import JSONB, UUID -from sqlalchemy.ext.asyncio import ( - AsyncSession, - AsyncEngine, - create_async_engine, - async_sessionmaker, -) -from sqlalchemy.ext.compiler import compiles -from sqlalchemy.orm import DeclarativeBase - -from app.core.config import settings - -# Configure logging -logger = logging.getLogger(__name__) - - -# SQLite compatibility for testing -@compiles(JSONB, 'sqlite') -def compile_jsonb_sqlite(type_, compiler, **kw): - return "TEXT" - - -@compiles(UUID, 'sqlite') -def compile_uuid_sqlite(type_, compiler, **kw): - return "TEXT" - - -# Declarative base for models (SQLAlchemy 2.0 style) -class Base(DeclarativeBase): - """Base class for all database models.""" - pass - - -def get_async_database_url(url: str) -> str: - """ - Convert sync database URL to async URL. - - postgresql:// -> postgresql+asyncpg:// - sqlite:// -> sqlite+aiosqlite:// - """ - if url.startswith("postgresql://"): - return url.replace("postgresql://", "postgresql+asyncpg://") - elif url.startswith("sqlite://"): - return url.replace("sqlite://", "sqlite+aiosqlite://") - return url - - -# Create async engine with optimized settings -def create_async_production_engine() -> AsyncEngine: - """Create an async database engine with production settings.""" - async_url = get_async_database_url(settings.database_url) - - # Base engine config - engine_config = { - "pool_size": settings.db_pool_size, - "max_overflow": settings.db_max_overflow, - "pool_timeout": settings.db_pool_timeout, - "pool_recycle": settings.db_pool_recycle, - "pool_pre_ping": True, - "echo": settings.sql_echo, - "echo_pool": settings.sql_echo_pool, - } - - # Add PostgreSQL-specific connect_args - if "postgresql" in async_url: - engine_config["connect_args"] = { - "server_settings": { - "application_name": "eventspace", - "timezone": "UTC", - }, - # asyncpg-specific settings - "command_timeout": 60, - "timeout": 10, - } - - return create_async_engine(async_url, **engine_config) - - -# Create async engine and session factory -async_engine = create_async_production_engine() -AsyncSessionLocal = async_sessionmaker( - async_engine, - class_=AsyncSession, - autocommit=False, - autoflush=False, - expire_on_commit=False, # Prevent unnecessary queries after commit -) - - -# FastAPI dependency for async database sessions -async def get_async_db() -> AsyncGenerator[AsyncSession, None]: - """ - FastAPI dependency that provides an async database session. - Automatically closes the session after the request completes. - - Usage: - @router.get("/users") - async def get_users(db: AsyncSession = Depends(get_async_db)): - result = await db.execute(select(User)) - return result.scalars().all() - """ - async with AsyncSessionLocal() as session: - try: - yield session - finally: - await session.close() - - -@asynccontextmanager -async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]: - """ - Provide an async transactional scope for database operations. - - Automatically commits on success or rolls back on exception. - Useful for grouping multiple operations in a single transaction. - - Usage: - async with async_transaction_scope() as db: - user = await user_crud.create(db, obj_in=user_create) - profile = await profile_crud.create(db, obj_in=profile_create) - # Both operations committed together - """ - async with AsyncSessionLocal() as session: - try: - yield session - await session.commit() - logger.debug("Async transaction committed successfully") - except Exception as e: - await session.rollback() - logger.error(f"Async transaction failed, rolling back: {str(e)}") - raise - finally: - await session.close() - - -async def check_async_database_health() -> bool: - """ - Check if async database connection is healthy. - Returns True if connection is successful, False otherwise. - """ - try: - async with async_transaction_scope() as db: - await db.execute(text("SELECT 1")) - return True - except Exception as e: - logger.error(f"Async database health check failed: {str(e)}") - return False - - -# Alias for consistency with main.py -check_database_health = check_async_database_health - - -async def init_async_db() -> None: - """ - Initialize async database tables. - - This creates all tables defined in the models. - Should only be used in development or testing. - In production, use Alembic migrations. - """ - async with async_engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) - logger.info("Async database tables created") - - -async def close_async_db() -> None: - """ - Close all async database connections. - - Should be called during application shutdown. - """ - await async_engine.dispose() - logger.info("Async database connections closed") diff --git a/backend/app/crud/base.py b/backend/app/crud/base.py old mode 100644 new mode 100755 index bf8728e..eb9d0ee --- a/backend/app/crud/base.py +++ b/backend/app/crud/base.py @@ -1,13 +1,19 @@ +# app/crud/base_async.py +""" +Async CRUD operations base class using SQLAlchemy 2.0 async patterns. + +Provides reusable create, read, update, and delete operations for all models. +""" import logging import uuid -from datetime import datetime, timezone from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple from fastapi.encoders import jsonable_encoder from pydantic import BaseModel -from sqlalchemy import asc, desc +from sqlalchemy import func, select from sqlalchemy.exc import IntegrityError, OperationalError, DataError -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import Load from app.core.database import Base @@ -19,17 +25,40 @@ UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): + """Async CRUD operations for a model.""" + def __init__(self, model: Type[ModelType]): """ - CRUD object with default methods to Create, Read, Update, Delete (CRUD). + CRUD object with default async methods to Create, Read, Update, Delete. Parameters: model: A SQLAlchemy model class """ self.model = model - def get(self, db: Session, id: str) -> Optional[ModelType]: - """Get a single record by ID with UUID validation.""" + async def get( + self, + db: AsyncSession, + id: str, + options: Optional[List[Load]] = None + ) -> Optional[ModelType]: + """ + Get a single record by ID with UUID validation and optional eager loading. + + Args: + db: Database session + id: Record UUID + options: Optional list of SQLAlchemy load options (e.g., joinedload, selectinload) + for eager loading relationships to prevent N+1 queries + + Returns: + Model instance or None if not found + + Example: + # Eager load user relationship + from sqlalchemy.orm import joinedload + session = await session_crud.get(db, id=session_id, options=[joinedload(UserSession.user)]) + """ # Validate UUID format and convert to UUID object if string try: if isinstance(id, uuid.UUID): @@ -41,15 +70,39 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): return None try: - return db.query(self.model).filter(self.model.id == uuid_obj).first() + query = select(self.model).where(self.model.id == uuid_obj) + + # Apply eager loading options if provided + if options: + for option in options: + query = query.options(option) + + result = await db.execute(query) + return result.scalar_one_or_none() except Exception as e: logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}") raise - def get_multi( - self, db: Session, *, skip: int = 0, limit: int = 100 + async def get_multi( + self, + db: AsyncSession, + *, + skip: int = 0, + limit: int = 100, + options: Optional[List[Load]] = None ) -> List[ModelType]: - """Get multiple records with pagination validation.""" + """ + Get multiple records with pagination validation and optional eager loading. + + Args: + db: Database session + skip: Number of records to skip + limit: Maximum number of records to return + options: Optional list of SQLAlchemy load options for eager loading + + Returns: + List of model instances + """ # Validate pagination parameters if skip < 0: raise ValueError("skip must be non-negative") @@ -59,22 +112,30 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): raise ValueError("Maximum limit is 1000") try: - return db.query(self.model).offset(skip).limit(limit).all() + query = select(self.model).offset(skip).limit(limit) + + # Apply eager loading options if provided + if options: + for option in options: + query = query.options(option) + + result = await db.execute(query) + return list(result.scalars().all()) except Exception as e: logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}") raise - def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType: + async def create(self, db: AsyncSession, *, obj_in: CreateSchemaType) -> ModelType: """Create a new record with error handling.""" try: obj_in_data = jsonable_encoder(obj_in) db_obj = self.model(**obj_in_data) db.add(db_obj) - db.commit() - db.refresh(db_obj) + await db.commit() + await db.refresh(db_obj) return db_obj except IntegrityError as e: - db.rollback() + await db.rollback() error_msg = str(e.orig) if hasattr(e, 'orig') else str(e) if "unique" in error_msg.lower() or "duplicate" in error_msg.lower(): logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}") @@ -82,20 +143,20 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}") raise ValueError(f"Database integrity error: {error_msg}") except (OperationalError, DataError) as e: - db.rollback() + await db.rollback() logger.error(f"Database error creating {self.model.__name__}: {str(e)}") raise ValueError(f"Database operation failed: {str(e)}") except Exception as e: - db.rollback() + await db.rollback() logger.error(f"Unexpected error creating {self.model.__name__}: {str(e)}", exc_info=True) raise - def update( - self, - db: Session, - *, - db_obj: ModelType, - obj_in: Union[UpdateSchemaType, Dict[str, Any]] + async def update( + self, + db: AsyncSession, + *, + db_obj: ModelType, + obj_in: Union[UpdateSchemaType, Dict[str, Any]] ) -> ModelType: """Update a record with error handling.""" try: @@ -104,15 +165,17 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): update_data = obj_in else: update_data = obj_in.model_dump(exclude_unset=True) + for field in obj_data: if field in update_data: setattr(db_obj, field, update_data[field]) + db.add(db_obj) - db.commit() - db.refresh(db_obj) + await db.commit() + await db.refresh(db_obj) return db_obj except IntegrityError as e: - db.rollback() + await db.rollback() error_msg = str(e.orig) if hasattr(e, 'orig') else str(e) if "unique" in error_msg.lower() or "duplicate" in error_msg.lower(): logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}") @@ -120,15 +183,15 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}") raise ValueError(f"Database integrity error: {error_msg}") except (OperationalError, DataError) as e: - db.rollback() + await db.rollback() logger.error(f"Database error updating {self.model.__name__}: {str(e)}") raise ValueError(f"Database operation failed: {str(e)}") except Exception as e: - db.rollback() + await db.rollback() logger.error(f"Unexpected error updating {self.model.__name__}: {str(e)}", exc_info=True) raise - def remove(self, db: Session, *, id: str) -> Optional[ModelType]: + async def remove(self, db: AsyncSession, *, id: str) -> Optional[ModelType]: """Delete a record with error handling and null check.""" # Validate UUID format and convert to UUID object if string try: @@ -141,27 +204,31 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): return None try: - obj = db.query(self.model).filter(self.model.id == uuid_obj).first() + result = await db.execute( + select(self.model).where(self.model.id == uuid_obj) + ) + obj = result.scalar_one_or_none() + if obj is None: logger.warning(f"{self.model.__name__} with id {id} not found for deletion") return None - db.delete(obj) - db.commit() + await db.delete(obj) + await db.commit() return obj except IntegrityError as e: - db.rollback() + await db.rollback() error_msg = str(e.orig) if hasattr(e, 'orig') else str(e) logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}") raise ValueError(f"Cannot delete {self.model.__name__}: referenced by other records") except Exception as e: - db.rollback() + await db.rollback() logger.error(f"Error deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True) raise - def get_multi_with_total( + async def get_multi_with_total( self, - db: Session, + db: AsyncSession, *, skip: int = 0, limit: int = 100, @@ -193,43 +260,63 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): try: # Build base query - query = db.query(self.model) + query = select(self.model) # Exclude soft-deleted records by default if hasattr(self.model, 'deleted_at'): - query = query.filter(self.model.deleted_at.is_(None)) + query = query.where(self.model.deleted_at.is_(None)) # Apply filters if filters: for field, value in filters.items(): if hasattr(self.model, field) and value is not None: - query = query.filter(getattr(self.model, field) == value) + query = query.where(getattr(self.model, field) == value) # Get total count (before pagination) - total = query.count() + count_query = select(func.count()).select_from(query.alias()) + count_result = await db.execute(count_query) + total = count_result.scalar_one() # Apply sorting if sort_by and hasattr(self.model, sort_by): sort_column = getattr(self.model, sort_by) if sort_order.lower() == "desc": - query = query.order_by(desc(sort_column)) + query = query.order_by(sort_column.desc()) else: - query = query.order_by(asc(sort_column)) + query = query.order_by(sort_column.asc()) # Apply pagination - items = query.offset(skip).limit(limit).all() + query = query.offset(skip).limit(limit) + items_result = await db.execute(query) + items = list(items_result.scalars().all()) return items, total except Exception as e: logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}") raise - def soft_delete(self, db: Session, *, id: str) -> Optional[ModelType]: + async def count(self, db: AsyncSession) -> int: + """Get total count of records.""" + try: + result = await db.execute(select(func.count(self.model.id))) + return result.scalar_one() + except Exception as e: + logger.error(f"Error counting {self.model.__name__} records: {str(e)}") + raise + + async def exists(self, db: AsyncSession, id: str) -> bool: + """Check if a record exists by ID.""" + obj = await self.get(db, id=id) + return obj is not None + + async def soft_delete(self, db: AsyncSession, *, id: str) -> Optional[ModelType]: """ Soft delete a record by setting deleted_at timestamp. Only works if the model has a 'deleted_at' column. """ + from datetime import datetime, timezone + # Validate UUID format and convert to UUID object if string try: if isinstance(id, uuid.UUID): @@ -241,7 +328,10 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): return None try: - obj = db.query(self.model).filter(self.model.id == uuid_obj).first() + result = await db.execute( + select(self.model).where(self.model.id == uuid_obj) + ) + obj = result.scalar_one_or_none() if obj is None: logger.warning(f"{self.model.__name__} with id {id} not found for soft deletion") @@ -255,15 +345,15 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): # Set deleted_at timestamp obj.deleted_at = datetime.now(timezone.utc) db.add(obj) - db.commit() - db.refresh(obj) + await db.commit() + await db.refresh(obj) return obj except Exception as e: - db.rollback() + await db.rollback() logger.error(f"Error soft deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True) raise - def restore(self, db: Session, *, id: str) -> Optional[ModelType]: + async def restore(self, db: AsyncSession, *, id: str) -> Optional[ModelType]: """ Restore a soft-deleted record by clearing the deleted_at timestamp. @@ -282,10 +372,13 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): try: # Find the soft-deleted record if hasattr(self.model, 'deleted_at'): - obj = db.query(self.model).filter( - self.model.id == uuid_obj, - self.model.deleted_at.isnot(None) - ).first() + result = await db.execute( + select(self.model).where( + self.model.id == uuid_obj, + self.model.deleted_at.isnot(None) + ) + ) + obj = result.scalar_one_or_none() else: logger.error(f"{self.model.__name__} does not support soft deletes") raise ValueError(f"{self.model.__name__} does not have a deleted_at column") @@ -297,10 +390,10 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): # Clear deleted_at timestamp obj.deleted_at = None db.add(obj) - db.commit() - db.refresh(obj) + await db.commit() + await db.refresh(obj) return obj except Exception as e: - db.rollback() + await db.rollback() logger.error(f"Error restoring {self.model.__name__} with id {id}: {str(e)}", exc_info=True) - raise \ No newline at end of file + raise diff --git a/backend/app/crud/base_async.py b/backend/app/crud/base_async.py deleted file mode 100755 index c10715f..0000000 --- a/backend/app/crud/base_async.py +++ /dev/null @@ -1,399 +0,0 @@ -# app/crud/base_async.py -""" -Async CRUD operations base class using SQLAlchemy 2.0 async patterns. - -Provides reusable create, read, update, and delete operations for all models. -""" -import logging -import uuid -from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple - -from fastapi.encoders import jsonable_encoder -from pydantic import BaseModel -from sqlalchemy import func, select -from sqlalchemy.exc import IntegrityError, OperationalError, DataError -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import Load - -from app.core.database_async import Base - -logger = logging.getLogger(__name__) - -ModelType = TypeVar("ModelType", bound=Base) -CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) -UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) - - -class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): - """Async CRUD operations for a model.""" - - def __init__(self, model: Type[ModelType]): - """ - CRUD object with default async methods to Create, Read, Update, Delete. - - Parameters: - model: A SQLAlchemy model class - """ - self.model = model - - async def get( - self, - db: AsyncSession, - id: str, - options: Optional[List[Load]] = None - ) -> Optional[ModelType]: - """ - Get a single record by ID with UUID validation and optional eager loading. - - Args: - db: Database session - id: Record UUID - options: Optional list of SQLAlchemy load options (e.g., joinedload, selectinload) - for eager loading relationships to prevent N+1 queries - - Returns: - Model instance or None if not found - - Example: - # Eager load user relationship - from sqlalchemy.orm import joinedload - session = await session_crud.get(db, id=session_id, options=[joinedload(UserSession.user)]) - """ - # Validate UUID format and convert to UUID object if string - try: - if isinstance(id, uuid.UUID): - uuid_obj = id - else: - uuid_obj = uuid.UUID(str(id)) - except (ValueError, AttributeError, TypeError) as e: - logger.warning(f"Invalid UUID format: {id} - {str(e)}") - return None - - try: - query = select(self.model).where(self.model.id == uuid_obj) - - # Apply eager loading options if provided - if options: - for option in options: - query = query.options(option) - - result = await db.execute(query) - return result.scalar_one_or_none() - except Exception as e: - logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}") - raise - - async def get_multi( - self, - db: AsyncSession, - *, - skip: int = 0, - limit: int = 100, - options: Optional[List[Load]] = None - ) -> List[ModelType]: - """ - Get multiple records with pagination validation and optional eager loading. - - Args: - db: Database session - skip: Number of records to skip - limit: Maximum number of records to return - options: Optional list of SQLAlchemy load options for eager loading - - Returns: - List of model instances - """ - # Validate pagination parameters - if skip < 0: - raise ValueError("skip must be non-negative") - if limit < 0: - raise ValueError("limit must be non-negative") - if limit > 1000: - raise ValueError("Maximum limit is 1000") - - try: - query = select(self.model).offset(skip).limit(limit) - - # Apply eager loading options if provided - if options: - for option in options: - query = query.options(option) - - result = await db.execute(query) - return list(result.scalars().all()) - except Exception as e: - logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}") - raise - - async def create(self, db: AsyncSession, *, obj_in: CreateSchemaType) -> ModelType: - """Create a new record with error handling.""" - try: - obj_in_data = jsonable_encoder(obj_in) - db_obj = self.model(**obj_in_data) - db.add(db_obj) - await db.commit() - await db.refresh(db_obj) - return db_obj - except IntegrityError as e: - await db.rollback() - error_msg = str(e.orig) if hasattr(e, 'orig') else str(e) - if "unique" in error_msg.lower() or "duplicate" in error_msg.lower(): - logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}") - raise ValueError(f"A {self.model.__name__} with this data already exists") - logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}") - raise ValueError(f"Database integrity error: {error_msg}") - except (OperationalError, DataError) as e: - await db.rollback() - logger.error(f"Database error creating {self.model.__name__}: {str(e)}") - raise ValueError(f"Database operation failed: {str(e)}") - except Exception as e: - await db.rollback() - logger.error(f"Unexpected error creating {self.model.__name__}: {str(e)}", exc_info=True) - raise - - async def update( - self, - db: AsyncSession, - *, - db_obj: ModelType, - obj_in: Union[UpdateSchemaType, Dict[str, Any]] - ) -> ModelType: - """Update a record with error handling.""" - try: - obj_data = jsonable_encoder(db_obj) - if isinstance(obj_in, dict): - update_data = obj_in - else: - update_data = obj_in.model_dump(exclude_unset=True) - - for field in obj_data: - if field in update_data: - setattr(db_obj, field, update_data[field]) - - db.add(db_obj) - await db.commit() - await db.refresh(db_obj) - return db_obj - except IntegrityError as e: - await db.rollback() - error_msg = str(e.orig) if hasattr(e, 'orig') else str(e) - if "unique" in error_msg.lower() or "duplicate" in error_msg.lower(): - logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}") - raise ValueError(f"A {self.model.__name__} with this data already exists") - logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}") - raise ValueError(f"Database integrity error: {error_msg}") - except (OperationalError, DataError) as e: - await db.rollback() - logger.error(f"Database error updating {self.model.__name__}: {str(e)}") - raise ValueError(f"Database operation failed: {str(e)}") - except Exception as e: - await db.rollback() - logger.error(f"Unexpected error updating {self.model.__name__}: {str(e)}", exc_info=True) - raise - - async def remove(self, db: AsyncSession, *, id: str) -> Optional[ModelType]: - """Delete a record with error handling and null check.""" - # Validate UUID format and convert to UUID object if string - try: - if isinstance(id, uuid.UUID): - uuid_obj = id - else: - uuid_obj = uuid.UUID(str(id)) - except (ValueError, AttributeError, TypeError) as e: - logger.warning(f"Invalid UUID format for deletion: {id} - {str(e)}") - return None - - try: - result = await db.execute( - select(self.model).where(self.model.id == uuid_obj) - ) - obj = result.scalar_one_or_none() - - if obj is None: - logger.warning(f"{self.model.__name__} with id {id} not found for deletion") - return None - - await db.delete(obj) - await db.commit() - return obj - except IntegrityError as e: - await db.rollback() - error_msg = str(e.orig) if hasattr(e, 'orig') else str(e) - logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}") - raise ValueError(f"Cannot delete {self.model.__name__}: referenced by other records") - except Exception as e: - await db.rollback() - logger.error(f"Error deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True) - raise - - async def get_multi_with_total( - self, - db: AsyncSession, - *, - skip: int = 0, - limit: int = 100, - sort_by: Optional[str] = None, - sort_order: str = "asc", - filters: Optional[Dict[str, Any]] = None - ) -> Tuple[List[ModelType], int]: - """ - Get multiple records with total count, filtering, and sorting. - - Args: - db: Database session - skip: Number of records to skip - limit: Maximum number of records to return - sort_by: Field name to sort by (must be a valid model attribute) - sort_order: Sort order ("asc" or "desc") - filters: Dictionary of filters (field_name: value) - - Returns: - Tuple of (items, total_count) - """ - # Validate pagination parameters - if skip < 0: - raise ValueError("skip must be non-negative") - if limit < 0: - raise ValueError("limit must be non-negative") - if limit > 1000: - raise ValueError("Maximum limit is 1000") - - try: - # Build base query - query = select(self.model) - - # Exclude soft-deleted records by default - if hasattr(self.model, 'deleted_at'): - query = query.where(self.model.deleted_at.is_(None)) - - # Apply filters - if filters: - for field, value in filters.items(): - if hasattr(self.model, field) and value is not None: - query = query.where(getattr(self.model, field) == value) - - # Get total count (before pagination) - count_query = select(func.count()).select_from(query.alias()) - count_result = await db.execute(count_query) - total = count_result.scalar_one() - - # Apply sorting - if sort_by and hasattr(self.model, sort_by): - sort_column = getattr(self.model, sort_by) - if sort_order.lower() == "desc": - query = query.order_by(sort_column.desc()) - else: - query = query.order_by(sort_column.asc()) - - # Apply pagination - query = query.offset(skip).limit(limit) - items_result = await db.execute(query) - items = list(items_result.scalars().all()) - - return items, total - except Exception as e: - logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}") - raise - - async def count(self, db: AsyncSession) -> int: - """Get total count of records.""" - try: - result = await db.execute(select(func.count(self.model.id))) - return result.scalar_one() - except Exception as e: - logger.error(f"Error counting {self.model.__name__} records: {str(e)}") - raise - - async def exists(self, db: AsyncSession, id: str) -> bool: - """Check if a record exists by ID.""" - obj = await self.get(db, id=id) - return obj is not None - - async def soft_delete(self, db: AsyncSession, *, id: str) -> Optional[ModelType]: - """ - Soft delete a record by setting deleted_at timestamp. - - Only works if the model has a 'deleted_at' column. - """ - from datetime import datetime, timezone - - # Validate UUID format and convert to UUID object if string - try: - if isinstance(id, uuid.UUID): - uuid_obj = id - else: - uuid_obj = uuid.UUID(str(id)) - except (ValueError, AttributeError, TypeError) as e: - logger.warning(f"Invalid UUID format for soft deletion: {id} - {str(e)}") - return None - - try: - result = await db.execute( - select(self.model).where(self.model.id == uuid_obj) - ) - obj = result.scalar_one_or_none() - - if obj is None: - logger.warning(f"{self.model.__name__} with id {id} not found for soft deletion") - return None - - # Check if model supports soft deletes - if not hasattr(self.model, 'deleted_at'): - logger.error(f"{self.model.__name__} does not support soft deletes") - raise ValueError(f"{self.model.__name__} does not have a deleted_at column") - - # Set deleted_at timestamp - obj.deleted_at = datetime.now(timezone.utc) - db.add(obj) - await db.commit() - await db.refresh(obj) - return obj - except Exception as e: - await db.rollback() - logger.error(f"Error soft deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True) - raise - - async def restore(self, db: AsyncSession, *, id: str) -> Optional[ModelType]: - """ - Restore a soft-deleted record by clearing the deleted_at timestamp. - - Only works if the model has a 'deleted_at' column. - """ - # Validate UUID format - try: - if isinstance(id, uuid.UUID): - uuid_obj = id - else: - uuid_obj = uuid.UUID(str(id)) - except (ValueError, AttributeError, TypeError) as e: - logger.warning(f"Invalid UUID format for restoration: {id} - {str(e)}") - return None - - try: - # Find the soft-deleted record - if hasattr(self.model, 'deleted_at'): - result = await db.execute( - select(self.model).where( - self.model.id == uuid_obj, - self.model.deleted_at.isnot(None) - ) - ) - obj = result.scalar_one_or_none() - else: - logger.error(f"{self.model.__name__} does not support soft deletes") - raise ValueError(f"{self.model.__name__} does not have a deleted_at column") - - if obj is None: - logger.warning(f"Soft-deleted {self.model.__name__} with id {id} not found for restoration") - return None - - # Clear deleted_at timestamp - obj.deleted_at = None - db.add(obj) - await db.commit() - await db.refresh(obj) - return obj - except Exception as e: - await db.rollback() - logger.error(f"Error restoring {self.model.__name__} with id {id}: {str(e)}", exc_info=True) - raise diff --git a/backend/app/crud/organization.py b/backend/app/crud/organization.py old mode 100644 new mode 100755 index 0a86b31..8e7304e --- a/backend/app/crud/organization.py +++ b/backend/app/crud/organization.py @@ -1,11 +1,12 @@ -# app/crud/organization.py +# app/crud/organization_async.py +"""Async CRUD operations for Organization model using SQLAlchemy 2.0 patterns.""" import logging from typing import Optional, List, Dict, Any from uuid import UUID -from sqlalchemy import func, or_, and_ +from sqlalchemy import func, or_, and_, select from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession from app.crud.base import CRUDBase from app.models.organization import Organization @@ -13,20 +14,27 @@ from app.models.user import User from app.models.user_organization import UserOrganization, OrganizationRole from app.schemas.organizations import ( OrganizationCreate, - OrganizationUpdate + OrganizationUpdate, ) logger = logging.getLogger(__name__) class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUpdate]): - """CRUD operations for Organization model.""" + """Async CRUD operations for Organization model.""" - def get_by_slug(self, db: Session, *, slug: str) -> Optional[Organization]: + async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Optional[Organization]: """Get organization by slug.""" - return db.query(Organization).filter(Organization.slug == slug).first() + try: + result = await db.execute( + select(Organization).where(Organization.slug == slug) + ) + return result.scalar_one_or_none() + except Exception as e: + logger.error(f"Error getting organization by slug {slug}: {str(e)}") + raise - def create(self, db: Session, *, obj_in: OrganizationCreate) -> Organization: + async def create(self, db: AsyncSession, *, obj_in: OrganizationCreate) -> Organization: """Create a new organization with error handling.""" try: db_obj = Organization( @@ -37,11 +45,11 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp settings=obj_in.settings or {} ) db.add(db_obj) - db.commit() - db.refresh(db_obj) + await db.commit() + await db.refresh(db_obj) return db_obj except IntegrityError as e: - db.rollback() + await db.rollback() error_msg = str(e.orig) if hasattr(e, 'orig') else str(e) if "slug" in error_msg.lower(): logger.warning(f"Duplicate slug attempted: {obj_in.slug}") @@ -49,13 +57,13 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp logger.error(f"Integrity error creating organization: {error_msg}") raise ValueError(f"Database integrity error: {error_msg}") except Exception as e: - db.rollback() + await db.rollback() logger.error(f"Unexpected error creating organization: {str(e)}", exc_info=True) raise - def get_multi_with_filters( + async def get_multi_with_filters( self, - db: Session, + db: AsyncSession, *, skip: int = 0, limit: int = 100, @@ -70,47 +78,139 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp Returns: Tuple of (organizations list, total count) """ - query = db.query(Organization) + try: + query = select(Organization) - # Apply filters - if is_active is not None: - query = query.filter(Organization.is_active == is_active) + # Apply filters + if is_active is not None: + query = query.where(Organization.is_active == is_active) - if search: - search_filter = or_( - Organization.name.ilike(f"%{search}%"), - Organization.slug.ilike(f"%{search}%"), - Organization.description.ilike(f"%{search}%") - ) - query = query.filter(search_filter) + if search: + search_filter = or_( + Organization.name.ilike(f"%{search}%"), + Organization.slug.ilike(f"%{search}%"), + Organization.description.ilike(f"%{search}%") + ) + query = query.where(search_filter) - # Get total count before pagination - total = query.count() + # Get total count before pagination + count_query = select(func.count()).select_from(query.alias()) + count_result = await db.execute(count_query) + total = count_result.scalar_one() - # Apply sorting - sort_column = getattr(Organization, sort_by, Organization.created_at) - if sort_order == "desc": - query = query.order_by(sort_column.desc()) - else: - query = query.order_by(sort_column.asc()) + # Apply sorting + sort_column = getattr(Organization, sort_by, Organization.created_at) + if sort_order == "desc": + query = query.order_by(sort_column.desc()) + else: + query = query.order_by(sort_column.asc()) - # Apply pagination - organizations = query.offset(skip).limit(limit).all() + # Apply pagination + query = query.offset(skip).limit(limit) + result = await db.execute(query) + organizations = list(result.scalars().all()) - return organizations, total + return organizations, total + except Exception as e: + logger.error(f"Error getting organizations with filters: {str(e)}") + raise - def get_member_count(self, db: Session, *, organization_id: UUID) -> int: + async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int: """Get the count of active members in an organization.""" - return db.query(func.count(UserOrganization.user_id)).filter( - and_( - UserOrganization.organization_id == organization_id, - UserOrganization.is_active == True + try: + result = await db.execute( + select(func.count(UserOrganization.user_id)).where( + and_( + UserOrganization.organization_id == organization_id, + UserOrganization.is_active == True + ) + ) ) - ).scalar() or 0 + return result.scalar_one() or 0 + except Exception as e: + logger.error(f"Error getting member count for organization {organization_id}: {str(e)}") + raise - def add_user( + async def get_multi_with_member_counts( self, - db: Session, + db: AsyncSession, + *, + skip: int = 0, + limit: int = 100, + is_active: Optional[bool] = None, + search: Optional[str] = None + ) -> tuple[List[Dict[str, Any]], int]: + """ + Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY. + This eliminates the N+1 query problem. + + Returns: + Tuple of (list of dicts with org and member_count, total count) + """ + try: + # Build base query with LEFT JOIN and GROUP BY + query = ( + select( + Organization, + func.count( + func.distinct( + and_( + UserOrganization.is_active == True, + UserOrganization.user_id + ).self_group() + ) + ).label('member_count') + ) + .outerjoin(UserOrganization, Organization.id == UserOrganization.organization_id) + .group_by(Organization.id) + ) + + # Apply filters + if is_active is not None: + query = query.where(Organization.is_active == is_active) + + if search: + search_filter = or_( + Organization.name.ilike(f"%{search}%"), + Organization.slug.ilike(f"%{search}%"), + Organization.description.ilike(f"%{search}%") + ) + query = query.where(search_filter) + + # Get total count + count_query = select(func.count(Organization.id)) + if is_active is not None: + count_query = count_query.where(Organization.is_active == is_active) + if search: + count_query = count_query.where(search_filter) + + count_result = await db.execute(count_query) + total = count_result.scalar_one() + + # Apply pagination and ordering + query = query.order_by(Organization.created_at.desc()).offset(skip).limit(limit) + + result = await db.execute(query) + rows = result.all() + + # Convert to list of dicts + orgs_with_counts = [ + { + 'organization': org, + 'member_count': member_count + } + for org, member_count in rows + ] + + return orgs_with_counts, total + + except Exception as e: + logger.error(f"Error getting organizations with member counts: {str(e)}", exc_info=True) + raise + + async def add_user( + self, + db: AsyncSession, *, organization_id: UUID, user_id: UUID, @@ -120,12 +220,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp """Add a user to an organization with a specific role.""" try: # Check if relationship already exists - existing = db.query(UserOrganization).filter( - and_( - UserOrganization.user_id == user_id, - UserOrganization.organization_id == organization_id + result = await db.execute( + select(UserOrganization).where( + and_( + UserOrganization.user_id == user_id, + UserOrganization.organization_id == organization_id + ) ) - ).first() + ) + existing = result.scalar_one_or_none() if existing: # Reactivate if inactive, or raise error if already active @@ -133,8 +236,8 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp existing.is_active = True existing.role = role existing.custom_permissions = custom_permissions - db.commit() - db.refresh(existing) + await db.commit() + await db.refresh(existing) return existing else: raise ValueError("User is already a member of this organization") @@ -148,48 +251,51 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp custom_permissions=custom_permissions ) db.add(user_org) - db.commit() - db.refresh(user_org) + await db.commit() + await db.refresh(user_org) return user_org except IntegrityError as e: - db.rollback() + await db.rollback() logger.error(f"Integrity error adding user to organization: {str(e)}") raise ValueError("Failed to add user to organization") except Exception as e: - db.rollback() + await db.rollback() logger.error(f"Error adding user to organization: {str(e)}", exc_info=True) raise - def remove_user( + async def remove_user( self, - db: Session, + db: AsyncSession, *, organization_id: UUID, user_id: UUID ) -> bool: """Remove a user from an organization (soft delete).""" try: - user_org = db.query(UserOrganization).filter( - and_( - UserOrganization.user_id == user_id, - UserOrganization.organization_id == organization_id + result = await db.execute( + select(UserOrganization).where( + and_( + UserOrganization.user_id == user_id, + UserOrganization.organization_id == organization_id + ) ) - ).first() + ) + user_org = result.scalar_one_or_none() if not user_org: return False user_org.is_active = False - db.commit() + await db.commit() return True except Exception as e: - db.rollback() + await db.rollback() logger.error(f"Error removing user from organization: {str(e)}", exc_info=True) raise - def update_user_role( + async def update_user_role( self, - db: Session, + db: AsyncSession, *, organization_id: UUID, user_id: UUID, @@ -198,12 +304,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp ) -> Optional[UserOrganization]: """Update a user's role in an organization.""" try: - user_org = db.query(UserOrganization).filter( - and_( - UserOrganization.user_id == user_id, - UserOrganization.organization_id == organization_id + result = await db.execute( + select(UserOrganization).where( + and_( + UserOrganization.user_id == user_id, + UserOrganization.organization_id == organization_id + ) ) - ).first() + ) + user_org = result.scalar_one_or_none() if not user_org: return None @@ -211,17 +320,17 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp user_org.role = role if custom_permissions is not None: user_org.custom_permissions = custom_permissions - db.commit() - db.refresh(user_org) + await db.commit() + await db.refresh(user_org) return user_org except Exception as e: - db.rollback() + await db.rollback() logger.error(f"Error updating user role: {str(e)}", exc_info=True) raise - def get_organization_members( + async def get_organization_members( self, - db: Session, + db: AsyncSession, *, organization_id: UUID, skip: int = 0, @@ -234,86 +343,175 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp Returns: Tuple of (members list with user details, total count) """ - query = db.query(UserOrganization, User).join( - User, UserOrganization.user_id == User.id - ).filter(UserOrganization.organization_id == organization_id) + try: + # Build query with join + query = ( + select(UserOrganization, User) + .join(User, UserOrganization.user_id == User.id) + .where(UserOrganization.organization_id == organization_id) + ) - if is_active is not None: - query = query.filter(UserOrganization.is_active == is_active) + if is_active is not None: + query = query.where(UserOrganization.is_active == is_active) - total = query.count() + # Get total count + count_query = select(func.count()).select_from( + select(UserOrganization) + .where(UserOrganization.organization_id == organization_id) + .where(UserOrganization.is_active == is_active if is_active is not None else True) + .alias() + ) + count_result = await db.execute(count_query) + total = count_result.scalar_one() - results = query.order_by(UserOrganization.created_at.desc()).offset(skip).limit(limit).all() + # Apply ordering and pagination + query = query.order_by(UserOrganization.created_at.desc()).offset(skip).limit(limit) + result = await db.execute(query) + results = result.all() - members = [] - for user_org, user in results: - members.append({ - "user_id": user.id, - "email": user.email, - "first_name": user.first_name, - "last_name": user.last_name, - "role": user_org.role, - "is_active": user_org.is_active, - "joined_at": user_org.created_at - }) + members = [] + for user_org, user in results: + members.append({ + "user_id": user.id, + "email": user.email, + "first_name": user.first_name, + "last_name": user.last_name, + "role": user_org.role, + "is_active": user_org.is_active, + "joined_at": user_org.created_at + }) - return members, total + return members, total + except Exception as e: + logger.error(f"Error getting organization members: {str(e)}") + raise - def get_user_organizations( + async def get_user_organizations( self, - db: Session, + db: AsyncSession, *, user_id: UUID, is_active: bool = True ) -> List[Organization]: """Get all organizations a user belongs to.""" - query = db.query(Organization).join( - UserOrganization, Organization.id == UserOrganization.organization_id - ).filter(UserOrganization.user_id == user_id) + try: + query = ( + select(Organization) + .join(UserOrganization, Organization.id == UserOrganization.organization_id) + .where(UserOrganization.user_id == user_id) + ) - if is_active is not None: - query = query.filter(UserOrganization.is_active == is_active) + if is_active is not None: + query = query.where(UserOrganization.is_active == is_active) - return query.all() + result = await db.execute(query) + return list(result.scalars().all()) + except Exception as e: + logger.error(f"Error getting user organizations: {str(e)}") + raise - def get_user_role_in_org( + async def get_user_organizations_with_details( self, - db: Session, + db: AsyncSession, + *, + user_id: UUID, + is_active: bool = True + ) -> List[Dict[str, Any]]: + """ + Get user's organizations with role and member count in SINGLE QUERY. + Eliminates N+1 problem by using subquery for member counts. + + Returns: + List of dicts with organization, role, and member_count + """ + try: + # Subquery to get member counts for each organization + member_count_subq = ( + select( + UserOrganization.organization_id, + func.count(UserOrganization.user_id).label('member_count') + ) + .where(UserOrganization.is_active == True) + .group_by(UserOrganization.organization_id) + .subquery() + ) + + # Main query with JOIN to get org, role, and member count + query = ( + select( + Organization, + UserOrganization.role, + func.coalesce(member_count_subq.c.member_count, 0).label('member_count') + ) + .join(UserOrganization, Organization.id == UserOrganization.organization_id) + .outerjoin(member_count_subq, Organization.id == member_count_subq.c.organization_id) + .where(UserOrganization.user_id == user_id) + ) + + if is_active is not None: + query = query.where(UserOrganization.is_active == is_active) + + result = await db.execute(query) + rows = result.all() + + return [ + { + 'organization': org, + 'role': role, + 'member_count': member_count + } + for org, role, member_count in rows + ] + + except Exception as e: + logger.error(f"Error getting user organizations with details: {str(e)}", exc_info=True) + raise + + async def get_user_role_in_org( + self, + db: AsyncSession, *, user_id: UUID, organization_id: UUID ) -> Optional[OrganizationRole]: """Get a user's role in a specific organization.""" - user_org = db.query(UserOrganization).filter( - and_( - UserOrganization.user_id == user_id, - UserOrganization.organization_id == organization_id, - UserOrganization.is_active == True + try: + result = await db.execute( + select(UserOrganization).where( + and_( + UserOrganization.user_id == user_id, + UserOrganization.organization_id == organization_id, + UserOrganization.is_active == True + ) + ) ) - ).first() + user_org = result.scalar_one_or_none() - return user_org.role if user_org else None + return user_org.role if user_org else None + except Exception as e: + logger.error(f"Error getting user role in org: {str(e)}") + raise - def is_user_org_owner( + async def is_user_org_owner( self, - db: Session, + db: AsyncSession, *, user_id: UUID, organization_id: UUID ) -> bool: """Check if a user is an owner of an organization.""" - role = self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id) + role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id) return role == OrganizationRole.OWNER - def is_user_org_admin( + async def is_user_org_admin( self, - db: Session, + db: AsyncSession, *, user_id: UUID, organization_id: UUID ) -> bool: """Check if a user is an owner or admin of an organization.""" - role = self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id) + role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id) return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN] diff --git a/backend/app/crud/organization_async.py b/backend/app/crud/organization_async.py deleted file mode 100755 index fe8f010..0000000 --- a/backend/app/crud/organization_async.py +++ /dev/null @@ -1,519 +0,0 @@ -# app/crud/organization_async.py -"""Async CRUD operations for Organization model using SQLAlchemy 2.0 patterns.""" -import logging -from typing import Optional, List, Dict, Any -from uuid import UUID - -from sqlalchemy import func, or_, and_, select -from sqlalchemy.exc import IntegrityError -from sqlalchemy.ext.asyncio import AsyncSession - -from app.crud.base_async import CRUDBaseAsync -from app.models.organization import Organization -from app.models.user import User -from app.models.user_organization import UserOrganization, OrganizationRole -from app.schemas.organizations import ( - OrganizationCreate, - OrganizationUpdate, -) - -logger = logging.getLogger(__name__) - - -class CRUDOrganizationAsync(CRUDBaseAsync[Organization, OrganizationCreate, OrganizationUpdate]): - """Async CRUD operations for Organization model.""" - - async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Optional[Organization]: - """Get organization by slug.""" - try: - result = await db.execute( - select(Organization).where(Organization.slug == slug) - ) - return result.scalar_one_or_none() - except Exception as e: - logger.error(f"Error getting organization by slug {slug}: {str(e)}") - raise - - async def create(self, db: AsyncSession, *, obj_in: OrganizationCreate) -> Organization: - """Create a new organization with error handling.""" - try: - db_obj = Organization( - name=obj_in.name, - slug=obj_in.slug, - description=obj_in.description, - is_active=obj_in.is_active, - settings=obj_in.settings or {} - ) - db.add(db_obj) - await db.commit() - await db.refresh(db_obj) - return db_obj - except IntegrityError as e: - await db.rollback() - error_msg = str(e.orig) if hasattr(e, 'orig') else str(e) - if "slug" in error_msg.lower(): - logger.warning(f"Duplicate slug attempted: {obj_in.slug}") - raise ValueError(f"Organization with slug '{obj_in.slug}' already exists") - logger.error(f"Integrity error creating organization: {error_msg}") - raise ValueError(f"Database integrity error: {error_msg}") - except Exception as e: - await db.rollback() - logger.error(f"Unexpected error creating organization: {str(e)}", exc_info=True) - raise - - async def get_multi_with_filters( - self, - db: AsyncSession, - *, - skip: int = 0, - limit: int = 100, - is_active: Optional[bool] = None, - search: Optional[str] = None, - sort_by: str = "created_at", - sort_order: str = "desc" - ) -> tuple[List[Organization], int]: - """ - Get multiple organizations with filtering, searching, and sorting. - - Returns: - Tuple of (organizations list, total count) - """ - try: - query = select(Organization) - - # Apply filters - if is_active is not None: - query = query.where(Organization.is_active == is_active) - - if search: - search_filter = or_( - Organization.name.ilike(f"%{search}%"), - Organization.slug.ilike(f"%{search}%"), - Organization.description.ilike(f"%{search}%") - ) - query = query.where(search_filter) - - # Get total count before pagination - count_query = select(func.count()).select_from(query.alias()) - count_result = await db.execute(count_query) - total = count_result.scalar_one() - - # Apply sorting - sort_column = getattr(Organization, sort_by, Organization.created_at) - if sort_order == "desc": - query = query.order_by(sort_column.desc()) - else: - query = query.order_by(sort_column.asc()) - - # Apply pagination - query = query.offset(skip).limit(limit) - result = await db.execute(query) - organizations = list(result.scalars().all()) - - return organizations, total - except Exception as e: - logger.error(f"Error getting organizations with filters: {str(e)}") - raise - - async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int: - """Get the count of active members in an organization.""" - try: - result = await db.execute( - select(func.count(UserOrganization.user_id)).where( - and_( - UserOrganization.organization_id == organization_id, - UserOrganization.is_active == True - ) - ) - ) - return result.scalar_one() or 0 - except Exception as e: - logger.error(f"Error getting member count for organization {organization_id}: {str(e)}") - raise - - async def get_multi_with_member_counts( - self, - db: AsyncSession, - *, - skip: int = 0, - limit: int = 100, - is_active: Optional[bool] = None, - search: Optional[str] = None - ) -> tuple[List[Dict[str, Any]], int]: - """ - Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY. - This eliminates the N+1 query problem. - - Returns: - Tuple of (list of dicts with org and member_count, total count) - """ - try: - # Build base query with LEFT JOIN and GROUP BY - query = ( - select( - Organization, - func.count( - func.distinct( - and_( - UserOrganization.is_active == True, - UserOrganization.user_id - ).self_group() - ) - ).label('member_count') - ) - .outerjoin(UserOrganization, Organization.id == UserOrganization.organization_id) - .group_by(Organization.id) - ) - - # Apply filters - if is_active is not None: - query = query.where(Organization.is_active == is_active) - - if search: - search_filter = or_( - Organization.name.ilike(f"%{search}%"), - Organization.slug.ilike(f"%{search}%"), - Organization.description.ilike(f"%{search}%") - ) - query = query.where(search_filter) - - # Get total count - count_query = select(func.count(Organization.id)) - if is_active is not None: - count_query = count_query.where(Organization.is_active == is_active) - if search: - count_query = count_query.where(search_filter) - - count_result = await db.execute(count_query) - total = count_result.scalar_one() - - # Apply pagination and ordering - query = query.order_by(Organization.created_at.desc()).offset(skip).limit(limit) - - result = await db.execute(query) - rows = result.all() - - # Convert to list of dicts - orgs_with_counts = [ - { - 'organization': org, - 'member_count': member_count - } - for org, member_count in rows - ] - - return orgs_with_counts, total - - except Exception as e: - logger.error(f"Error getting organizations with member counts: {str(e)}", exc_info=True) - raise - - async def add_user( - self, - db: AsyncSession, - *, - organization_id: UUID, - user_id: UUID, - role: OrganizationRole = OrganizationRole.MEMBER, - custom_permissions: Optional[str] = None - ) -> UserOrganization: - """Add a user to an organization with a specific role.""" - try: - # Check if relationship already exists - result = await db.execute( - select(UserOrganization).where( - and_( - UserOrganization.user_id == user_id, - UserOrganization.organization_id == organization_id - ) - ) - ) - existing = result.scalar_one_or_none() - - if existing: - # Reactivate if inactive, or raise error if already active - if not existing.is_active: - existing.is_active = True - existing.role = role - existing.custom_permissions = custom_permissions - await db.commit() - await db.refresh(existing) - return existing - else: - raise ValueError("User is already a member of this organization") - - # Create new relationship - user_org = UserOrganization( - user_id=user_id, - organization_id=organization_id, - role=role, - is_active=True, - custom_permissions=custom_permissions - ) - db.add(user_org) - await db.commit() - await db.refresh(user_org) - return user_org - except IntegrityError as e: - await db.rollback() - logger.error(f"Integrity error adding user to organization: {str(e)}") - raise ValueError("Failed to add user to organization") - except Exception as e: - await db.rollback() - logger.error(f"Error adding user to organization: {str(e)}", exc_info=True) - raise - - async def remove_user( - self, - db: AsyncSession, - *, - organization_id: UUID, - user_id: UUID - ) -> bool: - """Remove a user from an organization (soft delete).""" - try: - result = await db.execute( - select(UserOrganization).where( - and_( - UserOrganization.user_id == user_id, - UserOrganization.organization_id == organization_id - ) - ) - ) - user_org = result.scalar_one_or_none() - - if not user_org: - return False - - user_org.is_active = False - await db.commit() - return True - except Exception as e: - await db.rollback() - logger.error(f"Error removing user from organization: {str(e)}", exc_info=True) - raise - - async def update_user_role( - self, - db: AsyncSession, - *, - organization_id: UUID, - user_id: UUID, - role: OrganizationRole, - custom_permissions: Optional[str] = None - ) -> Optional[UserOrganization]: - """Update a user's role in an organization.""" - try: - result = await db.execute( - select(UserOrganization).where( - and_( - UserOrganization.user_id == user_id, - UserOrganization.organization_id == organization_id - ) - ) - ) - user_org = result.scalar_one_or_none() - - if not user_org: - return None - - user_org.role = role - if custom_permissions is not None: - user_org.custom_permissions = custom_permissions - await db.commit() - await db.refresh(user_org) - return user_org - except Exception as e: - await db.rollback() - logger.error(f"Error updating user role: {str(e)}", exc_info=True) - raise - - async def get_organization_members( - self, - db: AsyncSession, - *, - organization_id: UUID, - skip: int = 0, - limit: int = 100, - is_active: bool = True - ) -> tuple[List[Dict[str, Any]], int]: - """ - Get members of an organization with user details. - - Returns: - Tuple of (members list with user details, total count) - """ - try: - # Build query with join - query = ( - select(UserOrganization, User) - .join(User, UserOrganization.user_id == User.id) - .where(UserOrganization.organization_id == organization_id) - ) - - if is_active is not None: - query = query.where(UserOrganization.is_active == is_active) - - # Get total count - count_query = select(func.count()).select_from( - select(UserOrganization) - .where(UserOrganization.organization_id == organization_id) - .where(UserOrganization.is_active == is_active if is_active is not None else True) - .alias() - ) - count_result = await db.execute(count_query) - total = count_result.scalar_one() - - # Apply ordering and pagination - query = query.order_by(UserOrganization.created_at.desc()).offset(skip).limit(limit) - result = await db.execute(query) - results = result.all() - - members = [] - for user_org, user in results: - members.append({ - "user_id": user.id, - "email": user.email, - "first_name": user.first_name, - "last_name": user.last_name, - "role": user_org.role, - "is_active": user_org.is_active, - "joined_at": user_org.created_at - }) - - return members, total - except Exception as e: - logger.error(f"Error getting organization members: {str(e)}") - raise - - async def get_user_organizations( - self, - db: AsyncSession, - *, - user_id: UUID, - is_active: bool = True - ) -> List[Organization]: - """Get all organizations a user belongs to.""" - try: - query = ( - select(Organization) - .join(UserOrganization, Organization.id == UserOrganization.organization_id) - .where(UserOrganization.user_id == user_id) - ) - - if is_active is not None: - query = query.where(UserOrganization.is_active == is_active) - - result = await db.execute(query) - return list(result.scalars().all()) - except Exception as e: - logger.error(f"Error getting user organizations: {str(e)}") - raise - - async def get_user_organizations_with_details( - self, - db: AsyncSession, - *, - user_id: UUID, - is_active: bool = True - ) -> List[Dict[str, Any]]: - """ - Get user's organizations with role and member count in SINGLE QUERY. - Eliminates N+1 problem by using subquery for member counts. - - Returns: - List of dicts with organization, role, and member_count - """ - try: - # Subquery to get member counts for each organization - member_count_subq = ( - select( - UserOrganization.organization_id, - func.count(UserOrganization.user_id).label('member_count') - ) - .where(UserOrganization.is_active == True) - .group_by(UserOrganization.organization_id) - .subquery() - ) - - # Main query with JOIN to get org, role, and member count - query = ( - select( - Organization, - UserOrganization.role, - func.coalesce(member_count_subq.c.member_count, 0).label('member_count') - ) - .join(UserOrganization, Organization.id == UserOrganization.organization_id) - .outerjoin(member_count_subq, Organization.id == member_count_subq.c.organization_id) - .where(UserOrganization.user_id == user_id) - ) - - if is_active is not None: - query = query.where(UserOrganization.is_active == is_active) - - result = await db.execute(query) - rows = result.all() - - return [ - { - 'organization': org, - 'role': role, - 'member_count': member_count - } - for org, role, member_count in rows - ] - - except Exception as e: - logger.error(f"Error getting user organizations with details: {str(e)}", exc_info=True) - raise - - async def get_user_role_in_org( - self, - db: AsyncSession, - *, - user_id: UUID, - organization_id: UUID - ) -> Optional[OrganizationRole]: - """Get a user's role in a specific organization.""" - try: - result = await db.execute( - select(UserOrganization).where( - and_( - UserOrganization.user_id == user_id, - UserOrganization.organization_id == organization_id, - UserOrganization.is_active == True - ) - ) - ) - user_org = result.scalar_one_or_none() - - return user_org.role if user_org else None - except Exception as e: - logger.error(f"Error getting user role in org: {str(e)}") - raise - - async def is_user_org_owner( - self, - db: AsyncSession, - *, - user_id: UUID, - organization_id: UUID - ) -> bool: - """Check if a user is an owner of an organization.""" - role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id) - return role == OrganizationRole.OWNER - - async def is_user_org_admin( - self, - db: AsyncSession, - *, - user_id: UUID, - organization_id: UUID - ) -> bool: - """Check if a user is an owner or admin of an organization.""" - role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id) - return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN] - - -# Create a singleton instance for use across the application -organization_async = CRUDOrganizationAsync(Organization) diff --git a/backend/app/crud/session.py b/backend/app/crud/session.py old mode 100644 new mode 100755 index 2b23a04..aa87c25 --- a/backend/app/crud/session.py +++ b/backend/app/crud/session.py @@ -1,13 +1,14 @@ """ -CRUD operations for user sessions. +Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns. """ import logging from datetime import datetime, timezone, timedelta from typing import List, Optional from uuid import UUID -from sqlalchemy import and_ -from sqlalchemy.orm import Session +from sqlalchemy import and_, select, update, delete, func +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload from app.crud.base import CRUDBase from app.models.user_session import UserSession @@ -17,9 +18,9 @@ logger = logging.getLogger(__name__) class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): - """CRUD operations for user sessions.""" + """Async CRUD operations for user sessions.""" - def get_by_jti(self, db: Session, *, jti: str) -> Optional[UserSession]: + async def get_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]: """ Get session by refresh token JTI. @@ -31,14 +32,15 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): UserSession if found, None otherwise """ try: - return db.query(UserSession).filter( - UserSession.refresh_token_jti == jti - ).first() + result = await db.execute( + select(UserSession).where(UserSession.refresh_token_jti == jti) + ) + return result.scalar_one_or_none() except Exception as e: logger.error(f"Error getting session by JTI {jti}: {str(e)}") raise - def get_active_by_jti(self, db: Session, *, jti: str) -> Optional[UserSession]: + async def get_active_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]: """ Get active session by refresh token JTI. @@ -50,30 +52,35 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): Active UserSession if found, None otherwise """ try: - return db.query(UserSession).filter( - and_( - UserSession.refresh_token_jti == jti, - UserSession.is_active == True + result = await db.execute( + select(UserSession).where( + and_( + UserSession.refresh_token_jti == jti, + UserSession.is_active == True + ) ) - ).first() + ) + return result.scalar_one_or_none() except Exception as e: logger.error(f"Error getting active session by JTI {jti}: {str(e)}") raise - def get_user_sessions( + async def get_user_sessions( self, - db: Session, + db: AsyncSession, *, user_id: str, - active_only: bool = True + active_only: bool = True, + with_user: bool = False ) -> List[UserSession]: """ - Get all sessions for a user. + Get all sessions for a user with optional eager loading. Args: db: Database session user_id: User ID active_only: If True, return only active sessions + with_user: If True, eager load user relationship to prevent N+1 Returns: List of UserSession objects @@ -82,19 +89,25 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): # Convert user_id string to UUID if needed user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id - query = db.query(UserSession).filter(UserSession.user_id == user_uuid) + query = select(UserSession).where(UserSession.user_id == user_uuid) + + # Add eager loading if requested to prevent N+1 queries + if with_user: + query = query.options(joinedload(UserSession.user)) if active_only: - query = query.filter(UserSession.is_active == True) + query = query.where(UserSession.is_active == True) - return query.order_by(UserSession.last_used_at.desc()).all() + query = query.order_by(UserSession.last_used_at.desc()) + result = await db.execute(query) + return list(result.scalars().all()) except Exception as e: logger.error(f"Error getting sessions for user {user_id}: {str(e)}") raise - def create_session( + async def create_session( self, - db: Session, + db: AsyncSession, *, obj_in: SessionCreate ) -> UserSession: @@ -126,8 +139,8 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): location_country=obj_in.location_country, ) db.add(db_obj) - db.commit() - db.refresh(db_obj) + await db.commit() + await db.refresh(db_obj) logger.info( f"Session created for user {obj_in.user_id} from {obj_in.device_name} " @@ -136,11 +149,11 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): return db_obj except Exception as e: - db.rollback() + await db.rollback() logger.error(f"Error creating session: {str(e)}", exc_info=True) raise ValueError(f"Failed to create session: {str(e)}") - def deactivate(self, db: Session, *, session_id: str) -> Optional[UserSession]: + async def deactivate(self, db: AsyncSession, *, session_id: str) -> Optional[UserSession]: """ Deactivate a session (logout from device). @@ -152,15 +165,15 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): Deactivated UserSession if found, None otherwise """ try: - session = self.get(db, id=session_id) + session = await self.get(db, id=session_id) if not session: logger.warning(f"Session {session_id} not found for deactivation") return None session.is_active = False db.add(session) - db.commit() - db.refresh(session) + await db.commit() + await db.refresh(session) logger.info( f"Session {session_id} deactivated for user {session.user_id} " @@ -169,13 +182,13 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): return session except Exception as e: - db.rollback() + await db.rollback() logger.error(f"Error deactivating session {session_id}: {str(e)}") raise - def deactivate_all_user_sessions( + async def deactivate_all_user_sessions( self, - db: Session, + db: AsyncSession, *, user_id: str ) -> int: @@ -193,26 +206,33 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): # Convert user_id string to UUID if needed user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id - count = db.query(UserSession).filter( - and_( - UserSession.user_id == user_uuid, - UserSession.is_active == True + stmt = ( + update(UserSession) + .where( + and_( + UserSession.user_id == user_uuid, + UserSession.is_active == True + ) ) - ).update({"is_active": False}) + .values(is_active=False) + ) - db.commit() + result = await db.execute(stmt) + await db.commit() + + count = result.rowcount logger.info(f"Deactivated {count} sessions for user {user_id}") return count except Exception as e: - db.rollback() + await db.rollback() logger.error(f"Error deactivating all sessions for user {user_id}: {str(e)}") raise - def update_last_used( + async def update_last_used( self, - db: Session, + db: AsyncSession, *, session: UserSession ) -> UserSession: @@ -229,17 +249,17 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): try: session.last_used_at = datetime.now(timezone.utc) db.add(session) - db.commit() - db.refresh(session) + await db.commit() + await db.refresh(session) return session except Exception as e: - db.rollback() + await db.rollback() logger.error(f"Error updating last_used for session {session.id}: {str(e)}") raise - def update_refresh_token( + async def update_refresh_token( self, - db: Session, + db: AsyncSession, *, session: UserSession, new_jti: str, @@ -264,22 +284,24 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): session.expires_at = new_expires_at session.last_used_at = datetime.now(timezone.utc) db.add(session) - db.commit() - db.refresh(session) + await db.commit() + await db.refresh(session) return session except Exception as e: - db.rollback() + await db.rollback() logger.error(f"Error updating refresh token for session {session.id}: {str(e)}") raise - def cleanup_expired(self, db: Session, *, keep_days: int = 30) -> int: + async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int: """ - Clean up expired sessions. + Clean up expired sessions using optimized bulk DELETE. Deletes sessions that are: - Expired AND inactive - Older than keep_days + Uses single DELETE query instead of N individual deletes for efficiency. + Args: db: Database session keep_days: Keep inactive sessions for this many days (for audit) @@ -289,31 +311,87 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): """ try: cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days) + now = datetime.now(timezone.utc) - # Delete sessions that are: - # 1. Expired (expires_at < now) AND inactive - # AND - # 2. Older than keep_days - count = db.query(UserSession).filter( + # Use bulk DELETE with WHERE clause - single query + stmt = delete(UserSession).where( and_( UserSession.is_active == False, - UserSession.expires_at < datetime.now(timezone.utc), + UserSession.expires_at < now, UserSession.created_at < cutoff_date ) - ).delete() + ) - db.commit() + result = await db.execute(stmt) + await db.commit() + + count = result.rowcount if count > 0: - logger.info(f"Cleaned up {count} expired sessions") + logger.info(f"Cleaned up {count} expired sessions using bulk DELETE") return count except Exception as e: - db.rollback() + await db.rollback() logger.error(f"Error cleaning up expired sessions: {str(e)}") raise - def get_user_session_count(self, db: Session, *, user_id: str) -> int: + async def cleanup_expired_for_user( + self, + db: AsyncSession, + *, + user_id: str + ) -> int: + """ + Clean up expired and inactive sessions for a specific user. + + Uses single bulk DELETE query for efficiency instead of N individual deletes. + + Args: + db: Database session + user_id: User ID to cleanup sessions for + + Returns: + Number of sessions deleted + """ + try: + # Validate UUID + try: + uuid_obj = uuid.UUID(user_id) + except (ValueError, AttributeError): + logger.error(f"Invalid UUID format: {user_id}") + raise ValueError(f"Invalid user ID format: {user_id}") + + now = datetime.now(timezone.utc) + + # Use bulk DELETE with WHERE clause - single query + stmt = delete(UserSession).where( + and_( + UserSession.user_id == uuid_obj, + UserSession.is_active == False, + UserSession.expires_at < now + ) + ) + + result = await db.execute(stmt) + await db.commit() + + count = result.rowcount + + if count > 0: + logger.info( + f"Cleaned up {count} expired sessions for user {user_id} using bulk DELETE" + ) + + return count + except Exception as e: + await db.rollback() + logger.error( + f"Error cleaning up expired sessions for user {user_id}: {str(e)}" + ) + raise + + async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int: """ Get count of active sessions for a user. @@ -325,12 +403,18 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): Number of active sessions """ try: - return db.query(UserSession).filter( - and_( - UserSession.user_id == user_id, - UserSession.is_active == True + # Convert user_id string to UUID if needed + user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id + + result = await db.execute( + select(func.count(UserSession.id)).where( + and_( + UserSession.user_id == user_uuid, + UserSession.is_active == True + ) ) - ).count() + ) + return result.scalar_one() except Exception as e: logger.error(f"Error counting sessions for user {user_id}: {str(e)}") raise diff --git a/backend/app/crud/session_async.py b/backend/app/crud/session_async.py deleted file mode 100755 index 53eb58c..0000000 --- a/backend/app/crud/session_async.py +++ /dev/null @@ -1,424 +0,0 @@ -""" -Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns. -""" -import logging -from datetime import datetime, timezone, timedelta -from typing import List, Optional -from uuid import UUID - -from sqlalchemy import and_, select, update, delete, func -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload - -from app.crud.base_async import CRUDBaseAsync -from app.models.user_session import UserSession -from app.schemas.sessions import SessionCreate, SessionUpdate - -logger = logging.getLogger(__name__) - - -class CRUDSessionAsync(CRUDBaseAsync[UserSession, SessionCreate, SessionUpdate]): - """Async CRUD operations for user sessions.""" - - async def get_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]: - """ - Get session by refresh token JTI. - - Args: - db: Database session - jti: Refresh token JWT ID - - Returns: - UserSession if found, None otherwise - """ - try: - result = await db.execute( - select(UserSession).where(UserSession.refresh_token_jti == jti) - ) - return result.scalar_one_or_none() - except Exception as e: - logger.error(f"Error getting session by JTI {jti}: {str(e)}") - raise - - async def get_active_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]: - """ - Get active session by refresh token JTI. - - Args: - db: Database session - jti: Refresh token JWT ID - - Returns: - Active UserSession if found, None otherwise - """ - try: - result = await db.execute( - select(UserSession).where( - and_( - UserSession.refresh_token_jti == jti, - UserSession.is_active == True - ) - ) - ) - return result.scalar_one_or_none() - except Exception as e: - logger.error(f"Error getting active session by JTI {jti}: {str(e)}") - raise - - async def get_user_sessions( - self, - db: AsyncSession, - *, - user_id: str, - active_only: bool = True, - with_user: bool = False - ) -> List[UserSession]: - """ - Get all sessions for a user with optional eager loading. - - Args: - db: Database session - user_id: User ID - active_only: If True, return only active sessions - with_user: If True, eager load user relationship to prevent N+1 - - Returns: - List of UserSession objects - """ - try: - # Convert user_id string to UUID if needed - user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id - - query = select(UserSession).where(UserSession.user_id == user_uuid) - - # Add eager loading if requested to prevent N+1 queries - if with_user: - query = query.options(joinedload(UserSession.user)) - - if active_only: - query = query.where(UserSession.is_active == True) - - query = query.order_by(UserSession.last_used_at.desc()) - result = await db.execute(query) - return list(result.scalars().all()) - except Exception as e: - logger.error(f"Error getting sessions for user {user_id}: {str(e)}") - raise - - async def create_session( - self, - db: AsyncSession, - *, - obj_in: SessionCreate - ) -> UserSession: - """ - Create a new user session. - - Args: - db: Database session - obj_in: SessionCreate schema with session data - - Returns: - Created UserSession - - Raises: - ValueError: If session creation fails - """ - try: - db_obj = UserSession( - user_id=obj_in.user_id, - refresh_token_jti=obj_in.refresh_token_jti, - device_name=obj_in.device_name, - device_id=obj_in.device_id, - ip_address=obj_in.ip_address, - user_agent=obj_in.user_agent, - last_used_at=obj_in.last_used_at, - expires_at=obj_in.expires_at, - is_active=True, - location_city=obj_in.location_city, - location_country=obj_in.location_country, - ) - db.add(db_obj) - await db.commit() - await db.refresh(db_obj) - - logger.info( - f"Session created for user {obj_in.user_id} from {obj_in.device_name} " - f"(IP: {obj_in.ip_address})" - ) - - return db_obj - except Exception as e: - await db.rollback() - logger.error(f"Error creating session: {str(e)}", exc_info=True) - raise ValueError(f"Failed to create session: {str(e)}") - - async def deactivate(self, db: AsyncSession, *, session_id: str) -> Optional[UserSession]: - """ - Deactivate a session (logout from device). - - Args: - db: Database session - session_id: Session UUID - - Returns: - Deactivated UserSession if found, None otherwise - """ - try: - session = await self.get(db, id=session_id) - if not session: - logger.warning(f"Session {session_id} not found for deactivation") - return None - - session.is_active = False - db.add(session) - await db.commit() - await db.refresh(session) - - logger.info( - f"Session {session_id} deactivated for user {session.user_id} " - f"({session.device_name})" - ) - - return session - except Exception as e: - await db.rollback() - logger.error(f"Error deactivating session {session_id}: {str(e)}") - raise - - async def deactivate_all_user_sessions( - self, - db: AsyncSession, - *, - user_id: str - ) -> int: - """ - Deactivate all active sessions for a user (logout from all devices). - - Args: - db: Database session - user_id: User ID - - Returns: - Number of sessions deactivated - """ - try: - # Convert user_id string to UUID if needed - user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id - - stmt = ( - update(UserSession) - .where( - and_( - UserSession.user_id == user_uuid, - UserSession.is_active == True - ) - ) - .values(is_active=False) - ) - - result = await db.execute(stmt) - await db.commit() - - count = result.rowcount - - logger.info(f"Deactivated {count} sessions for user {user_id}") - - return count - except Exception as e: - await db.rollback() - logger.error(f"Error deactivating all sessions for user {user_id}: {str(e)}") - raise - - async def update_last_used( - self, - db: AsyncSession, - *, - session: UserSession - ) -> UserSession: - """ - Update the last_used_at timestamp for a session. - - Args: - db: Database session - session: UserSession object - - Returns: - Updated UserSession - """ - try: - session.last_used_at = datetime.now(timezone.utc) - db.add(session) - await db.commit() - await db.refresh(session) - return session - except Exception as e: - await db.rollback() - logger.error(f"Error updating last_used for session {session.id}: {str(e)}") - raise - - async def update_refresh_token( - self, - db: AsyncSession, - *, - session: UserSession, - new_jti: str, - new_expires_at: datetime - ) -> UserSession: - """ - Update session with new refresh token JTI and expiration. - - Called during token refresh. - - Args: - db: Database session - session: UserSession object - new_jti: New refresh token JTI - new_expires_at: New expiration datetime - - Returns: - Updated UserSession - """ - try: - session.refresh_token_jti = new_jti - session.expires_at = new_expires_at - session.last_used_at = datetime.now(timezone.utc) - db.add(session) - await db.commit() - await db.refresh(session) - return session - except Exception as e: - await db.rollback() - logger.error(f"Error updating refresh token for session {session.id}: {str(e)}") - raise - - async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int: - """ - Clean up expired sessions using optimized bulk DELETE. - - Deletes sessions that are: - - Expired AND inactive - - Older than keep_days - - Uses single DELETE query instead of N individual deletes for efficiency. - - Args: - db: Database session - keep_days: Keep inactive sessions for this many days (for audit) - - Returns: - Number of sessions deleted - """ - try: - cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days) - now = datetime.now(timezone.utc) - - # Use bulk DELETE with WHERE clause - single query - stmt = delete(UserSession).where( - and_( - UserSession.is_active == False, - UserSession.expires_at < now, - UserSession.created_at < cutoff_date - ) - ) - - result = await db.execute(stmt) - await db.commit() - - count = result.rowcount - - if count > 0: - logger.info(f"Cleaned up {count} expired sessions using bulk DELETE") - - return count - except Exception as e: - await db.rollback() - logger.error(f"Error cleaning up expired sessions: {str(e)}") - raise - - async def cleanup_expired_for_user( - self, - db: AsyncSession, - *, - user_id: str - ) -> int: - """ - Clean up expired and inactive sessions for a specific user. - - Uses single bulk DELETE query for efficiency instead of N individual deletes. - - Args: - db: Database session - user_id: User ID to cleanup sessions for - - Returns: - Number of sessions deleted - """ - try: - # Validate UUID - try: - uuid_obj = uuid.UUID(user_id) - except (ValueError, AttributeError): - logger.error(f"Invalid UUID format: {user_id}") - raise ValueError(f"Invalid user ID format: {user_id}") - - now = datetime.now(timezone.utc) - - # Use bulk DELETE with WHERE clause - single query - stmt = delete(UserSession).where( - and_( - UserSession.user_id == uuid_obj, - UserSession.is_active == False, - UserSession.expires_at < now - ) - ) - - result = await db.execute(stmt) - await db.commit() - - count = result.rowcount - - if count > 0: - logger.info( - f"Cleaned up {count} expired sessions for user {user_id} using bulk DELETE" - ) - - return count - except Exception as e: - await db.rollback() - logger.error( - f"Error cleaning up expired sessions for user {user_id}: {str(e)}" - ) - raise - - async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int: - """ - Get count of active sessions for a user. - - Args: - db: Database session - user_id: User ID - - Returns: - Number of active sessions - """ - try: - # Convert user_id string to UUID if needed - user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id - - result = await db.execute( - select(func.count(UserSession.id)).where( - and_( - UserSession.user_id == user_uuid, - UserSession.is_active == True - ) - ) - ) - return result.scalar_one() - except Exception as e: - logger.error(f"Error counting sessions for user {user_id}: {str(e)}") - raise - - -# Create singleton instance -session_async = CRUDSessionAsync(UserSession) diff --git a/backend/app/crud/user.py b/backend/app/crud/user.py old mode 100644 new mode 100755 index d1fafd6..3efe634 --- a/backend/app/crud/user.py +++ b/backend/app/crud/user.py @@ -1,12 +1,15 @@ -# app/crud/user.py +# app/crud/user_async.py +"""Async CRUD operations for User model using SQLAlchemy 2.0 patterns.""" import logging +from datetime import datetime, timezone from typing import Optional, Union, Dict, Any, List, Tuple +from uuid import UUID -from sqlalchemy import or_, asc, desc +from sqlalchemy import or_, select, update from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession -from app.core.auth import get_password_hash +from app.core.auth import get_password_hash_async from app.crud.base import CRUDBase from app.models.user import User from app.schemas.users import UserCreate, UserUpdate @@ -15,15 +18,28 @@ logger = logging.getLogger(__name__) class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): - def get_by_email(self, db: Session, *, email: str) -> Optional[User]: - return db.query(User).filter(User.email == email).first() + """Async CRUD operations for User model.""" - def create(self, db: Session, *, obj_in: UserCreate) -> User: - """Create a new user with password hashing and error handling.""" + async def get_by_email(self, db: AsyncSession, *, email: str) -> Optional[User]: + """Get user by email address.""" try: + result = await db.execute( + select(User).where(User.email == email) + ) + return result.scalar_one_or_none() + except Exception as e: + logger.error(f"Error getting user by email {email}: {str(e)}") + raise + + async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User: + """Create a new user with async password hashing and error handling.""" + try: + # Hash password asynchronously to avoid blocking event loop + password_hash = await get_password_hash_async(obj_in.password) + db_obj = User( email=obj_in.email, - password_hash=get_password_hash(obj_in.password), + password_hash=password_hash, first_name=obj_in.first_name, last_name=obj_in.last_name, phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None, @@ -31,11 +47,11 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): preferences={} ) db.add(db_obj) - db.commit() - db.refresh(db_obj) + await db.commit() + await db.refresh(db_obj) return db_obj except IntegrityError as e: - db.rollback() + await db.rollback() error_msg = str(e.orig) if hasattr(e, 'orig') else str(e) if "email" in error_msg.lower(): logger.warning(f"Duplicate email attempted: {obj_in.email}") @@ -43,32 +59,34 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): logger.error(f"Integrity error creating user: {error_msg}") raise ValueError(f"Database integrity error: {error_msg}") except Exception as e: - db.rollback() + await db.rollback() logger.error(f"Unexpected error creating user: {str(e)}", exc_info=True) raise - def update( - self, - db: Session, - *, - db_obj: User, - obj_in: Union[UserUpdate, Dict[str, Any]] + async def update( + self, + db: AsyncSession, + *, + db_obj: User, + obj_in: Union[UserUpdate, Dict[str, Any]] ) -> User: + """Update user with async password hashing if password is updated.""" if isinstance(obj_in, dict): update_data = obj_in else: update_data = obj_in.model_dump(exclude_unset=True) # Handle password separately if it exists in update data + # Hash password asynchronously to avoid blocking event loop if "password" in update_data: - update_data["password_hash"] = get_password_hash(update_data["password"]) + update_data["password_hash"] = await get_password_hash_async(update_data["password"]) del update_data["password"] - return super().update(db, db_obj=db_obj, obj_in=update_data) + return await super().update(db, db_obj=db_obj, obj_in=update_data) - def get_multi_with_total( + async def get_multi_with_total( self, - db: Session, + db: AsyncSession, *, skip: int = 0, limit: int = 100, @@ -102,16 +120,16 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): try: # Build base query - query = db.query(User) + query = select(User) # Exclude soft-deleted users - query = query.filter(User.deleted_at.is_(None)) + query = query.where(User.deleted_at.is_(None)) # Apply filters if filters: for field, value in filters.items(): if hasattr(User, field) and value is not None: - query = query.filter(getattr(User, field) == value) + query = query.where(getattr(User, field) == value) # Apply search if search: @@ -120,21 +138,26 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): User.first_name.ilike(f"%{search}%"), User.last_name.ilike(f"%{search}%") ) - query = query.filter(search_filter) + query = query.where(search_filter) # Get total count - total = query.count() + from sqlalchemy import func + count_query = select(func.count()).select_from(query.alias()) + count_result = await db.execute(count_query) + total = count_result.scalar_one() # Apply sorting if sort_by and hasattr(User, sort_by): sort_column = getattr(User, sort_by) if sort_order.lower() == "desc": - query = query.order_by(desc(sort_column)) + query = query.order_by(sort_column.desc()) else: - query = query.order_by(asc(sort_column)) + query = query.order_by(sort_column.asc()) # Apply pagination - users = query.offset(skip).limit(limit).all() + query = query.offset(skip).limit(limit) + result = await db.execute(query) + users = list(result.scalars().all()) return users, total @@ -142,12 +165,108 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): logger.error(f"Error retrieving paginated users: {str(e)}") raise + async def bulk_update_status( + self, + db: AsyncSession, + *, + user_ids: List[UUID], + is_active: bool + ) -> int: + """ + Bulk update is_active status for multiple users. + + Args: + db: Database session + user_ids: List of user IDs to update + is_active: New active status + + Returns: + Number of users updated + """ + try: + if not user_ids: + return 0 + + # Use UPDATE with WHERE IN for efficiency + stmt = ( + update(User) + .where(User.id.in_(user_ids)) + .where(User.deleted_at.is_(None)) # Don't update deleted users + .values(is_active=is_active, updated_at=datetime.now(timezone.utc)) + ) + + result = await db.execute(stmt) + await db.commit() + + updated_count = result.rowcount + logger.info(f"Bulk updated {updated_count} users to is_active={is_active}") + return updated_count + + except Exception as e: + await db.rollback() + logger.error(f"Error bulk updating user status: {str(e)}", exc_info=True) + raise + + async def bulk_soft_delete( + self, + db: AsyncSession, + *, + user_ids: List[UUID], + exclude_user_id: Optional[UUID] = None + ) -> int: + """ + Bulk soft delete multiple users. + + Args: + db: Database session + user_ids: List of user IDs to delete + exclude_user_id: Optional user ID to exclude (e.g., the admin performing the action) + + Returns: + Number of users deleted + """ + try: + if not user_ids: + return 0 + + # Remove excluded user from list + filtered_ids = [uid for uid in user_ids if uid != exclude_user_id] + + if not filtered_ids: + return 0 + + # Use UPDATE with WHERE IN for efficiency + stmt = ( + update(User) + .where(User.id.in_(filtered_ids)) + .where(User.deleted_at.is_(None)) # Don't re-delete already deleted users + .values( + deleted_at=datetime.now(timezone.utc), + is_active=False, + updated_at=datetime.now(timezone.utc) + ) + ) + + result = await db.execute(stmt) + await db.commit() + + deleted_count = result.rowcount + logger.info(f"Bulk soft deleted {deleted_count} users") + return deleted_count + + except Exception as e: + await db.rollback() + logger.error(f"Error bulk deleting users: {str(e)}", exc_info=True) + raise + def is_active(self, user: User) -> bool: + """Check if user is active.""" return user.is_active def is_superuser(self, user: User) -> bool: + """Check if user is a superuser.""" return user.is_superuser # Create a singleton instance for use across the application -user = CRUDUser(User) \ No newline at end of file +user = CRUDUser(User) diff --git a/backend/app/crud/user_async.py b/backend/app/crud/user_async.py deleted file mode 100755 index 63fdecd..0000000 --- a/backend/app/crud/user_async.py +++ /dev/null @@ -1,272 +0,0 @@ -# app/crud/user_async.py -"""Async CRUD operations for User model using SQLAlchemy 2.0 patterns.""" -import logging -from datetime import datetime, timezone -from typing import Optional, Union, Dict, Any, List, Tuple -from uuid import UUID - -from sqlalchemy import or_, select, update -from sqlalchemy.exc import IntegrityError -from sqlalchemy.ext.asyncio import AsyncSession - -from app.core.auth import get_password_hash_async -from app.crud.base_async import CRUDBaseAsync -from app.models.user import User -from app.schemas.users import UserCreate, UserUpdate - -logger = logging.getLogger(__name__) - - -class CRUDUserAsync(CRUDBaseAsync[User, UserCreate, UserUpdate]): - """Async CRUD operations for User model.""" - - async def get_by_email(self, db: AsyncSession, *, email: str) -> Optional[User]: - """Get user by email address.""" - try: - result = await db.execute( - select(User).where(User.email == email) - ) - return result.scalar_one_or_none() - except Exception as e: - logger.error(f"Error getting user by email {email}: {str(e)}") - raise - - async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User: - """Create a new user with async password hashing and error handling.""" - try: - # Hash password asynchronously to avoid blocking event loop - password_hash = await get_password_hash_async(obj_in.password) - - db_obj = User( - email=obj_in.email, - password_hash=password_hash, - first_name=obj_in.first_name, - last_name=obj_in.last_name, - phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None, - is_superuser=obj_in.is_superuser if hasattr(obj_in, 'is_superuser') else False, - preferences={} - ) - db.add(db_obj) - await db.commit() - await db.refresh(db_obj) - return db_obj - except IntegrityError as e: - await db.rollback() - error_msg = str(e.orig) if hasattr(e, 'orig') else str(e) - if "email" in error_msg.lower(): - logger.warning(f"Duplicate email attempted: {obj_in.email}") - raise ValueError(f"User with email {obj_in.email} already exists") - logger.error(f"Integrity error creating user: {error_msg}") - raise ValueError(f"Database integrity error: {error_msg}") - except Exception as e: - await db.rollback() - logger.error(f"Unexpected error creating user: {str(e)}", exc_info=True) - raise - - async def update( - self, - db: AsyncSession, - *, - db_obj: User, - obj_in: Union[UserUpdate, Dict[str, Any]] - ) -> User: - """Update user with async password hashing if password is updated.""" - if isinstance(obj_in, dict): - update_data = obj_in - else: - update_data = obj_in.model_dump(exclude_unset=True) - - # Handle password separately if it exists in update data - # Hash password asynchronously to avoid blocking event loop - if "password" in update_data: - update_data["password_hash"] = await get_password_hash_async(update_data["password"]) - del update_data["password"] - - return await super().update(db, db_obj=db_obj, obj_in=update_data) - - async def get_multi_with_total( - self, - db: AsyncSession, - *, - skip: int = 0, - limit: int = 100, - sort_by: Optional[str] = None, - sort_order: str = "asc", - filters: Optional[Dict[str, Any]] = None, - search: Optional[str] = None - ) -> Tuple[List[User], int]: - """ - Get multiple users with total count, filtering, sorting, and search. - - Args: - db: Database session - skip: Number of records to skip - limit: Maximum number of records to return - sort_by: Field name to sort by - sort_order: Sort order ("asc" or "desc") - filters: Dictionary of filters (field_name: value) - search: Search term to match against email, first_name, last_name - - Returns: - Tuple of (users list, total count) - """ - # Validate pagination - if skip < 0: - raise ValueError("skip must be non-negative") - if limit < 0: - raise ValueError("limit must be non-negative") - if limit > 1000: - raise ValueError("Maximum limit is 1000") - - try: - # Build base query - query = select(User) - - # Exclude soft-deleted users - query = query.where(User.deleted_at.is_(None)) - - # Apply filters - if filters: - for field, value in filters.items(): - if hasattr(User, field) and value is not None: - query = query.where(getattr(User, field) == value) - - # Apply search - if search: - search_filter = or_( - User.email.ilike(f"%{search}%"), - User.first_name.ilike(f"%{search}%"), - User.last_name.ilike(f"%{search}%") - ) - query = query.where(search_filter) - - # Get total count - from sqlalchemy import func - count_query = select(func.count()).select_from(query.alias()) - count_result = await db.execute(count_query) - total = count_result.scalar_one() - - # Apply sorting - if sort_by and hasattr(User, sort_by): - sort_column = getattr(User, sort_by) - if sort_order.lower() == "desc": - query = query.order_by(sort_column.desc()) - else: - query = query.order_by(sort_column.asc()) - - # Apply pagination - query = query.offset(skip).limit(limit) - result = await db.execute(query) - users = list(result.scalars().all()) - - return users, total - - except Exception as e: - logger.error(f"Error retrieving paginated users: {str(e)}") - raise - - async def bulk_update_status( - self, - db: AsyncSession, - *, - user_ids: List[UUID], - is_active: bool - ) -> int: - """ - Bulk update is_active status for multiple users. - - Args: - db: Database session - user_ids: List of user IDs to update - is_active: New active status - - Returns: - Number of users updated - """ - try: - if not user_ids: - return 0 - - # Use UPDATE with WHERE IN for efficiency - stmt = ( - update(User) - .where(User.id.in_(user_ids)) - .where(User.deleted_at.is_(None)) # Don't update deleted users - .values(is_active=is_active, updated_at=datetime.now(timezone.utc)) - ) - - result = await db.execute(stmt) - await db.commit() - - updated_count = result.rowcount - logger.info(f"Bulk updated {updated_count} users to is_active={is_active}") - return updated_count - - except Exception as e: - await db.rollback() - logger.error(f"Error bulk updating user status: {str(e)}", exc_info=True) - raise - - async def bulk_soft_delete( - self, - db: AsyncSession, - *, - user_ids: List[UUID], - exclude_user_id: Optional[UUID] = None - ) -> int: - """ - Bulk soft delete multiple users. - - Args: - db: Database session - user_ids: List of user IDs to delete - exclude_user_id: Optional user ID to exclude (e.g., the admin performing the action) - - Returns: - Number of users deleted - """ - try: - if not user_ids: - return 0 - - # Remove excluded user from list - filtered_ids = [uid for uid in user_ids if uid != exclude_user_id] - - if not filtered_ids: - return 0 - - # Use UPDATE with WHERE IN for efficiency - stmt = ( - update(User) - .where(User.id.in_(filtered_ids)) - .where(User.deleted_at.is_(None)) # Don't re-delete already deleted users - .values( - deleted_at=datetime.now(timezone.utc), - is_active=False, - updated_at=datetime.now(timezone.utc) - ) - ) - - result = await db.execute(stmt) - await db.commit() - - deleted_count = result.rowcount - logger.info(f"Bulk soft deleted {deleted_count} users") - return deleted_count - - except Exception as e: - await db.rollback() - logger.error(f"Error bulk deleting users: {str(e)}", exc_info=True) - raise - - def is_active(self, user: User) -> bool: - """Check if user is active.""" - return user.is_active - - def is_superuser(self, user: User) -> bool: - """Check if user is a superuser.""" - return user.is_superuser - - -# Create a singleton instance for use across the application -user_async = CRUDUserAsync(User) diff --git a/backend/app/init_db.py b/backend/app/init_db.py deleted file mode 100755 index c67e2f8..0000000 --- a/backend/app/init_db.py +++ /dev/null @@ -1,78 +0,0 @@ -# app/init_db.py -import logging -from typing import Optional - -from sqlalchemy.orm import Session - -from app.core.config import settings -from app.core.database import engine -from app.crud.user import user as user_crud -from app.schemas.users import UserCreate - -logger = logging.getLogger(__name__) - - -def init_db(db: Session) -> Optional[UserCreate]: - """ - Initialize database with first superuser if settings are configured and user doesn't exist. - - Returns: - The created or existing superuser, or None if creation fails - """ - # Use default values if not set in environment variables - superuser_email = settings.FIRST_SUPERUSER_EMAIL or "admin@example.com" - superuser_password = settings.FIRST_SUPERUSER_PASSWORD or "Admin123!Change" - - if not settings.FIRST_SUPERUSER_EMAIL or not settings.FIRST_SUPERUSER_PASSWORD: - logger.warning( - "First superuser credentials not configured in settings. " - f"Using defaults: {superuser_email}" - ) - - try: - # Check if superuser already exists - existing_user = user_crud.get_by_email(db, email=superuser_email) - - if existing_user: - logger.info(f"Superuser already exists: {existing_user.email}") - return existing_user - - # Create superuser if doesn't exist - user_in = UserCreate( - email=superuser_email, - password=superuser_password, - first_name="Admin", - last_name="User", - is_superuser=True - ) - - user = user_crud.create(db, obj_in=user_in) - logger.info(f"Created first superuser: {user.email}") - - return user - - except Exception as e: - logger.error(f"Error initializing database: {e}") - raise - - -if __name__ == "__main__": - # Configure logging to show info logs - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" - ) - - with Session(engine) as session: - try: - user = init_db(session) - if user: - print(f"✓ Database initialized successfully") - print(f"✓ Superuser: {user.email}") - else: - print("✗ Failed to initialize database") - except Exception as e: - print(f"✗ Error initializing database: {e}") - raise - finally: - session.close() diff --git a/backend/app/main.py b/backend/app/main.py index 8f7d7f3..5d5d7f6 100755 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -13,7 +13,7 @@ from slowapi.util import get_remote_address from app.api.main import api_router from app.core.config import settings -from app.core.database_async import check_database_health +from app.core.database import check_database_health from app.core.exceptions import ( APIException, api_exception_handler, diff --git a/backend/app/services/session_cleanup.py b/backend/app/services/session_cleanup.py index ff7fa04..230eeda 100755 --- a/backend/app/services/session_cleanup.py +++ b/backend/app/services/session_cleanup.py @@ -6,8 +6,8 @@ This service runs periodically to remove old session records from the database. import logging from datetime import datetime, timezone -from app.core.database_async import AsyncSessionLocal -from app.crud.session_async import session_async as session_crud +from app.core.database import SessionLocal +from app.crud.session import session as session_crud logger = logging.getLogger(__name__) @@ -29,7 +29,7 @@ async def cleanup_expired_sessions(keep_days: int = 30) -> int: """ logger.info("Starting session cleanup job...") - async with AsyncSessionLocal() as db: + async with SessionLocal() as db: try: # Use CRUD method to cleanup count = await session_crud.cleanup_expired(db, keep_days=keep_days) @@ -50,7 +50,7 @@ async def get_session_statistics() -> dict: Returns: Dictionary with session stats """ - async with AsyncSessionLocal() as db: + async with SessionLocal() as db: try: from app.models.user_session import UserSession from sqlalchemy import select, func