Remove unused async database and CRUD modules

- Deleted `database_async.py`, `base_async.py`, and `organization_async.py` modules due to deprecation and unused references across the project.
- Improved overall codebase clarity and minimized redundant functionality by removing unused async database logic, CRUD utilities, and organization-related operations.
This commit is contained in:
Felipe Cardoso
2025-11-01 05:47:43 +01:00
parent ee938ce6a6
commit efcf10f9aa
20 changed files with 972 additions and 2283 deletions

View File

@@ -7,7 +7,7 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import get_token_data, TokenExpiredError, TokenInvalidError from app.core.auth import get_token_data, TokenExpiredError, TokenInvalidError
from app.core.database_async import get_async_db from app.core.database import get_db
from app.models.user import User from app.models.user import User
# OAuth2 configuration # OAuth2 configuration
@@ -15,7 +15,7 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
async def get_current_user( async def get_current_user(
db: AsyncSession = Depends(get_async_db), db: AsyncSession = Depends(get_db),
token: str = Depends(oauth2_scheme) token: str = Depends(oauth2_scheme)
) -> User: ) -> User:
""" """
@@ -139,7 +139,7 @@ async def get_optional_token(authorization: str = Header(None)) -> Optional[str]
async def get_optional_current_user( async def get_optional_current_user(
db: AsyncSession = Depends(get_async_db), db: AsyncSession = Depends(get_db),
token: Optional[str] = Depends(get_optional_token) token: Optional[str] = Depends(get_optional_token)
) -> Optional[User]: ) -> Optional[User]:
""" """

View File

