Remove unused async database and CRUD modules
- Deleted `database_async.py`, `base_async.py`, and `organization_async.py` modules due to deprecation and unused references across the project. - Improved overall codebase clarity and minimized redundant functionality by removing unused async database logic, CRUD utilities, and organization-related operations.
This commit is contained in:
@@ -7,7 +7,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.auth import get_token_data, TokenExpiredError, TokenInvalidError
|
||||
from app.core.database_async import get_async_db
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
|
||||
# OAuth2 configuration
|
||||
@@ -15,7 +15,7 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
token: str = Depends(oauth2_scheme)
|
||||
) -> User:
|
||||
"""
|
||||
@@ -139,7 +139,7 @@ async def get_optional_token(authorization: str = Header(None)) -> Optional[str]
|
||||
|
||||
|
||||
async def get_optional_current_user(
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
token: Optional[str] = Depends(get_optional_token)
|
||||
) -> Optional[User]:
|
||||
"""
|
||||
|
||||
@@ -14,8 +14,8 @@ from fastapi import Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.core.database_async import get_async_db
|
||||
from app.crud.organization_async import organization_async as organization_crud
|
||||
from app.core.database import get_db
|
||||
from app.crud.organization import organization as organization_crud
|
||||
from app.models.user import User
|
||||
from app.models.user_organization import OrganizationRole
|
||||
|
||||
@@ -78,7 +78,7 @@ class OrganizationPermission:
|
||||
self,
|
||||
organization_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> User:
|
||||
"""
|
||||
Check if user has required role in the organization.
|
||||
@@ -133,7 +133,7 @@ require_org_member = OrganizationPermission([
|
||||
async def get_current_org_role(
|
||||
organization_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Optional[OrganizationRole]:
|
||||
"""
|
||||
Get the current user's role in an organization.
|
||||
@@ -164,7 +164,7 @@ async def get_current_org_role(
|
||||
async def require_org_membership(
|
||||
organization_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> User:
|
||||
"""
|
||||
Ensure user is a member of the organization (any role).
|
||||
|
||||
@@ -15,10 +15,10 @@ from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.permissions import require_superuser
|
||||
from app.core.database_async import get_async_db
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import NotFoundError, DuplicateError, AuthorizationError, ErrorCode
|
||||
from app.crud.organization_async import organization_async as organization_crud
|
||||
from app.crud.user_async import user_async as user_crud
|
||||
from app.crud.organization import organization as organization_crud
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.user import User
|
||||
from app.models.user_organization import OrganizationRole
|
||||
from app.schemas.common import (
|
||||
@@ -80,7 +80,7 @@ async def admin_list_users(
|
||||
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"),
|
||||
search: Optional[str] = Query(None, description="Search by email, name"),
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
List all users with comprehensive filtering and search.
|
||||
@@ -131,7 +131,7 @@ async def admin_list_users(
|
||||
async def admin_create_user(
|
||||
user_in: UserCreate,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Create a new user with admin privileges.
|
||||
@@ -163,7 +163,7 @@ async def admin_create_user(
|
||||
async def admin_get_user(
|
||||
user_id: UUID,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""Get detailed information about a specific user."""
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
@@ -186,7 +186,7 @@ async def admin_update_user(
|
||||
user_id: UUID,
|
||||
user_in: UserUpdate,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""Update user information with admin privileges."""
|
||||
try:
|
||||
@@ -218,7 +218,7 @@ async def admin_update_user(
|
||||
async def admin_delete_user(
|
||||
user_id: UUID,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""Soft delete a user (sets deleted_at timestamp)."""
|
||||
try:
|
||||
@@ -262,7 +262,7 @@ async def admin_delete_user(
|
||||
async def admin_activate_user(
|
||||
user_id: UUID,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""Activate a user account."""
|
||||
try:
|
||||
@@ -298,7 +298,7 @@ async def admin_activate_user(
|
||||
async def admin_deactivate_user(
|
||||
user_id: UUID,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""Deactivate a user account."""
|
||||
try:
|
||||
@@ -342,7 +342,7 @@ async def admin_deactivate_user(
|
||||
async def admin_bulk_user_action(
|
||||
bulk_action: BulkUserAction,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Perform bulk actions on multiple users using optimized bulk operations.
|
||||
@@ -410,7 +410,7 @@ async def admin_list_organizations(
|
||||
is_active: Optional[bool] = Query(None, description="Filter by active status"),
|
||||
search: Optional[str] = Query(None, description="Search by name, slug, description"),
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""List all organizations with filtering and search."""
|
||||
try:
|
||||
@@ -467,7 +467,7 @@ async def admin_list_organizations(
|
||||
async def admin_create_organization(
|
||||
org_in: OrganizationCreate,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""Create a new organization."""
|
||||
try:
|
||||
@@ -509,7 +509,7 @@ async def admin_create_organization(
|
||||
async def admin_get_organization(
|
||||
org_id: UUID,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""Get detailed information about a specific organization."""
|
||||
org = await organization_crud.get(db, id=org_id)
|
||||
@@ -544,7 +544,7 @@ async def admin_update_organization(
|
||||
org_id: UUID,
|
||||
org_in: OrganizationUpdate,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""Update organization information."""
|
||||
try:
|
||||
@@ -588,7 +588,7 @@ async def admin_update_organization(
|
||||
async def admin_delete_organization(
|
||||
org_id: UUID,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""Delete an organization and all its relationships."""
|
||||
try:
|
||||
@@ -626,7 +626,7 @@ async def admin_list_organization_members(
|
||||
pagination: PaginationParams = Depends(),
|
||||
is_active: Optional[bool] = Query(True, description="Filter by active status"),
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""List all members of an organization."""
|
||||
try:
|
||||
@@ -681,7 +681,7 @@ async def admin_add_organization_member(
|
||||
org_id: UUID,
|
||||
request: AddMemberRequest,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""Add a user to an organization."""
|
||||
try:
|
||||
@@ -742,7 +742,7 @@ async def admin_remove_organization_member(
|
||||
org_id: UUID,
|
||||
user_id: UUID,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""Remove a user from an organization."""
|
||||
try:
|
||||
|
||||
@@ -13,14 +13,14 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.core.auth import TokenExpiredError, TokenInvalidError, decode_token
|
||||
from app.core.auth import get_password_hash
|
||||
from app.core.database_async import get_async_db
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import (
|
||||
AuthenticationError as AuthError,
|
||||
DatabaseError,
|
||||
ErrorCode
|
||||
)
|
||||
from app.crud.session_async import session_async as session_crud
|
||||
from app.crud.user_async import user_async as user_crud
|
||||
from app.crud.session import session as session_crud
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.common import MessageResponse
|
||||
from app.schemas.sessions import SessionCreate, LogoutRequest
|
||||
@@ -54,7 +54,7 @@ RATE_MULTIPLIER = 100 if IS_TEST else 1
|
||||
async def register_user(
|
||||
request: Request,
|
||||
user_data: UserCreate,
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Register a new user.
|
||||
@@ -85,7 +85,7 @@ async def register_user(
|
||||
async def login(
|
||||
request: Request,
|
||||
login_data: LoginRequest,
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Login with username and password.
|
||||
@@ -167,7 +167,7 @@ async def login(
|
||||
async def login_oauth(
|
||||
request: Request,
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
OAuth2-compatible login endpoint, used by the OpenAPI UI.
|
||||
@@ -244,7 +244,7 @@ async def login_oauth(
|
||||
async def refresh_token(
|
||||
request: Request,
|
||||
refresh_data: RefreshTokenRequest,
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Refresh access token using a refresh token.
|
||||
@@ -333,7 +333,7 @@ async def refresh_token(
|
||||
async def request_password_reset(
|
||||
request: Request,
|
||||
reset_request: PasswordResetRequest,
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Request a password reset.
|
||||
@@ -391,7 +391,7 @@ async def request_password_reset(
|
||||
async def confirm_password_reset(
|
||||
request: Request,
|
||||
reset_confirm: PasswordResetConfirm,
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Confirm password reset with token.
|
||||
@@ -430,7 +430,7 @@ async def confirm_password_reset(
|
||||
|
||||
# SECURITY: Invalidate all existing sessions after password reset
|
||||
# This prevents stolen sessions from being used after password change
|
||||
from app.crud.session_async import session_async as session_crud
|
||||
from app.crud.session import session as session_crud
|
||||
try:
|
||||
deactivated_count = await session_crud.deactivate_all_user_sessions(
|
||||
db,
|
||||
@@ -478,7 +478,7 @@ async def logout(
|
||||
request: Request,
|
||||
logout_request: LogoutRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Logout from current device by deactivating the session.
|
||||
@@ -566,7 +566,7 @@ async def logout(
|
||||
async def logout_all(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Logout from all devices by deactivating all user sessions.
|
||||
|
||||
@@ -13,9 +13,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.api.dependencies.permissions import require_org_admin, require_org_membership
|
||||
from app.core.database_async import get_async_db
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import NotFoundError, ErrorCode
|
||||
from app.crud.organization_async import organization_async as organization_crud
|
||||
from app.crud.organization import organization as organization_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.common import (
|
||||
PaginationParams,
|
||||
@@ -43,7 +43,7 @@ router = APIRouter()
|
||||
async def get_my_organizations(
|
||||
is_active: bool = Query(True, description="Filter by active membership"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Get all organizations the current user belongs to.
|
||||
@@ -93,7 +93,7 @@ async def get_my_organizations(
|
||||
async def get_organization(
|
||||
organization_id: UUID,
|
||||
current_user: User = Depends(require_org_membership),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Get details of a specific organization.
|
||||
@@ -140,7 +140,7 @@ async def get_organization_members(
|
||||
pagination: PaginationParams = Depends(),
|
||||
is_active: bool = Query(True, description="Filter by active status"),
|
||||
current_user: User = Depends(require_org_membership),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Get all members of an organization.
|
||||
@@ -183,7 +183,7 @@ async def update_organization(
|
||||
organization_id: UUID,
|
||||
org_in: OrganizationUpdate,
|
||||
current_user: User = Depends(require_org_admin),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Update organization details.
|
||||
|
||||
@@ -14,9 +14,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.core.auth import decode_token
|
||||
from app.core.database_async import get_async_db
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode
|
||||
from app.crud.session_async import session_async as session_crud
|
||||
from app.crud.session import session as session_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.common import MessageResponse
|
||||
from app.schemas.sessions import SessionResponse, SessionListResponse
|
||||
@@ -45,7 +45,7 @@ limiter = Limiter(key_func=get_remote_address)
|
||||
async def list_my_sessions(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
List all active sessions for the current user.
|
||||
@@ -129,7 +129,7 @@ async def revoke_session(
|
||||
request: Request,
|
||||
session_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Revoke a specific session by ID.
|
||||
@@ -204,7 +204,7 @@ async def revoke_session(
|
||||
async def cleanup_expired_sessions(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Cleanup expired sessions for the current user.
|
||||
|
||||
@@ -11,13 +11,13 @@ from slowapi.util import get_remote_address
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.auth import get_current_user, get_current_superuser
|
||||
from app.core.database_async import get_async_db
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import (
|
||||
NotFoundError,
|
||||
AuthorizationError,
|
||||
ErrorCode
|
||||
)
|
||||
from app.crud.user_async import user_async as user_crud
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.common import (
|
||||
PaginationParams,
|
||||
@@ -58,7 +58,7 @@ async def list_users(
|
||||
is_active: Optional[bool] = Query(None, description="Filter by active status"),
|
||||
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
List all users with pagination, filtering, and sorting.
|
||||
@@ -138,7 +138,7 @@ def get_current_user_profile(
|
||||
async def update_current_user(
|
||||
user_update: UserUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Update current user's profile.
|
||||
@@ -188,7 +188,7 @@ async def update_current_user(
|
||||
async def get_user_by_id(
|
||||
user_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Get user by ID.
|
||||
@@ -236,7 +236,7 @@ async def update_user(
|
||||
user_id: UUID,
|
||||
user_update: UserUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Update user by ID.
|
||||
@@ -304,7 +304,7 @@ async def change_current_user_password(
|
||||
request: Request,
|
||||
password_change: PasswordChange,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Change current user's password.
|
||||
@@ -356,7 +356,7 @@ async def change_current_user_password(
|
||||
async def delete_user(
|
||||
user_id: UUID,
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Delete user by ID (superuser only).
|
||||
|
||||
207
backend/app/core/database.py
Normal file → Executable file
207
backend/app/core/database.py
Normal file → Executable file
@@ -1,113 +1,186 @@
|
||||
# app/core/database.py
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
"""
|
||||
Database configuration using SQLAlchemy 2.0 and asyncpg.
|
||||
|
||||
from sqlalchemy import create_engine, text
|
||||
This module provides async database connectivity with proper connection pooling
|
||||
and session management for FastAPI endpoints.
|
||||
"""
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncSession,
|
||||
AsyncEngine,
|
||||
create_async_engine,
|
||||
async_sessionmaker,
|
||||
)
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# SQLite compatibility for testing
|
||||
@compiles(JSONB, 'sqlite')
|
||||
def compile_jsonb_sqlite(type_, compiler, **kw):
|
||||
return "TEXT"
|
||||
|
||||
|
||||
@compiles(UUID, 'sqlite')
|
||||
def compile_uuid_sqlite(type_, compiler, **kw):
|
||||
return "TEXT"
|
||||
|
||||
# Declarative base for models
|
||||
Base = declarative_base()
|
||||
|
||||
# Create engine with optimized settings for PostgreSQL
|
||||
def create_production_engine():
|
||||
return create_engine(
|
||||
settings.database_url,
|
||||
# Connection pool settings
|
||||
pool_size=settings.db_pool_size,
|
||||
max_overflow=settings.db_max_overflow,
|
||||
pool_timeout=settings.db_pool_timeout,
|
||||
pool_recycle=settings.db_pool_recycle,
|
||||
pool_pre_ping=True,
|
||||
# Query execution settings
|
||||
connect_args={
|
||||
"application_name": "eventspace",
|
||||
"keepalives": 1,
|
||||
"keepalives_idle": 60,
|
||||
"keepalives_interval": 10,
|
||||
"keepalives_count": 5,
|
||||
"options": "-c timezone=UTC",
|
||||
},
|
||||
isolation_level="READ COMMITTED",
|
||||
echo=settings.sql_echo,
|
||||
echo_pool=settings.sql_echo_pool,
|
||||
)
|
||||
# Declarative base for models (SQLAlchemy 2.0 style)
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all database models."""
|
||||
pass
|
||||
|
||||
# Default production engine and session factory
|
||||
engine = create_production_engine()
|
||||
SessionLocal = sessionmaker(
|
||||
|
||||
def get_async_database_url(url: str) -> str:
|
||||
"""
|
||||
Convert sync database URL to async URL.
|
||||
|
||||
postgresql:// -> postgresql+asyncpg://
|
||||
sqlite:// -> sqlite+aiosqlite://
|
||||
"""
|
||||
if url.startswith("postgresql://"):
|
||||
return url.replace("postgresql://", "postgresql+asyncpg://")
|
||||
elif url.startswith("sqlite://"):
|
||||
return url.replace("sqlite://", "sqlite+aiosqlite://")
|
||||
return url
|
||||
|
||||
|
||||
# Create async engine with optimized settings
|
||||
def create_async_production_engine() -> AsyncEngine:
|
||||
"""Create an async database engine with production settings."""
|
||||
async_url = get_async_database_url(settings.database_url)
|
||||
|
||||
# Base engine config
|
||||
engine_config = {
|
||||
"pool_size": settings.db_pool_size,
|
||||
"max_overflow": settings.db_max_overflow,
|
||||
"pool_timeout": settings.db_pool_timeout,
|
||||
"pool_recycle": settings.db_pool_recycle,
|
||||
"pool_pre_ping": True,
|
||||
"echo": settings.sql_echo,
|
||||
"echo_pool": settings.sql_echo_pool,
|
||||
}
|
||||
|
||||
# Add PostgreSQL-specific connect_args
|
||||
if "postgresql" in async_url:
|
||||
engine_config["connect_args"] = {
|
||||
"server_settings": {
|
||||
"application_name": "eventspace",
|
||||
"timezone": "UTC",
|
||||
},
|
||||
# asyncpg-specific settings
|
||||
"command_timeout": 60,
|
||||
"timeout": 10,
|
||||
}
|
||||
|
||||
return create_async_engine(async_url, **engine_config)
|
||||
|
||||
|
||||
# Create async engine and session factory
|
||||
engine = create_async_production_engine()
|
||||
SessionLocal = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=engine,
|
||||
expire_on_commit=False # Prevent unnecessary queries after commit
|
||||
expire_on_commit=False, # Prevent unnecessary queries after commit
|
||||
)
|
||||
|
||||
# FastAPI dependency
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
|
||||
# FastAPI dependency for async database sessions
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
FastAPI dependency that provides a database session.
|
||||
FastAPI dependency that provides an async database session.
|
||||
Automatically closes the session after the request completes.
|
||||
|
||||
Usage:
|
||||
@router.get("/users")
|
||||
async def get_users(db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(User))
|
||||
return result.scalars().all()
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
async with SessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def transaction_scope() -> Generator[Session, None, None]:
|
||||
@asynccontextmanager
|
||||
async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Provide a transactional scope for database operations.
|
||||
Provide an async transactional scope for database operations.
|
||||
|
||||
Automatically commits on success or rolls back on exception.
|
||||
Useful for grouping multiple operations in a single transaction.
|
||||
|
||||
Usage:
|
||||
with transaction_scope() as db:
|
||||
user = user_crud.create(db, obj_in=user_create)
|
||||
profile = profile_crud.create(db, obj_in=profile_create)
|
||||
async with async_transaction_scope() as db:
|
||||
user = await user_crud.create(db, obj_in=user_create)
|
||||
profile = await profile_crud.create(db, obj_in=profile_create)
|
||||
# Both operations committed together
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
db.commit()
|
||||
logger.debug("Transaction committed successfully")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Transaction failed, rolling back: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
async with SessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
logger.debug("Async transaction committed successfully")
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Async transaction failed, rolling back: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
def check_database_health() -> bool:
|
||||
async def check_async_database_health() -> bool:
|
||||
"""
|
||||
Check if database connection is healthy.
|
||||
Check if async database connection is healthy.
|
||||
Returns True if connection is successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
with transaction_scope() as db:
|
||||
db.execute(text("SELECT 1"))
|
||||
async with async_transaction_scope() as db:
|
||||
await db.execute(text("SELECT 1"))
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Database health check failed: {str(e)}")
|
||||
return False
|
||||
logger.error(f"Async database health check failed: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
# Alias for consistency with main.py
|
||||
check_database_health = check_async_database_health
|
||||
|
||||
|
||||
async def init_async_db() -> None:
|
||||
"""
|
||||
Initialize async database tables.
|
||||
|
||||
This creates all tables defined in the models.
|
||||
Should only be used in development or testing.
|
||||
In production, use Alembic migrations.
|
||||
"""
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
logger.info("Async database tables created")
|
||||
|
||||
|
||||
async def close_async_db() -> None:
|
||||
"""
|
||||
Close all async database connections.
|
||||
|
||||
Should be called during application shutdown.
|
||||
"""
|
||||
await engine.dispose()
|
||||
logger.info("Async database connections closed")
|
||||
|
||||
@@ -1,186 +0,0 @@
|
||||
# app/core/database_async.py
|
||||
"""
|
||||
Async database configuration using SQLAlchemy 2.0 and asyncpg.
|
||||
|
||||
This module provides async database connectivity with proper connection pooling
|
||||
and session management for FastAPI endpoints.
|
||||
"""
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncSession,
|
||||
AsyncEngine,
|
||||
create_async_engine,
|
||||
async_sessionmaker,
|
||||
)
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# SQLite compatibility for testing
|
||||
@compiles(JSONB, 'sqlite')
|
||||
def compile_jsonb_sqlite(type_, compiler, **kw):
|
||||
return "TEXT"
|
||||
|
||||
|
||||
@compiles(UUID, 'sqlite')
|
||||
def compile_uuid_sqlite(type_, compiler, **kw):
|
||||
return "TEXT"
|
||||
|
||||
|
||||
# Declarative base for models (SQLAlchemy 2.0 style)
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all database models."""
|
||||
pass
|
||||
|
||||
|
||||
def get_async_database_url(url: str) -> str:
|
||||
"""
|
||||
Convert sync database URL to async URL.
|
||||
|
||||
postgresql:// -> postgresql+asyncpg://
|
||||
sqlite:// -> sqlite+aiosqlite://
|
||||
"""
|
||||
if url.startswith("postgresql://"):
|
||||
return url.replace("postgresql://", "postgresql+asyncpg://")
|
||||
elif url.startswith("sqlite://"):
|
||||
return url.replace("sqlite://", "sqlite+aiosqlite://")
|
||||
return url
|
||||
|
||||
|
||||
# Create async engine with optimized settings
|
||||
def create_async_production_engine() -> AsyncEngine:
|
||||
"""Create an async database engine with production settings."""
|
||||
async_url = get_async_database_url(settings.database_url)
|
||||
|
||||
# Base engine config
|
||||
engine_config = {
|
||||
"pool_size": settings.db_pool_size,
|
||||
"max_overflow": settings.db_max_overflow,
|
||||
"pool_timeout": settings.db_pool_timeout,
|
||||
"pool_recycle": settings.db_pool_recycle,
|
||||
"pool_pre_ping": True,
|
||||
"echo": settings.sql_echo,
|
||||
"echo_pool": settings.sql_echo_pool,
|
||||
}
|
||||
|
||||
# Add PostgreSQL-specific connect_args
|
||||
if "postgresql" in async_url:
|
||||
engine_config["connect_args"] = {
|
||||
"server_settings": {
|
||||
"application_name": "eventspace",
|
||||
"timezone": "UTC",
|
||||
},
|
||||
# asyncpg-specific settings
|
||||
"command_timeout": 60,
|
||||
"timeout": 10,
|
||||
}
|
||||
|
||||
return create_async_engine(async_url, **engine_config)
|
||||
|
||||
|
||||
# Create async engine and session factory
|
||||
async_engine = create_async_production_engine()
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
async_engine,
|
||||
class_=AsyncSession,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
expire_on_commit=False, # Prevent unnecessary queries after commit
|
||||
)
|
||||
|
||||
|
||||
# FastAPI dependency for async database sessions
|
||||
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
FastAPI dependency that provides an async database session.
|
||||
Automatically closes the session after the request completes.
|
||||
|
||||
Usage:
|
||||
@router.get("/users")
|
||||
async def get_users(db: AsyncSession = Depends(get_async_db)):
|
||||
result = await db.execute(select(User))
|
||||
return result.scalars().all()
|
||||
"""
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Provide an async transactional scope for database operations.
|
||||
|
||||
Automatically commits on success or rolls back on exception.
|
||||
Useful for grouping multiple operations in a single transaction.
|
||||
|
||||
Usage:
|
||||
async with async_transaction_scope() as db:
|
||||
user = await user_crud.create(db, obj_in=user_create)
|
||||
profile = await profile_crud.create(db, obj_in=profile_create)
|
||||
# Both operations committed together
|
||||
"""
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
logger.debug("Async transaction committed successfully")
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Async transaction failed, rolling back: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def check_async_database_health() -> bool:
|
||||
"""
|
||||
Check if async database connection is healthy.
|
||||
Returns True if connection is successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
async with async_transaction_scope() as db:
|
||||
await db.execute(text("SELECT 1"))
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Async database health check failed: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
# Alias for consistency with main.py
|
||||
check_database_health = check_async_database_health
|
||||
|
||||
|
||||
async def init_async_db() -> None:
|
||||
"""
|
||||
Initialize async database tables.
|
||||
|
||||
This creates all tables defined in the models.
|
||||
Should only be used in development or testing.
|
||||
In production, use Alembic migrations.
|
||||
"""
|
||||
async with async_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
logger.info("Async database tables created")
|
||||
|
||||
|
||||
async def close_async_db() -> None:
|
||||
"""
|
||||
Close all async database connections.
|
||||
|
||||
Should be called during application shutdown.
|
||||
"""
|
||||
await async_engine.dispose()
|
||||
logger.info("Async database connections closed")
|
||||
207
backend/app/crud/base.py
Normal file → Executable file
207
backend/app/crud/base.py
Normal file → Executable file
@@ -1,13 +1,19 @@
|
||||
# app/crud/base_async.py
|
||||
"""
|
||||
Async CRUD operations base class using SQLAlchemy 2.0 async patterns.
|
||||
|
||||
Provides reusable create, read, update, and delete operations for all models.
|
||||
"""
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import asc, desc
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Load
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
@@ -19,17 +25,40 @@ UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
|
||||
|
||||
|
||||
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
"""Async CRUD operations for a model."""
|
||||
|
||||
def __init__(self, model: Type[ModelType]):
|
||||
"""
|
||||
CRUD object with default methods to Create, Read, Update, Delete (CRUD).
|
||||
CRUD object with default async methods to Create, Read, Update, Delete.
|
||||
|
||||
Parameters:
|
||||
model: A SQLAlchemy model class
|
||||
"""
|
||||
self.model = model
|
||||
|
||||
def get(self, db: Session, id: str) -> Optional[ModelType]:
|
||||
"""Get a single record by ID with UUID validation."""
|
||||
async def get(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
id: str,
|
||||
options: Optional[List[Load]] = None
|
||||
) -> Optional[ModelType]:
|
||||
"""
|
||||
Get a single record by ID with UUID validation and optional eager loading.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
id: Record UUID
|
||||
options: Optional list of SQLAlchemy load options (e.g., joinedload, selectinload)
|
||||
for eager loading relationships to prevent N+1 queries
|
||||
|
||||
Returns:
|
||||
Model instance or None if not found
|
||||
|
||||
Example:
|
||||
# Eager load user relationship
|
||||
from sqlalchemy.orm import joinedload
|
||||
session = await session_crud.get(db, id=session_id, options=[joinedload(UserSession.user)])
|
||||
"""
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
@@ -41,15 +70,39 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
return None
|
||||
|
||||
try:
|
||||
return db.query(self.model).filter(self.model.id == uuid_obj).first()
|
||||
query = select(self.model).where(self.model.id == uuid_obj)
|
||||
|
||||
# Apply eager loading options if provided
|
||||
if options:
|
||||
for option in options:
|
||||
query = query.options(option)
|
||||
|
||||
result = await db.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_multi(
|
||||
self, db: Session, *, skip: int = 0, limit: int = 100
|
||||
async def get_multi(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
options: Optional[List[Load]] = None
|
||||
) -> List[ModelType]:
|
||||
"""Get multiple records with pagination validation."""
|
||||
"""
|
||||
Get multiple records with pagination validation and optional eager loading.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
skip: Number of records to skip
|
||||
limit: Maximum number of records to return
|
||||
options: Optional list of SQLAlchemy load options for eager loading
|
||||
|
||||
Returns:
|
||||
List of model instances
|
||||
"""
|
||||
# Validate pagination parameters
|
||||
if skip < 0:
|
||||
raise ValueError("skip must be non-negative")
|
||||
@@ -59,22 +112,30 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
raise ValueError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
return db.query(self.model).offset(skip).limit(limit).all()
|
||||
query = select(self.model).offset(skip).limit(limit)
|
||||
|
||||
# Apply eager loading options if provided
|
||||
if options:
|
||||
for option in options:
|
||||
query = query.options(option)
|
||||
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}")
|
||||
raise
|
||||
|
||||
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
|
||||
async def create(self, db: AsyncSession, *, obj_in: CreateSchemaType) -> ModelType:
|
||||
"""Create a new record with error handling."""
|
||||
try:
|
||||
obj_in_data = jsonable_encoder(obj_in)
|
||||
db_obj = self.model(**obj_in_data)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
|
||||
@@ -82,20 +143,20 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except (OperationalError, DataError) as e:
|
||||
db.rollback()
|
||||
await db.rollback()
|
||||
logger.error(f"Database error creating {self.model.__name__}: {str(e)}")
|
||||
raise ValueError(f"Database operation failed: {str(e)}")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error creating {self.model.__name__}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def update(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
db_obj: ModelType,
|
||||
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
|
||||
async def update(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
db_obj: ModelType,
|
||||
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
|
||||
) -> ModelType:
|
||||
"""Update a record with error handling."""
|
||||
try:
|
||||
@@ -104,15 +165,17 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
|
||||
for field in obj_data:
|
||||
if field in update_data:
|
||||
setattr(db_obj, field, update_data[field])
|
||||
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
|
||||
@@ -120,15 +183,15 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except (OperationalError, DataError) as e:
|
||||
db.rollback()
|
||||
await db.rollback()
|
||||
logger.error(f"Database error updating {self.model.__name__}: {str(e)}")
|
||||
raise ValueError(f"Database operation failed: {str(e)}")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error updating {self.model.__name__}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def remove(self, db: Session, *, id: str) -> Optional[ModelType]:
|
||||
async def remove(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
|
||||
"""Delete a record with error handling and null check."""
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
@@ -141,27 +204,31 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
return None
|
||||
|
||||
try:
|
||||
obj = db.query(self.model).filter(self.model.id == uuid_obj).first()
|
||||
result = await db.execute(
|
||||
select(self.model).where(self.model.id == uuid_obj)
|
||||
)
|
||||
obj = result.scalar_one_or_none()
|
||||
|
||||
if obj is None:
|
||||
logger.warning(f"{self.model.__name__} with id {id} not found for deletion")
|
||||
return None
|
||||
|
||||
db.delete(obj)
|
||||
db.commit()
|
||||
await db.delete(obj)
|
||||
await db.commit()
|
||||
return obj
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Cannot delete {self.model.__name__}: referenced by other records")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
await db.rollback()
|
||||
logger.error(f"Error deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def get_multi_with_total(
|
||||
async def get_multi_with_total(
|
||||
self,
|
||||
db: Session,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
@@ -193,43 +260,63 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
|
||||
try:
|
||||
# Build base query
|
||||
query = db.query(self.model)
|
||||
query = select(self.model)
|
||||
|
||||
# Exclude soft-deleted records by default
|
||||
if hasattr(self.model, 'deleted_at'):
|
||||
query = query.filter(self.model.deleted_at.is_(None))
|
||||
query = query.where(self.model.deleted_at.is_(None))
|
||||
|
||||
# Apply filters
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
if hasattr(self.model, field) and value is not None:
|
||||
query = query.filter(getattr(self.model, field) == value)
|
||||
query = query.where(getattr(self.model, field) == value)
|
||||
|
||||
# Get total count (before pagination)
|
||||
total = query.count()
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply sorting
|
||||
if sort_by and hasattr(self.model, sort_by):
|
||||
sort_column = getattr(self.model, sort_by)
|
||||
if sort_order.lower() == "desc":
|
||||
query = query.order_by(desc(sort_column))
|
||||
query = query.order_by(sort_column.desc())
|
||||
else:
|
||||
query = query.order_by(asc(sort_column))
|
||||
query = query.order_by(sort_column.asc())
|
||||
|
||||
# Apply pagination
|
||||
items = query.offset(skip).limit(limit).all()
|
||||
query = query.offset(skip).limit(limit)
|
||||
items_result = await db.execute(query)
|
||||
items = list(items_result.scalars().all())
|
||||
|
||||
return items, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}")
|
||||
raise
|
||||
|
||||
def soft_delete(self, db: Session, *, id: str) -> Optional[ModelType]:
|
||||
async def count(self, db: AsyncSession) -> int:
|
||||
"""Get total count of records."""
|
||||
try:
|
||||
result = await db.execute(select(func.count(self.model.id)))
|
||||
return result.scalar_one()
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting {self.model.__name__} records: {str(e)}")
|
||||
raise
|
||||
|
||||
async def exists(self, db: AsyncSession, id: str) -> bool:
|
||||
"""Check if a record exists by ID."""
|
||||
obj = await self.get(db, id=id)
|
||||
return obj is not None
|
||||
|
||||
async def soft_delete(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
|
||||
"""
|
||||
Soft delete a record by setting deleted_at timestamp.
|
||||
|
||||
Only works if the model has a 'deleted_at' column.
|
||||
"""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
@@ -241,7 +328,10 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
return None
|
||||
|
||||
try:
|
||||
obj = db.query(self.model).filter(self.model.id == uuid_obj).first()
|
||||
result = await db.execute(
|
||||
select(self.model).where(self.model.id == uuid_obj)
|
||||
)
|
||||
obj = result.scalar_one_or_none()
|
||||
|
||||
if obj is None:
|
||||
logger.warning(f"{self.model.__name__} with id {id} not found for soft deletion")
|
||||
@@ -255,15 +345,15 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
# Set deleted_at timestamp
|
||||
obj.deleted_at = datetime.now(timezone.utc)
|
||||
db.add(obj)
|
||||
db.commit()
|
||||
db.refresh(obj)
|
||||
await db.commit()
|
||||
await db.refresh(obj)
|
||||
return obj
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
await db.rollback()
|
||||
logger.error(f"Error soft deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def restore(self, db: Session, *, id: str) -> Optional[ModelType]:
|
||||
async def restore(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
|
||||
"""
|
||||
Restore a soft-deleted record by clearing the deleted_at timestamp.
|
||||
|
||||
@@ -282,10 +372,13 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
try:
|
||||
# Find the soft-deleted record
|
||||
if hasattr(self.model, 'deleted_at'):
|
||||
obj = db.query(self.model).filter(
|
||||
self.model.id == uuid_obj,
|
||||
self.model.deleted_at.isnot(None)
|
||||
).first()
|
||||
result = await db.execute(
|
||||
select(self.model).where(
|
||||
self.model.id == uuid_obj,
|
||||
self.model.deleted_at.isnot(None)
|
||||
)
|
||||
)
|
||||
obj = result.scalar_one_or_none()
|
||||
else:
|
||||
logger.error(f"{self.model.__name__} does not support soft deletes")
|
||||
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
|
||||
@@ -297,10 +390,10 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
# Clear deleted_at timestamp
|
||||
obj.deleted_at = None
|
||||
db.add(obj)
|
||||
db.commit()
|
||||
db.refresh(obj)
|
||||
await db.commit()
|
||||
await db.refresh(obj)
|
||||
return obj
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
await db.rollback()
|
||||
logger.error(f"Error restoring {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
raise
|
||||
|
||||
@@ -1,399 +0,0 @@
|
||||
# app/crud/base_async.py
|
||||
"""
|
||||
Async CRUD operations base class using SQLAlchemy 2.0 async patterns.
|
||||
|
||||
Provides reusable create, read, update, and delete operations for all models.
|
||||
"""
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Load
|
||||
|
||||
from app.core.database_async import Base
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=Base)
|
||||
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
|
||||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
|
||||
|
||||
|
||||
class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
"""Async CRUD operations for a model."""
|
||||
|
||||
def __init__(self, model: Type[ModelType]):
|
||||
"""
|
||||
CRUD object with default async methods to Create, Read, Update, Delete.
|
||||
|
||||
Parameters:
|
||||
model: A SQLAlchemy model class
|
||||
"""
|
||||
self.model = model
|
||||
|
||||
async def get(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
id: str,
|
||||
options: Optional[List[Load]] = None
|
||||
) -> Optional[ModelType]:
|
||||
"""
|
||||
Get a single record by ID with UUID validation and optional eager loading.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
id: Record UUID
|
||||
options: Optional list of SQLAlchemy load options (e.g., joinedload, selectinload)
|
||||
for eager loading relationships to prevent N+1 queries
|
||||
|
||||
Returns:
|
||||
Model instance or None if not found
|
||||
|
||||
Example:
|
||||
# Eager load user relationship
|
||||
from sqlalchemy.orm import joinedload
|
||||
session = await session_crud.get(db, id=session_id, options=[joinedload(UserSession.user)])
|
||||
"""
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format: {id} - {str(e)}")
|
||||
return None
|
||||
|
||||
try:
|
||||
query = select(self.model).where(self.model.id == uuid_obj)
|
||||
|
||||
# Apply eager loading options if provided
|
||||
if options:
|
||||
for option in options:
|
||||
query = query.options(option)
|
||||
|
||||
result = await db.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_multi(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
options: Optional[List[Load]] = None
|
||||
) -> List[ModelType]:
|
||||
"""
|
||||
Get multiple records with pagination validation and optional eager loading.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
skip: Number of records to skip
|
||||
limit: Maximum number of records to return
|
||||
options: Optional list of SQLAlchemy load options for eager loading
|
||||
|
||||
Returns:
|
||||
List of model instances
|
||||
"""
|
||||
# Validate pagination parameters
|
||||
if skip < 0:
|
||||
raise ValueError("skip must be non-negative")
|
||||
if limit < 0:
|
||||
raise ValueError("limit must be non-negative")
|
||||
if limit > 1000:
|
||||
raise ValueError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
query = select(self.model).offset(skip).limit(limit)
|
||||
|
||||
# Apply eager loading options if provided
|
||||
if options:
|
||||
for option in options:
|
||||
query = query.options(option)
|
||||
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}")
|
||||
raise
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: CreateSchemaType) -> ModelType:
|
||||
"""Create a new record with error handling."""
|
||||
try:
|
||||
obj_in_data = jsonable_encoder(obj_in)
|
||||
db_obj = self.model(**obj_in_data)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"A {self.model.__name__} with this data already exists")
|
||||
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except (OperationalError, DataError) as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Database error creating {self.model.__name__}: {str(e)}")
|
||||
raise ValueError(f"Database operation failed: {str(e)}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error creating {self.model.__name__}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
db_obj: ModelType,
|
||||
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
|
||||
) -> ModelType:
|
||||
"""Update a record with error handling."""
|
||||
try:
|
||||
obj_data = jsonable_encoder(db_obj)
|
||||
if isinstance(obj_in, dict):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
|
||||
for field in obj_data:
|
||||
if field in update_data:
|
||||
setattr(db_obj, field, update_data[field])
|
||||
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"A {self.model.__name__} with this data already exists")
|
||||
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except (OperationalError, DataError) as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Database error updating {self.model.__name__}: {str(e)}")
|
||||
raise ValueError(f"Database operation failed: {str(e)}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error updating {self.model.__name__}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def remove(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
|
||||
"""Delete a record with error handling and null check."""
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format for deletion: {id} - {str(e)}")
|
||||
return None
|
||||
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(self.model).where(self.model.id == uuid_obj)
|
||||
)
|
||||
obj = result.scalar_one_or_none()
|
||||
|
||||
if obj is None:
|
||||
logger.warning(f"{self.model.__name__} with id {id} not found for deletion")
|
||||
return None
|
||||
|
||||
await db.delete(obj)
|
||||
await db.commit()
|
||||
return obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Cannot delete {self.model.__name__}: referenced by other records")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_multi_with_total(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
sort_by: Optional[str] = None,
|
||||
sort_order: str = "asc",
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[List[ModelType], int]:
|
||||
"""
|
||||
Get multiple records with total count, filtering, and sorting.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
skip: Number of records to skip
|
||||
limit: Maximum number of records to return
|
||||
sort_by: Field name to sort by (must be a valid model attribute)
|
||||
sort_order: Sort order ("asc" or "desc")
|
||||
filters: Dictionary of filters (field_name: value)
|
||||
|
||||
Returns:
|
||||
Tuple of (items, total_count)
|
||||
"""
|
||||
# Validate pagination parameters
|
||||
if skip < 0:
|
||||
raise ValueError("skip must be non-negative")
|
||||
if limit < 0:
|
||||
raise ValueError("limit must be non-negative")
|
||||
if limit > 1000:
|
||||
raise ValueError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
# Build base query
|
||||
query = select(self.model)
|
||||
|
||||
# Exclude soft-deleted records by default
|
||||
if hasattr(self.model, 'deleted_at'):
|
||||
query = query.where(self.model.deleted_at.is_(None))
|
||||
|
||||
# Apply filters
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
if hasattr(self.model, field) and value is not None:
|
||||
query = query.where(getattr(self.model, field) == value)
|
||||
|
||||
# Get total count (before pagination)
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply sorting
|
||||
if sort_by and hasattr(self.model, sort_by):
|
||||
sort_column = getattr(self.model, sort_by)
|
||||
if sort_order.lower() == "desc":
|
||||
query = query.order_by(sort_column.desc())
|
||||
else:
|
||||
query = query.order_by(sort_column.asc())
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
items_result = await db.execute(query)
|
||||
items = list(items_result.scalars().all())
|
||||
|
||||
return items, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}")
|
||||
raise
|
||||
|
||||
async def count(self, db: AsyncSession) -> int:
|
||||
"""Get total count of records."""
|
||||
try:
|
||||
result = await db.execute(select(func.count(self.model.id)))
|
||||
return result.scalar_one()
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting {self.model.__name__} records: {str(e)}")
|
||||
raise
|
||||
|
||||
async def exists(self, db: AsyncSession, id: str) -> bool:
|
||||
"""Check if a record exists by ID."""
|
||||
obj = await self.get(db, id=id)
|
||||
return obj is not None
|
||||
|
||||
async def soft_delete(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
|
||||
"""
|
||||
Soft delete a record by setting deleted_at timestamp.
|
||||
|
||||
Only works if the model has a 'deleted_at' column.
|
||||
"""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format for soft deletion: {id} - {str(e)}")
|
||||
return None
|
||||
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(self.model).where(self.model.id == uuid_obj)
|
||||
)
|
||||
obj = result.scalar_one_or_none()
|
||||
|
||||
if obj is None:
|
||||
logger.warning(f"{self.model.__name__} with id {id} not found for soft deletion")
|
||||
return None
|
||||
|
||||
# Check if model supports soft deletes
|
||||
if not hasattr(self.model, 'deleted_at'):
|
||||
logger.error(f"{self.model.__name__} does not support soft deletes")
|
||||
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
|
||||
|
||||
# Set deleted_at timestamp
|
||||
obj.deleted_at = datetime.now(timezone.utc)
|
||||
db.add(obj)
|
||||
await db.commit()
|
||||
await db.refresh(obj)
|
||||
return obj
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error soft deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def restore(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
|
||||
"""
|
||||
Restore a soft-deleted record by clearing the deleted_at timestamp.
|
||||
|
||||
Only works if the model has a 'deleted_at' column.
|
||||
"""
|
||||
# Validate UUID format
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format for restoration: {id} - {str(e)}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Find the soft-deleted record
|
||||
if hasattr(self.model, 'deleted_at'):
|
||||
result = await db.execute(
|
||||
select(self.model).where(
|
||||
self.model.id == uuid_obj,
|
||||
self.model.deleted_at.isnot(None)
|
||||
)
|
||||
)
|
||||
obj = result.scalar_one_or_none()
|
||||
else:
|
||||
logger.error(f"{self.model.__name__} does not support soft deletes")
|
||||
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
|
||||
|
||||
if obj is None:
|
||||
logger.warning(f"Soft-deleted {self.model.__name__} with id {id} not found for restoration")
|
||||
return None
|
||||
|
||||
# Clear deleted_at timestamp
|
||||
obj.deleted_at = None
|
||||
db.add(obj)
|
||||
await db.commit()
|
||||
await db.refresh(obj)
|
||||
return obj
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error restoring {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
434
backend/app/crud/organization.py
Normal file → Executable file
434
backend/app/crud/organization.py
Normal file → Executable file
@@ -1,11 +1,12 @@
|
||||
# app/crud/organization.py
|
||||
# app/crud/organization_async.py
|
||||
"""Async CRUD operations for Organization model using SQLAlchemy 2.0 patterns."""
|
||||
import logging
|
||||
from typing import Optional, List, Dict, Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, or_, and_
|
||||
from sqlalchemy import func, or_, and_, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.organization import Organization
|
||||
@@ -13,20 +14,27 @@ from app.models.user import User
|
||||
from app.models.user_organization import UserOrganization, OrganizationRole
|
||||
from app.schemas.organizations import (
|
||||
OrganizationCreate,
|
||||
OrganizationUpdate
|
||||
OrganizationUpdate,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUpdate]):
|
||||
"""CRUD operations for Organization model."""
|
||||
"""Async CRUD operations for Organization model."""
|
||||
|
||||
def get_by_slug(self, db: Session, *, slug: str) -> Optional[Organization]:
|
||||
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Optional[Organization]:
|
||||
"""Get organization by slug."""
|
||||
return db.query(Organization).filter(Organization.slug == slug).first()
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(Organization).where(Organization.slug == slug)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organization by slug {slug}: {str(e)}")
|
||||
raise
|
||||
|
||||
def create(self, db: Session, *, obj_in: OrganizationCreate) -> Organization:
|
||||
async def create(self, db: AsyncSession, *, obj_in: OrganizationCreate) -> Organization:
|
||||
"""Create a new organization with error handling."""
|
||||
try:
|
||||
db_obj = Organization(
|
||||
@@ -37,11 +45,11 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
settings=obj_in.settings or {}
|
||||
)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
if "slug" in error_msg.lower():
|
||||
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
|
||||
@@ -49,13 +57,13 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
logger.error(f"Integrity error creating organization: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error creating organization: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def get_multi_with_filters(
|
||||
async def get_multi_with_filters(
|
||||
self,
|
||||
db: Session,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
@@ -70,47 +78,139 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
Returns:
|
||||
Tuple of (organizations list, total count)
|
||||
"""
|
||||
query = db.query(Organization)
|
||||
try:
|
||||
query = select(Organization)
|
||||
|
||||
# Apply filters
|
||||
if is_active is not None:
|
||||
query = query.filter(Organization.is_active == is_active)
|
||||
# Apply filters
|
||||
if is_active is not None:
|
||||
query = query.where(Organization.is_active == is_active)
|
||||
|
||||
if search:
|
||||
search_filter = or_(
|
||||
Organization.name.ilike(f"%{search}%"),
|
||||
Organization.slug.ilike(f"%{search}%"),
|
||||
Organization.description.ilike(f"%{search}%")
|
||||
)
|
||||
query = query.filter(search_filter)
|
||||
if search:
|
||||
search_filter = or_(
|
||||
Organization.name.ilike(f"%{search}%"),
|
||||
Organization.slug.ilike(f"%{search}%"),
|
||||
Organization.description.ilike(f"%{search}%")
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
# Get total count before pagination
|
||||
total = query.count()
|
||||
# Get total count before pagination
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply sorting
|
||||
sort_column = getattr(Organization, sort_by, Organization.created_at)
|
||||
if sort_order == "desc":
|
||||
query = query.order_by(sort_column.desc())
|
||||
else:
|
||||
query = query.order_by(sort_column.asc())
|
||||
# Apply sorting
|
||||
sort_column = getattr(Organization, sort_by, Organization.created_at)
|
||||
if sort_order == "desc":
|
||||
query = query.order_by(sort_column.desc())
|
||||
else:
|
||||
query = query.order_by(sort_column.asc())
|
||||
|
||||
# Apply pagination
|
||||
organizations = query.offset(skip).limit(limit).all()
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
result = await db.execute(query)
|
||||
organizations = list(result.scalars().all())
|
||||
|
||||
return organizations, total
|
||||
return organizations, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organizations with filters: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_member_count(self, db: Session, *, organization_id: UUID) -> int:
|
||||
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
|
||||
"""Get the count of active members in an organization."""
|
||||
return db.query(func.count(UserOrganization.user_id)).filter(
|
||||
and_(
|
||||
UserOrganization.organization_id == organization_id,
|
||||
UserOrganization.is_active == True
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(func.count(UserOrganization.user_id)).where(
|
||||
and_(
|
||||
UserOrganization.organization_id == organization_id,
|
||||
UserOrganization.is_active == True
|
||||
)
|
||||
)
|
||||
)
|
||||
).scalar() or 0
|
||||
return result.scalar_one() or 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting member count for organization {organization_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def add_user(
|
||||
async def get_multi_with_member_counts(
|
||||
self,
|
||||
db: Session,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: Optional[bool] = None,
|
||||
search: Optional[str] = None
|
||||
) -> tuple[List[Dict[str, Any]], int]:
|
||||
"""
|
||||
Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY.
|
||||
This eliminates the N+1 query problem.
|
||||
|
||||
Returns:
|
||||
Tuple of (list of dicts with org and member_count, total count)
|
||||
"""
|
||||
try:
|
||||
# Build base query with LEFT JOIN and GROUP BY
|
||||
query = (
|
||||
select(
|
||||
Organization,
|
||||
func.count(
|
||||
func.distinct(
|
||||
and_(
|
||||
UserOrganization.is_active == True,
|
||||
UserOrganization.user_id
|
||||
).self_group()
|
||||
)
|
||||
).label('member_count')
|
||||
)
|
||||
.outerjoin(UserOrganization, Organization.id == UserOrganization.organization_id)
|
||||
.group_by(Organization.id)
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if is_active is not None:
|
||||
query = query.where(Organization.is_active == is_active)
|
||||
|
||||
if search:
|
||||
search_filter = or_(
|
||||
Organization.name.ilike(f"%{search}%"),
|
||||
Organization.slug.ilike(f"%{search}%"),
|
||||
Organization.description.ilike(f"%{search}%")
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count(Organization.id))
|
||||
if is_active is not None:
|
||||
count_query = count_query.where(Organization.is_active == is_active)
|
||||
if search:
|
||||
count_query = count_query.where(search_filter)
|
||||
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply pagination and ordering
|
||||
query = query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
|
||||
|
||||
result = await db.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
# Convert to list of dicts
|
||||
orgs_with_counts = [
|
||||
{
|
||||
'organization': org,
|
||||
'member_count': member_count
|
||||
}
|
||||
for org, member_count in rows
|
||||
]
|
||||
|
||||
return orgs_with_counts, total
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organizations with member counts: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def add_user(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
user_id: UUID,
|
||||
@@ -120,12 +220,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
"""Add a user to an organization with a specific role."""
|
||||
try:
|
||||
# Check if relationship already exists
|
||||
existing = db.query(UserOrganization).filter(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id
|
||||
result = await db.execute(
|
||||
select(UserOrganization).where(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id
|
||||
)
|
||||
)
|
||||
).first()
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
# Reactivate if inactive, or raise error if already active
|
||||
@@ -133,8 +236,8 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
existing.is_active = True
|
||||
existing.role = role
|
||||
existing.custom_permissions = custom_permissions
|
||||
db.commit()
|
||||
db.refresh(existing)
|
||||
await db.commit()
|
||||
await db.refresh(existing)
|
||||
return existing
|
||||
else:
|
||||
raise ValueError("User is already a member of this organization")
|
||||
@@ -148,48 +251,51 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
custom_permissions=custom_permissions
|
||||
)
|
||||
db.add(user_org)
|
||||
db.commit()
|
||||
db.refresh(user_org)
|
||||
await db.commit()
|
||||
await db.refresh(user_org)
|
||||
return user_org
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
await db.rollback()
|
||||
logger.error(f"Integrity error adding user to organization: {str(e)}")
|
||||
raise ValueError("Failed to add user to organization")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
await db.rollback()
|
||||
logger.error(f"Error adding user to organization: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def remove_user(
|
||||
async def remove_user(
|
||||
self,
|
||||
db: Session,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
user_id: UUID
|
||||
) -> bool:
|
||||
"""Remove a user from an organization (soft delete)."""
|
||||
try:
|
||||
user_org = db.query(UserOrganization).filter(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id
|
||||
result = await db.execute(
|
||||
select(UserOrganization).where(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id
|
||||
)
|
||||
)
|
||||
).first()
|
||||
)
|
||||
user_org = result.scalar_one_or_none()
|
||||
|
||||
if not user_org:
|
||||
return False
|
||||
|
||||
user_org.is_active = False
|
||||
db.commit()
|
||||
await db.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
await db.rollback()
|
||||
logger.error(f"Error removing user from organization: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def update_user_role(
|
||||
async def update_user_role(
|
||||
self,
|
||||
db: Session,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
user_id: UUID,
|
||||
@@ -198,12 +304,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
) -> Optional[UserOrganization]:
|
||||
"""Update a user's role in an organization."""
|
||||
try:
|
||||
user_org = db.query(UserOrganization).filter(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id
|
||||
result = await db.execute(
|
||||
select(UserOrganization).where(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id
|
||||
)
|
||||
)
|
||||
).first()
|
||||
)
|
||||
user_org = result.scalar_one_or_none()
|
||||
|
||||
if not user_org:
|
||||
return None
|
||||
@@ -211,17 +320,17 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
user_org.role = role
|
||||
if custom_permissions is not None:
|
||||
user_org.custom_permissions = custom_permissions
|
||||
db.commit()
|
||||
db.refresh(user_org)
|
||||
await db.commit()
|
||||
await db.refresh(user_org)
|
||||
return user_org
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
await db.rollback()
|
||||
logger.error(f"Error updating user role: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def get_organization_members(
|
||||
async def get_organization_members(
|
||||
self,
|
||||
db: Session,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
skip: int = 0,
|
||||
@@ -234,86 +343,175 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
Returns:
|
||||
Tuple of (members list with user details, total count)
|
||||
"""
|
||||
query = db.query(UserOrganization, User).join(
|
||||
User, UserOrganization.user_id == User.id
|
||||
).filter(UserOrganization.organization_id == organization_id)
|
||||
try:
|
||||
# Build query with join
|
||||
query = (
|
||||
select(UserOrganization, User)
|
||||
.join(User, UserOrganization.user_id == User.id)
|
||||
.where(UserOrganization.organization_id == organization_id)
|
||||
)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.filter(UserOrganization.is_active == is_active)
|
||||
if is_active is not None:
|
||||
query = query.where(UserOrganization.is_active == is_active)
|
||||
|
||||
total = query.count()
|
||||
# Get total count
|
||||
count_query = select(func.count()).select_from(
|
||||
select(UserOrganization)
|
||||
.where(UserOrganization.organization_id == organization_id)
|
||||
.where(UserOrganization.is_active == is_active if is_active is not None else True)
|
||||
.alias()
|
||||
)
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
results = query.order_by(UserOrganization.created_at.desc()).offset(skip).limit(limit).all()
|
||||
# Apply ordering and pagination
|
||||
query = query.order_by(UserOrganization.created_at.desc()).offset(skip).limit(limit)
|
||||
result = await db.execute(query)
|
||||
results = result.all()
|
||||
|
||||
members = []
|
||||
for user_org, user in results:
|
||||
members.append({
|
||||
"user_id": user.id,
|
||||
"email": user.email,
|
||||
"first_name": user.first_name,
|
||||
"last_name": user.last_name,
|
||||
"role": user_org.role,
|
||||
"is_active": user_org.is_active,
|
||||
"joined_at": user_org.created_at
|
||||
})
|
||||
members = []
|
||||
for user_org, user in results:
|
||||
members.append({
|
||||
"user_id": user.id,
|
||||
"email": user.email,
|
||||
"first_name": user.first_name,
|
||||
"last_name": user.last_name,
|
||||
"role": user_org.role,
|
||||
"is_active": user_org.is_active,
|
||||
"joined_at": user_org.created_at
|
||||
})
|
||||
|
||||
return members, total
|
||||
return members, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organization members: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_user_organizations(
|
||||
async def get_user_organizations(
|
||||
self,
|
||||
db: Session,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
is_active: bool = True
|
||||
) -> List[Organization]:
|
||||
"""Get all organizations a user belongs to."""
|
||||
query = db.query(Organization).join(
|
||||
UserOrganization, Organization.id == UserOrganization.organization_id
|
||||
).filter(UserOrganization.user_id == user_id)
|
||||
try:
|
||||
query = (
|
||||
select(Organization)
|
||||
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
|
||||
.where(UserOrganization.user_id == user_id)
|
||||
)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.filter(UserOrganization.is_active == is_active)
|
||||
if is_active is not None:
|
||||
query = query.where(UserOrganization.is_active == is_active)
|
||||
|
||||
return query.all()
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user organizations: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_user_role_in_org(
|
||||
async def get_user_organizations_with_details(
|
||||
self,
|
||||
db: Session,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
is_active: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get user's organizations with role and member count in SINGLE QUERY.
|
||||
Eliminates N+1 problem by using subquery for member counts.
|
||||
|
||||
Returns:
|
||||
List of dicts with organization, role, and member_count
|
||||
"""
|
||||
try:
|
||||
# Subquery to get member counts for each organization
|
||||
member_count_subq = (
|
||||
select(
|
||||
UserOrganization.organization_id,
|
||||
func.count(UserOrganization.user_id).label('member_count')
|
||||
)
|
||||
.where(UserOrganization.is_active == True)
|
||||
.group_by(UserOrganization.organization_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# Main query with JOIN to get org, role, and member count
|
||||
query = (
|
||||
select(
|
||||
Organization,
|
||||
UserOrganization.role,
|
||||
func.coalesce(member_count_subq.c.member_count, 0).label('member_count')
|
||||
)
|
||||
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
|
||||
.outerjoin(member_count_subq, Organization.id == member_count_subq.c.organization_id)
|
||||
.where(UserOrganization.user_id == user_id)
|
||||
)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.where(UserOrganization.is_active == is_active)
|
||||
|
||||
result = await db.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
return [
|
||||
{
|
||||
'organization': org,
|
||||
'role': role,
|
||||
'member_count': member_count
|
||||
}
|
||||
for org, role, member_count in rows
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user organizations with details: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_user_role_in_org(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
organization_id: UUID
|
||||
) -> Optional[OrganizationRole]:
|
||||
"""Get a user's role in a specific organization."""
|
||||
user_org = db.query(UserOrganization).filter(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id,
|
||||
UserOrganization.is_active == True
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(UserOrganization).where(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id,
|
||||
UserOrganization.is_active == True
|
||||
)
|
||||
)
|
||||
)
|
||||
).first()
|
||||
user_org = result.scalar_one_or_none()
|
||||
|
||||
return user_org.role if user_org else None
|
||||
return user_org.role if user_org else None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user role in org: {str(e)}")
|
||||
raise
|
||||
|
||||
def is_user_org_owner(
|
||||
async def is_user_org_owner(
|
||||
self,
|
||||
db: Session,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
organization_id: UUID
|
||||
) -> bool:
|
||||
"""Check if a user is an owner of an organization."""
|
||||
role = self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
|
||||
role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
|
||||
return role == OrganizationRole.OWNER
|
||||
|
||||
def is_user_org_admin(
|
||||
async def is_user_org_admin(
|
||||
self,
|
||||
db: Session,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
organization_id: UUID
|
||||
) -> bool:
|
||||
"""Check if a user is an owner or admin of an organization."""
|
||||
role = self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
|
||||
role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
|
||||
return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
|
||||
|
||||
|
||||
|
||||
@@ -1,519 +0,0 @@
|
||||
# app/crud/organization_async.py
|
||||
"""Async CRUD operations for Organization model using SQLAlchemy 2.0 patterns."""
|
||||
import logging
|
||||
from typing import Optional, List, Dict, Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, or_, and_, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.crud.base_async import CRUDBaseAsync
|
||||
from app.models.organization import Organization
|
||||
from app.models.user import User
|
||||
from app.models.user_organization import UserOrganization, OrganizationRole
|
||||
from app.schemas.organizations import (
|
||||
OrganizationCreate,
|
||||
OrganizationUpdate,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDOrganizationAsync(CRUDBaseAsync[Organization, OrganizationCreate, OrganizationUpdate]):
|
||||
"""Async CRUD operations for Organization model."""
|
||||
|
||||
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Optional[Organization]:
|
||||
"""Get organization by slug."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(Organization).where(Organization.slug == slug)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organization by slug {slug}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: OrganizationCreate) -> Organization:
|
||||
"""Create a new organization with error handling."""
|
||||
try:
|
||||
db_obj = Organization(
|
||||
name=obj_in.name,
|
||||
slug=obj_in.slug,
|
||||
description=obj_in.description,
|
||||
is_active=obj_in.is_active,
|
||||
settings=obj_in.settings or {}
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
if "slug" in error_msg.lower():
|
||||
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
|
||||
raise ValueError(f"Organization with slug '{obj_in.slug}' already exists")
|
||||
logger.error(f"Integrity error creating organization: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error creating organization: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_multi_with_filters(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: Optional[bool] = None,
|
||||
search: Optional[str] = None,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc"
|
||||
) -> tuple[List[Organization], int]:
|
||||
"""
|
||||
Get multiple organizations with filtering, searching, and sorting.
|
||||
|
||||
Returns:
|
||||
Tuple of (organizations list, total count)
|
||||
"""
|
||||
try:
|
||||
query = select(Organization)
|
||||
|
||||
# Apply filters
|
||||
if is_active is not None:
|
||||
query = query.where(Organization.is_active == is_active)
|
||||
|
||||
if search:
|
||||
search_filter = or_(
|
||||
Organization.name.ilike(f"%{search}%"),
|
||||
Organization.slug.ilike(f"%{search}%"),
|
||||
Organization.description.ilike(f"%{search}%")
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
# Get total count before pagination
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply sorting
|
||||
sort_column = getattr(Organization, sort_by, Organization.created_at)
|
||||
if sort_order == "desc":
|
||||
query = query.order_by(sort_column.desc())
|
||||
else:
|
||||
query = query.order_by(sort_column.asc())
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
result = await db.execute(query)
|
||||
organizations = list(result.scalars().all())
|
||||
|
||||
return organizations, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organizations with filters: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
|
||||
"""Get the count of active members in an organization."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(func.count(UserOrganization.user_id)).where(
|
||||
and_(
|
||||
UserOrganization.organization_id == organization_id,
|
||||
UserOrganization.is_active == True
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one() or 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting member count for organization {organization_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_multi_with_member_counts(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: Optional[bool] = None,
|
||||
search: Optional[str] = None
|
||||
) -> tuple[List[Dict[str, Any]], int]:
|
||||
"""
|
||||
Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY.
|
||||
This eliminates the N+1 query problem.
|
||||
|
||||
Returns:
|
||||
Tuple of (list of dicts with org and member_count, total count)
|
||||
"""
|
||||
try:
|
||||
# Build base query with LEFT JOIN and GROUP BY
|
||||
query = (
|
||||
select(
|
||||
Organization,
|
||||
func.count(
|
||||
func.distinct(
|
||||
and_(
|
||||
UserOrganization.is_active == True,
|
||||
UserOrganization.user_id
|
||||
).self_group()
|
||||
)
|
||||
).label('member_count')
|
||||
)
|
||||
.outerjoin(UserOrganization, Organization.id == UserOrganization.organization_id)
|
||||
.group_by(Organization.id)
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if is_active is not None:
|
||||
query = query.where(Organization.is_active == is_active)
|
||||
|
||||
if search:
|
||||
search_filter = or_(
|
||||
Organization.name.ilike(f"%{search}%"),
|
||||
Organization.slug.ilike(f"%{search}%"),
|
||||
Organization.description.ilike(f"%{search}%")
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count(Organization.id))
|
||||
if is_active is not None:
|
||||
count_query = count_query.where(Organization.is_active == is_active)
|
||||
if search:
|
||||
count_query = count_query.where(search_filter)
|
||||
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply pagination and ordering
|
||||
query = query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
|
||||
|
||||
result = await db.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
# Convert to list of dicts
|
||||
orgs_with_counts = [
|
||||
{
|
||||
'organization': org,
|
||||
'member_count': member_count
|
||||
}
|
||||
for org, member_count in rows
|
||||
]
|
||||
|
||||
return orgs_with_counts, total
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organizations with member counts: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def add_user(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
user_id: UUID,
|
||||
role: OrganizationRole = OrganizationRole.MEMBER,
|
||||
custom_permissions: Optional[str] = None
|
||||
) -> UserOrganization:
|
||||
"""Add a user to an organization with a specific role."""
|
||||
try:
|
||||
# Check if relationship already exists
|
||||
result = await db.execute(
|
||||
select(UserOrganization).where(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id
|
||||
)
|
||||
)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
# Reactivate if inactive, or raise error if already active
|
||||
if not existing.is_active:
|
||||
existing.is_active = True
|
||||
existing.role = role
|
||||
existing.custom_permissions = custom_permissions
|
||||
await db.commit()
|
||||
await db.refresh(existing)
|
||||
return existing
|
||||
else:
|
||||
raise ValueError("User is already a member of this organization")
|
||||
|
||||
# Create new relationship
|
||||
user_org = UserOrganization(
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
role=role,
|
||||
is_active=True,
|
||||
custom_permissions=custom_permissions
|
||||
)
|
||||
db.add(user_org)
|
||||
await db.commit()
|
||||
await db.refresh(user_org)
|
||||
return user_org
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Integrity error adding user to organization: {str(e)}")
|
||||
raise ValueError("Failed to add user to organization")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error adding user to organization: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def remove_user(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
user_id: UUID
|
||||
) -> bool:
|
||||
"""Remove a user from an organization (soft delete)."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(UserOrganization).where(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id
|
||||
)
|
||||
)
|
||||
)
|
||||
user_org = result.scalar_one_or_none()
|
||||
|
||||
if not user_org:
|
||||
return False
|
||||
|
||||
user_org.is_active = False
|
||||
await db.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error removing user from organization: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update_user_role(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
user_id: UUID,
|
||||
role: OrganizationRole,
|
||||
custom_permissions: Optional[str] = None
|
||||
) -> Optional[UserOrganization]:
|
||||
"""Update a user's role in an organization."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(UserOrganization).where(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id
|
||||
)
|
||||
)
|
||||
)
|
||||
user_org = result.scalar_one_or_none()
|
||||
|
||||
if not user_org:
|
||||
return None
|
||||
|
||||
user_org.role = role
|
||||
if custom_permissions is not None:
|
||||
user_org.custom_permissions = custom_permissions
|
||||
await db.commit()
|
||||
await db.refresh(user_org)
|
||||
return user_org
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error updating user role: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_organization_members(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool = True
|
||||
) -> tuple[List[Dict[str, Any]], int]:
|
||||
"""
|
||||
Get members of an organization with user details.
|
||||
|
||||
Returns:
|
||||
Tuple of (members list with user details, total count)
|
||||
"""
|
||||
try:
|
||||
# Build query with join
|
||||
query = (
|
||||
select(UserOrganization, User)
|
||||
.join(User, UserOrganization.user_id == User.id)
|
||||
.where(UserOrganization.organization_id == organization_id)
|
||||
)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.where(UserOrganization.is_active == is_active)
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count()).select_from(
|
||||
select(UserOrganization)
|
||||
.where(UserOrganization.organization_id == organization_id)
|
||||
.where(UserOrganization.is_active == is_active if is_active is not None else True)
|
||||
.alias()
|
||||
)
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply ordering and pagination
|
||||
query = query.order_by(UserOrganization.created_at.desc()).offset(skip).limit(limit)
|
||||
result = await db.execute(query)
|
||||
results = result.all()
|
||||
|
||||
members = []
|
||||
for user_org, user in results:
|
||||
members.append({
|
||||
"user_id": user.id,
|
||||
"email": user.email,
|
||||
"first_name": user.first_name,
|
||||
"last_name": user.last_name,
|
||||
"role": user_org.role,
|
||||
"is_active": user_org.is_active,
|
||||
"joined_at": user_org.created_at
|
||||
})
|
||||
|
||||
return members, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organization members: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_user_organizations(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
is_active: bool = True
|
||||
) -> List[Organization]:
|
||||
"""Get all organizations a user belongs to."""
|
||||
try:
|
||||
query = (
|
||||
select(Organization)
|
||||
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
|
||||
.where(UserOrganization.user_id == user_id)
|
||||
)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.where(UserOrganization.is_active == is_active)
|
||||
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user organizations: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_user_organizations_with_details(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
is_active: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get user's organizations with role and member count in SINGLE QUERY.
|
||||
Eliminates N+1 problem by using subquery for member counts.
|
||||
|
||||
Returns:
|
||||
List of dicts with organization, role, and member_count
|
||||
"""
|
||||
try:
|
||||
# Subquery to get member counts for each organization
|
||||
member_count_subq = (
|
||||
select(
|
||||
UserOrganization.organization_id,
|
||||
func.count(UserOrganization.user_id).label('member_count')
|
||||
)
|
||||
.where(UserOrganization.is_active == True)
|
||||
.group_by(UserOrganization.organization_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# Main query with JOIN to get org, role, and member count
|
||||
query = (
|
||||
select(
|
||||
Organization,
|
||||
UserOrganization.role,
|
||||
func.coalesce(member_count_subq.c.member_count, 0).label('member_count')
|
||||
)
|
||||
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
|
||||
.outerjoin(member_count_subq, Organization.id == member_count_subq.c.organization_id)
|
||||
.where(UserOrganization.user_id == user_id)
|
||||
)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.where(UserOrganization.is_active == is_active)
|
||||
|
||||
result = await db.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
return [
|
||||
{
|
||||
'organization': org,
|
||||
'role': role,
|
||||
'member_count': member_count
|
||||
}
|
||||
for org, role, member_count in rows
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user organizations with details: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_user_role_in_org(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
organization_id: UUID
|
||||
) -> Optional[OrganizationRole]:
|
||||
"""Get a user's role in a specific organization."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(UserOrganization).where(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id,
|
||||
UserOrganization.is_active == True
|
||||
)
|
||||
)
|
||||
)
|
||||
user_org = result.scalar_one_or_none()
|
||||
|
||||
return user_org.role if user_org else None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user role in org: {str(e)}")
|
||||
raise
|
||||
|
||||
async def is_user_org_owner(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
organization_id: UUID
|
||||
) -> bool:
|
||||
"""Check if a user is an owner of an organization."""
|
||||
role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
|
||||
return role == OrganizationRole.OWNER
|
||||
|
||||
async def is_user_org_admin(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
organization_id: UUID
|
||||
) -> bool:
|
||||
"""Check if a user is an owner or admin of an organization."""
|
||||
role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
|
||||
return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
|
||||
|
||||
|
||||
# Create a singleton instance for use across the application
|
||||
organization_async = CRUDOrganizationAsync(Organization)
|
||||
220
backend/app/crud/session.py
Normal file → Executable file
220
backend/app/crud/session.py
Normal file → Executable file
@@ -1,13 +1,14 @@
|
||||
"""
|
||||
CRUD operations for user sessions.
|
||||
Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns.
|
||||
"""
|
||||
import logging
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, select, update, delete, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.user_session import UserSession
|
||||
@@ -17,9 +18,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
"""CRUD operations for user sessions."""
|
||||
"""Async CRUD operations for user sessions."""
|
||||
|
||||
def get_by_jti(self, db: Session, *, jti: str) -> Optional[UserSession]:
|
||||
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]:
|
||||
"""
|
||||
Get session by refresh token JTI.
|
||||
|
||||
@@ -31,14 +32,15 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
UserSession if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
return db.query(UserSession).filter(
|
||||
UserSession.refresh_token_jti == jti
|
||||
).first()
|
||||
result = await db.execute(
|
||||
select(UserSession).where(UserSession.refresh_token_jti == jti)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting session by JTI {jti}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_active_by_jti(self, db: Session, *, jti: str) -> Optional[UserSession]:
|
||||
async def get_active_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]:
|
||||
"""
|
||||
Get active session by refresh token JTI.
|
||||
|
||||
@@ -50,30 +52,35 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
Active UserSession if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
return db.query(UserSession).filter(
|
||||
and_(
|
||||
UserSession.refresh_token_jti == jti,
|
||||
UserSession.is_active == True
|
||||
result = await db.execute(
|
||||
select(UserSession).where(
|
||||
and_(
|
||||
UserSession.refresh_token_jti == jti,
|
||||
UserSession.is_active == True
|
||||
)
|
||||
)
|
||||
).first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting active session by JTI {jti}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_user_sessions(
|
||||
async def get_user_sessions(
|
||||
self,
|
||||
db: Session,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str,
|
||||
active_only: bool = True
|
||||
active_only: bool = True,
|
||||
with_user: bool = False
|
||||
) -> List[UserSession]:
|
||||
"""
|
||||
Get all sessions for a user.
|
||||
Get all sessions for a user with optional eager loading.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
active_only: If True, return only active sessions
|
||||
with_user: If True, eager load user relationship to prevent N+1
|
||||
|
||||
Returns:
|
||||
List of UserSession objects
|
||||
@@ -82,19 +89,25 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
# Convert user_id string to UUID if needed
|
||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
query = db.query(UserSession).filter(UserSession.user_id == user_uuid)
|
||||
query = select(UserSession).where(UserSession.user_id == user_uuid)
|
||||
|
||||
# Add eager loading if requested to prevent N+1 queries
|
||||
if with_user:
|
||||
query = query.options(joinedload(UserSession.user))
|
||||
|
||||
if active_only:
|
||||
query = query.filter(UserSession.is_active == True)
|
||||
query = query.where(UserSession.is_active == True)
|
||||
|
||||
return query.order_by(UserSession.last_used_at.desc()).all()
|
||||
query = query.order_by(UserSession.last_used_at.desc())
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting sessions for user {user_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def create_session(
|
||||
async def create_session(
|
||||
self,
|
||||
db: Session,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
obj_in: SessionCreate
|
||||
) -> UserSession:
|
||||
@@ -126,8 +139,8 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
location_country=obj_in.location_country,
|
||||
)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
|
||||
logger.info(
|
||||
f"Session created for user {obj_in.user_id} from {obj_in.device_name} "
|
||||
@@ -136,11 +149,11 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
|
||||
return db_obj
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
await db.rollback()
|
||||
logger.error(f"Error creating session: {str(e)}", exc_info=True)
|
||||
raise ValueError(f"Failed to create session: {str(e)}")
|
||||
|
||||
def deactivate(self, db: Session, *, session_id: str) -> Optional[UserSession]:
|
||||
async def deactivate(self, db: AsyncSession, *, session_id: str) -> Optional[UserSession]:
|
||||
"""
|
||||
Deactivate a session (logout from device).
|
||||
|
||||
@@ -152,15 +165,15 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
Deactivated UserSession if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
session = self.get(db, id=session_id)
|
||||
session = await self.get(db, id=session_id)
|
||||
if not session:
|
||||
logger.warning(f"Session {session_id} not found for deactivation")
|
||||
return None
|
||||
|
||||
session.is_active = False
|
||||
db.add(session)
|
||||
db.commit()
|
||||
db.refresh(session)
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
|
||||
logger.info(
|
||||
f"Session {session_id} deactivated for user {session.user_id} "
|
||||
@@ -169,13 +182,13 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
|
||||
return session
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
await db.rollback()
|
||||
logger.error(f"Error deactivating session {session_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def deactivate_all_user_sessions(
|
||||
async def deactivate_all_user_sessions(
|
||||
self,
|
||||
db: Session,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str
|
||||
) -> int:
|
||||
@@ -193,26 +206,33 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
# Convert user_id string to UUID if needed
|
||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
count = db.query(UserSession).filter(
|
||||
and_(
|
||||
UserSession.user_id == user_uuid,
|
||||
UserSession.is_active == True
|
||||
stmt = (
|
||||
update(UserSession)
|
||||
.where(
|
||||
and_(
|
||||
UserSession.user_id == user_uuid,
|
||||
UserSession.is_active == True
|
||||
)
|
||||
)
|
||||
).update({"is_active": False})
|
||||
.values(is_active=False)
|
||||
)
|
||||
|
||||
db.commit()
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
count = result.rowcount
|
||||
|
||||
logger.info(f"Deactivated {count} sessions for user {user_id}")
|
||||
|
||||
return count
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
await db.rollback()
|
||||
logger.error(f"Error deactivating all sessions for user {user_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def update_last_used(
|
||||
async def update_last_used(
|
||||
self,
|
||||
db: Session,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
session: UserSession
|
||||
) -> UserSession:
|
||||
@@ -229,17 +249,17 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
try:
|
||||
session.last_used_at = datetime.now(timezone.utc)
|
||||
db.add(session)
|
||||
db.commit()
|
||||
db.refresh(session)
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
return session
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
await db.rollback()
|
||||
logger.error(f"Error updating last_used for session {session.id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def update_refresh_token(
|
||||
async def update_refresh_token(
|
||||
self,
|
||||
db: Session,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
session: UserSession,
|
||||
new_jti: str,
|
||||
@@ -264,22 +284,24 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
session.expires_at = new_expires_at
|
||||
session.last_used_at = datetime.now(timezone.utc)
|
||||
db.add(session)
|
||||
db.commit()
|
||||
db.refresh(session)
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
return session
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
await db.rollback()
|
||||
logger.error(f"Error updating refresh token for session {session.id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def cleanup_expired(self, db: Session, *, keep_days: int = 30) -> int:
|
||||
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
|
||||
"""
|
||||
Clean up expired sessions.
|
||||
Clean up expired sessions using optimized bulk DELETE.
|
||||
|
||||
Deletes sessions that are:
|
||||
- Expired AND inactive
|
||||
- Older than keep_days
|
||||
|
||||
Uses single DELETE query instead of N individual deletes for efficiency.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
keep_days: Keep inactive sessions for this many days (for audit)
|
||||
@@ -289,31 +311,87 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
"""
|
||||
try:
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Delete sessions that are:
|
||||
# 1. Expired (expires_at < now) AND inactive
|
||||
# AND
|
||||
# 2. Older than keep_days
|
||||
count = db.query(UserSession).filter(
|
||||
# Use bulk DELETE with WHERE clause - single query
|
||||
stmt = delete(UserSession).where(
|
||||
and_(
|
||||
UserSession.is_active == False,
|
||||
UserSession.expires_at < datetime.now(timezone.utc),
|
||||
UserSession.expires_at < now,
|
||||
UserSession.created_at < cutoff_date
|
||||
)
|
||||
).delete()
|
||||
)
|
||||
|
||||
db.commit()
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
count = result.rowcount
|
||||
|
||||
if count > 0:
|
||||
logger.info(f"Cleaned up {count} expired sessions")
|
||||
logger.info(f"Cleaned up {count} expired sessions using bulk DELETE")
|
||||
|
||||
return count
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
await db.rollback()
|
||||
logger.error(f"Error cleaning up expired sessions: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_user_session_count(self, db: Session, *, user_id: str) -> int:
|
||||
async def cleanup_expired_for_user(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str
|
||||
) -> int:
|
||||
"""
|
||||
Clean up expired and inactive sessions for a specific user.
|
||||
|
||||
Uses single bulk DELETE query for efficiency instead of N individual deletes.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID to cleanup sessions for
|
||||
|
||||
Returns:
|
||||
Number of sessions deleted
|
||||
"""
|
||||
try:
|
||||
# Validate UUID
|
||||
try:
|
||||
uuid_obj = uuid.UUID(user_id)
|
||||
except (ValueError, AttributeError):
|
||||
logger.error(f"Invalid UUID format: {user_id}")
|
||||
raise ValueError(f"Invalid user ID format: {user_id}")
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Use bulk DELETE with WHERE clause - single query
|
||||
stmt = delete(UserSession).where(
|
||||
and_(
|
||||
UserSession.user_id == uuid_obj,
|
||||
UserSession.is_active == False,
|
||||
UserSession.expires_at < now
|
||||
)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
count = result.rowcount
|
||||
|
||||
if count > 0:
|
||||
logger.info(
|
||||
f"Cleaned up {count} expired sessions for user {user_id} using bulk DELETE"
|
||||
)
|
||||
|
||||
return count
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error cleaning up expired sessions for user {user_id}: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int:
|
||||
"""
|
||||
Get count of active sessions for a user.
|
||||
|
||||
@@ -325,12 +403,18 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
Number of active sessions
|
||||
"""
|
||||
try:
|
||||
return db.query(UserSession).filter(
|
||||
and_(
|
||||
UserSession.user_id == user_id,
|
||||
UserSession.is_active == True
|
||||
# Convert user_id string to UUID if needed
|
||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
result = await db.execute(
|
||||
select(func.count(UserSession.id)).where(
|
||||
and_(
|
||||
UserSession.user_id == user_uuid,
|
||||
UserSession.is_active == True
|
||||
)
|
||||
)
|
||||
).count()
|
||||
)
|
||||
return result.scalar_one()
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting sessions for user {user_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
@@ -1,424 +0,0 @@
|
||||
"""
|
||||
Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns.
|
||||
"""
|
||||
import logging
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, select, update, delete, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.crud.base_async import CRUDBaseAsync
|
||||
from app.models.user_session import UserSession
|
||||
from app.schemas.sessions import SessionCreate, SessionUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDSessionAsync(CRUDBaseAsync[UserSession, SessionCreate, SessionUpdate]):
|
||||
"""Async CRUD operations for user sessions."""
|
||||
|
||||
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]:
|
||||
"""
|
||||
Get session by refresh token JTI.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
jti: Refresh token JWT ID
|
||||
|
||||
Returns:
|
||||
UserSession if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(UserSession).where(UserSession.refresh_token_jti == jti)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting session by JTI {jti}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_active_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]:
|
||||
"""
|
||||
Get active session by refresh token JTI.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
jti: Refresh token JWT ID
|
||||
|
||||
Returns:
|
||||
Active UserSession if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(UserSession).where(
|
||||
and_(
|
||||
UserSession.refresh_token_jti == jti,
|
||||
UserSession.is_active == True
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting active session by JTI {jti}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_user_sessions(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str,
|
||||
active_only: bool = True,
|
||||
with_user: bool = False
|
||||
) -> List[UserSession]:
|
||||
"""
|
||||
Get all sessions for a user with optional eager loading.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
active_only: If True, return only active sessions
|
||||
with_user: If True, eager load user relationship to prevent N+1
|
||||
|
||||
Returns:
|
||||
List of UserSession objects
|
||||
"""
|
||||
try:
|
||||
# Convert user_id string to UUID if needed
|
||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
query = select(UserSession).where(UserSession.user_id == user_uuid)
|
||||
|
||||
# Add eager loading if requested to prevent N+1 queries
|
||||
if with_user:
|
||||
query = query.options(joinedload(UserSession.user))
|
||||
|
||||
if active_only:
|
||||
query = query.where(UserSession.is_active == True)
|
||||
|
||||
query = query.order_by(UserSession.last_used_at.desc())
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting sessions for user {user_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def create_session(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
obj_in: SessionCreate
|
||||
) -> UserSession:
|
||||
"""
|
||||
Create a new user session.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
obj_in: SessionCreate schema with session data
|
||||
|
||||
Returns:
|
||||
Created UserSession
|
||||
|
||||
Raises:
|
||||
ValueError: If session creation fails
|
||||
"""
|
||||
try:
|
||||
db_obj = UserSession(
|
||||
user_id=obj_in.user_id,
|
||||
refresh_token_jti=obj_in.refresh_token_jti,
|
||||
device_name=obj_in.device_name,
|
||||
device_id=obj_in.device_id,
|
||||
ip_address=obj_in.ip_address,
|
||||
user_agent=obj_in.user_agent,
|
||||
last_used_at=obj_in.last_used_at,
|
||||
expires_at=obj_in.expires_at,
|
||||
is_active=True,
|
||||
location_city=obj_in.location_city,
|
||||
location_country=obj_in.location_country,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
|
||||
logger.info(
|
||||
f"Session created for user {obj_in.user_id} from {obj_in.device_name} "
|
||||
f"(IP: {obj_in.ip_address})"
|
||||
)
|
||||
|
||||
return db_obj
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error creating session: {str(e)}", exc_info=True)
|
||||
raise ValueError(f"Failed to create session: {str(e)}")
|
||||
|
||||
async def deactivate(self, db: AsyncSession, *, session_id: str) -> Optional[UserSession]:
|
||||
"""
|
||||
Deactivate a session (logout from device).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
session_id: Session UUID
|
||||
|
||||
Returns:
|
||||
Deactivated UserSession if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
session = await self.get(db, id=session_id)
|
||||
if not session:
|
||||
logger.warning(f"Session {session_id} not found for deactivation")
|
||||
return None
|
||||
|
||||
session.is_active = False
|
||||
db.add(session)
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
|
||||
logger.info(
|
||||
f"Session {session_id} deactivated for user {session.user_id} "
|
||||
f"({session.device_name})"
|
||||
)
|
||||
|
||||
return session
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error deactivating session {session_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def deactivate_all_user_sessions(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str
|
||||
) -> int:
|
||||
"""
|
||||
Deactivate all active sessions for a user (logout from all devices).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Number of sessions deactivated
|
||||
"""
|
||||
try:
|
||||
# Convert user_id string to UUID if needed
|
||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
stmt = (
|
||||
update(UserSession)
|
||||
.where(
|
||||
and_(
|
||||
UserSession.user_id == user_uuid,
|
||||
UserSession.is_active == True
|
||||
)
|
||||
)
|
||||
.values(is_active=False)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
count = result.rowcount
|
||||
|
||||
logger.info(f"Deactivated {count} sessions for user {user_id}")
|
||||
|
||||
return count
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error deactivating all sessions for user {user_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def update_last_used(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
session: UserSession
|
||||
) -> UserSession:
|
||||
"""
|
||||
Update the last_used_at timestamp for a session.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
session: UserSession object
|
||||
|
||||
Returns:
|
||||
Updated UserSession
|
||||
"""
|
||||
try:
|
||||
session.last_used_at = datetime.now(timezone.utc)
|
||||
db.add(session)
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
return session
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error updating last_used for session {session.id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def update_refresh_token(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
session: UserSession,
|
||||
new_jti: str,
|
||||
new_expires_at: datetime
|
||||
) -> UserSession:
|
||||
"""
|
||||
Update session with new refresh token JTI and expiration.
|
||||
|
||||
Called during token refresh.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
session: UserSession object
|
||||
new_jti: New refresh token JTI
|
||||
new_expires_at: New expiration datetime
|
||||
|
||||
Returns:
|
||||
Updated UserSession
|
||||
"""
|
||||
try:
|
||||
session.refresh_token_jti = new_jti
|
||||
session.expires_at = new_expires_at
|
||||
session.last_used_at = datetime.now(timezone.utc)
|
||||
db.add(session)
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
return session
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error updating refresh token for session {session.id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
|
||||
"""
|
||||
Clean up expired sessions using optimized bulk DELETE.
|
||||
|
||||
Deletes sessions that are:
|
||||
- Expired AND inactive
|
||||
- Older than keep_days
|
||||
|
||||
Uses single DELETE query instead of N individual deletes for efficiency.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
keep_days: Keep inactive sessions for this many days (for audit)
|
||||
|
||||
Returns:
|
||||
Number of sessions deleted
|
||||
"""
|
||||
try:
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Use bulk DELETE with WHERE clause - single query
|
||||
stmt = delete(UserSession).where(
|
||||
and_(
|
||||
UserSession.is_active == False,
|
||||
UserSession.expires_at < now,
|
||||
UserSession.created_at < cutoff_date
|
||||
)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
count = result.rowcount
|
||||
|
||||
if count > 0:
|
||||
logger.info(f"Cleaned up {count} expired sessions using bulk DELETE")
|
||||
|
||||
return count
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error cleaning up expired sessions: {str(e)}")
|
||||
raise
|
||||
|
||||
async def cleanup_expired_for_user(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str
|
||||
) -> int:
|
||||
"""
|
||||
Clean up expired and inactive sessions for a specific user.
|
||||
|
||||
Uses single bulk DELETE query for efficiency instead of N individual deletes.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID to cleanup sessions for
|
||||
|
||||
Returns:
|
||||
Number of sessions deleted
|
||||
"""
|
||||
try:
|
||||
# Validate UUID
|
||||
try:
|
||||
uuid_obj = uuid.UUID(user_id)
|
||||
except (ValueError, AttributeError):
|
||||
logger.error(f"Invalid UUID format: {user_id}")
|
||||
raise ValueError(f"Invalid user ID format: {user_id}")
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Use bulk DELETE with WHERE clause - single query
|
||||
stmt = delete(UserSession).where(
|
||||
and_(
|
||||
UserSession.user_id == uuid_obj,
|
||||
UserSession.is_active == False,
|
||||
UserSession.expires_at < now
|
||||
)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
count = result.rowcount
|
||||
|
||||
if count > 0:
|
||||
logger.info(
|
||||
f"Cleaned up {count} expired sessions for user {user_id} using bulk DELETE"
|
||||
)
|
||||
|
||||
return count
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error cleaning up expired sessions for user {user_id}: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int:
|
||||
"""
|
||||
Get count of active sessions for a user.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Number of active sessions
|
||||
"""
|
||||
try:
|
||||
# Convert user_id string to UUID if needed
|
||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
result = await db.execute(
|
||||
select(func.count(UserSession.id)).where(
|
||||
and_(
|
||||
UserSession.user_id == user_uuid,
|
||||
UserSession.is_active == True
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one()
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting sessions for user {user_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
# Create singleton instance
|
||||
session_async = CRUDSessionAsync(UserSession)
|
||||
183
backend/app/crud/user.py
Normal file → Executable file
183
backend/app/crud/user.py
Normal file → Executable file
@@ -1,12 +1,15 @@
|
||||
# app/crud/user.py
|
||||
# app/crud/user_async.py
|
||||
"""Async CRUD operations for User model using SQLAlchemy 2.0 patterns."""
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Union, Dict, Any, List, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import or_, asc, desc
|
||||
from sqlalchemy import or_, select, update
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.auth import get_password_hash
|
||||
from app.core.auth import get_password_hash_async
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
@@ -15,15 +18,28 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
def get_by_email(self, db: Session, *, email: str) -> Optional[User]:
|
||||
return db.query(User).filter(User.email == email).first()
|
||||
"""Async CRUD operations for User model."""
|
||||
|
||||
def create(self, db: Session, *, obj_in: UserCreate) -> User:
|
||||
"""Create a new user with password hashing and error handling."""
|
||||
async def get_by_email(self, db: AsyncSession, *, email: str) -> Optional[User]:
|
||||
"""Get user by email address."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(User).where(User.email == email)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user by email {email}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
|
||||
"""Create a new user with async password hashing and error handling."""
|
||||
try:
|
||||
# Hash password asynchronously to avoid blocking event loop
|
||||
password_hash = await get_password_hash_async(obj_in.password)
|
||||
|
||||
db_obj = User(
|
||||
email=obj_in.email,
|
||||
password_hash=get_password_hash(obj_in.password),
|
||||
password_hash=password_hash,
|
||||
first_name=obj_in.first_name,
|
||||
last_name=obj_in.last_name,
|
||||
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None,
|
||||
@@ -31,11 +47,11 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
preferences={}
|
||||
)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
if "email" in error_msg.lower():
|
||||
logger.warning(f"Duplicate email attempted: {obj_in.email}")
|
||||
@@ -43,32 +59,34 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
logger.error(f"Integrity error creating user: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error creating user: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def update(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
db_obj: User,
|
||||
obj_in: Union[UserUpdate, Dict[str, Any]]
|
||||
async def update(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
db_obj: User,
|
||||
obj_in: Union[UserUpdate, Dict[str, Any]]
|
||||
) -> User:
|
||||
"""Update user with async password hashing if password is updated."""
|
||||
if isinstance(obj_in, dict):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
|
||||
# Handle password separately if it exists in update data
|
||||
# Hash password asynchronously to avoid blocking event loop
|
||||
if "password" in update_data:
|
||||
update_data["password_hash"] = get_password_hash(update_data["password"])
|
||||
update_data["password_hash"] = await get_password_hash_async(update_data["password"])
|
||||
del update_data["password"]
|
||||
|
||||
return super().update(db, db_obj=db_obj, obj_in=update_data)
|
||||
return await super().update(db, db_obj=db_obj, obj_in=update_data)
|
||||
|
||||
def get_multi_with_total(
|
||||
async def get_multi_with_total(
|
||||
self,
|
||||
db: Session,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
@@ -102,16 +120,16 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
|
||||
try:
|
||||
# Build base query
|
||||
query = db.query(User)
|
||||
query = select(User)
|
||||
|
||||
# Exclude soft-deleted users
|
||||
query = query.filter(User.deleted_at.is_(None))
|
||||
query = query.where(User.deleted_at.is_(None))
|
||||
|
||||
# Apply filters
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
if hasattr(User, field) and value is not None:
|
||||
query = query.filter(getattr(User, field) == value)
|
||||
query = query.where(getattr(User, field) == value)
|
||||
|
||||
# Apply search
|
||||
if search:
|
||||
@@ -120,21 +138,26 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
User.first_name.ilike(f"%{search}%"),
|
||||
User.last_name.ilike(f"%{search}%")
|
||||
)
|
||||
query = query.filter(search_filter)
|
||||
query = query.where(search_filter)
|
||||
|
||||
# Get total count
|
||||
total = query.count()
|
||||
from sqlalchemy import func
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply sorting
|
||||
if sort_by and hasattr(User, sort_by):
|
||||
sort_column = getattr(User, sort_by)
|
||||
if sort_order.lower() == "desc":
|
||||
query = query.order_by(desc(sort_column))
|
||||
query = query.order_by(sort_column.desc())
|
||||
else:
|
||||
query = query.order_by(asc(sort_column))
|
||||
query = query.order_by(sort_column.asc())
|
||||
|
||||
# Apply pagination
|
||||
users = query.offset(skip).limit(limit).all()
|
||||
query = query.offset(skip).limit(limit)
|
||||
result = await db.execute(query)
|
||||
users = list(result.scalars().all())
|
||||
|
||||
return users, total
|
||||
|
||||
@@ -142,12 +165,108 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
logger.error(f"Error retrieving paginated users: {str(e)}")
|
||||
raise
|
||||
|
||||
async def bulk_update_status(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_ids: List[UUID],
|
||||
is_active: bool
|
||||
) -> int:
|
||||
"""
|
||||
Bulk update is_active status for multiple users.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_ids: List of user IDs to update
|
||||
is_active: New active status
|
||||
|
||||
Returns:
|
||||
Number of users updated
|
||||
"""
|
||||
try:
|
||||
if not user_ids:
|
||||
return 0
|
||||
|
||||
# Use UPDATE with WHERE IN for efficiency
|
||||
stmt = (
|
||||
update(User)
|
||||
.where(User.id.in_(user_ids))
|
||||
.where(User.deleted_at.is_(None)) # Don't update deleted users
|
||||
.values(is_active=is_active, updated_at=datetime.now(timezone.utc))
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
updated_count = result.rowcount
|
||||
logger.info(f"Bulk updated {updated_count} users to is_active={is_active}")
|
||||
return updated_count
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error bulk updating user status: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def bulk_soft_delete(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_ids: List[UUID],
|
||||
exclude_user_id: Optional[UUID] = None
|
||||
) -> int:
|
||||
"""
|
||||
Bulk soft delete multiple users.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_ids: List of user IDs to delete
|
||||
exclude_user_id: Optional user ID to exclude (e.g., the admin performing the action)
|
||||
|
||||
Returns:
|
||||
Number of users deleted
|
||||
"""
|
||||
try:
|
||||
if not user_ids:
|
||||
return 0
|
||||
|
||||
# Remove excluded user from list
|
||||
filtered_ids = [uid for uid in user_ids if uid != exclude_user_id]
|
||||
|
||||
if not filtered_ids:
|
||||
return 0
|
||||
|
||||
# Use UPDATE with WHERE IN for efficiency
|
||||
stmt = (
|
||||
update(User)
|
||||
.where(User.id.in_(filtered_ids))
|
||||
.where(User.deleted_at.is_(None)) # Don't re-delete already deleted users
|
||||
.values(
|
||||
deleted_at=datetime.now(timezone.utc),
|
||||
is_active=False,
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
deleted_count = result.rowcount
|
||||
logger.info(f"Bulk soft deleted {deleted_count} users")
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error bulk deleting users: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def is_active(self, user: User) -> bool:
|
||||
"""Check if user is active."""
|
||||
return user.is_active
|
||||
|
||||
def is_superuser(self, user: User) -> bool:
|
||||
"""Check if user is a superuser."""
|
||||
return user.is_superuser
|
||||
|
||||
|
||||
# Create a singleton instance for use across the application
|
||||
user = CRUDUser(User)
|
||||
user = CRUDUser(User)
|
||||
|
||||
@@ -1,272 +0,0 @@
|
||||
# app/crud/user_async.py
|
||||
"""Async CRUD operations for User model using SQLAlchemy 2.0 patterns."""
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Union, Dict, Any, List, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import or_, select, update
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.auth import get_password_hash_async
|
||||
from app.crud.base_async import CRUDBaseAsync
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDUserAsync(CRUDBaseAsync[User, UserCreate, UserUpdate]):
|
||||
"""Async CRUD operations for User model."""
|
||||
|
||||
async def get_by_email(self, db: AsyncSession, *, email: str) -> Optional[User]:
|
||||
"""Get user by email address."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(User).where(User.email == email)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user by email {email}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
|
||||
"""Create a new user with async password hashing and error handling."""
|
||||
try:
|
||||
# Hash password asynchronously to avoid blocking event loop
|
||||
password_hash = await get_password_hash_async(obj_in.password)
|
||||
|
||||
db_obj = User(
|
||||
email=obj_in.email,
|
||||
password_hash=password_hash,
|
||||
first_name=obj_in.first_name,
|
||||
last_name=obj_in.last_name,
|
||||
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None,
|
||||
is_superuser=obj_in.is_superuser if hasattr(obj_in, 'is_superuser') else False,
|
||||
preferences={}
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
if "email" in error_msg.lower():
|
||||
logger.warning(f"Duplicate email attempted: {obj_in.email}")
|
||||
raise ValueError(f"User with email {obj_in.email} already exists")
|
||||
logger.error(f"Integrity error creating user: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error creating user: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
db_obj: User,
|
||||
obj_in: Union[UserUpdate, Dict[str, Any]]
|
||||
) -> User:
|
||||
"""Update user with async password hashing if password is updated."""
|
||||
if isinstance(obj_in, dict):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
|
||||
# Handle password separately if it exists in update data
|
||||
# Hash password asynchronously to avoid blocking event loop
|
||||
if "password" in update_data:
|
||||
update_data["password_hash"] = await get_password_hash_async(update_data["password"])
|
||||
del update_data["password"]
|
||||
|
||||
return await super().update(db, db_obj=db_obj, obj_in=update_data)
|
||||
|
||||
async def get_multi_with_total(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
sort_by: Optional[str] = None,
|
||||
sort_order: str = "asc",
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
search: Optional[str] = None
|
||||
) -> Tuple[List[User], int]:
|
||||
"""
|
||||
Get multiple users with total count, filtering, sorting, and search.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
skip: Number of records to skip
|
||||
limit: Maximum number of records to return
|
||||
sort_by: Field name to sort by
|
||||
sort_order: Sort order ("asc" or "desc")
|
||||
filters: Dictionary of filters (field_name: value)
|
||||
search: Search term to match against email, first_name, last_name
|
||||
|
||||
Returns:
|
||||
Tuple of (users list, total count)
|
||||
"""
|
||||
# Validate pagination
|
||||
if skip < 0:
|
||||
raise ValueError("skip must be non-negative")
|
||||
if limit < 0:
|
||||
raise ValueError("limit must be non-negative")
|
||||
if limit > 1000:
|
||||
raise ValueError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
# Build base query
|
||||
query = select(User)
|
||||
|
||||
# Exclude soft-deleted users
|
||||
query = query.where(User.deleted_at.is_(None))
|
||||
|
||||
# Apply filters
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
if hasattr(User, field) and value is not None:
|
||||
query = query.where(getattr(User, field) == value)
|
||||
|
||||
# Apply search
|
||||
if search:
|
||||
search_filter = or_(
|
||||
User.email.ilike(f"%{search}%"),
|
||||
User.first_name.ilike(f"%{search}%"),
|
||||
User.last_name.ilike(f"%{search}%")
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
# Get total count
|
||||
from sqlalchemy import func
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply sorting
|
||||
if sort_by and hasattr(User, sort_by):
|
||||
sort_column = getattr(User, sort_by)
|
||||
if sort_order.lower() == "desc":
|
||||
query = query.order_by(sort_column.desc())
|
||||
else:
|
||||
query = query.order_by(sort_column.asc())
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
result = await db.execute(query)
|
||||
users = list(result.scalars().all())
|
||||
|
||||
return users, total
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving paginated users: {str(e)}")
|
||||
raise
|
||||
|
||||
async def bulk_update_status(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_ids: List[UUID],
|
||||
is_active: bool
|
||||
) -> int:
|
||||
"""
|
||||
Bulk update is_active status for multiple users.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_ids: List of user IDs to update
|
||||
is_active: New active status
|
||||
|
||||
Returns:
|
||||
Number of users updated
|
||||
"""
|
||||
try:
|
||||
if not user_ids:
|
||||
return 0
|
||||
|
||||
# Use UPDATE with WHERE IN for efficiency
|
||||
stmt = (
|
||||
update(User)
|
||||
.where(User.id.in_(user_ids))
|
||||
.where(User.deleted_at.is_(None)) # Don't update deleted users
|
||||
.values(is_active=is_active, updated_at=datetime.now(timezone.utc))
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
updated_count = result.rowcount
|
||||
logger.info(f"Bulk updated {updated_count} users to is_active={is_active}")
|
||||
return updated_count
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error bulk updating user status: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def bulk_soft_delete(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_ids: List[UUID],
|
||||
exclude_user_id: Optional[UUID] = None
|
||||
) -> int:
|
||||
"""
|
||||
Bulk soft delete multiple users.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_ids: List of user IDs to delete
|
||||
exclude_user_id: Optional user ID to exclude (e.g., the admin performing the action)
|
||||
|
||||
Returns:
|
||||
Number of users deleted
|
||||
"""
|
||||
try:
|
||||
if not user_ids:
|
||||
return 0
|
||||
|
||||
# Remove excluded user from list
|
||||
filtered_ids = [uid for uid in user_ids if uid != exclude_user_id]
|
||||
|
||||
if not filtered_ids:
|
||||
return 0
|
||||
|
||||
# Use UPDATE with WHERE IN for efficiency
|
||||
stmt = (
|
||||
update(User)
|
||||
.where(User.id.in_(filtered_ids))
|
||||
.where(User.deleted_at.is_(None)) # Don't re-delete already deleted users
|
||||
.values(
|
||||
deleted_at=datetime.now(timezone.utc),
|
||||
is_active=False,
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
deleted_count = result.rowcount
|
||||
logger.info(f"Bulk soft deleted {deleted_count} users")
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error bulk deleting users: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def is_active(self, user: User) -> bool:
|
||||
"""Check if user is active."""
|
||||
return user.is_active
|
||||
|
||||
def is_superuser(self, user: User) -> bool:
|
||||
"""Check if user is a superuser."""
|
||||
return user.is_superuser
|
||||
|
||||
|
||||
# Create a singleton instance for use across the application
|
||||
user_async = CRUDUserAsync(User)
|
||||
@@ -1,78 +0,0 @@
|
||||
# app/init_db.py
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import engine
|
||||
from app.crud.user import user as user_crud
|
||||
from app.schemas.users import UserCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def init_db(db: Session) -> Optional[UserCreate]:
|
||||
"""
|
||||
Initialize database with first superuser if settings are configured and user doesn't exist.
|
||||
|
||||
Returns:
|
||||
The created or existing superuser, or None if creation fails
|
||||
"""
|
||||
# Use default values if not set in environment variables
|
||||
superuser_email = settings.FIRST_SUPERUSER_EMAIL or "admin@example.com"
|
||||
superuser_password = settings.FIRST_SUPERUSER_PASSWORD or "Admin123!Change"
|
||||
|
||||
if not settings.FIRST_SUPERUSER_EMAIL or not settings.FIRST_SUPERUSER_PASSWORD:
|
||||
logger.warning(
|
||||
"First superuser credentials not configured in settings. "
|
||||
f"Using defaults: {superuser_email}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Check if superuser already exists
|
||||
existing_user = user_crud.get_by_email(db, email=superuser_email)
|
||||
|
||||
if existing_user:
|
||||
logger.info(f"Superuser already exists: {existing_user.email}")
|
||||
return existing_user
|
||||
|
||||
# Create superuser if doesn't exist
|
||||
user_in = UserCreate(
|
||||
email=superuser_email,
|
||||
password=superuser_password,
|
||||
first_name="Admin",
|
||||
last_name="User",
|
||||
is_superuser=True
|
||||
)
|
||||
|
||||
user = user_crud.create(db, obj_in=user_in)
|
||||
logger.info(f"Created first superuser: {user.email}")
|
||||
|
||||
return user
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing database: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Configure logging to show info logs
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
with Session(engine) as session:
|
||||
try:
|
||||
user = init_db(session)
|
||||
if user:
|
||||
print(f"✓ Database initialized successfully")
|
||||
print(f"✓ Superuser: {user.email}")
|
||||
else:
|
||||
print("✗ Failed to initialize database")
|
||||
except Exception as e:
|
||||
print(f"✗ Error initializing database: {e}")
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
@@ -13,7 +13,7 @@ from slowapi.util import get_remote_address
|
||||
|
||||
from app.api.main import api_router
|
||||
from app.core.config import settings
|
||||
from app.core.database_async import check_database_health
|
||||
from app.core.database import check_database_health
|
||||
from app.core.exceptions import (
|
||||
APIException,
|
||||
api_exception_handler,
|
||||
|
||||
@@ -6,8 +6,8 @@ This service runs periodically to remove old session records from the database.
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from app.core.database_async import AsyncSessionLocal
|
||||
from app.crud.session_async import session_async as session_crud
|
||||
from app.core.database import SessionLocal
|
||||
from app.crud.session import session as session_crud
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -29,7 +29,7 @@ async def cleanup_expired_sessions(keep_days: int = 30) -> int:
|
||||
"""
|
||||
logger.info("Starting session cleanup job...")
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
async with SessionLocal() as db:
|
||||
try:
|
||||
# Use CRUD method to cleanup
|
||||
count = await session_crud.cleanup_expired(db, keep_days=keep_days)
|
||||
@@ -50,7 +50,7 @@ async def get_session_statistics() -> dict:
|
||||
Returns:
|
||||
Dictionary with session stats
|
||||
"""
|
||||
async with AsyncSessionLocal() as db:
|
||||
async with SessionLocal() as db:
|
||||
try:
|
||||
from app.models.user_session import UserSession
|
||||
from sqlalchemy import select, func
|
||||
|
||||
Reference in New Issue
Block a user