@@ -14,8 +14,8 @@ from fastapi import Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user from app.api.dependencies.auth import get_current_user
from app.core.database_async import get_async_db from app.core.database import get_db
from app.crud.organization_async import organization_async as organization_crud from app.crud.organization import organization as organization_crud
from app.models.user import User from app.models.user import User
from app.models.user_organization import OrganizationRole from app.models.user_organization import OrganizationRole
@@ -78,7 +78,7 @@ class OrganizationPermission:
self, self,
organization_id: UUID, organization_id: UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> User: ) -> User:
""" """
Check if user has required role in the organization. Check if user has required role in the organization.
@@ -133,7 +133,7 @@ require_org_member = OrganizationPermission([
async def get_current_org_role( async def get_current_org_role(
organization_id: UUID, organization_id: UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Optional[OrganizationRole]: ) -> Optional[OrganizationRole]:
""" """
Get the current user's role in an organization. Get the current user's role in an organization.
@@ -164,7 +164,7 @@ async def get_current_org_role(
async def require_org_membership( async def require_org_membership(
organization_id: UUID, organization_id: UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> User: ) -> User:
""" """
Ensure user is a member of the organization (any role). Ensure user is a member of the organization (any role).

View File

@@ -15,10 +15,10 @@ from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.permissions import require_superuser from app.api.dependencies.permissions import require_superuser
from app.core.database_async import get_async_db from app.core.database import get_db
from app.core.exceptions import NotFoundError, DuplicateError, AuthorizationError, ErrorCode from app.core.exceptions import NotFoundError, DuplicateError, AuthorizationError, ErrorCode
from app.crud.organization_async import organization_async as organization_crud from app.crud.organization import organization as organization_crud
from app.crud.user_async import user_async as user_crud from app.crud.user import user as user_crud
from app.models.user import User from app.models.user import User
from app.models.user_organization import OrganizationRole from app.models.user_organization import OrganizationRole
from app.schemas.common import ( from app.schemas.common import (
@@ -80,7 +80,7 @@ async def admin_list_users(
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"), is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"),
search: Optional[str] = Query(None, description="Search by email, name"), search: Optional[str] = Query(None, description="Search by email, name"),
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
""" """
List all users with comprehensive filtering and search. List all users with comprehensive filtering and search.
@@ -131,7 +131,7 @@ async def admin_list_users(
async def admin_create_user( async def admin_create_user(
user_in: UserCreate, user_in: UserCreate,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
""" """
Create a new user with admin privileges. Create a new user with admin privileges.
@@ -163,7 +163,7 @@ async def admin_create_user(
async def admin_get_user( async def admin_get_user(
user_id: UUID, user_id: UUID,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
"""Get detailed information about a specific user.""" """Get detailed information about a specific user."""
user = await user_crud.get(db, id=user_id) user = await user_crud.get(db, id=user_id)
@@ -186,7 +186,7 @@ async def admin_update_user(
user_id: UUID, user_id: UUID,
user_in: UserUpdate, user_in: UserUpdate,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
"""Update user information with admin privileges.""" """Update user information with admin privileges."""
try: try:
@@ -218,7 +218,7 @@ async def admin_update_user(
async def admin_delete_user( async def admin_delete_user(
user_id: UUID, user_id: UUID,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
"""Soft delete a user (sets deleted_at timestamp).""" """Soft delete a user (sets deleted_at timestamp)."""
try: try:
@@ -262,7 +262,7 @@ async def admin_delete_user(
async def admin_activate_user( async def admin_activate_user(
user_id: UUID, user_id: UUID,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
"""Activate a user account.""" """Activate a user account."""
try: try:
@@ -298,7 +298,7 @@ async def admin_activate_user(
async def admin_deactivate_user( async def admin_deactivate_user(
user_id: UUID, user_id: UUID,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
"""Deactivate a user account.""" """Deactivate a user account."""
try: try:
@@ -342,7 +342,7 @@ async def admin_deactivate_user(
async def admin_bulk_user_action( async def admin_bulk_user_action(
bulk_action: BulkUserAction, bulk_action: BulkUserAction,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
""" """
Perform bulk actions on multiple users using optimized bulk operations. Perform bulk actions on multiple users using optimized bulk operations.
@@ -410,7 +410,7 @@ async def admin_list_organizations(
is_active: Optional[bool] = Query(None, description="Filter by active status"), is_active: Optional[bool] = Query(None, description="Filter by active status"),
search: Optional[str] = Query(None, description="Search by name, slug, description"), search: Optional[str] = Query(None, description="Search by name, slug, description"),
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
"""List all organizations with filtering and search.""" """List all organizations with filtering and search."""
try: try:
@@ -467,7 +467,7 @@ async def admin_list_organizations(
async def admin_create_organization( async def admin_create_organization(
org_in: OrganizationCreate, org_in: OrganizationCreate,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
"""Create a new organization.""" """Create a new organization."""
try: try:
@@ -509,7 +509,7 @@ async def admin_create_organization(
async def admin_get_organization( async def admin_get_organization(
org_id: UUID, org_id: UUID,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
"""Get detailed information about a specific organization.""" """Get detailed information about a specific organization."""
org = await organization_crud.get(db, id=org_id) org = await organization_crud.get(db, id=org_id)
@@ -544,7 +544,7 @@ async def admin_update_organization(
org_id: UUID, org_id: UUID,
org_in: OrganizationUpdate, org_in: OrganizationUpdate,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
"""Update organization information.""" """Update organization information."""
try: try:
@@ -588,7 +588,7 @@ async def admin_update_organization(
async def admin_delete_organization( async def admin_delete_organization(
org_id: UUID, org_id: UUID,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
"""Delete an organization and all its relationships.""" """Delete an organization and all its relationships."""
try: try:
@@ -626,7 +626,7 @@ async def admin_list_organization_members(
pagination: PaginationParams = Depends(), pagination: PaginationParams = Depends(),
is_active: Optional[bool] = Query(True, description="Filter by active status"), is_active: Optional[bool] = Query(True, description="Filter by active status"),
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
"""List all members of an organization.""" """List all members of an organization."""
try: try:
@@ -681,7 +681,7 @@ async def admin_add_organization_member(
org_id: UUID, org_id: UUID,
request: AddMemberRequest, request: AddMemberRequest,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
"""Add a user to an organization.""" """Add a user to an organization."""
try: try:
@@ -742,7 +742,7 @@ async def admin_remove_organization_member(
org_id: UUID, org_id: UUID,
user_id: UUID, user_id: UUID,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
"""Remove a user from an organization.""" """Remove a user from an organization."""
try: try:

View File

@@ -13,14 +13,14 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user from app.api.dependencies.auth import get_current_user
from app.core.auth import TokenExpiredError, TokenInvalidError, decode_token from app.core.auth import TokenExpiredError, TokenInvalidError, decode_token
from app.core.auth import get_password_hash from app.core.auth import get_password_hash
from app.core.database_async import get_async_db from app.core.database import get_db
from app.core.exceptions import ( from app.core.exceptions import (
AuthenticationError as AuthError, AuthenticationError as AuthError,
DatabaseError, DatabaseError,
ErrorCode ErrorCode
) )
from app.crud.session_async import session_async as session_crud from app.crud.session import session as session_crud
from app.crud.user_async import user_async as user_crud from app.crud.user import user as user_crud
from app.models.user import User from app.models.user import User
from app.schemas.common import MessageResponse from app.schemas.common import MessageResponse
from app.schemas.sessions import SessionCreate, LogoutRequest from app.schemas.sessions import SessionCreate, LogoutRequest
@@ -54,7 +54,7 @@ RATE_MULTIPLIER = 100 if IS_TEST else 1
async def register_user( async def register_user(
request: Request, request: Request,
user_data: UserCreate, user_data: UserCreate,
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
""" """
Register a new user. Register a new user.
@@ -85,7 +85,7 @@ async def register_user(
async def login( async def login(
request: Request, request: Request,
login_data: LoginRequest, login_data: LoginRequest,
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
""" """
Login with username and password. Login with username and password.
@@ -167,7 +167,7 @@ async def login(
async def login_oauth( async def login_oauth(
request: Request, request: Request,
form_data: OAuth2PasswordRequestForm = Depends(), form_data: OAuth2PasswordRequestForm = Depends(),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
""" """
OAuth2-compatible login endpoint, used by the OpenAPI UI. OAuth2-compatible login endpoint, used by the OpenAPI UI.
@@ -244,7 +244,7 @@ async def login_oauth(
async def refresh_token( async def refresh_token(
request: Request, request: Request,
refresh_data: RefreshTokenRequest, refresh_data: RefreshTokenRequest,
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
""" """
Refresh access token using a refresh token. Refresh access token using a refresh token.
@@ -333,7 +333,7 @@ async def refresh_token(
async def request_password_reset( async def request_password_reset(
request: Request, request: Request,
reset_request: PasswordResetRequest, reset_request: PasswordResetRequest,
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
""" """
Request a password reset. Request a password reset.
@@ -391,7 +391,7 @@ async def request_password_reset(
async def confirm_password_reset( async def confirm_password_reset(
request: Request, request: Request,
reset_confirm: PasswordResetConfirm, reset_confirm: PasswordResetConfirm,
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
""" """
Confirm password reset with token. Confirm password reset with token.
@@ -430,7 +430,7 @@ async def confirm_password_reset(
# SECURITY: Invalidate all existing sessions after password reset # SECURITY: Invalidate all existing sessions after password reset
# This prevents stolen sessions from being used after password change # This prevents stolen sessions from being used after password change
from app.crud.session_async import session_async as session_crud from app.crud.session import session as session_crud
try: try:
deactivated_count = await session_crud.deactivate_all_user_sessions( deactivated_count = await session_crud.deactivate_all_user_sessions(
db, db,
@@ -478,7 +478,7 @@ async def logout(
request: Request, request: Request,
logout_request: LogoutRequest, logout_request: LogoutRequest,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
""" """
Logout from current device by deactivating the session. Logout from current device by deactivating the session.
@@ -566,7 +566,7 @@ async def logout(
async def logout_all( async def logout_all(
request: Request, request: Request,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
""" """
Logout from all devices by deactivating all user sessions. Logout from all devices by deactivating all user sessions.

View File

@@ -13,9 +13,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user from app.api.dependencies.auth import get_current_user
from app.api.dependencies.permissions import require_org_admin, require_org_membership from app.api.dependencies.permissions import require_org_admin, require_org_membership
from app.core.database_async import get_async_db from app.core.database import get_db
from app.core.exceptions import NotFoundError, ErrorCode from app.core.exceptions import NotFoundError, ErrorCode
from app.crud.organization_async import organization_async as organization_crud from app.crud.organization import organization as organization_crud
from app.models.user import User from app.models.user import User
from app.schemas.common import ( from app.schemas.common import (
PaginationParams, PaginationParams,
@@ -43,7 +43,7 @@ router = APIRouter()
async def get_my_organizations( async def get_my_organizations(
is_active: bool = Query(True, description="Filter by active membership"), is_active: bool = Query(True, description="Filter by active membership"),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
""" """
Get all organizations the current user belongs to. Get all organizations the current user belongs to.
@@ -93,7 +93,7 @@ async def get_my_organizations(
async def get_organization( async def get_organization(
organization_id: UUID, organization_id: UUID,
current_user: User = Depends(require_org_membership), current_user: User = Depends(require_org_membership),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
""" """
Get details of a specific organization. Get details of a specific organization.
@@ -140,7 +140,7 @@ async def get_organization_members(
pagination: PaginationParams = Depends(), pagination: PaginationParams = Depends(),
is_active: bool = Query(True, description="Filter by active status"), is_active: bool = Query(True, description="Filter by active status"),
current_user: User = Depends(require_org_membership), current_user: User = Depends(require_org_membership),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
""" """
Get all members of an organization. Get all members of an organization.
@@ -183,7 +183,7 @@ async def update_organization(
organization_id: UUID, organization_id: UUID,
org_in: OrganizationUpdate, org_in: OrganizationUpdate,
current_user: User = Depends(require_org_admin), current_user: User = Depends(require_org_admin),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
""" """
Update organization details. Update organization details.

View File

@@ -14,9 +14,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user from app.api.dependencies.auth import get_current_user
from app.core.auth import decode_token from app.core.auth import decode_token
from app.core.database_async import get_async_db from app.core.database import get_db
from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode
from app.crud.session_async import session_async as session_crud from app.crud.session import session as session_crud
from app.models.user import User from app.models.user import User
from app.schemas.common import MessageResponse from app.schemas.common import MessageResponse
from app.schemas.sessions import SessionResponse, SessionListResponse from app.schemas.sessions import SessionResponse, SessionListResponse
@@ -45,7 +45,7 @@ limiter = Limiter(key_func=get_remote_address)
async def list_my_sessions( async def list_my_sessions(
request: Request, request: Request,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
""" """
List all active sessions for the current user. List all active sessions for the current user.
@@ -129,7 +129,7 @@ async def revoke_session(
request: Request, request: Request,
session_id: UUID, session_id: UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
""" """
Revoke a specific session by ID. Revoke a specific session by ID.
@@ -204,7 +204,7 @@ async def revoke_session(
async def cleanup_expired_sessions( async def cleanup_expired_sessions(
request: Request, request: Request,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
""" """
Cleanup expired sessions for the current user. Cleanup expired sessions for the current user.

View File

@@ -11,13 +11,13 @@ from slowapi.util import get_remote_address
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user, get_current_superuser from app.api.dependencies.auth import get_current_user, get_current_superuser
from app.core.database_async import get_async_db from app.core.database import get_db
from app.core.exceptions import ( from app.core.exceptions import (
NotFoundError, NotFoundError,
AuthorizationError, AuthorizationError,
ErrorCode ErrorCode
) )
from app.crud.user_async import user_async as user_crud from app.crud.user import user as user_crud
from app.models.user import User from app.models.user import User
from app.schemas.common import ( from app.schemas.common import (
PaginationParams, PaginationParams,
@@ -58,7 +58,7 @@ async def list_users(
is_active: Optional[bool] = Query(None, description="Filter by active status"), is_active: Optional[bool] = Query(None, description="Filter by active status"),
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"), is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"),
current_user: User = Depends(get_current_superuser), current_user: User = Depends(get_current_superuser),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
""" """
List all users with pagination, filtering, and sorting. List all users with pagination, filtering, and sorting.
@@ -138,7 +138,7 @@ def get_current_user_profile(
async def update_current_user( async def update_current_user(
user_update: UserUpdate, user_update: UserUpdate,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
""" """
Update current user's profile. Update current user's profile.
@@ -188,7 +188,7 @@ async def update_current_user(
async def get_user_by_id( async def get_user_by_id(
user_id: UUID, user_id: UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
""" """
Get user by ID. Get user by ID.
@@ -236,7 +236,7 @@ async def update_user(
user_id: UUID, user_id: UUID,
user_update: UserUpdate, user_update: UserUpdate,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
""" """
Update user by ID. Update user by ID.
@@ -304,7 +304,7 @@ async def change_current_user_password(
request: Request, request: Request,
password_change: PasswordChange, password_change: PasswordChange,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
""" """
Change current user's password. Change current user's password.
@@ -356,7 +356,7 @@ async def change_current_user_password(
async def delete_user( async def delete_user(
user_id: UUID, user_id: UUID,
current_user: User = Depends(get_current_superuser), current_user: User = Depends(get_current_superuser),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
""" """
Delete user by ID (superuser only). Delete user by ID (superuser only).

189
backend/app/core/database.py Normal file → Executable file
View File

@@ -1,113 +1,186 @@
# app/core/database.py # app/core/database.py
import logging """
from contextlib import contextmanager Database configuration using SQLAlchemy 2.0 and asyncpg.
from typing import Generator
from sqlalchemy import create_engine, text This module provides async database connectivity with proper connection pooling
and session management for FastAPI endpoints.
"""
import logging
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from sqlalchemy import text
from sqlalchemy.dialects.postgresql import JSONB, UUID from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.ext.asyncio import (
AsyncSession,
AsyncEngine,
create_async_engine,
async_sessionmaker,
)
from sqlalchemy.ext.compiler import compiles from sqlalchemy.ext.compiler import compiles
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import sessionmaker, Session
from app.core.config import settings from app.core.config import settings
# Configure logging # Configure logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# SQLite compatibility for testing # SQLite compatibility for testing
@compiles(JSONB, 'sqlite') @compiles(JSONB, 'sqlite')
def compile_jsonb_sqlite(type_, compiler, **kw): def compile_jsonb_sqlite(type_, compiler, **kw):
return "TEXT" return "TEXT"
@compiles(UUID, 'sqlite') @compiles(UUID, 'sqlite')
def compile_uuid_sqlite(type_, compiler, **kw): def compile_uuid_sqlite(type_, compiler, **kw):
return "TEXT" return "TEXT"
# Declarative base for models
Base = declarative_base()
# Create engine with optimized settings for PostgreSQL # Declarative base for models (SQLAlchemy 2.0 style)
def create_production_engine(): class Base(DeclarativeBase):
return create_engine( """Base class for all database models."""
settings.database_url, pass
# Connection pool settings
pool_size=settings.db_pool_size,
max_overflow=settings.db_max_overflow, def get_async_database_url(url: str) -> str:
pool_timeout=settings.db_pool_timeout, """
pool_recycle=settings.db_pool_recycle, Convert sync database URL to async URL.
pool_pre_ping=True,
# Query execution settings postgresql:// -> postgresql+asyncpg://
connect_args={ sqlite:// -> sqlite+aiosqlite://
"""
if url.startswith("postgresql://"):
return url.replace("postgresql://", "postgresql+asyncpg://")
elif url.startswith("sqlite://"):
return url.replace("sqlite://", "sqlite+aiosqlite://")
return url
# Create async engine with optimized settings
def create_async_production_engine() -> AsyncEngine:
"""Create an async database engine with production settings."""
async_url = get_async_database_url(settings.database_url)
# Base engine config
engine_config = {
"pool_size": settings.db_pool_size,
"max_overflow": settings.db_max_overflow,
"pool_timeout": settings.db_pool_timeout,
"pool_recycle": settings.db_pool_recycle,
"pool_pre_ping": True,
"echo": settings.sql_echo,
"echo_pool": settings.sql_echo_pool,
}
# Add PostgreSQL-specific connect_args
if "postgresql" in async_url:
engine_config["connect_args"] = {
"server_settings": {
"application_name": "eventspace", "application_name": "eventspace",
"keepalives": 1, "timezone": "UTC",
"keepalives_idle": 60,
"keepalives_interval": 10,
"keepalives_count": 5,
"options": "-c timezone=UTC",
}, },
isolation_level="READ COMMITTED", # asyncpg-specific settings
echo=settings.sql_echo, "command_timeout": 60,
echo_pool=settings.sql_echo_pool, "timeout": 10,
) }
# Default production engine and session factory return create_async_engine(async_url, **engine_config)
engine = create_production_engine()
SessionLocal = sessionmaker(
# Create async engine and session factory
engine = create_async_production_engine()
SessionLocal = async_sessionmaker(
engine,
class_=AsyncSession,
autocommit=False, autocommit=False,
autoflush=False, autoflush=False,
bind=engine, expire_on_commit=False, # Prevent unnecessary queries after commit
expire_on_commit=False # Prevent unnecessary queries after commit
) )
# FastAPI dependency
def get_db() -> Generator[Session, None, None]: # FastAPI dependency for async database sessions
async def get_db() -> AsyncGenerator[AsyncSession, None]:
""" """
FastAPI dependency that provides a database session. FastAPI dependency that provides an async database session.
Automatically closes the session after the request completes. Automatically closes the session after the request completes.
Usage:
@router.get("/users")
async def get_users(db: AsyncSession = Depends(get_db)):
result = await db.execute(select(User))
return result.scalars().all()
""" """
db = SessionLocal() async with SessionLocal() as session:
try: try:
yield db yield session
finally: finally:
db.close() await session.close()
@contextmanager @asynccontextmanager
def transaction_scope() -> Generator[Session, None, None]: async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
""" """
Provide a transactional scope for database operations. Provide an async transactional scope for database operations.
Automatically commits on success or rolls back on exception. Automatically commits on success or rolls back on exception.
Useful for grouping multiple operations in a single transaction. Useful for grouping multiple operations in a single transaction.
Usage: Usage:
with transaction_scope() as db: async with async_transaction_scope() as db:
user = user_crud.create(db, obj_in=user_create) user = await user_crud.create(db, obj_in=user_create)
profile = profile_crud.create(db, obj_in=profile_create) profile = await profile_crud.create(db, obj_in=profile_create)
# Both operations committed together # Both operations committed together
""" """
db = SessionLocal() async with SessionLocal() as session:
try: try:
yield db yield session
db.commit() await session.commit()
logger.debug("Transaction committed successfully") logger.debug("Async transaction committed successfully")
except Exception as e: except Exception as e:
db.rollback() await session.rollback()
logger.error(f"Transaction failed, rolling back: {str(e)}") logger.error(f"Async transaction failed, rolling back: {str(e)}")
raise raise
finally: finally:
db.close() await session.close()
def check_database_health() -> bool: async def check_async_database_health() -> bool:
""" """
Check if database connection is healthy. Check if async database connection is healthy.
Returns True if connection is successful, False otherwise. Returns True if connection is successful, False otherwise.
""" """
try: try:
with transaction_scope() as db: async with async_transaction_scope() as db:
db.execute(text("SELECT 1")) await db.execute(text("SELECT 1"))
return True return True
except Exception as e: except Exception as e:
logger.error(f"Database health check failed: {str(e)}") logger.error(f"Async database health check failed: {str(e)}")
return False return False
# Alias for consistency with main.py
check_database_health = check_async_database_health
async def init_async_db() -> None:
"""
Initialize async database tables.
This creates all tables defined in the models.
Should only be used in development or testing.
In production, use Alembic migrations.
"""
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.info("Async database tables created")
async def close_async_db() -> None:
"""
Close all async database connections.
Should be called during application shutdown.
"""
await engine.dispose()
logger.info("Async database connections closed")

View File

@@ -1,186 +0,0 @@
# app/core/database_async.py
"""
Async database configuration using SQLAlchemy 2.0 and asyncpg.
This module provides async database connectivity with proper connection pooling
and session management for FastAPI endpoints.
"""
import logging
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from sqlalchemy import text
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.ext.asyncio import (
AsyncSession,
AsyncEngine,
create_async_engine,
async_sessionmaker,
)
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import DeclarativeBase
from app.core.config import settings
# Configure logging
logger = logging.getLogger(__name__)
# SQLite compatibility for testing
@compiles(JSONB, 'sqlite')
def compile_jsonb_sqlite(type_, compiler, **kw):
return "TEXT"
@compiles(UUID, 'sqlite')
def compile_uuid_sqlite(type_, compiler, **kw):
return "TEXT"
# Declarative base for models (SQLAlchemy 2.0 style)
class Base(DeclarativeBase):
"""Base class for all database models."""
pass
def get_async_database_url(url: str) -> str:
"""
Convert sync database URL to async URL.
postgresql:// -> postgresql+asyncpg://
sqlite:// -> sqlite+aiosqlite://
"""
if url.startswith("postgresql://"):
return url.replace("postgresql://", "postgresql+asyncpg://")
elif url.startswith("sqlite://"):
return url.replace("sqlite://", "sqlite+aiosqlite://")
return url
# Create async engine with optimized settings
def create_async_production_engine() -> AsyncEngine:
"""Create an async database engine with production settings."""
async_url = get_async_database_url(settings.database_url)
# Base engine config
engine_config = {
"pool_size": settings.db_pool_size,
"max_overflow": settings.db_max_overflow,
"pool_timeout": settings.db_pool_timeout,
"pool_recycle": settings.db_pool_recycle,
"pool_pre_ping": True,
"echo": settings.sql_echo,
"echo_pool": settings.sql_echo_pool,
}
# Add PostgreSQL-specific connect_args
if "postgresql" in async_url:
engine_config["connect_args"] = {
"server_settings": {
"application_name": "eventspace",
"timezone": "UTC",
},
# asyncpg-specific settings
"command_timeout": 60,
"timeout": 10,
}
return create_async_engine(async_url, **engine_config)
# Create async engine and session factory
async_engine = create_async_production_engine()
AsyncSessionLocal = async_sessionmaker(
async_engine,
class_=AsyncSession,
autocommit=False,
autoflush=False,
expire_on_commit=False, # Prevent unnecessary queries after commit
)
# FastAPI dependency for async database sessions
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
"""
FastAPI dependency that provides an async database session.
Automatically closes the session after the request completes.
Usage:
@router.get("/users")
async def get_users(db: AsyncSession = Depends(get_async_db)):
result = await db.execute(select(User))
return result.scalars().all()
"""
async with AsyncSessionLocal() as session:
try:
yield session
finally:
await session.close()
@asynccontextmanager
async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
"""
Provide an async transactional scope for database operations.
Automatically commits on success or rolls back on exception.
Useful for grouping multiple operations in a single transaction.
Usage:
async with async_transaction_scope() as db:
user = await user_crud.create(db, obj_in=user_create)
profile = await profile_crud.create(db, obj_in=profile_create)
# Both operations committed together
"""
async with AsyncSessionLocal() as session:
try:
yield session
await session.commit()
logger.debug("Async transaction committed successfully")
except Exception as e:
await session.rollback()
logger.error(f"Async transaction failed, rolling back: {str(e)}")
raise
finally:
await session.close()
async def check_async_database_health() -> bool:
"""
Check if async database connection is healthy.
Returns True if connection is successful, False otherwise.
"""
try:
async with async_transaction_scope() as db:
await db.execute(text("SELECT 1"))
return True
except Exception as e:
logger.error(f"Async database health check failed: {str(e)}")
return False
# Alias for consistency with main.py
check_database_health = check_async_database_health
async def init_async_db() -> None:
"""
Initialize async database tables.
This creates all tables defined in the models.
Should only be used in development or testing.
In production, use Alembic migrations.
"""
async with async_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.info("Async database tables created")
async def close_async_db() -> None:
"""
Close all async database connections.
Should be called during application shutdown.
"""
await async_engine.dispose()
logger.info("Async database connections closed")

193
backend/app/crud/base.py Normal file → Executable file
View File

@@ -1,13 +1,19 @@
# app/crud/base_async.py
"""
Async CRUD operations base class using SQLAlchemy 2.0 async patterns.
Provides reusable create, read, update, and delete operations for all models.
"""
import logging import logging
import uuid import uuid
from datetime import datetime, timezone
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import asc, desc from sqlalchemy import func, select
from sqlalchemy.exc import IntegrityError, OperationalError, DataError from sqlalchemy.exc import IntegrityError, OperationalError, DataError
from sqlalchemy.orm import Session from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Load
from app.core.database import Base from app.core.database import Base
@@ -19,17 +25,40 @@ UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
"""Async CRUD operations for a model."""
def __init__(self, model: Type[ModelType]): def __init__(self, model: Type[ModelType]):
""" """
CRUD object with default methods to Create, Read, Update, Delete (CRUD). CRUD object with default async methods to Create, Read, Update, Delete.
Parameters: Parameters:
model: A SQLAlchemy model class model: A SQLAlchemy model class
""" """
self.model = model self.model = model
def get(self, db: Session, id: str) -> Optional[ModelType]: async def get(
"""Get a single record by ID with UUID validation.""" self,
db: AsyncSession,
id: str,
options: Optional[List[Load]] = None
) -> Optional[ModelType]:
"""
Get a single record by ID with UUID validation and optional eager loading.
Args:
db: Database session
id: Record UUID
options: Optional list of SQLAlchemy load options (e.g., joinedload, selectinload)
for eager loading relationships to prevent N+1 queries
Returns:
Model instance or None if not found
Example:
# Eager load user relationship
from sqlalchemy.orm import joinedload
session = await session_crud.get(db, id=session_id, options=[joinedload(UserSession.user)])
"""
# Validate UUID format and convert to UUID object if string # Validate UUID format and convert to UUID object if string
try: try:
if isinstance(id, uuid.UUID): if isinstance(id, uuid.UUID):
@@ -41,15 +70,39 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
return None return None
try: try:
return db.query(self.model).filter(self.model.id == uuid_obj).first() query = select(self.model).where(self.model.id == uuid_obj)
# Apply eager loading options if provided
if options:
for option in options:
query = query.options(option)
result = await db.execute(query)
return result.scalar_one_or_none()
except Exception as e: except Exception as e:
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}") logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}")
raise raise
def get_multi( async def get_multi(
self, db: Session, *, skip: int = 0, limit: int = 100 self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
options: Optional[List[Load]] = None
) -> List[ModelType]: ) -> List[ModelType]:
"""Get multiple records with pagination validation.""" """
Get multiple records with pagination validation and optional eager loading.
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
options: Optional list of SQLAlchemy load options for eager loading
Returns:
List of model instances
"""
# Validate pagination parameters # Validate pagination parameters
if skip < 0: if skip < 0:
raise ValueError("skip must be non-negative") raise ValueError("skip must be non-negative")
@@ -59,22 +112,30 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
raise ValueError("Maximum limit is 1000") raise ValueError("Maximum limit is 1000")
try: try:
return db.query(self.model).offset(skip).limit(limit).all() query = select(self.model).offset(skip).limit(limit)
# Apply eager loading options if provided
if options:
for option in options:
query = query.options(option)
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e: except Exception as e:
logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}") logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}")
raise raise
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType: async def create(self, db: AsyncSession, *, obj_in: CreateSchemaType) -> ModelType:
"""Create a new record with error handling.""" """Create a new record with error handling."""
try: try:
obj_in_data = jsonable_encoder(obj_in) obj_in_data = jsonable_encoder(obj_in)
db_obj = self.model(**obj_in_data) db_obj = self.model(**obj_in_data)
db.add(db_obj) db.add(db_obj)
db.commit() await db.commit()
db.refresh(db_obj) await db.refresh(db_obj)
return db_obj return db_obj
except IntegrityError as e: except IntegrityError as e:
db.rollback() await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e) error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower(): if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}") logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
@@ -82,17 +143,17 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}") logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}") raise ValueError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e: except (OperationalError, DataError) as e:
db.rollback() await db.rollback()
logger.error(f"Database error creating {self.model.__name__}: {str(e)}") logger.error(f"Database error creating {self.model.__name__}: {str(e)}")
raise ValueError(f"Database operation failed: {str(e)}") raise ValueError(f"Database operation failed: {str(e)}")
except Exception as e: except Exception as e:
db.rollback() await db.rollback()
logger.error(f"Unexpected error creating {self.model.__name__}: {str(e)}", exc_info=True) logger.error(f"Unexpected error creating {self.model.__name__}: {str(e)}", exc_info=True)
raise raise
def update( async def update(
self, self,
db: Session, db: AsyncSession,
*, *,
db_obj: ModelType, db_obj: ModelType,
obj_in: Union[UpdateSchemaType, Dict[str, Any]] obj_in: Union[UpdateSchemaType, Dict[str, Any]]
@@ -104,15 +165,17 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
update_data = obj_in update_data = obj_in
else: else:
update_data = obj_in.model_dump(exclude_unset=True) update_data = obj_in.model_dump(exclude_unset=True)
for field in obj_data: for field in obj_data:
if field in update_data: if field in update_data:
setattr(db_obj, field, update_data[field]) setattr(db_obj, field, update_data[field])
db.add(db_obj) db.add(db_obj)
db.commit() await db.commit()
db.refresh(db_obj) await db.refresh(db_obj)
return db_obj return db_obj
except IntegrityError as e: except IntegrityError as e:
db.rollback() await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e) error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower(): if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}") logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
@@ -120,15 +183,15 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}") logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}") raise ValueError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e: except (OperationalError, DataError) as e:
db.rollback() await db.rollback()
logger.error(f"Database error updating {self.model.__name__}: {str(e)}") logger.error(f"Database error updating {self.model.__name__}: {str(e)}")
raise ValueError(f"Database operation failed: {str(e)}") raise ValueError(f"Database operation failed: {str(e)}")
except Exception as e: except Exception as e:
db.rollback() await db.rollback()
logger.error(f"Unexpected error updating {self.model.__name__}: {str(e)}", exc_info=True) logger.error(f"Unexpected error updating {self.model.__name__}: {str(e)}", exc_info=True)
raise raise
def remove(self, db: Session, *, id: str) -> Optional[ModelType]: async def remove(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
"""Delete a record with error handling and null check.""" """Delete a record with error handling and null check."""
# Validate UUID format and convert to UUID object if string # Validate UUID format and convert to UUID object if string
try: try:
@@ -141,27 +204,31 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
return None return None
try: try:
obj = db.query(self.model).filter(self.model.id == uuid_obj).first() result = await db.execute(
select(self.model).where(self.model.id == uuid_obj)
)
obj = result.scalar_one_or_none()
if obj is None: if obj is None:
logger.warning(f"{self.model.__name__} with id {id} not found for deletion") logger.warning(f"{self.model.__name__} with id {id} not found for deletion")
return None return None
db.delete(obj) await db.delete(obj)
db.commit() await db.commit()
return obj return obj
except IntegrityError as e: except IntegrityError as e:
db.rollback() await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e) error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}") logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
raise ValueError(f"Cannot delete {self.model.__name__}: referenced by other records") raise ValueError(f"Cannot delete {self.model.__name__}: referenced by other records")
except Exception as e: except Exception as e:
db.rollback() await db.rollback()
logger.error(f"Error deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True) logger.error(f"Error deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
raise raise
def get_multi_with_total( async def get_multi_with_total(
self, self,
db: Session, db: AsyncSession,
*, *,
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
@@ -193,43 +260,63 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
try: try:
# Build base query # Build base query
query = db.query(self.model) query = select(self.model)
# Exclude soft-deleted records by default # Exclude soft-deleted records by default
if hasattr(self.model, 'deleted_at'): if hasattr(self.model, 'deleted_at'):
query = query.filter(self.model.deleted_at.is_(None)) query = query.where(self.model.deleted_at.is_(None))
# Apply filters # Apply filters
if filters: if filters:
for field, value in filters.items(): for field, value in filters.items():
if hasattr(self.model, field) and value is not None: if hasattr(self.model, field) and value is not None:
query = query.filter(getattr(self.model, field) == value) query = query.where(getattr(self.model, field) == value)
# Get total count (before pagination) # Get total count (before pagination)
total = query.count() count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply sorting # Apply sorting
if sort_by and hasattr(self.model, sort_by): if sort_by and hasattr(self.model, sort_by):
sort_column = getattr(self.model, sort_by) sort_column = getattr(self.model, sort_by)
if sort_order.lower() == "desc": if sort_order.lower() == "desc":
query = query.order_by(desc(sort_column)) query = query.order_by(sort_column.desc())
else: else:
query = query.order_by(asc(sort_column)) query = query.order_by(sort_column.asc())
# Apply pagination # Apply pagination
items = query.offset(skip).limit(limit).all() query = query.offset(skip).limit(limit)
items_result = await db.execute(query)
items = list(items_result.scalars().all())
return items, total return items, total
except Exception as e: except Exception as e:
logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}") logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}")
raise raise
def soft_delete(self, db: Session, *, id: str) -> Optional[ModelType]: async def count(self, db: AsyncSession) -> int:
"""Get total count of records."""
try:
result = await db.execute(select(func.count(self.model.id)))
return result.scalar_one()
except Exception as e:
logger.error(f"Error counting {self.model.__name__} records: {str(e)}")
raise
async def exists(self, db: AsyncSession, id: str) -> bool:
"""Check if a record exists by ID."""
obj = await self.get(db, id=id)
return obj is not None
async def soft_delete(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
""" """
Soft delete a record by setting deleted_at timestamp. Soft delete a record by setting deleted_at timestamp.
Only works if the model has a 'deleted_at' column. Only works if the model has a 'deleted_at' column.
""" """
from datetime import datetime, timezone
# Validate UUID format and convert to UUID object if string # Validate UUID format and convert to UUID object if string
try: try:
if isinstance(id, uuid.UUID): if isinstance(id, uuid.UUID):
@@ -241,7 +328,10 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
return None return None
try: try:
obj = db.query(self.model).filter(self.model.id == uuid_obj).first() result = await db.execute(
select(self.model).where(self.model.id == uuid_obj)
)
obj = result.scalar_one_or_none()
if obj is None: if obj is None:
logger.warning(f"{self.model.__name__} with id {id} not found for soft deletion") logger.warning(f"{self.model.__name__} with id {id} not found for soft deletion")
@@ -255,15 +345,15 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
# Set deleted_at timestamp # Set deleted_at timestamp
obj.deleted_at = datetime.now(timezone.utc) obj.deleted_at = datetime.now(timezone.utc)
db.add(obj) db.add(obj)
db.commit() await db.commit()
db.refresh(obj) await db.refresh(obj)
return obj return obj
except Exception as e: except Exception as e:
db.rollback() await db.rollback()
logger.error(f"Error soft deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True) logger.error(f"Error soft deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
raise raise
def restore(self, db: Session, *, id: str) -> Optional[ModelType]: async def restore(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
""" """
Restore a soft-deleted record by clearing the deleted_at timestamp. Restore a soft-deleted record by clearing the deleted_at timestamp.
@@ -282,10 +372,13 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
try: try:
# Find the soft-deleted record # Find the soft-deleted record
if hasattr(self.model, 'deleted_at'): if hasattr(self.model, 'deleted_at'):
obj = db.query(self.model).filter( result = await db.execute(
select(self.model).where(
self.model.id == uuid_obj, self.model.id == uuid_obj,
self.model.deleted_at.isnot(None) self.model.deleted_at.isnot(None)
).first() )
)
obj = result.scalar_one_or_none()
else: else:
logger.error(f"{self.model.__name__} does not support soft deletes") logger.error(f"{self.model.__name__} does not support soft deletes")
raise ValueError(f"{self.model.__name__} does not have a deleted_at column") raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
@@ -297,10 +390,10 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
# Clear deleted_at timestamp # Clear deleted_at timestamp
obj.deleted_at = None obj.deleted_at = None
db.add(obj) db.add(obj)
db.commit() await db.commit()
db.refresh(obj) await db.refresh(obj)
return obj return obj
except Exception as e: except Exception as e:
db.rollback() await db.rollback()
logger.error(f"Error restoring {self.model.__name__} with id {id}: {str(e)}", exc_info=True) logger.error(f"Error restoring {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
raise raise

View File

@@ -1,399 +0,0 @@
# app/crud/base_async.py
"""
Async CRUD operations base class using SQLAlchemy 2.0 async patterns.
Provides reusable create, read, update, and delete operations for all models.
"""
import logging
import uuid
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from sqlalchemy import func, select
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Load
from app.core.database_async import Base
logger = logging.getLogger(__name__)
ModelType = TypeVar("ModelType", bound=Base)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
"""Async CRUD operations for a model."""
def __init__(self, model: Type[ModelType]):
"""
CRUD object with default async methods to Create, Read, Update, Delete.
Parameters:
model: A SQLAlchemy model class
"""
self.model = model
async def get(
self,
db: AsyncSession,
id: str,
options: Optional[List[Load]] = None
) -> Optional[ModelType]:
"""
Get a single record by ID with UUID validation and optional eager loading.
Args:
db: Database session
id: Record UUID
options: Optional list of SQLAlchemy load options (e.g., joinedload, selectinload)
for eager loading relationships to prevent N+1 queries
Returns:
Model instance or None if not found
Example:
# Eager load user relationship
from sqlalchemy.orm import joinedload
session = await session_crud.get(db, id=session_id, options=[joinedload(UserSession.user)])
"""
# Validate UUID format and convert to UUID object if string
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format: {id} - {str(e)}")
return None
try:
query = select(self.model).where(self.model.id == uuid_obj)
# Apply eager loading options if provided
if options:
for option in options:
query = query.options(option)
result = await db.execute(query)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}")
raise
async def get_multi(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
options: Optional[List[Load]] = None
) -> List[ModelType]:
"""
Get multiple records with pagination validation and optional eager loading.
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
options: Optional list of SQLAlchemy load options for eager loading
Returns:
List of model instances
"""
# Validate pagination parameters
if skip < 0:
raise ValueError("skip must be non-negative")
if limit < 0:
raise ValueError("limit must be non-negative")
if limit > 1000:
raise ValueError("Maximum limit is 1000")
try:
query = select(self.model).offset(skip).limit(limit)
# Apply eager loading options if provided
if options:
for option in options:
query = query.options(option)
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}")
raise
async def create(self, db: AsyncSession, *, obj_in: CreateSchemaType) -> ModelType:
"""Create a new record with error handling."""
try:
obj_in_data = jsonable_encoder(obj_in)
db_obj = self.model(**obj_in_data)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
raise ValueError(f"A {self.model.__name__} with this data already exists")
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e:
await db.rollback()
logger.error(f"Database error creating {self.model.__name__}: {str(e)}")
raise ValueError(f"Database operation failed: {str(e)}")
except Exception as e:
await db.rollback()
logger.error(f"Unexpected error creating {self.model.__name__}: {str(e)}", exc_info=True)
raise
async def update(
self,
db: AsyncSession,
*,
db_obj: ModelType,
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
) -> ModelType:
"""Update a record with error handling."""
try:
obj_data = jsonable_encoder(db_obj)
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.model_dump(exclude_unset=True)
for field in obj_data:
if field in update_data:
setattr(db_obj, field, update_data[field])
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
raise ValueError(f"A {self.model.__name__} with this data already exists")
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e:
await db.rollback()
logger.error(f"Database error updating {self.model.__name__}: {str(e)}")
raise ValueError(f"Database operation failed: {str(e)}")
except Exception as e:
await db.rollback()
logger.error(f"Unexpected error updating {self.model.__name__}: {str(e)}", exc_info=True)
raise
async def remove(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
"""Delete a record with error handling and null check."""
# Validate UUID format and convert to UUID object if string
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for deletion: {id} - {str(e)}")
return None
try:
result = await db.execute(
select(self.model).where(self.model.id == uuid_obj)
)
obj = result.scalar_one_or_none()
if obj is None:
logger.warning(f"{self.model.__name__} with id {id} not found for deletion")
return None
await db.delete(obj)
await db.commit()
return obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
raise ValueError(f"Cannot delete {self.model.__name__}: referenced by other records")
except Exception as e:
await db.rollback()
logger.error(f"Error deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
raise
async def get_multi_with_total(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
sort_by: Optional[str] = None,
sort_order: str = "asc",
filters: Optional[Dict[str, Any]] = None
) -> Tuple[List[ModelType], int]:
"""
Get multiple records with total count, filtering, and sorting.
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
sort_by: Field name to sort by (must be a valid model attribute)
sort_order: Sort order ("asc" or "desc")
filters: Dictionary of filters (field_name: value)
Returns:
Tuple of (items, total_count)
"""
# Validate pagination parameters
if skip < 0:
raise ValueError("skip must be non-negative")
if limit < 0:
raise ValueError("limit must be non-negative")
if limit > 1000:
raise ValueError("Maximum limit is 1000")
try:
# Build base query
query = select(self.model)
# Exclude soft-deleted records by default
if hasattr(self.model, 'deleted_at'):
query = query.where(self.model.deleted_at.is_(None))
# Apply filters
if filters:
for field, value in filters.items():
if hasattr(self.model, field) and value is not None:
query = query.where(getattr(self.model, field) == value)
# Get total count (before pagination)
count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply sorting
if sort_by and hasattr(self.model, sort_by):
sort_column = getattr(self.model, sort_by)
if sort_order.lower() == "desc":
query = query.order_by(sort_column.desc())
else:
query = query.order_by(sort_column.asc())
# Apply pagination
query = query.offset(skip).limit(limit)
items_result = await db.execute(query)
items = list(items_result.scalars().all())
return items, total
except Exception as e:
logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}")
raise
async def count(self, db: AsyncSession) -> int:
"""Get total count of records."""
try:
result = await db.execute(select(func.count(self.model.id)))
return result.scalar_one()
except Exception as e:
logger.error(f"Error counting {self.model.__name__} records: {str(e)}")
raise
async def exists(self, db: AsyncSession, id: str) -> bool:
"""Check if a record exists by ID."""
obj = await self.get(db, id=id)
return obj is not None
async def soft_delete(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
"""
Soft delete a record by setting deleted_at timestamp.
Only works if the model has a 'deleted_at' column.
"""
from datetime import datetime, timezone
# Validate UUID format and convert to UUID object if string
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for soft deletion: {id} - {str(e)}")
return None
try:
result = await db.execute(
select(self.model).where(self.model.id == uuid_obj)
)
obj = result.scalar_one_or_none()
if obj is None:
logger.warning(f"{self.model.__name__} with id {id} not found for soft deletion")
return None
# Check if model supports soft deletes
if not hasattr(self.model, 'deleted_at'):
logger.error(f"{self.model.__name__} does not support soft deletes")
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
# Set deleted_at timestamp
obj.deleted_at = datetime.now(timezone.utc)
db.add(obj)
await db.commit()
await db.refresh(obj)
return obj
except Exception as e:
await db.rollback()
logger.error(f"Error soft deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
raise
async def restore(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
"""
Restore a soft-deleted record by clearing the deleted_at timestamp.
Only works if the model has a 'deleted_at' column.
"""
# Validate UUID format
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for restoration: {id} - {str(e)}")
return None
try:
# Find the soft-deleted record
if hasattr(self.model, 'deleted_at'):
result = await db.execute(
select(self.model).where(
self.model.id == uuid_obj,
self.model.deleted_at.isnot(None)
)
)
obj = result.scalar_one_or_none()
else:
logger.error(f"{self.model.__name__} does not support soft deletes")
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
if obj is None:
logger.warning(f"Soft-deleted {self.model.__name__} with id {id} not found for restoration")
return None
# Clear deleted_at timestamp
obj.deleted_at = None
db.add(obj)
await db.commit()
await db.refresh(obj)
return obj
except Exception as e:
await db.rollback()
logger.error(f"Error restoring {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
raise

338
backend/app/crud/organization.py Normal file → Executable file
View File

@@ -1,11 +1,12 @@
# app/crud/organization.py # app/crud/organization_async.py
"""Async CRUD operations for Organization model using SQLAlchemy 2.0 patterns."""
import logging import logging
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any
from uuid import UUID from uuid import UUID
from sqlalchemy import func, or_, and_ from sqlalchemy import func, or_, and_, select
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session from sqlalchemy.ext.asyncio import AsyncSession
from app.crud.base import CRUDBase from app.crud.base import CRUDBase
from app.models.organization import Organization from app.models.organization import Organization
@@ -13,20 +14,27 @@ from app.models.user import User
from app.models.user_organization import UserOrganization, OrganizationRole from app.models.user_organization import UserOrganization, OrganizationRole
from app.schemas.organizations import ( from app.schemas.organizations import (
OrganizationCreate, OrganizationCreate,
OrganizationUpdate OrganizationUpdate,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUpdate]): class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUpdate]):
"""CRUD operations for Organization model.""" """Async CRUD operations for Organization model."""
def get_by_slug(self, db: Session, *, slug: str) -> Optional[Organization]: async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Optional[Organization]:
"""Get organization by slug.""" """Get organization by slug."""
return db.query(Organization).filter(Organization.slug == slug).first() try:
result = await db.execute(
select(Organization).where(Organization.slug == slug)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting organization by slug {slug}: {str(e)}")
raise
def create(self, db: Session, *, obj_in: OrganizationCreate) -> Organization: async def create(self, db: AsyncSession, *, obj_in: OrganizationCreate) -> Organization:
"""Create a new organization with error handling.""" """Create a new organization with error handling."""
try: try:
db_obj = Organization( db_obj = Organization(
@@ -37,11 +45,11 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
settings=obj_in.settings or {} settings=obj_in.settings or {}
) )
db.add(db_obj) db.add(db_obj)
db.commit() await db.commit()
db.refresh(db_obj) await db.refresh(db_obj)
return db_obj return db_obj
except IntegrityError as e: except IntegrityError as e:
db.rollback() await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e) error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
if "slug" in error_msg.lower(): if "slug" in error_msg.lower():
logger.warning(f"Duplicate slug attempted: {obj_in.slug}") logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
@@ -49,13 +57,13 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
logger.error(f"Integrity error creating organization: {error_msg}") logger.error(f"Integrity error creating organization: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}") raise ValueError(f"Database integrity error: {error_msg}")
except Exception as e: except Exception as e:
db.rollback() await db.rollback()
logger.error(f"Unexpected error creating organization: {str(e)}", exc_info=True) logger.error(f"Unexpected error creating organization: {str(e)}", exc_info=True)
raise raise
def get_multi_with_filters( async def get_multi_with_filters(
self, self,
db: Session, db: AsyncSession,
*, *,
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
@@ -70,11 +78,12 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
Returns: Returns:
Tuple of (organizations list, total count) Tuple of (organizations list, total count)
""" """
query = db.query(Organization) try:
query = select(Organization)
# Apply filters # Apply filters
if is_active is not None: if is_active is not None:
query = query.filter(Organization.is_active == is_active) query = query.where(Organization.is_active == is_active)
if search: if search:
search_filter = or_( search_filter = or_(
@@ -82,10 +91,12 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
Organization.slug.ilike(f"%{search}%"), Organization.slug.ilike(f"%{search}%"),
Organization.description.ilike(f"%{search}%") Organization.description.ilike(f"%{search}%")
) )
query = query.filter(search_filter) query = query.where(search_filter)
# Get total count before pagination # Get total count before pagination
total = query.count() count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply sorting # Apply sorting
sort_column = getattr(Organization, sort_by, Organization.created_at) sort_column = getattr(Organization, sort_by, Organization.created_at)
@@ -95,22 +106,111 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
query = query.order_by(sort_column.asc()) query = query.order_by(sort_column.asc())
# Apply pagination # Apply pagination
organizations = query.offset(skip).limit(limit).all() query = query.offset(skip).limit(limit)
result = await db.execute(query)
organizations = list(result.scalars().all())
return organizations, total return organizations, total
except Exception as e:
logger.error(f"Error getting organizations with filters: {str(e)}")
raise
def get_member_count(self, db: Session, *, organization_id: UUID) -> int: async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
"""Get the count of active members in an organization.""" """Get the count of active members in an organization."""
return db.query(func.count(UserOrganization.user_id)).filter( try:
result = await db.execute(
select(func.count(UserOrganization.user_id)).where(
and_( and_(
UserOrganization.organization_id == organization_id, UserOrganization.organization_id == organization_id,
UserOrganization.is_active == True UserOrganization.is_active == True
) )
).scalar() or 0 )
)
return result.scalar_one() or 0
except Exception as e:
logger.error(f"Error getting member count for organization {organization_id}: {str(e)}")
raise
def add_user( async def get_multi_with_member_counts(
self, self,
db: Session, db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None,
search: Optional[str] = None
) -> tuple[List[Dict[str, Any]], int]:
"""
Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY.
This eliminates the N+1 query problem.
Returns:
Tuple of (list of dicts with org and member_count, total count)
"""
try:
# Build base query with LEFT JOIN and GROUP BY
query = (
select(
Organization,
func.count(
func.distinct(
and_(
UserOrganization.is_active == True,
UserOrganization.user_id
).self_group()
)
).label('member_count')
)
.outerjoin(UserOrganization, Organization.id == UserOrganization.organization_id)
.group_by(Organization.id)
)
# Apply filters
if is_active is not None:
query = query.where(Organization.is_active == is_active)
if search:
search_filter = or_(
Organization.name.ilike(f"%{search}%"),
Organization.slug.ilike(f"%{search}%"),
Organization.description.ilike(f"%{search}%")
)
query = query.where(search_filter)
# Get total count
count_query = select(func.count(Organization.id))
if is_active is not None:
count_query = count_query.where(Organization.is_active == is_active)
if search:
count_query = count_query.where(search_filter)
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply pagination and ordering
query = query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
result = await db.execute(query)
rows = result.all()
# Convert to list of dicts
orgs_with_counts = [
{
'organization': org,
'member_count': member_count
}
for org, member_count in rows
]
return orgs_with_counts, total
except Exception as e:
logger.error(f"Error getting organizations with member counts: {str(e)}", exc_info=True)
raise
async def add_user(
self,
db: AsyncSession,
*, *,
organization_id: UUID, organization_id: UUID,
user_id: UUID, user_id: UUID,
@@ -120,12 +220,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
"""Add a user to an organization with a specific role.""" """Add a user to an organization with a specific role."""
try: try:
# Check if relationship already exists # Check if relationship already exists
existing = db.query(UserOrganization).filter( result = await db.execute(
select(UserOrganization).where(
and_( and_(
UserOrganization.user_id == user_id, UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id UserOrganization.organization_id == organization_id
) )
).first() )
)
existing = result.scalar_one_or_none()
if existing: if existing:
# Reactivate if inactive, or raise error if already active # Reactivate if inactive, or raise error if already active
@@ -133,8 +236,8 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
existing.is_active = True existing.is_active = True
existing.role = role existing.role = role
existing.custom_permissions = custom_permissions existing.custom_permissions = custom_permissions
db.commit() await db.commit()
db.refresh(existing) await db.refresh(existing)
return existing return existing
else: else:
raise ValueError("User is already a member of this organization") raise ValueError("User is already a member of this organization")
@@ -148,48 +251,51 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
custom_permissions=custom_permissions custom_permissions=custom_permissions
) )
db.add(user_org) db.add(user_org)
db.commit() await db.commit()
db.refresh(user_org) await db.refresh(user_org)
return user_org return user_org
except IntegrityError as e: except IntegrityError as e:
db.rollback() await db.rollback()
logger.error(f"Integrity error adding user to organization: {str(e)}") logger.error(f"Integrity error adding user to organization: {str(e)}")
raise ValueError("Failed to add user to organization") raise ValueError("Failed to add user to organization")
except Exception as e: except Exception as e:
db.rollback() await db.rollback()
logger.error(f"Error adding user to organization: {str(e)}", exc_info=True) logger.error(f"Error adding user to organization: {str(e)}", exc_info=True)
raise raise
def remove_user( async def remove_user(
self, self,
db: Session, db: AsyncSession,
*, *,
organization_id: UUID, organization_id: UUID,
user_id: UUID user_id: UUID
) -> bool: ) -> bool:
"""Remove a user from an organization (soft delete).""" """Remove a user from an organization (soft delete)."""
try: try:
user_org = db.query(UserOrganization).filter( result = await db.execute(
select(UserOrganization).where(
and_( and_(
UserOrganization.user_id == user_id, UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id UserOrganization.organization_id == organization_id
) )
).first() )
)
user_org = result.scalar_one_or_none()
if not user_org: if not user_org:
return False return False
user_org.is_active = False user_org.is_active = False
db.commit() await db.commit()
return True return True
except Exception as e: except Exception as e:
db.rollback() await db.rollback()
logger.error(f"Error removing user from organization: {str(e)}", exc_info=True) logger.error(f"Error removing user from organization: {str(e)}", exc_info=True)
raise raise
def update_user_role( async def update_user_role(
self, self,
db: Session, db: AsyncSession,
*, *,
organization_id: UUID, organization_id: UUID,
user_id: UUID, user_id: UUID,
@@ -198,12 +304,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
) -> Optional[UserOrganization]: ) -> Optional[UserOrganization]:
"""Update a user's role in an organization.""" """Update a user's role in an organization."""
try: try:
user_org = db.query(UserOrganization).filter( result = await db.execute(
select(UserOrganization).where(
and_( and_(
UserOrganization.user_id == user_id, UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id UserOrganization.organization_id == organization_id
) )
).first() )
)
user_org = result.scalar_one_or_none()
if not user_org: if not user_org:
return None return None
@@ -211,17 +320,17 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
user_org.role = role user_org.role = role
if custom_permissions is not None: if custom_permissions is not None:
user_org.custom_permissions = custom_permissions user_org.custom_permissions = custom_permissions
db.commit() await db.commit()
db.refresh(user_org) await db.refresh(user_org)
return user_org return user_org
except Exception as e: except Exception as e:
db.rollback() await db.rollback()
logger.error(f"Error updating user role: {str(e)}", exc_info=True) logger.error(f"Error updating user role: {str(e)}", exc_info=True)
raise raise
def get_organization_members( async def get_organization_members(
self, self,
db: Session, db: AsyncSession,
*, *,
organization_id: UUID, organization_id: UUID,
skip: int = 0, skip: int = 0,
@@ -234,16 +343,31 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
Returns: Returns:
Tuple of (members list with user details, total count) Tuple of (members list with user details, total count)
""" """
query = db.query(UserOrganization, User).join( try:
User, UserOrganization.user_id == User.id # Build query with join
).filter(UserOrganization.organization_id == organization_id) query = (
select(UserOrganization, User)
.join(User, UserOrganization.user_id == User.id)
.where(UserOrganization.organization_id == organization_id)
)
if is_active is not None: if is_active is not None:
query = query.filter(UserOrganization.is_active == is_active) query = query.where(UserOrganization.is_active == is_active)
total = query.count() # Get total count
count_query = select(func.count()).select_from(
select(UserOrganization)
.where(UserOrganization.organization_id == organization_id)
.where(UserOrganization.is_active == is_active if is_active is not None else True)
.alias()
)
count_result = await db.execute(count_query)
total = count_result.scalar_one()
results = query.order_by(UserOrganization.created_at.desc()).offset(skip).limit(limit).all() # Apply ordering and pagination
query = query.order_by(UserOrganization.created_at.desc()).offset(skip).limit(limit)
result = await db.execute(query)
results = result.all()
members = [] members = []
for user_org, user in results: for user_org, user in results:
@@ -258,62 +382,136 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
}) })
return members, total return members, total
except Exception as e:
logger.error(f"Error getting organization members: {str(e)}")
raise
def get_user_organizations( async def get_user_organizations(
self, self,
db: Session, db: AsyncSession,
*, *,
user_id: UUID, user_id: UUID,
is_active: bool = True is_active: bool = True
) -> List[Organization]: ) -> List[Organization]:
"""Get all organizations a user belongs to.""" """Get all organizations a user belongs to."""
query = db.query(Organization).join( try:
UserOrganization, Organization.id == UserOrganization.organization_id query = (
).filter(UserOrganization.user_id == user_id) select(Organization)
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
.where(UserOrganization.user_id == user_id)
)
if is_active is not None: if is_active is not None:
query = query.filter(UserOrganization.is_active == is_active) query = query.where(UserOrganization.is_active == is_active)
return query.all() result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting user organizations: {str(e)}")
raise
def get_user_role_in_org( async def get_user_organizations_with_details(
self, self,
db: Session, db: AsyncSession,
*,
user_id: UUID,
is_active: bool = True
) -> List[Dict[str, Any]]:
"""
Get user's organizations with role and member count in SINGLE QUERY.
Eliminates N+1 problem by using subquery for member counts.
Returns:
List of dicts with organization, role, and member_count
"""
try:
# Subquery to get member counts for each organization
member_count_subq = (
select(
UserOrganization.organization_id,
func.count(UserOrganization.user_id).label('member_count')
)
.where(UserOrganization.is_active == True)
.group_by(UserOrganization.organization_id)
.subquery()
)
# Main query with JOIN to get org, role, and member count
query = (
select(
Organization,
UserOrganization.role,
func.coalesce(member_count_subq.c.member_count, 0).label('member_count')
)
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
.outerjoin(member_count_subq, Organization.id == member_count_subq.c.organization_id)
.where(UserOrganization.user_id == user_id)
)
if is_active is not None:
query = query.where(UserOrganization.is_active == is_active)
result = await db.execute(query)
rows = result.all()
return [
{
'organization': org,
'role': role,
'member_count': member_count
}
for org, role, member_count in rows
]
except Exception as e:
logger.error(f"Error getting user organizations with details: {str(e)}", exc_info=True)
raise
async def get_user_role_in_org(
self,
db: AsyncSession,
*, *,
user_id: UUID, user_id: UUID,
organization_id: UUID organization_id: UUID
) -> Optional[OrganizationRole]: ) -> Optional[OrganizationRole]:
"""Get a user's role in a specific organization.""" """Get a user's role in a specific organization."""
user_org = db.query(UserOrganization).filter( try:
result = await db.execute(
select(UserOrganization).where(
and_( and_(
UserOrganization.user_id == user_id, UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id, UserOrganization.organization_id == organization_id,
UserOrganization.is_active == True UserOrganization.is_active == True
) )
).first() )
)
user_org = result.scalar_one_or_none()
return user_org.role if user_org else None return user_org.role if user_org else None
except Exception as e:
logger.error(f"Error getting user role in org: {str(e)}")
raise
def is_user_org_owner( async def is_user_org_owner(
self, self,
db: Session, db: AsyncSession,
*, *,
user_id: UUID, user_id: UUID,
organization_id: UUID organization_id: UUID
) -> bool: ) -> bool:
"""Check if a user is an owner of an organization.""" """Check if a user is an owner of an organization."""
role = self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id) role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
return role == OrganizationRole.OWNER return role == OrganizationRole.OWNER
def is_user_org_admin( async def is_user_org_admin(
self, self,
db: Session, db: AsyncSession,
*, *,
user_id: UUID, user_id: UUID,
organization_id: UUID organization_id: UUID
) -> bool: ) -> bool:
"""Check if a user is an owner or admin of an organization.""" """Check if a user is an owner or admin of an organization."""
role = self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id) role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN] return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]

View File

@@ -1,519 +0,0 @@
# app/crud/organization_async.py
"""Async CRUD operations for Organization model using SQLAlchemy 2.0 patterns."""
import logging
from typing import Optional, List, Dict, Any
from uuid import UUID
from sqlalchemy import func, or_, and_, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.crud.base_async import CRUDBaseAsync
from app.models.organization import Organization
from app.models.user import User
from app.models.user_organization import UserOrganization, OrganizationRole
from app.schemas.organizations import (
OrganizationCreate,
OrganizationUpdate,
)
logger = logging.getLogger(__name__)
class CRUDOrganizationAsync(CRUDBaseAsync[Organization, OrganizationCreate, OrganizationUpdate]):
"""Async CRUD operations for Organization model."""
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Optional[Organization]:
"""Get organization by slug."""
try:
result = await db.execute(
select(Organization).where(Organization.slug == slug)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting organization by slug {slug}: {str(e)}")
raise
async def create(self, db: AsyncSession, *, obj_in: OrganizationCreate) -> Organization:
"""Create a new organization with error handling."""
try:
db_obj = Organization(
name=obj_in.name,
slug=obj_in.slug,
description=obj_in.description,
is_active=obj_in.is_active,
settings=obj_in.settings or {}
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
if "slug" in error_msg.lower():
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
raise ValueError(f"Organization with slug '{obj_in.slug}' already exists")
logger.error(f"Integrity error creating organization: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except Exception as e:
await db.rollback()
logger.error(f"Unexpected error creating organization: {str(e)}", exc_info=True)
raise
async def get_multi_with_filters(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None,
search: Optional[str] = None,
sort_by: str = "created_at",
sort_order: str = "desc"
) -> tuple[List[Organization], int]:
"""
Get multiple organizations with filtering, searching, and sorting.
Returns:
Tuple of (organizations list, total count)
"""
try:
query = select(Organization)
# Apply filters
if is_active is not None:
query = query.where(Organization.is_active == is_active)
if search:
search_filter = or_(
Organization.name.ilike(f"%{search}%"),
Organization.slug.ilike(f"%{search}%"),
Organization.description.ilike(f"%{search}%")
)
query = query.where(search_filter)
# Get total count before pagination
count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply sorting
sort_column = getattr(Organization, sort_by, Organization.created_at)
if sort_order == "desc":
query = query.order_by(sort_column.desc())
else:
query = query.order_by(sort_column.asc())
# Apply pagination
query = query.offset(skip).limit(limit)
result = await db.execute(query)
organizations = list(result.scalars().all())
return organizations, total
except Exception as e:
logger.error(f"Error getting organizations with filters: {str(e)}")
raise
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
"""Get the count of active members in an organization."""
try:
result = await db.execute(
select(func.count(UserOrganization.user_id)).where(
and_(
UserOrganization.organization_id == organization_id,
UserOrganization.is_active == True
)
)
)
return result.scalar_one() or 0
except Exception as e:
logger.error(f"Error getting member count for organization {organization_id}: {str(e)}")
raise
async def get_multi_with_member_counts(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None,
search: Optional[str] = None
) -> tuple[List[Dict[str, Any]], int]:
"""
Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY.
This eliminates the N+1 query problem.
Returns:
Tuple of (list of dicts with org and member_count, total count)
"""
try:
# Build base query with LEFT JOIN and GROUP BY
query = (
select(
Organization,
func.count(
func.distinct(
and_(
UserOrganization.is_active == True,
UserOrganization.user_id
).self_group()
)
).label('member_count')
)
.outerjoin(UserOrganization, Organization.id == UserOrganization.organization_id)
.group_by(Organization.id)
)
# Apply filters
if is_active is not None:
query = query.where(Organization.is_active == is_active)
if search:
search_filter = or_(
Organization.name.ilike(f"%{search}%"),
Organization.slug.ilike(f"%{search}%"),
Organization.description.ilike(f"%{search}%")
)
query = query.where(search_filter)
# Get total count
count_query = select(func.count(Organization.id))
if is_active is not None:
count_query = count_query.where(Organization.is_active == is_active)
if search:
count_query = count_query.where(search_filter)
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply pagination and ordering
query = query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
result = await db.execute(query)
rows = result.all()
# Convert to list of dicts
orgs_with_counts = [
{
'organization': org,
'member_count': member_count
}
for org, member_count in rows
]
return orgs_with_counts, total
except Exception as e:
logger.error(f"Error getting organizations with member counts: {str(e)}", exc_info=True)
raise
async def add_user(
self,
db: AsyncSession,
*,
organization_id: UUID,
user_id: UUID,
role: OrganizationRole = OrganizationRole.MEMBER,
custom_permissions: Optional[str] = None
) -> UserOrganization:
"""Add a user to an organization with a specific role."""
try:
# Check if relationship already exists
result = await db.execute(
select(UserOrganization).where(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id
)
)
)
existing = result.scalar_one_or_none()
if existing:
# Reactivate if inactive, or raise error if already active
if not existing.is_active:
existing.is_active = True
existing.role = role
existing.custom_permissions = custom_permissions
await db.commit()
await db.refresh(existing)
return existing
else:
raise ValueError("User is already a member of this organization")
# Create new relationship
user_org = UserOrganization(
user_id=user_id,
organization_id=organization_id,
role=role,
is_active=True,
custom_permissions=custom_permissions
)
db.add(user_org)
await db.commit()
await db.refresh(user_org)
return user_org
except IntegrityError as e:
await db.rollback()
logger.error(f"Integrity error adding user to organization: {str(e)}")
raise ValueError("Failed to add user to organization")
except Exception as e:
await db.rollback()
logger.error(f"Error adding user to organization: {str(e)}", exc_info=True)
raise
async def remove_user(
self,
db: AsyncSession,
*,
organization_id: UUID,
user_id: UUID
) -> bool:
"""Remove a user from an organization (soft delete)."""
try:
result = await db.execute(
select(UserOrganization).where(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id
)
)
)
user_org = result.scalar_one_or_none()
if not user_org:
return False
user_org.is_active = False
await db.commit()
return True
except Exception as e:
await db.rollback()
logger.error(f"Error removing user from organization: {str(e)}", exc_info=True)
raise
async def update_user_role(
self,
db: AsyncSession,
*,
organization_id: UUID,
user_id: UUID,
role: OrganizationRole,
custom_permissions: Optional[str] = None
) -> Optional[UserOrganization]:
"""Update a user's role in an organization."""
try:
result = await db.execute(
select(UserOrganization).where(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id
)
)
)
user_org = result.scalar_one_or_none()
if not user_org:
return None
user_org.role = role
if custom_permissions is not None:
user_org.custom_permissions = custom_permissions
await db.commit()
await db.refresh(user_org)
return user_org
except Exception as e:
await db.rollback()
logger.error(f"Error updating user role: {str(e)}", exc_info=True)
raise
async def get_organization_members(
self,
db: AsyncSession,
*,
organization_id: UUID,
skip: int = 0,
limit: int = 100,
is_active: bool = True
) -> tuple[List[Dict[str, Any]], int]:
"""
Get members of an organization with user details.
Returns:
Tuple of (members list with user details, total count)
"""
try:
# Build query with join
query = (
select(UserOrganization, User)
.join(User, UserOrganization.user_id == User.id)
.where(UserOrganization.organization_id == organization_id)
)
if is_active is not None:
query = query.where(UserOrganization.is_active == is_active)
# Get total count
count_query = select(func.count()).select_from(
select(UserOrganization)
.where(UserOrganization.organization_id == organization_id)
.where(UserOrganization.is_active == is_active if is_active is not None else True)
.alias()
)
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply ordering and pagination
query = query.order_by(UserOrganization.created_at.desc()).offset(skip).limit(limit)
result = await db.execute(query)
results = result.all()
members = []
for user_org, user in results:
members.append({
"user_id": user.id,
"email": user.email,
"first_name": user.first_name,
"last_name": user.last_name,
"role": user_org.role,
"is_active": user_org.is_active,
"joined_at": user_org.created_at
})
return members, total
except Exception as e:
logger.error(f"Error getting organization members: {str(e)}")
raise
async def get_user_organizations(
self,
db: AsyncSession,
*,
user_id: UUID,
is_active: bool = True
) -> List[Organization]:
"""Get all organizations a user belongs to."""
try:
query = (
select(Organization)
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
.where(UserOrganization.user_id == user_id)
)
if is_active is not None:
query = query.where(UserOrganization.is_active == is_active)
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting user organizations: {str(e)}")
raise
async def get_user_organizations_with_details(
self,
db: AsyncSession,
*,
user_id: UUID,
is_active: bool = True
) -> List[Dict[str, Any]]:
"""
Get user's organizations with role and member count in SINGLE QUERY.
Eliminates N+1 problem by using subquery for member counts.
Returns:
List of dicts with organization, role, and member_count
"""
try:
# Subquery to get member counts for each organization
member_count_subq = (
select(
UserOrganization.organization_id,
func.count(UserOrganization.user_id).label('member_count')
)
.where(UserOrganization.is_active == True)
.group_by(UserOrganization.organization_id)
.subquery()
)
# Main query with JOIN to get org, role, and member count
query = (
select(
Organization,
UserOrganization.role,
func.coalesce(member_count_subq.c.member_count, 0).label('member_count')
)
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
.outerjoin(member_count_subq, Organization.id == member_count_subq.c.organization_id)
.where(UserOrganization.user_id == user_id)
)
if is_active is not None:
query = query.where(UserOrganization.is_active == is_active)
result = await db.execute(query)
rows = result.all()
return [
{
'organization': org,
'role': role,
'member_count': member_count
}
for org, role, member_count in rows
]
except Exception as e:
logger.error(f"Error getting user organizations with details: {str(e)}", exc_info=True)
raise
async def get_user_role_in_org(
self,
db: AsyncSession,
*,
user_id: UUID,
organization_id: UUID
) -> Optional[OrganizationRole]:
"""Get a user's role in a specific organization."""
try:
result = await db.execute(
select(UserOrganization).where(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id,
UserOrganization.is_active == True
)
)
)
user_org = result.scalar_one_or_none()
return user_org.role if user_org else None
except Exception as e:
logger.error(f"Error getting user role in org: {str(e)}")
raise
async def is_user_org_owner(
self,
db: AsyncSession,
*,
user_id: UUID,
organization_id: UUID
) -> bool:
"""Check if a user is an owner of an organization."""
role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
return role == OrganizationRole.OWNER
async def is_user_org_admin(
self,
db: AsyncSession,
*,
user_id: UUID,
organization_id: UUID
) -> bool:
"""Check if a user is an owner or admin of an organization."""
role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
# Create a singleton instance for use across the application
organization_async = CRUDOrganizationAsync(Organization)

204
backend/app/crud/session.py Normal file → Executable file
View File

@@ -1,13 +1,14 @@
""" """
CRUD operations for user sessions. Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns.
""" """
import logging import logging
from datetime import datetime, timezone, timedelta from datetime import datetime, timezone, timedelta
from typing import List, Optional from typing import List, Optional
from uuid import UUID from uuid import UUID
from sqlalchemy import and_ from sqlalchemy import and_, select, update, delete, func
from sqlalchemy.orm import Session from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from app.crud.base import CRUDBase from app.crud.base import CRUDBase
from app.models.user_session import UserSession from app.models.user_session import UserSession
@@ -17,9 +18,9 @@ logger = logging.getLogger(__name__)
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
"""CRUD operations for user sessions.""" """Async CRUD operations for user sessions."""
def get_by_jti(self, db: Session, *, jti: str) -> Optional[UserSession]: async def get_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]:
""" """
Get session by refresh token JTI. Get session by refresh token JTI.
@@ -31,14 +32,15 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
UserSession if found, None otherwise UserSession if found, None otherwise
""" """
try: try:
return db.query(UserSession).filter( result = await db.execute(
UserSession.refresh_token_jti == jti select(UserSession).where(UserSession.refresh_token_jti == jti)
).first() )
return result.scalar_one_or_none()
except Exception as e: except Exception as e:
logger.error(f"Error getting session by JTI {jti}: {str(e)}") logger.error(f"Error getting session by JTI {jti}: {str(e)}")
raise raise
def get_active_by_jti(self, db: Session, *, jti: str) -> Optional[UserSession]: async def get_active_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]:
""" """
Get active session by refresh token JTI. Get active session by refresh token JTI.
@@ -50,30 +52,35 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
Active UserSession if found, None otherwise Active UserSession if found, None otherwise
""" """
try: try:
return db.query(UserSession).filter( result = await db.execute(
select(UserSession).where(
and_( and_(
UserSession.refresh_token_jti == jti, UserSession.refresh_token_jti == jti,
UserSession.is_active == True UserSession.is_active == True
) )
).first() )
)
return result.scalar_one_or_none()
except Exception as e: except Exception as e:
logger.error(f"Error getting active session by JTI {jti}: {str(e)}") logger.error(f"Error getting active session by JTI {jti}: {str(e)}")
raise raise
def get_user_sessions( async def get_user_sessions(
self, self,
db: Session, db: AsyncSession,
*, *,
user_id: str, user_id: str,
active_only: bool = True active_only: bool = True,
with_user: bool = False
) -> List[UserSession]: ) -> List[UserSession]:
""" """
Get all sessions for a user. Get all sessions for a user with optional eager loading.
Args: Args:
db: Database session db: Database session
user_id: User ID user_id: User ID
active_only: If True, return only active sessions active_only: If True, return only active sessions
with_user: If True, eager load user relationship to prevent N+1
Returns: Returns:
List of UserSession objects List of UserSession objects
@@ -82,19 +89,25 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
# Convert user_id string to UUID if needed # Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
query = db.query(UserSession).filter(UserSession.user_id == user_uuid) query = select(UserSession).where(UserSession.user_id == user_uuid)
# Add eager loading if requested to prevent N+1 queries
if with_user:
query = query.options(joinedload(UserSession.user))
if active_only: if active_only:
query = query.filter(UserSession.is_active == True) query = query.where(UserSession.is_active == True)
return query.order_by(UserSession.last_used_at.desc()).all() query = query.order_by(UserSession.last_used_at.desc())
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e: except Exception as e:
logger.error(f"Error getting sessions for user {user_id}: {str(e)}") logger.error(f"Error getting sessions for user {user_id}: {str(e)}")
raise raise
def create_session( async def create_session(
self, self,
db: Session, db: AsyncSession,
*, *,
obj_in: SessionCreate obj_in: SessionCreate
) -> UserSession: ) -> UserSession:
@@ -126,8 +139,8 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
location_country=obj_in.location_country, location_country=obj_in.location_country,
) )
db.add(db_obj) db.add(db_obj)
db.commit() await db.commit()
db.refresh(db_obj) await db.refresh(db_obj)
logger.info( logger.info(
f"Session created for user {obj_in.user_id} from {obj_in.device_name} " f"Session created for user {obj_in.user_id} from {obj_in.device_name} "
@@ -136,11 +149,11 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
return db_obj return db_obj
except Exception as e: except Exception as e:
db.rollback() await db.rollback()
logger.error(f"Error creating session: {str(e)}", exc_info=True) logger.error(f"Error creating session: {str(e)}", exc_info=True)
raise ValueError(f"Failed to create session: {str(e)}") raise ValueError(f"Failed to create session: {str(e)}")
def deactivate(self, db: Session, *, session_id: str) -> Optional[UserSession]: async def deactivate(self, db: AsyncSession, *, session_id: str) -> Optional[UserSession]:
""" """
Deactivate a session (logout from device). Deactivate a session (logout from device).
@@ -152,15 +165,15 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
Deactivated UserSession if found, None otherwise Deactivated UserSession if found, None otherwise
""" """
try: try:
session = self.get(db, id=session_id) session = await self.get(db, id=session_id)
if not session: if not session:
logger.warning(f"Session {session_id} not found for deactivation") logger.warning(f"Session {session_id} not found for deactivation")
return None return None
session.is_active = False session.is_active = False
db.add(session) db.add(session)
db.commit() await db.commit()
db.refresh(session) await db.refresh(session)
logger.info( logger.info(
f"Session {session_id} deactivated for user {session.user_id} " f"Session {session_id} deactivated for user {session.user_id} "
@@ -169,13 +182,13 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
return session return session
except Exception as e: except Exception as e:
db.rollback() await db.rollback()
logger.error(f"Error deactivating session {session_id}: {str(e)}") logger.error(f"Error deactivating session {session_id}: {str(e)}")
raise raise
def deactivate_all_user_sessions( async def deactivate_all_user_sessions(
self, self,
db: Session, db: AsyncSession,
*, *,
user_id: str user_id: str
) -> int: ) -> int:
@@ -193,26 +206,33 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
# Convert user_id string to UUID if needed # Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
count = db.query(UserSession).filter( stmt = (
update(UserSession)
.where(
and_( and_(
UserSession.user_id == user_uuid, UserSession.user_id == user_uuid,
UserSession.is_active == True UserSession.is_active == True
) )
).update({"is_active": False}) )
.values(is_active=False)
)
db.commit() result = await db.execute(stmt)
await db.commit()
count = result.rowcount
logger.info(f"Deactivated {count} sessions for user {user_id}") logger.info(f"Deactivated {count} sessions for user {user_id}")
return count return count
except Exception as e: except Exception as e:
db.rollback() await db.rollback()
logger.error(f"Error deactivating all sessions for user {user_id}: {str(e)}") logger.error(f"Error deactivating all sessions for user {user_id}: {str(e)}")
raise raise
def update_last_used( async def update_last_used(
self, self,
db: Session, db: AsyncSession,
*, *,
session: UserSession session: UserSession
) -> UserSession: ) -> UserSession:
@@ -229,17 +249,17 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
try: try:
session.last_used_at = datetime.now(timezone.utc) session.last_used_at = datetime.now(timezone.utc)
db.add(session) db.add(session)
db.commit() await db.commit()
db.refresh(session) await db.refresh(session)
return session return session
except Exception as e: except Exception as e:
db.rollback() await db.rollback()
logger.error(f"Error updating last_used for session {session.id}: {str(e)}") logger.error(f"Error updating last_used for session {session.id}: {str(e)}")
raise raise
def update_refresh_token( async def update_refresh_token(
self, self,
db: Session, db: AsyncSession,
*, *,
session: UserSession, session: UserSession,
new_jti: str, new_jti: str,
@@ -264,22 +284,24 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
session.expires_at = new_expires_at session.expires_at = new_expires_at
session.last_used_at = datetime.now(timezone.utc) session.last_used_at = datetime.now(timezone.utc)
db.add(session) db.add(session)
db.commit() await db.commit()
db.refresh(session) await db.refresh(session)
return session return session
except Exception as e: except Exception as e:
db.rollback() await db.rollback()
logger.error(f"Error updating refresh token for session {session.id}: {str(e)}") logger.error(f"Error updating refresh token for session {session.id}: {str(e)}")
raise raise
def cleanup_expired(self, db: Session, *, keep_days: int = 30) -> int: async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
""" """
Clean up expired sessions. Clean up expired sessions using optimized bulk DELETE.
Deletes sessions that are: Deletes sessions that are:
- Expired AND inactive - Expired AND inactive
- Older than keep_days - Older than keep_days
Uses single DELETE query instead of N individual deletes for efficiency.
Args: Args:
db: Database session db: Database session
keep_days: Keep inactive sessions for this many days (for audit) keep_days: Keep inactive sessions for this many days (for audit)
@@ -289,31 +311,87 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
""" """
try: try:
cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days) cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days)
now = datetime.now(timezone.utc)
# Delete sessions that are: # Use bulk DELETE with WHERE clause - single query
# 1. Expired (expires_at < now) AND inactive stmt = delete(UserSession).where(
# AND
# 2. Older than keep_days
count = db.query(UserSession).filter(
and_( and_(
UserSession.is_active == False, UserSession.is_active == False,
UserSession.expires_at < datetime.now(timezone.utc), UserSession.expires_at < now,
UserSession.created_at < cutoff_date UserSession.created_at < cutoff_date
) )
).delete() )
db.commit() result = await db.execute(stmt)
await db.commit()
count = result.rowcount
if count > 0: if count > 0:
logger.info(f"Cleaned up {count} expired sessions") logger.info(f"Cleaned up {count} expired sessions using bulk DELETE")
return count return count
except Exception as e: except Exception as e:
db.rollback() await db.rollback()
logger.error(f"Error cleaning up expired sessions: {str(e)}") logger.error(f"Error cleaning up expired sessions: {str(e)}")
raise raise
def get_user_session_count(self, db: Session, *, user_id: str) -> int: async def cleanup_expired_for_user(
self,
db: AsyncSession,
*,
user_id: str
) -> int:
"""
Clean up expired and inactive sessions for a specific user.
Uses single bulk DELETE query for efficiency instead of N individual deletes.
Args:
db: Database session
user_id: User ID to cleanup sessions for
Returns:
Number of sessions deleted
"""
try:
# Validate UUID
try:
uuid_obj = uuid.UUID(user_id)
except (ValueError, AttributeError):
logger.error(f"Invalid UUID format: {user_id}")
raise ValueError(f"Invalid user ID format: {user_id}")
now = datetime.now(timezone.utc)
# Use bulk DELETE with WHERE clause - single query
stmt = delete(UserSession).where(
and_(
UserSession.user_id == uuid_obj,
UserSession.is_active == False,
UserSession.expires_at < now
)
)
result = await db.execute(stmt)
await db.commit()
count = result.rowcount
if count > 0:
logger.info(
f"Cleaned up {count} expired sessions for user {user_id} using bulk DELETE"
)
return count
except Exception as e:
await db.rollback()
logger.error(
f"Error cleaning up expired sessions for user {user_id}: {str(e)}"
)
raise
async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int:
""" """
Get count of active sessions for a user. Get count of active sessions for a user.
@@ -325,12 +403,18 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
Number of active sessions Number of active sessions
""" """
try: try:
return db.query(UserSession).filter( # Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
result = await db.execute(
select(func.count(UserSession.id)).where(
and_( and_(
UserSession.user_id == user_id, UserSession.user_id == user_uuid,
UserSession.is_active == True UserSession.is_active == True
) )
).count() )
)
return result.scalar_one()
except Exception as e: except Exception as e:
logger.error(f"Error counting sessions for user {user_id}: {str(e)}") logger.error(f"Error counting sessions for user {user_id}: {str(e)}")
raise raise

View File

@@ -1,424 +0,0 @@
"""
Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns.
"""
import logging
from datetime import datetime, timezone, timedelta
from typing import List, Optional
from uuid import UUID
from sqlalchemy import and_, select, update, delete, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from app.crud.base_async import CRUDBaseAsync
from app.models.user_session import UserSession
from app.schemas.sessions import SessionCreate, SessionUpdate
logger = logging.getLogger(__name__)
class CRUDSessionAsync(CRUDBaseAsync[UserSession, SessionCreate, SessionUpdate]):
"""Async CRUD operations for user sessions."""
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]:
"""
Get session by refresh token JTI.
Args:
db: Database session
jti: Refresh token JWT ID
Returns:
UserSession if found, None otherwise
"""
try:
result = await db.execute(
select(UserSession).where(UserSession.refresh_token_jti == jti)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting session by JTI {jti}: {str(e)}")
raise
async def get_active_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]:
"""
Get active session by refresh token JTI.
Args:
db: Database session
jti: Refresh token JWT ID
Returns:
Active UserSession if found, None otherwise
"""
try:
result = await db.execute(
select(UserSession).where(
and_(
UserSession.refresh_token_jti == jti,
UserSession.is_active == True
)
)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting active session by JTI {jti}: {str(e)}")
raise
async def get_user_sessions(
self,
db: AsyncSession,
*,
user_id: str,
active_only: bool = True,
with_user: bool = False
) -> List[UserSession]:
"""
Get all sessions for a user with optional eager loading.
Args:
db: Database session
user_id: User ID
active_only: If True, return only active sessions
with_user: If True, eager load user relationship to prevent N+1
Returns:
List of UserSession objects
"""
try:
# Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
query = select(UserSession).where(UserSession.user_id == user_uuid)
# Add eager loading if requested to prevent N+1 queries
if with_user:
query = query.options(joinedload(UserSession.user))
if active_only:
query = query.where(UserSession.is_active == True)
query = query.order_by(UserSession.last_used_at.desc())
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting sessions for user {user_id}: {str(e)}")
raise
async def create_session(
self,
db: AsyncSession,
*,
obj_in: SessionCreate
) -> UserSession:
"""
Create a new user session.
Args:
db: Database session
obj_in: SessionCreate schema with session data
Returns:
Created UserSession
Raises:
ValueError: If session creation fails
"""
try:
db_obj = UserSession(
user_id=obj_in.user_id,
refresh_token_jti=obj_in.refresh_token_jti,
device_name=obj_in.device_name,
device_id=obj_in.device_id,
ip_address=obj_in.ip_address,
user_agent=obj_in.user_agent,
last_used_at=obj_in.last_used_at,
expires_at=obj_in.expires_at,
is_active=True,
location_city=obj_in.location_city,
location_country=obj_in.location_country,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.info(
f"Session created for user {obj_in.user_id} from {obj_in.device_name} "
f"(IP: {obj_in.ip_address})"
)
return db_obj
except Exception as e:
await db.rollback()
logger.error(f"Error creating session: {str(e)}", exc_info=True)
raise ValueError(f"Failed to create session: {str(e)}")
async def deactivate(self, db: AsyncSession, *, session_id: str) -> Optional[UserSession]:
"""
Deactivate a session (logout from device).
Args:
db: Database session
session_id: Session UUID
Returns:
Deactivated UserSession if found, None otherwise
"""
try:
session = await self.get(db, id=session_id)
if not session:
logger.warning(f"Session {session_id} not found for deactivation")
return None
session.is_active = False
db.add(session)
await db.commit()
await db.refresh(session)
logger.info(
f"Session {session_id} deactivated for user {session.user_id} "
f"({session.device_name})"
)
return session
except Exception as e:
await db.rollback()
logger.error(f"Error deactivating session {session_id}: {str(e)}")
raise
async def deactivate_all_user_sessions(
self,
db: AsyncSession,
*,
user_id: str
) -> int:
"""
Deactivate all active sessions for a user (logout from all devices).
Args:
db: Database session
user_id: User ID
Returns:
Number of sessions deactivated
"""
try:
# Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
stmt = (
update(UserSession)
.where(
and_(
UserSession.user_id == user_uuid,
UserSession.is_active == True
)
)
.values(is_active=False)
)
result = await db.execute(stmt)
await db.commit()
count = result.rowcount
logger.info(f"Deactivated {count} sessions for user {user_id}")
return count
except Exception as e:
await db.rollback()
logger.error(f"Error deactivating all sessions for user {user_id}: {str(e)}")
raise
async def update_last_used(
self,
db: AsyncSession,
*,
session: UserSession
) -> UserSession:
"""
Update the last_used_at timestamp for a session.
Args:
db: Database session
session: UserSession object
Returns:
Updated UserSession
"""
try:
session.last_used_at = datetime.now(timezone.utc)
db.add(session)
await db.commit()
await db.refresh(session)
return session
except Exception as e:
await db.rollback()
logger.error(f"Error updating last_used for session {session.id}: {str(e)}")
raise
async def update_refresh_token(
self,
db: AsyncSession,
*,
session: UserSession,
new_jti: str,
new_expires_at: datetime
) -> UserSession:
"""
Update session with new refresh token JTI and expiration.
Called during token refresh.
Args:
db: Database session
session: UserSession object
new_jti: New refresh token JTI
new_expires_at: New expiration datetime
Returns:
Updated UserSession
"""
try:
session.refresh_token_jti = new_jti
session.expires_at = new_expires_at
session.last_used_at = datetime.now(timezone.utc)
db.add(session)
await db.commit()
await db.refresh(session)
return session
except Exception as e:
await db.rollback()
logger.error(f"Error updating refresh token for session {session.id}: {str(e)}")
raise
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
"""
Clean up expired sessions using optimized bulk DELETE.
Deletes sessions that are:
- Expired AND inactive
- Older than keep_days
Uses single DELETE query instead of N individual deletes for efficiency.
Args:
db: Database session
keep_days: Keep inactive sessions for this many days (for audit)
Returns:
Number of sessions deleted
"""
try:
cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days)
now = datetime.now(timezone.utc)
# Use bulk DELETE with WHERE clause - single query
stmt = delete(UserSession).where(
and_(
UserSession.is_active == False,
UserSession.expires_at < now,
UserSession.created_at < cutoff_date
)
)
result = await db.execute(stmt)
await db.commit()
count = result.rowcount
if count > 0:
logger.info(f"Cleaned up {count} expired sessions using bulk DELETE")
return count
except Exception as e:
await db.rollback()
logger.error(f"Error cleaning up expired sessions: {str(e)}")
raise
async def cleanup_expired_for_user(
self,
db: AsyncSession,
*,
user_id: str
) -> int:
"""
Clean up expired and inactive sessions for a specific user.
Uses single bulk DELETE query for efficiency instead of N individual deletes.
Args:
db: Database session
user_id: User ID to cleanup sessions for
Returns:
Number of sessions deleted
"""
try:
# Validate UUID
try:
uuid_obj = uuid.UUID(user_id)
except (ValueError, AttributeError):
logger.error(f"Invalid UUID format: {user_id}")
raise ValueError(f"Invalid user ID format: {user_id}")
now = datetime.now(timezone.utc)
# Use bulk DELETE with WHERE clause - single query
stmt = delete(UserSession).where(
and_(
UserSession.user_id == uuid_obj,
UserSession.is_active == False,
UserSession.expires_at < now
)
)
result = await db.execute(stmt)
await db.commit()
count = result.rowcount
if count > 0:
logger.info(
f"Cleaned up {count} expired sessions for user {user_id} using bulk DELETE"
)
return count
except Exception as e:
await db.rollback()
logger.error(
f"Error cleaning up expired sessions for user {user_id}: {str(e)}"
)
raise
async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int:
"""
Get count of active sessions for a user.
Args:
db: Database session
user_id: User ID
Returns:
Number of active sessions
"""
try:
# Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
result = await db.execute(
select(func.count(UserSession.id)).where(
and_(
UserSession.user_id == user_uuid,
UserSession.is_active == True
)
)
)
return result.scalar_one()
except Exception as e:
logger.error(f"Error counting sessions for user {user_id}: {str(e)}")
raise
# Create singleton instance
session_async = CRUDSessionAsync(UserSession)

173
backend/app/crud/user.py Normal file → Executable file
View File

@@ -1,12 +1,15 @@
# app/crud/user.py # app/crud/user_async.py
"""Async CRUD operations for User model using SQLAlchemy 2.0 patterns."""
import logging import logging
from datetime import datetime, timezone
from typing import Optional, Union, Dict, Any, List, Tuple from typing import Optional, Union, Dict, Any, List, Tuple
from uuid import UUID
from sqlalchemy import or_, asc, desc from sqlalchemy import or_, select, update
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import get_password_hash from app.core.auth import get_password_hash_async
from app.crud.base import CRUDBase from app.crud.base import CRUDBase
from app.models.user import User from app.models.user import User
from app.schemas.users import UserCreate, UserUpdate from app.schemas.users import UserCreate, UserUpdate
@@ -15,15 +18,28 @@ logger = logging.getLogger(__name__)
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
def get_by_email(self, db: Session, *, email: str) -> Optional[User]: """Async CRUD operations for User model."""
return db.query(User).filter(User.email == email).first()
def create(self, db: Session, *, obj_in: UserCreate) -> User: async def get_by_email(self, db: AsyncSession, *, email: str) -> Optional[User]:
"""Create a new user with password hashing and error handling.""" """Get user by email address."""
try: try:
result = await db.execute(
select(User).where(User.email == email)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting user by email {email}: {str(e)}")
raise
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
"""Create a new user with async password hashing and error handling."""
try:
# Hash password asynchronously to avoid blocking event loop
password_hash = await get_password_hash_async(obj_in.password)
db_obj = User( db_obj = User(
email=obj_in.email, email=obj_in.email,
password_hash=get_password_hash(obj_in.password), password_hash=password_hash,
first_name=obj_in.first_name, first_name=obj_in.first_name,
last_name=obj_in.last_name, last_name=obj_in.last_name,
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None, phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None,
@@ -31,11 +47,11 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
preferences={} preferences={}
) )
db.add(db_obj) db.add(db_obj)
db.commit() await db.commit()
db.refresh(db_obj) await db.refresh(db_obj)
return db_obj return db_obj
except IntegrityError as e: except IntegrityError as e:
db.rollback() await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e) error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
if "email" in error_msg.lower(): if "email" in error_msg.lower():
logger.warning(f"Duplicate email attempted: {obj_in.email}") logger.warning(f"Duplicate email attempted: {obj_in.email}")
@@ -43,32 +59,34 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
logger.error(f"Integrity error creating user: {error_msg}") logger.error(f"Integrity error creating user: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}") raise ValueError(f"Database integrity error: {error_msg}")
except Exception as e: except Exception as e:
db.rollback() await db.rollback()
logger.error(f"Unexpected error creating user: {str(e)}", exc_info=True) logger.error(f"Unexpected error creating user: {str(e)}", exc_info=True)
raise raise
def update( async def update(
self, self,
db: Session, db: AsyncSession,
*, *,
db_obj: User, db_obj: User,
obj_in: Union[UserUpdate, Dict[str, Any]] obj_in: Union[UserUpdate, Dict[str, Any]]
) -> User: ) -> User:
"""Update user with async password hashing if password is updated."""
if isinstance(obj_in, dict): if isinstance(obj_in, dict):
update_data = obj_in update_data = obj_in
else: else:
update_data = obj_in.model_dump(exclude_unset=True) update_data = obj_in.model_dump(exclude_unset=True)
# Handle password separately if it exists in update data # Handle password separately if it exists in update data
# Hash password asynchronously to avoid blocking event loop
if "password" in update_data: if "password" in update_data:
update_data["password_hash"] = get_password_hash(update_data["password"]) update_data["password_hash"] = await get_password_hash_async(update_data["password"])
del update_data["password"] del update_data["password"]
return super().update(db, db_obj=db_obj, obj_in=update_data) return await super().update(db, db_obj=db_obj, obj_in=update_data)
def get_multi_with_total( async def get_multi_with_total(
self, self,
db: Session, db: AsyncSession,
*, *,
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
@@ -102,16 +120,16 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
try: try:
# Build base query # Build base query
query = db.query(User) query = select(User)
# Exclude soft-deleted users # Exclude soft-deleted users
query = query.filter(User.deleted_at.is_(None)) query = query.where(User.deleted_at.is_(None))
# Apply filters # Apply filters
if filters: if filters:
for field, value in filters.items(): for field, value in filters.items():
if hasattr(User, field) and value is not None: if hasattr(User, field) and value is not None:
query = query.filter(getattr(User, field) == value) query = query.where(getattr(User, field) == value)
# Apply search # Apply search
if search: if search:
@@ -120,21 +138,26 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
User.first_name.ilike(f"%{search}%"), User.first_name.ilike(f"%{search}%"),
User.last_name.ilike(f"%{search}%") User.last_name.ilike(f"%{search}%")
) )
query = query.filter(search_filter) query = query.where(search_filter)
# Get total count # Get total count
total = query.count() from sqlalchemy import func
count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply sorting # Apply sorting
if sort_by and hasattr(User, sort_by): if sort_by and hasattr(User, sort_by):
sort_column = getattr(User, sort_by) sort_column = getattr(User, sort_by)
if sort_order.lower() == "desc": if sort_order.lower() == "desc":
query = query.order_by(desc(sort_column)) query = query.order_by(sort_column.desc())
else: else:
query = query.order_by(asc(sort_column)) query = query.order_by(sort_column.asc())
# Apply pagination # Apply pagination
users = query.offset(skip).limit(limit).all() query = query.offset(skip).limit(limit)
result = await db.execute(query)
users = list(result.scalars().all())
return users, total return users, total
@@ -142,10 +165,106 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
logger.error(f"Error retrieving paginated users: {str(e)}") logger.error(f"Error retrieving paginated users: {str(e)}")
raise raise
async def bulk_update_status(
self,
db: AsyncSession,
*,
user_ids: List[UUID],
is_active: bool
) -> int:
"""
Bulk update is_active status for multiple users.
Args:
db: Database session
user_ids: List of user IDs to update
is_active: New active status
Returns:
Number of users updated
"""
try:
if not user_ids:
return 0
# Use UPDATE with WHERE IN for efficiency
stmt = (
update(User)
.where(User.id.in_(user_ids))
.where(User.deleted_at.is_(None)) # Don't update deleted users
.values(is_active=is_active, updated_at=datetime.now(timezone.utc))
)
result = await db.execute(stmt)
await db.commit()
updated_count = result.rowcount
logger.info(f"Bulk updated {updated_count} users to is_active={is_active}")
return updated_count
except Exception as e:
await db.rollback()
logger.error(f"Error bulk updating user status: {str(e)}", exc_info=True)
raise
async def bulk_soft_delete(
self,
db: AsyncSession,
*,
user_ids: List[UUID],
exclude_user_id: Optional[UUID] = None
) -> int:
"""
Bulk soft delete multiple users.
Args:
db: Database session
user_ids: List of user IDs to delete
exclude_user_id: Optional user ID to exclude (e.g., the admin performing the action)
Returns:
Number of users deleted
"""
try:
if not user_ids:
return 0
# Remove excluded user from list
filtered_ids = [uid for uid in user_ids if uid != exclude_user_id]
if not filtered_ids:
return 0
# Use UPDATE with WHERE IN for efficiency
stmt = (
update(User)
.where(User.id.in_(filtered_ids))
.where(User.deleted_at.is_(None)) # Don't re-delete already deleted users
.values(
deleted_at=datetime.now(timezone.utc),
is_active=False,
updated_at=datetime.now(timezone.utc)
)
)
result = await db.execute(stmt)
await db.commit()
deleted_count = result.rowcount
logger.info(f"Bulk soft deleted {deleted_count} users")
return deleted_count
except Exception as e:
await db.rollback()
logger.error(f"Error bulk deleting users: {str(e)}", exc_info=True)
raise
def is_active(self, user: User) -> bool: def is_active(self, user: User) -> bool:
"""Check if user is active."""
return user.is_active return user.is_active
def is_superuser(self, user: User) -> bool: def is_superuser(self, user: User) -> bool:
"""Check if user is a superuser."""
return user.is_superuser return user.is_superuser

View File

@@ -1,272 +0,0 @@
# app/crud/user_async.py
"""Async CRUD operations for User model using SQLAlchemy 2.0 patterns."""
import logging
from datetime import datetime, timezone
from typing import Optional, Union, Dict, Any, List, Tuple
from uuid import UUID
from sqlalchemy import or_, select, update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import get_password_hash_async
from app.crud.base_async import CRUDBaseAsync
from app.models.user import User
from app.schemas.users import UserCreate, UserUpdate
logger = logging.getLogger(__name__)
class CRUDUserAsync(CRUDBaseAsync[User, UserCreate, UserUpdate]):
"""Async CRUD operations for User model."""
async def get_by_email(self, db: AsyncSession, *, email: str) -> Optional[User]:
"""Get user by email address."""
try:
result = await db.execute(
select(User).where(User.email == email)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting user by email {email}: {str(e)}")
raise
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
"""Create a new user with async password hashing and error handling."""
try:
# Hash password asynchronously to avoid blocking event loop
password_hash = await get_password_hash_async(obj_in.password)
db_obj = User(
email=obj_in.email,
password_hash=password_hash,
first_name=obj_in.first_name,
last_name=obj_in.last_name,
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None,
is_superuser=obj_in.is_superuser if hasattr(obj_in, 'is_superuser') else False,
preferences={}
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
if "email" in error_msg.lower():
logger.warning(f"Duplicate email attempted: {obj_in.email}")
raise ValueError(f"User with email {obj_in.email} already exists")
logger.error(f"Integrity error creating user: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except Exception as e:
await db.rollback()
logger.error(f"Unexpected error creating user: {str(e)}", exc_info=True)
raise
async def update(
self,
db: AsyncSession,
*,
db_obj: User,
obj_in: Union[UserUpdate, Dict[str, Any]]
) -> User:
"""Update user with async password hashing if password is updated."""
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.model_dump(exclude_unset=True)
# Handle password separately if it exists in update data
# Hash password asynchronously to avoid blocking event loop
if "password" in update_data:
update_data["password_hash"] = await get_password_hash_async(update_data["password"])
del update_data["password"]
return await super().update(db, db_obj=db_obj, obj_in=update_data)
async def get_multi_with_total(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
sort_by: Optional[str] = None,
sort_order: str = "asc",
filters: Optional[Dict[str, Any]] = None,
search: Optional[str] = None
) -> Tuple[List[User], int]:
"""
Get multiple users with total count, filtering, sorting, and search.
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
sort_by: Field name to sort by
sort_order: Sort order ("asc" or "desc")
filters: Dictionary of filters (field_name: value)
search: Search term to match against email, first_name, last_name
Returns:
Tuple of (users list, total count)
"""
# Validate pagination
if skip < 0:
raise ValueError("skip must be non-negative")
if limit < 0:
raise ValueError("limit must be non-negative")
if limit > 1000:
raise ValueError("Maximum limit is 1000")
try:
# Build base query
query = select(User)
# Exclude soft-deleted users
query = query.where(User.deleted_at.is_(None))
# Apply filters
if filters:
for field, value in filters.items():
if hasattr(User, field) and value is not None:
query = query.where(getattr(User, field) == value)
# Apply search
if search:
search_filter = or_(
User.email.ilike(f"%{search}%"),
User.first_name.ilike(f"%{search}%"),
User.last_name.ilike(f"%{search}%")
)
query = query.where(search_filter)
# Get total count
from sqlalchemy import func
count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply sorting
if sort_by and hasattr(User, sort_by):
sort_column = getattr(User, sort_by)
if sort_order.lower() == "desc":
query = query.order_by(sort_column.desc())
else:
query = query.order_by(sort_column.asc())
# Apply pagination
query = query.offset(skip).limit(limit)
result = await db.execute(query)
users = list(result.scalars().all())
return users, total
except Exception as e:
logger.error(f"Error retrieving paginated users: {str(e)}")
raise
async def bulk_update_status(
self,
db: AsyncSession,
*,
user_ids: List[UUID],
is_active: bool
) -> int:
"""
Bulk update is_active status for multiple users.
Args:
db: Database session
user_ids: List of user IDs to update
is_active: New active status
Returns:
Number of users updated
"""
try:
if not user_ids:
return 0
# Use UPDATE with WHERE IN for efficiency
stmt = (
update(User)
.where(User.id.in_(user_ids))
.where(User.deleted_at.is_(None)) # Don't update deleted users
.values(is_active=is_active, updated_at=datetime.now(timezone.utc))
)
result = await db.execute(stmt)
await db.commit()
updated_count = result.rowcount
logger.info(f"Bulk updated {updated_count} users to is_active={is_active}")
return updated_count
except Exception as e:
await db.rollback()
logger.error(f"Error bulk updating user status: {str(e)}", exc_info=True)
raise
async def bulk_soft_delete(
self,
db: AsyncSession,
*,
user_ids: List[UUID],
exclude_user_id: Optional[UUID] = None
) -> int:
"""
Bulk soft delete multiple users.
Args:
db: Database session
user_ids: List of user IDs to delete
exclude_user_id: Optional user ID to exclude (e.g., the admin performing the action)
Returns:
Number of users deleted
"""
try:
if not user_ids:
return 0
# Remove excluded user from list
filtered_ids = [uid for uid in user_ids if uid != exclude_user_id]
if not filtered_ids:
return 0
# Use UPDATE with WHERE IN for efficiency
stmt = (
update(User)
.where(User.id.in_(filtered_ids))
.where(User.deleted_at.is_(None)) # Don't re-delete already deleted users
.values(
deleted_at=datetime.now(timezone.utc),
is_active=False,
updated_at=datetime.now(timezone.utc)
)
)
result = await db.execute(stmt)
await db.commit()
deleted_count = result.rowcount
logger.info(f"Bulk soft deleted {deleted_count} users")
return deleted_count
except Exception as e:
await db.rollback()
logger.error(f"Error bulk deleting users: {str(e)}", exc_info=True)
raise
def is_active(self, user: User) -> bool:
"""Check if user is active."""
return user.is_active
def is_superuser(self, user: User) -> bool:
"""Check if user is a superuser."""
return user.is_superuser
# Create a singleton instance for use across the application
user_async = CRUDUserAsync(User)

View File

@@ -1,78 +0,0 @@
# app/init_db.py
import logging
from typing import Optional
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.database import engine
from app.crud.user import user as user_crud
from app.schemas.users import UserCreate
logger = logging.getLogger(__name__)
def init_db(db: Session) -> Optional[UserCreate]:
"""
Initialize database with first superuser if settings are configured and user doesn't exist.
Returns:
The created or existing superuser, or None if creation fails
"""
# Use default values if not set in environment variables
superuser_email = settings.FIRST_SUPERUSER_EMAIL or "admin@example.com"
superuser_password = settings.FIRST_SUPERUSER_PASSWORD or "Admin123!Change"
if not settings.FIRST_SUPERUSER_EMAIL or not settings.FIRST_SUPERUSER_PASSWORD:
logger.warning(
"First superuser credentials not configured in settings. "
f"Using defaults: {superuser_email}"
)
try:
# Check if superuser already exists
existing_user = user_crud.get_by_email(db, email=superuser_email)
if existing_user:
logger.info(f"Superuser already exists: {existing_user.email}")
return existing_user
# Create superuser if doesn't exist
user_in = UserCreate(
email=superuser_email,
password=superuser_password,
first_name="Admin",
last_name="User",
is_superuser=True
)
user = user_crud.create(db, obj_in=user_in)
logger.info(f"Created first superuser: {user.email}")
return user
except Exception as e:
logger.error(f"Error initializing database: {e}")
raise
if __name__ == "__main__":
# Configure logging to show info logs
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
with Session(engine) as session:
try:
user = init_db(session)
if user:
print(f"✓ Database initialized successfully")
print(f"✓ Superuser: {user.email}")
else:
print("✗ Failed to initialize database")
except Exception as e:
print(f"✗ Error initializing database: {e}")
raise
finally:
session.close()

View File

@@ -13,7 +13,7 @@ from slowapi.util import get_remote_address
from app.api.main import api_router from app.api.main import api_router
from app.core.config import settings from app.core.config import settings
from app.core.database_async import check_database_health from app.core.database import check_database_health
from app.core.exceptions import ( from app.core.exceptions import (
APIException, APIException,
api_exception_handler, api_exception_handler,

View File

@@ -6,8 +6,8 @@ This service runs periodically to remove old session records from the database.
import logging import logging
from datetime import datetime, timezone from datetime import datetime, timezone
from app.core.database_async import AsyncSessionLocal from app.core.database import SessionLocal
from app.crud.session_async import session_async as session_crud from app.crud.session import session as session_crud
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -29,7 +29,7 @@ async def cleanup_expired_sessions(keep_days: int = 30) -> int:
""" """
logger.info("Starting session cleanup job...") logger.info("Starting session cleanup job...")
async with AsyncSessionLocal() as db: async with SessionLocal() as db:
try: try:
# Use CRUD method to cleanup # Use CRUD method to cleanup
count = await session_crud.cleanup_expired(db, keep_days=keep_days) count = await session_crud.cleanup_expired(db, keep_days=keep_days)
@@ -50,7 +50,7 @@ async def get_session_statistics() -> dict:
Returns: Returns:
Dictionary with session stats Dictionary with session stats
""" """
async with AsyncSessionLocal() as db: async with SessionLocal() as db:
try: try:
from app.models.user_session import UserSession from app.models.user_session import UserSession
from sqlalchemy import select, func from sqlalchemy import select, func